From 1051624878ca61db2209d82870065c6f3e5db5d0 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 13 Sep 2024 03:37:11 -0400 Subject: [PATCH 01/31] wip: ir redesign --- Cargo.lock | 53 ++ Cargo.toml | 1 + docs/design/ir.md | 839 ++++++++++++++++++++++++ hir-macros/src/lib.rs | 23 + hir-macros/src/op.rs | 268 ++++++++ hir2/Cargo.toml | 63 ++ hir2/src/core.rs | 33 + hir2/src/core/attribute.rs | 273 ++++++++ hir2/src/core/block.rs | 78 +++ hir2/src/core/component.rs | 0 hir2/src/core/context.rs | 94 +++ hir2/src/core/entity.rs | 679 +++++++++++++++++++ hir2/src/core/entity/list.rs | 540 +++++++++++++++ hir2/src/core/function.rs | 308 +++++++++ hir2/src/core/interface.rs | 0 hir2/src/core/module.rs | 43 ++ hir2/src/core/op.rs | 0 hir2/src/core/operation.rs | 336 ++++++++++ hir2/src/core/region.rs | 19 + hir2/src/core/symbol_table.rs | 54 ++ hir2/src/core/traits.rs | 30 + hir2/src/core/traits/multitrait.rs | 109 +++ hir2/src/core/types.rs | 1 + hir2/src/core/usable.rs | 35 + hir2/src/core/value.rs | 166 +++++ hir2/src/dialects/hir.rs | 0 hir2/src/lib.rs | 6 + hir2/src/ops/binary.rs | 60 ++ hir2/src/ops/call.rs | 6 + hir2/src/ops/cast.rs | 33 + hir2/src/ops/control_flow.rs | 50 ++ hir2/src/ops/global_value.rs | 35 + hir2/src/ops/inline_asm.rs | 0 hir2/src/ops/mem.rs | 38 ++ hir2/src/ops/mod.rs | 16 + hir2/src/ops/primop.rs | 22 + hir2/src/ops/ret.rs | 10 + hir2/src/ops/structured_control_flow.rs | 56 ++ hir2/src/ops/unary.rs | 37 ++ hir2/src/unsafe_ref.rs | 101 +++ 40 files changed, 4515 insertions(+) create mode 100644 docs/design/ir.md create mode 100644 hir-macros/src/op.rs create mode 100644 hir2/Cargo.toml create mode 100644 hir2/src/core.rs create mode 100644 hir2/src/core/attribute.rs create mode 100644 hir2/src/core/block.rs create mode 100644 hir2/src/core/component.rs create mode 100644 hir2/src/core/context.rs create mode 100644 hir2/src/core/entity.rs create mode 100644 hir2/src/core/entity/list.rs create mode 100644 hir2/src/core/function.rs create mode 100644 hir2/src/core/interface.rs create mode 100644 hir2/src/core/module.rs create mode 100644 hir2/src/core/op.rs create mode 100644 hir2/src/core/operation.rs create mode 100644 hir2/src/core/region.rs create mode 100644 hir2/src/core/symbol_table.rs create mode 100644 hir2/src/core/traits.rs create mode 100644 hir2/src/core/traits/multitrait.rs create mode 100644 hir2/src/core/types.rs create mode 100644 hir2/src/core/usable.rs create mode 100644 hir2/src/core/value.rs create mode 100644 hir2/src/dialects/hir.rs create mode 100644 hir2/src/lib.rs create mode 100644 hir2/src/ops/binary.rs create mode 100644 hir2/src/ops/call.rs create mode 100644 hir2/src/ops/cast.rs create mode 100644 hir2/src/ops/control_flow.rs create mode 100644 hir2/src/ops/global_value.rs create mode 100644 hir2/src/ops/inline_asm.rs create mode 100644 hir2/src/ops/mem.rs create mode 100644 hir2/src/ops/mod.rs create mode 100644 hir2/src/ops/primop.rs create mode 100644 hir2/src/ops/ret.rs create mode 100644 hir2/src/ops/structured_control_flow.rs create mode 100644 hir2/src/ops/unary.rs create mode 100644 hir2/src/unsafe_ref.rs diff --git a/Cargo.lock b/Cargo.lock index fa2b23a23..929fd3dfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -537,6 +537,15 @@ dependencies = [ "constant_time_eq", ] +[[package]] +name = "blink-alloc" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e669f146bb8b2327006ed94c69cf78c8ec81c100192564654230a40b4f091d82" +dependencies = [ + "allocator-api2", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -1361,6 +1370,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + [[package]] name = "ecdsa" version = "0.16.9" @@ -3404,6 +3419,44 @@ dependencies = [ "smallvec", ] +[[package]] +name = "midenc-hir2" +version = "0.0.6" +dependencies = [ + "anyhow", + "blink-alloc", + "cranelift-entity", + "derive_more", + "downcast-rs", + "either", + "indexmap 2.2.6", + "intrusive-collections", + "inventory", + "lalrpop", + "lalrpop-util", + "log", + "miden-assembly", + "miden-core", + "miden-thiserror", + "midenc-hir-macros", + "midenc-hir-symbol", + "midenc-hir-type", + "midenc-session", + "num-bigint", + "num-traits 0.2.19", + "paste", + "petgraph", + "pretty_assertions", + "rustc-demangle", + "rustc-hash 1.1.0", + "serde 1.0.208", + "serde_bytes", + "serde_repr", + "smallvec", + "typed-arena", + "unicode-width", +] + [[package]] name = "midenc-session" version = "0.0.6" diff --git a/Cargo.toml b/Cargo.toml index d1c632847..ba0cbd222 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "codegen/*", "frontend-wasm", "hir", + "hir2", "hir-analysis", "hir-macros", "hir-symbol", diff --git a/docs/design/ir.md b/docs/design/ir.md new file mode 100644 index 000000000..6740b8109 --- /dev/null +++ b/docs/design/ir.md @@ -0,0 +1,839 @@ +# High-Level Intermediate Representation (HIR) + +This document describes the concepts, usage, and overall structure of the intermediate +representation used by `midenc`. + +## Introduction + +TODO + +## Concepts + +### Components + +A _component_ is a named entity that encapsulates one or more [_interfaces_](#interfaces), and comes +in two forms: + +* An _executable_ component, which has a statically-defined entrypoint, a function which initializes +and executes a program encapsulated by the component. +* A _library_ component, which exports one or more interfaces, and can be used as a dependency of +other components. + +We also commonly refer to executable components as _programs_, and library components as _libraries_, +which also correspond to the equivalent concepts in Miden Assembly. However, components are a more +general abstraction over programs and libraries, where the distinction is mostly one of intended +use and/or convention. + +Components can have zero or more dependencies, which are expressed in the form of interfaces that +they require instances of at runtime. Thus any component that provides the interface can be used +to satisfy the dependency. + +A component _instance_ refers to a component that has had all of its dependencies resolved +concretely, and is thus fully-defined. + +A component _definition_ specifies four things: + +1. The name of the component +2. The interfaces it imports +3. The interfaces it exports +4. The [_modules_](#modules) which implement the exported interfaces concretely + +### Interfaces + +An _interface_ is a named entity that describes one or more [_functions_](#functions) that it +exports. Conceptually, an _interface_ loosely corresponds to the notion of a module, in that both +a module and an interface define a namespace, in which one or more functions are exported. + +However, an _interface_, unlike a module, is abstract, and does not have any internal structure. +It is more like a trait, in that it abstractly represents a set of named behaviors implemented by +some [component](#components). + +### Modules + +A module is primarily two things: + +1. A container for one or more functions belonging to a common namespace +2. A concrete implementation of one or more [interfaces](#interfaces) + +Functions within a module may be exported, and so a module always has an implicit interface +consisting of all of its exported functions. Functions which are _not_ exported, are only visible +within the module, and do not form a part of the implicit interface of a module. + +Module names are used to name the implicit interface of the module. Thus, within a component, both +imported interfaces, and the implicit interfaces of all modules it defines, can be used to resolve +function references in those modules. + +A module defines a symbol table, whose entries are the functions defined in that module. + +### Functions + +A function is a special type of [_operation_](#operations). It is special in the following ways: + +* A function has a [_symbol_](#symbols-and-symbol-tables), and thus declares an entry in the nearest +containing [_symbol table_](#symbols-and-symbol-tables). +* A function is _isolated from above_, i.e. the contents of the function cannot escape the function, +nor reference things outside the function, except via symbol table references. Thus entities such as +[_values_](#values) and [_blocks_](#blocks) are function-scoped, if not more narrowly scoped, in the +case of operations with nested [_regions_](#regions). + +A function has an arbitrary set of parameters and results, corresponding to its type signature. A +function also has the notion of an _application binary interface_ (ABI), which drives how code is +generated for both caller and callee. For example, a function may have a specific calling convention +as a whole, and specific parameters/results may have type-specific semantics declared, such as +whether to zero- or sign-extend the value if the input is of a smaller range. + +A function always consists of a single [_region_](#regions), called the _body_, with at least one +[_block_](#blocks), which is called the _entry block_. The block parameters of the entry block +always correspond to the function parameters, i.e. the arity and type of the block parameters must +match the function signature. + +Additionally, a function has an additional constraint on its body, which is that all blocks in the +region must end with one of a restricted set of _terminator_ operations: any branch operation, which +transfers control between blocks of the region; the `unreachable` operation, which will result in +aborting the program if executed; or the `return` operation, which must return the same arity and +type of values declared in the function signature. + +### Global Variables + +A global variable is a second special type of [_operation_](#operations), after +[_functions_](#functions): + +* A global variable has a [_symbol_](#symbols-and-symbol-tables), and declares an entry in the +nearest containing [_symbol table_](#symbols-and-symbol-tables). +* The _initializer_ of a global variable is, like function bodies, _isolated from above_. + +A global variable may have an _initializer_, a single region/single block body which is implicitly +executed to initialize the value of the global variable. The initializer must be statically +evaluatable, i.e. a "constant" expression. In most cases, this will simply be a constant value, but +some limited forms of constant expressions are permitted. + +### Symbols and Symbol Tables + +A _symbol_ is simply a named entity, e.g. a function `foo` is a symbol whose value is `foo`. +On their own, symbols aren't particularly useful. This is where the concept of a _symbol table_ +becomes important. + +A _symbol table_ is a collection of uniqued symbols belonging to the same namespace, i.e. every +symbol has a single entry in the symbol table, regardless of entity type. Thus, it is not permitted +to have both a function and a global variable with the same name, in the same symbol table. If such +a thing needed to be allowed, perhaps because the namespace for functions and global variables are +separate, then you would use a per-entity symbol table. + +For our purposes, a _module_ defines a symbol table, and both functions and global variables share +that table. We do not currently use symbol tables for anything else. + +### Operations + +An _operation_ is the most important entity in HIR, and the most abstract. In the _Regionalized +Value State Dependence Graph_ paper, the entire representation described there consists of various +types of operations. In HIR, we do not go quite that abstract, however we do take a fair amount of +inspiration from that paper, as well as from MLIR. + +Operations consist of the following pieces: + +* Zero or more [_regions_](#regions) and their constituent [_blocks_](#blocks) +* Zero or more [_operands_](#operands), i.e. arguments or inputs +* Zero or more _results_, or outputs, one of the two ways that [_values_](#values) can be introduced +* Zero or more [_successors_](#successors), in the case of operations which transfer control to +another block in the same region. +* Zero or more [_attributes_](#attributes), the semantics of which depend on the operation. +* Zero or more [_traits_](#traits), implemented by the operation. + +An operation always belongs to a _block_ when in use. + +As you can see, this is a highly flexible concept. It is capable of representing modules and +functions, as well as primitive instructions. It can represent both structured and unstructured +control-flow. There is very little in terms of an IR that _can't_ be represented using operations. + +However, in our case, we use operations for five specific concepts: + +* Functions (the first of two special ops) +* Global Variables (the second of two special ops) +* Structured Control Flow (if/then, do/while and for loops) +* Unstructured Control Flow (br, cond_br, switch, ret) +* Primitive Instructions (i.e. things which correspond to the target ISA, e.g. `add`, `call`, etc.) + +For the most part, the fact that functions and global variables are implemented using operations +is not particularly important. Instead, most operations you will interact with are of the other +three varieties. While we've broken them up into three categories, for the most part, they aren't +actually significantly different. The primary difference is that the unstructured control-flow ops +are valid _terminators_ for blocks, in multi-block regions, while the structured control-flow ops +are not, and only a few special cases of primitive ops are also valid terminators (namely the +`ret` and `unreachable` ops). For most primitive and structured control-flow ops, their behavior +appears very similar: they take some operands, perform some action, and possibly return some +results. + +### Regions + +A _region_ encapsulates a control-flow graph (CFG) of one or more [_basic blocks_](#blocks). In HIR, +the contents of a region are in _single-static assignment_ (SSA) form, meaning that values may only +be defined once, definitions must [_dominate_](#dominance-relation) uses, and operations in the CFG +described by the region are executed one-by-one, from the entry block of the region, until control +exits the region (e.g. via `ret` or some other terminator instruction). + +The order of operations in the region closely corresponds to their scheduling order, though the +code generator may reschedule operations when it is safe - and more efficient - to do so. + +Operations in a region may introduce nested regions. For example, the body of a function consists +of a single region, and it might contain an `if` operation that defines two nested regions, one for +the true branch, and one for the false branch. Nested regions may access any [_values_](#values) in +an ancestor region, so long as those values dominate the operation that introduced the nested region. +The exception to this are operations which are _isolated from above_. The regions of such an +operation are not permitted to reference anything defined in an outer scope, except via +[_symbols_](#symbols-and-symbol-tables). For example, [_functions_](#functions) are an operation +which is isolated from above. + +The purpose of regions, is to allow for hierarchical/structured control flow operations. Without +them, representing structured control flow in the IR is difficult and error-prone, due to the +semantics of SSA CFGs, particularly with regards to analyses like dominance and loops. It is also +an important part of what makes [_operations_](#operations) such a powerful abstraction, as it +provides a way to generically represent the concept of something like a function body, without +needing to special-case them. + +A region must always consist of at least one block (the entry block), but not all regions allow +multiple blocks. When multiple blocks are present, it implies the presence of unstructured control +flow, as the only way to transfer control between blocks is by using unstructured control flow +operations, such as `br`, `cond_br`, or `switch`. Structured control flow operations such as `if`, +introduce nested regions consisting of only a single block, as all control flow within a structured +control flow op, must itself be structured. The specific rules for a region depend on the semantics +of the containing operation. + +### Blocks + +A _block_, or _basic block_, is a set of one or more [_operations_](#operations) in which there is +no control flow, except via the block _terminator_, i.e. the last operation in the block, which is +responsible for transferring control to another block, exiting the current region (e.g. returning +from a function body), or terminating program execution in some way (e.g. `unreachable`). + +A block may declare _block parameters_, the only other way to introduce [_values_](#values) into +the IR, aside from operation results. Predecessors of a block must ensure that they provide +arguments for all block parameters when transfering control to the block. + +Blocks always belong to a [_region_](#regions). The first block in a region is called the _entry +block_, and is special in that its block parameters (if any) correspond to whatever arguments +the region accepts. For example, the body of a function is a region, and the entry block in that +region must have a parameter list that exactly matches the arity and type of the parameters +declared in the function signature. In this way, the function parameters are materialized as +SSA values in the IR. + +### Values + +A _value_ represents terms in a program, temporaries created to store data as it flows through the +program. In HIR, which is in SSA form, values are immutable - once created they cannot be changed +nor destroyed. This property of values allows them to be reused, rather than recomputed, when the +operation that produced them contains no side-effects, i.e. invoking the operation with the same +inputs must produce the same outputs. This forms the basis of one of the ways in which SSA IRs can +optimize programs. + +!!! note + +> One way in which you can form an intuition for values in an SSA IR, is by thinking of them as +> registers in a virtual machine with no limit to the number of machine registers. This corresponds +> well to the fact that most values in an IR, are of a type which corresponds to something that can +> fit in a typical machine register (e.g. 32-bit or 64-bit values, sometimes larger). +> +> Values which cannot be held in actual machine registers, are usually managed in the form of heap +> or stack-allocated memory, with various operations used to allocate, copy/move, or extract smaller +> values from them. While not strictly required by the SSA representation, this is almost always +> effectively enforced by the instruction set, which will only consist of instructions whose +> operands and results are of a type that can be held in machine registers. + +Value _definitions_ (aka "defs") can be introduced in two ways: + +1. Block parameters. Most notably, the entry block for function bodies materializes the function +parameters as values via block parameters. Block parameters are also used at places in the CFG +where two definitions for a single value are joined together. For example, if the value assigned to +a variable in the source language is assigned conditionally, then in the IR, there will be a block +with a parameter corresponding to the value of that variable after it is assigned. All uses after +that point, would refer to that block parameter, rather than the value from a specific branch. +Similarly, loop-carried variables, such as an iteration count, are typically manifested as block +parameters of the block corresponding to the loop header. +2. Operation results. The most common way in which values are introduced. + +Values have _uses_ corresponding to operands or successor arguments (special operands which are used +to satisfy successor block parameters). As a result, values also have _users_, corresponding to the +specific operation and operand forming a _use. + +All _uses_ of a value must be [_dominated_](#dominance-relation) by its _definition_. The IR is +invalid if this rule is ever violated. + +### Operands + +An _operand_ is a [_value_](#values) or [_immediate_](#immediates) used as an argument to an +operation. + +Beyond the semantics of any given operation, operand ordering is only significant in so far as it +is used as the order in which those items are expected to appear on the operand stack once lowered +to Miden Assembly. The earlier an operand appears in the list of operands for an operation, the +closer to the top of the operand stack it will appear. + +Similarly, the ordering of operand results also correlates to the operand stack order after +lowering. Specifically, the earlier a result appears in the result list, the closer to the top of +the operand stack it will appear after the operation executes. + +### Immediates + +An _immediate_ is a literal value, typically of integral type, used as an operand. Not all +operations support immediates, but those that do, will typically use them to attempt to perform +optimizations only possible when there is static information available about the operands. For +example, multiplying any number by 2, will always produce an even number, so a sequence such as +`mul.2 is_odd` can be folded to `false` at compile-time, allowing further optimizations to occur. + +Immediates are separate from _constants_, in that immediates _are_ constants, but specifically +constants which are valid operand values. + +### Attributes + +An _attribute_ is (typically optional) metadata attached to an IR entity. In HIR, attributes can be +attached to functions, global variables, and operations. + +Attributes are stored as a set of arbitrary key-value data, where values can be one of four types: + +* `unit`, Attributes of this value type are usually "marker" attributes, i.e. they convey their +information simply by being present. +* `bool`, Attributes of this value type are somewhat similar to those of `unit` type, but by +carrying a boolean value, they can be used to convey both positive and negative meaning. For +example, you might want to support explicit inlining with `#[inline(true)]`, and prevent any form +of inlining with `#[inline(false)]`. Here, `unit` would be insufficient to describe both options +under a single attribute. +* `int`, Attributes of this value type are used to convey numeric metadata. For example, inliner +thresholds, or some other kind of per-operation limits. +* `string`, Attributes of this value type are useed to convey arbitrary values. Most commonly +you might see this type with things that are enum-like, e.g. `#[cc(fast)]` to specify a particular +calling convention for a function. + +Some attributes are "first-class", in that they are defined as part of an operation. For example, +the calling convention of a function is an intrinsic attribute of a function, and feels like a +native part of the `Function` API - rather than having to look up the attribute, and cast the value +to a more natural Rust type, you can simply call `function.calling_convention()`. + +Attributes are not heavily used at this time, but are expected to serve more purposes in the future +as we increase the amount of information frontends need to convey to the compiler backend. + +### Traits + +A _trait_ defines some behavior that can be implemented by an operation. This allows operations +to operated over generically in an analysis or rewrite, rather than having to handle every possible +concrete operation type. This makes passes less fragile to changes in the IR in general, and allows +the IR to be extended without having to update every single place where operations are handled. + +An operation can be cast to a specific trait that it implements, and trait instances can be +downcast to the concrete operation type if known. + +There are a handful of built-in traits, used to convey certain semantic information about the +operations they are attached to, and in particular, are used to validate those operations, for +example: + +* `IsolatedFromAbove`, a marker trait that indicates that regions of the operation it is attached to +cannot reference items from any parents, except via [_symbols_](#symbols-and-symbol-tables). +* `Terminator`, a marker trait for operations which are valid block terminators +* `ReturnLike`, a trait that describes behavior shared by instructions that exit from an enclosing +region, "returning" the results of executing that region. The most notable of these is `ret`, but +`yield` used by the structured control flow ops is also return-like in nature. +* `BranchOp`, a trait that describes behavior shared by all unstructured control-flow branch +instructions, e.g. `br`, `cond_br`, and `switch`. +* `ConstantLike`, a marker trait for operations that produce a constant value +* `Commutative`, a marker trait for binary operations that exhibit commutativity, i.e. the order of +the operands can be swapped without changing semantics. + +There are others as well, responsible for aiding in type checking, decorating operations with the +types of side effects they do (or do not) exhibit, and more. + +### Successors and Predecessors + +The concept of _predecessor_ and _successor_ corresponds to a parent/child relationship in a +control-flow graph (CFG), where edges in the graph are directed, and describe the order in which +control flows through the program. If a node $A$ transfers control to a node $B$ after it is +finished executing, then $A$ is a _predecessor_ of $B$, and $B$ is a _successor_ of $A$. + +Successors and predecessors can be looked at from two similar, but slightly different, perspectives: + +1. In terms of operations. In an SSA CFG, operations in a basic block are executed in order, and +thus the successor of an operation in the block, is the next operation to be executed in that block, +with the predecessor being the inverse of that relationship. At basic block boundaries, the +successor(s) of the _terminator_ operation, are the set of operations to which control can be +transferred. Likewise, the predecessor(s) of the first operation in a block, are the set of +terminators which can transfer control to the containing block. This is the most precise, but is not +quite as intuitive as the alternative. +2. In terms of blocks. The successor(s) of a basic block, are the set of blocks to which control may +be transferred when exiting the block. Likewise, the precessor(s) of a block, are the set of blocks +which can transfer control to it. We are most frequently dealing with the concept of successors and +predecessors in terms of blocks, as it allows us to focus on the interesting parts of the CFG. For +example, the dominator tree and loop analyses, are constructed in terms of a block-oriented CFG, +since we can trivially derive dominance and loop information for individual ops from their +containing blocks. + +Typically, you will see successors as a pair of `(block_id, &[value_id])`, i.e. the block to which +control is transferred, and the set of values being passed as block arguments. On the other hand, +predecessors are most often a pair of `(block_id, terminator_op_id)`, i.e. the block from which +control originates, and the specific operation responsible. + +### Dominance Relation + +In an SSA IR, the concept of _dominance_ is of critical importance. Dominance is a property of the +relationship between two or more entities and their respective program points. For example, between +the use of a value as an operand for an operation, and the definition of that value; or between a +basic block and its successors. The dominance property is anti-symmetric, i.e. if $A$ dominates $B$, +then $B$ cannot dominate $A$, unless $A = B$. Put simply: + +> Given a control-flow graph $G$, and a node $A \in G$, then $\forall B \in G$, $A dom B$ if all +> paths to $B$ from the root of $G$, pass through $A$. +> +> Furthermore, $A$ _strictly_ dominates $B$, if $A \neq B$. + +An example of why dominance is an important property of a program, can be seen when considering the +meaning of a program like so (written in pseudocode): + +``` +if (...) { + var a = 1; +} + +foo(a) +``` + +Here, the definition of `a` does not dominate its usage in the call to `foo`. If the conditional +branch is ever false, `a` is never defined, nor initialized - so what should happen when we reach +the call to `foo`? + +In practice, of course, such a program is rarely possible to expresss in a high-level language, +however in a low-level CFG, it is possible to reference values which are defined somewhere in the +graph, but in such a way that is not _legal_ according to the "definitions must dominate uses" +rule of SSA CFGs. The dominance property is what we use to validate the correctness of the IR, as +well as evaluate the range of valid transformations that can be applied to the IR. For example, we +might determine that it is valid to move an expression into a specific `if/then` branch, because +it is only used in that branch - the dominance property is how we determine that there are paths +through the program in which the result of the expression is unused, as well as what program points +represent the nearest point to one of its uses that still dominates _all_ of the uses. + +There is another useful notion of dominance, called _post-dominance_, which can be described much +like the regular notion of dominance, except in terms of paths to the exit of the CFG, rather than +paths from the entry: + +> Given a control-flow graph $G$, and a node $A \in $G$, then $\forall B \in G$, $A pdom B$ if all +> paths through $B$ that exit the CFG, must flow through $A$ first. +> +> Furthermore, $A$ _strictly_ post-dominates $B$ if $A \neq B$. + +The notion of post-dominance is important in determining the applicability of certain transformations, +in particular with loops. + + +## Structure + +The hierarchy of HIR looks like so: + + Component <- Imports <- Interface + | + v + Exports + | + v + Interface + | + v + Module --------- + | | + v v + Function/Op Global Variable + | + v + -- Region + | | + | v + | Block -> Value <-- + | | | | + | | v | + | | Operand | + | v | | + > Operation <- | + | | + v | + Result ----------- + +In short: + +* A _component_ imports dependencies in the form of _interfaces_ +* A _component_ exports at least one _interface_. +* An _interface_ is concretely implemented with a _module_. +* A _module_ contains _function_ and _global variable_ definitions, and imports them from the set +of interfaces available to the component. +* A _function_, as a type of _operation_, consists of a body _region_. +* A _region_ consists of at least one _block_, that may define _values_ in the form of block +parameters. +* A _block_ consists of _operations_, whose results introduce new _values_, and whose operands are +_values_ introduced either as block parameters or the results of previous operations. +* An _operation_ may contain nested _regions_, with associated blocks and operations. + +### Passes + +A _pass_ is effectively a fancy function with a specific signature and some constraints on its +semantics. There are three primary types of passes: [_analysis_](#analyses), +[_rewrite_](#rewrites), and [_conversion_](#conversions). These three pass types have different +signatures and semantics, but play symbiotic roles in the compilation pipeline. + +There are two abstractions over all passes, used and extended by the three types described above: + +* The `Pass` trait, which provides an abstraction suitable for describing any type of compiler pass +used by `midenc`. It primarily exists to allow writing pass-generic helpers. +* The `PassInfo` trait, which exists to provide a common interface for pass metadata, such as the +name, description, and command-line flag prefix. All passes must implement this trait. + +#### Analyses + +Analysis of the IR is expressed in the form of a specialized pass, an `Analysis`, and an +`AnalysisManager`, which is responsible for computing analyses on demand, caching them, and +invalidating the relevant parts of the cache when the IR changes. Analyses are expressed in terms +of a specific entity, such as `Function`, and are cached based on the unique identity of that +entity. + +An _analysis_ is responsible for computing some fact about the given IR entity it is given. Facts +typically include things such as: computing dominance, identifying loops and their various component +parts, reachability, liveness, identifying unused (i.e. dead) code, and much more. + +To do this, analyses are given an immutable reference to the IR in question; a reference to the +current `AnalysisManager`, so that the results of other analyses can be consulted; and a reference +to the current compilation `Session` for access to configuration relevant to the analysis. + +Analysis results are computed as an instance of `Self`. This provides structure to the analysis +results, and provides a place to implement helpful functionality for querying the results. + +A well-written analysis should be based purely off the inputs given to `Analysis::analyze`, and +ideally be based on some formalism, so that properties of the analysis can be verified. Most of +the analyses in HIR today, are based on the formalisms underpinning _dataflow analysis_, e.g. +semi-join lattices. + +#### Rewrites + +_Rewrites_ are a type of pass which mutate the IR entity to which they are applied. They can be +chained (via the `chain` method on the `RewritePass` trait) to form rewrite pipelines, called a +`RewriteSet`. A `RewriteSet` manages executing each rewrite in the set, and coordinating with the +`AnalysisManager` between rewrites. + +Rewrites are given a mutable reference to the IR they should apply to; the current `AnalysisManager` +so that analyes can be consulted (or computed) to facilitate the rewrite; and the current +compilation `Session`, for configuration. + +Rewrites must leave the IR in a valid state. Rewrites are also responsible for indicating, via the +`AnalysisManager`, which analyses can be preserved after applying the rewrite. A rewrite that makes +no changes, should mark all analyses preserved, to avoid recomputing the analysis results the next +time they are requested. If an analysis is not explicitly preserved by a rewrite, it will be +invalidated by the containing `RewriteSet`. + +Rewrite passes written for a `Function`, can be adapted for application to a `Module`, using the +`ModuleRewriteAdapter`. This makes writing rewrites for use in the main compiler pipeline, as simple +as defining it for a `Function`, and then using the `ModuleRewriteAdapter` to add the rewrite to the +pipeline. + +A well-written rewrite pass should only use data available in the IR itself, or analyses in the +provided `AnalysisManager`, to drive application of the rewrite. Additionally, rewrites should +focus on a single transformation, and rely on chaining rewrites to orchestrate more complex +transformations composed of multiple stages. Lastly, the rewrite should ideally have a logical +proof of its safety, or failing that, a basis in some formalism that can be suitably analyzed and/or +tested. If a rewrite cannot be described in such a way, it will be very difficult to provide +guarantees about the code produced by the transformation. This makes it hard to be confident +in the rewrite (and by extension, the compiler), and impossible to verify. + +#### Conversions + +_Conversions_ are a type of pass which converts between intermediate representations, or more +abstractly, between _dialects_. + +The concept of a dialect is focused primarily on semantics, and less so on concrete representations. +For example, source and target dialects might share the same underlying IR, with the target dialect +having a more restricted set of _legal_ operations, possibly with stricter semantics. + +This brings us to the concept of _legalization_, i.e. converting all _illegal_ operations in the IR +into _legal_ equivalents. Each dialect defines the set of operations which it considers legal. The +concept of legality is mostly important when the same underlying IR is used across multiple dialects, +as is the case with HIR. + +There are two types of conversion passes in HIR: _dialect conversion_, and _translation_: + +* _Dialect conversion_ is used at various points in the compilation pipeline to simplify the IR +for later passes. For example, Miden Assembly has no way to represent multi-way branches, such +as implemented by `switch`. At a certain point, we switch to a dialect where `switch` is illegal, +so that further passes can be written as if `switch` doesn't exist. Yet another dialect later in +the pipeline, makes all unstructured control flow illegal, in preparation for translation to +Miden Assembly, which has no unstructured control flow operators. +* _Translation_ refers to a conversion from one IR to a completely different one. Currently, the +only translation we have, is the one responsible for translating from HIR to Miden Assembly. In +the future, we hope to also implement frontends as translations _to_ HIR, but that is not currently +the case. Translations are currently implemented as simple passes that take in some IR as input, +and produce _whatever_ as output. + +Dialect conversions are implemented in the form of generic _conversion infrastructure_. A dialect +conversion is described as a set of _conversion patterns_, which define what to do when a specific +operation is seen in the input IR; and the set of operations which are legal in the target dialect. +The conversion driver is responsible for visiting the IR, repetitively applying any matched +conversion patterns until a fixpoint is reached, i.e. no more patterns are matched, or a conversion +pattern fails to apply successfully. If any illegal operations are found which do not have +corresponding conversion patterns, then a legalization error is raised, and the conversion overall +fails. + +## Usage + +Let's get into the gritty details of how the compiler works, particularly in relation to HIR. + +The entry point for compilation, generally speaking, is the _driver_. The driver is what is +responsible for collecting compiler inputs, including configuration, from the user, instantiating a +`Session`, and then invoking the compiler frontend with those items. + +The first stage of compilation is _parsing_, in which each input is converted to HIR by an +appropriate frontend. For example, Wasm modules are loaded and translated to HIR using the Wasm +frontend. The purpose of this stage is to get all inputs into HIR form, for subsequent stages. The +only exception to this are MASM sources, which are assembled to MAST directly, and then set aside +until later in the pipeline when we go to assemble the final artifact. + +The second stage of compilation is _semantic analysis_. This is the point where we validate the +HIR we have so far, and ensure that there are no obvious issues that will cause compilation to +fail unexpectedly later on. In some cases, this stage is skipped, as we have already validated the +IR in the frontend. + +The third stage of compilation is _linking_. This is where we gather together all of the inputs, +as well as any compiler options that tell us what libraries to link against, and where to search +for them, and then ensure that there are no missing inputs, undefined symbols, or incompatible type +signatures. The output of this stage is a well-formed [_component_](#components). + +The fourth stage of compilation is _rewriting_, in which all of the rewrite passes we wish to +apply to the IR, are applied. You could also think of this stage as a combination of optimization +and preparing for codegen. + +The fifth stage of compilation is _codegen_, where HIR is translated to Miden Assembly. + +The final stage of compilation is _assembly_, where the Miden Assembly we produced, along with +any other Miden Assembly libraries, are assembled to MAST, and then packaged in the Miden package +format. + +The `Session` object contains all of the important compiler configuration and exists for the +duration of a compiler invocation. In addition to this, there is a `Context` object, which is used +to allocate IR entities contained within a single `Module`. The context of each `Module` is further +subdivided by `Function`, in the form of `FunctionContext`, from which all function-local IR +entities are allocated. This ensures that each `Function` gets its own set of values, blocks, and +regions, while sharing `Module`-wide entities, such as constants and symbols. + +Operations are allocated from the nearest `Context`, and use that context to allocate and access +IR entities used by the operation. + +The `Context` object is pervasive, it is needed just about anywhere that IR entities are used, to +allow accessing data associated with those entities. Most entity references are integer ids, which +uniquely identify the entity, but provide no access to them without going through the `Context`. + +This is a bit awkward, but is easier to work with in Rust. The alternative is to rely on +dynamically-checked interior mutability inside reference-counted allocations, e.g. `Rc>`. +Not only is this similarly pervasive in terms of how APIs are structured, but it loses some of +the performance benefits of allocating IR objects close together on the heap. + +The following is some example code in Rust, demonstrating how it might look to create a component +in HIR and work with it. + +```rust +use midenc_hir::*; + +/// Defining Ops + +/// `Add` is a binary integral arithmetic operator, i.e. `+` +/// +/// It consists of two operands, which must be of the same type, and produces a result that +/// is also of that type. +/// +/// It supports various types of overflow semantics, see `Overflow`. +pub struct Add { + op: Operation, +} +impl Add { + pub fn build(overflow: Overflow, lhs: Operand, rhs: Operand, span: SourceSpan, context: &mut Context) -> OpId { + let ctrl_ty = context.value_type(lhs); + let mut builder = OpBuilder::::new(context); + // Set the source span for this op + builder.with_span(span); + // We specify the concrete operand values + builder.with_operands([lhs, rhs]); + // We specify results in terms of types, and the builder will materialize values + // for the instruction results based on those types. + builder.with_results([ctrl_ty]); + // We must also specify operation attributes, e.g. overflow behavior + // + // Attribute values must implement `AttributeValue`, a trait which represents the encoding + // and decoding of a type as an attribute value. + builder.with_attribute("overflow", overflow); + // Add op traits that this type implements + builder.with_trait::(); + builder.with_trait::(); + builder.with_trait::(); + // Instantiate the op + // + // NOTE: In order to use `OpBuilder`, `Self` must implement `Default` if it has any other + // fields than the underlying `Operation`. This is because the `OpBuilder` will construct + // a default instance of the type when allocating it, and according to Rust rules, any + // subsequent reference to the type requires that all fields were properly initialized. + builder.build() + } + + /// The `OpOperand` type abstracts over information about a given operand, such as whether it + /// is a value or immediate, and its type. It also contains the link for the operand in the + /// original `Value`'s use list. + pub fn lhs(&self) -> &OpOperand { + &self.op.operands[0] + } + + pub fn rhs(&self) -> &OpOperand { + &self.op.operands[1] + } + + /// The `OpResult` type abstracts over information about a given result, such as whether it + /// is a value or constant, and its type. + pub fn result(&self) -> &OpResult { + &self.op.results[0] + } + + /// Attributes of the op can be reified from the `AttributeSet` of the underlying `Operation`, + /// using `get_attribute`, which will use the `AttributeValue` implementation for the type to + /// reify it from the raw attribute data. + pub fn overflow(&self) -> Overflow { + self.op.get_attribute("overflow").unwrap_or_default() + } +} +/// All ops must implement this trait, but most implementations will look like this +impl Op for Add { + type Id = OpId; + + fn id(&self) -> Self::Id { self.op.key } + fn name(&self) -> &'static str { "add" } + fn as_operation(&self) -> &Operation { &self.op} + fn as_operation_mut(&mut self) -> &mut Operation { &mut self.op} +} +/// Marker trait used to indicate an op exhibits the commutativity property +impl Commutative for Add {} +/// Marker trait used to indicate that operands of this op should all be the same type +impl SameTypeOperands for Add {} +/// Canonicalization is optional, but encouraged when an operation has a canonical form. +/// +/// This is applied after transformations which introduce or modify an `Add` op, and ensures +/// that it is in canonical form. +/// +/// Canonicalizations ensure that pattern-based rewrites can be expressed in terms of the +/// canonical form, rather than needing to account for all possible variations. +impl Canonicalize for Add { + fn canonicalize(&mut self) { + // If `add` is given an immediate operand, always place it on the right-hand side + if self.lhs().is_immediate() { + if self.rhs().is_immediate() { + return; + } + self.op.operands.swap(0, 1); + } + } +} +/// Ops can optionally implement [PrettyPrint] and [PrettyParser], to allow for a less verbose +/// textual representation. If not implemented, the op will be printed using the generic format +/// driven by the underlying `Operation`. +impl formatter::PrettyPrint for Add { + fn render(&self) -> formatter::Document { + use formatter::*; + + let opcode = match self.overflow() { + Overflow::Unchecked => const_text("add"), + Overflow::Checked => const_text("add.checked"), + Overflow::Wrapping => const_text("add.wrapping"), + Overflow::Overflowing => const_text("add.overflowing"), + }; + opcode + const_text(" ") + display(self.lhs()) + const_text(", ") + display(self.rhs()) + } +} +impl parsing::PrettyParser for Add { + fn parse(s: &str, context: &mut Context) -> Result { + todo!("not shown here") + } +} + +/// Constructing Components + +// An interface can consist of functions, global variables, and potentially types in the future +let mut std = Interface::new("std"); +std.insert("math::u64/add_checked", FunctionType::new([Type::U64, Type::U64], [Type::U64])); + +let test = Interface::new("test"); +test.insert("run", FunctionType::new([], [Type::U32])); + +// A component is instantiated empty +let mut component = Component::new("test"); + +// You must then declare the interfaces that it imports and exports +component.import(std.clone()); +component.export(test.clone()); + +// Then, to actually define a component instance, you must define modules which implement +// the interfaces exported by the component +let mut module = Module::new("test"); +let mut run = module.define("run", FunctionType::new([], [Type::U32])); +//... build 'run' function + +// And add them to the component, like shown here. Here, `module` is added to the component as an +// implementation of the `test` interface +component.implement(test, module); + +// Modules can also be added to a component without using them to implement an interface, +// in which case they are only accessible from other modules in the same component, e.g.: +let foo = Module::new("foo"); +// Here, the 'foo' module will only be accessible from the 'test' module. +component.add(foo); + +// Lastly, during compilation, imports are resolved by linking against the components which +// implement them. The linker step identifies the concrete paths of each item that provides +// an imported symbol, and rewrites the generic interface path with the concrete path of the +// item it was resolved to. + +/// Visiting IR + +// Visiting the CFG + +let entry = function.body().entry(); +let mut worklist = VecDeque::from_iter([entry]); + +while let Some(next) = worklist.pop_front() { + // Ignore blocks we've already visited + if !visited.insert(next) { + continue; + } + let terminator = context.block(next).last().unwrap(); + // Visit all successors after this block, if the terminator branches to another block + if let Some(branch) = terminator.downcast_ref::() { + worklist.extend(branch.successors().iter().map(|succ| succ.block)); + } + // Visit the operations in the block bottom-up + let mut current_op = terminator; + while let Some(op) = current_op.prev() { + current_op = op; + } +} + +// Visiting uses of a value + +let op = function.body().entry().first().unwrap(); +let result = op.results().first().unwrap(); +assert!(result.is_used()); +for user in result.uses() { + dbg!(user); +} + +// Applying a rewrite pattern to every match in a function body + +struct FoldAdd; +impl RewritePattern for FoldAdd { + fn matches(&self, op: &dyn Op) -> bool { + op.is::() + } + + fn apply(&mut self, op: &mut dyn Op) -> RewritePatternResult { + let add_op = op.downcast_mut::().unwrap(); + if let Some(lhs) = add_op.lhs().as_immediate() { + if let Some(rhs) = add_op.rhs().as_immediate() { + let result = lhs + rhs; + return Ok(RewriteAction::ReplaceAllUsesWith(add_op.result().value, result)); + } + } + Ok(RewriteAction::None) + } +} +``` diff --git a/hir-macros/src/lib.rs b/hir-macros/src/lib.rs index f1f1491fc..db4b85f59 100644 --- a/hir-macros/src/lib.rs +++ b/hir-macros/src/lib.rs @@ -1,5 +1,6 @@ extern crate proc_macro; +//mod op; mod spanned; use inflector::cases::kebabcase::to_kebab_case; @@ -25,6 +26,28 @@ pub fn derive_spanned(input: proc_macro::TokenStream) -> proc_macro::TokenStream } } +/// #[derive(Op)] +/// #[op(name = "select", interfaces(BranchOpInterface))] +/// pub struct Select { +/// #[operation] +/// op: Operation, +/// #[operand] +/// selector: OpOperand, +/// +/// } +/* +#[proc_macro_derive(Op, attributes(op, operation, operand, result, successor, region, interfaces))] +pub fn derive_op(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + // Parse into syntax tree + let derive = parse_macro_input!(input as DeriveInput); + let op = match op::Op::from_derive_input(derive) { + Ok(op) => op, + Err(err) => err.to_compile_error().into(), + }; + quote!(#op).into() +} + */ + #[proc_macro_derive(PassInfo)] pub fn derive_pass_info(item: proc_macro::TokenStream) -> proc_macro::TokenStream { let derive_input = parse_macro_input!(item as DeriveInput); diff --git a/hir-macros/src/op.rs b/hir-macros/src/op.rs new file mode 100644 index 000000000..baf0f473c --- /dev/null +++ b/hir-macros/src/op.rs @@ -0,0 +1,268 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{punctuated::Punctuated, DeriveInput, Token}; + +pub struct Op { + generics: syn::Generics, + ident: syn::Ident, + fields: syn::Fields, + args: OpArgs, +} + +#[derive(Default)] +pub struct OpArgs { + pub code: Option, + pub severity: Option, + pub help: Option, + pub labels: Option, + pub source_code: Option, + pub url: Option, + pub forward: Option, + pub related: Option, + pub diagnostic_source: Option, +} + +pub enum OpArg { + Transparent, + Code(Code), + Severity(Severity), + Help(Help), + Url(Url), + Forward(Forward), +} + +impl Parse for OpArg { + fn parse(input: ParseStream) -> syn::Result { + let ident = input.fork().parse::()?; + if ident == "transparent" { + // consume the token + let _: syn::Ident = input.parse()?; + Ok(OpArg::Transparent) + } else if ident == "forward" { + Ok(OpArg::Forward(input.parse()?)) + } else if ident == "code" { + Ok(OpArg::Code(input.parse()?)) + } else if ident == "severity" { + Ok(OpArg::Severity(input.parse()?)) + } else if ident == "help" { + Ok(OpArg::Help(input.parse()?)) + } else if ident == "url" { + Ok(OpArg::Url(input.parse()?)) + } else { + Err(syn::Error::new(ident.span(), "Unrecognized diagnostic option")) + } + } +} + +impl OpArgs { + pub(crate) fn forward_or_override_enum( + &self, + variant: &syn::Ident, + which_fn: WhichFn, + mut f: impl FnMut(&ConcreteOpArgs) -> Option, + ) -> Option { + match self { + Self::Transparent(forward) => Some(forward.gen_enum_match_arm(variant, which_fn)), + Self::Concrete(concrete) => f(concrete).or_else(|| { + concrete + .forward + .as_ref() + .map(|forward| forward.gen_enum_match_arm(variant, which_fn)) + }), + } + } +} + +impl OpArgs { + fn parse( + _ident: &syn::Ident, + fields: &syn::Fields, + attrs: &[&syn::Attribute], + allow_transparent: bool, + ) -> syn::Result { + let mut errors = Vec::new(); + + let mut concrete = OpArgs::for_fields(fields)?; + for attr in attrs { + let args = attr.parse_args_with(Punctuated::::parse_terminated); + let args = match args { + Ok(args) => args, + Err(error) => { + errors.push(error); + continue; + } + }; + + concrete.add_args(attr, args, &mut errors); + } + + let combined_error = errors.into_iter().reduce(|mut lhs, rhs| { + lhs.combine(rhs); + lhs + }); + if let Some(error) = combined_error { + Err(error) + } else { + Ok(concrete) + } + } + + fn for_fields(fields: &syn::Fields) -> Result { + let labels = Labels::from_fields(fields)?; + let source_code = SourceCode::from_fields(fields)?; + let related = Related::from_fields(fields)?; + let help = Help::from_fields(fields)?; + let diagnostic_source = DiagnosticSource::from_fields(fields)?; + Ok(Self { + code: None, + help, + related, + severity: None, + labels, + url: None, + forward: None, + source_code, + diagnostic_source, + }) + } + + fn add_args( + &mut self, + attr: &syn::Attribute, + args: impl Iterator, + errors: &mut Vec, + ) { + for arg in args { + match arg { + OpArg::Transparent => { + errors.push(syn::Error::new_spanned(attr, "transparent not allowed")); + } + OpArg::Forward(to_field) => { + if self.forward.is_some() { + errors.push(syn::Error::new_spanned( + attr, + "forward has already been specified", + )); + } + self.forward = Some(to_field); + } + OpArg::Code(new_code) => { + if self.code.is_some() { + errors + .push(syn::Error::new_spanned(attr, "code has already been specified")); + } + self.code = Some(new_code); + } + OpArg::Severity(sev) => { + if self.severity.is_some() { + errors.push(syn::Error::new_spanned( + attr, + "severity has already been specified", + )); + } + self.severity = Some(sev); + } + OpArg::Help(hl) => { + if self.help.is_some() { + errors + .push(syn::Error::new_spanned(attr, "help has already been specified")); + } + self.help = Some(hl); + } + OpArg::Url(u) => { + if self.url.is_some() { + errors + .push(syn::Error::new_spanned(attr, "url has already been specified")); + } + self.url = Some(u); + } + } + } + } +} + +impl Op { + pub fn from_derive_input(input: DeriveInput) -> Result { + let input_attrs = input + .attrs + .iter() + .filter(|x| x.path().is_ident("operation")) + .collect::>(); + Ok(match input.data { + syn::Data::Struct(data_struct) => { + let args = OpArgs::parse(&input.ident, &data_struct.fields, &input_attrs, true)?; + + Op { + fields: data_struct.fields, + ident: input.ident, + generics: input.generics, + args, + } + } + syn::Data::Enum(_) | syn::Data::Union(_) => { + return Err(syn::Error::new( + input.ident.span(), + "Can't derive Op for enums or unions", + )) + } + }) + } + + pub fn gen(&self) -> TokenStream { + let (impl_generics, ty_generics, where_clause) = &self.generics.split_for_impl(); + let concrete = &self.args; + let forward = |which| concrete.forward.as_ref().map(|fwd| fwd.gen_struct_method(which)); + let code_body = concrete + .code + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::Code)); + let help_body = concrete + .help + .as_ref() + .and_then(|x| x.gen_struct(fields)) + .or_else(|| forward(WhichFn::Help)); + let sev_body = concrete + .severity + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::Severity)); + let rel_body = concrete + .related + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::Related)); + let url_body = concrete + .url + .as_ref() + .and_then(|x| x.gen_struct(ident, fields)) + .or_else(|| forward(WhichFn::Url)); + let labels_body = concrete + .labels + .as_ref() + .and_then(|x| x.gen_struct(fields)) + .or_else(|| forward(WhichFn::Labels)); + let src_body = concrete + .source_code + .as_ref() + .and_then(|x| x.gen_struct(fields)) + .or_else(|| forward(WhichFn::SourceCode)); + let diagnostic_source = concrete + .diagnostic_source + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::DiagnosticSource)); + quote! { + impl #impl_generics miette::Diagnostic for #ident #ty_generics #where_clause { + #code_body + #help_body + #sev_body + #rel_body + #url_body + #labels_body + #src_body + #diagnostic_source + } + } + } +} diff --git a/hir2/Cargo.toml b/hir2/Cargo.toml new file mode 100644 index 000000000..6563c30c4 --- /dev/null +++ b/hir2/Cargo.toml @@ -0,0 +1,63 @@ +[package] +name = "midenc-hir2" +description = "High-level Intermediate Representation for Miden Assembly" +version.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true +categories.workspace = true +keywords.workspace = true +license.workspace = true +readme.workspace = true +edition.workspace = true + +[features] +default = ["std"] +std = ["rustc-demangle/std"] +serde = [ + "dep:serde", + "dep:serde_repr", + "dep:serde_bytes", + "midenc-hir-symbol/serde", +] + +[build-dependencies] +lalrpop = { version = "0.20", default-features = false } + +[dependencies] +anyhow.workspace = true +blink-alloc = { version = "0.3", default-features = false, features = [ + "alloc", + "nightly", +] } +either.workspace = true +cranelift-entity.workspace = true +downcast-rs = { version = "1.2", default-features = false } +intrusive-collections.workspace = true +inventory.workspace = true +lalrpop-util = "0.20" +log.workspace = true +miden-core.workspace = true +miden-assembly.workspace = true +midenc-hir-symbol.workspace = true +midenc-hir-type.workspace = true +midenc-hir-macros.workspace = true +midenc-session.workspace = true +num-bigint = "0.4" +num-traits = "0.2" +petgraph.workspace = true +paste.workspace = true +rustc-hash.workspace = true +rustc-demangle = "0.1.19" +serde = { workspace = true, optional = true } +serde_repr = { workspace = true, optional = true } +serde_bytes = { workspace = true, optional = true } +smallvec.workspace = true +thiserror.workspace = true +typed-arena = "2.0" +unicode-width = { version = "0.1", features = ["no_std"] } +derive_more.workspace = true +indexmap.workspace = true + +[dev-dependencies] +pretty_assertions = "1.0" diff --git a/hir2/src/core.rs b/hir2/src/core.rs new file mode 100644 index 000000000..cfc7084e3 --- /dev/null +++ b/hir2/src/core.rs @@ -0,0 +1,33 @@ +mod attribute; +mod block; +mod component; +mod context; +mod entity; +mod function; +mod interface; +mod module; +mod op; +mod operation; +mod region; +mod symbol_table; +mod traits; +mod types; +mod usable; +mod value; + +pub use self::{ + block::{Block, BlockCursor, BlockCursorMut, BlockList, BlockOperand}, + entity::{ + Entity, EntityCursor, EntityCursorMut, EntityHandle, EntityId, EntityIter, EntityList, + EntityMut, EntityRef, TrackedEntityHandle, + }, + function::{ + AbiParam, ArgumentExtension, ArgumentPurpose, CallConv, Function, FunctionIdent, Signature, + }, + module::Module, + region::{Region, RegionCursor, RegionCursorMut, RegionList}, + symbol_table::{Symbol, SymbolTable}, + types::*, + usable::Usable, + value::{BlockArgument, OpOperand, OpResult, Value, ValueKind}, +}; diff --git a/hir2/src/core/attribute.rs b/hir2/src/core/attribute.rs new file mode 100644 index 000000000..05d62b820 --- /dev/null +++ b/hir2/src/core/attribute.rs @@ -0,0 +1,273 @@ +use alloc::collections::BTreeMap; +use core::{borrow::Borrow, fmt}; + +use midenc_hir_symbol::Symbol; + +pub mod attributes { + use midenc_hir_symbol::symbols; + + use super::*; + + /// This attribute indicates that the decorated function is the entrypoint + /// for its containing program, regardless of what module it is defined in. + pub const ENTRYPOINT: Attribute = Attribute { + name: symbols::Entrypoint, + value: AttributeValue::Unit, + }; +} + +/// An [AttributeSet] is a uniqued collection of attributes associated with some IR entity +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +pub struct AttributeSet(BTreeMap); +impl FromIterator for AttributeSet { + fn from_iter(attrs: T) -> Self + where + T: IntoIterator, + { + let mut map = BTreeMap::default(); + for attr in attrs.into_iter() { + map.insert(attr.name, attr.value); + } + Self(map) + } +} +impl FromIterator<(Symbol, AttributeValue)> for AttributeSet { + fn from_iter(attrs: T) -> Self + where + T: IntoIterator, + { + let mut map = BTreeMap::default(); + for (name, value) in attrs.into_iter() { + map.insert(name, value); + } + Self(map) + } +} +impl AttributeSet { + /// Get a new, empty [AttributeSet] + pub fn new() -> Self { + Self::default() + } + + /// Insert a new [Attribute] in this set by `name` and `value` + pub fn insert(&mut self, name: impl Into, value: impl Into) { + self.0.insert(name.into(), value.into()); + } + + /// Adds `attr` to this set + pub fn set(&mut self, attr: Attribute) { + self.0.insert(attr.name, attr.value); + } + + /// Remove an [Attribute] by name from this set + pub fn remove(&mut self, name: &Q) + where + Symbol: Borrow, + Q: Ord + ?Sized, + { + self.0.remove(name); + } + + /// Determine if the named [Attribute] is present in this set + pub fn has(&self, key: &Q) -> bool + where + Symbol: Borrow, + Q: Ord + ?Sized, + { + self.0.contains_key(key) + } + + /// Get the [AttributeValue] associated with the named [Attribute] + pub fn get(&self, key: &Q) -> Option<&AttributeValue> + where + Symbol: Borrow, + Q: Ord + ?Sized, + { + self.0.get(key) + } + + /// Get the value associated with the named [Attribute] as a boolean, or `None`. + pub fn get_bool(&self, key: &Q) -> Option + where + Symbol: Borrow, + Q: Ord + ?Sized, + { + self.0.get(key).and_then(|v| v.as_bool()) + } + + /// Get the value associated with the named [Attribute] as an integer, or `None`. + pub fn get_int(&self, key: &Q) -> Option + where + Symbol: Borrow, + Q: Ord + ?Sized, + { + self.0.get(key).and_then(|v| v.as_int()) + } + + /// Get the value associated with the named [Attribute] as a [Symbol], or `None`. + pub fn get_symbol(&self, key: &Q) -> Option + where + Symbol: Borrow, + Q: Ord + ?Sized, + { + self.0.get(key).and_then(|v| v.as_symbol()) + } + + /// Iterate over each [Attribute] in this set + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter().map(|(k, v)| Attribute { + name: *k, + value: *v, + }) + } +} + +/// An [Attribute] associates some data with a well-known identifier (name). +/// +/// Attributes are used for representing metadata that helps guide compilation, +/// but which is not part of the code itself. For example, `cfg` flags in Rust +/// are an example of something which you could represent using an [Attribute]. +/// They can also be used to store documentation, source locations, and more. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Attribute { + /// The name of this attribute + pub name: Symbol, + /// The value associated with this attribute + pub value: AttributeValue, +} +impl Attribute { + pub fn new(name: impl Into, value: impl Into) -> Self { + Self { + name: name.into(), + value: value.into(), + } + } +} +impl fmt::Display for Attribute { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.value { + AttributeValue::Unit => write!(f, "#[{}]", self.name.as_str()), + value => write!(f, "#[{}({value})]", &self.name), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum AttributeValue { + /// No concrete value (i.e. presence of the attribute is significant) + Unit, + /// A boolean value + Bool(bool), + /// A signed integer + Int(isize), + /// An interned string + String(Symbol), +} +impl AttributeValue { + pub fn as_bool(&self) -> Option { + match self { + Self::Bool(value) => Some(*value), + _ => None, + } + } + + pub fn as_int(&self) -> Option { + match self { + Self::Int(value) => Some(*value), + _ => None, + } + } + + pub fn as_symbol(&self) -> Option { + match self { + Self::String(value) => Some(*value), + _ => None, + } + } +} +impl fmt::Display for AttributeValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Unit => f.write_str("()"), + Self::Bool(value) => write!(f, "{value}"), + Self::Int(value) => write!(f, "{value}"), + Self::String(value) => write!(f, "\"{}\"", value.as_str().escape_default()), + } + } +} +impl From<()> for AttributeValue { + fn from(_: ()) -> Self { + Self::Unit + } +} +impl From for AttributeValue { + fn from(value: bool) -> Self { + Self::Bool(value) + } +} +impl From for AttributeValue { + fn from(value: isize) -> Self { + Self::Int(value) + } +} +impl From<&str> for AttributeValue { + fn from(value: &str) -> Self { + Self::String(Symbol::intern(value)) + } +} +impl From for AttributeValue { + fn from(value: String) -> Self { + Self::String(Symbol::intern(value.as_str())) + } +} +impl From for AttributeValue { + fn from(value: u8) -> Self { + Self::Int(value as isize) + } +} +impl From for AttributeValue { + fn from(value: i8) -> Self { + Self::Int(value as isize) + } +} +impl From for AttributeValue { + fn from(value: u16) -> Self { + Self::Int(value as isize) + } +} +impl From for AttributeValue { + fn from(value: i16) -> Self { + Self::Int(value as isize) + } +} +impl From for AttributeValue { + fn from(value: u32) -> Self { + Self::Int(value as isize) + } +} +impl From for AttributeValue { + fn from(value: i32) -> Self { + Self::Int(value as isize) + } +} +impl TryFrom for AttributeValue { + type Error = core::num::TryFromIntError; + + fn try_from(value: usize) -> Result { + Ok(Self::Int(value.try_into()?)) + } +} +impl TryFrom for AttributeValue { + type Error = core::num::TryFromIntError; + + fn try_from(value: u64) -> Result { + Ok(Self::Int(value.try_into()?)) + } +} +impl TryFrom for AttributeValue { + type Error = core::num::TryFromIntError; + + fn try_from(value: i64) -> Result { + Ok(Self::Int(value.try_into()?)) + } +} diff --git a/hir2/src/core/block.rs b/hir2/src/core/block.rs new file mode 100644 index 000000000..5c6422008 --- /dev/null +++ b/hir2/src/core/block.rs @@ -0,0 +1,78 @@ +use super::{ + BlockArgument, EntityCursor, EntityCursorMut, EntityHandle, EntityIter, EntityList, OpList, + Operation, Usable, +}; + +/// An intrusive, doubly-linked list of [Block] +pub type BlockList = EntityList; +pub type BlockCursor<'a> = EntityCursor<'a, Block>; +pub type BlockCursorMut<'a> = EntityCursorMut<'a, Block>; + +pub struct Block { + /// The set of uses of this block + uses: BlockOperandList, + /// The region this block is attached to. + /// + /// If `link.is_linked() == true`, this will always be set to a valid pointer + region: Option>, + /// The list of [Operation]s that comprise this block + ops: OpList, + /// The parameter list for this block + arguments: Vec, +} +impl Usable for Block { + type Use = BlockOperand; + + fn is_used(&self) -> bool { + !self.uses.is_empty() + } + + fn uses(&self) -> BlockOperandIter<'_> { + self.uses.iter() + } + + fn first_use(&self) -> BlockOperandCursor<'_> { + self.uses.front() + } + + fn first_use_mut(&mut self) -> BlockOperandCursorMut<'_> { + self.uses.front_mut() + } +} +impl Block { + #[inline(always)] + pub fn has_predecessors(&self) -> bool { + self.is_used() + } + + #[inline(always)] + pub fn predecessors(&self) -> BlockOperandIter<'_> { + self.uses() + } +} + +/// An intrusive, doubly-linked list of [BlockOperand] +pub type BlockOperandList = EntityList; +pub type BlockOperandCursor<'a> = EntityCursor<'a, BlockOperand>; +pub type BlockOperandCursorMut<'a> = EntityCursorMut<'a, BlockOperand>; +pub type BlockOperandIter<'a> = EntityIter<'a, BlockOperand>; + +/// A [BlockOperand] represents a use of a [Block] by an [Operation] +pub struct BlockOperand { + /// The block value + pub block: Block, + /// The owner of this operand, i.e. the operation it is an operand of + pub owner: EntityHandle, + /// The index of this operand in the set of block operands of the operation + pub index: u8, +} +impl BlockOperand { + #[inline] + pub fn new(block: Block, owner: EntityHandle, index: u8) -> Self { + Self { + block, + owner, + index, + } + } +} diff --git a/hir2/src/core/component.rs b/hir2/src/core/component.rs new file mode 100644 index 000000000..e69de29bb diff --git a/hir2/src/core/context.rs b/hir2/src/core/context.rs new file mode 100644 index 000000000..52452f208 --- /dev/null +++ b/hir2/src/core/context.rs @@ -0,0 +1,94 @@ +use core::{ + cell::{Cell, UnsafeCell}, + fmt, + mem::MaybeUninit, + ptr::NonNull, +}; + +use blink_alloc::Blink; +use cranelift_entity::{PrimaryMap, SecondaryMap}; + +use super::{ + entity::{EntityObj, TrackedEntityObj}, + *, +}; +use crate::UnsafeRef; + +pub struct Context { + pub allocator: Blink, + pub blocks: PrimaryMap, + pub values: PrimaryMap, + pub constants: ConstantPool, +} + +impl Context { + pub fn new() -> Self { + let allocator = Blink::new(); + Self { + allocator, + blocks: PrimaryMap::new(), + values: PrimaryMap::new(), + constants: Default::default(), + } + } + + /// Allocate a new uninitialized entity of type `T` + /// + /// In general, you can probably prefer [Context::alloc] instead, but for use cases where you + /// need to allocate the space for `T` first, and then perform initialization, this can be + /// used. + pub fn alloc_uninit(&self) -> EntityHandle> { + let entity = self.allocator.uninit::>(); + unsafe { EntityHandle::new_uninit(NonNull::new_unchecked(entity)) } + } + + /// Allocate a new uninitialized entity of type `T`, which needs to be tracked in an intrusive + /// doubly-linked list. + /// + /// In general, you can probably prefer [Context::alloc_tracked] instead, but for use cases + /// where you need to allocate the space for `T` first, and then perform initialization, + /// this can be used. + pub fn alloc_uninit_tracked(&self) -> TrackedEntityHandle> { + let entity = self.allocator.uninit::>(); + unsafe { TrackedEntityHandle::new_uninit(NonNull::new_unchecked(entity)) } + } + + /// Allocate a new `EntityHandle`. + /// + /// [EntityHandle] is a smart-pointer type for IR entities, which behaves like a ref-counted + /// pointer with dynamically-checked borrow checking rules. It is designed to play well with + /// entities allocated from a [Context], and with the somewhat cyclical nature of the IR. + pub fn alloc(&self, value: T) -> EntityHandle { + let entity = self.allocator.put(EntityObj::new(value)); + unsafe { EntityHandle::new(NonNull::new_unchecked(entity)) } + } + + /// Allocate a new `TrackedEntityHandle`. + /// + /// [TrackedEntityHandle] is like [EntityHandle], except that it is specially designed for + /// entities which are meant to be tracked in intrusive linked lists. For example, the blocks + /// in a region, or the ops in a block. It does this without requiring the entity to know about + /// the link at all, while still making it possible to access the link from the entity. + pub fn alloc_tracked(&self, value: T) -> TrackedEntityHandle { + let entity = self.allocator.put(TrackedEntityObj::new(value)); + unsafe { TrackedEntityHandle::new(NonNull::new_unchecked(entity)) } + } + + pub fn create_op(&mut self, mut op: T) -> OpId { + let key = self.ops.next_key(); + let op = self.allocator.put(op); + let ptr = op as *mut T; + { + let operation = op.as_operation_mut(); + operation.key = key; + operation.vtable.set_data_ptr(ptr); + } + let op = unsafe { NonNull::new_unchecked(op) }; + self.ops.push(op.cast()); + key + } + + pub fn op(&self, id: OpId) -> &dyn Op { + self.ops[id].as_ref() + } +} diff --git a/hir2/src/core/entity.rs b/hir2/src/core/entity.rs new file mode 100644 index 000000000..e731a5444 --- /dev/null +++ b/hir2/src/core/entity.rs @@ -0,0 +1,679 @@ +mod list; + +use core::{ + cell::{Cell, UnsafeCell}, + fmt, + mem::MaybeUninit, + ptr::NonNull, +}; + +pub use self::list::{EntityCursor, EntityCursorMut, EntityIter, EntityList}; + +pub trait Entity { + type Id: EntityId; + + fn id(&self) -> Self::Key; + unsafe fn set_id(&self, id: Self::Key); +} + +pub trait EntityId: Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash { + fn as_usize(&self) -> usize; + unsafe fn from_usize(raw: usize) -> Self; +} + +/// An error raised when an aliasing violation is detected in the use of [EntityHandle] +#[non_exhaustive] +pub struct AliasingViolationError { + #[cfg(debug_assertions)] + location: &'static core::panic::Location<'static>, + kind: AliasingViolationKind, +} + +#[derive(Debug)] +enum AliasingViolationKind { + /// Attempted to create an immutable alias for an entity that was mutably borrowed + Immutable, + /// Attempted to create a mutable alias for an entity that was immutably borrowed + Mutable, +} + +impl fmt::Display for AliasingViolationKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Immutable => f.write_str("already mutably borrowed"), + Self::Mutable => f.write_str("already borrowed"), + } + } +} + +impl fmt::Debug for AliasingViolationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("AliasingViolationError"); + builder.field("kind", &self.kind); + #[cfg(debug_assertions)] + builder.field("location", &self.location); + builder.finish() + } +} +impl fmt::Display for AliasingViolationError { + #[cfg(debug_assertions)] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} in file '{}' at line {} and column {}", + &self.kind, + self.location.file(), + self.location.line(), + self.location.column() + ) + } + + #[cfg(not(debug_assertions))] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", &self.kind) + } +} + +/// An [EntityHandle] is a smart-pointer type for IR entities allocated in a [Context]. +/// +/// Unlike regular references, no reference to the underlying `T` is constructed until one is +/// needed, at which point the borrow (whether mutable or immutable) is dynamically checked to +/// ensure that it is valid according to Rust's aliasing rules. +/// +/// As a result, an [EntityHandle] is not considered an alias, and it is possible to acquire a +/// mutable reference to the underlying data even while other copies of the handle exist. Any +/// attempt to construct invalid aliases (immutable reference while a mutable reference exists, or +/// vice versa), will result in a runtime panic. +/// +/// This is a tradeoff, as we do not get compile-time guarantees that such panics will not occur, +/// but in exchange we get a much more flexible and powerful IR structure. +pub struct EntityHandle { + inner: NonNull>, +} +impl Clone for EntityHandle { + fn clone(&self) -> Self { + Self { inner: self.inner } + } +} +impl EntityHandle { + /// Create a new [EntityHandle] from a raw pointer to the underlying [EntityObj]. + /// + /// # SAFETY + /// + /// [EntityHandle] is designed to operate like an owned smart-pointer type, ala `Rc`. As a + /// result, it expects that the underlying data _never moves_ after it is allocated, for as + /// long as any outstanding [EntityHandle]s exist that might be used to access that data. + /// + /// Additionally, it is expected that all accesses to the underlying data flow through an + /// [EntityHandle], as it is the foundation on which the soundness of [EntityHandle] is built. + /// You must ensure that there no other references to the underlying data exist, or can be + /// created, _except_ via [EntityHandle]. + /// + /// You should generally not be using this API, as it is meant solely for constructing an + /// [EntityHandle] immediately after allocating the underlying [EntityObj]. + pub(crate) unsafe fn new(ptr: NonNull>) -> Self { + Self { inner } + } + + /// Get a dynamically-checked immutable reference to the underlying `T` + pub fn get(&self) -> EntityRef<'_, T> { + unsafe { + let obj = self.inner.as_ref(); + obj.borrow() + } + } + + /// Get a dynamically-checked mutable reference to the underlying `T` + pub fn get_mut(&mut self) -> EntityMut<'_, T> { + unsafe { + let obj = self.inner.as_ref(); + obj.borrow_mut() + } + } + + /// Convert this handle into a raw pointer to the underlying entity. + /// + /// This should only be used in situations where the returned pointer will not be used to + /// actually access the underlying entity. Use [get] or [get_mut] for that. [EntityHandle] + /// ensures that Rust's aliasing rules are not violated when using it, but if you use the + /// returned pointer to do so, no such guarantee is provided, and undefined behavior can + /// result. + /// + /// # SAFETY + /// + /// The returned pointer _must_ not be used to create a reference to the underlying entity + /// unless you can guarantee that such a reference does not violate Rust's aliasing rules. + /// + /// Do not use the pointer to create a mutable reference if other references exist, and do + /// not use the pointer to create an immutable reference if a mutable reference exists or + /// might be created while the immutable reference lives. + pub fn into_raw(self) -> NonNull { + unsafe { NonNull::new_unchecked(self.inner.as_ref().as_ptr()) } + } +} + +impl EntityHandle> { + /// Create an [EntityHandle] for an entity which may not be fully initialized. + /// + /// # SAFETY + /// + /// The safety rules are much the same as [EntityHandle::new], with the main difference + /// being that the `T` does not have to be initialized yet. No references to the `T` will + /// be created directly until [EntityHandle::assume_init] is called. + pub(crate) unsafe fn new_uninit(ptr: NonNull>>) -> Self { + Self { inner: ptr } + } + + /// Converts to `EntityHandle`. + /// + /// Just like with [MaybeUninit::assume_init], it is up to the caller to guarantee that the + /// value really is in an initialized state. Calling this when the content is not yet fully + /// initialized causes immediate undefined behavior. + pub unsafe fn assume_init(self) -> EntityHandle { + EntityHandle { + inner: self.inner.cast(), + } + } +} + +/// A [TrackedEntityHandle] is like [EntityHandle], except it provides built-in support for +/// adding the entity to an [intrusive_collections::LinkedList] that doesn't require constructing +/// a reference to the entity itself, and thus potentially causing an aliasing violation. Instead, +/// the link is stored as part of the underlying allocation, but separate from the entity. +pub struct TrackedEntityHandle { + inner: NonNull>, +} +impl Clone for TrackedEntityHandle { + fn clone(&self) -> Self { + Self { inner: self.inner } + } +} +impl TrackedEntityHandle { + /// Create a new [TrackedEntityHandle] from a raw pointer to the underlying [TrackedEntityObj]. + /// + /// # SAFETY + /// + /// This function has the same requirements around safety as [EntityHandle::new]. + pub(crate) unsafe fn new(ptr: NonNull>) -> Self { + Self { inner } + } + + /// Get a dynamically-checked immutable reference to the underlying `T` + pub fn get(&self) -> EntityRef<'_, T> { + unsafe { + let obj = self.inner.as_ref(); + obj.entity.borrow() + } + } + + /// Get a dynamically-checked mutable reference to the underlying `T` + pub fn get_mut(&mut self) -> EntityMut<'_, T> { + unsafe { + let obj = self.inner.as_ref(); + obj.entity.borrow_mut() + } + } + + /// Convert this handle into a raw pointer to the underlying entity. + /// + /// This should only be used in situations where the returned pointer will not be used to + /// actually access the underlying entity. Use [get] or [get_mut] for that. [EntityHandle] + /// ensures that Rust's aliasing rules are not violated when using it, but if you use the + /// returned pointer to do so, no such guarantee is provided, and undefined behavior can + /// result. + /// + /// # SAFETY + /// + /// The returned pointer _must_ not be used to create a reference to the underlying entity + /// unless you can guarantee that such a reference does not violate Rust's aliasing rules. + /// + /// Do not use the pointer to create a mutable reference if other references exist, and do + /// not use the pointer to create an immutable reference if a mutable reference exists or + /// might be created while the immutable reference lives. + pub fn into_raw(self) -> NonNull { + unsafe { NonNull::new_unchecked(self.inner.as_ref().entity.as_ptr()) } + } +} + +impl TrackedEntityHandle> { + /// Create a [TrackedEntityHandle] for an entity which may not be fully initialized. + /// + /// # SAFETY + /// + /// The safety rules are much the same as [TrackedEntityHandle::new], with the main difference + /// being that the `T` does not have to be initialized yet. No references to the `T` will + /// be created directly until [TrackedEntityHandle::assume_init] is called. + pub(crate) unsafe fn new_uninit(ptr: NonNull>>) -> Self { + Self { inner: ptr } + } + + /// Converts to `TrackedEntityHandle`. + /// + /// Just like with [MaybeUninit::assume_init], it is up to the caller to guarantee that the + /// value really is in an initialized state. Calling this when the content is not yet fully + /// initialized causes immediate undefined behavior. + pub unsafe fn assume_init(self) -> TrackedEntityHandle { + TrackedEntityHandle { + inner: self.inner.cast(), + } + } +} + +unsafe impl intrusive_collections::PointerOps for TrackedEntityHandle { + type Pointer = TrackedEntityHandle; + type Value = EntityObj; + + unsafe fn from_raw(&self, value: *const Self::Value) -> Self::Pointer { + assert!(!value.is_null()); + let offset = core::mem::offset_of!(TrackedEntityObj, entity); + let ptr = value.cast_mut().byte_sub(offset).cast::>(); + debug_assert!(ptr.is_aligned()); + TrackedEntityHandle::new(NonNull::new_unchecked(ptr)) + } + + fn into_raw(&self, ptr: Self::Pointer) -> *const Self::Value { + let ptr = ptr.into_raw().as_ptr().cast_const(); + let offset = core::mem::offset_of!(EntityObj, cell); + unsafe { ptr.byte_sub(offset).cast() } + } +} + +/// An adapter for storing any `Entity` impl in a [intrusive_collections::LinkedList] +#[derive(Default, Copy, Clone)] +pub struct EntityAdapter(core::marker::PhantomData); +impl EntityAdapter { + pub const fn new() -> Self { + Self(core::marker::PhantomData) + } +} + +unsafe impl intrusive_collections::Adapter for EntityAdapter { + type LinkOps = intrusive_collections::linked_list::LinkOps; + type PointerOps = intrusive_collections::DefaultPointerOps>; + + unsafe fn get_value( + &self, + link: ::LinkPtr, + ) -> *const ::Value { + let offset = core::mem::offset_of!(TrackedEntityObj, link); + let ptr = link.as_ptr().cast_const().byte_sub(offset); + let offset = core::mem::offset_of!(TrackedEntityObj, entity); + ptr.byte_add(offset) + } + + unsafe fn get_link( + &self, + value: *const ::Value, + ) -> ::LinkPtr { + let offset = core::mem::offset_of!(TrackedEntityObj, entity); + let ptr = value.byte_sub(offset); + let offset = core::mem::offset_of!(TrackedEntityObj, link); + let ptr = ptr.byte_add(offset); + NonNull::new_unchecked(ptr.cast_mut()) + } + + fn link_ops(&self) -> &Self::LinkOps { + &intrusive_collections::linked_list::LinkOps + } + + fn link_ops_mut(&mut self) -> &mut Self::LinkOps { + &mut intrusive_collections::linked_list::LinkOps + } + + fn pointer_ops(&self) -> &Self::PointerOps { + const OPS: intrusive_collections::DefaultPointerOps>; + + &OPS + } +} + +/// A guard that ensures a reference to an IR entity cannot be mutably aliased +pub struct EntityRef<'b, T: ?Sized + 'b> { + value: NonNull, + borrow: BorrowRef<'b>, +} +impl core::ops::Deref for EntityRef<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_ref() } + } +} +impl<'b, T: ?Sized> EntityRef<'b, T> { + #[must_use] + #[inline] + pub fn clone(orig: &Self) -> Self { + Self { + value: orig.value, + borrow: orig.borrow.clone(), + } + } + + #[inline] + pub fn map(orig: Self, f: F) -> EntityRef<'b, U> + where + F: FnOnce(&T) -> &U, + { + EntityRef { + value: NonNull::from(f(&*orig)), + borrow: orig.borrow, + } + } +} + +impl<'b, T, U> core::ops::CoerceUnsized> for EntityRef<'b, T> +where + T: ?Sized + core::marker::Unsize, + U: ?Sized, +{ +} + +impl fmt::Display for EntityRef<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +/// A guard that provides exclusive access to an IR entity +pub struct EntityMut<'a, T> { + value: NonNull, + borrow: BorrowRefMut<'b>, + _marker: core::marker::PhantomData<&'b mut T>, +} +impl Deref for EntityMut<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_ref() } + } +} +impl DerefMut for EntityMut<'_, T> { + #[inline] + fn deref_mut(&mut self) -> &mut T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_mut() } + } +} + +impl<'b, T, U> core::ops::CoerceUnsized> for EntityMut<'b, T> +where + T: ?Sized + core::marker::Unsize, + U: ?Sized, +{ +} + +impl fmt::Display for EntityMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +/// An [EntityObj] is a wrapper around IR objects that are allocated in a [Context]. +/// +/// It ensures that any [EntityHandle] which references the underlying entity, adheres to Rust's +/// aliasing rules. +pub struct EntityObj { + borrow: Cell, + #[cfg(debug_assertions)] + borrowed_at: Cell>>, + cell: UnsafeCell, +} + +/// A [TrackedEntityObj] is a wrapper around IR entities that are linked in to an +/// [intrusive_collections::LinkedList] for tracking of that entity. This permits the linked list +/// to be visited/mutated without borrowing the entities themselves, and thus risk violation of +/// the aliasing rules. +pub struct TrackedEntityObj { + link: intrusive_collections::linked_list::LinkedListLink, + entity: EntityObj, +} +impl TrackedEntityObj { + pub fn new(value: T) -> Self { + Self { + link: Default::default(), + entity: EntityObj::new(value), + } + } +} + +impl EntityObj { + pub fn new(value: T) -> Self { + Self { + borrow: Cell::new(BorrowFlag::UNUSED), + #[cfg(debug_assertions)] + borrowed_at: Cell::new(None), + cell: UnsafeCell::new(value), + } + } + + #[track_caller] + #[inline] + pub fn borrow(&self) -> EntityRef<'_, T> { + match self.try_borrow() { + Ok(b) => b, + Err(err) => panic_aliasing_violation(err), + } + } + + #[inline] + #[cfg_attr(debug_assertions, track_caller)] + pub fn try_borrow(&self) -> Result, AliasingViolationError> { + match BorrowRef::new(&self.borrow) { + Some(b) => { + #[cfg(debug_assertions)] + { + // `borrowed_at` is always the *first* active borrow + if b.borrow.get() == 1 { + self.borrowed_at.set(Some(core::panic::Location::caller())); + } + } + + // SAFETY: `BorrowRef` ensures that there is only immutable access to the value + // while borrowed. + let value = unsafe { NonNull::new_unchecked(self.value.get()) }; + Ok(EntityRef { value, borrow: b }) + } + None => Err(AliasingViolationError { + #[cfg(debug_assertions)] + location: self.borrowed_at.get().unwrap(), + kind: AliasingViolationKind::Immutable, + }), + } + } + + #[inline] + #[track_caller] + pub fn borrow_mut(&self) -> EntityMut<'_, T> { + match self.try_borrow_mut() { + Ok(b) => b, + Err(err) => panic_aliasing_violation(err), + } + } + + #[inline] + #[cfg_attr(feature = "debug_refcell", track_caller)] + pub fn try_borrow_mut(&self) -> Result, AliasingViolationError> { + match BorrowRefMut::new(&self.borrow) { + Some(b) => { + #[cfg(debug_assertions)] + { + self.borrowed_at.set(Some(core::panic::Location::caller())); + } + + // SAFETY: `BorrowRefMut` guarantees unique access. + let value = unsafe { NonNull::new_unchecked(self.value.get()) }; + Ok(EntityMut { + value, + borrow: b, + _marker: PhantomData, + }) + } + None => Err(AliasingViolationError { + // If a borrow occurred, then we must already have an outstanding borrow, + // so `borrowed_at` will be `Some` + #[cfg(debug_assertions)] + location: self.borrowed_at.get().unwrap(), + kind: AliasingViolationKind::Mutable, + }), + } + } + + #[inline] + pub fn as_ptr(&self) -> *mut T { + self.cell.get() + } + + #[inline] + pub fn get_mut(&mut self) -> &mut T { + self.cell.get_mut() + } +} + +struct BorrowRef<'b> { + borrow: &'b Cell, +} +impl<'b> BorrowRef<'b> { + #[inline] + fn new(borrow: &'b Cell) -> Option { + let b = borrow.get().wrapping_add(1); + if !b.is_reading() { + // Incrementing borrow can result in a non-reading value (<= 0) in these cases: + // 1. It was < 0, i.e. there are writing borrows, so we can't allow a read borrow due to + // Rust's reference aliasing rules + // 2. It was isize::MAX (the max amount of reading borrows) and it overflowed into + // isize::MIN (the max amount of writing borrows) so we can't allow an additional + // read borrow because isize can't represent so many read borrows (this can only + // happen if you mem::forget more than a small constant amount of `EntityRef`s, which + // is not good practice) + None + } else { + // Incrementing borrow can result in a reading value (> 0) in these cases: + // 1. It was = 0, i.e. it wasn't borrowed, and we are taking the first read borrow + // 2. It was > 0 and < isize::MAX, i.e. there were read borrows, and isize is large + // enough to represent having one more read borrow + borrow.set(b); + Some(Self { borrow }) + } + } +} +impl Drop for BorrowRef<'_> { + #[inline] + fn drop(&mut self) { + let borrow = self.borrow.get(); + debug_assert!(borrow.is_reading()); + self.borrow.set(borrow - 1); + } +} +impl Clone for BorrowRef<'_> { + #[inline] + fn clone(&self) -> Self { + // Since this Ref exists, we know the borrow flag + // is a reading borrow. + let borrow = self.borrow.get(); + debug_assert!(borrow.is_reading()); + // Prevent the borrow counter from overflowing into + // a writing borrow. + assert!(borrow != BorrowFlag::MAX); + self.borrow.set(borrow + 1); + BorrowRef { + borrow: self.borrow, + } + } +} + +struct BorrowRefMut<'b> { + borrow: &'b Cell, +} +impl Drop for BorrowRefMut<'_> { + #[inline] + fn drop(&mut self) { + let borrow = self.borrow.get(); + debug_assert!(borrow.is_writing()); + self.borrow.set(borrow + 1); + } +} +impl<'b> BorrowRefMut<'b> { + #[inline] + fn new(borrow: &'b Cell) -> Option { + // NOTE: Unlike BorrowRefMut::clone, new is called to create the initial + // mutable reference, and so there must currently be no existing + // references. Thus, while clone increments the mutable refcount, here + // we explicitly only allow going from UNUSED to UNUSED - 1. + match borrow.get() { + BorrowFlag::UNUSED => { + borrow.set(BorrowFlag::UNUSED - 1); + Some(Self { borrow }) + } + _ => None, + } + } + + // Clones a `BorrowRefMut`. + // + // This is only valid if each `BorrowRefMut` is used to track a mutable + // reference to a distinct, nonoverlapping range of the original object. + // This isn't in a Clone impl so that code doesn't call this implicitly. + #[inline] + fn clone(&self) -> Self { + let borrow = self.borrow.get(); + debug_assert!(borrow.is_writing()); + // Prevent the borrow counter from underflowing. + assert!(borrow != BorrowFlag::MIN); + self.borrow.set(borrow - 1); + Self { + borrow: self.borrow, + } + } +} + +/// Positive values represent the number of outstanding immutable borrows, while negative values +/// represent the number of outstanding mutable borrows. Multiple mutable borrows can only be +/// active simultaneously if they refer to distinct, non-overlapping components of an entity. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +#[repr(transparent)] +struct BorrowFlag(isize); +impl BorrowFlag { + const MAX: Self = Self(isize::MAX); + const MIN: Self = Self(isize::MIN); + const UNUSED: Self = Self(0); + + pub fn is_writing(&self) -> bool { + self.0 < Self::UNUSED.0 + } + + pub fn is_reading(&self) -> bool { + self.0 > Self::UNUSED.0 + } + + #[inline] + pub const fn wrapping_add(self, rhs: isize) -> Self { + Self(self.0.wrapping_add(rhs)) + } +} +impl core::ops::Add for BorrowFlag { + type Output = BorrowFlag; + + #[inline] + fn add(self, rhs: isize) -> Self::Output { + Self(self.0 + rhs) + } +} +impl core::ops::Sub for BorrowFlag { + type Output = BorrowFlag; + + #[inline] + fn sub(self, rhs: isize) -> Self::Output { + Self(self.0 - rhs) + } +} + +// This ensures the panicking code is outlined from `borrow` and `borrow_mut` for `EntityObj`. +#[cfg_attr(not(panic = "abort"), inline(never))] +#[track_caller] +#[cold] +fn panic_aliasing_violation(err: AliasingViolationError) -> ! { + panic!("{err:?}") +} diff --git a/hir2/src/core/entity/list.rs b/hir2/src/core/entity/list.rs new file mode 100644 index 000000000..d3b2cc5df --- /dev/null +++ b/hir2/src/core/entity/list.rs @@ -0,0 +1,540 @@ +use core::fmt; + +use super::{EntityAdapter, EntityRef, TrackedEntityHandle}; + +#[derive(Default)] +pub struct EntityList { + list: intrusive_collections::linked_list::LinkedList>, +} +impl EntityList { + /// Construct a new, empty [EntityList] + pub fn new() -> Self { + Self::default() + } + + /// Returns true if this list is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.list.is_empty() + } + + /// Returns the number of entities in this list + pub fn len(&self) -> usize { + let mut cursor = self.list.front(); + let mut usize = 0; + while !cursor.is_null() { + usize += 1; + cursor.move_next(); + } + usize + } + + /// Prepend `entity` to this list + pub fn push_front(&mut self, entity: TrackedEntityHandle) { + self.list.push_front(entity); + } + + /// Append `entity` to this list + pub fn push_back(&mut self, entity: TrackedEntityHandle) { + self.list.push_back(entity); + } + + /// Remove the entity at the front of the list, returning its [TrackedEntityHandle] + /// + /// Returns `None` if the list is empty. + pub fn pop_front(&mut self) -> Option> { + self.list.pop_back() + } + + /// Remove the entity at the back of the list, returning its [TrackedEntityHandle] + /// + /// Returns `None` if the list is empty. + pub fn pop_back(&mut self) -> Option> { + self.list.pop_back() + } + + /// Get an [EntityCursor] pointing to the first entity in the list, or the null object if + /// the list is empty + pub fn front(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.list.front(), + } + } + + /// Get an [EntityCursorMut] pointing to the first entity in the list, or the null object if + /// the list is empty + pub fn front_mut(&self) -> EntityCursorMut<'_, T> { + EntityCursorMut { + cursor: self.list.front_mut(), + } + } + + /// Get an [EntityCursor] pointing to the last entity in the list, or the null object if + /// the list is empty + pub fn back(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.list.back(), + } + } + + /// Get an [EntityCursorMut] pointing to the last entity in the list, or the null object if + /// the list is empty + pub fn back_mut(&self) -> EntityCursorMut<'_, T> { + EntityCursorMut { + cursor: self.list.back_mut(), + } + } + + /// Get an iterator over the entities in this list + /// + /// The iterator returned produces [EntityRef]s for each item in the list, with their lifetime + /// bound to the list itself, not the iterator. + pub fn iter(&self) -> EntityIter<'_, T> { + EntityIter { + cursor: self.list.cursor(), + started: false, + } + } + + /// Removes all items from this list. + /// + /// This will unlink all entities currently in the list, which requires iterating through all + /// elements in the list. If the entities may be used again, this ensures that their intrusive + /// link is properly unlinked. + pub fn clear(&mut self) { + self.list.clear(); + } + + /// Empties the list without properly unlinking the intrusive links of the items in the list. + /// + /// Since this does not unlink any objects, any attempts to link these objects into another + /// [EntityList] will fail but will not cause any memory unsafety. To unlink those objects + /// manually, you must call the `force_unlink` function on the link. + pub fn fast_clear(&mut self) { + self.list.fast_clear(); + } + + /// Takes all the elements out of the [EntityList], leaving it empty. + /// + /// The taken elements are returned as a new [EntityList]. + pub fn take(&mut self) -> Self { + Self { + list: self.list.take(), + } + } +} + +impl fmt::Debug for EntityList { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_list(); + for entity in self.iter() { + builder.entry(&entity); + } + builder.finish() + } +} + +impl FromIterator> for EntityList { + fn from_iter(iter: T) -> Self + where + T: IntoIterator>, + { + let mut list = EntityList::::default(); + for handle in iter { + list.push_back(handle); + } + list + } +} + +impl IntoIterator for EntityList { + type IntoIter = intrusive_collections::linked_list::IntoIter; + type Item = TrackedEntityHandle; + + fn into_iter(self) -> Self::IntoIter { + self.list.into_iter() + } +} + +impl<'a, T> IntoIterator for &'a EntityList { + type IntoIter = EntityIter<'a, T>; + type Item = EntityRef<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// A cursor which provides read-only access to an [EntityList]. +pub struct EntityCursor<'a, T> { + cursor: intrusive_collections::linked_list::Cursor<'a, EntityAdapter>, +} +impl<'a, T> EntityCursor<'a, T> { + /// Returns true if this cursor is pointing to the null object + #[inline] + pub fn is_null(&self) -> bool { + self.cursor.is_null() + } + + /// Get a shared reference to the entity under the cursor. + /// + /// Returns `None` if the cursor is currently pointing to the null object. + /// + /// NOTE: This returns an [EntityRef] whose lifetime is bound to the underlying [EntityList], + /// _not_ the [EntityCursor], since the cursor cannot mutate the list. + pub fn get(&self) -> Option> { + match self.cursor.get() { + Some(obj) => Some(obj.borrow()), + None => None, + } + } + + /// Get the [TrackedEntityHandle] corresponding to the entity under the cursor. + /// + /// Returns `None` if the cursor is pointing to the null object. + #[inline] + pub fn as_pointer(&self) -> Option> { + self.cursor.clone_pointer() + } + + /// Moves the cursor to the next element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will move it to the front of the + /// [EntityList]. If it is pointing to the back of the [EntityList] then this will move it to + /// the null object. + #[inline] + pub fn move_next(&mut self) { + self.cursor.move_next(); + } + + /// Moves the cursor to the previous element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will move it to the back of the + /// [EntityList]. If it is pointing to the front of the [EntityList] then this will move it to + /// the null object. + #[inline] + pub fn move_prev(&mut self) { + self.cursor.move_prev(); + } + + /// Returns a cursor pointing to the next element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will return a cursor pointing to the + /// front of the [EntityList]. If it is pointing to the last entity of the [EntityList] then + /// this will return a null cursor. + #[inline] + pub fn peek_next(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.peek_next(), + } + } + + /// Returns a cursor pointing to the previous element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will return a cursor pointing to + /// the last entity in the [EntityList]. If it is pointing to the front of the [EntityList] then + /// this will return a null cursor. + #[inline] + pub fn peek_prev(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.peek_prev(), + } + } +} + +/// A cursor which provides mutable access to an [EntityList]. +pub struct EntityCursorMut<'a, T> { + cursor: intrusive_collections::linked_list::CursorMut<'a, EntityAdapter>, +} +impl<'a, T> EntityCursorMut<'a, T> { + /// Returns true if this cursor is pointing to the null object + #[inline] + pub fn is_null(&self) -> bool { + self.cursor.is_null() + } + + /// Get a shared reference to the entity under the cursor. + /// + /// Returns `None` if the cursor is currently pointing to the null object. + /// + /// NOTE: This binds the lifetime of the [EntityRef] to the cursor, to ensure that the cursor + /// is frozen while the entity is being borrowed. This ensures that only one reference at a + /// time is being handed out by this cursor. + pub fn get(&self) -> Option> { + match self.cursor.get() { + Some(obj) => Some(obj.borrow()), + None => None, + } + } + + /// Get a mutable reference to the entity under the cursor. + /// + /// Returns `None` if the cursor is currently pointing to the null object. + /// + /// Not only does this mutably borrow the cursor, the lifetime of the [EntityMut] is bound to + /// that of the cursor, which means it cannot outlive the cursor, and also prevents the cursor + /// from being accessed in any way until the mutable reference is dropped. This makes it + /// impossible to try and alias the underlying entity using the cursor. + pub fn get_mut(&mut self) -> Option> { + match self.cursor.get() { + Some(obj) => Some(obj.borrow_mut()), + None => None, + } + } + + /// Returns a read-only cursor pointing to the current element. + /// + /// The lifetime of the returned [EntityCursor] is bound to that of the [EntityCursorMut], which + /// means it cannot outlive the [EntityCursorMut] and that the [EntityCursorMut] is frozen for + /// the lifetime of the [EntityCursor]. + pub fn as_cursor(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.as_cursor(), + } + } + + /// Get the [TrackedEntityHandle] corresponding to the entity under the cursor. + /// + /// Returns `None` if the cursor is pointing to the null object. + #[inline] + pub fn as_pointer(&self) -> Option> { + self.cursor.clone_pointer() + } + + /// Moves the cursor to the next element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will move it to the front of the + /// [EntityList]. If it is pointing to the back of the [EntityList] then this will move it to + /// the null object. + #[inline] + pub fn move_next(&mut self) { + self.cursor.move_next(); + } + + /// Moves the cursor to the previous element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will move it to the back of the + /// [EntityList]. If it is pointing to the front of the [EntityList] then this will move it to + /// the null object. + #[inline] + pub fn move_prev(&mut self) { + self.cursor.move_prev(); + } + + /// Returns a cursor pointing to the next element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will return a cursor pointing to the + /// front of the [EntityList]. If it is pointing to the last entity of the [EntityList] then + /// this will return a null cursor. + #[inline] + pub fn peek_next(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.peek_next(), + } + } + + /// Returns a cursor pointing to the previous element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will return a cursor pointing to + /// the last entity in the [EntityList]. If it is pointing to the front of the [EntityList] then + /// this will return a null cursor. + #[inline] + pub fn peek_prev(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.peek_prev(), + } + } + + /// Removes the current entity from the [EntityList]. + /// + /// A pointer to the element that was removed is returned, and the cursor is moved to point to + /// the next element in the [Entitylist]. + /// + /// If the cursor is currently pointing to the null object then nothing is removed and `None` is + /// returned. + #[inline] + pub fn remove(&mut self) -> Option> { + self.cursor.remove() + } + + /// Removes the current entity from the [EntityList] and inserts another one in its place. + /// + /// A pointer to the entity that was removed is returned, and the cursor is modified to point to + /// the newly added entity. + /// + /// If the cursor is currently pointing to the null object then `Err` is returned containing the + /// entity we failed to insert. + /// + /// # Panics + /// Panics if the new entity is already linked to a different intrusive collection. + #[inline] + pub fn replace_with( + &mut self, + value: TrackedEntityHandle, + ) -> Result, TrackedEntityHandle> { + self.cursor.replace_with(value) + } + + /// Inserts a new entity into the [EntityList], after the current cursor position. + /// + /// If the cursor is pointing at the null object then the entity is inserted at the start of the + /// underlying [EntityList]. + /// + /// # Panics + /// + /// Panics if the entity is already linked to a different [EntityList] + #[inline] + pub fn insert_after(&mut self, value: TrackedEntityHandle) { + self.cursor.insert_after(value) + } + + /// Inserts a new entity into the [EntityList], before the current cursor position. + /// + /// If the cursor is pointing at the null object then the entity is inserted at the end of the + /// underlying [EntityList]. + /// + /// # Panics + /// + /// Panics if the entity is already linked to a different [EntityList] + #[inline] + pub fn insert_before(&mut self, value: TrackedEntityHandle) { + self.cursor.insert_before(value) + } + + /// This splices `list` into the underlying list of `self` by inserting the elements of `list` + /// after the current cursor position. + /// + /// For example, let's say we have the following list and cursor position: + /// + /// ```text,ignore + /// [A, B, C] + /// ^-- cursor + /// ``` + /// + /// Splicing a new list, `[D, E, F]` after the cursor would result in: + /// + /// ```text,ignore + /// [A, B, D, E, F, C] + /// ^-- cursor + /// ``` + /// + /// If the cursor is pointing at the null object, then `list` is appended to the start of the + /// underlying [EntityList] for this cursor. + #[inline] + pub fn splice_after(&mut self, list: EntityList) { + self.cursor.splice_after(list.list) + } + + /// This splices `list` into the underlying list of `self` by inserting the elements of `list` + /// before the current cursor position. + /// + /// For example, let's say we have the following list and cursor position: + /// + /// ```text,ignore + /// [A, B, C] + /// ^-- cursor + /// ``` + /// + /// Splicing a new list, `[D, E, F]` before the cursor would result in: + /// + /// ```text,ignore + /// [A, D, E, F, B, C] + /// ^-- cursor + /// ``` + /// + /// If the cursor is pointing at the null object, then `list` is appended to the end of the + /// underlying [EntityList] for this cursor. + #[inline] + pub fn splice_before(&mut self, list: EntityList) { + self.cursor.splice_before(list.list) + } + + /// Splits the list into two after the current cursor position. + /// + /// This will return a new list consisting of everything after the cursor, with the original + /// list retaining everything before. + /// + /// If the cursor is pointing at the null object then the entire contents of the [EntityList] + /// are moved. + pub fn split_after(&mut self) -> EntityList { + let list = self.cursor.split_after(); + EntityList { list } + } + + /// Splits the list into two before the current cursor position. + /// + /// This will return a new list consisting of everything before the cursor, with the original + /// list retaining everything after. + /// + /// If the cursor is pointing at the null object then the entire contents of the [EntityList] + /// are moved. + pub fn split_before(&mut self) -> EntityList { + let list = self.cursor.split_before(); + EntityList { list } + } + + /// Consumes this cursor, and returns a reference to the entity that the cursor is currently + /// pointing to. + /// + /// Unlike [get], the returned reference’s lifetime is tied to [EntityList]’s lifetime. + /// + /// This returns `None` if the cursor is currently pointing to the null object. + /// + /// NOTE: This function will panic if there are any outstanding mutable borrows of the + /// underlying entity. + pub fn into_ref(self) -> Option> { + match self.cursor.get() { + Some(obj) => Some(obj.borrow()), + None => None, + } + } + + /// Consumes this cursor, and returns a mutable reference to the entity that the cursor is + /// currently pointing to. + /// + /// Unlike [get_mut], the returned reference’s lifetime is tied to the [EntityList]’s lifetime. + /// + /// This returns `None` if the cursor is currently pointing to the null object. + /// + /// NOTE: This function will panic if there are any outstanding borrows of the underlying entity + pub fn into_mut(self) -> Option> { + match self.cursor.get() { + Some(obj) => Some(obj.borrow_mut()), + None => None, + } + } +} + +pub struct EntityIter<'a, T> { + cursor: EntityCursor<'a, T>, + started: bool, +} +impl<'a, T> core::iter::FusedIterator for EntityIter<'a, T> {} +impl<'a, T> Iterator for EntityIter<'a, T> { + type Item = EntityRef<'a, T>; + + fn next(&mut self) -> Option { + // If we haven't started iterating yet, then we're on the null cursor, so move to the + // front of the list now that we have started iterating. + if !self.started { + self.started = true; + self.cursor.move_next(); + } + let item = self.cursor.get()?; + self.cursor.move_next(); + Some(item) + } +} +impl<'a, T> DoubleEndedIterator for EntityIter<'a, T> { + fn next_back(&mut self) -> Option { + // If we haven't started iterating yet, then we're on the null cursor, so move to the + // back of the list now that we have started iterating. + if !self.started { + self.started = true; + self.cursor.move_prev(); + } + let item = self.cursor.get()?; + self.cursor.move_prev(); + Some(item) + } +} diff --git a/hir2/src/core/function.rs b/hir2/src/core/function.rs new file mode 100644 index 000000000..7bd377d4e --- /dev/null +++ b/hir2/src/core/function.rs @@ -0,0 +1,308 @@ +use super::{Operation, Symbol}; +use crate::Spanned; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct FunctionIdent { + module: midenc_hir_symbol::Symbol, + function: midenc_hir_symbol::Symbol, +} + +#[derive(Spanned)] +pub struct Function { + #[span] + op: Operation, + id: FunctionIdent, + signature: Signature, +} +impl Symbol for Function { + type Id = midenc_hir_symbol::Symbol; + + fn id(&self) -> Self::Id { + self.id.function + } +} + +struct Function + +/// Represents the calling convention of a function. +/// +/// Calling conventions are part of a program's ABI (Application Binary Interface), and +/// they define things such how arguments are passed to a function, how results are returned, +/// etc. In essence, the contract between caller and callee is described by the calling convention +/// of a function. +/// +/// Importantly, it is perfectly normal to mix calling conventions. For example, the public +/// API for a C library will use whatever calling convention is used by C on the target +/// platform (for Miden, that would be `SystemV`). However, internally that library may use +/// the `Fast` calling convention to allow the compiler to optimize more effectively calls +/// from the public API to private functions. In short, choose a calling convention that is +/// well-suited for a given function, to the extent that other constraints don't impose a choice +/// on you. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[cfg_attr( + feature = "serde", + derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) +)] +#[repr(u8)] +pub enum CallConv { + /// This calling convention is what I like to call "chef's choice" - the + /// compiler chooses it's own convention that optimizes for call performance. + /// + /// As a result of this, it is not permitted to use this convention in externally + /// linked functions, as the convention is unstable, and the compiler can't ensure + /// that the caller in another translation unit will use the correct convention. + Fast, + /// The standard calling convention used for C on most platforms + #[default] + SystemV, + /// A function which is using the WebAssembly Component Model "Canonical ABI". + Wasm, + /// A function with this calling convention must be called using + /// the `syscall` instruction. Attempts to call it with any other + /// call instruction will cause a validation error. The one exception + /// to this rule is when calling another function with the `Kernel` + /// convention that is defined in the same module, which can use the + /// standard `call` instruction. + /// + /// Kernel functions may only be defined in a kernel [Module]. + /// + /// In all other respects, this calling convention is the same as `SystemV` + Kernel, +} +impl fmt::Display for CallConv { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Fast => f.write_str("fast"), + Self::SystemV => f.write_str("C"), + Self::Wasm => f.write_str("wasm"), + Self::Kernel => f.write_str("kernel"), + } + } +} + +/// Represents whether an argument or return value has a special purpose in +/// the calling convention of a function. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[cfg_attr( + feature = "serde", + derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) +)] +#[repr(u8)] +pub enum ArgumentPurpose { + /// No special purpose, the argument is passed/returned by value + #[default] + Default, + /// Used for platforms where the calling convention expects return values of + /// a certain size to be written to a pointer passed in by the caller. + StructReturn, +} +impl fmt::Display for ArgumentPurpose { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Default => f.write_str("default"), + Self::StructReturn => f.write_str("sret"), + } + } +} + +/// Represents how to extend a small integer value to native machine integer width. +/// +/// For Miden, native integrals are unsigned 64-bit field elements, but it is typically +/// going to be the case that we are targeting the subset of Miden Assembly where integrals +/// are unsigned 32-bit integers with a standard twos-complement binary representation. +/// +/// It is for the latter scenario that argument extension is really relevant. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[cfg_attr( + feature = "serde", + derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) +)] +#[repr(u8)] +pub enum ArgumentExtension { + /// Do not perform any extension, high bits have undefined contents + #[default] + None, + /// Zero-extend the value + Zext, + /// Sign-extend the value + Sext, +} +impl fmt::Display for ArgumentExtension { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::None => f.write_str("none"), + Self::Zext => f.write_str("zext"), + Self::Sext => f.write_str("sext"), + } + } +} + +/// Describes a function parameter or result. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct AbiParam { + /// The type associated with this value + pub ty: Type, + /// The special purpose, if any, of this parameter or result + pub purpose: ArgumentPurpose, + /// The desired approach to extending the size of this value to + /// a larger bit width, if applicable. + pub extension: ArgumentExtension, +} +impl AbiParam { + pub fn new(ty: Type) -> Self { + Self { + ty, + purpose: ArgumentPurpose::default(), + extension: ArgumentExtension::default(), + } + } + + pub fn sret(ty: Type) -> Self { + assert!(ty.is_pointer(), "sret parameters must be pointers"); + Self { + ty, + purpose: ArgumentPurpose::StructReturn, + extension: ArgumentExtension::default(), + } + } +} +impl formatter::PrettyPrint for AbiParam { + fn render(&self) -> formatter::Document { + use crate::formatter::*; + + let mut doc = const_text("(") + const_text("param") + const_text(" "); + if !matches!(self.purpose, ArgumentPurpose::Default) { + doc += const_text("(") + display(self.purpose) + const_text(")") + const_text(" "); + } + if !matches!(self.extension, ArgumentExtension::None) { + doc += const_text("(") + display(self.extension) + const_text(")") + const_text(" "); + } + doc + text(format!("{}", &self.ty)) + const_text(")") + } +} + +/// A [Signature] represents the type, ABI, and linkage of a function. +/// +/// A function signature provides us with all of the necessary detail to correctly +/// validate and emit code for a function, whether from the perspective of a caller, +/// or the callee. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Signature { + /// The arguments expected by this function + pub params: Vec, + /// The results returned by this function + pub results: Vec, + /// The calling convention that applies to this function + pub cc: CallConv, + /// The linkage that should be used for this function + pub linkage: Linkage, +} +impl Signature { + /// Create a new signature with the given parameter and result types, + /// for a public function using the `SystemV` calling convention + pub fn new, R: IntoIterator>( + params: P, + results: R, + ) -> Self { + Self { + params: params.into_iter().collect(), + results: results.into_iter().collect(), + cc: CallConv::SystemV, + linkage: Linkage::External, + } + } + + /// Returns true if this function is externally visible + pub fn is_public(&self) -> bool { + matches!(self.linkage, Linkage::External) + } + + /// Returns true if this function is only visible within it's containing module + pub fn is_private(&self) -> bool { + matches!(self.linkage, Linkage::Internal) + } + + /// Returns true if this function is a kernel function + pub fn is_kernel(&self) -> bool { + matches!(self.cc, CallConv::Kernel) + } + + /// Returns the number of arguments expected by this function + pub fn arity(&self) -> usize { + self.params().len() + } + + /// Returns a slice containing the parameters for this function + pub fn params(&self) -> &[AbiParam] { + self.params.as_slice() + } + + /// Returns the parameter at `index`, if present + #[inline] + pub fn param(&self, index: usize) -> Option<&AbiParam> { + self.params.get(index) + } + + /// Returns a slice containing the results of this function + pub fn results(&self) -> &[AbiParam] { + match self.results.as_slice() { + [AbiParam { ty: Type::Unit, .. }] => &[], + [AbiParam { + ty: Type::Never, .. + }] => &[], + results => results, + } + } +} +impl Eq for Signature {} +impl PartialEq for Signature { + fn eq(&self, other: &Self) -> bool { + self.linkage == other.linkage + && self.cc == other.cc + && self.params.len() == other.params.len() + && self.results.len() == other.results.len() + } +} +impl formatter::PrettyPrint for Signature { + fn render(&self) -> formatter::Document { + use crate::formatter::*; + + let cc = if matches!(self.cc, CallConv::SystemV) { + None + } else { + Some( + const_text("(") + + const_text("cc") + + const_text(" ") + + display(self.cc) + + const_text(")"), + ) + }; + + let params = self.params.iter().fold(cc.unwrap_or(Document::Empty), |acc, param| { + if acc.is_empty() { + param.render() + } else { + acc + const_text(" ") + param.render() + } + }); + + if self.results.is_empty() { + params + } else { + let open = const_text("(") + const_text("result"); + let results = self + .results + .iter() + .fold(open, |acc, e| acc + const_text(" ") + text(format!("{}", &e.ty))) + + const_text(")"); + if matches!(params, Document::Empty) { + results + } else { + params + const_text(" ") + results + } + } + } +} diff --git a/hir2/src/core/interface.rs b/hir2/src/core/interface.rs new file mode 100644 index 000000000..e69de29bb diff --git a/hir2/src/core/module.rs b/hir2/src/core/module.rs new file mode 100644 index 000000000..bd62388b3 --- /dev/null +++ b/hir2/src/core/module.rs @@ -0,0 +1,43 @@ +use std::{ + any::{Any, TypeId}, + collections::BTreeMap, +}; + +use super::{Function, FunctionIdent, Symbol, SymbolTable}; +use crate::UnsafeRef; + +pub struct Module { + name: midenc_hir_symbol::Symbol, + functions: BTreeMap>, +} +impl SymbolTable for Module { + type Key = midenc_hir_symbol::Symbol; + + fn get(&self, id: &Self::Key) -> Option> + where + T: Symbol, + { + if TypeId::of::() == TypeId::of::() { + self.functions.get(id).copied().map(|unsafe_ref| { + let ptr = unsafe_ref.into_raw(); + UnsafeRef::new(ptr.cast()) + }) + } else { + None + } + } + + fn insert(&self, entry: UnsafeRef) -> bool + where + T: Symbol, + { + todo!() + } + + fn remove(&self, id: &Self::Key) -> Option> + where + T: Symbol, + { + todo!() + } +} diff --git a/hir2/src/core/op.rs b/hir2/src/core/op.rs new file mode 100644 index 000000000..e69de29bb diff --git a/hir2/src/core/operation.rs b/hir2/src/core/operation.rs new file mode 100644 index 000000000..bdf6d2d20 --- /dev/null +++ b/hir2/src/core/operation.rs @@ -0,0 +1,336 @@ +use core::{ + any::{Any, TypeId}, + mem, + ptr::{NonNull, Pointee}, +}; + +use cranelift_entity::{packed_option::ReservedValue, EntityRef}; +use downcast_rs::{impl_downcast, Downcast}; +use intrusive_collections::{ + container_of, intrusive_adapter, + linked_list::{LinkOps, LinkedListOps}, + LinkedListLink, UnsafeRef, +}; +use smallvec::SmallVec; + +use super::*; + +pub type OpList = EntityList; +pub type OpCursor<'a> = EntityCursor<'a, Operation>; +pub type OpCursorMut<'a> = EntityCursorMut<'a, Operation>; + +/// An [OpSuccessor] is a BlockOperand + OpOperands for that block, attached to an Operation +struct OpSuccessor { + block: TrackedEntityHandle, + args: SmallVec<[TrackedEntityHandle; 1]>, +} + +// TODO: We need a safe way to construct arbitrary Ops imperatively: +// +// * Allocate an uninit instance of T +// * Initialize the Operartion field of T with the empty Operation data +// * Use the primary builder methods to mutate Operation fields +// * Use generated methods on Op-specific builders to mutate Op fields +// * At the end, convert uninit T to init T, return handle to caller +// +// Problems: +// +// * How do we default-initialize an instance of T for this purpose +// * If we use MaybeUninit, how do we compute field offsets for the Operation field +// * Generated methods can compute offsets, but how do we generate the specialized builders? +pub struct OperationBuilder<'a, T> { + context: &'a Context, + op: Operation, + _marker: core::marker::PhantomData, +} +impl OperationBuilder { + pub fn new(context: &'a Context) -> Self { + let op = Operation::uninit::(); + let handle = context.alloc_uninit_tracked(op); + Self { + context, + op, + _marker: core::marker::PhantomData, + } + } + + pub fn build(self) -> TrackedEntityHandle { + todo!() + } +} + +#[derive(Spanned)] +pub struct Operation { + /// In order to support upcasting from [Operation] to its concrete [Op] type, as well as + /// casting to any of the operation traits it implements, we need our own vtable that lets + /// us track the individual vtables for each type and trait we need to cast to for this + /// instance. + pub(crate) vtable: traits::MultiTraitVtable, + #[span] + pub span: SourceSpan, + /// Attributes that apply to this operation + pub attrs: AttributeSet, + /// The containing block of this operation + /// + /// Is set to `None` if this operation is detached + pub block: Option>, + /// The set of operands for this operation + /// + /// NOTE: If the op supports immediate operands, the storage for the immediates is handled + /// by the op, rather than here. Additionally, the semantics of the immediate operands are + /// determined by the op, e.g. whether the immediate operands are always applied first, or + /// what they are used for. + pub operands: SmallVec<[TrackedEntityHandle; 1]>, + /// The set of values produced by this operation. + pub results: SmallVec<[Value; 1]>, + /// If this operation represents control flow, this field stores the set of successors, + /// and successor operands. + pub successors: SmallVec<[OpSuccessor; 1]>, + /// The set of regions belonging to this operation, if any + pub regions: RegionList, +} +impl AsRef for Operation { + fn as_ref(&self) -> &dyn Op { + self.vtable.downcast_trait().unwrap() + } +} +impl AsMut for Operation { + fn as_mut(&mut self) -> &mut dyn Op { + self.vtable.downcast_trait_mut().unwrap() + } +} +impl Operation { + fn uninit() -> Self { + use crate::traits::MultiTraitVtable; + + let mut vtable = MultiTraitVtable::new::(); + vtable.register_trait::(); + + Self { + vtable, + span: Default::default(), + attrs: Default::default(), + block: Default::default(), + operands: Default::default(), + results: Default::default(), + successors: Default::default(), + regions: Default::default(), + } + } +} + +/// Traits/Casts +impl Operation { + /// Returns true if the concrete type of this operation is `T` + #[inline] + pub fn is(&self) -> bool { + self.vtable.is::() + } + + /// Returns true if this operation implements `Trait` + #[inline] + pub fn implements(&self) -> bool + where + Trait: ?Sized + Pointee + 'static, + { + self.vtable.implements::() + } + + /// Attempt to downcast to the concrete [Op] type of this operation + pub fn downcast_ref(&self) -> Option<&T> { + self.vtable.downcast_ref::() + } + + /// Attempt to downcast to the concrete [Op] type of this operation + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.vtable.downcast_mut::() + } +} + +/// Attributes +impl Operation { + /// Return the value associated with attribute `name` for this function + pub fn get_attribute(&self, name: &Q) -> Option<&AttributeValue> + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.attrs.get(name) + } + + /// Return true if this function has an attributed named `name` + pub fn has_attribute(&self, name: &Q) -> bool + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.attrs.has(name) + } + + /// Set the attribute `name` with `value` for this function. + pub fn set_attribute(&mut self, name: impl Into, value: impl Into) { + self.attrs.insert(name, value); + } + + /// Remove any attribute with the given name from this function + pub fn remove_attribute(&mut self, name: &Q) + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.attrs.remove(name); + } +} + +/// Navigation +impl Operation { + pub fn prev(&self) -> Option { + unsafe { + let current = core::ptr::NonNull::new_unchecked(&self.link); + LinkOps.prev(current).map(Self::link_to_key) + } + } + + pub fn next(&self) -> Option { + unsafe { + let current = core::ptr::NonNull::new_unchecked(&self.link); + LinkOps.next(current).map(Self::link_to_key) + } + } + + #[inline] + unsafe fn link_to_key(link: NonNull) -> OpId { + let link = link.as_ref(); + let operation = container_of!(link, Operation, link); + let key_offset = mem::offset_of!(Operation, key); + let prev_key = operation.byte_add(key_offset as isize) as *const OpId; + *prev_key + } +} + +/// Operands +impl Operation { + pub fn replaces_uses_of_with(&mut self, from: Value, to: Value) { + if from == to { + return; + } + + for operand in self.operands.iter_mut() { + if operand == &from { + *operand = to; + } + } + } +} + +pub trait Op: Downcast { + type Id: Copy + PartialEq + Eq + PartialOrd + Ord; + + fn id(&self) -> Self::Id; + fn name(&self) -> &'static str; + fn parent(&self) -> Option { + let parent = self.as_operation().parent; + if parent.is_reserved_value() { + None + } else { + Some(parent) + } + } + fn prev(&self) -> Option { + self.as_operation().prev() + } + fn next(&self) -> Option { + self.as_operation().next() + } + fn parent_block(&self) -> Option { + let block = self.as_operation().block; + if block.is_reserved_value() { + None + } else { + Some(block) + } + } + fn regions(&self) -> &[RegionId] { + self.as_operation().regions.as_slice() + } + fn operands(&self) -> &ValueList { + &self.as_operation().operands + } + fn results(&self) -> &ValueList { + &self.as_operation().results + } + fn successors(&self) -> &[Successor] { + self.as_operation().successors.as_slice() + } + fn as_operation(&self) -> &Operation; + fn as_operation_mut(&mut self) -> &mut Operation; +} + +impl_downcast!(Op assoc Id where Id: Copy + PartialEq + Eq + PartialOrd + Ord); + +impl miden_assembly::Spanned for dyn Op { + fn span(&self) -> SourceSpan { + self.as_operation().span + } +} + +pub trait OpExt { + /// Return the value associated with attribute `name` for this function + fn get_attribute(&self, name: &Q) -> Option<&AttributeValue> + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized; + + /// Return true if this function has an attributed named `name` + fn has_attribute(&self, name: &Q) -> bool + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized; + + /// Set the attribute `name` with `value` for this function. + fn set_attribute(&mut self, name: impl Into, value: impl Into); + + /// Remove any attribute with the given name from this function + fn remove_attribute(&mut self, name: &Q) + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized; +} + +impl OpExt for T { + /// Return the value associated with attribute `name` for this function + #[inline] + fn get_attribute(&self, name: &Q) -> Option<&AttributeValue> + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.as_operation().get_attribute(name) + } + + /// Return true if this function has an attributed named `name` + #[inline] + fn has_attribute(&self, name: &Q) -> bool + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.as_operation().has_attribute(name) + } + + /// Set the attribute `name` with `value` for this function. + #[inline] + fn set_attribute(&mut self, name: impl Into, value: impl Into) { + self.as_operation_mut().insert(name, value); + } + + /// Remove any attribute with the given name from this function + #[inline] + fn remove_attribute(&mut self, name: &Q) + where + Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.as_operation_mut().remove(name); + } +} diff --git a/hir2/src/core/region.rs b/hir2/src/core/region.rs new file mode 100644 index 000000000..d8b07cf54 --- /dev/null +++ b/hir2/src/core/region.rs @@ -0,0 +1,19 @@ +use super::{BlockList, EntityCursor, EntityCursorMut, EntityHandle, EntityList}; + +/// An intrusive, doubly-linked list of [Region]s +pub type RegionList = EntityList; + +/// A cursor in a [RegionList] +pub type RegionCursor<'a> = EntityCursor<'a, Region>; + +/// A mutable cursor in a [RegionList] +pub type RegionCursorMut<'a> = EntityCursorMut<'a, Region>; + +pub struct Region { + /// The operation this region is attached to. + /// + /// If `link.is_linked() == true`, this will always be set to a valid pointer + owner: Option>, + /// The list of [Block]s that comprise this region + body: BlockList, +} diff --git a/hir2/src/core/symbol_table.rs b/hir2/src/core/symbol_table.rs new file mode 100644 index 000000000..e122d1557 --- /dev/null +++ b/hir2/src/core/symbol_table.rs @@ -0,0 +1,54 @@ +use core::any::Any; + +use crate::UnsafeRef; + +/// A [SymbolTable] is an IR entity which contains other IR entities, called _symbols_, each of +/// which has a name, aka symbol, that uniquely identifies it amongst all other entities in the +/// same [SymbolTable]. +/// +/// The symbols in a [SymbolTable] do not need to all refer to the same entity type, however the +/// concrete value type of the symbol itself, e.g. `String`, must be the same. This is enforced +/// in the way that the [SymbolTable] and [Symbol] traits interact. A [SymbolTable] has an +/// associated `Key` type, and a [Symbol] has an associated `Id` type - only types whose `Id` +/// type matches the `Key` type of the [SymbolTable], can be stored in that table. +pub trait SymbolTable { + /// The unique key type associated with entries in this symbol table + type Key; + + /// Check if `id` is associated with an entry of type `T` in this table + fn has_symbol_of_type(&self, id: &Self::Key) -> bool + where + T: Symbol, + { + self.get::(id) + } + + /// Get the entry for `id` in this table + fn get(&self, id: &Self::Key) -> Option> + where + T: Symbol; + + /// Insert `entry` in the symbol table. + /// + /// Returns `true` if successful, or `false` if an entry already exists + fn insert(&self, entry: UnsafeRef) -> bool + where + T: Symbol; + + /// Remove the symbol `id`, and return the entry if one was present. + fn remove(&self, id: &Self::Key) -> Option> + where + T: Symbol; +} + +/// A [Symbol] is an IR entity with an associated _symbol_, or name, which is expected to be unique +/// amongst all other symbols in the same namespace. +/// +/// For example, functions are named, and are expected to be unique within the same module, +/// otherwise it would not be possible to unambiguously refer to a function by name. Likewise +/// with modules in a program, etc. +pub trait Symbol: Any { + type Id: Copy + Clone + PartialEq + Eq + PartialOrd + Ord; + + fn id(&self) -> Self::Id; +} diff --git a/hir2/src/core/traits.rs b/hir2/src/core/traits.rs new file mode 100644 index 000000000..8414aa318 --- /dev/null +++ b/hir2/src/core/traits.rs @@ -0,0 +1,30 @@ +mod multitrait; + +pub(crate) use self::multitrait::MultiTraitVtable; + +/// Marker trait for commutative ops, e.g. `X op Y == Y op X` +pub trait Commutative {} + +/// Marker trait for constant-like ops +pub trait ConstantLike {} + +/// Marker trait for ops with side effects +pub trait HasSideEffects {} + +/// Marker trait for ops which read memory +pub trait MemoryRead {} + +/// Marker trait for ops which write memory +pub trait MemoryWrite {} + +/// Marker trait for return-like ops +pub trait ReturnLike {} + +/// All operands of the given op are the same type +pub trait SameTypeOperands {} + +/// Marker trait for ops whose regions contain only a single block +pub trait SingleBlock {} + +/// Marker trait for ops which can terminate a block +pub trait Terminator {} diff --git a/hir2/src/core/traits/multitrait.rs b/hir2/src/core/traits/multitrait.rs new file mode 100644 index 000000000..c7a3be62e --- /dev/null +++ b/hir2/src/core/traits/multitrait.rs @@ -0,0 +1,109 @@ +use core::{ + any::{Any, TypeId}, + ptr::{null, null_mut}, +}; + +pub(crate) struct MultiTraitVtable { + pub(crate) data: *mut (), + pub(crate) type_id: TypeId, + pub(crate) traits: Vec<(TypeId, *const ())>, +} +impl MultiTraitVtable { + pub fn new() -> Self { + let type_id = TypeId::of::(); + let (any_type, any_vtable) = { + let ptr = null::().cast::(); + let (_, vtable) = ptr.to_raw_parts(); + (TypeId::of::(), vtable) + }; + + Self { + data: null_mut(), + type_id, + traits: vec![(any_type, any_vtable)], + } + } + + pub fn set_data_ptr(&mut self, ptr: *mut T) { + let type_id = TypeId::of::(); + assert_eq!(self.type_id, type_id); + self.data = data.cast(); + } + + pub fn register_trait(&mut self) + where + Trait: ?Sized + Pointee + 'static, + { + let (type_id, vtable) = { + let ptr = null::().cast::(); + let (_, vtable) = ptr.to_raw_parts(); + (TypeId::of::(), vtable) + }; + if self.traits.iter().any(|(tid, _)| tid == &type_id) { + return; + } + self.traits.push((type_id, vtable)); + self.traits.sort_by_key(|(tid, _)| tid); + } + + #[inline] + pub fn is(&self) -> bool { + self.type_id == TypeId::of::() + } + + pub fn implements(&self) -> bool + where + Trait: ?Sized + Pointee + 'static, + { + let type_id = TypeId::of::(); + self.traits.binary_search_by(|(tid, _)| tid.cmp(&type_id)).is_ok() + } + + #[inline] + pub fn downcast_ref(&self) -> Option<&T> { + if self.is::() { + Some(unsafe { self.downcast_reF_unchecked() }) + } else { + None + } + } + + #[inline(always)] + unsafe fn downcast_ref_unchecked(&self) -> &T { + core::ptr::from_raw_parts(self.data, ()) + } + + #[inline] + pub fn downcast_mut(&mut self) -> Option<&mut T> { + if self.is::() { + Some(unsafe { self.downcast_mut_unchecked() }) + } else { + None + } + } + + #[inline(always)] + unsafe fn downcast_mut_unchecked(&mut self) -> &mut T { + core::ptr::from_raw_parts(self.data, ()) + } + + pub fn downcast_trait(&self) -> Option<&Trait> + where + Trait: ?Sized + Pointee + 'static, + { + self.traits.binary_search_by(|(tid, _)| tid.cmp(&type_id)).map(|index| { + let vtable = self.traits[index].1; + core::ptr::from_raw_parts::(self.data, vtable) + }) + } + + pub fn downcast_trait_mut(&mut self) -> Option<&mut Trait> + where + Trait: ?Sized + Pointee + 'static, + { + self.traits.binary_search_by(|(tid, _)| tid.cmp(&type_id)).map(|index| { + let vtable = self.traits[index].1; + core::ptr::from_raw_parts_mut::(self.data, vtable) + }) + } +} diff --git a/hir2/src/core/types.rs b/hir2/src/core/types.rs new file mode 100644 index 000000000..aba5bea0a --- /dev/null +++ b/hir2/src/core/types.rs @@ -0,0 +1 @@ +pub use midenc_hir_type::*; diff --git a/hir2/src/core/usable.rs b/hir2/src/core/usable.rs new file mode 100644 index 000000000..98afa56d6 --- /dev/null +++ b/hir2/src/core/usable.rs @@ -0,0 +1,35 @@ +use super::{entity::EntityIter, EntityCursor, EntityCursorMut}; + +/// The [Usable] trait is implemented for IR entities which are _defined_ and _used_, and as a +/// result, require a data structure called the _use-def list_. +/// +/// A _definition_ of an IR entity, is a unique instantiation of that entity, the result of which +/// is different from all other definitions, even if the data associated with that definition is +/// the same as another definition. For example, SSA values are defined as either block arguments +/// or operation results, and a given value can only be defined once. +/// +/// A _use_ represents a unique reference to a _definition_ of some IR entity. Each use is unique, +/// and can be used to obtain not only the _user_ of the reference, but the location of that use +/// within the user. Uses are tracked in a _use list_, also called the _use-def list_, which +/// associates all uses to the definition, or _def_, that they reference. For example, operations +/// in HIR _use_ SSA values defined previously in the program. +/// +/// A _user_ does not have to be of the same IR type as the _definition_, and the type representing +/// the _use_ is typically different than both, and represents the type of relationship between the +/// two. For example, an `OpOperand` represents a single use of a `Value` by an `Op`. The entity +/// being defined is a `Value`, the entity using that definition is an `Op`, and the data associated +/// with each use is represented by `OpOperand`. +pub trait Usable { + /// The type associated with each unique use, e.g. `OpOperand` + type Use; + + /// Returns true if this definition is used + fn is_used(&self) -> bool; + /// Get an iterator over the uses of this definition + fn uses(&self) -> EntityIter<'_, Self::Use>; + /// Get a cursor positioned on the first use of this definition, or the null cursor if unused. + fn first_use(&self) -> EntityCursor<'_, Self::Use>; + /// Get a mutable cursor positioned on the first use of this definition, or the null cursor if + /// unused. + fn first_use_mut(&mut self) -> EntityCursorMut<'_, Self::Use>; +} diff --git a/hir2/src/core/value.rs b/hir2/src/core/value.rs new file mode 100644 index 000000000..76d3b5a37 --- /dev/null +++ b/hir2/src/core/value.rs @@ -0,0 +1,166 @@ +use core::{fmt, ptr}; + +use super::{Block, EntityCursor, EntityCursorMut, EntityIter, EntityList, Type, Usable}; +use crate::{SourceSpan, Spanned}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[repr(u8)] +pub enum ValueKind { + Result, + BlockArgument, +} + +#[derive(Spanned)] +pub struct ValueImpl { + kind: ValueKind, + ty: Type, + #[span] + span: SourceSpan, + uses: OpOperandList, +} +impl ValueImpl { + #[inline(always)] + pub const fn kind(&self) -> ValueKind { + self.kind + } + + pub fn is_result(&self) -> bool { + matches!(self, ValueKind::Result) + } + + pub fn is_block_argument(&self) -> bool { + matches!(self, ValueKind::BlockArgument) + } + + #[inline(always)] + pub fn ty(&self) -> &Type { + &self.ty + } + + #[inline(always)] + pub fn set_type(&mut self, ty: Type) { + self.ty = ty; + } +} +impl Usable for ValueImpl { + type Use = OpOperand; + + #[inline] + fn is_used(&self) -> bool { + !self.uses.is_empty() + } + + #[inline] + fn uses(&self) -> OpOperandIter<'_> { + self.uses.iter() + } + + #[inline] + fn first_use(&self) -> OpOperandCursor<'_> { + self.uses.front() + } + + #[inline] + fn first_use_mut(&mut self) -> OpOperandCursorMut<'_> { + self.uses.front_mut() + } +} +impl fmt::Debug for ValueImpl { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("ValueImpl") + .field("kind", &self.kind) + .field("ty", &self.ty) + .field("uses", &self.uses) + .finish() + } +} + +pub type Value = EntityHandle; + +#[derive(Spanned)] +pub struct BlockArgument { + #[span] + value: ValueImpl, + owner: EntityHandle, + index: u8, +} +impl Usable for BlockArgument { + type Use = OpOperand; + + #[inline] + fn is_used(&self) -> bool { + self.value.is_used() + } + + #[inline] + fn uses(&self) -> OpOperandIter<'_> { + self.value.uses() + } + + #[inline] + fn first_use(&self) -> OpOperandCursor<'_> { + self.value.first_use() + } + + #[inline] + fn first_use_mut(&mut self) -> OpOperandCursorMut<'_> { + self.value.first_use_mut() + } +} + +/// An [OpResult] represents the definition of a [Value] by the result of an [Operation] +#[derive(Spanned)] +pub struct OpResult { + #[span] + value: ValueImpl, + owner: EntityHandle, + index: u8, +} +impl Usable for OpResult { + type Use = OpOperand; + + #[inline] + fn is_used(&self) -> bool { + self.value.is_used() + } + + #[inline] + fn uses(&self) -> OpOperandIter<'_> { + self.value.uses() + } + + #[inline] + fn first_use(&self) -> OpOperandCursor<'_> { + self.value.first_use() + } + + #[inline] + fn first_use_mut(&mut self) -> OpOperandCursorMut<'_> { + self.value.first_use_mut() + } +} + +pub type OpOperandList = EntityList; +pub type OpOperandIter<'a> = EntityIter<'a, OpOperand>; +pub type OpOperandCursor<'a> = EntityCursor<'a, OpOperand>; +pub type OpOperandCursorMut<'a> = EntityCursorMut<'a, OpOperand>; + +/// An [OpOperand] represents a use of a [Value] by an [Operation] +pub struct OpOperand { + /// The operand value + pub value: Value, + /// The owner of this operand, i.e. the operation it is an operand of + pub owner: EntityHandle, + /// The index of this operand in the operand list of an operation + pub index: u8, +} +impl OpOperand { + #[inline] + pub fn new(value: Value, owner: EntityHandle, index: u8) -> Self { + Self { + value, + owner, + index, + } + } +} diff --git a/hir2/src/dialects/hir.rs b/hir2/src/dialects/hir.rs new file mode 100644 index 000000000..e69de29bb diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs new file mode 100644 index 000000000..8c7aca80a --- /dev/null +++ b/hir2/src/lib.rs @@ -0,0 +1,6 @@ +mod core; +mod unsafe_ref; + +pub use miden_assembly::{SourceSpan, Spanned}; + +pub use self::{core::*, unsafe_ref::UnsafeRef}; diff --git a/hir2/src/ops/binary.rs b/hir2/src/ops/binary.rs new file mode 100644 index 000000000..576f49763 --- /dev/null +++ b/hir2/src/ops/binary.rs @@ -0,0 +1,60 @@ +use crate::*; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum BinaryOpcode { + Add(Overflow), + Sub(Overflow), + Mul(Overflow), + Div, + Mod, + DivMod, + Exp(Overflow), + And, + Band, + Or, + Bor, + Xor, + Bxor, + Shl, + Shr, + Rotl, + Rotr, + Eq, + Neq, + Gt, + Gte, + Lt, + Lte, + Min, + Max, +} +impl BinaryOpcode { + pub fn is_commutative(&self) -> bool { + matches!( + self, + Self::Add + | Self::Mul + | Self::Min + | Self::Max + | Self::Eq + | Self::Neq + | Self::And + | Self::Band + | Self::Or + | Self::Bor + | Self::Xor + | Self::Bxor + ) + } +} + +pub struct BinaryOp { + pub op: Operation, + pub opcode: BinaryOpcode, +} + +pub struct BinaryOpImm { + pub op: Operation, + pub opcode: BinaryOpcode, + pub imm: Immediate, +} diff --git a/hir2/src/ops/call.rs b/hir2/src/ops/call.rs new file mode 100644 index 000000000..16983dd4f --- /dev/null +++ b/hir2/src/ops/call.rs @@ -0,0 +1,6 @@ +use crate::*; + +pub struct Call { + pub op: Operation, + pub callee: FunctionIdent, +} diff --git a/hir2/src/ops/cast.rs b/hir2/src/ops/cast.rs new file mode 100644 index 000000000..a8141d7ad --- /dev/null +++ b/hir2/src/ops/cast.rs @@ -0,0 +1,33 @@ +use crate::*; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum CastKind { + /// Reinterpret the bits of the operand as the target type, without any consideration for + /// the original meaning of those bits. + /// + /// For example, transmuting `u32::MAX` to `i32`, produces a value of `-1`, because the input + /// value overflows when interpreted as a signed integer. + Transmute, + /// Like `Transmute`, but the input operand is checked to verify that it is a valid value + /// of both the source and target types. + /// + /// For example, a checked cast of `u32::MAX` to `i32` would assert, because the input value + /// cannot be represented as an `i32` due to overflow. + Checked, + /// Convert the input value to the target type, by zero-extending the value to the target + /// bitwidth. A cast of this type must be a widening cast, i.e. from a smaller bitwidth to + /// a larger one. + Zext, + /// Convert the input value to the target type, by sign-extending the value to the target + /// bitwidth. A cast of this type must be a widening cast, i.e. from a smaller bitwidth to + /// a larger one. + Sext, + /// Convert the input value to the target type, by truncating the excess bits. A cast of this + /// type must be a narrowing cast, i.e. from a larger bitwidth to a smaller one. + Trunc, +} + +pub struct Cast { + pub op: Operation, + pub kind: CastKind, +} diff --git a/hir2/src/ops/control_flow.rs b/hir2/src/ops/control_flow.rs new file mode 100644 index 000000000..13b80c8e5 --- /dev/null +++ b/hir2/src/ops/control_flow.rs @@ -0,0 +1,50 @@ +use smallvec::SmallVec; + +use crate::*; + +pub struct Br { + pub op: Operation, +} +impl Br { + pub fn dest(&self) -> &Successor { + &self.op.successors[0] + } +} + +pub struct CondBr { + pub op: Operation, +} +impl CondBr { + pub fn condition(&self) -> Value { + todo!() + } + + pub fn then_dest(&self) -> &Successor { + &self.op.successors[0] + } + + pub fn else_dest(&self) -> &Successor { + &self.op.successors[1] + } +} + +pub struct Switch { + pub op: Operation, + pub cases: SmallVec<[u32; 4]>, + pub default_successor: usize, +} +impl Switch { + pub fn selector(&self) -> Value { + todo!() + } + + pub fn default_dest(&self) -> &Successor { + &self.op.successors[self.default_successor] + } +} + +#[derive(Debug, Clone)] +pub struct SwitchCase { + pub value: u32, + pub successor: Successor, +} diff --git a/hir2/src/ops/global_value.rs b/hir2/src/ops/global_value.rs new file mode 100644 index 000000000..2b53fe926 --- /dev/null +++ b/hir2/src/ops/global_value.rs @@ -0,0 +1,35 @@ +use crate::*; + +#[derive(Debug, Clone)] +pub struct GlobalValueOp { + pub id: GlobalValue, + pub data: GlobalValueData, + pub op: Operation, +} + +impl Op for GlobalValueOp { + type Id = GlobalValue; + + #[inline(always)] + fn id(&self) -> Self::Id { + self.id + } + + fn name(&self) -> &'static str { + match self.data { + GlobalValueData::Symbol { .. } => "global.symbol", + GlobalValueData::Load { .. } => "global.load", + GlobalValueData::IAddImm { .. } => "global.iadd", + } + } + + #[inline(always)] + fn as_operation(&self) -> &Operation { + &self.op + } + + #[inline(always)] + fn as_operation_mut(&mut self) -> &mut Operation { + &mut self.op + } +} diff --git a/hir2/src/ops/inline_asm.rs b/hir2/src/ops/inline_asm.rs new file mode 100644 index 000000000..e69de29bb diff --git a/hir2/src/ops/mem.rs b/hir2/src/ops/mem.rs new file mode 100644 index 000000000..9f810e220 --- /dev/null +++ b/hir2/src/ops/mem.rs @@ -0,0 +1,38 @@ +use crate::*; + +pub struct Store { + pub op: Operation, +} +impl Store { + pub fn addr(&self) -> Value { + todo!() + } + + pub fn value(&self) -> Value { + todo!() + } +} + +pub struct StoreLocal { + pub op: Operation, + pub local: LocalId, +} +impl StoreLocal { + pub fn value(&self) -> Value { + todo!() + } +} + +pub struct Load { + pub op: Operation, +} +impl Load { + pub fn addr(&self) -> Value { + todo!() + } +} + +pub struct LoadLocal { + pub op: Operation, + pub local: LocalId, +} diff --git a/hir2/src/ops/mod.rs b/hir2/src/ops/mod.rs new file mode 100644 index 000000000..5b8c4344c --- /dev/null +++ b/hir2/src/ops/mod.rs @@ -0,0 +1,16 @@ +mod binary; +mod call; +mod cast; +mod control_flow; +mod global_value; +mod inline_asm; +mod mem; +mod primop; +mod ret; +mod structured_control_flow; +mod unary; + +pub use self::{ + binary::*, call::*, cast::*, control_flow::*, global_value::*, inline_asm::*, mem::*, + primop::*, ret::*, structured_control_flow::*, unary::*, +}; diff --git a/hir2/src/ops/primop.rs b/hir2/src/ops/primop.rs new file mode 100644 index 000000000..a1437d829 --- /dev/null +++ b/hir2/src/ops/primop.rs @@ -0,0 +1,22 @@ +use crate::*; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum PrimOpcode { + MemGrow, + MemSize, + MemSet, + MemCpy, +} + +pub struct PrimOp { + pub op: Operation, +} + +pub struct PrimOpImm { + pub op: Operation, + pub imm: Immediate, +} + +pub struct Unreachable { + pub op: Operation, +} diff --git a/hir2/src/ops/ret.rs b/hir2/src/ops/ret.rs new file mode 100644 index 000000000..c290ee36b --- /dev/null +++ b/hir2/src/ops/ret.rs @@ -0,0 +1,10 @@ +use crate::*; + +pub struct Ret { + pub op: Operation, +} + +pub struct RetImm { + pub op: Operation, + pub imm: Immediate, +} diff --git a/hir2/src/ops/structured_control_flow.rs b/hir2/src/ops/structured_control_flow.rs new file mode 100644 index 000000000..a4add842d --- /dev/null +++ b/hir2/src/ops/structured_control_flow.rs @@ -0,0 +1,56 @@ +use crate::*; + +pub struct If { + pub op: Operation, +} +impl If { + pub fn condition(&self) -> Value { + todo!() + } + + pub fn then_dest(&self) -> &Successor { + todo!() + } + + pub fn else_dest(&self) -> &Successor { + todo!() + } +} + +/// A while is a loop structure composed of two regions: a "before" region, and an "after" region. +/// +/// The "before" region's entry block parameters correspond to the operands expected by the +/// operation, and can be used to compute the condition that determines whether the "after" body +/// is executed or not, or simply forwarded to the "after" region. The "before" region must +/// terminate with a [Condition] operation, which will be evaluated to determine whether or not +/// to continue the loop. +/// +/// The "after" region corresponds to the loop body, and must terminate with a [Yield] operation, +/// whose operands must be of the same arity and type as the "before" region's argument list. In +/// this way, the "after" body can feed back input to the "before" body to determine whether to +/// continue the loop. +pub struct While { + pub op: Operation, +} +impl While { + pub fn before_region(&self) -> RegionId { + self.op.regions[0] + } + + pub fn after_region(&self) -> RegionId { + self.op.regions[1] + } +} + +pub struct Condition { + pub op: Operation, +} +impl Condition { + pub fn condition(&self) -> Value { + todo!() + } +} + +pub struct Yield { + pub op: Operation, +} diff --git a/hir2/src/ops/unary.rs b/hir2/src/ops/unary.rs new file mode 100644 index 000000000..623b040cd --- /dev/null +++ b/hir2/src/ops/unary.rs @@ -0,0 +1,37 @@ +use crate::*; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum UnaryOpcode { + PtrToInt, + IntToPtr, + Cast, + Bitcast, + Trunc, + Zext, + Sext, + Test, + Neg, + Inv, + Incr, + Ilog2, + Pow2, + Popcnt, + Clz, + Ctz, + Clo, + Cto, + Not, + Bnot, + IsOdd, +} + +pub struct UnaryOp { + pub op: Operation, + pub opcode: UnaryOpcode, +} + +pub struct UnaryOpImm { + pub op: Operation, + pub opcode: UnaryOpcode, + pub imm: Immediate, +} diff --git a/hir2/src/unsafe_ref.rs b/hir2/src/unsafe_ref.rs new file mode 100644 index 000000000..d510489f8 --- /dev/null +++ b/hir2/src/unsafe_ref.rs @@ -0,0 +1,101 @@ +use core::ptr::NonNull; + +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct UnsafeRef(NonNull); + +impl UnsafeRef { + /// Construct a new [UnsafeRef] from a non-null pointer to `T` + pub fn new(ptr: NonNull) -> Self { + Self(ptr) + } + + /// Get the underlying raw pointer for this [UnsafeRef] + #[inline(always)] + pub const fn into_raw(self) -> NonNull { + self.0 + } + + /// Construct an [UnsafeRef] from a [Box] + pub fn from_box(ptr: Box) -> Self { + Self(unsafe { NonNull::new_unchecked(Box::into_raw(ptr)) }) + } + + /// Convert this [UnsafeRef] back into the [Box] it was derived from. + /// + /// # Safety + /// + /// The following must be upheld by the caller: + /// + /// * This [UnsafeRef] _MUST_ have been created via [UnsafeRef::from_box] + /// * There _MUST NOT_ be any other [UnsafeRef] pointing to the same allocation + /// * `T` must be the same type as the original [Box] was allocated with + pub unsafe fn into_box(self) -> Box { + Box::from_raw(self.0.as_ptr()) + } +} + +impl core::ops::Deref for UnsafeRef { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl AsRef for UnsafeRef { + fn as_ref(&self) -> &T { + unsafe { self.0.as_ref() } + } +} + +impl AsRef for UnsafeRef +where + T: core::marker::Unsize + ?Sized, + U: ?Sized, +{ + fn as_ref(&self) -> &U { + unsafe { self.0.as_ref() as &U } + } +} + +impl core::borrow::Borrow for UnsafeRef { + fn borrow(&self) -> &T { + unsafe { self.0.as_ref() } + } +} + +impl core::ops::CoerceUnsized> for UnsafeRef +where + T: core::marker::Unsize + ?Sized, + U: ?Sized, +{ +} + +impl core::ops::DispatchFromDyn> for UnsafeRef +where + T: core::marker::Unsize + ?Sized, + U: ?Sized, +{ +} + +unsafe impl Send for UnsafeRef {} + +unsafe impl Sync for UnsafeRef {} + +unsafe impl intrusive_collections::PointerOps + for intrusive_collections::DefaultPointerOps> +{ + type Pointer = UnsafeRef; + type Value = T; + + unsafe fn from_raw(&self, value: *const Self::Value) -> Self::Pointer { + let value = NonNull::new(value.cast_mut()).expect("expected non-null node pointer"); + UnsafeRef::new(value) + } + + fn into_raw(&self, ptr: Self::Pointer) -> *const Self::Value { + ptr.into_raw().as_ptr().cast_const() + } +} From 117dfcf1434a2368bcd1cbab9a31c28e6efb6d7f Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Thu, 19 Sep 2024 17:49:42 -0400 Subject: [PATCH 02/31] wip: basic structure landed, starting refinement --- Cargo.lock | 45 +- Cargo.toml | 1 + hir-symbol/src/symbols.toml | 3 + hir2/Cargo.toml | 2 + hir2/src/attributes.rs | 5 + hir2/src/attributes/call_conv.rs | 57 ++ hir2/src/attributes/linkage.rs | 66 ++ hir2/src/attributes/overflow.rs | 60 ++ hir2/src/core.rs | 39 +- hir2/src/core/attribute.rs | 266 +++----- hir2/src/core/block.rs | 163 ++++- hir2/src/core/context.rs | 130 ++-- hir2/src/core/dialect.rs | 61 ++ hir2/src/core/entity.rs | 641 ++++++++++++------ hir2/src/core/entity/list.rs | 305 +++++++-- hir2/src/core/function.rs | 71 +- hir2/src/core/ident.rs | 230 +++++++ hir2/src/core/immediates.rs | 792 ++++++++++++++++++++++ hir2/src/core/module.rs | 68 +- hir2/src/core/op.rs | 116 ++++ hir2/src/core/operation.rs | 403 ++++++----- hir2/src/core/region.rs | 28 +- hir2/src/core/symbol_table.rs | 24 +- hir2/src/core/traits.rs | 170 ++++- hir2/src/core/traits/multitrait.rs | 144 ++-- hir2/src/core/usable.rs | 12 +- hir2/src/core/value.rs | 323 +++++---- hir2/src/core/verifier.rs | 213 ++++++ hir2/src/demangle.rs | 13 + hir2/src/derive.rs | 866 ++++++++++++++++++++++++ hir2/src/dialects.rs | 1 + hir2/src/dialects/hir.rs | 14 + hir2/src/dialects/hir/ops.rs | 9 + hir2/src/dialects/hir/ops/binary.rs | 129 ++++ hir2/src/dialects/hir/ops/cast.rs | 126 ++++ hir2/src/dialects/hir/ops/control.rs | 120 ++++ hir2/src/dialects/hir/ops/invoke.rs | 24 + hir2/src/dialects/hir/ops/mem.rs | 25 + hir2/src/dialects/hir/ops/primop.rs | 51 ++ hir2/src/dialects/hir/ops/unary.rs | 65 ++ hir2/src/formatter.rs | 17 + hir2/src/lib.rs | 28 +- hir2/src/ops/binary.rs | 60 -- hir2/src/ops/call.rs | 6 - hir2/src/ops/cast.rs | 33 - hir2/src/ops/control_flow.rs | 50 -- hir2/src/ops/global_value.rs | 35 - hir2/src/ops/inline_asm.rs | 0 hir2/src/ops/mem.rs | 38 -- hir2/src/ops/mod.rs | 16 - hir2/src/ops/primop.rs | 22 - hir2/src/ops/ret.rs | 10 - hir2/src/ops/structured_control_flow.rs | 56 -- hir2/src/ops/unary.rs | 37 - hir2/src/unsafe_ref.rs | 101 --- 55 files changed, 4969 insertions(+), 1421 deletions(-) create mode 100644 hir2/src/attributes.rs create mode 100644 hir2/src/attributes/call_conv.rs create mode 100644 hir2/src/attributes/linkage.rs create mode 100644 hir2/src/attributes/overflow.rs create mode 100644 hir2/src/core/dialect.rs create mode 100644 hir2/src/core/ident.rs create mode 100644 hir2/src/core/immediates.rs create mode 100644 hir2/src/core/verifier.rs create mode 100644 hir2/src/demangle.rs create mode 100644 hir2/src/derive.rs create mode 100644 hir2/src/dialects.rs create mode 100644 hir2/src/dialects/hir/ops.rs create mode 100644 hir2/src/dialects/hir/ops/binary.rs create mode 100644 hir2/src/dialects/hir/ops/cast.rs create mode 100644 hir2/src/dialects/hir/ops/control.rs create mode 100644 hir2/src/dialects/hir/ops/invoke.rs create mode 100644 hir2/src/dialects/hir/ops/mem.rs create mode 100644 hir2/src/dialects/hir/ops/primop.rs create mode 100644 hir2/src/dialects/hir/ops/unary.rs create mode 100644 hir2/src/formatter.rs delete mode 100644 hir2/src/ops/binary.rs delete mode 100644 hir2/src/ops/call.rs delete mode 100644 hir2/src/ops/cast.rs delete mode 100644 hir2/src/ops/control_flow.rs delete mode 100644 hir2/src/ops/global_value.rs delete mode 100644 hir2/src/ops/inline_asm.rs delete mode 100644 hir2/src/ops/mem.rs delete mode 100644 hir2/src/ops/mod.rs delete mode 100644 hir2/src/ops/primop.rs delete mode 100644 hir2/src/ops/ret.rs delete mode 100644 hir2/src/ops/structured_control_flow.rs delete mode 100644 hir2/src/ops/unary.rs delete mode 100644 hir2/src/unsafe_ref.rs diff --git a/Cargo.lock b/Cargo.lock index 929fd3dfa..e60c5773e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,9 +158,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" [[package]] name = "anymap2" @@ -170,9 +170,9 @@ checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" [[package]] name = "arrayref" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" [[package]] name = "arrayvec" @@ -392,9 +392,9 @@ dependencies = [ [[package]] name = "auth-git2" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51bd0e4592409df8631ca807716dc1e5caafae5d01ce0157c966c71c7e49c3c" +checksum = "3810b5af212b013fe7302b12d86616c6c39a48e18f2e4b812a5a9e5710213791" dependencies = [ "dirs", "git2", @@ -835,9 +835,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.18" +version = "1.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" +checksum = "2d74707dde2ba56f86ae90effb3b43ddd369504387e718014de010cec7959800" dependencies = [ "jobserver", "libc", @@ -2333,9 +2333,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -3037,9 +3037,9 @@ dependencies = [ [[package]] name = "miden-crypto" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6fad06fc3af260ed3c4235821daa2132813d993f96d446856036ae97e9606dd" +checksum = "8a69f8362ca496a79c88cf8e5b9b349bf9c6ed49fa867d0548e670afc1f3fca5" dependencies = [ "blake3", "cc", @@ -3145,9 +3145,9 @@ dependencies = [ [[package]] name = "miden-processor" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01e7b212b152b69373e89b069a18cb01742ef2c3f9c328e7b24c44e44f022e52" +checksum = "04a128e20400086c9a985f4e5702e438ba781338fb0bdf9acff16d996c640087" dependencies = [ "miden-air", "miden-core", @@ -3429,7 +3429,8 @@ dependencies = [ "derive_more", "downcast-rs", "either", - "indexmap 2.2.6", + "hashbrown 0.14.5", + "indexmap 2.5.0", "intrusive-collections", "inventory", "lalrpop", @@ -3449,7 +3450,7 @@ dependencies = [ "pretty_assertions", "rustc-demangle", "rustc-hash 1.1.0", - "serde 1.0.208", + "serde 1.0.210", "serde_bytes", "serde_repr", "smallvec", @@ -4188,9 +4189,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "pretty_assertions" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af7cee1a6c8a5b9208b3cb1061f10c0cb689087b3d8ce85fb9d2dd7a29b6ba66" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" dependencies = [ "diff", "yansi", @@ -5984,9 +5985,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-truncate" @@ -7100,9 +7101,9 @@ dependencies = [ [[package]] name = "yansi" -version = "0.5.1" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "zbus" diff --git a/Cargo.toml b/Cargo.toml index ba0cbd222..460a9d1f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ cranelift-bforest = "0.108" env_logger = "0.11" either = { version = "1.10", default-features = false } expect-test = "1.4.1" +hashbrown = { version = "0.14", features = ["nightly"] } Inflector = "0.11" intrusive-collections = "0.9" inventory = "0.3" diff --git a/hir-symbol/src/symbols.toml b/hir-symbol/src/symbols.toml index bababf11f..e6b40f87a 100644 --- a/hir-symbol/src/symbols.toml +++ b/hir-symbol/src/symbols.toml @@ -88,3 +88,6 @@ treeify = { value = "treeify" } [attributes] entrypoint = {} + +[dialects] +hir = {} diff --git a/hir2/Cargo.toml b/hir2/Cargo.toml index 6563c30c4..4c352c9be 100644 --- a/hir2/Cargo.toml +++ b/hir2/Cargo.toml @@ -20,6 +20,7 @@ serde = [ "dep:serde_bytes", "midenc-hir-symbol/serde", ] +debug_refcell = [] [build-dependencies] lalrpop = { version = "0.20", default-features = false } @@ -33,6 +34,7 @@ blink-alloc = { version = "0.3", default-features = false, features = [ either.workspace = true cranelift-entity.workspace = true downcast-rs = { version = "1.2", default-features = false } +hashbrown.workspace = true intrusive-collections.workspace = true inventory.workspace = true lalrpop-util = "0.20" diff --git a/hir2/src/attributes.rs b/hir2/src/attributes.rs new file mode 100644 index 000000000..5a469e47a --- /dev/null +++ b/hir2/src/attributes.rs @@ -0,0 +1,5 @@ +mod call_conv; +mod linkage; +mod overflow; + +pub use self::{call_conv::CallConv, linkage::Linkage, overflow::Overflow}; diff --git a/hir2/src/attributes/call_conv.rs b/hir2/src/attributes/call_conv.rs new file mode 100644 index 000000000..283bbf29d --- /dev/null +++ b/hir2/src/attributes/call_conv.rs @@ -0,0 +1,57 @@ +use core::fmt; + +/// Represents the calling convention of a function. +/// +/// Calling conventions are part of a program's ABI (Application Binary Interface), and +/// they define things such how arguments are passed to a function, how results are returned, +/// etc. In essence, the contract between caller and callee is described by the calling convention +/// of a function. +/// +/// Importantly, it is perfectly normal to mix calling conventions. For example, the public +/// API for a C library will use whatever calling convention is used by C on the target +/// platform (for Miden, that would be `SystemV`). However, internally that library may use +/// the `Fast` calling convention to allow the compiler to optimize more effectively calls +/// from the public API to private functions. In short, choose a calling convention that is +/// well-suited for a given function, to the extent that other constraints don't impose a choice +/// on you. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[cfg_attr( + feature = "serde", + derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) +)] +#[repr(u8)] +pub enum CallConv { + /// This calling convention is what I like to call "chef's choice" - the + /// compiler chooses it's own convention that optimizes for call performance. + /// + /// As a result of this, it is not permitted to use this convention in externally + /// linked functions, as the convention is unstable, and the compiler can't ensure + /// that the caller in another translation unit will use the correct convention. + Fast, + /// The standard calling convention used for C on most platforms + #[default] + SystemV, + /// A function which is using the WebAssembly Component Model "Canonical ABI". + Wasm, + /// A function with this calling convention must be called using + /// the `syscall` instruction. Attempts to call it with any other + /// call instruction will cause a validation error. The one exception + /// to this rule is when calling another function with the `Kernel` + /// convention that is defined in the same module, which can use the + /// standard `call` instruction. + /// + /// Kernel functions may only be defined in a kernel [Module]. + /// + /// In all other respects, this calling convention is the same as `SystemV` + Kernel, +} +impl fmt::Display for CallConv { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Fast => f.write_str("fast"), + Self::SystemV => f.write_str("C"), + Self::Wasm => f.write_str("wasm"), + Self::Kernel => f.write_str("kernel"), + } + } +} diff --git a/hir2/src/attributes/linkage.rs b/hir2/src/attributes/linkage.rs new file mode 100644 index 000000000..d1d3c1277 --- /dev/null +++ b/hir2/src/attributes/linkage.rs @@ -0,0 +1,66 @@ +use core::fmt; + +/// The policy to apply to a global variable (or function) when linking +/// together a program during code generation. +/// +/// Miden doesn't (currently) have a notion of a symbol table for things like global variables. +/// At runtime, there are not actually symbols at all in any familiar sense, instead functions, +/// being the only entities with a formal identity in MASM, are either inlined at all their call +/// sites, or are referenced by the hash of their MAST root, to be unhashed at runtime if the call +/// is executed. +/// +/// Because of this, and because we cannot perform linking ourselves (we must emit separate modules, +/// and leave it up to the VM to link them into the MAST), there are limits to what we can do in +/// terms of linking function symbols. We essentially just validate that given a set of modules in +/// a [Program], that there are no invalid references across modules to symbols which either don't +/// exist, or which exist, but have internal linkage. +/// +/// However, with global variables, we have a bit more freedom, as it is a concept that we are +/// completely inventing from whole cloth without explicit support from the VM or Miden Assembly. +/// In short, when we compile a [Program] to MASM, we first gather together all of the global +/// variables into a program-wide table, merging and garbage collecting as appropriate, and updating +/// all references to them in each module. This global variable table is then assumed to be laid out +/// in memory starting at the base of the linear memory address space in the same order, with +/// appropriate padding to ensure accesses are aligned. Then, when emitting MASM instructions which +/// reference global values, we use the layout information to derive the address where that global +/// value is allocated. +/// +/// This has some downsides however, the biggest of which is that we can't prevent someone from +/// loading modules generated from a [Program] with either their own hand-written modules, or +/// even with modules from another [Program]. In such cases, assumptions about the allocation of +/// linear memory from different sets of modules will almost certainly lead to undefined behavior. +/// In the future, we hope to have a better solution to this problem, preferably one involving +/// native support from the Miden VM itself. For now though, we're working with what we've got. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] +#[cfg_attr( + feature = "serde", + derive(serde_repr::Deserialize_repr, serde_repr::Serialize_repr) +)] +#[repr(u8)] +pub enum Linkage { + /// This symbol is only visible in the containing module. + /// + /// Internal symbols may be renamed to avoid collisions + /// + /// Unreferenced internal symbols can be discarded at link time. + Internal, + /// This symbol will be linked using the "one definition rule", i.e. symbols with + /// the same name, type, and linkage will be merged into a single definition. + /// + /// Unlike `internal` linkage, unreferenced `odr` symbols cannot be discarded. + /// + /// NOTE: `odr` symbols cannot satisfy external symbol references + Odr, + /// This symbol is visible externally, and can be used to resolve external symbol references. + #[default] + External, +} +impl fmt::Display for Linkage { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Internal => f.write_str("internal"), + Self::Odr => f.write_str("odr"), + Self::External => f.write_str("external"), + } + } +} diff --git a/hir2/src/attributes/overflow.rs b/hir2/src/attributes/overflow.rs new file mode 100644 index 000000000..b7073e479 --- /dev/null +++ b/hir2/src/attributes/overflow.rs @@ -0,0 +1,60 @@ +use core::fmt; + +use crate::define_attr_type; + +/// This enumeration represents the various ways in which arithmetic operations +/// can be configured to behave when either the operands or results over/underflow +/// the range of the integral type. +/// +/// Always check the documentation of the specific instruction involved to see if there +/// are any specific differences in how this enum is interpreted compared to the default +/// meaning of each variant. +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)] +pub enum Overflow { + /// Typically, this means the operation is performed using the equivalent field element + /// operation, rather than a dedicated operation for the given type. Because of this, the + /// result of the operation may exceed that of the integral type expected, but this will + /// not be caught right away. + /// + /// It is the callers responsibility to ensure that resulting value is in range. + #[default] + Unchecked, + /// The operation will trap if the operands, or the result, is not valid for the range of the + /// integral type involved, e.g. u32. + Checked, + /// The operation will wrap around, depending on the range of the integral type. For example, + /// given a u32 value, this is done by applying `mod 2^32` to the result. + Wrapping, + /// The result of the operation will be computed as in [Wrapping], however in addition to the + /// result, this variant also pushes a value on the stack which represents whether or not the + /// operation over/underflowed; either 1 if over/underflow occurred, or 0 otherwise. + Overflowing, +} +impl Overflow { + /// Returns true if overflow is unchecked + pub fn is_unchecked(&self) -> bool { + matches!(self, Self::Unchecked) + } + + /// Returns true if overflow will cause a trap + pub fn is_checked(&self) -> bool { + matches!(self, Self::Checked) + } + + /// Returns true if overflow will add an extra boolean on top of the stack + pub fn is_overflowing(&self) -> bool { + matches!(self, Self::Overflowing) + } +} +impl fmt::Display for Overflow { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Unchecked => f.write_str("unchecked"), + Self::Checked => f.write_str("checked"), + Self::Wrapping => f.write_str("wrapping"), + Self::Overflowing => f.write_str("overflow"), + } + } +} + +define_attr_type!(Overflow); diff --git a/hir2/src/core.rs b/hir2/src/core.rs index cfc7084e3..b869bbfd4 100644 --- a/hir2/src/core.rs +++ b/hir2/src/core.rs @@ -2,32 +2,53 @@ mod attribute; mod block; mod component; mod context; +mod dialect; mod entity; mod function; +mod ident; +mod immediates; mod interface; mod module; mod op; mod operation; mod region; mod symbol_table; -mod traits; +pub mod traits; mod types; mod usable; mod value; +pub(crate) mod verifier; + +pub use midenc_hir_symbol as interner; +pub use midenc_session::diagnostics::{Report, SourceSpan, Spanned}; pub use self::{ - block::{Block, BlockCursor, BlockCursorMut, BlockList, BlockOperand}, - entity::{ - Entity, EntityCursor, EntityCursorMut, EntityHandle, EntityId, EntityIter, EntityList, - EntityMut, EntityRef, TrackedEntityHandle, + attribute::{attributes::*, Attribute, AttributeSet, AttributeValue}, + block::{ + Block, BlockCursor, BlockCursorMut, BlockId, BlockList, BlockOperand, BlockOperandRef, + BlockRef, }, - function::{ - AbiParam, ArgumentExtension, ArgumentPurpose, CallConv, Function, FunctionIdent, Signature, + context::Context, + dialect::{Dialect, DialectName}, + entity::{ + Entity, EntityCursor, EntityCursorMut, EntityId, EntityIter, EntityList, EntityMut, + EntityRef, RawEntityRef, UnsafeEntityRef, UnsafeIntrusiveEntityRef, }, + function::{AbiParam, ArgumentExtension, ArgumentPurpose, Function, Signature}, + ident::{FunctionIdent, Ident}, + immediates::{Felt, FieldElement, Immediate, StarkField}, module::Module, - region::{Region, RegionCursor, RegionCursorMut, RegionList}, + op::{Op, OpExt}, + operation::{ + OpCursor, OpCursorMut, OpList, OpSuccessor, Operation, OperationBuilder, OperationName, + OperationRef, + }, + region::{Region, RegionCursor, RegionCursorMut, RegionList, RegionRef}, symbol_table::{Symbol, SymbolTable}, types::*, usable::Usable, - value::{BlockArgument, OpOperand, OpResult, Value, ValueKind}, + value::{ + BlockArgument, BlockArgumentRef, OpOperand, OpResult, OpResultRef, Value, ValueId, ValueRef, + }, + verifier::{OpVerifier, Verify}, }; diff --git a/hir2/src/core/attribute.rs b/hir2/src/core/attribute.rs index 05d62b820..d39a47d95 100644 --- a/hir2/src/core/attribute.rs +++ b/hir2/src/core/attribute.rs @@ -1,7 +1,7 @@ use alloc::collections::BTreeMap; -use core::{borrow::Borrow, fmt}; +use core::{any::Any, borrow::Borrow, fmt}; -use midenc_hir_symbol::Symbol; +use super::interner::Symbol; pub mod attributes { use midenc_hir_symbol::symbols; @@ -12,13 +12,13 @@ pub mod attributes { /// for its containing program, regardless of what module it is defined in. pub const ENTRYPOINT: Attribute = Attribute { name: symbols::Entrypoint, - value: AttributeValue::Unit, + value: None, }; } /// An [AttributeSet] is a uniqued collection of attributes associated with some IR entity -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] -pub struct AttributeSet(BTreeMap); +#[derive(Debug, Default)] +pub struct AttributeSet(Vec); impl FromIterator for AttributeSet { fn from_iter(attrs: T) -> Self where @@ -28,19 +28,19 @@ impl FromIterator for AttributeSet { for attr in attrs.into_iter() { map.insert(attr.name, attr.value); } - Self(map) + Self(map.into_iter().map(|(name, value)| Attribute { name, value }).collect()) } } -impl FromIterator<(Symbol, AttributeValue)> for AttributeSet { +impl FromIterator<(Symbol, Option>)> for AttributeSet { fn from_iter(attrs: T) -> Self where - T: IntoIterator, + T: IntoIterator>)>, { let mut map = BTreeMap::default(); for (name, value) in attrs.into_iter() { map.insert(name, value); } - Self(map) + Self(map.into_iter().map(|(name, value)| Attribute { name, value }).collect()) } } impl AttributeSet { @@ -50,13 +50,37 @@ impl AttributeSet { } /// Insert a new [Attribute] in this set by `name` and `value` - pub fn insert(&mut self, name: impl Into, value: impl Into) { - self.0.insert(name.into(), value.into()); + pub fn insert(&mut self, name: impl Into, value: Option) { + let name = name.into(); + match self.0.binary_search_by_key(&name, |attr| attr.name) { + Ok(index) => { + self.0[index].value = value.map(|v| Box::new(v) as Box); + } + Err(index) => { + let value = value.map(|v| Box::new(v) as Box); + if index == self.0.len() { + self.0.push(Attribute { name, value }); + } else { + self.0.insert(index, Attribute { name, value }); + } + } + } } /// Adds `attr` to this set pub fn set(&mut self, attr: Attribute) { - self.0.insert(attr.name, attr.value); + match self.0.binary_search_by_key(&attr.name, |attr| attr.name) { + Ok(index) => { + self.0[index].value = attr.value; + } + Err(index) => { + if index == self.0.len() { + self.0.push(attr); + } else { + self.0.insert(index, attr); + } + } + } } /// Remove an [Attribute] by name from this set @@ -65,7 +89,16 @@ impl AttributeSet { Symbol: Borrow, Q: Ord + ?Sized, { - self.0.remove(name); + let name = name.borrow(); + match self.0.binary_search_by(|attr| name.cmp(attr.name.borrow()).reverse()) { + Ok(index) if index + 1 == self.0.len() => { + self.0.pop(); + } + Ok(index) => { + self.0.remove(index); + } + Err(_) => (), + } } /// Determine if the named [Attribute] is present in this set @@ -74,51 +107,36 @@ impl AttributeSet { Symbol: Borrow, Q: Ord + ?Sized, { - self.0.contains_key(key) + let key = key.borrow(); + self.0.binary_search_by(|attr| key.cmp(attr.name.borrow()).reverse()).is_ok() } /// Get the [AttributeValue] associated with the named [Attribute] - pub fn get(&self, key: &Q) -> Option<&AttributeValue> - where - Symbol: Borrow, - Q: Ord + ?Sized, - { - self.0.get(key) - } - - /// Get the value associated with the named [Attribute] as a boolean, or `None`. - pub fn get_bool(&self, key: &Q) -> Option - where - Symbol: Borrow, - Q: Ord + ?Sized, - { - self.0.get(key).and_then(|v| v.as_bool()) - } - - /// Get the value associated with the named [Attribute] as an integer, or `None`. - pub fn get_int(&self, key: &Q) -> Option + pub fn get_any(&self, key: &Q) -> Option<&dyn AttributeValue> where Symbol: Borrow, Q: Ord + ?Sized, { - self.0.get(key).and_then(|v| v.as_int()) + let key = key.borrow(); + match self.0.binary_search_by(|attr| key.cmp(attr.name.borrow())) { + Ok(index) => self.0[index].value.as_deref(), + Err(_) => None, + } } - /// Get the value associated with the named [Attribute] as a [Symbol], or `None`. - pub fn get_symbol(&self, key: &Q) -> Option + /// Get the value associated with the named [Attribute] as a value of type `V`, or `None`. + pub fn get(&self, key: &Q) -> Option<&V> where Symbol: Borrow, Q: Ord + ?Sized, + V: AttributeValue, { - self.0.get(key).and_then(|v| v.as_symbol()) + self.get_any(key).and_then(|v| v.downcast_ref::()) } /// Iterate over each [Attribute] in this set - pub fn iter(&self) -> impl Iterator + '_ { - self.0.iter().map(|(k, v)| Attribute { - name: *k, - value: *v, - }) + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter() } } @@ -128,146 +146,82 @@ impl AttributeSet { /// but which is not part of the code itself. For example, `cfg` flags in Rust /// are an example of something which you could represent using an [Attribute]. /// They can also be used to store documentation, source locations, and more. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug)] pub struct Attribute { /// The name of this attribute pub name: Symbol, /// The value associated with this attribute - pub value: AttributeValue, + pub value: Option>, } impl Attribute { - pub fn new(name: impl Into, value: impl Into) -> Self { + pub fn new(name: impl Into, value: Option) -> Self { Self { name: name.into(), - value: value.into(), - } - } -} -impl fmt::Display for Attribute { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.value { - AttributeValue::Unit => write!(f, "#[{}]", self.name.as_str()), - value => write!(f, "#[{}({value})]", &self.name), - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum AttributeValue { - /// No concrete value (i.e. presence of the attribute is significant) - Unit, - /// A boolean value - Bool(bool), - /// A signed integer - Int(isize), - /// An interned string - String(Symbol), -} -impl AttributeValue { - pub fn as_bool(&self) -> Option { - match self { - Self::Bool(value) => Some(*value), - _ => None, + value: value.map(|v| Box::new(v) as Box), } } - pub fn as_int(&self) -> Option { - match self { - Self::Int(value) => Some(*value), - _ => None, - } + pub fn value(&self) -> Option<&dyn AttributeValue> { + self.value.as_deref() } - pub fn as_symbol(&self) -> Option { - match self { - Self::String(value) => Some(*value), - _ => None, + pub fn value_as(&self) -> Option<&V> + where + V: AttributeValue, + { + match self.value.as_deref() { + Some(value) => value.downcast_ref::(), + None => None, } } } -impl fmt::Display for AttributeValue { +impl fmt::Display for Attribute { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Unit => f.write_str("()"), - Self::Bool(value) => write!(f, "{value}"), - Self::Int(value) => write!(f, "{value}"), - Self::String(value) => write!(f, "\"{}\"", value.as_str().escape_default()), + match self.value.as_deref() { + None => write!(f, "#[{}]", self.name.as_str()), + Some(value) => write!(f, "#[{}({value})]", &self.name), } } } -impl From<()> for AttributeValue { - fn from(_: ()) -> Self { - Self::Unit - } -} -impl From for AttributeValue { - fn from(value: bool) -> Self { - Self::Bool(value) - } -} -impl From for AttributeValue { - fn from(value: isize) -> Self { - Self::Int(value) - } -} -impl From<&str> for AttributeValue { - fn from(value: &str) -> Self { - Self::String(Symbol::intern(value)) - } -} -impl From for AttributeValue { - fn from(value: String) -> Self { - Self::String(Symbol::intern(value.as_str())) - } -} -impl From for AttributeValue { - fn from(value: u8) -> Self { - Self::Int(value as isize) - } -} -impl From for AttributeValue { - fn from(value: i8) -> Self { - Self::Int(value as isize) - } -} -impl From for AttributeValue { - fn from(value: u16) -> Self { - Self::Int(value as isize) - } -} -impl From for AttributeValue { - fn from(value: i16) -> Self { - Self::Int(value as isize) - } -} -impl From for AttributeValue { - fn from(value: u32) -> Self { - Self::Int(value as isize) - } + +pub trait AttributeValue: Any + fmt::Debug + fmt::Display + 'static { + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; } -impl From for AttributeValue { - fn from(value: i32) -> Self { - Self::Int(value as isize) + +impl dyn AttributeValue { + pub fn is(&self) -> bool { + self.as_any().is::() } -} -impl TryFrom for AttributeValue { - type Error = core::num::TryFromIntError; - fn try_from(value: usize) -> Result { - Ok(Self::Int(value.try_into()?)) + pub fn downcast_ref(&self) -> Option<&T> { + self.as_any().downcast_ref::() } -} -impl TryFrom for AttributeValue { - type Error = core::num::TryFromIntError; - fn try_from(value: u64) -> Result { - Ok(Self::Int(value.try_into()?)) + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.as_any_mut().downcast_mut::() } } -impl TryFrom for AttributeValue { - type Error = core::num::TryFromIntError; - fn try_from(value: i64) -> Result { - Ok(Self::Int(value.try_into()?)) - } +#[macro_export] +macro_rules! define_attr_type { + ($T:ty) => { + impl $crate::AttributeValue for $T { + #[inline(always)] + fn as_any(&self) -> &dyn core::any::Any { + self as &dyn core::any::Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn core::any::Any { + self as &mut dyn core::any::Any + } + } + }; } + +define_attr_type!(bool); +define_attr_type!(isize); +define_attr_type!(Symbol); +define_attr_type!(super::Immediate); +define_attr_type!(super::Type); diff --git a/hir2/src/core/block.rs b/hir2/src/core/block.rs index 5c6422008..abc0ca9fe 100644 --- a/hir2/src/core/block.rs +++ b/hir2/src/core/block.rs @@ -1,24 +1,74 @@ -use super::{ - BlockArgument, EntityCursor, EntityCursorMut, EntityHandle, EntityIter, EntityList, OpList, - Operation, Usable, -}; +use core::fmt; +use super::*; + +pub type BlockRef = UnsafeIntrusiveEntityRef; /// An intrusive, doubly-linked list of [Block] pub type BlockList = EntityList; pub type BlockCursor<'a> = EntityCursor<'a, Block>; pub type BlockCursorMut<'a> = EntityCursorMut<'a, Block>; +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct BlockId(u32); +impl BlockId { + pub const fn from_u32(id: u32) -> Self { + Self(id) + } + + pub const fn as_u32(&self) -> u32 { + self.0 + } +} +impl EntityId for BlockId { + #[inline(always)] + fn as_usize(&self) -> usize { + self.0 as usize + } +} +impl fmt::Debug for BlockId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "block{}", &self.0) + } +} +impl fmt::Display for BlockId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "block{}", &self.0) + } +} + pub struct Block { + /// The unique id of this block + id: BlockId, /// The set of uses of this block uses: BlockOperandList, /// The region this block is attached to. /// - /// If `link.is_linked() == true`, this will always be set to a valid pointer - region: Option>, + /// This will always be set if this block is attached to a region + region: Option, /// The list of [Operation]s that comprise this block - ops: OpList, + body: OpList, /// The parameter list for this block - arguments: Vec, + arguments: Vec, +} +impl fmt::Debug for Block { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Block") + .field("id", &self.id) + .field_with("region", |f| match self.region.as_ref() { + None => f.write_str("None"), + Some(r) => write!(f, "Some({r:p})"), + }) + .field("arguments", &self.arguments) + .finish_non_exhaustive() + } +} +impl Entity for Block { + type Id = BlockId; + + fn id(&self) -> Self::Id { + self.id + } } impl Usable for Block { type Use = BlockOperand; @@ -27,7 +77,17 @@ impl Usable for Block { !self.uses.is_empty() } - fn uses(&self) -> BlockOperandIter<'_> { + #[inline(always)] + fn uses(&self) -> &BlockOperandList { + &self.uses + } + + #[inline(always)] + fn uses_mut(&mut self) -> &mut BlockOperandList { + &mut self.uses + } + + fn iter_uses(&self) -> BlockOperandIter<'_> { self.uses.iter() } @@ -38,19 +98,83 @@ impl Usable for Block { fn first_use_mut(&mut self) -> BlockOperandCursorMut<'_> { self.uses.front_mut() } + + fn insert_use(&mut self, user: BlockOperandRef) { + self.uses.push_back(user); + } } impl Block { + pub fn new(id: BlockId) -> Self { + Self { + id, + uses: Default::default(), + region: None, + body: Default::default(), + arguments: Default::default(), + } + } + + #[inline] + pub fn has_arguments(&self) -> bool { + !self.arguments.is_empty() + } + + #[inline] + pub fn num_arguments(&self) -> usize { + self.arguments.len() + } + + #[inline(always)] + pub fn arguments(&self) -> &[BlockArgumentRef] { + self.arguments.as_slice() + } + + #[inline(always)] + pub fn arguments_mut(&mut self) -> &mut Vec { + &mut self.arguments + } + + #[inline] + pub fn get_argument(&self, index: usize) -> BlockArgumentRef { + self.arguments[index].clone() + } + + /// Get a handle to the containing [Region] of this block, if it is attached to one + pub fn parent(&self) -> Option { + self.region.clone() + } + + /// Get a handle to the containing [Operation] of this block, if it is attached to one + pub fn parent_op(&self) -> Option { + self.region.as_ref().and_then(|region| region.borrow().parent()) + } + + /// Get the list of [Operation] comprising the body of this block + #[inline(always)] + pub fn body(&self) -> &OpList { + &self.body + } + + /// Get a mutable reference to the list of [Operation] comprising the body of this block + #[inline(always)] + pub fn body_mut(&mut self) -> &mut OpList { + &mut self.body + } + + /// Returns true if this block has predecessors #[inline(always)] pub fn has_predecessors(&self) -> bool { self.is_used() } + /// Get an iterator over the predecessors of this block #[inline(always)] pub fn predecessors(&self) -> BlockOperandIter<'_> { - self.uses() + self.iter_uses() } } +pub type BlockOperandRef = UnsafeIntrusiveEntityRef; /// An intrusive, doubly-linked list of [BlockOperand] pub type BlockOperandList = EntityList; pub type BlockOperandCursor<'a> = EntityCursor<'a, BlockOperand>; @@ -60,19 +184,32 @@ pub type BlockOperandIter<'a> = EntityIter<'a, BlockOperand>; /// A [BlockOperand] represents a use of a [Block] by an [Operation] pub struct BlockOperand { /// The block value - pub block: Block, + pub block: BlockRef, /// The owner of this operand, i.e. the operation it is an operand of - pub owner: EntityHandle, + pub owner: OperationRef, /// The index of this operand in the set of block operands of the operation pub index: u8, } impl BlockOperand { #[inline] - pub fn new(block: Block, owner: EntityHandle, index: u8) -> Self { + pub fn new(block: BlockRef, owner: OperationRef, index: u8) -> Self { Self { block, owner, index, } } + + pub fn block_id(&self) -> BlockId { + self.block.borrow().id + } +} +impl fmt::Debug for BlockOperand { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BlockOperand") + .field("block", &self.block.borrow().id()) + .field_with("owner", |f| write!(f, "{:p}", &self.owner)) + .field("index", &self.index) + .finish() + } } diff --git a/hir2/src/core/context.rs b/hir2/src/core/context.rs index 52452f208..181cc4f24 100644 --- a/hir2/src/core/context.rs +++ b/hir2/src/core/context.rs @@ -1,45 +1,90 @@ -use core::{ - cell::{Cell, UnsafeCell}, - fmt, - mem::MaybeUninit, - ptr::NonNull, -}; +use alloc::rc::Rc; +use core::{cell::Cell, mem::MaybeUninit}; use blink_alloc::Blink; -use cranelift_entity::{PrimaryMap, SecondaryMap}; +use midenc_session::Session; -use super::{ - entity::{EntityObj, TrackedEntityObj}, - *, -}; -use crate::UnsafeRef; +use super::*; pub struct Context { - pub allocator: Blink, - pub blocks: PrimaryMap, - pub values: PrimaryMap, - pub constants: ConstantPool, + pub session: Rc, + allocator: Rc, + next_block_id: Cell, + next_value_id: Cell, + //pub constants: ConstantPool, +} + +impl Default for Context { + fn default() -> Self { + use alloc::sync::Arc; + + use midenc_session::diagnostics::DefaultSourceManager; + + let target_dir = std::env::current_dir().unwrap(); + let options = midenc_session::Options::default(); + let source_manager = Arc::new(DefaultSourceManager::default()); + let session = + Rc::new(Session::new([], None, None, target_dir, options, None, source_manager)); + Self::new(session) + } } impl Context { - pub fn new() -> Self { - let allocator = Blink::new(); + pub fn new(session: Rc) -> Self { + let allocator = Rc::new(Blink::new()); Self { + session, allocator, - blocks: PrimaryMap::new(), - values: PrimaryMap::new(), - constants: Default::default(), + next_block_id: Cell::new(0), + next_value_id: Cell::new(0), + //constants: Default::default(), } } + /// Create a new, detached and empty [Block] with no parameters + pub fn create_block(&self) -> BlockRef { + let block = Block::new(self.alloc_block_id()); + self.alloc_tracked(block) + } + + /// Create a new, detached and empty [Block], with parameters corresponding to the given types + pub fn create_block_with_params(&self, tys: I) -> BlockRef + where + I: IntoIterator, + { + let block = Block::new(self.alloc_block_id()); + let mut block = self.alloc_tracked(block); + let owner = block.clone(); + let args = tys.into_iter().enumerate().map(|(index, ty)| { + let id = self.alloc_value_id(); + let arg = BlockArgument::new( + id, + ty, + owner.clone(), + index.try_into().expect("too many block arguments"), + ); + self.alloc(arg) + }); + block.borrow_mut().arguments_mut().extend(args); + block + } + + /// Create a new [OpResult] with the given type, owner, and index + /// + /// NOTE: This does not attach the result to the operation, it is expected that the caller will + /// do so. + pub fn make_result(&self, ty: Type, owner: OperationRef, index: u8) -> OpResultRef { + let id = self.alloc_value_id(); + self.alloc(OpResult::new(id, ty, owner, index)) + } + /// Allocate a new uninitialized entity of type `T` /// /// In general, you can probably prefer [Context::alloc] instead, but for use cases where you /// need to allocate the space for `T` first, and then perform initialization, this can be /// used. - pub fn alloc_uninit(&self) -> EntityHandle> { - let entity = self.allocator.uninit::>(); - unsafe { EntityHandle::new_uninit(NonNull::new_unchecked(entity)) } + pub fn alloc_uninit(&self) -> UnsafeEntityRef> { + UnsafeEntityRef::new_uninit(&self.allocator) } /// Allocate a new uninitialized entity of type `T`, which needs to be tracked in an intrusive @@ -48,9 +93,8 @@ impl Context { /// In general, you can probably prefer [Context::alloc_tracked] instead, but for use cases /// where you need to allocate the space for `T` first, and then perform initialization, /// this can be used. - pub fn alloc_uninit_tracked(&self) -> TrackedEntityHandle> { - let entity = self.allocator.uninit::>(); - unsafe { TrackedEntityHandle::new_uninit(NonNull::new_unchecked(entity)) } + pub fn alloc_uninit_tracked(&self) -> UnsafeIntrusiveEntityRef> { + UnsafeIntrusiveEntityRef::new_uninit(&self.allocator) } /// Allocate a new `EntityHandle`. @@ -58,9 +102,8 @@ impl Context { /// [EntityHandle] is a smart-pointer type for IR entities, which behaves like a ref-counted /// pointer with dynamically-checked borrow checking rules. It is designed to play well with /// entities allocated from a [Context], and with the somewhat cyclical nature of the IR. - pub fn alloc(&self, value: T) -> EntityHandle { - let entity = self.allocator.put(EntityObj::new(value)); - unsafe { EntityHandle::new(NonNull::new_unchecked(entity)) } + pub fn alloc(&self, value: T) -> UnsafeEntityRef { + UnsafeEntityRef::new(value, &self.allocator) } /// Allocate a new `TrackedEntityHandle`. @@ -69,26 +112,19 @@ impl Context { /// entities which are meant to be tracked in intrusive linked lists. For example, the blocks /// in a region, or the ops in a block. It does this without requiring the entity to know about /// the link at all, while still making it possible to access the link from the entity. - pub fn alloc_tracked(&self, value: T) -> TrackedEntityHandle { - let entity = self.allocator.put(TrackedEntityObj::new(value)); - unsafe { TrackedEntityHandle::new(NonNull::new_unchecked(entity)) } + pub fn alloc_tracked(&self, value: T) -> UnsafeIntrusiveEntityRef { + UnsafeIntrusiveEntityRef::new(value, &self.allocator) } - pub fn create_op(&mut self, mut op: T) -> OpId { - let key = self.ops.next_key(); - let op = self.allocator.put(op); - let ptr = op as *mut T; - { - let operation = op.as_operation_mut(); - operation.key = key; - operation.vtable.set_data_ptr(ptr); - } - let op = unsafe { NonNull::new_unchecked(op) }; - self.ops.push(op.cast()); - key + fn alloc_block_id(&self) -> BlockId { + let id = self.next_block_id.get(); + self.next_block_id.set(id + 1); + BlockId::from_u32(id) } - pub fn op(&self, id: OpId) -> &dyn Op { - self.ops[id].as_ref() + fn alloc_value_id(&self) -> ValueId { + let id = self.next_value_id.get(); + self.next_value_id.set(id + 1); + ValueId::from_u32(id) } } diff --git a/hir2/src/core/dialect.rs b/hir2/src/core/dialect.rs new file mode 100644 index 000000000..daffc8dc1 --- /dev/null +++ b/hir2/src/core/dialect.rs @@ -0,0 +1,61 @@ +use core::ops::Deref; + +pub trait Dialect { + const INIT: Self; + + fn name(&self) -> DialectName; +} + +/// A strongly-typed symbol representing the name of a [Dialect]. +/// +/// Dialect names should be in lowercase ASCII format, though this is not enforced. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct DialectName(::midenc_hir_symbol::Symbol); +impl DialectName { + pub fn new(name: S) -> Self + where + S: Into<::midenc_hir_symbol::Symbol>, + { + Self(name.into()) + } + + pub const fn from_symbol(name: ::midenc_hir_symbol::Symbol) -> Self { + Self(name) + } +} +impl core::fmt::Debug for DialectName { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.0.as_str()) + } +} +impl core::fmt::Display for DialectName { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.0.as_str()) + } +} +impl From<::midenc_hir_symbol::Symbol> for DialectName { + #[inline(always)] + fn from(value: ::midenc_hir_symbol::Symbol) -> Self { + Self(value) + } +} +impl From for ::midenc_hir_symbol::Symbol { + #[inline(always)] + fn from(value: DialectName) -> Self { + value.0 + } +} +impl Deref for DialectName { + type Target = ::midenc_hir_symbol::Symbol; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl AsRef<::midenc_hir_symbol::Symbol> for DialectName { + #[inline(always)] + fn as_ref(&self) -> &::midenc_hir_symbol::Symbol { + &self.0 + } +} diff --git a/hir2/src/core/entity.rs b/hir2/src/core/entity.rs index e731a5444..85a8c2f47 100644 --- a/hir2/src/core/entity.rs +++ b/hir2/src/core/entity.rs @@ -1,24 +1,26 @@ mod list; +use alloc::alloc::{AllocError, Layout}; use core::{ + any::Any, cell::{Cell, UnsafeCell}, fmt, + hash::Hash, mem::MaybeUninit, + ops::{Deref, DerefMut}, ptr::NonNull, }; pub use self::list::{EntityCursor, EntityCursorMut, EntityIter, EntityList}; -pub trait Entity { +pub trait Entity: Any { type Id: EntityId; - fn id(&self) -> Self::Key; - unsafe fn set_id(&self, id: Self::Key); + fn id(&self) -> Self::Id; } pub trait EntityId: Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash { fn as_usize(&self) -> usize; - unsafe fn from_usize(raw: usize) -> Self; } /// An error raised when an aliasing violation is detected in the use of [EntityHandle] @@ -74,28 +76,48 @@ impl fmt::Display for AliasingViolationError { } } -/// An [EntityHandle] is a smart-pointer type for IR entities allocated in a [Context]. +pub type UnsafeEntityRef = RawEntityRef; + +pub type UnsafeIntrusiveEntityRef = RawEntityRef; + +/// A [RawEntityRef] is an unsafe smart pointer type for IR entities allocated in a [Context]. +/// +/// Along with the type of entity referenced, it can be instantiated with extra metadata of any +/// type. For example, [UnsafeIntrusiveEntityRef] stores an intrusive link in the entity metadata, +/// so that the entity can be added to an intrusive linked list without the entity needing to +/// know about the link - and without violating aliasing rules when navigating the list. /// /// Unlike regular references, no reference to the underlying `T` is constructed until one is /// needed, at which point the borrow (whether mutable or immutable) is dynamically checked to /// ensure that it is valid according to Rust's aliasing rules. /// -/// As a result, an [EntityHandle] is not considered an alias, and it is possible to acquire a +/// As a result, a [RawEntityRef] is not considered an alias, and it is possible to acquire a /// mutable reference to the underlying data even while other copies of the handle exist. Any /// attempt to construct invalid aliases (immutable reference while a mutable reference exists, or /// vice versa), will result in a runtime panic. /// /// This is a tradeoff, as we do not get compile-time guarantees that such panics will not occur, /// but in exchange we get a much more flexible and powerful IR structure. -pub struct EntityHandle { - inner: NonNull>, +/// +/// # SAFETY +/// +/// Unlike most smart-pointer types, e.g. `Rc`, [RAwEntityRef] does not provide any protection +/// against the underlying allocation being deallocated (i.e. the arena it points into is dropped). +/// This is by design, as the type is meant to be stored in objects inside the arena, and +/// _not_ dropped when the arena is dropped. This requires care when using it however, to ensure +/// that no [RawEntityRef] lives longer than the arena that allocated it. +/// +/// For a safe entity reference, see [EntityRef], which binds a [RawEntityRef] to the lifetime +/// of the arena. +pub struct RawEntityRef { + inner: NonNull>, } -impl Clone for EntityHandle { +impl Clone for RawEntityRef { fn clone(&self) -> Self { Self { inner: self.inner } } } -impl EntityHandle { +impl RawEntityRef { /// Create a new [EntityHandle] from a raw pointer to the underlying [EntityObj]. /// /// # SAFETY @@ -111,118 +133,97 @@ impl EntityHandle { /// /// You should generally not be using this API, as it is meant solely for constructing an /// [EntityHandle] immediately after allocating the underlying [EntityObj]. - pub(crate) unsafe fn new(ptr: NonNull>) -> Self { + #[inline] + unsafe fn from_inner(inner: NonNull>) -> Self { Self { inner } } - /// Get a dynamically-checked immutable reference to the underlying `T` - pub fn get(&self) -> EntityRef<'_, T> { - unsafe { - let obj = self.inner.as_ref(); - obj.borrow() - } + #[inline] + unsafe fn from_ptr(ptr: *mut RawEntityMetadata) -> Self { + debug_assert!(!ptr.is_null()); + Self::from_inner(NonNull::new_unchecked(ptr)) } - /// Get a dynamically-checked mutable reference to the underlying `T` - pub fn get_mut(&mut self) -> EntityMut<'_, T> { - unsafe { - let obj = self.inner.as_ref(); - obj.borrow_mut() - } + #[inline] + fn into_inner(this: Self) -> NonNull> { + this.inner } +} - /// Convert this handle into a raw pointer to the underlying entity. - /// - /// This should only be used in situations where the returned pointer will not be used to - /// actually access the underlying entity. Use [get] or [get_mut] for that. [EntityHandle] - /// ensures that Rust's aliasing rules are not violated when using it, but if you use the - /// returned pointer to do so, no such guarantee is provided, and undefined behavior can - /// result. +impl RawEntityRef { + /// Create a new [RawEntityRef] by allocating `value` with `metadata` in the given arena + /// allocator. /// /// # SAFETY /// - /// The returned pointer _must_ not be used to create a reference to the underlying entity - /// unless you can guarantee that such a reference does not violate Rust's aliasing rules. - /// - /// Do not use the pointer to create a mutable reference if other references exist, and do - /// not use the pointer to create an immutable reference if a mutable reference exists or - /// might be created while the immutable reference lives. - pub fn into_raw(self) -> NonNull { - unsafe { NonNull::new_unchecked(self.inner.as_ref().as_ptr()) } + /// The resulting [RawEntityRef] must not outlive the arena. This is not enforced statically, + /// it is up to the caller to uphold the invariants of this type. + pub fn new_with_metadata(value: T, metadata: Metadata, arena: &blink_alloc::Blink) -> Self { + unsafe { + Self::from_inner(NonNull::new_unchecked( + arena.put(RawEntityMetadata::new(value, metadata)), + )) + } } -} -impl EntityHandle> { - /// Create an [EntityHandle] for an entity which may not be fully initialized. + /// Create a [RawEntityRef] for an entity which may not be fully initialized, using the provided + /// arena. /// /// # SAFETY /// - /// The safety rules are much the same as [EntityHandle::new], with the main difference + /// The safety rules are much the same as [RawEntityRef::new], with the main difference /// being that the `T` does not have to be initialized yet. No references to the `T` will - /// be created directly until [EntityHandle::assume_init] is called. - pub(crate) unsafe fn new_uninit(ptr: NonNull>>) -> Self { - Self { inner: ptr } - } - - /// Converts to `EntityHandle`. - /// - /// Just like with [MaybeUninit::assume_init], it is up to the caller to guarantee that the - /// value really is in an initialized state. Calling this when the content is not yet fully - /// initialized causes immediate undefined behavior. - pub unsafe fn assume_init(self) -> EntityHandle { - EntityHandle { - inner: self.inner.cast(), + /// be created directly until [RawEntityRef::assume_init] is called. + pub fn new_uninit_with_metadata( + metadata: Metadata, + arena: &blink_alloc::Blink, + ) -> RawEntityRef, Metadata> { + unsafe { + RawEntityRef::from_ptr(RawEntityRef::allocate_for_layout( + metadata, + Layout::new::(), + |layout| arena.allocator().allocate(layout), + <*mut u8>::cast, + )) } } } -/// A [TrackedEntityHandle] is like [EntityHandle], except it provides built-in support for -/// adding the entity to an [intrusive_collections::LinkedList] that doesn't require constructing -/// a reference to the entity itself, and thus potentially causing an aliasing violation. Instead, -/// the link is stored as part of the underlying allocation, but separate from the entity. -pub struct TrackedEntityHandle { - inner: NonNull>, -} -impl Clone for TrackedEntityHandle { - fn clone(&self) -> Self { - Self { inner: self.inner } - } -} -impl TrackedEntityHandle { - /// Create a new [TrackedEntityHandle] from a raw pointer to the underlying [TrackedEntityObj]. - /// - /// # SAFETY - /// - /// This function has the same requirements around safety as [EntityHandle::new]. - pub(crate) unsafe fn new(ptr: NonNull>) -> Self { - Self { inner } +impl RawEntityRef { + pub fn new(value: T, arena: &blink_alloc::Blink) -> Self { + RawEntityRef::new_with_metadata(value, (), arena) } - /// Get a dynamically-checked immutable reference to the underlying `T` - pub fn get(&self) -> EntityRef<'_, T> { - unsafe { - let obj = self.inner.as_ref(); - obj.entity.borrow() - } + pub fn new_uninit(arena: &blink_alloc::Blink) -> RawEntityRef, ()> { + RawEntityRef::new_uninit_with_metadata((), arena) } +} - /// Get a dynamically-checked mutable reference to the underlying `T` - pub fn get_mut(&mut self) -> EntityMut<'_, T> { - unsafe { - let obj = self.inner.as_ref(); - obj.entity.borrow_mut() - } +impl RawEntityRef, Metadata> { + /// Converts to `RawEntityRef`. + /// + /// # Safety + /// + /// Just like with [MaybeUninit::assume_init], it is up to the caller to guarantee that the + /// value really is in an initialized state. Calling this when the content is not yet fully + /// initialized causes immediate undefined behavior. + #[inline] + pub unsafe fn assume_init(self) -> RawEntityRef { + let ptr = Self::into_inner(self); + unsafe { RawEntityRef::from_inner(ptr.cast()) } } +} +impl RawEntityRef { /// Convert this handle into a raw pointer to the underlying entity. /// /// This should only be used in situations where the returned pointer will not be used to - /// actually access the underlying entity. Use [get] or [get_mut] for that. [EntityHandle] + /// actually access the underlying entity. Use [get] or [get_mut] for that. [RawEntityRef] /// ensures that Rust's aliasing rules are not violated when using it, but if you use the /// returned pointer to do so, no such guarantee is provided, and undefined behavior can /// result. /// - /// # SAFETY + /// # Safety /// /// The returned pointer _must_ not be used to create a reference to the underlying entity /// unless you can guarantee that such a reference does not violate Rust's aliasing rules. @@ -230,100 +231,158 @@ impl TrackedEntityHandle { /// Do not use the pointer to create a mutable reference if other references exist, and do /// not use the pointer to create an immutable reference if a mutable reference exists or /// might be created while the immutable reference lives. - pub fn into_raw(self) -> NonNull { - unsafe { NonNull::new_unchecked(self.inner.as_ref().entity.as_ptr()) } + pub fn into_raw(this: Self) -> *const T { + Self::as_ptr(&this) + } + + pub fn as_ptr(this: &Self) -> *const T { + let ptr: *mut RawEntityMetadata = NonNull::as_ptr(this.inner); + + // SAFETY: This cannot go through Deref::deref or RawEntityRef::inner because this + // is required to retain raw/mut provenance such that e.g. `get_mut` can write through + // the pointer after the RawEntityRef is recovered through `from_raw` + let ptr = unsafe { core::ptr::addr_of_mut!((*ptr).entity.cell) }; + UnsafeCell::raw_get(ptr).cast_const() } -} -impl TrackedEntityHandle> { - /// Create a [TrackedEntityHandle] for an entity which may not be fully initialized. + /// Convert a pointer returned by [RawEntityRef::into_raw] back into a [RawEntityRef]. /// - /// # SAFETY + /// # Safety /// - /// The safety rules are much the same as [TrackedEntityHandle::new], with the main difference - /// being that the `T` does not have to be initialized yet. No references to the `T` will - /// be created directly until [TrackedEntityHandle::assume_init] is called. - pub(crate) unsafe fn new_uninit(ptr: NonNull>>) -> Self { - Self { inner: ptr } + /// * It is _only_ valid to call this method on a pointer returned by [RawEntityRef::into_raw]. + /// * The pointer must be a valid pointer for `T` + pub unsafe fn from_raw(ptr: *const T) -> Self { + let offset = unsafe { RawEntityMetadata::::data_offset(ptr) }; + + // Reverse the offset to find the original EntityObj + let entity_ptr = unsafe { ptr.byte_sub(offset) as *mut RawEntityMetadata }; + + unsafe { Self::from_ptr(entity_ptr) } } - /// Converts to `TrackedEntityHandle`. - /// - /// Just like with [MaybeUninit::assume_init], it is up to the caller to guarantee that the - /// value really is in an initialized state. Calling this when the content is not yet fully - /// initialized causes immediate undefined behavior. - pub unsafe fn assume_init(self) -> TrackedEntityHandle { - TrackedEntityHandle { - inner: self.inner.cast(), - } + /// Get a dynamically-checked immutable reference to the underlying `T` + pub fn borrow(&self) -> EntityRef<'_, T> { + let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); + unsafe { (*core::ptr::addr_of!((*ptr).entity)).borrow() } } -} -unsafe impl intrusive_collections::PointerOps for TrackedEntityHandle { - type Pointer = TrackedEntityHandle; - type Value = EntityObj; + /// Get a dynamically-checked mutable reference to the underlying `T` + pub fn borrow_mut(&mut self) -> EntityMut<'_, T> { + let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); + unsafe { (*core::ptr::addr_of!((*ptr).entity)).borrow_mut() } + } - unsafe fn from_raw(&self, value: *const Self::Value) -> Self::Pointer { - assert!(!value.is_null()); - let offset = core::mem::offset_of!(TrackedEntityObj, entity); - let ptr = value.cast_mut().byte_sub(offset).cast::>(); - debug_assert!(ptr.is_aligned()); - TrackedEntityHandle::new(NonNull::new_unchecked(ptr)) + /// Try to get a dynamically-checked mutable reference to the underlying `T` + /// + /// Returns `None` if the entity is already borrowed + pub fn try_borrow_mut(&mut self) -> Option> { + let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); + unsafe { (*core::ptr::addr_of!((*ptr).entity)).try_borrow_mut().ok() } } - fn into_raw(&self, ptr: Self::Pointer) -> *const Self::Value { - let ptr = ptr.into_raw().as_ptr().cast_const(); - let offset = core::mem::offset_of!(EntityObj, cell); - unsafe { ptr.byte_sub(offset).cast() } + pub fn ptr_eq(this: &Self, other: &Self) -> bool { + core::ptr::addr_eq(this.inner.as_ptr(), other.inner.as_ptr()) } -} -/// An adapter for storing any `Entity` impl in a [intrusive_collections::LinkedList] -#[derive(Default, Copy, Clone)] -pub struct EntityAdapter(core::marker::PhantomData); -impl EntityAdapter { - pub const fn new() -> Self { - Self(core::marker::PhantomData) + unsafe fn allocate_for_layout( + metadata: Metadata, + value_layout: Layout, + allocate: F, + mem_to_metadata: F2, + ) -> *mut RawEntityMetadata + where + F: FnOnce(Layout) -> Result, AllocError>, + F2: FnOnce(*mut u8) -> *mut RawEntityMetadata, + { + use alloc::alloc::handle_alloc_error; + + let layout = raw_entity_metadata_layout_for_value_layout::(value_layout); + unsafe { + RawEntityRef::try_allocate_for_layout(metadata, value_layout, allocate, mem_to_metadata) + .unwrap_or_else(|_| handle_alloc_error(layout)) + } } -} -unsafe impl intrusive_collections::Adapter for EntityAdapter { - type LinkOps = intrusive_collections::linked_list::LinkOps; - type PointerOps = intrusive_collections::DefaultPointerOps>; + #[inline] + unsafe fn try_allocate_for_layout( + metadata: Metadata, + value_layout: Layout, + allocate: F, + mem_to_metadata: F2, + ) -> Result<*mut RawEntityMetadata, AllocError> + where + F: FnOnce(Layout) -> Result, AllocError>, + F2: FnOnce(*mut u8) -> *mut RawEntityMetadata, + { + let layout = raw_entity_metadata_layout_for_value_layout::(value_layout); + let ptr = allocate(layout)?; + let inner = mem_to_metadata(ptr.as_non_null_ptr().as_ptr()); + unsafe { + debug_assert_eq!(Layout::for_value_raw(inner), layout); + + core::ptr::addr_of_mut!((*inner).metadata).write(metadata); + core::ptr::addr_of_mut!((*inner).entity.borrow).write(Cell::new(BorrowFlag::UNUSED)); + #[cfg(debug_assertions)] + core::ptr::addr_of_mut!((*inner).entity.borrowed_at).write(Cell::new(None)); + } - unsafe fn get_value( - &self, - link: ::LinkPtr, - ) -> *const ::Value { - let offset = core::mem::offset_of!(TrackedEntityObj, link); - let ptr = link.as_ptr().cast_const().byte_sub(offset); - let offset = core::mem::offset_of!(TrackedEntityObj, entity); - ptr.byte_add(offset) + Ok(inner) } +} - unsafe fn get_link( - &self, - value: *const ::Value, - ) -> ::LinkPtr { - let offset = core::mem::offset_of!(TrackedEntityObj, entity); - let ptr = value.byte_sub(offset); - let offset = core::mem::offset_of!(TrackedEntityObj, link); - let ptr = ptr.byte_add(offset); - NonNull::new_unchecked(ptr.cast_mut()) +impl RawEntityRef { + /// Returns true if the underlying value is a `T` + #[inline] + pub fn is(self) -> bool { + self.borrow().is::() } - fn link_ops(&self) -> &Self::LinkOps { - &intrusive_collections::linked_list::LinkOps + /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. + /// + /// If the cast is not valid for this reference, `Err` is returned containing the original value. + #[inline] + pub fn downcast(self) -> Result, Self> { + if self.borrow().is::() { + unsafe { Ok(Self::downcast_unchecked(self)) } + } else { + Err(self) + } } - fn link_ops_mut(&mut self) -> &mut Self::LinkOps { - &mut intrusive_collections::linked_list::LinkOps + /// Casts this reference to the concrete type `T` without checking that the cast is valid. + /// + /// # Safety + /// + /// The referenced value must be of type `T`. Calling this method with the incorrect type is + /// _undefined behavior_. + #[inline] + pub unsafe fn downcast_unchecked(self) -> RawEntityRef { + unsafe { + let ptr = RawEntityRef::into_inner(self); + RawEntityRef::from_inner(ptr.cast()) + } } +} - fn pointer_ops(&self) -> &Self::PointerOps { - const OPS: intrusive_collections::DefaultPointerOps>; +impl core::ops::CoerceUnsized> + for RawEntityRef +where + T: ?Sized + core::marker::Unsize, + U: ?Sized, +{ +} + +impl fmt::Pointer for RawEntityRef { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Pointer::fmt(&Self::as_ptr(self), f) + } +} - &OPS +impl fmt::Debug for RawEntityRef { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.borrow(), f) } } @@ -341,15 +400,6 @@ impl core::ops::Deref for EntityRef<'_, T> { } } impl<'b, T: ?Sized> EntityRef<'b, T> { - #[must_use] - #[inline] - pub fn clone(orig: &Self) -> Self { - Self { - value: orig.value, - borrow: orig.borrow.clone(), - } - } - #[inline] pub fn map(orig: Self, f: F) -> EntityRef<'b, U> where @@ -369,18 +419,119 @@ where { } +impl fmt::Debug for EntityRef<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} impl fmt::Display for EntityRef<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } } +impl Eq for EntityRef<'_, T> {} +impl PartialEq for EntityRef<'_, T> { + fn eq(&self, other: &Self) -> bool { + **self == **other + } +} +impl PartialOrd for EntityRef<'_, T> { + fn partial_cmp(&self, other: &Self) -> Option { + (**self).partial_cmp(&**other) + } + + fn ge(&self, other: &Self) -> bool { + **self >= **other + } + + fn gt(&self, other: &Self) -> bool { + **self > **other + } + + fn le(&self, other: &Self) -> bool { + **self <= **other + } + + fn lt(&self, other: &Self) -> bool { + **self < **other + } +} +impl Ord for EntityRef<'_, T> { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + (**self).cmp(&**other) + } +} +impl Hash for EntityRef<'_, T> { + fn hash(&self, state: &mut H) { + (**self).hash(state); + } +} /// A guard that provides exclusive access to an IR entity -pub struct EntityMut<'a, T> { +pub struct EntityMut<'b, T: ?Sized> { + /// The raw pointer to the underlying data + /// + /// This is a pointer rather than a `&'b mut T` to avoid `noalias` violations, because a + /// `EntityMut` argument doesn't hold exclusivity for its whole scope, only until it drops. value: NonNull, + /// This value provides the drop glue for tracking that the underlying allocation is + /// mutably borrowed, but it is otherwise not read. + #[allow(unused)] borrow: BorrowRefMut<'b>, + /// `NonNull` is covariant over `T`, so we need to reintroduce invariance via phantom data _marker: core::marker::PhantomData<&'b mut T>, } +impl<'b, T: ?Sized> EntityMut<'b, T> { + /// Splits an `EntityMut` into multiple `EntityMut`s for different components of the borrowed + /// data. + /// + /// The underlying entity will remain mutably borrowed until both returned `EntityMut`s go out + /// of scope. + /// + /// The entity is already mutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as `EntityMut::map_split(...)`, so as + /// to avoid conflicting with any method of the same name accessible via the `Deref` impl. + /// + /// # Examples + /// + /// ```rust + /// use crate::*; + /// use blink_alloc::Blink; + /// + /// let alloc = Blink::default(); + /// let entity = UnsafeEntityRef::new([1, 2, 3, 4], &alloc); + /// let borrow = entity.get_mut(); + /// let (mut begin, mut end) = EntityMut::map_split(borrow, |slice| slice.split_at_mut(2)); + /// assert_eq!(*begin, [1, 2]); + /// assert_eq!(*end, [3, 4]); + /// begin.copy_from_slice(&[4, 3]); + /// end.copy_from_slice(&[2, 1]); + /// ``` + #[inline] + pub fn map_split( + mut orig: Self, + f: F, + ) -> (EntityMut<'b, U>, EntityMut<'b, V>) + where + F: FnOnce(&mut T) -> (&mut U, &mut V), + { + let borrow = orig.borrow.clone(); + let (a, b) = f(&mut *orig); + ( + EntityMut { + value: NonNull::from(a), + borrow, + _marker: core::marker::PhantomData, + }, + EntityMut { + value: NonNull::from(b), + borrow: orig.borrow, + _marker: core::marker::PhantomData, + }, + ) + } +} impl Deref for EntityMut<'_, T> { type Target = T; @@ -405,41 +556,133 @@ where { } +impl fmt::Debug for EntityMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} impl fmt::Display for EntityMut<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } } +impl Eq for EntityMut<'_, T> {} +impl PartialEq for EntityMut<'_, T> { + fn eq(&self, other: &Self) -> bool { + **self == **other + } +} +impl PartialOrd for EntityMut<'_, T> { + fn partial_cmp(&self, other: &Self) -> Option { + (**self).partial_cmp(&**other) + } -/// An [EntityObj] is a wrapper around IR objects that are allocated in a [Context]. -/// -/// It ensures that any [EntityHandle] which references the underlying entity, adheres to Rust's -/// aliasing rules. -pub struct EntityObj { - borrow: Cell, - #[cfg(debug_assertions)] - borrowed_at: Cell>>, - cell: UnsafeCell, + fn ge(&self, other: &Self) -> bool { + **self >= **other + } + + fn gt(&self, other: &Self) -> bool { + **self > **other + } + + fn le(&self, other: &Self) -> bool { + **self <= **other + } + + fn lt(&self, other: &Self) -> bool { + **self < **other + } +} +impl Ord for EntityMut<'_, T> { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + (**self).cmp(&**other) + } +} +impl Hash for EntityMut<'_, T> { + fn hash(&self, state: &mut H) { + (**self).hash(state); + } } -/// A [TrackedEntityObj] is a wrapper around IR entities that are linked in to an -/// [intrusive_collections::LinkedList] for tracking of that entity. This permits the linked list -/// to be visited/mutated without borrowing the entities themselves, and thus risk violation of -/// the aliasing rules. -pub struct TrackedEntityObj { - link: intrusive_collections::linked_list::LinkedListLink, - entity: EntityObj, +// This type wraps the entity data with extra metadata we want to associate with the entity, but +// separately from it, so that pointers to the metadata do not cause aliasing violations if the +// entity itself is borrowed. +// +// The kind of metadata stored here is unconstrained, but in practice should be limited to things +// that you _need_ to be able to access from a `RawEntityRef`, without aliasing the entity. For now +// the main reason we use this is for the intrusive link used to store entities in an intrusive +// linked list. We don't want traversing the intrusive list to require borrowing the entity, only +// the link, unless we explicitly want to borrow the entity, thus we use the metadata field here +// to hold the link. +// +// This has to be `pub` for implementing the traits required for the intrusive collections +// integration, but its internals are hidden outside this module, and we hide it from the generated +// docs as well. +#[repr(C)] +#[doc(hidden)] +pub struct RawEntityMetadata { + metadata: Metadata, + entity: RawEntity, } -impl TrackedEntityObj { - pub fn new(value: T) -> Self { +impl RawEntityMetadata { + pub(crate) fn new(value: T, metadata: Metadata) -> Self { Self { - link: Default::default(), - entity: EntityObj::new(value), + metadata, + entity: RawEntity::new(value), } } } +impl RawEntityMetadata { + #[inline] + const fn metadata_offset() -> usize { + core::mem::offset_of!(RawEntityMetadata<(), Metadata>, metadata) + } + + /// Get the offset within a `RawEntityMetadata` for the payload behind a pointer. + /// + /// # Safety + /// + /// The pointer must point to (and have valid metadata for) a previously valid instance of T, but + /// the T is allowed to be dropped. + unsafe fn data_offset(ptr: *const T) -> usize { + use core::mem::align_of_val_raw; -impl EntityObj { + // Align the unsized value to the end of the RawEntityMetadata. + // Because RawEntityMetadata/RawEntity is repr(C), it will always be the last field in memory. + // + // SAFETY: since the only unsized types possible are slices, trait objects, and extern types, + // the input safety requirement is currently enough to satisfy the requirements of + // align_of_val_raw; but this is an implementation detail of the language that is unstable + unsafe { RawEntityMetadata::<(), Metadata>::data_offset_align(align_of_val_raw(ptr)) } + } + + #[inline] + fn data_offset_align(align: usize) -> usize { + let layout = Layout::new::>(); + layout.size() + layout.padding_needed_for(align) + } +} + +fn raw_entity_metadata_layout_for_value_layout(layout: Layout) -> Layout { + Layout::new::>() + .extend(layout) + .unwrap() + .0 + .pad_to_align() +} + +/// A [RawEntity] wraps an entity to be allocated in a [Context], and provides dynamic borrow- +/// checking functionality for [UnsafeEntityRef], thereby protecting the entity by ensuring that +/// all accesses adhere to Rust's aliasing rules. +#[repr(C)] +struct RawEntity { + borrow: Cell, + #[cfg(debug_assertions)] + borrowed_at: Cell>>, + cell: UnsafeCell, +} + +impl RawEntity { pub fn new(value: T) -> Self { Self { borrow: Cell::new(BorrowFlag::UNUSED), @@ -448,7 +691,9 @@ impl EntityObj { cell: UnsafeCell::new(value), } } +} +impl RawEntity { #[track_caller] #[inline] pub fn borrow(&self) -> EntityRef<'_, T> { @@ -466,14 +711,14 @@ impl EntityObj { #[cfg(debug_assertions)] { // `borrowed_at` is always the *first* active borrow - if b.borrow.get() == 1 { + if b.borrow.get() == BorrowFlag(1) { self.borrowed_at.set(Some(core::panic::Location::caller())); } } // SAFETY: `BorrowRef` ensures that there is only immutable access to the value // while borrowed. - let value = unsafe { NonNull::new_unchecked(self.value.get()) }; + let value = unsafe { NonNull::new_unchecked(self.cell.get()) }; Ok(EntityRef { value, borrow: b }) } None => Err(AliasingViolationError { @@ -504,11 +749,11 @@ impl EntityObj { } // SAFETY: `BorrowRefMut` guarantees unique access. - let value = unsafe { NonNull::new_unchecked(self.value.get()) }; + let value = unsafe { NonNull::new_unchecked(self.cell.get()) }; Ok(EntityMut { value, borrow: b, - _marker: PhantomData, + _marker: core::marker::PhantomData, }) } None => Err(AliasingViolationError { @@ -520,16 +765,6 @@ impl EntityObj { }), } } - - #[inline] - pub fn as_ptr(&self) -> *mut T { - self.cell.get() - } - - #[inline] - pub fn get_mut(&mut self) -> &mut T { - self.cell.get_mut() - } } struct BorrowRef<'b> { diff --git a/hir2/src/core/entity/list.rs b/hir2/src/core/entity/list.rs index d3b2cc5df..5cb5e5ee0 100644 --- a/hir2/src/core/entity/list.rs +++ b/hir2/src/core/entity/list.rs @@ -1,11 +1,17 @@ -use core::fmt; +use core::{fmt, mem::MaybeUninit, ptr::NonNull}; -use super::{EntityAdapter, EntityRef, TrackedEntityHandle}; +use super::{EntityMut, EntityRef, RawEntityMetadata, RawEntityRef, UnsafeIntrusiveEntityRef}; -#[derive(Default)] pub struct EntityList { list: intrusive_collections::linked_list::LinkedList>, } +impl Default for EntityList { + fn default() -> Self { + Self { + list: Default::default(), + } + } +} impl EntityList { /// Construct a new, empty [EntityList] pub fn new() -> Self { @@ -30,29 +36,43 @@ impl EntityList { } /// Prepend `entity` to this list - pub fn push_front(&mut self, entity: TrackedEntityHandle) { + pub fn push_front(&mut self, entity: UnsafeIntrusiveEntityRef) { self.list.push_front(entity); } /// Append `entity` to this list - pub fn push_back(&mut self, entity: TrackedEntityHandle) { + pub fn push_back(&mut self, entity: UnsafeIntrusiveEntityRef) { self.list.push_back(entity); } /// Remove the entity at the front of the list, returning its [TrackedEntityHandle] /// /// Returns `None` if the list is empty. - pub fn pop_front(&mut self) -> Option> { + pub fn pop_front(&mut self) -> Option> { self.list.pop_back() } /// Remove the entity at the back of the list, returning its [TrackedEntityHandle] /// /// Returns `None` if the list is empty. - pub fn pop_back(&mut self) -> Option> { + pub fn pop_back(&mut self) -> Option> { self.list.pop_back() } + #[doc(hidden)] + pub fn cursor(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.list.cursor(), + } + } + + #[doc(hidden)] + pub fn cursor_mut(&mut self) -> EntityCursorMut<'_, T> { + EntityCursorMut { + cursor: self.list.cursor_mut(), + } + } + /// Get an [EntityCursor] pointing to the first entity in the list, or the null object if /// the list is empty pub fn front(&self) -> EntityCursor<'_, T> { @@ -63,7 +83,7 @@ impl EntityList { /// Get an [EntityCursorMut] pointing to the first entity in the list, or the null object if /// the list is empty - pub fn front_mut(&self) -> EntityCursorMut<'_, T> { + pub fn front_mut(&mut self) -> EntityCursorMut<'_, T> { EntityCursorMut { cursor: self.list.front_mut(), } @@ -79,7 +99,7 @@ impl EntityList { /// Get an [EntityCursorMut] pointing to the last entity in the list, or the null object if /// the list is empty - pub fn back_mut(&self) -> EntityCursorMut<'_, T> { + pub fn back_mut(&mut self) -> EntityCursorMut<'_, T> { EntityCursorMut { cursor: self.list.back_mut(), } @@ -91,7 +111,7 @@ impl EntityList { /// bound to the list itself, not the iterator. pub fn iter(&self) -> EntityIter<'_, T> { EntityIter { - cursor: self.list.cursor(), + cursor: self.cursor(), started: false, } } @@ -122,6 +142,41 @@ impl EntityList { list: self.list.take(), } } + + /// Get a cursor to the item pointed to by `ptr`. + /// + /// # Safety + /// + /// This function may only be called when it is known that `ptr` refers to an entity which is + /// linked into this list. This operation will panic if the entity is not linked into any list, + /// and may result in undefined behavior if the operation is linked into a different list. + pub unsafe fn cursor_from_ptr(&self, ptr: UnsafeIntrusiveEntityRef) -> EntityCursor<'_, T> { + unsafe { + let raw = UnsafeIntrusiveEntityRef::into_inner(ptr).as_ptr(); + EntityCursor { + cursor: self.list.cursor_from_ptr(raw), + } + } + } + + /// Get a mutable cursor to the item pointed to by `ptr`. + /// + /// # Safety + /// + /// This function may only be called when it is known that `ptr` refers to an entity which is + /// linked into this list. This operation will panic if the entity is not linked into any list, + /// and may result in undefined behavior if the operation is linked into a different list. + pub unsafe fn cursor_mut_from_ptr( + &mut self, + ptr: UnsafeIntrusiveEntityRef, + ) -> EntityCursorMut<'_, T> { + let raw = UnsafeIntrusiveEntityRef::into_inner(ptr).as_ptr(); + unsafe { + EntityCursorMut { + cursor: self.list.cursor_mut_from_ptr(raw), + } + } + } } impl fmt::Debug for EntityList { @@ -134,10 +189,10 @@ impl fmt::Debug for EntityList { } } -impl FromIterator> for EntityList { - fn from_iter(iter: T) -> Self +impl FromIterator> for EntityList { + fn from_iter(iter: I) -> Self where - T: IntoIterator>, + I: IntoIterator>, { let mut list = EntityList::::default(); for handle in iter { @@ -148,8 +203,8 @@ impl FromIterator> for EntityList { } impl IntoIterator for EntityList { - type IntoIter = intrusive_collections::linked_list::IntoIter; - type Item = TrackedEntityHandle; + type IntoIter = intrusive_collections::linked_list::IntoIter>; + type Item = UnsafeIntrusiveEntityRef; fn into_iter(self) -> Self::IntoIter { self.list.into_iter() @@ -183,17 +238,14 @@ impl<'a, T> EntityCursor<'a, T> { /// NOTE: This returns an [EntityRef] whose lifetime is bound to the underlying [EntityList], /// _not_ the [EntityCursor], since the cursor cannot mutate the list. pub fn get(&self) -> Option> { - match self.cursor.get() { - Some(obj) => Some(obj.borrow()), - None => None, - } + self.cursor.get().map(|obj| obj.entity.borrow()) } /// Get the [TrackedEntityHandle] corresponding to the entity under the cursor. /// /// Returns `None` if the cursor is pointing to the null object. #[inline] - pub fn as_pointer(&self) -> Option> { + pub fn as_pointer(&self) -> Option> { self.cursor.clone_pointer() } @@ -261,10 +313,7 @@ impl<'a, T> EntityCursorMut<'a, T> { /// is frozen while the entity is being borrowed. This ensures that only one reference at a /// time is being handed out by this cursor. pub fn get(&self) -> Option> { - match self.cursor.get() { - Some(obj) => Some(obj.borrow()), - None => None, - } + self.cursor.get().map(|obj| obj.entity.borrow()) } /// Get a mutable reference to the entity under the cursor. @@ -276,10 +325,7 @@ impl<'a, T> EntityCursorMut<'a, T> { /// from being accessed in any way until the mutable reference is dropped. This makes it /// impossible to try and alias the underlying entity using the cursor. pub fn get_mut(&mut self) -> Option> { - match self.cursor.get() { - Some(obj) => Some(obj.borrow_mut()), - None => None, - } + self.cursor.get().map(|obj| obj.entity.borrow_mut()) } /// Returns a read-only cursor pointing to the current element. @@ -297,8 +343,8 @@ impl<'a, T> EntityCursorMut<'a, T> { /// /// Returns `None` if the cursor is pointing to the null object. #[inline] - pub fn as_pointer(&self) -> Option> { - self.cursor.clone_pointer() + pub fn as_pointer(&self) -> Option> { + self.cursor.as_cursor().clone_pointer() } /// Moves the cursor to the next element of the [EntityList]. @@ -353,7 +399,7 @@ impl<'a, T> EntityCursorMut<'a, T> { /// If the cursor is currently pointing to the null object then nothing is removed and `None` is /// returned. #[inline] - pub fn remove(&mut self) -> Option> { + pub fn remove(&mut self) -> Option> { self.cursor.remove() } @@ -370,8 +416,8 @@ impl<'a, T> EntityCursorMut<'a, T> { #[inline] pub fn replace_with( &mut self, - value: TrackedEntityHandle, - ) -> Result, TrackedEntityHandle> { + value: UnsafeIntrusiveEntityRef, + ) -> Result, UnsafeIntrusiveEntityRef> { self.cursor.replace_with(value) } @@ -384,7 +430,7 @@ impl<'a, T> EntityCursorMut<'a, T> { /// /// Panics if the entity is already linked to a different [EntityList] #[inline] - pub fn insert_after(&mut self, value: TrackedEntityHandle) { + pub fn insert_after(&mut self, value: UnsafeIntrusiveEntityRef) { self.cursor.insert_after(value) } @@ -397,7 +443,7 @@ impl<'a, T> EntityCursorMut<'a, T> { /// /// Panics if the entity is already linked to a different [EntityList] #[inline] - pub fn insert_before(&mut self, value: TrackedEntityHandle) { + pub fn insert_before(&mut self, value: UnsafeIntrusiveEntityRef) { self.cursor.insert_before(value) } @@ -472,37 +518,6 @@ impl<'a, T> EntityCursorMut<'a, T> { let list = self.cursor.split_before(); EntityList { list } } - - /// Consumes this cursor, and returns a reference to the entity that the cursor is currently - /// pointing to. - /// - /// Unlike [get], the returned reference’s lifetime is tied to [EntityList]’s lifetime. - /// - /// This returns `None` if the cursor is currently pointing to the null object. - /// - /// NOTE: This function will panic if there are any outstanding mutable borrows of the - /// underlying entity. - pub fn into_ref(self) -> Option> { - match self.cursor.get() { - Some(obj) => Some(obj.borrow()), - None => None, - } - } - - /// Consumes this cursor, and returns a mutable reference to the entity that the cursor is - /// currently pointing to. - /// - /// Unlike [get_mut], the returned reference’s lifetime is tied to the [EntityList]’s lifetime. - /// - /// This returns `None` if the cursor is currently pointing to the null object. - /// - /// NOTE: This function will panic if there are any outstanding borrows of the underlying entity - pub fn into_mut(self) -> Option> { - match self.cursor.get() { - Some(obj) => Some(obj.borrow_mut()), - None => None, - } - } } pub struct EntityIter<'a, T> { @@ -538,3 +553,161 @@ impl<'a, T> DoubleEndedIterator for EntityIter<'a, T> { Some(item) } } + +type IntrusiveLink = intrusive_collections::LinkedListLink; + +impl RawEntityRef { + /// Create a new [UnsafeIntrusiveEntityRef] by allocating `value` in `arena` + /// + /// # SAFETY + /// + /// This function has the same requirements around safety as [RawEntityRef::new]. + pub fn new(value: T, arena: &blink_alloc::Blink) -> Self { + RawEntityRef::new_with_metadata(value, IntrusiveLink::new(), arena) + } + + pub fn new_uninit(arena: &blink_alloc::Blink) -> RawEntityRef, IntrusiveLink> { + RawEntityRef::new_uninit_with_metadata(IntrusiveLink::new(), arena) + } +} +impl RawEntityRef { + /// Returns true if this entity is linked into an intrusive list + pub fn is_linked(&self) -> bool { + unsafe { + let offset = core::mem::offset_of!(RawEntityMetadata, metadata); + let current = self.inner.byte_add(offset).cast::(); + current.as_ref().is_linked() + } + } + + /// Get the previous entity in the list of `T` containing the current entity + /// + /// For example, in a list of `Operation` in a `Block`, this would return the handle of the + /// previous operation in the block, or `None` if there are no other ops before this one. + pub fn prev(&self) -> Option { + use intrusive_collections::linked_list::{LinkOps, LinkedListOps}; + unsafe { + let offset = core::mem::offset_of!(RawEntityMetadata, metadata); + let current = self.inner.byte_add(offset).cast(); + LinkOps.prev(current).map(|link_ptr| Self::from_link_ptr(link_ptr)) + } + } + + /// Get the next entity in the list of `T` containing the current entity + /// + /// For example, in a list of `Operation` in a `Block`, this would return the handle of the + /// next operation in the block, or `None` if there are no other ops after this one. + pub fn next(&self) -> Option { + use intrusive_collections::linked_list::{LinkOps, LinkedListOps}; + unsafe { + let offset = core::mem::offset_of!(RawEntityMetadata, metadata); + let current = self.inner.byte_add(offset).cast(); + LinkOps.next(current).map(|link_ptr| Self::from_link_ptr(link_ptr)) + } + } + + #[inline] + unsafe fn from_link_ptr(link: NonNull) -> Self { + let offset = core::mem::offset_of!(RawEntityMetadata, metadata); + let ptr = link.byte_sub(offset).cast::>(); + Self { inner: ptr } + } +} + +#[doc(hidden)] +pub struct DefaultPointerOps(core::marker::PhantomData); +impl Copy for DefaultPointerOps {} +impl Clone for DefaultPointerOps { + fn clone(&self) -> Self { + *self + } +} +impl Default for DefaultPointerOps { + fn default() -> Self { + Self::new() + } +} +impl DefaultPointerOps { + const fn new() -> Self { + Self(core::marker::PhantomData) + } +} + +unsafe impl intrusive_collections::PointerOps + for DefaultPointerOps> +{ + type Pointer = UnsafeIntrusiveEntityRef; + type Value = RawEntityMetadata; + + #[inline] + unsafe fn from_raw(&self, value: *const Self::Value) -> Self::Pointer { + debug_assert!(!value.is_null() && value.is_aligned()); + UnsafeIntrusiveEntityRef::from_ptr(value.cast_mut()) + } + + #[inline] + fn into_raw(&self, ptr: Self::Pointer) -> *const Self::Value { + UnsafeIntrusiveEntityRef::into_inner(ptr).as_ptr().cast_const() + } +} + +/// An adapter for storing any `Entity` impl in a [intrusive_collections::LinkedList] +pub struct EntityAdapter { + link_ops: intrusive_collections::linked_list::LinkOps, + ptr_ops: DefaultPointerOps>, + marker: core::marker::PhantomData, +} +impl Copy for EntityAdapter {} +impl Clone for EntityAdapter { + fn clone(&self) -> Self { + *self + } +} +impl Default for EntityAdapter { + fn default() -> Self { + Self::new() + } +} +impl EntityAdapter { + pub const fn new() -> Self { + Self { + link_ops: intrusive_collections::linked_list::LinkOps, + ptr_ops: DefaultPointerOps::new(), + marker: core::marker::PhantomData, + } + } +} + +unsafe impl intrusive_collections::Adapter for EntityAdapter { + type LinkOps = intrusive_collections::linked_list::LinkOps; + type PointerOps = DefaultPointerOps>; + + unsafe fn get_value( + &self, + link: ::LinkPtr, + ) -> *const ::Value { + let raw_entity_ref = UnsafeIntrusiveEntityRef::::from_link_ptr(link); + raw_entity_ref.inner.as_ptr().cast_const() + } + + unsafe fn get_link( + &self, + value: *const ::Value, + ) -> ::LinkPtr { + let raw_entity_ref = UnsafeIntrusiveEntityRef::from_ptr(value.cast_mut()); + let offset = RawEntityMetadata::::metadata_offset(); + raw_entity_ref.inner.byte_add(offset).cast() + } + + fn link_ops(&self) -> &Self::LinkOps { + &self.link_ops + } + + fn link_ops_mut(&mut self) -> &mut Self::LinkOps { + &mut self.link_ops + } + + fn pointer_ops(&self) -> &Self::PointerOps { + &self.ptr_ops + } +} diff --git a/hir2/src/core/function.rs b/hir2/src/core/function.rs index 7bd377d4e..6afd9012c 100644 --- a/hir2/src/core/function.rs +++ b/hir2/src/core/function.rs @@ -1,11 +1,7 @@ -use super::{Operation, Symbol}; -use crate::Spanned; +use core::fmt; -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct FunctionIdent { - module: midenc_hir_symbol::Symbol, - function: midenc_hir_symbol::Symbol, -} +use super::*; +use crate::{formatter, CallConv, Linkage}; #[derive(Spanned)] pub struct Function { @@ -15,68 +11,15 @@ pub struct Function { signature: Signature, } impl Symbol for Function { - type Id = midenc_hir_symbol::Symbol; + type Id = Ident; fn id(&self) -> Self::Id { self.id.function } } - -struct Function - -/// Represents the calling convention of a function. -/// -/// Calling conventions are part of a program's ABI (Application Binary Interface), and -/// they define things such how arguments are passed to a function, how results are returned, -/// etc. In essence, the contract between caller and callee is described by the calling convention -/// of a function. -/// -/// Importantly, it is perfectly normal to mix calling conventions. For example, the public -/// API for a C library will use whatever calling convention is used by C on the target -/// platform (for Miden, that would be `SystemV`). However, internally that library may use -/// the `Fast` calling convention to allow the compiler to optimize more effectively calls -/// from the public API to private functions. In short, choose a calling convention that is -/// well-suited for a given function, to the extent that other constraints don't impose a choice -/// on you. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] -#[cfg_attr( - feature = "serde", - derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) -)] -#[repr(u8)] -pub enum CallConv { - /// This calling convention is what I like to call "chef's choice" - the - /// compiler chooses it's own convention that optimizes for call performance. - /// - /// As a result of this, it is not permitted to use this convention in externally - /// linked functions, as the convention is unstable, and the compiler can't ensure - /// that the caller in another translation unit will use the correct convention. - Fast, - /// The standard calling convention used for C on most platforms - #[default] - SystemV, - /// A function which is using the WebAssembly Component Model "Canonical ABI". - Wasm, - /// A function with this calling convention must be called using - /// the `syscall` instruction. Attempts to call it with any other - /// call instruction will cause a validation error. The one exception - /// to this rule is when calling another function with the `Kernel` - /// convention that is defined in the same module, which can use the - /// standard `call` instruction. - /// - /// Kernel functions may only be defined in a kernel [Module]. - /// - /// In all other respects, this calling convention is the same as `SystemV` - Kernel, -} -impl fmt::Display for CallConv { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Fast => f.write_str("fast"), - Self::SystemV => f.write_str("C"), - Self::Wasm => f.write_str("wasm"), - Self::Kernel => f.write_str("kernel"), - } +impl Function { + pub fn signature(&self) -> &Signature { + &self.signature } } diff --git a/hir2/src/core/ident.rs b/hir2/src/core/ident.rs new file mode 100644 index 000000000..8304f80c8 --- /dev/null +++ b/hir2/src/core/ident.rs @@ -0,0 +1,230 @@ +use core::{ + cmp::Ordering, + fmt, + hash::{Hash, Hasher}, + str::FromStr, +}; + +use anyhow::anyhow; + +use super::{ + interner::{symbols, Symbol}, + SourceSpan, Spanned, +}; +use crate::formatter::{self, PrettyPrint}; + +/// Represents a globally-unique module/function name pair, with corresponding source spans. +#[derive(Copy, Clone, PartialEq, Eq, Hash, Spanned)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct FunctionIdent { + pub module: Ident, + #[span] + pub function: Ident, +} +impl FunctionIdent { + pub fn display(&self) -> impl fmt::Display + '_ { + use crate::formatter::*; + + flatten( + const_text(self.module.as_str()) + + const_text("::") + + const_text(self.function.as_str()), + ) + } +} +impl FromStr for FunctionIdent { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.rsplit_once("::") { + Some((ns, id)) => { + let module = Ident::with_empty_span(Symbol::intern(ns)); + let function = Ident::with_empty_span(Symbol::intern(id)); + Ok(Self { module, function }) + } + None => Err(anyhow!( + "invalid function name, expected fully-qualified identifier, e.g. \ + 'std::math::u64::checked_add'" + )), + } + } +} +impl fmt::Debug for FunctionIdent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("FunctionIdent") + .field("module", &self.module.name) + .field("function", &self.function.name) + .finish() + } +} +impl fmt::Display for FunctionIdent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.pretty_print(f) + } +} +impl PrettyPrint for FunctionIdent { + fn render(&self) -> formatter::Document { + use crate::formatter::*; + + flatten(const_text("(") + display(self.module) + const_text(" ") + display(self.function)) + } +} +impl PartialOrd for FunctionIdent { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for FunctionIdent { + fn cmp(&self, other: &Self) -> Ordering { + self.module.cmp(&other.module).then(self.function.cmp(&other.function)) + } +} + +/// Represents an identifier in the IR. +/// +/// An identifier is some string, along with an associated source span +#[derive(Copy, Clone, Eq, Spanned)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(into = "Symbol", from = "Symbol"))] +pub struct Ident { + pub name: Symbol, + #[span] + pub span: SourceSpan, +} +impl Default for Ident { + fn default() -> Self { + Self { + name: symbols::Empty, + span: SourceSpan::UNKNOWN, + } + } +} +impl FromStr for Ident { + type Err = core::convert::Infallible; + + fn from_str(name: &str) -> Result { + Ok(Self::from(name)) + } +} +impl<'a> From<&'a str> for Ident { + fn from(name: &'a str) -> Self { + Self::with_empty_span(Symbol::intern(name)) + } +} +impl From for Ident { + #[inline] + fn from(sym: Symbol) -> Self { + Self::with_empty_span(sym) + } +} +impl From for Symbol { + #[inline] + fn from(id: Ident) -> Self { + id.as_symbol() + } +} +impl Ident { + #[inline] + pub const fn new(name: Symbol, span: SourceSpan) -> Ident { + Ident { name, span } + } + + #[inline] + pub const fn with_empty_span(name: Symbol) -> Ident { + Ident::new(name, SourceSpan::UNKNOWN) + } + + #[inline] + pub fn as_str(self) -> &'static str { + self.name.as_str() + } + + #[inline(always)] + pub fn as_symbol(self) -> Symbol { + self.name + } + + // An identifier can be unquoted if is composed of any sequence of printable + // ASCII characters, except whitespace, quotation marks, comma, semicolon, or brackets + pub fn requires_quoting(&self) -> bool { + self.as_str().contains(|c| match c { + c if c.is_ascii_control() => true, + ' ' | '\'' | '"' | ',' | ';' | '[' | ']' => true, + c if c.is_ascii_graphic() => false, + _ => true, + }) + } +} +impl AsRef for Ident { + #[inline(always)] + fn as_ref(&self) -> &str { + self.as_str() + } +} +impl alloc::borrow::Borrow for Ident { + #[inline] + fn borrow(&self) -> &Symbol { + &self.name + } +} +impl alloc::borrow::Borrow for Ident { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} +impl Ord for Ident { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + self.as_str().cmp(other.as_str()) + } +} +impl PartialOrd for Ident { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl PartialEq for Ident { + #[inline] + fn eq(&self, rhs: &Self) -> bool { + self.name == rhs.name + } +} +impl PartialEq for Ident { + #[inline] + fn eq(&self, rhs: &Symbol) -> bool { + self.name.eq(rhs) + } +} +impl PartialEq for Ident { + fn eq(&self, rhs: &str) -> bool { + self.name.as_str() == rhs + } +} +impl Hash for Ident { + fn hash(&self, state: &mut H) { + self.name.hash(state); + } +} +impl fmt::Debug for Ident { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.pretty_print(f) + } +} +impl fmt::Display for Ident { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.pretty_print(f) + } +} +impl PrettyPrint for Ident { + fn render(&self) -> formatter::Document { + use crate::formatter::*; + + if self.requires_quoting() { + text(format!("\"{}\"", self.as_str().escape_default())) + } else { + text(format!("#{}", self.as_str())) + } + } +} diff --git a/hir2/src/core/immediates.rs b/hir2/src/core/immediates.rs new file mode 100644 index 000000000..8f8806485 --- /dev/null +++ b/hir2/src/core/immediates.rs @@ -0,0 +1,792 @@ +use core::{ + fmt, + hash::{Hash, Hasher}, +}; + +pub use miden_core::{Felt, FieldElement, StarkField}; + +use super::Type; + +#[derive(Debug, Copy, Clone)] +pub enum Immediate { + I1(bool), + U8(u8), + I8(i8), + U16(u16), + I16(i16), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + U128(u128), + I128(i128), + F64(f64), + Felt(Felt), +} +impl Immediate { + pub fn ty(&self) -> Type { + match self { + Self::I1(_) => Type::I1, + Self::U8(_) => Type::U8, + Self::I8(_) => Type::I8, + Self::U16(_) => Type::U16, + Self::I16(_) => Type::I16, + Self::U32(_) => Type::U32, + Self::I32(_) => Type::I32, + Self::U64(_) => Type::U64, + Self::I64(_) => Type::I64, + Self::U128(_) => Type::U128, + Self::I128(_) => Type::I128, + Self::F64(_) => Type::F64, + Self::Felt(_) => Type::Felt, + } + } + + /// Returns true if this immediate is a non-negative value + pub fn is_non_negative(&self) -> bool { + match self { + Self::I1(i) => *i, + Self::I8(i) => *i > 0, + Self::U8(i) => *i > 0, + Self::I16(i) => *i > 0, + Self::U16(i) => *i > 0, + Self::I32(i) => *i > 0, + Self::U32(i) => *i > 0, + Self::I64(i) => *i > 0, + Self::U64(i) => *i > 0, + Self::U128(i) => *i > 0, + Self::I128(i) => *i > 0, + Self::F64(f) => f.is_sign_positive(), + Self::Felt(_) => true, + } + } + + /// Returns true if this immediate can represent negative values + pub fn is_signed(&self) -> bool { + matches!( + self, + Self::I8(_) | Self::I16(_) | Self::I32(_) | Self::I64(_) | Self::I128(_) | Self::F64(_) + ) + } + + /// Returns true if this immediate can only represent non-negative values + pub fn is_unsigned(&self) -> bool { + matches!( + self, + Self::I1(_) + | Self::U8(_) + | Self::U16(_) + | Self::U32(_) + | Self::U64(_) + | Self::U128(_) + | Self::Felt(_) + ) + } + + /// Returns true if this immediate is an odd integer, otherwise false + /// + /// If the immediate is not an integer, returns `None` + pub fn is_odd(&self) -> Option { + match self { + Self::I1(b) => Some(*b), + Self::U8(i) => Some(*i % 2 == 0), + Self::I8(i) => Some(*i % 2 == 0), + Self::U16(i) => Some(*i % 2 == 0), + Self::I16(i) => Some(*i % 2 == 0), + Self::U32(i) => Some(*i % 2 == 0), + Self::I32(i) => Some(*i % 2 == 0), + Self::U64(i) => Some(*i % 2 == 0), + Self::I64(i) => Some(*i % 2 == 0), + Self::Felt(i) => Some(i.as_int() % 2 == 0), + Self::U128(i) => Some(*i % 2 == 0), + Self::I128(i) => Some(*i % 2 == 0), + Self::F64(_) => None, + } + } + + /// Returns true if this immediate is a non-zero integer, otherwise false + /// + /// If the immediate is not an integer, returns `None` + pub fn as_bool(self) -> Option { + match self { + Self::I1(b) => Some(b), + Self::U8(i) => Some(i != 0), + Self::I8(i) => Some(i != 0), + Self::U16(i) => Some(i != 0), + Self::I16(i) => Some(i != 0), + Self::U32(i) => Some(i != 0), + Self::I32(i) => Some(i != 0), + Self::U64(i) => Some(i != 0), + Self::I64(i) => Some(i != 0), + Self::Felt(i) => Some(i.as_int() != 0), + Self::U128(i) => Some(i != 0), + Self::I128(i) => Some(i != 0), + Self::F64(_) => None, + } + } + + /// Attempts to convert this value to a u32 + pub fn as_u32(self) -> Option { + match self { + Self::I1(b) => Some(b as u32), + Self::U8(b) => Some(b as u32), + Self::I8(b) if b >= 0 => Some(b as u32), + Self::I8(_) => None, + Self::U16(b) => Some(b as u32), + Self::I16(b) if b >= 0 => Some(b as u32), + Self::I16(_) => None, + Self::U32(b) => Some(b), + Self::I32(b) if b >= 0 => Some(b as u32), + Self::I32(_) => None, + Self::U64(b) => u32::try_from(b).ok(), + Self::I64(b) if b >= 0 => u32::try_from(b as u64).ok(), + Self::I64(_) => None, + Self::Felt(i) => u32::try_from(i.as_int()).ok(), + Self::U128(b) if b <= (u32::MAX as u64 as u128) => Some(b as u32), + Self::U128(_) => None, + Self::I128(b) if b >= 0 && b <= (u32::MAX as u64 as i128) => Some(b as u32), + Self::I128(_) => None, + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to i32 + pub fn as_i32(self) -> Option { + match self { + Self::I1(b) => Some(b as i32), + Self::U8(i) => Some(i as i32), + Self::I8(i) => Some(i as i32), + Self::U16(i) => Some(i as i32), + Self::I16(i) => Some(i as i32), + Self::U32(i) => i.try_into().ok(), + Self::I32(i) => Some(i), + Self::U64(i) => i.try_into().ok(), + Self::I64(i) => i.try_into().ok(), + Self::Felt(i) => i.as_int().try_into().ok(), + Self::U128(i) if i <= (i32::MAX as u32 as u128) => Some(i as u32 as i32), + Self::U128(_) => None, + Self::I128(i) if i >= (i32::MIN as i128) && i <= (i32::MAX as i128) => Some(i as i32), + Self::I128(_) => None, + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to a field element + pub fn as_felt(self) -> Option { + match self { + Self::I1(b) => Some(Felt::new(b as u64)), + Self::U8(b) => Some(Felt::new(b as u64)), + Self::I8(b) => u64::try_from(b).ok().map(Felt::new), + Self::U16(b) => Some(Felt::new(b as u64)), + Self::I16(b) => u64::try_from(b).ok().map(Felt::new), + Self::U32(b) => Some(Felt::new(b as u64)), + Self::I32(b) => u64::try_from(b).ok().map(Felt::new), + Self::U64(b) => Some(Felt::new(b)), + Self::I64(b) => u64::try_from(b).ok().map(Felt::new), + Self::Felt(i) => Some(i), + Self::U128(b) => u64::try_from(b).ok().map(Felt::new), + Self::I128(b) => u64::try_from(b).ok().map(Felt::new), + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to u64 + pub fn as_u64(self) -> Option { + match self { + Self::I1(b) => Some(b as u64), + Self::U8(i) => Some(i as u64), + Self::I8(i) if i >= 0 => Some(i as u64), + Self::I8(_) => None, + Self::U16(i) => Some(i as u64), + Self::I16(i) if i >= 0 => Some(i as u16 as u64), + Self::I16(_) => None, + Self::U32(i) => Some(i as u64), + Self::I32(i) if i >= 0 => Some(i as u32 as u64), + Self::I32(_) => None, + Self::U64(i) => Some(i), + Self::I64(i) if i >= 0 => Some(i as u64), + Self::I64(_) => None, + Self::Felt(i) => Some(i.as_int()), + Self::U128(i) => (i).try_into().ok(), + Self::I128(i) if i >= 0 => (i).try_into().ok(), + Self::I128(_) => None, + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to i64 + pub fn as_i64(self) -> Option { + match self { + Self::I1(b) => Some(b as i64), + Self::U8(i) => Some(i as i64), + Self::I8(i) => Some(i as i64), + Self::U16(i) => Some(i as i64), + Self::I16(i) => Some(i as i64), + Self::U32(i) => Some(i as i64), + Self::I32(i) => Some(i as i64), + Self::U64(i) => (i).try_into().ok(), + Self::I64(i) => Some(i), + Self::Felt(i) => i.as_int().try_into().ok(), + Self::U128(i) if i <= i64::MAX as u128 => Some(i as u64 as i64), + Self::U128(_) => None, + Self::I128(i) => (i).try_into().ok(), + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to u128 + pub fn as_u128(self) -> Option { + match self { + Self::I1(b) => Some(b as u128), + Self::U8(i) => Some(i as u128), + Self::I8(i) if i >= 0 => Some(i as u128), + Self::I8(_) => None, + Self::U16(i) => Some(i as u128), + Self::I16(i) if i >= 0 => Some(i as u16 as u128), + Self::I16(_) => None, + Self::U32(i) => Some(i as u128), + Self::I32(i) if i >= 0 => Some(i as u32 as u128), + Self::I32(_) => None, + Self::U64(i) => Some(i as u128), + Self::I64(i) if i >= 0 => Some(i as u128), + Self::I64(_) => None, + Self::Felt(i) => Some(i.as_int() as u128), + Self::U128(i) => Some(i), + Self::I128(i) if i >= 0 => (i).try_into().ok(), + Self::I128(_) => None, + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to i128 + pub fn as_i128(self) -> Option { + match self { + Self::I1(b) => Some(b as i128), + Self::U8(i) => Some(i as i128), + Self::I8(i) => Some(i as i128), + Self::U16(i) => Some(i as i128), + Self::I16(i) => Some(i as i128), + Self::U32(i) => Some(i as i128), + Self::I32(i) => Some(i as i128), + Self::U64(i) => Some(i as i128), + Self::I64(i) => Some(i as i128), + Self::Felt(i) => Some(i.as_int() as i128), + Self::U128(i) if i <= i128::MAX as u128 => Some(i as i128), + Self::U128(_) => None, + Self::I128(i) => Some(i), + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } +} +impl fmt::Display for Immediate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::I1(i) => write!(f, "{}", i), + Self::U8(i) => write!(f, "{}", i), + Self::I8(i) => write!(f, "{}", i), + Self::U16(i) => write!(f, "{}", i), + Self::I16(i) => write!(f, "{}", i), + Self::U32(i) => write!(f, "{}", i), + Self::I32(i) => write!(f, "{}", i), + Self::U64(i) => write!(f, "{}", i), + Self::I64(i) => write!(f, "{}", i), + Self::U128(i) => write!(f, "{}", i), + Self::I128(i) => write!(f, "{}", i), + Self::F64(n) => write!(f, "{}", n), + Self::Felt(i) => write!(f, "{}", i), + } + } +} +impl Hash for Immediate { + fn hash(&self, state: &mut H) { + let d = std::mem::discriminant(self); + d.hash(state); + match self { + Self::I1(i) => i.hash(state), + Self::U8(i) => i.hash(state), + Self::I8(i) => i.hash(state), + Self::U16(i) => i.hash(state), + Self::I16(i) => i.hash(state), + Self::U32(i) => i.hash(state), + Self::I32(i) => i.hash(state), + Self::U64(i) => i.hash(state), + Self::I64(i) => i.hash(state), + Self::U128(i) => i.hash(state), + Self::I128(i) => i.hash(state), + Self::F64(f) => { + let bytes = f.to_be_bytes(); + bytes.hash(state) + } + Self::Felt(i) => i.as_int().hash(state), + } + } +} +impl Eq for Immediate {} +impl PartialEq for Immediate { + fn eq(&self, other: &Self) -> bool { + match (*self, *other) { + (Self::I8(x), Self::I8(y)) => x == y, + (Self::U16(x), Self::U16(y)) => x == y, + (Self::I16(x), Self::I16(y)) => x == y, + (Self::U32(x), Self::U32(y)) => x == y, + (Self::I32(x), Self::I32(y)) => x == y, + (Self::U64(x), Self::U64(y)) => x == y, + (Self::I64(x), Self::I64(y)) => x == y, + (Self::U128(x), Self::U128(y)) => x == y, + (Self::I128(x), Self::I128(y)) => x == y, + (Self::F64(x), Self::F64(y)) => x == y, + (Self::Felt(x), Self::Felt(y)) => x == y, + _ => false, + } + } +} +impl PartialEq for Immediate { + fn eq(&self, other: &isize) -> bool { + let y = *other; + match *self { + Self::I1(x) => x == (y == 1), + Self::U8(_) if y < 0 => false, + Self::U8(x) => x as isize == y, + Self::I8(x) => x as isize == y, + Self::U16(_) if y < 0 => false, + Self::U16(x) => x as isize == y, + Self::I16(x) => x as isize == y, + Self::U32(_) if y < 0 => false, + Self::U32(x) => x as isize == y, + Self::I32(x) => x as isize == y, + Self::U64(_) if y < 0 => false, + Self::U64(x) => x == y as i64 as u64, + Self::I64(x) => x == y as i64, + Self::U128(_) if y < 0 => false, + Self::U128(x) => x == y as i128 as u128, + Self::I128(x) => x == y as i128, + Self::F64(_) => false, + Self::Felt(_) if y < 0 => false, + Self::Felt(x) => x.as_int() == y as i64 as u64, + } + } +} +impl PartialOrd for Immediate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for Immediate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + use std::cmp::Ordering; + + match (self, other) { + // Floats require special treatment + (Self::F64(x), Self::F64(y)) => x.total_cmp(y), + // Here we're attempting to compare against any integer immediate, + // so we must attempt to convert the float to the largest possible + // integer representation, i128, and then promote the integer immediate + // to i128 for comparison + // + // If the float is not an integer value, truncate it and compare, then + // adjust the result to account for the truncation + (Self::F64(x), y) => { + let y = y + .as_i128() + .expect("expected rhs to be an integer capable of fitting in an i128"); + if let Ok(x) = FloatToInt::::to_int(*x) { + x.cmp(&y) + } else { + let is_positive = x.is_sign_positive(); + if let Ok(x) = FloatToInt::::to_int((*x).trunc()) { + // Edge case for equality: the float must be bigger due to truncation + match x.cmp(&y) { + Ordering::Equal if is_positive => Ordering::Greater, + Ordering::Equal => Ordering::Less, + o => o, + } + } else { + // The float is larger than i128 can represent, the sign tells us in what + // direction + if is_positive { + Ordering::Greater + } else { + Ordering::Less + } + } + } + } + (x, y @ Self::F64(_)) => y.cmp(x).reverse(), + // u128 immediates require separate treatment + (Self::U128(x), Self::U128(y)) => x.cmp(y), + (Self::U128(x), y) => { + let y = y.as_u128().expect("expected rhs to be an integer in the range of u128"); + x.cmp(&y) + } + (x, Self::U128(y)) => { + let x = x.as_u128().expect("expected lhs to be an integer in the range of u128"); + x.cmp(y) + } + // i128 immediates require separate treatment + (Self::I128(x), Self::I128(y)) => x.cmp(y), + // We're only comparing against values here which are u64, i64, or smaller than 64-bits + (Self::I128(x), y) => { + let y = y.as_i128().expect("expected rhs to be an integer smaller than i128"); + x.cmp(&y) + } + (x, Self::I128(y)) => { + let x = x.as_i128().expect("expected lhs to be an integer smaller than i128"); + x.cmp(y) + } + // u64 immediates may not fit in an i64 + (Self::U64(x), Self::U64(y)) => x.cmp(y), + // We're only comparing against values here which are i64, or smaller than 64-bits + (Self::U64(x), y) => { + let y = + y.as_i64().expect("expected rhs to be an integer capable of fitting in an i64") + as u64; + x.cmp(&y) + } + (x, Self::U64(y)) => { + let x = + x.as_i64().expect("expected lhs to be an integer capable of fitting in an i64") + as u64; + x.cmp(y) + } + // All immediates at this point are i64 or smaller + (x, y) => { + let x = + x.as_i64().expect("expected lhs to be an integer capable of fitting in an i64"); + let y = + y.as_i64().expect("expected rhs to be an integer capable of fitting in an i64"); + x.cmp(&y) + } + } + } +} +impl From for Type { + #[inline] + fn from(imm: Immediate) -> Self { + imm.ty() + } +} +impl From<&Immediate> for Type { + #[inline(always)] + fn from(imm: &Immediate) -> Self { + imm.ty() + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: bool) -> Self { + Self::I1(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i8) -> Self { + Self::I8(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u8) -> Self { + Self::U8(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i16) -> Self { + Self::I16(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u16) -> Self { + Self::U16(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i32) -> Self { + Self::I32(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u32) -> Self { + Self::U32(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i64) -> Self { + Self::I64(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u64) -> Self { + Self::U64(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u128) -> Self { + Self::U128(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i128) -> Self { + Self::I128(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: f64) -> Self { + Self::F64(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: char) -> Self { + Self::I32(value as u32 as i32) + } +} + +trait FloatToInt: Sized { + const ZERO: T; + + fn upper_bound() -> Self; + fn lower_bound() -> Self; + fn to_int(self) -> Result; + unsafe fn to_int_unchecked(self) -> T; +} +impl FloatToInt for f64 { + const ZERO: i8 = 0; + + fn upper_bound() -> Self { + f64::from(i8::MAX) + 1.0 + } + + fn lower_bound() -> Self { + f64::from(i8::MIN) - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i8 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: u8 = 0; + + fn upper_bound() -> Self { + f64::from(u8::MAX) + 1.0 + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u8 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: i16 = 0; + + fn upper_bound() -> Self { + f64::from(i16::MAX) + 1.0 + } + + fn lower_bound() -> Self { + f64::from(i16::MIN) - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i16 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: u16 = 0; + + fn upper_bound() -> Self { + f64::from(u16::MAX) + 1.0 + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u16 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: i32 = 0; + + fn upper_bound() -> Self { + f64::from(i32::MAX) + 1.0 + } + + fn lower_bound() -> Self { + f64::from(i32::MIN) - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i32 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: u32 = 0; + + fn upper_bound() -> Self { + f64::from(u32::MAX) + 1.0 + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u32 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: i64 = 0; + + fn upper_bound() -> Self { + 63.0f64.exp2() + } + + fn lower_bound() -> Self { + (63.0f64.exp2() * -1.0) - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i64 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: u64 = 0; + + fn upper_bound() -> Self { + 64.0f64.exp2() + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u64 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: Felt = Felt::ZERO; + + fn upper_bound() -> Self { + 64.0f64.exp2() - 32.0f64.exp2() + 1.0 + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self).map(Felt::new) + } + + unsafe fn to_int_unchecked(self) -> Felt { + Felt::new(f64::to_int_unchecked::(self)) + } +} +impl FloatToInt for f64 { + const ZERO: u128 = 0; + + fn upper_bound() -> Self { + 128.0f64.exp2() + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u128 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: i128 = 0; + + fn upper_bound() -> Self { + f64::from(i128::BITS - 1).exp2() + } + + fn lower_bound() -> Self { + (f64::from(i128::BITS - 1) * -1.0).exp2() - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i128 { + f64::to_int_unchecked(self) + } +} + +fn float_to_int(f: f64) -> Result +where + I: Copy, + f64: FloatToInt, +{ + use std::num::FpCategory; + match f.classify() { + FpCategory::Nan | FpCategory::Infinite | FpCategory::Subnormal => Err(()), + FpCategory::Zero => Ok(>::ZERO), + FpCategory::Normal => { + if f == f.trunc() + && f > >::lower_bound() + && f < >::upper_bound() + { + // SAFETY: We know that x must be integral, and within the bounds of its type + Ok(unsafe { >::to_int_unchecked(f) }) + } else { + Err(()) + } + } + } +} diff --git a/hir2/src/core/module.rs b/hir2/src/core/module.rs index bd62388b3..738a7e875 100644 --- a/hir2/src/core/module.rs +++ b/hir2/src/core/module.rs @@ -1,43 +1,49 @@ -use std::{ - any::{Any, TypeId}, - collections::BTreeMap, -}; +use alloc::collections::BTreeMap; -use super::{Function, FunctionIdent, Symbol, SymbolTable}; -use crate::UnsafeRef; +use super::{EntityList, Function, Ident, Symbol, SymbolTable, UnsafeIntrusiveEntityRef}; pub struct Module { - name: midenc_hir_symbol::Symbol, - functions: BTreeMap>, + name: Ident, + functions: EntityList, + registry: BTreeMap>, +} +impl Module { + pub const fn name(&self) -> Ident { + self.name + } + + pub fn functions(&self) -> &EntityList { + &self.functions + } + + pub fn functions_mut(&mut self) -> &mut EntityList { + &mut self.functions + } } impl SymbolTable for Module { - type Key = midenc_hir_symbol::Symbol; + type Entry = UnsafeIntrusiveEntityRef; + type Key = Ident; - fn get(&self, id: &Self::Key) -> Option> - where - T: Symbol, - { - if TypeId::of::() == TypeId::of::() { - self.functions.get(id).copied().map(|unsafe_ref| { - let ptr = unsafe_ref.into_raw(); - UnsafeRef::new(ptr.cast()) - }) - } else { - None - } + fn get(&self, id: &Self::Key) -> Option { + self.registry.get(id).cloned() } - fn insert(&self, entry: UnsafeRef) -> bool - where - T: Symbol, - { - todo!() + fn insert(&mut self, entry: Self::Entry) -> bool { + let id = entry.borrow().id(); + if self.registry.contains_key(&id) { + return false; + } + self.registry.insert(id, entry.clone()); + self.functions.push_back(entry); + true } - fn remove(&self, id: &Self::Key) -> Option> - where - T: Symbol, - { - todo!() + fn remove(&mut self, id: &Self::Key) -> Option { + if let Some(ptr) = self.registry.remove(id) { + let mut cursor = unsafe { self.functions.cursor_mut_from_ptr(ptr) }; + cursor.remove() + } else { + None + } } } diff --git a/hir2/src/core/op.rs b/hir2/src/core/op.rs index e69de29bb..136fd6392 100644 --- a/hir2/src/core/op.rs +++ b/hir2/src/core/op.rs @@ -0,0 +1,116 @@ +use downcast_rs::{impl_downcast, Downcast}; + +use super::*; + +pub trait Op: Downcast + OpVerifier { + /// The name of this operation's opcode + /// + /// The opcode must be distinct from all other opcodes in the same dialect + fn name(&self) -> OperationName; + fn as_operation(&self) -> &Operation; + fn as_operation_mut(&mut self) -> &mut Operation; + + fn parent(&self) -> Option { + self.as_operation().parent() + } + fn parent_region(&self) -> Option { + self.as_operation().parent_region() + } + fn parent_op(&self) -> Option { + self.as_operation().parent_op() + } + fn regions(&self) -> &RegionList { + &self.as_operation().regions + } + fn operands(&self) -> &[OpOperand] { + self.as_operation().operands.as_slice() + } + fn results(&self) -> &[OpResultRef] { + self.as_operation().results.as_slice() + } + fn successors(&self) -> &[OpSuccessor] { + self.as_operation().successors.as_slice() + } +} + +impl_downcast!(Op); + +impl Spanned for dyn Op { + fn span(&self) -> SourceSpan { + self.as_operation().span + } +} + +pub trait OpExt { + /// Return the value associated with attribute `name` for this function + fn get_attribute(&self, name: &Q) -> Option<&dyn AttributeValue> + where + interner::Symbol: std::borrow::Borrow, + Q: Ord + ?Sized; + + /// Return true if this function has an attributed named `name` + fn has_attribute(&self, name: &Q) -> bool + where + interner::Symbol: std::borrow::Borrow, + Q: Ord + ?Sized; + + /// Set the attribute `name` with `value` for this function. + fn set_attribute( + &mut self, + name: impl Into, + value: Option, + ); + + /// Remove any attribute with the given name from this function + fn remove_attribute(&mut self, name: &Q) + where + interner::Symbol: std::borrow::Borrow, + Q: Ord + ?Sized; + + /// Returns a handle to the nearest containing [Operation] of type `T` for this operation, if it + /// is attached to one + fn nearest_parent_op(&self) -> Option>; +} + +impl OpExt for T { + #[inline] + fn get_attribute(&self, name: &Q) -> Option<&dyn AttributeValue> + where + interner::Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.as_operation().get_attribute(name) + } + + #[inline] + fn has_attribute(&self, name: &Q) -> bool + where + interner::Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.as_operation().has_attribute(name) + } + + #[inline] + fn set_attribute( + &mut self, + name: impl Into, + value: Option, + ) { + self.as_operation_mut().set_attribute(name, value); + } + + #[inline] + fn remove_attribute(&mut self, name: &Q) + where + interner::Symbol: std::borrow::Borrow, + Q: Ord + ?Sized, + { + self.as_operation_mut().remove_attribute(name); + } + + #[inline] + fn nearest_parent_op(&self) -> Option> { + self.as_operation().nearest_parent_op() + } +} diff --git a/hir2/src/core/operation.rs b/hir2/src/core/operation.rs index bdf6d2d20..a097e5b53 100644 --- a/hir2/src/core/operation.rs +++ b/hir2/src/core/operation.rs @@ -1,28 +1,58 @@ use core::{ - any::{Any, TypeId}, - mem, - ptr::{NonNull, Pointee}, + fmt, + marker::Unsize, + ptr::{DynMetadata, Pointee}, }; -use cranelift_entity::{packed_option::ReservedValue, EntityRef}; -use downcast_rs::{impl_downcast, Downcast}; -use intrusive_collections::{ - container_of, intrusive_adapter, - linked_list::{LinkOps, LinkedListOps}, - LinkedListLink, UnsafeRef, -}; use smallvec::SmallVec; use super::*; +pub type OperationRef = UnsafeIntrusiveEntityRef; pub type OpList = EntityList; pub type OpCursor<'a> = EntityCursor<'a, Operation>; pub type OpCursorMut<'a> = EntityCursorMut<'a, Operation>; +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct OperationName { + pub dialect: DialectName, + pub name: interner::Symbol, +} +impl OperationName { + pub fn new(dialect: DialectName, name: S) -> Self + where + S: Into, + { + Self { + dialect, + name: name.into(), + } + } +} +impl fmt::Debug for OperationName { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} +impl fmt::Display for OperationName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}.{}", &self.dialect, &self.name) + } +} + /// An [OpSuccessor] is a BlockOperand + OpOperands for that block, attached to an Operation -struct OpSuccessor { - block: TrackedEntityHandle, - args: SmallVec<[TrackedEntityHandle; 1]>, +pub struct OpSuccessor { + pub block: BlockOperandRef, + pub args: SmallVec<[OpOperand; 1]>, +} +impl fmt::Debug for OpSuccessor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpSuccessor") + .field("block", &self.block.borrow().block_id()) + .field("args", &self.args) + .finish() + } } // TODO: We need a safe way to construct arbitrary Ops imperatively: @@ -40,13 +70,24 @@ struct OpSuccessor { // * Generated methods can compute offsets, but how do we generate the specialized builders? pub struct OperationBuilder<'a, T> { context: &'a Context, - op: Operation, + op: UnsafeIntrusiveEntityRef, _marker: core::marker::PhantomData, } -impl OperationBuilder { - pub fn new(context: &'a Context) -> Self { - let op = Operation::uninit::(); - let handle = context.alloc_uninit_tracked(op); +impl<'a, T: Op> OperationBuilder<'a, T> { + pub fn new(context: &'a Context, op: T) -> Self { + let mut op = context.alloc_tracked(op); + + // SAFETY: Setting the data pointer of the multi-trait vtable must ensure + // that it points to the concrete type of the allocation, which we can guarantee here, + // having just allocated it. Until the data pointer is set, casts using the vtable are + // undefined behavior, so by never allowing the uninitialized vtable to be accessed, + // we can ensure the multi-trait impl is safe + unsafe { + let data_ptr = UnsafeIntrusiveEntityRef::as_ptr(&op); + let mut op_mut = op.borrow_mut(); + op_mut.as_operation_mut().vtable.set_data_ptr(data_ptr.cast_mut()); + } + Self { context, op, @@ -54,8 +95,77 @@ impl OperationBuilder { } } - pub fn build(self) -> TrackedEntityHandle { - todo!() + /// Register this op as an implementation of `Trait`. + /// + /// This is enforced statically by the type system, as well as dynamically via verification. + /// + /// This must be called for any trait that you wish to be able to cast the type-erased + /// [Operation] to later, or if you wish to get a `dyn Trait` reference from a `dyn Op` + /// reference. + /// + /// If `Trait` has a verifier implementation, it will be automatically applied when calling + /// [Operation::verify]. + pub fn implement(&mut self) + where + Trait: ?Sized + Pointee> + 'static, + T: Unsize + verifier::Verifier + 'static, + { + let mut op = self.op.borrow_mut(); + let operation = op.as_operation_mut(); + operation.vtable.register_trait::(); + } + + /// Set attribute `name` on this op to `value` + pub fn with_attr(&mut self, name: &'static str, value: A) + where + A: AttributeValue, + { + let mut op = self.op.borrow_mut(); + op.as_operation_mut().attrs.insert(interner::Symbol::intern(name), Some(value)); + } + + /// Set the operands given to this op + pub fn with_operands(&mut self, operands: I) + where + I: IntoIterator, + { + let mut op = self.op.borrow_mut(); + // TODO: Verify the safety of this conversion + let owner = unsafe { + let ptr = op.as_operation() as *const Operation; + UnsafeIntrusiveEntityRef::from_raw(ptr) + }; + let operands = operands.into_iter().enumerate().map(|(index, value)| { + self.context + .alloc_tracked(value::OpOperandImpl::new(value, owner.clone(), index as u8)) + }); + let op_mut = op.as_operation_mut(); + op_mut.operands.clear(); + op_mut.operands.extend(operands); + } + + /// Allocate `n` results for this op, of unknown type, to be filled in later + pub fn with_results(&mut self, n: usize) { + let mut op = self.op.borrow_mut(); + let owner = unsafe { + let ptr = op.as_operation() as *const Operation; + UnsafeIntrusiveEntityRef::from_raw(ptr) + }; + let results = + (0..n).map(|idx| self.context.make_result(Type::Unknown, owner.clone(), idx as u8)); + let op_mut = op.as_operation_mut(); + op_mut.results.clear(); + op_mut.results.extend(results); + } + + /// Consume this builder, verify the op, and return a handle to it, or an error if validation + /// failed. + pub fn build(self) -> Result, Report> { + { + let op = self.op.borrow(); + op.as_operation().verify(self.context)?; + } + Ok(self.op) } } @@ -73,22 +183,33 @@ pub struct Operation { /// The containing block of this operation /// /// Is set to `None` if this operation is detached - pub block: Option>, + pub block: Option, /// The set of operands for this operation /// /// NOTE: If the op supports immediate operands, the storage for the immediates is handled /// by the op, rather than here. Additionally, the semantics of the immediate operands are /// determined by the op, e.g. whether the immediate operands are always applied first, or /// what they are used for. - pub operands: SmallVec<[TrackedEntityHandle; 1]>, + pub operands: SmallVec<[OpOperand; 1]>, /// The set of values produced by this operation. - pub results: SmallVec<[Value; 1]>, + pub results: SmallVec<[OpResultRef; 1]>, /// If this operation represents control flow, this field stores the set of successors, /// and successor operands. pub successors: SmallVec<[OpSuccessor; 1]>, /// The set of regions belonging to this operation, if any pub regions: RegionList, } +impl fmt::Debug for Operation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Operation") + .field("attrs", &self.attrs) + .field("block", &self.block.as_ref().map(|b| b.borrow().id())) + .field("operands", &self.operands) + .field("results", &self.results) + .field("successors", &self.successors) + .finish_non_exhaustive() + } +} impl AsRef for Operation { fn as_ref(&self) -> &dyn Op { self.vtable.downcast_trait().unwrap() @@ -100,8 +221,8 @@ impl AsMut for Operation { } } impl Operation { - fn uninit() -> Self { - use crate::traits::MultiTraitVtable; + pub fn uninit() -> Self { + use super::traits::MultiTraitVtable; let mut vtable = MultiTraitVtable::new::(); vtable.register_trait::(); @@ -119,6 +240,14 @@ impl Operation { } } +/// Verification +impl Operation { + pub fn verify(&self, context: &Context) -> Result<(), Report> { + let dyn_op: &dyn Op = self.as_ref(); + dyn_op.verify(context) + } +} + /// Traits/Casts impl Operation { /// Returns true if the concrete type of this operation is `T` @@ -131,7 +260,7 @@ impl Operation { #[inline] pub fn implements(&self) -> bool where - Trait: ?Sized + Pointee + 'static, + Trait: ?Sized + Pointee> + 'static, { self.vtable.implements::() } @@ -145,37 +274,57 @@ impl Operation { pub fn downcast_mut(&mut self) -> Option<&mut T> { self.vtable.downcast_mut::() } + + /// Attempt to cast this operation reference to an implementation of `Trait` + pub fn as_trait(&self) -> Option<&Trait> + where + Trait: ?Sized + Pointee> + 'static, + { + self.vtable.downcast_trait() + } + + /// Attempt to cast this operation reference to an implementation of `Trait` + pub fn as_trait_mut(&mut self) -> Option<&mut Trait> + where + Trait: ?Sized + Pointee> + 'static, + { + self.vtable.downcast_trait_mut() + } } /// Attributes impl Operation { /// Return the value associated with attribute `name` for this function - pub fn get_attribute(&self, name: &Q) -> Option<&AttributeValue> + pub fn get_attribute(&self, name: &Q) -> Option<&dyn AttributeValue> where - Symbol: std::borrow::Borrow, + interner::Symbol: std::borrow::Borrow, Q: Ord + ?Sized, { - self.attrs.get(name) + self.attrs.get_any(name) } /// Return true if this function has an attributed named `name` pub fn has_attribute(&self, name: &Q) -> bool where - Symbol: std::borrow::Borrow, + interner::Symbol: std::borrow::Borrow, Q: Ord + ?Sized, { self.attrs.has(name) } /// Set the attribute `name` with `value` for this function. - pub fn set_attribute(&mut self, name: impl Into, value: impl Into) { + pub fn set_attribute( + &mut self, + name: impl Into, + value: Option, + ) { self.attrs.insert(name, value); } /// Remove any attribute with the given name from this function pub fn remove_attribute(&mut self, name: &Q) where - Symbol: std::borrow::Borrow, + interner::Symbol: std::borrow::Borrow, Q: Ord + ?Sized, { self.attrs.remove(name); @@ -184,153 +333,101 @@ impl Operation { /// Navigation impl Operation { - pub fn prev(&self) -> Option { - unsafe { - let current = core::ptr::NonNull::new_unchecked(&self.link); - LinkOps.prev(current).map(Self::link_to_key) - } + /// Returns a handle to the containing [Block] of this operation, if it is attached to one + pub fn parent(&self) -> Option { + self.block.clone() } - pub fn next(&self) -> Option { - unsafe { - let current = core::ptr::NonNull::new_unchecked(&self.link); - LinkOps.next(current).map(Self::link_to_key) - } + /// Returns a handle to the containing [Region] of this operation, if it is attached to one + pub fn parent_region(&self) -> Option { + self.block.as_ref().and_then(|block| block.borrow().parent()) } - #[inline] - unsafe fn link_to_key(link: NonNull) -> OpId { - let link = link.as_ref(); - let operation = container_of!(link, Operation, link); - let key_offset = mem::offset_of!(Operation, key); - let prev_key = operation.byte_add(key_offset as isize) as *const OpId; - *prev_key + /// Returns a handle to the nearest containing [Operation] of this operation, if it is attached + /// to one + pub fn parent_op(&self) -> Option { + self.block.as_ref().and_then(|block| block.borrow().parent_op()) } -} -/// Operands -impl Operation { - pub fn replaces_uses_of_with(&mut self, from: Value, to: Value) { - if from == to { - return; - } - - for operand in self.operands.iter_mut() { - if operand == &from { - *operand = to; + /// Returns a handle to the nearest containing [Operation] of type `T` for this operation, if it + /// is attached to one + pub fn nearest_parent_op(&self) -> Option> { + let mut parent = self.parent_op(); + while let Some(op) = parent.take() { + let entity_ref = op.borrow(); + parent = entity_ref.parent_op(); + if let Some(t_ref) = entity_ref.downcast_ref::() { + return Some(unsafe { UnsafeIntrusiveEntityRef::from_raw(t_ref) }); } } + None } } -pub trait Op: Downcast { - type Id: Copy + PartialEq + Eq + PartialOrd + Ord; - - fn id(&self) -> Self::Id; - fn name(&self) -> &'static str; - fn parent(&self) -> Option { - let parent = self.as_operation().parent; - if parent.is_reserved_value() { - None - } else { - Some(parent) - } - } - fn prev(&self) -> Option { - self.as_operation().prev() - } - fn next(&self) -> Option { - self.as_operation().next() - } - fn parent_block(&self) -> Option { - let block = self.as_operation().block; - if block.is_reserved_value() { - None - } else { - Some(block) - } - } - fn regions(&self) -> &[RegionId] { - self.as_operation().regions.as_slice() - } - fn operands(&self) -> &ValueList { - &self.as_operation().operands - } - fn results(&self) -> &ValueList { - &self.as_operation().results - } - fn successors(&self) -> &[Successor] { - self.as_operation().successors.as_slice() +/// Regions +impl Operation { + #[inline] + pub fn has_regions(&self) -> bool { + !self.regions.is_empty() } - fn as_operation(&self) -> &Operation; - fn as_operation_mut(&mut self) -> &mut Operation; -} -impl_downcast!(Op assoc Id where Id: Copy + PartialEq + Eq + PartialOrd + Ord); - -impl miden_assembly::Spanned for dyn Op { - fn span(&self) -> SourceSpan { - self.as_operation().span + #[inline] + pub fn num_regions(&self) -> usize { + self.regions.len() } -} -pub trait OpExt { - /// Return the value associated with attribute `name` for this function - fn get_attribute(&self, name: &Q) -> Option<&AttributeValue> - where - Symbol: std::borrow::Borrow, - Q: Ord + ?Sized; - - /// Return true if this function has an attributed named `name` - fn has_attribute(&self, name: &Q) -> bool - where - Symbol: std::borrow::Borrow, - Q: Ord + ?Sized; - - /// Set the attribute `name` with `value` for this function. - fn set_attribute(&mut self, name: impl Into, value: impl Into); + #[inline(always)] + pub fn regions(&self) -> &RegionList { + &self.regions + } - /// Remove any attribute with the given name from this function - fn remove_attribute(&mut self, name: &Q) - where - Symbol: std::borrow::Borrow, - Q: Ord + ?Sized; + #[inline(always)] + pub fn regions_mut(&mut self) -> &mut RegionList { + &mut self.regions + } } -impl OpExt for T { - /// Return the value associated with attribute `name` for this function +/// Operands +impl Operation { #[inline] - fn get_attribute(&self, name: &Q) -> Option<&AttributeValue> - where - Symbol: std::borrow::Borrow, - Q: Ord + ?Sized, - { - self.as_operation().get_attribute(name) + pub fn has_operands(&self) -> bool { + !self.operands.is_empty() } - /// Return true if this function has an attributed named `name` #[inline] - fn has_attribute(&self, name: &Q) -> bool - where - Symbol: std::borrow::Borrow, - Q: Ord + ?Sized, - { - self.as_operation().has_attribute(name) + pub fn num_operands(&self) -> usize { + self.operands.len() } - /// Set the attribute `name` with `value` for this function. #[inline] - fn set_attribute(&mut self, name: impl Into, value: impl Into) { - self.as_operation_mut().insert(name, value); + pub fn operands(&self) -> &[OpOperand] { + self.operands.as_slice() } - /// Remove any attribute with the given name from this function - #[inline] - fn remove_attribute(&mut self, name: &Q) - where - Symbol: std::borrow::Borrow, - Q: Ord + ?Sized, - { - self.as_operation_mut().remove(name); + pub fn replaces_uses_of_with(&mut self, mut from: ValueRef, mut to: ValueRef) { + if ValueRef::ptr_eq(&from, &to) { + return; + } + + let from_id = from.borrow().id(); + if from_id == to.borrow().id() { + return; + } + + for mut operand in self.operands.iter().cloned() { + if operand.borrow().value.borrow().id() == from_id { + debug_assert!(operand.is_linked()); + // Remove the operand from `from` + { + let mut from_mut = from.borrow_mut(); + let from_uses = from_mut.uses_mut(); + let mut cursor = unsafe { from_uses.cursor_mut_from_ptr(operand.clone()) }; + cursor.remove(); + } + // Add the operand to `to` + operand.borrow_mut().value = to.clone(); + to.borrow_mut().insert_use(operand); + } + } } } diff --git a/hir2/src/core/region.rs b/hir2/src/core/region.rs index d8b07cf54..c28896175 100644 --- a/hir2/src/core/region.rs +++ b/hir2/src/core/region.rs @@ -1,11 +1,10 @@ -use super::{BlockList, EntityCursor, EntityCursorMut, EntityHandle, EntityList}; +use super::*; +pub type RegionRef = UnsafeIntrusiveEntityRef; /// An intrusive, doubly-linked list of [Region]s pub type RegionList = EntityList; - /// A cursor in a [RegionList] pub type RegionCursor<'a> = EntityCursor<'a, Region>; - /// A mutable cursor in a [RegionList] pub type RegionCursorMut<'a> = EntityCursorMut<'a, Region>; @@ -13,7 +12,28 @@ pub struct Region { /// The operation this region is attached to. /// /// If `link.is_linked() == true`, this will always be set to a valid pointer - owner: Option>, + owner: Option, /// The list of [Block]s that comprise this region body: BlockList, } +impl Region { + /// Get the defining [Operation] for this region, if the region is attached to one. + pub fn parent(&self) -> Option { + self.owner.clone() + } + + /// Get a handle to the entry block for this region + pub fn entry(&self) -> EntityRef<'_, Block> { + self.body.front().get().unwrap() + } + + /// Get the list of blocks comprising the body of this region + pub fn body(&self) -> &BlockList { + &self.body + } + + /// Get a mutable reference to the list of blocks comprising the body of this region + pub fn body_mut(&mut self) -> &mut BlockList { + &mut self.body + } +} diff --git a/hir2/src/core/symbol_table.rs b/hir2/src/core/symbol_table.rs index e122d1557..39add5b0a 100644 --- a/hir2/src/core/symbol_table.rs +++ b/hir2/src/core/symbol_table.rs @@ -1,7 +1,5 @@ use core::any::Any; -use crate::UnsafeRef; - /// A [SymbolTable] is an IR entity which contains other IR entities, called _symbols_, each of /// which has a name, aka symbol, that uniquely identifies it amongst all other entities in the /// same [SymbolTable]. @@ -14,31 +12,19 @@ use crate::UnsafeRef; pub trait SymbolTable { /// The unique key type associated with entries in this symbol table type Key; - - /// Check if `id` is associated with an entry of type `T` in this table - fn has_symbol_of_type(&self, id: &Self::Key) -> bool - where - T: Symbol, - { - self.get::(id) - } + /// The value type of an entry in the symbol table + type Entry; /// Get the entry for `id` in this table - fn get(&self, id: &Self::Key) -> Option> - where - T: Symbol; + fn get(&self, id: &Self::Key) -> Option; /// Insert `entry` in the symbol table. /// /// Returns `true` if successful, or `false` if an entry already exists - fn insert(&self, entry: UnsafeRef) -> bool - where - T: Symbol; + fn insert(&mut self, entry: Self::Entry) -> bool; /// Remove the symbol `id`, and return the entry if one was present. - fn remove(&self, id: &Self::Key) -> Option> - where - T: Symbol; + fn remove(&mut self, id: &Self::Key) -> Option; } /// A [Symbol] is an IR entity with an associated _symbol_, or name, which is expected to be unique diff --git a/hir2/src/core/traits.rs b/hir2/src/core/traits.rs index 8414aa318..4bff9295b 100644 --- a/hir2/src/core/traits.rs +++ b/hir2/src/core/traits.rs @@ -1,6 +1,9 @@ mod multitrait; +use midenc_session::diagnostics::Severity; + pub(crate) use self::multitrait::MultiTraitVtable; +use crate::{derive, Context, Operation, Report, Spanned}; /// Marker trait for commutative ops, e.g. `X op Y == Y op X` pub trait Commutative {} @@ -20,11 +23,166 @@ pub trait MemoryWrite {} /// Marker trait for return-like ops pub trait ReturnLike {} -/// All operands of the given op are the same type -pub trait SameTypeOperands {} +/// Op is a terminator (i.e. it can be used to terminate a block) +pub trait Terminator {} -/// Marker trait for ops whose regions contain only a single block -pub trait SingleBlock {} +/// Marker trait for idemptoent ops, i.e. `op op X == op X (unary) / X op X == X (binary)` +pub trait Idempotent {} -/// Marker trait for ops which can terminate a block -pub trait Terminator {} +/// Marker trait for ops that exhibit the property `op op X == X` +pub trait Involution {} + +/// Marker trait for ops which are not permitted to access values defined above them +pub trait IsolatedFromAbove {} + +derive! { + /// Marker trait for unary ops, i.e. those which take a single operand + pub trait UnaryOp {} + + verify { + fn is_unary_op(op: &Operation, context: &Context) -> Result<(), Report> { + if op.num_operands() == 1 { + Ok(()) + } else { + Err( + context.session.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label(op.span(), format!("incorrect number of operands, expected 1, got {}", op.num_operands())) + .with_help("this operator implements 'UnaryOp', which requires it to have exactly one operand") + .into_report() + ) + } + } + } +} + +derive! { + /// Marker trait for binary ops, i.e. those which take two operands + pub trait BinaryOp {} + + verify { + fn is_binary_op(op: &Operation, context: &Context) -> Result<(), Report> { + if op.num_operands() == 2 { + Ok(()) + } else { + Err( + context.session.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label(op.span(), format!("incorrect number of operands, expected 2, got {}", op.num_operands())) + .with_help("this operator implements 'BinaryOp', which requires it to have exactly two operands") + .into_report() + ) + } + } + } +} + +derive! { + /// Op expects all operands to be of the same type + pub trait SameTypeOperands {} + + verify { + fn operands_are_the_same_type(op: &Operation, context: &Context) -> Result<(), Report> { + if let Some((first_operand, operands)) = op.operands().split_first() { + let (expected_ty, set_by) = { + let operand = first_operand.borrow(); + let value = operand.value(); + (value.ty().clone(), value.span()) + }; + for operand in operands { + let operand = operand.borrow(); + let value = operand.value(); + let value_ty = value.ty(); + if value_ty != &expected_ty { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation expects all operands to be of the same type" + ) + .with_secondary_label( + set_by, + "inferred the expected type from this value" + ) + .with_secondary_label( + value.span(), + "which differs from this value" + ) + .with_help(format!("expected '{expected_ty}', got '{value_ty}'")) + .into_report() + ); + } + } + } + + Ok(()) + } + } +} + +derive! { + /// Op expects all operands and results to be of the same type + /// + /// TODO(pauls): Implement verification for this. Ideally we could require `SameTypeOperands` + /// as a super trait, check the operands using its implementation, and then check the results + /// separately + pub trait SameOperandsAndResultType {} +} + +derive! { + /// Op's regions have no arguments + pub trait NoRegionArguments {} + + verify { + fn no_region_arguments(op: &Operation, context: &Context) -> Result<(), Report> { + for region in op.regions().iter() { + if region.entry().has_arguments() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation does not permit regions with arguments, but one was found" + ) + .into_report()); + } + } + + Ok(()) + } + } +} + +derive! { + /// Op's regions have a single block + pub trait SingleBlock {} + + verify { + fn has_only_single_block_regions(op: &Operation, context: &Context) -> Result<(), Report> { + for region in op.regions().iter() { + if region.body().iter().count() > 1 { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation requires single-block regions, but regions with multiple \ + blocks were found", + ) + .into_report()); + } + } + + Ok(()) + } + } +} diff --git a/hir2/src/core/traits/multitrait.rs b/hir2/src/core/traits/multitrait.rs index c7a3be62e..c71d98f37 100644 --- a/hir2/src/core/traits/multitrait.rs +++ b/hir2/src/core/traits/multitrait.rs @@ -1,49 +1,105 @@ use core::{ any::{Any, TypeId}, - ptr::{null, null_mut}, + marker::Unsize, + ptr::{null, null_mut, DynMetadata, Pointee}, }; +struct TraitImpl { + /// The [TypeId] of the trait type, used as a unique key for [TraitImpl]s + type_id: TypeId, + /// Type-erased dyn metadata containing the trait vtable pointer for the concrete type + /// + /// This is transmuted to the correct trait type when reifying a `&dyn Trait` reference, + /// which is safe as `DynMetadata` is always the same size for all types. + metadata: DynMetadata, +} +impl TraitImpl { + fn new() -> Self + where + T: Any + Unsize, + Trait: ?Sized + Pointee> + 'static, + { + let type_id = TypeId::of::(); + let ptr = null::() as *const Trait; + let (_, metadata) = ptr.to_raw_parts(); + Self { + type_id, + metadata: unsafe { + core::mem::transmute::, DynMetadata>(metadata) + }, + } + } + + unsafe fn metadata_unchecked(&self) -> DynMetadata + where + Trait: ?Sized + Pointee> + 'static, + { + debug_assert!(self.type_id == TypeId::of::()); + core::mem::transmute(self.metadata) + } +} +impl Eq for TraitImpl {} +impl PartialEq for TraitImpl { + fn eq(&self, other: &Self) -> bool { + self.type_id == other.type_id + } +} +impl PartialEq for TraitImpl { + fn eq(&self, other: &TypeId) -> bool { + self.type_id.eq(other) + } +} +impl PartialOrd for TraitImpl { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.type_id.cmp(&other.type_id)) + } +} +impl PartialOrd for TraitImpl { + fn partial_cmp(&self, other: &TypeId) -> Option { + Some(self.type_id.cmp(other)) + } +} +impl Ord for TraitImpl { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.type_id.cmp(&other.type_id) + } +} + pub(crate) struct MultiTraitVtable { - pub(crate) data: *mut (), - pub(crate) type_id: TypeId, - pub(crate) traits: Vec<(TypeId, *const ())>, + data: *mut (), + type_id: TypeId, + traits: Vec, } impl MultiTraitVtable { pub fn new() -> Self { let type_id = TypeId::of::(); - let (any_type, any_vtable) = { - let ptr = null::().cast::(); - let (_, vtable) = ptr.to_raw_parts(); - (TypeId::of::(), vtable) - }; + let any_impl = TraitImpl::new::(); Self { data: null_mut(), type_id, - traits: vec![(any_type, any_vtable)], + traits: vec![any_impl], } } - pub fn set_data_ptr(&mut self, ptr: *mut T) { - let type_id = TypeId::of::(); - assert_eq!(self.type_id, type_id); - self.data = data.cast(); + pub(crate) unsafe fn set_data_ptr(&mut self, ptr: *mut T) { + assert!(!ptr.is_null()); + assert!(ptr.is_aligned()); + assert!(self.is::()); + self.data = ptr.cast(); } - pub fn register_trait(&mut self) + pub fn register_trait(&mut self) where - Trait: ?Sized + Pointee + 'static, + T: Any + Unsize + 'static, + Trait: ?Sized + Pointee> + 'static, { - let (type_id, vtable) = { - let ptr = null::().cast::(); - let (_, vtable) = ptr.to_raw_parts(); - (TypeId::of::(), vtable) - }; - if self.traits.iter().any(|(tid, _)| tid == &type_id) { - return; + let trait_impl = TraitImpl::new::(); + match self.traits.binary_search(&trait_impl) { + Ok(_) => (), + Err(index) if index + 1 == self.traits.len() => self.traits.push(trait_impl), + Err(index) => self.traits.insert(index, trait_impl), } - self.traits.push((type_id, vtable)); - self.traits.sort_by_key(|(tid, _)| tid); } #[inline] @@ -53,16 +109,16 @@ impl MultiTraitVtable { pub fn implements(&self) -> bool where - Trait: ?Sized + Pointee + 'static, + Trait: ?Sized + Pointee> + 'static, { let type_id = TypeId::of::(); - self.traits.binary_search_by(|(tid, _)| tid.cmp(&type_id)).is_ok() + self.traits.binary_search_by(|ti| ti.type_id.cmp(&type_id)).is_ok() } #[inline] pub fn downcast_ref(&self) -> Option<&T> { if self.is::() { - Some(unsafe { self.downcast_reF_unchecked() }) + Some(unsafe { self.downcast_ref_unchecked() }) } else { None } @@ -70,7 +126,7 @@ impl MultiTraitVtable { #[inline(always)] unsafe fn downcast_ref_unchecked(&self) -> &T { - core::ptr::from_raw_parts(self.data, ()) + &*core::ptr::from_raw_parts(self.data.cast::(), ()) } #[inline] @@ -84,26 +140,34 @@ impl MultiTraitVtable { #[inline(always)] unsafe fn downcast_mut_unchecked(&mut self) -> &mut T { - core::ptr::from_raw_parts(self.data, ()) + &mut *core::ptr::from_raw_parts_mut(self.data.cast::(), ()) } pub fn downcast_trait(&self) -> Option<&Trait> where - Trait: ?Sized + Pointee + 'static, + Trait: ?Sized + Pointee> + 'static, { - self.traits.binary_search_by(|(tid, _)| tid.cmp(&type_id)).map(|index| { - let vtable = self.traits[index].1; - core::ptr::from_raw_parts::(self.data, vtable) - }) + let metadata = self + .get::() + .map(|trait_impl| unsafe { trait_impl.metadata_unchecked::() })?; + Some(unsafe { &*core::ptr::from_raw_parts(self.data, metadata) }) } pub fn downcast_trait_mut(&mut self) -> Option<&mut Trait> where - Trait: ?Sized + Pointee + 'static, + Trait: ?Sized + Pointee> + 'static, { - self.traits.binary_search_by(|(tid, _)| tid.cmp(&type_id)).map(|index| { - let vtable = self.traits[index].1; - core::ptr::from_raw_parts_mut::(self.data, vtable) - }) + let metadata = self + .get::() + .map(|trait_impl| unsafe { trait_impl.metadata_unchecked::() })?; + Some(unsafe { &mut *core::ptr::from_raw_parts_mut(self.data, metadata) }) + } + + fn get(&self) -> Option<&TraitImpl> { + let type_id = TypeId::of::(); + self.traits + .binary_search_by(|ti| ti.type_id.cmp(&type_id)) + .ok() + .map(|index| &self.traits[index]) } } diff --git a/hir2/src/core/usable.rs b/hir2/src/core/usable.rs index 98afa56d6..ce0ff554a 100644 --- a/hir2/src/core/usable.rs +++ b/hir2/src/core/usable.rs @@ -1,4 +1,6 @@ -use super::{entity::EntityIter, EntityCursor, EntityCursorMut}; +use super::{ + entity::EntityIter, EntityCursor, EntityCursorMut, EntityList, UnsafeIntrusiveEntityRef, +}; /// The [Usable] trait is implemented for IR entities which are _defined_ and _used_, and as a /// result, require a data structure called the _use-def list_. @@ -25,11 +27,17 @@ pub trait Usable { /// Returns true if this definition is used fn is_used(&self) -> bool; + /// Get a list of uses of this definition + fn uses(&self) -> &EntityList; + /// Get a mutable list of uses of this definition + fn uses_mut(&mut self) -> &mut EntityList; /// Get an iterator over the uses of this definition - fn uses(&self) -> EntityIter<'_, Self::Use>; + fn iter_uses(&self) -> EntityIter<'_, Self::Use>; /// Get a cursor positioned on the first use of this definition, or the null cursor if unused. fn first_use(&self) -> EntityCursor<'_, Self::Use>; /// Get a mutable cursor positioned on the first use of this definition, or the null cursor if /// unused. fn first_use_mut(&mut self) -> EntityCursorMut<'_, Self::Use>; + /// Add `user` to the set of uses of this definition + fn insert_use(&mut self, user: UnsafeIntrusiveEntityRef); } diff --git a/hir2/src/core/value.rs b/hir2/src/core/value.rs index 76d3b5a37..c8ebe55f1 100644 --- a/hir2/src/core/value.rs +++ b/hir2/src/core/value.rs @@ -1,166 +1,253 @@ -use core::{fmt, ptr}; +use core::fmt; -use super::{Block, EntityCursor, EntityCursorMut, EntityIter, EntityList, Type, Usable}; -use crate::{SourceSpan, Spanned}; +use super::*; -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -#[repr(u8)] -pub enum ValueKind { - Result, - BlockArgument, -} +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct ValueId(u32); +impl ValueId { + pub const fn from_u32(id: u32) -> Self { + Self(id) + } -#[derive(Spanned)] -pub struct ValueImpl { - kind: ValueKind, - ty: Type, - #[span] - span: SourceSpan, - uses: OpOperandList, + pub const fn as_u32(&self) -> u32 { + self.0 + } } -impl ValueImpl { +impl EntityId for ValueId { #[inline(always)] - pub const fn kind(&self) -> ValueKind { - self.kind - } - - pub fn is_result(&self) -> bool { - matches!(self, ValueKind::Result) + fn as_usize(&self) -> usize { + self.0 as usize } - - pub fn is_block_argument(&self) -> bool { - matches!(self, ValueKind::BlockArgument) +} +impl fmt::Debug for ValueId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "v{}", &self.0) } - - #[inline(always)] - pub fn ty(&self) -> &Type { - &self.ty +} +impl fmt::Display for ValueId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "v{}", &self.0) } +} - #[inline(always)] - pub fn set_type(&mut self, ty: Type) { - self.ty = ty; - } +pub trait Value: Entity + Spanned + Usable + fmt::Debug { + fn ty(&self) -> &Type; + fn set_type(&mut self, ty: Type); } -impl Usable for ValueImpl { - type Use = OpOperand; - #[inline] - fn is_used(&self) -> bool { - !self.uses.is_empty() - } +macro_rules! value_impl { + ( + $(#[$outer:meta])* + $vis:vis struct $ValueKind:ident { + $( + $(*[$inner:ident $($args:tt)*])* + $Field:ident: $FieldTy:ty, + )* + } - #[inline] - fn uses(&self) -> OpOperandIter<'_> { - self.uses.iter() - } + $($t:tt)* + ) => { + $(#[$outer])* + #[derive(Spanned)] + $vis struct $ValueKind { + id: ValueId, + #[span] + span: SourceSpan, + ty: Type, + uses: OpOperandList, + $( + $(#[$inner $($args)*])* + $Field: $FieldTy + ),* + } - #[inline] - fn first_use(&self) -> OpOperandCursor<'_> { - self.uses.front() - } + impl $ValueKind { + pub fn new( + id: ValueId, + ty: Type, + $( + $Field: $FieldTy + ),* + ) -> Self { + Self { + id, + ty, + span: Default::default(), + uses: Default::default(), + $( + $Field + ),* + } + } + } - #[inline] - fn first_use_mut(&mut self) -> OpOperandCursorMut<'_> { - self.uses.front_mut() - } -} -impl fmt::Debug for ValueImpl { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("ValueImpl") - .field("kind", &self.kind) - .field("ty", &self.ty) - .field("uses", &self.uses) - .finish() - } -} + impl Value for $ValueKind { + fn ty(&self) -> &Type { + &self.ty + } -pub type Value = EntityHandle; + fn set_type(&mut self, ty: Type) { + self.ty = ty; + } + } -#[derive(Spanned)] -pub struct BlockArgument { - #[span] - value: ValueImpl, - owner: EntityHandle, - index: u8, -} -impl Usable for BlockArgument { - type Use = OpOperand; + impl Entity for $ValueKind { + type Id = ValueId; - #[inline] - fn is_used(&self) -> bool { - self.value.is_used() - } + #[inline(always)] + fn id(&self) -> Self::Id { + self.id + } + } - #[inline] - fn uses(&self) -> OpOperandIter<'_> { - self.value.uses() - } + impl Usable for $ValueKind { + type Use = OpOperandImpl; + + #[inline] + fn is_used(&self) -> bool { + !self.uses.is_empty() + } + + #[inline(always)] + fn uses(&self) -> &OpOperandList { + &self.uses + } + + #[inline(always)] + fn uses_mut(&mut self) -> &mut OpOperandList { + &mut self.uses + } + + #[inline] + fn iter_uses(&self) -> OpOperandIter<'_> { + self.uses.iter() + } + + #[inline] + fn first_use(&self) -> OpOperandCursor<'_> { + self.uses.front() + } + + #[inline] + fn first_use_mut(&mut self) -> OpOperandCursorMut<'_> { + self.uses.front_mut() + } + + fn insert_use(&mut self, user: OpOperand) { + self.uses.push_back(user); + } + } - #[inline] - fn first_use(&self) -> OpOperandCursor<'_> { - self.value.first_use() - } + impl fmt::Debug for $ValueKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut builder = f.debug_struct(stringify!($ValueKind)); + builder + .field("id", &self.id) + .field("ty", &self.ty) + .field("uses", &self.uses); - #[inline] - fn first_use_mut(&mut self) -> OpOperandCursorMut<'_> { - self.value.first_use_mut() + $( + builder.field(stringify!($Field), &self.$Field); + )* + + builder.finish_non_exhaustive() + } + } + + $($t)* } } -/// An [OpResult] represents the definition of a [Value] by the result of an [Operation] -#[derive(Spanned)] -pub struct OpResult { - #[span] - value: ValueImpl, - owner: EntityHandle, - index: u8, -} -impl Usable for OpResult { - type Use = OpOperand; +pub type ValueRef = UnsafeEntityRef; +pub type BlockArgumentRef = UnsafeEntityRef; +pub type OpResultRef = UnsafeEntityRef; - #[inline] - fn is_used(&self) -> bool { - self.value.is_used() +value_impl!( + /// A [BlockArgument] represents the definition of a [Value] by a block parameter + pub struct BlockArgument { + owner: BlockRef, + index: u8, } +); - #[inline] - fn uses(&self) -> OpOperandIter<'_> { - self.value.uses() +value_impl!( + /// An [OpResult] represents the definition of a [Value] by the result of an [Operation] + pub struct OpResult { + owner: OperationRef, + index: u8, } +); - #[inline] - fn first_use(&self) -> OpOperandCursor<'_> { - self.value.first_use() +impl BlockArgument { + pub fn owner(&self) -> BlockRef { + self.owner.clone() } - #[inline] - fn first_use_mut(&mut self) -> OpOperandCursorMut<'_> { - self.value.first_use_mut() + pub fn index(&self) -> usize { + self.index as usize } } -pub type OpOperandList = EntityList; -pub type OpOperandIter<'a> = EntityIter<'a, OpOperand>; -pub type OpOperandCursor<'a> = EntityCursor<'a, OpOperand>; -pub type OpOperandCursorMut<'a> = EntityCursorMut<'a, OpOperand>; +impl OpResult { + pub fn owner(&self) -> OperationRef { + self.owner.clone() + } + + pub fn index(&self) -> usize { + self.index as usize + } +} + +pub type OpOperand = UnsafeIntrusiveEntityRef; +pub type OpOperandList = EntityList; +pub type OpOperandIter<'a> = EntityIter<'a, OpOperandImpl>; +pub type OpOperandCursor<'a> = EntityCursor<'a, OpOperandImpl>; +pub type OpOperandCursorMut<'a> = EntityCursorMut<'a, OpOperandImpl>; /// An [OpOperand] represents a use of a [Value] by an [Operation] -pub struct OpOperand { +pub struct OpOperandImpl { /// The operand value - pub value: Value, + pub value: ValueRef, /// The owner of this operand, i.e. the operation it is an operand of - pub owner: EntityHandle, + pub owner: OperationRef, /// The index of this operand in the operand list of an operation pub index: u8, } -impl OpOperand { +impl OpOperandImpl { #[inline] - pub fn new(value: Value, owner: EntityHandle, index: u8) -> Self { + pub fn new(value: ValueRef, owner: OperationRef, index: u8) -> Self { Self { value, owner, index, } } + + pub fn value(&self) -> EntityRef<'_, dyn Value> { + self.value.borrow() + } +} +impl fmt::Debug for OpOperandImpl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + #[allow(unused)] + struct ValueInfo<'a> { + id: ValueId, + ty: &'a Type, + } + + let value = self.value.borrow(); + let id = value.id(); + let ty = value.ty(); + f.debug_struct("OpOperand") + .field("index", &self.index) + .field("value", &ValueInfo { id, ty }) + .finish_non_exhaustive() + } +} + +pub enum OpOperandValue { + Value(ValueRef), + Immediate(Immediate), } diff --git a/hir2/src/core/verifier.rs b/hir2/src/core/verifier.rs new file mode 100644 index 000000000..9fbaad483 --- /dev/null +++ b/hir2/src/core/verifier.rs @@ -0,0 +1,213 @@ +use super::{Context, Report}; + +/// The `OpVerifier` trait is expected to be implemented by all [Op] impls as a prequisite. +/// +/// The actual implementation is typically generated as part of deriving [Op]. +pub trait OpVerifier { + fn verify(&self, context: &Context) -> Result<(), Report>; +} + +/// The `Verify` trait represents verification logic associated with implementations of some trait. +/// +/// This is specifically used for automatically deriving verification checks for [Op] impls that +/// implement traits that imply constraints on the representation or behavior of that op. +/// +/// For example, if some [Op] derives an op trait like `SingleBlock`, this information is recorded +/// in the underlying [Operation] metadata, so that we can recover a trait object reference for the +/// trait when needed. However, just deriving the trait is not sufficient to guarantee that the op +/// actually adheres to the implicit constraints and behavior of that trait. For example, +/// `SingleBlock` implies that the implementing op contains only regions that consist of a single +/// [Block]. This cannot be checked statically. The first step to addressing this though, is to +/// reify the implicit validation rules as explicit checks - hence this trait. +/// +/// So we've established that some op traits, such as `SingleBlock` mentioned above, have implicit +/// validation rules, and we can implement [Verify] to make the implicit validation rules of such +/// traits explicit - but how do we ensure that when an op derives an op trait, that the [Verify] +/// impl is also derived, _and_ that it is called when the op is verified? +/// +/// The answer lies in the use of some tricky type-level code to accomplish the following goals: +/// +/// * Do not emit useless checks for op traits that have no verification rules +/// * Do not require storing data in each instance of an [Op] just to verify a trait +/// * Do not require emitting a bunch of redundant type checks for information we know statically +/// * Be able to automatically derive all of the verification machinery along with the op traits +/// +/// The way this works is as follows: +/// +/// * We `impl Verify for T where T: Op` for every trait `Trait` with validation rules +/// * A blanket impl of [HasVerifier] exists for all `T: Verify`. This is a market trait used +/// in conjunction with specialization. See the trait docs for more details on its purpose. +/// * The [Verifier] trait provides a default vacuous impl for all `Trait` and `T` pairs. However, +/// we also provided a specialized [Verifier] impl for all `T: Verify` using the +/// `HasVerifier` marker. The specialized impl applies the underlying `Verify` impl. +/// * When deriving the op traits for an `Op` impl, we generate a hidden type that encodes all of +/// the op traits implemented by the op. We then generate an `OpVerifier` impl for the op, which +/// uses the hidden type we generated to reify the `Verifier` impl for each trait. The +/// `OpVerifier` implementation uses const eval to strip out all of the vacuous verifier impls, +/// leaving behind just the "real" verification rules specific to the traits implemented by that +/// op. +/// * The `OpVerifier` impl is object-safe, and is in fact a required super-trait of `Op` to ensure +/// that verification is part of defining an `Op`, but also to ensure that `verify` is a method +/// of `Op`, and that we can cast an `Operation` to `&dyn OpVerifier` and call `verify` on that. +/// +/// As a result of all this, we end up with highly-specialized verifiers for each op, with no +/// dynamic dispatch, and automatically maintained as part of the `Op` definition. When a new +/// op trait is derived, the verifier for the op is automatically updated to verify the new trait. +pub trait Verify { + /// In cases where verification may be disabled via runtime configuration, or based on + /// dynamic properties of the type, this method can be overridden and used to signal to + /// the verification driver that verification should be skipped on this item. + #[inline(always)] + #[allow(unused_variables)] + fn should_verify(&self, context: &Context) -> bool { + true + } + /// Apply this verifier, but only if [Verify::should_verify] returns true. + #[inline] + fn maybe_verify(&self, context: &Context) -> Result<(), Report> { + if self.should_verify(context) { + self.verify(context) + } else { + Ok(()) + } + } + /// Apply this verifier to the current item. + fn verify(&self, context: &Context) -> Result<(), Report>; +} + +/// A marker trait used for verifier specialization. +/// +/// # Safety +/// +/// In order for the `#[rustc_unsafe_specialization_marker]` attribute to be used safely and +/// correctly, the following rules must hold: +/// +/// * No associated items +/// * No impls with lifetime constraints, as specialization will ignore them +/// +/// For our use case, which is specializing verification for a given type and trait combination, +/// by optimizing out verification-related code for type combinations which have no verifier, these +/// are easy rules to uphold. +/// +/// However, we must ensure that we continue to uphold these rules moving forward. +#[rustc_unsafe_specialization_marker] +unsafe trait HasVerifier: Verify {} + +// While at first glance, it appears we would be using this to specialize on the fact that a type +// _has_ a verifier, we're actually using this to specialize on the _absence_ of a verifier. See +// `Verifier` for more information. +unsafe impl HasVerifier for T where T: Verify {} + +/// The `Verifier` trait is used to derive a verifier for a given trait and concrete type. +/// +/// It does this by providing a default implementation for all combinations of `Trait` and `T`, +/// which always succeeds, and then specializing that implementation for `T: HasVerifier`. +/// +/// This has the effect of making all traits "verifiable", but only actually doing any verification +/// for types which implement `Verify`. +/// +/// We go a step further and actually set things up so that `rustc` can eliminate all of the dead +/// code when verification is vacuous. See the `trait_verifier` function for details on how that +/// is used in practice. +/// +/// NOTE: Because this trait provides a default blanket impl for all `T`, you should avoid bringing +/// it into scope unless absolutely needed. It is virtually always preferred to explicitly invoke +/// this trait using turbofish syntax, so as to avoid conflict with the [Verify] trait, and to +/// avoid polluting the namespace for all types in scope. +pub trait Verifier { + /// An implementation of `Verifier` sets this flag to true when its implementation is vacuous, + /// i.e. it always succeeds and is not dependent on runtime context. + /// + /// The default implementation of this trait sets this to `true`, since without a verifier for + /// the type, verification always succeeds. However, we can specialize on the presence of + /// a verifier and set this to `false`, which will result in all of the verification logic + /// being applied. + /// + /// ## Example Usage + /// + /// Shown below is an example of how one can use const eval to eliminate dead code branches + /// in verifier selection, so that the resulting implementation is specialized and able to + /// have more optimizations applied as a result. + /// + /// ```rust,ignore + /// #[inline(always)] + /// fn noop(&T, &Context) -> Result<(), Report> { Ok(()) } + /// let verify_fn = const { + /// if >::VACUOUS { + /// noop + /// } else { + /// >::maybe_verify + /// } + /// }; + /// verify_fn(op, context) + /// ``` + const VACUOUS: bool; + + /// Checks if this verifier is applicable for the current item + fn should_verify(&self, context: &Context) -> bool; + /// Applies the verifier for this item, if [Verifier::should_verify] returns `true` + fn maybe_verify(&self, context: &Context) -> Result<(), Report>; + /// Applies the verifier for this item + fn verify(&self, context: &Context) -> Result<(), Report>; +} + +/// The default blanket impl for all types and traits +impl Verifier for T { + default const VACUOUS: bool = true; + + #[inline(always)] + default fn should_verify(&self, _context: &Context) -> bool { + false + } + + #[inline(always)] + default fn maybe_verify(&self, _context: &Context) -> Result<(), Report> { + Ok(()) + } + + #[inline(always)] + default fn verify(&self, _context: &Context) -> Result<(), Report> { + Ok(()) + } +} + +/// THe specialized impl for types which implement `Verify` +impl Verifier for T +where + T: HasVerifier, +{ + const VACUOUS: bool = false; + + #[inline] + fn should_verify(&self, context: &Context) -> bool { + >::should_verify(self, context) + } + + #[inline(always)] + fn maybe_verify(&self, context: &Context) -> Result<(), Report> { + >::maybe_verify(self, context) + } + + #[inline] + fn verify(&self, context: &Context) -> Result<(), Report> { + >::verify(self, context) + } +} + +#[cfg(test)] +mod tests { + use core::hint::black_box; + + use super::*; + use crate::{traits::SingleBlock, Operation}; + + struct Vacuous; + + /// In this test, we're validating that a type that trivially verifies specializes as vacuous, + /// and that a type we know has a "real" verifier, specializes as _not_ vacuous + #[test] + fn verifier_specialization_concrete() { + assert!(black_box(>::VACUOUS)); + assert!(black_box(!>::VACUOUS)); + } +} diff --git a/hir2/src/demangle.rs b/hir2/src/demangle.rs new file mode 100644 index 000000000..98493de1c --- /dev/null +++ b/hir2/src/demangle.rs @@ -0,0 +1,13 @@ +/// Demangle `name`, where `name` was mangled using Rust's mangling scheme +#[inline] +pub fn demangle>(name: S) -> String { + demangle_impl(name.as_ref()) +} + +fn demangle_impl(name: &str) -> String { + let mut input = name.as_bytes(); + let mut demangled = Vec::with_capacity(input.len() * 2); + rustc_demangle::demangle_stream(&mut input, &mut demangled, /* include_hash= */ false) + .expect("failed to write demangled identifier"); + String::from_utf8(demangled).expect("demangled identifier contains invalid utf-8") +} diff --git a/hir2/src/derive.rs b/hir2/src/derive.rs new file mode 100644 index 000000000..e81fc641c --- /dev/null +++ b/hir2/src/derive.rs @@ -0,0 +1,866 @@ +use crate::Operation; + +/// This macro is used to generate the boilerplate for [Op] implementations. +#[macro_export] +macro_rules! derive { + ( + $(#[$outer:meta])* + $vis:vis trait $OpTrait:ident { + $( + $OpTraitItem:item + )* + } + + verify { + $( + fn $verify_fn:ident($op:ident: &$OperationPath:path, $ctx:ident: &$ContextPath:path) -> $VerifyResult:ty $verify:block + )+ + } + + $($t:tt)* + ) => { + $crate::__derive_op_trait! { + $(#[$outer])* + $vis trait $OpTrait { + $( + $OpTraitItem:item + )* + } + + verify { + $( + fn $verify_fn($op: &$OperationPath, $ctx: &$ContextPath) -> $VerifyResult $verify + )* + } + } + + $($t)* + }; + + ( + $(#[$outer:meta])* + $vis:vis trait $OpTrait:ident { + $( + $OpTraitItem:item + )* + } + + $($t:tt)* + ) => { + $crate::__derive_op_trait! { + $(#[$outer])* + $vis trait $OpTrait { + $( + $OpTraitItem:item + )* + } + } + + $($t)* + }; + + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident : Op { + $( + $(#[$inner:ident $($args:tt)*])* + $Field:ident: $FieldTy:ty, + )* + } + + $($t:tt)* + ) => { + $crate::__derive_op!( + $(#[$outer])* + #[derive($crate::Spanned)] + $vis struct $Op { + $( + $(#[$inner $($args)*])* + $Field: $FieldTy + ),* + } + ); + + $($t)* + }; + + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident : Op implements $OpTrait:ident { + $( + $(#[$inner:ident $($args:tt)*])* + $Field:ident: $FieldTy:ty, + )* + } + + $($t:tt)* + ) => { + $crate::__derive_op!( + $(#[$outer])* + $vis struct $Op { + $( + $(#[$inner $($args)*])* + $Field: $FieldTy, + )* + } + + implement $OpTrait; + ); + + $($t)* + }; + + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident : Op implements $OpTrait1:ident $(, $OpTraitRest:ident)* { + $( + $(#[$inner:ident $($args:tt)*])* + $Field:ident: $FieldTy:ty, + )* + } + + $($t:tt)* + ) => { + $crate::__derive_op!( + $(#[$outer])* + $vis struct $Op { + $( + $(#[$inner $($args)*])* + $Field: $FieldTy, + )* + } + + implement $OpTrait1 + $(, implement $OpTraitRest)*; + ); + + $($t)* + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __derive_op_trait { + ( + $(#[$outer:meta])* + $vis:vis trait $OpTrait:ident { + $( + $OpTraitItem:item + )* + } + + verify { + $( + fn $verify_fn:ident($op:ident: &$OperationPath:path, $ctx:ident: &$ContextPath:path) -> $VerifyResult:ty $verify:block + )+ + } + ) => { + $(#[$outer])* + $vis trait $OpTrait { + $( + $OpTraitItem + )* + } + + impl $crate::Verify for T { + #[inline] + fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { + <$crate::Operation as $crate::Verify>::verify(self.as_operation(), context) + } + } + + impl $crate::Verify for $crate::Operation { + fn should_verify(&self, _context: &$crate::Context) -> bool { + self.implements::() + } + + fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { + $( + #[inline] + fn $verify_fn($op: &$OperationPath, $ctx: &$ContextPath) -> $VerifyResult $verify + )* + + $( + $verify_fn(self, context)?; + )* + + Ok(()) + } + } + }; + + ( + $(#[$outer:meta])* + $vis:vis trait $OpTrait:ident { + $( + $OpTraitItem:item + )* + } + ) => { + $(#[$outer])* + $vis trait $OpTrait { + $( + $OpTraitItem + )* + } + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __derive_op { + // Entry + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident { + $( + $(#[$inner:ident $($args:tt)*])* + $Field:ident: $FieldTy:ty, + )* + } + + $(implement $OpTrait:ident),*; + ) => { + $crate::__derive_op! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + $( + { + unprocessed: [$(#[$inner $($args)*])*], + field: $Field, + field_type: $FieldTy, + } + )* + ], + processed: { + dialect: [], + traits: [$(implement $OpTrait),*], + attrs: [], + operands_count: [0usize], + operands: [], + results_count: [0usize], + results: [], + } + } + }; + + // Handle duplicate `dialect` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[dialect] + $($attrs_rest:tt)* + ], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + dialect: [$(dialect_processed:tt)+], + traits: [$(implement $OpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + compile_error!("unexpected duplicate dialect attr: got '{}', but '{}' was previously seen", stringify!($Dialect), stringify!($dialect_processed)); + }; + + // Handle `dialect` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[dialect] + $($attrs_rest:tt)* + ], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + dialect: [], + traits: [$(implement $OpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + dialect: [dialect $FieldTy], + traits: [$(implement $OpTrait),*], + attrs: [$($attrs_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle `operand` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[operand $($args:tt)*] + $($attrs_rest:tt)* + ], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + dialect: [$($dialect_processed:tt)*], + traits: [$(implement $OpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + dialect: [$($dialect_processed)*], + traits: [$(implement $OpTrait),*], + attrs: [$($attrs_processed)*], + operands_count: [1usize + $operands_count], + operands: [operand $Field at $operands_count $($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle `result` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[result $($args:tt)*] + $($attrs_rest:tt)* + ], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + dialect: [$($dialect_processed:tt)*], + traits: [$(implement $OpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + dialect: [$($dialect_processed)*], + traits: [$(implement $OpTrait),*], + attrs: [$($attrs_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [1usize + $results_count], + results: [result $Field at $results_count $($results_processed)*], + } + } + }; + + // Handle `attr` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[attr $($args:tt)*] + $($attrs_rest:tt)* + ], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + dialect: [$($dialect_processed:tt)*], + traits: [$(implement $OpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + dialect: [$($dialect_processed)*], + traits: [$(implement $OpTrait),*], + attrs: [attr $Field: $FieldTy $($attrs_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle end of unprocessed attributes + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + dialect: [$($dialect_processed:tt)*], + traits: [$(implement $OpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + $($fields_rest)* + ], + processed: { + dialect: [$($dialect_processed)*], + traits: [$(implement $OpTrait),*], + attrs: [$($attrs_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle end of unprocessed fields + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [], + processed: { + dialect: [$($dialect_processed:tt)*], + traits: [$(implement $OpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op_impl!( + $(#[$outer])* + $vis struct $Op; + + $($dialect_processed)*; + $(implement $OpTrait),*; + $($attrs_processed)*; + $($operands_processed)*; + $($results_processed)*; + ); + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __derive_op_impl { + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + dialect $Dialect:ty; + $(implement $OpTrait:ident),*; + $(attr $AttrField:ident: $AttrTy:ty)*; + $(operand $Operand:ident at $OperandIdx:expr)*; + $(result $Result:ident at $ResultIdx:expr)*; + + ) => { + $(#[$outer])* + #[derive(Spanned)] + $vis struct $Op { + #[span] + op: $crate::Operation, + } + + #[allow(unused)] + impl $Op { + /// Get a new, uninitialized instance of this op + pub fn uninit() -> Self { + Self { + op: $crate::Operation::uninit::(), + } + } + + pub fn create( + context: &$crate::Context + $( + , $Operand: $crate::ValueRef + )* + $( + , $AttrField: $AttrTy + )* + ) -> Result<$crate::UnsafeIntrusiveEntityRef<$Op>, $crate::Report> { + let mut builder = $crate::OperationBuilder::::new(context, Self::uninit()); + $( + builder.implement::(); + )* + $( + builder.with_attr(stringify!($AttrField), $AttrField); + )* + builder.with_operands([$($Operand),*]); + let num_results = const { + let results: &[usize] = &[$($ResultIdx),*]; + results.len() + }; + builder.with_results(num_results); + builder.build() + } + + $( + fn $AttrField(&self) -> $AttrTy { + let sym = stringify!($AttrField); + let value = self.op.get_attribute(&::midenc_hir_symbol::Symbol::intern(sym)).unwrap(); + value.downcast_ref::<$AttrTy>().unwrap().clone() + } + )* + + $( + fn $Operand(&self) -> $crate::OpOperand { + self.operands()[$OperandIdx].clone() + } + )* + + $( + fn $Result(&self) -> $crate::ValueRef { + self.results()[$ResultIdx].clone() + } + )* + } + + impl AsRef<$crate::Operation> for $Op { + #[inline(always)] + fn as_ref(&self) -> &$crate::Operation { + &self.op + } + } + + impl AsMut<$crate::Operation> for $Op { + #[inline(always)] + fn as_mut(&mut self) -> &mut $crate::Operation { + &mut self.op + } + } + + __derive_op_name!($Op); + + impl $crate::Op for $Op { + fn name(&self) -> $crate::OperationName { + const DIALECT: $Dialect = <$Dialect as $crate::Dialect>::INIT; + let dialect = <$Dialect as $crate::Dialect>::name(&DIALECT); + paste::paste! { + $crate::OperationName::new(dialect, *[<__ $Op _NAME>]) + } + } + + #[inline(always)] + fn as_operation(&self) -> &$crate::Operation { + &self.op + } + + #[inline(always)] + fn as_operation_mut(&mut self) -> &mut $crate::Operation { + &mut self.op + } + } + + __derive_op_traits!($Op, $($OpTrait),*); + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __derive_op_name { + ($Op:ident) => { + paste::paste! { + #[allow(non_upper_case_globals)] + static [<__ $Op _NAME>]: ::std::sync::LazyLock<::midenc_hir_symbol::Symbol> = ::std::sync::LazyLock::new(|| { + // CondBrOp => CondBr => cond_br + // Add => add + let type_name = stringify!($Op); + let type_name = type_name.strip_suffix("Op").unwrap_or(type_name); + let mut buf = ::alloc::string::String::with_capacity(type_name.len()); + let mut word_started_at = None; + for (i, c) in type_name.char_indices() { + if c.is_ascii_uppercase() { + if word_started_at.is_some() { + buf.push('_'); + buf.push(c.to_ascii_lowercase()); + } else { + word_started_at = Some(i); + buf.push(c.to_ascii_lowercase()); + } + } else if word_started_at.is_none() { + word_started_at = Some(i); + buf.push(c); + } else { + buf.push(c); + } + } + ::midenc_hir_symbol::Symbol::intern(buf) + }); + } + } +} + +/// This macro emits the trait derivations and specialized verifier for a given [Op] impl. +#[doc(hidden)] +#[macro_export] +macro_rules! __derive_op_traits { + ($T:ty) => { + impl $crate::OpVerifier for $T { + #[inline(always)] + fn verify(&self, _context: &$crate::Context) -> Result<(), $crate::Report> { + Ok(()) + } + } + }; + + ($T:ty, $($Trait:ident),+) => { + $( + impl $Trait for $T {} + )* + + impl $crate::OpVerifier for $T { + fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { + #[allow(unused_parens)] + type OpVerifierImpl<'a> = $crate::derive::DeriveVerifier<'a, $T, ($(&'a dyn $Trait),*)>; + #[allow(unused_parens)] + impl<'a> $crate::OpVerifier for $crate::derive::DeriveVerifier<'a, $T, ($(&'a dyn $Trait),*)> + where + $( + $T: $crate::verifier::Verifier + ),* + { + fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { + let op = self.downcast_ref::<$T>().unwrap(); + $( + if const { !<$T as $crate::verifier::Verifier>::VACUOUS } { + <$T as $crate::verifier::Verifier>::maybe_verify(op, context)?; + } + )* + + Ok(()) + } + } + + let op = self.as_operation(); + let verifier = OpVerifierImpl::new(op); + verifier.verify(context) + } + } + } +} + +/// This type represents the concrete set of derived traits for some op `T`, paired with a +/// type-erased [Operation] reference for an instance of that op. +/// +/// This is used for two purposes: +/// +/// 1. To generate a specialized [OpVerifier] for `T` which contains all of the type and +/// trait-specific validation logic for that `T`. +/// 2. To apply the specialized verifier for `T` using the wrapped [Operation] reference. +#[doc(hidden)] +pub struct DeriveVerifier<'a, T, Derived: ?Sized> { + op: &'a Operation, + _t: core::marker::PhantomData, + _derived: core::marker::PhantomData, +} +impl<'a, T, Derived: ?Sized> DeriveVerifier<'a, T, Derived> { + #[doc(hidden)] + pub const fn new(op: &'a Operation) -> Self { + Self { + op, + _t: core::marker::PhantomData, + _derived: core::marker::PhantomData, + } + } +} +impl<'a, T, Derived: ?Sized> core::ops::Deref for DeriveVerifier<'a, T, Derived> { + type Target = Operation; + + fn deref(&self) -> &Self::Target { + self.op + } +} + +#[cfg(test)] +mod tests { + use core::fmt; + + use crate::{ + define_attr_type, dialects::hir::HirDialect, traits::*, Context, Op, Operation, Report, + SourceSpan, Spanned, + }; + + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + enum Overflow { + #[allow(unused)] + None, + Wrapping, + #[allow(unused)] + Overflowing, + } + impl fmt::Display for Overflow { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(self, f) + } + } + define_attr_type!(Overflow); + + derive! { + /// An example op implementation to make sure all of the type machinery works + struct AddOp : Op implements SingleBlock, SameTypeOperands, ArithmeticOp { + #[dialect] + dialect: HirDialect, + #[attr] + overflow: Overflow, + #[operand] + lhs: OpOperand, + #[operand] + rhs: OpOperand, + } + } + + derive! { + /// A marker trait for arithmetic ops + trait ArithmeticOp {} + + verify { + fn is_binary_op(op: &Operation, ctx: &Context) -> Result<(), Report> { + if op.num_operands() == 2 { + Ok(()) + } else { + Err( + ctx.session.diagnostics + .diagnostic(miden_assembly::diagnostics::Severity::Error) + .with_message("invalid operation") + .with_primary_label(op.span(), format!("incorrect number of operands, expected 2, got {}", op.num_operands())) + .with_help("this operator implements 'ArithmeticOp' which requires ops to be binary") + .into_report() + ) + } + } + } + } + + #[test] + fn test_derived_op() { + use crate::Type; + + let context = Context::default(); + let block = context.create_block_with_params([Type::U32, Type::I64]); + let block = block.borrow(); + let lhs = block.get_argument(0); + let rhs = block.get_argument(1); + let op = AddOp::create(&context, rhs, lhs, Overflow::Wrapping); + let op = op.expect("failed to create AddOp"); + let op = op.borrow(); + assert!(op.as_operation().implements::()); + assert!(core::hint::black_box( + !>::VACUOUS + )); + } +} diff --git a/hir2/src/dialects.rs b/hir2/src/dialects.rs new file mode 100644 index 000000000..85fe52210 --- /dev/null +++ b/hir2/src/dialects.rs @@ -0,0 +1 @@ +pub mod hir; diff --git a/hir2/src/dialects/hir.rs b/hir2/src/dialects/hir.rs index e69de29bb..976e9ef46 100644 --- a/hir2/src/dialects/hir.rs +++ b/hir2/src/dialects/hir.rs @@ -0,0 +1,14 @@ +mod ops; + +pub use self::ops::*; +use crate::{interner, Dialect, DialectName}; + +#[derive(Default, Debug)] +pub struct HirDialect; +impl Dialect for HirDialect { + const INIT: Self = HirDialect; + + fn name(&self) -> DialectName { + DialectName::from_symbol(interner::symbols::Hir) + } +} diff --git a/hir2/src/dialects/hir/ops.rs b/hir2/src/dialects/hir/ops.rs new file mode 100644 index 000000000..a3656ce55 --- /dev/null +++ b/hir2/src/dialects/hir/ops.rs @@ -0,0 +1,9 @@ +mod binary; +mod cast; +mod control; +mod invoke; +mod mem; +mod primop; +mod unary; + +pub use self::{binary::*, cast::*, control::*, invoke::*, mem::*, primop::*, unary::*}; diff --git a/hir2/src/dialects/hir/ops/binary.rs b/hir2/src/dialects/hir/ops/binary.rs new file mode 100644 index 000000000..2fcd705c0 --- /dev/null +++ b/hir2/src/dialects/hir/ops/binary.rs @@ -0,0 +1,129 @@ +use crate::{dialects::hir::HirDialect, traits::*, *}; + +macro_rules! derive_binary_op_with_overflow { + ($Op:ident) => { + derive! { + pub struct $Op: Op implements BinaryOp { + #[dialect] + dialect: HirDialect, + #[operand] + lhs: OpOperandRef, + #[operand] + rhs: OpOperandRef, + #[result] + result: OpResultRef, + #[attr] + overflow: Overflow, + } + } + }; + + ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive! { + pub struct $Op: Op implements BinaryOp, $OpTrait $(, $OpTraitRest)* { + #[dialect] + dialect: HirDialect, + #[operand] + lhs: OpOperandRef, + #[operand] + rhs: OpOperandRef, + #[result] + result: OpResultRef, + #[attr] + overflow: Overflow, + } + } + }; +} + +macro_rules! derive_binary_op { + ($Op:ident) => { + derive! { + pub struct $Op: Op implements BinaryOp { + #[dialect] + dialect: HirDialect, + #[operand] + lhs: OpOperandRef, + #[operand] + rhs: OpOperandRef, + #[result] + result: OpResultRef, + } + } + }; + + ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive! { + pub struct $Op: Op implements BinaryOp, $OpTrait $(, $OpTraitRest)* { + #[dialect] + dialect: HirDialect, + #[operand] + lhs: OpOperandRef, + #[operand] + rhs: OpOperandRef, + #[result] + result: OpResultRef, + } + } + }; +} + +macro_rules! derive_binary_logical_op { + ($Op:ident) => { + derive_binary_op!($Op implements SameTypeOperands, SameOperandsAndResultType, Commutative); + }; + + ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive_binary_op!($Op implements SameTypeOperands, SameOperandsAndResultType, Commutative, $OpTrait $(, $OpTraitRest)*); + }; +} + +macro_rules! derive_binary_bitwise_op { + ($Op:ident) => { + derive_binary_op!($Op implements SameTypeOperands, SameOperandsAndResultType); + }; + + ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive_binary_op!($Op implements SameTypeOperands, SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); + }; +} + +macro_rules! derive_binary_comparison_op { + ($Op:ident) => { + derive_binary_op!($Op implements SameTypeOperands); + }; + + ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive_binary_op!($Op implements SameTypeOperands, $OpTrait $(, $OpTraitRest)*); + }; +} + +derive_binary_op_with_overflow!(Add implements Commutative, SameTypeOperands); +derive_binary_op_with_overflow!(Sub implements SameTypeOperands); +derive_binary_op_with_overflow!(Mul implements Commutative, SameTypeOperands); +derive_binary_op_with_overflow!(Exp); + +derive_binary_op!(Div implements SameTypeOperands, SameOperandsAndResultType); +derive_binary_op!(Mod implements SameTypeOperands, SameOperandsAndResultType); +derive_binary_op!(DivMod implements SameTypeOperands, SameOperandsAndResultType); + +derive_binary_logical_op!(And); +derive_binary_logical_op!(Or); +derive_binary_logical_op!(Xor); + +derive_binary_bitwise_op!(Band implements Commutative); +derive_binary_bitwise_op!(Bor implements Commutative); +derive_binary_bitwise_op!(Bxor implements Commutative); +derive_binary_op!(Shl); +derive_binary_op!(Shr); +derive_binary_op!(Rotl); +derive_binary_op!(Rotr); + +derive_binary_comparison_op!(Eq implements Commutative); +derive_binary_comparison_op!(Neq implements Commutative); +derive_binary_comparison_op!(Gt); +derive_binary_comparison_op!(Gte); +derive_binary_comparison_op!(Lt); +derive_binary_comparison_op!(Lte); +derive_binary_comparison_op!(Min implements Commutative); +derive_binary_comparison_op!(Max implements Commutative); diff --git a/hir2/src/dialects/hir/ops/cast.rs b/hir2/src/dialects/hir/ops/cast.rs new file mode 100644 index 000000000..494240bcf --- /dev/null +++ b/hir2/src/dialects/hir/ops/cast.rs @@ -0,0 +1,126 @@ +use crate::{dialects::hir::HirDialect, traits::*, *}; + +// TODO(pauls): Implement support in `derive!` for expressing: +// +// * Doc comments +// * Type constraints +// * Additional verification rules +/* +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum CastKind { + /// Reinterpret the bits of the operand as the target type, without any consideration for + /// the original meaning of those bits. + /// + /// For example, transmuting `u32::MAX` to `i32`, produces a value of `-1`, because the input + /// value overflows when interpreted as a signed integer. + Transmute, + /// Like `Transmute`, but the input operand is checked to verify that it is a valid value + /// of both the source and target types. + /// + /// For example, a checked cast of `u32::MAX` to `i32` would assert, because the input value + /// cannot be represented as an `i32` due to overflow. + Checked, + /// Convert the input value to the target type, by zero-extending the value to the target + /// bitwidth. A cast of this type must be a widening cast, i.e. from a smaller bitwidth to + /// a larger one. + Zext, + /// Convert the input value to the target type, by sign-extending the value to the target + /// bitwidth. A cast of this type must be a widening cast, i.e. from a smaller bitwidth to + /// a larger one. + Sext, + /// Convert the input value to the target type, by truncating the excess bits. A cast of this + /// type must be a narrowing cast, i.e. from a larger bitwidth to a smaller one. + Trunc, +} + */ + +derive! { + pub struct PtrToInt : Op implements UnaryOp { + #[dialect] + dialect: HirDialect, + #[attr] + ty: Type, + #[operand] + operand: OpOperand, + #[result] + result: OpResult, + } +} + +derive! { + pub struct IntToPtr : Op implements UnaryOp { + #[dialect] + dialect: HirDialect, + #[attr] + ty: Type, + #[operand] + operand: OpOperand, + #[result] + result: OpResult, + } +} + +derive! { + pub struct Cast : Op implements UnaryOp { + #[dialect] + dialect: HirDialect, + #[attr] + ty: Type, + #[operand] + operand: OpOperand, + #[result] + result: OpResult, + } +} + +derive! { + pub struct Bitcast : Op implements UnaryOp { + #[dialect] + dialect: HirDialect, + #[attr] + ty: Type, + #[operand] + operand: OpOperand, + #[result] + result: OpResult, + } +} + +derive! { + pub struct Trunc : Op implements UnaryOp { + #[dialect] + dialect: HirDialect, + #[attr] + ty: Type, + #[operand] + operand: OpOperand, + #[result] + result: OpResult, + } +} + +derive! { + pub struct Zext : Op implements UnaryOp { + #[dialect] + dialect: HirDialect, + #[attr] + ty: Type, + #[operand] + operand: OpOperand, + #[result] + result: OpResult, + } +} + +derive! { + pub struct Sext : Op implements UnaryOp { + #[dialect] + dialect: HirDialect, + #[attr] + ty: Type, + #[operand] + operand: OpOperand, + #[result] + result: OpResult, + } +} diff --git a/hir2/src/dialects/hir/ops/control.rs b/hir2/src/dialects/hir/ops/control.rs new file mode 100644 index 000000000..e6da63afe --- /dev/null +++ b/hir2/src/dialects/hir/ops/control.rs @@ -0,0 +1,120 @@ +use smallvec::SmallVec; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +derive! { + pub struct Ret : Op implements Terminator { + #[dialect] + dialect: HirDialect, + #[operand] + value: OpOperand, + } +} + +// TODO(pauls): RetImm + +// TODO(pauls): Implement support for: +// +// * `#[successor]` to represent a single `Successor` of this op +derive! { + pub struct Br : Op implements Terminator { + #[dialect] + dialect: HirDialect, + #[successor] + target: Successor, + } +} + +derive! { + pub struct CondBr : Op implements Terminator { + #[dialect] + dialect: HirDialect, + #[operand] + condition: OpOperand, + #[successor] + then_dest: Successor, + #[successor] + else_dest: Successor, + } +} + +// TODO(pauls): Implement support for: +// +// * `SuccessorInterface` for custom types which represent a `Successor` +// * `#[successors]` to represent variadic successors of an op +// * `#[successors(interface)]` to indicate that the successor info should be obtained from this field via `SuccessorInterface` +derive! { + pub struct Switch : Op implements Terminator { + #[dialect] + dialect: HirDialect, + #[operand] + selector: OpOperand, + #[successors(delegated)] + cases: SmallVec<[SwitchCase; 2]>, + #[successor] + fallback: Successor, + } +} + +// TODO(pauls): Implement `SuccessorInterface` for this type +#[derive(Debug, Clone)] +pub struct SwitchCase { + pub value: u32, + pub successor: Successor, +} + +// TODO(pauls): Implement: +// +// * `region` attribute +derive! { + pub struct If : Op implements SingleBlock, NoRegionArguments { + #[dialect] + dialect: HirDialect, + #[operand] + condition: OpOperand, + #[region] + then_body: Region, + #[region] + else_body: Region, + } +} + +/// A while is a loop structure composed of two regions: a "before" region, and an "after" region. +/// +/// The "before" region's entry block parameters correspond to the operands expected by the +/// operation, and can be used to compute the condition that determines whether the "after" body +/// is executed or not, or simply forwarded to the "after" region. The "before" region must +/// terminate with a [Condition] operation, which will be evaluated to determine whether or not +/// to continue the loop. +/// +/// The "after" region corresponds to the loop body, and must terminate with a [Yield] operation, +/// whose operands must be of the same arity and type as the "before" region's argument list. In +/// this way, the "after" body can feed back input to the "before" body to determine whether to +/// continue the loop. + +derive! { + pub struct While : Op implements SingleBlock { + #[dialect] + dialect: HirDialect, + #[region] + before: Region, + #[region] + after: Region, + } +} + +derive! { + pub struct Condition : Op implements Terminator, ReturnLike { + #[dialect] + dialect: HirDialect, + #[operand] + value: OpOperand, + } +} + +derive! { + pub struct Yield : Op implements Terminator, ReturnLike { + #[dialect] + dialect: HirDialect, + } +} diff --git a/hir2/src/dialects/hir/ops/invoke.rs b/hir2/src/dialects/hir/ops/invoke.rs new file mode 100644 index 000000000..2800740ea --- /dev/null +++ b/hir2/src/dialects/hir/ops/invoke.rs @@ -0,0 +1,24 @@ +use crate::{dialects::hir::HirDialect, traits::*, *}; + +// TODO(pauls): Implement support for: +// +// * Inferring op constraints from callee signature +derive! { + pub struct Exec : Op implements CallInterface { + #[dialect] + dialect: HirDialect, + #[attr] + callee: FunctionIdent, + } +} + +derive! { + pub struct ExecIndirect : Op implements CallInterface { + #[dialect] + dialect: HirDialect, + #[attr] + signature: Signature, + #[operand] + callee: OpOperand, + } +} diff --git a/hir2/src/dialects/hir/ops/mem.rs b/hir2/src/dialects/hir/ops/mem.rs new file mode 100644 index 000000000..131d93231 --- /dev/null +++ b/hir2/src/dialects/hir/ops/mem.rs @@ -0,0 +1,25 @@ +use crate::{dialects::hir::HirDialect, traits::*, *}; + +derive! { + pub struct Store : Op implements HasSideEffects, MemoryWrite { + #[dialect] + dialect: HirDialect, + #[operand] + addr: OpOperand, + #[operand] + value: OpOperand, + } +} + +// TODO(pauls): StoreLocal + +derive! { + pub struct Load : Op implements HasSideEffects, MemoryRead { + #[dialect] + dialect: HirDialect, + #[operand] + addr: OpOperand, + } +} + +// TODO(pauls): LoadLocal diff --git a/hir2/src/dialects/hir/ops/primop.rs b/hir2/src/dialects/hir/ops/primop.rs new file mode 100644 index 000000000..8cf786a1f --- /dev/null +++ b/hir2/src/dialects/hir/ops/primop.rs @@ -0,0 +1,51 @@ +use crate::{dialects::hir::HirDialect, traits::*, *}; + +derive! { + pub struct MemGrow : Op implements HasSideEffects, MemoryRead, MemoryWrite { + #[dialect] + dialect: HirDialect, + #[operand] + pages: OpOperand, + #[result] + result: OpResult, + } +} + +derive! { + pub struct MemSize : Op implements HasSideEffects, MemoryRead { + #[dialect] + dialect: HirDialect, + #[result] + result: OpResult, + } +} + +derive! { + pub struct MemSet : Op implements HasSideEffects, MemoryWrite { + #[dialect] + dialect: HirDialect, + #[operand] + addr: OpOperand, + #[operand] + count: OpOperand, + #[operand] + value: OpOperand, + #[result] + result: OpResult, + } +} + +derive! { + pub struct MemCpy : Op implements HasSideEffects, MemoryRead, MemoryWrite { + #[dialect] + dialect: HirDialect, + #[operand] + source: OpOperand, + #[operand] + destination: OpOperand, + #[operand] + count: OpOperand, + #[result] + result: OpResult, + } +} diff --git a/hir2/src/dialects/hir/ops/unary.rs b/hir2/src/dialects/hir/ops/unary.rs new file mode 100644 index 000000000..ad1fb7b7f --- /dev/null +++ b/hir2/src/dialects/hir/ops/unary.rs @@ -0,0 +1,65 @@ +use crate::{dialects::hir::HirDialect, traits::*, *}; + +macro_rules! derive_unary_op { + ($Op:ident) => { + derive! { + pub struct $Op: Op implements UnaryOp { + #[dialect] + dialect: HirDialect, + #[operand] + operand: OpOperandRef, + #[result] + result: OpResultRef, + } + } + }; + + ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive! { + pub struct $Op: Op implements UnaryOp, $OpTrait $(, $OpTraitRest)* { + #[dialect] + dialect: HirDialect, + #[operand] + operand: OpOperandRef, + #[result] + result: OpResultRef, + } + } + }; +} + +macro_rules! derive_unary_logical_op { + ($Op:ident) => { + derive_unary_op!($Op implements SameOperandsAndResultType); + }; + + ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive_unary_op!($Op implements SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); + }; +} + +macro_rules! derive_unary_bitwise_op { + ($Op:ident) => { + derive_unary_op!($Op implements SameOperandsAndResultType); + }; + + ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive_unary_op!($Op implements SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); + }; +} + +derive_unary_op!(Neg implements SameOperandsAndResultType); +derive_unary_op!(Inv implements SameOperandsAndResultType); +derive_unary_op!(Incr implements SameOperandsAndResultType); +derive_unary_op!(Ilog2 implements SameOperandsAndResultType); +derive_unary_op!(Pow2 implements SameOperandsAndResultType); + +derive_unary_logical_op!(Not); +derive_unary_logical_op!(IsOdd); + +derive_unary_bitwise_op!(Bnot); +derive_unary_bitwise_op!(Popcnt); +derive_unary_bitwise_op!(Clz); +derive_unary_bitwise_op!(Ctz); +derive_unary_bitwise_op!(Clo); +derive_unary_bitwise_op!(Cto); diff --git a/hir2/src/formatter.rs b/hir2/src/formatter.rs new file mode 100644 index 000000000..08eab2ae5 --- /dev/null +++ b/hir2/src/formatter.rs @@ -0,0 +1,17 @@ +use core::fmt; + +pub use miden_core::{ + prettier::*, + utils::{DisplayHex, ToHex}, +}; + +pub struct DisplayIndent(pub usize); +impl fmt::Display for DisplayIndent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + const INDENT: &str = " "; + for _ in 0..self.0 { + f.write_str(INDENT)?; + } + Ok(()) + } +} diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index 8c7aca80a..425d7975e 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -1,6 +1,26 @@ -mod core; -mod unsafe_ref; +#![feature(allocator_api)] +#![feature(alloc_layout_extra)] +#![feature(coerce_unsized)] +#![feature(unsize)] +#![feature(ptr_metadata)] +#![feature(layout_for_ptr)] +#![feature(slice_ptr_get)] +#![feature(specialization)] +#![feature(rustc_attrs)] +#![feature(debug_closure_helpers)] +#![allow(incomplete_features)] +#![allow(internal_features)] + +extern crate alloc; -pub use miden_assembly::{SourceSpan, Spanned}; +#[cfg(feature = "std")] +extern crate std; + +mod attributes; +mod core; +pub mod demangle; +pub mod derive; +pub mod dialects; +pub mod formatter; -pub use self::{core::*, unsafe_ref::UnsafeRef}; +pub use self::{attributes::*, core::*}; diff --git a/hir2/src/ops/binary.rs b/hir2/src/ops/binary.rs deleted file mode 100644 index 576f49763..000000000 --- a/hir2/src/ops/binary.rs +++ /dev/null @@ -1,60 +0,0 @@ -use crate::*; - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum BinaryOpcode { - Add(Overflow), - Sub(Overflow), - Mul(Overflow), - Div, - Mod, - DivMod, - Exp(Overflow), - And, - Band, - Or, - Bor, - Xor, - Bxor, - Shl, - Shr, - Rotl, - Rotr, - Eq, - Neq, - Gt, - Gte, - Lt, - Lte, - Min, - Max, -} -impl BinaryOpcode { - pub fn is_commutative(&self) -> bool { - matches!( - self, - Self::Add - | Self::Mul - | Self::Min - | Self::Max - | Self::Eq - | Self::Neq - | Self::And - | Self::Band - | Self::Or - | Self::Bor - | Self::Xor - | Self::Bxor - ) - } -} - -pub struct BinaryOp { - pub op: Operation, - pub opcode: BinaryOpcode, -} - -pub struct BinaryOpImm { - pub op: Operation, - pub opcode: BinaryOpcode, - pub imm: Immediate, -} diff --git a/hir2/src/ops/call.rs b/hir2/src/ops/call.rs deleted file mode 100644 index 16983dd4f..000000000 --- a/hir2/src/ops/call.rs +++ /dev/null @@ -1,6 +0,0 @@ -use crate::*; - -pub struct Call { - pub op: Operation, - pub callee: FunctionIdent, -} diff --git a/hir2/src/ops/cast.rs b/hir2/src/ops/cast.rs deleted file mode 100644 index a8141d7ad..000000000 --- a/hir2/src/ops/cast.rs +++ /dev/null @@ -1,33 +0,0 @@ -use crate::*; - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum CastKind { - /// Reinterpret the bits of the operand as the target type, without any consideration for - /// the original meaning of those bits. - /// - /// For example, transmuting `u32::MAX` to `i32`, produces a value of `-1`, because the input - /// value overflows when interpreted as a signed integer. - Transmute, - /// Like `Transmute`, but the input operand is checked to verify that it is a valid value - /// of both the source and target types. - /// - /// For example, a checked cast of `u32::MAX` to `i32` would assert, because the input value - /// cannot be represented as an `i32` due to overflow. - Checked, - /// Convert the input value to the target type, by zero-extending the value to the target - /// bitwidth. A cast of this type must be a widening cast, i.e. from a smaller bitwidth to - /// a larger one. - Zext, - /// Convert the input value to the target type, by sign-extending the value to the target - /// bitwidth. A cast of this type must be a widening cast, i.e. from a smaller bitwidth to - /// a larger one. - Sext, - /// Convert the input value to the target type, by truncating the excess bits. A cast of this - /// type must be a narrowing cast, i.e. from a larger bitwidth to a smaller one. - Trunc, -} - -pub struct Cast { - pub op: Operation, - pub kind: CastKind, -} diff --git a/hir2/src/ops/control_flow.rs b/hir2/src/ops/control_flow.rs deleted file mode 100644 index 13b80c8e5..000000000 --- a/hir2/src/ops/control_flow.rs +++ /dev/null @@ -1,50 +0,0 @@ -use smallvec::SmallVec; - -use crate::*; - -pub struct Br { - pub op: Operation, -} -impl Br { - pub fn dest(&self) -> &Successor { - &self.op.successors[0] - } -} - -pub struct CondBr { - pub op: Operation, -} -impl CondBr { - pub fn condition(&self) -> Value { - todo!() - } - - pub fn then_dest(&self) -> &Successor { - &self.op.successors[0] - } - - pub fn else_dest(&self) -> &Successor { - &self.op.successors[1] - } -} - -pub struct Switch { - pub op: Operation, - pub cases: SmallVec<[u32; 4]>, - pub default_successor: usize, -} -impl Switch { - pub fn selector(&self) -> Value { - todo!() - } - - pub fn default_dest(&self) -> &Successor { - &self.op.successors[self.default_successor] - } -} - -#[derive(Debug, Clone)] -pub struct SwitchCase { - pub value: u32, - pub successor: Successor, -} diff --git a/hir2/src/ops/global_value.rs b/hir2/src/ops/global_value.rs deleted file mode 100644 index 2b53fe926..000000000 --- a/hir2/src/ops/global_value.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::*; - -#[derive(Debug, Clone)] -pub struct GlobalValueOp { - pub id: GlobalValue, - pub data: GlobalValueData, - pub op: Operation, -} - -impl Op for GlobalValueOp { - type Id = GlobalValue; - - #[inline(always)] - fn id(&self) -> Self::Id { - self.id - } - - fn name(&self) -> &'static str { - match self.data { - GlobalValueData::Symbol { .. } => "global.symbol", - GlobalValueData::Load { .. } => "global.load", - GlobalValueData::IAddImm { .. } => "global.iadd", - } - } - - #[inline(always)] - fn as_operation(&self) -> &Operation { - &self.op - } - - #[inline(always)] - fn as_operation_mut(&mut self) -> &mut Operation { - &mut self.op - } -} diff --git a/hir2/src/ops/inline_asm.rs b/hir2/src/ops/inline_asm.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/hir2/src/ops/mem.rs b/hir2/src/ops/mem.rs deleted file mode 100644 index 9f810e220..000000000 --- a/hir2/src/ops/mem.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::*; - -pub struct Store { - pub op: Operation, -} -impl Store { - pub fn addr(&self) -> Value { - todo!() - } - - pub fn value(&self) -> Value { - todo!() - } -} - -pub struct StoreLocal { - pub op: Operation, - pub local: LocalId, -} -impl StoreLocal { - pub fn value(&self) -> Value { - todo!() - } -} - -pub struct Load { - pub op: Operation, -} -impl Load { - pub fn addr(&self) -> Value { - todo!() - } -} - -pub struct LoadLocal { - pub op: Operation, - pub local: LocalId, -} diff --git a/hir2/src/ops/mod.rs b/hir2/src/ops/mod.rs deleted file mode 100644 index 5b8c4344c..000000000 --- a/hir2/src/ops/mod.rs +++ /dev/null @@ -1,16 +0,0 @@ -mod binary; -mod call; -mod cast; -mod control_flow; -mod global_value; -mod inline_asm; -mod mem; -mod primop; -mod ret; -mod structured_control_flow; -mod unary; - -pub use self::{ - binary::*, call::*, cast::*, control_flow::*, global_value::*, inline_asm::*, mem::*, - primop::*, ret::*, structured_control_flow::*, unary::*, -}; diff --git a/hir2/src/ops/primop.rs b/hir2/src/ops/primop.rs deleted file mode 100644 index a1437d829..000000000 --- a/hir2/src/ops/primop.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::*; - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum PrimOpcode { - MemGrow, - MemSize, - MemSet, - MemCpy, -} - -pub struct PrimOp { - pub op: Operation, -} - -pub struct PrimOpImm { - pub op: Operation, - pub imm: Immediate, -} - -pub struct Unreachable { - pub op: Operation, -} diff --git a/hir2/src/ops/ret.rs b/hir2/src/ops/ret.rs deleted file mode 100644 index c290ee36b..000000000 --- a/hir2/src/ops/ret.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::*; - -pub struct Ret { - pub op: Operation, -} - -pub struct RetImm { - pub op: Operation, - pub imm: Immediate, -} diff --git a/hir2/src/ops/structured_control_flow.rs b/hir2/src/ops/structured_control_flow.rs deleted file mode 100644 index a4add842d..000000000 --- a/hir2/src/ops/structured_control_flow.rs +++ /dev/null @@ -1,56 +0,0 @@ -use crate::*; - -pub struct If { - pub op: Operation, -} -impl If { - pub fn condition(&self) -> Value { - todo!() - } - - pub fn then_dest(&self) -> &Successor { - todo!() - } - - pub fn else_dest(&self) -> &Successor { - todo!() - } -} - -/// A while is a loop structure composed of two regions: a "before" region, and an "after" region. -/// -/// The "before" region's entry block parameters correspond to the operands expected by the -/// operation, and can be used to compute the condition that determines whether the "after" body -/// is executed or not, or simply forwarded to the "after" region. The "before" region must -/// terminate with a [Condition] operation, which will be evaluated to determine whether or not -/// to continue the loop. -/// -/// The "after" region corresponds to the loop body, and must terminate with a [Yield] operation, -/// whose operands must be of the same arity and type as the "before" region's argument list. In -/// this way, the "after" body can feed back input to the "before" body to determine whether to -/// continue the loop. -pub struct While { - pub op: Operation, -} -impl While { - pub fn before_region(&self) -> RegionId { - self.op.regions[0] - } - - pub fn after_region(&self) -> RegionId { - self.op.regions[1] - } -} - -pub struct Condition { - pub op: Operation, -} -impl Condition { - pub fn condition(&self) -> Value { - todo!() - } -} - -pub struct Yield { - pub op: Operation, -} diff --git a/hir2/src/ops/unary.rs b/hir2/src/ops/unary.rs deleted file mode 100644 index 623b040cd..000000000 --- a/hir2/src/ops/unary.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::*; - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum UnaryOpcode { - PtrToInt, - IntToPtr, - Cast, - Bitcast, - Trunc, - Zext, - Sext, - Test, - Neg, - Inv, - Incr, - Ilog2, - Pow2, - Popcnt, - Clz, - Ctz, - Clo, - Cto, - Not, - Bnot, - IsOdd, -} - -pub struct UnaryOp { - pub op: Operation, - pub opcode: UnaryOpcode, -} - -pub struct UnaryOpImm { - pub op: Operation, - pub opcode: UnaryOpcode, - pub imm: Immediate, -} diff --git a/hir2/src/unsafe_ref.rs b/hir2/src/unsafe_ref.rs deleted file mode 100644 index d510489f8..000000000 --- a/hir2/src/unsafe_ref.rs +++ /dev/null @@ -1,101 +0,0 @@ -use core::ptr::NonNull; - -#[derive(Copy, Clone)] -#[repr(transparent)] -pub struct UnsafeRef(NonNull); - -impl UnsafeRef { - /// Construct a new [UnsafeRef] from a non-null pointer to `T` - pub fn new(ptr: NonNull) -> Self { - Self(ptr) - } - - /// Get the underlying raw pointer for this [UnsafeRef] - #[inline(always)] - pub const fn into_raw(self) -> NonNull { - self.0 - } - - /// Construct an [UnsafeRef] from a [Box] - pub fn from_box(ptr: Box) -> Self { - Self(unsafe { NonNull::new_unchecked(Box::into_raw(ptr)) }) - } - - /// Convert this [UnsafeRef] back into the [Box] it was derived from. - /// - /// # Safety - /// - /// The following must be upheld by the caller: - /// - /// * This [UnsafeRef] _MUST_ have been created via [UnsafeRef::from_box] - /// * There _MUST NOT_ be any other [UnsafeRef] pointing to the same allocation - /// * `T` must be the same type as the original [Box] was allocated with - pub unsafe fn into_box(self) -> Box { - Box::from_raw(self.0.as_ptr()) - } -} - -impl core::ops::Deref for UnsafeRef { - type Target = T; - - #[inline] - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - -impl AsRef for UnsafeRef { - fn as_ref(&self) -> &T { - unsafe { self.0.as_ref() } - } -} - -impl AsRef for UnsafeRef -where - T: core::marker::Unsize + ?Sized, - U: ?Sized, -{ - fn as_ref(&self) -> &U { - unsafe { self.0.as_ref() as &U } - } -} - -impl core::borrow::Borrow for UnsafeRef { - fn borrow(&self) -> &T { - unsafe { self.0.as_ref() } - } -} - -impl core::ops::CoerceUnsized> for UnsafeRef -where - T: core::marker::Unsize + ?Sized, - U: ?Sized, -{ -} - -impl core::ops::DispatchFromDyn> for UnsafeRef -where - T: core::marker::Unsize + ?Sized, - U: ?Sized, -{ -} - -unsafe impl Send for UnsafeRef {} - -unsafe impl Sync for UnsafeRef {} - -unsafe impl intrusive_collections::PointerOps - for intrusive_collections::DefaultPointerOps> -{ - type Pointer = UnsafeRef; - type Value = T; - - unsafe fn from_raw(&self, value: *const Self::Value) -> Self::Pointer { - let value = NonNull::new(value.cast_mut()).expect("expected non-null node pointer"); - UnsafeRef::new(value) - } - - fn into_raw(&self, ptr: Self::Pointer) -> *const Self::Value { - ptr.into_raw().as_ptr().cast_const() - } -} From 3e48bb10702ef18c22e637000884905073ed16dc Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Thu, 19 Sep 2024 17:57:45 -0400 Subject: [PATCH 03/31] wip: move hir2/src/core to hir2/src/ir --- hir2/src/derive.rs | 10 ++++++++++ hir2/src/dialects/hir/ops/cast.rs | 5 ----- hir2/src/dialects/hir/ops/control.rs | 11 ----------- hir2/src/{core.rs => ir.rs} | 0 hir2/src/{core => ir}/attribute.rs | 0 hir2/src/{core => ir}/block.rs | 0 hir2/src/{core => ir}/component.rs | 0 hir2/src/{core => ir}/context.rs | 0 hir2/src/{core => ir}/dialect.rs | 0 hir2/src/{core => ir}/entity.rs | 0 hir2/src/{core => ir}/entity/list.rs | 0 hir2/src/{core => ir}/function.rs | 0 hir2/src/{core => ir}/ident.rs | 0 hir2/src/{core => ir}/immediates.rs | 0 hir2/src/{core => ir}/interface.rs | 0 hir2/src/{core => ir}/module.rs | 0 hir2/src/{core => ir}/op.rs | 0 hir2/src/{core => ir}/operation.rs | 0 hir2/src/{core => ir}/region.rs | 0 hir2/src/{core => ir}/symbol_table.rs | 0 hir2/src/{core => ir}/traits.rs | 0 hir2/src/{core => ir}/traits/multitrait.rs | 0 hir2/src/{core => ir}/types.rs | 0 hir2/src/{core => ir}/usable.rs | 0 hir2/src/{core => ir}/value.rs | 0 hir2/src/{core => ir}/verifier.rs | 0 hir2/src/lib.rs | 4 ++-- 27 files changed, 12 insertions(+), 18 deletions(-) rename hir2/src/{core.rs => ir.rs} (100%) rename hir2/src/{core => ir}/attribute.rs (100%) rename hir2/src/{core => ir}/block.rs (100%) rename hir2/src/{core => ir}/component.rs (100%) rename hir2/src/{core => ir}/context.rs (100%) rename hir2/src/{core => ir}/dialect.rs (100%) rename hir2/src/{core => ir}/entity.rs (100%) rename hir2/src/{core => ir}/entity/list.rs (100%) rename hir2/src/{core => ir}/function.rs (100%) rename hir2/src/{core => ir}/ident.rs (100%) rename hir2/src/{core => ir}/immediates.rs (100%) rename hir2/src/{core => ir}/interface.rs (100%) rename hir2/src/{core => ir}/module.rs (100%) rename hir2/src/{core => ir}/op.rs (100%) rename hir2/src/{core => ir}/operation.rs (100%) rename hir2/src/{core => ir}/region.rs (100%) rename hir2/src/{core => ir}/symbol_table.rs (100%) rename hir2/src/{core => ir}/traits.rs (100%) rename hir2/src/{core => ir}/traits/multitrait.rs (100%) rename hir2/src/{core => ir}/types.rs (100%) rename hir2/src/{core => ir}/usable.rs (100%) rename hir2/src/{core => ir}/value.rs (100%) rename hir2/src/{core => ir}/verifier.rs (100%) diff --git a/hir2/src/derive.rs b/hir2/src/derive.rs index e81fc641c..75a336dcb 100644 --- a/hir2/src/derive.rs +++ b/hir2/src/derive.rs @@ -1,6 +1,16 @@ use crate::Operation; /// This macro is used to generate the boilerplate for [Op] implementations. +/// +/// TODO(pauls): +/// +/// * Implement `#[region]` support +/// * Implement `#[successor]` support +/// * Implement `#[successors]` support for variadic successors +/// * Implement `#[successors(interface)]` to access successors through `SuccessorInterface` +/// * Support doc comments +/// * Implement type constraints/inference +/// * Implement `verify` blocks for custom verification rules #[macro_export] macro_rules! derive { ( diff --git a/hir2/src/dialects/hir/ops/cast.rs b/hir2/src/dialects/hir/ops/cast.rs index 494240bcf..829d47350 100644 --- a/hir2/src/dialects/hir/ops/cast.rs +++ b/hir2/src/dialects/hir/ops/cast.rs @@ -1,10 +1,5 @@ use crate::{dialects::hir::HirDialect, traits::*, *}; -// TODO(pauls): Implement support in `derive!` for expressing: -// -// * Doc comments -// * Type constraints -// * Additional verification rules /* #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum CastKind { diff --git a/hir2/src/dialects/hir/ops/control.rs b/hir2/src/dialects/hir/ops/control.rs index e6da63afe..f8cc57d5c 100644 --- a/hir2/src/dialects/hir/ops/control.rs +++ b/hir2/src/dialects/hir/ops/control.rs @@ -13,9 +13,6 @@ derive! { // TODO(pauls): RetImm -// TODO(pauls): Implement support for: -// -// * `#[successor]` to represent a single `Successor` of this op derive! { pub struct Br : Op implements Terminator { #[dialect] @@ -38,11 +35,6 @@ derive! { } } -// TODO(pauls): Implement support for: -// -// * `SuccessorInterface` for custom types which represent a `Successor` -// * `#[successors]` to represent variadic successors of an op -// * `#[successors(interface)]` to indicate that the successor info should be obtained from this field via `SuccessorInterface` derive! { pub struct Switch : Op implements Terminator { #[dialect] @@ -63,9 +55,6 @@ pub struct SwitchCase { pub successor: Successor, } -// TODO(pauls): Implement: -// -// * `region` attribute derive! { pub struct If : Op implements SingleBlock, NoRegionArguments { #[dialect] diff --git a/hir2/src/core.rs b/hir2/src/ir.rs similarity index 100% rename from hir2/src/core.rs rename to hir2/src/ir.rs diff --git a/hir2/src/core/attribute.rs b/hir2/src/ir/attribute.rs similarity index 100% rename from hir2/src/core/attribute.rs rename to hir2/src/ir/attribute.rs diff --git a/hir2/src/core/block.rs b/hir2/src/ir/block.rs similarity index 100% rename from hir2/src/core/block.rs rename to hir2/src/ir/block.rs diff --git a/hir2/src/core/component.rs b/hir2/src/ir/component.rs similarity index 100% rename from hir2/src/core/component.rs rename to hir2/src/ir/component.rs diff --git a/hir2/src/core/context.rs b/hir2/src/ir/context.rs similarity index 100% rename from hir2/src/core/context.rs rename to hir2/src/ir/context.rs diff --git a/hir2/src/core/dialect.rs b/hir2/src/ir/dialect.rs similarity index 100% rename from hir2/src/core/dialect.rs rename to hir2/src/ir/dialect.rs diff --git a/hir2/src/core/entity.rs b/hir2/src/ir/entity.rs similarity index 100% rename from hir2/src/core/entity.rs rename to hir2/src/ir/entity.rs diff --git a/hir2/src/core/entity/list.rs b/hir2/src/ir/entity/list.rs similarity index 100% rename from hir2/src/core/entity/list.rs rename to hir2/src/ir/entity/list.rs diff --git a/hir2/src/core/function.rs b/hir2/src/ir/function.rs similarity index 100% rename from hir2/src/core/function.rs rename to hir2/src/ir/function.rs diff --git a/hir2/src/core/ident.rs b/hir2/src/ir/ident.rs similarity index 100% rename from hir2/src/core/ident.rs rename to hir2/src/ir/ident.rs diff --git a/hir2/src/core/immediates.rs b/hir2/src/ir/immediates.rs similarity index 100% rename from hir2/src/core/immediates.rs rename to hir2/src/ir/immediates.rs diff --git a/hir2/src/core/interface.rs b/hir2/src/ir/interface.rs similarity index 100% rename from hir2/src/core/interface.rs rename to hir2/src/ir/interface.rs diff --git a/hir2/src/core/module.rs b/hir2/src/ir/module.rs similarity index 100% rename from hir2/src/core/module.rs rename to hir2/src/ir/module.rs diff --git a/hir2/src/core/op.rs b/hir2/src/ir/op.rs similarity index 100% rename from hir2/src/core/op.rs rename to hir2/src/ir/op.rs diff --git a/hir2/src/core/operation.rs b/hir2/src/ir/operation.rs similarity index 100% rename from hir2/src/core/operation.rs rename to hir2/src/ir/operation.rs diff --git a/hir2/src/core/region.rs b/hir2/src/ir/region.rs similarity index 100% rename from hir2/src/core/region.rs rename to hir2/src/ir/region.rs diff --git a/hir2/src/core/symbol_table.rs b/hir2/src/ir/symbol_table.rs similarity index 100% rename from hir2/src/core/symbol_table.rs rename to hir2/src/ir/symbol_table.rs diff --git a/hir2/src/core/traits.rs b/hir2/src/ir/traits.rs similarity index 100% rename from hir2/src/core/traits.rs rename to hir2/src/ir/traits.rs diff --git a/hir2/src/core/traits/multitrait.rs b/hir2/src/ir/traits/multitrait.rs similarity index 100% rename from hir2/src/core/traits/multitrait.rs rename to hir2/src/ir/traits/multitrait.rs diff --git a/hir2/src/core/types.rs b/hir2/src/ir/types.rs similarity index 100% rename from hir2/src/core/types.rs rename to hir2/src/ir/types.rs diff --git a/hir2/src/core/usable.rs b/hir2/src/ir/usable.rs similarity index 100% rename from hir2/src/core/usable.rs rename to hir2/src/ir/usable.rs diff --git a/hir2/src/core/value.rs b/hir2/src/ir/value.rs similarity index 100% rename from hir2/src/core/value.rs rename to hir2/src/ir/value.rs diff --git a/hir2/src/core/verifier.rs b/hir2/src/ir/verifier.rs similarity index 100% rename from hir2/src/core/verifier.rs rename to hir2/src/ir/verifier.rs diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index 425d7975e..b5a574538 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -17,10 +17,10 @@ extern crate alloc; extern crate std; mod attributes; -mod core; pub mod demangle; pub mod derive; pub mod dialects; pub mod formatter; +mod ir; -pub use self::{attributes::*, core::*}; +pub use self::{attributes::*, ir::*}; From b8ee446bbc17ecb7b87795fa7e57bf039fec9479 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 20 Sep 2024 16:08:55 -0400 Subject: [PATCH 04/31] feat: support no-std builds of midenc_hir_symbol --- Cargo.lock | 3 + Cargo.toml | 1 + hir-symbol/Cargo.toml | 7 +- hir-symbol/src/lib.rs | 99 ++++++++++---- hir-symbol/src/sync.rs | 20 +++ hir-symbol/src/sync/lazy_lock.rs | 187 ++++++++++++++++++++++++++ hir-symbol/src/sync/rw_lock.rs | 220 +++++++++++++++++++++++++++++++ 7 files changed, 509 insertions(+), 28 deletions(-) create mode 100644 hir-symbol/src/sync.rs create mode 100644 hir-symbol/src/sync/lazy_lock.rs create mode 100644 hir-symbol/src/sync/rw_lock.rs diff --git a/Cargo.lock b/Cargo.lock index e60c5773e..a6dad288f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3390,6 +3390,9 @@ name = "midenc-hir-symbol" version = "0.0.6" dependencies = [ "Inflector", + "compact_str", + "lock_api", + "parking_lot", "rustc-hash 1.1.0", "serde 1.0.210", "toml 0.8.19", diff --git a/Cargo.toml b/Cargo.toml index 460a9d1f1..02a397543 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ clap = { version = "4.1", default-features = false, features = [ ] } cranelift-entity = "0.108" cranelift-bforest = "0.108" +compact_str = { version = "0.8", default-features = false } env_logger = "0.11" either = { version = "1.10", default-features = false } expect-test = "1.4.1" diff --git a/hir-symbol/Cargo.toml b/hir-symbol/Cargo.toml index 6274ede9d..14cd4bf8a 100644 --- a/hir-symbol/Cargo.toml +++ b/hir-symbol/Cargo.toml @@ -12,10 +12,15 @@ readme.workspace = true edition.workspace = true [features] -default = [] +default = ["std"] +std = ["dep:parking_lot"] serde = ["dep:serde"] +compact_str = ["dep:compact_str"] [dependencies] +compact_str = { workspace = true, optional = true } +lock_api = "0.4" +parking_lot = { version = "0.12", optional = true } serde = { workspace = true, optional = true } [build-dependencies] diff --git a/hir-symbol/src/lib.rs b/hir-symbol/src/lib.rs index 7d24dc6f6..7b94862a3 100644 --- a/hir-symbol/src/lib.rs +++ b/hir-symbol/src/lib.rs @@ -1,26 +1,29 @@ -use core::{fmt, mem, ops::Deref, str}; -use std::{ +#![no_std] + +extern crate alloc; +#[cfg(feature = "std")] +extern crate std; + +mod sync; + +use alloc::{ + boxed::Box, collections::BTreeMap, - sync::{OnceLock, RwLock}, + string::{String, ToString}, + vec::Vec, }; - -static SYMBOL_TABLE: OnceLock = OnceLock::new(); +use core::{fmt, mem, ops::Deref, str}; pub mod symbols { include!(env!("SYMBOLS_RS")); } +static SYMBOL_TABLE: sync::LazyLock = sync::LazyLock::new(SymbolTable::default); + +#[derive(Default)] struct SymbolTable { - interner: RwLock, + interner: sync::RwLock, } -impl SymbolTable { - pub fn new() -> Self { - Self { - interner: RwLock::new(Interner::new()), - } - } -} -unsafe impl Sync for SymbolTable {} /// A symbol is an interned string. #[derive(Clone, Copy, PartialEq, Eq, Hash)] @@ -68,8 +71,8 @@ impl Symbol { } /// Maps a string to its interned representation. - pub fn intern>(string: S) -> Self { - let string = string.into(); + pub fn intern(string: impl ToString) -> Self { + let string = string.to_string(); with_interner(|interner| interner.intern(string)) } @@ -122,6 +125,36 @@ impl> PartialEq for Symbol { self.as_str() == other.deref() } } +impl From<&'static str> for Symbol { + fn from(s: &'static str) -> Self { + with_interner(|interner| interner.insert(s)) + } +} +impl From for Symbol { + fn from(s: String) -> Self { + Self::intern(s) + } +} +impl From> for Symbol { + fn from(s: Box) -> Self { + Self::intern(s) + } +} +impl From> for Symbol { + fn from(s: alloc::borrow::Cow<'static, str>) -> Self { + use alloc::borrow::Cow; + match s { + Cow::Borrowed(s) => s.into(), + Cow::Owned(s) => Self::intern(s), + } + } +} +#[cfg(feature = "compact_str")] +impl From for Symbol { + fn from(s: compact_str::CompactString) -> Self { + Self::intern(s.into_string()) + } +} #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] struct SymbolIndex(u32); @@ -159,22 +192,26 @@ impl From for usize { } } -#[derive(Default)] struct Interner { pub names: BTreeMap<&'static str, Symbol>, pub strings: Vec<&'static str>, } -impl Interner { - pub fn new() -> Self { - let mut this = Interner::default(); +impl Default for Interner { + fn default() -> Self { + let mut this = Self { + names: BTreeMap::default(), + strings: Vec::with_capacity(symbols::__SYMBOLS.len()), + }; for (sym, s) in symbols::__SYMBOLS { this.names.insert(s, *sym); this.strings.push(s); } this } +} +impl Interner { pub fn intern(&mut self, string: String) -> Symbol { if let Some(&name) = self.names.get(string.as_str()) { return name; @@ -189,7 +226,17 @@ impl Interner { name } - pub fn get(&self, symbol: Symbol) -> &str { + pub fn insert(&mut self, s: &'static str) -> Symbol { + if let Some(&name) = self.names.get(s) { + return name; + } + let name = Symbol::new(self.strings.len() as u32); + self.strings.push(s); + self.names.insert(s, name); + name + } + + pub fn get(&self, symbol: Symbol) -> &'static str { self.strings[symbol.0.as_usize()] } } @@ -197,14 +244,12 @@ impl Interner { // If an interner exists, return it. Otherwise, prepare a fresh one. #[inline] fn with_interner T>(f: F) -> T { - let table = SYMBOL_TABLE.get_or_init(SymbolTable::new); - let mut r = table.interner.write().unwrap(); - f(&mut r) + let mut table = SYMBOL_TABLE.interner.write(); + f(&mut table) } #[inline] fn with_read_only_interner T>(f: F) -> T { - let table = SYMBOL_TABLE.get_or_init(SymbolTable::new); - let r = table.interner.read().unwrap(); - f(&r) + let table = SYMBOL_TABLE.interner.read(); + f(&table) } diff --git a/hir-symbol/src/sync.rs b/hir-symbol/src/sync.rs new file mode 100644 index 000000000..ffcab049a --- /dev/null +++ b/hir-symbol/src/sync.rs @@ -0,0 +1,20 @@ +#[cfg(all(not(feature = "std"), target_family = "wasm"))] +mod lazy_lock; +#[cfg(all(not(feature = "std"), target_family = "wasm"))] +mod rw_lock; + +#[cfg(feature = "std")] +pub use std::sync::LazyLock; + +#[cfg(feature = "std")] +#[allow(unused)] +pub use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +#[cfg(all(not(feature = "std"), target_family = "wasm"))] +pub use self::lazy_lock::RacyLock as LazyLock; +#[cfg(all(not(feature = "std"), target_family = "wasm"))] +#[allow(unused)] +pub use self::rw_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +#[cfg(all(not(feature = "std"), not(target_family = "wasm")))] +compile_error!("no_std builds of this crate are only supported on wasm targets"); diff --git a/hir-symbol/src/sync/lazy_lock.rs b/hir-symbol/src/sync/lazy_lock.rs new file mode 100644 index 000000000..157cd41ed --- /dev/null +++ b/hir-symbol/src/sync/lazy_lock.rs @@ -0,0 +1,187 @@ +use alloc::boxed::Box; +use core::{ + fmt, + ops::Deref, + ptr, + sync::atomic::{AtomicPtr, Ordering}, +}; + +/// Thread-safe, non-blocking, lazily evaluated lock with the same interface +/// as [`std::sync::LazyLock`]. +/// +/// Concurrent threads will race to set the value atomically, and memory allocated by losing threads +/// will be dropped immediately after they fail to set the pointer. +/// +/// The underlying implementation is based on `once_cell::race::OnceBox` which relies on +/// [`core::sync::atomic::AtomicPtr`] to ensure that the data race results in a single successful +/// write to the relevant pointer, namely the first write. +/// See . +/// +/// Performs lazy evaluation and can be used for statics. +pub struct RacyLock T> +where + F: Fn() -> T, +{ + inner: AtomicPtr, + f: F, +} + +impl RacyLock +where + F: Fn() -> T, +{ + /// Creates a new lazy, racy value with the given initializing function. + pub const fn new(f: F) -> Self { + Self { + inner: AtomicPtr::new(ptr::null_mut()), + f, + } + } + + /// Forces the evaluation of the locked value and returns a reference to + /// the result. This is equivalent to the [`Self::deref`]. + /// + /// There is no blocking involved in this operation. Instead, concurrent + /// threads will race to set the underlying pointer. Memory allocated by + /// losing threads will be dropped immediately after they fail to set the pointer. + /// + /// This function's interface is designed around [`std::sync::LazyLock::force`] but + /// the implementation is derived from `once_cell::race::OnceBox::get_or_try_init`. + pub fn force(this: &RacyLock) -> &T { + let mut ptr = this.inner.load(Ordering::Acquire); + + // Pointer is not yet set, attempt to set it ourselves. + if ptr.is_null() { + // Execute the initialization function and allocate. + let val = (this.f)(); + ptr = Box::into_raw(Box::new(val)); + + // Attempt atomic store. + let exchange = this.inner.compare_exchange( + ptr::null_mut(), + ptr, + Ordering::AcqRel, + Ordering::Acquire, + ); + + // Pointer already set, load. + if let Err(old) = exchange { + drop(unsafe { Box::from_raw(ptr) }); + ptr = old; + } + } + + unsafe { &*ptr } + } +} + +impl Default for RacyLock { + /// Creates a new lock that will evaluate the underlying value based on `T::default`. + #[inline] + fn default() -> RacyLock { + RacyLock::new(T::default) + } +} + +impl fmt::Debug for RacyLock +where + T: fmt::Debug, + F: Fn() -> T, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "RacyLock({:?})", self.inner.load(Ordering::Relaxed)) + } +} + +impl Deref for RacyLock +where + F: Fn() -> T, +{ + type Target = T; + + /// Either sets or retrieves the value, and dereferences it. + /// + /// See [`Self::force`] for more details. + #[inline] + fn deref(&self) -> &T { + RacyLock::force(self) + } +} + +impl Drop for RacyLock +where + F: Fn() -> T, +{ + /// Drops the underlying pointer. + fn drop(&mut self) { + let ptr = *self.inner.get_mut(); + if !ptr.is_null() { + // SAFETY: for any given value of `ptr`, we are guaranteed to have at most a single + // instance of `RacyLock` holding that value. Hence, synchronizing threads + // in `drop()` is not necessary, and we are guaranteed never to double-free. + // In short, since `RacyLock` doesn't implement `Clone`, the only scenario + // where there can be multiple instances of `RacyLock` across multiple threads + // referring to the same `ptr` value is when `RacyLock` is used in a static variable. + drop(unsafe { Box::from_raw(ptr) }); + } + } +} + +#[cfg(test)] +mod tests { + use alloc::vec::Vec; + + use super::*; + + #[test] + fn deref_default() { + // Lock a copy type and validate default value. + let lock: RacyLock = RacyLock::default(); + assert_eq!(*lock, 0); + } + + #[test] + fn deref_copy() { + // Lock a copy type and validate value. + let lock = RacyLock::new(|| 42); + assert_eq!(*lock, 42); + } + + #[test] + fn deref_clone() { + // Lock a no copy type. + let lock = RacyLock::new(|| Vec::from([1, 2, 3])); + + // Use the value so that the compiler forces us to clone. + let mut v = lock.clone(); + v.push(4); + + // Validate the value. + assert_eq!(v, Vec::from([1, 2, 3, 4])); + } + + #[test] + fn deref_static() { + // Create a static lock. + static VEC: RacyLock> = RacyLock::new(|| Vec::from([1, 2, 3])); + + // Validate that the address of the value does not change. + let addr = &*VEC as *const Vec; + for _ in 0..5 { + assert_eq!(*VEC, [1, 2, 3]); + assert_eq!(addr, &(*VEC) as *const Vec) + } + } + + #[test] + fn type_inference() { + // Check that we can infer `T` from closure's type. + let _ = RacyLock::new(|| ()); + } + + #[test] + fn is_sync_send() { + fn assert_traits() {} + assert_traits::>>(); + } +} diff --git a/hir-symbol/src/sync/rw_lock.rs b/hir-symbol/src/sync/rw_lock.rs new file mode 100644 index 000000000..66b61bf0a --- /dev/null +++ b/hir-symbol/src/sync/rw_lock.rs @@ -0,0 +1,220 @@ +use core::{ + hint, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use lock_api::RawRwLock; + +/// An implementation of a reader-writer lock, based on a spinlock primitive, no-std compatible +/// +/// See [lock_api::RwLock] for usage. +pub type RwLock = lock_api::RwLock; + +/// See [lock_api::RwLockReadGuard] for usage. +pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, Spinlock, T>; + +/// See [lock_api::RwLockWriteGuard] for usage. +pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, Spinlock, T>; + +/// The underlying raw reader-writer primitive that implements [lock_api::RawRwLock] +/// +/// This is fundamentally a spinlock, in that blocking operations on the lock will spin until +/// they succeed in acquiring/releasing the lock. +/// +/// To acheive the ability to share the underlying data with multiple readers, or hold +/// exclusive access for one writer, the lock state is based on a "locked" count, where shared +/// access increments the count by an even number, and acquiring exclusive access relies on the +/// use of the lowest order bit to stop further shared acquisition, and indicate that the lock +/// is exclusively held (the difference between the two is irrelevant from the perspective of +/// a thread attempting to acquire the lock, but internally the state uses `usize::MAX` as the +/// "exlusively locked" sentinel). +/// +/// This mechanism gets us the following: +/// +/// * Whether the lock has been acquired (shared or exclusive) +/// * Whether the lock is being exclusively acquired +/// * How many times the lock has been acquired +/// * Whether the acquisition(s) are exclusive or shared +/// +/// Further implementation details, such as how we manage draining readers once an attempt to +/// exclusively acquire the lock occurs, are described below. +/// +/// NOTE: This is a simple implementation, meant for use in no-std environments; there are much +/// more robust/performant implementations available when OS primitives can be used. +pub struct Spinlock { + /// The state of the lock, primarily representing the acquisition count, but relying on + /// the distinction between even and odd values to indicate whether or not exclusive access + /// is being acquired. + state: AtomicUsize, + /// A counter used to wake a parked writer once the last shared lock is released during + /// acquisition of an exclusive lock. The actual count is not acutally important, and + /// simply wraps around on overflow, but what is important is that when the value changes, + /// the writer will wake and resume attempting to acquire the exclusive lock. + writer_wake_counter: AtomicUsize, +} + +impl Default for Spinlock { + #[inline(always)] + fn default() -> Self { + Self::new() + } +} + +impl Spinlock { + pub const fn new() -> Self { + Self { + state: AtomicUsize::new(0), + writer_wake_counter: AtomicUsize::new(0), + } + } +} + +unsafe impl RawRwLock for Spinlock { + type GuardMarker = lock_api::GuardSend; + + // This is intentional on the part of the [RawRwLock] API, basically a hack to provide + // initial values as static items. + #[allow(clippy::declare_interior_mutable_const)] + const INIT: Spinlock = Spinlock::new(); + + /// The operation invoked when calling `RwLock::read`, blocks the caller until acquired + fn lock_shared(&self) { + let mut s = self.state.load(Ordering::Relaxed); + loop { + // If the exclusive bit is unset, attempt to acquire a read lock + if s & 1 == 0 { + match self.state.compare_exchange_weak( + s, + s + 2, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => return, + // Someone else beat us to the punch and acquired a lock + Err(e) => s = e, + } + } + // If an exclusive lock is held/being acquired, loop until the lock state changes + // at which point, try to acquire the lock again + if s & 1 == 1 { + loop { + let next = self.state.load(Ordering::Relaxed); + if s == next { + hint::spin_loop(); + continue; + } else { + s = next; + break; + } + } + } + } + } + + /// The operation invoked when calling `RwLock::try_read`, returns whether or not the + /// lock was acquired + fn try_lock_shared(&self) -> bool { + let s = self.state.load(Ordering::Relaxed); + if s & 1 == 0 { + self.state + .compare_exchange_weak(s, s + 2, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } else { + false + } + } + + /// The operation invoked when dropping a `RwLockReadGuard` + unsafe fn unlock_shared(&self) { + if self.state.fetch_sub(2, Ordering::Release) == 3 { + // The lock is being exclusively acquired, and we're the last shared acquisition + // to be released, so wake the writer by incrementing the wake counter + self.writer_wake_counter.fetch_add(1, Ordering::Release); + } + } + + /// The operation invoked when calling `RwLock::write`, blocks the caller until acquired + fn lock_exclusive(&self) { + let mut s = self.state.load(Ordering::Relaxed); + loop { + // Attempt to acquire the lock immediately, or complete acquistion of the lock + // if we're continuing the loop after acquiring the exclusive bit. If another + // thread acquired it first, we race to be the first thread to acquire it once + // released, by busy looping here. + if s <= 1 { + match self.state.compare_exchange( + s, + usize::MAX, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => return, + Err(e) => { + s = e; + hint::spin_loop(); + continue; + } + } + } + + // Only shared locks have been acquired, attempt to acquire the exclusive bit, + // which will prevent further shared locks from being acquired. It does not + // in and of itself grant us exclusive access however. + if s & 1 == 0 { + if let Err(e) = + self.state.compare_exchange(s, s + 1, Ordering::Relaxed, Ordering::Relaxed) + { + // The lock state has changed before we could acquire the exclusive bit, + // update our view of the lock state and try again + s = e; + continue; + } + } + + // We've acquired the exclusive bit, now we need to busy wait until all shared + // acquisitions are released. + let w = self.writer_wake_counter.load(Ordering::Acquire); + s = self.state.load(Ordering::Relaxed); + + // "Park" the thread here (by busy looping), until the release of the last shared + // lock, which is communicated to us by it incrementing the wake counter. + if s >= 2 { + while self.writer_wake_counter.load(Ordering::Acquire) == w { + hint::spin_loop(); + } + s = self.state.load(Ordering::Relaxed); + } + + // All shared locks have been released, go back to the top and try to complete + // acquisition of exclusive access. + } + } + + /// The operation invoked when calling `RwLock::try_write`, returns whether or not the + /// lock was acquired + fn try_lock_exclusive(&self) -> bool { + let s = self.state.load(Ordering::Relaxed); + if s <= 1 { + self.state + .compare_exchange(s, usize::MAX, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } else { + false + } + } + + /// The operation invoked when dropping a `RwLockWriteGuard` + unsafe fn unlock_exclusive(&self) { + // Infallible, as we hold an exclusive lock + // + // Note the use of `Release` ordering here, which ensures any loads of the lock state + // by other threads, are ordered after this store. + self.state.store(0, Ordering::Release); + // This fetch_add isn't important for signaling purposes, however it serves a key + // purpose, in that it imposes a memory ordering on any loads of this field that + // have an `Acquire` ordering, i.e. they will read the value stored here. Without + // a `Release` store, loads/stores of this field could be reordered relative to + // each other. + self.writer_wake_counter.fetch_add(1, Ordering::Release); + } +} From db1aca0d09dc040f3f126ae7bb6f8e524882370b Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Mon, 23 Sep 2024 01:33:42 -0400 Subject: [PATCH 05/31] wip: implement hir dialect ops, flesh out remaining core ir infra --- Cargo.lock | 8 +- hir-symbol/src/lib.rs | 12 + hir-type/src/lib.rs | 5 + hir2/Cargo.toml | 4 +- hir2/src/attributes.rs | 4 +- hir2/src/attributes/linkage.rs | 66 -- hir2/src/attributes/visibility.rs | 48 ++ hir2/src/derive.rs | 914 +++++++++++++++++++++++---- hir2/src/dialects/hir/ops/binary.rs | 64 +- hir2/src/dialects/hir/ops/cast.rs | 28 +- hir2/src/dialects/hir/ops/control.rs | 73 ++- hir2/src/dialects/hir/ops/invoke.rs | 48 +- hir2/src/dialects/hir/ops/mem.rs | 8 +- hir2/src/dialects/hir/ops/primop.rs | 16 +- hir2/src/dialects/hir/ops/unary.rs | 28 +- hir2/src/ir.rs | 25 +- hir2/src/ir/attribute.rs | 23 + hir2/src/ir/block.rs | 41 +- hir2/src/ir/context.rs | 13 + hir2/src/ir/dialect.rs | 3 + hir2/src/ir/entity.rs | 28 + hir2/src/ir/entity/list.rs | 18 + hir2/src/ir/function.rs | 183 +++++- hir2/src/ir/ident.rs | 7 +- hir2/src/ir/insert.rs | 100 +++ hir2/src/ir/module.rs | 106 +++- hir2/src/ir/op.rs | 40 +- hir2/src/ir/operands.rs | 564 +++++++++++++++++ hir2/src/ir/operation.rs | 343 +++++----- hir2/src/ir/operation/builder.rs | 168 +++++ hir2/src/ir/operation/name.rs | 38 ++ hir2/src/ir/region.rs | 33 +- hir2/src/ir/successor.rs | 23 + hir2/src/ir/symbol_table.rs | 448 ++++++++++++- hir2/src/ir/traits.rs | 83 +-- hir2/src/ir/traits/callable.rs | 117 ++++ hir2/src/ir/traits/multitrait.rs | 6 + hir2/src/ir/traits/types.rs | 489 ++++++++++++++ hir2/src/ir/usable.rs | 26 +- hir2/src/ir/value.rs | 94 +-- hir2/src/ir/verifier.rs | 15 +- hir2/src/ir/visit.rs | 162 +++++ hir2/src/lib.rs | 21 + 43 files changed, 3871 insertions(+), 672 deletions(-) delete mode 100644 hir2/src/attributes/linkage.rs create mode 100644 hir2/src/attributes/visibility.rs create mode 100644 hir2/src/ir/insert.rs create mode 100644 hir2/src/ir/operands.rs create mode 100644 hir2/src/ir/operation/builder.rs create mode 100644 hir2/src/ir/operation/name.rs create mode 100644 hir2/src/ir/successor.rs create mode 100644 hir2/src/ir/traits/callable.rs create mode 100644 hir2/src/ir/traits/types.rs create mode 100644 hir2/src/ir/visit.rs diff --git a/Cargo.lock b/Cargo.lock index a6dad288f..052381346 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1370,12 +1370,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "downcast-rs" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" - [[package]] name = "ecdsa" version = "0.16.9" @@ -3428,9 +3422,9 @@ version = "0.0.6" dependencies = [ "anyhow", "blink-alloc", + "compact_str", "cranelift-entity", "derive_more", - "downcast-rs", "either", "hashbrown 0.14.5", "indexmap 2.5.0", diff --git a/hir-symbol/src/lib.rs b/hir-symbol/src/lib.rs index 7b94862a3..40d5a8254 100644 --- a/hir-symbol/src/lib.rs +++ b/hir-symbol/src/lib.rs @@ -120,6 +120,18 @@ impl Ord for Symbol { self.as_str().cmp(other.as_str()) } } +impl AsRef for Symbol { + #[inline(always)] + fn as_ref(&self) -> &str { + self.as_str() + } +} +impl core::borrow::Borrow for Symbol { + #[inline(always)] + fn borrow(&self) -> &str { + self.as_str() + } +} impl> PartialEq for Symbol { fn eq(&self, other: &T) -> bool { self.as_str() == other.deref() diff --git a/hir-type/src/lib.rs b/hir-type/src/lib.rs index c19c93c2d..978950677 100644 --- a/hir-type/src/lib.rs +++ b/hir-type/src/lib.rs @@ -217,6 +217,11 @@ impl Type { matches!(self, Self::Array(_, _)) } + #[inline] + pub fn is_list(&self) -> bool { + matches!(self, Self::List(_)) + } + /// Returns true if `self` and `other` are compatible operand types for a binary operator, e.g. /// `add` /// diff --git a/hir2/Cargo.toml b/hir2/Cargo.toml index 4c352c9be..72dfa1ea4 100644 --- a/hir2/Cargo.toml +++ b/hir2/Cargo.toml @@ -13,7 +13,7 @@ edition.workspace = true [features] default = ["std"] -std = ["rustc-demangle/std"] +std = ["rustc-demangle/std", "compact_str/std"] serde = [ "dep:serde", "dep:serde_repr", @@ -33,7 +33,7 @@ blink-alloc = { version = "0.3", default-features = false, features = [ ] } either.workspace = true cranelift-entity.workspace = true -downcast-rs = { version = "1.2", default-features = false } +compact_str.workspace = true hashbrown.workspace = true intrusive-collections.workspace = true inventory.workspace = true diff --git a/hir2/src/attributes.rs b/hir2/src/attributes.rs index 5a469e47a..c24d79ee0 100644 --- a/hir2/src/attributes.rs +++ b/hir2/src/attributes.rs @@ -1,5 +1,5 @@ mod call_conv; -mod linkage; mod overflow; +mod visibility; -pub use self::{call_conv::CallConv, linkage::Linkage, overflow::Overflow}; +pub use self::{call_conv::CallConv, overflow::Overflow, visibility::Visibility}; diff --git a/hir2/src/attributes/linkage.rs b/hir2/src/attributes/linkage.rs deleted file mode 100644 index d1d3c1277..000000000 --- a/hir2/src/attributes/linkage.rs +++ /dev/null @@ -1,66 +0,0 @@ -use core::fmt; - -/// The policy to apply to a global variable (or function) when linking -/// together a program during code generation. -/// -/// Miden doesn't (currently) have a notion of a symbol table for things like global variables. -/// At runtime, there are not actually symbols at all in any familiar sense, instead functions, -/// being the only entities with a formal identity in MASM, are either inlined at all their call -/// sites, or are referenced by the hash of their MAST root, to be unhashed at runtime if the call -/// is executed. -/// -/// Because of this, and because we cannot perform linking ourselves (we must emit separate modules, -/// and leave it up to the VM to link them into the MAST), there are limits to what we can do in -/// terms of linking function symbols. We essentially just validate that given a set of modules in -/// a [Program], that there are no invalid references across modules to symbols which either don't -/// exist, or which exist, but have internal linkage. -/// -/// However, with global variables, we have a bit more freedom, as it is a concept that we are -/// completely inventing from whole cloth without explicit support from the VM or Miden Assembly. -/// In short, when we compile a [Program] to MASM, we first gather together all of the global -/// variables into a program-wide table, merging and garbage collecting as appropriate, and updating -/// all references to them in each module. This global variable table is then assumed to be laid out -/// in memory starting at the base of the linear memory address space in the same order, with -/// appropriate padding to ensure accesses are aligned. Then, when emitting MASM instructions which -/// reference global values, we use the layout information to derive the address where that global -/// value is allocated. -/// -/// This has some downsides however, the biggest of which is that we can't prevent someone from -/// loading modules generated from a [Program] with either their own hand-written modules, or -/// even with modules from another [Program]. In such cases, assumptions about the allocation of -/// linear memory from different sets of modules will almost certainly lead to undefined behavior. -/// In the future, we hope to have a better solution to this problem, preferably one involving -/// native support from the Miden VM itself. For now though, we're working with what we've got. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] -#[cfg_attr( - feature = "serde", - derive(serde_repr::Deserialize_repr, serde_repr::Serialize_repr) -)] -#[repr(u8)] -pub enum Linkage { - /// This symbol is only visible in the containing module. - /// - /// Internal symbols may be renamed to avoid collisions - /// - /// Unreferenced internal symbols can be discarded at link time. - Internal, - /// This symbol will be linked using the "one definition rule", i.e. symbols with - /// the same name, type, and linkage will be merged into a single definition. - /// - /// Unlike `internal` linkage, unreferenced `odr` symbols cannot be discarded. - /// - /// NOTE: `odr` symbols cannot satisfy external symbol references - Odr, - /// This symbol is visible externally, and can be used to resolve external symbol references. - #[default] - External, -} -impl fmt::Display for Linkage { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Internal => f.write_str("internal"), - Self::Odr => f.write_str("odr"), - Self::External => f.write_str("external"), - } - } -} diff --git a/hir2/src/attributes/visibility.rs b/hir2/src/attributes/visibility.rs new file mode 100644 index 000000000..2d8c7e188 --- /dev/null +++ b/hir2/src/attributes/visibility.rs @@ -0,0 +1,48 @@ +use core::{fmt, str::FromStr}; + +/// The types of visibility that a [Symbol] may have +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum Visibility { + /// The symbol is public and may be referenced anywhere internal or external to the visible + /// references in the IR. + /// + /// Public visibility implies that we cannot remove the symbol even if we are unaware of any + /// references, and no other constraints apply, as we must assume that the symbol has references + /// we don't know about. + #[default] + Public, + /// The symbol is private and may only be referenced by ops local to operations within the + /// current symbol table. + /// + /// Private visibility implies that we know all uses of the symbol, and that those uses must + /// all exist within the current symbol table. + Private, + /// The symbol is public, but may only be referenced by symbol tables in the current compilation + /// graph, thus retaining the ability to observe all uses, and optimize based on that + /// information. + /// + /// Nested visibility implies that we know all uses of the symbol, but that there may be uses + /// in other symbol tables in addition to the current one. + Nested, +} +impl fmt::Display for Visibility { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Public => f.write_str("public"), + Self::Private => f.write_str("private"), + Self::Nested => f.write_str("nested"), + } + } +} +impl FromStr for Visibility { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "public" => Ok(Self::Public), + "private" => Ok(Self::Private), + "nested" => Ok(Self::Nested), + _ => Err(()), + } + } +} diff --git a/hir2/src/derive.rs b/hir2/src/derive.rs index 75a336dcb..d6ef93136 100644 --- a/hir2/src/derive.rs +++ b/hir2/src/derive.rs @@ -4,13 +4,19 @@ use crate::Operation; /// /// TODO(pauls): /// -/// * Implement `#[region]` support -/// * Implement `#[successor]` support -/// * Implement `#[successors]` support for variadic successors -/// * Implement `#[successors(interface)]` to access successors through `SuccessorInterface` /// * Support doc comments /// * Implement type constraints/inference /// * Implement `verify` blocks for custom verification rules +/// * FIX: Currently #[operands] simply adds boilerplate for creating an operation with those +/// operands, but it does not create methods to access them, and simply adds them in with the +/// other operands. We should figure out how to store operands in such a way that multiple operand +/// groups can be maintained even when adding/removing operands later. +/// * FIX: Currently #[successors] adds a field to the struct to store whatever custom type is used +/// to represent the successors, but these successors are not reachable from the Operation backing +/// the op, and as a result, any successor operations acting from the Operation and not the Op may +/// cause the two to converge. Like the #[operands] issue above, we need to store the actual +/// successor in the Operation, and provide some way to map between the two, OR change how we +/// represent successors to allow storing arbitrary successor-like types in the Operation #[macro_export] macro_rules! derive { ( @@ -78,32 +84,8 @@ macro_rules! derive { )* } - $($t:tt)* - ) => { - $crate::__derive_op!( - $(#[$outer])* - #[derive($crate::Spanned)] - $vis struct $Op { - $( - $(#[$inner $($args)*])* - $Field: $FieldTy - ),* - } - ); - - $($t)* - }; - - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident : Op implements $OpTrait:ident { - $( - $(#[$inner:ident $($args:tt)*])* - $Field:ident: $FieldTy:ty, - )* - } - - $($t:tt)* + $(derives $DerivedOpTrait:ident $(, $MoreDerivedTraits:ident)*;)* + $(implements $ImplementedOpTrait:ident $(, $MoreImplementedTraits:ident)*;)* ) => { $crate::__derive_op!( $(#[$outer])* @@ -114,37 +96,19 @@ macro_rules! derive { )* } - implement $OpTrait; - ); - - $($t)* - }; - - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident : Op implements $OpTrait1:ident $(, $OpTraitRest:ident)* { $( - $(#[$inner:ident $($args:tt)*])* - $Field:ident: $FieldTy:ty, + derives $DerivedOpTrait + $( + derives $MoreDerivedTraits + )* )* - } - - $($t:tt)* - ) => { - $crate::__derive_op!( - $(#[$outer])* - $vis struct $Op { + $( + implements $ImplementedOpTrait $( - $(#[$inner $($args)*])* - $Field: $FieldTy, + implements $MoreImplementedTraits )* - } - - implement $OpTrait1 - $(, implement $OpTraitRest)*; + )* ); - - $($t)* }; } @@ -229,9 +193,10 @@ macro_rules! __derive_op { )* } - $(implement $OpTrait:ident),*; + $(derives $DerivedOpTrait:ident)* + $(implements $ImplementedOpTrait:ident)* ) => { - $crate::__derive_op! { + $crate::__derive_op_processor! { $(#[$outer])* $vis struct $Op; @@ -239,15 +204,25 @@ macro_rules! __derive_op { $( { unprocessed: [$(#[$inner $($args)*])*], + ignore: [], field: $Field, field_type: $FieldTy, } )* ], processed: { + fields: [], dialect: [], - traits: [$(implement $OpTrait),*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], attrs: [], + regions_count: [0usize], + regions: [], + successor_groups_count: [0usize], + successor_groups: [], + successors_count: [0usize], + successors: [], + operand_groups_count: [0usize], + operand_groups: [], operands_count: [0usize], operands: [], results_count: [0usize], @@ -255,7 +230,11 @@ macro_rules! __derive_op { } } }; +} +#[doc(hidden)] +#[macro_export] +macro_rules! __derive_op_processor { // Handle duplicate `dialect` attr ( $(#[$outer:meta])* @@ -267,15 +246,25 @@ macro_rules! __derive_op { #[dialect] $($attrs_rest:tt)* ], + ignore: [$($IgnoredReason:tt)*], field: $Field:ident, field_type: $FieldTy:ty, } $($fields_rest:tt)* ], processed: { + fields: [$($extra_fields_processed:tt)*], dialect: [$(dialect_processed:tt)+], - traits: [$(implement $OpTrait:ident),*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], operands_count: [$operands_count:expr], operands: [$($operands_processed:tt)*], results_count: [$results_count:expr], @@ -296,22 +285,32 @@ macro_rules! __derive_op { #[dialect] $($attrs_rest:tt)* ], + ignore: [$($IgnoredReason:tt)*], field: $Field:ident, field_type: $FieldTy:ty, } $($fields_rest:tt)* ], processed: { + fields: [$($extra_fields_processed:tt)*], dialect: [], - traits: [$(implement $OpTrait:ident),*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], operands_count: [$operands_count:expr], operands: [$($operands_processed:tt)*], results_count: [$results_count:expr], results: [$($results_processed:tt)*], } ) => { - $crate::__derive_op! { + $crate::__derive_op_processor! { $(#[$outer])* $vis struct $Op; @@ -320,15 +319,241 @@ macro_rules! __derive_op { unprocessed: [ $($attrs_rest)* ], + ignore: [dialect $($IgnoredReason)*], field: $Field, field_type: $FieldTy, } $($fields_rest)* ], processed: { + fields: [$($extra_fields_processed)*], dialect: [dialect $FieldTy], - traits: [$(implement $OpTrait),*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], + attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle `region` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[region $($args:tt)*] + $($attrs_rest:tt)* + ], + ignore: [$($IgnoredReason:tt)*], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + fields: [$($extra_fields_processed:tt)*], + dialect: [$($dialect_processed:tt)*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op_processor! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + ignore: [region $($IgnoredReason)*], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + fields: [$($extra_fields_processed)*], + dialect: [$($dialect_processed)*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], + attrs: [$($attrs_processed)*], + regions_count: [1usize + $regions_count], + regions: [region $Field at $regions_count $($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle `successor` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[successor $($args:tt)*] + $($attrs_rest:tt)* + ], + ignore: [$($IgnoredReason:tt)*], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + fields: [$($extra_fields_processed:tt)*], + dialect: [$($dialect_processed:tt)*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op_processor! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + ignore: [successor $($IgnoredReason)*], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + fields: [$($extra_fields_processed)*], + dialect: [$($dialect_processed)*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], + attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [1usize + $succ_count], + successors: [successor $Field at $succ_count $($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle `successors` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[successors $($args:tt)*] + $($attrs_rest:tt)* + ], + ignore: [$($IgnoredReason:tt)*], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + fields: [$($extra_fields_processed:tt)*], + dialect: [$($dialect_processed:tt)*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op_processor! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + ignore: [successors $($IgnoredReason)*], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + fields: [$($extra_fields_processed)*], + dialect: [$($dialect_processed)*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [1usize + $succ_groups_count], + successor_groups: [successors $Field : $FieldTy $($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], operands_count: [$operands_count], operands: [$($operands_processed)*], results_count: [$results_count], @@ -348,22 +573,32 @@ macro_rules! __derive_op { #[operand $($args:tt)*] $($attrs_rest:tt)* ], + ignore: [$($IgnoredReason:tt)*], field: $Field:ident, field_type: $FieldTy:ty, } $($fields_rest:tt)* ], processed: { + fields: [$($extra_fields_processed:tt)*], dialect: [$($dialect_processed:tt)*], - traits: [$(implement $OpTrait:ident),*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], operands_count: [$operands_count:expr], operands: [$($operands_processed:tt)*], results_count: [$results_count:expr], results: [$($results_processed:tt)*], } ) => { - $crate::__derive_op! { + $crate::__derive_op_processor! { $(#[$outer])* $vis struct $Op; @@ -372,15 +607,25 @@ macro_rules! __derive_op { unprocessed: [ $($attrs_rest)* ], + ignore: [operand $($IgnoredReason)*], field: $Field, field_type: $FieldTy, } $($fields_rest)* ], processed: { + fields: [$($extra_fields_processed)*], dialect: [$($dialect_processed)*], - traits: [$(implement $OpTrait),*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], operands_count: [1usize + $operands_count], operands: [operand $Field at $operands_count $($operands_processed)*], results_count: [$results_count], @@ -389,6 +634,78 @@ macro_rules! __derive_op { } }; + // Handle `operands` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[operands $($args:tt)*] + $($attrs_rest:tt)* + ], + ignore: [$($IgnoredReason:tt)*], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + fields: [$($extra_fields_processed:tt)*], + dialect: [$($dialect_processed:tt)*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op_processor! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + ignore: [operands $($IgnoredReason)*], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + fields: [$($extra_fields_processed)*], + dialect: [$($dialect_processed)*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], + attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [1usize + $operand_groups_count], + operand_groups: [operands $Field at $operand_groups_count $($operand_groups_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + // Handle `result` attr ( $(#[$outer:meta])* @@ -400,22 +717,32 @@ macro_rules! __derive_op { #[result $($args:tt)*] $($attrs_rest:tt)* ], + ignore: [$($IgnoredReason:tt)*], field: $Field:ident, field_type: $FieldTy:ty, } $($fields_rest:tt)* ], processed: { + fields: [$($extra_fields_processed:tt)*], dialect: [$($dialect_processed:tt)*], - traits: [$(implement $OpTrait:ident),*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], operands_count: [$operands_count:expr], operands: [$($operands_processed:tt)*], results_count: [$results_count:expr], results: [$($results_processed:tt)*], } ) => { - $crate::__derive_op! { + $crate::__derive_op_processor! { $(#[$outer])* $vis struct $Op; @@ -424,15 +751,25 @@ macro_rules! __derive_op { unprocessed: [ $($attrs_rest)* ], + ignore: [result $($IgnoredReason)*], field: $Field, field_type: $FieldTy, } $($fields_rest)* ], processed: { + fields: [$($extra_fields_processed)*], dialect: [$($dialect_processed)*], - traits: [$(implement $OpTrait),*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], operands_count: [$operands_count], operands: [$($operands_processed)*], results_count: [1usize + $results_count], @@ -452,22 +789,32 @@ macro_rules! __derive_op { #[attr $($args:tt)*] $($attrs_rest:tt)* ], + ignore: [$($IgnoredReason:tt)*], field: $Field:ident, field_type: $FieldTy:ty, } $($fields_rest:tt)* ], processed: { + fields: [$($extra_fields_processed:tt)*], dialect: [$($dialect_processed:tt)*], - traits: [$(implement $OpTrait:ident),*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], operands_count: [$operands_count:expr], operands: [$($operands_processed:tt)*], results_count: [$results_count:expr], results: [$($results_processed:tt)*], } ) => { - $crate::__derive_op! { + $crate::__derive_op_processor! { $(#[$outer])* $vis struct $Op; @@ -476,15 +823,158 @@ macro_rules! __derive_op { unprocessed: [ $($attrs_rest)* ], + ignore: [attr $($IgnoredReason)*], field: $Field, field_type: $FieldTy, } $($fields_rest)* ], processed: { + fields: [$($extra_fields_processed)*], dialect: [$($dialect_processed)*], - traits: [$(implement $OpTrait),*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], attrs: [attr $Field: $FieldTy $($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle `doc` attr + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [ + #[doc $($args:tt)*] + $($attrs_rest:tt)* + ], + ignore: [$($IgnoredReason:tt)*], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + fields: [$($extra_fields_processed:tt)*], + dialect: [$($dialect_processed:tt)*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op_processor! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + { + unprocessed: [ + $($attrs_rest)* + ], + ignore: [$($IgnoredReason)*], + field: $Field, + field_type: $FieldTy, + } + $($fields_rest)* + ], + processed: { + fields: [$($extra_fields_processed)*], + dialect: [$($dialect_processed)*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], + attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], + operands_count: [$operands_count], + operands: [$($operands_processed)*], + results_count: [$results_count], + results: [$($results_processed)*], + } + } + }; + + // Handle end of unprocessed attributes (ignore=false) + ( + $(#[$outer:meta])* + $vis:vis struct $Op:ident; + + unprocessed: [ + { + unprocessed: [], + ignore: [], + field: $Field:ident, + field_type: $FieldTy:ty, + } + $($fields_rest:tt)* + ], + processed: { + fields: [$($extra_fields_processed:tt)*], + dialect: [$($dialect_processed:tt)*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], + attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], + operands_count: [$operands_count:expr], + operands: [$($operands_processed:tt)*], + results_count: [$results_count:expr], + results: [$($results_processed:tt)*], + } + ) => { + $crate::__derive_op_processor! { + $(#[$outer])* + $vis struct $Op; + + unprocessed: [ + $($fields_rest)* + ], + processed: { + fields: [field $Field: $FieldTy $($extra_fields_processed)*], + dialect: [$($dialect_processed)*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], + attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], operands_count: [$operands_count], operands: [$($operands_processed)*], results_count: [$results_count], @@ -493,7 +983,7 @@ macro_rules! __derive_op { } }; - // Handle end of unprocessed attributes + // Handle end of unprocessed attributes (ignore=true) ( $(#[$outer:meta])* $vis:vis struct $Op:ident; @@ -501,22 +991,32 @@ macro_rules! __derive_op { unprocessed: [ { unprocessed: [], + ignore: [$($IgnoredReason:tt)+], field: $Field:ident, field_type: $FieldTy:ty, } $($fields_rest:tt)* ], processed: { + fields: [$($extra_fields_processed:tt)*], dialect: [$($dialect_processed:tt)*], - traits: [$(implement $OpTrait:ident),*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], operands_count: [$operands_count:expr], operands: [$($operands_processed:tt)*], results_count: [$results_count:expr], results: [$($results_processed:tt)*], } ) => { - $crate::__derive_op! { + $crate::__derive_op_processor! { $(#[$outer])* $vis struct $Op; @@ -524,9 +1024,18 @@ macro_rules! __derive_op { $($fields_rest)* ], processed: { + fields: [$($extra_fields_processed)*], dialect: [$($dialect_processed)*], - traits: [$(implement $OpTrait),*], + traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], attrs: [$($attrs_processed)*], + regions_count: [$regions_count], + regions: [$($regions_processed)*], + successor_groups_count: [$succ_groups_count], + successor_groups: [$($succ_groups_processed)*], + successors_count: [$succ_count], + successors: [$($succ_processed)*], + operand_groups_count: [$operand_groups_count], + operand_groups: [$($operand_groups_processed)*], operands_count: [$operands_count], operands: [$($operands_processed)*], results_count: [$results_count], @@ -542,9 +1051,18 @@ macro_rules! __derive_op { unprocessed: [], processed: { + fields: [$($extra_fields_processed:tt)*], dialect: [$($dialect_processed:tt)*], - traits: [$(implement $OpTrait:ident),*], + traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], attrs: [$($attrs_processed:tt)*], + regions_count: [$regions_count:expr], + regions: [$($regions_processed:tt)*], + successor_groups_count: [$succ_groups_count:expr], + successor_groups: [$($succ_groups_processed:tt)*], + successors_count: [$succ_count:expr], + successors: [$($succ_processed:tt)*], + operand_groups_count: [$operand_groups_count:expr], + operand_groups: [$($operand_groups_processed:tt)*], operands_count: [$operands_count:expr], operands: [$($operands_processed:tt)*], results_count: [$results_count:expr], @@ -556,8 +1074,15 @@ macro_rules! __derive_op { $vis struct $Op; $($dialect_processed)*; - $(implement $OpTrait),*; + $($extra_fields_processed)*; + $(derives $DerivedOpTrait)*; + $(implements $ImplementedOpTrait)*; $($attrs_processed)*; + regions $regions_count; + $($regions_processed)*; + $($succ_groups_processed)*; + $($succ_processed)*; + $($operand_groups_processed)*; $($operands_processed)*; $($results_processed)*; ); @@ -572,25 +1097,48 @@ macro_rules! __derive_op_impl { $vis:vis struct $Op:ident; dialect $Dialect:ty; - $(implement $OpTrait:ident),*; + $(field $Field:ident: $FieldTy:ty)*; + $(derives $DerivedOpTrait:ident)*; + $(implements $ImplementedOpTrait:ident)*; $(attr $AttrField:ident: $AttrTy:ty)*; + regions $NumRegions:expr; + $(region $RegionField:ident at $RegionIdx:expr)*; + $(successors $SuccGroupField:ident: $SuccGroupTy:ty)*; + $(successor $SuccField:ident at $SuccIdx:expr)*; + $(operands $OperandGroupField:ident at $OperandGroupIdx:expr)*; $(operand $Operand:ident at $OperandIdx:expr)*; $(result $Result:ident at $ResultIdx:expr)*; ) => { $(#[$outer])* - #[derive(Spanned)] $vis struct $Op { - #[span] op: $crate::Operation, + $( + $Field: $FieldTy, + )* + $( + $SuccGroupField: $SuccGroupTy, + )* + } + impl ::midenc_session::diagnostics::Spanned for $Op { + fn span(&self) -> ::midenc_session::diagnostics::SourceSpan { + self.op.span() + } } #[allow(unused)] impl $Op { /// Get a new, uninitialized instance of this op - pub fn uninit() -> Self { + pub fn uninit($($Field: $FieldTy),*) -> Self { + let mut op = $crate::Operation::uninit::(); Self { - op: $crate::Operation::uninit::(), + op, + $( + $Field, + )* + $( + $SuccGroupField: Default::default(), + )* } } @@ -599,18 +1147,51 @@ macro_rules! __derive_op_impl { $( , $Operand: $crate::ValueRef )* + $( + , $OperandGroupField: impl IntoIterator + )* + $( + , $Field: $FieldTy + )* $( , $AttrField: $AttrTy )* + $( + , $SuccGroupField: $SuccGroupTy + )* + $( + , $SuccField: $crate::OpSuccessor + )* ) -> Result<$crate::UnsafeIntrusiveEntityRef<$Op>, $crate::Report> { - let mut builder = $crate::OperationBuilder::::new(context, Self::uninit()); + let mut this = Self::uninit($($Field),*); + $( + this.$SuccGroupField = $SuccGroupField.clone(); + )* + + let mut builder = $crate::OperationBuilder::::new(context, this); + $( + builder.implement::(); + )* $( - builder.implement::(); + builder.implement::(); )* $( builder.with_attr(stringify!($AttrField), $AttrField); )* builder.with_operands([$($Operand),*]); + $( + builder.with_operands_in_group($OperandGroupIdx, $OperandGroupField); + )* + $( + #[doc = stringify!($RegionField)] + builder.create_region(); + )* + $( + builder.with_successors($SuccGroupField); + )* + $( + builder.with_successor($SuccField); + )* let num_results = const { let results: &[usize] = &[$($ResultIdx),*]; results.len() @@ -620,22 +1201,117 @@ macro_rules! __derive_op_impl { } $( - fn $AttrField(&self) -> $AttrTy { + #[inline] + fn $Field(&self) -> &$FieldTy { + &self.$Field + } + + paste::paste! { + #[inline] + fn [<$Field _mut>](&mut self) -> &mut $FieldTy { + &mut self.$Field + } + + #[doc = concat!("Set the value of ", stringify!($Field))] + #[inline] + fn [](&mut self, $Field: $FieldTy) { + self.$Field = $Field; + } + } + )* + + $( + fn $AttrField(&self) -> &$AttrTy { let sym = stringify!($AttrField); - let value = self.op.get_attribute(&::midenc_hir_symbol::Symbol::intern(sym)).unwrap(); - value.downcast_ref::<$AttrTy>().unwrap().clone() + self.op.get_typed_attribute::<$AttrTy, _>(&::midenc_hir_symbol::Symbol::intern(sym)).unwrap() + } + + paste::paste! { + fn [<$AttrField _mut>](&mut self) -> &mut $AttrTy { + let sym = stringify!($AttrField); + self.op.get_typed_attribute_mut::<$AttrTy, _>(&::midenc_hir_symbol::Symbol::intern(sym)).unwrap() + } + + fn [](&mut self, value: $AttrTy) { + let sym = stringify!($AttrField); + self.op.set_attribute(::midenc_hir_symbol::Symbol::intern(sym), Some(value)); + } } )* $( - fn $Operand(&self) -> $crate::OpOperand { - self.operands()[$OperandIdx].clone() + fn $RegionField(&self) -> $crate::EntityRef<'_, $crate::Region> { + self.op.region($RegionIdx) + } + + paste::paste! { + fn [<$RegionField _mut>](&mut self) -> $crate::EntityMut<'_, $crate::Region> { + self.op.region_mut($RegionIdx) + } } )* $( - fn $Result(&self) -> $crate::ValueRef { - self.results()[$ResultIdx].clone() + #[inline] + fn $SuccGroupField(&self) -> &$SuccGroupTy { + &self.$SuccGroupField + } + + paste::paste! { + #[inline] + fn [<$SuccGroupField _mut>](&mut self) -> &mut $SuccGroupTy { + &mut self.$SuccGroupField + } + } + )* + + $( + #[inline] + fn $SuccField(&self) -> &$crate::OpSuccessor { + &self.successors()[$SuccIdx] + } + + paste::paste! { + #[inline] + fn [<$SuccField _mut>](&mut self) -> &mut $crate::OpSuccessor { + &mut self.successors_mut()[$SuccIdx] + } + } + )* + + $( + fn $Operand(&self) -> $crate::EntityRef<'_, $crate::OpOperandImpl> { + self.op.operands()[$OperandIdx].borrow() + } + + paste::paste! { + fn [<$Operand _mut>](&mut self) -> $crate::EntityMut<'_, $crate::OpOperandImpl> { + self.op.operands_mut()[$OperandIdx].borrow_mut() + } + } + )* + + $( + fn $OperandGroupField(&self) -> $crate::OpOperandRange<'_> { + self.op.operands().group($OperandGroupIdx) + } + + paste::paste! { + fn [<$OperandGroupField _mut>](&mut self) -> $crate::OpOperandRangeMut<'_> { + self.op.operands_mut().group_mut($OperandGroupIdx) + } + } + )* + + $( + fn $Result(&self) -> $crate::EntityRef<'_, dyn $crate::Value> { + self.results()[$ResultIdx].borrow() + } + + paste::paste! { + fn [<$Result _mut>](&mut self) -> $crate::EntityMut<'_, dyn $crate::Value> { + self.op.results_mut()[$ResultIdx].borrow_mut() + } } )* } @@ -654,14 +1330,12 @@ macro_rules! __derive_op_impl { } } - __derive_op_name!($Op); + $crate::__derive_op_name!($Op, $Dialect); impl $crate::Op for $Op { fn name(&self) -> $crate::OperationName { - const DIALECT: $Dialect = <$Dialect as $crate::Dialect>::INIT; - let dialect = <$Dialect as $crate::Dialect>::name(&DIALECT); paste::paste! { - $crate::OperationName::new(dialect, *[<__ $Op _NAME>]) + *[<__ $Op _NAME>] } } @@ -676,17 +1350,19 @@ macro_rules! __derive_op_impl { } } - __derive_op_traits!($Op, $($OpTrait),*); + $crate::__derive_op_traits!($Op $(, derive $DerivedOpTrait)* $(, implement $ImplementedOpTrait)*); }; } #[doc(hidden)] #[macro_export] macro_rules! __derive_op_name { - ($Op:ident) => { + ($Op:ident, $Dialect:ty) => { paste::paste! { #[allow(non_upper_case_globals)] - static [<__ $Op _NAME>]: ::std::sync::LazyLock<::midenc_hir_symbol::Symbol> = ::std::sync::LazyLock::new(|| { + static [<__ $Op _NAME>]: ::std::sync::LazyLock<$crate::OperationName> = ::std::sync::LazyLock::new(|| { + const DIALECT: $Dialect = <$Dialect as $crate::Dialect>::INIT; + // CondBrOp => CondBr => cond_br // Add => add let type_name = stringify!($Op); @@ -709,7 +1385,9 @@ macro_rules! __derive_op_name { buf.push(c); } } - ::midenc_hir_symbol::Symbol::intern(buf) + let name = ::midenc_hir_symbol::Symbol::intern(buf); + let dialect = <$Dialect as $crate::Dialect>::name(&DIALECT); + $crate::OperationName::new(dialect, name) }); } } @@ -728,27 +1406,35 @@ macro_rules! __derive_op_traits { } }; - ($T:ty, $($Trait:ident),+) => { + ($T:ty $(, derive $DeriveTrait:ident)* $(, implement $ImplementTrait:ident)*) => { $( - impl $Trait for $T {} + impl $DeriveTrait for $T {} )* impl $crate::OpVerifier for $T { fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { #[allow(unused_parens)] - type OpVerifierImpl<'a> = $crate::derive::DeriveVerifier<'a, $T, ($(&'a dyn $Trait),*)>; + type OpVerifierImpl<'a> = $crate::derive::DeriveVerifier<'a, $T, ($(&'a dyn $DeriveTrait,)* $(&'a dyn $ImplementTrait),*)>; #[allow(unused_parens)] - impl<'a> $crate::OpVerifier for $crate::derive::DeriveVerifier<'a, $T, ($(&'a dyn $Trait),*)> + impl<'a> $crate::OpVerifier for $crate::derive::DeriveVerifier<'a, $T, ($(&'a dyn $DeriveTrait,)* $(&'a dyn $ImplementTrait),*)> where $( - $T: $crate::verifier::Verifier - ),* + $T: $crate::verifier::Verifier, + )* + $( + $T: $crate::verifier::Verifier, + )* { fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { let op = self.downcast_ref::<$T>().unwrap(); $( - if const { !<$T as $crate::verifier::Verifier>::VACUOUS } { - <$T as $crate::verifier::Verifier>::maybe_verify(op, context)?; + if const { !<$T as $crate::verifier::Verifier>::VACUOUS } { + <$T as $crate::verifier::Verifier>::maybe_verify(op, context)?; + } + )* + $( + if const { !<$T as $crate::verifier::Verifier>::VACUOUS } { + <$T as $crate::verifier::Verifier>::maybe_verify(op, context)?; } )* @@ -756,8 +1442,7 @@ macro_rules! __derive_op_traits { } } - let op = self.as_operation(); - let verifier = OpVerifierImpl::new(op); + let verifier = OpVerifierImpl::new(&self.op); verifier.verify(context) } } @@ -802,7 +1487,7 @@ mod tests { use crate::{ define_attr_type, dialects::hir::HirDialect, traits::*, Context, Op, Operation, Report, - SourceSpan, Spanned, + Spanned, }; #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -822,7 +1507,7 @@ mod tests { derive! { /// An example op implementation to make sure all of the type machinery works - struct AddOp : Op implements SingleBlock, SameTypeOperands, ArithmeticOp { + struct AddOp : Op { #[dialect] dialect: HirDialect, #[attr] @@ -832,8 +1517,13 @@ mod tests { #[operand] rhs: OpOperand, } + + derives SingleBlock, SameTypeOperands; + implements ArithmeticOp; } + impl ArithmeticOp for AddOp {} + derive! { /// A marker trait for arithmetic ops trait ArithmeticOp {} diff --git a/hir2/src/dialects/hir/ops/binary.rs b/hir2/src/dialects/hir/ops/binary.rs index 2fcd705c0..e7b5c2b8e 100644 --- a/hir2/src/dialects/hir/ops/binary.rs +++ b/hir2/src/dialects/hir/ops/binary.rs @@ -3,7 +3,7 @@ use crate::{dialects::hir::HirDialect, traits::*, *}; macro_rules! derive_binary_op_with_overflow { ($Op:ident) => { derive! { - pub struct $Op: Op implements BinaryOp { + pub struct $Op: Op { #[dialect] dialect: HirDialect, #[operand] @@ -15,12 +15,14 @@ macro_rules! derive_binary_op_with_overflow { #[attr] overflow: Overflow, } + + derives BinaryOp; } }; - ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { derive! { - pub struct $Op: Op implements BinaryOp, $OpTrait $(, $OpTraitRest)* { + pub struct $Op: Op { #[dialect] dialect: HirDialect, #[operand] @@ -32,6 +34,8 @@ macro_rules! derive_binary_op_with_overflow { #[attr] overflow: Overflow, } + + derives BinaryOp, $OpTrait $(, $OpTraitRest)*; } }; } @@ -39,7 +43,7 @@ macro_rules! derive_binary_op_with_overflow { macro_rules! derive_binary_op { ($Op:ident) => { derive! { - pub struct $Op: Op implements BinaryOp { + pub struct $Op: Op { #[dialect] dialect: HirDialect, #[operand] @@ -49,12 +53,14 @@ macro_rules! derive_binary_op { #[result] result: OpResultRef, } + + derives BinaryOp; } }; - ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { derive! { - pub struct $Op: Op implements BinaryOp, $OpTrait $(, $OpTraitRest)* { + pub struct $Op: Op { #[dialect] dialect: HirDialect, #[operand] @@ -64,66 +70,68 @@ macro_rules! derive_binary_op { #[result] result: OpResultRef, } + + derives BinaryOp, $OpTrait $(, $OpTraitRest)*; } }; } macro_rules! derive_binary_logical_op { ($Op:ident) => { - derive_binary_op!($Op implements SameTypeOperands, SameOperandsAndResultType, Commutative); + derive_binary_op!($Op derives SameTypeOperands, SameOperandsAndResultType, Commutative); }; - ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_binary_op!($Op implements SameTypeOperands, SameOperandsAndResultType, Commutative, $OpTrait $(, $OpTraitRest)*); + ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive_binary_op!($Op derives SameTypeOperands, SameOperandsAndResultType, Commutative, $OpTrait $(, $OpTraitRest)*); }; } macro_rules! derive_binary_bitwise_op { ($Op:ident) => { - derive_binary_op!($Op implements SameTypeOperands, SameOperandsAndResultType); + derive_binary_op!($Op derives SameTypeOperands, SameOperandsAndResultType); }; - ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_binary_op!($Op implements SameTypeOperands, SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); + ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive_binary_op!($Op derives SameTypeOperands, SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); }; } macro_rules! derive_binary_comparison_op { ($Op:ident) => { - derive_binary_op!($Op implements SameTypeOperands); + derive_binary_op!($Op derives SameTypeOperands); }; - ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_binary_op!($Op implements SameTypeOperands, $OpTrait $(, $OpTraitRest)*); + ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { + derive_binary_op!($Op derives SameTypeOperands, $OpTrait $(, $OpTraitRest)*); }; } -derive_binary_op_with_overflow!(Add implements Commutative, SameTypeOperands); -derive_binary_op_with_overflow!(Sub implements SameTypeOperands); -derive_binary_op_with_overflow!(Mul implements Commutative, SameTypeOperands); +derive_binary_op_with_overflow!(Add derives Commutative, SameTypeOperands); +derive_binary_op_with_overflow!(Sub derives SameTypeOperands); +derive_binary_op_with_overflow!(Mul derives Commutative, SameTypeOperands); derive_binary_op_with_overflow!(Exp); -derive_binary_op!(Div implements SameTypeOperands, SameOperandsAndResultType); -derive_binary_op!(Mod implements SameTypeOperands, SameOperandsAndResultType); -derive_binary_op!(DivMod implements SameTypeOperands, SameOperandsAndResultType); +derive_binary_op!(Div derives SameTypeOperands, SameOperandsAndResultType); +derive_binary_op!(Mod derives SameTypeOperands, SameOperandsAndResultType); +derive_binary_op!(DivMod derives SameTypeOperands, SameOperandsAndResultType); derive_binary_logical_op!(And); derive_binary_logical_op!(Or); derive_binary_logical_op!(Xor); -derive_binary_bitwise_op!(Band implements Commutative); -derive_binary_bitwise_op!(Bor implements Commutative); -derive_binary_bitwise_op!(Bxor implements Commutative); +derive_binary_bitwise_op!(Band derives Commutative); +derive_binary_bitwise_op!(Bor derives Commutative); +derive_binary_bitwise_op!(Bxor derives Commutative); derive_binary_op!(Shl); derive_binary_op!(Shr); derive_binary_op!(Rotl); derive_binary_op!(Rotr); -derive_binary_comparison_op!(Eq implements Commutative); -derive_binary_comparison_op!(Neq implements Commutative); +derive_binary_comparison_op!(Eq derives Commutative); +derive_binary_comparison_op!(Neq derives Commutative); derive_binary_comparison_op!(Gt); derive_binary_comparison_op!(Gte); derive_binary_comparison_op!(Lt); derive_binary_comparison_op!(Lte); -derive_binary_comparison_op!(Min implements Commutative); -derive_binary_comparison_op!(Max implements Commutative); +derive_binary_comparison_op!(Min derives Commutative); +derive_binary_comparison_op!(Max derives Commutative); diff --git a/hir2/src/dialects/hir/ops/cast.rs b/hir2/src/dialects/hir/ops/cast.rs index 829d47350..3d51841f0 100644 --- a/hir2/src/dialects/hir/ops/cast.rs +++ b/hir2/src/dialects/hir/ops/cast.rs @@ -30,7 +30,7 @@ pub enum CastKind { */ derive! { - pub struct PtrToInt : Op implements UnaryOp { + pub struct PtrToInt : Op { #[dialect] dialect: HirDialect, #[attr] @@ -40,10 +40,12 @@ derive! { #[result] result: OpResult, } + + derives UnaryOp; } derive! { - pub struct IntToPtr : Op implements UnaryOp { + pub struct IntToPtr : Op { #[dialect] dialect: HirDialect, #[attr] @@ -53,10 +55,12 @@ derive! { #[result] result: OpResult, } + + derives UnaryOp; } derive! { - pub struct Cast : Op implements UnaryOp { + pub struct Cast : Op { #[dialect] dialect: HirDialect, #[attr] @@ -66,10 +70,12 @@ derive! { #[result] result: OpResult, } + + derives UnaryOp; } derive! { - pub struct Bitcast : Op implements UnaryOp { + pub struct Bitcast : Op { #[dialect] dialect: HirDialect, #[attr] @@ -79,10 +85,12 @@ derive! { #[result] result: OpResult, } + + derives UnaryOp; } derive! { - pub struct Trunc : Op implements UnaryOp { + pub struct Trunc : Op { #[dialect] dialect: HirDialect, #[attr] @@ -92,10 +100,12 @@ derive! { #[result] result: OpResult, } + + derives UnaryOp; } derive! { - pub struct Zext : Op implements UnaryOp { + pub struct Zext : Op { #[dialect] dialect: HirDialect, #[attr] @@ -105,10 +115,12 @@ derive! { #[result] result: OpResult, } + + derives UnaryOp; } derive! { - pub struct Sext : Op implements UnaryOp { + pub struct Sext : Op { #[dialect] dialect: HirDialect, #[attr] @@ -118,4 +130,6 @@ derive! { #[result] result: OpResult, } + + derives UnaryOp; } diff --git a/hir2/src/dialects/hir/ops/control.rs b/hir2/src/dialects/hir/ops/control.rs index f8cc57d5c..71a2fd871 100644 --- a/hir2/src/dialects/hir/ops/control.rs +++ b/hir2/src/dialects/hir/ops/control.rs @@ -3,27 +3,31 @@ use smallvec::SmallVec; use crate::{dialects::hir::HirDialect, traits::*, *}; derive! { - pub struct Ret : Op implements Terminator { + pub struct Ret : Op { #[dialect] dialect: HirDialect, #[operand] value: OpOperand, } + + derives Terminator; } // TODO(pauls): RetImm derive! { - pub struct Br : Op implements Terminator { + pub struct Br : Op { #[dialect] dialect: HirDialect, #[successor] target: Successor, } + + derives Terminator; } derive! { - pub struct CondBr : Op implements Terminator { + pub struct CondBr : Op { #[dialect] dialect: HirDialect, #[operand] @@ -33,10 +37,12 @@ derive! { #[successor] else_dest: Successor, } + + derives Terminator; } derive! { - pub struct Switch : Op implements Terminator { + pub struct Switch : Op { #[dialect] dialect: HirDialect, #[operand] @@ -46,17 +52,33 @@ derive! { #[successor] fallback: Successor, } + + derives Terminator; } // TODO(pauls): Implement `SuccessorInterface` for this type #[derive(Debug, Clone)] pub struct SwitchCase { pub value: u32, - pub successor: Successor, + pub successor: OpSuccessor, +} + +impl From for OpSuccessor { + #[inline] + fn from(value: SwitchCase) -> Self { + value.successor + } +} + +impl From<&SwitchCase> for OpSuccessor { + #[inline] + fn from(value: &SwitchCase) -> Self { + value.successor.clone() + } } derive! { - pub struct If : Op implements SingleBlock, NoRegionArguments { + pub struct If : Op { #[dialect] dialect: HirDialect, #[operand] @@ -66,23 +88,24 @@ derive! { #[region] else_body: Region, } -} -/// A while is a loop structure composed of two regions: a "before" region, and an "after" region. -/// -/// The "before" region's entry block parameters correspond to the operands expected by the -/// operation, and can be used to compute the condition that determines whether the "after" body -/// is executed or not, or simply forwarded to the "after" region. The "before" region must -/// terminate with a [Condition] operation, which will be evaluated to determine whether or not -/// to continue the loop. -/// -/// The "after" region corresponds to the loop body, and must terminate with a [Yield] operation, -/// whose operands must be of the same arity and type as the "before" region's argument list. In -/// this way, the "after" body can feed back input to the "before" body to determine whether to -/// continue the loop. + derives SingleBlock, NoRegionArguments; +} derive! { - pub struct While : Op implements SingleBlock { + /// A while is a loop structure composed of two regions: a "before" region, and an "after" region. + /// + /// The "before" region's entry block parameters correspond to the operands expected by the + /// operation, and can be used to compute the condition that determines whether the "after" body + /// is executed or not, or simply forwarded to the "after" region. The "before" region must + /// terminate with a [Condition] operation, which will be evaluated to determine whether or not + /// to continue the loop. + /// + /// The "after" region corresponds to the loop body, and must terminate with a [Yield] operation, + /// whose operands must be of the same arity and type as the "before" region's argument list. In + /// this way, the "after" body can feed back input to the "before" body to determine whether to + /// continue the loop. + pub struct While : Op { #[dialect] dialect: HirDialect, #[region] @@ -90,20 +113,26 @@ derive! { #[region] after: Region, } + + derives SingleBlock; } derive! { - pub struct Condition : Op implements Terminator, ReturnLike { + pub struct Condition : Op { #[dialect] dialect: HirDialect, #[operand] value: OpOperand, } + + derives Terminator, ReturnLike; } derive! { - pub struct Yield : Op implements Terminator, ReturnLike { + pub struct Yield : Op { #[dialect] dialect: HirDialect, } + + derives Terminator, ReturnLike; } diff --git a/hir2/src/dialects/hir/ops/invoke.rs b/hir2/src/dialects/hir/ops/invoke.rs index 2800740ea..b326ee510 100644 --- a/hir2/src/dialects/hir/ops/invoke.rs +++ b/hir2/src/dialects/hir/ops/invoke.rs @@ -4,16 +4,20 @@ use crate::{dialects::hir::HirDialect, traits::*, *}; // // * Inferring op constraints from callee signature derive! { - pub struct Exec : Op implements CallInterface { + pub struct Exec : Op { #[dialect] dialect: HirDialect, #[attr] - callee: FunctionIdent, + callee: SymbolNameAttr, + #[operands] + arguments: Vec, } + + implements CallOpInterface; } derive! { - pub struct ExecIndirect : Op implements CallInterface { + pub struct ExecIndirect : Op { #[dialect] dialect: HirDialect, #[attr] @@ -22,3 +26,41 @@ derive! { callee: OpOperand, } } + +impl CallOpInterface for Exec { + #[inline(always)] + fn callable_for_callee(&self) -> Callable { + self.callee().into() + } + + fn set_callee(&mut self, callable: Callable) { + let callee = callable.unwrap_symbol_name(); + *self.callee_mut() = callee; + } + + #[inline(always)] + fn arguments(&self) -> OpOperandRange<'_> { + self.operands().group(0) + } + + #[inline(always)] + fn arguments_mut(&mut self) -> OpOperandRangeMut<'_> { + self.operands_mut().group_mut(0) + } + + fn resolve(&self) -> Option { + let callee = self.callee(); + if callee.has_parent() { + todo!() + } + let module = self.as_operation().nearest_symbol_table()?; + let module = module.borrow(); + let symbol_table = module.as_trait::().unwrap(); + symbol_table.get(callee.name) + } + + fn resolve_in_symbol_table(&self, symbols: &dyn crate::SymbolTable) -> Option { + let callee = self.callee(); + symbols.get(callee.name) + } +} diff --git a/hir2/src/dialects/hir/ops/mem.rs b/hir2/src/dialects/hir/ops/mem.rs index 131d93231..da053079d 100644 --- a/hir2/src/dialects/hir/ops/mem.rs +++ b/hir2/src/dialects/hir/ops/mem.rs @@ -1,7 +1,7 @@ use crate::{dialects::hir::HirDialect, traits::*, *}; derive! { - pub struct Store : Op implements HasSideEffects, MemoryWrite { + pub struct Store : Op { #[dialect] dialect: HirDialect, #[operand] @@ -9,17 +9,21 @@ derive! { #[operand] value: OpOperand, } + + derives HasSideEffects, MemoryWrite; } // TODO(pauls): StoreLocal derive! { - pub struct Load : Op implements HasSideEffects, MemoryRead { + pub struct Load : Op { #[dialect] dialect: HirDialect, #[operand] addr: OpOperand, } + + derives HasSideEffects, MemoryRead; } // TODO(pauls): LoadLocal diff --git a/hir2/src/dialects/hir/ops/primop.rs b/hir2/src/dialects/hir/ops/primop.rs index 8cf786a1f..744f50d42 100644 --- a/hir2/src/dialects/hir/ops/primop.rs +++ b/hir2/src/dialects/hir/ops/primop.rs @@ -1,7 +1,7 @@ use crate::{dialects::hir::HirDialect, traits::*, *}; derive! { - pub struct MemGrow : Op implements HasSideEffects, MemoryRead, MemoryWrite { + pub struct MemGrow : Op { #[dialect] dialect: HirDialect, #[operand] @@ -9,19 +9,23 @@ derive! { #[result] result: OpResult, } + + derives HasSideEffects, MemoryRead, MemoryWrite; } derive! { - pub struct MemSize : Op implements HasSideEffects, MemoryRead { + pub struct MemSize : Op { #[dialect] dialect: HirDialect, #[result] result: OpResult, } + + derives HasSideEffects, MemoryRead; } derive! { - pub struct MemSet : Op implements HasSideEffects, MemoryWrite { + pub struct MemSet : Op { #[dialect] dialect: HirDialect, #[operand] @@ -33,10 +37,12 @@ derive! { #[result] result: OpResult, } + + derives HasSideEffects, MemoryWrite; } derive! { - pub struct MemCpy : Op implements HasSideEffects, MemoryRead, MemoryWrite { + pub struct MemCpy : Op { #[dialect] dialect: HirDialect, #[operand] @@ -48,4 +54,6 @@ derive! { #[result] result: OpResult, } + + derives HasSideEffects, MemoryRead, MemoryWrite; } diff --git a/hir2/src/dialects/hir/ops/unary.rs b/hir2/src/dialects/hir/ops/unary.rs index ad1fb7b7f..e097923f2 100644 --- a/hir2/src/dialects/hir/ops/unary.rs +++ b/hir2/src/dialects/hir/ops/unary.rs @@ -3,7 +3,7 @@ use crate::{dialects::hir::HirDialect, traits::*, *}; macro_rules! derive_unary_op { ($Op:ident) => { derive! { - pub struct $Op: Op implements UnaryOp { + pub struct $Op: Op { #[dialect] dialect: HirDialect, #[operand] @@ -11,12 +11,14 @@ macro_rules! derive_unary_op { #[result] result: OpResultRef, } + + derives UnaryOp; } }; - ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { + ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { derive! { - pub struct $Op: Op implements UnaryOp, $OpTrait $(, $OpTraitRest)* { + pub struct $Op: Op { #[dialect] dialect: HirDialect, #[operand] @@ -24,35 +26,37 @@ macro_rules! derive_unary_op { #[result] result: OpResultRef, } + + derives UnaryOp, $OpTrait $(, $OpTraitRest)*; } }; } macro_rules! derive_unary_logical_op { ($Op:ident) => { - derive_unary_op!($Op implements SameOperandsAndResultType); + derive_unary_op!($Op derives SameOperandsAndResultType); }; ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_unary_op!($Op implements SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); + derive_unary_op!($Op derives SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); }; } macro_rules! derive_unary_bitwise_op { ($Op:ident) => { - derive_unary_op!($Op implements SameOperandsAndResultType); + derive_unary_op!($Op derives SameOperandsAndResultType); }; ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_unary_op!($Op implements SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); + derive_unary_op!($Op derives SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); }; } -derive_unary_op!(Neg implements SameOperandsAndResultType); -derive_unary_op!(Inv implements SameOperandsAndResultType); -derive_unary_op!(Incr implements SameOperandsAndResultType); -derive_unary_op!(Ilog2 implements SameOperandsAndResultType); -derive_unary_op!(Pow2 implements SameOperandsAndResultType); +derive_unary_op!(Neg derives SameOperandsAndResultType); +derive_unary_op!(Inv derives SameOperandsAndResultType); +derive_unary_op!(Incr derives SameOperandsAndResultType); +derive_unary_op!(Ilog2 derives SameOperandsAndResultType); +derive_unary_op!(Pow2 derives SameOperandsAndResultType); derive_unary_logical_op!(Not); derive_unary_logical_op!(IsOdd); diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index b869bbfd4..84a4f3921 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -7,17 +7,21 @@ mod entity; mod function; mod ident; mod immediates; +mod insert; mod interface; mod module; mod op; +mod operands; mod operation; mod region; -mod symbol_table; +mod successor; +pub(crate) mod symbol_table; pub mod traits; mod types; mod usable; mod value; pub(crate) mod verifier; +mod visit; pub use midenc_hir_symbol as interner; pub use midenc_session::diagnostics::{Report, SourceSpan, Spanned}; @@ -37,18 +41,25 @@ pub use self::{ function::{AbiParam, ArgumentExtension, ArgumentPurpose, Function, Signature}, ident::{FunctionIdent, Ident}, immediates::{Felt, FieldElement, Immediate, StarkField}, + insert::{Insert, InsertionPoint, ProgramPoint}, module::Module, op::{Op, OpExt}, + operands::{ + OpOperand, OpOperandImpl, OpOperandList, OpOperandRange, OpOperandRangeMut, + OpOperandStorage, + }, operation::{ - OpCursor, OpCursorMut, OpList, OpSuccessor, Operation, OperationBuilder, OperationName, - OperationRef, + OpCursor, OpCursorMut, OpList, Operation, OperationBuilder, OperationName, OperationRef, }, region::{Region, RegionCursor, RegionCursorMut, RegionList, RegionRef}, - symbol_table::{Symbol, SymbolTable}, + successor::OpSuccessor, + symbol_table::{ + Symbol, SymbolName, SymbolNameAttr, SymbolNameComponent, SymbolRef, SymbolTable, SymbolUse, + SymbolUseCursor, SymbolUseCursorMut, SymbolUseIter, SymbolUseList, SymbolUseRef, + }, types::*, usable::Usable, - value::{ - BlockArgument, BlockArgumentRef, OpOperand, OpResult, OpResultRef, Value, ValueId, ValueRef, - }, + value::{BlockArgument, BlockArgumentRef, OpResult, OpResultRef, Value, ValueId, ValueRef}, verifier::{OpVerifier, Verify}, + visit::{OpVisitor, OperationVisitor, Searcher, SymbolVisitor, Visitor}, }; diff --git a/hir2/src/ir/attribute.rs b/hir2/src/ir/attribute.rs index d39a47d95..aaf7831e4 100644 --- a/hir2/src/ir/attribute.rs +++ b/hir2/src/ir/attribute.rs @@ -124,6 +124,19 @@ impl AttributeSet { } } + /// Get the [AttributeValue] associated with the named [Attribute] + pub fn get_any_mut(&mut self, key: &Q) -> Option<&mut dyn AttributeValue> + where + Symbol: Borrow, + Q: Ord + ?Sized, + { + let key = key.borrow(); + match self.0.binary_search_by(|attr| key.cmp(attr.name.borrow())) { + Ok(index) => self.0[index].value.as_deref_mut(), + Err(_) => None, + } + } + /// Get the value associated with the named [Attribute] as a value of type `V`, or `None`. pub fn get(&self, key: &Q) -> Option<&V> where @@ -134,6 +147,16 @@ impl AttributeSet { self.get_any(key).and_then(|v| v.downcast_ref::()) } + /// Get the value associated with the named [Attribute] as a value of type `V`, or `None`. + pub fn get_mut(&mut self, key: &Q) -> Option<&mut V> + where + Symbol: Borrow, + Q: Ord + ?Sized, + V: AttributeValue, + { + self.get_any_mut(key).and_then(|v| v.downcast_mut::()) + } + /// Iterate over each [Attribute] in this set pub fn iter(&self) -> impl Iterator + '_ { self.0.iter() diff --git a/hir2/src/ir/block.rs b/hir2/src/ir/block.rs index abc0ca9fe..8479b195b 100644 --- a/hir2/src/ir/block.rs +++ b/hir2/src/ir/block.rs @@ -2,12 +2,16 @@ use core::fmt; use super::*; +/// A pointer to a [Block] pub type BlockRef = UnsafeIntrusiveEntityRef; /// An intrusive, doubly-linked list of [Block] pub type BlockList = EntityList; +/// A cursor into a [BlockList] pub type BlockCursor<'a> = EntityCursor<'a, Block>; +/// A mutable cursor into a [BlockList] pub type BlockCursorMut<'a> = EntityCursorMut<'a, Block>; +/// The unique identifier for a [Block] #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct BlockId(u32); @@ -37,6 +41,21 @@ impl fmt::Display for BlockId { } } +/// Represents a basic block in the IR. +/// +/// Basic blocks are used in SSA regions to provide the structure of the control-flow graph. +/// Operations within a basic block appear in the order they will be executed. +/// +/// A block must have a [traits::Terminator], an operation which transfers control to another block +/// in the same region, or out of the containing operation (e.g. returning from a function). +/// +/// Blocks have _predecessors_ and _successors_, representing the inbound and outbound edges +/// (respectively) formed by operations that transfer control between blocks. A block can have +/// zero or more predecessors and/or successors. A well-formed region will generally only have a +/// single block (the entry block) with no predecessors (i.e. no unreachable blocks), and no blocks +/// with both multiple predecessors _and_ multiple successors (i.e. no critical edges). It is valid +/// to have both unreachable blocks and critical edges in the IR, but they must be removed during +/// the course of compilation. pub struct Block { /// The unique id of this block id: BlockId, @@ -73,10 +92,6 @@ impl Entity for Block { impl Usable for Block { type Use = BlockOperand; - fn is_used(&self) -> bool { - !self.uses.is_empty() - } - #[inline(always)] fn uses(&self) -> &BlockOperandList { &self.uses @@ -86,22 +101,6 @@ impl Usable for Block { fn uses_mut(&mut self) -> &mut BlockOperandList { &mut self.uses } - - fn iter_uses(&self) -> BlockOperandIter<'_> { - self.uses.iter() - } - - fn first_use(&self) -> BlockOperandCursor<'_> { - self.uses.front() - } - - fn first_use_mut(&mut self) -> BlockOperandCursorMut<'_> { - self.uses.front_mut() - } - - fn insert_use(&mut self, user: BlockOperandRef) { - self.uses.push_back(user); - } } impl Block { pub fn new(id: BlockId) -> Self { @@ -177,7 +176,9 @@ impl Block { pub type BlockOperandRef = UnsafeIntrusiveEntityRef; /// An intrusive, doubly-linked list of [BlockOperand] pub type BlockOperandList = EntityList; +#[allow(unused)] pub type BlockOperandCursor<'a> = EntityCursor<'a, BlockOperand>; +#[allow(unused)] pub type BlockOperandCursorMut<'a> = EntityCursorMut<'a, BlockOperand>; pub type BlockOperandIter<'a> = EntityIter<'a, BlockOperand>; diff --git a/hir2/src/ir/context.rs b/hir2/src/ir/context.rs index 181cc4f24..ba940f98c 100644 --- a/hir2/src/ir/context.rs +++ b/hir2/src/ir/context.rs @@ -6,6 +6,18 @@ use midenc_session::Session; use super::*; +/// Represents the shared state of the IR, used during a compilation session. +/// +/// The primary purpose(s) of the context are: +/// +/// * Provide storage/memory for all allocated IR entities for the lifetime of the session. +/// * Provide unique value and block identifiers for printing the IR +/// * Provide a uniqued constant pool +/// * Provide configuration used during compilation +/// +/// # Safety +/// +/// The [Context] _must_ live as long as any reference to an IR entity may be dereferenced. pub struct Context { pub session: Rc, allocator: Rc, @@ -30,6 +42,7 @@ impl Default for Context { } impl Context { + /// Create a new [Context] for the given [Session] pub fn new(session: Rc) -> Self { let allocator = Rc::new(Blink::new()); Self { diff --git a/hir2/src/ir/dialect.rs b/hir2/src/ir/dialect.rs index daffc8dc1..11463fd9b 100644 --- a/hir2/src/ir/dialect.rs +++ b/hir2/src/ir/dialect.rs @@ -1,5 +1,8 @@ use core::ops::Deref; +/// A [Dialect] represents a collection of IR entities that are used in conjunction with one +/// another. Multiple dialects can co-exist _or_ be mutually exclusive. Converting between dialects +/// is the job of the conversion infrastructure, using a process called _legalization_. pub trait Dialect { const INIT: Self; diff --git a/hir2/src/ir/entity.rs b/hir2/src/ir/entity.rs index 85a8c2f47..2d41c7bab 100644 --- a/hir2/src/ir/entity.rs +++ b/hir2/src/ir/entity.rs @@ -13,12 +13,16 @@ use core::{ pub use self::list::{EntityCursor, EntityCursorMut, EntityIter, EntityList}; +/// A trait implemented by an IR entity that has a unique identifier +/// +/// Currently, this is used only for [Value]s and [Block]s. pub trait Entity: Any { type Id: EntityId; fn id(&self) -> Self::Id; } +/// A trait that must be implemented by the unique identifier for an [Entity] pub trait EntityId: Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash { fn as_usize(&self) -> usize; } @@ -76,8 +80,10 @@ impl fmt::Display for AliasingViolationError { } } +/// A raw pointer to an IR entity that has no associated metadata pub type UnsafeEntityRef = RawEntityRef; +/// A raw pointer to an IR entity that has an intrusive linked-list link as its metadata pub type UnsafeIntrusiveEntityRef = RawEntityRef; /// A [RawEntityRef] is an unsafe smart pointer type for IR entities allocated in a [Context]. @@ -386,6 +392,18 @@ impl fmt::Debug for RawEntityRef } } +impl Eq for RawEntityRef {} +impl PartialEq for RawEntityRef { + fn eq(&self, other: &Self) -> bool { + Self::ptr_eq(self, other) + } +} +impl core::hash::Hash for RawEntityRef { + fn hash(&self, state: &mut H) { + self.inner.hash(state); + } +} + /// A guard that ensures a reference to an IR entity cannot be mutably aliased pub struct EntityRef<'b, T: ?Sized + 'b> { value: NonNull, @@ -633,6 +651,16 @@ impl RawEntityMetadata { } } impl RawEntityMetadata { + pub(self) fn borrow(&self) -> EntityRef<'_, T> { + let ptr = self as *const Self; + unsafe { (*core::ptr::addr_of!((*ptr).entity)).borrow() } + } + + pub(self) fn borrow_mut(&self) -> EntityMut<'_, T> { + let ptr = (self as *const Self).cast_mut(); + unsafe { (*core::ptr::addr_of_mut!((*ptr).entity)).borrow_mut() } + } + #[inline] const fn metadata_offset() -> usize { core::mem::offset_of!(RawEntityMetadata<(), Metadata>, metadata) diff --git a/hir2/src/ir/entity/list.rs b/hir2/src/ir/entity/list.rs index 5cb5e5ee0..7c38fe5ef 100644 --- a/hir2/src/ir/entity/list.rs +++ b/hir2/src/ir/entity/list.rs @@ -249,6 +249,12 @@ impl<'a, T> EntityCursor<'a, T> { self.cursor.clone_pointer() } + /// Consume the cursor and convert it into a borrow of the current entity, or `None` if null. + #[inline] + pub fn into_borrow(self) -> Option> { + self.cursor.get().map(|item| item.borrow()) + } + /// Moves the cursor to the next element of the [EntityList]. /// /// If the cursor is pointing to the null object then this will move it to the front of the @@ -347,6 +353,18 @@ impl<'a, T> EntityCursorMut<'a, T> { self.cursor.as_cursor().clone_pointer() } + /// Consume the cursor and convert it into a borrow of the current entity, or `None` if null. + #[inline] + pub fn into_borrow(self) -> Option> { + self.cursor.into_ref().map(|item| item.borrow()) + } + + /// Consume the cursor and convert it into a mutable borrow of the current entity, or `None` if null. + #[inline] + pub fn into_borrow_mut(self) -> Option> { + self.cursor.into_ref().map(|item| item.borrow_mut()) + } + /// Moves the cursor to the next element of the [EntityList]. /// /// If the cursor is pointing to the null object then this will move it to the front of the diff --git a/hir2/src/ir/function.rs b/hir2/src/ir/function.rs index 6afd9012c..498728ef8 100644 --- a/hir2/src/ir/function.rs +++ b/hir2/src/ir/function.rs @@ -1,25 +1,128 @@ use core::fmt; use super::*; -use crate::{formatter, CallConv, Linkage}; +use crate::{ + derive, + dialects::hir::HirDialect, + formatter, + traits::{CallableOpInterface, SingleRegion}, + CallConv, Symbol, SymbolName, SymbolUse, SymbolUseIter, SymbolUseList, Visibility, +}; -#[derive(Spanned)] -pub struct Function { - #[span] - op: Operation, - id: FunctionIdent, - signature: Signature, +trait UsableSymbol = Usable; + +derive! { + pub struct Function: Op { + #[dialect] + dialect: HirDialect, + #[region] + body: RegionRef, + #[attr] + name: Ident, + #[attr] + signature: Signature, + /// The uses of this function as a symbol + uses: SymbolUseList, + } + + derives SingleRegion; + implements UsableSymbol, Symbol, CallableOpInterface; +} + +impl Usable for Function { + type Use = SymbolUse; + + #[inline(always)] + fn uses(&self) -> &EntityList { + &self.uses + } + + #[inline(always)] + fn uses_mut(&mut self) -> &mut EntityList { + &mut self.uses + } } + impl Symbol for Function { - type Id = Ident; + #[inline(always)] + fn as_operation(&self) -> &Operation { + &self.op + } + + #[inline(always)] + fn as_operation_mut(&mut self) -> &mut Operation { + &mut self.op + } + + fn name(&self) -> SymbolName { + Self::name(self).as_symbol() + } + + /// Set the name of this symbol + fn set_name(&mut self, name: SymbolName) { + let mut id = *self.name(); + id.name = name; + Function::set_name(self, id) + } + + /// Get the visibility of this symbol + fn visibility(&self) -> Visibility { + self.signature().visibility + } + + /// Returns true if this symbol has private visibility + #[inline] + fn is_private(&self) -> bool { + self.signature().is_private() + } + + /// Returns true if this symbol has public visibility + #[inline] + fn is_public(&self) -> bool { + self.signature().is_public() + } + + /// Sets the visibility of this symbol + fn set_visibility(&mut self, visibility: Visibility) { + self.signature_mut().visibility = visibility; + } - fn id(&self) -> Self::Id { - self.id.function + /// Get all of the uses of this symbol that are nested within `from` + fn symbol_uses(&self, from: OperationRef) -> SymbolUseIter { + todo!() + } + + /// Return true if there are no uses of this symbol nested within `from` + fn symbol_uses_known_empty(&self, from: OperationRef) -> SymbolUseIter { + todo!() + } + + /// Attempt to replace all uses of this symbol nested within `from`, with the provided replacement + fn replace_all_uses(&self, replacement: SymbolRef, from: OperationRef) -> Result<(), Report> { + todo!() + } + + /// Returns true if this operation is a declaration, rather than a definition, of a symbol + /// + /// The default implementation assumes that all operations are definitions + #[inline] + fn is_declaration(&self) -> bool { + self.body().is_empty() } } -impl Function { - pub fn signature(&self) -> &Signature { - &self.signature + +impl CallableOpInterface for Function { + fn get_callable_region(&self) -> Option { + if self.is_declaration() { + None + } else { + self.regions().front().as_pointer() + } + } + + #[inline] + fn signature(&self) -> &Signature { + Function::signature(self) } } @@ -125,6 +228,20 @@ impl formatter::PrettyPrint for AbiParam { } } +impl fmt::Display for AbiParam { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_map(); + builder.entry(&"ty", &format_args!("{}", &self.ty)); + if !matches!(self.purpose, ArgumentPurpose::Default) { + builder.entry(&"purpose", &format_args!("{}", &self.purpose)); + } + if !matches!(self.extension, ArgumentExtension::None) { + builder.entry(&"extension", &format_args!("{}", &self.extension)); + } + builder.finish() + } +} + /// A [Signature] represents the type, ABI, and linkage of a function. /// /// A function signature provides us with all of the necessary detail to correctly @@ -139,9 +256,37 @@ pub struct Signature { pub results: Vec, /// The calling convention that applies to this function pub cc: CallConv, - /// The linkage that should be used for this function - pub linkage: Linkage, + /// The linkage/visibility that should be used for this function + pub visibility: Visibility, +} + +crate::define_attr_type!(Signature); + +impl fmt::Display for Signature { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .key(&"params") + .value_with(|f| { + let mut builder = f.debug_list(); + for param in self.params.iter() { + builder.entry(&format_args!("{param}")); + } + builder.finish() + }) + .key(&"results") + .value_with(|f| { + let mut builder = f.debug_list(); + for param in self.params.iter() { + builder.entry(&format_args!("{param}")); + } + builder.finish() + }) + .entry(&"cc", &format_args!("{}", &self.cc)) + .entry(&"visibility", &format_args!("{}", &self.visibility)) + .finish() + } } + impl Signature { /// Create a new signature with the given parameter and result types, /// for a public function using the `SystemV` calling convention @@ -153,18 +298,18 @@ impl Signature { params: params.into_iter().collect(), results: results.into_iter().collect(), cc: CallConv::SystemV, - linkage: Linkage::External, + visibility: Visibility::Public, } } /// Returns true if this function is externally visible pub fn is_public(&self) -> bool { - matches!(self.linkage, Linkage::External) + matches!(self.visibility, Visibility::Public) } /// Returns true if this function is only visible within it's containing module pub fn is_private(&self) -> bool { - matches!(self.linkage, Linkage::Internal) + matches!(self.visibility, Visibility::Public) } /// Returns true if this function is a kernel function @@ -202,7 +347,7 @@ impl Signature { impl Eq for Signature {} impl PartialEq for Signature { fn eq(&self, other: &Self) -> bool { - self.linkage == other.linkage + self.visibility == other.visibility && self.cc == other.cc && self.params.len() == other.params.len() && self.results.len() == other.results.len() diff --git a/hir2/src/ir/ident.rs b/hir2/src/ir/ident.rs index 8304f80c8..2ec6c8125 100644 --- a/hir2/src/ir/ident.rs +++ b/hir2/src/ir/ident.rs @@ -11,7 +11,10 @@ use super::{ interner::{symbols, Symbol}, SourceSpan, Spanned, }; -use crate::formatter::{self, PrettyPrint}; +use crate::{ + define_attr_type, + formatter::{self, PrettyPrint}, +}; /// Represents a globally-unique module/function name pair, with corresponding source spans. #[derive(Copy, Clone, PartialEq, Eq, Hash, Spanned)] @@ -21,6 +24,7 @@ pub struct FunctionIdent { #[span] pub function: Ident, } +define_attr_type!(FunctionIdent); impl FunctionIdent { pub fn display(&self) -> impl fmt::Display + '_ { use crate::formatter::*; @@ -91,6 +95,7 @@ pub struct Ident { #[span] pub span: SourceSpan, } +define_attr_type!(Ident); impl Default for Ident { fn default() -> Self { Self { diff --git a/hir2/src/ir/insert.rs b/hir2/src/ir/insert.rs new file mode 100644 index 000000000..c3899c009 --- /dev/null +++ b/hir2/src/ir/insert.rs @@ -0,0 +1,100 @@ +use core::fmt; + +use crate::{BlockRef, OperationRef}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Insert { + Before, + After, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct InsertionPoint { + pub at: ProgramPoint, + pub action: Insert, +} +impl InsertionPoint { + #[inline] + pub const fn new(at: ProgramPoint, action: Insert) -> Self { + Self { at, action } + } + + #[inline] + pub const fn before(at: ProgramPoint) -> Self { + Self { + at, + action: Insert::Before, + } + } + + #[inline] + pub const fn after(at: ProgramPoint) -> Self { + Self { + at, + action: Insert::After, + } + } + + pub fn block(&self) -> BlockRef { + self.at.block().expect("cannot insert relative to detached operation") + } +} + +/// A `ProgramPoint` represents a position in a function where the live range of an SSA value can +/// begin or end. It can be either: +/// +/// 1. An instruction or +/// 2. A block header. +/// +/// This corresponds more or less to the lines in the textual form of the IR. +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum ProgramPoint { + /// An operation + Op(OperationRef), + /// A block header. + Block(BlockRef), +} +impl ProgramPoint { + /// Get the operation we know is inside. + pub fn unwrap_op(self) -> OperationRef { + use crate::Entity; + match self { + Self::Op(x) => x, + Self::Block(x) => panic!("expected operation, but got {}", x.borrow().id()), + } + } + + /// Get the block associated with this program point + /// + /// Returns `None` if the program point is a detached operation. + pub fn block(&self) -> Option { + match self { + Self::Op(op) => op.borrow().parent(), + Self::Block(block) => Some(block.clone()), + } + } +} +impl From for ProgramPoint { + fn from(op: OperationRef) -> Self { + Self::Op(op) + } +} +impl From for ProgramPoint { + fn from(block: BlockRef) -> Self { + Self::Block(block) + } +} +impl fmt::Display for ProgramPoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use crate::Entity; + match self { + Self::Op(x) => write!(f, "{}", x.borrow().name()), + Self::Block(x) => write!(f, "{}", x.borrow().id()), + } + } +} +impl fmt::Debug for ProgramPoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("ProgramPoint").field_with(|f| write!(f, "{}", self)).finish() + } +} diff --git a/hir2/src/ir/module.rs b/hir2/src/ir/module.rs index 738a7e875..a3e915bff 100644 --- a/hir2/src/ir/module.rs +++ b/hir2/src/ir/module.rs @@ -1,49 +1,99 @@ use alloc::collections::BTreeMap; -use super::{EntityList, Function, Ident, Symbol, SymbolTable, UnsafeIntrusiveEntityRef}; +use crate::{ + derive, + dialects::hir::HirDialect, + traits::{NoRegionArguments, SingleBlock, SingleRegion}, + Ident, InsertionPoint, Operation, Report, SymbolName, SymbolRef, SymbolTable, +}; -pub struct Module { - name: Ident, - functions: EntityList, - registry: BTreeMap>, -} -impl Module { - pub const fn name(&self) -> Ident { - self.name +derive! { + pub struct Module : Op { + #[dialect] + dialect: HirDialect, + #[attr] + name: Ident, + #[region] + body: RegionRef, + registry: BTreeMap, } - pub fn functions(&self) -> &EntityList { - &self.functions + derives SingleRegion, SingleBlock, NoRegionArguments; + implements SymbolTable; +} + +impl SymbolTable for Module { + #[inline(always)] + fn as_operation(&self) -> &Operation { + &self.op } - pub fn functions_mut(&mut self) -> &mut EntityList { - &mut self.functions + #[inline(always)] + fn as_operation_mut(&mut self) -> &mut Operation { + &mut self.op } -} -impl SymbolTable for Module { - type Entry = UnsafeIntrusiveEntityRef; - type Key = Ident; - fn get(&self, id: &Self::Key) -> Option { - self.registry.get(id).cloned() + fn get(&self, name: SymbolName) -> Option { + self.registry.get(&name).cloned() } - fn insert(&mut self, entry: Self::Entry) -> bool { - let id = entry.borrow().id(); - if self.registry.contains_key(&id) { + //TODO(pauls): Insert symbol ref in module body + fn insert_new(&mut self, entry: SymbolRef, ip: Option) -> bool { + let symbol = entry.borrow(); + let name = symbol.name(); + if self.registry.contains_key(&name) { return false; } - self.registry.insert(id, entry.clone()); - self.functions.push_back(entry); + drop(symbol); + self.registry.insert(name, entry); true } - fn remove(&mut self, id: &Self::Key) -> Option { - if let Some(ptr) = self.registry.remove(id) { - let mut cursor = unsafe { self.functions.cursor_mut_from_ptr(ptr) }; - cursor.remove() + //TODO(pauls): Insert symbol ref in module body + fn insert(&mut self, mut entry: SymbolRef, ip: Option) -> SymbolName { + let mut symbol = entry.borrow_mut(); + let mut name = symbol.name(); + if self.registry.contains_key(&name) { + // Unique the symbol name + let mut counter = 0; + name = super::symbol_table::generate_symbol_name(name, &mut counter, |name| { + self.registry.contains_key(name) + }); + symbol.set_name(name); + } + drop(symbol); + self.registry.insert(name, entry); + name + } + + fn remove(&mut self, name: SymbolName) -> Option { + if let Some(ptr) = self.registry.remove(&name) { + let op = ptr.borrow().as_operation_ref(); + let mut body = self.body_mut(); + let mut entry = body.entry_mut(); + let mut cursor = unsafe { entry.body_mut().cursor_mut_from_ptr(op) }; + cursor.remove().map(|_| ptr) } else { None } } + + fn rename(&mut self, from: SymbolName, to: SymbolName) -> Result<(), Report> { + if let Some(symbol) = self.registry.get_mut(&from) { + let mut sym = symbol.borrow_mut(); + sym.set_name(to); + let uses = sym.uses_mut(); + let mut cursor = uses.front_mut(); + while let Some(mut next_use) = cursor.get_mut() { + next_use.symbol.name = to; + } + + Ok(()) + } else { + Err(Report::msg(format!( + "unable to rename '{from}': no such symbol in '{}'", + self.name().as_str() + ))) + } + } } diff --git a/hir2/src/ir/op.rs b/hir2/src/ir/op.rs index 136fd6392..440a911a7 100644 --- a/hir2/src/ir/op.rs +++ b/hir2/src/ir/op.rs @@ -1,8 +1,8 @@ -use downcast_rs::{impl_downcast, Downcast}; +use core::any::Any; use super::*; -pub trait Op: Downcast + OpVerifier { +pub trait Op: Any + OpVerifier { /// The name of this operation's opcode /// /// The opcode must be distinct from all other opcodes in the same dialect @@ -20,21 +20,43 @@ pub trait Op: Downcast + OpVerifier { self.as_operation().parent_op() } fn regions(&self) -> &RegionList { - &self.as_operation().regions + self.as_operation().regions() } - fn operands(&self) -> &[OpOperand] { - self.as_operation().operands.as_slice() + fn regions_mut(&mut self) -> &mut RegionList { + self.as_operation_mut().regions_mut() + } + fn region(&self, index: usize) -> EntityRef<'_, Region> { + self.as_operation().region(index) + } + fn region_mut(&mut self, index: usize) -> EntityMut<'_, Region> { + self.as_operation_mut().region_mut(index) + } + fn has_operands(&self) -> bool { + self.as_operation().has_operands() + } + fn num_operands(&self) -> usize { + self.as_operation().num_operands() + } + fn operands(&self) -> &OpOperandStorage { + self.as_operation().operands() + } + fn operands_mut(&mut self) -> &mut OpOperandStorage { + self.as_operation_mut().operands_mut() } fn results(&self) -> &[OpResultRef] { - self.as_operation().results.as_slice() + self.as_operation().results() + } + fn results_mut(&mut self) -> &mut [OpResultRef] { + self.as_operation_mut().results_mut() } fn successors(&self) -> &[OpSuccessor] { - self.as_operation().successors.as_slice() + self.as_operation().successors() + } + fn successors_mut(&mut self) -> &mut [OpSuccessor] { + self.as_operation_mut().successors_mut() } } -impl_downcast!(Op); - impl Spanned for dyn Op { fn span(&self) -> SourceSpan { self.as_operation().span diff --git a/hir2/src/ir/operands.rs b/hir2/src/ir/operands.rs new file mode 100644 index 000000000..48dd5c15b --- /dev/null +++ b/hir2/src/ir/operands.rs @@ -0,0 +1,564 @@ +use core::{fmt, num::NonZeroU16}; + +use smallvec::{smallvec, SmallVec}; + +use crate::{EntityRef, OperationRef, Type, UnsafeIntrusiveEntityRef, Value, ValueId, ValueRef}; + +pub type OpOperand = UnsafeIntrusiveEntityRef; +pub type OpOperandList = crate::EntityList; +#[allow(unused)] +pub type OpOperandIter<'a> = crate::EntityIter<'a, OpOperandImpl>; +#[allow(unused)] +pub type OpOperandCursor<'a> = crate::EntityCursor<'a, OpOperandImpl>; +#[allow(unused)] +pub type OpOperandCursorMut<'a> = crate::EntityCursorMut<'a, OpOperandImpl>; + +/// An [OpOperand] represents a use of a [Value] by an [Operation] +pub struct OpOperandImpl { + /// The operand value + pub value: ValueRef, + /// The owner of this operand, i.e. the operation it is an operand of + pub owner: OperationRef, + /// The index of this operand in the operand list of an operation + pub index: u8, +} +impl OpOperandImpl { + #[inline] + pub fn new(value: ValueRef, owner: OperationRef, index: u8) -> Self { + Self { + value, + owner, + index, + } + } + + pub fn value(&self) -> EntityRef<'_, dyn Value> { + self.value.borrow() + } + + pub fn unlink(&mut self) { + let ptr = unsafe { OpOperand::from_raw(self as *mut Self) }; + let mut value = self.value.borrow_mut(); + let uses = value.uses_mut(); + unsafe { + let mut cursor = uses.cursor_mut_from_ptr(ptr); + cursor.remove(); + } + } +} +impl fmt::Debug for OpOperandImpl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + #[allow(unused)] + struct ValueInfo<'a> { + id: ValueId, + ty: &'a Type, + } + + let value = self.value.borrow(); + let id = value.id(); + let ty = value.ty(); + f.debug_struct("OpOperand") + .field("index", &self.index) + .field("value", &ValueInfo { id, ty }) + .finish_non_exhaustive() + } +} + +#[derive(Default, Copy, Clone)] +struct OpOperandGroup(Option); +impl OpOperandGroup { + const START_MASK: u16 = u8::MAX as u16; + + fn new(start: usize, len: usize) -> Self { + if len == 0 { + return Self::default(); + } + + let start = u16::try_from(start).expect("too many operands"); + let len = u16::try_from(len).expect("operand group too large"); + let group = start | (len << 8); + + Self(Some(unsafe { NonZeroU16::new_unchecked(group) })) + } + + #[allow(unused)] + #[inline] + pub fn start(&self) -> Option { + Some((self.0?.get() & Self::START_MASK) as usize) + } + + #[inline] + pub fn end(&self) -> Option { + self.as_range().map(|range| range.end) + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_none() + } + + #[allow(unused)] + #[inline] + pub fn len(&self) -> usize { + self.0.as_ref().map(|group| (group.get() >> 8) as usize).unwrap_or(0) + } + + pub fn as_range(&self) -> Option> { + let raw = self.0?.get(); + let start = (raw & Self::START_MASK) as usize; + let len = (raw >> 8) as usize; + Some(start..(start + len)) + } + + pub fn increase_size(&mut self, size: usize) { + let group = self.0.as_mut().expect("expected non-empty group"); + let raw = group.get(); + let size = u16::try_from(size).expect("too many operands"); + let start = raw & Self::START_MASK; + let len = (raw >> 8) + size; + assert!(len <= u8::MAX as u16, "operand group is too large"); + *group = unsafe { NonZeroU16::new_unchecked(start | (len << 8)) }; + } + + pub fn decrease_size(&mut self, size: usize) { + let group = self.0.as_mut().expect("expected non-empty group"); + let raw = group.get(); + let size = u16::try_from(size).expect("too many operands"); + let len = (raw >> 8) - size; + if len > 0 { + let start = raw & Self::START_MASK; + *group = unsafe { NonZeroU16::new_unchecked(start | (len << 8)) }; + } else { + self.0 = None; + } + } + + pub fn shift_start(&mut self, offset: isize) { + let offset = i16::try_from(offset).expect("offset too large"); + if let Some(group) = self.0.as_mut() { + let raw = group.get(); + let mut start = raw & Self::START_MASK; + if offset >= 0 { + start += offset as u16; + } else { + start -= offset.unsigned_abs(); + } + assert!(start <= Self::START_MASK, "too many operands"); + // Clear previous start value + let raw = raw & !Self::START_MASK; + *group = unsafe { NonZeroU16::new_unchecked(raw | start) }; + } + } +} + +pub struct OpOperandStorage { + operands: SmallVec<[OpOperand; 1]>, + groups: SmallVec<[OpOperandGroup; 2]>, +} +impl fmt::Debug for OpOperandStorage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpOperandStorage") + .field_with("groups", |f| { + let mut builder = f.debug_list(); + for group in self.groups.iter() { + match group.as_range() { + Some(range) => { + let operands = &self.operands[range.clone()]; + builder.entry_with(|f| { + f.debug_map() + .entry(&"range", &range) + .entry(&"operands", &operands) + .finish() + }); + } + None => { + builder.entry(&""); + } + } + } + builder.finish() + }) + .finish() + } +} +impl Default for OpOperandStorage { + fn default() -> Self { + Self { + operands: Default::default(), + groups: smallvec![OpOperandGroup::default()], + } + } +} +impl OpOperandStorage { + #[inline] + pub fn is_empty(&self) -> bool { + self.operands.is_empty() + } + + #[inline] + pub fn len(&self) -> usize { + self.operands.len() + } + + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, OpOperand> { + self.operands.iter() + } + + #[inline] + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, OpOperand> { + self.operands.iter_mut() + } + + /// Push operand to the last operand group + pub fn push_operand(&mut self, mut operand: OpOperand) { + let index = self.operands.len() as u8; + operand.borrow_mut().index = index; + self.operands.push(operand); + let group = self.groups.last_mut().unwrap(); + if group.is_empty() { + *group = OpOperandGroup::new(self.operands.len(), 1); + return; + } + group.increase_size(1); + } + + /// Push operand to the specified group + pub fn push_operand_to_group(&mut self, group: usize, operand: OpOperand) { + if self.groups.len() <= group { + self.groups.resize(group + 1, OpOperandGroup::default()); + } + let mut group = self.group_mut(group); + group.push(operand); + } + + /// Create operand group with index `group`, allocating any intervening groups if missing + pub fn push_operands_to_group(&mut self, group: usize, operands: I) + where + I: IntoIterator, + { + if self.groups.len() <= group { + self.groups.resize(group + 1, OpOperandGroup::default()); + } + let mut group = self.group_mut(group); + group.extend(operands); + } + + /// Push multiple operands to the last operand group + pub fn extend(&mut self, operands: I) + where + I: IntoIterator, + { + let mut group = self.group_mut(self.groups.len() - 1); + group.extend(operands); + } + + pub fn clear(&mut self) { + for mut operand in self.operands.drain(..) { + let mut operand = operand.borrow_mut(); + operand.unlink(); + } + self.groups.clear(); + self.groups.push(OpOperandGroup::default()); + } + + /// Get all the operands + pub fn all(&self) -> OpOperandRange<'_> { + OpOperandRange { + range: 0..self.operands.len(), + operands: self.operands.as_slice(), + } + } + + /// Get operands for the specified group + pub fn group(&self, group: usize) -> OpOperandRange<'_> { + OpOperandRange { + range: self.groups[group].as_range().unwrap_or(0..0), + operands: self.operands.as_slice(), + } + } + + /// Get operands for the specified group + pub fn group_mut(&mut self, group: usize) -> OpOperandRangeMut<'_> { + let range = self.groups[group].as_range(); + OpOperandRangeMut { + group, + range, + groups: &mut self.groups, + operands: &mut self.operands, + } + } +} +impl core::ops::Index for OpOperandStorage { + type Output = OpOperand; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.operands[index] + } +} +impl core::ops::IndexMut for OpOperandStorage { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.operands[index] + } +} + +/// A reference to a range of operands in [OpOperandStorage] +pub struct OpOperandRange<'a> { + range: core::ops::Range, + operands: &'a [OpOperand], +} +impl<'a> OpOperandRange<'a> { + #[inline] + pub fn is_empty(&self) -> bool { + self.as_slice().is_empty() + } + + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + #[inline] + pub fn as_slice(&self) -> &[OpOperand] { + &self.operands[self.range.start..self.range.end] + } + + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, OpOperand> { + self.as_slice().iter() + } + + #[inline] + pub fn get(&self, index: usize) -> Option<&OpOperand> { + self.as_slice().get(index) + } +} +impl<'a> core::ops::Index for OpOperandRange<'a> { + type Output = OpOperand; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.as_slice()[index] + } +} + +/// A mutable range of operands in [OpOperandStorage] +/// +/// Operands outside the range are not modified, however the range itself can have its size change, +/// which as a result will shift other operands around. Any other groups in [OpOperandStorage] will +/// be updated to reflect such changes, so in general this should be transparent. +pub struct OpOperandRangeMut<'a> { + group: usize, + range: Option>, + groups: &'a mut [OpOperandGroup], + operands: &'a mut SmallVec<[OpOperand; 1]>, +} +impl<'a> OpOperandRangeMut<'a> { + #[inline] + pub fn is_empty(&self) -> bool { + self.as_slice().is_empty() + } + + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + #[inline] + pub fn push(&mut self, operand: OpOperand) { + self.extend([operand]); + } + + pub fn extend(&mut self, operands: I) + where + I: IntoIterator, + { + // Handle edge case where group is the last group + let is_last = self.groups.len() == self.group + 1; + let is_empty = self.range.is_none(); + + if is_last && is_empty { + let prev_len = self.operands.len(); + self.operands.extend(operands.into_iter().enumerate().map(|(i, mut operand)| { + let mut operand_mut = operand.borrow_mut(); + operand_mut.index = (prev_len + i) as u8; + drop(operand_mut); + operand + })); + let num_inserted = self.operands.len().abs_diff(prev_len); + if num_inserted == 0 { + return; + } + self.groups[self.group] = OpOperandGroup::new(self.operands.len(), num_inserted); + self.range = self.groups[self.group].as_range(); + } else if is_last { + self.extend_last(operands); + } else { + self.extend_within(operands); + } + } + + fn extend_last(&mut self, operands: I) + where + I: IntoIterator, + { + let prev_len = self.operands.len(); + self.operands.extend(operands.into_iter().enumerate().map(|(i, mut operand)| { + let mut operand_mut = operand.borrow_mut(); + operand_mut.index = (prev_len + i) as u8; + drop(operand_mut); + operand + })); + let num_inserted = self.operands.len().abs_diff(prev_len); + if num_inserted == 0 { + return; + } + self.groups[self.group].increase_size(num_inserted); + self.range = self.groups[self.group].as_range(); + } + + fn extend_within(&mut self, operands: I) + where + I: IntoIterator, + { + let prev_len = self.operands.len(); + let num_inserted; + + match self.range.as_mut() { + Some(range) => { + let start = range.end; + self.operands.insert_many( + range.end, + operands.into_iter().enumerate().map(|(i, mut operand)| { + let mut operand_mut = operand.borrow_mut(); + operand_mut.index = (start + i) as u8; + drop(operand_mut); + operand + }), + ); + num_inserted = self.operands.len().abs_diff(prev_len); + if num_inserted == 0 { + return; + } + self.groups[self.group].increase_size(num_inserted); + range.end += num_inserted; + } + None => { + let start = self.groups[..self.group] + .iter() + .rev() + .filter_map(OpOperandGroup::end) + .next() + .unwrap_or(0); + self.operands.insert_many( + start, + operands.into_iter().enumerate().map(|(i, mut operand)| { + let mut operand_mut = operand.borrow_mut(); + operand_mut.index = (start + i) as u8; + drop(operand_mut); + operand + }), + ); + num_inserted = self.operands.len().abs_diff(prev_len); + if num_inserted == 0 { + return; + } + self.groups[self.group] = OpOperandGroup::new(start, num_inserted); + self.range = self.groups[self.group].as_range(); + } + } + + // Shift groups + for group in self.groups[(self.group + 1)..].iter_mut() { + if group.is_empty() { + continue; + } + group.shift_start(num_inserted as isize); + } + + // Shift operand indices + let shifted = self.range.as_ref().unwrap().end; + for operand in self.operands[shifted..].iter_mut() { + let mut operand_mut = operand.borrow_mut(); + operand_mut.index += 1; + } + } + + pub fn pop(&mut self) -> Option { + let range = self.range.as_mut()?; + let index = range.end; + range.end -= 1; + if (*range).is_empty() { + self.range = None; + } + self.groups[self.group].decrease_size(1); + let mut removed = self.operands.remove(index); + { + let mut operand_mut = removed.borrow_mut(); + operand_mut.unlink(); + } + + // Shift groups + for group in self.groups[(self.group + 1)..].iter_mut() { + if group.is_empty() { + continue; + } + group.shift_start(-1); + } + + // Shift operand indices + for operand in self.operands[index..].iter_mut() { + let mut operand_mut = operand.borrow_mut(); + operand_mut.index -= 1; + } + + Some(removed) + } + + #[inline] + pub fn as_slice(&self) -> &[OpOperand] { + &self.operands[self.range.clone().unwrap_or(0..0)] + } + + #[inline] + pub fn as_slice_mut(&mut self) -> &mut [OpOperand] { + &mut self.operands[self.range.clone().unwrap_or(0..0)] + } + + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, OpOperand> { + self.as_slice().iter() + } + + #[inline] + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, OpOperand> { + self.as_slice_mut().iter_mut() + } + + #[inline] + pub fn get(&self, index: usize) -> Option<&OpOperand> { + self.as_slice().get(index) + } + + #[inline] + pub fn get_mut(&mut self, index: usize) -> Option<&mut OpOperand> { + self.as_slice_mut().get_mut(index) + } +} +impl<'a> core::ops::Index for OpOperandRangeMut<'a> { + type Output = OpOperand; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.as_slice()[index] + } +} +impl<'a> core::ops::IndexMut for OpOperandRangeMut<'a> { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.as_slice_mut()[index] + } +} diff --git a/hir2/src/ir/operation.rs b/hir2/src/ir/operation.rs index a097e5b53..57cc5bcab 100644 --- a/hir2/src/ir/operation.rs +++ b/hir2/src/ir/operation.rs @@ -1,11 +1,14 @@ +mod builder; +mod name; + use core::{ fmt, - marker::Unsize, ptr::{DynMetadata, Pointee}, }; use smallvec::SmallVec; +pub use self::{builder::OperationBuilder, name::OperationName}; use super::*; pub type OperationRef = UnsafeIntrusiveEntityRef; @@ -13,162 +16,55 @@ pub type OpList = EntityList; pub type OpCursor<'a> = EntityCursor<'a, Operation>; pub type OpCursorMut<'a> = EntityCursorMut<'a, Operation>; -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct OperationName { - pub dialect: DialectName, - pub name: interner::Symbol, -} -impl OperationName { - pub fn new(dialect: DialectName, name: S) -> Self - where - S: Into, - { - Self { - dialect, - name: name.into(), - } - } -} -impl fmt::Debug for OperationName { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(self, f) - } -} -impl fmt::Display for OperationName { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}.{}", &self.dialect, &self.name) - } -} - -/// An [OpSuccessor] is a BlockOperand + OpOperands for that block, attached to an Operation -pub struct OpSuccessor { - pub block: BlockOperandRef, - pub args: SmallVec<[OpOperand; 1]>, -} -impl fmt::Debug for OpSuccessor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpSuccessor") - .field("block", &self.block.borrow().block_id()) - .field("args", &self.args) - .finish() - } -} - -// TODO: We need a safe way to construct arbitrary Ops imperatively: -// -// * Allocate an uninit instance of T -// * Initialize the Operartion field of T with the empty Operation data -// * Use the primary builder methods to mutate Operation fields -// * Use generated methods on Op-specific builders to mutate Op fields -// * At the end, convert uninit T to init T, return handle to caller -// -// Problems: -// -// * How do we default-initialize an instance of T for this purpose -// * If we use MaybeUninit, how do we compute field offsets for the Operation field -// * Generated methods can compute offsets, but how do we generate the specialized builders? -pub struct OperationBuilder<'a, T> { - context: &'a Context, - op: UnsafeIntrusiveEntityRef, - _marker: core::marker::PhantomData, -} -impl<'a, T: Op> OperationBuilder<'a, T> { - pub fn new(context: &'a Context, op: T) -> Self { - let mut op = context.alloc_tracked(op); - - // SAFETY: Setting the data pointer of the multi-trait vtable must ensure - // that it points to the concrete type of the allocation, which we can guarantee here, - // having just allocated it. Until the data pointer is set, casts using the vtable are - // undefined behavior, so by never allowing the uninitialized vtable to be accessed, - // we can ensure the multi-trait impl is safe - unsafe { - let data_ptr = UnsafeIntrusiveEntityRef::as_ptr(&op); - let mut op_mut = op.borrow_mut(); - op_mut.as_operation_mut().vtable.set_data_ptr(data_ptr.cast_mut()); - } - - Self { - context, - op, - _marker: core::marker::PhantomData, - } - } - - /// Register this op as an implementation of `Trait`. - /// - /// This is enforced statically by the type system, as well as dynamically via verification. - /// - /// This must be called for any trait that you wish to be able to cast the type-erased - /// [Operation] to later, or if you wish to get a `dyn Trait` reference from a `dyn Op` - /// reference. - /// - /// If `Trait` has a verifier implementation, it will be automatically applied when calling - /// [Operation::verify]. - pub fn implement(&mut self) - where - Trait: ?Sized + Pointee> + 'static, - T: Unsize + verifier::Verifier + 'static, - { - let mut op = self.op.borrow_mut(); - let operation = op.as_operation_mut(); - operation.vtable.register_trait::(); - } - - /// Set attribute `name` on this op to `value` - pub fn with_attr(&mut self, name: &'static str, value: A) - where - A: AttributeValue, - { - let mut op = self.op.borrow_mut(); - op.as_operation_mut().attrs.insert(interner::Symbol::intern(name), Some(value)); - } - - /// Set the operands given to this op - pub fn with_operands(&mut self, operands: I) - where - I: IntoIterator, - { - let mut op = self.op.borrow_mut(); - // TODO: Verify the safety of this conversion - let owner = unsafe { - let ptr = op.as_operation() as *const Operation; - UnsafeIntrusiveEntityRef::from_raw(ptr) - }; - let operands = operands.into_iter().enumerate().map(|(index, value)| { - self.context - .alloc_tracked(value::OpOperandImpl::new(value, owner.clone(), index as u8)) - }); - let op_mut = op.as_operation_mut(); - op_mut.operands.clear(); - op_mut.operands.extend(operands); - } - - /// Allocate `n` results for this op, of unknown type, to be filled in later - pub fn with_results(&mut self, n: usize) { - let mut op = self.op.borrow_mut(); - let owner = unsafe { - let ptr = op.as_operation() as *const Operation; - UnsafeIntrusiveEntityRef::from_raw(ptr) - }; - let results = - (0..n).map(|idx| self.context.make_result(Type::Unknown, owner.clone(), idx as u8)); - let op_mut = op.as_operation_mut(); - op_mut.results.clear(); - op_mut.results.extend(results); - } - - /// Consume this builder, verify the op, and return a handle to it, or an error if validation - /// failed. - pub fn build(self) -> Result, Report> { - { - let op = self.op.borrow(); - op.as_operation().verify(self.context)?; - } - Ok(self.op) - } -} - +/// The [Operation] struct provides the common foundation for all [Op] implementations. +/// +/// It provides: +/// +/// * Support for casting between the concrete operation type `T`, `dyn Op`, the underlying +/// `Operation`, and any of the operation traits that the op implements. Not only can the casts +/// be performed, but an [Operation] can be queried to see if it implements a specific trait at +/// runtime to conditionally perform some behavior. This makes working with operations in the IR +/// very flexible and allows for adding or modifying operations without needing to change most of +/// the compiler, which predominately works on operation traits rather than concrete ops. +/// * Storage for all IR entities attached to an operation, e.g. operands, results, nested regions, +/// attributes, etc. +/// * Navigation of the IR graph; navigate up to the containing block/region/op, down to nested +/// regions/blocks/ops, or next/previous sibling operations in the same block. Additionally, you +/// can navigate directly to the definitions of operands used, to users of results produced, and +/// to successor blocks. +/// * Many utility functions related to working with operations, many of which are also accessible +/// via the [Op] trait, so that working with an [Op] or an [Operation] are largely +/// indistinguishable. +/// +/// All [Op] implementations can be cast to the underlying [Operation], but most of the +/// fucntionality is re-exported via default implementations of methods on the [Op] trait. The main +/// benefit is avoiding any potential overhead of casting when going through the trait, rather than +/// calling the underlying [Operation] method directly. +/// +/// # Safety +/// +/// [Operation] is implemented as part of a larger structure that relies on assumptions which depend +/// on IR entities being allocated via [Context], i.e. the arena. Those allocations produce an +/// [UnsafeIntrusiveEntityRef] or [UnsafeEntityRef], which allocate the pointee type inside a struct +/// that provides metadata about the pointee that can be accessed without aliasing the pointee +/// itself - in particular, links for intrusive collections. This is important, because while these +/// pointer types are a bit like raw pointers in that they lack any lifetime information, and are +/// thus unsafe to dereference in general, they _do_ ensure that the pointee can be safely reified +/// as a reference without violating Rust's borrow checking rules, i.e. they are dynamically borrow- +/// checked. +/// +/// The reason why we are able to generally treat these "unsafe" references as safe, is because we +/// require that all IR entities be allocated via [Context]. This makes it essential to keep the +/// context around in order to work with the IR, and effectively guarantees that no [RawEntityRef] +/// will be dereferenced after the context is dropped. This is not a guarantee provided by the +/// compiler however, but one that is imposed in practice, as attempting to work with the IR in +/// any capacity without a [Context] is almost impossible. We must ensure however, that we work +/// within this set of rules to uphold the safety guarantees. +/// +/// This "fragility" is a tradeoff - we get the performance characteristics of an arena-allocated +/// IR, with the flexibility and power of using pointers rather than indexes as handles, while also +/// maintaining the safety guarantees of Rust's borrowing system. The downside is that we can't just +/// allocate IR entities wherever we want and use them the same way. #[derive(Spanned)] pub struct Operation { /// In order to support upcasting from [Operation] to its concrete [Op] type, as well as @@ -190,7 +86,7 @@ pub struct Operation { /// by the op, rather than here. Additionally, the semantics of the immediate operands are /// determined by the op, e.g. whether the immediate operands are always applied first, or /// what they are used for. - pub operands: SmallVec<[OpOperand; 1]>, + pub operands: OpOperandStorage, /// The set of values produced by this operation. pub results: SmallVec<[OpResultRef; 1]>, /// If this operation represents control flow, this field stores the set of successors, @@ -202,6 +98,7 @@ pub struct Operation { impl fmt::Debug for Operation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Operation") + .field_with("name", |f| write!(f, "{}", &self.name())) .field("attrs", &self.attrs) .field("block", &self.block.as_ref().map(|b| b.borrow().id())) .field("operands", &self.operands) @@ -220,6 +117,8 @@ impl AsMut for Operation { self.vtable.downcast_trait_mut().unwrap() } } + +/// Construction impl Operation { pub fn uninit() -> Self { use super::traits::MultiTraitVtable; @@ -240,6 +139,13 @@ impl Operation { } } +/// Metadata +impl Operation { + pub fn name(&self) -> OperationName { + AsRef::::as_ref(self).name() + } +} + /// Verification impl Operation { pub fn verify(&self, context: &Context) -> Result<(), Report> { @@ -250,6 +156,13 @@ impl Operation { /// Traits/Casts impl Operation { + #[inline(always)] + pub fn as_operation_ref(&self) -> OperationRef { + // SAFETY: This is safe under the assumption that we always allocate Operations using the + // arena, i.e. it is a child of a RawEntityMetadata structure. + unsafe { OperationRef::from_raw(self) } + } + /// Returns true if the concrete type of this operation is `T` #[inline] pub fn is(&self) -> bool { @@ -297,12 +210,43 @@ impl Operation { /// Return the value associated with attribute `name` for this function pub fn get_attribute(&self, name: &Q) -> Option<&dyn AttributeValue> where - interner::Symbol: std::borrow::Borrow, + interner::Symbol: core::borrow::Borrow, Q: Ord + ?Sized, { self.attrs.get_any(name) } + /// Return the value associated with attribute `name` for this function + pub fn get_attribute_mut(&mut self, name: &Q) -> Option<&mut dyn AttributeValue> + where + interner::Symbol: core::borrow::Borrow, + Q: Ord + ?Sized, + { + self.attrs.get_any_mut(name) + } + + /// Return the value associated with attribute `name` for this function, as its concrete type + /// `T`, _if_ the attribute by that name, is of that type. + pub fn get_typed_attribute(&self, name: &Q) -> Option<&T> + where + T: AttributeValue, + interner::Symbol: core::borrow::Borrow, + Q: Ord + ?Sized, + { + self.attrs.get(name) + } + + /// Return the value associated with attribute `name` for this function, as its concrete type + /// `T`, _if_ the attribute by that name, is of that type. + pub fn get_typed_attribute_mut(&mut self, name: &Q) -> Option<&mut T> + where + T: AttributeValue, + interner::Symbol: core::borrow::Borrow, + Q: Ord + ?Sized, + { + self.attrs.get_mut(name) + } + /// Return true if this function has an attributed named `name` pub fn has_attribute(&self, name: &Q) -> bool where @@ -385,6 +329,55 @@ impl Operation { pub fn regions_mut(&mut self) -> &mut RegionList { &mut self.regions } + + pub fn region(&self, index: usize) -> EntityRef<'_, Region> { + let mut cursor = self.regions.front(); + let mut count = 0; + while !cursor.is_null() { + if index == count { + return cursor.into_borrow().unwrap(); + } + cursor.move_next(); + count += 1; + } + panic!("invalid region index {index}: out of bounds"); + } + + pub fn region_mut(&mut self, index: usize) -> EntityMut<'_, Region> { + let mut cursor = self.regions.front_mut(); + let mut count = 0; + while !cursor.is_null() { + if index == count { + return cursor.into_borrow_mut().unwrap(); + } + cursor.move_next(); + count += 1; + } + panic!("invalid region index {index}: out of bounds"); + } +} + +/// Successors +impl Operation { + #[inline] + pub fn has_successors(&self) -> bool { + !self.successors.is_empty() + } + + #[inline] + pub fn num_successors(&self) -> usize { + self.successors.len() + } + + #[inline(always)] + pub fn successors(&self) -> &[OpSuccessor] { + &self.successors + } + + #[inline(always)] + pub fn successors_mut(&mut self) -> &mut [OpSuccessor] { + &mut self.successors + } } /// Operands @@ -400,8 +393,13 @@ impl Operation { } #[inline] - pub fn operands(&self) -> &[OpOperand] { - self.operands.as_slice() + pub fn operands(&self) -> &OpOperandStorage { + &self.operands + } + + #[inline] + pub fn operands_mut(&mut self) -> &mut OpOperandStorage { + &mut self.operands } pub fn replaces_uses_of_with(&mut self, mut from: ValueRef, mut to: ValueRef) { @@ -431,3 +429,26 @@ impl Operation { } } } + +/// Results +impl Operation { + #[inline] + pub fn has_results(&self) -> bool { + !self.results.is_empty() + } + + #[inline] + pub fn num_results(&self) -> usize { + self.results.len() + } + + #[inline] + pub fn results(&self) -> &[OpResultRef] { + self.results.as_slice() + } + + #[inline] + pub fn results_mut(&mut self) -> &mut [OpResultRef] { + self.results.as_mut_slice() + } +} diff --git a/hir2/src/ir/operation/builder.rs b/hir2/src/ir/operation/builder.rs new file mode 100644 index 000000000..a1a7af922 --- /dev/null +++ b/hir2/src/ir/operation/builder.rs @@ -0,0 +1,168 @@ +use core::{ + marker::Unsize, + ptr::{DynMetadata, Pointee}, +}; + +use super::{Operation, OperationRef}; +use crate::{ + verifier, AttributeValue, Context, Op, OpOperandImpl, OpSuccessor, Region, Report, Type, + UnsafeIntrusiveEntityRef, ValueRef, +}; + +// TODO: We need a safe way to construct arbitrary Ops imperatively: +// +// * Allocate an uninit instance of T +// * Initialize the Operartion field of T with the empty Operation data +// * Use the primary builder methods to mutate Operation fields +// * Use generated methods on Op-specific builders to mutate Op fields +// * At the end, convert uninit T to init T, return handle to caller +// +// Problems: +// +// * How do we default-initialize an instance of T for this purpose +// * If we use MaybeUninit, how do we compute field offsets for the Operation field +// * Generated methods can compute offsets, but how do we generate the specialized builders? +pub struct OperationBuilder<'a, T> { + context: &'a Context, + op: UnsafeIntrusiveEntityRef, + _marker: core::marker::PhantomData, +} +impl<'a, T: Op> OperationBuilder<'a, T> { + pub fn new(context: &'a Context, op: T) -> Self { + let mut op = context.alloc_tracked(op); + + // SAFETY: Setting the data pointer of the multi-trait vtable must ensure + // that it points to the concrete type of the allocation, which we can guarantee here, + // having just allocated it. Until the data pointer is set, casts using the vtable are + // undefined behavior, so by never allowing the uninitialized vtable to be accessed, + // we can ensure the multi-trait impl is safe + unsafe { + let data_ptr = UnsafeIntrusiveEntityRef::as_ptr(&op); + let mut op_mut = op.borrow_mut(); + op_mut.as_operation_mut().vtable.set_data_ptr(data_ptr.cast_mut()); + } + + Self { + context, + op, + _marker: core::marker::PhantomData, + } + } + + /// Register this op as an implementation of `Trait`. + /// + /// This is enforced statically by the type system, as well as dynamically via verification. + /// + /// This must be called for any trait that you wish to be able to cast the type-erased + /// [Operation] to later, or if you wish to get a `dyn Trait` reference from a `dyn Op` + /// reference. + /// + /// If `Trait` has a verifier implementation, it will be automatically applied when calling + /// [Operation::verify]. + pub fn implement(&mut self) + where + Trait: ?Sized + Pointee> + 'static, + T: Unsize + verifier::Verifier + 'static, + { + let mut op = self.op.borrow_mut(); + let operation = op.as_operation_mut(); + operation.vtable.register_trait::(); + } + + /// Set attribute `name` on this op to `value` + pub fn with_attr(&mut self, name: &'static str, value: A) + where + A: AttributeValue, + { + let mut op = self.op.borrow_mut(); + op.as_operation_mut().attrs.insert(name, Some(value)); + } + + /// Add a new [Region] to this operation. + /// + /// NOTE: You must ensure this is called _after_ [Self::with_operands], and [Self::implements] + /// if the op implements the [traits::NoRegionArguments] trait. Otherwise, the inserted region + /// may not be valid for this op. + pub fn create_region(&mut self) { + let mut region = Region::default(); + unsafe { + region.set_owner(Some(self.as_operation_ref())); + } + let region = self.context.alloc_tracked(region); + let mut op = self.op.borrow_mut(); + op.as_operation_mut().regions.push_back(region); + } + + pub fn with_successor(&mut self, succ: OpSuccessor) { + todo!() + } + + pub fn with_successors(&mut self, succs: I) + where + S: Into, + I: IntoIterator, + { + todo!() + } + + /// Set the operands given to this op + pub fn with_operands(&mut self, operands: I) + where + I: IntoIterator, + { + // TODO: Verify the safety of this conversion + let owner = self.as_operation_ref(); + let mut op = self.op.borrow_mut(); + let operands = operands.into_iter().enumerate().map(|(index, value)| { + self.context + .alloc_tracked(OpOperandImpl::new(value, owner.clone(), index as u8)) + }); + let op_mut = op.as_operation_mut(); + op_mut.operands.clear(); + op_mut.operands.extend(operands); + } + + pub fn with_operands_in_group(&mut self, group: usize, operands: I) + where + I: IntoIterator, + { + let owner = self.as_operation_ref(); + let mut op = self.op.borrow_mut(); + let operands = operands.into_iter().enumerate().map(|(index, value)| { + self.context + .alloc_tracked(OpOperandImpl::new(value, owner.clone(), index as u8)) + }); + let op_operands = op.operands_mut(); + op_operands.push_operands_to_group(group, operands); + } + + /// Allocate `n` results for this op, of unknown type, to be filled in later + pub fn with_results(&mut self, n: usize) { + let owner = self.as_operation_ref(); + let mut op = self.op.borrow_mut(); + let results = + (0..n).map(|idx| self.context.make_result(Type::Unknown, owner.clone(), idx as u8)); + let op_mut = op.as_operation_mut(); + op_mut.results.clear(); + op_mut.results.extend(results); + } + + /// Consume this builder, verify the op, and return a handle to it, or an error if validation + /// failed. + pub fn build(self) -> Result, Report> { + { + let op = self.op.borrow(); + op.as_operation().verify(self.context)?; + } + Ok(self.op) + } + + #[inline] + fn as_operation_ref(&self) -> OperationRef { + let op = self.op.borrow(); + unsafe { + let ptr = op.as_operation() as *const Operation; + OperationRef::from_raw(ptr) + } + } +} diff --git a/hir2/src/ir/operation/name.rs b/hir2/src/ir/operation/name.rs new file mode 100644 index 000000000..fd462b375 --- /dev/null +++ b/hir2/src/ir/operation/name.rs @@ -0,0 +1,38 @@ +use core::fmt; + +use crate::{interner, DialectName}; + +/// The operation name, or mnemonic, that uniquely identifies an operation. +/// +/// The operation name consists of its dialect name, and the opcode name within the dialect. +/// +/// No two operation names can share the same fully-qualified operation name. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct OperationName { + /// The dialect of this operation + pub dialect: DialectName, + /// The opcode name for this operation + pub name: interner::Symbol, +} +impl OperationName { + pub fn new(dialect: DialectName, name: S) -> Self + where + S: Into, + { + Self { + dialect, + name: name.into(), + } + } +} +impl fmt::Debug for OperationName { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} +impl fmt::Display for OperationName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}.{}", &self.dialect, &self.name) + } +} diff --git a/hir2/src/ir/region.rs b/hir2/src/ir/region.rs index c28896175..f47499b9a 100644 --- a/hir2/src/ir/region.rs +++ b/hir2/src/ir/region.rs @@ -8,6 +8,7 @@ pub type RegionCursor<'a> = EntityCursor<'a, Region>; /// A mutable cursor in a [RegionList] pub type RegionCursorMut<'a> = EntityCursorMut<'a, Region>; +#[derive(Default)] pub struct Region { /// The operation this region is attached to. /// @@ -17,14 +18,44 @@ pub struct Region { body: BlockList, } impl Region { + /// Returns true if this region is empty (has no blocks) + pub fn is_empty(&self) -> bool { + self.body.is_empty() + } + /// Get the defining [Operation] for this region, if the region is attached to one. pub fn parent(&self) -> Option { self.owner.clone() } + /// Set the owner of this region. + /// + /// Returns the previous owner. + /// + /// # Safety + /// + /// It is dangerous to set this field unless doing so as part of allocating the [Region] or + /// moving the [Region] from one op to another. If it is set to a different entity than actually + /// owns the region, it will result in undefined behavior or panics when we attempt to access + /// the owner via the region. + /// + /// You must ensure that the owner given _actually_ owns the region. Similarly, if you are + /// unsetting the owner, you must ensure that no entity _thinks_ it owns this region. + pub unsafe fn set_owner(&mut self, owner: Option) -> Option { + match owner { + None => self.owner.take(), + Some(owner) => self.owner.replace(owner), + } + } + /// Get a handle to the entry block for this region pub fn entry(&self) -> EntityRef<'_, Block> { - self.body.front().get().unwrap() + self.body.front().into_borrow().unwrap() + } + + /// Get a mutable handle to the entry block for this region + pub fn entry_mut(&mut self) -> EntityMut<'_, Block> { + self.body.front_mut().into_borrow_mut().unwrap() } /// Get the list of blocks comprising the body of this region diff --git a/hir2/src/ir/successor.rs b/hir2/src/ir/successor.rs new file mode 100644 index 000000000..16c405cf3 --- /dev/null +++ b/hir2/src/ir/successor.rs @@ -0,0 +1,23 @@ +use core::fmt; + +use crate::{BlockOperandRef, OpOperand}; + +/// TODO: +/// +/// * Replace usage of OpSuccessor with BlockOperand +/// * Store OpSuccessor operands in OpOperandStorage in groups per BlockOperand + +/// An [OpSuccessor] is a BlockOperand + OpOperands for that block, attached to an Operation +#[derive(Clone)] +pub struct OpSuccessor { + pub block: BlockOperandRef, + pub args: smallvec::SmallVec<[OpOperand; 1]>, +} +impl fmt::Debug for OpSuccessor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpSuccessor") + .field("block", &self.block.borrow().block_id()) + .field("args", &self.args) + .finish() + } +} diff --git a/hir2/src/ir/symbol_table.rs b/hir2/src/ir/symbol_table.rs index 39add5b0a..988141c91 100644 --- a/hir2/src/ir/symbol_table.rs +++ b/hir2/src/ir/symbol_table.rs @@ -1,4 +1,159 @@ -use core::any::Any; +use alloc::collections::VecDeque; +use core::fmt; + +use crate::{ + define_attr_type, interner, InsertionPoint, Op, Operation, OperationRef, Report, Searcher, + UnsafeIntrusiveEntityRef, Usable, Visibility, +}; + +/// Represents the name of a [Symbol] in its local [SymbolTable] +pub type SymbolName = interner::Symbol; + +#[derive(Debug, Copy, Clone)] +pub struct SymbolNameAttr { + /// The path through the abstract symbol space to the containing symbol table + /// + /// It is assumed that all symbol tables are also symbols themselves, and thus the path to + /// `name` is formed from the names of all parent symbol tables, in hierarchical order. + /// + /// For example, consider a program consisting of a single component `@test_component`, + /// containing a module `@foo`, which in turn contains a function `@a`. The `path` for `@a` + /// would be `@test_component::@foo`, and `name` would be `@a`. + /// + /// If set to `interner::symbols::Empty`, the symbol `name` is in the global namespace. + /// + /// If set to any other value, then we recover the components of the path by splitting the + /// value on `::`. If not present, the path represents a single namespace. If multiple parts + /// are present, then each part represents a nested namespace starting from the global one. + pub path: SymbolName, + /// The name of the symbol + pub name: SymbolName, +} +define_attr_type!(SymbolNameAttr); +impl SymbolNameAttr { + #[inline(always)] + pub const fn name(&self) -> SymbolName { + self.name + } + + #[inline(always)] + pub const fn path(&self) -> SymbolName { + self.path + } + + /// Returns true if this symbol name is fully-qualified + pub fn is_absolute(&self) -> bool { + self.path.as_str().starts_with("::") + } + + #[inline] + pub fn has_parent(&self) -> bool { + self.path != interner::symbols::Empty + } + + pub fn components(&self) -> impl Iterator { + SymbolNameComponents::new(self.path, self.name) + } +} +impl fmt::Display for SymbolNameAttr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.has_parent() { + write!(f, "{}::{}", &self.path, &self.name) + } else { + f.write_str(self.name.as_str()) + } + } +} +impl Eq for SymbolNameAttr {} +impl PartialEq for SymbolNameAttr { + fn eq(&self, other: &Self) -> bool { + self.path == other.path && self.name == other.name + } +} +impl PartialOrd for SymbolNameAttr { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for SymbolNameAttr { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.path.cmp(&other.path).then_with(|| self.name.cmp(&other.name)) + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum SymbolNameComponent { + /// A component that signals the path is relative to the root symbol table + Root, + /// A component of the symbol name path + Component(SymbolName), + /// The name of the symbol in its local symbol table + Leaf(SymbolName), +} + +struct SymbolNameComponents { + parts: VecDeque<&'static str>, + name: SymbolName, + done: bool, +} +impl SymbolNameComponents { + fn new(path: SymbolName, name: SymbolName) -> Self { + let mut parts = VecDeque::default(); + if path == interner::symbols::Empty { + return Self { + parts, + name, + done: false, + }; + } + let mut split = path.as_str().split("::"); + let start = split.next().unwrap(); + if start.is_empty() { + parts.push_back("::"); + } + + while let Some(part) = split.next() { + if part.is_empty() { + if let Some(part2) = split.next() { + if part2.is_empty() { + parts.push_back("::"); + } else { + parts.push_back(part2); + } + } else { + break; + } + } else { + parts.push_back(part); + } + } + + Self { + parts, + name, + done: false, + } + } +} +impl core::iter::FusedIterator for SymbolNameComponents {} +impl Iterator for SymbolNameComponents { + type Item = SymbolNameComponent; + + fn next(&mut self) -> Option { + if self.done { + return None; + } + if let Some(part) = self.parts.pop_front() { + if part == "::" { + return Some(SymbolNameComponent::Root); + } + return Some(SymbolNameComponent::Component(part.into())); + } + self.done = true; + Some(SymbolNameComponent::Leaf(self.name)) + } +} /// A [SymbolTable] is an IR entity which contains other IR entities, called _symbols_, each of /// which has a name, aka symbol, that uniquely identifies it amongst all other entities in the @@ -10,21 +165,48 @@ use core::any::Any; /// associated `Key` type, and a [Symbol] has an associated `Id` type - only types whose `Id` /// type matches the `Key` type of the [SymbolTable], can be stored in that table. pub trait SymbolTable { - /// The unique key type associated with entries in this symbol table - type Key; - /// The value type of an entry in the symbol table - type Entry; + /// Get a reference to the underlying [Operation] + fn as_operation(&self) -> &Operation; + + /// Get a mutable reference to the underlying [Operation] + fn as_operation_mut(&mut self) -> &mut Operation; - /// Get the entry for `id` in this table - fn get(&self, id: &Self::Key) -> Option; + /// Get the entry for `name` in this table + fn get(&self, name: SymbolName) -> Option; - /// Insert `entry` in the symbol table. + /// Insert `entry` in the symbol table, but only if no other symbol with the same name exists. + /// + /// If provided, the symbol will be inserted at the given insertion point in the body of the + /// symbol table operation. + /// + /// This function will panic if the symbol is attached to another symbol table. /// - /// Returns `true` if successful, or `false` if an entry already exists - fn insert(&mut self, entry: Self::Entry) -> bool; + /// Returns `true` if successful, `false` if the symbol is already defined + fn insert_new(&mut self, entry: SymbolRef, ip: Option) -> bool; - /// Remove the symbol `id`, and return the entry if one was present. - fn remove(&mut self, id: &Self::Key) -> Option; + /// Like [SymbolTable::insert_new], except the symbol is renamed to avoid collisions. + /// + /// Returns the name of the symbol after insertion. + fn insert(&mut self, entry: SymbolRef, ip: Option) -> SymbolName; + + /// Remove the symbol `name`, and return the entry if one was present. + fn remove(&mut self, name: SymbolName) -> Option; + + /// Renames the symbol named `from`, as `to`, as well as all uses of that symbol. + /// + /// Returns `Err` if unable to update all uses. + fn rename(&mut self, from: SymbolName, to: SymbolName) -> Result<(), Report>; +} + +impl dyn SymbolTable { + /// Look up a symbol with the given name and concrete type, returning `None` if no such symbol + /// exists + pub fn find(&self, name: SymbolName) -> Option> { + let op = self.get(name)?; + let op = op.borrow(); + let op = op.as_operation().downcast_ref::()?; + Some(unsafe { UnsafeIntrusiveEntityRef::from_raw(op) }) + } } /// A [Symbol] is an IR entity with an associated _symbol_, or name, which is expected to be unique @@ -33,8 +215,244 @@ pub trait SymbolTable { /// For example, functions are named, and are expected to be unique within the same module, /// otherwise it would not be possible to unambiguously refer to a function by name. Likewise /// with modules in a program, etc. -pub trait Symbol: Any { - type Id: Copy + Clone + PartialEq + Eq + PartialOrd + Ord; +pub trait Symbol: Usable + 'static { + fn as_operation(&self) -> &Operation; + fn as_operation_mut(&mut self) -> &mut Operation; + /// Get the name of this symbol + fn name(&self) -> SymbolName; + /// Set the name of this symbol + fn set_name(&mut self, name: SymbolName); + /// Get the visibility of this symbol + fn visibility(&self) -> Visibility; + /// Returns true if this symbol has private visibility + fn is_private(&self) -> bool; + /// Returns true if this symbol has public visibility + fn is_public(&self) -> bool; + /// Sets the visibility of this symbol + fn set_visibility(&mut self, visibility: Visibility); + /// Sets the visibility of this symbol to private + fn set_private(&mut self) { + self.set_visibility(Visibility::Private); + } + /// Sets the visibility of this symbol to nested + fn set_nested(&mut self) { + self.set_visibility(Visibility::Nested); + } + /// Sets the visibility of this symbol to public + fn set_public(&mut self) { + self.set_visibility(Visibility::Public); + } + /// Get all of the uses of this symbol that are nested within `from` + fn symbol_uses(&self, from: OperationRef) -> SymbolUseIter; + /// Return true if there are no uses of this symbol nested within `from` + fn symbol_uses_known_empty(&self, from: OperationRef) -> SymbolUseIter; + /// Attempt to replace all uses of this symbol nested within `from`, with the provided replacement + fn replace_all_uses(&self, replacement: SymbolRef, from: OperationRef) -> Result<(), Report>; + /// Returns true if this operation can be discarded if it has no remaining symbol uses + /// + /// By default, if the visibility is non-public, a symbol is considered discardable + fn can_discard_when_unused(&self) -> bool { + !self.is_public() + } + /// Returns true if this operation is a declaration, rather than a definition, of a symbol + /// + /// The default implementation assumes that all operations are definitions + fn is_declaration(&self) -> bool { + false + } +} + +impl dyn Symbol { + pub fn is(&self) -> bool { + let op = self.as_operation(); + op.is::() + } + + pub fn downcast_ref(&self) -> Option<&T> { + let op = self.as_operation(); + op.downcast_ref::() + } + + pub fn downcast_mut(&mut self) -> Option<&mut T> { + let op = self.as_operation_mut(); + op.downcast_mut::() + } + + /// Get an [OperationRef] for the operation underlying this symbol + /// + /// NOTE: This relies on the assumption that all ops are allocated via the arena, and that all + /// [Symbol] implementations are ops. + pub fn as_operation_ref(&self) -> OperationRef { + unsafe { OperationRef::from_raw(self.as_operation()) } + } +} + +impl Operation { + /// Returns true if this operation implements [Symbol] + #[inline] + pub fn is_symbol(&self) -> bool { + self.implements::() + } + + /// Get this operation as a [Symbol], if this operation implements the trait. + #[inline] + pub fn as_symbol(&self) -> Option<&dyn Symbol> { + self.as_trait::() + } + + /// Returns the nearest [SymbolTable] from this operation. + /// + /// Returns `None` if no parent of this operation is a valid symbol table. + pub fn nearest_symbol_table(&self) -> Option { + let mut parent = self.parent_op(); + while let Some(parent_op) = parent.take() { + let op = parent_op.borrow(); + if op.implements::() { + drop(op); + return Some(parent_op); + } + parent = op.parent_op(); + } + None + } + + /// Returns the operation registered with the given symbol name within the closest symbol table + /// including `self`. + /// + /// Returns `None` if the symbol is not found. + pub fn nearest_symbol(&self, symbol: SymbolName) -> Option { + if let Some(sym) = self.as_symbol() { + if sym.name() == symbol { + return Some(unsafe { UnsafeIntrusiveEntityRef::from_raw(sym) }); + } + } + let symbol_table_op = self.nearest_symbol_table()?; + let op = symbol_table_op.borrow(); + let symbol_table = op.as_trait::().unwrap(); + symbol_table.get(symbol) + } + + /// Walks all symbol table operations nested within this operation, including itself. + /// + /// For each symbol table operation, the provided callback is invoked with the op and a boolean + /// signifying if the symbols within that symbol table can be treated as if all uses within the + /// IR are visible to the caller. + pub fn walk_symbol_tables(&self, all_symbol_uses_visible: bool, mut callback: F) + where + F: FnMut(&dyn Symbol, bool), + { + use core::ops::ControlFlow; + + let visitor = move |op: &dyn Symbol| { + callback(op, all_symbol_uses_visible); + ControlFlow::<()>::Continue(()) + }; + + let op = self.as_operation_ref(); + let mut searcher = Searcher::new(op, visitor); + + searcher.visit(); + } +} + +pub type SymbolRef = UnsafeIntrusiveEntityRef; + +impl crate::Verify for T +where + T: Op + Symbol, +{ + fn verify(&self, context: &super::Context) -> Result<(), Report> { + verify_symbol(self, context) + } +} + +impl crate::Verify for Operation { + fn should_verify(&self, _context: &super::Context) -> bool { + self.implements::() + } + + fn verify(&self, context: &super::Context) -> Result<(), Report> { + verify_symbol( + self.as_trait::() + .expect("this operation does not implement the `Symbol` trait"), + context, + ) + } +} + +fn verify_symbol(symbol: &dyn Symbol, context: &super::Context) -> Result<(), Report> { + use midenc_session::diagnostics::{Severity, Spanned}; + + // Symbols must either have no parent, or be an immediate child of a SymbolTable + let op = symbol.as_operation(); + let parent = op.parent_op(); + if !parent.is_none_or(|parent| parent.borrow().implements::()) { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label(op.span(), "expected parent of this operation to be a symbol table") + .with_help("required due to this operation implementing the 'Symbol' trait") + .into_report()); + } + Ok(()) +} + +pub type SymbolUseRef = UnsafeIntrusiveEntityRef; +pub type SymbolUseList = crate::EntityList; +pub type SymbolUseIter<'a> = crate::EntityIter<'a, SymbolUse>; +pub type SymbolUseCursor<'a> = crate::EntityCursor<'a, SymbolUse>; +pub type SymbolUseCursorMut<'a> = crate::EntityCursorMut<'a, SymbolUse>; + +/// An [OpOperand] represents a use of a [Value] by an [Operation] +pub struct SymbolUse { + /// The user of the symbol + pub owner: OperationRef, + /// The symbol used + pub symbol: SymbolNameAttr, +} +impl SymbolUse { + #[inline] + pub fn new(owner: OperationRef, symbol: SymbolNameAttr) -> Self { + Self { owner, symbol } + } +} +impl fmt::Debug for SymbolUse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SymbolUse") + .field("symbol", &self.symbol) + .finish_non_exhaustive() + } +} + +/// Generate a unique symbol name. +/// +/// Iteratively increase `counter` and use it as a suffix for symbol names until `is_unique` does +/// not detect any conflict. +pub fn generate_symbol_name(name: SymbolName, counter: &mut usize, is_unique: F) -> SymbolName +where + F: Fn(&str) -> bool, +{ + use core::fmt::Write; + + use crate::SmallStr; + + if is_unique(name.as_str()) { + return name; + } + + let base_len = name.as_str().len(); + let mut buf = SmallStr::with_capacity(base_len + 2); + buf.push_str(name.as_str()); + loop { + *counter += 1; + buf.truncate(base_len); + buf.push('_'); + write!(&mut buf, "{counter}").unwrap(); - fn id(&self) -> Self::Id; + if is_unique(buf.as_str()) { + break SymbolName::intern(buf); + } + } } diff --git a/hir2/src/ir/traits.rs b/hir2/src/ir/traits.rs index 4bff9295b..39e441d99 100644 --- a/hir2/src/ir/traits.rs +++ b/hir2/src/ir/traits.rs @@ -1,8 +1,11 @@ +mod callable; mod multitrait; +mod types; use midenc_session::diagnostics::Severity; pub(crate) use self::multitrait::MultiTraitVtable; +pub use self::{callable::*, types::*}; use crate::{derive, Context, Operation, Report, Spanned}; /// Marker trait for commutative ops, e.g. `X op Y == Y op X` @@ -79,61 +82,6 @@ derive! { } } -derive! { - /// Op expects all operands to be of the same type - pub trait SameTypeOperands {} - - verify { - fn operands_are_the_same_type(op: &Operation, context: &Context) -> Result<(), Report> { - if let Some((first_operand, operands)) = op.operands().split_first() { - let (expected_ty, set_by) = { - let operand = first_operand.borrow(); - let value = operand.value(); - (value.ty().clone(), value.span()) - }; - for operand in operands { - let operand = operand.borrow(); - let value = operand.value(); - let value_ty = value.ty(); - if value_ty != &expected_ty { - return Err(context - .session - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid operation") - .with_primary_label( - op.span(), - "this operation expects all operands to be of the same type" - ) - .with_secondary_label( - set_by, - "inferred the expected type from this value" - ) - .with_secondary_label( - value.span(), - "which differs from this value" - ) - .with_help(format!("expected '{expected_ty}', got '{value_ty}'")) - .into_report() - ); - } - } - } - - Ok(()) - } - } -} - -derive! { - /// Op expects all operands and results to be of the same type - /// - /// TODO(pauls): Implement verification for this. Ideally we could require `SameTypeOperands` - /// as a super trait, check the operands using its implementation, and then check the results - /// separately - pub trait SameOperandsAndResultType {} -} - derive! { /// Op's regions have no arguments pub trait NoRegionArguments {} @@ -186,3 +134,28 @@ derive! { } } } + +derive! { + /// Op has a single region + pub trait SingleRegion {} + + verify { + fn has_exactly_one_region(op: &Operation, context: &Context) -> Result<(), Report> { + let num_regions = op.num_regions(); + if num_regions != 1 { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + format!("this operation requires exactly one region, but got {num_regions}") + ) + .into_report()); + } + + Ok(()) + } + } +} diff --git a/hir2/src/ir/traits/callable.rs b/hir2/src/ir/traits/callable.rs new file mode 100644 index 000000000..f45278680 --- /dev/null +++ b/hir2/src/ir/traits/callable.rs @@ -0,0 +1,117 @@ +use crate::{ + EntityRef, OpOperandRange, OpOperandRangeMut, RegionRef, Signature, SymbolNameAttr, SymbolRef, + Value, ValueRef, +}; + +/// A call-like operation is one that transfers control from one function to another. +/// +/// These operations may be traditional static calls, e.g. `call @foo`, or indirect calls, e.g. +/// `call_indirect v1`. An operation that uses this interface cannot _also_ implement the +/// `CallableOpInterface`. +pub trait CallOpInterface { + /// Get the callee of this operation. + /// + /// A callee is either a symbol, or a reference to an SSA value. + fn callable_for_callee(&self) -> Callable; + /// Sets the callee for this operation. + fn set_callee(&mut self, callable: Callable); + /// Get the operands of this operation that are used as arguments for the callee + fn arguments(&self) -> OpOperandRange<'_>; + /// Get a mutable reference to the operands of this operation that are used as arguments for the + /// callee + fn arguments_mut(&mut self) -> OpOperandRangeMut<'_>; + /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None` + /// if a valid callable was not resolved, using the provided symbol table. + /// + /// This method is used to perform callee resolution using a cached symbol table, rather than + /// traversing the operation hierarchy looking for symbol tables to try resolving with. + fn resolve_in_symbol_table(&self, symbols: &dyn crate::SymbolTable) -> Option; + /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None` + /// if a valid callable was not resolved. + fn resolve(&self) -> Option; +} + +/// A callable operation is one who represents a potential function, and may be a target for a call- +/// like operation (i.e. implementations of `CallOpInterface`). These operations may be traditional +/// function ops (i.e. `Function`), as well as function reference-producing operations, such as an +/// op that creates closures, or captures a function by reference. +/// +/// These operations may only contain a single region. +pub trait CallableOpInterface { + /// Returns the region on the current operation that is callable. + /// + /// This may return `None` in the case of an external callable object, e.g. an externally- + /// defined function reference. + fn get_callable_region(&self) -> Option; + /// Returns the signature of the callable + fn signature(&self) -> &Signature; +} + +/// A [Callable] represents a symbol or a value which can be used as a valid _callee_ for a +/// [CallOpInterface] implementation. +/// +/// Symbols are not SSA values, but there are situations where we want to treat them as one, such +/// as indirect calls. Abstracting over whether the callable is a symbol or an SSA value allows us +/// to focus on the call semantics, rather than the difference between the type types of value. +#[derive(Debug, Clone)] +pub enum Callable { + Symbol(SymbolNameAttr), + Value(ValueRef), +} +impl From<&SymbolNameAttr> for Callable { + fn from(value: &SymbolNameAttr) -> Self { + Self::Symbol(*value) + } +} +impl From for Callable { + fn from(value: SymbolNameAttr) -> Self { + Self::Symbol(value) + } +} +impl From for Callable { + fn from(value: ValueRef) -> Self { + Self::Value(value) + } +} +impl Callable { + #[inline(always)] + pub fn new(callable: impl Into) -> Self { + callable.into() + } + + pub fn is_symbol(&self) -> bool { + matches!(self, Self::Symbol(_)) + } + + pub fn is_value(&self) -> bool { + matches!(self, Self::Value(_)) + } + + pub fn as_symbol_name(&self) -> Option<&SymbolNameAttr> { + match self { + Self::Symbol(ref name) => Some(name), + _ => None, + } + } + + pub fn as_value(&self) -> Option> { + match self { + Self::Value(ref value_ref) => Some(value_ref.borrow()), + _ => None, + } + } + + pub fn unwrap_symbol_name(self) -> SymbolNameAttr { + match self { + Self::Symbol(name) => name, + Self::Value(value_ref) => panic!("expected symbol, got {}", value_ref.borrow().id()), + } + } + + pub fn unwrap_value_ref(self) -> ValueRef { + match self { + Self::Value(value) => value, + Self::Symbol(ref name) => panic!("expected value, got {name}"), + } + } +} diff --git a/hir2/src/ir/traits/multitrait.rs b/hir2/src/ir/traits/multitrait.rs index c71d98f37..d1b352620 100644 --- a/hir2/src/ir/traits/multitrait.rs +++ b/hir2/src/ir/traits/multitrait.rs @@ -82,6 +82,12 @@ impl MultiTraitVtable { } } + #[allow(unused)] + #[inline] + pub const fn data_ptr(&self) -> *mut () { + self.data + } + pub(crate) unsafe fn set_data_ptr(&mut self, ptr: *mut T) { assert!(!ptr.is_null()); assert!(ptr.is_aligned()); diff --git a/hir2/src/ir/traits/types.rs b/hir2/src/ir/traits/types.rs new file mode 100644 index 000000000..6166f4675 --- /dev/null +++ b/hir2/src/ir/traits/types.rs @@ -0,0 +1,489 @@ +use core::fmt; + +use midenc_session::diagnostics::Severity; + +use crate::{derive, Context, Operation, Report, Spanned}; + +/// OpInterface to compute the return type(s) of an operation. +pub trait InferTypeOpInterface { + /// Run type inference for this op's results, using the current state, and apply any changes. + /// + /// Returns an error if unable to infer types, or if some type constraint is violated. + fn infer_types(&mut self) -> Result<(), Report>; +} + +derive! { + /// Op expects all operands to be of the same type + pub trait SameTypeOperands {} + + verify { + fn operands_are_the_same_type(op: &Operation, context: &Context) -> Result<(), Report> { + let mut operands = op.operands().iter(); + if let Some(first_operand) = operands.next() { + let (expected_ty, set_by) = { + let operand = first_operand.borrow(); + let value = operand.value(); + (value.ty().clone(), value.span()) + }; + + for operand in operands { + let operand = operand.borrow(); + let value = operand.value(); + let value_ty = value.ty(); + if value_ty != &expected_ty { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation expects all operands to be of the same type" + ) + .with_secondary_label( + set_by, + "inferred the expected type from this value" + ) + .with_secondary_label( + value.span(), + "which differs from this value" + ) + .with_help(format!("expected '{expected_ty}', got '{value_ty}'")) + .into_report() + ); + } + } + } + + Ok(()) + } + } +} + +derive! { + /// Op expects all operands and results to be of the same type + /// + /// TODO(pauls): Implement verification for this. Ideally we could require `SameTypeOperands` + /// as a super trait, check the operands using its implementation, and then check the results + /// separately + pub trait SameOperandsAndResultType {} +} + +/// An operation trait that indicates it expects a variable number of operands, matching the given +/// type constraint, i.e. zero or more of the base type. +pub trait Variadic {} + +impl crate::Verify> for T +where + T: crate::Op + Variadic, +{ + fn verify(&self, context: &Context) -> Result<(), Report> { + self.as_operation().verify(context) + } +} +impl crate::Verify> for Operation { + fn should_verify(&self, _context: &Context) -> bool { + self.implements::>() + } + + fn verify(&self, context: &Context) -> Result<(), Report> { + for operand in self.operands().iter() { + let operand = operand.borrow(); + let value = operand.value(); + let ty = value.ty(); + if ::matches(ty) { + continue; + } else { + let description = ::description(); + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand") + .with_primary_label( + value.span(), + format!("expected operand type to be {description}, but got {ty}"), + ) + .into_report()); + } + } + + Ok(()) + } +} + +pub trait TypeConstraint: 'static { + fn description() -> impl fmt::Display; + fn matches(ty: &crate::Type) -> bool; + fn check(ty: &crate::Type) -> Result<(), Report> { + if Self::matches(ty) { + Ok(()) + } else { + let expected = Self::description(); + Err(Report::msg(format!("expected {expected}, got '{ty}'"))) + } + } +} + +/// A type that can be constructed as a [crate::Type] +pub trait BuildableTypeConstraint: TypeConstraint { + fn build() -> crate::Type; +} + +macro_rules! type_constraint { + ($Constraint:ident, $description:literal, $matcher:literal) => { + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct $Constraint; + impl TypeConstraint for $Constraint { + #[inline(always)] + fn description() -> impl core::fmt::Display { + $description + } + + #[inline(always)] + fn matches(_ty: &$crate::Type) -> bool { + $matcher + } + } + }; + + ($Constraint:ident, $description:literal, $matcher:path) => { + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct $Constraint; + impl TypeConstraint for $Constraint { + #[inline(always)] + fn description() -> impl core::fmt::Display { + $description + } + + #[inline(always)] + fn matches(ty: &$crate::Type) -> bool { + $matcher(ty) + } + } + }; +} + +type_constraint!(AnyType, "any type", true); +// TODO(pauls): Extend Type with new Function variant, we'll use that to represent function handles +//type_constraint!(AnyFunction, "a function type", crate::Type::is_function); +type_constraint!(AnyList, "any list type", crate::Type::is_list); +type_constraint!(AnyArray, "any array type", crate::Type::is_array); +type_constraint!(AnyStruct, "any struct type", crate::Type::is_struct); +type_constraint!(AnyPointer, "a pointer type", crate::Type::is_pointer); +type_constraint!(AnyInteger, "an integral type", crate::Type::is_integer); +type_constraint!(AnySignedInteger, "a signed integral type", crate::Type::is_signed_integer); +type_constraint!( + AnyUnsignedInteger, + "an unsigned integral type", + crate::Type::is_unsigned_integer +); +type_constraint!(IntFelt, "a field element", crate::Type::is_felt); + +/// A signless 8-bit integer +pub type Int8 = SizedInt<8>; +/// A signed 8-bit integer +pub type SInt8 = And>; +/// An unsigned 8-bit integer +pub type UInt8 = And>; + +/// A signless 16-bit integer +pub type Int16 = SizedInt<16>; +/// A signed 16-bit integer +pub type SInt16 = And>; +/// An unsigned 16-bit integer +pub type UInt16 = And>; + +/// A signless 32-bit integer +pub type Int32 = SizedInt<32>; +/// A signed 16-bit integer +pub type SInt32 = And>; +/// An unsigned 16-bit integer +pub type UInt32 = And>; + +/// A signless 64-bit integer +pub type Int64 = SizedInt<64>; +/// A signed 64-bit integer +pub type SInt64 = And>; +/// An unsigned 64-bit integer +pub type UInt64 = And>; + +impl BuildableTypeConstraint for IntFelt { + fn build() -> crate::Type { + crate::Type::Felt + } +} +impl BuildableTypeConstraint for UInt8 { + fn build() -> crate::Type { + crate::Type::U8 + } +} +impl BuildableTypeConstraint for SInt8 { + fn build() -> crate::Type { + crate::Type::I8 + } +} +impl BuildableTypeConstraint for UInt16 { + fn build() -> crate::Type { + crate::Type::U16 + } +} +impl BuildableTypeConstraint for SInt16 { + fn build() -> crate::Type { + crate::Type::I16 + } +} +impl BuildableTypeConstraint for UInt32 { + fn build() -> crate::Type { + crate::Type::U32 + } +} +impl BuildableTypeConstraint for SInt32 { + fn build() -> crate::Type { + crate::Type::I32 + } +} +impl BuildableTypeConstraint for UInt64 { + fn build() -> crate::Type { + crate::Type::U64 + } +} +impl BuildableTypeConstraint for SInt64 { + fn build() -> crate::Type { + crate::Type::I64 + } +} + +/// Represents a fixed-width integer of `N` bits +pub struct SizedInt(core::marker::PhantomData<[(); N]>); +impl Copy for SizedInt {} +impl Clone for SizedInt { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for SizedInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl fmt::Display for SizedInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{N}-bit integral type") + } +} +impl TypeConstraint for SizedInt { + fn description() -> impl fmt::Display { + Self(core::marker::PhantomData) + } + + fn matches(ty: &crate::Type) -> bool { + ty.is_integer() + } +} +impl BuildableTypeConstraint for SizedInt<8> { + fn build() -> crate::Type { + crate::Type::I8 + } +} +impl BuildableTypeConstraint for SizedInt<16> { + fn build() -> crate::Type { + crate::Type::I16 + } +} +impl BuildableTypeConstraint for SizedInt<32> { + fn build() -> crate::Type { + crate::Type::I32 + } +} +impl BuildableTypeConstraint for SizedInt<64> { + fn build() -> crate::Type { + crate::Type::I64 + } +} + +/// A type constraint for pointer values +pub struct PointerOf(core::marker::PhantomData); +impl Copy for PointerOf {} +impl Clone for PointerOf { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for PointerOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl fmt::Display for PointerOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let pointee = ::description(); + write!(f, "a pointer to {pointee}") + } +} +impl TypeConstraint for PointerOf { + #[inline(always)] + fn description() -> impl fmt::Display { + Self(core::marker::PhantomData) + } + + fn matches(ty: &crate::Type) -> bool { + ty.pointee().is_some_and(|pointee| ::matches(pointee)) + } +} +impl BuildableTypeConstraint for PointerOf { + fn build() -> crate::Type { + let pointee = Box::new(::build()); + crate::Type::Ptr(pointee) + } +} + +/// A type constraint for array values +pub struct AnyArrayOf(core::marker::PhantomData); +impl Copy for AnyArrayOf {} +impl Clone for AnyArrayOf { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for AnyArrayOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl fmt::Display for AnyArrayOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let element = ::description(); + write!(f, "an array of {element}") + } +} +impl TypeConstraint for AnyArrayOf { + #[inline(always)] + fn description() -> impl fmt::Display { + Self(core::marker::PhantomData) + } + + fn matches(ty: &crate::Type) -> bool { + match ty { + crate::Type::Array(ref elem, _) => ::matches(elem), + _ => false, + } + } +} + +/// A type constraint for array values +pub struct ArrayOf(core::marker::PhantomData<[T; N]>); +impl Copy for ArrayOf {} +impl Clone for ArrayOf { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for ArrayOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl fmt::Display for ArrayOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let element = ::description(); + write!(f, "an array of {N} {element}") + } +} +impl TypeConstraint for ArrayOf { + #[inline(always)] + fn description() -> impl fmt::Display { + Self(core::marker::PhantomData) + } + + fn matches(ty: &crate::Type) -> bool { + match ty { + crate::Type::Array(ref elem, arity) if *arity == N => { + ::matches(elem) + } + _ => false, + } + } +} +impl BuildableTypeConstraint for ArrayOf { + fn build() -> crate::Type { + let element = Box::new(::build()); + crate::Type::Array(element, N) + } +} + +/// Represents a conjunction of two constraints as a concrete value +pub struct And { + _left: core::marker::PhantomData, + _right: core::marker::PhantomData, +} +impl Copy for And {} +impl Clone for And { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for And { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl TypeConstraint for And { + fn description() -> impl fmt::Display { + struct Both { + left: L, + right: R, + } + impl fmt::Display for Both { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "both {} and {}", &self.left, &self.right) + } + } + let left = ::description(); + let right = ::description(); + Both { left, right } + } + + #[inline] + fn matches(ty: &crate::Type) -> bool { + ::matches(ty) && ::matches(ty) + } +} + +/// Represents a disjunction of two constraints as a concrete value +pub struct Or { + _left: core::marker::PhantomData, + _right: core::marker::PhantomData, +} +impl Copy for Or {} +impl Clone for Or { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for Or { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl TypeConstraint for Or { + fn description() -> impl fmt::Display { + struct Either { + left: L, + right: R, + } + impl fmt::Display for Either { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "either {} or {}", &self.left, &self.right) + } + } + let left = ::description(); + let right = ::description(); + Either { left, right } + } + + #[inline] + fn matches(ty: &crate::Type) -> bool { + ::matches(ty) || ::matches(ty) + } +} diff --git a/hir2/src/ir/usable.rs b/hir2/src/ir/usable.rs index ce0ff554a..315c32e10 100644 --- a/hir2/src/ir/usable.rs +++ b/hir2/src/ir/usable.rs @@ -25,19 +25,33 @@ pub trait Usable { /// The type associated with each unique use, e.g. `OpOperand` type Use; - /// Returns true if this definition is used - fn is_used(&self) -> bool; /// Get a list of uses of this definition fn uses(&self) -> &EntityList; /// Get a mutable list of uses of this definition fn uses_mut(&mut self) -> &mut EntityList; + + /// Returns true if this definition is used + #[inline] + fn is_used(&self) -> bool { + !self.uses().is_empty() + } /// Get an iterator over the uses of this definition - fn iter_uses(&self) -> EntityIter<'_, Self::Use>; + #[inline] + fn iter_uses(&self) -> EntityIter<'_, Self::Use> { + self.uses().iter() + } /// Get a cursor positioned on the first use of this definition, or the null cursor if unused. - fn first_use(&self) -> EntityCursor<'_, Self::Use>; + fn first_use(&self) -> EntityCursor<'_, Self::Use> { + self.uses().front() + } /// Get a mutable cursor positioned on the first use of this definition, or the null cursor if /// unused. - fn first_use_mut(&mut self) -> EntityCursorMut<'_, Self::Use>; + #[inline] + fn first_use_mut(&mut self) -> EntityCursorMut<'_, Self::Use> { + self.uses_mut().front_mut() + } /// Add `user` to the set of uses of this definition - fn insert_use(&mut self, user: UnsafeIntrusiveEntityRef); + fn insert_use(&mut self, user: UnsafeIntrusiveEntityRef) { + self.uses_mut().push_back(user); + } } diff --git a/hir2/src/ir/value.rs b/hir2/src/ir/value.rs index c8ebe55f1..e812c111a 100644 --- a/hir2/src/ir/value.rs +++ b/hir2/src/ir/value.rs @@ -2,6 +2,7 @@ use core::fmt; use super::*; +/// A unique identifier for a [Value] in the IR #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct ValueId(u32); @@ -31,11 +32,20 @@ impl fmt::Display for ValueId { } } +/// Represents an SSA value in the IR. +/// +/// The data underlying a [Value] represents a _definition_, and thus implements [Usable]. The users +/// of a [Value] are operands (see [OpOperandImpl]). Operands are associated with an operation. Thus +/// the graph formed of the edges between values and operations via operands forms the data-flow +/// graph of the program. pub trait Value: Entity + Spanned + Usable + fmt::Debug { + /// Get the type of this value fn ty(&self) -> &Type; + /// Set the type of this value fn set_type(&mut self, ty: Type); } +/// Generates the boilerplate for a concrete [Value] type. macro_rules! value_impl { ( $(#[$outer:meta])* @@ -104,11 +114,6 @@ macro_rules! value_impl { impl Usable for $ValueKind { type Use = OpOperandImpl; - #[inline] - fn is_used(&self) -> bool { - !self.uses.is_empty() - } - #[inline(always)] fn uses(&self) -> &OpOperandList { &self.uses @@ -118,25 +123,6 @@ macro_rules! value_impl { fn uses_mut(&mut self) -> &mut OpOperandList { &mut self.uses } - - #[inline] - fn iter_uses(&self) -> OpOperandIter<'_> { - self.uses.iter() - } - - #[inline] - fn first_use(&self) -> OpOperandCursor<'_> { - self.uses.front() - } - - #[inline] - fn first_use_mut(&mut self) -> OpOperandCursorMut<'_> { - self.uses.front_mut() - } - - fn insert_use(&mut self, user: OpOperand) { - self.uses.push_back(user); - } } impl fmt::Debug for $ValueKind { @@ -159,8 +145,11 @@ macro_rules! value_impl { } } +/// A pointer to a [Value] pub type ValueRef = UnsafeEntityRef; +/// A pointer to a [BlockArgument] pub type BlockArgumentRef = UnsafeEntityRef; +/// A pointer to a [OpResult] pub type OpResultRef = UnsafeEntityRef; value_impl!( @@ -180,74 +169,25 @@ value_impl!( ); impl BlockArgument { + /// Get the [Block] to which this [BlockArgument] belongs pub fn owner(&self) -> BlockRef { self.owner.clone() } + /// Get the index of this argument in the argument list of the owning [Block] pub fn index(&self) -> usize { self.index as usize } } impl OpResult { + /// Get the [Operation] to which this [OpResult] belongs pub fn owner(&self) -> OperationRef { self.owner.clone() } + /// Get the index of this result in the result list of the owning [Operation] pub fn index(&self) -> usize { self.index as usize } } - -pub type OpOperand = UnsafeIntrusiveEntityRef; -pub type OpOperandList = EntityList; -pub type OpOperandIter<'a> = EntityIter<'a, OpOperandImpl>; -pub type OpOperandCursor<'a> = EntityCursor<'a, OpOperandImpl>; -pub type OpOperandCursorMut<'a> = EntityCursorMut<'a, OpOperandImpl>; - -/// An [OpOperand] represents a use of a [Value] by an [Operation] -pub struct OpOperandImpl { - /// The operand value - pub value: ValueRef, - /// The owner of this operand, i.e. the operation it is an operand of - pub owner: OperationRef, - /// The index of this operand in the operand list of an operation - pub index: u8, -} -impl OpOperandImpl { - #[inline] - pub fn new(value: ValueRef, owner: OperationRef, index: u8) -> Self { - Self { - value, - owner, - index, - } - } - - pub fn value(&self) -> EntityRef<'_, dyn Value> { - self.value.borrow() - } -} -impl fmt::Debug for OpOperandImpl { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - #[derive(Debug)] - #[allow(unused)] - struct ValueInfo<'a> { - id: ValueId, - ty: &'a Type, - } - - let value = self.value.borrow(); - let id = value.id(); - let ty = value.ty(); - f.debug_struct("OpOperand") - .field("index", &self.index) - .field("value", &ValueInfo { id, ty }) - .finish_non_exhaustive() - } -} - -pub enum OpOperandValue { - Value(ValueRef), - Immediate(Immediate), -} diff --git a/hir2/src/ir/verifier.rs b/hir2/src/ir/verifier.rs index 9fbaad483..e3bcb2599 100644 --- a/hir2/src/ir/verifier.rs +++ b/hir2/src/ir/verifier.rs @@ -34,8 +34,8 @@ pub trait OpVerifier { /// /// The way this works is as follows: /// -/// * We `impl Verify for T where T: Op` for every trait `Trait` with validation rules -/// * A blanket impl of [HasVerifier] exists for all `T: Verify`. This is a market trait used +/// * We `impl Verify for T where T: Op` for every trait `Trait` with validation rules. +/// * A blanket impl of [HasVerifier] exists for all `T: Verify`. This is a marker trait used /// in conjunction with specialization. See the trait docs for more details on its purpose. /// * The [Verifier] trait provides a default vacuous impl for all `Trait` and `T` pairs. However, /// we also provided a specialized [Verifier] impl for all `T: Verify` using the @@ -94,8 +94,9 @@ pub trait Verify { unsafe trait HasVerifier: Verify {} // While at first glance, it appears we would be using this to specialize on the fact that a type -// _has_ a verifier, we're actually using this to specialize on the _absence_ of a verifier. See -// `Verifier` for more information. +// _has_ a verifier, which is strictly-speaking true, the actual goal we're aiming to acheive is +// to be able to identify the _absence_ of a verifier, so that we can eliminate the boilerplate for +// verifying that trait. See `Verifier` for more information. unsafe impl HasVerifier for T where T: Verify {} /// The `Verifier` trait is used to derive a verifier for a given trait and concrete type. @@ -107,8 +108,10 @@ unsafe impl HasVerifier for T where T: Verify {} /// for types which implement `Verify`. /// /// We go a step further and actually set things up so that `rustc` can eliminate all of the dead -/// code when verification is vacuous. See the `trait_verifier` function for details on how that -/// is used in practice. +/// code when verification is vacuous. This is done by using const eval in the hidden type generated +/// for an [Op] impls [OpVerifier] implementation, which will wrap verification in a const-evaluated +/// check of the `VACUOUS` associated const. It can also be used directly, but the general idea +/// behind all of this is that we don't need to directly touch any of this, it's all generated. /// /// NOTE: Because this trait provides a default blanket impl for all `T`, you should avoid bringing /// it into scope unless absolutely needed. It is virtually always preferred to explicitly invoke diff --git a/hir2/src/ir/visit.rs b/hir2/src/ir/visit.rs new file mode 100644 index 000000000..b0b09656a --- /dev/null +++ b/hir2/src/ir/visit.rs @@ -0,0 +1,162 @@ +use alloc::collections::VecDeque; +pub use core::ops::ControlFlow; + +use crate::{BlockRef, Op, Operation, OperationRef, Symbol}; + +/// A generic trait that describes visitors for all kinds +pub trait Visitor { + /// The type of output produced by visiting an item. + type Output; + + /// The function which is applied to each `T` as it is visited. + fn visit(&mut self, current: &T) -> ControlFlow; +} + +/// We can automatically convert any closure of appropriate type to a `Visitor` +impl Visitor for F +where + F: FnMut(&T) -> ControlFlow, +{ + type Output = U; + + #[inline] + fn visit(&mut self, op: &T) -> ControlFlow { + self(op) + } +} + +/// Represents a visitor over [Operation] +pub trait OperationVisitor: Visitor {} +impl OperationVisitor for V where V: Visitor {} + +/// Represents a visitor over [Op] of type `T` +pub trait OpVisitor: Visitor {} +impl OpVisitor for V where V: Visitor {} + +/// Represents a visitor over [Symbol] +pub trait SymbolVisitor: Visitor {} +impl SymbolVisitor for V where V: Visitor {} + +/// [Searcher] is a driver for [Visitor] impls as applied to some root [Operation]. +/// +/// It traverses the objects reachable from the root as follows: +/// +/// * The root operation is visited first +/// * Then for each region of the root, the entry block is visited top to bottom, enqueing any nested +/// blocks of those operations to be visited after all blocks of region have been visited. When the +/// entry block has been visited, the process is repeated for the remaining blocks of the region. +/// * When all regions of the root have been visited, and no more blocks remain in the queue, the +/// traversal is complete +/// +/// This traversal is _not_ in control flow order, _or_ data flow order, so you should not rely on +/// the order in which operations are visited for your [Visitor] implementation. +pub struct Searcher { + visitor: V, + queue: VecDeque, + current: Option, + started: bool, + _marker: core::marker::PhantomData, +} +impl> Searcher { + pub fn new(root: OperationRef, visitor: V) -> Self { + Self { + visitor, + queue: VecDeque::default(), + current: Some(root), + started: false, + _marker: core::marker::PhantomData, + } + } + + #[inline] + fn next(&mut self) -> Option { + visit_next(&mut self.started, &mut self.current, &mut self.queue) + } +} + +impl Searcher { + pub fn visit(&mut self) -> ControlFlow<>::Output> { + while let Some(op) = self.next() { + let op = op.borrow(); + self.visitor.visit(&op)?; + } + + ControlFlow::Continue(()) + } +} + +impl> Searcher { + pub fn visit(&mut self) -> ControlFlow<>::Output> { + while let Some(op) = self.next() { + let op = op.borrow(); + if let Some(op) = op.downcast_ref::() { + self.visitor.visit(op)?; + } + } + + ControlFlow::Continue(()) + } +} + +impl Searcher { + pub fn visit(&mut self) -> ControlFlow<>::Output> { + while let Some(op) = self.next() { + let op = op.borrow(); + if let Some(op) = op.as_symbol() { + self.visitor.visit(op)?; + } + } + + ControlFlow::Continue(()) + } +} + +/// Outlined implementation of the traversal performed by `Searcher` +#[inline(never)] +fn visit_next( + started: &mut bool, + current: &mut Option, + queue: &mut VecDeque, +) -> Option { + if !*started { + *started = true; + let curr = current.take()?; + // When just starting, we're at the root, so we descend into the operation, rather + // than visiting its next sibling. + { + let op = curr.borrow(); + for region in op.regions().iter() { + let mut cursor = region.body().front(); + if current.is_none() { + let entry = cursor.as_pointer().expect("invalid region: has no entry block"); + let entry = entry.borrow(); + let next = entry.body().front().as_pointer(); + *current = next; + cursor.move_next(); + } + while let Some(block) = cursor.as_pointer() { + queue.push_back(block); + cursor.move_next(); + } + } + } + return Some(curr); + } + + // Here, we've already visited the root operation, so one of the following is true: + // + // * `current` is `None`, so pop the next block from the queue, if there are no more blocks, + // then we're done visiting and can return `None`. If there is a block, then we set + // `current` to the first operation in that block, and retry + // * `current` is `Some`, so obtain the next value of `current` by obtaining the next + // sibling operation of the current operation. + while current.is_none() { + let block = queue.pop_front()?; + let block = block.borrow(); + *current = block.body().front().as_pointer(); + } + + let next = current.as_ref().and_then(|curr| curr.next()); + + core::mem::replace(current, next) +} diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index b5a574538..553f0fc88 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -8,6 +8,8 @@ #![feature(specialization)] #![feature(rustc_attrs)] #![feature(debug_closure_helpers)] +#![feature(trait_alias)] +#![feature(is_none_or)] #![allow(incomplete_features)] #![allow(internal_features)] @@ -16,6 +18,10 @@ extern crate alloc; #[cfg(feature = "std")] extern crate std; +pub use compact_str::{ + CompactString as SmallStr, CompactStringExt as SmallStrExt, ToCompactString as ToSmallStr, +}; + mod attributes; pub mod demangle; pub mod derive; @@ -24,3 +30,18 @@ pub mod formatter; mod ir; pub use self::{attributes::*, ir::*}; + +// TODO(pauls): The following is a rough list of what needs to be implemented for the IR +// refactoring to be complete and usable in place of the old IR (some are optional): +// +// * constants and constant-like ops +// * global variables and global ops +// * Builders (i.e. component builder, interface builder, module builder, function builder, last is most important) +// NOTE: The underlying builder infra is basically done, so layering on the high-level builders is pretty simple +// * canonicalization (optional) +// * visitors (partially complete, need CFG and DFG walkers as well though, largely variations on the existing infra) +// * pattern matching/rewrites (needed for legalization/conversion) +// * dataflow analysis framework (required to replace old analyses) +// * linking/global symbol resolution (required to replace old linker, partially implemented via symbols/symbol tables already) +// * legalization/dialect conversion (required to convert between unstructured and structured control flow dialects at minimum) +// * lowering From 0057afcdeacbed02d7b166140ba3ec0371824fee Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Oct 2024 16:35:46 -0400 Subject: [PATCH 06/31] feat: implement #[operation] macro --- Cargo.lock | 1 + hir-macros/Cargo.toml | 1 + hir-macros/src/lib.rs | 81 +- hir-macros/src/operation.rs | 2665 +++++++++++++++++++++++++++++++++++ 4 files changed, 2729 insertions(+), 19 deletions(-) create mode 100644 hir-macros/src/operation.rs diff --git a/Cargo.lock b/Cargo.lock index 052381346..749bed599 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3374,6 +3374,7 @@ name = "midenc-hir-macros" version = "0.0.6" dependencies = [ "Inflector", + "darling", "proc-macro2", "quote", "syn 2.0.77", diff --git a/hir-macros/Cargo.toml b/hir-macros/Cargo.toml index c65f18312..c9311110c 100644 --- a/hir-macros/Cargo.toml +++ b/hir-macros/Cargo.toml @@ -17,6 +17,7 @@ proc-macro = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +darling = { version = "0.20", features = ["diagnostics"] } Inflector.workspace = true proc-macro2 = "1.0" quote = "1.0" diff --git a/hir-macros/src/lib.rs b/hir-macros/src/lib.rs index db4b85f59..f6ca222c4 100644 --- a/hir-macros/src/lib.rs +++ b/hir-macros/src/lib.rs @@ -1,10 +1,10 @@ extern crate proc_macro; -//mod op; +mod operation; mod spanned; use inflector::cases::kebabcase::to_kebab_case; -use quote::quote; +use quote::{format_ident, quote}; use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Error, Ident, Token}; #[proc_macro_derive(Spanned, attributes(span))] @@ -26,27 +26,70 @@ pub fn derive_spanned(input: proc_macro::TokenStream) -> proc_macro::TokenStream } } -/// #[derive(Op)] -/// #[op(name = "select", interfaces(BranchOpInterface))] -/// pub struct Select { -/// #[operation] -/// op: Operation, +/// #[operation( +/// dialect = HirDialect, +/// traits(Terminator), +/// implements(BranchOpInterface), +/// )] +/// pub struct Switch { /// #[operand] -/// selector: OpOperand, +/// selector: UInt32, +/// #[successors(keyed)] +/// cases: SwitchArm, +/// #[successor] +/// fallback: Successor, +/// } /// +/// pub struct Call { +/// #[attr] +/// callee: Symbol, +/// #[operands] +/// arguments: Vec, +/// #[results] +/// results: Vec, /// } -/* -#[proc_macro_derive(Op, attributes(op, operation, operand, result, successor, region, interfaces))] -pub fn derive_op(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - // Parse into syntax tree - let derive = parse_macro_input!(input as DeriveInput); - let op = match op::Op::from_derive_input(derive) { - Ok(op) => op, - Err(err) => err.to_compile_error().into(), +/// +/// #[operation] +/// pub struct If { +/// #[operand] +/// condition: Bool, +/// #[region] +/// then_region: RegionRef, +/// #[region] +/// else_region: RegionRef, +/// } +#[proc_macro_attribute] +pub fn operation( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let attr = proc_macro2::TokenStream::from(attr); + let mut input = syn::parse_macro_input!(item as syn::ItemStruct); + let span = input.span(); + + // Reconstruct the input so we can treat this like a derive macro + // + // We can't _actually_ use derive, because we need to modify the item itself. + input.attrs.push(syn::Attribute { + pound_token: syn::token::Pound(span), + style: syn::AttrStyle::Outer, + bracket_token: syn::token::Bracket(span), + meta: syn::Meta::List(syn::MetaList { + path: syn::parse_str("operation").unwrap(), + delimiter: syn::MacroDelimiter::Paren(syn::token::Paren(span)), + tokens: attr, + }), + }); + + let input = syn::parse_quote! { + #input }; - quote!(#op).into() + + match operation::derive_operation(input) { + Ok(token_stream) => proc_macro::TokenStream::from(token_stream), + Err(err) => err.write_errors().into(), + } } - */ #[proc_macro_derive(PassInfo)] pub fn derive_pass_info(item: proc_macro::TokenStream) -> proc_macro::TokenStream { @@ -59,7 +102,7 @@ pub fn derive_pass_info(item: proc_macro::TokenStream) -> proc_macro::TokenStrea let pass_name = to_kebab_case(&name); let pass_name_lit = syn::Lit::Str(syn::LitStr::new(&pass_name, id.span())); - let doc_ident = syn::Ident::new("doc", derive_span); + let doc_ident = format_ident!("doc", span = derive_span); let docs = derive_input .attrs .iter() diff --git a/hir-macros/src/operation.rs b/hir-macros/src/operation.rs new file mode 100644 index 000000000..08c46b5e8 --- /dev/null +++ b/hir-macros/src/operation.rs @@ -0,0 +1,2665 @@ +use std::rc::Rc; + +use darling::{ + util::{Flag, SpannedValue}, + Error, FromDeriveInput, FromField, FromMeta, +}; +use inflector::Inflector; +use quote::{format_ident, quote, ToTokens}; +use syn::{spanned::Spanned, Ident, Token}; + +pub fn derive_operation(input: syn::DeriveInput) -> darling::Result { + let op = OpDefinition::from_derive_input(&input)?; + + Ok(op.into_token_stream()) +} + +/// This struct represents the fully parsed and prepared definition of an operation, along with all +/// of its associated items, trait impls, etc. +pub struct OpDefinition { + /// The span of the original item decorated with `#[operation]` + span: proc_macro2::Span, + /// The name of the dialect type corresponding to the dialect this op belongs to + dialect: Ident, + /// The type name of the concrete `Op` implementation, i.e. the item with `#[operation]` on it + name: Ident, + /// The name of the operation in the textual form of the IR, e.g. `Add` would be `add`. + opcode: Ident, + /// The set of paths corresponding to the op traits we need to generate impls for + traits: darling::util::PathList, + /// The set of paths corresponding to the op traits manually implemented by this op + implements: darling::util::PathList, + /// The named regions declared for this op + regions: Vec, + /// The named attributes declared for this op + attrs: Vec, + /// The named operands, and operand groups, declared for this op + /// + /// Sequential individually named operands are collected into an "unnamed" operand group, i.e. + /// the group is not named, only the individual operands. Conversely, each "named" operand group + /// can refer to the group by name, but not the individual operands. + operands: Vec, + /// The named results of this operation + /// + /// An operation can have no results, one or more individually named results, or a single named + /// result group, but not a combination. + results: Option, + /// The named successors, and successor groups, declared for this op. + /// + /// This is represented almost identically to `operands`, except we also support successor + /// groups with "keyed" items represented by an implementation of the `KeyedSuccessor` trait. + /// Keyed successor groups are handled a bit differently than "normal" successor groups in terms + /// of the types expected by the op builder for this type. + successors: Vec, + symbols: Vec, + /// The struct definition + op: syn::ItemStruct, + /// The implementation of `{Op}Builder` for this op. + op_builder_impl: OpBuilderImpl, + /// The implementation of `OpVerifier` for this op. + op_verifier_impl: OpVerifierImpl, +} +impl OpDefinition { + /// Initialize an [OpDefinition] from the parsed [Operation] received as input + fn from_operation(span: proc_macro2::Span, op: &mut Operation) -> darling::Result { + let dialect = op.dialect.clone(); + let name = op.ident.clone(); + let opcode = op.name.clone().unwrap_or_else(|| { + let name = name.to_string().to_snake_case(); + let name = name.strip_suffix("Op").unwrap_or(name.as_str()); + format_ident!("{name}", span = name.span()) + }); + let traits = core::mem::take(&mut op.traits); + let implements = core::mem::take(&mut op.implements); + + let fields = core::mem::replace( + &mut op.data, + darling::ast::Data::Struct(darling::ast::Fields::new( + darling::ast::Style::Struct, + vec![], + )), + ) + .take_struct() + .unwrap(); + + let mut named_fields = syn::punctuated::Punctuated::::new(); + // Add the `op` field (which holds the underlying Operation) + named_fields.push(syn::Field { + attrs: vec![], + vis: syn::Visibility::Inherited, + mutability: syn::FieldMutability::None, + ident: Some(format_ident!("op")), + colon_token: Some(syn::token::Colon(span)), + ty: make_type("::midenc_hir2::Operation"), + }); + + let op = syn::ItemStruct { + attrs: core::mem::take(&mut op.attrs), + vis: op.vis.clone(), + struct_token: syn::token::Struct(span), + ident: name.clone(), + generics: core::mem::take(&mut op.generics), + fields: syn::Fields::Named(syn::FieldsNamed { + brace_token: syn::token::Brace(span), + named: named_fields, + }), + semi_token: None, + }; + + let op_builder_impl = OpBuilderImpl::empty(name.clone()); + let op_verifier_impl = + OpVerifierImpl::new(name.clone(), traits.clone(), implements.clone()); + + let mut this = Self { + span, + dialect, + name, + opcode, + traits, + implements, + regions: vec![], + attrs: vec![], + operands: vec![], + results: None, + successors: vec![], + symbols: vec![], + op, + op_builder_impl, + op_verifier_impl, + }; + + this.hydrate(fields)?; + + Ok(this) + } + + fn hydrate(&mut self, fields: darling::ast::Fields) -> darling::Result<()> { + let named_fields = match &mut self.op.fields { + syn::Fields::Named(syn::FieldsNamed { ref mut named, .. }) => named, + _ => unreachable!(), + }; + let mut create_params = vec![]; + let (_, mut fields) = fields.split(); + + // Compute the absolute ordering of op parameters as follows: + // + // * By default, the ordering is implied by the order of field declarations in the struct + // * A field can be decorated with #[order(N)], where `N` is an absolute index + // * If all fields have an explicit order, then the sort following that order is used + // * If a mix of fields have explicit ordering, so as to acheive a particular struct layout, + // then the implicit order given to a field ensures that it appears after the highest + // ordered field which comes before it in the struct. For example, if I have the following + // pseudo-struct definition: `{ #[order(2)] a, b, #[order(1)] c, d }`, then the actual + // order of the parameters corresponding to those fields will be `c`, `a`, `b`, `d`. This + // is due to the fact that a.) `b` is assigned an index of `3` because it is the next + // available index following `2`, which was assigned to `a` before it in the struct, and + // 2.) `d` is assigned an index of `4`, as it is the next highest available index after + // `2`, which is the highest explicitly ordered field that is defined before it in the + // struct. + let mut assigned_highwater = 0; + let mut highwater = 0; + let mut claimed_indices = fields.iter().filter_map(|f| f.attrs.order).collect::>(); + claimed_indices.sort(); + claimed_indices.dedup(); + for field in fields.iter_mut() { + match field.attrs.order { + // If this order precedes a previous #[order] field, skip it + Some(order) if highwater > order => continue, + Some(order) => { + // Move high water mark to `order` + highwater = order; + } + None => { + // Find the next unused index > `highwater` && `assigned_highwater` + assigned_highwater = core::cmp::max(assigned_highwater, highwater); + let mut next_index = assigned_highwater + 1; + while claimed_indices.contains(&next_index) { + next_index += 1; + } + assigned_highwater = next_index; + field.attrs.order = Some(next_index); + } + } + } + fields.sort_by_key(|field| field.attrs.order); + + for field in fields { + let field_name = field.ident.clone().unwrap(); + let field_span = field_name.span(); + let field_ty = field.ty.clone(); + + let op_field_ty = field.attrs.pseudo_type(); + match op_field_ty.as_deref() { + // Forwarded field + None => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::CustomField(field_name.clone(), field_ty), + r#default: field.attrs.default.is_present(), + }); + named_fields.push(syn::Field { + attrs: field.attrs.forwarded, + vis: field.vis, + mutability: syn::FieldMutability::None, + ident: Some(field_name), + colon_token: Some(syn::token::Colon(field_span)), + ty: field.ty, + }); + continue; + } + Some(OperationFieldType::Attr) => { + let attr = OpAttribute { + name: field_name, + ty: field_ty, + }; + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Attr(attr.clone()), + r#default: field.attrs.default.is_present(), + }); + self.attrs.push(attr); + } + Some(OperationFieldType::Operand) => { + let operand = Operand { + name: field_name.clone(), + constraint: field_ty, + }; + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Operand(operand.clone()), + r#default: field.attrs.default.is_present(), + }); + match self.operands.last_mut() { + None => { + self.operands.push(OpOperandGroup::Unnamed(vec![operand])); + } + Some(OpOperandGroup::Unnamed(ref mut operands)) => { + operands.push(operand); + } + Some(OpOperandGroup::Named(..)) => { + // Start a new group + self.operands.push(OpOperandGroup::Unnamed(vec![operand])); + } + } + } + Some(OperationFieldType::Operands) => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::OperandGroup( + field_name.clone(), + field_ty.clone(), + ), + r#default: field.attrs.default.is_present(), + }); + self.operands.push(OpOperandGroup::Named(field_name, field_ty)); + } + Some(OperationFieldType::Result) => { + let result = OpResult { + name: field_name.clone(), + constraint: field_ty, + }; + match self.results.as_mut() { + None => { + self.results = Some(OpResultGroup::Unnamed(vec![result])); + } + Some(OpResultGroup::Unnamed(ref mut results)) => { + results.push(result); + } + Some(OpResultGroup::Named(..)) => { + return Err(Error::custom("#[result] and #[results] cannot be mixed") + .with_span(&field_name)); + } + } + } + Some(OperationFieldType::Results) => match self.results.as_mut() { + None => { + self.results = Some(OpResultGroup::Named(field_name, field_ty)); + } + Some(OpResultGroup::Unnamed(_)) => { + return Err(Error::custom("#[result] and #[results] cannot be mixed") + .with_span(&field_name)); + } + Some(OpResultGroup::Named(..)) => { + return Err(Error::custom("#[results] may only appear on a single field") + .with_span(&field_name)); + } + }, + Some(OperationFieldType::Region) => { + self.regions.push(field_name); + } + Some(OperationFieldType::Successor) => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Successor(field_name.clone()), + r#default: field.attrs.default.is_present(), + }); + match self.successors.last_mut() { + None => { + self.successors.push(SuccessorGroup::Unnamed(vec![field_name])); + } + Some(SuccessorGroup::Unnamed(ref mut ids)) => { + ids.push(field_name); + } + Some(SuccessorGroup::Named(_) | SuccessorGroup::Keyed(..)) => { + // Start a new group + self.successors.push(SuccessorGroup::Unnamed(vec![field_name])); + } + } + } + Some(OperationFieldType::Successors(SuccessorsType::Default)) => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::SuccessorGroupNamed(field_name.clone()), + r#default: field.attrs.default.is_present(), + }); + self.successors.push(SuccessorGroup::Named(field_name)); + } + Some(OperationFieldType::Successors(SuccessorsType::Keyed)) => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::SuccessorGroupKeyed( + field_name.clone(), + field_ty.clone(), + ), + r#default: field.attrs.default.is_present(), + }); + self.successors.push(SuccessorGroup::Keyed(field_name, field_ty)); + } + Some(OperationFieldType::Symbol(None)) => { + let symbol = Symbol { + name: field_name, + ty: SymbolType::Concrete(field_ty), + }; + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Symbol(symbol.clone()), + r#default: field.attrs.default.is_present(), + }); + self.symbols.push(symbol); + } + Some(OperationFieldType::Symbol(Some(ty))) => { + let symbol = Symbol { + name: field_name, + ty: ty.clone(), + }; + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Symbol(symbol.clone()), + r#default: field.attrs.default.is_present(), + }); + self.symbols.push(symbol); + } + } + } + + self.op_builder_impl.set_create_params(&self.op.generics, create_params); + + Ok(()) + } +} +impl FromDeriveInput for OpDefinition { + fn from_derive_input(input: &syn::DeriveInput) -> darling::Result { + let span = input.span(); + let mut operation = Operation::from_derive_input(input)?; + Self::from_operation(span, &mut operation) + } +} + +struct OpCreateFn<'a> { + op: &'a OpDefinition, + generics: syn::Generics, +} +impl<'a> OpCreateFn<'a> { + pub fn new(op: &'a OpDefinition) -> Self { + // Op::create generic parameters + let generics = syn::Generics { + lt_token: Some(syn::token::Lt(op.span)), + params: syn::punctuated::Punctuated::from_iter( + [syn::parse_str("B: ?Sized + ::midenc_hir2::Builder").unwrap()] + .into_iter() + .chain(op.op_builder_impl.buildable_op_impl.generics.params.iter().cloned()), + ), + gt_token: Some(syn::token::Gt(op.span)), + where_clause: op.op_builder_impl.buildable_op_impl.generics.where_clause.clone(), + }; + + Self { op, generics } + } +} + +struct WithAttrs<'a>(&'a OpDefinition); +impl quote::ToTokens for WithAttrs<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for param in self.0.op_builder_impl.create_params.iter() { + if let OpCreateParamType::Attr(OpAttribute { name, .. }) = ¶m.param_ty { + let field_name = syn::Lit::Str(syn::LitStr::new(&format!("{name}"), name.span())); + tokens.extend(quote! { + op_builder.with_attr(#field_name, #name); + }); + } + } + } +} + +struct WithSymbols<'a>(&'a OpDefinition); +impl quote::ToTokens for WithSymbols<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for param in self.0.op_builder_impl.create_params.iter() { + if let OpCreateParamType::Symbol(Symbol { name, ty }) = ¶m.param_ty { + let field_name = syn::Lit::Str(syn::LitStr::new(&format!("{name}"), name.span())); + match ty { + SymbolType::Any | SymbolType::Concrete(_) | SymbolType::Trait(_) => { + tokens.extend(quote! { + op_builder.with_symbol(#field_name, #name); + }); + } + SymbolType::Callable => { + tokens.extend(quote! { + op_builder.with_callable_symbol(#field_name, #name); + }); + } + } + } + } + } +} + +struct WithOperands<'a>(&'a OpDefinition); +impl quote::ToTokens for WithOperands<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for (group_index, group) in self.0.operands.iter().enumerate() { + match group { + OpOperandGroup::Unnamed(operands) => { + let group_index = syn::Lit::Int(syn::LitInt::new( + &format!("{group_index}usize"), + operands[0].name.span(), + )); + let operand_name = operands.iter().map(|o| &o.name).collect::>(); + let operand_constraint = operands.iter().map(|o| &o.constraint); + let constraint_violation = operands.iter().map(|o| { + syn::Lit::Str(syn::LitStr::new( + &format!("type constraint violation for '{}'", &o.name), + o.name.span(), + )) + }); + tokens.extend(quote! { + #( + { + let value = #operand_name.borrow(); + let value_ty = value.ty(); + if !<#operand_constraint as ::midenc_hir2::traits::TypeConstraint>::matches(value_ty) { + let expected = <#operand_constraint as ::midenc_hir2::traits::TypeConstraint>::description(); + return Err(builder.context() + .session + .diagnostics + .diagnostic(::midenc_session::diagnostics::Severity::Error) + .with_message("invalid operand") + .with_primary_label(span, #constraint_violation) + .with_secondary_label(value.span(), format!("this value has type '{value_ty}', but expected '{expected}'")) + .into_report()); + } + } + )* + op_builder.with_operands_in_group(#group_index, [#(#operand_name),*]); + }); + } + OpOperandGroup::Named(group_name, group_constraint) => { + let group_index = syn::Lit::Int(syn::LitInt::new( + &format!("{group_index}usize"), + group_name.span(), + )); + let constraint_violation = syn::Lit::Str(syn::LitStr::new( + &format!("type constraint violation for operand in '{group_name}'"), + group_name.span(), + )); + tokens.extend(quote! { + let #group_name = #group_name.into_iter().collect::<::alloc::vec::Vec<_>>(); + for operand in #group_name.iter() { + let value = operand.borrow(); + let value_ty = value.ty(); + if !<#group_constraint as ::midenc_hir2::traits::TypeConstraint>::matches(value_ty) { + let expected = <#group_constraint as ::midenc_hir2::traits::TypeConstraint>::description(); + return Err(builder.context() + .session + .diagnostics + .diagnostic(::midenc_session::diagnostics::Severity::Error) + .with_message("invalid operand") + .with_primary_label(span, #constraint_violation) + .with_secondary_label(value.span(), format!("this value has type '{value_ty}', but expected '{expected}'")) + .into_report()); + } + } + op_builder.with_operands_in_group(#group_index, #group_name); + }); + } + } + } + } +} + +struct InitializeCustomFields<'a>(&'a OpDefinition); +impl quote::ToTokens for InitializeCustomFields<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for param in self.0.op_builder_impl.create_params.iter() { + if let OpCreateParamType::CustomField(id, ..) = ¶m.param_ty { + tokens.extend(quote! { + core::ptr::addr_of_mut!((*__ptr).#id).write(#id); + }); + } + } + } +} + +struct WithResults<'a>(&'a OpDefinition); +impl quote::ToTokens for WithResults<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + match self.0.results.as_ref() { + None => (), + Some(OpResultGroup::Unnamed(results)) => { + let num_results = syn::Lit::Int(syn::LitInt::new( + &format!("{}usize", results.len()), + results[0].name.span(), + )); + tokens.extend(quote! { + op_builder.with_results(#num_results); + }); + } + // Named result groups can have an arbitrary number of results + Some(OpResultGroup::Named(..)) => (), + } + } +} + +struct WithSuccessors<'a>(&'a OpDefinition); +impl quote::ToTokens for WithSuccessors<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for group in self.0.successors.iter() { + match group { + SuccessorGroup::Unnamed(successors) => { + let successor_args = successors.iter().map(|s| format_ident!("{s}_args")); + tokens.extend(quote! { + op_builder.with_successors([ + #(( + #successors, + #successor_args.into_iter().collect::<::alloc::vec::Vec<_>>(), + ),)* + ]); + }); + } + SuccessorGroup::Named(name) => { + tokens.extend(quote! { + op_builder.with_successors(#name); + }); + } + SuccessorGroup::Keyed(name, _) => { + tokens.extend(quote! { + op_builder.with_keyed_successors(#name); + }); + } + } + } + } +} + +struct BuildOp<'a>(&'a OpDefinition); +impl quote::ToTokens for BuildOp<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + match self.0.results.as_ref() { + None => { + tokens.extend(quote! { + op_builder.build() + }); + } + Some(group) => { + let verify_result_constraints = match group { + OpResultGroup::Unnamed(results) => { + let verify_result = results.iter().map(|result| { + let result_name = &result.name; + let result_constraint = &result.constraint; + let constraint_violation = syn::Lit::Str(syn::LitStr::new(&format!("type constraint violation for result '{result_name}'"), result_name.span())); + quote! { + { + let op_result = op.#result_name(); + let value_ty = op_result.ty(); + if !<#result_constraint as ::midenc_hir2::traits::TypeConstraint>::matches(value_ty) { + let expected = <#result_constraint as ::midenc_hir2::traits::TypeConstraint>::description(); + return Err(builder.context() + .session + .diagnostics + .diagnostic(::midenc_session::diagnostics::Severity::Error) + .with_message("invalid operation") + .with_primary_label(span, #constraint_violation) + .with_secondary_label(op_result.span(), format!("this value has type '{value_ty}', but expected '{expected}'")) + .into_report()); + } + } + } + }); + quote! { + #( + #verify_result + )* + } + } + OpResultGroup::Named(name, constraint) => { + let constraint_violation = syn::Lit::Str(syn::LitStr::new( + &format!("type constraint violation for result in '{name}'"), + name.span(), + )); + quote! { + { + let results = op.#name(); + for result in results.iter() { + let value = result.borrow(); + let value_ty = value.ty(); + if !<#constraint as ::midenc_hir2::traits::TypeConstraint>::matches(value_ty) { + let expected = <#constraint as ::midenc_hir2::traits::TypeConstraint>::description(); + return Err(builder.context() + .session + .diagnostics + .diagnostic(::midenc_session::diagnostics::Severity::Error) + .with_message("invalid operation") + .with_primary_label(span, #constraint_violation) + .with_secondary_label(value.span(), format!("this value has type '{value_ty}', but expected '{expected}'")) + .into_report()); + } + } + } + } + } + }; + + tokens.extend(quote! { + let op = op_builder.build()?; + + { + let op = op.borrow(); + #verify_result_constraints + } + + Ok(op) + }) + } + } + } +} + +impl quote::ToTokens for OpCreateFn<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let dialect = &self.op.dialect; + let (impl_generics, _, where_clause) = self.generics.split_for_impl(); + let param_names = + self.op.op_builder_impl.create_params.iter().flat_map(OpCreateParam::bindings); + let param_types = self + .op + .op_builder_impl + .create_params + .iter() + .flat_map(OpCreateParam::binding_types); + let traits = &self.op.traits; + let implements = &self.op.implements; + let initialize_custom_fields = InitializeCustomFields(self.op); + let with_symbols = WithSymbols(self.op); + let with_attrs = WithAttrs(self.op); + let with_operands = WithOperands(self.op); + let with_results = WithResults(self.op); + let with_regions = self.op.regions.iter().map(|_| { + quote! { + op_builder.create_region(); + } + }); + let with_successors = WithSuccessors(self.op); + let build_op = BuildOp(self.op); + + tokens.extend(quote! { + /// Manually construct a new [#op_ident] + /// + /// It is generally preferable to use [`::midenc_hir2::Builder::create`] instead. + pub fn create #impl_generics( + builder: &mut B, + span: ::midenc_session::diagnostics::SourceSpan, + #( + #param_names: #param_types, + )* + ) -> Result<::midenc_hir2::UnsafeIntrusiveEntityRef, ::midenc_session::diagnostics::Report> + #where_clause + { + use ::midenc_hir2::{Builder, Op}; + let mut __this = { + let __operation_name = { + let context = builder.context(); + let dialect = context.get_or_register_dialect::<#dialect>(); + let opcode = ::name(); + dialect.get_or_register_op( + opcode, + |dialect_name, opcode| { + ::midenc_hir2::OperationName::new::( + dialect_name, + opcode, + [ + ::midenc_hir2::traits::TraitInfo::new::(), + ::midenc_hir2::traits::TraitInfo::new::(), + #( + ::midenc_hir2::traits::TraitInfo::new::(), + )* + #( + ::midenc_hir2::traits::TraitInfo::new::(), + )* + ] + ) + } + ) + }; + let __context = builder.context_rc(); + let mut __op = __context.alloc_uninit_tracked::(); + unsafe { + { + let mut __uninit = __op.borrow_mut(); + let __ptr = (*__uninit).as_mut_ptr(); + let __offset = core::mem::offset_of!(Self, op); + let __op_ptr = core::ptr::addr_of_mut!((*__ptr).op); + __op_ptr.write(::midenc_hir2::Operation::uninit::(__context, __operation_name, __offset)); + #initialize_custom_fields + } + let mut __this = ::midenc_hir2::UnsafeIntrusiveEntityRef::assume_init(__op); + __this.borrow_mut().set_span(span); + __this + } + }; + + let mut op_builder = ::midenc_hir2::OperationBuilder::new(builder, __this); + #with_attrs + #with_symbols + #with_operands + #( + #with_regions + )* + #with_successors + #with_results + + // Finalize construction of this op, verifying it + #build_op + } + }); + } +} + +impl quote::ToTokens for OpDefinition { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let op_ident = &self.name; + let (impl_generics, ty_generics, where_clause) = self.op.generics.split_for_impl(); + + // struct $Op + self.op.to_tokens(tokens); + + // impl Spanned + tokens.extend(quote! { + impl #impl_generics ::midenc_session::diagnostics::Spanned for #op_ident #ty_generics #where_clause { + fn span(&self) -> ::midenc_session::diagnostics::SourceSpan { + self.op.span() + } + } + }); + + // impl AsRef/AsMut + tokens.extend(quote! { + impl #impl_generics AsRef<::midenc_hir2::Operation> for #op_ident #ty_generics #where_clause { + #[inline(always)] + fn as_ref(&self) -> &::midenc_hir2::Operation { + &self.op + } + } + + impl #impl_generics AsMut<::midenc_hir2::Operation> for #op_ident #ty_generics #where_clause { + #[inline(always)] + fn as_mut(&mut self) -> &mut ::midenc_hir2::Operation { + &mut self.op + } + } + }); + + // impl Op + // impl OpRegistration + let opcode = &self.opcode; + let opcode_str = syn::Lit::Str(syn::LitStr::new(&opcode.to_string(), opcode.span())); + tokens.extend(quote! { + impl #impl_generics ::midenc_hir2::Op for #op_ident #ty_generics #where_clause { + #[inline] + fn name(&self) -> ::midenc_hir2::OperationName { + self.op.name() + } + + #[inline(always)] + fn as_operation(&self) -> &::midenc_hir2::Operation { + &self.op + } + + #[inline(always)] + fn as_operation_mut(&mut self) -> &mut ::midenc_hir2::Operation { + &mut self.op + } + } + + impl #impl_generics ::midenc_hir2::OpRegistration for #op_ident #ty_generics #where_clause { + fn name() -> ::midenc_hir_symbol::Symbol { + ::midenc_hir_symbol::Symbol::intern(#opcode_str) + } + } + }); + + // impl $OpBuilder + // impl BuildableOp + self.op_builder_impl.to_tokens(tokens); + + // impl $Op + { + let create_fn = OpCreateFn::new(self); + let custom_field_fns = OpCustomFieldFns(self); + let attr_fns = OpAttrFns(self); + let symbol_fns = OpSymbolFns(self); + let operand_fns = OpOperandFns(self); + let result_fns = OpResultFns(self); + let region_fns = OpRegionFns(self); + let successor_fns = OpSuccessorFns(self); + tokens.extend(quote! { + /// Construction + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #create_fn + } + + /// User-defined Fields + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #custom_field_fns + } + + /// Attributes + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #attr_fns + } + + /// Symbols + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #symbol_fns + } + + /// Operands + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #operand_fns + } + + /// Results + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #result_fns + } + + /// Regions + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #region_fns + } + + /// Successors + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #successor_fns + } + }); + } + + // impl $DerivedTrait + for derived_trait in self.traits.iter() { + tokens.extend(quote! { + impl #impl_generics #derived_trait for #op_ident #ty_generics #where_clause {} + }); + } + + // impl OpVerifier + self.op_verifier_impl.to_tokens(tokens); + } +} + +struct OpCustomFieldFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpCustomFieldFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // User-defined fields + for field in self.0.op.fields.iter() { + let field_name = field.ident.as_ref().unwrap(); + let field_name_mut = format_ident!("{field_name}_mut"); + let set_field_name = format_ident!("set_{field_name}"); + let field_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the value of `{field_name}`"), + field_name.span(), + )); + let field_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the value of `{field_name}`"), + field_name.span(), + )); + let set_field_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Set the value of `{field_name}`"), + field_name.span(), + )); + let field_ty = &field.ty; + tokens.extend(quote! { + #[doc = #field_doc] + #[inline] + pub fn #field_name(&self) -> &#field_ty { + &self.#field_name + } + + #[doc = #field_mut_doc] + #[inline] + pub fn #field_name_mut(&mut self) -> &mut #field_ty { + &mut self.#field_name + } + + #[doc = #set_field_doc] + #[inline] + pub fn #set_field_name(&mut self, #field_name: #field_ty) { + self.#field_name = #field_name; + } + }); + } + } +} + +struct OpSymbolFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpSymbolFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // Symbols + for Symbol { + name: ref symbol, + ty: ref symbol_kind, + } in self.0.symbols.iter() + { + let span = symbol.span(); + let symbol_str = syn::Lit::Str(syn::LitStr::new(&symbol.to_string(), span)); + let symbol_mut = format_ident!("{symbol}_mut"); + let set_symbol = format_ident!("set_{symbol}"); + let set_symbol_unchecked = format_ident!("set_{symbol}_unchecked"); + let symbol_symbol = format_ident!("{symbol}_symbol"); + let symbol_symbol_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get the symbol under which the `{symbol}` attribute is stored"), + span, + )); + let symbol_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the value of the `{symbol}` attribute."), + span, + )); + let symbol_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the value of the `{symbol}` attribute."), + span, + )); + let set_symbol_doc_lines = [ + syn::Lit::Str(syn::LitStr::new( + &format!(" Set the value of the `{symbol}` symbol."), + span, + )), + syn::Lit::Str(syn::LitStr::new("", span)), + syn::Lit::Str(syn::LitStr::new( + " Returns `Err` if the symbol cannot be resolved in the nearest symbol table.", + span, + )), + ]; + let set_symbol_unchecked_doc_lines = [ + syn::Lit::Str(syn::LitStr::new( + &format!( + " Set the value of the `{symbol}` symbol without attempting to resolve it." + ), + span, + )), + syn::Lit::Str(syn::LitStr::new("", span)), + syn::Lit::Str(syn::LitStr::new( + " Because this does not resolve the given symbol, the caller is responsible \ + for updating the symbol use list.", + span, + )), + ]; + + tokens.extend(quote! { + #[doc = #symbol_symbol_doc] + #[inline(always)] + pub fn #symbol_symbol() -> ::midenc_hir_symbol::Symbol { + ::midenc_hir_symbol::Symbol::intern(#symbol_str) + } + + #[doc = #symbol_doc] + pub fn #symbol(&self) -> &::midenc_hir2::SymbolNameAttr { + self.op.get_typed_attribute(&Self::#symbol_symbol()).unwrap() + } + + #[doc = #symbol_mut_doc] + pub fn #symbol_mut(&mut self) -> &mut ::midenc_hir2::SymbolNameAttr { + self.op.get_typed_attribute_mut(&Self::#symbol_symbol()).unwrap() + } + + #( + #[doc = #set_symbol_unchecked_doc_lines] + )* + pub fn #set_symbol_unchecked(&mut self, value: ::midenc_hir2::SymbolNameAttr) { + self.op.set_attribute(Self::#symbol_symbol(), Some(value)); + } + }); + + let is_concrete_ty = match symbol_kind { + SymbolType::Concrete(ref ty) => [quote! { + // The way we check the type depends on whether `symbol` is a reference to `self` + let (data_ptr, _) = ::midenc_hir2::SymbolRef::as_ptr(&symbol).to_raw_parts(); + if core::ptr::addr_eq(data_ptr, (self as *const Self as *const ())) { + if !self.op.is::<#ty>() { + return Err(::midenc_hir2::InvalidSymbolRefError::InvalidType { + symbol: span, + expected: stringify!(#ty), + got: self.op.name(), + }); + } + } else { + if !symbol.borrow().is::<#ty>() { + return Err(::midenc_hir2::InvalidSymbolRefError::InvalidType { + symbol: span, + expected: stringify!(#ty), + got: symbol.as_symbol_operation().name(), + }); + } + } + }], + _ => [quote! {}], + }; + + match symbol_kind { + SymbolType::Any | SymbolType::Trait(_) | SymbolType::Concrete(_) => { + tokens.extend(quote! { + #( + #[doc = #set_symbol_doc_lines] + )* + pub fn #set_symbol(&mut self, symbol: impl ::midenc_hir2::AsSymbolRef) -> Result<(), ::midenc_hir2::InvalidSymbolRefError> { + let symbol = symbol.as_symbol_ref(); + #(#is_concrete_ty)* + self.op.set_symbol_attribute(Self::#symbol_symbol(), symbol); + + Ok(()) + } + }); + } + SymbolType::Callable => { + tokens.extend(quote! { + #( + #[doc = #set_symbol_doc_lines] + )* + pub fn #set_symbol(&mut self, symbol: impl ::midenc_hir2::traits::AsCallableSymbolRef) -> Result<(), ::midenc_hir2::InvalidSymbolRefError> { + let symbol = symbol.as_callable_symbol_ref(); + let (data_ptr, _) = ::midenc_hir2::SymbolRef::as_ptr(&symbol).to_raw_parts(); + if core::ptr::addr_eq(data_ptr, (self as *const Self as *const ())) { + if !self.op.implements::() { + return Err(::midenc_hir2::InvalidSymbolRefError::NotCallable { + symbol: self.span(), + }); + } + } else { + let symbol = symbol.borrow(); + let symbol_op = symbol.as_symbol_operation(); + if !symbol_op.implements::() { + return Err(::midenc_hir2::InvalidSymbolRefError::NotCallable { + symbol: symbol_op.span(), + }); + } + } + self.op.set_symbol_attribute(Self::#symbol_symbol(), symbol.clone()); + + Ok(()) + } + }); + } + } + } + } +} + +struct OpAttrFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpAttrFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // Attributes + for OpAttribute { + name: ref attr, + ty: ref attr_ty, + } in self.0.attrs.iter() + { + let attr_str = syn::Lit::Str(syn::LitStr::new(&attr.to_string(), attr.span())); + let attr_mut = format_ident!("{attr}_mut"); + let set_attr = format_ident!("set_{attr}"); + let attr_symbol = format_ident!("{attr}_symbol"); + let attr_symbol_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get the symbol under which the `{attr}` attribute is stored"), + attr.span(), + )); + let attr_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the value of the `{attr}` attribute."), + attr.span(), + )); + let attr_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the value of the `{attr}` attribute."), + attr.span(), + )); + let set_attr_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Set the value of the `{attr}` attribute."), + attr.span(), + )); + tokens.extend(quote! { + #[doc = #attr_symbol_doc] + #[inline(always)] + pub fn #attr_symbol() -> ::midenc_hir_symbol::Symbol { + ::midenc_hir_symbol::Symbol::intern(#attr_str) + } + + #[doc = #attr_doc] + pub fn #attr(&self) -> &#attr_ty { + self.op.get_typed_attribute::<#attr_ty, _>(&Self::#attr_symbol()).unwrap() + } + + #[doc = #attr_mut_doc] + pub fn #attr_mut(&mut self) -> &mut #attr_ty { + self.op.get_typed_attribute_mut::<#attr_ty, _>(&Self::#attr_symbol()).unwrap() + } + + #[doc = #set_attr_doc] + pub fn #set_attr(&mut self, value: impl Into<#attr_ty>) { + self.op.set_attribute(Self::#attr_symbol(), Some(value.into())); + } + }); + } + } +} + +struct OpOperandFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpOperandFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for (group_index, operand_group) in self.0.operands.iter().enumerate() { + let group_index = syn::Lit::Int(syn::LitInt::new( + &format!("{group_index}usize"), + proc_macro2::Span::call_site(), + )); + match operand_group { + // Operands + OpOperandGroup::Unnamed(operands) => { + for ( + operand_index, + Operand { + name: ref operand, .. + }, + ) in operands.iter().enumerate() + { + let operand_index = syn::Lit::Int(syn::LitInt::new( + &format!("{operand_index}usize"), + proc_macro2::Span::call_site(), + )); + let operand_mut = format_ident!("{operand}_mut"); + let operand_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{operand}` operand."), + operand.span(), + )); + let operand_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{operand}` operand."), + operand.span(), + )); + tokens.extend(quote!{ + #[doc = #operand_doc] + #[inline] + pub fn #operand(&self) -> ::midenc_hir2::EntityRef<'_, ::midenc_hir2::OpOperandImpl> { + self.op.operands().group(#group_index)[#operand_index].borrow() + } + + #[doc = #operand_mut_doc] + #[inline] + pub fn #operand_mut(&mut self) -> ::midenc_hir2::EntityMut<'_, ::midenc_hir2::OpOperandImpl> { + self.op.operands_mut().group_mut(#group_index)[#operand_index].borrow_mut() + } + }); + } + } + // User-defined operand groups + OpOperandGroup::Named(group_name, _) => { + let group_name_mut = format_ident!("{group_name}_mut"); + let group_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{group_name}` operand group."), + group_name.span(), + )); + let group_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{group_name}` operand group."), + group_name.span(), + )); + tokens.extend(quote! { + #[doc = #group_doc] + #[inline] + pub fn #group_name(&self) -> ::midenc_hir2::OpOperandRange<'_> { + self.op.operands().group(#group_index) + } + + #[doc = #group_mut_doc] + #[inline] + pub fn #group_name_mut(&mut self) -> ::midenc_hir2::OpOperandRangeMut<'_> { + self.op.operands_mut().group_mut(#group_index) + } + }); + } + } + } + } +} + +struct OpResultFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpResultFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + if let Some(group) = self.0.results.as_ref() { + match group { + OpResultGroup::Unnamed(results) => { + for ( + index, + OpResult { + name: ref result, .. + }, + ) in results.iter().enumerate() + { + let index = syn::Lit::Int(syn::LitInt::new( + &format!("{index}usize"), + result.span(), + )); + let result_mut = format_ident!("{result}_mut"); + let result_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{result}` result."), + result.span(), + )); + let result_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{result}` result."), + result.span(), + )); + tokens.extend(quote!{ + #[doc = #result_doc] + #[inline] + pub fn #result(&self) -> ::midenc_hir2::EntityRef<'_, ::midenc_hir2::OpResult> { + self.op.results()[#index].borrow() + } + + #[doc = #result_mut_doc] + #[inline] + pub fn #result_mut(&mut self) -> ::midenc_hir2::EntityMut<'_, ::midenc_hir2::OpResult> { + self.op.results_mut()[#index].borrow_mut() + } + }); + } + } + OpResultGroup::Named(group, _) => { + let group_mut = format_ident!("{group}_mut"); + let group_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{group}` result group."), + group.span(), + )); + let group_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{group}` result group."), + group.span(), + )); + tokens.extend(quote! { + #[doc = #group_doc] + #[inline] + pub fn #group(&self) -> ::midenc_hir::OpResultRange<'_> { + self.results().group(0) + } + + #[doc = #group_mut_doc] + #[inline] + pub fn #group_mut(&mut self) -> ::midenc_hir::OpResultRangeMut<'_> { + self.op.results_mut().group_mut(0) + } + }); + } + } + } + } +} + +struct OpRegionFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpRegionFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // Regions + for (index, region) in self.0.regions.iter().enumerate() { + let index = syn::Lit::Int(syn::LitInt::new(&format!("{index}usize"), region.span())); + let region_mut = format_ident!("{region}_mut"); + let region_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{region}` region."), + region.span(), + )); + let region_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{region}` region."), + region.span(), + )); + tokens.extend(quote! { + #[doc = #region_doc] + #[inline] + pub fn #region(&self) -> ::midenc_hir2::EntityRef<'_, ::midenc_hir2::Region> { + self.op.region(#index) + } + + #[doc = #region_mut_doc] + #[inline] + pub fn #region_mut(&mut self) -> ::midenc_hir2::EntityMut<'_, ::midenc_hir2::Region> { + self.op.region_mut(#index) + } + }); + } + } +} + +struct OpSuccessorFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpSuccessorFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for (group_index, group) in self.0.successors.iter().enumerate() { + let group_index = syn::Lit::Int(syn::LitInt::new( + &format!("{group_index}usize"), + proc_macro2::Span::call_site(), + )); + match group { + // Successors + SuccessorGroup::Unnamed(successors) => { + for (index, successor) in successors.iter().enumerate() { + let index = syn::Lit::Int(syn::LitInt::new( + &format!("{index}usize"), + proc_macro2::Span::call_site(), + )); + let successor_mut = format_ident!("{successor}_mut"); + let successor_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{successor}` successor."), + successor.span(), + )); + let successor_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{successor}` successor."), + successor.span(), + )); + tokens.extend(quote! { + #[doc = #successor_doc] + #[inline] + pub fn #successor(&self) -> ::midenc_hir2::OpSuccessor<'_> { + self.op.successor_in_group(#group_index, #index) + } + + #[doc = #successor_mut_doc] + #[inline] + pub fn #successor_mut(&mut self) -> ::midenc_hir2::OpSuccessorMut<'_> { + self.op.successor_in_group_mut(#group_index, #index) + } + }); + } + } + // Variadic successor groups + SuccessorGroup::Named(group) => { + let group_mut = format_ident!("{group}_mut"); + let group_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{group}` successor group."), + group.span(), + )); + let group_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{group}` successor group."), + group.span(), + )); + tokens.extend(quote! { + #[doc = #group_doc] + #[inline] + pub fn #group(&self) -> ::midenc_hir2::OpSuccessorRange<'_> { + self.op.successor_group(#group_index) + } + + #[doc = #group_mut_doc] + #[inline] + pub fn #group_mut(&mut self) -> ::midenc_hir2::OpSuccessorRangeMut<'_> { + self.op.successor_group(#group_index) + } + }); + } + // User-defined successor groups + SuccessorGroup::Keyed(group, group_ty) => { + let group_mut = format_ident!("{group}_mut"); + let group_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{group}` successor group."), + group.span(), + )); + let group_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{group}` successor group."), + group.span(), + )); + tokens.extend(quote! { + #[doc = #group_doc] + #[inline] + pub fn #group(&self) -> ::midenc_hir2::KeyedSuccessorRange<'_, #group_ty> { + self.op.keyed_successor_group::<#group_ty>(#group_index) + } + + #[doc = #group_mut_doc] + #[inline] + pub fn #group_mut(&mut self) -> ::midenc_hir2::KeyedSuccessorRangeMut<'_, #group_ty> { + self.op.keyed_successor_group_mut::<#group_ty>(#group_index) + } + }); + } + } + } + } +} + +/// Represents a field decorated with `#[attr]` +/// +/// The type associated with an `#[attr]` field represents the concrete value type of the attribute, +/// and thus must implement the `AttributeValue` trait. +#[derive(Debug, Clone)] +pub struct OpAttribute { + /// The attribute name + pub name: Ident, + /// The value type of the attribute + pub ty: syn::Type, +} + +/// An abstraction over named vs unnamed groups of some IR entity +pub enum EntityGroup { + /// An unnamed group consisting of individual named items + Unnamed(Vec), + /// A named group consisting of unnamed items + Named(Ident, syn::Type), +} + +/// A type representing a type constraint applied to a `Value` impl +pub type Constraint = syn::Type; + +#[derive(Debug, Clone)] +pub struct Operand { + pub name: Ident, + pub constraint: Constraint, +} + +pub type OpOperandGroup = EntityGroup; + +#[derive(Debug, Clone)] +pub struct OpResult { + pub name: Ident, + pub constraint: Constraint, +} + +pub type OpResultGroup = EntityGroup; + +#[derive(Debug)] +pub enum SuccessorGroup { + /// An unnamed group consisting of individual named successors + Unnamed(Vec), + /// A named group consisting of unnamed successors + Named(Ident), + /// A named group consisting of unnamed successors with an associated key + Keyed(Ident, syn::Type), +} + +/// Represents the generated `$OpBuilder` type used to create instances of `$Op` +/// +/// The implementation of the type requires us to know the type signature specific to this op, +/// so that we can emit an implementation matching that signature. +pub struct OpBuilderImpl { + /// The `$Op` we're building + op: Ident, + /// The `$OpBuilder` type name + name: Ident, + /// The doc string for `$OpBuilder` + doc: DocString, + /// The doc string for `$OpBuilder::new` + new_doc: DocString, + /// The set of parameters expected by `$Op::create` + /// + /// The order of these parameters is determined by: + /// + /// 1. The `order = N` property of the corresponding attribute type, e.g. `#[attr(order = 1)]` + /// 2. The default "kind" ordering of: symbols, required user-defined fields, operands, successors, attributes + /// 3. The order of appearance of the fields in the struct + create_params: Rc<[OpCreateParam]>, + /// The implementation of the `BuildableOp` trait for `$Op` via `$OpBuilder` + buildable_op_impl: BuildableOpImpl, + /// The implementation of the `FnOnce` trait for `$OpBuilder` + fn_once_impl: OpBuilderFnOnceImpl, +} +impl OpBuilderImpl { + pub fn empty(op: Ident) -> Self { + let name = format_ident!("{}Builder", &op); + let doc = DocString::new( + op.span(), + format!( + " A specialized builder for [{op}], which is used by calling it like a function." + ), + ); + let new_doc = DocString::new( + op.span(), + format!( + " Get a new [{name}] from the provided [::midenc_hir2::Builder] impl and span." + ), + ); + let create_params = Rc::<[OpCreateParam]>::from([]); + let buildable_op_impl = BuildableOpImpl { + op: op.clone(), + op_builder: name.clone(), + op_generics: Default::default(), + generics: Default::default(), + required_generics: None, + params: Rc::clone(&create_params), + }; + let fn_once_impl = OpBuilderFnOnceImpl { + op: op.clone(), + op_builder: name.clone(), + generics: Default::default(), + required_generics: None, + params: Rc::clone(&create_params), + }; + Self { + op, + name, + doc, + new_doc, + create_params, + buildable_op_impl, + fn_once_impl, + } + } + + pub fn set_create_params(&mut self, op_generics: &syn::Generics, params: Vec) { + let span = self.op.span(); + + let create_params = Rc::from(params.into_boxed_slice()); + self.create_params = Rc::clone(&create_params); + + let has_required_variant = self.create_params.iter().any(|param| param.default); + + // BuildableOp generic parameters + self.buildable_op_impl.params = Rc::clone(&create_params); + self.buildable_op_impl.op_generics = op_generics.clone(); + self.buildable_op_impl.required_generics = if has_required_variant { + Some(syn::Generics { + lt_token: Some(syn::token::Lt(span)), + params: syn::punctuated::Punctuated::from_iter( + op_generics.params.iter().cloned().chain( + self.create_params.iter().flat_map(OpCreateParam::required_generic_types), + ), + ), + gt_token: Some(syn::token::Gt(span)), + where_clause: op_generics.where_clause.clone(), + }) + } else { + None + }; + self.buildable_op_impl.generics = syn::Generics { + lt_token: Some(syn::token::Lt(span)), + params: syn::punctuated::Punctuated::from_iter( + op_generics + .params + .iter() + .cloned() + .chain(self.create_params.iter().flat_map(OpCreateParam::generic_types)), + ), + gt_token: Some(syn::token::Gt(span)), + where_clause: op_generics.where_clause.clone(), + }; + + // FnOnce generic parameters + self.fn_once_impl.params = create_params; + self.fn_once_impl.required_generics = + self.buildable_op_impl.required_generics.as_ref().map( + |buildable_op_impl_required_generics| syn::Generics { + lt_token: Some(syn::token::Lt(span)), + params: syn::punctuated::Punctuated::from_iter( + [ + syn::GenericParam::Lifetime(syn::LifetimeParam { + attrs: vec![], + lifetime: syn::Lifetime::new("'a", proc_macro2::Span::call_site()), + colon_token: None, + bounds: Default::default(), + }), + syn::parse_str("B: ?Sized + ::midenc_hir2::Builder").unwrap(), + ] + .into_iter() + .chain(buildable_op_impl_required_generics.params.iter().cloned()), + ), + gt_token: Some(syn::token::Gt(span)), + where_clause: buildable_op_impl_required_generics.where_clause.clone(), + }, + ); + self.fn_once_impl.generics = syn::Generics { + lt_token: Some(syn::token::Lt(span)), + params: syn::punctuated::Punctuated::from_iter( + [ + syn::GenericParam::Lifetime(syn::LifetimeParam { + attrs: vec![], + lifetime: syn::Lifetime::new("'a", proc_macro2::Span::call_site()), + colon_token: None, + bounds: Default::default(), + }), + syn::parse_str("B: ?Sized + ::midenc_hir2::Builder").unwrap(), + ] + .into_iter() + .chain(self.buildable_op_impl.generics.params.iter().cloned()), + ), + gt_token: Some(syn::token::Gt(span)), + where_clause: self.buildable_op_impl.generics.where_clause.clone(), + }; + } +} +impl quote::ToTokens for OpBuilderImpl { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // Emit `$OpBuilder` + tokens.extend({ + let op_builder = &self.name; + let op_builder_doc = &self.doc; + let op_builder_new_doc = &self.new_doc; + quote! { + #op_builder_doc + pub struct #op_builder <'a, B: ?Sized> { + builder: &'a mut B, + span: ::midenc_session::diagnostics::SourceSpan, + } + + impl<'a, B> #op_builder <'a, B> + where + B: ?Sized + ::midenc_hir2::Builder, + { + #op_builder_new_doc + #[inline(always)] + pub fn new(builder: &'a mut B, span: ::midenc_session::diagnostics::SourceSpan) -> Self { + Self { + builder, + span, + } + } + } + } + }); + + // Emit `impl BuildableOp for $OpBuilder` + self.buildable_op_impl.to_tokens(tokens); + + // Emit `impl FnOnce for $OpBuilder` + self.fn_once_impl.to_tokens(tokens); + } +} + +pub struct BuildableOpImpl { + op: Ident, + op_builder: Ident, + op_generics: syn::Generics, + generics: syn::Generics, + required_generics: Option, + params: Rc<[OpCreateParam]>, +} +impl quote::ToTokens for BuildableOpImpl { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let op = &self.op; + let op_builder = &self.op_builder; + + // Minimal builder (specify only required parameters) + // + // NOTE: This is only emitted if there are `default` parameters + if let Some(required_generics) = self.required_generics.as_ref() { + let required_params = + self.params.iter().flat_map(OpCreateParam::required_binding_types); + let (_, required_ty_generics, _) = self.op_generics.split_for_impl(); + let (required_impl_generics, _, required_where_clause) = + required_generics.split_for_impl(); + let required_params_ty = syn::TypeTuple { + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(required_params), + }; + let quoted = quote! { + impl #required_impl_generics ::midenc_hir2::BuildableOp<#required_params_ty> for #op #required_ty_generics #required_where_clause { + type Builder<'a, T: ?Sized + ::midenc_hir2::Builder + 'a> = #op_builder <'a, T>; + + #[inline(always)] + fn builder<'b, B>(builder: &'b mut B, span: ::midenc_session::diagnostics::SourceSpan) -> Self::Builder<'b, B> + where + B: ?Sized + ::midenc_hir2::Builder + 'b, + { + #op_builder { + builder, + span, + } + } + } + }; + tokens.extend(quoted); + } + + // Maximal builder (specify all parameters) + let params = self.params.iter().flat_map(OpCreateParam::binding_types); + let (_, ty_generics, _) = self.op_generics.split_for_impl(); + let (impl_generics, _, where_clause) = self.generics.split_for_impl(); + let params_ty = syn::TypeTuple { + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(params), + }; + let quoted = quote! { + impl #impl_generics ::midenc_hir2::BuildableOp<#params_ty> for #op #ty_generics #where_clause { + type Builder<'a, T: ?Sized + ::midenc_hir2::Builder + 'a> = #op_builder <'a, T>; + + #[inline(always)] + fn builder<'b, B>(builder: &'b mut B, span: ::midenc_session::diagnostics::SourceSpan) -> Self::Builder<'b, B> + where + B: ?Sized + ::midenc_hir2::Builder + 'b, + { + #op_builder { + builder, + span, + } + } + } + }; + tokens.extend(quoted); + } +} + +pub struct OpBuilderFnOnceImpl { + op: Ident, + op_builder: Ident, + generics: syn::Generics, + required_generics: Option, + params: Rc<[OpCreateParam]>, +} +impl quote::ToTokens for OpBuilderFnOnceImpl { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let op = &self.op; + let op_builder = &self.op_builder; + + let create_param_names = + self.params.iter().flat_map(OpCreateParam::bindings).collect::>(); + + // Minimal builder (specify only required parameters) + // + // NOTE: This is only emitted if there are `default` parameters + if let Some(required_generics) = self.required_generics.as_ref() { + let required_param_names = + self.params.iter().flat_map(OpCreateParam::required_bindings); + let defaulted_param_names = self.params.iter().flat_map(|param| { + if param.default { + param.bindings() + } else { + vec![] + } + }); + let required_param_types = + self.params.iter().flat_map(OpCreateParam::required_binding_types); + let (required_impl_generics, _required_ty_generics, required_where_clause) = + required_generics.split_for_impl(); + let required_params_ty = syn::TypeTuple { + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(required_param_types), + }; + let required_params_bound = syn::PatTuple { + attrs: Default::default(), + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter( + required_param_names.into_iter().map(|id| { + syn::Pat::Ident(syn::PatIdent { + attrs: Default::default(), + by_ref: None, + mutability: None, + ident: id, + subpat: None, + }) + }), + ), + }; + tokens.extend(quote! { + impl #required_impl_generics ::core::ops::FnOnce<#required_params_ty> for #op_builder<'a, B> #required_where_clause { + type Output = Result<::midenc_hir2::UnsafeIntrusiveEntityRef<#op>, ::midenc_session::diagnostics::Report>; + + #[inline] + extern "rust-call" fn call_once(self, args: #required_params_ty) -> Self::Output { + let #required_params_bound = args; + #( + let #defaulted_param_names = Default::default(); + )* + <#op>::create(self.builder, self.span, #(#create_param_names),*) + } + } + }); + } + + // Maximal builder (specify all parameters) + let param_types = self.params.iter().flat_map(OpCreateParam::binding_types); + let (impl_generics, _ty_generics, where_clause) = self.generics.split_for_impl(); + let params_ty = syn::TypeTuple { + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(param_types), + }; + let params_bound = syn::PatTuple { + attrs: Default::default(), + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(create_param_names.iter().map(|id| { + syn::Pat::Ident(syn::PatIdent { + attrs: Default::default(), + by_ref: None, + mutability: None, + ident: id.clone(), + subpat: None, + }) + })), + }; + tokens.extend(quote! { + impl #impl_generics ::core::ops::FnOnce<#params_ty> for #op_builder<'a, B> #where_clause { + type Output = Result<::midenc_hir2::UnsafeIntrusiveEntityRef<#op>, ::midenc_session::diagnostics::Report>; + + #[inline] + extern "rust-call" fn call_once(self, args: #params_ty) -> Self::Output { + let #params_bound = args; + <#op>::create(self.builder, self.span, #(#create_param_names),*) + } + } + }); + } +} + +pub struct OpVerifierImpl { + op: Ident, + traits: darling::util::PathList, + implements: darling::util::PathList, +} +impl OpVerifierImpl { + pub fn new( + op: Ident, + traits: darling::util::PathList, + implements: darling::util::PathList, + ) -> Self { + Self { + op, + traits, + implements, + } + } +} +impl quote::ToTokens for OpVerifierImpl { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let op = &self.op; + if self.traits.is_empty() && self.implements.is_empty() { + tokens.extend(quote! { + /// No-op verifier implementation generated via `#[operation]` derive + /// + /// This implementation was chosen as no op traits were indicated as being derived _or_ + /// manually implemented by this type. + impl ::midenc_hir2::OpVerifier for #op { + #[inline(always)] + fn verify(&self, _context: &::midenc_hir2::Context) -> Result<(), ::midenc_session::diagnostics::Report> { + Ok(()) + } + } + }); + return; + } + + let op_verifier_doc_lines = { + let span = self.op.span(); + let mut lines = vec![ + syn::Lit::Str(syn::LitStr::new( + " Generated verifier implementation via `#[operation]` attribute", + span, + )), + syn::Lit::Str(syn::LitStr::new("", span)), + syn::Lit::Str(syn::LitStr::new(" Traits verified by this implementation:", span)), + syn::Lit::Str(syn::LitStr::new("", span)), + ]; + for derived_trait in self.traits.iter() { + lines.push(syn::Lit::Str(syn::LitStr::new( + &format!(" * [{}]", derived_trait.get_ident().unwrap()), + span, + ))); + } + for implemented_trait in self.implements.iter() { + lines.push(syn::Lit::Str(syn::LitStr::new( + &format!(" * [{}]", implemented_trait.get_ident().unwrap()), + span, + ))); + } + lines.push(syn::Lit::Str(syn::LitStr::new("", span))); + lines.push(syn::Lit::Str(syn::LitStr::new( + " Use `cargo-expand` to view the generated code if you suspect verification is \ + broken.", + span, + ))); + lines + }; + + let derived_traits = &self.traits; + let implemented_traits = &self.implements; + tokens.extend(quote! { + #( + #[doc = #op_verifier_doc_lines] + )* + impl ::midenc_hir2::OpVerifier for #op { + fn verify(&self, context: &::midenc_hir2::Context) -> Result<(), ::midenc_session::diagnostics::Report> { + /// Type alias for the generated concrete verifier type + #[allow(unused_parens)] + type OpVerifierImpl<'a, T> = ::midenc_hir2::derive::DeriveVerifier<'a, T, (#(&'a dyn #derived_traits,)* #(&'a dyn #implemented_traits),*)>; + + #[allow(unused_parens)] + impl<'a> ::midenc_hir2::OpVerifier for OpVerifierImpl<'a, #op> + where + #( + #op: ::midenc_hir2::verifier::Verifier, + )* + #( + #op: ::midenc_hir2::verifier::Verifier, + )* + { + #[inline] + fn verify(&self, context: &::midenc_hir2::Context) -> Result<(), ::midenc_session::diagnostics::Report> { + let op = self.downcast_ref::<#op>().unwrap(); + #( + if const { !<#op as ::midenc_hir2::verifier::Verifier>::VACUOUS } { + <#op as ::midenc_hir2::verifier::Verifier>::maybe_verify(op, context)?; + } + )* + #( + if const { !<#op as ::midenc_hir2::verifier::Verifier>::VACUOUS } { + <#op as ::midenc_hir2::verifier::Verifier>::maybe_verify(op, context)?; + } + )* + + Ok(()) + } + } + + let verifier = OpVerifierImpl::<#op>::new(&self.op); + verifier.verify(context) + } + } + }); + } +} + +/// Represents the parsed struct definition for the operation we wish to define +/// +/// Only named structs are allowed at this time. +#[derive(Debug, FromDeriveInput)] +#[darling( + attributes(operation), + supports(struct_named), + forward_attrs(doc, cfg, allow, derive) +)] +pub struct Operation { + ident: Ident, + vis: syn::Visibility, + generics: syn::Generics, + attrs: Vec, + data: darling::ast::Data<(), OperationField>, + dialect: Ident, + #[darling(default)] + name: Option, + #[darling(default)] + traits: darling::util::PathList, + #[darling(default)] + implements: darling::util::PathList, +} + +/// Represents a field in the input struct +#[derive(Debug, FromField)] +#[darling(forward_attrs( + doc, cfg, allow, attr, operand, operands, region, successor, successors, result, results, + default, order, symbol +))] +pub struct OperationField { + /// The name of this field. + /// + /// This will always be `Some`, as we do not support any types other than structs + ident: Option, + /// The visibility assigned to this field + vis: syn::Visibility, + /// The type assigned to this field + ty: syn::Type, + /// The processed attributes of this field + #[darling(with = OperationFieldAttrs::new)] + attrs: OperationFieldAttrs, +} + +#[derive(Default, Debug)] +pub struct OperationFieldAttrs { + /// Attributes we don't care about, and are forwarding along untouched + forwarded: Vec, + /// Whether or not to create instances of this op using the `Default` impl for this field + r#default: Flag, + /// Whether or not to assign an explicit order to this field. + /// + /// Once an explicit order has been assigned to a field, all subsequent fields must either have + /// an explicit order, or they will be assigned the next largest unallocated index in the order. + order: Option, + /// Was this an `#[attr]` field? + attr: Flag, + /// Was this an `#[operand]` field? + operand: Flag, + /// Was this an `#[operands]` field? + operands: Flag, + /// Was this a `#[result]` field? + result: Flag, + /// Was this a `#[results]` field? + results: Flag, + /// Was this a `#[region]` field? + region: Flag, + /// Was this a `#[successor]` field? + successor: Flag, + /// Was this a `#[successors]` field? + successors: Option>, + /// Was this a `#[symbol]` field? + symbol: Option>>, +} + +impl OperationFieldAttrs { + pub fn new(attrs: Vec) -> darling::Result { + let mut result = Self::default(); + let mut prev_decorator = None; + for attr in attrs { + if let Some(name) = attr.path().get_ident().map(|id| id.to_string()) { + match name.as_str() { + "attr" => { + if let Some(prev) = prev_decorator.replace("attr") { + return Err(Error::custom(format!( + "#[attr] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.attr = Flag::from_meta(&attr.meta).unwrap(); + } + "operand" => { + if let Some(prev) = prev_decorator.replace("operand") { + return Err(Error::custom(format!( + "#[operand] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.operand = Flag::from_meta(&attr.meta).unwrap(); + } + "operands" => { + if let Some(prev) = prev_decorator.replace("operands") { + return Err(Error::custom(format!( + "#[operands] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.operands = Flag::from_meta(&attr.meta).unwrap(); + } + "result" => { + if let Some(prev) = prev_decorator.replace("result") { + return Err(Error::custom(format!( + "#[result] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.result = Flag::from_meta(&attr.meta).unwrap(); + } + "results" => { + if let Some(prev) = prev_decorator.replace("results") { + return Err(Error::custom(format!( + "#[results] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.results = Flag::from_meta(&attr.meta).unwrap(); + } + "region" => { + if let Some(prev) = prev_decorator.replace("region") { + return Err(Error::custom(format!( + "#[region] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.region = Flag::from_meta(&attr.meta).unwrap(); + } + "successor" => { + if let Some(prev) = prev_decorator.replace("successor") { + return Err(Error::custom(format!( + "#[successor] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.successor = Flag::from_meta(&attr.meta).unwrap(); + } + "successors" => { + if let Some(prev) = prev_decorator.replace("successors") { + return Err(Error::custom(format!( + "#[successors] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + let span = attr.span(); + let mut succ_ty = SuccessorsType::Default; + match attr.parse_nested_meta(|meta| { + if meta.path.is_ident("keyed") { + succ_ty = SuccessorsType::Keyed; + Ok(()) + } else { + Err(meta.error(format!( + "invalid #[successors] decorator: unrecognized key '{}'", + meta.path.get_ident().unwrap() + ))) + } + }) { + Ok(_) => { + result.successors = Some(SpannedValue::new(succ_ty, span)); + } + Err(err) => { + return Err(Error::from(err)); + } + } + } + "symbol" => { + if let Some(prev) = prev_decorator.replace("symbol") { + return Err(Error::custom(format!( + "#[symbol] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + let span = attr.span(); + let mut symbol_ty = None; + match &attr.meta { + // A bare #[symbol], nothing to do + syn::Meta::Path(_) => (), + syn::Meta::List(ref list) => { + list.parse_nested_meta(|meta| { + if meta.path.is_ident("callable") { + symbol_ty = Some(SymbolType::Callable); + Ok(()) + } else if meta.path.is_ident("any") { + symbol_ty = Some(SymbolType::Any); + Ok(()) + } else if meta.path.is_ident("bounds") { + let symbol_bound = meta + .input + .parse::() + .map_err(Error::from)?; + symbol_ty = Some(symbol_bound.into()); + Ok(()) + } else { + Err(meta.error(format!( + "invalid #[symbol] decorator: unrecognized key '{}'", + meta.path.get_ident().unwrap() + ))) + } + }) + .map_err(Error::from)?; + } + meta @ syn::Meta::NameValue(_) => { + return Err(Error::custom( + "invalid #[symbol] decorator: invalid format, expected either \ + bare 'symbol' or a meta list", + ) + .with_span(meta)); + } + } + result.symbol = Some(SpannedValue::new(symbol_ty, span)); + } + "default" => { + result.default = Flag::present(); + } + "order" => { + result.order = Some( + attr.parse_args::() + .map_err(Error::from) + .and_then(|n| n.base10_parse::().map_err(Error::from))?, + ); + } + _ => { + result.forwarded.push(attr); + } + } + } else { + result.forwarded.push(attr); + } + } + + Ok(result) + } +} + +impl OperationFieldAttrs { + pub fn pseudo_type(&self) -> Option> { + use darling::util::SpannedValue; + if self.attr.is_present() { + Some(SpannedValue::new(OperationFieldType::Attr, self.attr.span())) + } else if self.operand.is_present() { + Some(SpannedValue::new(OperationFieldType::Operand, self.operand.span())) + } else if self.operands.is_present() { + Some(SpannedValue::new(OperationFieldType::Operands, self.operands.span())) + } else if self.result.is_present() { + Some(SpannedValue::new(OperationFieldType::Result, self.result.span())) + } else if self.results.is_present() { + Some(SpannedValue::new(OperationFieldType::Results, self.results.span())) + } else if self.region.is_present() { + Some(SpannedValue::new(OperationFieldType::Region, self.region.span())) + } else if self.successor.is_present() { + Some(SpannedValue::new(OperationFieldType::Successor, self.successor.span())) + } else if self.successors.is_some() { + self.successors.map(|succ| succ.map_ref(|s| OperationFieldType::Successors(*s))) + } else if self.symbol.is_some() { + self.symbol + .as_ref() + .map(|sym| sym.map_ref(|sym| OperationFieldType::Symbol(sym.clone()))) + } else { + None + } + } +} + +#[derive(Debug, Clone)] +pub enum OperationFieldType { + /// An operation attribute + Attr, + /// A named operand + Operand, + /// A named variadic operand group (zero or more operands) + Operands, + /// A named result + Result, + /// A named variadic result group (zero or more results) + Results, + /// A named region + Region, + /// A named successor + Successor, + /// A named variadic successor group (zero or more successors) + Successors(SuccessorsType), + /// A symbol operand + /// + /// Symbols are handled differently than regular operands, as they are not SSA values, and + /// are tracked using a different use/def graph than normal values. + /// + /// If the symbol type is `None`, it implies we should use the concrete field type as the + /// expected symbol type. Otherwise, use the provided symbol type to derive bounds for that + /// field. + Symbol(Option), +} +impl core::fmt::Display for OperationFieldType { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Attr => f.write_str("attr"), + Self::Operand => f.write_str("operand"), + Self::Operands => f.write_str("operands"), + Self::Result => f.write_str("result"), + Self::Results => f.write_str("results"), + Self::Region => f.write_str("region"), + Self::Successor => f.write_str("successor"), + Self::Successors(SuccessorsType::Default) => f.write_str("successors"), + Self::Successors(SuccessorsType::Keyed) => f.write_str("successors(keyed)"), + Self::Symbol(None) => f.write_str("symbol"), + Self::Symbol(Some(SymbolType::Any)) => f.write_str("symbol(any)"), + Self::Symbol(Some(SymbolType::Callable)) => f.write_str("symbol(callable)"), + Self::Symbol(Some(SymbolType::Concrete(_))) => write!(f, "symbol(concrete)"), + Self::Symbol(Some(SymbolType::Trait(_))) => write!(f, "symbol(trait)"), + } + } +} + +/// The type of successor group +#[derive(Default, Debug, darling::FromMeta, Copy, Clone)] +#[darling(default)] +pub enum SuccessorsType { + /// The default successor type consists of a `BlockRef` and an iterable of `ValueRef` + #[default] + Default, + /// A keyed successor is a custom type that implements the `KeyedSuccessor` trait + Keyed, +} + +/// Represents parameter information for `$Op::create` and the associated builder infrastructure. +#[derive(Debug)] +pub struct OpCreateParam { + /// The actual parameter type and payload + param_ty: OpCreateParamType, + /// Is this value initialized using `Default::default` when `Op::create` is called? + r#default: bool, +} + +#[derive(Debug)] +pub enum OpCreateParamType { + Attr(OpAttribute), + Operand(Operand), + #[allow(dead_code)] + OperandGroup(Ident, syn::Type), + CustomField(Ident, syn::Type), + Successor(Ident), + SuccessorGroupNamed(Ident), + SuccessorGroupKeyed(Ident, syn::Type), + Symbol(Symbol), +} +impl OpCreateParam { + /// Returns the names of all bindings implied by this parameter. + pub fn bindings(&self) -> Vec { + match &self.param_ty { + OpCreateParamType::Attr(OpAttribute { name, .. }) + | OpCreateParamType::CustomField(name, _) + | OpCreateParamType::Operand(Operand { name, .. }) + | OpCreateParamType::OperandGroup(name, _) + | OpCreateParamType::SuccessorGroupNamed(name) + | OpCreateParamType::SuccessorGroupKeyed(name, _) + | OpCreateParamType::Symbol(Symbol { name, .. }) => vec![name.clone()], + OpCreateParamType::Successor(name) => { + vec![name.clone(), format_ident!("{}_args", name)] + } + } + } + + /// Returns the names of all required (i.e. non-defaulted) bindings implied by this parameter. + pub fn required_bindings(&self) -> Vec { + if self.default { + return vec![]; + } + self.bindings() + } + + /// Returns the types assigned to the bindings returned by [Self::bindings] + pub fn binding_types(&self) -> Vec { + match &self.param_ty { + OpCreateParamType::Attr(OpAttribute { ty, .. }) + | OpCreateParamType::CustomField(_, ty) => { + vec![ty.clone()] + } + OpCreateParamType::Operand(_) => vec![make_type("::midenc_hir2::ValueRef")], + OpCreateParamType::OperandGroup(group_name, _) + | OpCreateParamType::SuccessorGroupNamed(group_name) + | OpCreateParamType::SuccessorGroupKeyed(group_name, _) => { + vec![make_type(format!("T{}", group_name.to_string().to_pascal_case()))] + } + OpCreateParamType::Successor(name) => vec![ + make_type("::midenc_hir2::BlockRef"), + make_type(format!("T{}Args", name.to_string().to_pascal_case())), + ], + OpCreateParamType::Symbol(Symbol { name, ty }) => match ty { + SymbolType::Any | SymbolType::Callable | SymbolType::Trait(_) => { + vec![make_type(format!("T{}", name.to_string().to_pascal_case()))] + } + SymbolType::Concrete(ty) => vec![ty.clone()], + }, + } + } + + /// Returns the types assigned to the bindings returned by [Self::required_bindings] + pub fn required_binding_types(&self) -> Vec { + if self.default { + return vec![]; + } + self.binding_types() + } + + /// Returns the generic type parameters bound for use by the types in [Self::binding_typess] + pub fn generic_types(&self) -> Vec { + match &self.param_ty { + OpCreateParamType::OperandGroup(name, _) => { + let value_iter_bound: syn::TypeParamBound = + syn::parse_str("IntoIterator").unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!( + "T{}", + &name.to_string().to_pascal_case(), + span = name.span() + ), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([value_iter_bound]), + eq_token: None, + r#default: None, + })] + } + OpCreateParamType::Successor(name) => { + let value_iter_bound: syn::TypeParamBound = + syn::parse_str("IntoIterator").unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!( + "T{}Args", + &name.to_string().to_pascal_case(), + span = name.span() + ), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([value_iter_bound]), + eq_token: None, + r#default: None, + })] + } + OpCreateParamType::SuccessorGroupNamed(name) => { + let value_iter_bound: syn::TypeParamBound = syn::parse_str( + "IntoIterator)>", + ) + .unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!( + "T{}", + &name.to_string().to_pascal_case(), + span = name.span() + ), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([value_iter_bound]), + eq_token: None, + r#default: None, + })] + } + OpCreateParamType::SuccessorGroupKeyed(name, ty) => { + let item_name = name.to_string().to_pascal_case(); + let iterator_ty = format_ident!("T{item_name}", span = name.span()); + vec![syn::parse_quote! { + #iterator_ty: IntoIterator + }] + } + OpCreateParamType::Symbol(Symbol { name, ty }) => match ty { + SymbolType::Any => { + let as_symbol_ref_bound = + syn::parse_str::("::midenc_hir2::AsSymbolRef") + .unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!("T{}", name.to_string().to_pascal_case()), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([as_symbol_ref_bound]), + eq_token: None, + r#default: None, + })] + } + SymbolType::Callable => { + let as_callable_symbol_ref_bound = syn::parse_str::( + "::midenc_hir2::traits::AsCallableSymbolRef", + ) + .unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!("T{}", name.to_string().to_pascal_case()), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([ + as_callable_symbol_ref_bound, + ]), + eq_token: None, + r#default: None, + })] + } + SymbolType::Concrete(_) => vec![], + SymbolType::Trait(bounds) => { + let as_symbol_ref_bound = syn::parse_str("::midenc_hir2::AsSymbolRef").unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!("T{}", name.to_string().to_pascal_case()), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter( + [as_symbol_ref_bound].into_iter().chain(bounds.iter().cloned()), + ), + eq_token: None, + r#default: None, + })] + } + }, + _ => vec![], + } + } + + /// Returns the generic type parameters bound for use by the types in [Self::required_binding_typess] + pub fn required_generic_types(&self) -> Vec { + if self.default { + return vec![]; + } + self.generic_types() + } +} + +/// A symbol value +#[derive(Debug, Clone)] +pub struct Symbol { + pub name: Ident, + pub ty: SymbolType, +} + +/// Represents the type of a symbol +#[derive(Debug, Clone)] +pub enum SymbolType { + /// Any `Symbol` implementation can be used + Any, + /// Any `Symbol + CallableOpInterface` implementation can be used + Callable, + /// Only the specific concrete type can be used, it must implement the `Symbol` trait + Concrete(syn::Type), + /// Any implementation of the provided trait can be used. + /// + /// The given trait type _must_ have `Symbol` as a supertrait. + Trait(syn::punctuated::Punctuated), +} + +struct SymbolTraitBound { + _eq_token: Token![=], + bounds: syn::punctuated::Punctuated, +} +impl syn::parse::Parse for SymbolTraitBound { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + if !lookahead.peek(Token![=]) { + return Err(lookahead.error()); + } + + let _eq_token = input.parse::()?; + let bounds = syn::punctuated::Punctuated::parse_separated_nonempty(input)?; + + Ok(Self { _eq_token, bounds }) + } +} +impl From for SymbolType { + #[inline] + fn from(value: SymbolTraitBound) -> Self { + SymbolType::Trait(value.bounds) + } +} + +pub struct DocString { + span: proc_macro2::Span, + doc: String, +} +impl DocString { + pub fn new(span: proc_macro2::Span, doc: String) -> Self { + Self { span, doc } + } +} +impl quote::ToTokens for DocString { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let attr = syn::Attribute { + pound_token: syn::token::Pound(self.span), + style: syn::AttrStyle::Outer, + bracket_token: syn::token::Bracket(self.span), + meta: syn::Meta::NameValue(syn::MetaNameValue { + path: attr_path("doc"), + eq_token: syn::token::Eq(self.span), + value: syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::Str(syn::LitStr::new(&self.doc, self.span)), + }), + }), + }; + + attr.to_tokens(tokens); + } +} + +#[derive(Copy, Clone)] +enum PathStyle { + Default, + Absolute, +} + +fn make_type(s: impl AsRef) -> syn::Type { + let s = s.as_ref(); + let path = type_path(s); + syn::Type::Path(syn::TypePath { qself: None, path }) +} + +fn type_path(s: impl AsRef) -> syn::Path { + let s = s.as_ref(); + let (s, style) = if let Some(s) = s.strip_prefix("::") { + (s, PathStyle::Absolute) + } else { + (s, PathStyle::Default) + }; + let parts = s.split("::"); + make_path(parts, style) +} + +fn attr_path(s: impl AsRef) -> syn::Path { + make_path([s.as_ref()], PathStyle::Default) +} + +fn make_path<'a>(parts: impl IntoIterator, style: PathStyle) -> syn::Path { + use proc_macro2::Span; + + syn::Path { + leading_colon: match style { + PathStyle::Default => None, + PathStyle::Absolute => Some(syn::token::PathSep(Span::call_site())), + }, + segments: syn::punctuated::Punctuated::from_iter(parts.into_iter().map(|part| { + syn::PathSegment { + ident: format_ident!("{}", part), + arguments: syn::PathArguments::None, + } + })), + } +} + +#[cfg(test)] +mod tests { + #![allow(dead_code)] + + #[test] + fn operation_impl_test() { + let item_input: syn::DeriveInput = syn::parse_quote! { + /// Two's complement sum + #[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface), + )] + pub struct Add { + /// The left-hand operand + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, + #[attr] + overflow: Overflow, + } + }; + + let output = super::derive_operation(item_input); + match output { + Ok(output) => { + let formatted = format_output(&output.to_string()); + println!("{formatted}"); + } + Err(err) => { + panic!("{err}"); + } + } + } + + fn format_output(input: &str) -> String { + use std::{ + io::{Read, Write}, + process::{Command, Stdio}, + }; + + let mut child = Command::new("rustfmt") + .args(["+nightly", "--edition", "2024"]) + .args([ + "--config", + "unstable_features=true,normalize_doc_attributes=true,\ + use_field_init_shorthand=true,condense_wildcard_suffixes=true,\ + format_strings=true,group_imports=StdExternalCrate,imports_granularity=Crate", + ]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("failed to spawn 'rustfmt'"); + + { + let mut stdin = child.stdin.take().unwrap(); + stdin.write_all(input.as_bytes()).expect("failed to write input to 'rustfmt'"); + } + let mut buf = String::new(); + let mut stdout = child.stdout.take().unwrap(); + stdout.read_to_string(&mut buf).expect("failed to read output from 'rustfmt'"); + match child.wait() { + Ok(status) => { + if status.success() { + buf + } else { + let mut stderr = child.stderr.take().unwrap(); + let mut err_buf = String::new(); + let _ = stderr.read_to_string(&mut err_buf).ok(); + panic!( + "command 'rustfmt' failed with status {:?}\n\nReason: {}", + status.code(), + if err_buf.is_empty() { + "" + } else { + err_buf.as_str() + }, + ); + } + } + Err(err) => panic!("command 'rustfmt' failed with {err}"), + } + } +} From 95379c6cbbb18f016ac8582a11131f67eb031c0b Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Oct 2024 16:36:27 -0400 Subject: [PATCH 07/31] fix: expose sync module from hir-symbol --- hir-symbol/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hir-symbol/src/lib.rs b/hir-symbol/src/lib.rs index 40d5a8254..749be982b 100644 --- a/hir-symbol/src/lib.rs +++ b/hir-symbol/src/lib.rs @@ -4,7 +4,7 @@ extern crate alloc; #[cfg(feature = "std")] extern crate std; -mod sync; +pub mod sync; use alloc::{ boxed::Box, From 1146d982be322ab6e7f5c63f4e0bb083f0c9490c Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Oct 2024 16:36:52 -0400 Subject: [PATCH 08/31] feat: implement pretty-print trait for Symbol/Type --- Cargo.lock | 2 ++ Cargo.toml | 1 + hir-symbol/Cargo.toml | 1 + hir-symbol/src/lib.rs | 8 ++++++++ hir-type/Cargo.toml | 1 + hir-type/src/lib.rs | 9 +++++++++ 6 files changed, 22 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 749bed599..571e5aab8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3387,6 +3387,7 @@ dependencies = [ "Inflector", "compact_str", "lock_api", + "miden-formatting", "parking_lot", "rustc-hash 1.1.0", "serde 1.0.210", @@ -3412,6 +3413,7 @@ dependencies = [ name = "midenc-hir-type" version = "0.0.6" dependencies = [ + "miden-formatting", "serde 1.0.210", "serde_repr", "smallvec", diff --git a/Cargo.toml b/Cargo.toml index 02a397543..15e62c400 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,6 +88,7 @@ derive_more = "0.99" indexmap = "2.2" miden-assembly = { version = "0.10.3" } miden-core = { version = "0.10.3" } +miden-formatting = { version = "0.1", default-features = false } miden-parsing = "0.1" miden-processor = { version = "0.10.3" } miden-stdlib = { version = "0.10.3", features = ["with-debug-info"] } diff --git a/hir-symbol/Cargo.toml b/hir-symbol/Cargo.toml index 14cd4bf8a..d4d4f5a1e 100644 --- a/hir-symbol/Cargo.toml +++ b/hir-symbol/Cargo.toml @@ -20,6 +20,7 @@ compact_str = ["dep:compact_str"] [dependencies] compact_str = { workspace = true, optional = true } lock_api = "0.4" +miden-formatting.workspace = true parking_lot = { version = "0.12", optional = true } serde = { workspace = true, optional = true } diff --git a/hir-symbol/src/lib.rs b/hir-symbol/src/lib.rs index 749be982b..a1bc3cfb1 100644 --- a/hir-symbol/src/lib.rs +++ b/hir-symbol/src/lib.rs @@ -14,6 +14,8 @@ use alloc::{ }; use core::{fmt, mem, ops::Deref, str}; +use miden_formatting::prettier::PrettyPrint; + pub mod symbols { include!(env!("SYMBOLS_RS")); } @@ -110,6 +112,12 @@ impl fmt::Display for Symbol { fmt::Display::fmt(&self.as_str(), f) } } +impl PrettyPrint for Symbol { + fn render(&self) -> miden_formatting::prettier::Document { + use miden_formatting::prettier::*; + const_text(self.as_str()) + } +} impl PartialOrd for Symbol { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) diff --git a/hir-type/Cargo.toml b/hir-type/Cargo.toml index bee94a9d2..7706f8b3c 100644 --- a/hir-type/Cargo.toml +++ b/hir-type/Cargo.toml @@ -16,6 +16,7 @@ default = ["serde"] serde = ["dep:serde", "dep:serde_repr"] [dependencies] +miden-formatting.workspace = true smallvec.workspace = true serde = { workspace = true, optional = true } serde_repr = { workspace = true, optional = true } diff --git a/hir-type/src/lib.rs b/hir-type/src/lib.rs index 978950677..7fc802c08 100644 --- a/hir-type/src/lib.rs +++ b/hir-type/src/lib.rs @@ -7,6 +7,8 @@ mod layout; use alloc::{boxed::Box, vec::Vec}; use core::{fmt, num::NonZeroU16, str::FromStr}; +use miden_formatting::prettier::PrettyPrint; + pub use self::layout::Alignable; /// Represents the type of a value @@ -333,6 +335,13 @@ impl fmt::Display for Type { } } } +impl PrettyPrint for Type { + fn render(&self) -> miden_formatting::prettier::Document { + use miden_formatting::prettier::*; + + display(self) + } +} /// This represents metadata about how a structured type will be represented in memory #[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)] From 7e37cf1e691c69e9022baf3da4d46de417992714 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Oct 2024 16:39:33 -0400 Subject: [PATCH 09/31] wip: implement useful static casting traits --- hir2/src/any.rs | 269 ++++++++++++++++++++++++++++++++++++++++++++++++ hir2/src/lib.rs | 1 + 2 files changed, 270 insertions(+) create mode 100644 hir2/src/any.rs diff --git a/hir2/src/any.rs b/hir2/src/any.rs new file mode 100644 index 000000000..3a3fd4296 --- /dev/null +++ b/hir2/src/any.rs @@ -0,0 +1,269 @@ +use core::{any::Any, marker::Unsize}; + +pub trait AsAny: Any + Unsize { + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self + } + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + #[inline(always)] + fn into_any(self: Box) -> Box { + self + } +} + +impl> AsAny for T {} + +/// # Safety +/// +/// This trait is not safe (or possible) to implement manually. +/// +/// It is automatically derived for all `T: Unsize` +pub unsafe trait Is: Unsize {} +unsafe impl Is for T where T: ?Sized + Unsize {} + +pub trait IsObjOf {} +impl IsObjOf for Trait +where + T: ?Sized + Is, + Trait: ?Sized, +{ +} + +#[allow(unused)] +pub trait DowncastRef: IsObjOf { + fn downcast_ref(&self) -> Option<&To>; + fn downcast_mut(&mut self) -> Option<&mut To>; +} +impl DowncastRef for From +where + From: ?Sized, + To: ?Sized + DowncastFromRef, +{ + #[inline(always)] + fn downcast_ref(&self) -> Option<&To> { + To::downcast_from_ref(self) + } + + #[inline(always)] + fn downcast_mut(&mut self) -> Option<&mut To> { + To::downcast_from_mut(self) + } +} + +pub trait DowncastFromRef: Is { + fn downcast_from_ref(from: &From) -> Option<&Self>; + fn downcast_from_mut(from: &mut From) -> Option<&mut Self>; +} +impl DowncastFromRef for To +where + From: ?Sized + AsAny, + To: Is + 'static, +{ + #[inline] + fn downcast_from_ref(from: &From) -> Option<&Self> { + from.as_any().downcast_ref() + } + + #[inline] + fn downcast_from_mut(from: &mut From) -> Option<&mut Self> { + from.as_any_mut().downcast_mut() + } +} + +#[allow(unused)] +pub trait Downcast: DowncastRef + Is { + fn downcast(self: Box) -> Result, Box>; +} +impl Downcast for From +where + From: ?Sized + DowncastRef + Is, + To: ?Sized + DowncastFrom, + Obj: ?Sized, +{ + #[inline] + fn downcast(self: Box) -> Result, Box> { + To::downcast_from(self) + } +} + +pub trait DowncastFrom: DowncastFromRef +where + From: ?Sized + Is, + Obj: ?Sized, +{ + fn downcast_from(from: Box) -> Result, Box>; +} +impl DowncastFrom for To +where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, +{ + fn downcast_from(from: Box) -> Result, Box> { + if !from.as_any().is::() { + Ok(from.into_any().downcast().unwrap()) + } else { + Err(from) + } + } +} + +pub trait Upcast: TryUpcastRef { + fn upcast_ref(&self) -> &To; + fn upcast_mut(&mut self) -> &mut To; + fn upcast(self: Box) -> Box; +} +impl Upcast for From +where + From: ?Sized + Is, + To: ?Sized, +{ + #[inline(always)] + fn upcast_ref(&self) -> &To { + self + } + + #[inline(always)] + fn upcast_mut(&mut self) -> &mut To { + self + } + + #[inline(always)] + fn upcast(self: Box) -> Box { + self + } +} + +pub trait TryUpcastRef: Is +where + To: ?Sized, +{ + #[allow(unused)] + fn is_of(&self) -> bool; + fn try_upcast_ref(&self) -> Option<&To>; + fn try_upcast_mut(&mut self) -> Option<&mut To>; +} +impl TryUpcastRef for From +where + From: Upcast + ?Sized, + To: ?Sized, +{ + #[inline(always)] + fn is_of(&self) -> bool { + true + } + + #[inline(always)] + fn try_upcast_ref(&self) -> Option<&To> { + Some(self.upcast_ref()) + } + + #[inline(always)] + fn try_upcast_mut(&mut self) -> Option<&mut To> { + Some(self.upcast_mut()) + } +} + +pub trait TryUpcast: Is + TryUpcastRef +where + To: ?Sized, + Obj: ?Sized, +{ + #[inline(always)] + fn try_upcast(self: Box) -> Result, Box> { + Err(self) + } +} +impl TryUpcast for From +where + From: Is + Upcast + ?Sized, + To: ?Sized, + Obj: ?Sized, +{ + #[inline] + fn try_upcast(self: Box) -> Result, Box> { + Ok(self.upcast()) + } +} + +/// Upcasts a type into a trait object +#[allow(unused)] +pub trait UpcastFrom +where + From: ?Sized, +{ + fn upcast_from_ref(from: &From) -> &Self; + fn upcast_from_mut(from: &mut From) -> &mut Self; + fn upcast_from(from: Box) -> Box; +} + +impl UpcastFrom for To +where + From: Upcast + ?Sized, + To: ?Sized, +{ + #[inline] + fn upcast_from_ref(from: &From) -> &Self { + from.upcast_ref() + } + + #[inline] + fn upcast_from_mut(from: &mut From) -> &mut Self { + from.upcast_mut() + } + + #[inline] + fn upcast_from(from: Box) -> Box { + from.upcast() + } +} + +#[allow(unused)] +pub trait TryUpcastFromRef: IsObjOf +where + From: ?Sized, +{ + fn try_upcast_from_ref(from: &From) -> Option<&Self>; + fn try_upcast_from_mut(from: &mut From) -> Option<&mut Self>; +} + +impl TryUpcastFromRef for To +where + From: TryUpcastRef + ?Sized, + To: ?Sized, +{ + #[inline] + fn try_upcast_from_ref(from: &From) -> Option<&Self> { + from.try_upcast_ref() + } + + #[inline] + fn try_upcast_from_mut(from: &mut From) -> Option<&mut Self> { + from.try_upcast_mut() + } +} + +#[allow(unused)] +pub trait TryUpcastFrom: TryUpcastFromRef +where + From: Is + ?Sized, + Obj: ?Sized, +{ + fn try_upcast_from(from: Box) -> Result, Box>; +} + +impl TryUpcastFrom for To +where + From: TryUpcast + ?Sized, + To: ?Sized, + Obj: ?Sized, +{ + #[inline] + fn try_upcast_from(from: Box) -> Result, Box> { + from.try_upcast() + } +} diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index 553f0fc88..696216177 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -22,6 +22,7 @@ pub use compact_str::{ CompactString as SmallStr, CompactStringExt as SmallStrExt, ToCompactString as ToSmallStr, }; +mod any; mod attributes; pub mod demangle; pub mod derive; From 03cb6f509df8fbd3ef9294e1f5e7444b10450254 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Oct 2024 16:41:27 -0400 Subject: [PATCH 10/31] wip: implement support for more attributes, with pretty printing --- hir2/src/attributes/overflow.rs | 6 + hir2/src/attributes/visibility.rs | 37 +++++- hir2/src/ir/attribute.rs | 210 +++++++++++++++++++++++++++++- hir2/src/ir/immediates.rs | 8 +- 4 files changed, 254 insertions(+), 7 deletions(-) diff --git a/hir2/src/attributes/overflow.rs b/hir2/src/attributes/overflow.rs index b7073e479..7822ae3b2 100644 --- a/hir2/src/attributes/overflow.rs +++ b/hir2/src/attributes/overflow.rs @@ -56,5 +56,11 @@ impl fmt::Display for Overflow { } } } +impl crate::formatter::PrettyPrint for Overflow { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + display(self) + } +} define_attr_type!(Overflow); diff --git a/hir2/src/attributes/visibility.rs b/hir2/src/attributes/visibility.rs index 2d8c7e188..637fd2eda 100644 --- a/hir2/src/attributes/visibility.rs +++ b/hir2/src/attributes/visibility.rs @@ -1,5 +1,7 @@ use core::{fmt, str::FromStr}; +use crate::define_attr_type; + /// The types of visibility that a [Symbol] may have #[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum Visibility { @@ -21,16 +23,43 @@ pub enum Visibility { /// graph, thus retaining the ability to observe all uses, and optimize based on that /// information. /// - /// Nested visibility implies that we know all uses of the symbol, but that there may be uses + /// Internal visibility implies that we know all uses of the symbol, but that there may be uses /// in other symbol tables in addition to the current one. - Nested, + Internal, +} +define_attr_type!(Visibility); +impl Visibility { + #[inline] + pub fn is_public(&self) -> bool { + matches!(self, Self::Public) + } + + #[inline] + pub fn is_private(&self) -> bool { + matches!(self, Self::Private) + } + + #[inline] + pub fn is_internal(&self) -> bool { + matches!(self, Self::Internal) + } +} +impl crate::formatter::PrettyPrint for Visibility { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + match self { + Self::Public => const_text("public"), + Self::Private => const_text("private"), + Self::Internal => const_text("internal"), + } + } } impl fmt::Display for Visibility { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Public => f.write_str("public"), Self::Private => f.write_str("private"), - Self::Nested => f.write_str("nested"), + Self::Internal => f.write_str("internal"), } } } @@ -41,7 +70,7 @@ impl FromStr for Visibility { match s { "public" => Ok(Self::Public), "private" => Ok(Self::Private), - "nested" => Ok(Self::Nested), + "internal" => Ok(Self::Internal), _ => Err(()), } } diff --git a/hir2/src/ir/attribute.rs b/hir2/src/ir/attribute.rs index aaf7831e4..2266ac176 100644 --- a/hir2/src/ir/attribute.rs +++ b/hir2/src/ir/attribute.rs @@ -200,14 +200,14 @@ impl Attribute { } impl fmt::Display for Attribute { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.value.as_deref() { + match self.value.as_deref().map(|v| v.render()) { None => write!(f, "#[{}]", self.name.as_str()), Some(value) => write!(f, "#[{}({value})]", &self.name), } } } -pub trait AttributeValue: Any + fmt::Debug + fmt::Display + 'static { +pub trait AttributeValue: Any + fmt::Debug + crate::formatter::PrettyPrint + 'static { fn as_any(&self) -> &dyn Any; fn as_any_mut(&mut self) -> &mut dyn Any; } @@ -226,6 +226,202 @@ impl dyn AttributeValue { } } +pub struct SetAttr { + values: Vec, +} +impl Default for SetAttr { + fn default() -> Self { + Self { + values: Default::default(), + } + } +} +impl SetAttr +where + K: Ord + Clone, +{ + pub fn insert(&mut self, key: K) -> bool { + match self.values.binary_search_by(|k| key.cmp(k)) { + Ok(index) => { + self.values[index] = key; + false + } + Err(index) => { + self.values.insert(index, key); + true + } + } + } + + pub fn contains(&self, key: &K) -> bool { + self.values.binary_search_by(|k| key.cmp(k)).is_ok() + } + + pub fn iter(&self) -> core::slice::Iter<'_, K> { + self.values.iter() + } + + pub fn remove(&mut self, key: &Q) -> Option + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|k| key.cmp(k.borrow())) { + Ok(index) => Some(self.values.remove(index)), + Err(_) => None, + } + } +} +impl Eq for SetAttr where K: Eq {} +impl PartialEq for SetAttr +where + K: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.values == other.values + } +} +impl fmt::Debug for SetAttr +where + K: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_set().entries(self.values.iter()).finish() + } +} +impl crate::formatter::PrettyPrint for SetAttr +where + K: crate::formatter::PrettyPrint, +{ + fn render(&self) -> crate::formatter::Document { + todo!() + } +} +impl AttributeValue for SetAttr +where + K: fmt::Debug + crate::formatter::PrettyPrint + 'static, +{ + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self as &mut dyn Any + } +} + +#[derive(Clone)] +pub struct DictAttr { + values: Vec<(K, V)>, +} +impl Default for DictAttr { + fn default() -> Self { + Self { values: vec![] } + } +} +impl DictAttr +where + K: Ord, + V: Clone, +{ + pub fn insert(&mut self, key: K, value: V) { + match self.values.binary_search_by(|(k, _)| key.cmp(k)) { + Ok(index) => { + self.values[index].1 = value; + } + Err(index) => { + self.values.insert(index, (key, value)); + } + } + } + + pub fn contains_key(&self, key: &Q) -> bool + where + K: Borrow, + Q: ?Sized + Ord, + { + self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())).is_ok() + } + + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())) { + Ok(index) => Some(&self.values[index].1), + Err(_) => None, + } + } + + pub fn remove(&mut self, key: &Q) -> Option + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())) { + Ok(index) => Some(self.values.remove(index).1), + Err(_) => None, + } + } + + pub fn iter(&self) -> core::slice::Iter<'_, (K, V)> { + self.values.iter() + } +} +impl Eq for DictAttr +where + K: Eq, + V: Eq, +{ +} +impl PartialEq for DictAttr +where + K: PartialEq, + V: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.values == other.values + } +} +impl fmt::Debug for DictAttr +where + K: fmt::Debug, + V: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.values.iter().map(|entry| (&entry.0, &entry.1))) + .finish() + } +} +impl crate::formatter::PrettyPrint for DictAttr +where + K: crate::formatter::PrettyPrint, + V: crate::formatter::PrettyPrint, +{ + fn render(&self) -> crate::formatter::Document { + todo!() + } +} +impl AttributeValue for DictAttr +where + K: fmt::Debug + crate::formatter::PrettyPrint + 'static, + V: fmt::Debug + crate::formatter::PrettyPrint + 'static, +{ + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self as &mut dyn Any + } +} + #[macro_export] macro_rules! define_attr_type { ($T:ty) => { @@ -244,6 +440,16 @@ macro_rules! define_attr_type { } define_attr_type!(bool); +define_attr_type!(u8); +define_attr_type!(i8); +define_attr_type!(u16); +define_attr_type!(i16); +define_attr_type!(u32); +define_attr_type!(core::num::NonZeroU32); +define_attr_type!(i32); +define_attr_type!(u64); +define_attr_type!(i64); +define_attr_type!(usize); define_attr_type!(isize); define_attr_type!(Symbol); define_attr_type!(super::Immediate); diff --git a/hir2/src/ir/immediates.rs b/hir2/src/ir/immediates.rs index 8f8806485..57d4795f5 100644 --- a/hir2/src/ir/immediates.rs +++ b/hir2/src/ir/immediates.rs @@ -5,7 +5,7 @@ use core::{ pub use miden_core::{Felt, FieldElement, StarkField}; -use super::Type; +use crate::{formatter::PrettyPrint, Type}; #[derive(Debug, Copy, Clone)] pub enum Immediate { @@ -297,6 +297,12 @@ impl fmt::Display for Immediate { } } } +impl PrettyPrint for Immediate { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + display(self) + } +} impl Hash for Immediate { fn hash(&self, state: &mut H) { let d = std::mem::discriminant(self); From 08004231f173c4d404d8d2821e5fc7f3665cabd0 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Oct 2024 16:42:16 -0400 Subject: [PATCH 11/31] wip: implement generic pretty printer for ir --- hir2/src/ir/print.rs | 194 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 hir2/src/ir/print.rs diff --git a/hir2/src/ir/print.rs b/hir2/src/ir/print.rs new file mode 100644 index 000000000..c465da32c --- /dev/null +++ b/hir2/src/ir/print.rs @@ -0,0 +1,194 @@ +use core::fmt; + +use super::{Context, Operation}; +use crate::{ + formatter::PrettyPrint, + traits::{CallableOpInterface, SingleBlock, SingleRegion}, + Entity, Value, +}; + +pub struct OpPrintingFlags; + +/// The `OpPrinter` trait is expected to be implemented by all [Op] impls as a prequisite. +/// +/// The actual implementation is typically generated as part of deriving [Op]. +pub trait OpPrinter { + fn print( + &self, + flags: &OpPrintingFlags, + context: &Context, + f: &mut fmt::Formatter, + ) -> fmt::Result; +} + +impl OpPrinter for Operation { + #[inline] + fn print( + &self, + _flags: &OpPrintingFlags, + _context: &Context, + f: &mut fmt::Formatter, + ) -> fmt::Result { + write!(f, "{}", self.render()) + } +} + +/// The generic format for printed operations is: +/// +/// <%result..> = .(%operand : , ..) : #.. { +/// // Region +/// ^(<%block_argument...>): +/// // Block +/// }; +/// +/// Special handling is provided for SingleRegionSingleBlock and CallableOpInterface ops: +/// +/// * SingleRegionSingleBlock ops with no operands will have the block header elided +/// * CallableOpInterface ops with no operands will be printed differently, using their +/// symbol and signature, as shown below: +/// +/// . @() -> #.. { +/// ... +/// } +impl PrettyPrint for Operation { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + + let is_single_region_single_block = + self.implements::() && self.implements::(); + let is_callable_op = self.implements::(); + let is_symbol = self.is_symbol(); + let no_operands = self.operands().is_empty(); + + let results = self.results(); + let mut doc = if !results.is_empty() { + let results = results.iter().enumerate().fold(Document::Empty, |doc, (i, result)| { + if i > 0 { + doc + const_text(", ") + display(result.borrow().id()) + } else { + doc + display(result.borrow().id()) + } + }); + results + const_text(" = ") + } else { + Document::Empty + }; + doc += display(self.name()); + let doc = if is_callable_op && is_symbol && no_operands { + let name = self.as_symbol().unwrap().name(); + let callable = self.as_trait::().unwrap(); + let signature = callable.signature(); + let mut doc = doc + display(signature.visibility) + text(format!(" @{}", name)); + if let Some(body) = callable.get_callable_region() { + let body = body.borrow(); + let entry = body.entry(); + doc += entry.arguments().iter().enumerate().fold( + const_text("("), + |doc, (i, param)| { + let param = param.borrow(); + let doc = if i > 0 { doc + const_text(", ") } else { doc }; + doc + display(param.id()) + const_text(": ") + display(param.ty()) + }, + ) + const_text(")"); + if !signature.results.is_empty() { + doc += signature.results().iter().enumerate().fold( + const_text(" -> "), + |doc, (i, result)| { + if i > 0 { + doc + const_text(", ") + display(&result.ty) + } else { + doc + display(&result.ty) + } + }, + ); + } + } else { + doc += signature.render() + } + doc + } else { + let operands = self.operands(); + let doc = if !operands.is_empty() { + operands.iter().enumerate().fold(doc + const_text("("), |doc, (i, operand)| { + let operand = operand.borrow(); + let value = operand.value(); + if i > 0 { + doc + const_text(", ") + + display(value.id()) + + const_text(": ") + + display(value.ty()) + } else { + doc + display(value.id()) + const_text(": ") + display(value.ty()) + } + }) + const_text(")") + } else { + doc + }; + if !results.is_empty() { + let results = + results.iter().enumerate().fold(Document::Empty, |doc, (i, result)| { + if i > 0 { + doc + const_text(", ") + text(format!("{}", result.borrow().ty())) + } else { + doc + text(format!("{}", result.borrow().ty())) + } + }); + doc + const_text(" : ") + results + } else { + doc + } + }; + + let doc = self.attrs.iter().enumerate().fold(doc, |doc, (i, attr)| { + let doc = if i > 0 { doc + const_text(" ") } else { doc }; + if let Some(value) = attr.value() { + doc + const_text("#[") + + display(attr.name) + + const_text(" = ") + + value.render() + + const_text("]") + } else { + doc + text(format!("#[{}]", &attr.name)) + } + }); + + if self.has_regions() { + self.regions.iter().fold(doc, |doc, region| { + let blocks = region.body().iter().fold(Document::Empty, |doc, block| { + let ops = + block.body().iter().fold(Document::Empty, |doc, op| doc + op.render()); + if is_single_region_single_block && no_operands { + doc + indent(4, nl() + ops) + nl() + } else { + let block_args = block.arguments().iter().enumerate().fold( + Document::Empty, + |doc, (i, arg)| { + if i > 0 { + doc + const_text(", ") + arg.borrow().render() + } else { + doc + arg.borrow().render() + } + }, + ); + let block_args = if block_args.is_empty() { + block_args + } else { + const_text("(") + block_args + const_text(")") + }; + doc + indent( + 4, + text(format!("^{}", block.id())) + + block_args + + const_text(":") + + nl() + + ops, + ) + nl() + } + }); + doc + indent(4, const_text(" {") + nl() + blocks) + nl() + const_text("}") + }) + const_text(";") + } else { + doc + const_text(";") + } + } +} From 5ad94accdf899a7b2bcb6afdf721f290ca2a89ce Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Oct 2024 16:44:54 -0400 Subject: [PATCH 12/31] wip: improve ergonomics of entity refs, implement support for generic grouped entity storage --- hir2/src/ir/entity.rs | 225 ++++++++++-- hir2/src/ir/entity/group.rs | 217 +++++++++++ hir2/src/ir/entity/list.rs | 2 +- hir2/src/ir/entity/storage.rs | 674 ++++++++++++++++++++++++++++++++++ 4 files changed, 1082 insertions(+), 36 deletions(-) create mode 100644 hir2/src/ir/entity/group.rs create mode 100644 hir2/src/ir/entity/storage.rs diff --git a/hir2/src/ir/entity.rs b/hir2/src/ir/entity.rs index 2d41c7bab..6539b4864 100644 --- a/hir2/src/ir/entity.rs +++ b/hir2/src/ir/entity.rs @@ -1,4 +1,6 @@ +mod group; mod list; +mod storage; use alloc::alloc::{AllocError, Layout}; use core::{ @@ -11,7 +13,12 @@ use core::{ ptr::NonNull, }; -pub use self::list::{EntityCursor, EntityCursorMut, EntityIter, EntityList}; +pub use self::{ + group::EntityGroup, + list::{EntityCursor, EntityCursorMut, EntityIter, EntityList}, + storage::{EntityRange, EntityRangeMut, EntityStorage}, +}; +use crate::any::*; /// A trait implemented by an IR entity that has a unique identifier /// @@ -22,6 +29,24 @@ pub trait Entity: Any { fn id(&self) -> Self::Id; } +/// A trait implemented by an IR entity that can be stored in [EntityStorage]. +pub trait StorableEntity { + /// Get the absolute index of this entity in its container. + fn index(&self) -> usize; + /// Set the absolute index of this entity in its container. + /// + /// # Safety + /// + /// This is intended to be called only by the [EntityStorage] implementation, as it is + /// responsible for maintaining indices of all items it is storing. However, entities commonly + /// want to know their own index in storage, so this trait allows them to conceptually own the + /// index, but delegate maintenance to [EntityStorage]. + unsafe fn set_index(&mut self, index: usize); + /// Called when this entity is removed from [EntityStorage] + #[inline(always)] + fn unlink(&mut self) {} +} + /// A trait that must be implemented by the unique identifier for an [Entity] pub trait EntityId: Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash { fn as_usize(&self) -> usize; @@ -214,7 +239,7 @@ impl RawEntityRef, Metadata> { /// value really is in an initialized state. Calling this when the content is not yet fully /// initialized causes immediate undefined behavior. #[inline] - pub unsafe fn assume_init(self) -> RawEntityRef { + pub unsafe fn assume_init(self) -> RawEntityRef { let ptr = Self::into_inner(self); unsafe { RawEntityRef::from_inner(ptr.cast()) } } @@ -267,13 +292,15 @@ impl RawEntityRef { } /// Get a dynamically-checked immutable reference to the underlying `T` - pub fn borrow(&self) -> EntityRef<'_, T> { + #[track_caller] + pub fn borrow<'a, 'b: 'a>(&'a self) -> EntityRef<'b, T> { let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); unsafe { (*core::ptr::addr_of!((*ptr).entity)).borrow() } } /// Get a dynamically-checked mutable reference to the underlying `T` - pub fn borrow_mut(&mut self) -> EntityMut<'_, T> { + #[track_caller] + pub fn borrow_mut<'a, 'b: 'a>(&'a mut self) -> EntityMut<'b, T> { let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); unsafe { (*core::ptr::addr_of!((*ptr).entity)).borrow_mut() } } @@ -281,7 +308,7 @@ impl RawEntityRef { /// Try to get a dynamically-checked mutable reference to the underlying `T` /// /// Returns `None` if the entity is already borrowed - pub fn try_borrow_mut(&mut self) -> Option> { + pub fn try_borrow_mut<'a, 'b: 'a>(&'a mut self) -> Option> { let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); unsafe { (*core::ptr::addr_of!((*ptr).entity)).try_borrow_mut().ok() } } @@ -336,38 +363,130 @@ impl RawEntityRef { } } -impl RawEntityRef { - /// Returns true if the underlying value is a `T` +impl RawEntityRef { + /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. + /// + /// If the cast is not valid for this reference, `Err` is returned containing the original value. #[inline] - pub fn is(self) -> bool { - self.borrow().is::() + pub fn try_downcast( + self, + ) -> Result, RawEntityRef> + where + To: DowncastFromRef + 'static, + From: Is + AsAny + 'static, + Obj: ?Sized, + { + RawEntityRef::::try_downcast_from(self) } /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. /// /// If the cast is not valid for this reference, `Err` is returned containing the original value. #[inline] - pub fn downcast(self) -> Result, Self> { - if self.borrow().is::() { - unsafe { Ok(Self::downcast_unchecked(self)) } - } else { - Err(self) - } + pub fn try_downcast_ref(&self) -> Option> + where + To: DowncastFromRef + 'static, + From: Is + AsAny + 'static, + Obj: ?Sized, + { + RawEntityRef::::try_downcast_from_ref(self) } - /// Casts this reference to the concrete type `T` without checking that the cast is valid. + /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. /// - /// # Safety + /// Panics if the cast is not valid for this reference. + #[inline] + #[track_caller] + pub fn downcast(self) -> RawEntityRef + where + To: DowncastFromRef + 'static, + From: Is + AsAny + 'static, + Obj: ?Sized, + { + RawEntityRef::::downcast_from(self) + } + + /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. /// - /// The referenced value must be of type `T`. Calling this method with the incorrect type is - /// _undefined behavior_. + /// Panics if the cast is not valid for this reference. #[inline] - pub unsafe fn downcast_unchecked(self) -> RawEntityRef { - unsafe { - let ptr = RawEntityRef::into_inner(self); - RawEntityRef::from_inner(ptr.cast()) + #[track_caller] + pub fn downcast_ref(&self) -> RawEntityRef + where + To: DowncastFromRef + 'static, + From: Is + AsAny + 'static, + Obj: ?Sized, + { + RawEntityRef::::downcast_from_ref(self) + } +} + +impl RawEntityRef { + pub fn try_downcast_from( + from: RawEntityRef, + ) -> Result> + where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, + { + let borrow = from.borrow(); + if let Some(to) = borrow.as_any().downcast_ref() { + Ok(unsafe { RawEntityRef::from_raw(to) }) + } else { + Err(from) } } + + pub fn try_downcast_from_ref(from: &RawEntityRef) -> Option + where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, + { + let borrow = from.borrow(); + if let Some(to) = borrow.as_any().downcast_ref() { + Some(unsafe { RawEntityRef::from_raw(to) }) + } else { + None + } + } + + #[track_caller] + pub fn downcast_from(from: RawEntityRef) -> Self + where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, + { + let borrow = from.borrow(); + unsafe { RawEntityRef::from_raw(borrow.as_any().downcast_ref().expect("invalid cast")) } + } + + #[track_caller] + pub fn downcast_from_ref(from: &RawEntityRef) -> Self + where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, + { + let borrow = from.borrow(); + unsafe { RawEntityRef::from_raw(borrow.as_any().downcast_ref().expect("invalid cast")) } + } +} + +impl RawEntityRef { + /// Casts this reference to the an unsized type `Trait`, if `From` implements `Trait` + /// + /// If the cast is not valid for this reference, `Err` is returned containing the original value. + #[inline] + pub fn upcast(self) -> RawEntityRef + where + To: ?Sized, + From: core::marker::Unsize + AsAny + 'static, + { + unsafe { RawEntityRef::::from_inner(self.inner) } + } } impl core::ops::CoerceUnsized> @@ -377,7 +496,17 @@ where U: ?Sized, { } - +impl Eq for RawEntityRef {} +impl PartialEq for RawEntityRef { + fn eq(&self, other: &Self) -> bool { + Self::ptr_eq(self, other) + } +} +impl core::hash::Hash for RawEntityRef { + fn hash(&self, state: &mut H) { + self.inner.hash(state); + } +} impl fmt::Pointer for RawEntityRef { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -391,16 +520,30 @@ impl fmt::Debug for RawEntityRef fmt::Debug::fmt(&self.borrow(), f) } } - -impl Eq for RawEntityRef {} -impl PartialEq for RawEntityRef { - fn eq(&self, other: &Self) -> bool { - Self::ptr_eq(self, other) +impl crate::formatter::PrettyPrint + for RawEntityRef +{ + #[inline] + fn render(&self) -> crate::formatter::Document { + self.borrow().render() } } -impl core::hash::Hash for RawEntityRef { - fn hash(&self, state: &mut H) { - self.inner.hash(state); +impl StorableEntity for RawEntityRef { + #[inline] + fn index(&self) -> usize { + self.borrow().index() + } + + #[inline] + unsafe fn set_index(&mut self, index: usize) { + unsafe { + self.borrow_mut().set_index(index); + } + } + + #[inline] + fn unlink(&mut self) { + self.borrow_mut().unlink() } } @@ -447,6 +590,12 @@ impl fmt::Display for EntityRef<'_, T> { (**self).fmt(f) } } +impl crate::formatter::PrettyPrint for EntityRef<'_, T> { + #[inline] + fn render(&self) -> crate::formatter::Document { + (**self).render() + } +} impl Eq for EntityRef<'_, T> {} impl PartialEq for EntityRef<'_, T> { fn eq(&self, other: &Self) -> bool { @@ -514,12 +663,12 @@ impl<'b, T: ?Sized> EntityMut<'b, T> { /// # Examples /// /// ```rust - /// use crate::*; + /// use midenc_hir2::*; /// use blink_alloc::Blink; /// /// let alloc = Blink::default(); - /// let entity = UnsafeEntityRef::new([1, 2, 3, 4], &alloc); - /// let borrow = entity.get_mut(); + /// let mut entity = UnsafeEntityRef::new([1, 2, 3, 4], &alloc); + /// let borrow = entity.borrow_mut(); /// let (mut begin, mut end) = EntityMut::map_split(borrow, |slice| slice.split_at_mut(2)); /// assert_eq!(*begin, [1, 2]); /// assert_eq!(*end, [3, 4]); @@ -584,6 +733,12 @@ impl fmt::Display for EntityMut<'_, T> { (**self).fmt(f) } } +impl crate::formatter::PrettyPrint for EntityMut<'_, T> { + #[inline] + fn render(&self) -> crate::formatter::Document { + (**self).render() + } +} impl Eq for EntityMut<'_, T> {} impl PartialEq for EntityMut<'_, T> { fn eq(&self, other: &Self) -> bool { diff --git a/hir2/src/ir/entity/group.rs b/hir2/src/ir/entity/group.rs new file mode 100644 index 000000000..f3e8bb69c --- /dev/null +++ b/hir2/src/ir/entity/group.rs @@ -0,0 +1,217 @@ +use core::fmt; + +/// Represents size and range information for a contiguous grouping of entities in a vector. +/// +/// This is used so that individual groups can be grown or shrunk, while maintaining stability +/// of references to items in other groups. +#[derive(Default, Copy, Clone)] +pub struct EntityGroup(u16); +impl fmt::Debug for EntityGroup { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EntityGroup") + .field("range", &self.as_range()) + .field("len", &self.len()) + .finish() + } +} +impl EntityGroup { + const START_MASK: u16 = u8::MAX as u16; + + /// Create a new group of size `len`, starting at index `start` + pub fn new(start: usize, len: usize) -> Self { + let start = u16::try_from(start).expect("too many items"); + let len = u16::try_from(len).expect("group too large"); + let group = start | (len << 8); + + Self(group) + } + + /// Get the start index in the containing vector + #[inline] + pub fn start(&self) -> usize { + (self.0 & Self::START_MASK) as usize + } + + /// Get the end index (exclusive) in the containing vector + #[inline] + pub fn end(&self) -> usize { + self.start() + self.len() + } + + /// Returns true if this group is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get the number of items in this group + #[inline] + pub fn len(&self) -> usize { + (self.0 >> 8) as usize + } + + /// Get the [core::ops::Range] equivalent of this group + pub fn as_range(&self) -> core::ops::Range { + let start = self.start(); + let len = self.len(); + start..(start + len) + } + + /// Increase the size of this group by `n` items + /// + /// Panics if `n` overflows `u16::MAX`, or if the resulting size overflows `u8::MAX` + pub fn grow(&mut self, n: usize) { + let n = u16::try_from(n).expect("group is too large"); + let start = self.0 & Self::START_MASK; + let len = (self.0 >> 8) + n; + assert!(len <= u8::MAX as u16, "group is too large"); + self.0 = start | (len << 8); + } + + /// Decrease the size of this group by `n` items + /// + /// Panics if `n` overflows `u16::MAX`, or if `n` is greater than the number of remaining items. + pub fn shrink(&mut self, n: usize) { + let n = u16::try_from(n).expect("cannot shrink by a size larger than the max group size"); + let start = self.0 & Self::START_MASK; + let len = (self.0 >> 8).saturating_sub(n); + self.0 = start | (len << 8); + } + + /// Shift the position of this group by `offset` + pub fn shift_start(&mut self, offset: isize) { + let offset = i16::try_from(offset).expect("offset too large"); + let mut start = self.0 & Self::START_MASK; + if offset >= 0 { + start += offset as u16; + } else { + start -= offset.unsigned_abs(); + } + assert!(start <= Self::START_MASK, "group offset cannot be larger than u8::MAX"); + self.0 &= !Self::START_MASK; + self.0 |= start; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn entity_group_empty() { + let group = EntityGroup::new(0, 0); + assert_eq!(group.start(), 0); + assert_eq!(group.end(), 0); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 0..0); + + let group = EntityGroup::new(101, 0); + assert_eq!(group.start(), 101); + assert_eq!(group.end(), 101); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 101..101); + } + + #[test] + fn entity_group_non_empty() { + let group = EntityGroup::new(0, 1); + assert_eq!(group.start(), 0); + assert_eq!(group.end(), 1); + assert_eq!(group.len(), 1); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 0..1); + + let group = EntityGroup::new(255, 255); + assert_eq!(group.start(), 255); + assert_eq!(group.end(), 510); + assert_eq!(group.len(), 255); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 255..510); + } + + #[test] + fn entity_group_grow() { + let mut group = EntityGroup::new(10, 0); + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 10); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 10..10); + + group.grow(1); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 11); + assert_eq!(group.len(), 1); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 10..11); + + group.grow(3); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 14); + assert_eq!(group.len(), 4); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 10..14); + } + + #[test] + fn entity_group_shrink() { + let mut group = EntityGroup::new(10, 4); + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 14); + assert_eq!(group.len(), 4); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 10..14); + + group.shrink(3); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 11); + assert_eq!(group.len(), 1); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 10..11); + + group.shrink(1); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 10); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 10..10); + + group.shrink(1); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 10); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 10..10); + } + + #[test] + fn entity_group_shift_start() { + let mut group = EntityGroup::new(10, 0); + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 10); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 10..10); + + group.shift_start(10); + assert_eq!(group.start(), 20); + assert_eq!(group.end(), 20); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 20..20); + + group.shift_start(-5); + assert_eq!(group.start(), 15); + assert_eq!(group.end(), 15); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 15..15); + } +} diff --git a/hir2/src/ir/entity/list.rs b/hir2/src/ir/entity/list.rs index 7c38fe5ef..f3a8540fe 100644 --- a/hir2/src/ir/entity/list.rs +++ b/hir2/src/ir/entity/list.rs @@ -572,7 +572,7 @@ impl<'a, T> DoubleEndedIterator for EntityIter<'a, T> { } } -type IntrusiveLink = intrusive_collections::LinkedListLink; +pub type IntrusiveLink = intrusive_collections::LinkedListLink; impl RawEntityRef { /// Create a new [UnsafeIntrusiveEntityRef] by allocating `value` in `arena` diff --git a/hir2/src/ir/entity/storage.rs b/hir2/src/ir/entity/storage.rs new file mode 100644 index 000000000..fed8d6cc2 --- /dev/null +++ b/hir2/src/ir/entity/storage.rs @@ -0,0 +1,674 @@ +use core::fmt; + +use smallvec::{smallvec, SmallVec}; + +use super::{EntityGroup, StorableEntity}; + +/// [EntityStorage] provides an abstraction over storing IR entities in an [crate::Operation]. +/// +/// Specifically, it provides an abstraction for storing IR entities in a flat vector, while +/// retaining the ability to semantically group the entities, access them by group or individually, +/// and grow or shrink the group or overall set. +/// +/// The implementation expects the types stored in it to implement the [StorableEntity] trait, which +/// provides it the ability to ensure the entity is kept up to date with its position in the +/// set. Additionally, it ensures that removing an entity will unlink that entity from any +/// dependents or dependencies that it needs to maintain links for. +/// +/// Users can control the number of entities stored inline via the `INLINE` const parameter. By +/// default, only a single entity is stored inline, but sometimes more may be desired if you know +/// that a particular entity always has a particular cardinality. +pub struct EntityStorage { + /// The items being stored + items: SmallVec<[T; INLINE]>, + /// The semantic grouping information for this instance. + /// + /// There is always at least one group, and more can be explicitly added/removed. + groups: SmallVec<[EntityGroup; 2]>, +} + +impl fmt::Debug for EntityStorage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(core::any::type_name::()) + .field_with("groups", |f| { + let mut builder = f.debug_list(); + for group in self.groups.iter() { + let range = group.as_range(); + let items = &self.items[range.clone()]; + builder.entry_with(|f| { + f.debug_map().entry(&"range", &range).entry(&"items", &items).finish() + }); + } + builder.finish() + }) + .finish() + } +} + +impl Default for EntityStorage { + fn default() -> Self { + Self { + items: Default::default(), + groups: smallvec![EntityGroup::default()], + } + } +} + +impl EntityStorage { + /// Returns true if there are no items in storage. + #[inline] + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Returns the total number of items in storage. + #[inline] + pub fn len(&self) -> usize { + self.items.len() + } + + /// Returns the number of groups with allocated storage. + #[inline] + pub fn num_groups(&self) -> usize { + self.groups.len() + } + + /// Get an iterator over all of the items in storage + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.items.iter() + } + + /// Get a mutable iterator over all of the items in storage + #[inline] + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, T> { + self.items.iter_mut() + } +} + +impl EntityStorage { + /// Push an item to the last group + pub fn push(&mut self, mut item: T) { + let index = self.items.len(); + unsafe { + item.set_index(index); + } + self.items.push(item); + let group = self.groups.last_mut().unwrap(); + group.grow(1); + } + + /// Extend the last group with `items` + pub fn extend(&mut self, items: I) + where + I: IntoIterator, + { + let mut group = self.group_mut(self.groups.len() - 1); + group.extend(items); + } + + /// Push `items` as a new group, and return the group index + #[inline] + pub fn push_group(&mut self, items: impl IntoIterator) -> usize { + let group = self.groups.len(); + self.extend_group(group, items); + group + } + + /// Push `item` to the specified group + pub fn push_to_group(&mut self, group: usize, item: T) { + if self.groups.len() <= group { + let next_offset = self.groups.last().map(|group| group.as_range().end).unwrap_or(0); + self.groups.resize(group + 1, EntityGroup::new(next_offset, 0)); + } + let mut group = self.group_mut(group); + group.push(item); + } + + /// Pushes `items` to the given group, creating it if necessary, and allocating any intervening + /// implied groups if they have not been created it. + pub fn extend_group(&mut self, group: usize, items: I) + where + I: IntoIterator, + { + if self.groups.len() <= group { + let next_offset = self.groups.last().map(|group| group.as_range().end).unwrap_or(0); + self.groups.resize(group + 1, EntityGroup::new(next_offset, 0)); + } + let mut group = self.group_mut(group); + group.extend(items); + } + + /// Clear all items in storage + pub fn clear(&mut self) { + for mut item in self.items.drain(..) { + item.unlink(); + } + self.groups.clear(); + self.groups.push(EntityGroup::default()); + } + + /// Get all the items in storage + pub fn all(&self) -> EntityRange<'_, T> { + EntityRange { + range: 0..self.items.len(), + items: self.items.as_slice(), + } + } + + /// Get an [EntityRange] covering items in the specified group + pub fn group(&self, group: usize) -> EntityRange<'_, T> { + EntityRange { + range: self.groups[group].as_range(), + items: self.items.as_slice(), + } + } + + /// Get an [EntityRangeMut] covering items in the specified group + pub fn group_mut(&mut self, group: usize) -> EntityRangeMut<'_, T, INLINE> { + let range = self.groups[group].as_range(); + EntityRangeMut { + group, + range, + groups: &mut self.groups, + items: &mut self.items, + } + } +} +impl core::ops::Index for EntityStorage { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.items[index] + } +} +impl core::ops::IndexMut for EntityStorage { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.items[index] + } +} + +/// A reference to a range of items in [EntityStorage] +pub struct EntityRange<'a, T> { + range: core::ops::Range, + items: &'a [T], +} +impl<'a, T> EntityRange<'a, T> { + /// Returns true if this range is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.as_slice().is_empty() + } + + /// Returns the size of this range + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + /// Get this range as a slice + #[inline] + pub fn as_slice(&self) -> &[T] { + if self.range.is_empty() { + &[] + } else { + &self.items[self.range.start..self.range.end] + } + } + + /// Get an iterator over the items in this range + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.as_slice().iter() + } + + /// Get an item at the specified index relative to this range, or `None` if the index is out of bounds. + #[inline] + pub fn get(&self, index: usize) -> Option<&T> { + self.as_slice().get(index) + } +} +impl<'a, T> core::ops::Index for EntityRange<'a, T> { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.as_slice()[index] + } +} + +/// A mutable range of items in [EntityStorage] +/// +/// Items outside the range are not modified, however the range itself can have its size change, +/// which as a result will shift other items around. Any other groups in [EntityStorage] will +/// be updated to reflect such changes, so in general this should be transparent. +pub struct EntityRangeMut<'a, T, const INLINE: usize = 1> { + group: usize, + range: core::ops::Range, + groups: &'a mut [EntityGroup], + items: &'a mut SmallVec<[T; INLINE]>, +} +impl<'a, T, const INLINE: usize> EntityRangeMut<'a, T, INLINE> { + /// Returns true if this range is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.as_slice().is_empty() + } + + /// Get the number of items covered by this range + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + /// Get this range as a slice + #[inline] + pub fn as_slice(&self) -> &[T] { + if self.range.is_empty() { + &[] + } else { + &self.items[self.range.start..self.range.end] + } + } + + /// Get this range as a mutable slice + #[inline] + pub fn as_slice_mut(&mut self) -> &mut [T] { + if self.range.is_empty() { + &mut [] + } else { + &mut self.items[self.range.start..self.range.end] + } + } + + /// Get an iterator over the items covered by this range + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.as_slice().iter() + } + + /// Get a mutable iterator over the items covered by this range + #[inline] + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, T> { + self.as_slice_mut().iter_mut() + } + + /// Get a reference to the item at `index`, relative to the start of this range. + #[inline] + pub fn get(&self, index: usize) -> Option<&T> { + self.as_slice().get(index) + } + + /// Get a mutable reference to the item at `index`, relative to the start of this range. + #[inline] + pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { + self.as_slice_mut().get_mut(index) + } +} + +impl<'a, T: StorableEntity, const INLINE: usize> EntityRangeMut<'a, T, INLINE> { + /// Append `item` to storage at the end of this range + #[inline] + pub fn push(&mut self, item: T) { + self.extend([item]); + } + + /// Append `items` to storage at the end of this range + pub fn extend(&mut self, operands: I) + where + I: IntoIterator, + { + // Handle edge case where group is the last group + let is_last = self.groups.len() == self.group + 1; + if is_last { + self.extend_last(operands); + } else { + self.extend_within(operands); + } + } + + fn extend_last(&mut self, items: I) + where + I: IntoIterator, + { + let prev_len = self.items.len(); + self.items.extend(items.into_iter().enumerate().map(|(i, mut item)| { + unsafe { + item.set_index(prev_len + i); + } + item + })); + let num_inserted = self.items.len().abs_diff(prev_len); + if num_inserted == 0 { + return; + } + self.groups[self.group].grow(num_inserted); + self.range = self.groups[self.group].as_range(); + } + + fn extend_within(&mut self, items: I) + where + I: IntoIterator, + { + let prev_len = self.items.len(); + let start = self.range.end; + self.items.insert_many( + start, + items.into_iter().enumerate().map(|(i, mut item)| { + unsafe { + item.set_index(start + i); + } + item + }), + ); + let num_inserted = self.items.len().abs_diff(prev_len); + if num_inserted == 0 { + return; + } + self.groups[self.group].grow(num_inserted); + self.range = self.groups[self.group].as_range(); + + // Shift groups + for group in self.groups[(self.group + 1)..].iter_mut() { + group.shift_start(num_inserted as isize); + } + + // Shift item indices + let shifted = self.range.end; + for (offset, item) in self.items[shifted..].iter_mut().enumerate() { + unsafe { + item.set_index(shifted + offset); + } + } + } + + /// Remove the last item from this group, or `None` if empty + pub fn pop(&mut self) -> Option { + if self.range.is_empty() { + return None; + } + let index = self.range.end - 1; + self.range.end = index; + self.groups[self.group].shrink(1); + let mut removed = self.items.remove(index); + { + removed.unlink(); + } + + // Shift groups + let next_group = self.group + 1; + if next_group < self.groups.len() { + for group in self.groups[next_group..].iter_mut() { + group.shift_start(-1); + } + } + + // Shift item indices + let next_item = index; + if next_item < self.items.len() { + for (offset, item) in self.items[next_item..].iter_mut().enumerate() { + unsafe { + item.set_index(next_item + offset); + } + } + } + + Some(removed) + } +} +impl<'a, T, const INLINE: usize> core::ops::Index for EntityRangeMut<'a, T, INLINE> { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.as_slice()[index] + } +} +impl<'a, T, const INLINE: usize> core::ops::IndexMut for EntityRangeMut<'a, T, INLINE> { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.as_slice_mut()[index] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + struct Item { + index: usize, + value: usize, + } + impl Item { + pub fn new(value: usize) -> Self { + Self { index: 0, value } + } + } + impl StorableEntity for Item { + fn index(&self) -> usize { + self.index + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index; + } + + fn unlink(&mut self) {} + } + + type ItemStorage = EntityStorage; + #[allow(unused)] + type ItemRange<'a> = EntityRange<'a, Item>; + #[allow(unused)] + type ItemRangeMut<'a> = EntityRangeMut<'a, Item, 1>; + + #[test] + fn entity_storage_empty_operations() { + let mut storage = ItemStorage::default(); + + // No items, but always have at least one group + assert_eq!(storage.len(), 0); + assert!(storage.is_empty()); + assert_eq!(storage.num_groups(), 1); + + { + let range = storage.all(); + assert_eq!(range.len(), 0); + assert!(range.is_empty()); + assert_eq!(range.as_slice(), &[]); + assert_eq!(range.iter().next(), None); + } + + // No items, two groups + let group = storage.push_group(None); + assert_eq!(group, 1); + assert_eq!(storage.num_groups(), 2); + assert_eq!(storage.len(), 0); + assert!(storage.is_empty()); + + { + let range = storage.group(0); + assert_eq!(range.len(), 0); + assert!(range.is_empty()); + assert_eq!(range.as_slice(), &[]); + assert_eq!(range.iter().next(), None); + } + } + + #[test] + fn entity_storage_push_to_empty_group_entity_range() { + let mut storage = ItemStorage::default(); + + // Get group as mutable range + let mut group_range = storage.group_mut(0); + + // Verify handling of empty group in EntityRangeMut + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + + // Push items to range + group_range.push(Item::new(0)); + group_range.push(Item::new(1)); + + // Verify range reflects changes + assert_eq!(group_range.len(), 2); + assert!(!group_range.is_empty()); + assert_eq!( + group_range.as_slice(), + &[Item { index: 0, value: 0 }, Item { index: 1, value: 1 }] + ); + assert_eq!(group_range.iter().next(), Some(&Item { index: 0, value: 0 })); + } + + #[test] + fn entity_storage_pop_from_non_empty_group_entity_range() { + let mut storage = ItemStorage::default(); + + assert_eq!(storage.num_groups(), 1); + storage.push_to_group(0, Item::new(0)); + assert_eq!(storage.len(), 1); + assert!(!storage.is_empty()); + + // Get group as mutable range + let mut group_range = storage.group_mut(0); + assert_eq!(group_range.len(), 1); + assert!(!group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[Item { index: 0, value: 0 }]); + assert_eq!(group_range.iter().next(), Some(&Item { index: 0, value: 0 })); + + // Pop item from range + let item = group_range.pop(); + assert_eq!(item, Some(Item { index: 0, value: 0 })); + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + assert_eq!(group_range.range.clone(), 0..0); + + // Pop from empty range should have no effect + let item = group_range.pop(); + assert_eq!(item, None); + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + assert_eq!(group_range.range.clone(), 0..0); + } + + #[test] + fn entity_storage_push_to_empty_group_entity_range_before_other_groups() { + let mut storage = ItemStorage::default(); + + storage.extend_group(0, [Item::new(0), Item::new(1)]); + let group1 = storage.push_group(None); + let group2 = storage.push_group(None); + let group3 = storage.push_group([Item::new(4), Item::new(5)]); + + assert!(!storage.is_empty()); + assert_eq!(storage.len(), 4); + assert_eq!(storage.num_groups(), 4); + + assert_eq!(storage.group(0).range.clone(), 0..2); + assert_eq!(storage.group(1).range.clone(), 2..2); + assert_eq!(storage.group(2).range.clone(), 2..2); + assert_eq!(storage.group(3).range.clone(), 2..4); + + // Insert items into first non-empty group + { + let mut group_range = storage.group_mut(group1); + + // Verify handling of empty group in EntityRangeMut + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + + // Push items to range + group_range.push(Item::new(2)); + group_range.push(Item::new(3)); + + // Verify range reflects changes + assert_eq!(group_range.len(), 2); + assert!(!group_range.is_empty()); + assert_eq!( + group_range.as_slice(), + &[Item { index: 2, value: 2 }, Item { index: 3, value: 3 }] + ); + assert_eq!(group_range.iter().next(), Some(&Item { index: 2, value: 2 })); + } + + // The subsequent empty group should still be empty, but at a new offset + let group_range = storage.group(group2); + assert_eq!(group_range.range.clone(), 4..4); + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + + // The trailing non-empty group should have updated offsets + let group_range = storage.group(group3); + assert_eq!(group_range.range.clone(), 4..6); + assert_eq!(group_range.len(), 2); + assert!(!group_range.is_empty()); + assert_eq!( + group_range.as_slice(), + &[Item { index: 4, value: 4 }, Item { index: 5, value: 5 }] + ); + assert_eq!(group_range.iter().next(), Some(&Item { index: 4, value: 4 })); + } + + #[test] + fn entity_storage_pop_from_non_empty_group_entity_range_before_other_groups() { + let mut storage = ItemStorage::default(); + + storage.extend_group(0, [Item::new(0), Item::new(1)]); + let group1 = storage.push_group(None); + let group2 = storage.push_group(None); + let group3 = storage.push_group([Item::new(4), Item::new(5)]); + + assert!(!storage.is_empty()); + assert_eq!(storage.len(), 4); + assert_eq!(storage.num_groups(), 4); + + assert_eq!(storage.group(0).range.clone(), 0..2); + assert_eq!(storage.group(1).range.clone(), 2..2); + assert_eq!(storage.group(2).range.clone(), 2..2); + assert_eq!(storage.group(3).range.clone(), 2..4); + + // Pop from group0 + { + let mut group_range = storage.group_mut(0); + let item = group_range.pop(); + assert_eq!(item, Some(Item { index: 1, value: 1 })); + assert_eq!(group_range.len(), 1); + assert!(!group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[Item { index: 0, value: 0 }]); + } + + // The subsequent empty group(s) should still be empty, but at a new offset + for group_index in [group1, group2] { + let group_range = storage.group(group_index); + assert_eq!(group_range.range.clone(), 1..1); + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + } + + // The trailing non-empty group should have updated offsets + let group_range = storage.group(group3); + assert_eq!(group_range.range.clone(), 1..3); + assert_eq!(group_range.len(), 2); + assert!(!group_range.is_empty()); + assert_eq!( + group_range.as_slice(), + &[Item { index: 1, value: 4 }, Item { index: 2, value: 5 }] + ); + assert_eq!(group_range.iter().next(), Some(&Item { index: 1, value: 4 })); + } +} From 9238669c2e140b75b55b559909da7e7dd0930216 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 4 Oct 2024 16:45:45 -0400 Subject: [PATCH 13/31] wip: implement key infrastructure for operations, patterns, builders This commit moves away from `derive!` for operation definitions, to the new `#[operation]` macro, and implements a number of changes/improvements to the `Operation` type and its related APIs. In particular with this commit, the builder infrastructure for operations was finalized and started to be tested with "real" ops. Validation was further improved by typing operands with type constraints that are validated when an op is verified. The handling of successors, operands, and results was improved and unified around a shared abstraction. The pattern and pattern rewriter infrastructure was sketched out as well, and I anticipate landing those next, along with tests for all of it. Once patterns are done, the next two items of note are the data analysis framework, and the conversion framework. With those done, we're ready to move to the new IR. --- hir2/src/derive.rs | 1432 ++------------------ hir2/src/dialects/hir.rs | 76 +- hir2/src/dialects/hir/builders.rs | 3 + hir2/src/dialects/hir/builders/function.rs | 858 ++++++++++++ hir2/src/dialects/hir/ops.rs | 7 +- hir2/src/dialects/hir/ops/assertions.rs | 55 + hir2/src/dialects/hir/ops/binary.rs | 741 ++++++++-- hir2/src/dialects/hir/ops/cast.rs | 170 ++- hir2/src/dialects/hir/ops/control.rs | 237 ++-- hir2/src/dialects/hir/ops/invoke.rs | 44 +- hir2/src/dialects/hir/ops/mem.rs | 71 +- hir2/src/dialects/hir/ops/primop.rs | 90 +- hir2/src/dialects/hir/ops/ternary.rs | 45 + hir2/src/dialects/hir/ops/unary.rs | 198 ++- hir2/src/ir.rs | 39 +- hir2/src/ir/block.rs | 78 ++ hir2/src/ir/builder.rs | 318 +++++ hir2/src/ir/context.rs | 98 +- hir2/src/ir/dialect.rs | 58 +- hir2/src/ir/function.rs | 135 +- hir2/src/ir/module.rs | 202 ++- hir2/src/ir/op.rs | 36 +- hir2/src/ir/operands.rs | 525 +------ hir2/src/ir/operation.rs | 542 +++++++- hir2/src/ir/operation/builder.rs | 284 ++-- hir2/src/ir/operation/name.rs | 175 ++- hir2/src/ir/region.rs | 4 + hir2/src/ir/successor.rs | 218 ++- hir2/src/ir/symbol_table.rs | 252 +++- hir2/src/ir/traits.rs | 52 +- hir2/src/ir/traits/callable.rs | 24 +- hir2/src/ir/traits/info.rs | 83 ++ hir2/src/ir/traits/multitrait.rs | 179 --- hir2/src/ir/traits/types.rs | 45 +- hir2/src/ir/value.rs | 110 +- hir2/src/ir/verifier.rs | 58 +- hir2/src/ir/visit.rs | 522 +++++-- hir2/src/lib.rs | 21 +- hir2/src/patterns.rs | 11 + hir2/src/patterns/applicator.rs | 179 +++ hir2/src/patterns/pattern.rs | 226 +++ hir2/src/patterns/pattern_set.rs | 110 ++ hir2/src/patterns/rewriter.rs | 523 +++++++ 43 files changed, 6217 insertions(+), 2917 deletions(-) create mode 100644 hir2/src/dialects/hir/builders.rs create mode 100644 hir2/src/dialects/hir/builders/function.rs create mode 100644 hir2/src/dialects/hir/ops/assertions.rs create mode 100644 hir2/src/dialects/hir/ops/ternary.rs create mode 100644 hir2/src/ir/builder.rs create mode 100644 hir2/src/ir/traits/info.rs delete mode 100644 hir2/src/ir/traits/multitrait.rs create mode 100644 hir2/src/patterns.rs create mode 100644 hir2/src/patterns/applicator.rs create mode 100644 hir2/src/patterns/pattern.rs create mode 100644 hir2/src/patterns/pattern_set.rs create mode 100644 hir2/src/patterns/rewriter.rs diff --git a/hir2/src/derive.rs b/hir2/src/derive.rs index d6ef93136..b8431f3a4 100644 --- a/hir2/src/derive.rs +++ b/hir2/src/derive.rs @@ -1,22 +1,8 @@ +pub use midenc_hir_macros::operation; + use crate::Operation; -/// This macro is used to generate the boilerplate for [Op] implementations. -/// -/// TODO(pauls): -/// -/// * Support doc comments -/// * Implement type constraints/inference -/// * Implement `verify` blocks for custom verification rules -/// * FIX: Currently #[operands] simply adds boilerplate for creating an operation with those -/// operands, but it does not create methods to access them, and simply adds them in with the -/// other operands. We should figure out how to store operands in such a way that multiple operand -/// groups can be maintained even when adding/removing operands later. -/// * FIX: Currently #[successors] adds a field to the struct to store whatever custom type is used -/// to represent the successors, but these successors are not reachable from the Operation backing -/// the op, and as a result, any successor operations acting from the Operation and not the Op may -/// cause the two to converge. Like the #[operands] issue above, we need to store the actual -/// successor in the Operation, and provide some way to map between the two, OR change how we -/// represent successors to allow storing arbitrary successor-like types in the Operation +/// This macro is used to generate the boilerplate for operation trait implementations. #[macro_export] macro_rules! derive { ( @@ -74,42 +60,6 @@ macro_rules! derive { $($t)* }; - - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident : Op { - $( - $(#[$inner:ident $($args:tt)*])* - $Field:ident: $FieldTy:ty, - )* - } - - $(derives $DerivedOpTrait:ident $(, $MoreDerivedTraits:ident)*;)* - $(implements $ImplementedOpTrait:ident $(, $MoreImplementedTraits:ident)*;)* - ) => { - $crate::__derive_op!( - $(#[$outer])* - $vis struct $Op { - $( - $(#[$inner $($args)*])* - $Field: $FieldTy, - )* - } - - $( - derives $DerivedOpTrait - $( - derives $MoreDerivedTraits - )* - )* - $( - implements $ImplementedOpTrait - $( - implements $MoreImplementedTraits - )* - )* - ); - }; } #[doc(hidden)] @@ -180,1275 +130,6 @@ macro_rules! __derive_op_trait { }; } -#[doc(hidden)] -#[macro_export] -macro_rules! __derive_op { - // Entry - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident { - $( - $(#[$inner:ident $($args:tt)*])* - $Field:ident: $FieldTy:ty, - )* - } - - $(derives $DerivedOpTrait:ident)* - $(implements $ImplementedOpTrait:ident)* - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - $( - { - unprocessed: [$(#[$inner $($args)*])*], - ignore: [], - field: $Field, - field_type: $FieldTy, - } - )* - ], - processed: { - fields: [], - dialect: [], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [], - regions_count: [0usize], - regions: [], - successor_groups_count: [0usize], - successor_groups: [], - successors_count: [0usize], - successors: [], - operand_groups_count: [0usize], - operand_groups: [], - operands_count: [0usize], - operands: [], - results_count: [0usize], - results: [], - } - } - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! __derive_op_processor { - // Handle duplicate `dialect` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[dialect] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$(dialect_processed:tt)+], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - compile_error!("unexpected duplicate dialect attr: got '{}', but '{}' was previously seen", stringify!($Dialect), stringify!($dialect_processed)); - }; - - // Handle `dialect` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[dialect] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [dialect $($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [dialect $FieldTy], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle `region` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[region $($args:tt)*] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [region $($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [1usize + $regions_count], - regions: [region $Field at $regions_count $($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle `successor` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[successor $($args:tt)*] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [successor $($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [1usize + $succ_count], - successors: [successor $Field at $succ_count $($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle `successors` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[successors $($args:tt)*] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [successors $($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [1usize + $succ_groups_count], - successor_groups: [successors $Field : $FieldTy $($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle `operand` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[operand $($args:tt)*] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [operand $($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [1usize + $operands_count], - operands: [operand $Field at $operands_count $($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle `operands` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[operands $($args:tt)*] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [operands $($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [1usize + $operand_groups_count], - operand_groups: [operands $Field at $operand_groups_count $($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle `result` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[result $($args:tt)*] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [result $($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [1usize + $results_count], - results: [result $Field at $results_count $($results_processed)*], - } - } - }; - - // Handle `attr` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[attr $($args:tt)*] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [attr $($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [attr $Field: $FieldTy $($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle `doc` attr - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [ - #[doc $($args:tt)*] - $($attrs_rest:tt)* - ], - ignore: [$($IgnoredReason:tt)*], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - { - unprocessed: [ - $($attrs_rest)* - ], - ignore: [$($IgnoredReason)*], - field: $Field, - field_type: $FieldTy, - } - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle end of unprocessed attributes (ignore=false) - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [], - ignore: [], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - $($fields_rest)* - ], - processed: { - fields: [field $Field: $FieldTy $($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle end of unprocessed attributes (ignore=true) - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [ - { - unprocessed: [], - ignore: [$($IgnoredReason:tt)+], - field: $Field:ident, - field_type: $FieldTy:ty, - } - $($fields_rest:tt)* - ], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_processor! { - $(#[$outer])* - $vis struct $Op; - - unprocessed: [ - $($fields_rest)* - ], - processed: { - fields: [$($extra_fields_processed)*], - dialect: [$($dialect_processed)*], - traits: [$(derives $DerivedOpTrait),* $(implements $ImplementedOpTrait),*], - attrs: [$($attrs_processed)*], - regions_count: [$regions_count], - regions: [$($regions_processed)*], - successor_groups_count: [$succ_groups_count], - successor_groups: [$($succ_groups_processed)*], - successors_count: [$succ_count], - successors: [$($succ_processed)*], - operand_groups_count: [$operand_groups_count], - operand_groups: [$($operand_groups_processed)*], - operands_count: [$operands_count], - operands: [$($operands_processed)*], - results_count: [$results_count], - results: [$($results_processed)*], - } - } - }; - - // Handle end of unprocessed fields - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - unprocessed: [], - processed: { - fields: [$($extra_fields_processed:tt)*], - dialect: [$($dialect_processed:tt)*], - traits: [$(derives $DerivedOpTrait:ident),* $(implements $ImplementedOpTrait:ident),*], - attrs: [$($attrs_processed:tt)*], - regions_count: [$regions_count:expr], - regions: [$($regions_processed:tt)*], - successor_groups_count: [$succ_groups_count:expr], - successor_groups: [$($succ_groups_processed:tt)*], - successors_count: [$succ_count:expr], - successors: [$($succ_processed:tt)*], - operand_groups_count: [$operand_groups_count:expr], - operand_groups: [$($operand_groups_processed:tt)*], - operands_count: [$operands_count:expr], - operands: [$($operands_processed:tt)*], - results_count: [$results_count:expr], - results: [$($results_processed:tt)*], - } - ) => { - $crate::__derive_op_impl!( - $(#[$outer])* - $vis struct $Op; - - $($dialect_processed)*; - $($extra_fields_processed)*; - $(derives $DerivedOpTrait)*; - $(implements $ImplementedOpTrait)*; - $($attrs_processed)*; - regions $regions_count; - $($regions_processed)*; - $($succ_groups_processed)*; - $($succ_processed)*; - $($operand_groups_processed)*; - $($operands_processed)*; - $($results_processed)*; - ); - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! __derive_op_impl { - ( - $(#[$outer:meta])* - $vis:vis struct $Op:ident; - - dialect $Dialect:ty; - $(field $Field:ident: $FieldTy:ty)*; - $(derives $DerivedOpTrait:ident)*; - $(implements $ImplementedOpTrait:ident)*; - $(attr $AttrField:ident: $AttrTy:ty)*; - regions $NumRegions:expr; - $(region $RegionField:ident at $RegionIdx:expr)*; - $(successors $SuccGroupField:ident: $SuccGroupTy:ty)*; - $(successor $SuccField:ident at $SuccIdx:expr)*; - $(operands $OperandGroupField:ident at $OperandGroupIdx:expr)*; - $(operand $Operand:ident at $OperandIdx:expr)*; - $(result $Result:ident at $ResultIdx:expr)*; - - ) => { - $(#[$outer])* - $vis struct $Op { - op: $crate::Operation, - $( - $Field: $FieldTy, - )* - $( - $SuccGroupField: $SuccGroupTy, - )* - } - impl ::midenc_session::diagnostics::Spanned for $Op { - fn span(&self) -> ::midenc_session::diagnostics::SourceSpan { - self.op.span() - } - } - - #[allow(unused)] - impl $Op { - /// Get a new, uninitialized instance of this op - pub fn uninit($($Field: $FieldTy),*) -> Self { - let mut op = $crate::Operation::uninit::(); - Self { - op, - $( - $Field, - )* - $( - $SuccGroupField: Default::default(), - )* - } - } - - pub fn create( - context: &$crate::Context - $( - , $Operand: $crate::ValueRef - )* - $( - , $OperandGroupField: impl IntoIterator - )* - $( - , $Field: $FieldTy - )* - $( - , $AttrField: $AttrTy - )* - $( - , $SuccGroupField: $SuccGroupTy - )* - $( - , $SuccField: $crate::OpSuccessor - )* - ) -> Result<$crate::UnsafeIntrusiveEntityRef<$Op>, $crate::Report> { - let mut this = Self::uninit($($Field),*); - $( - this.$SuccGroupField = $SuccGroupField.clone(); - )* - - let mut builder = $crate::OperationBuilder::::new(context, this); - $( - builder.implement::(); - )* - $( - builder.implement::(); - )* - $( - builder.with_attr(stringify!($AttrField), $AttrField); - )* - builder.with_operands([$($Operand),*]); - $( - builder.with_operands_in_group($OperandGroupIdx, $OperandGroupField); - )* - $( - #[doc = stringify!($RegionField)] - builder.create_region(); - )* - $( - builder.with_successors($SuccGroupField); - )* - $( - builder.with_successor($SuccField); - )* - let num_results = const { - let results: &[usize] = &[$($ResultIdx),*]; - results.len() - }; - builder.with_results(num_results); - builder.build() - } - - $( - #[inline] - fn $Field(&self) -> &$FieldTy { - &self.$Field - } - - paste::paste! { - #[inline] - fn [<$Field _mut>](&mut self) -> &mut $FieldTy { - &mut self.$Field - } - - #[doc = concat!("Set the value of ", stringify!($Field))] - #[inline] - fn [](&mut self, $Field: $FieldTy) { - self.$Field = $Field; - } - } - )* - - $( - fn $AttrField(&self) -> &$AttrTy { - let sym = stringify!($AttrField); - self.op.get_typed_attribute::<$AttrTy, _>(&::midenc_hir_symbol::Symbol::intern(sym)).unwrap() - } - - paste::paste! { - fn [<$AttrField _mut>](&mut self) -> &mut $AttrTy { - let sym = stringify!($AttrField); - self.op.get_typed_attribute_mut::<$AttrTy, _>(&::midenc_hir_symbol::Symbol::intern(sym)).unwrap() - } - - fn [](&mut self, value: $AttrTy) { - let sym = stringify!($AttrField); - self.op.set_attribute(::midenc_hir_symbol::Symbol::intern(sym), Some(value)); - } - } - )* - - $( - fn $RegionField(&self) -> $crate::EntityRef<'_, $crate::Region> { - self.op.region($RegionIdx) - } - - paste::paste! { - fn [<$RegionField _mut>](&mut self) -> $crate::EntityMut<'_, $crate::Region> { - self.op.region_mut($RegionIdx) - } - } - )* - - $( - #[inline] - fn $SuccGroupField(&self) -> &$SuccGroupTy { - &self.$SuccGroupField - } - - paste::paste! { - #[inline] - fn [<$SuccGroupField _mut>](&mut self) -> &mut $SuccGroupTy { - &mut self.$SuccGroupField - } - } - )* - - $( - #[inline] - fn $SuccField(&self) -> &$crate::OpSuccessor { - &self.successors()[$SuccIdx] - } - - paste::paste! { - #[inline] - fn [<$SuccField _mut>](&mut self) -> &mut $crate::OpSuccessor { - &mut self.successors_mut()[$SuccIdx] - } - } - )* - - $( - fn $Operand(&self) -> $crate::EntityRef<'_, $crate::OpOperandImpl> { - self.op.operands()[$OperandIdx].borrow() - } - - paste::paste! { - fn [<$Operand _mut>](&mut self) -> $crate::EntityMut<'_, $crate::OpOperandImpl> { - self.op.operands_mut()[$OperandIdx].borrow_mut() - } - } - )* - - $( - fn $OperandGroupField(&self) -> $crate::OpOperandRange<'_> { - self.op.operands().group($OperandGroupIdx) - } - - paste::paste! { - fn [<$OperandGroupField _mut>](&mut self) -> $crate::OpOperandRangeMut<'_> { - self.op.operands_mut().group_mut($OperandGroupIdx) - } - } - )* - - $( - fn $Result(&self) -> $crate::EntityRef<'_, dyn $crate::Value> { - self.results()[$ResultIdx].borrow() - } - - paste::paste! { - fn [<$Result _mut>](&mut self) -> $crate::EntityMut<'_, dyn $crate::Value> { - self.op.results_mut()[$ResultIdx].borrow_mut() - } - } - )* - } - - impl AsRef<$crate::Operation> for $Op { - #[inline(always)] - fn as_ref(&self) -> &$crate::Operation { - &self.op - } - } - - impl AsMut<$crate::Operation> for $Op { - #[inline(always)] - fn as_mut(&mut self) -> &mut $crate::Operation { - &mut self.op - } - } - - $crate::__derive_op_name!($Op, $Dialect); - - impl $crate::Op for $Op { - fn name(&self) -> $crate::OperationName { - paste::paste! { - *[<__ $Op _NAME>] - } - } - - #[inline(always)] - fn as_operation(&self) -> &$crate::Operation { - &self.op - } - - #[inline(always)] - fn as_operation_mut(&mut self) -> &mut $crate::Operation { - &mut self.op - } - } - - $crate::__derive_op_traits!($Op $(, derive $DerivedOpTrait)* $(, implement $ImplementedOpTrait)*); - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! __derive_op_name { - ($Op:ident, $Dialect:ty) => { - paste::paste! { - #[allow(non_upper_case_globals)] - static [<__ $Op _NAME>]: ::std::sync::LazyLock<$crate::OperationName> = ::std::sync::LazyLock::new(|| { - const DIALECT: $Dialect = <$Dialect as $crate::Dialect>::INIT; - - // CondBrOp => CondBr => cond_br - // Add => add - let type_name = stringify!($Op); - let type_name = type_name.strip_suffix("Op").unwrap_or(type_name); - let mut buf = ::alloc::string::String::with_capacity(type_name.len()); - let mut word_started_at = None; - for (i, c) in type_name.char_indices() { - if c.is_ascii_uppercase() { - if word_started_at.is_some() { - buf.push('_'); - buf.push(c.to_ascii_lowercase()); - } else { - word_started_at = Some(i); - buf.push(c.to_ascii_lowercase()); - } - } else if word_started_at.is_none() { - word_started_at = Some(i); - buf.push(c); - } else { - buf.push(c); - } - } - let name = ::midenc_hir_symbol::Symbol::intern(buf); - let dialect = <$Dialect as $crate::Dialect>::name(&DIALECT); - $crate::OperationName::new(dialect, name) - }); - } - } -} - -/// This macro emits the trait derivations and specialized verifier for a given [Op] impl. -#[doc(hidden)] -#[macro_export] -macro_rules! __derive_op_traits { - ($T:ty) => { - impl $crate::OpVerifier for $T { - #[inline(always)] - fn verify(&self, _context: &$crate::Context) -> Result<(), $crate::Report> { - Ok(()) - } - } - }; - - ($T:ty $(, derive $DeriveTrait:ident)* $(, implement $ImplementTrait:ident)*) => { - $( - impl $DeriveTrait for $T {} - )* - - impl $crate::OpVerifier for $T { - fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { - #[allow(unused_parens)] - type OpVerifierImpl<'a> = $crate::derive::DeriveVerifier<'a, $T, ($(&'a dyn $DeriveTrait,)* $(&'a dyn $ImplementTrait),*)>; - #[allow(unused_parens)] - impl<'a> $crate::OpVerifier for $crate::derive::DeriveVerifier<'a, $T, ($(&'a dyn $DeriveTrait,)* $(&'a dyn $ImplementTrait),*)> - where - $( - $T: $crate::verifier::Verifier, - )* - $( - $T: $crate::verifier::Verifier, - )* - { - fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { - let op = self.downcast_ref::<$T>().unwrap(); - $( - if const { !<$T as $crate::verifier::Verifier>::VACUOUS } { - <$T as $crate::verifier::Verifier>::maybe_verify(op, context)?; - } - )* - $( - if const { !<$T as $crate::verifier::Verifier>::VACUOUS } { - <$T as $crate::verifier::Verifier>::maybe_verify(op, context)?; - } - )* - - Ok(()) - } - } - - let verifier = OpVerifierImpl::new(&self.op); - verifier.verify(context) - } - } - } -} - /// This type represents the concrete set of derived traits for some op `T`, paired with a /// type-erased [Operation] reference for an instance of that op. /// @@ -1483,11 +164,13 @@ impl<'a, T, Derived: ?Sized> core::ops::Deref for DeriveVerifier<'a, T, Derived> #[cfg(test)] mod tests { + use alloc::rc::Rc; use core::fmt; + use super::operation; use crate::{ - define_attr_type, dialects::hir::HirDialect, traits::*, Context, Op, Operation, Report, - Spanned, + define_attr_type, dialects::hir::HirDialect, formatter, traits::*, Builder, Context, Op, + Operation, Report, Spanned, Value, }; #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -1503,26 +186,50 @@ mod tests { fmt::Debug::fmt(self, f) } } - define_attr_type!(Overflow); - - derive! { - /// An example op implementation to make sure all of the type machinery works - struct AddOp : Op { - #[dialect] - dialect: HirDialect, - #[attr] - overflow: Overflow, - #[operand] - lhs: OpOperand, - #[operand] - rhs: OpOperand, + impl formatter::PrettyPrint for Overflow { + fn render(&self) -> formatter::Document { + use formatter::*; + display(self) } + } + define_attr_type!(Overflow); - derives SingleBlock, SameTypeOperands; - implements ArithmeticOp; + /// An example op implementation to make sure all of the type machinery works + #[operation( + dialect = HirDialect, + traits(ArithmeticOp, BinaryOp, Commutative, SingleBlock, SameTypeOperands), + implements(InferTypeOpInterface) + )] + struct AddOp { + #[attr] + overflow: Overflow, + #[operand] + #[order(0)] + lhs: AnyInteger, + #[operand] + #[order(1)] + rhs: AnyInteger, + #[result] + result: AnyInteger, } - impl ArithmeticOp for AddOp {} + impl InferTypeOpInterface for AddOp { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty(); + { + let rhs = self.rhs(); + let rhs = rhs.value(); + let rhs_ty = rhs.ty(); + if &lhs != rhs_ty { + return Err(Report::msg(format!( + "lhs and rhs types do not match: expected '{lhs}', got '{rhs_ty}'" + ))); + } + } + self.result_mut().set_type(lhs); + Ok(()) + } + } derive! { /// A marker trait for arithmetic ops @@ -1547,15 +254,21 @@ mod tests { } #[test] - fn test_derived_op() { - use crate::Type; - - let context = Context::default(); - let block = context.create_block_with_params([Type::U32, Type::I64]); - let block = block.borrow(); - let lhs = block.get_argument(0); - let rhs = block.get_argument(1); - let op = AddOp::create(&context, rhs, lhs, Overflow::Wrapping); + fn derived_op_builder_test() { + use crate::{SourceSpan, Type}; + + let context = Rc::new(Context::default()); + let block = context.create_block_with_params([Type::U32, Type::U32]); + let (lhs, rhs) = { + let block = block.borrow(); + let lhs = block.get_argument(0).upcast::(); + let rhs = block.get_argument(1).upcast::(); + (lhs, rhs) + }; + let mut builder = context.builder(); + builder.set_insertion_point_to_end(block); + let op_builder = builder.create::(SourceSpan::default()); + let op = op_builder(lhs, rhs, Overflow::Wrapping); let op = op.expect("failed to create AddOp"); let op = op.borrow(); assert!(op.as_operation().implements::()); @@ -1563,4 +276,25 @@ mod tests { !>::VACUOUS )); } + + #[test] + #[should_panic = "lhs and rhs types do not match: expected 'u32', got 'i64'"] + fn derived_op_verifier_test() { + use crate::{SourceSpan, Type}; + + let context = Rc::new(Context::default()); + let block = context.create_block_with_params([Type::U32, Type::I64]); + let (lhs, invalid_rhs) = { + let block = block.borrow(); + let lhs = block.get_argument(0).upcast::(); + let rhs = block.get_argument(1).upcast::(); + (lhs, rhs) + }; + let mut builder = context.builder(); + builder.set_insertion_point_to_end(block); + // Try to create instance of AddOp with mismatched operand types + let op_builder = builder.create::(SourceSpan::default()); + let op = op_builder(lhs, invalid_rhs, Overflow::Wrapping); + let _op = op.unwrap(); + } } diff --git a/hir2/src/dialects/hir.rs b/hir2/src/dialects/hir.rs index 976e9ef46..fae91c19b 100644 --- a/hir2/src/dialects/hir.rs +++ b/hir2/src/dialects/hir.rs @@ -1,14 +1,78 @@ +mod builders; mod ops; -pub use self::ops::*; -use crate::{interner, Dialect, DialectName}; +use alloc::rc::Rc; +use core::cell::{Cell, RefCell}; -#[derive(Default, Debug)] -pub struct HirDialect; -impl Dialect for HirDialect { - const INIT: Self = HirDialect; +pub use self::{ + builders::{DefaultInstBuilder, FunctionBuilder}, + ops::*, +}; +use crate::{interner, Dialect, DialectName, DialectRegistration, OperationName}; + +#[derive(Default)] +pub struct HirDialect { + registered_ops: RefCell>, + registered_op_cache: Cell>>, +} + +impl core::fmt::Debug for HirDialect { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("HirDialect") + .field_with("registered_ops", |f| { + f.debug_set().entries(self.registered_ops.borrow().iter()).finish() + }) + .finish_non_exhaustive() + } +} + +impl HirDialect { + #[inline] + pub fn num_registered(&self) -> usize { + self.registered_ops.borrow().len() + } +} +impl Dialect for HirDialect { + #[inline] fn name(&self) -> DialectName { DialectName::from_symbol(interner::symbols::Hir) } + + fn registered_ops(&self) -> Rc<[OperationName]> { + let registered = unsafe { (*self.registered_op_cache.as_ptr()).clone() }; + if registered.as_ref().is_some_and(|ops| self.num_registered() == ops.len()) { + registered.unwrap() + } else { + let registered = self.registered_ops.borrow(); + let ops = Rc::from(registered.clone().into_boxed_slice()); + self.registered_op_cache.set(Some(Rc::clone(&ops))); + ops + } + } + + fn get_or_register_op( + &self, + opcode: midenc_hir_symbol::Symbol, + register: fn(DialectName, midenc_hir_symbol::Symbol) -> crate::OperationName, + ) -> crate::OperationName { + let mut registered = self.registered_ops.borrow_mut(); + match registered.binary_search_by_key(&opcode, |op| op.name()) { + Ok(index) => registered[index].clone(), + Err(index) => { + let name = register(self.name(), opcode); + registered.insert(index, name.clone()); + name + } + } + } +} + +impl DialectRegistration for HirDialect { + const NAMESPACE: &'static str = "hir"; + + #[inline] + fn init() -> Self { + Self::default() + } } diff --git a/hir2/src/dialects/hir/builders.rs b/hir2/src/dialects/hir/builders.rs new file mode 100644 index 000000000..697d0e2d7 --- /dev/null +++ b/hir2/src/dialects/hir/builders.rs @@ -0,0 +1,3 @@ +mod function; + +pub use self::function::*; diff --git a/hir2/src/dialects/hir/builders/function.rs b/hir2/src/dialects/hir/builders/function.rs new file mode 100644 index 000000000..b560156f6 --- /dev/null +++ b/hir2/src/dialects/hir/builders/function.rs @@ -0,0 +1,858 @@ +use self::traits::AsCallableSymbolRef; +use crate::*; + +pub struct FunctionBuilder<'f> { + pub func: &'f mut Function, + builder: OpBuilder, +} +impl<'f> FunctionBuilder<'f> { + pub fn new(func: &'f mut Function) -> Self { + let context = func.as_operation().context_rc(); + let mut builder = OpBuilder::new(context); + builder.set_insertion_point_to_end(func.last_block()); + + Self { func, builder } + } + + pub fn at(func: &'f mut Function, ip: InsertionPoint) -> Self { + let context = func.as_operation().context_rc(); + let mut builder = OpBuilder::new(context); + builder.set_insertion_point(ip); + + Self { func, builder } + } + + pub fn body_region(&self) -> RegionRef { + unsafe { RegionRef::from_raw(&*self.func.body()) } + } + + pub fn entry_block(&self) -> BlockRef { + self.func.entry_block() + } + + #[inline] + pub fn current_block(&self) -> BlockRef { + self.builder.insertion_block().expect("builder has no insertion point set") + } + + #[inline] + pub fn switch_to_block(&mut self, block: BlockRef) { + self.builder.set_insertion_point_to_end(block); + } + + pub fn create_block(&mut self) -> BlockRef { + self.builder.create_block(self.body_region(), None, None) + } + + pub fn detach_block(&mut self, mut block: BlockRef) { + assert_ne!( + block, + self.current_block(), + "cannot remove block the builder is currently inserting in" + ); + assert_eq!( + block.borrow().parent().map(|p| RegionRef::as_ptr(&p)), + Some(&*self.func.body() as *const Region), + "cannot detach a block that does not belong to this function" + ); + let mut body = self.func.body_mut(); + unsafe { + body.body_mut().cursor_mut_from_ptr(block.clone()).remove(); + } + block.borrow_mut().uses_mut().clear(); + } + + pub fn append_block_param(&mut self, block: BlockRef, ty: Type, span: SourceSpan) -> ValueRef { + self.builder.context().append_block_argument(block, ty, span) + } + + pub fn ins<'a, 'b: 'a>(&'b mut self) -> DefaultInstBuilder<'a> { + DefaultInstBuilder::new(self.func, &mut self.builder) + } +} + +pub struct DefaultInstBuilder<'f> { + func: &'f mut Function, + builder: &'f mut OpBuilder, +} +impl<'f> DefaultInstBuilder<'f> { + pub(crate) fn new(func: &'f mut Function, builder: &'f mut OpBuilder) -> Self { + Self { func, builder } + } +} +impl<'f> InstBuilderBase<'f> for DefaultInstBuilder<'f> { + fn builder_parts(&mut self) -> (&mut Function, &mut OpBuilder) { + (self.func, self.builder) + } + + fn builder(&self) -> &OpBuilder { + self.builder + } + + fn builder_mut(&mut self) -> &mut OpBuilder { + self.builder + } +} + +pub trait InstBuilderBase<'f>: Sized { + fn builder(&self) -> &OpBuilder; + fn builder_mut(&mut self) -> &mut OpBuilder; + fn builder_parts(&mut self) -> (&mut Function, &mut OpBuilder); + /// Get a default instruction builder using the dataflow graph and insertion point of the + /// current builder + fn ins<'a, 'b: 'a>(&'b mut self) -> DefaultInstBuilder<'a> { + let (func, builder) = self.builder_parts(); + DefaultInstBuilder::new(func, builder) + } +} + +pub trait InstBuilder<'f>: InstBuilderBase<'f> { + fn assert( + mut self, + value: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = + self.builder_mut().create::(span); + op_builder(value) + } + + fn assert_with_error( + mut self, + value: ValueRef, + code: u32, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = + self.builder_mut().create::(span); + op_builder(value, code) + } + + fn assertz( + mut self, + value: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = + self.builder_mut().create::(span); + op_builder(value) + } + + fn assertz_with_error( + mut self, + value: ValueRef, + code: u32, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self + .builder_mut() + .create::(span); + op_builder(value, code) + } + + fn assert_eq( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(lhs, rhs) + } + + fn assert_eq_imm( + mut self, + lhs: ValueRef, + rhs: Immediate, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(lhs, rhs) + } + + //signed_integer_literal!(1, bool); + //integer_literal!(8); + //integer_literal!(16); + //integer_literal!(32); + //integer_literal!(64); + //integer_literal!(128); + + /* + fn felt(self, i: Felt, span: SourceSpan) -> Value { + into_first_result!(self.UnaryImm(Opcode::ImmFelt, Type::Felt, Immediate::Felt(i), span)) + } + + fn f64(self, f: f64, span: SourceSpan) -> Value { + into_first_result!(self.UnaryImm(Opcode::ImmF64, Type::F64, Immediate::F64(f), span)) + } + + fn character(self, c: char, span: SourceSpan) -> Value { + self.i32((c as u32) as i32, span) + } + */ + + /// Grow the global heap by `num_pages` pages, in 64kb units. + /// + /// Returns the previous size (in pages) of the heap, or -1 if the heap could not be grown. + fn mem_grow(mut self, num_pages: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(num_pages)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Return the size of the global heap in pages, where each page is 64kb. + fn mem_size(mut self, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder()?; + Ok(op.borrow().result().as_value_ref()) + } + + /* + /// Get a [GlobalValue] which represents the address of a global variable whose symbol is `name` + /// + /// On it's own, this does nothing, you must use the resulting [GlobalValue] with a builder + /// that expects one as an argument, or use `global_value` to obtain a [Value] from it. + fn symbol>(self, name: S, span: SourceSpan) -> GlobalValue { + self.symbol_relative(name, 0, span) + } + + /// Same semantics as `symbol`, but applies a constant offset to the address of the given + /// symbol. + /// + /// If the offset is zero, this is equivalent to `symbol` + fn symbol_relative>( + mut self, + name: S, + offset: i32, + span: SourceSpan, + ) -> GlobalValue { + self.data_flow_graph_mut().create_global_value(GlobalValueData::Symbol { + name: Ident::new(Symbol::intern(name.as_ref()), span), + offset, + }) + } + + /// Get the address of a global variable whose symbol is `name` + /// + /// The type of the pointer produced is given as `ty`. It is up to the caller + /// to ensure that loading memory from that pointer is valid for the provided + /// type. + fn symbol_addr>(self, name: S, ty: Type, span: SourceSpan) -> Value { + self.symbol_relative_addr(name, 0, ty, span) + } + + /// Same semantics as `symbol_addr`, but applies a constant offset to the address of the given + /// symbol. + /// + /// If the offset is zero, this is equivalent to `symbol_addr` + fn symbol_relative_addr>( + mut self, + name: S, + offset: i32, + ty: Type, + span: SourceSpan, + ) -> Value { + assert!(ty.is_pointer(), "expected pointer type, got '{}'", &ty); + let gv = self.data_flow_graph_mut().create_global_value(GlobalValueData::Symbol { + name: Ident::new(Symbol::intern(name.as_ref()), span), + offset, + }); + into_first_result!(self.Global(gv, ty, span)) + } + + /// Loads a value of type `ty` from the global variable whose symbol is `name`. + /// + /// NOTE: There is no requirement that the memory contents at the given symbol + /// contain a valid value of type `ty`. That is left entirely up the caller to + /// guarantee at a higher level. + fn load_symbol>(self, name: S, ty: Type, span: SourceSpan) -> Value { + self.load_symbol_relative(name, ty, 0, span) + } + + /// Same semantics as `load_symbol`, but a constant offset is applied to the address before + /// issuing the load. + fn load_symbol_relative>( + mut self, + name: S, + ty: Type, + offset: i32, + span: SourceSpan, + ) -> Value { + let base = self.data_flow_graph_mut().create_global_value(GlobalValueData::Symbol { + name: Ident::new(Symbol::intern(name.as_ref()), span), + offset: 0, + }); + self.load_global_relative(base, ty, offset, span) + } + + /// Loads a value of type `ty` from the address represented by `addr` + /// + /// NOTE: There is no requirement that the memory contents at the given symbol + /// contain a valid value of type `ty`. That is left entirely up the caller to + /// guarantee at a higher level. + fn load_global(self, addr: GlobalValue, ty: Type, span: SourceSpan) -> Value { + self.load_global_relative(addr, ty, 0, span) + } + + /// Same semantics as `load_global_relative`, but a constant offset is applied to the address + /// before issuing the load. + fn load_global_relative( + mut self, + base: GlobalValue, + ty: Type, + offset: i32, + span: SourceSpan, + ) -> Value { + if let GlobalValueData::Load { + ty: ref base_ty, .. + } = self.data_flow_graph().global_value(base) + { + // If the base global is a load, the target address cannot be computed until runtime, + // so expand this to the appropriate sequence of instructions to do so in that case + assert!(base_ty.is_pointer(), "expected global value to have pointer type"); + let base_ty = base_ty.clone(); + let base = self.ins().load_global(base, base_ty.clone(), span); + let addr = self.ins().ptrtoint(base, Type::U32, span); + let offset_addr = if offset >= 0 { + self.ins().add_imm_checked(addr, Immediate::U32(offset as u32), span) + } else { + self.ins().sub_imm_checked(addr, Immediate::U32(offset.unsigned_abs()), span) + }; + let ptr = self.ins().inttoptr(offset_addr, base_ty, span); + self.load(ptr, span) + } else { + // The global address can be computed statically + let gv = self.data_flow_graph_mut().create_global_value(GlobalValueData::Load { + base, + offset, + ty: ty.clone(), + }); + into_first_result!(self.Global(gv, ty, span)) + } + } + + /// Computes an address relative to the pointer produced by `base`, by applying an offset + /// given by multiplying `offset` * the size in bytes of `unit_ty`. + /// + /// The type of the pointer produced is the same as the type of the pointer given by `base` + /// + /// This is useful in some scenarios where `load_global_relative` is not, namely when computing + /// the effective address of an element of an array stored in a global variable. + fn global_addr_offset( + mut self, + base: GlobalValue, + offset: i32, + unit_ty: Type, + span: SourceSpan, + ) -> Value { + if let GlobalValueData::Load { + ty: ref base_ty, .. + } = self.data_flow_graph().global_value(base) + { + // If the base global is a load, the target address cannot be computed until runtime, + // so expand this to the appropriate sequence of instructions to do so in that case + assert!(base_ty.is_pointer(), "expected global value to have pointer type"); + let base_ty = base_ty.clone(); + let base = self.ins().load_global(base, base_ty.clone(), span); + let addr = self.ins().ptrtoint(base, Type::U32, span); + let unit_size: i32 = unit_ty + .size_in_bytes() + .try_into() + .expect("invalid type: size is larger than 2^32"); + let computed_offset = unit_size * offset; + let offset_addr = if computed_offset >= 0 { + self.ins().add_imm_checked(addr, Immediate::U32(offset as u32), span) + } else { + self.ins().sub_imm_checked(addr, Immediate::U32(offset.unsigned_abs()), span) + }; + let ptr = self.ins().inttoptr(offset_addr, base_ty, span); + self.load(ptr, span) + } else { + // The global address can be computed statically + let gv = self.data_flow_graph_mut().create_global_value(GlobalValueData::IAddImm { + base, + offset, + ty: unit_ty.clone(), + }); + let ty = self.data_flow_graph().global_type(gv); + into_first_result!(self.Global(gv, ty, span)) + } + } + + /// Loads a value of the type pointed to by the given pointer, on to the stack + /// + /// NOTE: This function will panic if `ptr` is not a pointer typed value + fn load(self, addr: Value, span: SourceSpan) -> Value { + let ty = require_pointee!(self, addr).clone(); + let data = Instruction::Load(LoadOp { + op: Opcode::Load, + addr, + ty: ty.clone(), + }); + into_first_result!(self.build(data, Type::Ptr(Box::new(ty)), span)) + } + + /// Loads a value from the given temporary (local variable), of the type associated with that + /// local. + fn load_local(self, local: LocalId, span: SourceSpan) -> Value { + let data = Instruction::LocalVar(LocalVarOp { + op: Opcode::Load, + local, + args: ValueList::default(), + }); + let ty = self.data_flow_graph().local_type(local).clone(); + into_first_result!(self.build(data, Type::Ptr(Box::new(ty)), span)) + } + + /// Stores `value` to the address given by `ptr` + /// + /// NOTE: This function will panic if the pointer and pointee types do not match + fn store(mut self, ptr: Value, value: Value, span: SourceSpan) -> Inst { + let pointee_ty = require_pointee!(self, ptr); + let value_ty = self.data_flow_graph().value_type(value); + assert_eq!(pointee_ty, value_ty, "expected value to be a {}, got {}", pointee_ty, value_ty); + let mut vlist = ValueList::default(); + { + let dfg = self.data_flow_graph_mut(); + vlist.extend([ptr, value], &mut dfg.value_lists); + } + self.PrimOp(Opcode::Store, Type::Unit, vlist, span).0 + } + + /// Stores `value` to the given temporary (local variable). + /// + /// NOTE: This function will panic if the type of `value` does not match the type of the local + /// variable. + fn store_local(mut self, local: LocalId, value: Value, span: SourceSpan) -> Inst { + let mut vlist = ValueList::default(); + { + let dfg = self.data_flow_graph_mut(); + let local_ty = dfg.local_type(local); + let value_ty = dfg.value_type(value); + assert_eq!(local_ty, value_ty, "expected value to be a {}, got {}", local_ty, value_ty); + vlist.push(value, &mut dfg.value_lists); + } + let data = Instruction::LocalVar(LocalVarOp { + op: Opcode::Store, + local, + args: vlist, + }); + self.build(data, Type::Unit, span).0 + } + */ + + /// Writes `count` copies of `value` to memory starting at address `dst`. + /// + /// Each copy of `value` will be written to memory starting at the next aligned address from + /// the previous copy. This instruction will trap if the input address does not meet the + /// minimum alignment requirements of the type. + fn memset( + mut self, + dst: ValueRef, + count: ValueRef, + value: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(dst, count, value) + } + + /// Copies `count` values from the memory at address `src`, to the memory at address `dst`. + /// + /// The unit size for `count` is determined by the `src` pointer type, i.e. a pointer to u8 + /// will copy one `count` bytes, a pointer to u16 will copy `count * 2` bytes, and so on. + /// + /// NOTE: The source and destination pointer types must match, or this function will panic. + fn memcpy( + mut self, + src: ValueRef, + dst: ValueRef, + count: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(src, dst, count) + } + + /// This is a cast operation that permits performing arithmetic on pointer values + /// by casting a pointer to a specified integral type. + fn ptrtoint(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// This is the inverse of `ptrtoint`, used to recover a pointer that was + /// previously cast to an integer type. It may also be used to cast arbitrary + /// integer values to pointers. + /// + /// In both cases, use of the resulting pointer must not violate the semantics + /// of the higher level language being represented in Miden IR. + fn inttoptr(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /* + /// This is an intrinsic which derives a new pointer from an existing pointer to an aggregate. + /// + /// In short, this represents the common need to calculate a new pointer from an existing + /// pointer, but without losing provenance of the original pointer. It is specifically + /// intended for use in obtaining a pointer to an element/field of an array/struct, of the + /// correct type, given a well typed pointer to the aggregate. + /// + /// This function will panic if the pointer is not to an aggregate type + /// + /// The new pointer is derived by statically navigating the structure of the pointee type, using + /// `offsets` to guide the traversal. Initially, the first offset is relative to the original + /// pointer, where `0` refers to the base/first field of the object. The second offset is then + /// relative to the base of the object selected by the first offset, and so on. Offsets must + /// remain in bounds, any attempt to index outside a type's boundaries will result in a + /// panic. + fn getelementptr(mut self, ptr: ValueRef, mut indices: &[usize], span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + op_builder(arg, ty) + } */ + + /// Cast `arg` to a value of type `ty` + /// + /// NOTE: This is only supported for integral types currently, and the types must be of the same + /// size in bytes, i.e. i32 -> u32 or vice versa. + /// + /// The intention of bitcasts is to reinterpret a value with different semantics, with no + /// validation that is typically implied by casting from one type to another. + fn bitcast(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Cast `arg` to a value of type `ty` + /// + /// NOTE: This is only valid for numeric to numeric, or pointer to pointer casts. + /// For numeric to pointer, or pointer to numeric casts, use `inttoptr` and `ptrtoint` + /// respectively. + fn cast(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Truncates an integral value as necessary to fit in `ty`. + /// + /// NOTE: Truncating a value into a larger type has undefined behavior, it is + /// equivalent to extending a value without doing anything with the new high-order + /// bits of the resulting value. + fn trunc(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Extends an integer into a larger integeral type, by zero-extending the value, + /// i.e. the new high-order bits of the resulting value will be all zero. + /// + /// NOTE: This function will panic if `ty` is smaller than `arg`. + /// + /// If `arg` is the same type as `ty`, `arg` is returned as-is + fn zext(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Extends an integer into a larger integeral type, by sign-extending the value, + /// i.e. the new high-order bits of the resulting value will all match the sign bit. + /// + /// NOTE: This function will panic if `ty` is smaller than `arg`. + /// + /// If `arg` is the same type as `ty`, `arg` is returned as-is + fn sext(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /* + binary_int_op_with_overflow!(add, Opcode::Add); + binary_int_op_with_overflow!(sub, Opcode::Sub); + binary_int_op_with_overflow!(mul, Opcode::Mul); + checked_binary_int_op!(div, Opcode::Div); + binary_int_op!(min, Opcode::Min); + binary_int_op!(max, Opcode::Max); + checked_binary_int_op!(r#mod, Opcode::Mod); + checked_binary_int_op!(divmod, Opcode::DivMod); + binary_int_op!(exp, Opcode::Exp); + binary_boolean_op!(and, Opcode::And); + binary_int_op!(band, Opcode::Band); + binary_boolean_op!(or, Opcode::Or); + binary_int_op!(bor, Opcode::Bor); + binary_boolean_op!(xor, Opcode::Xor); + binary_int_op!(bxor, Opcode::Bxor); + unary_int_op!(neg, Opcode::Neg); + unary_int_op!(inv, Opcode::Inv); + unary_int_op_with_overflow!(incr, Opcode::Incr); + unary_int_op!(ilog2, Opcode::Ilog2); + unary_int_op!(pow2, Opcode::Pow2); + unary_boolean_op!(not, Opcode::Not); + unary_int_op!(bnot, Opcode::Bnot); + unary_int_op!(popcnt, Opcode::Popcnt); + unary_int_op!(clz, Opcode::Clz); + unary_int_op!(ctz, Opcode::Ctz); + unary_int_op!(clo, Opcode::Clo); + unary_int_op!(cto, Opcode::Cto); + */ + + fn rotl(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn rotr(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn shl(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn shr(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn eq(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn neq(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn gt(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn gte(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn lt(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn lte(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + #[allow(clippy::wrong_self_convention)] + fn is_odd(mut self, value: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(value)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn exec( + mut self, + callee: C, + args: A, + span: SourceSpan, + ) -> Result, Report> + where + C: AsCallableSymbolRef, + A: IntoIterator, + { + let op_builder = self.builder_mut().create::(span); + op_builder(callee, args) + } + + /* + fn call(mut self, callee: FunctionIdent, args: &[Value], span: SourceSpan) -> Inst { + let mut vlist = ValueList::default(); + { + let dfg = self.data_flow_graph_mut(); + assert!( + dfg.get_import(&callee).is_some(), + "must import callee ({}) before calling it", + &callee + ); + vlist.extend(args.iter().copied(), &mut dfg.value_lists); + } + self.Call(Opcode::Call, callee, vlist, span).0 + } + */ + + fn select( + mut self, + cond: ValueRef, + a: ValueRef, + b: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(cond, a, b)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn br( + mut self, + block: BlockRef, + args: A, + span: SourceSpan, + ) -> Result, Report> + where + A: IntoIterator, + { + let op_builder = self.builder_mut().create::(span); + op_builder(block, args) + } + + fn cond_br( + mut self, + cond: ValueRef, + then_dest: BlockRef, + then_args: T, + else_dest: BlockRef, + else_args: F, + span: SourceSpan, + ) -> Result, Report> + where + T: IntoIterator, + F: IntoIterator, + { + let op_builder = + self.builder_mut().create::(span); + op_builder(cond, then_dest, then_args, else_dest, else_args) + } + + /* + fn switch(self, arg: ValueRef, span: SourceSpan) -> SwitchBuilder<'f, Self> { + require_integer!(self, arg, Type::U32); + SwitchBuilder::new(self, arg, span) + } + */ + + fn ret( + mut self, + returning: Option, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self + .builder_mut() + .create:: as IntoIterator>::IntoIter,)>( + span, + ); + op_builder(returning) + } + + fn ret_imm( + mut self, + arg: Immediate, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(arg) + } + + fn unreachable( + mut self, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder() + } + + /* + fn inline_asm( + self, + args: &[Value], + results: impl IntoIterator, + span: SourceSpan, + ) -> MasmBuilder { + MasmBuilder::new(self, args, results.into_iter().collect(), span) + } + */ +} + +impl<'f, T: InstBuilderBase<'f>> InstBuilder<'f> for T {} + +/* +/// An instruction builder for `switch`, to ensure it is validated during construction +pub struct SwitchBuilder<'f, T: InstBuilder<'f>> { + builder: T, + arg: ValueRef, + span: SourceSpan, + arms: Vec, + _marker: core::marker::PhantomData<&'f Function>, +} +impl<'f, T: InstBuilder<'f>> SwitchBuilder<'f, T> { + fn new(builder: T, arg: ValueRef, span: SourceSpan) -> Self { + Self { + builder, + arg, + span, + arms: Default::default(), + _marker: core::marker::PhantomData, + } + } + + /// Specify to what block a specific discriminant value should be dispatched + pub fn case(mut self, discriminant: u32, target: Block, args: &[Value]) -> Self { + assert_eq!( + self.arms + .iter() + .find(|arm| arm.value == discriminant) + .map(|arm| arm.successor.destination), + None, + "duplicate switch case value '{discriminant}': already matched" + ); + let mut vlist = ValueList::default(); + { + let pool = &mut self.builder.data_flow_graph_mut().value_lists; + vlist.extend(args.iter().copied(), pool); + } + let arm = SwitchArm { + value: discriminant, + successor: Successor { + destination: target, + args: vlist, + }, + }; + self.arms.push(arm); + self + } + + /// Build the `switch` by specifying the fallback destination if none of the arms match + pub fn or_else(mut self, target: Block, args: &[Value]) -> Inst { + let mut vlist = ValueList::default(); + { + let pool = &mut self.builder.data_flow_graph_mut().value_lists; + vlist.extend(args.iter().copied(), pool); + } + let fallback = Successor { + destination: target, + args: vlist, + }; + self.builder.Switch(self.arg, self.arms, fallback, self.span).0 + } +} + */ diff --git a/hir2/src/dialects/hir/ops.rs b/hir2/src/dialects/hir/ops.rs index a3656ce55..a4749119f 100644 --- a/hir2/src/dialects/hir/ops.rs +++ b/hir2/src/dialects/hir/ops.rs @@ -1,9 +1,14 @@ +mod assertions; mod binary; mod cast; mod control; mod invoke; mod mem; mod primop; +mod ternary; mod unary; -pub use self::{binary::*, cast::*, control::*, invoke::*, mem::*, primop::*, unary::*}; +pub use self::{ + assertions::*, binary::*, cast::*, control::*, invoke::*, mem::*, primop::*, ternary::*, + unary::*, +}; diff --git a/hir2/src/dialects/hir/ops/assertions.rs b/hir2/src/dialects/hir/ops/assertions.rs new file mode 100644 index 000000000..0c704349f --- /dev/null +++ b/hir2/src/dialects/hir/ops/assertions.rs @@ -0,0 +1,55 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +#[operation( + dialect = HirDialect, + traits(HasSideEffects) +)] +pub struct Assert { + #[operand] + value: Bool, + #[attr] + #[default] + code: u32, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects) +)] +pub struct Assertz { + #[operand] + value: Bool, + #[attr] + #[default] + code: u32, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, Commutative, SameTypeOperands) +)] +pub struct AssertEq { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects) +)] +pub struct AssertEqImm { + #[operand] + lhs: AnyInteger, + #[attr] + rhs: Immediate, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, Terminator) +)] +pub struct Unreachable {} diff --git a/hir2/src/dialects/hir/ops/binary.rs b/hir2/src/dialects/hir/ops/binary.rs index e7b5c2b8e..8582197ca 100644 --- a/hir2/src/dialects/hir/ops/binary.rs +++ b/hir2/src/dialects/hir/ops/binary.rs @@ -1,137 +1,654 @@ -use crate::{dialects::hir::HirDialect, traits::*, *}; - -macro_rules! derive_binary_op_with_overflow { - ($Op:ident) => { - derive! { - pub struct $Op: Op { - #[dialect] - dialect: HirDialect, - #[operand] - lhs: OpOperandRef, - #[operand] - rhs: OpOperandRef, - #[result] - result: OpResultRef, - #[attr] - overflow: Overflow, - } +use crate::{derive::operation, dialects::hir::HirDialect, traits::*, *}; - derives BinaryOp; - } - }; - - ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive! { - pub struct $Op: Op { - #[dialect] - dialect: HirDialect, - #[operand] - lhs: OpOperandRef, - #[operand] - rhs: OpOperandRef, - #[result] - result: OpResultRef, - #[attr] - overflow: Overflow, +/// Two's complement sum +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Add { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, + #[attr] + overflow: Overflow, +} + +impl InferTypeOpInterface for Add { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + let span = self.span(); + let lhs = self.lhs().ty().clone(); + { + let rhs = self.rhs(); + if lhs != rhs.ty() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand types") + .with_primary_label(span, "operands of this operation are not compatible") + .with_secondary_label( + rhs.span(), + format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), + ) + .into_report()); } + } + self.result_mut().set_type(lhs); + Ok(()) + } +} + +/// Two's complement sum with overflow bit +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct AddOverflowing { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + overflowed: Bool, + #[result] + result: AnyInteger, +} - derives BinaryOp, $OpTrait $(, $OpTraitRest)*; +impl InferTypeOpInterface for AddOverflowing { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + let span = self.span(); + let lhs = self.lhs().ty().clone(); + { + let rhs = self.rhs(); + if lhs != rhs.ty() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand types") + .with_primary_label(span, "operands of this operation are not compatible") + .with_secondary_label( + rhs.span(), + format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), + ) + .into_report()); + } } - }; -} - -macro_rules! derive_binary_op { - ($Op:ident) => { - derive! { - pub struct $Op: Op { - #[dialect] - dialect: HirDialect, - #[operand] - lhs: OpOperandRef, - #[operand] - rhs: OpOperandRef, - #[result] - result: OpResultRef, + self.result_mut().set_type(lhs); + Ok(()) + } +} + +/// Two's complement difference (subtraction) +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Sub { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, + #[attr] + overflow: Overflow, +} + +impl InferTypeOpInterface for Sub { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + let span = self.span(); + let lhs = self.lhs().ty().clone(); + { + let rhs = self.rhs(); + if lhs != rhs.ty() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand types") + .with_primary_label(span, "operands of this operation are not compatible") + .with_secondary_label( + rhs.span(), + format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), + ) + .into_report()); } + } + self.result_mut().set_type(lhs); + Ok(()) + } +} + +/// Two's complement difference (subtraction) with underflow bit +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct SubOverflowing { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + overflowed: Bool, + #[result] + result: AnyInteger, +} - derives BinaryOp; +impl InferTypeOpInterface for SubOverflowing { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + let span = self.span(); + let lhs = self.lhs().ty().clone(); + { + let rhs = self.rhs(); + if lhs != rhs.ty() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand types") + .with_primary_label(span, "operands of this operation are not compatible") + .with_secondary_label( + rhs.span(), + format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), + ) + .into_report()); + } } - }; - - ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive! { - pub struct $Op: Op { - #[dialect] - dialect: HirDialect, - #[operand] - lhs: OpOperandRef, - #[operand] - rhs: OpOperandRef, - #[result] - result: OpResultRef, + self.result_mut().set_type(lhs); + Ok(()) + } +} + +/// Two's complement product +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Mul { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, + #[attr] + overflow: Overflow, +} + +impl InferTypeOpInterface for Mul { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + let span = self.span(); + let lhs = self.lhs().ty().clone(); + { + let rhs = self.rhs(); + if lhs != rhs.ty() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand types") + .with_primary_label(span, "operands of this operation are not compatible") + .with_secondary_label( + rhs.span(), + format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), + ) + .into_report()); } + } + self.result_mut().set_type(lhs); + Ok(()) + } +} + +/// Two's complement product with overflow bit +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) + )] +pub struct MulOverflowing { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + overflowed: Bool, + #[result] + result: AnyInteger, +} - derives BinaryOp, $OpTrait $(, $OpTraitRest)*; +impl InferTypeOpInterface for MulOverflowing { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + let span = self.span(); + let lhs = self.lhs().ty().clone(); + { + let rhs = self.rhs(); + if lhs != rhs.ty() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand types") + .with_primary_label(span, "operands of this operation are not compatible") + .with_secondary_label( + rhs.span(), + format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), + ) + .into_report()); + } } - }; + self.result_mut().set_type(lhs); + Ok(()) + } } -macro_rules! derive_binary_logical_op { - ($Op:ident) => { - derive_binary_op!($Op derives SameTypeOperands, SameOperandsAndResultType, Commutative); - }; +/// Exponentiation for field elements +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Exp { + #[operand] + lhs: IntFelt, + #[operand] + rhs: IntFelt, + #[result] + result: IntFelt, +} - ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_binary_op!($Op derives SameTypeOperands, SameOperandsAndResultType, Commutative, $OpTrait $(, $OpTraitRest)*); - }; +/// Unsigned integer division, traps on division by zero +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Div { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, } -macro_rules! derive_binary_bitwise_op { - ($Op:ident) => { - derive_binary_op!($Op derives SameTypeOperands, SameOperandsAndResultType); - }; +/// Signed integer division, traps on division by zero or dividing the minimum signed value by -1 +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Sdiv { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} - ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_binary_op!($Op derives SameTypeOperands, SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); - }; +/// Unsigned integer Euclidean modulo, traps on division by zero +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Mod { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, } -macro_rules! derive_binary_comparison_op { - ($Op:ident) => { - derive_binary_op!($Op derives SameTypeOperands); - }; +/// Signed integer Euclidean modulo, traps on division by zero +/// +/// The result has the same sign as the dividend (lhs) +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Smod { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} - ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_binary_op!($Op derives SameTypeOperands, $OpTrait $(, $OpTraitRest)*); - }; +/// Combined unsigned integer Euclidean division and remainder (modulo). +/// +/// Traps on division by zero. +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Divmod { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + remainder: AnyInteger, + #[result] + quotient: AnyInteger, } -derive_binary_op_with_overflow!(Add derives Commutative, SameTypeOperands); -derive_binary_op_with_overflow!(Sub derives SameTypeOperands); -derive_binary_op_with_overflow!(Mul derives Commutative, SameTypeOperands); -derive_binary_op_with_overflow!(Exp); +/// Combined signed integer Euclidean division and remainder (modulo). +/// +/// Traps on division by zero. +/// +/// The remainder has the same sign as the dividend (lhs) +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Sdivmod { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + remainder: AnyInteger, + #[result] + quotient: AnyInteger, +} -derive_binary_op!(Div derives SameTypeOperands, SameOperandsAndResultType); -derive_binary_op!(Mod derives SameTypeOperands, SameOperandsAndResultType); -derive_binary_op!(DivMod derives SameTypeOperands, SameOperandsAndResultType); +/// Logical AND +/// +/// Operands must be boolean. +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct And { + #[operand] + lhs: Bool, + #[operand] + rhs: Bool, + #[result] + result: Bool, +} -derive_binary_logical_op!(And); -derive_binary_logical_op!(Or); -derive_binary_logical_op!(Xor); +/// Logical OR +/// +/// Operands must be boolean. +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Or { + #[operand] + lhs: Bool, + #[operand] + rhs: Bool, + #[result] + result: Bool, +} -derive_binary_bitwise_op!(Band derives Commutative); -derive_binary_bitwise_op!(Bor derives Commutative); -derive_binary_bitwise_op!(Bxor derives Commutative); -derive_binary_op!(Shl); -derive_binary_op!(Shr); -derive_binary_op!(Rotl); -derive_binary_op!(Rotr); +/// Logical XOR +/// +/// Operands must be boolean. +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Xor { + #[operand] + lhs: Bool, + #[operand] + rhs: Bool, + #[result] + result: Bool, +} + +/// Bitwise AND +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Band { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +/// Bitwise OR +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Bor { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +/// Bitwise XOR +/// +/// Operands must be boolean. +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Bxor { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +/// Bitwise shift-left +/// +/// Shifts larger than the bitwidth of the value will be wrapped to zero. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + )] +pub struct Shl { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +/// Bitwise (logical) shift-right +/// +/// Shifts larger than the bitwidth of the value will effectively truncate the value to zero. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + )] +pub struct Shr { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +/// Arithmetic (signed) shift-right +/// +/// The result of shifts larger than the bitwidth of the value depend on the sign of the value; +/// for positive values, it rounds to zero; for negative values, it rounds to MIN. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + )] +pub struct Ashr { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +/// Bitwise rotate-left +/// +/// The rotation count must be < the bitwidth of the value type. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + )] +pub struct Rotl { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +/// Bitwise rotate-right +/// +/// The rotation count must be < the bitwidth of the value type. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + )] +pub struct Rotr { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +/// Equality comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + )] +pub struct Eq { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +/// Inequality comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + )] +pub struct Neq { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +/// Greater-than comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + )] +pub struct Gt { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +/// Greater-than-or-equal comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + )] +pub struct Gte { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +/// Less-than comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + )] +pub struct Lt { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +/// Less-than-or-equal comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + )] +pub struct Lte { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +/// Select minimum value +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Min { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} -derive_binary_comparison_op!(Eq derives Commutative); -derive_binary_comparison_op!(Neq derives Commutative); -derive_binary_comparison_op!(Gt); -derive_binary_comparison_op!(Gte); -derive_binary_comparison_op!(Lt); -derive_binary_comparison_op!(Lte); -derive_binary_comparison_op!(Min derives Commutative); -derive_binary_comparison_op!(Max derives Commutative); +/// Select maximum value +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + )] +pub struct Max { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} diff --git a/hir2/src/dialects/hir/ops/cast.rs b/hir2/src/dialects/hir/ops/cast.rs index 3d51841f0..06017b501 100644 --- a/hir2/src/dialects/hir/ops/cast.rs +++ b/hir2/src/dialects/hir/ops/cast.rs @@ -1,3 +1,5 @@ +use midenc_hir_macros::operation; + use crate::{dialects::hir::HirDialect, traits::*, *}; /* @@ -29,107 +31,93 @@ pub enum CastKind { } */ -derive! { - pub struct PtrToInt : Op { - #[dialect] - dialect: HirDialect, - #[attr] - ty: Type, - #[operand] - operand: OpOperand, - #[result] - result: OpResult, - } - - derives UnaryOp; +#[operation( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct PtrToInt { + #[operand] + operand: AnyPointer, + #[attr] + ty: Type, + #[result] + result: AnyInteger, } -derive! { - pub struct IntToPtr : Op { - #[dialect] - dialect: HirDialect, - #[attr] - ty: Type, - #[operand] - operand: OpOperand, - #[result] - result: OpResult, - } - - derives UnaryOp; +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct IntToPtr { + #[operand] + operand: AnyInteger, + #[attr] + ty: Type, + #[result] + result: AnyPointer, } -derive! { - pub struct Cast : Op { - #[dialect] - dialect: HirDialect, - #[attr] - ty: Type, - #[operand] - operand: OpOperand, - #[result] - result: OpResult, - } - - derives UnaryOp; +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Cast { + #[operand] + operand: AnyInteger, + #[attr] + ty: Type, + #[result] + result: AnyInteger, } -derive! { - pub struct Bitcast : Op { - #[dialect] - dialect: HirDialect, - #[attr] - ty: Type, - #[operand] - operand: OpOperand, - #[result] - result: OpResult, - } - - derives UnaryOp; +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Bitcast { + #[operand] + operand: AnyPointerOrInteger, + #[attr] + ty: Type, + #[result] + result: AnyPointerOrInteger, } -derive! { - pub struct Trunc : Op { - #[dialect] - dialect: HirDialect, - #[attr] - ty: Type, - #[operand] - operand: OpOperand, - #[result] - result: OpResult, - } - - derives UnaryOp; +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Trunc { + #[operand] + operand: AnyInteger, + #[attr] + ty: Type, + #[result] + result: AnyInteger, } -derive! { - pub struct Zext : Op { - #[dialect] - dialect: HirDialect, - #[attr] - ty: Type, - #[operand] - operand: OpOperand, - #[result] - result: OpResult, - } - - derives UnaryOp; +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Zext { + #[operand] + operand: AnyUnsignedInteger, + #[attr] + ty: Type, + #[result] + result: AnyUnsignedInteger, } -derive! { - pub struct Sext : Op { - #[dialect] - dialect: HirDialect, - #[attr] - ty: Type, - #[operand] - operand: OpOperand, - #[result] - result: OpResult, - } - - derives UnaryOp; +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Sext { + #[operand] + operand: AnySignedInteger, + #[attr] + ty: Type, + #[result] + result: AnySignedInteger, } diff --git a/hir2/src/dialects/hir/ops/control.rs b/hir2/src/dialects/hir/ops/control.rs index 71a2fd871..f41ceb73b 100644 --- a/hir2/src/dialects/hir/ops/control.rs +++ b/hir2/src/dialects/hir/ops/control.rs @@ -1,138 +1,165 @@ -use smallvec::SmallVec; +use midenc_hir_macros::operation; use crate::{dialects::hir::HirDialect, traits::*, *}; -derive! { - pub struct Ret : Op { - #[dialect] - dialect: HirDialect, - #[operand] - value: OpOperand, - } - - derives Terminator; +#[operation( + dialect = HirDialect, + traits(Terminator, ReturnLike) +)] +pub struct Ret { + #[operands] + values: AnyType, } -// TODO(pauls): RetImm - -derive! { - pub struct Br : Op { - #[dialect] - dialect: HirDialect, - #[successor] - target: Successor, - } - - derives Terminator; +#[operation( + dialect = HirDialect, + traits(Terminator, ReturnLike) +)] +pub struct RetImm { + #[attr] + value: Immediate, } -derive! { - pub struct CondBr : Op { - #[dialect] - dialect: HirDialect, - #[operand] - condition: OpOperand, - #[successor] - then_dest: Successor, - #[successor] - else_dest: Successor, - } - - derives Terminator; +#[operation( + dialect = HirDialect, + traits(Terminator) +)] +pub struct Br { + #[successor] + target: Successor, } -derive! { - pub struct Switch : Op { - #[dialect] - dialect: HirDialect, - #[operand] - selector: OpOperand, - #[successors(delegated)] - cases: SmallVec<[SwitchCase; 2]>, - #[successor] - fallback: Successor, - } +#[operation( + dialect = HirDialect, + traits(Terminator) +)] +pub struct CondBr { + #[operand] + condition: Bool, + #[successor] + then_dest: Successor, + #[successor] + else_dest: Successor, +} - derives Terminator; +#[operation( + dialect = HirDialect, + traits(Terminator) +)] +pub struct Switch { + #[operand] + selector: UInt32, + #[successors(keyed)] + cases: SwitchCase, + #[successor] + fallback: Successor, } // TODO(pauls): Implement `SuccessorInterface` for this type #[derive(Debug, Clone)] pub struct SwitchCase { pub value: u32, - pub successor: OpSuccessor, + pub successor: BlockRef, + pub arguments: Vec, } -impl From for OpSuccessor { - #[inline] - fn from(value: SwitchCase) -> Self { - value.successor - } +pub struct SwitchCaseRef<'a> { + pub value: u32, + pub successor: BlockOperandRef, + pub arguments: OpOperandRange<'a>, } -impl From<&SwitchCase> for OpSuccessor { - #[inline] - fn from(value: &SwitchCase) -> Self { - value.successor.clone() - } +pub struct SwitchCaseMut<'a> { + pub value: u32, + pub successor: BlockOperandRef, + pub arguments: OpOperandRangeMut<'a>, } -derive! { - pub struct If : Op { - #[dialect] - dialect: HirDialect, - #[operand] - condition: OpOperand, - #[region] - then_body: Region, - #[region] - else_body: Region, - } +impl KeyedSuccessor for SwitchCase { + type Key = u32; + type Repr<'a> = SwitchCaseRef<'a>; + type ReprMut<'a> = SwitchCaseMut<'a>; - derives SingleBlock, NoRegionArguments; -} + fn key(&self) -> &Self::Key { + &self.value + } -derive! { - /// A while is a loop structure composed of two regions: a "before" region, and an "after" region. - /// - /// The "before" region's entry block parameters correspond to the operands expected by the - /// operation, and can be used to compute the condition that determines whether the "after" body - /// is executed or not, or simply forwarded to the "after" region. The "before" region must - /// terminate with a [Condition] operation, which will be evaluated to determine whether or not - /// to continue the loop. - /// - /// The "after" region corresponds to the loop body, and must terminate with a [Yield] operation, - /// whose operands must be of the same arity and type as the "before" region's argument list. In - /// this way, the "after" body can feed back input to the "before" body to determine whether to - /// continue the loop. - pub struct While : Op { - #[dialect] - dialect: HirDialect, - #[region] - before: Region, - #[region] - after: Region, + fn into_parts(self) -> (Self::Key, BlockRef, Vec) { + (self.value, self.successor, self.arguments) } - derives SingleBlock; -} + fn into_repr( + key: Self::Key, + block: BlockOperandRef, + operands: OpOperandRange<'_>, + ) -> Self::Repr<'_> { + SwitchCaseRef { + value: key, + successor: block, + arguments: operands, + } + } -derive! { - pub struct Condition : Op { - #[dialect] - dialect: HirDialect, - #[operand] - value: OpOperand, + fn into_repr_mut( + key: Self::Key, + block: BlockOperandRef, + operands: OpOperandRangeMut<'_>, + ) -> Self::ReprMut<'_> { + SwitchCaseMut { + value: key, + successor: block, + arguments: operands, + } } +} - derives Terminator, ReturnLike; +#[operation( + dialect = HirDialect, + traits(SingleBlock, NoRegionArguments) +)] +pub struct If { + #[operand] + condition: Bool, + #[region] + then_body: Region, + #[region] + else_body: Region, } -derive! { - pub struct Yield : Op { - #[dialect] - dialect: HirDialect, - } +/// A while is a loop structure composed of two regions: a "before" region, and an "after" region. +/// +/// The "before" region's entry block parameters correspond to the operands expected by the +/// operation, and can be used to compute the condition that determines whether the "after" body +/// is executed or not, or simply forwarded to the "after" region. The "before" region must +/// terminate with a [Condition] operation, which will be evaluated to determine whether or not +/// to continue the loop. +/// +/// The "after" region corresponds to the loop body, and must terminate with a [Yield] operation, +/// whose operands must be of the same arity and type as the "before" region's argument list. In +/// this way, the "after" body can feed back input to the "before" body to determine whether to +/// continue the loop. +#[operation( + dialect = HirDialect, + traits(SingleBlock) +)] +pub struct While { + #[region] + before: Region, + #[region] + after: Region, +} - derives Terminator, ReturnLike; +#[operation( + dialect = HirDialect, + traits(Terminator, ReturnLike) +)] +pub struct Condition { + #[operand] + value: Bool, } + +#[operation( + dialect = HirDialect, + traits(Terminator, ReturnLike) +)] +pub struct Yield {} diff --git a/hir2/src/dialects/hir/ops/invoke.rs b/hir2/src/dialects/hir/ops/invoke.rs index b326ee510..cb3851513 100644 --- a/hir2/src/dialects/hir/ops/invoke.rs +++ b/hir2/src/dialects/hir/ops/invoke.rs @@ -1,32 +1,34 @@ +use midenc_hir_macros::operation; + use crate::{dialects::hir::HirDialect, traits::*, *}; // TODO(pauls): Implement support for: // // * Inferring op constraints from callee signature -derive! { - pub struct Exec : Op { - #[dialect] - dialect: HirDialect, - #[attr] - callee: SymbolNameAttr, - #[operands] - arguments: Vec, - } - - implements CallOpInterface; +#[operation( + dialect = HirDialect, + implements(CallOpInterface) +)] +pub struct Exec { + #[symbol(callable)] + callee: SymbolNameAttr, + #[operands] + arguments: AnyType, } -derive! { - pub struct ExecIndirect : Op { - #[dialect] - dialect: HirDialect, - #[attr] - signature: Signature, - #[operand] - callee: OpOperand, - } +/* +#[operation( + dialect = HirDialect, + implements(CallOpInterface) +)] +pub struct ExecIndirect { + #[attr] + signature: Signature, + /// TODO(pauls): Change this to FunctionType + #[operand] + callee: AnyType, } - + */ impl CallOpInterface for Exec { #[inline(always)] fn callable_for_callee(&self) -> Callable { diff --git a/hir2/src/dialects/hir/ops/mem.rs b/hir2/src/dialects/hir/ops/mem.rs index da053079d..1650e8aa4 100644 --- a/hir2/src/dialects/hir/ops/mem.rs +++ b/hir2/src/dialects/hir/ops/mem.rs @@ -1,29 +1,62 @@ -use crate::{dialects::hir::HirDialect, traits::*, *}; +use midenc_hir_macros::operation; -derive! { - pub struct Store : Op { - #[dialect] - dialect: HirDialect, - #[operand] - addr: OpOperand, - #[operand] - value: OpOperand, - } +use crate::{dialects::hir::HirDialect, traits::*, *}; - derives HasSideEffects, MemoryWrite; +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryWrite) +)] +pub struct Store { + #[operand] + addr: AnyPointer, + #[operand] + value: AnyType, } // TODO(pauls): StoreLocal -derive! { - pub struct Load : Op { - #[dialect] - dialect: HirDialect, - #[operand] - addr: OpOperand, - } +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryRead), + implements(InferTypeOpInterface) +)] +pub struct Load { + #[operand] + addr: AnyPointer, + #[result] + result: AnyType, +} - derives HasSideEffects, MemoryRead; +impl InferTypeOpInterface for Load { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + let span = self.span(); + let pointee = { + let addr = self.addr(); + let addr_value = addr.value(); + addr_value.ty().pointee().cloned() + }; + match pointee { + Some(pointee) => { + self.result_mut().set_type(pointee); + Ok(()) + } + None => { + let addr = self.addr(); + let addr_value = addr.value(); + let addr_ty = addr_value.ty(); + Err(context + .session + .diagnostics + .diagnostic(miden_assembly::diagnostics::Severity::Error) + .with_message("invalid operand for 'load'") + .with_primary_label( + span, + format!("invalid 'addr' operand, expected pointer, got '{addr_ty}'"), + ) + .into_report()) + } + } + } } // TODO(pauls): LoadLocal diff --git a/hir2/src/dialects/hir/ops/primop.rs b/hir2/src/dialects/hir/ops/primop.rs index 744f50d42..859d89894 100644 --- a/hir2/src/dialects/hir/ops/primop.rs +++ b/hir2/src/dialects/hir/ops/primop.rs @@ -1,59 +1,49 @@ -use crate::{dialects::hir::HirDialect, traits::*, *}; +use midenc_hir_macros::operation; -derive! { - pub struct MemGrow : Op { - #[dialect] - dialect: HirDialect, - #[operand] - pages: OpOperand, - #[result] - result: OpResult, - } +use crate::{dialects::hir::HirDialect, traits::*, *}; - derives HasSideEffects, MemoryRead, MemoryWrite; +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryRead, MemoryWrite, SameOperandsAndResultType) +)] +pub struct MemGrow { + #[operand] + pages: UInt32, + #[result] + result: UInt32, } -derive! { - pub struct MemSize : Op { - #[dialect] - dialect: HirDialect, - #[result] - result: OpResult, - } - - derives HasSideEffects, MemoryRead; +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryRead) +)] +pub struct MemSize { + #[result] + result: UInt32, } -derive! { - pub struct MemSet : Op { - #[dialect] - dialect: HirDialect, - #[operand] - addr: OpOperand, - #[operand] - count: OpOperand, - #[operand] - value: OpOperand, - #[result] - result: OpResult, - } - - derives HasSideEffects, MemoryWrite; +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryWrite) +)] +pub struct MemSet { + #[operand] + addr: AnyPointer, + #[operand] + count: UInt32, + #[operand] + value: AnyType, } -derive! { - pub struct MemCpy : Op { - #[dialect] - dialect: HirDialect, - #[operand] - source: OpOperand, - #[operand] - destination: OpOperand, - #[operand] - count: OpOperand, - #[result] - result: OpResult, - } - - derives HasSideEffects, MemoryRead, MemoryWrite; +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryRead, MemoryWrite) +)] +pub struct MemCpy { + #[operand] + source: AnyPointer, + #[operand] + destination: AnyPointer, + #[operand] + count: UInt32, } diff --git a/hir2/src/dialects/hir/ops/ternary.rs b/hir2/src/dialects/hir/ops/ternary.rs new file mode 100644 index 000000000..1cac4ffc1 --- /dev/null +++ b/hir2/src/dialects/hir/ops/ternary.rs @@ -0,0 +1,45 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +/// Choose a value based on a boolean condition +#[operation( + dialect = HirDialect, + implements(InferTypeOpInterface) +)] +pub struct Select { + #[operand] + cond: Bool, + #[operand] + first: AnyInteger, + #[operand] + second: AnyInteger, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for Select { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + let span = self.span(); + let lhs = self.first().ty().clone(); + { + let rhs = self.second(); + if lhs != rhs.ty() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand types") + .with_primary_label(span, "operands of this operation are not compatible") + .with_secondary_label( + rhs.span(), + format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), + ) + .into_report()); + } + } + self.result_mut().set_type(lhs); + Ok(()) + } +} diff --git a/hir2/src/dialects/hir/ops/unary.rs b/hir2/src/dialects/hir/ops/unary.rs index e097923f2..785bab848 100644 --- a/hir2/src/dialects/hir/ops/unary.rs +++ b/hir2/src/dialects/hir/ops/unary.rs @@ -1,69 +1,157 @@ -use crate::{dialects::hir::HirDialect, traits::*, *}; +use crate::{derive::operation, dialects::hir::HirDialect, traits::*, *}; -macro_rules! derive_unary_op { - ($Op:ident) => { - derive! { - pub struct $Op: Op { - #[dialect] - dialect: HirDialect, - #[operand] - operand: OpOperandRef, - #[result] - result: OpResultRef, - } +/// Increment +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Incr { + #[operand] + operand: AnyInteger, + #[result] + result: AnyInteger, +} + +/// Negation +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Neg { + #[operand] + operand: AnyInteger, + #[result] + result: AnyInteger, +} - derives UnaryOp; - } - }; +/// Modular inverse +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Inv { + #[operand] + operand: IntFelt, + #[result] + result: IntFelt, +} + +/// log2(operand) +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Ilog2 { + #[operand] + operand: IntFelt, + #[result] + result: IntFelt, +} - ($Op:ident derives $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive! { - pub struct $Op: Op { - #[dialect] - dialect: HirDialect, - #[operand] - operand: OpOperandRef, - #[result] - result: OpResultRef, - } +/// pow2(operand) +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Pow2 { + #[operand] + operand: AnyInteger, + #[result] + result: AnyInteger, +} - derives UnaryOp, $OpTrait $(, $OpTraitRest)*; - } - }; +/// Logical NOT +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Not { + #[operand] + operand: Bool, + #[result] + result: Bool, } -macro_rules! derive_unary_logical_op { - ($Op:ident) => { - derive_unary_op!($Op derives SameOperandsAndResultType); - }; +/// Bitwise NOT +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Bnot { + #[operand] + operand: AnyInteger, + #[result] + result: AnyInteger, +} - ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_unary_op!($Op derives SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); - }; +/// is_odd(operand) +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct IsOdd { + #[operand] + operand: AnyInteger, + #[result] + result: Bool, } -macro_rules! derive_unary_bitwise_op { - ($Op:ident) => { - derive_unary_op!($Op derives SameOperandsAndResultType); - }; +/// Count of non-zero bits (population count) +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Popcnt { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} - ($Op:ident implements $OpTrait:ident $(, $OpTraitRest:ident)*) => { - derive_unary_op!($Op derives SameOperandsAndResultType, $OpTrait $(, $OpTraitRest)*); - }; +/// Count Leading Zeros +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Clz { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, } -derive_unary_op!(Neg derives SameOperandsAndResultType); -derive_unary_op!(Inv derives SameOperandsAndResultType); -derive_unary_op!(Incr derives SameOperandsAndResultType); -derive_unary_op!(Ilog2 derives SameOperandsAndResultType); -derive_unary_op!(Pow2 derives SameOperandsAndResultType); +/// Count Trailing Zeros +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Ctz { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} -derive_unary_logical_op!(Not); -derive_unary_logical_op!(IsOdd); +/// Count Leading Ones +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Clo { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} -derive_unary_bitwise_op!(Bnot); -derive_unary_bitwise_op!(Popcnt); -derive_unary_bitwise_op!(Clz); -derive_unary_bitwise_op!(Ctz); -derive_unary_bitwise_op!(Clo); -derive_unary_bitwise_op!(Cto); +/// Count Trailing Ones +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Cto { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index 84a4f3921..e18763c0e 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -1,5 +1,6 @@ mod attribute; mod block; +mod builder; mod component; mod context; mod dialect; @@ -13,6 +14,7 @@ mod module; mod op; mod operands; mod operation; +mod print; mod region; mod successor; pub(crate) mod symbol_table; @@ -20,30 +22,32 @@ pub mod traits; mod types; mod usable; mod value; -pub(crate) mod verifier; +pub mod verifier; mod visit; pub use midenc_hir_symbol as interner; -pub use midenc_session::diagnostics::{Report, SourceSpan, Spanned}; +pub use midenc_session::diagnostics::{Report, SourceSpan, Span, Spanned}; pub use self::{ - attribute::{attributes::*, Attribute, AttributeSet, AttributeValue}, + attribute::{attributes::*, Attribute, AttributeSet, AttributeValue, DictAttr, SetAttr}, block::{ Block, BlockCursor, BlockCursorMut, BlockId, BlockList, BlockOperand, BlockOperandRef, BlockRef, }, + builder::{Builder, Listener, ListenerType, OpBuilder}, context::Context, - dialect::{Dialect, DialectName}, + dialect::{Dialect, DialectName, DialectRegistration}, entity::{ - Entity, EntityCursor, EntityCursorMut, EntityId, EntityIter, EntityList, EntityMut, - EntityRef, RawEntityRef, UnsafeEntityRef, UnsafeIntrusiveEntityRef, + Entity, EntityCursor, EntityCursorMut, EntityGroup, EntityId, EntityIter, EntityList, + EntityMut, EntityRange, EntityRangeMut, EntityRef, EntityStorage, RawEntityRef, + StorableEntity, UnsafeEntityRef, UnsafeIntrusiveEntityRef, }, function::{AbiParam, ArgumentExtension, ArgumentPurpose, Function, Signature}, ident::{FunctionIdent, Ident}, immediates::{Felt, FieldElement, Immediate, StarkField}, insert::{Insert, InsertionPoint, ProgramPoint}, module::Module, - op::{Op, OpExt}, + op::{BuildableOp, Op, OpExt, OpRegistration}, operands::{ OpOperand, OpOperandImpl, OpOperandList, OpOperandRange, OpOperandRangeMut, OpOperandStorage, @@ -51,15 +55,28 @@ pub use self::{ operation::{ OpCursor, OpCursorMut, OpList, Operation, OperationBuilder, OperationName, OperationRef, }, + print::OpPrinter, region::{Region, RegionCursor, RegionCursorMut, RegionList, RegionRef}, - successor::OpSuccessor, + successor::{ + KeyedSuccessor, KeyedSuccessorRange, KeyedSuccessorRangeMut, OpSuccessor, OpSuccessorMut, + OpSuccessorRange, OpSuccessorRangeMut, OpSuccessorStorage, SuccessorInfo, SuccessorWithKey, + SuccessorWithKeyMut, + }, symbol_table::{ - Symbol, SymbolName, SymbolNameAttr, SymbolNameComponent, SymbolRef, SymbolTable, SymbolUse, + AsSymbolRef, InvalidSymbolRefError, Symbol, SymbolName, SymbolNameAttr, + SymbolNameComponent, SymbolNameComponents, SymbolRef, SymbolTable, SymbolUse, SymbolUseCursor, SymbolUseCursorMut, SymbolUseIter, SymbolUseList, SymbolUseRef, + SymbolUsesIter, }, types::*, usable::Usable, - value::{BlockArgument, BlockArgumentRef, OpResult, OpResultRef, Value, ValueId, ValueRef}, + value::{ + BlockArgument, BlockArgumentRef, OpResult, OpResultRange, OpResultRangeMut, OpResultRef, + OpResultStorage, Value, ValueId, ValueRef, + }, verifier::{OpVerifier, Verify}, - visit::{OpVisitor, OperationVisitor, Searcher, SymbolVisitor, Visitor}, + visit::{ + OpVisitor, OperationVisitor, Searcher, SymbolVisitor, Visitor, WalkOrder, WalkResult, + WalkStage, Walkable, + }, }; diff --git a/hir2/src/ir/block.rs b/hir2/src/ir/block.rs index 8479b195b..36a403375 100644 --- a/hir2/src/ir/block.rs +++ b/hir2/src/ir/block.rs @@ -138,6 +138,59 @@ impl Block { self.arguments[index].clone() } + /// Insert this block after `after` in its containing region. + /// + /// Panics if this block is already attached to a region, or if `after` is not attached. + pub fn insert_after(&mut self, after: BlockRef) { + assert!( + self.region.is_none(), + "cannot insert block that is already attached to another region" + ); + let mut region = + after.borrow().parent().expect("'after' block is not attached to a region"); + { + let mut region = region.borrow_mut(); + let region_body = region.body_mut(); + let mut cursor = unsafe { region_body.cursor_mut_from_ptr(after) }; + cursor.insert_after(unsafe { BlockRef::from_raw(self) }); + } + self.region = Some(region); + } + + /// Insert this block before `before` in its containing region. + /// + /// Panics if this block is already attached to a region, or if `before` is not attached. + pub fn insert_before(&mut self, before: BlockRef) { + assert!( + self.region.is_none(), + "cannot insert block that is already attached to another region" + ); + let mut region = + before.borrow().parent().expect("'before' block is not attached to a region"); + { + let mut region = region.borrow_mut(); + let region_body = region.body_mut(); + let mut cursor = unsafe { region_body.cursor_mut_from_ptr(before) }; + cursor.insert_before(unsafe { BlockRef::from_raw(self) }); + } + self.region = Some(region); + } + + /// Insert this block at the end of `region`. + /// + /// Panics if this block is already attached to a region. + pub fn insert_at_end(&mut self, mut region: RegionRef) { + assert!( + self.region.is_none(), + "cannot insert block that is already attached to another region" + ); + { + let mut region = region.borrow_mut(); + region.body_mut().push_back(unsafe { BlockRef::from_raw(self) }); + } + self.region = Some(region); + } + /// Get a handle to the containing [Region] of this block, if it is attached to one pub fn parent(&self) -> Option { self.region.clone() @@ -171,6 +224,10 @@ impl Block { pub fn predecessors(&self) -> BlockOperandIter<'_> { self.iter_uses() } + + pub fn drop_all_defined_value_uses(&mut self) { + todo!() + } } pub type BlockOperandRef = UnsafeIntrusiveEntityRef; @@ -214,3 +271,24 @@ impl fmt::Debug for BlockOperand { .finish() } } +impl StorableEntity for BlockOperand { + #[inline(always)] + fn index(&self) -> usize { + self.index as usize + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index.try_into().expect("too many successors"); + } + + /// Remove this use of `block` + fn unlink(&mut self) { + let owner = unsafe { BlockOperandRef::from_raw(self) }; + let mut block = self.block.borrow_mut(); + let uses = block.uses_mut(); + unsafe { + let mut cursor = uses.cursor_mut_from_ptr(owner); + cursor.remove(); + } + } +} diff --git a/hir2/src/ir/builder.rs b/hir2/src/ir/builder.rs new file mode 100644 index 000000000..916e1cf57 --- /dev/null +++ b/hir2/src/ir/builder.rs @@ -0,0 +1,318 @@ +use alloc::rc::Rc; + +use crate::{ + BlockArgument, BlockRef, BuildableOp, Context, InsertionPoint, OperationRef, ProgramPoint, + RegionRef, SourceSpan, Type, Value, +}; + +/// The [Builder] trait encompasses all of the functionality needed to construct and insert blocks +/// and operations into the IR. +pub trait Builder: Listener { + fn context(&self) -> &Context; + fn context_rc(&self) -> Rc; + /// Returns the current insertion point of the builder + fn insertion_point(&self) -> Option<&InsertionPoint>; + /// Clears the current insertion point + fn clear_insertion_point(&mut self) -> Option; + /// Restores the current insertion point to `ip` + fn restore_insertion_point(&mut self, ip: Option); + /// Sets the current insertion point to `ip` + fn set_insertion_point(&mut self, ip: InsertionPoint); + + /// Sets the insertion point to the specified program point, causing subsequent insertions to + /// be placed before it. + #[inline] + fn set_insertion_point_before(&mut self, pp: ProgramPoint) { + self.set_insertion_point(InsertionPoint::before(pp)); + } + + /// Sets the insertion point to the specified program point, causing subsequent insertions to + /// be placed after it. + #[inline] + fn set_insertion_point_after(&mut self, pp: ProgramPoint) { + self.set_insertion_point(InsertionPoint::after(pp)); + } + + /// Sets the insertion point to the node after the specified value is defined. + /// + /// If value has a defining operation, this sets the insertion point after that operation, so + /// that all insertions are placed following the definition. + /// + /// Otherwise, a value must be a block argument, so the insertion point is placed at the start + /// of the block, causing insertions to be placed starting at the front of the block. + fn set_insertion_point_after_value(&mut self, value: &dyn Value) { + let pp = if let Some(op) = value.get_defining_op() { + ProgramPoint::Op(op) + } else { + let block_argument = value.downcast_ref::().unwrap(); + ProgramPoint::Block(block_argument.owner()) + }; + self.set_insertion_point_after(pp); + } + + /// Sets the current insertion point to the start of `block`. + /// + /// Operations inserted will be placed starting at the beginning of the block. + #[inline] + fn set_insertion_point_to_start(&mut self, block: BlockRef) { + self.set_insertion_point_before(block.into()); + } + + /// Sets the current insertion point to the end of `block`. + /// + /// Operations inserted will be placed starting at the end of the block. + #[inline] + fn set_insertion_point_to_end(&mut self, block: BlockRef) { + self.set_insertion_point_after(block.into()); + } + + /// Return the block the current insertion point belongs to. + /// + /// NOTE: The insertion point is not necessarily at the end of the block. + /// + /// Returns `None` if the insertion point is unset, or is pointing at an operation which is + /// detached from a block. + fn insertion_block(&self) -> Option { + self.insertion_point().and_then(|ip| ip.at.block()) + } + + /// Add a new block with `args` arguments, and set the insertion point to the end of it. + /// + /// The block is inserted at the provided insertion point `ip`, or at the end of `parent` if + /// not. + /// + /// Panics if `ip` is in a different region than `parent`, or if the position it refers to is no + /// longer valid. + fn create_block

( + &mut self, + parent: RegionRef, + ip: Option, + args: P, + ) -> BlockRef + where + P: IntoIterator, + { + let mut block = self.context().create_block_with_params(args); + if let Some(InsertionPoint { at, action }) = ip { + let at = at.block().expect("invalid insertion point"); + let region = at.borrow().parent().unwrap(); + assert!( + RegionRef::ptr_eq(&parent, ®ion), + "insertion point region differs from 'parent'" + ); + + match action { + crate::Insert::Before => block.borrow_mut().insert_before(at), + crate::Insert::After => block.borrow_mut().insert_after(at), + } + } else { + block.borrow_mut().insert_at_end(parent); + } + + self.notify_block_inserted(block.clone(), None, None); + + block + } + + /// Add a new block with `args` arguments, and set the insertion point to the end of it. + /// + /// The block is inserted before `before`. + fn create_block_before

(&mut self, before: BlockRef, args: P) -> BlockRef + where + P: IntoIterator, + { + let mut block = self.context().create_block_with_params(args); + block.borrow_mut().insert_before(before); + self.notify_block_inserted(block.clone(), None, None); + block + } + + /// Insert `op` at the current insertion point + /// + /// This function will panic if no insertion point is set. + fn insert(&mut self, mut op: OperationRef) { + let InsertionPoint { at, action } = + self.insertion_point().expect("insertion point is unset").clone(); + match at { + ProgramPoint::Block(block) => match action { + crate::Insert::Before => op.borrow_mut().insert_at_start(block), + crate::Insert::After => op.borrow_mut().insert_at_end(block), + }, + ProgramPoint::Op(other_op) => match action { + crate::Insert::Before => op.borrow_mut().insert_before(other_op), + crate::Insert::After => op.borrow_mut().insert_after(other_op), + }, + } + self.notify_operation_inserted(op, None); + } + + /// Returns a specialized builder for a concrete [Op], `T`, which can be called like a closure + /// with the arguments required to create an instance of the specified operation. + /// + /// # How it works + /// + /// The set of arguments which are valid for the specialized builder returned by `create`, are + /// determined by what implementations of the [BuildableOp] trait exist for `T`. The specific + /// impl that is chosen will depend on the types of the arguments given to it. Typically, there + /// should only be one implementation, or if there are multiple, they should not overlap in + /// ways that may confuse type inference, or you will be forced to specify the full type of the + /// argument pack. + /// + /// This mechanism for constructing ops using arbitrary arguments is essentially a workaround + /// for the lack of variadic generics in Rust, and isn't quite as nice as what you can acheive + /// in C++ with varidadic templates and `std::forward` and such, but is close enough so that + /// the ergonomics are still a significant improvement over the alternative approaches. + /// + /// The nice thing about this is that we can generate all of the boilerplate, and hide all of + /// the sensitive/unsafe parts of initializing operations. Alternative approaches require + /// exposing more unsafe APIs for use by builders, whereas this approach can conceal those + /// details within this crate. + /// + /// ## Example + /// + /// ```text,ignore + /// // Get an OpBuilder + /// let builder = context.builder(); + /// // Obtain a builder for AddOp + /// let add_builder = builder.create::(span); + /// // Consume the builder by creating the op with the given arguments + /// let add = add_builder(lhs, rhs, Overflow::Wrapping).expect("invalid add op"); + /// ``` + /// + /// Or, simplified/collapsed: + /// + /// ```text,ignore + /// let builder = context.builder(); + /// let add = builder.create::(span)(lhs, rhs, Overflow::Wrapping) + /// .expect("invalid add op"); + /// ``` + #[inline(always)] + fn create(&mut self, span: SourceSpan) -> >::Builder<'_, Self> + where + Args: core::marker::Tuple, + T: BuildableOp, + { + >::builder(self, span) + } +} + +pub struct OpBuilder { + context: Rc, + listener: Option>, + ip: Option, +} + +impl OpBuilder { + pub fn new(context: Rc) -> Self { + Self { + context, + listener: None, + ip: None, + } + } + + /// Sets the listener of this builder to `listener` + pub fn with_listener(&mut self, listener: impl Listener) -> &mut Self { + self.listener = Some(Box::new(listener)); + self + } +} + +impl Listener for OpBuilder { + fn kind(&self) -> ListenerType { + self.listener.as_ref().map(|l| l.kind()).unwrap_or(ListenerType::Builder) + } + + fn notify_block_inserted( + &mut self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + if let Some(listener) = self.listener.as_deref_mut() { + listener.notify_block_inserted(block, prev, ip); + } + } + + fn notify_operation_inserted(&mut self, op: OperationRef, prev: Option) { + if let Some(listener) = self.listener.as_deref_mut() { + listener.notify_operation_inserted(op, prev); + } + } +} + +impl Builder for OpBuilder { + #[inline(always)] + fn context(&self) -> &Context { + self.context.as_ref() + } + + #[inline(always)] + fn context_rc(&self) -> Rc { + self.context.clone() + } + + #[inline(always)] + fn insertion_point(&self) -> Option<&InsertionPoint> { + self.ip.as_ref() + } + + #[inline] + fn clear_insertion_point(&mut self) -> Option { + self.ip.take() + } + + #[inline] + fn restore_insertion_point(&mut self, ip: Option) { + self.ip = ip; + } + + #[inline(always)] + fn set_insertion_point(&mut self, ip: InsertionPoint) { + self.ip = Some(ip); + } +} + +#[derive(Debug, Copy, Clone)] +pub enum ListenerType { + Builder, + Rewriter, +} + +pub trait Listener: 'static { + fn kind(&self) -> ListenerType; + /// Notify the listener that the specified operation was inserted. + /// + /// * If the operation was moved, then `prev` is the previous location of the op + /// * If the operation was unlinked before it was inserted, then `prev` is `None` + fn notify_operation_inserted(&mut self, op: OperationRef, prev: Option); + /// Notify the listener that the specified block was inserted. + /// + /// * If the block was moved, then `prev` and `ip` represent the previous location of the block. + /// * If the block was unlinked before it was inserted, then `prev` and `ip` are `None` + fn notify_block_inserted( + &mut self, + block: BlockRef, + prev: Option, + ip: Option, + ); +} + +pub struct InsertionGuard<'a> { + builder: &'a mut OpBuilder, + ip: Option, +} +impl<'a> InsertionGuard<'a> { + #[allow(unused)] + pub fn new(builder: &'a mut OpBuilder, ip: InsertionPoint) -> Self { + Self { + builder, + ip: Some(ip), + } + } +} +impl Drop for InsertionGuard<'_> { + fn drop(&mut self) { + self.builder.restore_insertion_point(self.ip.take()); + } +} diff --git a/hir2/src/ir/context.rs b/hir2/src/ir/context.rs index ba940f98c..9a01bfc5b 100644 --- a/hir2/src/ir/context.rs +++ b/hir2/src/ir/context.rs @@ -1,5 +1,8 @@ -use alloc::rc::Rc; -use core::{cell::Cell, mem::MaybeUninit}; +use alloc::{collections::BTreeMap, rc::Rc}; +use core::{ + cell::{Cell, RefCell}, + mem::MaybeUninit, +}; use blink_alloc::Blink; use midenc_session::Session; @@ -21,6 +24,7 @@ use super::*; pub struct Context { pub session: Rc, allocator: Rc, + registered_dialects: RefCell>>, next_block_id: Cell, next_value_id: Cell, //pub constants: ConstantPool, @@ -48,12 +52,39 @@ impl Context { Self { session, allocator, + registered_dialects: Default::default(), next_block_id: Cell::new(0), next_value_id: Cell::new(0), //constants: Default::default(), } } + pub fn registered_dialects( + &self, + ) -> core::cell::Ref<'_, BTreeMap>> { + self.registered_dialects.borrow() + } + + pub fn get_or_register_dialect(&self) -> Rc { + use alloc::collections::btree_map::Entry; + + let mut registered_dialects = self.registered_dialects.borrow_mut(); + let dialect_name = DialectName::new(T::NAMESPACE); + match registered_dialects.entry(dialect_name) { + Entry::Occupied(entry) => Rc::clone(entry.get()), + Entry::Vacant(entry) => { + let dialect = Rc::new(T::init()) as Rc; + entry.insert(Rc::clone(&dialect)); + dialect + } + } + } + + /// Get a new [OpBuilder] for this context + pub fn builder(self: Rc) -> OpBuilder { + OpBuilder::new(Rc::clone(&self)) + } + /// Create a new, detached and empty [Block] with no parameters pub fn create_block(&self) -> BlockRef { let block = Block::new(self.alloc_block_id()); @@ -71,6 +102,7 @@ impl Context { let args = tys.into_iter().enumerate().map(|(index, ty)| { let id = self.alloc_value_id(); let arg = BlockArgument::new( + SourceSpan::default(), id, ty, owner.clone(), @@ -82,13 +114,71 @@ impl Context { block } + /// Append a new [BlockArgument] to `block`, with the given type and source location + /// + /// Returns the block argument as a `dyn Value` reference + pub fn append_block_argument( + &self, + mut block: BlockRef, + ty: Type, + span: SourceSpan, + ) -> ValueRef { + let next_index = block.borrow().num_arguments(); + let id = self.alloc_value_id(); + let arg = BlockArgument::new( + span, + id, + ty, + block.clone(), + next_index.try_into().expect("too many block arguments"), + ); + let arg = self.alloc(arg); + block.borrow_mut().arguments_mut().push(arg.clone()); + arg.upcast() + } + + /// Create a new [OpOperand] with the given value, owner, and index. + /// + /// NOTE: This inserts the operand as a user of `value`, but does _not_ add the operand to + /// `owner`'s operand storage, the caller is expected to do that. This makes this function a + /// more useful primitive. + pub fn make_operand(&self, mut value: ValueRef, owner: OperationRef, index: u8) -> OpOperand { + let op_operand = self.alloc_tracked(OpOperandImpl::new(value.clone(), owner, index)); + let mut value = value.borrow_mut(); + value.insert_use(op_operand.clone()); + op_operand + } + + /// Create a new [BlockOperand] with the given block, owner, and index. + /// + /// NOTE: This inserts the block operand as a user of `block`, but does _not_ add the block + /// operand to `owner`'s successor storage, the caller is expected to do that. This makes this + /// function a more useful primitive. + pub fn make_block_operand( + &self, + mut block: BlockRef, + owner: OperationRef, + index: u8, + ) -> BlockOperandRef { + let block_operand = self.alloc_tracked(BlockOperand::new(block.clone(), owner, index)); + let mut block = block.borrow_mut(); + block.insert_use(block_operand.clone()); + block_operand + } + /// Create a new [OpResult] with the given type, owner, and index /// /// NOTE: This does not attach the result to the operation, it is expected that the caller will /// do so. - pub fn make_result(&self, ty: Type, owner: OperationRef, index: u8) -> OpResultRef { + pub fn make_result( + &self, + span: SourceSpan, + ty: Type, + owner: OperationRef, + index: u8, + ) -> OpResultRef { let id = self.alloc_value_id(); - self.alloc(OpResult::new(id, ty, owner, index)) + self.alloc(OpResult::new(span, id, ty, owner, index)) } /// Allocate a new uninitialized entity of type `T` diff --git a/hir2/src/ir/dialect.rs b/hir2/src/ir/dialect.rs index 11463fd9b..78346f8e5 100644 --- a/hir2/src/ir/dialect.rs +++ b/hir2/src/ir/dialect.rs @@ -1,12 +1,48 @@ -use core::ops::Deref; +use alloc::rc::Rc; +use core::{borrow::Borrow, ops::Deref}; + +use crate::{AsAny, OperationName}; /// A [Dialect] represents a collection of IR entities that are used in conjunction with one /// another. Multiple dialects can co-exist _or_ be mutually exclusive. Converting between dialects /// is the job of the conversion infrastructure, using a process called _legalization_. pub trait Dialect { - const INIT: Self; - + /// Get the name(space) of this dialect fn name(&self) -> DialectName; + /// Get the set of registered operations associated with this dialect + fn registered_ops(&self) -> Rc<[OperationName]>; + /// Get the registered [OperationName] for an op `opcode`, or register it with `register`. + /// + /// Registering an operation with the dialect allows various parts of the IR to introspect the + /// set of operations which belong to a given dialect namespace. + fn get_or_register_op( + &self, + opcode: ::midenc_hir_symbol::Symbol, + register: fn(DialectName, ::midenc_hir_symbol::Symbol) -> OperationName, + ) -> OperationName; +} + +/// A [DialectRegistration] must be implemented for any implementation of [Dialect], to allow the +/// dialect to be registered with a [crate::Context] and instantiated on demand when building ops +/// in the IR. +/// +/// This is not part of the [Dialect] trait itself, as that trait must be object safe, and this +/// trait is _not_ object safe. +pub trait DialectRegistration: AsAny + Dialect { + /// The namespace of the dialect to register + /// + /// A dialect namespace serves both as a way to namespace the operations of that dialect, as + /// well as a way to uniquely name/identify the dialect itself. Thus, no two dialects can have + /// the same namespace at the same time. + const NAMESPACE: &'static str; + + /// Initialize an instance of this dialect to be stored (uniqued) in the current + /// [crate::Context]. + /// + /// A dialect will only ever be initialized once per context. A dialect must use interior + /// mutability to satisfy the requirements of the [Dialect] trait, and to allow the context to + /// store the returned instance in a reference-counted smart pointer. + fn init() -> Self; } /// A strongly-typed symbol representing the name of a [Dialect]. @@ -25,6 +61,10 @@ impl DialectName { pub const fn from_symbol(name: ::midenc_hir_symbol::Symbol) -> Self { Self(name) } + + pub const fn as_symbol(&self) -> ::midenc_hir_symbol::Symbol { + self.0 + } } impl core::fmt::Debug for DialectName { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -62,3 +102,15 @@ impl AsRef<::midenc_hir_symbol::Symbol> for DialectName { &self.0 } } +impl Borrow<::midenc_hir_symbol::Symbol> for DialectName { + #[inline(always)] + fn borrow(&self) -> &::midenc_hir_symbol::Symbol { + &self.0 + } +} +impl Borrow for DialectName { + #[inline(always)] + fn borrow(&self) -> &str { + self.0.as_str() + } +} diff --git a/hir2/src/ir/function.rs b/hir2/src/ir/function.rs index 498728ef8..49a7a9725 100644 --- a/hir2/src/ir/function.rs +++ b/hir2/src/ir/function.rs @@ -2,31 +2,58 @@ use core::fmt; use super::*; use crate::{ - derive, + derive::operation, dialects::hir::HirDialect, formatter, - traits::{CallableOpInterface, SingleRegion}, - CallConv, Symbol, SymbolName, SymbolUse, SymbolUseIter, SymbolUseList, Visibility, + traits::{ + CallableOpInterface, IsolatedFromAbove, RegionKind, RegionKindInterface, SingleRegion, + }, + CallConv, Symbol, SymbolName, SymbolUse, SymbolUseList, SymbolUsesIter, Visibility, }; trait UsableSymbol = Usable; -derive! { - pub struct Function: Op { - #[dialect] - dialect: HirDialect, - #[region] - body: RegionRef, - #[attr] - name: Ident, - #[attr] - signature: Signature, - /// The uses of this function as a symbol - uses: SymbolUseList, - } - - derives SingleRegion; - implements UsableSymbol, Symbol, CallableOpInterface; +#[operation( + dialect = HirDialect, + traits(SingleRegion, IsolatedFromAbove), + implements( + UsableSymbol, + Symbol, + CallableOpInterface, + RegionKindInterface + ) +)] +pub struct Function { + #[region] + body: RegionRef, + #[attr] + name: Ident, + #[attr] + signature: Signature, + /// The uses of this function as a symbol + uses: SymbolUseList, +} + +impl Function { + #[inline] + pub fn entry_block(&self) -> BlockRef { + unsafe { BlockRef::from_raw(&*self.body().entry()) } + } + + pub fn last_block(&self) -> BlockRef { + self.body() + .body() + .back() + .as_pointer() + .expect("cannot access blocks of a function declaration") + } +} + +impl RegionKindInterface for Function { + #[inline(always)] + fn kind(&self) -> RegionKind { + RegionKind::SSA + } } impl Usable for Function { @@ -45,12 +72,12 @@ impl Usable for Function { impl Symbol for Function { #[inline(always)] - fn as_operation(&self) -> &Operation { + fn as_symbol_operation(&self) -> &Operation { &self.op } #[inline(always)] - fn as_operation_mut(&mut self) -> &mut Operation { + fn as_symbol_operation_mut(&mut self) -> &mut Operation { &mut self.op } @@ -58,48 +85,56 @@ impl Symbol for Function { Self::name(self).as_symbol() } - /// Set the name of this symbol fn set_name(&mut self, name: SymbolName) { - let mut id = *self.name(); + let id = self.name_mut(); id.name = name; - Function::set_name(self, id) } - /// Get the visibility of this symbol fn visibility(&self) -> Visibility { self.signature().visibility } - /// Returns true if this symbol has private visibility - #[inline] - fn is_private(&self) -> bool { - self.signature().is_private() - } - - /// Returns true if this symbol has public visibility - #[inline] - fn is_public(&self) -> bool { - self.signature().is_public() - } - - /// Sets the visibility of this symbol fn set_visibility(&mut self, visibility: Visibility) { self.signature_mut().visibility = visibility; } - /// Get all of the uses of this symbol that are nested within `from` - fn symbol_uses(&self, from: OperationRef) -> SymbolUseIter { - todo!() - } - - /// Return true if there are no uses of this symbol nested within `from` - fn symbol_uses_known_empty(&self, from: OperationRef) -> SymbolUseIter { - todo!() - } + fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter { + SymbolUsesIter::from_iter(self.uses.iter().filter_map(|user| { + if OperationRef::ptr_eq(&from, &user.owner) + || from.borrow().is_proper_ancestor_of(user.owner.clone()) + { + Some(unsafe { SymbolUseRef::from_raw(&*user) }) + } else { + None + } + })) + } + + fn replace_all_uses( + &mut self, + replacement: SymbolRef, + from: OperationRef, + ) -> Result<(), Report> { + for symbol_use in self.symbol_uses(from) { + let (mut owner, attr_name) = { + let user = symbol_use.borrow(); + (user.owner.clone(), user.symbol) + }; + let mut owner = owner.borrow_mut(); + // Unlink previously used symbol + { + let current_symbol = owner + .get_typed_attribute_mut::(&attr_name) + .expect("stale symbol user"); + unsafe { + self.uses.cursor_mut_from_ptr(current_symbol.user.clone()).remove(); + } + } + // Link replacement symbol + owner.set_symbol_attribute(attr_name, replacement.clone()); + } - /// Attempt to replace all uses of this symbol nested within `from`, with the provided replacement - fn replace_all_uses(&self, replacement: SymbolRef, from: OperationRef) -> Result<(), Report> { - todo!() + Ok(()) } /// Returns true if this operation is a declaration, rather than a definition, of a symbol diff --git a/hir2/src/ir/module.rs b/hir2/src/ir/module.rs index a3e915bff..9cc4ee2a9 100644 --- a/hir2/src/ir/module.rs +++ b/hir2/src/ir/module.rs @@ -1,35 +1,141 @@ use alloc::collections::BTreeMap; use crate::{ - derive, + derive::operation, dialects::hir::HirDialect, - traits::{NoRegionArguments, SingleBlock, SingleRegion}, - Ident, InsertionPoint, Operation, Report, SymbolName, SymbolRef, SymbolTable, + symbol_table::SymbolUsesIter, + traits::{ + GraphRegionNoTerminator, HasOnlyGraphRegion, IsolatedFromAbove, NoRegionArguments, + NoTerminator, RegionKind, RegionKindInterface, SingleBlock, SingleRegion, + }, + Ident, InsertionPoint, Operation, OperationRef, Report, Symbol, SymbolName, SymbolNameAttr, + SymbolRef, SymbolTable, SymbolUseList, SymbolUseRef, Usable, Visibility, }; -derive! { - pub struct Module : Op { - #[dialect] - dialect: HirDialect, - #[attr] - name: Ident, - #[region] - body: RegionRef, - registry: BTreeMap, +#[operation( + dialect = HirDialect, + traits( + SingleRegion, + SingleBlock, + NoRegionArguments, + NoTerminator, + HasOnlyGraphRegion, + GraphRegionNoTerminator, + IsolatedFromAbove, + ), + implements(RegionKindInterface, SymbolTable, Symbol) +)] +pub struct Module { + #[attr] + name: Ident, + #[attr] + #[default] + visibility: Visibility, + #[region] + body: RegionRef, + #[default] + registry: BTreeMap, + #[default] + uses: SymbolUseList, +} + +impl RegionKindInterface for Module { + #[inline(always)] + fn kind(&self) -> RegionKind { + RegionKind::Graph + } +} + +impl Usable for Module { + type Use = crate::SymbolUse; + + #[inline(always)] + fn uses(&self) -> &SymbolUseList { + &self.uses } - derives SingleRegion, SingleBlock, NoRegionArguments; - implements SymbolTable; + #[inline(always)] + fn uses_mut(&mut self) -> &mut SymbolUseList { + &mut self.uses + } +} + +impl Symbol for Module { + #[inline(always)] + fn as_symbol_operation(&self) -> &Operation { + &self.op + } + + #[inline(always)] + fn as_symbol_operation_mut(&mut self) -> &mut Operation { + &mut self.op + } + + fn name(&self) -> SymbolName { + Module::name(self).as_symbol() + } + + fn set_name(&mut self, name: SymbolName) { + let id = self.name_mut(); + id.name = name; + } + + fn visibility(&self) -> Visibility { + *Module::visibility(self) + } + + fn set_visibility(&mut self, visibility: Visibility) { + *self.visibility_mut() = visibility; + } + + fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter { + SymbolUsesIter::from_iter(self.uses.iter().filter_map(|user| { + if OperationRef::ptr_eq(&from, &user.owner) { + Some(unsafe { SymbolUseRef::from_raw(&*user) }) + } else if from.borrow().is_proper_ancestor_of(user.owner.clone()) { + Some(unsafe { SymbolUseRef::from_raw(&*user) }) + } else { + None + } + })) + } + + fn replace_all_uses( + &mut self, + replacement: SymbolRef, + from: OperationRef, + ) -> Result<(), Report> { + for symbol_use in self.symbol_uses(from) { + let (mut owner, attr_name) = { + let user = symbol_use.borrow(); + (user.owner.clone(), user.symbol) + }; + let mut owner = owner.borrow_mut(); + // Unlink previously used symbol + { + let current_symbol = owner + .get_typed_attribute_mut::(&attr_name) + .expect("stale symbol user"); + unsafe { + self.uses.cursor_mut_from_ptr(current_symbol.user.clone()).remove(); + } + } + // Link replacement symbol + owner.set_symbol_attribute(attr_name, replacement.clone()); + } + + Ok(()) + } } impl SymbolTable for Module { #[inline(always)] - fn as_operation(&self) -> &Operation { + fn as_symbol_table_operation(&self) -> &Operation { &self.op } #[inline(always)] - fn as_operation_mut(&mut self) -> &mut Operation { + fn as_symbol_table_operation_mut(&mut self) -> &mut Operation { &mut self.op } @@ -37,32 +143,44 @@ impl SymbolTable for Module { self.registry.get(&name).cloned() } - //TODO(pauls): Insert symbol ref in module body fn insert_new(&mut self, entry: SymbolRef, ip: Option) -> bool { - let symbol = entry.borrow(); - let name = symbol.name(); - if self.registry.contains_key(&name) { - return false; + use crate::{BlockRef, Builder, OpBuilder}; + let op = { + let symbol = entry.borrow(); + let name = symbol.name(); + if self.registry.contains_key(&name) { + return false; + } + let op = symbol.as_operation_ref(); + drop(symbol); + self.registry.insert(name, entry.clone()); + op + }; + let mut builder = OpBuilder::new(self.op.context_rc()); + if let Some(ip) = ip { + builder.set_insertion_point(ip); + } else { + builder.set_insertion_point_to_end(unsafe { BlockRef::from_raw(&*self.body().entry()) }) } - drop(symbol); - self.registry.insert(name, entry); + builder.insert(op); true } - //TODO(pauls): Insert symbol ref in module body fn insert(&mut self, mut entry: SymbolRef, ip: Option) -> SymbolName { - let mut symbol = entry.borrow_mut(); - let mut name = symbol.name(); - if self.registry.contains_key(&name) { - // Unique the symbol name - let mut counter = 0; - name = super::symbol_table::generate_symbol_name(name, &mut counter, |name| { - self.registry.contains_key(name) - }); - symbol.set_name(name); - } - drop(symbol); - self.registry.insert(name, entry); + let name = { + let mut symbol = entry.borrow_mut(); + let mut name = symbol.name(); + if self.registry.contains_key(&name) { + // Unique the symbol name + let mut counter = 0; + name = crate::symbol_table::generate_symbol_name(name, &mut counter, |name| { + self.registry.contains_key(name) + }); + symbol.set_name(name); + } + name + }; + self.insert_new(entry, ip); name } @@ -84,8 +202,16 @@ impl SymbolTable for Module { sym.set_name(to); let uses = sym.uses_mut(); let mut cursor = uses.front_mut(); - while let Some(mut next_use) = cursor.get_mut() { - next_use.symbol.name = to; + while let Some(mut next_use) = cursor.as_pointer() { + { + let mut next_use = next_use.borrow_mut(); + let mut op = next_use.owner.borrow_mut(); + let symbol_name = op + .get_typed_attribute_mut::(&next_use.symbol) + .expect("stale symbol user"); + symbol_name.name = to; + } + cursor.move_next(); } Ok(()) diff --git a/hir2/src/ir/op.rs b/hir2/src/ir/op.rs index 440a911a7..c8ce1b1a3 100644 --- a/hir2/src/ir/op.rs +++ b/hir2/src/ir/op.rs @@ -1,8 +1,21 @@ -use core::any::Any; - use super::*; +use crate::any::AsAny; + +pub trait OpRegistration: Op { + fn name() -> ::midenc_hir_symbol::Symbol; +} + +pub trait BuildableOp: Op { + type Builder<'a, T>: FnOnce, crate::Report>> + + 'a + where + T: ?Sized + Builder + 'a; + fn builder<'b, B>(builder: &'b mut B, span: SourceSpan) -> Self::Builder<'b, B> + where + B: ?Sized + Builder + 'b; +} -pub trait Op: Any + OpVerifier { +pub trait Op: AsAny + OpVerifier { /// The name of this operation's opcode /// /// The opcode must be distinct from all other opcodes in the same dialect @@ -10,6 +23,9 @@ pub trait Op: Any + OpVerifier { fn as_operation(&self) -> &Operation; fn as_operation_mut(&mut self) -> &mut Operation; + fn set_span(&mut self, span: SourceSpan) { + self.as_operation_mut().set_span(span); + } fn parent(&self) -> Option { self.as_operation().parent() } @@ -31,6 +47,12 @@ pub trait Op: Any + OpVerifier { fn region_mut(&mut self, index: usize) -> EntityMut<'_, Region> { self.as_operation_mut().region_mut(index) } + fn has_successors(&self) -> bool { + self.as_operation().has_successors() + } + fn num_successors(&self) -> usize { + self.as_operation().num_successors() + } fn has_operands(&self) -> bool { self.as_operation().has_operands() } @@ -43,16 +65,16 @@ pub trait Op: Any + OpVerifier { fn operands_mut(&mut self) -> &mut OpOperandStorage { self.as_operation_mut().operands_mut() } - fn results(&self) -> &[OpResultRef] { + fn results(&self) -> &OpResultStorage { self.as_operation().results() } - fn results_mut(&mut self) -> &mut [OpResultRef] { + fn results_mut(&mut self) -> &mut OpResultStorage { self.as_operation_mut().results_mut() } - fn successors(&self) -> &[OpSuccessor] { + fn successors(&self) -> &OpSuccessorStorage { self.as_operation().successors() } - fn successors_mut(&mut self) -> &mut [OpSuccessor] { + fn successors_mut(&mut self) -> &mut OpSuccessorStorage { self.as_operation_mut().successors_mut() } } diff --git a/hir2/src/ir/operands.rs b/hir2/src/ir/operands.rs index 48dd5c15b..15d1c46f2 100644 --- a/hir2/src/ir/operands.rs +++ b/hir2/src/ir/operands.rs @@ -1,6 +1,4 @@ -use core::{fmt, num::NonZeroU16}; - -use smallvec::{smallvec, SmallVec}; +use core::fmt; use crate::{EntityRef, OperationRef, Type, UnsafeIntrusiveEntityRef, Value, ValueId, ValueRef}; @@ -36,14 +34,12 @@ impl OpOperandImpl { self.value.borrow() } - pub fn unlink(&mut self) { - let ptr = unsafe { OpOperand::from_raw(self as *mut Self) }; - let mut value = self.value.borrow_mut(); - let uses = value.uses_mut(); - unsafe { - let mut cursor = uses.cursor_mut_from_ptr(ptr); - cursor.remove(); - } + pub fn owner(&self) -> EntityRef<'_, crate::Operation> { + self.owner.borrow() + } + + pub fn ty(&self) -> crate::Type { + self.value().ty().clone() } } impl fmt::Debug for OpOperandImpl { @@ -64,501 +60,32 @@ impl fmt::Debug for OpOperandImpl { .finish_non_exhaustive() } } - -#[derive(Default, Copy, Clone)] -struct OpOperandGroup(Option); -impl OpOperandGroup { - const START_MASK: u16 = u8::MAX as u16; - - fn new(start: usize, len: usize) -> Self { - if len == 0 { - return Self::default(); - } - - let start = u16::try_from(start).expect("too many operands"); - let len = u16::try_from(len).expect("operand group too large"); - let group = start | (len << 8); - - Self(Some(unsafe { NonZeroU16::new_unchecked(group) })) - } - - #[allow(unused)] - #[inline] - pub fn start(&self) -> Option { - Some((self.0?.get() & Self::START_MASK) as usize) - } - - #[inline] - pub fn end(&self) -> Option { - self.as_range().map(|range| range.end) - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.0.is_none() - } - - #[allow(unused)] - #[inline] - pub fn len(&self) -> usize { - self.0.as_ref().map(|group| (group.get() >> 8) as usize).unwrap_or(0) - } - - pub fn as_range(&self) -> Option> { - let raw = self.0?.get(); - let start = (raw & Self::START_MASK) as usize; - let len = (raw >> 8) as usize; - Some(start..(start + len)) - } - - pub fn increase_size(&mut self, size: usize) { - let group = self.0.as_mut().expect("expected non-empty group"); - let raw = group.get(); - let size = u16::try_from(size).expect("too many operands"); - let start = raw & Self::START_MASK; - let len = (raw >> 8) + size; - assert!(len <= u8::MAX as u16, "operand group is too large"); - *group = unsafe { NonZeroU16::new_unchecked(start | (len << 8)) }; - } - - pub fn decrease_size(&mut self, size: usize) { - let group = self.0.as_mut().expect("expected non-empty group"); - let raw = group.get(); - let size = u16::try_from(size).expect("too many operands"); - let len = (raw >> 8) - size; - if len > 0 { - let start = raw & Self::START_MASK; - *group = unsafe { NonZeroU16::new_unchecked(start | (len << 8)) }; - } else { - self.0 = None; - } - } - - pub fn shift_start(&mut self, offset: isize) { - let offset = i16::try_from(offset).expect("offset too large"); - if let Some(group) = self.0.as_mut() { - let raw = group.get(); - let mut start = raw & Self::START_MASK; - if offset >= 0 { - start += offset as u16; - } else { - start -= offset.unsigned_abs(); - } - assert!(start <= Self::START_MASK, "too many operands"); - // Clear previous start value - let raw = raw & !Self::START_MASK; - *group = unsafe { NonZeroU16::new_unchecked(raw | start) }; - } - } -} - -pub struct OpOperandStorage { - operands: SmallVec<[OpOperand; 1]>, - groups: SmallVec<[OpOperandGroup; 2]>, -} -impl fmt::Debug for OpOperandStorage { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpOperandStorage") - .field_with("groups", |f| { - let mut builder = f.debug_list(); - for group in self.groups.iter() { - match group.as_range() { - Some(range) => { - let operands = &self.operands[range.clone()]; - builder.entry_with(|f| { - f.debug_map() - .entry(&"range", &range) - .entry(&"operands", &operands) - .finish() - }); - } - None => { - builder.entry(&""); - } - } - } - builder.finish() - }) - .finish() - } -} -impl Default for OpOperandStorage { - fn default() -> Self { - Self { - operands: Default::default(), - groups: smallvec![OpOperandGroup::default()], - } +impl crate::Spanned for OpOperandImpl { + fn span(&self) -> crate::SourceSpan { + self.value.borrow().span() } } -impl OpOperandStorage { - #[inline] - pub fn is_empty(&self) -> bool { - self.operands.is_empty() - } - - #[inline] - pub fn len(&self) -> usize { - self.operands.len() - } - - #[inline] - pub fn iter(&self) -> core::slice::Iter<'_, OpOperand> { - self.operands.iter() - } - - #[inline] - pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, OpOperand> { - self.operands.iter_mut() - } - - /// Push operand to the last operand group - pub fn push_operand(&mut self, mut operand: OpOperand) { - let index = self.operands.len() as u8; - operand.borrow_mut().index = index; - self.operands.push(operand); - let group = self.groups.last_mut().unwrap(); - if group.is_empty() { - *group = OpOperandGroup::new(self.operands.len(), 1); - return; - } - group.increase_size(1); - } - - /// Push operand to the specified group - pub fn push_operand_to_group(&mut self, group: usize, operand: OpOperand) { - if self.groups.len() <= group { - self.groups.resize(group + 1, OpOperandGroup::default()); - } - let mut group = self.group_mut(group); - group.push(operand); - } - - /// Create operand group with index `group`, allocating any intervening groups if missing - pub fn push_operands_to_group(&mut self, group: usize, operands: I) - where - I: IntoIterator, - { - if self.groups.len() <= group { - self.groups.resize(group + 1, OpOperandGroup::default()); - } - let mut group = self.group_mut(group); - group.extend(operands); - } - - /// Push multiple operands to the last operand group - pub fn extend(&mut self, operands: I) - where - I: IntoIterator, - { - let mut group = self.group_mut(self.groups.len() - 1); - group.extend(operands); +impl crate::StorableEntity for OpOperandImpl { + #[inline(always)] + fn index(&self) -> usize { + self.index as usize } - pub fn clear(&mut self) { - for mut operand in self.operands.drain(..) { - let mut operand = operand.borrow_mut(); - operand.unlink(); - } - self.groups.clear(); - self.groups.push(OpOperandGroup::default()); - } - - /// Get all the operands - pub fn all(&self) -> OpOperandRange<'_> { - OpOperandRange { - range: 0..self.operands.len(), - operands: self.operands.as_slice(), - } - } - - /// Get operands for the specified group - pub fn group(&self, group: usize) -> OpOperandRange<'_> { - OpOperandRange { - range: self.groups[group].as_range().unwrap_or(0..0), - operands: self.operands.as_slice(), - } + unsafe fn set_index(&mut self, index: usize) { + self.index = index.try_into().expect("too many operands"); } - /// Get operands for the specified group - pub fn group_mut(&mut self, group: usize) -> OpOperandRangeMut<'_> { - let range = self.groups[group].as_range(); - OpOperandRangeMut { - group, - range, - groups: &mut self.groups, - operands: &mut self.operands, - } - } -} -impl core::ops::Index for OpOperandStorage { - type Output = OpOperand; - - #[inline] - fn index(&self, index: usize) -> &Self::Output { - &self.operands[index] - } -} -impl core::ops::IndexMut for OpOperandStorage { - #[inline] - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - &mut self.operands[index] - } -} - -/// A reference to a range of operands in [OpOperandStorage] -pub struct OpOperandRange<'a> { - range: core::ops::Range, - operands: &'a [OpOperand], -} -impl<'a> OpOperandRange<'a> { - #[inline] - pub fn is_empty(&self) -> bool { - self.as_slice().is_empty() - } - - #[inline] - pub fn len(&self) -> usize { - self.as_slice().len() - } - - #[inline] - pub fn as_slice(&self) -> &[OpOperand] { - &self.operands[self.range.start..self.range.end] - } - - #[inline] - pub fn iter(&self) -> core::slice::Iter<'_, OpOperand> { - self.as_slice().iter() - } - - #[inline] - pub fn get(&self, index: usize) -> Option<&OpOperand> { - self.as_slice().get(index) - } -} -impl<'a> core::ops::Index for OpOperandRange<'a> { - type Output = OpOperand; - - #[inline] - fn index(&self, index: usize) -> &Self::Output { - &self.as_slice()[index] - } -} - -/// A mutable range of operands in [OpOperandStorage] -/// -/// Operands outside the range are not modified, however the range itself can have its size change, -/// which as a result will shift other operands around. Any other groups in [OpOperandStorage] will -/// be updated to reflect such changes, so in general this should be transparent. -pub struct OpOperandRangeMut<'a> { - group: usize, - range: Option>, - groups: &'a mut [OpOperandGroup], - operands: &'a mut SmallVec<[OpOperand; 1]>, -} -impl<'a> OpOperandRangeMut<'a> { - #[inline] - pub fn is_empty(&self) -> bool { - self.as_slice().is_empty() - } - - #[inline] - pub fn len(&self) -> usize { - self.as_slice().len() - } - - #[inline] - pub fn push(&mut self, operand: OpOperand) { - self.extend([operand]); - } - - pub fn extend(&mut self, operands: I) - where - I: IntoIterator, - { - // Handle edge case where group is the last group - let is_last = self.groups.len() == self.group + 1; - let is_empty = self.range.is_none(); - - if is_last && is_empty { - let prev_len = self.operands.len(); - self.operands.extend(operands.into_iter().enumerate().map(|(i, mut operand)| { - let mut operand_mut = operand.borrow_mut(); - operand_mut.index = (prev_len + i) as u8; - drop(operand_mut); - operand - })); - let num_inserted = self.operands.len().abs_diff(prev_len); - if num_inserted == 0 { - return; - } - self.groups[self.group] = OpOperandGroup::new(self.operands.len(), num_inserted); - self.range = self.groups[self.group].as_range(); - } else if is_last { - self.extend_last(operands); - } else { - self.extend_within(operands); - } - } - - fn extend_last(&mut self, operands: I) - where - I: IntoIterator, - { - let prev_len = self.operands.len(); - self.operands.extend(operands.into_iter().enumerate().map(|(i, mut operand)| { - let mut operand_mut = operand.borrow_mut(); - operand_mut.index = (prev_len + i) as u8; - drop(operand_mut); - operand - })); - let num_inserted = self.operands.len().abs_diff(prev_len); - if num_inserted == 0 { - return; - } - self.groups[self.group].increase_size(num_inserted); - self.range = self.groups[self.group].as_range(); - } - - fn extend_within(&mut self, operands: I) - where - I: IntoIterator, - { - let prev_len = self.operands.len(); - let num_inserted; - - match self.range.as_mut() { - Some(range) => { - let start = range.end; - self.operands.insert_many( - range.end, - operands.into_iter().enumerate().map(|(i, mut operand)| { - let mut operand_mut = operand.borrow_mut(); - operand_mut.index = (start + i) as u8; - drop(operand_mut); - operand - }), - ); - num_inserted = self.operands.len().abs_diff(prev_len); - if num_inserted == 0 { - return; - } - self.groups[self.group].increase_size(num_inserted); - range.end += num_inserted; - } - None => { - let start = self.groups[..self.group] - .iter() - .rev() - .filter_map(OpOperandGroup::end) - .next() - .unwrap_or(0); - self.operands.insert_many( - start, - operands.into_iter().enumerate().map(|(i, mut operand)| { - let mut operand_mut = operand.borrow_mut(); - operand_mut.index = (start + i) as u8; - drop(operand_mut); - operand - }), - ); - num_inserted = self.operands.len().abs_diff(prev_len); - if num_inserted == 0 { - return; - } - self.groups[self.group] = OpOperandGroup::new(start, num_inserted); - self.range = self.groups[self.group].as_range(); - } - } - - // Shift groups - for group in self.groups[(self.group + 1)..].iter_mut() { - if group.is_empty() { - continue; - } - group.shift_start(num_inserted as isize); - } - - // Shift operand indices - let shifted = self.range.as_ref().unwrap().end; - for operand in self.operands[shifted..].iter_mut() { - let mut operand_mut = operand.borrow_mut(); - operand_mut.index += 1; - } - } - - pub fn pop(&mut self) -> Option { - let range = self.range.as_mut()?; - let index = range.end; - range.end -= 1; - if (*range).is_empty() { - self.range = None; - } - self.groups[self.group].decrease_size(1); - let mut removed = self.operands.remove(index); - { - let mut operand_mut = removed.borrow_mut(); - operand_mut.unlink(); - } - - // Shift groups - for group in self.groups[(self.group + 1)..].iter_mut() { - if group.is_empty() { - continue; - } - group.shift_start(-1); - } - - // Shift operand indices - for operand in self.operands[index..].iter_mut() { - let mut operand_mut = operand.borrow_mut(); - operand_mut.index -= 1; + fn unlink(&mut self) { + let ptr = unsafe { OpOperand::from_raw(self as *mut Self) }; + let mut value = self.value.borrow_mut(); + let uses = value.uses_mut(); + unsafe { + let mut cursor = uses.cursor_mut_from_ptr(ptr); + cursor.remove(); } - - Some(removed) - } - - #[inline] - pub fn as_slice(&self) -> &[OpOperand] { - &self.operands[self.range.clone().unwrap_or(0..0)] - } - - #[inline] - pub fn as_slice_mut(&mut self) -> &mut [OpOperand] { - &mut self.operands[self.range.clone().unwrap_or(0..0)] - } - - #[inline] - pub fn iter(&self) -> core::slice::Iter<'_, OpOperand> { - self.as_slice().iter() - } - - #[inline] - pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, OpOperand> { - self.as_slice_mut().iter_mut() - } - - #[inline] - pub fn get(&self, index: usize) -> Option<&OpOperand> { - self.as_slice().get(index) - } - - #[inline] - pub fn get_mut(&mut self, index: usize) -> Option<&mut OpOperand> { - self.as_slice_mut().get_mut(index) } } -impl<'a> core::ops::Index for OpOperandRangeMut<'a> { - type Output = OpOperand; - #[inline] - fn index(&self, index: usize) -> &Self::Output { - &self.as_slice()[index] - } -} -impl<'a> core::ops::IndexMut for OpOperandRangeMut<'a> { - #[inline] - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - &mut self.as_slice_mut()[index] - } -} +pub type OpOperandStorage = crate::EntityStorage; +pub type OpOperandRange<'a> = crate::EntityRange<'a, OpOperand>; +pub type OpOperandRangeMut<'a> = crate::EntityRangeMut<'a, OpOperand, 1>; diff --git a/hir2/src/ir/operation.rs b/hir2/src/ir/operation.rs index 57cc5bcab..eacd097a3 100644 --- a/hir2/src/ir/operation.rs +++ b/hir2/src/ir/operation.rs @@ -1,13 +1,12 @@ mod builder; mod name; +use alloc::rc::Rc; use core::{ fmt, - ptr::{DynMetadata, Pointee}, + ptr::{DynMetadata, NonNull, Pointee}, }; -use smallvec::SmallVec; - pub use self::{builder::OperationBuilder, name::OperationName}; use super::*; @@ -67,11 +66,16 @@ pub type OpCursorMut<'a> = EntityCursorMut<'a, Operation>; /// allocate IR entities wherever we want and use them the same way. #[derive(Spanned)] pub struct Operation { - /// In order to support upcasting from [Operation] to its concrete [Op] type, as well as - /// casting to any of the operation traits it implements, we need our own vtable that lets - /// us track the individual vtables for each type and trait we need to cast to for this - /// instance. - pub(crate) vtable: traits::MultiTraitVtable, + /// The [Context] in which this [Operation] was allocated. + context: NonNull, + /// The dialect and opcode name for this operation, as well as trait implementation metadata + name: OperationName, + /// The offset of the field containing this struct inside the concrete [Op] it represents. + /// + /// This is required in order to be able to perform casts from [Operation]. An [Operation] + /// cannot be constructed without providing it to the `uninit` function, and callers of that + /// function are required to ensure that it is correct. + offset: usize, #[span] pub span: SourceSpan, /// Attributes that apply to this operation @@ -88,10 +92,10 @@ pub struct Operation { /// what they are used for. pub operands: OpOperandStorage, /// The set of values produced by this operation. - pub results: SmallVec<[OpResultRef; 1]>, + pub results: OpResultStorage, /// If this operation represents control flow, this field stores the set of successors, /// and successor operands. - pub successors: SmallVec<[OpSuccessor; 1]>, + pub successors: OpSuccessorStorage, /// The set of regions belonging to this operation, if any pub regions: RegionList, } @@ -109,25 +113,25 @@ impl fmt::Debug for Operation { } impl AsRef for Operation { fn as_ref(&self) -> &dyn Op { - self.vtable.downcast_trait().unwrap() + self.name.upcast(self.container()).unwrap() } } impl AsMut for Operation { fn as_mut(&mut self) -> &mut dyn Op { - self.vtable.downcast_trait_mut().unwrap() + self.name.upcast_mut(self.container().cast_mut()).unwrap() } } /// Construction impl Operation { - pub fn uninit() -> Self { - use super::traits::MultiTraitVtable; - - let mut vtable = MultiTraitVtable::new::(); - vtable.register_trait::(); + #[doc(hidden)] + pub unsafe fn uninit(context: Rc, name: OperationName, offset: usize) -> Self { + assert!(name.is::()); Self { - vtable, + context: unsafe { NonNull::new_unchecked(Rc::as_ptr(&context).cast_mut()) }, + name, + offset, span: Default::default(), attrs: Default::default(), block: Default::default(), @@ -141,21 +145,76 @@ impl Operation { /// Metadata impl Operation { + /// Get the name of this operation + /// + /// An operation name consists of both its dialect, and its opcode. pub fn name(&self) -> OperationName { - AsRef::::as_ref(self).name() + //AsRef::::as_ref(self).name() + self.name.clone() + } + + /// Set the source location associated with this operation + #[inline] + pub fn set_span(&mut self, span: SourceSpan) { + self.span = span; + } + + /// Get a borrowed reference to the owning [Context] of this operation + #[inline(always)] + pub fn context(&self) -> &Context { + // SAFETY: This is safe so long as this operation is allocated in a Context, since the + // Context by definition outlives the allocation. + unsafe { self.context.as_ref() } + } + + /// Get a owned reference to the owning [Context] of this operation + pub fn context_rc(&self) -> Rc { + // SAFETY: This is safe so long as this operation is allocated in a Context, since the + // Context by definition outlives the allocation. + // + // Additionally, constructing the Rc from a raw pointer is safe here, as the pointer was + // obtained using `Rc::as_ptr`, so the only requirement to call `Rc::from_raw` is to + // increment the strong count, as `as_ptr` does not preserve the count for the reference + // held by this operation. Incrementing the count first is required to manufacture new + // clones of the `Rc` safely. + unsafe { + let ptr = self.context.as_ptr().cast_const(); + Rc::increment_strong_count(ptr); + Rc::from_raw(ptr) + } } } /// Verification impl Operation { + /// Run any verifiers for this operation pub fn verify(&self, context: &Context) -> Result<(), Report> { let dyn_op: &dyn Op = self.as_ref(); dyn_op.verify(context) } + + /// Run any verifiers for this operation, and all of its nested operations, recursively. + /// + /// The verification is performed in post-order, so that when the verifier(s) for `self` are + /// run, it is known that all of its children have successfully verified. + pub fn recursively_verify(&self, context: &Context) -> Result<(), Report> { + self.postwalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + op.verify(context).into() + }) + .into_result() + } } /// Traits/Casts impl Operation { + pub(super) const fn container(&self) -> *const () { + unsafe { + let ptr = self as *const Self; + ptr.byte_sub(self.offset).cast() + } + } + #[inline(always)] pub fn as_operation_ref(&self) -> OperationRef { // SAFETY: This is safe under the assumption that we always allocate Operations using the @@ -166,7 +225,7 @@ impl Operation { /// Returns true if the concrete type of this operation is `T` #[inline] pub fn is(&self) -> bool { - self.vtable.is::() + self.name.is::() } /// Returns true if this operation implements `Trait` @@ -175,17 +234,17 @@ impl Operation { where Trait: ?Sized + Pointee> + 'static, { - self.vtable.implements::() + self.name.implements::() } /// Attempt to downcast to the concrete [Op] type of this operation pub fn downcast_ref(&self) -> Option<&T> { - self.vtable.downcast_ref::() + self.name.downcast_ref::(self.container()) } /// Attempt to downcast to the concrete [Op] type of this operation pub fn downcast_mut(&mut self) -> Option<&mut T> { - self.vtable.downcast_mut::() + self.name.downcast_mut::(self.container().cast_mut()) } /// Attempt to cast this operation reference to an implementation of `Trait` @@ -193,7 +252,7 @@ impl Operation { where Trait: ?Sized + Pointee> + 'static, { - self.vtable.downcast_trait() + self.name.upcast(self.container()) } /// Attempt to cast this operation reference to an implementation of `Trait` @@ -201,12 +260,24 @@ impl Operation { where Trait: ?Sized + Pointee> + 'static, { - self.vtable.downcast_trait_mut() + self.name.upcast_mut(self.container().cast_mut()) } } /// Attributes impl Operation { + /// Get the underlying attribute set for this operation + #[inline(always)] + pub fn attributes(&self) -> &AttributeSet { + &self.attrs + } + + /// Get a mutable reference to the underlying attribute set for this operation + #[inline(always)] + pub fn attributes_mut(&mut self) -> &mut AttributeSet { + &mut self.attrs + } + /// Return the value associated with attribute `name` for this function pub fn get_attribute(&self, name: &Q) -> Option<&dyn AttributeValue> where @@ -275,6 +346,58 @@ impl Operation { } } +/// Symbol Attributes +impl Operation { + pub fn set_symbol_attribute( + &mut self, + name: impl Into, + symbol: impl AsSymbolRef, + ) { + let name = name.into(); + let mut symbol = symbol.as_symbol_ref(); + + // Store the underlying attribute value + let user = self.context().alloc_tracked(SymbolUse { + owner: self.as_operation_ref(), + symbol: name, + }); + if self.has_attribute(&name) { + let attr = self.get_typed_attribute_mut::(&name).unwrap(); + let symbol = symbol.borrow(); + assert!( + !attr.user.is_linked(), + "attempted to replace symbol use without unlinking the previously used symbol \ + first" + ); + attr.user = user.clone(); + attr.name = symbol.name(); + attr.path = symbol.components().into_path(true); + } else { + let attr = { + let symbol = symbol.borrow(); + let name = symbol.name(); + let path = symbol.components().into_path(true); + SymbolNameAttr { + name, + path, + user: user.clone(), + } + }; + self.set_attribute(name, Some(attr)); + } + + // Add `self` as a user of `symbol`, unless `self` is `symbol` + let (data_ptr, _) = SymbolRef::as_ptr(&symbol).to_raw_parts(); + if core::ptr::addr_eq(data_ptr, self.container()) { + return; + } + + let mut symbol = symbol.borrow_mut(); + let symbol_uses = symbol.uses_mut(); + symbol_uses.push_back(user); + } +} + /// Navigation impl Operation { /// Returns a handle to the containing [Block] of this operation, if it is attached to one @@ -310,26 +433,36 @@ impl Operation { /// Regions impl Operation { + /// Returns true if this operation has any regions #[inline] pub fn has_regions(&self) -> bool { !self.regions.is_empty() } + /// Returns the number of regions owned by this operation. + /// + /// NOTE: This does not include regions of nested operations, just those directly attached + /// to this operation. #[inline] pub fn num_regions(&self) -> usize { self.regions.len() } + /// Get a reference to the region list for this operation #[inline(always)] pub fn regions(&self) -> &RegionList { &self.regions } + /// Get a mutable reference to the region list for this operation #[inline(always)] pub fn regions_mut(&mut self) -> &mut RegionList { &mut self.regions } + /// Get a reference to a specific region, given its index. + /// + /// This function will panic if the index is invalid. pub fn region(&self, index: usize) -> EntityRef<'_, Region> { let mut cursor = self.regions.front(); let mut count = 0; @@ -343,6 +476,9 @@ impl Operation { panic!("invalid region index {index}: out of bounds"); } + /// Get a mutable reference to a specific region, given its index. + /// + /// This function will panic if the index is invalid. pub fn region_mut(&mut self, index: usize) -> EntityMut<'_, Region> { let mut cursor = self.regions.front_mut(); let mut count = 0; @@ -359,49 +495,134 @@ impl Operation { /// Successors impl Operation { + /// Returns true if this operation has any successor blocks #[inline] pub fn has_successors(&self) -> bool { !self.successors.is_empty() } + /// Returns the number of successor blocks this operation may transfer control to #[inline] pub fn num_successors(&self) -> usize { self.successors.len() } + /// Get a reference to the successors of this operation #[inline(always)] - pub fn successors(&self) -> &[OpSuccessor] { + pub fn successors(&self) -> &OpSuccessorStorage { &self.successors } + /// Get a mutable reference to the successors of this operation #[inline(always)] - pub fn successors_mut(&mut self) -> &mut [OpSuccessor] { + pub fn successors_mut(&mut self) -> &mut OpSuccessorStorage { &mut self.successors } + + /// Get a reference to the successor group at `index` + #[inline] + pub fn successor_group(&self, index: usize) -> OpSuccessorRange<'_> { + self.successors.group(index) + } + + /// Get a mutable reference to the successor group at `index` + #[inline] + pub fn successor_group_mut(&mut self, index: usize) -> OpSuccessorRangeMut<'_> { + self.successors.group_mut(index) + } + + /// Get a reference to the keyed successor group at `index` + #[inline] + pub fn keyed_successor_group(&self, index: usize) -> KeyedSuccessorRange<'_, T> + where + T: KeyedSuccessor, + { + let range = self.successors.group(index); + KeyedSuccessorRange::new(range, &self.operands) + } + + /// Get a mutable reference to the keyed successor group at `index` + #[inline] + pub fn keyed_successor_group_mut(&mut self, index: usize) -> KeyedSuccessorRangeMut<'_, T> + where + T: KeyedSuccessor, + { + let range = self.successors.group_mut(index); + KeyedSuccessorRangeMut::new(range, &mut self.operands) + } + + /// Get a reference to the successor at `index` in the group at `group_index` + #[inline] + pub fn successor_in_group(&self, group_index: usize, index: usize) -> OpSuccessor<'_> { + let info = &self.successors.group(group_index)[index]; + OpSuccessor { + dest: info.block.clone(), + arguments: self.operands.group(info.operand_group as usize), + } + } + + /// Get a mutable reference to the successor at `index` in the group at `group_index` + #[inline] + pub fn successor_in_group_mut( + &mut self, + group_index: usize, + index: usize, + ) -> OpSuccessorMut<'_> { + let info = &self.successors.group(group_index)[index]; + OpSuccessorMut { + dest: info.block.clone(), + arguments: self.operands.group_mut(info.operand_group as usize), + } + } + + /// Get a reference to the successor at `index` + #[inline] + pub fn successor(&self, index: usize) -> OpSuccessor<'_> { + let info = &self.successors[index]; + OpSuccessor { + dest: info.block.clone(), + arguments: self.operands.group(info.operand_group as usize), + } + } + + /// Get a mutable reference to the successor at `index` + #[inline] + pub fn successor_mut(&mut self, index: usize) -> OpSuccessorMut<'_> { + let info = self.successors[index].clone(); + OpSuccessorMut { + dest: info.block, + arguments: self.operands.group_mut(info.operand_group as usize), + } + } } /// Operands impl Operation { + /// Returns true if this operation has at least one operand #[inline] pub fn has_operands(&self) -> bool { !self.operands.is_empty() } + /// Returns the number of operands given to this operation #[inline] pub fn num_operands(&self) -> usize { self.operands.len() } + /// Get a reference to the operand storage for this operation #[inline] pub fn operands(&self) -> &OpOperandStorage { &self.operands } + /// Get a mutable reference to the operand storage for this operation #[inline] pub fn operands_mut(&mut self) -> &mut OpOperandStorage { &mut self.operands } + /// TODO: Remove in favor of [OpBuilder] pub fn replaces_uses_of_with(&mut self, mut from: ValueRef, mut to: ValueRef) { if ValueRef::ptr_eq(&from, &to) { return; @@ -432,23 +653,282 @@ impl Operation { /// Results impl Operation { + /// Returns true if this operation produces any results #[inline] pub fn has_results(&self) -> bool { !self.results.is_empty() } + /// Returns the number of results produced by this operation #[inline] pub fn num_results(&self) -> usize { self.results.len() } + /// Get a reference to the result set of this operation #[inline] - pub fn results(&self) -> &[OpResultRef] { - self.results.as_slice() + pub fn results(&self) -> &OpResultStorage { + &self.results } + /// Get a mutable reference to the result set of this operation #[inline] - pub fn results_mut(&mut self) -> &mut [OpResultRef] { - self.results.as_mut_slice() + pub fn results_mut(&mut self) -> &mut OpResultStorage { + &mut self.results + } +} + +/// Insertion +impl Operation { + pub fn insert_at_start(&mut self, mut block: BlockRef) { + assert!( + self.block.is_none(), + "cannot insert operation that is already attached to another block" + ); + { + let mut block = block.borrow_mut(); + block.body_mut().push_front(unsafe { OperationRef::from_raw(self) }); + } + self.block = Some(block); + } + + pub fn insert_at_end(&mut self, mut block: BlockRef) { + assert!( + self.block.is_none(), + "cannot insert operation that is already attached to another block" + ); + { + let mut block = block.borrow_mut(); + block.body_mut().push_back(unsafe { OperationRef::from_raw(self) }); + } + self.block = Some(block); + } + + pub fn insert_before(&mut self, before: OperationRef) { + assert!( + self.block.is_none(), + "cannot insert operation that is already attached to another block" + ); + let mut block = + before.borrow().parent().expect("'before' block is not attached to a block"); + { + let mut block = block.borrow_mut(); + let block_body = block.body_mut(); + let mut cursor = unsafe { block_body.cursor_mut_from_ptr(before) }; + cursor.insert_before(unsafe { OperationRef::from_raw(self) }); + } + self.block = Some(block); + } + + pub fn insert_after(&mut self, after: OperationRef) { + assert!( + self.block.is_none(), + "cannot insert operation that is already attached to another block" + ); + let mut block = after.borrow().parent().expect("'after' block is not attached to a block"); + { + let mut block = block.borrow_mut(); + let block_body = block.body_mut(); + let mut cursor = unsafe { block_body.cursor_mut_from_ptr(after) }; + cursor.insert_after(unsafe { OperationRef::from_raw(self) }); + } + self.block = Some(block); + } +} + +/// Movement +impl Operation { + /// Remove this operation (and its descendants) from its containing block, and delete them + #[inline] + pub fn erase(&mut self) { + // We don't delete entities currently, so for now this is just an alias for `remove` + self.remove() + } + + /// Remove the operation from its parent block, but don't delete it. + pub fn remove(&mut self) { + let Some(mut parent) = self.block.take() else { + return; + }; + let mut block = parent.borrow_mut(); + let body = block.body_mut(); + let mut cursor = unsafe { body.cursor_mut_from_ptr(OperationRef::from_raw(self)) }; + cursor.remove(); + } + + /// Unlink this operation from its current block and insert it right before `ip`, which may + /// be in the same or another block in the same function. + pub fn move_before(&mut self, ip: ProgramPoint) { + self.remove(); + match ip { + ProgramPoint::Op(other) => { + self.insert_before(other); + } + ProgramPoint::Block(block) => { + self.insert_at_start(block); + } + } + } + + /// Unlink this operation from its current block and insert it right after `ip`, which may + /// be in the same or another block in the same function. + pub fn move_after(&mut self, ip: ProgramPoint) { + self.remove(); + match ip { + ProgramPoint::Op(other) => { + self.insert_after(other); + } + ProgramPoint::Block(block) => { + self.insert_at_end(block); + } + } + } + + /// This drops all operand uses from this operation, which is used to break cyclic dependencies + /// between references when they are to be deleted + pub fn drop_all_references(&mut self) { + self.operands.clear(); + + { + let mut region_cursor = self.regions.front_mut(); + while let Some(mut region) = region_cursor.as_pointer() { + region.borrow_mut().drop_all_references(); + region_cursor.move_next(); + } + } + + self.successors.clear(); + } + + /// This drops all uses of any values defined by this operation or its nested regions, + /// wherever they are located. + pub fn drop_all_defined_value_uses(&mut self) { + for result in self.results.iter_mut() { + let mut res = result.borrow_mut(); + res.uses_mut().clear(); + } + + let mut regions = self.regions.front_mut(); + while let Some(mut region) = regions.as_pointer() { + let mut region = region.borrow_mut(); + let blocks = region.body_mut(); + let mut cursor = blocks.front_mut(); + while let Some(mut block) = cursor.as_pointer() { + block.borrow_mut().drop_all_defined_value_uses(); + cursor.move_next(); + } + regions.move_next(); + } + } +} + +/// Ordering +impl Operation { + /// Returns true if this operation is a proper ancestor of `other` + pub fn is_proper_ancestor_of(&self, other: OperationRef) -> bool { + let this = self.as_operation_ref(); + let mut next = other.borrow().parent_op(); + while let Some(other) = next.take() { + if OperationRef::ptr_eq(&this, &other) { + return true; + } + } + false + } + + /// Given an operation `other` that is within the same parent block, return whether the current + /// operation is before it in the operation list. + /// + /// NOTE: This function has an average complexity of O(1), but worst case may take O(N) where + /// N is the number of operations within the parent block. + pub fn is_before_in_block(&self, _other: OperationRef) -> bool { + /* + let block = self.block().expect("operations without parent blocks have no order"); + let other = other.borrow(); + assert!(other.block().is_some_and(|other_block| BlockRef::ptr_eq(&block, other_block)), "expected both operations to have the same parent block"); + // If the order of the block is already invalid, directly recompute the parent + let block = block.borrow(); + if !block.is_op_order_valid() { + block.recompute_op_order(); + } else { + // Update the order of either operation if necessary. + self.update_order_if_necessary(); + other.update_order_if_necessary(); + } + + self.order < other.order + */ + todo!() + } + + /// Update the order index of this operation of this operation if necessary, + /// potentially recomputing the order of the parent block. + fn update_order_if_necessary(&self) { + /* + assert!(self.block.is_some(), "expected valid parent"); + + let this = self.as_operation_ref(); + + // If the order is valid for this operation there is nothing to do. + let block = self.block.as_ref().unwrap().borrow(); + if self.has_valid_order() || block.body().iter().count() == 1 { + return; + } + + let back = block.body().back().as_pointer(); + let front = block.body().front().as_pointer(); + assert!(!OperationRef::ptr_eq(&front, &back)); + + // If the operation is at the end of the block. + if Operation::ptr_eq(&this, &back) { + let prev = self.get_prev(); + if !prev.borrow().has_valid_order() { + return block.recompute_op_order(); + } + + // Add the stride to the previous operation. + self.order = prev.order + Self::ORDER_STRIDE; + return; + } + + // If this is the first operation try to use the next operation to compute the + // ordering. + if Operation::ptr_eq(&this, &front) { + let next = self.get_next(); + if !next.has_valid_order() { + return block.recompute_op_order(); + } + // There is no order to give this operation. + if next.order == 0 { + return block.recompute_op_order(); + } + + // If we can't use the stride, just take the middle value left. This is safe + // because we know there is at least one valid index to assign to. + if next.order <= Self::ORDER_STRIDE { + self.order = next.order / 2; + } else { + self.order = Self::ORDER_STRIDE; + } + return; + } + + // Otherwise, this operation is between two others. Place this operation in + // the middle of the previous and next if possible. + let prev = self.get_prev(); + let next = self.get_next(); + if !prev.has_valid_order() || !next.has_valid_order() { + return block.recompute_op_order(); + } + let prev_order = prev.order; + let next_order = next.order; + + // Check to see if there is a valid order between the two. + if prev_order + 1 == next_order { + return block.recompute_op_order(); + } + self.order = prev_order + ((next_order - prev_order) / 2); + */ + todo!() } } diff --git a/hir2/src/ir/operation/builder.rs b/hir2/src/ir/operation/builder.rs index a1a7af922..90923b817 100644 --- a/hir2/src/ir/operation/builder.rs +++ b/hir2/src/ir/operation/builder.rs @@ -1,81 +1,77 @@ -use core::{ - marker::Unsize, - ptr::{DynMetadata, Pointee}, -}; - -use super::{Operation, OperationRef}; use crate::{ - verifier, AttributeValue, Context, Op, OpOperandImpl, OpSuccessor, Region, Report, Type, - UnsafeIntrusiveEntityRef, ValueRef, + traits::{AsCallableSymbolRef, Terminator}, + AsSymbolRef, AttributeValue, BlockRef, Builder, KeyedSuccessor, Op, OpBuilder, OperationRef, + Region, Report, Spanned, SuccessorInfo, Type, UnsafeIntrusiveEntityRef, ValueRef, }; -// TODO: We need a safe way to construct arbitrary Ops imperatively: -// -// * Allocate an uninit instance of T -// * Initialize the Operartion field of T with the empty Operation data -// * Use the primary builder methods to mutate Operation fields -// * Use generated methods on Op-specific builders to mutate Op fields -// * At the end, convert uninit T to init T, return handle to caller -// -// Problems: -// -// * How do we default-initialize an instance of T for this purpose -// * If we use MaybeUninit, how do we compute field offsets for the Operation field -// * Generated methods can compute offsets, but how do we generate the specialized builders? -pub struct OperationBuilder<'a, T> { - context: &'a Context, - op: UnsafeIntrusiveEntityRef, +/// The [OperationBuilder] is a primitive for imperatively constructing an [Operation]. +/// +/// Currently, this is primarily used by our `#[operation]` macro infrastructure, to finalize +/// construction of the underlying [Operation] of an [Op] implementation, after both have been +/// allocated and initialized with only basic metadata. This builder is then used to add all of +/// the data under the op, e.g. operands, results, attributes, etc. Once complete, verification is +/// run on the constructed op. +/// +/// Using this directly is possible, see [OperationBuilder::new] for details. You may also find it +/// useful to examine the expansion of the `#[operation]` macro for existing ops to understand what goes +/// on behind the scenes for most ops. +pub struct OperationBuilder<'a, T, B: ?Sized = OpBuilder> { + builder: &'a mut B, + op: OperationRef, _marker: core::marker::PhantomData, } -impl<'a, T: Op> OperationBuilder<'a, T> { - pub fn new(context: &'a Context, op: T) -> Self { - let mut op = context.alloc_tracked(op); - - // SAFETY: Setting the data pointer of the multi-trait vtable must ensure - // that it points to the concrete type of the allocation, which we can guarantee here, - // having just allocated it. Until the data pointer is set, casts using the vtable are - // undefined behavior, so by never allowing the uninitialized vtable to be accessed, - // we can ensure the multi-trait impl is safe - unsafe { - let data_ptr = UnsafeIntrusiveEntityRef::as_ptr(&op); - let mut op_mut = op.borrow_mut(); - op_mut.as_operation_mut().vtable.set_data_ptr(data_ptr.cast_mut()); - } - +impl<'a, T, B> OperationBuilder<'a, T, B> +where + T: Op, + B: ?Sized + Builder, +{ + /// Create a new [OperationBuilder] for `op` using the provided [Builder]. + /// + /// The [Operation] underlying `op` must have been initialized correctly: + /// + /// * Allocated via the same context as `builder` + /// * Initialized via [crate::Operation::uninit] + /// * All op traits implemented by `T` must have been registered with its [OperationName] + /// * All fields of `T` must have been initialized to actual or default values. This builder + /// will invoke verification at the end, and if `T` is not correctly initialized, it will + /// result in undefined behavior. + pub fn new(builder: &'a mut B, op: UnsafeIntrusiveEntityRef) -> Self { + let op = unsafe { UnsafeIntrusiveEntityRef::from_raw(op.borrow().as_operation()) }; Self { - context, + builder, op, _marker: core::marker::PhantomData, } } - /// Register this op as an implementation of `Trait`. - /// - /// This is enforced statically by the type system, as well as dynamically via verification. - /// - /// This must be called for any trait that you wish to be able to cast the type-erased - /// [Operation] to later, or if you wish to get a `dyn Trait` reference from a `dyn Op` - /// reference. - /// - /// If `Trait` has a verifier implementation, it will be automatically applied when calling - /// [Operation::verify]. - pub fn implement(&mut self) - where - Trait: ?Sized + Pointee> + 'static, - T: Unsize + verifier::Verifier + 'static, - { - let mut op = self.op.borrow_mut(); - let operation = op.as_operation_mut(); - operation.vtable.register_trait::(); - } - /// Set attribute `name` on this op to `value` + #[inline] pub fn with_attr(&mut self, name: &'static str, value: A) where A: AttributeValue, { - let mut op = self.op.borrow_mut(); - op.as_operation_mut().attrs.insert(name, Some(value)); + self.op.borrow_mut().set_attribute(name, Some(value)); + } + + /// Set symbol `attr_name` on this op to `symbol`. + /// + /// Symbol references are stored as attributes, and have similar semantics to operands, i.e. + /// they require tracking uses. + #[inline] + pub fn with_symbol(&mut self, attr_name: &'static str, symbol: impl AsSymbolRef) { + self.op.borrow_mut().set_symbol_attribute(attr_name, symbol); + } + + /// Like [with_symbol], but further constrains the range of valid input symbols to those which + /// are valid [CallableOpInterface] implementations. + #[inline] + pub fn with_callable_symbol( + &mut self, + attr_name: &'static str, + callable: impl AsCallableSymbolRef, + ) { + let callable = callable.as_callable_symbol_ref(); + self.op.borrow_mut().set_symbol_attribute(attr_name, callable); } /// Add a new [Region] to this operation. @@ -86,83 +82,165 @@ impl<'a, T: Op> OperationBuilder<'a, T> { pub fn create_region(&mut self) { let mut region = Region::default(); unsafe { - region.set_owner(Some(self.as_operation_ref())); + region.set_owner(Some(self.op.clone())); } - let region = self.context.alloc_tracked(region); + let region = self.builder.context().alloc_tracked(region); let mut op = self.op.borrow_mut(); - op.as_operation_mut().regions.push_back(region); + op.regions.push_back(region); } - pub fn with_successor(&mut self, succ: OpSuccessor) { - todo!() + pub fn with_successor( + &mut self, + dest: BlockRef, + arguments: impl IntoIterator, + ) { + let owner = self.op.clone(); + // Insert operand group for this successor + let mut op = self.op.borrow_mut(); + let operand_group = + op.operands.push_group(arguments.into_iter().enumerate().map(|(index, arg)| { + self.builder.context().make_operand(arg, owner.clone(), index as u8) + })); + // Record SuccessorInfo for this successor in the op + let succ_index = u8::try_from(op.successors.len()).expect("too many successors"); + let successor = self.builder.context().make_block_operand(dest.clone(), owner, succ_index); + op.successors.push_group([SuccessorInfo { + block: successor, + key: None, + operand_group: operand_group.try_into().expect("too many operand groups"), + }]); + } + + pub fn with_successors(&mut self, succs: I) + where + I: IntoIterator)>, + { + let owner = self.op.clone(); + let mut op = self.op.borrow_mut(); + let mut group = vec![]; + for (i, (block, args)) in succs.into_iter().enumerate() { + let block = self.builder.context().make_block_operand(block, owner.clone(), i as u8); + let operands = args + .into_iter() + .map(|value_ref| self.builder.context().make_operand(value_ref, owner.clone(), 0)); + let operand_group = op.operands.push_group(operands); + group.push(SuccessorInfo { + block, + key: None, + operand_group: operand_group.try_into().expect("too many operand groups"), + }); + } + op.successors.push_group(group); } - pub fn with_successors(&mut self, succs: I) + pub fn with_keyed_successors(&mut self, succs: I) where - S: Into, + S: KeyedSuccessor, I: IntoIterator, { - todo!() + let owner = self.op.clone(); + let mut op = self.op.borrow_mut(); + let mut group = vec![]; + for (i, successor) in succs.into_iter().enumerate() { + let (key, block, args) = successor.into_parts(); + let block = self.builder.context().make_block_operand(block, owner.clone(), i as u8); + let operands = args + .into_iter() + .map(|value_ref| self.builder.context().make_operand(value_ref, owner.clone(), 0)); + let operand_group = op.operands.push_group(operands); + let key = Box::new(key); + let key = unsafe { core::ptr::NonNull::new_unchecked(Box::into_raw(key)) }; + group.push(SuccessorInfo { + block, + key: Some(key.cast()), + operand_group: operand_group.try_into().expect("too many operand groups"), + }); + } + op.successors.push_group(group); } - /// Set the operands given to this op + /// Append operands to the set of operands given to this op so far. pub fn with_operands(&mut self, operands: I) where I: IntoIterator, { - // TODO: Verify the safety of this conversion - let owner = self.as_operation_ref(); - let mut op = self.op.borrow_mut(); + let owner = self.op.clone(); let operands = operands.into_iter().enumerate().map(|(index, value)| { - self.context - .alloc_tracked(OpOperandImpl::new(value, owner.clone(), index as u8)) + self.builder.context().make_operand(value, owner.clone(), index as u8) }); - let op_mut = op.as_operation_mut(); - op_mut.operands.clear(); - op_mut.operands.extend(operands); + let mut op = self.op.borrow_mut(); + op.operands.extend(operands); } + /// Append operands to the set of operands in operand group `group` pub fn with_operands_in_group(&mut self, group: usize, operands: I) where I: IntoIterator, { - let owner = self.as_operation_ref(); - let mut op = self.op.borrow_mut(); + let owner = self.op.clone(); let operands = operands.into_iter().enumerate().map(|(index, value)| { - self.context - .alloc_tracked(OpOperandImpl::new(value, owner.clone(), index as u8)) + self.builder.context().make_operand(value, owner.clone(), index as u8) }); - let op_operands = op.operands_mut(); - op_operands.push_operands_to_group(group, operands); + let mut op = self.op.borrow_mut(); + op.operands.extend_group(group, operands); } /// Allocate `n` results for this op, of unknown type, to be filled in later pub fn with_results(&mut self, n: usize) { - let owner = self.as_operation_ref(); + let span = self.op.borrow().span; + let owner = self.op.clone(); + let results = (0..n).map(|idx| { + self.builder + .context() + .make_result(span, Type::Unknown, owner.clone(), idx as u8) + }); let mut op = self.op.borrow_mut(); - let results = - (0..n).map(|idx| self.context.make_result(Type::Unknown, owner.clone(), idx as u8)); - let op_mut = op.as_operation_mut(); - op_mut.results.clear(); - op_mut.results.extend(results); + op.results.clear(); + op.results.extend(results); } /// Consume this builder, verify the op, and return a handle to it, or an error if validation /// failed. - pub fn build(self) -> Result, Report> { + pub fn build(mut self) -> Result, Report> { + let op = { + let mut op = self.op.borrow_mut(); + + // Infer result types and apply any associated validation + if let Some(interface) = op.as_trait_mut::() { + interface.infer_return_types(self.builder.context())?; + } + + // Verify things that would require negative trait impls + if !op.implements::() && op.has_successors() { + return Err(self + .builder + .context() + .session + .diagnostics + .diagnostic(miden_assembly::diagnostics::Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation has successors, but does not implement the 'Terminator' \ + trait", + ) + .with_help("operations with successors must implement the 'Terminator' trait") + .into_report()); + } + + unsafe { UnsafeIntrusiveEntityRef::from_raw(op.container().cast()) } + }; + + // Run op-specific verification { - let op = self.op.borrow(); - op.as_operation().verify(self.context)?; + let op: super::EntityRef = op.borrow(); + //let op = op.borrow(); + op.verify(self.builder.context())?; } - Ok(self.op) - } - #[inline] - fn as_operation_ref(&self) -> OperationRef { - let op = self.op.borrow(); - unsafe { - let ptr = op.as_operation() as *const Operation; - OperationRef::from_raw(ptr) - } + // Insert op at current insertion point + self.builder.insert(self.op); + + Ok(op) } } diff --git a/hir2/src/ir/operation/name.rs b/hir2/src/ir/operation/name.rs index fd462b375..89210fc80 100644 --- a/hir2/src/ir/operation/name.rs +++ b/hir2/src/ir/operation/name.rs @@ -1,29 +1,139 @@ -use core::fmt; +use alloc::rc::Rc; +use core::{ + any::TypeId, + fmt, + ptr::{DynMetadata, Pointee}, +}; -use crate::{interner, DialectName}; +use crate::{interner, traits::TraitInfo, DialectName, Op}; /// The operation name, or mnemonic, that uniquely identifies an operation. /// /// The operation name consists of its dialect name, and the opcode name within the dialect. /// /// No two operation names can share the same fully-qualified operation name. -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct OperationName { +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct OperationName(Rc); + +struct OperationInfo { /// The dialect of this operation - pub dialect: DialectName, + dialect: DialectName, /// The opcode name for this operation - pub name: interner::Symbol, + name: interner::Symbol, + /// The type id of the concrete type that implements this operation + type_id: TypeId, + /// Details of the traits implemented by this operation, used to answer questions about what + /// traits are implemented, as well as reconstruct `&dyn Trait` references given a pointer to + /// the data of a specific operation instance. + traits: Box<[TraitInfo]>, } + impl OperationName { - pub fn new(dialect: DialectName, name: S) -> Self + pub fn new(dialect: DialectName, name: S, traits: T) -> Self where + O: crate::Op, S: Into, + T: IntoIterator, { - Self { - dialect, - name: name.into(), + let type_id = TypeId::of::(); + let mut traits = traits.into_iter().collect::>(); + traits.sort_by_key(|ti| *ti.type_id()); + let traits = traits.into_boxed_slice(); + let info = Rc::new(OperationInfo::new(dialect, name.into(), type_id, traits)); + Self(info) + } + + /// Returns the dialect name of this operation + pub fn dialect(&self) -> DialectName { + self.0.dialect + } + + /// Returns the namespace to which this operation name belongs (i.e. dialect name) + pub fn namespace(&self) -> interner::Symbol { + self.0.dialect.as_symbol() + } + + /// Returns the name/opcode of this operation + pub fn name(&self) -> interner::Symbol { + self.0.name + } + + /// Returns true if `T` is the concrete type that implements this operation + pub fn is(&self) -> bool { + TypeId::of::() == self.0.type_id + } + + /// Returns true if this operation implements `Trait` + pub fn implements(&self) -> bool + where + Trait: ?Sized + Pointee> + 'static, + { + let type_id = TypeId::of::(); + self.0.traits.binary_search_by(|ti| ti.type_id().cmp(&type_id)).is_ok() + } + + /// Returns true if this operation implements `trait`, where `trait` is the `TypeId` of a + /// `dyn Trait` type. + pub fn implements_trait_id(&self, trait_id: &TypeId) -> bool { + self.0.traits.binary_search_by(|ti| ti.type_id().cmp(trait_id)).is_ok() + } + + #[inline] + pub(super) fn downcast_ref(&self, ptr: *const ()) -> Option<&T> { + if self.is::() { + Some(unsafe { self.downcast_ref_unchecked(ptr) }) + } else { + None + } + } + + #[inline(always)] + unsafe fn downcast_ref_unchecked(&self, ptr: *const ()) -> &T { + &*core::ptr::from_raw_parts(ptr.cast::(), ()) + } + + #[inline] + pub(super) fn downcast_mut(&mut self, ptr: *mut ()) -> Option<&mut T> { + if self.is::() { + Some(unsafe { self.downcast_mut_unchecked(ptr) }) + } else { + None } } + + #[inline(always)] + unsafe fn downcast_mut_unchecked(&mut self, ptr: *mut ()) -> &mut T { + &mut *core::ptr::from_raw_parts_mut(ptr.cast::(), ()) + } + + pub(super) fn upcast(&self, ptr: *const ()) -> Option<&Trait> + where + Trait: ?Sized + Pointee> + 'static, + { + let metadata = self + .get::() + .map(|trait_impl| unsafe { trait_impl.metadata_unchecked::() })?; + Some(unsafe { &*core::ptr::from_raw_parts(ptr, metadata) }) + } + + pub(super) fn upcast_mut(&mut self, ptr: *mut ()) -> Option<&mut Trait> + where + Trait: ?Sized + Pointee> + 'static, + { + let metadata = self + .get::() + .map(|trait_impl| unsafe { trait_impl.metadata_unchecked::() })?; + Some(unsafe { &mut *core::ptr::from_raw_parts_mut(ptr, metadata) }) + } + + fn get(&self) -> Option<&TraitInfo> { + let type_id = TypeId::of::(); + self.0 + .traits + .binary_search_by(|ti| ti.type_id().cmp(&type_id)) + .ok() + .map(|index| &self.0.traits[index]) + } } impl fmt::Debug for OperationName { #[inline] @@ -33,6 +143,49 @@ impl fmt::Debug for OperationName { } impl fmt::Display for OperationName { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}.{}", &self.dialect, &self.name) + write!(f, "{}.{}", &self.namespace(), &self.name()) + } +} + +impl OperationInfo { + pub fn new( + dialect: DialectName, + name: interner::Symbol, + type_id: TypeId, + traits: Box<[TraitInfo]>, + ) -> Self { + Self { + dialect, + name, + type_id, + traits, + } + } +} + +impl Eq for OperationInfo {} +impl PartialEq for OperationInfo { + fn eq(&self, other: &Self) -> bool { + self.dialect == other.dialect && self.name == other.name && self.type_id == other.type_id + } +} +impl PartialOrd for OperationInfo { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for OperationInfo { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.dialect + .cmp(&other.dialect) + .then_with(|| self.name.cmp(&other.name)) + .then_with(|| self.type_id.cmp(&other.type_id)) + } +} +impl core::hash::Hash for OperationInfo { + fn hash(&self, state: &mut H) { + self.dialect.hash(state); + self.name.hash(state); + self.type_id.hash(state); } } diff --git a/hir2/src/ir/region.rs b/hir2/src/ir/region.rs index f47499b9a..21df63877 100644 --- a/hir2/src/ir/region.rs +++ b/hir2/src/ir/region.rs @@ -67,4 +67,8 @@ impl Region { pub fn body_mut(&mut self) -> &mut BlockList { &mut self.body } + + pub fn drop_all_references(&mut self) { + todo!() + } } diff --git a/hir2/src/ir/successor.rs b/hir2/src/ir/successor.rs index 16c405cf3..4c50f3dd6 100644 --- a/hir2/src/ir/successor.rs +++ b/hir2/src/ir/successor.rs @@ -1,23 +1,219 @@ use core::fmt; -use crate::{BlockOperandRef, OpOperand}; +use super::OpOperandStorage; +use crate::{AttributeValue, BlockOperandRef, BlockRef, OpOperandRange, OpOperandRangeMut}; -/// TODO: +pub type OpSuccessorStorage = crate::EntityStorage; +pub type OpSuccessorRange<'a> = crate::EntityRange<'a, SuccessorInfo>; +pub type OpSuccessorRangeMut<'a> = crate::EntityRangeMut<'a, SuccessorInfo, 0>; + +/// This trait represents successor-like values for operations, with support for control-flow +/// predicated on a "key", a sentinel value that must match in order for the successor block to be +/// taken. +/// +/// The ability to associate a successor with a user-defined key, is intended for modeling things +/// such as [crate::dialects::hir::Switch], which has one or more successors which are guarded by +/// an integer value that is matched against the input, or selector, value. Most importantly, doing +/// so in a way that keeps everything in sync as the IR is modified. /// -/// * Replace usage of OpSuccessor with BlockOperand -/// * Store OpSuccessor operands in OpOperandStorage in groups per BlockOperand +/// When used as a successor argument to an operation, each successor gets its own operand group, +/// and if it has an associated key, keyed successors are stored in a special attribute which tracks +/// each key and its associated successor index. This allows requesting the successor details and +/// getting back the correct key, destination, and operands. +pub trait KeyedSuccessor { + /// The type of key this successor + type Key: AttributeValue + Clone + Eq; + /// The type of value which will represent a reference to this successor. + /// + /// You should use [OpSuccessor] if this successor is not keyed. + type Repr<'a>: 'a; + /// The type of value which will represent a mutable reference to this successor. + /// + /// You should use [OpSuccessorMut] if this successor is not keyed. + type ReprMut<'a>: 'a; + + /// The (optional) value of the key for this successor. + /// + /// Keys must be valid attribute values, as they will be encoded in the operation attributes. + /// + /// If `None` is returned, this successor is to be treated like a regular successor argument, + /// i.e. a destination block and associated operands. If a key is returned, the key must be + /// unique across the set of keyed successors. + fn key(&self) -> &Self::Key; + /// Convert this value into the raw parts comprising the successor information: + /// + /// * The (optional) key under which this successor is selected + /// * The destination block + /// * The destination operands + fn into_parts(self) -> (Self::Key, BlockRef, Vec); + fn into_repr( + key: Self::Key, + block: BlockOperandRef, + operands: OpOperandRange<'_>, + ) -> Self::Repr<'_>; + fn into_repr_mut( + key: Self::Key, + block: BlockOperandRef, + operands: OpOperandRangeMut<'_>, + ) -> Self::ReprMut<'_>; +} -/// An [OpSuccessor] is a BlockOperand + OpOperands for that block, attached to an Operation -#[derive(Clone)] -pub struct OpSuccessor { +/// This struct tracks successor metadata needed by [crate::Operation] +#[derive(Debug, Clone)] +pub struct SuccessorInfo { pub block: BlockOperandRef, - pub args: smallvec::SmallVec<[OpOperand; 1]>, + pub(crate) key: Option>, + pub(crate) operand_group: u8, } -impl fmt::Debug for OpSuccessor { +impl crate::StorableEntity for SuccessorInfo { + #[inline(always)] + fn index(&self) -> usize { + self.block.index() + } + + #[inline(always)] + unsafe fn set_index(&mut self, index: usize) { + self.block.set_index(index); + } + + #[inline(always)] + fn unlink(&mut self) { + self.block.unlink(); + } +} + +/// An [OpSuccessor] is a BlockOperand + OpOperandRange for that block +pub struct OpSuccessor<'a> { + pub dest: BlockOperandRef, + pub arguments: OpOperandRange<'a>, +} +impl fmt::Debug for OpSuccessor<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("OpSuccessor") - .field("block", &self.block.borrow().block_id()) - .field("args", &self.args) + .field("block", &self.dest.borrow().block_id()) + .field_with("arguments", |f| f.debug_list().entries(self.arguments.iter()).finish()) .finish() } } + +/// An [OpSuccessorMut] is a BlockOperand + OpOperandRangeMut for that block +pub struct OpSuccessorMut<'a> { + pub dest: BlockOperandRef, + pub arguments: OpOperandRangeMut<'a>, +} +impl fmt::Debug for OpSuccessorMut<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpSuccessorMut") + .field("block", &self.dest.borrow().block_id()) + .field_with("arguments", |f| f.debug_list().entries(self.arguments.iter()).finish()) + .finish() + } +} + +pub struct KeyedSuccessorRange<'a, T> { + range: OpSuccessorRange<'a>, + operands: &'a OpOperandStorage, + _marker: core::marker::PhantomData, +} +impl<'a, T> KeyedSuccessorRange<'a, T> { + pub fn new(range: OpSuccessorRange<'a>, operands: &'a OpOperandStorage) -> Self { + Self { + range, + operands, + _marker: core::marker::PhantomData, + } + } + + pub fn get(&self, index: usize) -> Option> { + self.range.get(index).map(|info| { + let operands = self.operands.group(info.operand_group as usize); + SuccessorWithKey { + info, + operands, + _marker: core::marker::PhantomData, + } + }) + } +} + +pub struct KeyedSuccessorRangeMut<'a, T> { + range: OpSuccessorRangeMut<'a>, + operands: &'a mut OpOperandStorage, + _marker: core::marker::PhantomData, +} +impl<'a, T> KeyedSuccessorRangeMut<'a, T> { + pub fn new(range: OpSuccessorRangeMut<'a>, operands: &'a mut OpOperandStorage) -> Self { + Self { + range, + operands, + _marker: core::marker::PhantomData, + } + } + + pub fn get(&self, index: usize) -> Option> { + self.range.get(index).map(|info| { + let operands = self.operands.group(info.operand_group as usize); + SuccessorWithKey { + info, + operands, + _marker: core::marker::PhantomData, + } + }) + } + + pub fn get_mut(&mut self, index: usize) -> Option> { + self.range.get_mut(index).map(|info| { + let operands = self.operands.group_mut(info.operand_group as usize); + SuccessorWithKeyMut { + info, + operands, + _marker: core::marker::PhantomData, + } + }) + } +} + +pub struct SuccessorWithKey<'a, T> { + info: &'a SuccessorInfo, + operands: OpOperandRange<'a>, + _marker: core::marker::PhantomData, +} +impl<'a, T> SuccessorWithKey<'a, T> { + pub fn key(&self) -> Option<&T> { + self.info.key.map(|ptr| unsafe { &*(ptr.as_ptr() as *mut T) }) + } + + pub fn block(&self) -> BlockRef { + self.info.block.borrow().block.clone() + } + + #[inline(always)] + pub fn arguments(&self) -> &OpOperandRange<'a> { + &self.operands + } +} + +pub struct SuccessorWithKeyMut<'a, T> { + info: &'a SuccessorInfo, + operands: OpOperandRangeMut<'a>, + _marker: core::marker::PhantomData, +} +impl<'a, T> SuccessorWithKeyMut<'a, T> { + pub fn key(&self) -> Option<&T> { + self.info.key.map(|ptr| unsafe { &*(ptr.as_ptr() as *mut T) }) + } + + pub fn block(&self) -> BlockRef { + self.info.block.borrow().block.clone() + } + + #[inline(always)] + pub fn arguments(&self) -> &OpOperandRangeMut<'a> { + &self.operands + } + + #[inline(always)] + pub fn arguments_mut(&mut self) -> &mut OpOperandRangeMut<'a> { + &mut self.operands + } +} diff --git a/hir2/src/ir/symbol_table.rs b/hir2/src/ir/symbol_table.rs index 988141c91..0aaed3ab2 100644 --- a/hir2/src/ir/symbol_table.rs +++ b/hir2/src/ir/symbol_table.rs @@ -1,16 +1,52 @@ use alloc::collections::VecDeque; use core::fmt; +use midenc_session::diagnostics::{miette, Diagnostic}; + use crate::{ - define_attr_type, interner, InsertionPoint, Op, Operation, OperationRef, Report, Searcher, - UnsafeIntrusiveEntityRef, Usable, Visibility, + define_attr_type, interner, InsertionPoint, Op, Operation, OperationRef, Report, + UnsafeIntrusiveEntityRef, Usable, Visibility, Walkable, }; /// Represents the name of a [Symbol] in its local [SymbolTable] pub type SymbolName = interner::Symbol; -#[derive(Debug, Copy, Clone)] +#[derive(Debug, thiserror::Error, Diagnostic)] +pub enum InvalidSymbolRefError { + #[error("invalid symbol reference: no symbol table available")] + NoSymbolTable { + #[label("cannot resolve this symbol")] + symbol: crate::SourceSpan, + #[label( + "because this operation has no parent symbol table with which to resolve the reference" + )] + user: crate::SourceSpan, + }, + #[error("invalid symbol reference: undefined symbol")] + UnknownSymbol { + #[label("failed to resolve this symbol")] + symbol: crate::SourceSpan, + #[label("in the nearest symbol table from this operation")] + user: crate::SourceSpan, + }, + #[error("invalid symbol reference: undefined component '{component}' of symbol")] + UnknownSymbolComponent { + #[label("failed to resolve this symbol")] + symbol: crate::SourceSpan, + #[label("from the root symbol table of this operation")] + user: crate::SourceSpan, + component: &'static str, + }, + #[error("invalid symbol reference: expected callable")] + NotCallable { + #[label("expected this symbol to implement the CallableOpInterface")] + symbol: crate::SourceSpan, + }, +} + +#[derive(Debug, Clone)] pub struct SymbolNameAttr { + pub user: SymbolUseRef, /// The path through the abstract symbol space to the containing symbol table /// /// It is assumed that all symbol tables are also symbols themselves, and thus the path to @@ -51,7 +87,7 @@ impl SymbolNameAttr { self.path != interner::symbols::Empty } - pub fn components(&self) -> impl Iterator { + pub fn components(&self) -> SymbolNameComponents { SymbolNameComponents::new(self.path, self.name) } } @@ -64,6 +100,12 @@ impl fmt::Display for SymbolNameAttr { } } } +impl crate::formatter::PrettyPrint for SymbolNameAttr { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + display(self) + } +} impl Eq for SymbolNameAttr {} impl PartialEq for SymbolNameAttr { fn eq(&self, other: &Self) -> bool { @@ -92,7 +134,7 @@ pub enum SymbolNameComponent { Leaf(SymbolName), } -struct SymbolNameComponents { +pub struct SymbolNameComponents { parts: VecDeque<&'static str>, name: SymbolName, done: bool, @@ -135,6 +177,25 @@ impl SymbolNameComponents { done: false, } } + + /// Convert this iterator into a symbol name representing the path prefix of a [Symbol]. + /// + /// If `absolute` is set to true, then the resulting path will be prefixed with `::` + pub fn into_path(self, absolute: bool) -> SymbolName { + if self.parts.is_empty() { + return ::midenc_hir_symbol::symbols::Empty; + } + + let mut buf = + String::with_capacity(2usize + self.parts.iter().map(|p| p.len()).sum::()); + if absolute { + buf.push_str("::"); + } + for part in self.parts { + buf.push_str(part); + } + SymbolName::intern(buf) + } } impl core::iter::FusedIterator for SymbolNameComponents {} impl Iterator for SymbolNameComponents { @@ -155,6 +216,32 @@ impl Iterator for SymbolNameComponents { } } +/// A trait which allows multiple types to be coerced into a [SymbolRef]. +/// +/// This is primarily intended for use in operation builders. +pub trait AsSymbolRef { + fn as_symbol_ref(&self) -> SymbolRef; +} +impl AsSymbolRef for &T { + #[inline] + fn as_symbol_ref(&self) -> SymbolRef { + unsafe { SymbolRef::from_raw(*self as &dyn Symbol) } + } +} +impl AsSymbolRef for UnsafeIntrusiveEntityRef { + #[inline] + fn as_symbol_ref(&self) -> SymbolRef { + let t_ptr = Self::as_ptr(self); + unsafe { SymbolRef::from_raw(t_ptr as *const dyn Symbol) } + } +} +impl AsSymbolRef for SymbolRef { + #[inline(always)] + fn as_symbol_ref(&self) -> SymbolRef { + Self::clone(self) + } +} + /// A [SymbolTable] is an IR entity which contains other IR entities, called _symbols_, each of /// which has a name, aka symbol, that uniquely identifies it amongst all other entities in the /// same [SymbolTable]. @@ -166,10 +253,10 @@ impl Iterator for SymbolNameComponents { /// type matches the `Key` type of the [SymbolTable], can be stored in that table. pub trait SymbolTable { /// Get a reference to the underlying [Operation] - fn as_operation(&self) -> &Operation; + fn as_symbol_table_operation(&self) -> &Operation; /// Get a mutable reference to the underlying [Operation] - fn as_operation_mut(&mut self) -> &mut Operation; + fn as_symbol_table_operation_mut(&mut self) -> &mut Operation; /// Get the entry for `name` in this table fn get(&self, name: SymbolName) -> Option; @@ -204,7 +291,7 @@ impl dyn SymbolTable { pub fn find(&self, name: SymbolName) -> Option> { let op = self.get(name)?; let op = op.borrow(); - let op = op.as_operation().downcast_ref::()?; + let op = op.as_symbol_operation().downcast_ref::()?; Some(unsafe { UnsafeIntrusiveEntityRef::from_raw(op) }) } } @@ -216,38 +303,67 @@ impl dyn SymbolTable { /// otherwise it would not be possible to unambiguously refer to a function by name. Likewise /// with modules in a program, etc. pub trait Symbol: Usable + 'static { - fn as_operation(&self) -> &Operation; - fn as_operation_mut(&mut self) -> &mut Operation; + fn as_symbol_operation(&self) -> &Operation; + fn as_symbol_operation_mut(&mut self) -> &mut Operation; /// Get the name of this symbol fn name(&self) -> SymbolName; + /// Get an iterator over the components of the fully-qualified path of this symbol. + fn components(&self) -> SymbolNameComponents { + let mut parts = VecDeque::default(); + if let Some(symbol_table) = self.root_symbol_table() { + let symbol_table = symbol_table.borrow(); + symbol_table.walk_symbol_tables(true, |symbol_table, _| { + if let Some(sym) = symbol_table.as_symbol_table_operation().as_symbol() { + parts.push_back(sym.name().as_str()); + } + }); + } + SymbolNameComponents { + parts, + name: self.name(), + done: false, + } + } /// Set the name of this symbol fn set_name(&mut self, name: SymbolName); /// Get the visibility of this symbol fn visibility(&self) -> Visibility; /// Returns true if this symbol has private visibility - fn is_private(&self) -> bool; + #[inline] + fn is_private(&self) -> bool { + self.visibility().is_private() + } /// Returns true if this symbol has public visibility - fn is_public(&self) -> bool; + #[inline] + fn is_public(&self) -> bool { + self.visibility().is_public() + } /// Sets the visibility of this symbol fn set_visibility(&mut self, visibility: Visibility); /// Sets the visibility of this symbol to private fn set_private(&mut self) { self.set_visibility(Visibility::Private); } - /// Sets the visibility of this symbol to nested - fn set_nested(&mut self) { - self.set_visibility(Visibility::Nested); + /// Sets the visibility of this symbol to internal + fn set_internal(&mut self) { + self.set_visibility(Visibility::Internal); } /// Sets the visibility of this symbol to public fn set_public(&mut self) { self.set_visibility(Visibility::Public); } /// Get all of the uses of this symbol that are nested within `from` - fn symbol_uses(&self, from: OperationRef) -> SymbolUseIter; + fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter; /// Return true if there are no uses of this symbol nested within `from` - fn symbol_uses_known_empty(&self, from: OperationRef) -> SymbolUseIter; + fn symbol_uses_known_empty(&self, from: OperationRef) -> bool { + self.symbol_uses(from).is_empty() + } /// Attempt to replace all uses of this symbol nested within `from`, with the provided replacement - fn replace_all_uses(&self, replacement: SymbolRef, from: OperationRef) -> Result<(), Report>; + fn replace_all_uses( + &mut self, + replacement: SymbolRef, + from: OperationRef, + ) -> Result<(), Report>; /// Returns true if this operation can be discarded if it has no remaining symbol uses /// /// By default, if the visibility is non-public, a symbol is considered discardable @@ -260,21 +376,29 @@ pub trait Symbol: Usable + 'static { fn is_declaration(&self) -> bool { false } + /// Return the root symbol table in which this symbol is contained, if one exists. + /// + /// The root symbol table does not necessarily know about this symbol, rather the symbol table + /// which "owns" this symbol may itself be a symbol that belongs to another symbol table. This + /// function traces this chain as far as it goes, and returns the highest ancestor in the tree. + fn root_symbol_table(&self) -> Option { + self.as_symbol_operation().root_symbol_table() + } } impl dyn Symbol { pub fn is(&self) -> bool { - let op = self.as_operation(); + let op = self.as_symbol_operation(); op.is::() } pub fn downcast_ref(&self) -> Option<&T> { - let op = self.as_operation(); + let op = self.as_symbol_operation(); op.downcast_ref::() } pub fn downcast_mut(&mut self) -> Option<&mut T> { - let op = self.as_operation_mut(); + let op = self.as_symbol_operation_mut(); op.downcast_mut::() } @@ -283,7 +407,7 @@ impl dyn Symbol { /// NOTE: This relies on the assumption that all ops are allocated via the arena, and that all /// [Symbol] implementations are ops. pub fn as_operation_ref(&self) -> OperationRef { - unsafe { OperationRef::from_raw(self.as_operation()) } + self.as_symbol_operation().as_operation_ref() } } @@ -300,6 +424,27 @@ impl Operation { self.as_trait::() } + /// Get this operation as a [SymbolTable], if this operation implements the trait. + #[inline] + pub fn as_symbol_table(&self) -> Option<&dyn SymbolTable> { + self.as_trait::() + } + + /// Return the root symbol table in which this symbol is contained, if one exists. + /// + /// The root symbol table does not necessarily know about this symbol, rather the symbol table + /// which "owns" this symbol may itself be a symbol that belongs to another symbol table. This + /// function traces this chain as far as it goes, and returns the highest ancestor in the tree. + pub fn root_symbol_table(&self) -> Option { + let mut parent = self.nearest_symbol_table(); + let mut found = None; + while let Some(nearest_symbol_table) = parent.take() { + found = Some(nearest_symbol_table.clone()); + parent = nearest_symbol_table.borrow().nearest_symbol_table(); + } + found + } + /// Returns the nearest [SymbolTable] from this operation. /// /// Returns `None` if no parent of this operation is a valid symbol table. @@ -339,19 +484,14 @@ impl Operation { /// IR are visible to the caller. pub fn walk_symbol_tables(&self, all_symbol_uses_visible: bool, mut callback: F) where - F: FnMut(&dyn Symbol, bool), + F: FnMut(&dyn SymbolTable, bool), { - use core::ops::ControlFlow; - - let visitor = move |op: &dyn Symbol| { - callback(op, all_symbol_uses_visible); - ControlFlow::<()>::Continue(()) - }; - - let op = self.as_operation_ref(); - let mut searcher = Searcher::new(op, visitor); - - searcher.visit(); + self.prewalk(|op: OperationRef| { + let op = op.borrow(); + if let Some(sym) = op.as_symbol_table() { + callback(sym, all_symbol_uses_visible); + } + }); } } @@ -384,7 +524,7 @@ fn verify_symbol(symbol: &dyn Symbol, context: &super::Context) -> Result<(), Re use midenc_session::diagnostics::{Severity, Spanned}; // Symbols must either have no parent, or be an immediate child of a SymbolTable - let op = symbol.as_operation(); + let op = symbol.as_symbol_operation(); let parent = op.parent_op(); if !parent.is_none_or(|parent| parent.borrow().implements::()) { return Err(context @@ -405,23 +545,57 @@ pub type SymbolUseIter<'a> = crate::EntityIter<'a, SymbolUse>; pub type SymbolUseCursor<'a> = crate::EntityCursor<'a, SymbolUse>; pub type SymbolUseCursorMut<'a> = crate::EntityCursorMut<'a, SymbolUse>; +pub struct SymbolUsesIter { + items: VecDeque, +} +impl ExactSizeIterator for SymbolUsesIter { + #[inline(always)] + fn len(&self) -> usize { + self.items.len() + } +} +impl From> for SymbolUsesIter { + fn from(items: VecDeque) -> Self { + Self { items } + } +} +impl FromIterator for SymbolUsesIter { + fn from_iter>(iter: T) -> Self { + Self { + items: iter.into_iter().collect(), + } + } +} +impl core::iter::FusedIterator for SymbolUsesIter {} +impl Iterator for SymbolUsesIter { + type Item = SymbolUseRef; + + #[inline] + fn next(&mut self) -> Option { + self.items.pop_front() + } +} + /// An [OpOperand] represents a use of a [Value] by an [Operation] pub struct SymbolUse { /// The user of the symbol pub owner: OperationRef, - /// The symbol used - pub symbol: SymbolNameAttr, + /// The symbol attribute of the op that stores the symbol + pub symbol: crate::interner::Symbol, } impl SymbolUse { #[inline] - pub fn new(owner: OperationRef, symbol: SymbolNameAttr) -> Self { + pub fn new(owner: OperationRef, symbol: crate::interner::Symbol) -> Self { Self { owner, symbol } } } impl fmt::Debug for SymbolUse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let op = self.owner.borrow(); + let value = op.get_typed_attribute::(&self.symbol); f.debug_struct("SymbolUse") - .field("symbol", &self.symbol) + .field("attr", &self.symbol) + .field("symbol", &value) .finish_non_exhaustive() } } diff --git a/hir2/src/ir/traits.rs b/hir2/src/ir/traits.rs index 39e441d99..36866e1c2 100644 --- a/hir2/src/ir/traits.rs +++ b/hir2/src/ir/traits.rs @@ -1,10 +1,10 @@ mod callable; -mod multitrait; +mod info; mod types; use midenc_session::diagnostics::Severity; -pub(crate) use self::multitrait::MultiTraitVtable; +pub(crate) use self::info::TraitInfo; pub use self::{callable::*, types::*}; use crate::{derive, Context, Operation, Report, Spanned}; @@ -29,6 +29,9 @@ pub trait ReturnLike {} /// Op is a terminator (i.e. it can be used to terminate a block) pub trait Terminator {} +/// Op's regions do not require blocks to end with a [Terminator] +pub trait NoTerminator {} + /// Marker trait for idemptoent ops, i.e. `op op X == op X (unary) / X op X == X (binary)` pub trait Idempotent {} @@ -38,6 +41,51 @@ pub trait Involution {} /// Marker trait for ops which are not permitted to access values defined above them pub trait IsolatedFromAbove {} +/// Marker trait for ops which have only regions of [`RegionKind::Graph`] +pub trait HasOnlyGraphRegion {} + +/// Op's regions are all single-block graph regions, that not require a terminator +/// +/// This trait _cannot_ be derived via `derive!` +pub trait GraphRegionNoTerminator: + NoTerminator + SingleBlock + RegionKindInterface + HasOnlyGraphRegion +{ +} + +/// Represents the types of regions that can be represented in the IR +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RegionKind { + /// A graph region is one without control-flow semantics, i.e. dataflow between operations is + /// the only thing that dictates order, and operations can be conceptually executed in parallel + /// if the runtime supports it. + /// + /// As there is no control-flow in these regions, graph regions may only contain a single block. + Graph, + /// An SSA region is one where the strict control-flow semantics and properties of SSA (static + /// single assignment) form must be upheld. + /// + /// SSA regions must adhere to: + /// + /// * Values can only be defined once + /// * Definitions must dominate uses + /// * Ordering of operations in a block corresponds to execution order, i.e. operations earlier + /// in a block dominate those later in the block. + /// * Blocks must end with a terminator. + #[default] + SSA, +} + +/// An op interface that indicates what types of regions it holds +pub trait RegionKindInterface { + /// Get the [RegionKind] for this operation + fn kind(&self) -> RegionKind; + /// Returns true if the kind of this operation's regions requires SSA dominance + #[inline] + fn has_ssa_dominance(&self) -> bool { + matches!(self.kind(), RegionKind::SSA) + } +} + derive! { /// Marker trait for unary ops, i.e. those which take a single operand pub trait UnaryOp {} diff --git a/hir2/src/ir/traits/callable.rs b/hir2/src/ir/traits/callable.rs index f45278680..e99924e22 100644 --- a/hir2/src/ir/traits/callable.rs +++ b/hir2/src/ir/traits/callable.rs @@ -1,6 +1,6 @@ use crate::{ - EntityRef, OpOperandRange, OpOperandRangeMut, RegionRef, Signature, SymbolNameAttr, SymbolRef, - Value, ValueRef, + EntityRef, OpOperandRange, OpOperandRangeMut, RegionRef, Signature, Symbol, SymbolNameAttr, + SymbolRef, UnsafeIntrusiveEntityRef, Value, ValueRef, }; /// A call-like operation is one that transfers control from one function to another. @@ -47,6 +47,24 @@ pub trait CallableOpInterface { fn signature(&self) -> &Signature; } +#[doc(hidden)] +pub trait AsCallableSymbolRef { + fn as_callable_symbol_ref(&self) -> SymbolRef; +} +impl AsCallableSymbolRef for T { + #[inline(always)] + fn as_callable_symbol_ref(&self) -> SymbolRef { + unsafe { SymbolRef::from_raw(self as &dyn Symbol) } + } +} +impl AsCallableSymbolRef for UnsafeIntrusiveEntityRef { + #[inline(always)] + fn as_callable_symbol_ref(&self) -> SymbolRef { + let t_ptr = Self::as_ptr(self); + unsafe { SymbolRef::from_raw(t_ptr as *const dyn Symbol) } + } +} + /// A [Callable] represents a symbol or a value which can be used as a valid _callee_ for a /// [CallOpInterface] implementation. /// @@ -60,7 +78,7 @@ pub enum Callable { } impl From<&SymbolNameAttr> for Callable { fn from(value: &SymbolNameAttr) -> Self { - Self::Symbol(*value) + Self::Symbol(value.clone()) } } impl From for Callable { diff --git a/hir2/src/ir/traits/info.rs b/hir2/src/ir/traits/info.rs new file mode 100644 index 000000000..1a4e0754f --- /dev/null +++ b/hir2/src/ir/traits/info.rs @@ -0,0 +1,83 @@ +use core::{ + any::{Any, TypeId}, + marker::Unsize, + ptr::{null, DynMetadata, Pointee}, +}; + +pub struct TraitInfo { + /// The [TypeId] of the trait type, used as a unique key for [TraitImpl]s + type_id: TypeId, + /// Type-erased dyn metadata containing the trait vtable pointer for the concrete type + /// + /// This is transmuted to the correct trait type when reifying a `&dyn Trait` reference, + /// which is safe as `DynMetadata` is always the same size for all types. + metadata: DynMetadata, +} +impl TraitInfo { + pub fn new() -> Self + where + T: Any + Unsize + crate::verifier::Verifier, + Trait: ?Sized + Pointee> + 'static, + { + let type_id = TypeId::of::(); + let ptr = null::() as *const Trait; + let (_, metadata) = ptr.to_raw_parts(); + Self { + type_id, + metadata: unsafe { + core::mem::transmute::, DynMetadata>(metadata) + }, + } + } + + #[inline(always)] + pub const fn type_id(&self) -> &TypeId { + &self.type_id + } + + /// Obtain the dyn metadata for `Trait` from this info. + /// + /// # Safety + /// + /// This is highly unsafe - you must guarantee that `Trait` is the same type as the one used + /// to create this `TraitInfo` instance. In debug mode, errors like this will be caught, but + /// in release builds, no checks are performed, and absolute havoc will result if you use this + /// incorrectly. + /// + /// It is intended _only_ for use by generated code which has all of the type information + /// available to it statically. It must be public so that operations can be defined in other + /// crates. + pub unsafe fn metadata_unchecked(&self) -> DynMetadata + where + Trait: ?Sized + Pointee> + 'static, + { + debug_assert!(self.type_id == TypeId::of::()); + core::mem::transmute(self.metadata) + } +} +impl Eq for TraitInfo {} +impl PartialEq for TraitInfo { + fn eq(&self, other: &Self) -> bool { + self.type_id == other.type_id + } +} +impl PartialEq for TraitInfo { + fn eq(&self, other: &TypeId) -> bool { + self.type_id.eq(other) + } +} +impl PartialOrd for TraitInfo { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.type_id.cmp(&other.type_id)) + } +} +impl PartialOrd for TraitInfo { + fn partial_cmp(&self, other: &TypeId) -> Option { + Some(self.type_id.cmp(other)) + } +} +impl Ord for TraitInfo { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.type_id.cmp(&other.type_id) + } +} diff --git a/hir2/src/ir/traits/multitrait.rs b/hir2/src/ir/traits/multitrait.rs deleted file mode 100644 index d1b352620..000000000 --- a/hir2/src/ir/traits/multitrait.rs +++ /dev/null @@ -1,179 +0,0 @@ -use core::{ - any::{Any, TypeId}, - marker::Unsize, - ptr::{null, null_mut, DynMetadata, Pointee}, -}; - -struct TraitImpl { - /// The [TypeId] of the trait type, used as a unique key for [TraitImpl]s - type_id: TypeId, - /// Type-erased dyn metadata containing the trait vtable pointer for the concrete type - /// - /// This is transmuted to the correct trait type when reifying a `&dyn Trait` reference, - /// which is safe as `DynMetadata` is always the same size for all types. - metadata: DynMetadata, -} -impl TraitImpl { - fn new() -> Self - where - T: Any + Unsize, - Trait: ?Sized + Pointee> + 'static, - { - let type_id = TypeId::of::(); - let ptr = null::() as *const Trait; - let (_, metadata) = ptr.to_raw_parts(); - Self { - type_id, - metadata: unsafe { - core::mem::transmute::, DynMetadata>(metadata) - }, - } - } - - unsafe fn metadata_unchecked(&self) -> DynMetadata - where - Trait: ?Sized + Pointee> + 'static, - { - debug_assert!(self.type_id == TypeId::of::()); - core::mem::transmute(self.metadata) - } -} -impl Eq for TraitImpl {} -impl PartialEq for TraitImpl { - fn eq(&self, other: &Self) -> bool { - self.type_id == other.type_id - } -} -impl PartialEq for TraitImpl { - fn eq(&self, other: &TypeId) -> bool { - self.type_id.eq(other) - } -} -impl PartialOrd for TraitImpl { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.type_id.cmp(&other.type_id)) - } -} -impl PartialOrd for TraitImpl { - fn partial_cmp(&self, other: &TypeId) -> Option { - Some(self.type_id.cmp(other)) - } -} -impl Ord for TraitImpl { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.type_id.cmp(&other.type_id) - } -} - -pub(crate) struct MultiTraitVtable { - data: *mut (), - type_id: TypeId, - traits: Vec, -} -impl MultiTraitVtable { - pub fn new() -> Self { - let type_id = TypeId::of::(); - let any_impl = TraitImpl::new::(); - - Self { - data: null_mut(), - type_id, - traits: vec![any_impl], - } - } - - #[allow(unused)] - #[inline] - pub const fn data_ptr(&self) -> *mut () { - self.data - } - - pub(crate) unsafe fn set_data_ptr(&mut self, ptr: *mut T) { - assert!(!ptr.is_null()); - assert!(ptr.is_aligned()); - assert!(self.is::()); - self.data = ptr.cast(); - } - - pub fn register_trait(&mut self) - where - T: Any + Unsize + 'static, - Trait: ?Sized + Pointee> + 'static, - { - let trait_impl = TraitImpl::new::(); - match self.traits.binary_search(&trait_impl) { - Ok(_) => (), - Err(index) if index + 1 == self.traits.len() => self.traits.push(trait_impl), - Err(index) => self.traits.insert(index, trait_impl), - } - } - - #[inline] - pub fn is(&self) -> bool { - self.type_id == TypeId::of::() - } - - pub fn implements(&self) -> bool - where - Trait: ?Sized + Pointee> + 'static, - { - let type_id = TypeId::of::(); - self.traits.binary_search_by(|ti| ti.type_id.cmp(&type_id)).is_ok() - } - - #[inline] - pub fn downcast_ref(&self) -> Option<&T> { - if self.is::() { - Some(unsafe { self.downcast_ref_unchecked() }) - } else { - None - } - } - - #[inline(always)] - unsafe fn downcast_ref_unchecked(&self) -> &T { - &*core::ptr::from_raw_parts(self.data.cast::(), ()) - } - - #[inline] - pub fn downcast_mut(&mut self) -> Option<&mut T> { - if self.is::() { - Some(unsafe { self.downcast_mut_unchecked() }) - } else { - None - } - } - - #[inline(always)] - unsafe fn downcast_mut_unchecked(&mut self) -> &mut T { - &mut *core::ptr::from_raw_parts_mut(self.data.cast::(), ()) - } - - pub fn downcast_trait(&self) -> Option<&Trait> - where - Trait: ?Sized + Pointee> + 'static, - { - let metadata = self - .get::() - .map(|trait_impl| unsafe { trait_impl.metadata_unchecked::() })?; - Some(unsafe { &*core::ptr::from_raw_parts(self.data, metadata) }) - } - - pub fn downcast_trait_mut(&mut self) -> Option<&mut Trait> - where - Trait: ?Sized + Pointee> + 'static, - { - let metadata = self - .get::() - .map(|trait_impl| unsafe { trait_impl.metadata_unchecked::() })?; - Some(unsafe { &mut *core::ptr::from_raw_parts_mut(self.data, metadata) }) - } - - fn get(&self) -> Option<&TraitImpl> { - let type_id = TypeId::of::(); - self.traits - .binary_search_by(|ti| ti.type_id.cmp(&type_id)) - .ok() - .map(|index| &self.traits[index]) - } -} diff --git a/hir2/src/ir/traits/types.rs b/hir2/src/ir/traits/types.rs index 6166f4675..493b0a71e 100644 --- a/hir2/src/ir/traits/types.rs +++ b/hir2/src/ir/traits/types.rs @@ -2,14 +2,19 @@ use core::fmt; use midenc_session::diagnostics::Severity; -use crate::{derive, Context, Operation, Report, Spanned}; +use crate::{derive, Context, Op, Operation, Report, Spanned, Type}; /// OpInterface to compute the return type(s) of an operation. -pub trait InferTypeOpInterface { +pub trait InferTypeOpInterface: Op { /// Run type inference for this op's results, using the current state, and apply any changes. /// /// Returns an error if unable to infer types, or if some type constraint is violated. - fn infer_types(&mut self) -> Result<(), Report>; + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report>; + + /// Return whether the set sets of types are compatible + fn are_compatible_return_types(&self, lhs: &[Type], rhs: &[Type]) -> bool { + lhs == rhs + } } derive! { @@ -115,14 +120,6 @@ impl crate::Verify> for Operation { pub trait TypeConstraint: 'static { fn description() -> impl fmt::Display; fn matches(ty: &crate::Type) -> bool; - fn check(ty: &crate::Type) -> Result<(), Report> { - if Self::matches(ty) { - Ok(()) - } else { - let expected = Self::description(); - Err(Report::msg(format!("expected {expected}, got '{ty}'"))) - } - } } /// A type that can be constructed as a [crate::Type] @@ -162,6 +159,22 @@ macro_rules! type_constraint { } } }; + + ($Constraint:ident, $description:literal, |$matcher_input:ident| $matcher:expr) => { + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct $Constraint; + impl TypeConstraint for $Constraint { + #[inline(always)] + fn description() -> impl core::fmt::Display { + $description + } + + #[inline(always)] + fn matches($matcher_input: &$crate::Type) -> bool { + $matcher + } + } + }; } type_constraint!(AnyType, "any type", true); @@ -172,6 +185,8 @@ type_constraint!(AnyArray, "any array type", crate::Type::is_array); type_constraint!(AnyStruct, "any struct type", crate::Type::is_struct); type_constraint!(AnyPointer, "a pointer type", crate::Type::is_pointer); type_constraint!(AnyInteger, "an integral type", crate::Type::is_integer); +type_constraint!(AnyPointerOrInteger, "an integral or pointer type", |ty| ty.is_pointer() + || ty.is_integer()); type_constraint!(AnySignedInteger, "a signed integral type", crate::Type::is_signed_integer); type_constraint!( AnyUnsignedInteger, @@ -180,6 +195,9 @@ type_constraint!( ); type_constraint!(IntFelt, "a field element", crate::Type::is_felt); +/// A boolean +pub type Bool = SizedInt<1>; + /// A signless 8-bit integer pub type Int8 = SizedInt<8>; /// A signed 8-bit integer @@ -213,6 +231,11 @@ impl BuildableTypeConstraint for IntFelt { crate::Type::Felt } } +impl BuildableTypeConstraint for Bool { + fn build() -> crate::Type { + crate::Type::I1 + } +} impl BuildableTypeConstraint for UInt8 { fn build() -> crate::Type { crate::Type::U8 diff --git a/hir2/src/ir/value.rs b/hir2/src/ir/value.rs index e812c111a..c1376ca19 100644 --- a/hir2/src/ir/value.rs +++ b/hir2/src/ir/value.rs @@ -1,4 +1,4 @@ -use core::fmt; +use core::{any::Any, fmt}; use super::*; @@ -39,10 +39,35 @@ impl fmt::Display for ValueId { /// the graph formed of the edges between values and operations via operands forms the data-flow /// graph of the program. pub trait Value: Entity + Spanned + Usable + fmt::Debug { + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + /// Set the source location of this value + fn set_span(&mut self, span: SourceSpan); /// Get the type of this value fn ty(&self) -> &Type; /// Set the type of this value fn set_type(&mut self, ty: Type); + /// Get the defining operation for this value, _if_ defined by an operation. + /// + /// Returns `None` if this value is defined by other means than an operation result. + fn get_defining_op(&self) -> Option; +} + +impl dyn Value { + #[inline] + pub fn is(&self) -> bool { + self.as_any().is::() + } + + #[inline] + pub fn downcast_ref(&self) -> Option<&T> { + self.as_any().downcast_ref::() + } + + #[inline] + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.as_any_mut().downcast_mut::() + } } /// Generates the boilerplate for a concrete [Value] type. @@ -56,6 +81,8 @@ macro_rules! value_impl { )* } + fn get_defining_op(&$GetDefiningOpSelf:ident) -> Option $GetDefiningOp:block + $($t:tt)* ) => { $(#[$outer])* @@ -74,6 +101,7 @@ macro_rules! value_impl { impl $ValueKind { pub fn new( + span: SourceSpan, id: ValueId, ty: Type, $( @@ -83,7 +111,7 @@ macro_rules! value_impl { Self { id, ty, - span: Default::default(), + span, uses: Default::default(), $( $Field @@ -93,13 +121,27 @@ macro_rules! value_impl { } impl Value for $ValueKind { + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self + } + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } fn ty(&self) -> &Type { &self.ty } + fn set_span(&mut self, span: SourceSpan) { + self.span = span; + } + fn set_type(&mut self, ty: Type) { self.ty = ty; } + + fn get_defining_op(&$GetDefiningOpSelf) -> Option $GetDefiningOp } impl Entity for $ValueKind { @@ -158,6 +200,10 @@ value_impl!( owner: BlockRef, index: u8, } + + fn get_defining_op(&self) -> Option { + None + } ); value_impl!( @@ -166,6 +212,10 @@ value_impl!( owner: OperationRef, index: u8, } + + fn get_defining_op(&self) -> Option { + Some(self.owner.clone()) + } ); impl BlockArgument { @@ -180,6 +230,25 @@ impl BlockArgument { } } +impl crate::formatter::PrettyPrint for BlockArgument { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + + text(format!("{}", self.id)) + const_text(": ") + self.ty.render() + } +} + +impl StorableEntity for BlockArgument { + #[inline(always)] + fn index(&self) -> usize { + self.index as usize + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index.try_into().expect("too many block arguments"); + } +} + impl OpResult { /// Get the [Operation] to which this [OpResult] belongs pub fn owner(&self) -> OperationRef { @@ -190,4 +259,41 @@ impl OpResult { pub fn index(&self) -> usize { self.index as usize } + + #[inline] + pub fn as_value_ref(&self) -> ValueRef { + unsafe { ValueRef::from_raw(self as &dyn Value) } + } } + +impl crate::formatter::PrettyPrint for OpResult { + #[inline] + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + + display(self.id) + } +} + +impl StorableEntity for OpResult { + #[inline(always)] + fn index(&self) -> usize { + self.index as usize + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index.try_into().expect("too many op results"); + } + + /// Unlink all users of this result + /// + /// The users will still refer to this result, but the use list of this value will be empty + fn unlink(&mut self) { + let uses = self.uses_mut(); + uses.clear(); + } +} + +pub type OpResultStorage = crate::EntityStorage; +pub type OpResultRange<'a> = crate::EntityRange<'a, OpResultRef>; +pub type OpResultRangeMut<'a> = crate::EntityRangeMut<'a, OpResultRef, 1>; diff --git a/hir2/src/ir/verifier.rs b/hir2/src/ir/verifier.rs index e3bcb2599..9b5c2bcfe 100644 --- a/hir2/src/ir/verifier.rs +++ b/hir2/src/ir/verifier.rs @@ -132,17 +132,63 @@ pub trait Verifier { /// in verifier selection, so that the resulting implementation is specialized and able to /// have more optimizations applied as a result. /// - /// ```rust,ignore + /// ```rust + /// use midenc_hir2::{Context, Report, Verify, verifier::Verifier}; + /// + /// /// This trait is a marker for a type that should never validate, i.e. verifying it always + /// /// returns an error. + /// trait Nope {} + /// impl Verify for Any { + /// fn verify(&self, context: &Context) -> Result<(), Report> { + /// Err(Report::msg("nope")) + /// } + /// } + /// + /// /// We can't impl the `Verify` trait for all T outside of `midenc_hir2`, so we newtype all T + /// /// to do so, in effect, we can mimic the effect of implementing for all T using this type. + /// struct Any(core::marker::PhantomData); + /// impl Any { + /// fn new() -> Self { + /// Self(core::marker::PhantomData) + /// } + /// } + /// + /// /// This struct implements `Nope`, so it has an explicit verifier, which always fails + /// struct AlwaysRejected; + /// impl Nope for AlwaysRejected {} + /// + /// /// This struct doesn't implement `Nope`, so it gets a vacuous verifier, which always + /// /// succeeds. + /// struct AlwaysAccepted; + /// + /// /// Our vacuous verifier impl /// #[inline(always)] - /// fn noop(&T, &Context) -> Result<(), Report> { Ok(()) } - /// let verify_fn = const { - /// if >::VACUOUS { + /// fn noop(_: &Any, _: &Context) -> Result<(), Report> { Ok(()) } + /// + /// /// This block uses const-eval to select the verifier for Any statically + /// let always_accepted = const { + /// if as Verifier>::VACUOUS { /// noop /// } else { - /// >::maybe_verify + /// as Verifier>::maybe_verify /// } /// }; - /// verify_fn(op, context) + /// + /// /// This block uses const-eval to select the verifier for Any statically + /// let always_rejected = const { + /// if as Verifier>::VACUOUS { + /// noop + /// } else { + /// as Verifier>::maybe_verify + /// } + /// }; + /// + /// /// Verify that we got the correct impls. We can't verify that all of the abstraction was + /// /// eliminated, but from reviewing the assembly output, it appears that this is precisely + /// /// what happens. + /// let context = Context::default(); + /// assert!(always_accepted(&Any::new(), &context).is_ok()); + /// assert!(always_rejected(&Any::new(), &context).is_err()); /// ``` const VACUOUS: bool; diff --git a/hir2/src/ir/visit.rs b/hir2/src/ir/visit.rs index b0b09656a..0cef13b17 100644 --- a/hir2/src/ir/visit.rs +++ b/hir2/src/ir/visit.rs @@ -1,7 +1,9 @@ -use alloc::collections::VecDeque; pub use core::ops::ControlFlow; -use crate::{BlockRef, Op, Operation, OperationRef, Symbol}; +use crate::{ + Block, BlockRef, Op, Operation, OperationRef, Region, RegionRef, Report, Symbol, + UnsafeIntrusiveEntityRef, +}; /// A generic trait that describes visitors for all kinds pub trait Visitor { @@ -9,18 +11,18 @@ pub trait Visitor { type Output; /// The function which is applied to each `T` as it is visited. - fn visit(&mut self, current: &T) -> ControlFlow; + fn visit(&mut self, current: &T) -> WalkResult; } /// We can automatically convert any closure of appropriate type to a `Visitor` impl Visitor for F where - F: FnMut(&T) -> ControlFlow, + F: FnMut(&T) -> WalkResult, { type Output = U; #[inline] - fn visit(&mut self, op: &T) -> ControlFlow { + fn visit(&mut self, op: &T) -> WalkResult { self(op) } } @@ -37,126 +39,458 @@ impl OpVisitor for V where V: Visitor {} pub trait SymbolVisitor: Visitor {} impl SymbolVisitor for V where V: Visitor {} -/// [Searcher] is a driver for [Visitor] impls as applied to some root [Operation]. +/// A result-like type used to control traversals of a [Walkable] entity. /// -/// It traverses the objects reachable from the root as follows: +/// It is comparable to [core::ops::ControlFlow], with an additional option to continue traversal, +/// but with a sibling, rather than visiting any further children of the current item. /// -/// * The root operation is visited first -/// * Then for each region of the root, the entry block is visited top to bottom, enqueing any nested -/// blocks of those operations to be visited after all blocks of region have been visited. When the -/// entry block has been visited, the process is repeated for the remaining blocks of the region. -/// * When all regions of the root have been visited, and no more blocks remain in the queue, the -/// traversal is complete +/// It is compatible with the `?` operator, however doing so will exit early on _both_ `Skip` and +/// `Break`, so you should be aware of that when using the try operator. +#[derive(Clone)] +#[must_use] +pub enum WalkResult { + /// Continue the traversal normally, optionally producing a value for the current item. + Continue(C), + /// Skip traversing the current item and any children that have not been visited yet, and + /// continue the traversal with the next sibling of the current item. + Skip, + /// Stop the traversal entirely, and optionally returning a value associated with the break. + // + /// This can be used to represent both errors, and the successful result of a search. + Break(B), +} +impl WalkResult { + /// Returns true if the walk should continue + pub fn should_continue(&self) -> bool { + matches!(self, Self::Continue(_)) + } + + /// Returns true if the walk was interrupted + pub fn was_interrupted(&self) -> bool { + matches!(self, Self::Break(_)) + } + + /// Returns true if the walk was skipped + pub fn was_skipped(&self) -> bool { + matches!(self, Self::Skip) + } +} +impl WalkResult { + /// Convert this [WalkResult] into an equivalent [Result] + #[inline] + pub fn into_result(self) -> Result<(), B> { + match self { + Self::Break(err) => Err(err), + Self::Skip | Self::Continue(_) => Ok(()), + } + } +} +impl From> for WalkResult { + fn from(value: Result<(), B>) -> Self { + match value { + Ok(_) => WalkResult::Continue(()), + Err(err) => WalkResult::Break(err), + } + } +} +impl From> for Result<(), B> { + #[inline(always)] + fn from(value: WalkResult) -> Self { + value.into_result() + } +} +impl core::ops::FromResidual for WalkResult { + fn from_residual(residual: ::Residual) -> Self { + match residual { + WalkResult::Break(b) => WalkResult::Break(b), + _ => unreachable!(), + } + } +} +impl core::ops::Residual for WalkResult { + type TryType = WalkResult; +} +impl core::ops::Try for WalkResult { + type Output = C; + type Residual = WalkResult; + + #[inline] + fn from_output(output: Self::Output) -> Self { + WalkResult::Continue(output) + } + + #[inline] + fn branch(self) -> ControlFlow { + match self { + WalkResult::Continue(c) => ControlFlow::Continue(c), + WalkResult::Skip => ControlFlow::Break(WalkResult::Skip), + WalkResult::Break(b) => ControlFlow::Break(WalkResult::Break(b)), + } + } +} + +/// The traversal order for a walk of a region, block, or operation +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum WalkOrder { + PreOrder, + PostOrder, +} + +/// Encodes the current walk stage for generic walkers. /// -/// This traversal is _not_ in control flow order, _or_ data flow order, so you should not rely on -/// the order in which operations are visited for your [Visitor] implementation. -pub struct Searcher { - visitor: V, - queue: VecDeque, - current: Option, - started: bool, - _marker: core::marker::PhantomData, +/// When walking an operation, we can either choose a pre- or post-traversal walker which invokes +/// the callback on an operation before/after all its attached regions have been visited, or choose +/// a generic walker where the callback is invoked on the operation N+1 times, where N is the number +/// of regions attached to that operation. [WalkStage] encodes the current stage of the walk, i.e. +/// which regions have already been visited, and the callback accepts an additional argument for +/// the current stage. Such generic walkers that accept stage-aware callbacks are only applicable +/// when the callback operations on an operation (i.e. doesn't apply to callbacks on blocks or +/// regions). +#[derive(Clone, PartialEq, Eq)] +pub struct WalkStage { + /// The number of regions in the operation + num_regions: usize, + /// The next region to visit in the operation + next_region: usize, } -impl> Searcher { - pub fn new(root: OperationRef, visitor: V) -> Self { +impl WalkStage { + pub fn new(op: OperationRef) -> Self { + let op = op.borrow(); Self { - visitor, - queue: VecDeque::default(), - current: Some(root), - started: false, - _marker: core::marker::PhantomData, + num_regions: op.num_regions(), + next_region: 0, } } + /// Returns true if the parent operation is being visited before all regions. + #[inline] + pub fn is_before_all_regions(&self) -> bool { + self.next_region == 0 + } + + /// Returns true if the parent operation is being visited just before visiting `region` + #[inline] + pub fn is_before_region(&self, region: usize) -> bool { + self.next_region == region + } + + /// Returns true if the parent operation is being visited just after visiting `region` #[inline] - fn next(&mut self) -> Option { - visit_next(&mut self.started, &mut self.current, &mut self.queue) + pub fn is_after_region(&self, region: usize) -> bool { + self.next_region == region + 1 + } + + /// Returns true if the parent operation is being visited after all regions. + #[inline] + pub fn is_after_all_regions(&self) -> bool { + self.next_region == self.num_regions + } + + /// Advance the walk stage + #[inline] + pub fn advance(&mut self) { + self.next_region += 1; + } + + /// Returns the next region that will be visited + #[inline(always)] + pub const fn next_region(&self) -> usize { + self.next_region } } -impl Searcher { - pub fn visit(&mut self) -> ControlFlow<>::Output> { - while let Some(op) = self.next() { - let op = op.borrow(); - self.visitor.visit(&op)?; - } +/// A [Walkable] is an entity which can be traversed depth-first in either pre- or post-order +/// +/// An implementation of this trait specifies a type, `T`, corresponding to the type of item being +/// walked, while `Self` is the root entity, possibly of the same type, which may contain `T`. Thus +/// traversing from the root to all of the leaves, we will visit all reachable `T` nested within +/// `Self`, possibly including itself. +pub trait Walkable { + /// Walk all `T` in `self` in a specific order, applying the given callback to each. + /// + /// This is very similar to [Walkable::walk_interruptible], except the callback has no control + /// over the traversal, and must be infallible. + #[inline] + fn walk(&self, order: WalkOrder, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.walk_interruptible(order, |t| { + callback(t); + + WalkResult::<()>::Continue(()) + }); + } + + /// Walk all `T` in `self` using a pre-order, depth-first traversal, applying the given callback + /// to each `T`. + #[inline] + fn prewalk(&self, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.prewalk_interruptible(|t| { + callback(t); - ControlFlow::Continue(()) + WalkResult::<()>::Continue(()) + }); } + + /// Walk all `T` in `self` using a post-order, depth-first traversal, applying the given callback + /// to each `T`. + #[inline] + fn postwalk(&self, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.postwalk_interruptible(|t| { + callback(t); + + WalkResult::<()>::Continue(()) + }); + } + + /// Walk `self` in the given order, visiting each `T` and applying the given callback to them. + /// + /// The given callback can control the traversal using the [WalkResult] it returns: + /// + /// * `WalkResult::Skip` will skip the walk of the current item and its nested elements that + /// have not been visited already, continuing with the next item. + /// * `WalkResult::Break` will interrupt the walk, and no more items will be visited + /// * `WalkResult::Continue` will continue the walk + #[inline] + fn walk_interruptible(&self, order: WalkOrder, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult, + { + match order { + WalkOrder::PreOrder => self.prewalk_interruptible(callback), + WalkOrder::PostOrder => self.prewalk_interruptible(callback), + } + } + + /// Walk all `T` in `self` using a pre-order, depth-first traversal, applying the given callback + /// to each `T`, and determining how to proceed based on the returned [WalkResult]. + fn prewalk_interruptible(&self, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult; + + /// Walk all `T` in `self` using a post-order, depth-first traversal, applying the given callback + /// to each `T`, and determining how to proceed based on the returned [WalkResult]. + fn postwalk_interruptible(&self, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult; } -impl> Searcher { - pub fn visit(&mut self) -> ControlFlow<>::Output> { - while let Some(op) = self.next() { - let op = op.borrow(); - if let Some(op) = op.downcast_ref::() { - self.visitor.visit(op)?; +/// Walking regions of an [Operation], and those of all nested operations +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(RegionRef) -> WalkResult, + { + let mut regions = self.regions().front(); + while let Some(region) = regions.as_pointer() { + regions.move_next(); + match callback(region.clone()) { + WalkResult::Continue(_) => { + let region = region.borrow(); + for block in region.body().iter() { + for op in block.body().iter() { + op.prewalk_interruptible(&mut callback)?; + } + } + } + WalkResult::Skip => continue, + result @ WalkResult::Break(_) => return result, } } - ControlFlow::Continue(()) + WalkResult::Continue(()) } -} -impl Searcher { - pub fn visit(&mut self) -> ControlFlow<>::Output> { - while let Some(op) = self.next() { - let op = op.borrow(); - if let Some(op) = op.as_symbol() { - self.visitor.visit(op)?; + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(RegionRef) -> WalkResult, + { + let mut regions = self.regions().front(); + while let Some(region) = regions.as_pointer() { + regions.move_next(); + { + let region = region.borrow(); + for block in region.body().iter() { + for op in block.body().iter() { + op.postwalk_interruptible(&mut callback)?; + } + } } + callback(region.clone())?; } - ControlFlow::Continue(()) - } -} - -/// Outlined implementation of the traversal performed by `Searcher` -#[inline(never)] -fn visit_next( - started: &mut bool, - current: &mut Option, - queue: &mut VecDeque, -) -> Option { - if !*started { - *started = true; - let curr = current.take()?; - // When just starting, we're at the root, so we descend into the operation, rather - // than visiting its next sibling. - { - let op = curr.borrow(); - for region in op.regions().iter() { - let mut cursor = region.body().front(); - if current.is_none() { - let entry = cursor.as_pointer().expect("invalid region: has no entry block"); - let entry = entry.borrow(); - let next = entry.body().front().as_pointer(); - *current = next; - cursor.move_next(); + WalkResult::Continue(()) + } +} + +/// Walking blocks of an [Operation], and those of all nested operations +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(BlockRef) -> WalkResult, + { + for region in self.regions().iter() { + let mut blocks = region.body().front(); + while let Some(block) = blocks.as_pointer() { + blocks.move_next(); + match callback(block.clone()) { + WalkResult::Continue(_) => { + let block = block.borrow(); + for op in block.body().iter() { + op.prewalk_interruptible(&mut callback)?; + } + } + WalkResult::Skip => continue, + result @ WalkResult::Break(_) => return result, } - while let Some(block) = cursor.as_pointer() { - queue.push_back(block); - cursor.move_next(); + } + } + + WalkResult::Continue(()) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(BlockRef) -> WalkResult, + { + for region in self.regions().iter() { + let mut blocks = region.body().front(); + while let Some(block) = blocks.as_pointer() { + blocks.move_next(); + { + let block = block.borrow(); + for op in block.body().iter() { + op.postwalk_interruptible(&mut callback)?; + } } + callback(block.clone())?; + } + } + + WalkResult::Continue(()) + } +} + +/// Walking operations nested within an [Operation], including itself +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + prewalk_operation_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + postwalk_operation_interruptible(self, &mut callback) + } +} + +fn prewalk_operation_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + let result = callback(op.as_operation_ref()); + if !result.should_continue() { + return result; + } + + for region in op.regions().iter() { + for block in region.body().iter() { + let mut ops = block.body().front(); + while let Some(op) = ops.as_pointer() { + ops.move_next(); + let op = op.borrow(); + prewalk_operation_interruptible(&op, callback)?; } } - return Some(curr); } - // Here, we've already visited the root operation, so one of the following is true: - // - // * `current` is `None`, so pop the next block from the queue, if there are no more blocks, - // then we're done visiting and can return `None`. If there is a block, then we set - // `current` to the first operation in that block, and retry - // * `current` is `Some`, so obtain the next value of `current` by obtaining the next - // sibling operation of the current operation. - while current.is_none() { - let block = queue.pop_front()?; - let block = block.borrow(); - *current = block.body().front().as_pointer(); + WalkResult::Continue(()) +} + +fn postwalk_operation_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + for region in op.regions().iter() { + for block in region.body().iter() { + let mut ops = block.body().front(); + while let Some(op) = ops.as_pointer() { + ops.move_next(); + let op = op.borrow(); + postwalk_operation_interruptible(&op, callback)?; + } + } + } + + callback(op.as_operation_ref()) +} + +/// [Searcher] is a driver for [Visitor] impls as applied to some root [Operation]. +/// +/// The searcher traverses the object graph in depth-first preorder, from operations to regions to +/// blocks to operations, etc. All nested items of an entity are visited before its siblings, i.e. +/// a region is fully visited before the next region of the same containing operation. +/// +/// This is effectively control-flow order, from an abstract interpretation perspective, i.e. an +/// actual program might only execute one region of a multi-region op, but this traversal will visit +/// all of them unless otherwise directed by a `WalkResult`. +pub struct Searcher { + visitor: V, + root: OperationRef, + _marker: core::marker::PhantomData, +} +impl> Searcher { + pub fn new(root: OperationRef, visitor: V) -> Self { + Self { + visitor, + root, + _marker: core::marker::PhantomData, + } } +} - let next = current.as_ref().and_then(|curr| curr.next()); +impl Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + self.visitor.visit(&op) + }) + } +} + +impl> Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + if let Some(op) = op.downcast_ref::() { + self.visitor.visit(op) + } else { + WalkResult::Continue(()) + } + }) + } +} - core::mem::replace(current, next) +impl Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + if let Some(sym) = op.as_symbol() { + self.visitor.visit(sym) + } else { + WalkResult::Continue(()) + } + }) + } } diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index 696216177..c9e66a4b7 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -10,6 +10,13 @@ #![feature(debug_closure_helpers)] #![feature(trait_alias)] #![feature(is_none_or)] +#![feature(try_trait_v2)] +#![feature(try_trait_v2_residual)] +#![feature(tuple_trait)] +#![feature(fn_traits)] +#![feature(unboxed_closures)] +#![feature(const_type_id)] +#![feature(exact_size_is_empty)] #![allow(incomplete_features)] #![allow(internal_features)] @@ -18,6 +25,8 @@ extern crate alloc; #[cfg(feature = "std")] extern crate std; +extern crate self as midenc_hir2; + pub use compact_str::{ CompactString as SmallStr, CompactStringExt as SmallStrExt, ToCompactString as ToSmallStr, }; @@ -29,19 +38,23 @@ pub mod derive; pub mod dialects; pub mod formatter; mod ir; +mod patterns; -pub use self::{attributes::*, ir::*}; +pub use self::{any::AsAny, attributes::*, ir::*, patterns::*}; // TODO(pauls): The following is a rough list of what needs to be implemented for the IR // refactoring to be complete and usable in place of the old IR (some are optional): // // * constants and constant-like ops // * global variables and global ops +// * Need to implement InferTypeOpInterface for all applicable ops // * Builders (i.e. component builder, interface builder, module builder, function builder, last is most important) -// NOTE: The underlying builder infra is basically done, so layering on the high-level builders is pretty simple +// NOTE: The underlying builder infra is done, so layering on the high-level builders is pretty simple // * canonicalization (optional) -// * visitors (partially complete, need CFG and DFG walkers as well though, largely variations on the existing infra) -// * pattern matching/rewrites (needed for legalization/conversion) +// * pattern matching/rewrites (needed for legalization/conversion, mostly complete, see below) +// - Need to provide implementations of stubbed out rewriter methods +// - Need to implement the GreedyRewritePatternDriver +// - Need to implement matchers // * dataflow analysis framework (required to replace old analyses) // * linking/global symbol resolution (required to replace old linker, partially implemented via symbols/symbol tables already) // * legalization/dialect conversion (required to convert between unstructured and structured control flow dialects at minimum) diff --git a/hir2/src/patterns.rs b/hir2/src/patterns.rs new file mode 100644 index 000000000..7fa4d9c69 --- /dev/null +++ b/hir2/src/patterns.rs @@ -0,0 +1,11 @@ +mod applicator; +mod pattern; +mod pattern_set; +mod rewriter; + +pub use self::{ + applicator::PatternApplicator, + pattern::*, + pattern_set::{FrozenRewritePatternSet, RewritePatternSet}, + rewriter::*, +}; diff --git a/hir2/src/patterns/applicator.rs b/hir2/src/patterns/applicator.rs new file mode 100644 index 000000000..75fd4b517 --- /dev/null +++ b/hir2/src/patterns/applicator.rs @@ -0,0 +1,179 @@ +use alloc::{collections::BTreeMap, rc::Rc}; + +use smallvec::SmallVec; + +use super::{FrozenRewritePatternSet, PatternBenefit, PatternRewriter, RewritePattern}; +use crate::{Builder, OperationName, OperationRef, Report}; + +/// This type manages the application of a group of rewrite patterns, with a user-provided cost model +pub struct PatternApplicator { + /// The list that owns the patterns used within this applicator + rewrite_patterns_set: Rc, + /// The set of patterns to match for each operation, stable sorted by benefit. + patterns: BTreeMap; 2]>>, + /// The set of patterns that may match against any operation type, stable sorted by benefit. + match_any_patterns: SmallVec<[Rc; 1]>, +} +impl PatternApplicator { + pub fn new(rewrite_patterns_set: Rc) -> Self { + Self { + rewrite_patterns_set, + patterns: Default::default(), + match_any_patterns: Default::default(), + } + } + + /// Apply a cost model to the patterns within this applicator. + pub fn apply_cost_model(&mut self, model: CostModel) + where + CostModel: Fn(&dyn RewritePattern) -> PatternBenefit, + { + // Clear the results computed by the previous cost model + self.match_any_patterns.clear(); + self.patterns.clear(); + + // Filter out op-specific patterns with no benefit, and order by highest benefit first + let mut benefits = Vec::default(); + for (op, op_patterns) in self.rewrite_patterns_set.op_specific_patterns().iter() { + benefits + .extend(op_patterns.iter().filter_map(|p| filter_map_pattern_benefit(p, &model))); + benefits.sort_by_key(|(_, benefit)| *benefit); + self.patterns + .insert(op.clone(), benefits.drain(..).map(|(pat, _)| pat).collect()); + } + + // Filter out "match any" patterns with no benefit, and order by highest benefit first + benefits.extend( + self.rewrite_patterns_set + .any_op_patterns() + .iter() + .filter_map(|p| filter_map_pattern_benefit(p, &model)), + ); + benefits.sort_by_key(|(_, benefit)| *benefit); + self.match_any_patterns.extend(benefits.into_iter().map(|(pat, _)| pat)); + } + + /// Apply the default cost model that solely uses the pattern's static benefit + #[inline] + pub fn apply_default_cost_model(&mut self) { + self.apply_cost_model(|pattern| pattern.benefit()); + } + + /// Walk all of the patterns within the applicator. + pub fn walk_all_patterns(&self, mut callback: F) + where + F: FnMut(Rc), + { + for patterns in self.rewrite_patterns_set.op_specific_patterns().values() { + for pattern in patterns { + callback(Rc::clone(pattern)); + } + } + for pattern in self.rewrite_patterns_set.any_op_patterns() { + callback(Rc::clone(pattern)); + } + } + + pub fn match_and_rewrite( + &mut self, + op: OperationRef, + rewriter: &mut PatternRewriter, + can_apply: Option, + mut on_failure: Option, + mut on_success: Option, + ) -> Result<(), Report> + where + A: Fn(&dyn RewritePattern) -> bool, + F: FnMut(&dyn RewritePattern), + S: FnMut(&dyn RewritePattern) -> Result<(), Report>, + { + // Check to see if there are patterns matching this specific operation type. + let op_name = { + let op = op.borrow(); + op.name() + }; + let op_specific_patterns = self.patterns.get(&op_name).map(|p| p.as_slice()).unwrap_or(&[]); + + // Process the op-specific patterns and op-agnostic patterns in an interleaved fashion + let mut op_patterns = op_specific_patterns.iter().peekable(); + let mut any_op_patterns = self.match_any_patterns.iter().peekable(); + loop { + // Find the next pattern with the highest benefit + // + // 1. Start with the assumption that we'll use the next op-specific pattern + let mut best_pattern = op_patterns.peek().copied(); + // 2. But take the next op-agnostic pattern instead, IF: + // a. There are no more op-specific patterns + // b. The benefit of the op-agnostic pattern is higher than the op-specific pattern + if let Some(next_any_pattern) = any_op_patterns + .next_if(|p| best_pattern.is_none_or(|bp| bp.benefit() < p.benefit())) + { + best_pattern.replace(next_any_pattern); + } else { + // The op-specific pattern is best, so actually consume it from the iterator + best_pattern = op_patterns.next(); + } + + // Break if we have exhausted all patterns + let Some(best_pattern) = best_pattern else { + break; + }; + + // Can we apply this pattern? + let applicable = can_apply.as_ref().is_none_or(|can_apply| can_apply(&**best_pattern)); + if !applicable { + continue; + } + + // Try to match and rewrite this pattern. + // + // The patterns are sorted by benefit, so if we match we can immediately rewrite. + rewriter.set_insertion_point_before(crate::ProgramPoint::Op(op.clone())); + + // TODO: Save the nearest parent IsolatedFromAbove op of this op for use in debug + // messages/rendering, as the rewrite may invalidate `op` + log::debug!("trying to match '{}'", best_pattern.name()); + + if best_pattern.match_and_rewrite(op.clone(), rewriter)? { + log::debug!("successfully matched pattern '{}'", best_pattern.name()); + if let Some(on_success) = on_success.as_mut() { + on_success(&**best_pattern)?; + } + break; + } else { + // Perform any necessary cleanup + log::debug!("failed to match pattern '{}'", best_pattern.name()); + if let Some(on_failure) = on_failure.as_mut() { + on_failure(&**best_pattern); + } + } + } + + Ok(()) + } +} + +fn filter_map_pattern_benefit( + pattern: &Rc, + cost_model: &CostModel, +) -> Option<(Rc, PatternBenefit)> +where + CostModel: Fn(&dyn RewritePattern) -> PatternBenefit, +{ + let benefit = if pattern.benefit().is_impossible_to_match() { + PatternBenefit::NONE + } else { + cost_model(&**pattern) + }; + if benefit.is_impossible_to_match() { + log::debug!( + "ignoring pattern '{}' ({}) because it is impossible to match or cannot lead to legal \ + IR (by cost model)", + pattern.name(), + pattern.kind(), + ); + None + } else { + Some((Rc::clone(pattern), benefit)) + } +} diff --git a/hir2/src/patterns/pattern.rs b/hir2/src/patterns/pattern.rs new file mode 100644 index 000000000..a7cc9b7d8 --- /dev/null +++ b/hir2/src/patterns/pattern.rs @@ -0,0 +1,226 @@ +use alloc::rc::Rc; +use core::{any::TypeId, fmt}; + +use smallvec::SmallVec; + +use super::PatternRewriter; +use crate::{interner, Context, OperationName, OperationRef, Report}; + +#[derive(Debug)] +pub enum PatternKind { + /// The pattern root matches any operation + Any, + /// The pattern root matches a specific named operation + Operation(OperationName), + /// The pattern root matches a specific trait + Trait(TypeId), +} +impl fmt::Display for PatternKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Any => f.write_str("for any"), + Self::Operation(name) => write!(f, "for operation '{name}'"), + Self::Trait(_) => write!(f, "for trait"), + } + } +} + +/// Represents the benefit a pattern has. +/// +/// More beneficial patterns are preferred over those with lesser benefit, while patterns with no +/// benefit whatsoever can be discarded. +/// +/// This is used to evaluate which patterns to apply, and in what order. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +#[repr(transparent)] +pub struct PatternBenefit(Option); +impl PatternBenefit { + /// Represents a pattern which is the most beneficial + pub const MAX: Self = Self(Some(unsafe { core::num::NonZeroU16::new_unchecked(u16::MAX) })); + /// Represents a pattern which is the least beneficial + pub const MIN: Self = Self(Some(unsafe { core::num::NonZeroU16::new_unchecked(1) })); + /// Represents a pattern which can never match, and thus should be discarded + pub const NONE: Self = Self(None); + + /// Create a new [PatternBenefit] from a raw [u16] value. + /// + /// A value of `u16::MAX` is treated as impossible to match, while values from `0..=65534` range + /// from the least beneficial to the most beneficial. + pub fn new(benefit: u16) -> Self { + if benefit == u16::MAX { + Self(None) + } else { + Self(Some(unsafe { core::num::NonZeroU16::new_unchecked(benefit + 1) })) + } + } + + /// Returns true if the pattern benefit indicates it can never match + #[inline] + pub fn is_impossible_to_match(&self) -> bool { + self.0.is_none() + } +} + +impl PartialOrd for PatternBenefit { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for PatternBenefit { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + use core::cmp::Ordering; + match (self.0, other.0) { + (None, None) => Ordering::Equal, + // Impossible to match is always last + (None, Some(_)) => Ordering::Greater, + (Some(_), None) => Ordering::Less, + // Benefits are ordered in reverse of integer order (higher benefit appears earlier) + (Some(a), Some(b)) => a.get().cmp(&b.get()).reverse(), + } + } +} + +/// A [Pattern] describes all of the data related to a pattern, but does not express any actual +/// pattern logic, i.e. it is solely used for metadata about a pattern. +pub struct Pattern { + #[allow(unused)] + context: Rc, + name: &'static str, + kind: PatternKind, + #[allow(unused)] + labels: SmallVec<[interner::Symbol; 1]>, + benefit: PatternBenefit, + has_bounded_recursion: bool, + generated_ops: SmallVec<[OperationName; 0]>, +} +impl Pattern { + /// Create a new [Pattern] from its component parts. + pub fn new( + context: Rc, + name: &'static str, + kind: PatternKind, + benefit: PatternBenefit, + ) -> Self { + Self { + context, + name, + kind, + labels: SmallVec::default(), + benefit, + has_bounded_recursion: false, + generated_ops: SmallVec::default(), + } + } + + /// A name used when printing diagnostics related to this pattern + #[inline(always)] + pub const fn name(&self) -> &'static str { + self.name + } + + /// The kind of value used to select candidate root operations for this pattern. + #[inline(always)] + pub const fn kind(&self) -> &PatternKind { + &self.kind + } + + /// Returns the benefit - the inverse of "cost" - of matching this pattern. + /// + /// The benefit of a [Pattern] is always static - rewrites that may have dynamic benefit can be + /// instantiated multiple times (different instances), for each benefit that they may return, + /// and be guarded by different match condition predicates. + #[inline(always)] + pub const fn benefit(&self) -> PatternBenefit { + self.benefit + } + + /// Return a list of operations that may be generated when rewriting an operation instance + /// with this pattern. + #[inline] + pub fn generated_ops(&self) -> &[OperationName] { + &self.generated_ops + } + + /// Return the root operation that this pattern matches. + /// + /// Patterns that can match multiple root types return `None` + pub fn get_root_operation(&self) -> Option { + match self.kind { + PatternKind::Operation(ref name) => Some(name.clone()), + _ => None, + } + } + + /// Return the trait id used to match the root operation of this pattern. + /// + /// If the pattern does not use a trait id for deciding the root match, this returns `None` + pub fn get_root_trait(&self) -> Option { + match self.kind { + PatternKind::Trait(type_id) => Some(type_id), + _ => None, + } + } + + /// Returns true if this pattern is known to result in recursive application, i.e. this pattern + /// may generate IR that also matches this pattern, but is known to bound the recursion. This + /// signals to the rewrite driver that it is safe to apply this pattern recursively to the + /// generated IR. + #[inline(always)] + pub const fn has_bounded_rewrite_recursion(&self) -> bool { + self.has_bounded_recursion + } + + /// Set whether or not this pattern has bounded rewrite recursion + #[inline(always)] + pub fn with_bounded_rewrite_recursion(&mut self, yes: bool) -> &mut Self { + self.has_bounded_recursion = yes; + self + } +} + +/// A [RewritePattern] represents two things: +/// +/// * A pattern which matches some IR that we're interested in, typically to replace with something +/// else. +/// * A rewrite which replaces IR that maches the pattern, with new IR, i.e. a DAG-to-DAG +/// replacement +/// +/// Implementations must provide `matches` and `rewrite` implementations, from which the +/// `match_and_rewrite` implementation is derived. +pub trait RewritePattern { + /// A name to use for this pattern in diagnostics + fn name(&self) -> &'static str { + core::any::type_name::() + } + /// The pattern used to match candidate root operations for this rewrite. + fn kind(&self) -> &PatternKind; + /// The estimated benefit of this pattern + fn benefit(&self) -> PatternBenefit; + /// Whether or not this rewrite pattern has bounded recursion + fn has_bounded_rewrite_recursion(&self) -> bool; + /// Rewrite the IR rooted at the specified operation with the result of this pattern, generating + /// any new operations with the specified builder. If an unexpected error is encountered, i.e. + /// an internal compiler error, it is emitted through the normal diagnostic system, and the IR + /// is left in a valid state. + fn rewrite(&self, op: OperationRef, rewriter: &mut PatternRewriter); + + /// Attempt to match this pattern against the IR rooted at the specified operation, + /// which is the same operation as [Pattern::kind]. + fn matches(&self, op: OperationRef) -> Result; + + /// Attempt to match this pattern against the IR rooted at the specified operation. If + /// matching is successful, the rewrite is automatically applied. + fn match_and_rewrite( + &self, + op: OperationRef, + rewriter: &mut PatternRewriter, + ) -> Result { + if self.matches(op.clone())? { + self.rewrite(op, rewriter); + + Ok(true) + } else { + Ok(false) + } + } +} diff --git a/hir2/src/patterns/pattern_set.rs b/hir2/src/patterns/pattern_set.rs new file mode 100644 index 000000000..e8375964b --- /dev/null +++ b/hir2/src/patterns/pattern_set.rs @@ -0,0 +1,110 @@ +use alloc::{collections::BTreeMap, rc::Rc}; + +use smallvec::SmallVec; + +use super::*; +use crate::{Context, OperationName}; + +pub struct RewritePatternSet { + context: Rc, + patterns: Vec>, +} +impl RewritePatternSet { + pub fn new(context: Rc) -> Self { + Self { + context, + patterns: vec![], + } + } + + pub fn from_iter

(context: Rc, patterns: P) -> Self + where + P: IntoIterator>, + { + Self { + context, + patterns: patterns.into_iter().collect(), + } + } + + #[inline] + pub fn context(&self) -> Rc { + Rc::clone(&self.context) + } + + #[inline] + pub fn patterns(&self) -> &[Box] { + &self.patterns + } + + pub fn push(&mut self, pattern: impl RewritePattern + 'static) { + self.patterns.push(Box::new(pattern)); + } +} + +pub struct FrozenRewritePatternSet { + context: Rc, + patterns: Vec>, + op_specific_patterns: BTreeMap; 2]>>, + any_op_patterns: SmallVec<[Rc; 1]>, +} +impl FrozenRewritePatternSet { + pub fn new(patterns: RewritePatternSet) -> Self { + let RewritePatternSet { context, patterns } = patterns; + let mut this = Self { + context, + patterns: Default::default(), + op_specific_patterns: Default::default(), + any_op_patterns: Default::default(), + }; + + for pattern in patterns { + let pattern = Rc::::from(pattern); + match pattern.kind() { + PatternKind::Operation(name) => { + this.op_specific_patterns + .entry(name.clone()) + .or_default() + .push(Rc::clone(&pattern)); + this.patterns.push(pattern); + } + PatternKind::Trait(ref trait_id) => { + for dialect in this.context.registered_dialects().values() { + for op in dialect.registered_ops().iter() { + if op.implements_trait_id(&trait_id) { + this.op_specific_patterns + .entry(op.clone()) + .or_default() + .push(Rc::clone(&pattern)); + } + } + } + this.patterns.push(pattern); + } + PatternKind::Any => { + this.any_op_patterns.push(Rc::clone(&pattern)); + this.patterns.push(pattern); + } + } + } + + this + } + + #[inline] + pub fn patterns(&self) -> &[Rc] { + &self.patterns + } + + #[inline] + pub fn op_specific_patterns( + &self, + ) -> &BTreeMap; 2]>> { + &self.op_specific_patterns + } + + #[inline] + pub fn any_op_patterns(&self) -> &[Rc] { + &self.any_op_patterns + } +} diff --git a/hir2/src/patterns/rewriter.rs b/hir2/src/patterns/rewriter.rs new file mode 100644 index 000000000..88514e317 --- /dev/null +++ b/hir2/src/patterns/rewriter.rs @@ -0,0 +1,523 @@ +#![allow(unused)] +use alloc::rc::Rc; +use core::ops::{Deref, DerefMut}; + +use crate::{ + BlockRef, Builder, Context, InsertionPoint, Listener, ListenerType, OpBuilder, OpOperand, + OpResultRef, OperationRef, Pattern, RegionRef, Report, SourceSpan, Type, ValueRef, +}; + +/// A special type of `RewriterBase` that coordinates the application of a rewrite pattern on the +/// current IR being matched, providing a way to keep track of any mutations made. +/// +/// This type should be used to perform all necessary IR mutations within a rewrite pattern, as +/// the pattern driver may be tracking various state that would be invalidated when a mutation takes +/// place. +pub struct PatternRewriter { + rewriter: RewriterImpl, + recoverable: bool, +} +impl PatternRewriter { + pub fn new(builder: OpBuilder) -> Self { + Self { + rewriter: RewriterImpl::new(builder), + recoverable: false, + } + } + + #[inline] + pub const fn can_recover_from_rewrite_failure(&self) -> bool { + self.recoverable + } +} +impl Deref for PatternRewriter { + type Target = RewriterImpl; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.rewriter + } +} +impl DerefMut for PatternRewriter { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.rewriter + } +} + +pub struct RewriterImpl { + builder: OpBuilder, + listener: Option>, +} + +impl Listener for RewriterImpl { + fn kind(&self) -> ListenerType { + ListenerType::Rewriter + } + + fn notify_block_inserted( + &mut self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + if let Some(listener) = self.listener.as_deref_mut() { + listener.notify_block_inserted(block, prev, ip); + } else { + self.builder.notify_block_inserted(block, prev, ip); + } + } + + fn notify_operation_inserted(&mut self, op: OperationRef, prev: Option) { + if let Some(listener) = self.listener.as_deref_mut() { + listener.notify_operation_inserted(op, prev); + } else { + self.builder.notify_operation_inserted(op, prev); + } + } +} + +impl Builder for RewriterImpl { + #[inline(always)] + fn context(&self) -> &Context { + self.builder.context() + } + + #[inline(always)] + fn context_rc(&self) -> Rc { + self.builder.context_rc() + } + + #[inline(always)] + fn insertion_point(&self) -> Option<&InsertionPoint> { + self.builder.insertion_point() + } + + #[inline(always)] + fn clear_insertion_point(&mut self) -> Option { + self.builder.clear_insertion_point() + } + + #[inline(always)] + fn restore_insertion_point(&mut self, ip: Option) { + self.builder.restore_insertion_point(ip); + } + + #[inline(always)] + fn set_insertion_point(&mut self, ip: InsertionPoint) { + self.builder.set_insertion_point(ip); + } + + #[inline] + fn create_block

( + &mut self, + parent: RegionRef, + ip: Option, + args: P, + ) -> BlockRef + where + P: IntoIterator, + { + self.builder.create_block(parent, ip, args) + } + + #[inline] + fn create_block_before

(&mut self, before: BlockRef, args: P) -> BlockRef + where + P: IntoIterator, + { + self.builder.create_block_before(before, args) + } + + #[inline] + fn insert(&mut self, op: OperationRef) { + self.builder.insert(op); + } +} + +impl RewriterImpl { + pub fn new(builder: OpBuilder) -> Self { + Self { + builder, + listener: None, + } + } + + pub fn with_listener(mut self, listener: impl RewriterListener) -> Self { + self.listener = Some(Box::new(listener)); + self + } + + /// Move the blocks that belong to `region` before the given insertion point in another region, + /// `ip`. The two regions must be different. The caller is responsible for creating or + /// updating the operation transferring flow of control to the region, and passing it the + /// correct block arguments. + pub fn inline_region_before(&mut self, region: RegionRef, ip: InsertionPoint) { + todo!() + } + + /// Replace the results of the given operation with the specified list of values (replacements). + /// + /// The result types of the given op and the replacements must match. The original op is erased. + pub fn replace_op_with_values(&mut self, op: OperationRef, values: V) + where + V: IntoIterator, + { + todo!() + } + + /// Replace the results of the given operation with the specified replacement op. + /// + /// The result types of the two ops must match. The original op is erased. + pub fn replace_op(&mut self, op: OperationRef, new_op: OperationRef) { + todo!() + } + + /// This method erases an operation that is known to have no uses. + pub fn erase_op(&mut self, op: OperationRef) { + todo!() + } + + /// This method erases all operations in a block. + pub fn erase_block(&mut self, block: BlockRef) { + todo!() + } + + /// Inline the operations of block `src` before the given insertion point. + /// The source block will be deleted and must have no uses. The `args` values, if provided, are + /// used to replace the block arguments of `src`. + /// + /// If the source block is inserted at the end of the dest block, the dest block must have no + /// successors. Similarly, if the source block is inserted somewhere in the middle (or + /// beginning) of the dest block, the source block must have no successors. Otherwise, the + /// resulting IR would have unreachable operations. + pub fn inline_block_before( + &mut self, + src: BlockRef, + ip: InsertionPoint, + args: Option<&[ValueRef]>, + ) { + todo!() + } + + /// Inline the operations of block `src` into the end of block `dest`. The source block will be + /// deleted and must have no uses. The `args` values, if present, are used to replace the block + /// arguments of `src`. + /// + /// The dest block must have no successors. Otherwise, the resulting IR will have unreachable + /// operations. + pub fn merge_blocks(&mut self, src: BlockRef, dest: BlockRef, args: Option<&[ValueRef]>) { + todo!() + } + + /// Split the operations starting at `ip` (inclusive) out of the given block into a new block, + /// and return it. + pub fn split_block(&mut self, block: BlockRef, ip: InsertionPoint) -> BlockRef { + todo!() + } + + /// Unlink this operation from its current block and insert it right before `ip`, which + /// may be in the same or another block in the same function. + pub fn move_op_before(&mut self, op: OperationRef, ip: InsertionPoint) { + todo!() + } + + /// Unlink this operation from its current block and insert it right after `ip`, which may be + /// in the same or another block in the same function. + pub fn move_op_after(&mut self, op: OperationRef, ip: InsertionPoint) { + todo!() + } + + /// Unlink this block and insert it right before `ip`. + pub fn move_block_before(&mut self, block: BlockRef, ip: InsertionPoint) { + todo!() + } + + /// This method is used to notify the rewriter that an in-place operation modification is about + /// to happen. + /// + /// The returned guard can be used to access the rewriter, as well as finalize or cancel the + /// in-place modification. + pub fn start_in_place_modification( + &mut self, + op: OperationRef, + ) -> InPlaceModificationGuard<'_> { + InPlaceModificationGuard::new(self, op) + } + + /// Performs an in-place modification of `root` using `callback`, taking care of notifying the + /// rewriter of progress and outcome of the modification. + pub fn modify_op_in_place(&mut self, root: OperationRef, callback: F) + where + F: Fn(InPlaceModificationGuard<'_>), + { + let guard = self.start_in_place_modification(root); + callback(guard); + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + pub fn replace_all_uses_of_value_with(&mut self, from: ValueRef, to: ValueRef) { + todo!() + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + pub fn replace_all_uses_of_block_with(&mut self, from: BlockRef, to: BlockRef) { + todo!() + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + pub fn replace_all_uses_with(&mut self, from: &[ValueRef], to: &[ValueRef]) { + todo!() + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place modification (for every use that was replaced), + /// and that the `from` operation is about to be replaced. + pub fn replace_all_op_uses_with_values(&mut self, from: OperationRef, to: &[ValueRef]) { + todo!() + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place modification (for every use that was replaced), + /// and that the `from` operation is about to be replaced. + pub fn replace_all_op_uses_with(&mut self, from: OperationRef, to: OperationRef) { + todo!() + } + + /// Find uses of `from` and replace them with `to`, if `predicate` returns true. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + pub fn maybe_replace_uses_of_value_with

( + &mut self, + from: ValueRef, + to: ValueRef, + predicate: P, + ) -> bool + where + P: Fn(OpOperand) -> bool, + { + todo!() + } + + /// Find uses of `from` and replace them with `to`, if `predicate` returns true. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + pub fn maybe_replace_uses_with

( + &mut self, + from: &[ValueRef], + to: &[ValueRef], + predicate: P, + ) -> bool + where + P: Fn(OpOperand) -> bool, + { + todo!() + } + + /// Find uses of `from` and replace them with `to`, if `predicate` returns true. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + pub fn maybe_replace_op_uses_with

( + &mut self, + from: OperationRef, + to: &[ValueRef], + predicate: P, + ) -> bool + where + P: Fn(OpOperand) -> bool, + { + todo!() + } + + /// Find uses of `from` within `block` and replace them with `to`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + pub fn replace_op_uses_within_block( + &mut self, + from: OperationRef, + to: &[ValueRef], + block: BlockRef, + ) -> bool { + let parent_op = block.borrow().parent_op(); + self.maybe_replace_op_uses_with(from, to, |operand| { + let operand = operand.borrow(); + !parent_op + .as_ref() + .is_some_and(|op| op.borrow().is_proper_ancestor_of(operand.owner.clone())) + }) + } + + /// Find uses of `from` and replace them with `to`, except if the user is in `exceptions`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + pub fn replace_all_uses_except( + &mut self, + from: ValueRef, + to: ValueRef, + exceptions: &[OperationRef], + ) { + self.maybe_replace_uses_of_value_with(from, to, |operand| { + let operand = operand.borrow(); + !exceptions.contains(&operand.owner) + }); + } + + pub fn notify_match_failure(&mut self, span: SourceSpan, report: Report) { + if let Some(listener) = self.listener.as_mut() { + listener.notify_match_failure(span, report); + } + } +} + +/// Wraps an in-place modification of an [Operation] to ensure the rewriter is properly notified +/// about the progress and outcome of the in-place notification. +/// +/// This is a minor efficiency win, as it avoids creating a new operation, and removing the old one, +/// but also often allows simpler code in the client. +pub struct InPlaceModificationGuard<'a> { + rewriter: &'a mut RewriterImpl, + op: OperationRef, + canceled: bool, +} +impl<'a> InPlaceModificationGuard<'a> { + fn new(rewriter: &'a mut RewriterImpl, op: OperationRef) -> Self { + Self { + rewriter, + op, + canceled: false, + } + } + + #[inline] + pub fn rewriter(&mut self) -> &mut RewriterImpl { + self.rewriter + } + + #[inline] + pub fn op(&self) -> &OperationRef { + &self.op + } + + /// Cancels the pending in-place modification. + pub fn cancel(mut self) { + self.canceled = true; + } + + /// Signals the end of an in-place modification of the current operation. + pub fn finalize(self) {} +} +impl core::ops::Deref for InPlaceModificationGuard<'_> { + type Target = RewriterImpl; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.rewriter + } +} +impl core::ops::DerefMut for InPlaceModificationGuard<'_> { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.rewriter + } +} +impl Drop for InPlaceModificationGuard<'_> { + fn drop(&mut self) { + if self.canceled { + //self.rewriter.cancel_op_modification(self.op.clone()); + todo!("cancel op modification") + } else { + //self.rewriter.finalize_op_modification(self.op.clone()); + todo!("finalize op modification") + } + } +} + +pub trait RewriterListener: Listener { + /// Notify the listener that the specified block is about to be erased. + /// + /// At this point, the block has zero uses. + fn notify_block_erased(&mut self, block: BlockRef) {} + + /// Notify the listener that the specified operation was modified in-place. + fn notify_operation_modified(&mut self, op: OperationRef) {} + + /// Notify the listener that all uses of the specified operation's results are about to be + /// replaced with the results of another operation. This is called before the uses of the old + /// operation have been changed. + /// + /// By default, this function calls the "operation replaced with values" notification. + fn notify_operation_replaced(&mut self, op: OperationRef, replacement: OperationRef) { + let replacement = replacement.borrow(); + self.notify_operation_replaced_with_values(op, replacement.results().all().as_slice()); + } + + /// Notify the listener that all uses of the specified operation's results are about to be + /// replaced with the given range of values, potentially produced by other operations. This is + /// called before the uses of the operation have been changed. + fn notify_operation_replaced_with_values( + &mut self, + op: OperationRef, + replacement: &[OpResultRef], + ) { + } + + /// Notify the listener that the specified operation is about to be erased. At this point, the + /// operation has zero uses. + /// + /// NOTE: This notification is not triggered when unlinking an operation. + fn notify_operation_erased(&mut self, op: OperationRef) {} + + /// Notify the listener that the specified pattern is about to be applied at the specified root + /// operation. + fn notify_pattern_begin(&mut self, pattern: &Pattern, op: OperationRef) {} + + /// Notify the listener that a pattern application finished with the specified status. + /// + /// `Ok` indicates that the pattern was applied successfully. `Err` indicates that the pattern + /// could not be applied. The pattern may have communicated the reason for the failure with + /// `notify_match_failure` + fn notify_pattern_end(&mut self, pattern: &Pattern, status: Result<(), Report>) {} + + /// Notify the listener that the pattern failed to match, and provide a diagnostic explaining + /// the reason why the failure occurred. + fn notify_match_failure(&mut self, span: SourceSpan, reason: Report) {} +} + +struct RewriterListenerBase { + kind: ListenerType, +} +impl Listener for RewriterListenerBase { + #[inline(always)] + fn kind(&self) -> ListenerType { + ListenerType::Rewriter + } + + fn notify_block_inserted( + &mut self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + todo!() + } + + fn notify_operation_inserted(&mut self, op: OperationRef, prev: Option) { + todo!() + } +} From edf8a04f38a42bc757f75124379256abdcf806f9 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 5 Oct 2024 00:29:44 -0400 Subject: [PATCH 14/31] wip: make callables fundamental, move function/module to hir dialect --- hir-macros/src/operation.rs | 13 +- hir2/src/dialects/hir/builders/function.rs | 7 +- hir2/src/dialects/hir/ops.rs | 6 +- hir2/src/dialects/hir/ops/function.rs | 158 ++++++++++++++ hir2/src/{ir => dialects/hir/ops}/module.rs | 0 hir2/src/ir.rs | 6 +- hir2/src/ir/{function.rs => callable.rs} | 229 +++++++++----------- hir2/src/ir/operation/builder.rs | 6 +- hir2/src/ir/print.rs | 4 +- hir2/src/ir/traits.rs | 3 +- hir2/src/ir/traits/callable.rs | 135 ------------ 11 files changed, 283 insertions(+), 284 deletions(-) create mode 100644 hir2/src/dialects/hir/ops/function.rs rename hir2/src/{ir => dialects/hir/ops}/module.rs (100%) rename hir2/src/ir/{function.rs => callable.rs} (65%) delete mode 100644 hir2/src/ir/traits/callable.rs diff --git a/hir-macros/src/operation.rs b/hir-macros/src/operation.rs index 08c46b5e8..421aa09e6 100644 --- a/hir-macros/src/operation.rs +++ b/hir-macros/src/operation.rs @@ -1042,11 +1042,11 @@ impl quote::ToTokens for OpSymbolFns<'_> { #( #[doc = #set_symbol_doc_lines] )* - pub fn #set_symbol(&mut self, symbol: impl ::midenc_hir2::traits::AsCallableSymbolRef) -> Result<(), ::midenc_hir2::InvalidSymbolRefError> { + pub fn #set_symbol(&mut self, symbol: impl ::midenc_hir2::AsCallableSymbolRef) -> Result<(), ::midenc_hir2::InvalidSymbolRefError> { let symbol = symbol.as_callable_symbol_ref(); let (data_ptr, _) = ::midenc_hir2::SymbolRef::as_ptr(&symbol).to_raw_parts(); if core::ptr::addr_eq(data_ptr, (self as *const Self as *const ())) { - if !self.op.implements::() { + if !self.op.implements::() { return Err(::midenc_hir2::InvalidSymbolRefError::NotCallable { symbol: self.span(), }); @@ -1054,7 +1054,7 @@ impl quote::ToTokens for OpSymbolFns<'_> { } else { let symbol = symbol.borrow(); let symbol_op = symbol.as_symbol_operation(); - if !symbol_op.implements::() { + if !symbol_op.implements::() { return Err(::midenc_hir2::InvalidSymbolRefError::NotCallable { symbol: symbol_op.span(), }); @@ -2413,10 +2413,9 @@ impl OpCreateParam { })] } SymbolType::Callable => { - let as_callable_symbol_ref_bound = syn::parse_str::( - "::midenc_hir2::traits::AsCallableSymbolRef", - ) - .unwrap(); + let as_callable_symbol_ref_bound = + syn::parse_str::("::midenc_hir2::AsCallableSymbolRef") + .unwrap(); vec![syn::GenericParam::Type(syn::TypeParam { attrs: vec![], ident: format_ident!("T{}", name.to_string().to_pascal_case()), diff --git a/hir2/src/dialects/hir/builders/function.rs b/hir2/src/dialects/hir/builders/function.rs index b560156f6..b42b84b74 100644 --- a/hir2/src/dialects/hir/builders/function.rs +++ b/hir2/src/dialects/hir/builders/function.rs @@ -1,5 +1,8 @@ -use self::traits::AsCallableSymbolRef; -use crate::*; +use crate::{ + dialects::hir::*, AsCallableSymbolRef, BlockRef, Builder, Immediate, InsertionPoint, Op, + OpBuilder, Region, RegionRef, Report, SourceSpan, Type, UnsafeIntrusiveEntityRef, Usable, + ValueRef, +}; pub struct FunctionBuilder<'f> { pub func: &'f mut Function, diff --git a/hir2/src/dialects/hir/ops.rs b/hir2/src/dialects/hir/ops.rs index a4749119f..5eb25fa3c 100644 --- a/hir2/src/dialects/hir/ops.rs +++ b/hir2/src/dialects/hir/ops.rs @@ -2,13 +2,15 @@ mod assertions; mod binary; mod cast; mod control; +mod function; mod invoke; mod mem; +mod module; mod primop; mod ternary; mod unary; pub use self::{ - assertions::*, binary::*, cast::*, control::*, invoke::*, mem::*, primop::*, ternary::*, - unary::*, + assertions::*, binary::*, cast::*, control::*, function::*, invoke::*, mem::*, module::*, + primop::*, ternary::*, unary::*, }; diff --git a/hir2/src/dialects/hir/ops/function.rs b/hir2/src/dialects/hir/ops/function.rs new file mode 100644 index 000000000..2c92b039e --- /dev/null +++ b/hir2/src/dialects/hir/ops/function.rs @@ -0,0 +1,158 @@ +use crate::{ + derive::operation, + dialects::hir::HirDialect, + traits::{IsolatedFromAbove, RegionKind, RegionKindInterface, SingleRegion}, + BlockRef, CallableOpInterface, Ident, Operation, OperationRef, RegionRef, Report, Signature, + Symbol, SymbolName, SymbolNameAttr, SymbolRef, SymbolUse, SymbolUseList, SymbolUseRef, + SymbolUsesIter, Usable, Visibility, +}; + +trait UsableSymbol = Usable; + +#[operation( + dialect = HirDialect, + traits(SingleRegion, IsolatedFromAbove), + implements( + UsableSymbol, + Symbol, + CallableOpInterface, + RegionKindInterface + ) +)] +pub struct Function { + #[region] + body: RegionRef, + #[attr] + name: Ident, + #[attr] + signature: Signature, + /// The uses of this function as a symbol + uses: SymbolUseList, +} + +impl Function { + #[inline] + pub fn entry_block(&self) -> BlockRef { + unsafe { BlockRef::from_raw(&*self.body().entry()) } + } + + pub fn last_block(&self) -> BlockRef { + self.body() + .body() + .back() + .as_pointer() + .expect("cannot access blocks of a function declaration") + } +} + +impl RegionKindInterface for Function { + #[inline(always)] + fn kind(&self) -> RegionKind { + RegionKind::SSA + } +} + +impl Usable for Function { + type Use = SymbolUse; + + #[inline(always)] + fn uses(&self) -> &SymbolUseList { + &self.uses + } + + #[inline(always)] + fn uses_mut(&mut self) -> &mut SymbolUseList { + &mut self.uses + } +} + +impl Symbol for Function { + #[inline(always)] + fn as_symbol_operation(&self) -> &Operation { + &self.op + } + + #[inline(always)] + fn as_symbol_operation_mut(&mut self) -> &mut Operation { + &mut self.op + } + + fn name(&self) -> SymbolName { + Self::name(self).as_symbol() + } + + fn set_name(&mut self, name: SymbolName) { + let id = self.name_mut(); + id.name = name; + } + + fn visibility(&self) -> Visibility { + self.signature().visibility + } + + fn set_visibility(&mut self, visibility: Visibility) { + self.signature_mut().visibility = visibility; + } + + fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter { + SymbolUsesIter::from_iter(self.uses.iter().filter_map(|user| { + if OperationRef::ptr_eq(&from, &user.owner) + || from.borrow().is_proper_ancestor_of(user.owner.clone()) + { + Some(unsafe { SymbolUseRef::from_raw(&*user) }) + } else { + None + } + })) + } + + fn replace_all_uses( + &mut self, + replacement: SymbolRef, + from: OperationRef, + ) -> Result<(), Report> { + for symbol_use in self.symbol_uses(from) { + let (mut owner, attr_name) = { + let user = symbol_use.borrow(); + (user.owner.clone(), user.symbol) + }; + let mut owner = owner.borrow_mut(); + // Unlink previously used symbol + { + let current_symbol = owner + .get_typed_attribute_mut::(&attr_name) + .expect("stale symbol user"); + unsafe { + self.uses.cursor_mut_from_ptr(current_symbol.user.clone()).remove(); + } + } + // Link replacement symbol + owner.set_symbol_attribute(attr_name, replacement.clone()); + } + + Ok(()) + } + + /// Returns true if this operation is a declaration, rather than a definition, of a symbol + /// + /// The default implementation assumes that all operations are definitions + #[inline] + fn is_declaration(&self) -> bool { + self.body().is_empty() + } +} + +impl CallableOpInterface for Function { + fn get_callable_region(&self) -> Option { + if self.is_declaration() { + None + } else { + self.op.regions().front().as_pointer() + } + } + + #[inline] + fn signature(&self) -> &Signature { + Function::signature(self) + } +} diff --git a/hir2/src/ir/module.rs b/hir2/src/dialects/hir/ops/module.rs similarity index 100% rename from hir2/src/ir/module.rs rename to hir2/src/dialects/hir/ops/module.rs diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index e18763c0e..c8db791bc 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -1,16 +1,15 @@ mod attribute; mod block; mod builder; +mod callable; mod component; mod context; mod dialect; mod entity; -mod function; mod ident; mod immediates; mod insert; mod interface; -mod module; mod op; mod operands; mod operation; @@ -35,6 +34,7 @@ pub use self::{ BlockRef, }, builder::{Builder, Listener, ListenerType, OpBuilder}, + callable::*, context::Context, dialect::{Dialect, DialectName, DialectRegistration}, entity::{ @@ -42,11 +42,9 @@ pub use self::{ EntityMut, EntityRange, EntityRangeMut, EntityRef, EntityStorage, RawEntityRef, StorableEntity, UnsafeEntityRef, UnsafeIntrusiveEntityRef, }, - function::{AbiParam, ArgumentExtension, ArgumentPurpose, Function, Signature}, ident::{FunctionIdent, Ident}, immediates::{Felt, FieldElement, Immediate, StarkField}, insert::{Insert, InsertionPoint, ProgramPoint}, - module::Module, op::{BuildableOp, Op, OpExt, OpRegistration}, operands::{ OpOperand, OpOperandImpl, OpOperandList, OpOperandRange, OpOperandRangeMut, diff --git a/hir2/src/ir/function.rs b/hir2/src/ir/callable.rs similarity index 65% rename from hir2/src/ir/function.rs rename to hir2/src/ir/callable.rs index 49a7a9725..a9e9b67cf 100644 --- a/hir2/src/ir/function.rs +++ b/hir2/src/ir/callable.rs @@ -1,163 +1,138 @@ use core::fmt; -use super::*; use crate::{ - derive::operation, - dialects::hir::HirDialect, - formatter, - traits::{ - CallableOpInterface, IsolatedFromAbove, RegionKind, RegionKindInterface, SingleRegion, - }, - CallConv, Symbol, SymbolName, SymbolUse, SymbolUseList, SymbolUsesIter, Visibility, + formatter, CallConv, EntityRef, OpOperandRange, OpOperandRangeMut, RegionRef, Symbol, + SymbolNameAttr, SymbolRef, Type, UnsafeIntrusiveEntityRef, Value, ValueRef, Visibility, }; -trait UsableSymbol = Usable; - -#[operation( - dialect = HirDialect, - traits(SingleRegion, IsolatedFromAbove), - implements( - UsableSymbol, - Symbol, - CallableOpInterface, - RegionKindInterface - ) -)] -pub struct Function { - #[region] - body: RegionRef, - #[attr] - name: Ident, - #[attr] - signature: Signature, - /// The uses of this function as a symbol - uses: SymbolUseList, +/// A call-like operation is one that transfers control from one function to another. +/// +/// These operations may be traditional static calls, e.g. `call @foo`, or indirect calls, e.g. +/// `call_indirect v1`. An operation that uses this interface cannot _also_ implement the +/// `CallableOpInterface`. +pub trait CallOpInterface { + /// Get the callee of this operation. + /// + /// A callee is either a symbol, or a reference to an SSA value. + fn callable_for_callee(&self) -> Callable; + /// Sets the callee for this operation. + fn set_callee(&mut self, callable: Callable); + /// Get the operands of this operation that are used as arguments for the callee + fn arguments(&self) -> OpOperandRange<'_>; + /// Get a mutable reference to the operands of this operation that are used as arguments for the + /// callee + fn arguments_mut(&mut self) -> OpOperandRangeMut<'_>; + /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None` + /// if a valid callable was not resolved, using the provided symbol table. + /// + /// This method is used to perform callee resolution using a cached symbol table, rather than + /// traversing the operation hierarchy looking for symbol tables to try resolving with. + fn resolve_in_symbol_table(&self, symbols: &dyn crate::SymbolTable) -> Option; + /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None` + /// if a valid callable was not resolved. + fn resolve(&self) -> Option; } -impl Function { - #[inline] - pub fn entry_block(&self) -> BlockRef { - unsafe { BlockRef::from_raw(&*self.body().entry()) } - } - - pub fn last_block(&self) -> BlockRef { - self.body() - .body() - .back() - .as_pointer() - .expect("cannot access blocks of a function declaration") - } +/// A callable operation is one who represents a potential function, and may be a target for a call- +/// like operation (i.e. implementations of `CallOpInterface`). These operations may be traditional +/// function ops (i.e. `Function`), as well as function reference-producing operations, such as an +/// op that creates closures, or captures a function by reference. +/// +/// These operations may only contain a single region. +pub trait CallableOpInterface { + /// Returns the region on the current operation that is callable. + /// + /// This may return `None` in the case of an external callable object, e.g. an externally- + /// defined function reference. + fn get_callable_region(&self) -> Option; + /// Returns the signature of the callable + fn signature(&self) -> &Signature; } -impl RegionKindInterface for Function { - #[inline(always)] - fn kind(&self) -> RegionKind { - RegionKind::SSA - } +#[doc(hidden)] +pub trait AsCallableSymbolRef { + fn as_callable_symbol_ref(&self) -> SymbolRef; } - -impl Usable for Function { - type Use = SymbolUse; - - #[inline(always)] - fn uses(&self) -> &EntityList { - &self.uses - } - +impl AsCallableSymbolRef for T { #[inline(always)] - fn uses_mut(&mut self) -> &mut EntityList { - &mut self.uses + fn as_callable_symbol_ref(&self) -> SymbolRef { + unsafe { SymbolRef::from_raw(self as &dyn Symbol) } } } - -impl Symbol for Function { +impl AsCallableSymbolRef for UnsafeIntrusiveEntityRef { #[inline(always)] - fn as_symbol_operation(&self) -> &Operation { - &self.op + fn as_callable_symbol_ref(&self) -> SymbolRef { + let t_ptr = Self::as_ptr(self); + unsafe { SymbolRef::from_raw(t_ptr as *const dyn Symbol) } } +} - #[inline(always)] - fn as_symbol_operation_mut(&mut self) -> &mut Operation { - &mut self.op +/// A [Callable] represents a symbol or a value which can be used as a valid _callee_ for a +/// [CallOpInterface] implementation. +/// +/// Symbols are not SSA values, but there are situations where we want to treat them as one, such +/// as indirect calls. Abstracting over whether the callable is a symbol or an SSA value allows us +/// to focus on the call semantics, rather than the difference between the type types of value. +#[derive(Debug, Clone)] +pub enum Callable { + Symbol(SymbolNameAttr), + Value(ValueRef), +} +impl From<&SymbolNameAttr> for Callable { + fn from(value: &SymbolNameAttr) -> Self { + Self::Symbol(value.clone()) } - - fn name(&self) -> SymbolName { - Self::name(self).as_symbol() +} +impl From for Callable { + fn from(value: SymbolNameAttr) -> Self { + Self::Symbol(value) } - - fn set_name(&mut self, name: SymbolName) { - let id = self.name_mut(); - id.name = name; +} +impl From for Callable { + fn from(value: ValueRef) -> Self { + Self::Value(value) } - - fn visibility(&self) -> Visibility { - self.signature().visibility +} +impl Callable { + #[inline(always)] + pub fn new(callable: impl Into) -> Self { + callable.into() } - fn set_visibility(&mut self, visibility: Visibility) { - self.signature_mut().visibility = visibility; + pub fn is_symbol(&self) -> bool { + matches!(self, Self::Symbol(_)) } - fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter { - SymbolUsesIter::from_iter(self.uses.iter().filter_map(|user| { - if OperationRef::ptr_eq(&from, &user.owner) - || from.borrow().is_proper_ancestor_of(user.owner.clone()) - { - Some(unsafe { SymbolUseRef::from_raw(&*user) }) - } else { - None - } - })) + pub fn is_value(&self) -> bool { + matches!(self, Self::Value(_)) } - fn replace_all_uses( - &mut self, - replacement: SymbolRef, - from: OperationRef, - ) -> Result<(), Report> { - for symbol_use in self.symbol_uses(from) { - let (mut owner, attr_name) = { - let user = symbol_use.borrow(); - (user.owner.clone(), user.symbol) - }; - let mut owner = owner.borrow_mut(); - // Unlink previously used symbol - { - let current_symbol = owner - .get_typed_attribute_mut::(&attr_name) - .expect("stale symbol user"); - unsafe { - self.uses.cursor_mut_from_ptr(current_symbol.user.clone()).remove(); - } - } - // Link replacement symbol - owner.set_symbol_attribute(attr_name, replacement.clone()); + pub fn as_symbol_name(&self) -> Option<&SymbolNameAttr> { + match self { + Self::Symbol(ref name) => Some(name), + _ => None, } - - Ok(()) } - /// Returns true if this operation is a declaration, rather than a definition, of a symbol - /// - /// The default implementation assumes that all operations are definitions - #[inline] - fn is_declaration(&self) -> bool { - self.body().is_empty() + pub fn as_value(&self) -> Option> { + match self { + Self::Value(ref value_ref) => Some(value_ref.borrow()), + _ => None, + } } -} -impl CallableOpInterface for Function { - fn get_callable_region(&self) -> Option { - if self.is_declaration() { - None - } else { - self.regions().front().as_pointer() + pub fn unwrap_symbol_name(self) -> SymbolNameAttr { + match self { + Self::Symbol(name) => name, + Self::Value(value_ref) => panic!("expected symbol, got {}", value_ref.borrow().id()), } } - #[inline] - fn signature(&self) -> &Signature { - Function::signature(self) + pub fn unwrap_value_ref(self) -> ValueRef { + match self { + Self::Value(value) => value, + Self::Symbol(ref name) => panic!("expected value, got {name}"), + } } } diff --git a/hir2/src/ir/operation/builder.rs b/hir2/src/ir/operation/builder.rs index 90923b817..2818ef68a 100644 --- a/hir2/src/ir/operation/builder.rs +++ b/hir2/src/ir/operation/builder.rs @@ -1,7 +1,7 @@ use crate::{ - traits::{AsCallableSymbolRef, Terminator}, - AsSymbolRef, AttributeValue, BlockRef, Builder, KeyedSuccessor, Op, OpBuilder, OperationRef, - Region, Report, Spanned, SuccessorInfo, Type, UnsafeIntrusiveEntityRef, ValueRef, + traits::Terminator, AsCallableSymbolRef, AsSymbolRef, AttributeValue, BlockRef, Builder, + KeyedSuccessor, Op, OpBuilder, OperationRef, Region, Report, Spanned, SuccessorInfo, Type, + UnsafeIntrusiveEntityRef, ValueRef, }; /// The [OperationBuilder] is a primitive for imperatively constructing an [Operation]. diff --git a/hir2/src/ir/print.rs b/hir2/src/ir/print.rs index c465da32c..1ba13193f 100644 --- a/hir2/src/ir/print.rs +++ b/hir2/src/ir/print.rs @@ -3,8 +3,8 @@ use core::fmt; use super::{Context, Operation}; use crate::{ formatter::PrettyPrint, - traits::{CallableOpInterface, SingleBlock, SingleRegion}, - Entity, Value, + traits::{SingleBlock, SingleRegion}, + CallableOpInterface, Entity, Value, }; pub struct OpPrintingFlags; diff --git a/hir2/src/ir/traits.rs b/hir2/src/ir/traits.rs index 36866e1c2..b87af90ab 100644 --- a/hir2/src/ir/traits.rs +++ b/hir2/src/ir/traits.rs @@ -1,11 +1,10 @@ -mod callable; mod info; mod types; use midenc_session::diagnostics::Severity; pub(crate) use self::info::TraitInfo; -pub use self::{callable::*, types::*}; +pub use self::types::*; use crate::{derive, Context, Operation, Report, Spanned}; /// Marker trait for commutative ops, e.g. `X op Y == Y op X` diff --git a/hir2/src/ir/traits/callable.rs b/hir2/src/ir/traits/callable.rs deleted file mode 100644 index e99924e22..000000000 --- a/hir2/src/ir/traits/callable.rs +++ /dev/null @@ -1,135 +0,0 @@ -use crate::{ - EntityRef, OpOperandRange, OpOperandRangeMut, RegionRef, Signature, Symbol, SymbolNameAttr, - SymbolRef, UnsafeIntrusiveEntityRef, Value, ValueRef, -}; - -/// A call-like operation is one that transfers control from one function to another. -/// -/// These operations may be traditional static calls, e.g. `call @foo`, or indirect calls, e.g. -/// `call_indirect v1`. An operation that uses this interface cannot _also_ implement the -/// `CallableOpInterface`. -pub trait CallOpInterface { - /// Get the callee of this operation. - /// - /// A callee is either a symbol, or a reference to an SSA value. - fn callable_for_callee(&self) -> Callable; - /// Sets the callee for this operation. - fn set_callee(&mut self, callable: Callable); - /// Get the operands of this operation that are used as arguments for the callee - fn arguments(&self) -> OpOperandRange<'_>; - /// Get a mutable reference to the operands of this operation that are used as arguments for the - /// callee - fn arguments_mut(&mut self) -> OpOperandRangeMut<'_>; - /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None` - /// if a valid callable was not resolved, using the provided symbol table. - /// - /// This method is used to perform callee resolution using a cached symbol table, rather than - /// traversing the operation hierarchy looking for symbol tables to try resolving with. - fn resolve_in_symbol_table(&self, symbols: &dyn crate::SymbolTable) -> Option; - /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None` - /// if a valid callable was not resolved. - fn resolve(&self) -> Option; -} - -/// A callable operation is one who represents a potential function, and may be a target for a call- -/// like operation (i.e. implementations of `CallOpInterface`). These operations may be traditional -/// function ops (i.e. `Function`), as well as function reference-producing operations, such as an -/// op that creates closures, or captures a function by reference. -/// -/// These operations may only contain a single region. -pub trait CallableOpInterface { - /// Returns the region on the current operation that is callable. - /// - /// This may return `None` in the case of an external callable object, e.g. an externally- - /// defined function reference. - fn get_callable_region(&self) -> Option; - /// Returns the signature of the callable - fn signature(&self) -> &Signature; -} - -#[doc(hidden)] -pub trait AsCallableSymbolRef { - fn as_callable_symbol_ref(&self) -> SymbolRef; -} -impl AsCallableSymbolRef for T { - #[inline(always)] - fn as_callable_symbol_ref(&self) -> SymbolRef { - unsafe { SymbolRef::from_raw(self as &dyn Symbol) } - } -} -impl AsCallableSymbolRef for UnsafeIntrusiveEntityRef { - #[inline(always)] - fn as_callable_symbol_ref(&self) -> SymbolRef { - let t_ptr = Self::as_ptr(self); - unsafe { SymbolRef::from_raw(t_ptr as *const dyn Symbol) } - } -} - -/// A [Callable] represents a symbol or a value which can be used as a valid _callee_ for a -/// [CallOpInterface] implementation. -/// -/// Symbols are not SSA values, but there are situations where we want to treat them as one, such -/// as indirect calls. Abstracting over whether the callable is a symbol or an SSA value allows us -/// to focus on the call semantics, rather than the difference between the type types of value. -#[derive(Debug, Clone)] -pub enum Callable { - Symbol(SymbolNameAttr), - Value(ValueRef), -} -impl From<&SymbolNameAttr> for Callable { - fn from(value: &SymbolNameAttr) -> Self { - Self::Symbol(value.clone()) - } -} -impl From for Callable { - fn from(value: SymbolNameAttr) -> Self { - Self::Symbol(value) - } -} -impl From for Callable { - fn from(value: ValueRef) -> Self { - Self::Value(value) - } -} -impl Callable { - #[inline(always)] - pub fn new(callable: impl Into) -> Self { - callable.into() - } - - pub fn is_symbol(&self) -> bool { - matches!(self, Self::Symbol(_)) - } - - pub fn is_value(&self) -> bool { - matches!(self, Self::Value(_)) - } - - pub fn as_symbol_name(&self) -> Option<&SymbolNameAttr> { - match self { - Self::Symbol(ref name) => Some(name), - _ => None, - } - } - - pub fn as_value(&self) -> Option> { - match self { - Self::Value(ref value_ref) => Some(value_ref.borrow()), - _ => None, - } - } - - pub fn unwrap_symbol_name(self) -> SymbolNameAttr { - match self { - Self::Symbol(name) => name, - Self::Value(value_ref) => panic!("expected symbol, got {}", value_ref.borrow().id()), - } - } - - pub fn unwrap_value_ref(self) -> ValueRef { - match self { - Self::Value(value) => value, - Self::Symbol(ref name) => panic!("expected value, got {name}"), - } - } -} From e8397915b258663163d3dfb991f491e2059ef503 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 19:48:33 -0400 Subject: [PATCH 15/31] wip: add constant materialization hook to dialect --- hir2/src/ir/dialect.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/hir2/src/ir/dialect.rs b/hir2/src/ir/dialect.rs index 78346f8e5..0cd062117 100644 --- a/hir2/src/ir/dialect.rs +++ b/hir2/src/ir/dialect.rs @@ -1,7 +1,7 @@ use alloc::rc::Rc; use core::{borrow::Borrow, ops::Deref}; -use crate::{AsAny, OperationName}; +use crate::{AsAny, AttributeValue, Builder, OperationName, OperationRef, SourceSpan, Type}; /// A [Dialect] represents a collection of IR entities that are used in conjunction with one /// another. Multiple dialects can co-exist _or_ be mutually exclusive. Converting between dialects @@ -20,6 +20,25 @@ pub trait Dialect { opcode: ::midenc_hir_symbol::Symbol, register: fn(DialectName, ::midenc_hir_symbol::Symbol) -> OperationName, ) -> OperationName; + + /// A hook to materialize a single constant operation from a given attribute value and type. + /// + /// This method should use the provided builder to create the operation without changing the + /// insertion point. The generated operation is expected to be constant-like, i.e. single result + /// zero operands, no side effects, etc. + /// + /// Returns `None` if a constant cannot be materialized for the given attribute. + #[allow(unused_variables)] + #[inline] + fn materialize_constant( + &self, + builder: &mut dyn Builder, + attr: Box, + ty: &Type, + span: SourceSpan, + ) -> Option { + None + } } /// A [DialectRegistration] must be implemented for any implementation of [Dialect], to allow the From 07cf501a367012d3c5fc7f43a8a19da7c211b7d4 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 19:50:18 -0400 Subject: [PATCH 16/31] chore: add some useful deps --- Cargo.lock | 42 ++++++++++++++++++++++++++++++++++++++++++ hir2/Cargo.toml | 3 +++ 2 files changed, 45 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 571e5aab8..51ffbb835 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -524,6 +524,18 @@ dependencies = [ "typenum", ] +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "blake3" version = "1.5.4" @@ -1658,6 +1670,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "futures" version = "0.3.30" @@ -3424,11 +3442,14 @@ name = "midenc-hir2" version = "0.0.6" dependencies = [ "anyhow", + "bitflags 2.6.0", + "bitvec", "blink-alloc", "compact_str", "cranelift-entity", "derive_more", "either", + "env_logger 0.11.5", "hashbrown 0.14.5", "indexmap 2.5.0", "intrusive-collections", @@ -4443,6 +4464,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.8.5" @@ -5493,6 +5520,12 @@ dependencies = [ "libc", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "tempfile" version = "3.10.1" @@ -7080,6 +7113,15 @@ dependencies = [ "wasmparser 0.216.0", ] +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "xdg-home" version = "1.3.0" diff --git a/hir2/Cargo.toml b/hir2/Cargo.toml index 72dfa1ea4..8f680aacf 100644 --- a/hir2/Cargo.toml +++ b/hir2/Cargo.toml @@ -31,6 +31,8 @@ blink-alloc = { version = "0.3", default-features = false, features = [ "alloc", "nightly", ] } +bitvec = { version = "1.0", default-features = false, features = ["alloc"] } +bitflags.workspace = true either.workspace = true cranelift-entity.workspace = true compact_str.workspace = true @@ -63,3 +65,4 @@ indexmap.workspace = true [dev-dependencies] pretty_assertions = "1.0" +env_logger.workspace = true From c56be9b3532b45fbef6da3408ee113c2aab2d37f Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 19:57:23 -0400 Subject: [PATCH 17/31] wip: promote attributes to top level, add ability to clone and hash type-erased attribute values --- hir-macros/src/operation.rs | 8 +- hir2/src/attributes.rs | 481 ++++++++++++++++++++++++++++++ hir2/src/attributes/call_conv.rs | 2 +- hir2/src/attributes/overflow.rs | 2 +- hir2/src/attributes/visibility.rs | 2 +- hir2/src/derive.rs | 2 +- hir2/src/hash.rs | 97 ++++++ hir2/src/ir.rs | 2 - hir2/src/ir/attribute.rs | 456 ---------------------------- hir2/src/ir/callable.rs | 17 +- hir2/src/ir/op.rs | 35 +-- hir2/src/ir/operation.rs | 55 ++-- hir2/src/ir/symbol_table.rs | 8 +- 13 files changed, 625 insertions(+), 542 deletions(-) create mode 100644 hir2/src/hash.rs delete mode 100644 hir2/src/ir/attribute.rs diff --git a/hir-macros/src/operation.rs b/hir-macros/src/operation.rs index 421aa09e6..900814978 100644 --- a/hir-macros/src/operation.rs +++ b/hir-macros/src/operation.rs @@ -981,12 +981,12 @@ impl quote::ToTokens for OpSymbolFns<'_> { #[doc = #symbol_doc] pub fn #symbol(&self) -> &::midenc_hir2::SymbolNameAttr { - self.op.get_typed_attribute(&Self::#symbol_symbol()).unwrap() + self.op.get_typed_attribute(Self::#symbol_symbol()).unwrap() } #[doc = #symbol_mut_doc] pub fn #symbol_mut(&mut self) -> &mut ::midenc_hir2::SymbolNameAttr { - self.op.get_typed_attribute_mut(&Self::#symbol_symbol()).unwrap() + self.op.get_typed_attribute_mut(Self::#symbol_symbol()).unwrap() } #( @@ -1109,12 +1109,12 @@ impl quote::ToTokens for OpAttrFns<'_> { #[doc = #attr_doc] pub fn #attr(&self) -> &#attr_ty { - self.op.get_typed_attribute::<#attr_ty, _>(&Self::#attr_symbol()).unwrap() + self.op.get_typed_attribute::<#attr_ty>(Self::#attr_symbol()).unwrap() } #[doc = #attr_mut_doc] pub fn #attr_mut(&mut self) -> &mut #attr_ty { - self.op.get_typed_attribute_mut::<#attr_ty, _>(&Self::#attr_symbol()).unwrap() + self.op.get_typed_attribute_mut::<#attr_ty>(Self::#attr_symbol()).unwrap() } #[doc = #set_attr_doc] diff --git a/hir2/src/attributes.rs b/hir2/src/attributes.rs index c24d79ee0..269c36c63 100644 --- a/hir2/src/attributes.rs +++ b/hir2/src/attributes.rs @@ -2,4 +2,485 @@ mod call_conv; mod overflow; mod visibility; +use alloc::collections::BTreeMap; +use core::{any::Any, borrow::Borrow, fmt}; + pub use self::{call_conv::CallConv, overflow::Overflow, visibility::Visibility}; +use crate::interner::Symbol; + +pub mod markers { + use midenc_hir_symbol::symbols; + + use super::*; + + /// This attribute indicates that the decorated function is the entrypoint + /// for its containing program, regardless of what module it is defined in. + pub const ENTRYPOINT: Attribute = Attribute { + name: symbols::Entrypoint, + value: None, + }; +} + +/// An [AttributeSet] is a uniqued collection of attributes associated with some IR entity +#[derive(Debug, Default, Hash)] +pub struct AttributeSet(Vec); +impl FromIterator for AttributeSet { + fn from_iter(attrs: T) -> Self + where + T: IntoIterator, + { + let mut map = BTreeMap::default(); + for attr in attrs.into_iter() { + map.insert(attr.name, attr.value); + } + Self(map.into_iter().map(|(name, value)| Attribute { name, value }).collect()) + } +} +impl FromIterator<(Symbol, Option>)> for AttributeSet { + fn from_iter(attrs: T) -> Self + where + T: IntoIterator>)>, + { + let mut map = BTreeMap::default(); + for (name, value) in attrs.into_iter() { + map.insert(name, value); + } + Self(map.into_iter().map(|(name, value)| Attribute { name, value }).collect()) + } +} +impl AttributeSet { + /// Get a new, empty [AttributeSet] + pub fn new() -> Self { + Self::default() + } + + /// Insert a new [Attribute] in this set by `name` and `value` + pub fn insert(&mut self, name: impl Into, value: Option) { + self.set(Attribute { + name: name.into(), + value: value.map(|v| Box::new(v) as Box), + }); + } + + /// Adds `attr` to this set + pub fn set(&mut self, attr: Attribute) { + match self.0.binary_search_by_key(&attr.name, |attr| attr.name) { + Ok(index) => { + self.0[index].value = attr.value; + } + Err(index) => { + if index == self.0.len() { + self.0.push(attr); + } else { + self.0.insert(index, attr); + } + } + } + } + + /// Remove an [Attribute] by name from this set + pub fn remove(&mut self, name: impl Into) { + let name = name.into(); + match self.0.binary_search_by_key(&name, |attr| attr.name) { + Ok(index) if index + 1 == self.0.len() => { + self.0.pop(); + } + Ok(index) => { + self.0.remove(index); + } + Err(_) => (), + } + } + + /// Determine if the named [Attribute] is present in this set + pub fn has(&self, key: impl Into) -> bool { + let key = key.into(); + self.0.binary_search_by_key(&key, |attr| attr.name).is_ok() + } + + /// Get the [AttributeValue] associated with the named [Attribute] + pub fn get_any(&self, key: impl Into) -> Option<&dyn AttributeValue> { + let key = key.into(); + match self.0.binary_search_by_key(&key, |attr| attr.name) { + Ok(index) => self.0[index].value.as_deref(), + Err(_) => None, + } + } + + /// Get the [AttributeValue] associated with the named [Attribute] + pub fn get_any_mut(&mut self, key: impl Into) -> Option<&mut dyn AttributeValue> { + let key = key.into(); + match self.0.binary_search_by_key(&key, |attr| attr.name) { + Ok(index) => self.0[index].value.as_deref_mut(), + Err(_) => None, + } + } + + /// Get the value associated with the named [Attribute] as a value of type `V`, or `None`. + pub fn get(&self, key: impl Into) -> Option<&V> + where + V: AttributeValue, + { + self.get_any(key).and_then(|v| v.downcast_ref::()) + } + + /// Get the value associated with the named [Attribute] as a mutable value of type `V`, or + /// `None`. + pub fn get_mut(&mut self, key: impl Into) -> Option<&mut V> + where + V: AttributeValue, + { + self.get_any_mut(key).and_then(|v| v.downcast_mut::()) + } + + /// Iterate over each [Attribute] in this set + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter() + } +} + +/// An [Attribute] associates some data with a well-known identifier (name). +/// +/// Attributes are used for representing metadata that helps guide compilation, +/// but which is not part of the code itself. For example, `cfg` flags in Rust +/// are an example of something which you could represent using an [Attribute]. +/// They can also be used to store documentation, source locations, and more. +#[derive(Debug, Hash)] +pub struct Attribute { + /// The name of this attribute + pub name: Symbol, + /// The value associated with this attribute + pub value: Option>, +} +impl Attribute { + pub fn new(name: impl Into, value: Option) -> Self { + Self { + name: name.into(), + value: value.map(|v| Box::new(v) as Box), + } + } + + pub fn value(&self) -> Option<&dyn AttributeValue> { + self.value.as_deref() + } + + pub fn value_as(&self) -> Option<&V> + where + V: AttributeValue, + { + match self.value.as_deref() { + Some(value) => value.downcast_ref::(), + None => None, + } + } +} +impl fmt::Display for Attribute { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.value.as_deref().map(|v| v.render()) { + None => write!(f, "#[{}]", self.name.as_str()), + Some(value) => write!(f, "#[{}({value})]", &self.name), + } + } +} + +pub trait AttributeValue: + Any + fmt::Debug + crate::formatter::PrettyPrint + crate::DynHash + 'static +{ + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + fn clone_value(&self) -> Box; +} + +impl dyn AttributeValue { + pub fn is(&self) -> bool { + self.as_any().is::() + } + + pub fn downcast(self: Box) -> Result, Box> { + if self.is::() { + let ptr = Box::into_raw(self); + Ok(unsafe { Box::from_raw(ptr.cast()) }) + } else { + Err(self) + } + } + + pub fn downcast_ref(&self) -> Option<&T> { + self.as_any().downcast_ref::() + } + + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.as_any_mut().downcast_mut::() + } +} + +impl core::hash::Hash for dyn AttributeValue { + fn hash(&self, state: &mut H) { + use crate::DynHash; + + let hashable = self as &dyn DynHash; + hashable.dyn_hash(state); + } +} + +#[derive(Clone)] +pub struct SetAttr { + values: Vec, +} +impl Default for SetAttr { + fn default() -> Self { + Self { + values: Default::default(), + } + } +} +impl SetAttr +where + K: Ord + Clone, +{ + pub fn insert(&mut self, key: K) -> bool { + match self.values.binary_search_by(|k| key.cmp(k)) { + Ok(index) => { + self.values[index] = key; + false + } + Err(index) => { + self.values.insert(index, key); + true + } + } + } + + pub fn contains(&self, key: &K) -> bool { + self.values.binary_search_by(|k| key.cmp(k)).is_ok() + } + + pub fn iter(&self) -> core::slice::Iter<'_, K> { + self.values.iter() + } + + pub fn remove(&mut self, key: &Q) -> Option + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|k| key.cmp(k.borrow())) { + Ok(index) => Some(self.values.remove(index)), + Err(_) => None, + } + } +} +impl Eq for SetAttr where K: Eq {} +impl PartialEq for SetAttr +where + K: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.values == other.values + } +} +impl fmt::Debug for SetAttr +where + K: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_set().entries(self.values.iter()).finish() + } +} +impl crate::formatter::PrettyPrint for SetAttr +where + K: crate::formatter::PrettyPrint, +{ + fn render(&self) -> crate::formatter::Document { + todo!() + } +} +impl core::hash::Hash for SetAttr +where + K: core::hash::Hash, +{ + fn hash(&self, state: &mut H) { + as core::hash::Hash>::hash(&self.values, state); + } +} +impl AttributeValue for SetAttr +where + K: fmt::Debug + crate::formatter::PrettyPrint + Clone + core::hash::Hash + 'static, +{ + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self as &mut dyn Any + } + + #[inline] + fn clone_value(&self) -> Box { + Box::new(self.clone()) + } +} + +#[derive(Clone)] +pub struct DictAttr { + values: Vec<(K, V)>, +} +impl Default for DictAttr { + fn default() -> Self { + Self { values: vec![] } + } +} +impl DictAttr +where + K: Ord, + V: Clone, +{ + pub fn insert(&mut self, key: K, value: V) { + match self.values.binary_search_by(|(k, _)| key.cmp(k)) { + Ok(index) => { + self.values[index].1 = value; + } + Err(index) => { + self.values.insert(index, (key, value)); + } + } + } + + pub fn contains_key(&self, key: &Q) -> bool + where + K: Borrow, + Q: ?Sized + Ord, + { + self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())).is_ok() + } + + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())) { + Ok(index) => Some(&self.values[index].1), + Err(_) => None, + } + } + + pub fn remove(&mut self, key: &Q) -> Option + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())) { + Ok(index) => Some(self.values.remove(index).1), + Err(_) => None, + } + } + + pub fn iter(&self) -> core::slice::Iter<'_, (K, V)> { + self.values.iter() + } +} +impl Eq for DictAttr +where + K: Eq, + V: Eq, +{ +} +impl PartialEq for DictAttr +where + K: PartialEq, + V: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.values == other.values + } +} +impl fmt::Debug for DictAttr +where + K: fmt::Debug, + V: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.values.iter().map(|entry| (&entry.0, &entry.1))) + .finish() + } +} +impl crate::formatter::PrettyPrint for DictAttr +where + K: crate::formatter::PrettyPrint, + V: crate::formatter::PrettyPrint, +{ + fn render(&self) -> crate::formatter::Document { + todo!() + } +} +impl core::hash::Hash for DictAttr +where + K: core::hash::Hash, + V: core::hash::Hash, +{ + fn hash(&self, state: &mut H) { + as core::hash::Hash>::hash(&self.values, state); + } +} +impl AttributeValue for DictAttr +where + K: fmt::Debug + crate::formatter::PrettyPrint + Clone + core::hash::Hash + 'static, + V: fmt::Debug + crate::formatter::PrettyPrint + Clone + core::hash::Hash + 'static, +{ + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self as &mut dyn Any + } + + #[inline] + fn clone_value(&self) -> Box { + Box::new(self.clone()) + } +} + +#[macro_export] +macro_rules! define_attr_type { + ($T:ty) => { + impl $crate::AttributeValue for $T { + #[inline(always)] + fn as_any(&self) -> &dyn core::any::Any { + self as &dyn core::any::Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn core::any::Any { + self as &mut dyn core::any::Any + } + + #[inline] + fn clone_value(&self) -> Box { + Box::new(self.clone()) + } + } + }; +} + +define_attr_type!(bool); +define_attr_type!(u8); +define_attr_type!(i8); +define_attr_type!(u16); +define_attr_type!(i16); +define_attr_type!(u32); +define_attr_type!(core::num::NonZeroU32); +define_attr_type!(i32); +define_attr_type!(u64); +define_attr_type!(i64); +define_attr_type!(usize); +define_attr_type!(isize); +define_attr_type!(Symbol); +define_attr_type!(super::Immediate); +define_attr_type!(super::Type); diff --git a/hir2/src/attributes/call_conv.rs b/hir2/src/attributes/call_conv.rs index 283bbf29d..abededc0e 100644 --- a/hir2/src/attributes/call_conv.rs +++ b/hir2/src/attributes/call_conv.rs @@ -14,7 +14,7 @@ use core::fmt; /// from the public API to private functions. In short, choose a calling convention that is /// well-suited for a given function, to the extent that other constraints don't impose a choice /// on you. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)] #[cfg_attr( feature = "serde", derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) diff --git a/hir2/src/attributes/overflow.rs b/hir2/src/attributes/overflow.rs index 7822ae3b2..8bc7eae00 100644 --- a/hir2/src/attributes/overflow.rs +++ b/hir2/src/attributes/overflow.rs @@ -9,7 +9,7 @@ use crate::define_attr_type; /// Always check the documentation of the specific instruction involved to see if there /// are any specific differences in how this enum is interpreted compared to the default /// meaning of each variant. -#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] pub enum Overflow { /// Typically, this means the operation is performed using the equivalent field element /// operation, rather than a dedicated operation for the given type. Because of this, the diff --git a/hir2/src/attributes/visibility.rs b/hir2/src/attributes/visibility.rs index 637fd2eda..1c23a8f8f 100644 --- a/hir2/src/attributes/visibility.rs +++ b/hir2/src/attributes/visibility.rs @@ -3,7 +3,7 @@ use core::{fmt, str::FromStr}; use crate::define_attr_type; /// The types of visibility that a [Symbol] may have -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Visibility { /// The symbol is public and may be referenced anywhere internal or external to the visible /// references in the IR. diff --git a/hir2/src/derive.rs b/hir2/src/derive.rs index b8431f3a4..2ec48bca5 100644 --- a/hir2/src/derive.rs +++ b/hir2/src/derive.rs @@ -173,7 +173,7 @@ mod tests { Operation, Report, Spanned, Value, }; - #[derive(Debug, Copy, Clone, PartialEq, Eq)] + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] enum Overflow { #[allow(unused)] None, diff --git a/hir2/src/hash.rs b/hir2/src/hash.rs new file mode 100644 index 000000000..d2c0f1959 --- /dev/null +++ b/hir2/src/hash.rs @@ -0,0 +1,97 @@ +use core::hash::{Hash, Hasher}; + +/// A type-erased version of [core::hash::Hash] +pub trait DynHash { + fn dyn_hash(&self, hasher: &mut dyn Hasher); +} + +impl DynHash for H { + #[inline] + fn dyn_hash(&self, hasher: &mut dyn Hasher) { + let mut hasher = DynHasher(hasher); + ::hash(self, &mut hasher) + } +} + +pub struct DynHasher<'a>(&'a mut dyn Hasher); + +impl<'a> DynHasher<'a> { + pub fn new(hasher: &'a mut H) -> Self + where + H: Hasher, + { + Self(hasher) + } +} + +impl<'a> Hasher for DynHasher<'a> { + #[inline] + fn finish(&self) -> u64 { + self.0.finish() + } + + #[inline] + fn write(&mut self, bytes: &[u8]) { + self.0.write(bytes) + } + + #[inline] + fn write_u8(&mut self, i: u8) { + self.0.write_u8(i); + } + + #[inline] + fn write_i8(&mut self, i: i8) { + self.0.write_i8(i); + } + + #[inline] + fn write_u16(&mut self, i: u16) { + self.0.write_u16(i); + } + + #[inline] + fn write_i16(&mut self, i: i16) { + self.0.write_i16(i); + } + + #[inline] + fn write_u32(&mut self, i: u32) { + self.0.write_u32(i); + } + + #[inline] + fn write_i32(&mut self, i: i32) { + self.0.write_i32(i); + } + + #[inline] + fn write_u64(&mut self, i: u64) { + self.0.write_u64(i); + } + + #[inline] + fn write_i64(&mut self, i: i64) { + self.0.write_i64(i); + } + + #[inline] + fn write_u128(&mut self, i: u128) { + self.0.write_u128(i); + } + + #[inline] + fn write_i128(&mut self, i: i128) { + self.0.write_i128(i); + } + + #[inline] + fn write_usize(&mut self, i: usize) { + self.0.write_usize(i); + } + + #[inline] + fn write_isize(&mut self, i: isize) { + self.0.write_isize(i); + } +} diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index c8db791bc..f595b91f4 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -1,4 +1,3 @@ -mod attribute; mod block; mod builder; mod callable; @@ -28,7 +27,6 @@ pub use midenc_hir_symbol as interner; pub use midenc_session::diagnostics::{Report, SourceSpan, Span, Spanned}; pub use self::{ - attribute::{attributes::*, Attribute, AttributeSet, AttributeValue, DictAttr, SetAttr}, block::{ Block, BlockCursor, BlockCursorMut, BlockId, BlockList, BlockOperand, BlockOperandRef, BlockRef, diff --git a/hir2/src/ir/attribute.rs b/hir2/src/ir/attribute.rs deleted file mode 100644 index 2266ac176..000000000 --- a/hir2/src/ir/attribute.rs +++ /dev/null @@ -1,456 +0,0 @@ -use alloc::collections::BTreeMap; -use core::{any::Any, borrow::Borrow, fmt}; - -use super::interner::Symbol; - -pub mod attributes { - use midenc_hir_symbol::symbols; - - use super::*; - - /// This attribute indicates that the decorated function is the entrypoint - /// for its containing program, regardless of what module it is defined in. - pub const ENTRYPOINT: Attribute = Attribute { - name: symbols::Entrypoint, - value: None, - }; -} - -/// An [AttributeSet] is a uniqued collection of attributes associated with some IR entity -#[derive(Debug, Default)] -pub struct AttributeSet(Vec); -impl FromIterator for AttributeSet { - fn from_iter(attrs: T) -> Self - where - T: IntoIterator, - { - let mut map = BTreeMap::default(); - for attr in attrs.into_iter() { - map.insert(attr.name, attr.value); - } - Self(map.into_iter().map(|(name, value)| Attribute { name, value }).collect()) - } -} -impl FromIterator<(Symbol, Option>)> for AttributeSet { - fn from_iter(attrs: T) -> Self - where - T: IntoIterator>)>, - { - let mut map = BTreeMap::default(); - for (name, value) in attrs.into_iter() { - map.insert(name, value); - } - Self(map.into_iter().map(|(name, value)| Attribute { name, value }).collect()) - } -} -impl AttributeSet { - /// Get a new, empty [AttributeSet] - pub fn new() -> Self { - Self::default() - } - - /// Insert a new [Attribute] in this set by `name` and `value` - pub fn insert(&mut self, name: impl Into, value: Option) { - let name = name.into(); - match self.0.binary_search_by_key(&name, |attr| attr.name) { - Ok(index) => { - self.0[index].value = value.map(|v| Box::new(v) as Box); - } - Err(index) => { - let value = value.map(|v| Box::new(v) as Box); - if index == self.0.len() { - self.0.push(Attribute { name, value }); - } else { - self.0.insert(index, Attribute { name, value }); - } - } - } - } - - /// Adds `attr` to this set - pub fn set(&mut self, attr: Attribute) { - match self.0.binary_search_by_key(&attr.name, |attr| attr.name) { - Ok(index) => { - self.0[index].value = attr.value; - } - Err(index) => { - if index == self.0.len() { - self.0.push(attr); - } else { - self.0.insert(index, attr); - } - } - } - } - - /// Remove an [Attribute] by name from this set - pub fn remove(&mut self, name: &Q) - where - Symbol: Borrow, - Q: Ord + ?Sized, - { - let name = name.borrow(); - match self.0.binary_search_by(|attr| name.cmp(attr.name.borrow()).reverse()) { - Ok(index) if index + 1 == self.0.len() => { - self.0.pop(); - } - Ok(index) => { - self.0.remove(index); - } - Err(_) => (), - } - } - - /// Determine if the named [Attribute] is present in this set - pub fn has(&self, key: &Q) -> bool - where - Symbol: Borrow, - Q: Ord + ?Sized, - { - let key = key.borrow(); - self.0.binary_search_by(|attr| key.cmp(attr.name.borrow()).reverse()).is_ok() - } - - /// Get the [AttributeValue] associated with the named [Attribute] - pub fn get_any(&self, key: &Q) -> Option<&dyn AttributeValue> - where - Symbol: Borrow, - Q: Ord + ?Sized, - { - let key = key.borrow(); - match self.0.binary_search_by(|attr| key.cmp(attr.name.borrow())) { - Ok(index) => self.0[index].value.as_deref(), - Err(_) => None, - } - } - - /// Get the [AttributeValue] associated with the named [Attribute] - pub fn get_any_mut(&mut self, key: &Q) -> Option<&mut dyn AttributeValue> - where - Symbol: Borrow, - Q: Ord + ?Sized, - { - let key = key.borrow(); - match self.0.binary_search_by(|attr| key.cmp(attr.name.borrow())) { - Ok(index) => self.0[index].value.as_deref_mut(), - Err(_) => None, - } - } - - /// Get the value associated with the named [Attribute] as a value of type `V`, or `None`. - pub fn get(&self, key: &Q) -> Option<&V> - where - Symbol: Borrow, - Q: Ord + ?Sized, - V: AttributeValue, - { - self.get_any(key).and_then(|v| v.downcast_ref::()) - } - - /// Get the value associated with the named [Attribute] as a value of type `V`, or `None`. - pub fn get_mut(&mut self, key: &Q) -> Option<&mut V> - where - Symbol: Borrow, - Q: Ord + ?Sized, - V: AttributeValue, - { - self.get_any_mut(key).and_then(|v| v.downcast_mut::()) - } - - /// Iterate over each [Attribute] in this set - pub fn iter(&self) -> impl Iterator + '_ { - self.0.iter() - } -} - -/// An [Attribute] associates some data with a well-known identifier (name). -/// -/// Attributes are used for representing metadata that helps guide compilation, -/// but which is not part of the code itself. For example, `cfg` flags in Rust -/// are an example of something which you could represent using an [Attribute]. -/// They can also be used to store documentation, source locations, and more. -#[derive(Debug)] -pub struct Attribute { - /// The name of this attribute - pub name: Symbol, - /// The value associated with this attribute - pub value: Option>, -} -impl Attribute { - pub fn new(name: impl Into, value: Option) -> Self { - Self { - name: name.into(), - value: value.map(|v| Box::new(v) as Box), - } - } - - pub fn value(&self) -> Option<&dyn AttributeValue> { - self.value.as_deref() - } - - pub fn value_as(&self) -> Option<&V> - where - V: AttributeValue, - { - match self.value.as_deref() { - Some(value) => value.downcast_ref::(), - None => None, - } - } -} -impl fmt::Display for Attribute { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.value.as_deref().map(|v| v.render()) { - None => write!(f, "#[{}]", self.name.as_str()), - Some(value) => write!(f, "#[{}({value})]", &self.name), - } - } -} - -pub trait AttributeValue: Any + fmt::Debug + crate::formatter::PrettyPrint + 'static { - fn as_any(&self) -> &dyn Any; - fn as_any_mut(&mut self) -> &mut dyn Any; -} - -impl dyn AttributeValue { - pub fn is(&self) -> bool { - self.as_any().is::() - } - - pub fn downcast_ref(&self) -> Option<&T> { - self.as_any().downcast_ref::() - } - - pub fn downcast_mut(&mut self) -> Option<&mut T> { - self.as_any_mut().downcast_mut::() - } -} - -pub struct SetAttr { - values: Vec, -} -impl Default for SetAttr { - fn default() -> Self { - Self { - values: Default::default(), - } - } -} -impl SetAttr -where - K: Ord + Clone, -{ - pub fn insert(&mut self, key: K) -> bool { - match self.values.binary_search_by(|k| key.cmp(k)) { - Ok(index) => { - self.values[index] = key; - false - } - Err(index) => { - self.values.insert(index, key); - true - } - } - } - - pub fn contains(&self, key: &K) -> bool { - self.values.binary_search_by(|k| key.cmp(k)).is_ok() - } - - pub fn iter(&self) -> core::slice::Iter<'_, K> { - self.values.iter() - } - - pub fn remove(&mut self, key: &Q) -> Option - where - K: Borrow, - Q: ?Sized + Ord, - { - match self.values.binary_search_by(|k| key.cmp(k.borrow())) { - Ok(index) => Some(self.values.remove(index)), - Err(_) => None, - } - } -} -impl Eq for SetAttr where K: Eq {} -impl PartialEq for SetAttr -where - K: PartialEq, -{ - fn eq(&self, other: &Self) -> bool { - self.values == other.values - } -} -impl fmt::Debug for SetAttr -where - K: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_set().entries(self.values.iter()).finish() - } -} -impl crate::formatter::PrettyPrint for SetAttr -where - K: crate::formatter::PrettyPrint, -{ - fn render(&self) -> crate::formatter::Document { - todo!() - } -} -impl AttributeValue for SetAttr -where - K: fmt::Debug + crate::formatter::PrettyPrint + 'static, -{ - #[inline(always)] - fn as_any(&self) -> &dyn Any { - self as &dyn Any - } - - #[inline(always)] - fn as_any_mut(&mut self) -> &mut dyn Any { - self as &mut dyn Any - } -} - -#[derive(Clone)] -pub struct DictAttr { - values: Vec<(K, V)>, -} -impl Default for DictAttr { - fn default() -> Self { - Self { values: vec![] } - } -} -impl DictAttr -where - K: Ord, - V: Clone, -{ - pub fn insert(&mut self, key: K, value: V) { - match self.values.binary_search_by(|(k, _)| key.cmp(k)) { - Ok(index) => { - self.values[index].1 = value; - } - Err(index) => { - self.values.insert(index, (key, value)); - } - } - } - - pub fn contains_key(&self, key: &Q) -> bool - where - K: Borrow, - Q: ?Sized + Ord, - { - self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())).is_ok() - } - - pub fn get(&self, key: &Q) -> Option<&V> - where - K: Borrow, - Q: ?Sized + Ord, - { - match self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())) { - Ok(index) => Some(&self.values[index].1), - Err(_) => None, - } - } - - pub fn remove(&mut self, key: &Q) -> Option - where - K: Borrow, - Q: ?Sized + Ord, - { - match self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())) { - Ok(index) => Some(self.values.remove(index).1), - Err(_) => None, - } - } - - pub fn iter(&self) -> core::slice::Iter<'_, (K, V)> { - self.values.iter() - } -} -impl Eq for DictAttr -where - K: Eq, - V: Eq, -{ -} -impl PartialEq for DictAttr -where - K: PartialEq, - V: PartialEq, -{ - fn eq(&self, other: &Self) -> bool { - self.values == other.values - } -} -impl fmt::Debug for DictAttr -where - K: fmt::Debug, - V: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_map() - .entries(self.values.iter().map(|entry| (&entry.0, &entry.1))) - .finish() - } -} -impl crate::formatter::PrettyPrint for DictAttr -where - K: crate::formatter::PrettyPrint, - V: crate::formatter::PrettyPrint, -{ - fn render(&self) -> crate::formatter::Document { - todo!() - } -} -impl AttributeValue for DictAttr -where - K: fmt::Debug + crate::formatter::PrettyPrint + 'static, - V: fmt::Debug + crate::formatter::PrettyPrint + 'static, -{ - #[inline(always)] - fn as_any(&self) -> &dyn Any { - self as &dyn Any - } - - #[inline(always)] - fn as_any_mut(&mut self) -> &mut dyn Any { - self as &mut dyn Any - } -} - -#[macro_export] -macro_rules! define_attr_type { - ($T:ty) => { - impl $crate::AttributeValue for $T { - #[inline(always)] - fn as_any(&self) -> &dyn core::any::Any { - self as &dyn core::any::Any - } - - #[inline(always)] - fn as_any_mut(&mut self) -> &mut dyn core::any::Any { - self as &mut dyn core::any::Any - } - } - }; -} - -define_attr_type!(bool); -define_attr_type!(u8); -define_attr_type!(i8); -define_attr_type!(u16); -define_attr_type!(i16); -define_attr_type!(u32); -define_attr_type!(core::num::NonZeroU32); -define_attr_type!(i32); -define_attr_type!(u64); -define_attr_type!(i64); -define_attr_type!(usize); -define_attr_type!(isize); -define_attr_type!(Symbol); -define_attr_type!(super::Immediate); -define_attr_type!(super::Type); diff --git a/hir2/src/ir/callable.rs b/hir2/src/ir/callable.rs index a9e9b67cf..b2e6c3f33 100644 --- a/hir2/src/ir/callable.rs +++ b/hir2/src/ir/callable.rs @@ -138,7 +138,7 @@ impl Callable { /// Represents whether an argument or return value has a special purpose in /// the calling convention of a function. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)] #[cfg_attr( feature = "serde", derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) @@ -168,7 +168,7 @@ impl fmt::Display for ArgumentPurpose { /// are unsigned 32-bit integers with a standard twos-complement binary representation. /// /// It is for the latter scenario that argument extension is really relevant. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)] #[cfg_attr( feature = "serde", derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) @@ -194,7 +194,7 @@ impl fmt::Display for ArgumentExtension { } /// Describes a function parameter or result. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct AbiParam { /// The type associated with this value @@ -257,7 +257,7 @@ impl fmt::Display for AbiParam { /// A function signature provides us with all of the necessary detail to correctly /// validate and emit code for a function, whether from the perspective of a caller, /// or the callee. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Signature { /// The arguments expected by this function @@ -354,15 +354,6 @@ impl Signature { } } } -impl Eq for Signature {} -impl PartialEq for Signature { - fn eq(&self, other: &Self) -> bool { - self.visibility == other.visibility - && self.cc == other.cc - && self.params.len() == other.params.len() - && self.results.len() == other.results.len() - } -} impl formatter::PrettyPrint for Signature { fn render(&self) -> formatter::Document { use crate::formatter::*; diff --git a/hir2/src/ir/op.rs b/hir2/src/ir/op.rs index c8ce1b1a3..f28c3e2d6 100644 --- a/hir2/src/ir/op.rs +++ b/hir2/src/ir/op.rs @@ -1,5 +1,5 @@ use super::*; -use crate::any::AsAny; +use crate::{any::AsAny, AttributeValue}; pub trait OpRegistration: Op { fn name() -> ::midenc_hir_symbol::Symbol; @@ -87,16 +87,10 @@ impl Spanned for dyn Op { pub trait OpExt { /// Return the value associated with attribute `name` for this function - fn get_attribute(&self, name: &Q) -> Option<&dyn AttributeValue> - where - interner::Symbol: std::borrow::Borrow, - Q: Ord + ?Sized; + fn get_attribute(&self, name: impl Into) -> Option<&dyn AttributeValue>; /// Return true if this function has an attributed named `name` - fn has_attribute(&self, name: &Q) -> bool - where - interner::Symbol: std::borrow::Borrow, - Q: Ord + ?Sized; + fn has_attribute(&self, name: impl Into) -> bool; /// Set the attribute `name` with `value` for this function. fn set_attribute( @@ -106,10 +100,7 @@ pub trait OpExt { ); /// Remove any attribute with the given name from this function - fn remove_attribute(&mut self, name: &Q) - where - interner::Symbol: std::borrow::Borrow, - Q: Ord + ?Sized; + fn remove_attribute(&mut self, name: impl Into); /// Returns a handle to the nearest containing [Operation] of type `T` for this operation, if it /// is attached to one @@ -118,20 +109,12 @@ pub trait OpExt { impl OpExt for T { #[inline] - fn get_attribute(&self, name: &Q) -> Option<&dyn AttributeValue> - where - interner::Symbol: std::borrow::Borrow, - Q: Ord + ?Sized, - { + fn get_attribute(&self, name: impl Into) -> Option<&dyn AttributeValue> { self.as_operation().get_attribute(name) } #[inline] - fn has_attribute(&self, name: &Q) -> bool - where - interner::Symbol: std::borrow::Borrow, - Q: Ord + ?Sized, - { + fn has_attribute(&self, name: impl Into) -> bool { self.as_operation().has_attribute(name) } @@ -145,11 +128,7 @@ impl OpExt for T { } #[inline] - fn remove_attribute(&mut self, name: &Q) - where - interner::Symbol: std::borrow::Borrow, - Q: Ord + ?Sized, - { + fn remove_attribute(&mut self, name: impl Into) { self.as_operation_mut().remove_attribute(name); } diff --git a/hir2/src/ir/operation.rs b/hir2/src/ir/operation.rs index eacd097a3..fa1083727 100644 --- a/hir2/src/ir/operation.rs +++ b/hir2/src/ir/operation.rs @@ -9,6 +9,7 @@ use core::{ pub use self::{builder::OperationBuilder, name::OperationName}; use super::*; +use crate::{AttributeSet, AttributeValue}; pub type OperationRef = UnsafeIntrusiveEntityRef; pub type OpList = EntityList; @@ -279,52 +280,42 @@ impl Operation { } /// Return the value associated with attribute `name` for this function - pub fn get_attribute(&self, name: &Q) -> Option<&dyn AttributeValue> - where - interner::Symbol: core::borrow::Borrow, - Q: Ord + ?Sized, - { - self.attrs.get_any(name) + pub fn get_attribute(&self, name: impl Into) -> Option<&dyn AttributeValue> { + self.attrs.get_any(name.into()) } /// Return the value associated with attribute `name` for this function - pub fn get_attribute_mut(&mut self, name: &Q) -> Option<&mut dyn AttributeValue> - where - interner::Symbol: core::borrow::Borrow, - Q: Ord + ?Sized, - { - self.attrs.get_any_mut(name) + pub fn get_attribute_mut( + &mut self, + name: impl Into, + ) -> Option<&mut dyn AttributeValue> { + self.attrs.get_any_mut(name.into()) } /// Return the value associated with attribute `name` for this function, as its concrete type /// `T`, _if_ the attribute by that name, is of that type. - pub fn get_typed_attribute(&self, name: &Q) -> Option<&T> + pub fn get_typed_attribute(&self, name: impl Into) -> Option<&T> where T: AttributeValue, - interner::Symbol: core::borrow::Borrow, - Q: Ord + ?Sized, { - self.attrs.get(name) + self.attrs.get(name.into()) } /// Return the value associated with attribute `name` for this function, as its concrete type /// `T`, _if_ the attribute by that name, is of that type. - pub fn get_typed_attribute_mut(&mut self, name: &Q) -> Option<&mut T> + pub fn get_typed_attribute_mut( + &mut self, + name: impl Into, + ) -> Option<&mut T> where T: AttributeValue, - interner::Symbol: core::borrow::Borrow, - Q: Ord + ?Sized, { - self.attrs.get_mut(name) + self.attrs.get_mut(name.into()) } /// Return true if this function has an attributed named `name` - pub fn has_attribute(&self, name: &Q) -> bool - where - interner::Symbol: std::borrow::Borrow, - Q: Ord + ?Sized, - { - self.attrs.has(name) + pub fn has_attribute(&self, name: impl Into) -> bool { + self.attrs.has(name.into()) } /// Set the attribute `name` with `value` for this function. @@ -337,12 +328,8 @@ impl Operation { } /// Remove any attribute with the given name from this function - pub fn remove_attribute(&mut self, name: &Q) - where - interner::Symbol: std::borrow::Borrow, - Q: Ord + ?Sized, - { - self.attrs.remove(name); + pub fn remove_attribute(&mut self, name: impl Into) { + self.attrs.remove(name.into()); } } @@ -361,8 +348,8 @@ impl Operation { owner: self.as_operation_ref(), symbol: name, }); - if self.has_attribute(&name) { - let attr = self.get_typed_attribute_mut::(&name).unwrap(); + if self.has_attribute(name) { + let attr = self.get_typed_attribute_mut::(name).unwrap(); let symbol = symbol.borrow(); assert!( !attr.user.is_linked(), diff --git a/hir2/src/ir/symbol_table.rs b/hir2/src/ir/symbol_table.rs index 0aaed3ab2..8ee2bf030 100644 --- a/hir2/src/ir/symbol_table.rs +++ b/hir2/src/ir/symbol_table.rs @@ -123,6 +123,12 @@ impl Ord for SymbolNameAttr { self.path.cmp(&other.path).then_with(|| self.name.cmp(&other.name)) } } +impl core::hash::Hash for SymbolNameAttr { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.path.hash(state); + } +} #[derive(Copy, Clone, PartialEq, Eq)] pub enum SymbolNameComponent { @@ -592,7 +598,7 @@ impl SymbolUse { impl fmt::Debug for SymbolUse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let op = self.owner.borrow(); - let value = op.get_typed_attribute::(&self.symbol); + let value = op.get_typed_attribute::(self.symbol); f.debug_struct("SymbolUse") .field("attr", &self.symbol) .field("symbol", &value) From 54ccab62a01fe7b4218b228d4dc1ea466ec5c693 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 20:10:35 -0400 Subject: [PATCH 18/31] wip: implement a variety of useful apis on regions/blocks/ops/values --- hir-macros/src/operation.rs | 52 ++-- hir2/src/ir/block.rs | 361 ++++++++++++++++++++++++-- hir2/src/ir/builder.rs | 58 ++--- hir2/src/ir/context.rs | 4 + hir2/src/ir/entity.rs | 66 ++++- hir2/src/ir/entity/storage.rs | 68 +++++ hir2/src/ir/insert.rs | 168 ++++++++++-- hir2/src/ir/op.rs | 4 + hir2/src/ir/operands.rs | 13 +- hir2/src/ir/operation.rs | 428 ++++++++++++++++++++++++++----- hir2/src/ir/operation/builder.rs | 6 +- hir2/src/ir/operation/name.rs | 4 +- hir2/src/ir/region.rs | 169 ++++++++++-- hir2/src/ir/traits.rs | 17 ++ hir2/src/ir/traits/foldable.rs | 126 +++++++++ hir2/src/ir/value.rs | 161 ++++++++++-- 16 files changed, 1495 insertions(+), 210 deletions(-) create mode 100644 hir2/src/ir/traits/foldable.rs diff --git a/hir-macros/src/operation.rs b/hir-macros/src/operation.rs index 900814978..62094fe85 100644 --- a/hir-macros/src/operation.rs +++ b/hir-macros/src/operation.rs @@ -647,8 +647,6 @@ impl quote::ToTokens for OpCreateFn<'_> { .create_params .iter() .flat_map(OpCreateParam::binding_types); - let traits = &self.op.traits; - let implements = &self.op.implements; let initialize_custom_fields = InitializeCustomFields(self.op); let with_symbols = WithSymbols(self.op); let with_attrs = WithAttrs(self.op); @@ -680,26 +678,7 @@ impl quote::ToTokens for OpCreateFn<'_> { let __operation_name = { let context = builder.context(); let dialect = context.get_or_register_dialect::<#dialect>(); - let opcode = ::name(); - dialect.get_or_register_op( - opcode, - |dialect_name, opcode| { - ::midenc_hir2::OperationName::new::( - dialect_name, - opcode, - [ - ::midenc_hir2::traits::TraitInfo::new::(), - ::midenc_hir2::traits::TraitInfo::new::(), - #( - ::midenc_hir2::traits::TraitInfo::new::(), - )* - #( - ::midenc_hir2::traits::TraitInfo::new::(), - )* - ] - ) - } - ) + ::register_with(&*dialect) }; let __context = builder.context_rc(); let mut __op = __context.alloc_uninit_tracked::(); @@ -773,6 +752,8 @@ impl quote::ToTokens for OpDefinition { // impl OpRegistration let opcode = &self.opcode; let opcode_str = syn::Lit::Str(syn::LitStr::new(&opcode.to_string(), opcode.span())); + let traits = &self.traits; + let implements = &self.implements; tokens.extend(quote! { impl #impl_generics ::midenc_hir2::Op for #op_ident #ty_generics #where_clause { #[inline] @@ -795,6 +776,29 @@ impl quote::ToTokens for OpDefinition { fn name() -> ::midenc_hir_symbol::Symbol { ::midenc_hir_symbol::Symbol::intern(#opcode_str) } + + fn register_with(dialect: &dyn ::midenc_hir2::Dialect) -> ::midenc_hir2::OperationName { + let opcode = ::name(); + dialect.get_or_register_op( + opcode, + |dialect_name, opcode| { + ::midenc_hir2::OperationName::new::( + dialect_name, + opcode, + [ + ::midenc_hir2::traits::TraitInfo::new::(), + ::midenc_hir2::traits::TraitInfo::new::(), + #( + ::midenc_hir2::traits::TraitInfo::new::(), + )* + #( + ::midenc_hir2::traits::TraitInfo::new::(), + )* + ] + ) + } + ) + } } }); @@ -881,6 +885,10 @@ impl quote::ToTokens for OpCustomFieldFns<'_> { // User-defined fields for field in self.0.op.fields.iter() { let field_name = field.ident.as_ref().unwrap(); + // Do not generate field functions for custom fields with private visibility + if matches!(field.vis, syn::Visibility::Inherited) { + continue; + } let field_name_mut = format_ident!("{field_name}_mut"); let set_field_name = format_ident!("set_{field_name}"); let field_doc = syn::Lit::Str(syn::LitStr::new( diff --git a/hir2/src/ir/block.rs b/hir2/src/ir/block.rs index 36a403375..1be0e2a6e 100644 --- a/hir2/src/ir/block.rs +++ b/hir2/src/ir/block.rs @@ -59,6 +59,8 @@ impl fmt::Display for BlockId { pub struct Block { /// The unique id of this block id: BlockId, + /// Flag that indicates whether the ops in this block have a valid ordering + valid_op_ordering: bool, /// The set of uses of this block uses: BlockOperandList, /// The region this block is attached to. @@ -78,17 +80,51 @@ impl fmt::Debug for Block { None => f.write_str("None"), Some(r) => write!(f, "Some({r:p})"), }) - .field("arguments", &self.arguments) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for arg in self.arguments.iter() { + list.entry_with(|f| f.write_fmt(format_args!("{}", &arg.borrow()))); + } + list.finish() + }) .finish_non_exhaustive() } } -impl Entity for Block { +impl Entity for Block {} +impl EntityWithId for Block { type Id = BlockId; fn id(&self) -> Self::Id { self.id } } +impl EntityWithParent for Block { + type Parent = Region; + + fn on_inserted_into_parent( + mut this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().region = Some(parent); + } + + fn on_removed_from_parent( + mut this: UnsafeIntrusiveEntityRef, + _parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().region = None; + } + + fn on_transfered_to_new_parent( + _from: UnsafeIntrusiveEntityRef, + to: UnsafeIntrusiveEntityRef, + transferred: impl IntoIterator>, + ) { + for mut transferred_block in transferred { + transferred_block.borrow_mut().region = Some(to.clone()); + } + } +} impl Usable for Block { type Use = BlockOperand; @@ -106,6 +142,7 @@ impl Block { pub fn new(id: BlockId) -> Self { Self { id, + valid_op_ordering: true, uses: Default::default(), region: None, body: Default::default(), @@ -113,6 +150,57 @@ impl Block { } } + #[inline] + pub fn as_block_ref(&self) -> BlockRef { + unsafe { BlockRef::from_raw(self) } + } + + /// Get a handle to the containing [Region] of this block, if it is attached to one + pub fn parent(&self) -> Option { + self.region.clone() + } + + /// Get a handle to the containing [Operation] of this block, if it is attached to one + pub fn parent_op(&self) -> Option { + self.region.as_ref().and_then(|region| region.borrow().parent()) + } + + /// Returns true if this block is the entry block for its containing region + pub fn is_entry_block(&self) -> bool { + if let Some(parent) = self.region.as_ref().map(|r| r.borrow()) { + core::ptr::addr_eq(&*parent.entry(), self) + } else { + false + } + } + + /// Get the first operation in the body of this block + #[inline] + pub fn front(&self) -> Option { + self.body.front().as_pointer() + } + + /// Get the last operation in the body of this block + #[inline] + pub fn back(&self) -> Option { + self.body.back().as_pointer() + } + + /// Get the list of [Operation] comprising the body of this block + #[inline(always)] + pub fn body(&self) -> &OpList { + &self.body + } + + /// Get a mutable reference to the list of [Operation] comprising the body of this block + #[inline(always)] + pub fn body_mut(&mut self) -> &mut OpList { + &mut self.body + } +} + +/// Arguments +impl Block { #[inline] pub fn has_arguments(&self) -> bool { !self.arguments.is_empty() @@ -138,6 +226,35 @@ impl Block { self.arguments[index].clone() } + /// Erase the block argument at `index` + /// + /// Panics if the argument still has uses. + pub fn erase_argument(&mut self, index: usize) { + assert!( + !self.arguments[index].borrow().is_used(), + "cannot erase block arguments with uses" + ); + self.arguments.remove(index); + } + + /// Erase every parameter of this block for which `should_erase` returns true. + /// + /// Panics if any argument to be erased still has uses. + pub fn erase_arguments(&mut self, should_erase: F) + where + F: Fn(&BlockArgument) -> bool, + { + self.arguments.retain(|arg| { + let arg = arg.borrow(); + let keep = !should_erase(&arg); + assert!(keep || !arg.is_used(), "cannot erase block arguments with uses"); + keep + }); + } +} + +/// Placement +impl Block { /// Insert this block after `after` in its containing region. /// /// Panics if this block is already attached to a region, or if `after` is not attached. @@ -152,7 +269,7 @@ impl Block { let mut region = region.borrow_mut(); let region_body = region.body_mut(); let mut cursor = unsafe { region_body.cursor_mut_from_ptr(after) }; - cursor.insert_after(unsafe { BlockRef::from_raw(self) }); + cursor.insert_after(self.as_block_ref()); } self.region = Some(region); } @@ -171,7 +288,7 @@ impl Block { let mut region = region.borrow_mut(); let region_body = region.body_mut(); let mut cursor = unsafe { region_body.cursor_mut_from_ptr(before) }; - cursor.insert_before(unsafe { BlockRef::from_raw(self) }); + cursor.insert_before(self.as_block_ref()); } self.region = Some(region); } @@ -186,33 +303,86 @@ impl Block { ); { let mut region = region.borrow_mut(); - region.body_mut().push_back(unsafe { BlockRef::from_raw(self) }); + region.body_mut().push_back(self.as_block_ref()); } self.region = Some(region); } - /// Get a handle to the containing [Region] of this block, if it is attached to one - pub fn parent(&self) -> Option { - self.region.clone() + /// Unlink this block from its current region and insert it right before `before` + pub fn move_before(&mut self, before: BlockRef) { + self.unlink(); + self.insert_before(before); } - /// Get a handle to the containing [Operation] of this block, if it is attached to one - pub fn parent_op(&self) -> Option { - self.region.as_ref().and_then(|region| region.borrow().parent()) + /// Remove this block from its containing region + fn unlink(&mut self) { + if let Some(mut region) = self.region.take() { + let mut region = region.borrow_mut(); + unsafe { + let mut cursor = region.body_mut().cursor_mut_from_ptr(self.as_block_ref()); + cursor.remove(); + } + } } - /// Get the list of [Operation] comprising the body of this block - #[inline(always)] - pub fn body(&self) -> &OpList { - &self.body + /// Split this block into two blocks before the specified operation + /// + /// Note that all operations in the block prior to `before` stay as part of the original block, + /// and the rest are moved to the new block, including the old terminator. The original block is + /// thus left without a terminator. + /// + /// Returns the newly created block. + pub fn split_block(&mut self, before: OperationRef) -> BlockRef { + let this = self.as_block_ref(); + assert!( + BlockRef::ptr_eq( + &this, + &before.borrow().parent().expect("'before' op is not attached to a block") + ), + "cannot split block using an operation that does not belong to the block being split" + ); + + // We need the parent op so we can get access to the current Context, but this also tells us + // that this block is attached to a region and operation. + let parent = self.parent_op().expect("block is not attached to an operation"); + // Create a new empty block + let mut new_block = parent.borrow().context().create_block(); + // Insert the block in the same region as `self`, immediately after `self` + let region = self.region.as_mut().unwrap(); + { + let mut region_mut = region.borrow_mut(); + let blocks = region_mut.body_mut(); + let mut cursor = unsafe { blocks.cursor_mut_from_ptr(this.clone()) }; + cursor.insert_after(new_block.clone()); + } + // Split the body of `self` at `before`, and splice everything after `before`, including + // `before` itself, into the new block we created. + let mut ops = { + let mut cursor = unsafe { self.body.cursor_mut_from_ptr(before) }; + cursor.split_before() + }; + // The split_before method returns the list containing all of the ops before the cursor, but + // we want the inverse, so we just swap the two lists. + core::mem::swap(&mut self.body, &mut ops); + // Visit all of the ops and notify them of the move + for op in ops.iter() { + Operation::on_inserted_into_parent(op.as_operation_ref(), new_block.clone()); + } + new_block.borrow_mut().body = ops; + new_block } - /// Get a mutable reference to the list of [Operation] comprising the body of this block - #[inline(always)] - pub fn body_mut(&mut self) -> &mut OpList { - &mut self.body + pub fn clear(&mut self) { + // Drop all references from within this block + self.drop_all_references(); + + // Drop all operations within this block + self.body_mut().clear(); } +} +/// Predecessors and Successors +impl Block { /// Returns true if this block has predecessors #[inline(always)] pub fn has_predecessors(&self) -> bool { @@ -225,8 +395,159 @@ impl Block { self.iter_uses() } + /// If this block has exactly one predecessor, return it, otherwise `None` + /// + /// NOTE: A predecessor block with multiple edges, e.g. a conditional branch that has this block + /// as the destination for both true/false branches is _not_ considered a single predecessor by + /// this function. + pub fn get_single_predecessor(&self) -> Option { + let front = self.uses.front(); + if front.is_null() { + return None; + } + let front = front.as_pointer().unwrap(); + let back = self.uses.back().as_pointer().unwrap(); + if BlockOperandRef::ptr_eq(&front, &back) { + Some(front.borrow().block.clone()) + } else { + None + } + } + + /// If this block has a unique predecessor, i.e. all incoming edges originate from one block, + /// return it, otherwise `None` + pub fn get_unique_predecessor(&self) -> Option { + let mut front = self.uses.front(); + let block_operand = front.get()?; + let block = block_operand.block.clone(); + loop { + front.move_next(); + if let Some(bo) = front.get() { + if !BlockRef::ptr_eq(&block, &bo.block) { + break None; + } + } else { + break Some(block); + } + } + } + + /// Returns true if this block has any successors + #[inline] + pub fn has_successors(&self) -> bool { + self.num_successors() > 0 + } + + /// Get the number of successors of this block in the CFG + pub fn num_successors(&self) -> usize { + self.terminator().map(|op| op.borrow().num_successors()).unwrap_or(0) + } + + /// Get the `index`th successor of this block's terminator operation + pub fn get_successor(&self, index: usize) -> BlockRef { + let op = self.terminator().expect("this block has no terminator"); + op.borrow().successor(index).dest.borrow().block.clone() + } + + /// This drops all operand uses from operations within this block, which is an essential step in + /// breaking cyclic dependences between references when they are to be deleted. + pub fn drop_all_references(&mut self) { + let mut cursor = self.body.front_mut(); + while let Some(mut op) = cursor.as_pointer() { + op.borrow_mut().drop_all_references(); + cursor.move_next(); + } + } + + /// This drops all uses of values defined in this block or in the blocks of nested regions + /// wherever the uses are located. pub fn drop_all_defined_value_uses(&mut self) { - todo!() + for arg in self.arguments.iter_mut() { + let mut arg = arg.borrow_mut(); + arg.uses_mut().clear(); + } + let mut cursor = self.body.front_mut(); + while let Some(mut op) = cursor.as_pointer() { + op.borrow_mut().drop_all_defined_value_uses(); + cursor.move_next(); + } + self.drop_all_uses(); + } + + /// Drop all uses of this block via [BlockOperand] + #[inline] + pub fn drop_all_uses(&mut self) { + self.uses_mut().clear(); + } + + #[inline(always)] + pub(super) const fn is_op_order_valid(&self) -> bool { + self.valid_op_ordering + } + + #[inline(always)] + pub(super) fn mark_op_order_valid(&mut self) { + self.valid_op_ordering = true; + } + + pub(super) fn invalidate_op_order(&mut self) { + // Validate the current ordering + assert!(self.verify_op_order()); + self.valid_op_ordering = false; + } + + /// Returns true if the current operation ordering in this block is valid + pub(super) fn verify_op_order(&self) -> bool { + // The order is already known to be invalid + if !self.valid_op_ordering { + return false; + } + + // The order is valid if there are less than 2 operations + if self.body.is_empty() + || OperationRef::ptr_eq( + &self.body.front().as_pointer().unwrap(), + &self.body.back().as_pointer().unwrap(), + ) + { + return true; + } + + let mut cursor = self.body.front(); + let mut prev = None; + while let Some(op) = cursor.as_pointer() { + cursor.move_next(); + if prev.is_none() { + prev = Some(op); + continue; + } + + // The previous operation must have a smaller order index than the next + let prev_order = prev.take().unwrap().borrow().order(); + let current_order = op.borrow().order().unwrap_or(u32::MAX); + if prev_order.is_some_and(|o| o >= current_order) { + return false; + } + prev = Some(op); + } + + true + } + + /// Get the terminator operation of this block, or `None` if the block does not have one. + pub fn terminator(&self) -> Option { + if !self.has_terminator() { + None + } else { + self.body.back().as_pointer() + } + } + + /// Returns true if this block has a terminator + pub fn has_terminator(&self) -> bool { + use crate::traits::Terminator; + !self.body.is_empty() + && self.body.back().get().is_some_and(|op| op.implements::()) } } diff --git a/hir2/src/ir/builder.rs b/hir2/src/ir/builder.rs index 916e1cf57..1d5fb4e9a 100644 --- a/hir2/src/ir/builder.rs +++ b/hir2/src/ir/builder.rs @@ -78,74 +78,74 @@ pub trait Builder: Listener { /// Add a new block with `args` arguments, and set the insertion point to the end of it. /// - /// The block is inserted at the provided insertion point `ip`, or at the end of `parent` if + /// The block is inserted after the provided insertion point `ip`, or at the end of `parent` if /// not. /// /// Panics if `ip` is in a different region than `parent`, or if the position it refers to is no /// longer valid. - fn create_block

( - &mut self, - parent: RegionRef, - ip: Option, - args: P, - ) -> BlockRef - where - P: IntoIterator, - { - let mut block = self.context().create_block_with_params(args); - if let Some(InsertionPoint { at, action }) = ip { - let at = at.block().expect("invalid insertion point"); + fn create_block(&mut self, parent: RegionRef, ip: Option, args: &[Type]) -> BlockRef { + let mut block = self.context().create_block_with_params(args.iter().cloned()); + if let Some(at) = ip { let region = at.borrow().parent().unwrap(); assert!( RegionRef::ptr_eq(&parent, ®ion), "insertion point region differs from 'parent'" ); - match action { - crate::Insert::Before => block.borrow_mut().insert_before(at), - crate::Insert::After => block.borrow_mut().insert_after(at), - } + block.borrow_mut().insert_after(at); } else { block.borrow_mut().insert_at_end(parent); } self.notify_block_inserted(block.clone(), None, None); + self.set_insertion_point_to_end(block.clone()); + block } /// Add a new block with `args` arguments, and set the insertion point to the end of it. /// /// The block is inserted before `before`. - fn create_block_before

(&mut self, before: BlockRef, args: P) -> BlockRef - where - P: IntoIterator, - { - let mut block = self.context().create_block_with_params(args); + fn create_block_before(&mut self, before: BlockRef, args: &[Type]) -> BlockRef { + let mut block = self.context().create_block_with_params(args.iter().cloned()); block.borrow_mut().insert_before(before); self.notify_block_inserted(block.clone(), None, None); + self.set_insertion_point_to_end(block.clone()); block } - /// Insert `op` at the current insertion point + /// Insert `op` at the current insertion point. + /// + /// If the insertion point is inserting after the current operation, then after calling this + /// function, the insertion point will have been moved to the newly inserted operation. This + /// ensures that subsequent calls to `insert` will place operations in the block in the same + /// sequence as they were inserted. The other insertion point placements already have more or + /// less intuitive behavior, e.g. inserting _before_ the current operation multiple times will + /// result in operations being placed in the same sequence they were inserted, just before the + /// current op. /// /// This function will panic if no insertion point is set. fn insert(&mut self, mut op: OperationRef) { - let InsertionPoint { at, action } = - self.insertion_point().expect("insertion point is unset").clone(); - match at { - ProgramPoint::Block(block) => match action { + let ip = self.insertion_point().expect("insertion point is unset").clone(); + match ip.at { + ProgramPoint::Block(block) => match ip.placement { crate::Insert::Before => op.borrow_mut().insert_at_start(block), crate::Insert::After => op.borrow_mut().insert_at_end(block), }, - ProgramPoint::Op(other_op) => match action { + ProgramPoint::Op(other_op) => match ip.placement { crate::Insert::Before => op.borrow_mut().insert_before(other_op), - crate::Insert::After => op.borrow_mut().insert_after(other_op), + crate::Insert::After => { + op.borrow_mut().insert_after(other_op.clone()); + self.set_insertion_point_after(ProgramPoint::Op(other_op)); + } }, } self.notify_operation_inserted(op, None); } +} +pub trait BuilderExt: Builder { /// Returns a specialized builder for a concrete [Op], `T`, which can be called like a closure /// with the arguments required to create an instance of the specified operation. /// diff --git a/hir2/src/ir/context.rs b/hir2/src/ir/context.rs index 9a01bfc5b..cd702775f 100644 --- a/hir2/src/ir/context.rs +++ b/hir2/src/ir/context.rs @@ -65,6 +65,10 @@ impl Context { self.registered_dialects.borrow() } + pub fn get_registered_dialect(&self, dialect: &DialectName) -> Rc { + self.registered_dialects.borrow()[dialect].clone() + } + pub fn get_or_register_dialect(&self) -> Rc { use alloc::collections::btree_map::Entry; diff --git a/hir2/src/ir/entity.rs b/hir2/src/ir/entity.rs index 6539b4864..33155316e 100644 --- a/hir2/src/ir/entity.rs +++ b/hir2/src/ir/entity.rs @@ -20,10 +20,51 @@ pub use self::{ }; use crate::any::*; -/// A trait implemented by an IR entity that has a unique identifier +/// A trait implemented by an IR entity +pub trait Entity: Any {} + +/// A trait implemented by an [Entity] that is a logical child of another entity type, and is stored +/// in the parent using an [EntityList]. +/// +/// This trait defines callbacks that are executed any time the entity is modified in relation to its +/// parent entity, i.e. inserted in a parent, removed from a parent, or moved from one to another. +/// +/// By default, these callbacks are no-ops. +pub trait EntityWithParent: Entity { + /// The parent entity that this entity logically belongs to. + type Parent: Entity; + + /// Invoked when this entity type is inserted into an intrusive list + #[allow(unused_variables)] + #[inline] + fn on_inserted_into_parent( + this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + } + /// Invoked when this entity type is removed from an intrusive list + #[allow(unused_variables)] + #[inline] + fn on_removed_from_parent( + this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + } + /// Invoked when a set of entities is moved from one intrusive list to another + #[allow(unused_variables)] + #[inline] + fn on_transfered_to_new_parent( + from: UnsafeIntrusiveEntityRef, + to: UnsafeIntrusiveEntityRef, + transferred: impl IntoIterator>, + ) { + } +} + +/// A trait implemented by an [Entity] that has a unique identifier /// /// Currently, this is used only for [Value]s and [Block]s. -pub trait Entity: Any { +pub trait EntityWithId: Entity { type Id: EntityId; fn id(&self) -> Self::Id; @@ -445,11 +486,7 @@ impl RawEntityRef { Obj: ?Sized, { let borrow = from.borrow(); - if let Some(to) = borrow.as_any().downcast_ref() { - Some(unsafe { RawEntityRef::from_raw(to) }) - } else { - None - } + borrow.as_any().downcast_ref().map(|to| unsafe { RawEntityRef::from_raw(to) }) } #[track_caller] @@ -498,10 +535,25 @@ where } impl Eq for RawEntityRef {} impl PartialEq for RawEntityRef { + #[inline] fn eq(&self, other: &Self) -> bool { Self::ptr_eq(self, other) } } +impl PartialOrd for RawEntityRef { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for RawEntityRef { + #[inline] + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + let a = self.inner.as_ptr() as *const () as usize; + let b = other.inner.as_ptr() as *const () as usize; + a.cmp(&b) + } +} impl core::hash::Hash for RawEntityRef { fn hash(&self, state: &mut H) { self.inner.hash(state); diff --git a/hir2/src/ir/entity/storage.rs b/hir2/src/ir/entity/storage.rs index fed8d6cc2..76a918cc6 100644 --- a/hir2/src/ir/entity/storage.rs +++ b/hir2/src/ir/entity/storage.rs @@ -191,11 +191,20 @@ impl core::ops::IndexMut for EntityStorage { range: core::ops::Range, items: &'a [T], } impl<'a, T> EntityRange<'a, T> { + /// Get an empty range + pub fn empty() -> Self { + Self { + range: 0..0, + items: &[], + } + } + /// Returns true if this range is empty #[inline] pub fn is_empty(&self) -> bool { @@ -238,6 +247,15 @@ impl<'a, T> core::ops::Index for EntityRange<'a, T> { &self.as_slice()[index] } } +impl<'a, T> IntoIterator for EntityRange<'a, T> { + type IntoIter = core::slice::Iter<'a, T>; + type Item = &'a T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.iter() + } +} /// A mutable range of items in [EntityStorage] /// @@ -263,6 +281,14 @@ impl<'a, T, const INLINE: usize> EntityRangeMut<'a, T, INLINE> { self.as_slice().len() } + /// Temporarily borrow this range immutably + pub fn as_immutable(&self) -> EntityRange<'_, T> { + EntityRange { + range: self.range.clone(), + items: self.items.as_slice(), + } + } + /// Get this range as a slice #[inline] pub fn as_slice(&self) -> &[T] { @@ -384,6 +410,39 @@ impl<'a, T: StorableEntity, const INLINE: usize> EntityRangeMut<'a, T, INLINE> { } } + /// Remove the item at `index` in this group, and return it. + /// + /// NOTE: This will panic if `index` is out of bounds of the group. + pub fn erase(&mut self, index: usize) -> T { + assert!(self.range.len() > index, "index out of bounds"); + self.range.end -= 1; + self.groups[self.group].shrink(1); + let mut removed = self.items.remove(self.range.start + index); + { + removed.unlink(); + } + + // Shift groups + let next_group = self.group + 1; + if next_group < self.groups.len() { + for group in self.groups[next_group..].iter_mut() { + group.shift_start(-1); + } + } + + // Shift item indices + let next_item = index; + if next_item < self.items.len() { + for (offset, item) in self.items[next_item..].iter_mut().enumerate() { + unsafe { + item.set_index(next_item + offset); + } + } + } + + removed + } + /// Remove the last item from this group, or `None` if empty pub fn pop(&mut self) -> Option { if self.range.is_empty() { @@ -432,6 +491,15 @@ impl<'a, T, const INLINE: usize> core::ops::IndexMut for EntityRangeMut<' &mut self.as_slice_mut()[index] } } +impl<'a, T> IntoIterator for EntityRangeMut<'a, T> { + type IntoIter = core::slice::IterMut<'a, T>; + type Item = &'a mut T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} #[cfg(test)] mod tests { diff --git a/hir2/src/ir/insert.rs b/hir2/src/ir/insert.rs index c3899c009..72ec5361a 100644 --- a/hir2/src/ir/insert.rs +++ b/hir2/src/ir/insert.rs @@ -2,68 +2,202 @@ use core::fmt; use crate::{BlockRef, OperationRef}; +/// Represents the placement of inserted items relative to a [ProgramPoint] #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Insert { + /// New items will be inserted before the current program point Before, + /// New items will be inserted after the current program point After, } +/// Represents a cursor within a region where new operations will be inserted. +/// +/// The `placement` field determines how new operations will be inserted relative to `at`: +/// +/// * If `at` is a block: +/// * `Insert::Before` will always inserts new operations as the first operation in the block, +/// i.e. every insert pushes to the front of the list of operations. +/// * `Insert::After` will always insert new operations at the end of the block, i.e. every insert +/// pushes to the back of the list of operations. +/// * If `at` is an operation: +/// * `Insert::Before` will always insert new operations directly preceding the `at` operation. +/// * `Insert::After` will always insert new operations directly following the `at` operation. +/// +/// If a builder/rewriter wishes to insert new operations starting at some point in the middle of a +/// block, but then move the insertion point forward as new operations are inserted, the builder +/// must call [move_next] (or [move_prev]) after each insertion. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct InsertionPoint { pub at: ProgramPoint, - pub action: Insert, + pub placement: Insert, } impl InsertionPoint { + /// Create a new insertion point with the specified placement, at the given program point. #[inline] - pub const fn new(at: ProgramPoint, action: Insert) -> Self { - Self { at, action } + pub const fn new(at: ProgramPoint, placement: Insert) -> Self { + Self { at, placement } } + /// Create a new insertion point at the given program point, which will place new operations + /// "before" that point. + /// + /// See [Insert::Before] for what the semantics of "before" means with regards to the different + /// kinds of program point. #[inline] - pub const fn before(at: ProgramPoint) -> Self { + pub fn before(at: impl Into) -> Self { Self { - at, - action: Insert::Before, + at: at.into(), + placement: Insert::Before, } } + /// Create a new insertion point at the given program point, which will place new operations + /// "after" that point. + /// + /// See [Insert::After] for what the semantics of "after" means with regards to the different + /// kinds of program point. #[inline] - pub const fn after(at: ProgramPoint) -> Self { + pub fn after(at: impl Into) -> Self { Self { - at, - action: Insert::After, + at: at.into(), + placement: Insert::After, + } + } + + /// Moves the insertion point to the previous operation relative to the current point. + /// + /// If there is no operation before the current point, this has no effect. + /// + /// If the current point is a [ProgramPoint::Block], and `self.placement` is `Insert::After`, + /// then this moves the insertion point to the operation immediately preceding the last + /// operation in the block, _if_ there are at least two operations in the block. In other words, + /// `self.at` becomes a [ProgramPoint::Op]. + pub fn move_prev(&mut self) { + match &mut self.at { + ProgramPoint::Op(ref mut current) => { + let prev = current.prev(); + if let Some(prev) = prev { + *current = prev; + } + } + ProgramPoint::Block(ref block) => { + if matches!(self.placement, Insert::After) { + if let Some(prev) = + block.borrow().body().back().as_pointer().and_then(|current| current.prev()) + { + self.at = ProgramPoint::Op(prev); + } + } + } + } + } + + /// Moves the insertion point to the next operation relative to the current point. + /// + /// If there is no operation after the current point, this has no effect. + /// + /// If the current point is a [ProgramPoint::Block], and `self.placement` is `Insert::Before`, + /// then this moves the insertion point to the operation immediately following the first + /// operation in the block, _if_ there are at least two operations in the block. In other words, + /// `self.at` becomes a [ProgramPoint::Op]. + pub fn move_next(&mut self) { + match &mut self.at { + ProgramPoint::Op(ref mut current) => { + let next = current.next(); + if let Some(next) = next { + *current = next; + } + } + ProgramPoint::Block(ref block) => { + if matches!(self.placement, Insert::Before) { + if let Some(next) = block + .borrow() + .body() + .front() + .as_pointer() + .and_then(|current| current.next()) + { + self.at = ProgramPoint::Op(next); + } + } + } } } + /// Get a pointer to the [crate::Operation] on which this insertion point is positioned. + /// + /// Returns `None` if the insertion point is positioned in an empty block. + pub fn op(&self) -> Option { + match self.at { + ProgramPoint::Op(ref op) => Some(op.clone()), + ProgramPoint::Block(ref block) => match self.placement { + Insert::Before => block.borrow().front(), + Insert::After => block.borrow().back(), + }, + } + } + + /// Get a pointer to the [crate::Block] in which this insertion point is positioned. + /// + /// Panics if the current program point is an operation detached from any block. pub fn block(&self) -> BlockRef { - self.at.block().expect("cannot insert relative to detached operation") + self.at + .block() + .expect("invalid insertion point: operation is detached from any block") + } + + /// Returns true if this insertion point is positioned at the end of the containing block + pub fn is_at_block_end(&self) -> bool { + let block = self.block().borrow(); + if block.body().is_empty() { + matches!(self.at, ProgramPoint::Block(_)) + } else if matches!(self.placement, Insert::Before) { + false + } else { + match &self.at { + ProgramPoint::Block(_) => true, + ProgramPoint::Op(ref op) => &block.back().unwrap() == op, + } + } } } -/// A `ProgramPoint` represents a position in a function where the live range of an SSA value can +/// A `ProgramPoint` represents a position in a region where the live range of an SSA value can /// begin or end. It can be either: /// -/// 1. An instruction or -/// 2. A block header. +/// 1. An operation +/// 2. A block /// /// This corresponds more or less to the lines in the textual form of the IR. #[derive(PartialEq, Eq, Clone, Hash)] pub enum ProgramPoint { /// An operation Op(OperationRef), - /// A block header. + /// A block Block(BlockRef), } impl ProgramPoint { - /// Get the operation we know is inside. + /// Unwrap this program point as an [OperationRef], or panic if this is not a [ProgramPoint::Op] pub fn unwrap_op(self) -> OperationRef { - use crate::Entity; + use crate::EntityWithId; match self { Self::Op(x) => x, Self::Block(x) => panic!("expected operation, but got {}", x.borrow().id()), } } + /// Get the closest operation associated with this program point + /// + /// If this program point refers to a block, this will return the last operation in the block, + /// or `None` if the block is empty. + pub fn op(&self) -> Option { + match self { + Self::Op(op) => Some(op.clone()), + Self::Block(block) => block.borrow().back(), + } + } + /// Get the block associated with this program point /// /// Returns `None` if the program point is a detached operation. @@ -86,7 +220,7 @@ impl From for ProgramPoint { } impl fmt::Display for ProgramPoint { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use crate::Entity; + use crate::EntityWithId; match self { Self::Op(x) => write!(f, "{}", x.borrow().name()), Self::Block(x) => write!(f, "{}", x.borrow().id()), diff --git a/hir2/src/ir/op.rs b/hir2/src/ir/op.rs index f28c3e2d6..5406f7a12 100644 --- a/hir2/src/ir/op.rs +++ b/hir2/src/ir/op.rs @@ -3,6 +3,7 @@ use crate::{any::AsAny, AttributeValue}; pub trait OpRegistration: Op { fn name() -> ::midenc_hir_symbol::Symbol; + fn register_with(dialect: &dyn Dialect) -> OperationName; } pub trait BuildableOp: Op { @@ -35,6 +36,9 @@ pub trait Op: AsAny + OpVerifier { fn parent_op(&self) -> Option { self.as_operation().parent_op() } + fn num_regions(&self) -> usize { + self.as_operation().num_regions() + } fn regions(&self) -> &RegionList { self.as_operation().regions() } diff --git a/hir2/src/ir/operands.rs b/hir2/src/ir/operands.rs index 15d1c46f2..e2db1b83b 100644 --- a/hir2/src/ir/operands.rs +++ b/hir2/src/ir/operands.rs @@ -34,6 +34,16 @@ impl OpOperandImpl { self.value.borrow() } + #[inline] + pub fn as_value_ref(&self) -> ValueRef { + self.value.clone() + } + + #[inline] + pub fn as_operand_ref(&self) -> OpOperand { + unsafe { OpOperand::from_raw(self) } + } + pub fn owner(&self) -> EntityRef<'_, crate::Operation> { self.owner.borrow() } @@ -65,6 +75,7 @@ impl crate::Spanned for OpOperandImpl { self.value.borrow().span() } } +impl crate::Entity for OpOperandImpl {} impl crate::StorableEntity for OpOperandImpl { #[inline(always)] fn index(&self) -> usize { @@ -76,7 +87,7 @@ impl crate::StorableEntity for OpOperandImpl { } fn unlink(&mut self) { - let ptr = unsafe { OpOperand::from_raw(self as *mut Self) }; + let ptr = self.as_operand_ref(); let mut value = self.value.borrow_mut(); let uses = value.uses_mut(); unsafe { diff --git a/hir2/src/ir/operation.rs b/hir2/src/ir/operation.rs index fa1083727..298e970d7 100644 --- a/hir2/src/ir/operation.rs +++ b/hir2/src/ir/operation.rs @@ -5,8 +5,11 @@ use alloc::rc::Rc; use core::{ fmt, ptr::{DynMetadata, NonNull, Pointee}, + sync::atomic::AtomicU32, }; +use smallvec::SmallVec; + pub use self::{builder::OperationBuilder, name::OperationName}; use super::*; use crate::{AttributeSet, AttributeValue}; @@ -77,6 +80,13 @@ pub struct Operation { /// cannot be constructed without providing it to the `uninit` function, and callers of that /// function are required to ensure that it is correct. offset: usize, + /// The order of this operation in its containing block + /// + /// This is atomic to ensure that even if a mutable reference to this operation is held, loads + /// of this field cannot be elided, as the value can still be mutated at any time. In practice, + /// the only time this is ever written, is when all operations in a block have their orders + /// recomputed, or when a single operation is updating its own order. + order: AtomicU32, #[span] pub span: SourceSpan, /// Attributes that apply to this operation @@ -104,6 +114,8 @@ impl fmt::Debug for Operation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Operation") .field_with("name", |f| write!(f, "{}", &self.name())) + .field("offset", &self.offset) + .field("order", &self.order) .field("attrs", &self.attrs) .field("block", &self.block.as_ref().map(|b| b.borrow().id())) .field("operands", &self.operands) @@ -112,17 +124,59 @@ impl fmt::Debug for Operation { .finish_non_exhaustive() } } + impl AsRef for Operation { fn as_ref(&self) -> &dyn Op { self.name.upcast(self.container()).unwrap() } } + impl AsMut for Operation { fn as_mut(&mut self) -> &mut dyn Op { self.name.upcast_mut(self.container().cast_mut()).unwrap() } } +impl Entity for Operation {} +impl EntityWithParent for Operation { + type Parent = Block; + + fn on_inserted_into_parent( + mut this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + let mut op = this.borrow_mut(); + op.block = Some(parent); + op.order.store(Self::INVALID_ORDER, std::sync::atomic::Ordering::Release); + } + + fn on_removed_from_parent( + mut this: UnsafeIntrusiveEntityRef, + _parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().block = None; + } + + fn on_transfered_to_new_parent( + from: UnsafeIntrusiveEntityRef, + mut to: UnsafeIntrusiveEntityRef, + transferred: impl IntoIterator>, + ) { + // Invalidate the ordering of the new parent block + to.borrow_mut().invalidate_op_order(); + + // If we are transferring operations within the same block, the block pointer doesn't + // need to be updated + if BlockRef::ptr_eq(&from, &to) { + return; + } + + for mut transferred_op in transferred { + transferred_op.borrow_mut().block = Some(to.clone()); + } + } +} + /// Construction impl Operation { #[doc(hidden)] @@ -133,6 +187,7 @@ impl Operation { context: unsafe { NonNull::new_unchecked(Rc::as_ptr(&context).cast_mut()) }, name, offset, + order: AtomicU32::new(0), span: Default::default(), attrs: Default::default(), block: Default::default(), @@ -150,10 +205,14 @@ impl Operation { /// /// An operation name consists of both its dialect, and its opcode. pub fn name(&self) -> OperationName { - //AsRef::::as_ref(self).name() self.name.clone() } + /// Get the dialect associated with this operation + pub fn dialect(&self) -> Rc { + self.context().get_registered_dialect(self.name.dialect()) + } + /// Set the source location associated with this operation #[inline] pub fn set_span(&mut self, span: SourceSpan) { @@ -609,31 +668,73 @@ impl Operation { &mut self.operands } - /// TODO: Remove in favor of [OpBuilder] + /// Replace the current operands of this operation with the ones provided in `operands`. + pub fn set_operands(&mut self, operands: impl Iterator) { + self.operands.clear(); + let context = self.context_rc(); + let owner = self.as_operation_ref(); + self.operands.extend( + operands + .into_iter() + .enumerate() + .map(|(index, value)| context.make_operand(value, owner.clone(), index as u8)), + ); + } + + /// Replace any uses of `from` with `to` within this operation pub fn replaces_uses_of_with(&mut self, mut from: ValueRef, mut to: ValueRef) { if ValueRef::ptr_eq(&from, &to) { return; } - let from_id = from.borrow().id(); - if from_id == to.borrow().id() { - return; - } - - for mut operand in self.operands.iter().cloned() { - if operand.borrow().value.borrow().id() == from_id { - debug_assert!(operand.is_linked()); - // Remove the operand from `from` + for operand in self.operands.iter_mut() { + debug_assert!(operand.is_linked()); + if ValueRef::ptr_eq(&from, &operand.borrow().value) { + // Remove use of `from` by `operand` { let mut from_mut = from.borrow_mut(); let from_uses = from_mut.uses_mut(); let mut cursor = unsafe { from_uses.cursor_mut_from_ptr(operand.clone()) }; cursor.remove(); } - // Add the operand to `to` + // Add use of `to` by `operand` operand.borrow_mut().value = to.clone(); - to.borrow_mut().insert_use(operand); + to.borrow_mut().insert_use(operand.clone()); + } + } + } + + /// Replace all uses of this operation's results with `values` + /// + /// The number of results and the number of values in `values` must be exactly the same, + /// otherwise this function will panic. + pub fn replace_all_uses_with(&mut self, values: impl ExactSizeIterator) { + assert_eq!(self.num_results(), values.len()); + for (result, replacement) in self.results.iter_mut().zip(values) { + if ValueRef::ptr_eq(&result.clone().upcast(), &replacement) { + continue; + } + result.borrow_mut().replace_all_uses_with(replacement); + } + } + + /// Replace uses of this operation's results with `values`, for each use which, when provided + /// to the given callback, returns true. + /// + /// The number of results and the number of values in `values` must be exactly the same, + /// otherwise this function will panic. + pub fn replace_uses_with_if(&mut self, values: V, should_replace: F) + where + V: ExactSizeIterator, + F: Fn(&OpOperandImpl) -> bool, + { + assert_eq!(self.num_results(), values.len()); + for (result, replacement) in self.results.iter_mut().zip(values) { + let mut result = result.clone().upcast(); + if ValueRef::ptr_eq(&result, &replacement) { + continue; } + result.borrow_mut().replace_uses_with_if(replacement, &should_replace); } } } @@ -663,6 +764,109 @@ impl Operation { pub fn results_mut(&mut self) -> &mut OpResultStorage { &mut self.results } + + /// Get a reference to the result at `index` among all results of this operation + #[inline] + pub fn get_result(&self, index: usize) -> &OpResultRef { + &self.results[index] + } + + /// Returns true if the results of this operation are used + pub fn is_used(&self) -> bool { + self.results.iter().any(|result| result.borrow().is_used()) + } + + /// Returns true if the results of this operation have exactly one user + pub fn has_exactly_one_use(&self) -> bool { + let mut used_by = None; + for result in self.results.iter() { + let result = result.borrow(); + if !result.is_used() { + continue; + } + + for used in result.iter_uses() { + if used_by.as_ref().is_some_and(|user| !OperationRef::eq(user, &used.owner)) { + // We found more than one user + return false; + } else if used_by.is_none() { + used_by = Some(used.owner.clone()); + } + } + } + + // If we reach here, and we have a `used_by` set, we have exactly one user + used_by.is_some() + } + + /// Returns true if the results of this operation are used outside of the given block + pub fn is_used_outside_of_block(&self, block: &BlockRef) -> bool { + self.results + .iter() + .any(|result| result.borrow().is_used_outside_of_block(block)) + } + + /// Returns true if this operation is unused and has no side effects that prevent it being erased + pub fn is_trivially_dead(&self) -> bool { + !self.is_used() && self.would_be_trivially_dead() + } + + /// Returns true if this operation would be dead if unused, and has no side effects that would + /// prevent erasing it. This is equivalent to checking `is_trivially_dead` if `self` is unused. + /// + /// NOTE: Terminators and symbols are never considered to be trivially dead by this function. + pub fn would_be_trivially_dead(&self) -> bool { + if self.implements::() || self.implements::() { + false + } else { + self.would_be_trivially_dead_even_if_terminator() + } + } + + /// Implementation of `would_be_trivially_dead` that also considers terminator operations as + /// dead if they have no side effects. This allows for marking region operations as trivially + /// dead without always being conservative about terminators. + pub fn would_be_trivially_dead_even_if_terminator(&self) -> bool { + // The set of operations to consider when checking for side effects + let mut effecting_ops = SmallVec::<[OperationRef; 1]>::from_iter([self.as_operation_ref()]); + while let Some(op) = effecting_ops.pop() { + let op = op.borrow(); + // If the operation has recursive effects, push all of the nested operations on to the + // stack to consider. + let has_recursive_effects = + op.implements::(); + if has_recursive_effects { + for region in op.regions() { + for block in region.body() { + let mut cursor = block.body().front(); + while let Some(op) = cursor.as_pointer() { + effecting_ops.push(op); + cursor.move_next(); + } + } + } + } + + // If the op has memory effects, try to characterize them to see if the op is trivially + // dead here. + if op.implements::() + || op.implements::() + { + return false; + } + + // If there were no effect interfaces, we treat this op as conservatively having effects + if !op.implements::() + && !op.implements::() + { + return false; + } + } + + // If we get here, none of the operations had effects that prevented marking this operation + // as dead. + true + } } /// Insertion @@ -807,16 +1011,42 @@ impl Operation { regions.move_next(); } } + + /// Drop all uses of results of this operation + pub fn drop_all_uses(&mut self) { + for result in self.results.iter_mut() { + result.borrow_mut().uses_mut().clear(); + } + } } /// Ordering impl Operation { + /// This value represents an invalid index ordering for an operation within its containing block + const INVALID_ORDER: u32 = u32::MAX; + /// This value represents the stride to use when computing a new order for an operation + const ORDER_STRIDE: u32 = 5; + + /// Returns true if this operation is an ancestor of `other`. + /// + /// An operation is considered its own ancestor, use [Self::is_proper_ancestor_of] if you do not + /// want this behavior. + pub fn is_ancestor_of(&self, other: &OperationRef) -> bool { + let this = self.as_operation_ref(); + OperationRef::ptr_eq(&this, other) || Self::is_a_proper_ancestor_of_b(&this, other) + } + /// Returns true if this operation is a proper ancestor of `other` - pub fn is_proper_ancestor_of(&self, other: OperationRef) -> bool { + pub fn is_proper_ancestor_of(&self, other: &OperationRef) -> bool { let this = self.as_operation_ref(); - let mut next = other.borrow().parent_op(); - while let Some(other) = next.take() { - if OperationRef::ptr_eq(&this, &other) { + Self::is_a_proper_ancestor_of_b(&this, other) + } + + /// Returns true if operation `a` is a proper ancestor of operation `b` + fn is_a_proper_ancestor_of_b(a: &OperationRef, b: &OperationRef) -> bool { + let mut next = b.borrow().parent_op(); + while let Some(b) = next.take() { + if OperationRef::ptr_eq(a, &b) { return true; } } @@ -828,94 +1058,158 @@ impl Operation { /// /// NOTE: This function has an average complexity of O(1), but worst case may take O(N) where /// N is the number of operations within the parent block. - pub fn is_before_in_block(&self, _other: OperationRef) -> bool { - /* - let block = self.block().expect("operations without parent blocks have no order"); + pub fn is_before_in_block(&self, other: &OperationRef) -> bool { + use core::sync::atomic::Ordering; + + let block = self.block.clone().expect("operations without parent blocks have no order"); let other = other.borrow(); - assert!(other.block().is_some_and(|other_block| BlockRef::ptr_eq(&block, other_block)), "expected both operations to have the same parent block"); + assert!( + other + .block + .as_ref() + .is_some_and(|other_block| BlockRef::ptr_eq(&block, other_block)), + "expected both operations to have the same parent block" + ); + // If the order of the block is already invalid, directly recompute the parent - let block = block.borrow(); - if !block.is_op_order_valid() { - block.recompute_op_order(); + if !block.borrow().is_op_order_valid() { + Self::recompute_block_order(block); } else { // Update the order of either operation if necessary. self.update_order_if_necessary(); other.update_order_if_necessary(); } - self.order < other.order - */ - todo!() + self.order.load(Ordering::Relaxed) < other.order.load(Ordering::Relaxed) } /// Update the order index of this operation of this operation if necessary, /// potentially recomputing the order of the parent block. fn update_order_if_necessary(&self) { - /* - assert!(self.block.is_some(), "expected valid parent"); + use core::sync::atomic::Ordering; - let this = self.as_operation_ref(); + assert!(self.block.is_some(), "expected valid parent"); // If the order is valid for this operation there is nothing to do. - let block = self.block.as_ref().unwrap().borrow(); - if self.has_valid_order() || block.body().iter().count() == 1 { + let block = self.block.clone().unwrap(); + if self.has_valid_order() || block.borrow().body().iter().count() == 1 { return; } - let back = block.body().back().as_pointer(); - let front = block.body().front().as_pointer(); - assert!(!OperationRef::ptr_eq(&front, &back)); + let this = self.as_operation_ref(); + let prev = this.prev(); + let next = this.next(); + assert!(prev.is_some() || next.is_some(), "expected more than one operation in block"); // If the operation is at the end of the block. - if Operation::ptr_eq(&this, &back) { - let prev = self.get_prev(); - if !prev.borrow().has_valid_order() { - return block.recompute_op_order(); + if next.is_none() { + let prev = prev.unwrap(); + let prev = prev.borrow(); + let prev_order = prev.order.load(Ordering::Acquire); + if prev_order == Self::INVALID_ORDER { + return Self::recompute_block_order(block); } // Add the stride to the previous operation. - self.order = prev.order + Self::ORDER_STRIDE; + self.order.store(prev_order + Self::ORDER_STRIDE, Ordering::Release); return; } // If this is the first operation try to use the next operation to compute the // ordering. - if Operation::ptr_eq(&this, &front) { - let next = self.get_next(); - if !next.has_valid_order() { - return block.recompute_op_order(); - } - // There is no order to give this operation. - if next.order == 0 { - return block.recompute_op_order(); - } - - // If we can't use the stride, just take the middle value left. This is safe - // because we know there is at least one valid index to assign to. - if next.order <= Self::ORDER_STRIDE { - self.order = next.order / 2; - } else { - self.order = Self::ORDER_STRIDE; + if prev.is_none() { + let next = next.unwrap(); + let next = next.borrow(); + let next_order = next.order.load(Ordering::Acquire); + match next_order { + Self::INVALID_ORDER | 0 => { + return Self::recompute_block_order(block); + } + // If we can't use the stride, just take the middle value left. This is safe + // because we know there is at least one valid index to assign to. + order if order <= Self::ORDER_STRIDE => { + self.order.store(order / 2, Ordering::Release); + } + _ => { + self.order.store(Self::ORDER_STRIDE, Ordering::Release); + } } return; } // Otherwise, this operation is between two others. Place this operation in // the middle of the previous and next if possible. - let prev = self.get_prev(); - let next = self.get_next(); - if !prev.has_valid_order() || !next.has_valid_order() { - return block.recompute_op_order(); + let prev = prev.unwrap().borrow().order.load(Ordering::Acquire); + let next = next.unwrap().borrow().order.load(Ordering::Acquire); + if prev == Self::INVALID_ORDER || next == Self::INVALID_ORDER { + return Self::recompute_block_order(block); } - let prev_order = prev.order; - let next_order = next.order; // Check to see if there is a valid order between the two. - if prev_order + 1 == next_order { - return block.recompute_op_order(); + if prev + 1 == next { + return Self::recompute_block_order(block); + } + self.order.store(prev + ((next - prev) / 2), Ordering::Release); + } + + fn recompute_block_order(mut block: BlockRef) { + use core::sync::atomic::Ordering; + + let mut block = block.borrow_mut(); + let mut cursor = block.body().front(); + let mut index = 0; + while let Some(op) = cursor.as_pointer() { + index += Self::ORDER_STRIDE; + cursor.move_next(); + let ptr = OperationRef::as_ptr(&op); + unsafe { + let order_addr = core::ptr::addr_of!((*ptr).order); + (*order_addr).store(index, Ordering::Release); + } + } + + block.mark_op_order_valid(); + } + + /// Returns `None` if this operation has invalid ordering + #[inline] + pub(super) fn order(&self) -> Option { + use core::sync::atomic::Ordering; + match self.order.load(Ordering::Acquire) { + Self::INVALID_ORDER => None, + order => Some(order), + } + } + + /// Returns true if this operation has a valid order + #[inline(always)] + pub(super) fn has_valid_order(&self) -> bool { + self.order().is_some() + } +} + +impl crate::traits::Foldable for Operation { + fn fold(&self, results: &mut smallvec::SmallVec<[OpFoldResult; 1]>) -> FoldResult { + use crate::traits::Foldable; + + if let Some(foldable) = self.as_trait::() { + foldable.fold(results) + } else { + FoldResult::Failed + } + } + + fn fold_with<'operands>( + &self, + operands: &[Option>], + results: &mut smallvec::SmallVec<[OpFoldResult; 1]>, + ) -> FoldResult { + use crate::traits::Foldable; + + if let Some(foldable) = self.as_trait::() { + foldable.fold_with(operands, results) + } else { + FoldResult::Failed } - self.order = prev_order + ((next_order - prev_order) / 2); - */ - todo!() } } diff --git a/hir2/src/ir/operation/builder.rs b/hir2/src/ir/operation/builder.rs index 2818ef68a..8856b06a1 100644 --- a/hir2/src/ir/operation/builder.rs +++ b/hir2/src/ir/operation/builder.rs @@ -238,8 +238,10 @@ where op.verify(self.builder.context())?; } - // Insert op at current insertion point - self.builder.insert(self.op); + // Insert op at current insertion point, if set + if self.builder.insertion_point().is_some() { + self.builder.insert(self.op); + } Ok(op) } diff --git a/hir2/src/ir/operation/name.rs b/hir2/src/ir/operation/name.rs index 89210fc80..ab2802b88 100644 --- a/hir2/src/ir/operation/name.rs +++ b/hir2/src/ir/operation/name.rs @@ -44,8 +44,8 @@ impl OperationName { } /// Returns the dialect name of this operation - pub fn dialect(&self) -> DialectName { - self.0.dialect + pub fn dialect(&self) -> &DialectName { + &self.0.dialect } /// Returns the namespace to which this operation name belongs (i.e. dialect name) diff --git a/hir2/src/ir/region.rs b/hir2/src/ir/region.rs index 21df63877..9d2edf6da 100644 --- a/hir2/src/ir/region.rs +++ b/hir2/src/ir/region.rs @@ -8,6 +8,21 @@ pub type RegionCursor<'a> = EntityCursor<'a, Region>; /// A mutable cursor in a [RegionList] pub type RegionCursorMut<'a> = EntityCursorMut<'a, Region>; +/// A region is a container for [Block], in one of two forms: +/// +/// * Graph-like, in which the region consists of a single block, and the order of operations in +/// that block does not dictate any specific control flow semantics. It is up to the containing +/// operation to define. +/// * SSA-form, in which the region consists of one or more blocks that must obey the usual rules +/// of SSA dominance, and where operations in a block reflect the order in which those operations +/// are to be executed. Values defined by an operation must dominate any uses of those values in +/// the region. +/// +/// The first block in a region is the _entry_ block, and its argument list corresponds to the +/// arguments expected by the region itself. +/// +/// A region is only valid when it is attached to an [Operation], whereas the inverse is not true, +/// i.e. an operation without a parent region is a top-level operation, e.g. `Module`. #[derive(Default)] pub struct Region { /// The operation this region is attached to. @@ -17,17 +32,142 @@ pub struct Region { /// The list of [Block]s that comprise this region body: BlockList, } + +impl Entity for Region {} +impl EntityWithParent for Region { + type Parent = Operation; + + fn on_inserted_into_parent( + mut this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().owner = Some(parent); + } + + fn on_removed_from_parent( + mut this: UnsafeIntrusiveEntityRef, + _parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().owner = None; + } + + fn on_transfered_to_new_parent( + _from: UnsafeIntrusiveEntityRef, + to: UnsafeIntrusiveEntityRef, + transferred: impl IntoIterator>, + ) { + for mut transferred_region in transferred { + transferred_region.borrow_mut().owner = Some(to.clone()); + } + } +} + +/// Blocks impl Region { /// Returns true if this region is empty (has no blocks) pub fn is_empty(&self) -> bool { self.body.is_empty() } + /// Get a handle to the entry block for this region + pub fn entry(&self) -> EntityRef<'_, Block> { + self.body.front().into_borrow().unwrap() + } + + /// Get a mutable handle to the entry block for this region + pub fn entry_mut(&mut self) -> EntityMut<'_, Block> { + self.body.front_mut().into_borrow_mut().unwrap() + } + + /// Get the [BlockRef] of the entry block of this region, if it has one + #[inline] + pub fn entry_block_ref(&self) -> Option { + self.body.front().as_pointer() + } + + /// Get the list of blocks comprising the body of this region + pub fn body(&self) -> &BlockList { + &self.body + } + + /// Get a mutable reference to the list of blocks comprising the body of this region + pub fn body_mut(&mut self) -> &mut BlockList { + &mut self.body + } +} + +/// Metadata +impl Region { + #[inline] + pub fn as_region_ref(&self) -> RegionRef { + unsafe { RegionRef::from_raw(self) } + } + + /// Returns true if this region is an ancestor of `other`, i.e. it contains it. + /// + /// NOTE: This returns true if `self == other`, see [Self::is_proper_ancestor] if you do not + /// want this behavior. + pub fn is_ancestor(&self, other: &RegionRef) -> bool { + let this = self.as_region_ref(); + &this == other || Self::is_proper_ancestor_of(&this, other) + } + + /// Returns true if this region is a proper ancestor of `other`, i.e. `other` is contained by it + /// + /// NOTE: This returns false if `self == other`, see [Self::is_ancestor] if you do not want this + /// behavior. + pub fn is_proper_ancestor(&self, other: &RegionRef) -> bool { + let this = self.as_region_ref(); + Self::is_proper_ancestor_of(&this, other) + } + + fn is_proper_ancestor_of(this: &RegionRef, other: &RegionRef) -> bool { + if this == other { + return false; + } + + let mut parent = other.borrow().parent_region(); + while let Some(parent_region) = parent.take() { + if this == &parent_region { + return true; + } + parent = parent_region.borrow().parent_region(); + } + + false + } + + /// Returns true if this region may be a graph region without SSA dominance + pub fn may_be_graph_region(&self) -> bool { + if let Some(owner) = self.owner.as_ref() { + owner + .borrow() + .as_trait::() + .is_some_and(|rki| rki.has_graph_regions()) + } else { + true + } + } + + /// Returns true if this region has only one block + pub fn has_one_block(&self) -> bool { + !self.body.is_empty() + && BlockRef::ptr_eq( + &self.body.front().as_pointer().unwrap(), + &self.body.back().as_pointer().unwrap(), + ) + } + /// Get the defining [Operation] for this region, if the region is attached to one. pub fn parent(&self) -> Option { self.owner.clone() } + /// Get the region which contains the parent operation of this region, if there is one. + pub fn parent_region(&self) -> Option { + self.owner.as_ref().and_then(|op| op.borrow().parent_region()) + } + /// Set the owner of this region. /// /// Returns the previous owner. @@ -47,27 +187,24 @@ impl Region { Some(owner) => self.owner.replace(owner), } } +} - /// Get a handle to the entry block for this region - pub fn entry(&self) -> EntityRef<'_, Block> { - self.body.front().into_borrow().unwrap() - } - - /// Get a mutable handle to the entry block for this region - pub fn entry_mut(&mut self) -> EntityMut<'_, Block> { - self.body.front_mut().into_borrow_mut().unwrap() - } - - /// Get the list of blocks comprising the body of this region - pub fn body(&self) -> &BlockList { - &self.body +/// Mutation +impl Region { + /// Push `block` to the start of this region + #[inline] + pub fn push_front(&mut self, block: BlockRef) { + self.body.push_front(block); } - /// Get a mutable reference to the list of blocks comprising the body of this region - pub fn body_mut(&mut self) -> &mut BlockList { - &mut self.body + /// Push `block` to the end of this region + #[inline] + pub fn push_back(&mut self, block: BlockRef) { + self.body.push_back(block); } + /// Drop any references to blocks in this region - this is used to break cycles when cleaning + /// up regions. pub fn drop_all_references(&mut self) { todo!() } diff --git a/hir2/src/ir/traits.rs b/hir2/src/ir/traits.rs index b87af90ab..d5359b499 100644 --- a/hir2/src/ir/traits.rs +++ b/hir2/src/ir/traits.rs @@ -1,3 +1,4 @@ +mod foldable; mod info; mod types; @@ -16,6 +17,17 @@ pub trait ConstantLike {} /// Marker trait for ops with side effects pub trait HasSideEffects {} +/// Marker trait for ops with recursive memory effects, i.e. the effects of the operation includes +/// the effects of operations nested within its regions. If the operation does not implement any +/// effect markers, e.g. `MemoryWrite`, then it can be assumed to have no memory effects itself. +pub trait HasRecursiveMemoryEffects {} + +/// Marker trait for ops which allocate memory +pub trait MemoryAlloc {} + +/// Marker trait for ops which free memory +pub trait MemoryFree {} + /// Marker trait for ops which read memory pub trait MemoryRead {} @@ -182,6 +194,8 @@ derive! { } } +// pub trait SingleBlockImplicitTerminator {} + derive! { /// Op has a single region pub trait SingleRegion {} @@ -206,3 +220,6 @@ derive! { } } } + +// pub trait HasParent {} +// pub trait ParentOneOf<(T,...)> {} diff --git a/hir2/src/ir/traits/foldable.rs b/hir2/src/ir/traits/foldable.rs new file mode 100644 index 000000000..cef1a08b1 --- /dev/null +++ b/hir2/src/ir/traits/foldable.rs @@ -0,0 +1,126 @@ +use smallvec::SmallVec; + +use crate::{AttributeValue, ValueRef}; + +/// Represents the outcome of an attempt to fold an operation. +#[must_use] +pub enum FoldResult { + /// The operation was folded and erased, and the given fold results were returned + Ok(T), + /// The operation was modified in-place, but not erased. + InPlace, + /// The operation could not be folded + Failed, +} +impl FoldResult { + /// Returns true if folding was successful + #[inline] + pub fn is_ok(&self) -> bool { + matches!(self, Self::Ok(_) | Self::InPlace) + } + + /// Returns true if folding was unsuccessful + #[inline] + pub fn is_failed(&self) -> bool { + matches!(self, Self::Failed) + } + + /// Convert this result to an `Option` representing a successful outcome, where `None` indicates + /// an in-place fold, and `Some(T)` indicates that the operation was folded away. + /// + /// Panics with the given message if the fold attempt failed. + #[inline] + #[track_caller] + pub fn expect(self, message: &'static str) -> Option { + match self { + Self::Ok(out) => Some(out), + Self::InPlace => None, + Self::Failed => unwrap_failed_fold_result(message), + } + } +} + +#[cold] +#[track_caller] +#[inline(never)] +fn unwrap_failed_fold_result(message: &'static str) -> ! { + panic!("tried to unwrap failed fold result as successful: {message}") +} + +/// Represents a single result value of a folded operation. +#[derive(Debug)] +pub enum OpFoldResult { + /// The value is constant + Attribute(Box), + /// The value is a non-constant SSA value + Value(ValueRef), +} +impl OpFoldResult { + #[inline] + pub fn is_constant(&self) -> bool { + matches!(self, Self::Attribute(_)) + } +} +impl Eq for OpFoldResult {} +impl PartialEq for OpFoldResult { + fn eq(&self, other: &Self) -> bool { + use core::hash::{Hash, Hasher}; + + match (self, other) { + (Self::Attribute(lhs), Self::Attribute(rhs)) => { + if lhs.as_any().type_id() != rhs.as_any().type_id() { + return false; + } + let lhs_hash = { + let mut hasher = rustc_hash::FxHasher::default(); + lhs.hash(&mut hasher); + hasher.finish() + }; + let rhs_hash = { + let mut hasher = rustc_hash::FxHasher::default(); + rhs.hash(&mut hasher); + hasher.finish() + }; + lhs_hash == rhs_hash + } + (Self::Value(lhs), Self::Value(rhs)) => ValueRef::ptr_eq(lhs, rhs), + _ => false, + } + } +} +impl core::fmt::Display for OpFoldResult { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Attribute(ref attr) => attr.pretty_print(f), + Self::Value(ref value) => write!(f, "{}", value.borrow().id()), + } + } +} + +/// An operation that can be constant-folded must implement the folding logic via this trait. +/// +/// NOTE: Any `ConstantLike` operation must implement this trait as a no-op, i.e. returning the +/// value of the constant directly, as this is used by the pattern matching infrastructure to +/// extract the value of constant operations without knowing anything about the specific op. +pub trait Foldable { + /// Attempt to fold this operation using its current operand values. + /// + /// If folding was successful and the operation should be erased, `results` will contain the + /// folded results. See [FoldResult] for more details on what the various outcomes of folding + /// are. + fn fold(&self, results: &mut SmallVec<[OpFoldResult; 1]>) -> FoldResult; + + /// Attempt to fold this operation with the specified operand values. + /// + /// The elements in `operands` will correspond 1:1 with the operands of the operation, but will + /// be `None` if the value is non-constant. + /// + /// If folding was successful and the operation should be erased, `results` will contain the + /// folded results. See [FoldResult] for more details on what the various outcomes of folding + /// are. + fn fold_with( + &self, + operands: &[Option>], + results: &mut SmallVec<[OpFoldResult; 1]>, + ) -> FoldResult; +} diff --git a/hir2/src/ir/value.rs b/hir2/src/ir/value.rs index c1376ca19..2596f9f42 100644 --- a/hir2/src/ir/value.rs +++ b/hir2/src/ir/value.rs @@ -38,7 +38,9 @@ impl fmt::Display for ValueId { /// of a [Value] are operands (see [OpOperandImpl]). Operands are associated with an operation. Thus /// the graph formed of the edges between values and operations via operands forms the data-flow /// graph of the program. -pub trait Value: Entity + Spanned + Usable + fmt::Debug { +pub trait Value: + EntityWithId + Spanned + Usable + fmt::Debug +{ fn as_any(&self) -> &dyn Any; fn as_any_mut(&mut self) -> &mut dyn Any; /// Set the source location of this value @@ -51,6 +53,52 @@ pub trait Value: Entity + Spanned + Usable + /// /// Returns `None` if this value is defined by other means than an operation result. fn get_defining_op(&self) -> Option; + /// Get the region which contains the definition of this value + fn parent_region(&self) -> Option { + self.parent_block().and_then(|block| block.borrow().parent()) + } + /// Get the block which contains the definition of this value + fn parent_block(&self) -> Option; + /// Returns true if this value is used outside of the given block + fn is_used_outside_of_block(&self, block: &BlockRef) -> bool { + self.iter_uses().any(|user| { + user.owner.borrow().parent().is_some_and(|blk| !BlockRef::ptr_eq(&blk, block)) + }) + } + /// Replace all uses of `self` with `replacement` + fn replace_all_uses_with(&mut self, mut replacement: ValueRef) { + let mut cursor = self.uses_mut().front_mut(); + while let Some(mut user) = cursor.as_pointer() { + // Rewrite use of `self` with `replacement` + { + let mut user = user.borrow_mut(); + user.value = replacement.clone(); + } + // Remove `user` from the use list of `self` + cursor.remove(); + // Add `user` to the use list of `replacement` + replacement.borrow_mut().insert_use(user); + } + } + /// Replace all uses of `self` with `replacement` unless the user is in `exceptions` + fn replace_all_uses_except(&mut self, mut replacement: ValueRef, exceptions: &[OperationRef]) { + let mut cursor = self.uses_mut().front_mut(); + while let Some(mut user) = cursor.as_pointer() { + // Rewrite use of `self` with `replacement` if user not in `exceptions` + { + let mut user = user.borrow_mut(); + if exceptions.contains(&user.owner) { + cursor.move_next(); + continue; + } + user.value = replacement.clone(); + } + // Remove `user` from the use list of `self` + cursor.remove(); + // Add `user` to the use list of `replacement` + replacement.borrow_mut().insert_use(user); + } + } } impl dyn Value { @@ -68,6 +116,29 @@ impl dyn Value { pub fn downcast_mut(&mut self) -> Option<&mut T> { self.as_any_mut().downcast_mut::() } + + /// Replace all uses of `self` with `replacement` if `should_replace` returns true + pub fn replace_uses_with_if(&mut self, mut replacement: ValueRef, should_replace: F) + where + F: Fn(&OpOperandImpl) -> bool, + { + let mut cursor = self.uses_mut().front_mut(); + while let Some(mut user) = cursor.as_pointer() { + // Rewrite use of `self` with `replacement` if `should_replace` returns true + { + let mut user = user.borrow_mut(); + if !should_replace(&user) { + cursor.move_next(); + continue; + } + user.value = replacement.clone(); + } + // Remove `user` from the use list of `self` + cursor.remove(); + // Add `user` to the use list of `replacement` + replacement.borrow_mut().insert_use(user); + } + } } /// Generates the boilerplate for a concrete [Value] type. @@ -75,14 +146,20 @@ macro_rules! value_impl { ( $(#[$outer:meta])* $vis:vis struct $ValueKind:ident { + $(#[doc $($owner_doc_args:tt)*])* + owner: $OwnerTy:ty, + $(#[doc $($index_doc_args:tt)*])* + index: u8, $( - $(*[$inner:ident $($args:tt)*])* + $(#[$inner:ident $($args:tt)*])* $Field:ident: $FieldTy:ty, )* } fn get_defining_op(&$GetDefiningOpSelf:ident) -> Option $GetDefiningOp:block + fn parent_block(&$ParentBlockSelf:ident) -> Option $ParentBlock:block + $($t:tt)* ) => { $(#[$outer])* @@ -93,6 +170,8 @@ macro_rules! value_impl { span: SourceSpan, ty: Type, uses: OpOperandList, + owner: $OwnerTy, + index: u8, $( $(#[$inner $($args)*])* $Field: $FieldTy @@ -104,6 +183,8 @@ macro_rules! value_impl { span: SourceSpan, id: ValueId, ty: Type, + owner: $OwnerTy, + index: u8, $( $Field: $FieldTy ),* @@ -113,11 +194,23 @@ macro_rules! value_impl { ty, span, uses: Default::default(), + owner, + index, $( $Field ),* } } + + $(#[doc $($owner_doc_args)*])* + pub fn owner(&self) -> $OwnerTy { + self.owner.clone() + } + + $(#[doc $($index_doc_args)*])* + pub fn index(&self) -> usize { + self.index as usize + } } impl Value for $ValueKind { @@ -142,9 +235,12 @@ macro_rules! value_impl { } fn get_defining_op(&$GetDefiningOpSelf) -> Option $GetDefiningOp + + fn parent_block(&$ParentBlockSelf) -> Option $ParentBlock } - impl Entity for $ValueKind { + impl Entity for $ValueKind {} + impl EntityWithId for $ValueKind { type Id = ValueId; #[inline(always)] @@ -167,13 +263,22 @@ macro_rules! value_impl { } } + impl fmt::Display for $ValueKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use crate::formatter::PrettyPrint; + + self.pretty_print(f) + } + } + impl fmt::Debug for $ValueKind { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut builder = f.debug_struct(stringify!($ValueKind)); builder .field("id", &self.id) .field("ty", &self.ty) - .field("uses", &self.uses); + .field("index", &self.index) + .field("is_used", &(!self.uses.is_empty())); $( builder.field(stringify!($Field), &self.$Field); @@ -197,36 +302,30 @@ pub type OpResultRef = UnsafeEntityRef; value_impl!( /// A [BlockArgument] represents the definition of a [Value] by a block parameter pub struct BlockArgument { + /// Get the [Block] to which this [BlockArgument] belongs owner: BlockRef, + /// Get the index of this argument in the argument list of the owning [Block] index: u8, } fn get_defining_op(&self) -> Option { None } -); -value_impl!( - /// An [OpResult] represents the definition of a [Value] by the result of an [Operation] - pub struct OpResult { - owner: OperationRef, - index: u8, - } - - fn get_defining_op(&self) -> Option { + fn parent_block(&self) -> Option { Some(self.owner.clone()) } ); impl BlockArgument { - /// Get the [Block] to which this [BlockArgument] belongs - pub fn owner(&self) -> BlockRef { - self.owner.clone() + #[inline] + pub fn as_value_ref(&self) -> ValueRef { + self.as_block_argument_ref().upcast() } - /// Get the index of this argument in the argument list of the owning [Block] - pub fn index(&self) -> usize { - self.index as usize + #[inline] + pub fn as_block_argument_ref(&self) -> BlockArgumentRef { + unsafe { BlockArgumentRef::from_raw(self) } } } @@ -234,7 +333,7 @@ impl crate::formatter::PrettyPrint for BlockArgument { fn render(&self) -> crate::formatter::Document { use crate::formatter::*; - text(format!("{}", self.id)) + const_text(": ") + self.ty.render() + display(self.id) + const_text(": ") + self.ty.render() } } @@ -249,17 +348,25 @@ impl StorableEntity for BlockArgument { } } -impl OpResult { - /// Get the [Operation] to which this [OpResult] belongs - pub fn owner(&self) -> OperationRef { - self.owner.clone() +value_impl!( + /// An [OpResult] represents the definition of a [Value] by the result of an [Operation] + pub struct OpResult { + /// Get the [Operation] to which this [OpResult] belongs + owner: OperationRef, + /// Get the index of this result in the result list of the owning [Operation] + index: u8, } - /// Get the index of this result in the result list of the owning [Operation] - pub fn index(&self) -> usize { - self.index as usize + fn get_defining_op(&self) -> Option { + Some(self.owner.clone()) + } + + fn parent_block(&self) -> Option { + self.owner.borrow().parent() } +); +impl OpResult { #[inline] pub fn as_value_ref(&self) -> ValueRef { unsafe { ValueRef::from_raw(self as &dyn Value) } From de12ced3efb135268a85b7876254dcddf722460c Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 20:15:07 -0400 Subject: [PATCH 19/31] wip: implement region/non-region branch op interfaces, region simplification, rework visitors, and implement low-level pattern matchers --- hir2/src/derive.rs | 4 +- hir2/src/ir.rs | 23 +- hir2/src/ir/builder.rs | 183 ++++- hir2/src/ir/operation.rs | 8 + hir2/src/ir/print.rs | 140 ++-- hir2/src/ir/region.rs | 214 ++++- hir2/src/ir/region/branch_point.rs | 50 ++ hir2/src/ir/region/interfaces.rs | 247 ++++++ hir2/src/ir/region/invocation_bounds.rs | 132 +++ hir2/src/ir/region/kind.rs | 22 + hir2/src/ir/region/successor.rs | 96 +++ hir2/src/ir/region/transforms.rs | 6 + .../src/ir/region/transforms/block_merging.rs | 263 ++++++ hir2/src/ir/region/transforms/dce.rs | 354 ++++++++ .../region/transforms/drop_redundant_args.rs | 151 ++++ hir2/src/ir/successor.rs | 220 +++++ hir2/src/ir/traits.rs | 136 +++- hir2/src/ir/visit.rs | 414 +--------- hir2/src/ir/visit/blocks.rs | 102 +++ hir2/src/ir/visit/searcher.rs | 61 ++ hir2/src/ir/visit/visitor.rs | 36 + hir2/src/ir/visit/walkable.rs | 404 +++++++++ hir2/src/lib.rs | 3 + hir2/src/matchers.rs | 3 + hir2/src/matchers/matcher.rs | 770 ++++++++++++++++++ 25 files changed, 3513 insertions(+), 529 deletions(-) create mode 100644 hir2/src/ir/region/branch_point.rs create mode 100644 hir2/src/ir/region/interfaces.rs create mode 100644 hir2/src/ir/region/invocation_bounds.rs create mode 100644 hir2/src/ir/region/kind.rs create mode 100644 hir2/src/ir/region/successor.rs create mode 100644 hir2/src/ir/region/transforms.rs create mode 100644 hir2/src/ir/region/transforms/block_merging.rs create mode 100644 hir2/src/ir/region/transforms/dce.rs create mode 100644 hir2/src/ir/region/transforms/drop_redundant_args.rs create mode 100644 hir2/src/ir/visit/blocks.rs create mode 100644 hir2/src/ir/visit/searcher.rs create mode 100644 hir2/src/ir/visit/visitor.rs create mode 100644 hir2/src/ir/visit/walkable.rs create mode 100644 hir2/src/matchers.rs create mode 100644 hir2/src/matchers/matcher.rs diff --git a/hir2/src/derive.rs b/hir2/src/derive.rs index 2ec48bca5..8aeba6d68 100644 --- a/hir2/src/derive.rs +++ b/hir2/src/derive.rs @@ -169,8 +169,8 @@ mod tests { use super::operation; use crate::{ - define_attr_type, dialects::hir::HirDialect, formatter, traits::*, Builder, Context, Op, - Operation, Report, Spanned, Value, + define_attr_type, dialects::hir::HirDialect, formatter, traits::*, Builder, BuilderExt, + Context, Op, Operation, Report, Spanned, Value, }; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index f595b91f4..d2af639cb 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -31,14 +31,14 @@ pub use self::{ Block, BlockCursor, BlockCursorMut, BlockId, BlockList, BlockOperand, BlockOperandRef, BlockRef, }, - builder::{Builder, Listener, ListenerType, OpBuilder}, + builder::{Builder, BuilderExt, InsertionGuard, Listener, ListenerType, OpBuilder}, callable::*, context::Context, dialect::{Dialect, DialectName, DialectRegistration}, entity::{ Entity, EntityCursor, EntityCursorMut, EntityGroup, EntityId, EntityIter, EntityList, - EntityMut, EntityRange, EntityRangeMut, EntityRef, EntityStorage, RawEntityRef, - StorableEntity, UnsafeEntityRef, UnsafeIntrusiveEntityRef, + EntityMut, EntityRange, EntityRangeMut, EntityRef, EntityStorage, EntityWithId, + EntityWithParent, RawEntityRef, StorableEntity, UnsafeEntityRef, UnsafeIntrusiveEntityRef, }, ident::{FunctionIdent, Ident}, immediates::{Felt, FieldElement, Immediate, StarkField}, @@ -51,11 +51,17 @@ pub use self::{ operation::{ OpCursor, OpCursorMut, OpList, Operation, OperationBuilder, OperationName, OperationRef, }, - print::OpPrinter, - region::{Region, RegionCursor, RegionCursorMut, RegionList, RegionRef}, + print::{OpPrinter, OpPrintingFlags}, + region::{ + InvocationBounds, Region, RegionBranchOpInterface, RegionBranchPoint, + RegionBranchTerminatorOpInterface, RegionCursor, RegionCursorMut, RegionKind, + RegionKindInterface, RegionList, RegionRef, RegionSuccessor, RegionSuccessorInfo, + RegionSuccessorIter, RegionSuccessorMut, RegionTransformFailed, + }, successor::{ KeyedSuccessor, KeyedSuccessorRange, KeyedSuccessorRangeMut, OpSuccessor, OpSuccessorMut, - OpSuccessorRange, OpSuccessorRangeMut, OpSuccessorStorage, SuccessorInfo, SuccessorWithKey, + OpSuccessorRange, OpSuccessorRangeMut, OpSuccessorStorage, SuccessorInfo, SuccessorOperand, + SuccessorOperandRange, SuccessorOperandRangeMut, SuccessorOperands, SuccessorWithKey, SuccessorWithKeyMut, }, symbol_table::{ @@ -64,6 +70,7 @@ pub use self::{ SymbolUseCursor, SymbolUseCursorMut, SymbolUseIter, SymbolUseList, SymbolUseRef, SymbolUsesIter, }, + traits::{FoldResult, OpFoldResult}, types::*, usable::Usable, value::{ @@ -72,7 +79,7 @@ pub use self::{ }, verifier::{OpVerifier, Verify}, visit::{ - OpVisitor, OperationVisitor, Searcher, SymbolVisitor, Visitor, WalkOrder, WalkResult, - WalkStage, Walkable, + BlockIter, OpVisitor, OperationVisitor, PostOrderBlockIter, Searcher, SymbolVisitor, + Visitor, WalkOrder, WalkResult, WalkStage, Walkable, }, }; diff --git a/hir2/src/ir/builder.rs b/hir2/src/ir/builder.rs index 1d5fb4e9a..e8d1404c6 100644 --- a/hir2/src/ir/builder.rs +++ b/hir2/src/ir/builder.rs @@ -196,9 +196,11 @@ pub trait BuilderExt: Builder { } } -pub struct OpBuilder { +impl BuilderExt for B {} + +pub struct OpBuilder { context: Rc, - listener: Option>, + listener: Option, ip: Option, } @@ -210,38 +212,51 @@ impl OpBuilder { ip: None, } } +} +impl OpBuilder { /// Sets the listener of this builder to `listener` - pub fn with_listener(&mut self, listener: impl Listener) -> &mut Self { - self.listener = Some(Box::new(listener)); - self + pub fn with_listener(self, listener: L2) -> OpBuilder + where + L2: Listener, + { + OpBuilder { + context: self.context, + listener: Some(listener), + ip: self.ip, + } + } + + #[inline] + pub fn into_parts(self) -> (Rc, Option, Option) { + (self.context, self.listener, self.ip) } } -impl Listener for OpBuilder { +impl Listener for OpBuilder { fn kind(&self) -> ListenerType { self.listener.as_ref().map(|l| l.kind()).unwrap_or(ListenerType::Builder) } + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_inserted(op, prev); + } + } + fn notify_block_inserted( - &mut self, + &self, block: BlockRef, prev: Option, - ip: Option, + ip: Option, ) { - if let Some(listener) = self.listener.as_deref_mut() { + if let Some(listener) = self.listener.as_ref() { listener.notify_block_inserted(block, prev, ip); } } - - fn notify_operation_inserted(&mut self, op: OperationRef, prev: Option) { - if let Some(listener) = self.listener.as_deref_mut() { - listener.notify_operation_inserted(op, prev); - } - } } -impl Builder for OpBuilder { +impl Builder for OpBuilder { #[inline(always)] fn context(&self) -> &Context { self.context.as_ref() @@ -279,40 +294,148 @@ pub enum ListenerType { Rewriter, } +#[allow(unused_variables)] pub trait Listener: 'static { fn kind(&self) -> ListenerType; /// Notify the listener that the specified operation was inserted. /// /// * If the operation was moved, then `prev` is the previous location of the op /// * If the operation was unlinked before it was inserted, then `prev` is `None` - fn notify_operation_inserted(&mut self, op: OperationRef, prev: Option); + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) {} /// Notify the listener that the specified block was inserted. /// /// * If the block was moved, then `prev` and `ip` represent the previous location of the block. /// * If the block was unlinked before it was inserted, then `prev` and `ip` are `None` fn notify_block_inserted( - &mut self, + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + } +} + +impl Listener for Option { + fn kind(&self) -> ListenerType { + ListenerType::Builder + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + if let Some(listener) = self.as_ref() { + listener.notify_block_inserted(block, prev, ip); + } + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_inserted(op, prev); + } + } +} + +impl Listener for Box { + #[inline] + fn kind(&self) -> ListenerType { + (**self).kind() + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + (**self).notify_operation_inserted(op, prev) + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + (**self).notify_block_inserted(block, prev, ip) + } +} + +impl Listener for Rc { + #[inline] + fn kind(&self) -> ListenerType { + (**self).kind() + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + (**self).notify_operation_inserted(op, prev) + } + + fn notify_block_inserted( + &self, block: BlockRef, prev: Option, - ip: Option, - ); + ip: Option, + ) { + (**self).notify_block_inserted(block, prev, ip) + } } -pub struct InsertionGuard<'a> { - builder: &'a mut OpBuilder, +/// A listener of kind `Builder` that does nothing +pub struct NoopBuilderListener; +impl Listener for NoopBuilderListener { + #[inline] + fn kind(&self) -> ListenerType { + ListenerType::Builder + } +} + +/// This is used to allow [InsertionGuard] to be agnostic about the type of builder/rewriter it +/// wraps, while still performing the necessary insertion point restoration on drop. Without this, +/// we would be required to specify a `B: Builder` bound on the definition of [InsertionGuard]. +#[doc(hidden)] +#[allow(unused_variables)] +trait RestoreInsertionPointOnDrop { + fn restore_insertion_point_on_drop(&mut self, ip: Option); +} +impl RestoreInsertionPointOnDrop for InsertionGuard<'_, B> { + #[inline(always)] + default fn restore_insertion_point_on_drop(&mut self, _ip: Option) {} +} +impl RestoreInsertionPointOnDrop for InsertionGuard<'_, B> { + fn restore_insertion_point_on_drop(&mut self, ip: Option) { + self.builder.restore_insertion_point(ip); + } +} + +pub struct InsertionGuard<'a, B: ?Sized> { + builder: &'a mut B, ip: Option, } -impl<'a> InsertionGuard<'a> { +impl<'a, B> InsertionGuard<'a, B> +where + B: ?Sized + Builder, +{ #[allow(unused)] - pub fn new(builder: &'a mut OpBuilder, ip: InsertionPoint) -> Self { - Self { - builder, - ip: Some(ip), - } + pub fn new(builder: &'a mut B) -> Self { + let ip = builder.insertion_point().cloned(); + Self { builder, ip } + } +} +impl core::ops::Deref for InsertionGuard<'_, B> { + type Target = B; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.builder + } +} +impl core::ops::DerefMut for InsertionGuard<'_, B> { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.builder } } -impl Drop for InsertionGuard<'_> { +impl Drop for InsertionGuard<'_, B> { fn drop(&mut self) { - self.builder.restore_insertion_point(self.ip.take()); + let ip = self.ip.take(); + self.restore_insertion_point_on_drop(ip); } } diff --git a/hir2/src/ir/operation.rs b/hir2/src/ir/operation.rs index 298e970d7..bd189a1ea 100644 --- a/hir2/src/ir/operation.rs +++ b/hir2/src/ir/operation.rs @@ -640,6 +640,14 @@ impl Operation { arguments: self.operands.group_mut(info.operand_group as usize), } } + + /// Get an iterator over the successors of this operation + pub fn successor_iter(&self) -> impl DoubleEndedIterator> + '_ { + self.successors.iter().map(|info| OpSuccessor { + dest: info.block.clone(), + arguments: self.operands.group(info.operand_group as usize), + }) + } } /// Operands diff --git a/hir2/src/ir/print.rs b/hir2/src/ir/print.rs index 1ba13193f..c91e586b2 100644 --- a/hir2/src/ir/print.rs +++ b/hir2/src/ir/print.rs @@ -3,10 +3,12 @@ use core::fmt; use super::{Context, Operation}; use crate::{ formatter::PrettyPrint, + matchers::Matcher, traits::{SingleBlock, SingleRegion}, - CallableOpInterface, Entity, Value, + CallableOpInterface, EntityWithId, Value, }; +#[derive(Default)] pub struct OpPrintingFlags; /// The `OpPrinter` trait is expected to be implemented by all [Op] impls as a prequisite. @@ -33,6 +35,12 @@ impl OpPrinter for Operation { } } +impl fmt::Display for Operation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.render()) + } +} + /// The generic format for printed operations is: /// /// <%result..> = .(%operand : , ..) : #.. { @@ -73,7 +81,7 @@ impl PrettyPrint for Operation { } else { Document::Empty }; - doc += display(self.name()); + doc += display(self.name()) + const_text(" "); let doc = if is_callable_op && is_symbol && no_operands { let name = self.as_symbol().unwrap().name(); let callable = self.as_trait::().unwrap(); @@ -107,24 +115,27 @@ impl PrettyPrint for Operation { } doc } else { - let operands = self.operands(); - let doc = if !operands.is_empty() { - operands.iter().enumerate().fold(doc + const_text("("), |doc, (i, operand)| { - let operand = operand.borrow(); - let value = operand.value(); - if i > 0 { - doc + const_text(", ") - + display(value.id()) - + const_text(": ") - + display(value.ty()) - } else { - doc + display(value.id()) + const_text(": ") + display(value.ty()) - } - }) + const_text(")") + let mut is_constant = false; + let doc = if let Some(value) = crate::matchers::constant().matches(self) { + is_constant = true; + doc + value.render() } else { - doc + let operands = self.operands(); + if !operands.is_empty() { + operands.iter().enumerate().fold(doc, |doc, (i, operand)| { + let operand = operand.borrow(); + let value = operand.value(); + if i > 0 { + doc + const_text(", ") + display(value.id()) + } else { + doc + display(value.id()) + } + }) + } else { + doc + } }; - if !results.is_empty() { + let doc = if !results.is_empty() { let results = results.iter().enumerate().fold(Document::Empty, |doc, (i, result)| { if i > 0 { @@ -136,56 +147,73 @@ impl PrettyPrint for Operation { doc + const_text(" : ") + results } else { doc - } - }; + }; - let doc = self.attrs.iter().enumerate().fold(doc, |doc, (i, attr)| { - let doc = if i > 0 { doc + const_text(" ") } else { doc }; - if let Some(value) = attr.value() { - doc + const_text("#[") - + display(attr.name) - + const_text(" = ") - + value.render() - + const_text("]") + if is_constant { + doc } else { - doc + text(format!("#[{}]", &attr.name)) + self.attrs.iter().fold(doc, |doc, attr| { + let doc = doc + const_text(" "); + if let Some(value) = attr.value() { + doc + const_text("#[") + + display(attr.name) + + const_text(" = ") + + value.render() + + const_text("]") + } else { + doc + text(format!("#[{}]", &attr.name)) + } + }) } - }); + }; if self.has_regions() { self.regions.iter().fold(doc, |doc, region| { - let blocks = region.body().iter().fold(Document::Empty, |doc, block| { - let ops = - block.body().iter().fold(Document::Empty, |doc, op| doc + op.render()); - if is_single_region_single_block && no_operands { - doc + indent(4, nl() + ops) + nl() - } else { - let block_args = block.arguments().iter().enumerate().fold( + let blocks = region.body().iter().enumerate().fold( + Document::Empty, + |mut doc, (block_index, block)| { + if block_index > 0 { + doc += nl(); + } + let ops = block.body().iter().enumerate().fold( Document::Empty, - |doc, (i, arg)| { + |mut doc, (i, op)| { if i > 0 { - doc + const_text(", ") + arg.borrow().render() - } else { - doc + arg.borrow().render() + doc += nl(); } + doc + op.render() }, ); - let block_args = if block_args.is_empty() { - block_args + if is_single_region_single_block && no_operands { + doc + indent(4, nl() + ops) } else { - const_text("(") + block_args + const_text(")") - }; - doc + indent( - 4, - text(format!("^{}", block.id())) - + block_args - + const_text(":") - + nl() - + ops, - ) + nl() - } - }); - doc + indent(4, const_text(" {") + nl() + blocks) + nl() + const_text("}") + let block_args = block.arguments().iter().enumerate().fold( + Document::Empty, + |doc, (i, arg)| { + if i > 0 { + doc + const_text(", ") + arg.borrow().render() + } else { + doc + arg.borrow().render() + } + }, + ); + let block_args = if block_args.is_empty() { + block_args + } else { + const_text("(") + block_args + const_text(")") + }; + doc + indent( + 4, + text(format!("^{}", block.id())) + + block_args + + const_text(":") + + nl() + + ops, + ) + } + }, + ); + doc + const_text(" {") + nl() + blocks + nl() + const_text("}") }) + const_text(";") } else { doc + const_text(";") diff --git a/hir2/src/ir/region.rs b/hir2/src/ir/region.rs index 9d2edf6da..86ed702e0 100644 --- a/hir2/src/ir/region.rs +++ b/hir2/src/ir/region.rs @@ -1,4 +1,25 @@ +mod branch_point; +mod interfaces; +mod invocation_bounds; +mod kind; +mod successor; +mod transforms; + +use smallvec::SmallVec; + +pub use self::{ + branch_point::RegionBranchPoint, + interfaces::{ + RegionBranchOpInterface, RegionBranchTerminatorOpInterface, RegionKindInterface, + RegionSuccessorIter, + }, + invocation_bounds::InvocationBounds, + kind::RegionKind, + successor::{RegionSuccessor, RegionSuccessorInfo, RegionSuccessorMut}, + transforms::RegionTransformFailed, +}; use super::*; +use crate::RegionSimplificationLevel; pub type RegionRef = UnsafeIntrusiveEntityRef; /// An intrusive, doubly-linked list of [Region]s @@ -206,6 +227,197 @@ impl Region { /// Drop any references to blocks in this region - this is used to break cycles when cleaning /// up regions. pub fn drop_all_references(&mut self) { - todo!() + let mut cursor = self.body_mut().front_mut(); + while let Some(mut op) = cursor.as_pointer() { + op.borrow_mut().drop_all_references(); + cursor.move_next(); + } + } +} + +/// Values +impl Region { + /// Check if every value in `values` is defined above this region, i.e. they are defined in a + /// region which is a proper ancestor of `self`. + pub fn values_are_defined_above(&self, values: &[ValueRef]) -> bool { + let this = self.as_region_ref(); + for value in values { + if !value + .borrow() + .parent_region() + .is_some_and(|value_region| Self::is_proper_ancestor_of(&value_region, &this)) + { + return false; + } + } + true + } + + /// Replace all uses of `value` with `replacement`, within this region. + pub fn replace_all_uses_in_region_with(&mut self, _value: ValueRef, _replacement: ValueRef) { + todo!("RegionUtils.h") + } + + /// Visit each use of a value in this region (and its descendants), where that value was defined + /// in an ancestor of `limit`. + pub fn visit_used_values_defined_above(&self, _limit: &RegionRef, _callback: F) + where + F: FnMut(OpOperand), + { + todo!("RegionUtils.h") + } + + /// Visit each use of a value in any of the provided regions (or their descendants), where that + /// value was defined in an ancestor of that region. + pub fn visit_used_values_defined_above_any(_regions: &[RegionRef], _callback: F) + where + F: FnMut(OpOperand), + { + todo!("RegionUtils.h") + } + + /// Return a vector of values used in this region (and its descendants), and defined in an + /// ancestor of the `limit` region. + pub fn get_used_values_defined_above(&self, _limit: &RegionRef) -> SmallVec<[ValueRef; 1]> { + todo!("RegionUtils.h") + } + + /// Return a vector of values used in any of the provided regions, but defined in an ancestor. + pub fn get_used_values_defined_above_any(_regions: &[RegionRef]) -> SmallVec<[ValueRef; 1]> { + todo!("RegionUtils.h") + } + + /// Make this region isolated from above. + /// + /// * Capture the values that are defined above the region and used within it. + /// * Append block arguments to the entry block that represent each captured value. + /// * Replace all uses of the captured values within the region, with the new block arguments + /// * `clone_into_region` is called with the defining op of a captured value. If it returns + /// true, it indicates that the op needs to be cloned into the region. As a result, the + /// operands of that operation become part of the captured value set (unless the operations + /// that define the operand values themselves are to be cloned). The cloned operations are + /// added to the entry block of the region. + /// + /// Returns the set of captured values. + pub fn make_isolated_from_above( + &mut self, + _rewriter: &mut R, + _clone_into_region: F, + ) -> SmallVec<[ValueRef; 1]> + where + R: crate::Rewriter, + F: Fn(&Operation) -> bool, + { + todo!("RegionUtils.h") + } +} + +/// Queries +impl Region { + pub fn find_common_ancestor(ops: &[OperationRef]) -> Option { + use bitvec::prelude::*; + + match ops.len() { + 0 => None, + 1 => unsafe { ops.get_unchecked(0) }.borrow().parent_region(), + num_ops => { + let (first, rest) = unsafe { ops.split_first().unwrap_unchecked() }; + let mut region = first.borrow().parent_region(); + let mut remaining_ops = bitvec![1; num_ops - 1]; + while let Some(r) = region.take() { + while let Some(index) = remaining_ops.first_one() { + // Is this op contained in `region`? + if r.borrow().find_ancestor_op(&rest[index]).is_some() { + unsafe { + remaining_ops.set_unchecked(index, false); + } + } + } + if remaining_ops.not_any() { + break; + } + region = r.borrow().parent_region(); + } + region + } + } + } + + /// Returns `block` if `block` lies in this region, or otherwise finds the ancestor of `block` + /// that lies in this region. + /// + /// Returns `None` if the latter fails. + pub fn find_ancestor_block(&self, block: &BlockRef) -> Option { + let this = self.as_region_ref(); + let mut current = Some(block.clone()); + while let Some(current_block) = current.take() { + let parent = current_block.borrow().parent()?; + if parent == this { + return Some(current_block); + } + current = + parent.borrow().owner.as_ref().and_then(|parent_op| parent_op.borrow().parent()); + } + current + } + + /// Returns `op` if `op` lies in this region, or otherwise finds the ancestor of `op` that lies + /// in this region. + /// + /// Returns `None` if the latter fails. + pub fn find_ancestor_op(&self, op: &OperationRef) -> Option { + let this = self.as_region_ref(); + let mut current = Some(op.clone()); + while let Some(current_op) = current.take() { + let parent = current_op.borrow().parent_region()?; + if parent == this { + return Some(current_op); + } + current = parent.borrow().parent(); + } + current + } +} + +/// Transforms +impl Region { + /// Run a set of structural simplifications over the regions in `regions`. + /// + /// This includes transformations like unreachable block elimination, dead argument elimination, + /// as well as some other DCE. + /// + /// This function returns `Ok` if any of the regions were simplified, `Err` otherwise. + /// + /// The provided rewriter is used to notify callers of operation and block deletion. + /// + /// The provided [RegionSimplificationLevel] will be used to determine whether to apply more + /// aggressive simplifications, namely block merging. Note that when block merging is enabled, + /// this can lead to merged blocks with extra arguments. + pub fn simplify_all( + regions: &[RegionRef], + rewriter: &mut dyn crate::Rewriter, + simplification_level: RegionSimplificationLevel, + ) -> Result<(), RegionTransformFailed> { + let merge_blocks = matches!(simplification_level, RegionSimplificationLevel::Aggressive); + + let eliminated_blocks = Self::erase_unreachable_blocks(regions, rewriter).is_ok(); + let eliminated_ops_or_args = Self::dead_code_elimination(regions, rewriter).is_ok(); + + let mut merged_identical_blocks = false; + let mut dropped_redundant_arguments = false; + if merge_blocks { + merged_identical_blocks = Self::merge_identical_blocks(regions, rewriter).is_ok(); + dropped_redundant_arguments = Self::drop_redundant_arguments(regions, rewriter).is_ok(); + } + + if eliminated_blocks + || eliminated_ops_or_args + || merged_identical_blocks + || dropped_redundant_arguments + { + Ok(()) + } else { + Err(RegionTransformFailed) + } } } diff --git a/hir2/src/ir/region/branch_point.rs b/hir2/src/ir/region/branch_point.rs new file mode 100644 index 000000000..f3419f4e4 --- /dev/null +++ b/hir2/src/ir/region/branch_point.rs @@ -0,0 +1,50 @@ +use core::fmt; + +use super::*; + +/// This type represents a point being branched from in the methods of `RegionBranchOpInterface`. +/// +/// One can branch from one of two different kinds of places: +/// +/// * The parent operation (i.e. the op implementing `RegionBranchOpInterface`). +/// * A region within the parent operation (where the parent implements `RegionBranchOpInterface`). +#[derive(Clone, PartialEq, Eq)] +pub enum RegionBranchPoint { + /// A branch from the current operation to one of its regions + Parent, + /// A branch from the given region, within a parent `RegionBranchOpInterface` op + Child(RegionRef), +} +impl fmt::Debug for RegionBranchPoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Parent => f.write_str("Parent"), + Self::Child(ref region) => { + f.debug_tuple("Child").field(&format_args!("{:p}", region)).finish() + } + } + } +} +impl RegionBranchPoint { + /// Returns true if branching from the parent op. + #[inline] + pub fn is_parent(&self) -> bool { + matches!(self, Self::Parent) + } + + /// Returns the region if branching from a region, otherwise `None`. + pub fn region(&self) -> Option { + match self { + Self::Child(ref region) => Some(region.clone()), + Self::Parent => None, + } + } +} +impl<'a> From> for RegionBranchPoint { + fn from(succ: RegionSuccessor<'a>) -> Self { + match succ.into_successor() { + None => Self::Parent, + Some(succ) => Self::Child(succ), + } + } +} diff --git a/hir2/src/ir/region/interfaces.rs b/hir2/src/ir/region/interfaces.rs new file mode 100644 index 000000000..009c3202b --- /dev/null +++ b/hir2/src/ir/region/interfaces.rs @@ -0,0 +1,247 @@ +use super::*; +use crate::{ + attributes::AttributeValue, traits::Terminator, Op, SuccessorOperandRange, + SuccessorOperandRangeMut, Type, +}; + +/// An op interface that indicates what types of regions it holds +pub trait RegionKindInterface { + /// Get the [RegionKind] for this operation + fn kind(&self) -> RegionKind; + /// Returns true if the kind of this operation's regions requires SSA dominance + #[inline] + fn has_ssa_dominance(&self) -> bool { + matches!(self.kind(), RegionKind::SSA) + } + #[inline] + fn has_graph_regions(&self) -> bool { + matches!(self.kind(), RegionKind::Graph) + } +} + +// TODO(pauls): Implement verifier +/// This interface provides information for region operations that exhibit branching behavior +/// between held regions. I.e., this interface allows for expressing control flow information for +/// region holding operations. +/// +/// This interface is meant to model well-defined cases of control-flow and value propagation, +/// where what occurs along control-flow edges is assumed to be side-effect free. +/// +/// A "region branch point" indicates a point from which a branch originates. It can indicate either +/// a region of this op or [RegionBranchPoint::Parent]. In the latter case, the branch originates +/// from outside of the op, i.e., when first executing this op. +/// +/// A "region successor" indicates the target of a branch. It can indicate either a region of this +/// op or this op. In the former case, the region successor is a region pointer and a range of block +/// arguments to which the "successor operands" are forwarded to. In the latter case, the control +/// flow leaves this op and the region successor is a range of results of this op to which the +/// successor operands are forwarded to. +/// +/// By default, successor operands and successor block arguments/successor results must have the +/// same type. `areTypesCompatible` can be implemented to allow non-equal types. +/// +/// ## Example +/// +/// ```hir,ignore +/// %r = scf.for %iv = %lb to %ub step %step iter_args(%a = %b) +/// -> tensor<5xf32> { +/// ... +/// scf.yield %c : tensor<5xf32> +/// } +/// ``` +/// +/// `scf.for` has one region. The region has two region successors: the region itself and the +/// `scf.for` op. `%b` is an entry successor operand. `%c` is a successor operand. `%a` is a +/// successor block argument. `%r` is a successor result. +pub trait RegionBranchOpInterface: Op { + /// Returns the operands of this operation that are forwarded to the region successor's block + /// arguments or this operation's results when branching to `point`. `point` is guaranteed to + /// be among the successors that are returned by `get_entry_succcessor_regions` or + /// `get_successor_regions(parent_op())`. + /// + /// ## Example + /// + /// In the example in the top-level docs of this trait, this function returns the operand `%b` + /// of the `scf.for` op, regardless of the value of `point`, i.e. this op always forwards the + /// same operands, regardless of whether the loop has 0 or more iterations. + #[inline] + #[allow(unused_variables)] + fn get_entry_successor_operands(&self, point: RegionBranchPoint) -> SuccessorOperandRange<'_> { + crate::SuccessorOperandRange::empty() + } + /// Returns the potential region successors when first executing the op. + /// + /// Unlike [get_successor_regions], this method also passes along the constant operands of this + /// op. Based on these, the implementation may filter out certain successors. By default, it + /// simply dispatches to `get_successor_regions`. `operands` contains an entry for every operand + /// of this op, with `None` representing if the operand is non-constant. + /// + /// NOTE: The control flow does not necessarily have to enter any region of this op. + /// + /// ## Example + /// + /// In the example in the top-level docs of this trait, this function may return two region + /// successors: the single region of the `scf.for` op and the `scf.for` operation (that + /// implements this interface). If `%lb`, `%ub`, `%step` are constants and it can be determined + /// the loop does not have any iterations, this function may choose to return only this + /// operation. Similarly, if it can be determined that the loop has at least one iteration, this + /// function may choose to return only the region of the loop. + #[inline] + #[allow(unused_variables)] + fn get_entry_successor_regions( + &self, + operands: &[Option>], + ) -> RegionSuccessorIter<'_> { + self.get_successor_regions(RegionBranchPoint::Parent) + } + /// Returns the potential region successors when branching from `point`. + /// + /// These are the regions that may be selected during the flow of control. + /// + /// When `point` is [RegionBranchPoint::Parent], this function returns the region successors + /// when entering the operation. Otherwise, this method returns the successor regions when + /// branching from the region indicated by `point`. + /// + /// ## Example + /// + /// In the example in the top-level docs of this trait, this function returns the region of the + /// `scf.for` and this operation for either region branch point (`parent` and the region of the + /// `scf.for`). An implementation may choose to filter out region successors when it is + /// statically known (e.g., by examining the operands of this op) that those successors are not + /// branched to. + fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_>; + /// Returns a set of invocation bounds, representing the minimum and maximum number of times + /// this operation will invoke each attached region (assuming the regions yield normally, i.e. + /// do not abort or invoke an infinite loop). The minimum number of invocations is at least 0. + /// If the maximum number of invocations cannot be statically determined, then it will be set to + /// [InvocationBounds::unknown]. + /// + /// This function also passes along the constant operands of this op. `operands` contains an + /// entry for every operand of this op, with `None` representing if the operand is non-constant. + /// + /// This function may be called speculatively on operations where the provided operands are not + /// necessarily the same as the operation's current operands. This may occur in analyses that + /// wish to determine "what would be the region invocations if these were the operands?" + #[inline] + #[allow(unused_variables)] + fn get_region_invocation_bounds( + &self, + operands: &[Option>], + ) -> SmallVec<[InvocationBounds; 1]> { + use smallvec::smallvec; + + smallvec![InvocationBounds::Unknown; self.num_regions()] + } + /// This function is called to compare types along control-flow edges. + /// + /// By default, the types are check for exact equality. + #[inline] + fn are_types_compatible(&self, lhs: &Type, rhs: &Type) -> bool { + lhs == rhs + } + /// Returns `true` if control flow originating from the region at `index` may eventually branch + /// back to the same region, either from itself, or after passing through other regions first. + fn is_repetitive_region(&self, index: usize) -> bool; + /// Returns `true` if there is a loop in the region branching graph. + /// + /// Only reachable regions (starting from the entry region) are considered. + fn has_loop(&self) -> bool; +} + +// TODO(pauls): Implement verifier (should have no results and no successors) +/// This interface provides information for branching terminator operations in the presence of a +/// parent [RegionBranchOpInterface] implementation. It specifies which operands are passed to which +/// successor region. +pub trait RegionBranchTerminatorOpInterface: Op + Terminator { + /// Get a range of operands corresponding to values that are semantically "returned" by passing + /// them to the region successor indicated by `point`. + fn get_successor_operands(&self, point: RegionBranchPoint) -> SuccessorOperandRange<'_>; + /// Get a mutable range of operands corresponding to values that are semantically "returned" by + /// passing them to the region successor indicated by `point`. + fn get_mutable_successor_operands( + &mut self, + point: RegionBranchPoint, + ) -> SuccessorOperandRangeMut<'_>; + /// Returns the potential region successors that are branched to after this terminator based on + /// the given constant operands. + /// + /// This method also passes along the constant operands of this op. `operands` contains an entry + /// for every operand of this op, with `None` representing non-constant values. + /// + /// The default implementation simply dispatches to the parent `RegionBranchOpInterface`'s + /// `get_successor_regions` implementation. + #[allow(unused_variables)] + fn get_successor_regions( + &self, + operands: &[Option>], + ) -> SmallVec<[RegionSuccessorInfo; 2]> { + let parent_region = + self.parent_region().expect("expected operation to have a parent region"); + let parent_op = + parent_region.borrow().parent().expect("expected operation to have a parent op"); + parent_op + .borrow() + .as_trait::() + .expect("invalid region terminator parent: must implement RegionBranchOpInterface") + .get_successor_regions(RegionBranchPoint::Child(parent_region)) + .into_successor_infos() + } +} + +pub struct RegionSuccessorIter<'a> { + op: &'a Operation, + successors: SmallVec<[RegionSuccessorInfo; 2]>, + index: usize, +} +impl<'a> RegionSuccessorIter<'a> { + pub fn new( + op: &'a Operation, + successors: impl IntoIterator, + ) -> Self { + Self { + op, + successors: SmallVec::from_iter(successors), + index: 0, + } + } + + pub fn empty(op: &'a Operation) -> Self { + Self { + op, + successors: Default::default(), + index: 0, + } + } + + pub fn get(&self, index: usize) -> Option> { + self.successors.get(index).map(|info| RegionSuccessor { + dest: info.successor.clone(), + arguments: self.op.operands().group(info.operand_group as usize), + }) + } + + pub fn into_successor_infos(self) -> SmallVec<[RegionSuccessorInfo; 2]> { + self.successors + } +} +impl core::iter::FusedIterator for RegionSuccessorIter<'_> {} +impl<'a> ExactSizeIterator for RegionSuccessorIter<'a> { + fn len(&self) -> usize { + self.successors.len() + } +} +impl<'a> Iterator for RegionSuccessorIter<'a> { + type Item = RegionSuccessor<'a>; + + fn next(&mut self) -> Option { + if self.index >= self.successors.len() { + return None; + } + + let info = &self.successors[self.index]; + Some(RegionSuccessor { + dest: info.successor.clone(), + arguments: self.op.operands().group(info.operand_group as usize), + }) + } +} diff --git a/hir2/src/ir/region/invocation_bounds.rs b/hir2/src/ir/region/invocation_bounds.rs new file mode 100644 index 000000000..574471747 --- /dev/null +++ b/hir2/src/ir/region/invocation_bounds.rs @@ -0,0 +1,132 @@ +use core::ops::{Bound, RangeBounds}; + +/// This type represents upper and lower bounds on the number of times a region of a +/// `RegionBranchOpInterface` op can be invoked. The lower bound is at least zero, but the upper +/// bound may not be known. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum InvocationBounds { + /// The region can be invoked an unknown number of times, possibly never. + #[default] + Unknown, + /// The region can never be invoked + Never, + /// The region can be invoked exactly N times + Exact(u32), + /// The region can be invoked any number of times in the given range + Variable { min: u32, max: u32 }, + /// The region can be invoked at least N times, but an unknown number of times beyond that. + AtLeastN(u32), + /// The region can be invoked any number of times up to N + NoMoreThan(u32), +} +impl InvocationBounds { + #[inline] + pub fn new(bounds: impl Into) -> Self { + bounds.into() + } + + #[inline] + pub fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } + + pub fn min(&self) -> Bound<&u32> { + self.start_bound() + } + + pub fn max(&self) -> Bound<&u32> { + self.end_bound() + } +} +impl From for InvocationBounds { + fn from(value: u32) -> Self { + if value == 0 { + Self::Never + } else { + Self::Exact(value) + } + } +} +impl From> for InvocationBounds { + fn from(range: core::ops::Range) -> Self { + if range.start == range.end { + Self::Never + } else if range.end == range.start + 1 { + Self::Exact(range.start) + } else { + assert!(range.start < range.end); + Self::Variable { + min: range.start, + max: range.end, + } + } + } +} +impl From> for InvocationBounds { + fn from(value: core::ops::RangeFrom) -> Self { + if value.start == 0 { + Self::Unknown + } else { + Self::AtLeastN(value.start) + } + } +} +impl From> for InvocationBounds { + fn from(value: core::ops::RangeTo) -> Self { + if value.end == 1 { + Self::Never + } else if value.end == u32::MAX { + Self::Unknown + } else { + Self::NoMoreThan(value.end - 1) + } + } +} +impl From for InvocationBounds { + fn from(_value: core::ops::RangeFull) -> Self { + Self::Unknown + } +} +impl From> for InvocationBounds { + fn from(range: core::ops::RangeInclusive) -> Self { + let (start, end) = range.into_inner(); + if start == 0 && end == 0 { + Self::Never + } else if start == end { + Self::Exact(start) + } else { + Self::Variable { + min: start, + max: end + 1, + } + } + } +} +impl From> for InvocationBounds { + fn from(range: core::ops::RangeToInclusive) -> Self { + if range.end == 0 { + Self::Never + } else { + Self::NoMoreThan(range.end) + } + } +} +impl RangeBounds for InvocationBounds { + fn start_bound(&self) -> Bound<&u32> { + match self { + Self::Unknown | Self::NoMoreThan(_) => Bound::Unbounded, + Self::Never => Bound::Excluded(&0), + Self::Exact(n) | Self::Variable { min: n, .. } => Bound::Included(n), + Self::AtLeastN(n) => Bound::Included(n), + } + } + + fn end_bound(&self) -> Bound<&u32> { + match self { + Self::Unknown | Self::AtLeastN(_) => Bound::Unbounded, + Self::Never => Bound::Excluded(&0), + Self::Exact(n) | Self::Variable { max: n, .. } => Bound::Excluded(n), + Self::NoMoreThan(n) => Bound::Excluded(n), + } + } +} diff --git a/hir2/src/ir/region/kind.rs b/hir2/src/ir/region/kind.rs new file mode 100644 index 000000000..83eb6da8e --- /dev/null +++ b/hir2/src/ir/region/kind.rs @@ -0,0 +1,22 @@ +/// Represents the types of regions that can be represented in the IR +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RegionKind { + /// A graph region is one without control-flow semantics, i.e. dataflow between operations is + /// the only thing that dictates order, and operations can be conceptually executed in parallel + /// if the runtime supports it. + /// + /// As there is no control-flow in these regions, graph regions may only contain a single block. + Graph, + /// An SSA region is one where the strict control-flow semantics and properties of SSA (static + /// single assignment) form must be upheld. + /// + /// SSA regions must adhere to: + /// + /// * Values can only be defined once + /// * Definitions must dominate uses + /// * Ordering of operations in a block corresponds to execution order, i.e. operations earlier + /// in a block dominate those later in the block. + /// * Blocks must end with a terminator. + #[default] + SSA, +} diff --git a/hir2/src/ir/region/successor.rs b/hir2/src/ir/region/successor.rs new file mode 100644 index 000000000..3d69b2f0a --- /dev/null +++ b/hir2/src/ir/region/successor.rs @@ -0,0 +1,96 @@ +use core::fmt; + +use super::*; +use crate::{OpOperandRange, OpOperandRangeMut}; + +/// This struct represents owned region successor metadata +#[derive(Clone)] +pub struct RegionSuccessorInfo { + pub successor: RegionBranchPoint, + #[allow(unused)] + pub(crate) key: Option>, + pub(crate) operand_group: u8, +} + +/// A [RegionSuccessor] represents the successor of a region. +/// +/// +/// A region successor can either be another region, or the parent operation. If the successor is a +/// region, this class represents the destination region, as well as a set of arguments from that +/// region that will be populated when control flows into the region. If the successor is the parent +/// operation, this class represents an optional set of results that will be populated when control +/// returns to the parent operation. +/// +/// This interface assumes that the values from the current region that are used to populate the +/// successor inputs are the operands of the return-like terminator operations in the blocks within +/// this region. +pub struct RegionSuccessor<'a> { + pub dest: RegionBranchPoint, + pub arguments: OpOperandRange<'a>, +} +impl<'a> RegionSuccessor<'a> { + /// Returns true if the successor is the parent op + pub fn is_parent(&self) -> bool { + self.dest.is_parent() + } + + pub fn successor(&self) -> Option { + self.dest.region() + } + + pub fn into_successor(self) -> Option { + self.dest.region() + } + + /// Return the inputs to the successor that are remapped by the exit values of the current + /// region. + pub fn successor_inputs(&self) -> &OpOperandRange<'a> { + &self.arguments + } +} +impl fmt::Debug for RegionSuccessor<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RegionSuccessor") + .field("dest", &self.dest) + .field_with("arguments", |f| f.debug_list().entries(self.arguments.iter()).finish()) + .finish() + } +} + +/// The mutable version of [RegionSuccessor] +pub struct RegionSuccessorMut<'a> { + pub dest: RegionBranchPoint, + pub arguments: OpOperandRangeMut<'a>, +} +impl<'a> RegionSuccessorMut<'a> { + /// Returns true if the successor is the parent op + pub fn is_parent(&self) -> bool { + self.dest.is_parent() + } + + pub fn successor(&self) -> Option { + self.dest.region() + } + + pub fn into_successor(self) -> Option { + self.dest.region() + } + + /// Return the inputs to the successor that are remapped by the exit values of the current + /// region. + pub fn successor_inputs(&mut self) -> &mut OpOperandRangeMut<'a> { + &mut self.arguments + } + + pub fn into_successor_inputs(self) -> OpOperandRangeMut<'a> { + self.arguments + } +} +impl fmt::Debug for RegionSuccessorMut<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RegionSuccessorMut") + .field("dest", &self.dest) + .field_with("arguments", |f| f.debug_list().entries(self.arguments.iter()).finish()) + .finish() + } +} diff --git a/hir2/src/ir/region/transforms.rs b/hir2/src/ir/region/transforms.rs new file mode 100644 index 000000000..10f7d442d --- /dev/null +++ b/hir2/src/ir/region/transforms.rs @@ -0,0 +1,6 @@ +mod block_merging; +mod dce; +mod drop_redundant_args; + +#[derive(Debug, Copy, Clone)] +pub struct RegionTransformFailed; diff --git a/hir2/src/ir/region/transforms/block_merging.rs b/hir2/src/ir/region/transforms/block_merging.rs new file mode 100644 index 000000000..cf8bbc310 --- /dev/null +++ b/hir2/src/ir/region/transforms/block_merging.rs @@ -0,0 +1,263 @@ +#![allow(unused)] +use alloc::collections::BTreeMap; + +use super::RegionTransformFailed; +use crate::{ + BlockArgument, BlockRef, DynHash, OpResult, Operation, OperationRef, Region, RegionRef, + Rewriter, ValueRef, +}; + +bitflags::bitflags! { + struct EquivalenceFlags: u8 { + const IGNORE_LOCATIONS = 1; + } +} + +struct OpEquivalence { + flags: EquivalenceFlags, + operand_hasher: OperandHasher, + result_hasher: ResultHasher, +} + +type ValueHasher = Box; + +impl OpEquivalence { + pub fn new() -> Self { + Self { + flags: EquivalenceFlags::empty(), + operand_hasher: DefaultValueHasher, + result_hasher: DefaultValueHasher, + } + } +} +impl OpEquivalence { + #[inline] + pub fn with_flags(mut self, flags: EquivalenceFlags) -> Self { + self.flags.insert(flags); + self + } + + /// Ignore op operands when computing equivalence for operations + pub fn ignore_operands(self) -> OpEquivalence<(), ResultHasher> { + OpEquivalence { + flags: self.flags, + operand_hasher: (), + result_hasher: self.result_hasher, + } + } + + /// Ignore op results when computing equivalence for operations + pub fn ignore_results(self) -> OpEquivalence { + OpEquivalence { + flags: self.flags, + operand_hasher: self.operand_hasher, + result_hasher: (), + } + } + + /// Specify a custom hasher for op operands + pub fn with_operand_hasher( + self, + hasher: impl Fn(&ValueRef, &mut dyn core::hash::Hasher) + 'static, + ) -> OpEquivalence { + OpEquivalence { + flags: self.flags, + operand_hasher: Box::new(hasher), + result_hasher: self.result_hasher, + } + } + + /// Specify a custom hasher for op results + pub fn with_result_hasher( + self, + hasher: impl Fn(&ValueRef, &mut dyn core::hash::Hasher) + 'static, + ) -> OpEquivalence { + OpEquivalence { + flags: self.flags, + operand_hasher: self.operand_hasher, + result_hasher: Box::new(hasher), + } + } + + /// Compare if two operations are equivalent using the current equivalence configuration. + /// + /// This is equivalent to calling [compute_equivalence] with `are_values_equivalent` set to + /// `ValueRef::ptr_eq`, and `on_value_equivalence` to a no-op. + #[inline] + pub fn are_equivalent(&self, lhs: &OperationRef, rhs: &OperationRef) -> bool { + #[inline(always)] + fn noop(_: &ValueRef, _: &ValueRef) {} + + self.compute_equivalence(lhs, rhs, ValueRef::ptr_eq, noop) + } + + /// Compare if two operations (and their regions) are equivalent using the current equivalence + /// configuration. + /// + /// * `are_values_equivalent` is a callback used to check if two values are equivalent. For + /// two operations to be equivalent, their operands must be the same SSA value, or this + /// callback must return `true`. + /// * `on_value_equivalence` is a callback to inform the caller that the analysis determined + /// that two values are equivalent. + /// + /// NOTE: Additional information regarding value equivalence can be injected into the analysis + /// via `are_values_equivalent`. Typically, callers may want values that were recorded as + /// equivalent via `on_value_equivalence` to be reflected in `are_values_equivalent`, but it + /// depends on the exact semantics desired by the caller. + pub fn compute_equivalence( + &self, + lhs: &OperationRef, + rhs: &OperationRef, + are_values_equivalent: VE, + on_value_equivalence: OVE, + ) -> bool + where + VE: Fn(&ValueRef, &ValueRef) -> bool, + OVE: FnMut(&ValueRef, &ValueRef), + { + todo!() + } + + /// Compare if two regions are equivalent using the current equivalence configuration. + /// + /// See [compute_equivalence] for more details. + pub fn compute_region_equivalence( + &self, + lhs: &RegionRef, + rhs: &RegionRef, + are_values_equivalent: VE, + on_value_equivalence: OVE, + ) -> bool + where + VE: Fn(&ValueRef, &ValueRef) -> bool, + OVE: FnMut(&ValueRef, &ValueRef), + { + todo!() + } + + /// Hashes an operation based on: + /// + /// * OperationName + /// * Attributes + /// * Result types + fn hash_operation(&self, op: &Operation, hasher: &mut impl core::hash::Hasher) { + use core::hash::Hash; + + use crate::Value; + + op.name().hash(hasher); + for attr in op.attributes().iter() { + attr.hash(hasher); + } + for result in op.results().iter() { + result.borrow().ty().hash(hasher); + } + } +} + +#[inline(always)] +pub fn ignore_value_equivalence(_lhs: &ValueRef, _rhs: &ValueRef) -> bool { + true +} + +struct DefaultValueHasher; +impl FnOnce<(&ValueRef, &mut dyn core::hash::Hasher)> for DefaultValueHasher { + type Output = (); + + extern "rust-call" fn call_once( + self, + args: (&ValueRef, &mut dyn core::hash::Hasher), + ) -> Self::Output { + use core::hash::Hash; + + let (value, hasher) = args; + value.dyn_hash(hasher); + } +} +impl FnMut<(&ValueRef, &mut dyn core::hash::Hasher)> for DefaultValueHasher { + extern "rust-call" fn call_mut( + &mut self, + args: (&ValueRef, &mut dyn core::hash::Hasher), + ) -> Self::Output { + use core::hash::Hash; + + let (value, hasher) = args; + value.dyn_hash(hasher); + } +} +impl Fn<(&ValueRef, &mut dyn core::hash::Hasher)> for DefaultValueHasher { + extern "rust-call" fn call( + &self, + args: (&ValueRef, &mut dyn core::hash::Hasher), + ) -> Self::Output { + use core::hash::Hash; + + let (value, hasher) = args; + value.dyn_hash(hasher); + } +} + +struct BlockEquivalenceData { + /// The block this data refers to + block: BlockRef, + /// The hash for this block + hash: u64, + /// A map of result producing operations to their relative orders within this block. The order + /// of an operation is the number of defined values that are produced within the block before + /// this operation. + op_order_index: BTreeMap, +} +impl BlockEquivalenceData { + pub fn new(block: BlockRef) -> Self { + use core::hash::Hasher; + + let mut op_order_index = BTreeMap::default(); + + let b = block.borrow(); + let mut order = b.num_arguments() as u32; + let mut op_equivalence = OpEquivalence::new() + .with_flags(EquivalenceFlags::IGNORE_LOCATIONS) + .ignore_operands() + .ignore_results(); + + let mut hasher = rustc_hash::FxHasher::default(); + for op in b.body() { + let num_results = op.num_results() as u32; + if num_results > 0 { + op_order_index.insert(op.as_operation_ref(), order); + order += num_results; + } + op_equivalence.hash_operation(&op, &mut hasher); + } + + Self { + block, + hash: hasher.finish(), + op_order_index, + } + } + + fn get_order_of(&self, value: &ValueRef) -> usize { + let value = value.borrow(); + assert!(value.parent_block().unwrap() == self.block, "expected value of this block"); + + if let Some(block_arg) = value.downcast_ref::() { + return block_arg.index(); + } + + let result = value.downcast_ref::().unwrap(); + let order = + *self.op_order_index.get(&result.owner()).expect("expected op to have an order"); + result.index() + (order as usize) + } +} + +impl Region { + // TODO(pauls) + pub(in crate::ir::region) fn merge_identical_blocks( + _regions: &[RegionRef], + _rewriter: &mut dyn Rewriter, + ) -> Result<(), RegionTransformFailed> { + Err(RegionTransformFailed) + } +} diff --git a/hir2/src/ir/region/transforms/dce.rs b/hir2/src/ir/region/transforms/dce.rs new file mode 100644 index 000000000..84497cd77 --- /dev/null +++ b/hir2/src/ir/region/transforms/dce.rs @@ -0,0 +1,354 @@ +use alloc::collections::{BTreeSet, VecDeque}; + +use smallvec::SmallVec; + +use super::RegionTransformFailed; +use crate::{ + traits::{BranchOpInterface, Terminator}, + OpOperandImpl, OpResult, Operation, OperationRef, PostOrderBlockIter, Region, RegionRef, + Rewriter, SuccessorOperands, ValueRef, +}; + +/// Data structure used to track which values have already been proved live. +/// +/// Because operations can have multiple results, this data structure tracks liveness for both +/// values and operations to avoid having to look through all results when analyzing a use. +/// +/// This data structure essentially tracks the dataflow lattice. The set of values/ops proved live +/// increases monotonically to a fixed-point. +#[derive(Default)] +struct LiveMap { + values: BTreeSet, + ops: BTreeSet, + changed: bool, +} +impl LiveMap { + pub fn was_proven_live(&self, value: &ValueRef) -> bool { + // TODO(pauls): For results that are removable, e.g. for region based control flow, + // we could allow for these values to be tracked independently. + let val = value.borrow(); + if let Some(result) = val.downcast_ref::() { + self.ops.contains(&result.owner()) + } else { + self.values.contains(value) + } + } + + #[inline] + pub fn was_op_proven_live(&self, op: &OperationRef) -> bool { + self.ops.contains(op) + } + + pub fn set_proved_live(&mut self, value: ValueRef) { + // TODO(pauls): For results that are removable, e.g. for region based control flow, + // we could allow for these values to be tracked independently. + let val = value.borrow(); + if let Some(result) = val.downcast_ref::() { + self.changed |= self.ops.insert(result.owner()); + } else { + self.changed |= self.values.insert(value); + } + } + + pub fn set_op_proved_live(&mut self, op: OperationRef) { + self.changed |= self.ops.insert(op); + } + + #[inline(always)] + pub fn mark_unchanged(&mut self) { + self.changed = false; + } + + #[inline(always)] + pub const fn has_changed(&self) -> bool { + self.changed + } + + pub fn is_use_specially_known_dead(&self, user: &OpOperandImpl) -> bool { + // DCE generally treats all uses of an op as live if the op itself is considered live. + // However, for successor operands to terminators we need a finer-grained notion where we + // deduce liveness for operands individually. The reason for this is easiest to think about + // in terms of a classical phi node based SSA IR, where each successor operand is really an + // operand to a _separate_ phi node, rather than all operands to the branch itself as with + // the block argument representation that we use. + // + // And similarly, because each successor operand is really an operand to a phi node, rather + // than to the terminator op itself, a terminator op can't e.g. "print" the value of a + // successor operand. + let owner = &user.owner; + if owner.borrow().implements::() { + if let Some(branch_interface) = owner.borrow().as_trait::() { + if let Some(arg) = + branch_interface.get_successor_block_argument(user.index as usize) + { + return !self.was_proven_live(&arg.upcast()); + } + } + } + + false + } + + pub fn propagate_region_liveness(&mut self, region: &Region) { + if region.body().is_empty() { + return; + } + + for block in PostOrderBlockIter::new(region.body().front().as_pointer().unwrap()) { + // We process block arguments after the ops in the block, to promote faster convergence + // to a fixed point (we try to visit uses before defs). + let block = block.borrow(); + for op in block.body() { + self.propagate_liveness(&op); + } + + // We currently do not remove entry block arguments, so there is no need to track their + // liveness. + // + // TODO(pauls): We could track these and enable removing dead operands/arguments from + // region control flow operations in the future. + if block.is_entry_block() { + continue; + } + + for arg in block.arguments() { + let arg = arg.clone().upcast(); + if !self.was_proven_live(&arg) { + self.process_value(arg); + } + } + } + } + + pub fn propagate_liveness(&mut self, op: &Operation) { + // Recurse on any regions the op has + for region in op.regions() { + self.propagate_region_liveness(®ion); + } + + // We process terminator operations separately + if op.implements::() { + return self.propagate_terminator_liveness(op); + } + + // Don't reprocess live operations. + if self.was_op_proven_live(&op.as_operation_ref()) { + return; + } + + // Process this op + if !op.would_be_trivially_dead() { + self.set_op_proved_live(op.as_operation_ref()); + } + + // If the op isn't intrinsically alive, check it's results + for result in op.results().iter() { + self.process_value(result.clone().upcast()); + } + } + + fn propagate_terminator_liveness(&mut self, op: &Operation) { + // Terminators are always live + self.set_op_proved_live(op.as_operation_ref()); + + // Check to see if we can reason about the successor operands instead + // + // If we can't reason about the operand to a successor, conservatively mark it as live + if let Some(branch_op) = op.as_trait::() { + let num_successors = op.num_successors(); + for succ_index in 0..num_successors { + let succ_operands = branch_op.get_successor_operands(succ_index); + for arg_index in 0..succ_operands.num_produced() { + let succ = op.successors()[succ_index].block.borrow().block.clone(); + let succ_arg = succ.borrow().get_argument(arg_index).upcast(); + self.set_proved_live(succ_arg); + } + } + } else { + let num_successors = op.num_successors(); + for succ_index in 0..num_successors { + let succ = op.successor(succ_index); + for arg in succ.arguments.iter() { + let arg = arg.borrow().as_value_ref(); + self.set_proved_live(arg); + } + } + } + } + + fn process_value(&mut self, value: ValueRef) { + let proved_live = value.borrow().iter_uses().any(|user| { + if self.is_use_specially_known_dead(&user) { + return false; + } + self.was_op_proven_live(&user.owner) + }); + if proved_live { + self.set_proved_live(value); + } + } +} + +impl Region { + pub fn dead_code_elimination( + regions: &[RegionRef], + rewriter: &mut dyn Rewriter, + ) -> Result<(), RegionTransformFailed> { + let mut live_map = LiveMap::default(); + loop { + live_map.mark_unchanged(); + + for region in regions { + live_map.propagate_region_liveness(®ion.borrow()); + } + + if !live_map.has_changed() { + break; + } + } + + Self::cleanup_dead_code(regions, rewriter, &live_map) + } + + /// Erase the unreachable blocks within the regions in `regions`. + /// + /// Returns `Ok` if any blocks were erased, `Err` otherwise. + pub fn erase_unreachable_blocks( + regions: &[RegionRef], + rewriter: &mut dyn crate::Rewriter, + ) -> Result<(), RegionTransformFailed> { + let mut erased_dead_blocks = false; + let mut reachable = BTreeSet::default(); + let mut worklist = VecDeque::from_iter(regions.iter().cloned()); + while let Some(mut region) = worklist.pop_front() { + let mut current_region = region.borrow_mut(); + let blocks = current_region.body_mut(); + if blocks.is_empty() { + continue; + } + + // If this is a single block region, just collect nested regions. + if blocks.front().as_pointer() == blocks.back().as_pointer() { + for op in blocks.front().get().unwrap().body() { + worklist.extend(op.regions().iter().map(|r| r.as_region_ref())); + } + continue; + } + + // Mark all reachable blocks. + reachable.clear(); + let iter = PostOrderBlockIter::new(blocks.front().as_pointer().unwrap()); + reachable.extend(iter); + + // Collect all of the dead blocks and push the live regions on the worklist + let mut cursor = blocks.front_mut(); + cursor.move_next(); + while let Some(mut block) = cursor.as_pointer() { + if reachable.contains(&block) { + // Walk any regions within this block + for op in block.borrow().body() { + worklist.extend(op.regions().iter().map(|r| r.as_region_ref())); + } + continue; + } + + // The block is unreachable, erase it + block.borrow_mut().drop_all_defined_value_uses(); + rewriter.erase_block(block); + erased_dead_blocks = true; + } + } + + if erased_dead_blocks { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } + + fn cleanup_dead_code( + regions: &[RegionRef], + rewriter: &mut dyn Rewriter, + live_map: &LiveMap, + ) -> Result<(), RegionTransformFailed> { + let mut erased_anything = false; + for region in regions { + let current_region = region.borrow(); + if current_region.body().is_empty() { + continue; + } + + let has_single_block = current_region.has_one_block(); + + // Delete every operation that is not live. Graph regions may have cycles in the use-def + // graph, so we must explicitly drop all uses from each operation as we erase it. + // Visiting the operations in post-order guarantees that in SSA CFG regions, value uses + // are removed before defs, which makes `drop_all_uses` a no-op. + let iter = PostOrderBlockIter::new(current_region.entry_block_ref().unwrap()); + for block in iter { + if !has_single_block { + Self::erase_terminator_successor_operands( + block.borrow().terminator().expect("expected block to have terminator"), + live_map, + ); + } + let mut next_op = block.borrow().body().back().as_pointer(); + while let Some(mut child_op) = next_op.take() { + next_op = child_op.prev(); + if !live_map.was_op_proven_live(&child_op) { + erased_anything = true; + child_op.borrow_mut().drop_all_uses(); + rewriter.erase_op(child_op); + } else { + let child_regions = child_op + .borrow() + .regions() + .iter() + .map(|r| r.as_region_ref()) + .collect::>(); + erased_anything |= + Self::cleanup_dead_code(&child_regions, rewriter, live_map).is_ok(); + } + } + } + + // Delete block arguments. + // + // The entry block has an unknown contract with their enclosing block, so leave it alone. + let mut region = region.clone(); + let mut current_region = region.borrow_mut(); + let mut blocks = current_region.body_mut().front_mut(); + while let Some(mut block) = blocks.as_pointer() { + blocks.move_next(); + block + .borrow_mut() + .erase_arguments(|arg| !live_map.was_proven_live(&arg.as_value_ref())); + } + } + + if erased_anything { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } + + fn erase_terminator_successor_operands(mut terminator: OperationRef, live_map: &LiveMap) { + let mut op = terminator.borrow_mut(); + if !op.implements::() { + return; + } + + // Iterate successors in reverse to minimize the amount of operand shifting + for succ_index in (0..op.num_successors()).rev() { + let mut succ = op.successor_mut(succ_index); + // Iterate arguments in reverse so that erasing an argument does not shift the others + let num_arguments = succ.arguments.len(); + for arg_index in (0..num_arguments).rev() { + if !live_map.was_proven_live(&succ.arguments[arg_index].borrow().as_value_ref()) { + succ.arguments.erase(arg_index); + } + } + } + } +} diff --git a/hir2/src/ir/region/transforms/drop_redundant_args.rs b/hir2/src/ir/region/transforms/drop_redundant_args.rs new file mode 100644 index 000000000..f31086ecd --- /dev/null +++ b/hir2/src/ir/region/transforms/drop_redundant_args.rs @@ -0,0 +1,151 @@ +use smallvec::SmallVec; + +use super::RegionTransformFailed; +use crate::{ + traits::BranchOpInterface, BlockArgumentRef, BlockRef, Region, RegionRef, Rewriter, + SuccessorOperands, Usable, +}; + +impl Region { + /// This optimization drops redundant argument to blocks. I.e., if a given argument to a block + /// receives the same value from each of the block predecessors, we can remove the argument from + /// the block and use directly the original value. + /// + /// ## Example + /// + /// A simple example: + /// + /// ```hir,ignore + /// %cond = llvm.call @rand() : () -> i1 + /// %val0 = llvm.mlir.constant(1 : i64) : i64 + /// %val1 = llvm.mlir.constant(2 : i64) : i64 + /// %val2 = llvm.mlir.constant(3 : i64) : i64 + /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2 + /// : i64) + /// + /// ^bb1(%arg0 : i64, %arg1 : i64): + /// llvm.call @foo(%arg0, %arg1) + /// ``` + /// + /// That can be rewritten as: + /// + /// ```hir,ignore + /// %cond = llvm.call @rand() : () -> i1 + /// %val0 = llvm.mlir.constant(1 : i64) : i64 + /// %val1 = llvm.mlir.constant(2 : i64) : i64 + /// %val2 = llvm.mlir.constant(3 : i64) : i64 + /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64) + /// + /// ^bb1(%arg0 : i64): + /// llvm.call @foo(%val0, %arg0) + /// ``` + pub(in crate::ir::region) fn drop_redundant_arguments( + regions: &[RegionRef], + rewriter: &mut dyn Rewriter, + ) -> Result<(), RegionTransformFailed> { + let mut worklist = SmallVec::<[RegionRef; 1]>::from_iter(regions.iter().cloned()); + + let mut any_changed = false; + while let Some(region) = worklist.pop() { + // Add any nested regions to the worklist + let region = region.borrow(); + let mut blocks = region.body().front(); + while let Some(block) = blocks.as_pointer() { + blocks.move_next(); + + any_changed |= + Self::drop_redundant_block_arguments(block.clone(), rewriter).is_ok(); + + for op in block.borrow().body() { + let mut regions = op.regions().front(); + while let Some(region) = regions.as_pointer() { + worklist.push(region); + regions.move_next(); + } + } + } + } + + if any_changed { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } + + /// If a block's argument is always the same across different invocations, then + /// drop the argument and use the value directly inside the block + fn drop_redundant_block_arguments( + mut block: BlockRef, + rewriter: &mut dyn Rewriter, + ) -> Result<(), RegionTransformFailed> { + let mut args_to_erase = SmallVec::<[usize; 4]>::default(); + + // Go through the arguments of the block. + let mut block_mut = block.borrow_mut(); + let block_args = + SmallVec::<[BlockArgumentRef; 2]>::from_iter(block_mut.arguments().iter().cloned()); + for (arg_index, block_arg) in block_args.into_iter().enumerate() { + let mut same_arg = true; + let mut common_value = None; + + // Go through the block predecessor and flag if they pass to the block different values + // for the same argument. + for pred in block_mut.predecessors() { + let pred_op = pred.owner.borrow(); + if let Some(branch_op) = pred_op.as_trait::() { + let succ_index = pred.index as usize; + let succ_operands = branch_op.get_successor_operands(succ_index); + let branch_operands = succ_operands.forwarded(); + let arg = branch_operands[arg_index].borrow().as_value_ref(); + if common_value.is_none() { + common_value = Some(arg); + continue; + } + if common_value.as_ref().is_some_and(|cv| cv != &arg) { + same_arg = false; + break; + } + } else { + same_arg = false; + break; + } + } + + // If they are passing the same value, drop the argument. + if let Some(common_value) = common_value { + if same_arg { + args_to_erase.push(arg_index); + + // Remove the argument from the block. + rewriter.replace_all_uses_of_value_with(block_arg, common_value); + } + } + } + + // Remove the arguments. + for arg_index in args_to_erase.iter().copied() { + block_mut.erase_argument(arg_index); + + // Remove the argument from the branch ops. + let mut preds = block_mut.uses_mut().front_mut(); + while let Some(mut pred) = preds.as_pointer() { + preds.move_next(); + + let mut pred = pred.borrow_mut(); + let mut pred_op = pred.owner.borrow_mut(); + if let Some(branch_op) = pred_op.as_trait_mut::() { + let succ_index = pred.index as usize; + let mut succ_operands = branch_op.get_successor_operands_mut(succ_index); + succ_operands.forwarded_mut().erase(arg_index); + } + } + } + + if !args_to_erase.is_empty() { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } +} diff --git a/hir2/src/ir/successor.rs b/hir2/src/ir/successor.rs index 4c50f3dd6..68668595d 100644 --- a/hir2/src/ir/successor.rs +++ b/hir2/src/ir/successor.rs @@ -7,6 +7,201 @@ pub type OpSuccessorStorage = crate::EntityStorage; pub type OpSuccessorRange<'a> = crate::EntityRange<'a, SuccessorInfo>; pub type OpSuccessorRangeMut<'a> = crate::EntityRangeMut<'a, SuccessorInfo, 0>; +/// This trait represents common behavior shared by any range of successor operands. +pub trait SuccessorOperands { + /// Returns true if there are no operands in this set + fn is_empty(&self) -> bool { + self.num_produced() == 0 && self.len() == 0 + } + /// Returns the total number of operands in this set + fn len(&self) -> usize; + /// Returns the number of internally produced operands in this set + fn num_produced(&self) -> usize; + /// Returns true if the operand at `index` is internally produced + #[inline] + fn is_operand_produced(&self, index: usize) -> bool { + index < self.num_produced() + } + /// Get the range of forwarded operands + fn forwarded(&self) -> OpOperandRange<'_>; + /// Get a [SuccessorOperand] representing the operand at `index` + /// + /// Returns `None` if the index is out of range. + fn get(&self, index: usize) -> Option { + if self.is_operand_produced(index) { + Some(SuccessorOperand::Produced) + } else { + self.forwarded() + .get(index) + .map(|op_operand| SuccessorOperand::Forwarded(op_operand.borrow().as_value_ref())) + } + } + + /// Get a [SuccessorOperand] representing the operand at `index`. + /// + /// Panics if the index is out of range. + fn get_unchecked(&self, index: usize) -> SuccessorOperand { + if self.is_operand_produced(index) { + SuccessorOperand::Produced + } else { + SuccessorOperand::Forwarded(self.forwarded()[index].borrow().as_value_ref()) + } + } + + /// Gets the index of the forwarded operand which maps to the given successor block argument + /// index. + /// + /// Panics if the given block argument index does not correspond to a forwarded operand. + fn get_operand_index(&self, block_argument_index: usize) -> usize { + assert!( + self.is_operand_produced(block_argument_index), + "cannot map operands produced by the operation" + ); + let base_index = self.forwarded()[0].borrow().index as usize; + base_index + (block_argument_index - self.num_produced()) + } +} + +/// This type models how operands are forwarded to block arguments in control flow. It consists of a +/// number, denoting how many of the successor block arguments are produced by the operation, +/// followed by a range of operands that are forwarded. The produced operands are passed to the +/// first few block arguments of the successor, followed by the forwarded operands. It is +/// unsupported to pass them in a different order. +/// +/// An example operation with both of these concepts would be a branch-on-error operation, that +/// internally produces an error object on the error path: +/// +/// ```hir,ignore +/// invoke %function(%0) +/// label ^success ^error(%1 : i32) +/// +/// ^error(%e: !error, %arg0 : i32): +/// ... +/// ``` +/// +/// This operation would return an instance of [SuccessorOperands] with a produced operand count of +/// 1 (mapped to `%e` in the successor) and forwarded operands consisting of `%1` in the example +/// above (mapped to `%arg0` in the successor). +pub struct SuccessorOperandRange<'a> { + /// The explicit op operands which are to be passed along to the successor + forwarded: OpOperandRange<'a>, + /// The number of operands that are produced internally within the operation and which are to + /// be passed to the successor before any forwarded operands. + num_produced: usize, +} +impl<'a> SuccessorOperandRange<'a> { + /// Create an empty successor operand set + pub fn empty() -> Self { + Self { + forwarded: OpOperandRange::empty(), + num_produced: 0, + } + } + + /// Create a successor operand set consisting solely of forwarded op operands + #[inline] + pub const fn forward(forwarded: OpOperandRange<'a>) -> Self { + Self { + forwarded, + num_produced: 0, + } + } + + /// Create a successor operand set consisting solely of `num_produced` internally produced + /// results + pub fn produced(num_produced: usize) -> Self { + Self { + forwarded: OpOperandRange::empty(), + num_produced, + } + } + + /// Create a new successor operand set with the given number of internally produced results, + /// and forwarded op operands. + #[inline] + pub const fn new(num_produced: usize, forwarded: OpOperandRange<'a>) -> Self { + Self { + forwarded, + num_produced, + } + } +} +impl<'a> SuccessorOperands for SuccessorOperandRange<'a> { + #[inline] + fn len(&self) -> usize { + self.num_produced + self.forwarded.len() + } + + #[inline(always)] + fn num_produced(&self) -> usize { + self.num_produced + } + + fn forwarded(&self) -> OpOperandRange<'_> { + self.forwarded.clone() + } +} + +/// The mutable variant of [SuccessorOperandsRange]. +pub struct SuccessorOperandRangeMut<'a> { + /// The explicit op operands which are to be passed along to the successor + forwarded: OpOperandRangeMut<'a>, + /// The number of operands that are produced internally within the operation and which are to + /// be passed to the successor before any forwarded operands. + num_produced: usize, +} +impl<'a> SuccessorOperandRangeMut<'a> { + /// Create a successor operand set consisting solely of forwarded op operands + #[inline] + pub const fn forward(forwarded: OpOperandRangeMut<'a>) -> Self { + Self { + forwarded, + num_produced: 0, + } + } + + /// Create a new successor operand set with the given number of internally produced results, + /// and forwarded op operands. + #[inline] + pub const fn new(num_produced: usize, forwarded: OpOperandRangeMut<'a>) -> Self { + Self { + forwarded, + num_produced, + } + } + + #[inline(always)] + pub fn forwarded_mut(&mut self) -> &mut OpOperandRangeMut<'a> { + &mut self.forwarded + } +} +impl<'a> SuccessorOperands for SuccessorOperandRangeMut<'a> { + #[inline] + fn len(&self) -> usize { + self.num_produced + self.forwarded.len() + } + + #[inline(always)] + fn num_produced(&self) -> usize { + self.num_produced + } + + #[inline(always)] + fn forwarded(&self) -> OpOperandRange<'_> { + self.forwarded.as_immutable() + } +} + +/// Represents an operand in a [SuccessorOperands] set. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SuccessorOperand { + /// This operand is internally produced by the operation, and passed to the successor before + /// any forwarded op operands. + Produced, + /// This operand is a forwarded operand of the operation. + Forwarded(crate::ValueRef), +} + /// This trait represents successor-like values for operations, with support for control-flow /// predicated on a "key", a sentinel value that must match in order for the successor block to be /// taken. @@ -134,6 +329,13 @@ impl<'a, T> KeyedSuccessorRange<'a, T> { } }) } + + pub fn iter(&self) -> KeyedSuccessorRangeIter<'a, '_, T> { + KeyedSuccessorRangeIter { + range: self, + index: 0, + } + } } pub struct KeyedSuccessorRangeMut<'a, T> { @@ -173,6 +375,24 @@ impl<'a, T> KeyedSuccessorRangeMut<'a, T> { } } +pub struct KeyedSuccessorRangeIter<'a, 'b: 'a, T> { + range: &'b KeyedSuccessorRange<'a, T>, + index: usize, +} +impl<'a, 'b: 'a, T> Iterator for KeyedSuccessorRangeIter<'a, 'b, T> { + type Item = SuccessorWithKey<'b, T>; + + fn next(&mut self) -> Option { + if self.index >= self.range.range.len() { + return None; + } + + let idx = self.index; + self.index += 1; + self.range.get(idx) + } +} + pub struct SuccessorWithKey<'a, T> { info: &'a SuccessorInfo, operands: OpOperandRange<'a>, diff --git a/hir2/src/ir/traits.rs b/hir2/src/ir/traits.rs index d5359b499..18920e1d2 100644 --- a/hir2/src/ir/traits.rs +++ b/hir2/src/ir/traits.rs @@ -5,8 +5,11 @@ mod types; use midenc_session::diagnostics::Severity; pub(crate) use self::info::TraitInfo; -pub use self::types::*; -use crate::{derive, Context, Operation, Report, Spanned}; +pub use self::{ + foldable::{FoldResult, Foldable, OpFoldResult}, + types::*, +}; +use crate::{derive, AttributeValue, Context, Operation, Report, Spanned}; /// Marker trait for commutative ops, e.g. `X op Y == Y op X` pub trait Commutative {} @@ -59,42 +62,117 @@ pub trait HasOnlyGraphRegion {} /// /// This trait _cannot_ be derived via `derive!` pub trait GraphRegionNoTerminator: - NoTerminator + SingleBlock + RegionKindInterface + HasOnlyGraphRegion + NoTerminator + SingleBlock + crate::RegionKindInterface + HasOnlyGraphRegion { } -/// Represents the types of regions that can be represented in the IR -#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] -pub enum RegionKind { - /// A graph region is one without control-flow semantics, i.e. dataflow between operations is - /// the only thing that dictates order, and operations can be conceptually executed in parallel - /// if the runtime supports it. +// TODO(pauls): Implement verifier +/// This interface provides information for branching terminator operations, i.e. terminator +/// operations with successors. +/// +/// This interface is meant to model well-defined cases of control-flow of value propagation, where +/// what occurs along control-flow edges is assumed to be side-effect free. For example, +/// corresponding successor operands and successor block arguments may have different types. In such +/// cases, `are_types_compatible` can be implemented to compare types along control-flow edges. By +/// default, type equality is used. +pub trait BranchOpInterface: crate::Op { + /// Returns the operands that correspond to the arguments of the successor at `index`. /// - /// As there is no control-flow in these regions, graph regions may only contain a single block. - Graph, - /// An SSA region is one where the strict control-flow semantics and properties of SSA (static - /// single assignment) form must be upheld. + /// It consists of a number of operands that are internally produced by the operation, followed + /// by a range of operands that are forwarded. An example operation making use of produced + /// operands would be: /// - /// SSA regions must adhere to: + /// ```hir,ignore + /// invoke %function(%0) + /// label ^success ^error(%1 : i32) /// - /// * Values can only be defined once - /// * Definitions must dominate uses - /// * Ordering of operations in a block corresponds to execution order, i.e. operations earlier - /// in a block dominate those later in the block. - /// * Blocks must end with a terminator. - #[default] - SSA, -} + /// ^error(%e: !error, %arg0: i32): + /// ... + ///``` + /// + /// The operand that would map to the `^error`s `%e` operand is produced by the `invoke` + /// operation, while `%1` is a forwarded operand that maps to `%arg0` in the successor. + /// + /// Produced operands always map to the first few block arguments of the successor, followed by + /// the forwarded operands. Mapping them in any other order is not supported by the interface. + /// + /// By having the forwarded operands last allows users of the interface to append more forwarded + /// operands to the branch operation without interfering with other successor operands. + fn get_successor_operands(&self, index: usize) -> crate::SuccessorOperandRange<'_> { + let op = ::as_operation(self); + let operand_group = op.successors()[index].operand_group as usize; + crate::SuccessorOperandRange::forward(op.operands().group(operand_group)) + } + /// The mutable version of [Self::get_successor_operands]. + fn get_successor_operands_mut(&mut self, index: usize) -> crate::SuccessorOperandRangeMut<'_> { + let op = ::as_operation_mut(self); + let operand_group = op.successors()[index].operand_group as usize; + crate::SuccessorOperandRangeMut::forward(op.operands_mut().group_mut(operand_group)) + } + /// Returns the block argument of the successor corresponding to the operand at `operand_index`. + /// + /// Returns `None` if the specified operand is not a successor operand. + fn get_successor_block_argument( + &self, + operand_index: usize, + ) -> Option { + let op = ::as_operation(self); + let operand_groups = op.operands().num_groups(); + let mut next_index = 0usize; + for operand_group in 0..operand_groups { + let group_size = op.operands().group(operand_group).len(); + if (next_index..(next_index + group_size)).contains(&operand_index) { + let arg_index = operand_index - next_index; + // We found the operand group, now map that to a successor + let succ_info = + op.successors().iter().find(|s| operand_group == s.operand_group as usize)?; + return succ_info.block.borrow().block.borrow().arguments().get(arg_index).cloned(); + } + + next_index += group_size; + } -/// An op interface that indicates what types of regions it holds -pub trait RegionKindInterface { - /// Get the [RegionKind] for this operation - fn kind(&self) -> RegionKind; - /// Returns true if the kind of this operation's regions requires SSA dominance + None + } + /// Returns the successor that would be chosen with the given constant operands. + /// + /// Returns `None` if a single successor could not be chosen. #[inline] - fn has_ssa_dominance(&self) -> bool { - matches!(self.kind(), RegionKind::SSA) + #[allow(unused_variables)] + fn get_successor_for_operands( + &self, + operands: &[Box], + ) -> Option { + None } + /// This is called to compare types along control-flow edges. + /// + /// By default, types must be exactly equal to be compatible. + fn are_types_compatible(&self, lhs: &crate::Type, rhs: &crate::Type) -> bool { + lhs == rhs + } +} + +/// This interface provides information for select-like operations, i.e., operations that forward +/// specific operands to the output, depending on a binary condition. +/// +/// If the value of the condition is 1, then the `true` operand is returned, and the third operand +/// is ignored, even if it was poison. +/// +/// If the value of the condition is 0, then the `false` operand is returned, and the second operand +/// is ignored, even if it was poison. +/// +/// If the condition is poison, then poison is returned. +/// +/// Implementing operations can also accept shaped conditions, in which case the operation works +/// element-wise. +pub trait SelectLikeOpInterface { + /// Returns the operand that represents the boolean condition for this select-like op. + fn get_condition(&self) -> crate::ValueRef; + /// Returns the operand that would be chosen for a true condition. + fn get_true_value(&self) -> crate::ValueRef; + /// Returns the operand that would be chosen for a false condition. + fn get_false_value(&self) -> crate::ValueRef; } derive! { diff --git a/hir2/src/ir/visit.rs b/hir2/src/ir/visit.rs index 0cef13b17..64fe09412 100644 --- a/hir2/src/ir/visit.rs +++ b/hir2/src/ir/visit.rs @@ -1,43 +1,17 @@ +mod blocks; +mod searcher; +mod visitor; +mod walkable; + pub use core::ops::ControlFlow; -use crate::{ - Block, BlockRef, Op, Operation, OperationRef, Region, RegionRef, Report, Symbol, - UnsafeIntrusiveEntityRef, +pub use self::{ + blocks::{BlockIter, PostOrderBlockIter}, + searcher::Searcher, + visitor::{OpVisitor, OperationVisitor, SymbolVisitor, Visitor}, + walkable::{WalkOrder, WalkStage, Walkable}, }; - -/// A generic trait that describes visitors for all kinds -pub trait Visitor { - /// The type of output produced by visiting an item. - type Output; - - /// The function which is applied to each `T` as it is visited. - fn visit(&mut self, current: &T) -> WalkResult; -} - -/// We can automatically convert any closure of appropriate type to a `Visitor` -impl Visitor for F -where - F: FnMut(&T) -> WalkResult, -{ - type Output = U; - - #[inline] - fn visit(&mut self, op: &T) -> WalkResult { - self(op) - } -} - -/// Represents a visitor over [Operation] -pub trait OperationVisitor: Visitor {} -impl OperationVisitor for V where V: Visitor {} - -/// Represents a visitor over [Op] of type `T` -pub trait OpVisitor: Visitor {} -impl OpVisitor for V where V: Visitor {} - -/// Represents a visitor over [Symbol] -pub trait SymbolVisitor: Visitor {} -impl SymbolVisitor for V where V: Visitor {} +use crate::Report; /// A result-like type used to control traversals of a [Walkable] entity. /// @@ -128,369 +102,3 @@ impl core::ops::Try for WalkResult { } } } - -/// The traversal order for a walk of a region, block, or operation -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum WalkOrder { - PreOrder, - PostOrder, -} - -/// Encodes the current walk stage for generic walkers. -/// -/// When walking an operation, we can either choose a pre- or post-traversal walker which invokes -/// the callback on an operation before/after all its attached regions have been visited, or choose -/// a generic walker where the callback is invoked on the operation N+1 times, where N is the number -/// of regions attached to that operation. [WalkStage] encodes the current stage of the walk, i.e. -/// which regions have already been visited, and the callback accepts an additional argument for -/// the current stage. Such generic walkers that accept stage-aware callbacks are only applicable -/// when the callback operations on an operation (i.e. doesn't apply to callbacks on blocks or -/// regions). -#[derive(Clone, PartialEq, Eq)] -pub struct WalkStage { - /// The number of regions in the operation - num_regions: usize, - /// The next region to visit in the operation - next_region: usize, -} -impl WalkStage { - pub fn new(op: OperationRef) -> Self { - let op = op.borrow(); - Self { - num_regions: op.num_regions(), - next_region: 0, - } - } - - /// Returns true if the parent operation is being visited before all regions. - #[inline] - pub fn is_before_all_regions(&self) -> bool { - self.next_region == 0 - } - - /// Returns true if the parent operation is being visited just before visiting `region` - #[inline] - pub fn is_before_region(&self, region: usize) -> bool { - self.next_region == region - } - - /// Returns true if the parent operation is being visited just after visiting `region` - #[inline] - pub fn is_after_region(&self, region: usize) -> bool { - self.next_region == region + 1 - } - - /// Returns true if the parent operation is being visited after all regions. - #[inline] - pub fn is_after_all_regions(&self) -> bool { - self.next_region == self.num_regions - } - - /// Advance the walk stage - #[inline] - pub fn advance(&mut self) { - self.next_region += 1; - } - - /// Returns the next region that will be visited - #[inline(always)] - pub const fn next_region(&self) -> usize { - self.next_region - } -} - -/// A [Walkable] is an entity which can be traversed depth-first in either pre- or post-order -/// -/// An implementation of this trait specifies a type, `T`, corresponding to the type of item being -/// walked, while `Self` is the root entity, possibly of the same type, which may contain `T`. Thus -/// traversing from the root to all of the leaves, we will visit all reachable `T` nested within -/// `Self`, possibly including itself. -pub trait Walkable { - /// Walk all `T` in `self` in a specific order, applying the given callback to each. - /// - /// This is very similar to [Walkable::walk_interruptible], except the callback has no control - /// over the traversal, and must be infallible. - #[inline] - fn walk(&self, order: WalkOrder, mut callback: F) - where - F: FnMut(UnsafeIntrusiveEntityRef), - { - let _ = self.walk_interruptible(order, |t| { - callback(t); - - WalkResult::<()>::Continue(()) - }); - } - - /// Walk all `T` in `self` using a pre-order, depth-first traversal, applying the given callback - /// to each `T`. - #[inline] - fn prewalk(&self, mut callback: F) - where - F: FnMut(UnsafeIntrusiveEntityRef), - { - let _ = self.prewalk_interruptible(|t| { - callback(t); - - WalkResult::<()>::Continue(()) - }); - } - - /// Walk all `T` in `self` using a post-order, depth-first traversal, applying the given callback - /// to each `T`. - #[inline] - fn postwalk(&self, mut callback: F) - where - F: FnMut(UnsafeIntrusiveEntityRef), - { - let _ = self.postwalk_interruptible(|t| { - callback(t); - - WalkResult::<()>::Continue(()) - }); - } - - /// Walk `self` in the given order, visiting each `T` and applying the given callback to them. - /// - /// The given callback can control the traversal using the [WalkResult] it returns: - /// - /// * `WalkResult::Skip` will skip the walk of the current item and its nested elements that - /// have not been visited already, continuing with the next item. - /// * `WalkResult::Break` will interrupt the walk, and no more items will be visited - /// * `WalkResult::Continue` will continue the walk - #[inline] - fn walk_interruptible(&self, order: WalkOrder, callback: F) -> WalkResult - where - F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult, - { - match order { - WalkOrder::PreOrder => self.prewalk_interruptible(callback), - WalkOrder::PostOrder => self.prewalk_interruptible(callback), - } - } - - /// Walk all `T` in `self` using a pre-order, depth-first traversal, applying the given callback - /// to each `T`, and determining how to proceed based on the returned [WalkResult]. - fn prewalk_interruptible(&self, callback: F) -> WalkResult - where - F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult; - - /// Walk all `T` in `self` using a post-order, depth-first traversal, applying the given callback - /// to each `T`, and determining how to proceed based on the returned [WalkResult]. - fn postwalk_interruptible(&self, callback: F) -> WalkResult - where - F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult; -} - -/// Walking regions of an [Operation], and those of all nested operations -impl Walkable for Operation { - fn prewalk_interruptible(&self, mut callback: F) -> WalkResult - where - F: FnMut(RegionRef) -> WalkResult, - { - let mut regions = self.regions().front(); - while let Some(region) = regions.as_pointer() { - regions.move_next(); - match callback(region.clone()) { - WalkResult::Continue(_) => { - let region = region.borrow(); - for block in region.body().iter() { - for op in block.body().iter() { - op.prewalk_interruptible(&mut callback)?; - } - } - } - WalkResult::Skip => continue, - result @ WalkResult::Break(_) => return result, - } - } - - WalkResult::Continue(()) - } - - fn postwalk_interruptible(&self, mut callback: F) -> WalkResult - where - F: FnMut(RegionRef) -> WalkResult, - { - let mut regions = self.regions().front(); - while let Some(region) = regions.as_pointer() { - regions.move_next(); - { - let region = region.borrow(); - for block in region.body().iter() { - for op in block.body().iter() { - op.postwalk_interruptible(&mut callback)?; - } - } - } - callback(region.clone())?; - } - - WalkResult::Continue(()) - } -} - -/// Walking blocks of an [Operation], and those of all nested operations -impl Walkable for Operation { - fn prewalk_interruptible(&self, mut callback: F) -> WalkResult - where - F: FnMut(BlockRef) -> WalkResult, - { - for region in self.regions().iter() { - let mut blocks = region.body().front(); - while let Some(block) = blocks.as_pointer() { - blocks.move_next(); - match callback(block.clone()) { - WalkResult::Continue(_) => { - let block = block.borrow(); - for op in block.body().iter() { - op.prewalk_interruptible(&mut callback)?; - } - } - WalkResult::Skip => continue, - result @ WalkResult::Break(_) => return result, - } - } - } - - WalkResult::Continue(()) - } - - fn postwalk_interruptible(&self, mut callback: F) -> WalkResult - where - F: FnMut(BlockRef) -> WalkResult, - { - for region in self.regions().iter() { - let mut blocks = region.body().front(); - while let Some(block) = blocks.as_pointer() { - blocks.move_next(); - { - let block = block.borrow(); - for op in block.body().iter() { - op.postwalk_interruptible(&mut callback)?; - } - } - callback(block.clone())?; - } - } - - WalkResult::Continue(()) - } -} - -/// Walking operations nested within an [Operation], including itself -impl Walkable for Operation { - fn prewalk_interruptible(&self, mut callback: F) -> WalkResult - where - F: FnMut(OperationRef) -> WalkResult, - { - prewalk_operation_interruptible(self, &mut callback) - } - - fn postwalk_interruptible(&self, mut callback: F) -> WalkResult - where - F: FnMut(OperationRef) -> WalkResult, - { - postwalk_operation_interruptible(self, &mut callback) - } -} - -fn prewalk_operation_interruptible(op: &Operation, callback: &mut F) -> WalkResult -where - F: FnMut(OperationRef) -> WalkResult, -{ - let result = callback(op.as_operation_ref()); - if !result.should_continue() { - return result; - } - - for region in op.regions().iter() { - for block in region.body().iter() { - let mut ops = block.body().front(); - while let Some(op) = ops.as_pointer() { - ops.move_next(); - let op = op.borrow(); - prewalk_operation_interruptible(&op, callback)?; - } - } - } - - WalkResult::Continue(()) -} - -fn postwalk_operation_interruptible(op: &Operation, callback: &mut F) -> WalkResult -where - F: FnMut(OperationRef) -> WalkResult, -{ - for region in op.regions().iter() { - for block in region.body().iter() { - let mut ops = block.body().front(); - while let Some(op) = ops.as_pointer() { - ops.move_next(); - let op = op.borrow(); - postwalk_operation_interruptible(&op, callback)?; - } - } - } - - callback(op.as_operation_ref()) -} - -/// [Searcher] is a driver for [Visitor] impls as applied to some root [Operation]. -/// -/// The searcher traverses the object graph in depth-first preorder, from operations to regions to -/// blocks to operations, etc. All nested items of an entity are visited before its siblings, i.e. -/// a region is fully visited before the next region of the same containing operation. -/// -/// This is effectively control-flow order, from an abstract interpretation perspective, i.e. an -/// actual program might only execute one region of a multi-region op, but this traversal will visit -/// all of them unless otherwise directed by a `WalkResult`. -pub struct Searcher { - visitor: V, - root: OperationRef, - _marker: core::marker::PhantomData, -} -impl> Searcher { - pub fn new(root: OperationRef, visitor: V) -> Self { - Self { - visitor, - root, - _marker: core::marker::PhantomData, - } - } -} - -impl Searcher { - pub fn visit(&mut self) -> WalkResult<>::Output> { - self.root.borrow().prewalk_interruptible(|op: OperationRef| { - let op = op.borrow(); - self.visitor.visit(&op) - }) - } -} - -impl> Searcher { - pub fn visit(&mut self) -> WalkResult<>::Output> { - self.root.borrow().prewalk_interruptible(|op: OperationRef| { - let op = op.borrow(); - if let Some(op) = op.downcast_ref::() { - self.visitor.visit(op) - } else { - WalkResult::Continue(()) - } - }) - } -} - -impl Searcher { - pub fn visit(&mut self) -> WalkResult<>::Output> { - self.root.borrow().prewalk_interruptible(|op: OperationRef| { - let op = op.borrow(); - if let Some(sym) = op.as_symbol() { - self.visitor.visit(sym) - } else { - WalkResult::Continue(()) - } - }) - } -} diff --git a/hir2/src/ir/visit/blocks.rs b/hir2/src/ir/visit/blocks.rs new file mode 100644 index 000000000..54b7e69c2 --- /dev/null +++ b/hir2/src/ir/visit/blocks.rs @@ -0,0 +1,102 @@ +use alloc::collections::BTreeSet; + +use crate::BlockRef; + +#[allow(unused_variables)] +pub trait BlockVisitor { + /// Called when a block is first reached during a depth-first traversal, i.e. called in preorder + /// + /// If this function returns `false`, none of `block`'s children will be visited. This can be + /// used to prune the traversal, e.g. confining a visit to a specific loop in the CFG. + fn on_block_reached(&mut self, from: Option<&BlockRef>, block: &BlockRef) -> bool { + true + } + + /// Called when all children of a block have been visited by the depth-first traversal, i.e. + /// called in postorder. + fn on_block_visited(&mut self, block: &BlockRef) {} +} + +impl BlockVisitor for () {} + +#[repr(transparent)] +pub struct PostOrderBlockIter(BlockIter<()>); +impl PostOrderBlockIter { + #[inline] + pub fn new(root: BlockRef) -> Self { + Self(BlockIter::new(root, ())) + } +} +impl core::iter::FusedIterator for PostOrderBlockIter {} +impl Iterator for PostOrderBlockIter { + type Item = BlockRef; + + #[inline(always)] + fn next(&mut self) -> Option { + self.0.next() + } +} + +pub struct BlockIter { + visited: BTreeSet, + // First element is the basic block, second is the index of the next child to visit, third is the number of children + stack: Vec<(BlockRef, usize, usize)>, + visitor: V, +} + +impl BlockIter { + pub fn new(from: BlockRef, visitor: V) -> Self { + let mut this = Self { + visited: Default::default(), + stack: Default::default(), + visitor, + }; + this.insert_edge(None, from.clone()); + let num_successors = from.borrow().num_successors(); + this.stack.push((from, 0, num_successors)); + this.traverse_child(); + this + } + + /// Returns true if the target of the given edge should be visited. + /// + /// Called with `None` for `from` when adding the root node. + fn insert_edge(&mut self, from: Option, to: BlockRef) -> bool { + let should_visit = self.visitor.on_block_reached(from.as_ref(), &to); + let unvisited = self.visited.insert(to); + unvisited && should_visit + } + + fn traverse_child(&mut self) { + loop { + let Some((entry, index, max)) = self.stack.last_mut() else { + break; + }; + if index == max { + break; + } + let successor = entry.borrow().get_successor(*index); + *index += 1; + let entry = entry.clone(); + if self.insert_edge(Some(entry), successor.clone()) { + // If the block is not visited.. + let num_successors = successor.borrow().num_successors(); + self.stack.push((successor, 0, num_successors)); + } + } + } +} + +impl core::iter::FusedIterator for BlockIter {} +impl Iterator for BlockIter { + type Item = BlockRef; + + fn next(&mut self) -> Option { + let (next, ..) = self.stack.pop()?; + self.visitor.on_block_visited(&next); + if !self.stack.is_empty() { + self.traverse_child(); + } + Some(next) + } +} diff --git a/hir2/src/ir/visit/searcher.rs b/hir2/src/ir/visit/searcher.rs new file mode 100644 index 000000000..138c369b6 --- /dev/null +++ b/hir2/src/ir/visit/searcher.rs @@ -0,0 +1,61 @@ +use super::{OpVisitor, OperationVisitor, SymbolVisitor, Visitor, WalkResult, Walkable}; +use crate::{Op, Operation, OperationRef, Symbol}; + +/// [Searcher] is a driver for [Visitor] impls as applied to some root [Operation]. +/// +/// The searcher traverses the object graph in depth-first preorder, from operations to regions to +/// blocks to operations, etc. All nested items of an entity are visited before its siblings, i.e. +/// a region is fully visited before the next region of the same containing operation. +/// +/// This is effectively control-flow order, from an abstract interpretation perspective, i.e. an +/// actual program might only execute one region of a multi-region op, but this traversal will visit +/// all of them unless otherwise directed by a `WalkResult`. +pub struct Searcher { + visitor: V, + root: OperationRef, + _marker: core::marker::PhantomData, +} +impl> Searcher { + pub fn new(root: OperationRef, visitor: V) -> Self { + Self { + visitor, + root, + _marker: core::marker::PhantomData, + } + } +} + +impl Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + self.visitor.visit(&op) + }) + } +} + +impl> Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + if let Some(op) = op.downcast_ref::() { + self.visitor.visit(op) + } else { + WalkResult::Continue(()) + } + }) + } +} + +impl Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + if let Some(sym) = op.as_symbol() { + self.visitor.visit(sym) + } else { + WalkResult::Continue(()) + } + }) + } +} diff --git a/hir2/src/ir/visit/visitor.rs b/hir2/src/ir/visit/visitor.rs new file mode 100644 index 000000000..4a061d734 --- /dev/null +++ b/hir2/src/ir/visit/visitor.rs @@ -0,0 +1,36 @@ +use super::WalkResult; +use crate::{Op, Operation, Symbol}; + +/// A generic trait that describes visitors for all kinds +pub trait Visitor { + /// The type of output produced by visiting an item. + type Output; + + /// The function which is applied to each `T` as it is visited. + fn visit(&mut self, current: &T) -> WalkResult; +} + +/// We can automatically convert any closure of appropriate type to a `Visitor` +impl Visitor for F +where + F: FnMut(&T) -> WalkResult, +{ + type Output = U; + + #[inline] + fn visit(&mut self, op: &T) -> WalkResult { + self(op) + } +} + +/// Represents a visitor over [Operation] +pub trait OperationVisitor: Visitor {} +impl OperationVisitor for V where V: Visitor {} + +/// Represents a visitor over [Op] of type `T` +pub trait OpVisitor: Visitor {} +impl OpVisitor for V where V: Visitor {} + +/// Represents a visitor over [Symbol] +pub trait SymbolVisitor: Visitor {} +impl SymbolVisitor for V where V: Visitor {} diff --git a/hir2/src/ir/visit/walkable.rs b/hir2/src/ir/visit/walkable.rs new file mode 100644 index 000000000..a523a4c17 --- /dev/null +++ b/hir2/src/ir/visit/walkable.rs @@ -0,0 +1,404 @@ +use super::WalkResult; +use crate::{ + Block, BlockRef, Operation, OperationRef, Region, RegionRef, UnsafeIntrusiveEntityRef, +}; + +/// The traversal order for a walk of a region, block, or operation +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum WalkOrder { + PreOrder, + PostOrder, +} + +/// Encodes the current walk stage for generic walkers. +/// +/// When walking an operation, we can either choose a pre- or post-traversal walker which invokes +/// the callback on an operation before/after all its attached regions have been visited, or choose +/// a generic walker where the callback is invoked on the operation N+1 times, where N is the number +/// of regions attached to that operation. [WalkStage] encodes the current stage of the walk, i.e. +/// which regions have already been visited, and the callback accepts an additional argument for +/// the current stage. Such generic walkers that accept stage-aware callbacks are only applicable +/// when the callback operations on an operation (i.e. doesn't apply to callbacks on blocks or +/// regions). +#[derive(Clone, PartialEq, Eq)] +pub struct WalkStage { + /// The number of regions in the operation + num_regions: usize, + /// The next region to visit in the operation + next_region: usize, +} +impl WalkStage { + pub fn new(op: OperationRef) -> Self { + let op = op.borrow(); + Self { + num_regions: op.num_regions(), + next_region: 0, + } + } + + /// Returns true if the parent operation is being visited before all regions. + #[inline] + pub fn is_before_all_regions(&self) -> bool { + self.next_region == 0 + } + + /// Returns true if the parent operation is being visited just before visiting `region` + #[inline] + pub fn is_before_region(&self, region: usize) -> bool { + self.next_region == region + } + + /// Returns true if the parent operation is being visited just after visiting `region` + #[inline] + pub fn is_after_region(&self, region: usize) -> bool { + self.next_region == region + 1 + } + + /// Returns true if the parent operation is being visited after all regions. + #[inline] + pub fn is_after_all_regions(&self) -> bool { + self.next_region == self.num_regions + } + + /// Advance the walk stage + #[inline] + pub fn advance(&mut self) { + self.next_region += 1; + } + + /// Returns the next region that will be visited + #[inline(always)] + pub const fn next_region(&self) -> usize { + self.next_region + } +} + +/// A [Walkable] is an entity which can be traversed depth-first in either pre- or post-order +/// +/// An implementation of this trait specifies a type, `T`, corresponding to the type of item being +/// walked, while `Self` is the root entity, possibly of the same type, which may contain `T`. Thus +/// traversing from the root to all of the leaves, we will visit all reachable `T` nested within +/// `Self`, possibly including itself. +pub trait Walkable { + /// Walk all `T` in `self` in a specific order, applying the given callback to each. + /// + /// This is very similar to [Walkable::walk_interruptible], except the callback has no control + /// over the traversal, and must be infallible. + #[inline] + fn walk(&self, order: WalkOrder, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.walk_interruptible(order, |t| { + callback(t); + + WalkResult::<()>::Continue(()) + }); + } + + /// Walk all `T` in `self` using a pre-order, depth-first traversal, applying the given callback + /// to each `T`. + #[inline] + fn prewalk(&self, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.prewalk_interruptible(|t| { + callback(t); + + WalkResult::<()>::Continue(()) + }); + } + + /// Walk all `T` in `self` using a post-order, depth-first traversal, applying the given callback + /// to each `T`. + #[inline] + fn postwalk(&self, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.postwalk_interruptible(|t| { + callback(t); + + WalkResult::<()>::Continue(()) + }); + } + + /// Walk `self` in the given order, visiting each `T` and applying the given callback to them. + /// + /// The given callback can control the traversal using the [WalkResult] it returns: + /// + /// * `WalkResult::Skip` will skip the walk of the current item and its nested elements that + /// have not been visited already, continuing with the next item. + /// * `WalkResult::Break` will interrupt the walk, and no more items will be visited + /// * `WalkResult::Continue` will continue the walk + #[inline] + fn walk_interruptible(&self, order: WalkOrder, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult, + { + match order { + WalkOrder::PreOrder => self.prewalk_interruptible(callback), + WalkOrder::PostOrder => self.prewalk_interruptible(callback), + } + } + + /// Walk all `T` in `self` using a pre-order, depth-first traversal, applying the given callback + /// to each `T`, and determining how to proceed based on the returned [WalkResult]. + fn prewalk_interruptible(&self, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult; + + /// Walk all `T` in `self` using a post-order, depth-first traversal, applying the given callback + /// to each `T`, and determining how to proceed based on the returned [WalkResult]. + fn postwalk_interruptible(&self, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult; +} + +/// Walking operations nested within an [Operation], including itself +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + prewalk_operation_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + postwalk_operation_interruptible(self, &mut callback) + } +} + +fn prewalk_operation_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + let result = callback(op.as_operation_ref()); + if !result.should_continue() { + return result; + } + + for region in op.regions().iter() { + for block in region.body().iter() { + let mut ops = block.body().front(); + while let Some(op) = ops.as_pointer() { + ops.move_next(); + let op = op.borrow(); + prewalk_operation_interruptible(&op, callback)?; + } + } + } + + WalkResult::Continue(()) +} + +fn postwalk_operation_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + for region in op.regions().iter() { + for block in region.body().iter() { + let mut ops = block.body().front(); + while let Some(op) = ops.as_pointer() { + ops.move_next(); + let op = op.borrow(); + postwalk_operation_interruptible(&op, callback)?; + } + } + } + + callback(op.as_operation_ref()) +} + +/// Walking regions of an [Operation], and those of all nested operations +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(RegionRef) -> WalkResult, + { + prewalk_regions_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(RegionRef) -> WalkResult, + { + postwalk_regions_interruptible(self, &mut callback) + } +} + +fn prewalk_regions_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(RegionRef) -> WalkResult, +{ + let mut regions = op.regions().front(); + while let Some(region) = regions.as_pointer() { + regions.move_next(); + match callback(region.clone()) { + WalkResult::Continue(_) => { + let region = region.borrow(); + for block in region.body().iter() { + for op in block.body().iter() { + prewalk_regions_interruptible(&op, callback)?; + } + } + } + WalkResult::Skip => continue, + result @ WalkResult::Break(_) => return result, + } + } + + WalkResult::Continue(()) +} + +fn postwalk_regions_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(RegionRef) -> WalkResult, +{ + let mut regions = op.regions().front(); + while let Some(region) = regions.as_pointer() { + regions.move_next(); + { + let region = region.borrow(); + for block in region.body().iter() { + for op in block.body().iter() { + postwalk_regions_interruptible(&op, callback)?; + } + } + } + callback(region)?; + } + + WalkResult::Continue(()) +} + +/// Walking operations nested within a [Region] +impl Walkable for Region { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + prewalk_region_operations_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + postwalk_region_operations_interruptible(self, &mut callback) + } +} + +fn prewalk_region_operations_interruptible(region: &Region, callback: &mut F) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + for block in region.body().iter() { + let mut cursor = block.body().front(); + while let Some(op) = cursor.as_pointer() { + cursor.move_next(); + match callback(op.clone()) { + WalkResult::Continue(_) => { + let op = op.borrow(); + for region in op.regions() { + prewalk_region_operations_interruptible(®ion, callback)?; + } + } + WalkResult::Skip => continue, + result @ WalkResult::Break(_) => return result, + } + } + } + + WalkResult::Continue(()) +} + +fn postwalk_region_operations_interruptible( + region: &Region, + callback: &mut F, +) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + for block in region.body().iter() { + let mut cursor = block.body().front(); + while let Some(op) = cursor.as_pointer() { + cursor.move_next(); + { + let op = op.borrow(); + for region in op.regions() { + postwalk_region_operations_interruptible(®ion, callback)?; + } + } + callback(op)?; + } + } + + WalkResult::Continue(()) +} + +/// Walking blocks of an [Operation], and those of all nested operations +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(BlockRef) -> WalkResult, + { + prewalk_blocks_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(BlockRef) -> WalkResult, + { + postwalk_blocks_interruptible(self, &mut callback) + } +} + +fn prewalk_blocks_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(BlockRef) -> WalkResult, +{ + for region in op.regions().iter() { + let mut blocks = region.body().front(); + while let Some(block) = blocks.as_pointer() { + blocks.move_next(); + match callback(block.clone()) { + WalkResult::Continue(_) => { + let block = block.borrow(); + for op in block.body().iter() { + prewalk_blocks_interruptible(&op, callback)?; + } + } + WalkResult::Skip => continue, + result @ WalkResult::Break(_) => return result, + } + } + } + + WalkResult::Continue(()) +} + +fn postwalk_blocks_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(BlockRef) -> WalkResult, +{ + for region in op.regions().iter() { + let mut blocks = region.body().front(); + while let Some(block) = blocks.as_pointer() { + blocks.move_next(); + { + let block = block.borrow(); + for op in block.body().iter() { + postwalk_blocks_interruptible(&op, callback)?; + } + } + callback(block.clone())?; + } + } + + WalkResult::Continue(()) +} diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index c9e66a4b7..fa3acb1ea 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -9,6 +9,7 @@ #![feature(rustc_attrs)] #![feature(debug_closure_helpers)] #![feature(trait_alias)] +#![feature(trait_upcasting)] #![feature(is_none_or)] #![feature(try_trait_v2)] #![feature(try_trait_v2_residual)] @@ -37,7 +38,9 @@ pub mod demangle; pub mod derive; pub mod dialects; pub mod formatter; +mod hash; mod ir; +pub mod matchers; mod patterns; pub use self::{any::AsAny, attributes::*, ir::*, patterns::*}; diff --git a/hir2/src/matchers.rs b/hir2/src/matchers.rs new file mode 100644 index 000000000..1c3a0dc73 --- /dev/null +++ b/hir2/src/matchers.rs @@ -0,0 +1,3 @@ +mod matcher; + +pub use self::matcher::*; diff --git a/hir2/src/matchers/matcher.rs b/hir2/src/matchers/matcher.rs new file mode 100644 index 000000000..818d697fd --- /dev/null +++ b/hir2/src/matchers/matcher.rs @@ -0,0 +1,770 @@ +use core::ptr::{DynMetadata, Pointee}; + +use smallvec::SmallVec; + +use crate::{ + AttributeValue, Op, OpFoldResult, OpOperand, Operation, OperationRef, UnsafeIntrusiveEntityRef, + ValueRef, +}; + +/// [Matcher] is a pattern matching abstraction with support for expressing both matching and +/// capturing semantics. +/// +/// This is used to implement low-level pattern matching primitives for the IR for use in: +/// +/// * Folding +/// * Canonicalization +/// * Regionalized transformations and analyses +pub trait Matcher { + /// The value type produced as a result of a successful match + /// + /// Use `()` if this matcher does not capture any value, and simply signals whether or not + /// the pattern was matched. + type Matched; + + /// Check if `entity` is matched by this matcher, returning `Self::Matched` if successful. + fn matches(&self, entity: &T) -> Option; +} + +#[repr(transparent)] +pub struct MatchWith(pub F); +impl Matcher for MatchWith +where + F: Fn(&T) -> Option, +{ + type Matched = U; + + #[inline(always)] + fn matches(&self, entity: &T) -> Option { + (self.0)(entity) + } +} + +/// A match combinator representing the logical AND of two sub-matchers. +/// +/// Both patterns must match on the same IR entity, but only the matched value of `B` is returned, +/// i.e. the captured result of `A` is discarded. +/// +/// Returns the result of matching `B` if successful, otherwise `None` +pub struct AndMatcher { + a: A, + b: B, +} + +impl AndMatcher { + pub const fn new(a: A, b: B) -> Self { + Self { a, b } + } +} + +impl Matcher for AndMatcher +where + A: Matcher, + B: Matcher, +{ + type Matched = >::Matched; + + #[inline] + fn matches(&self, entity: &T) -> Option { + self.a.matches(entity).and_then(|_| self.b.matches(entity)) + } +} + +/// A match combinator representing a monadic bind of two patterns. +/// +/// In other words, given two patterns `A` and `B`: +/// +/// * `A` is matched, and if it fails, the entire match fails. +/// * `B` is then matched against the output of `A`, and if it fails, the entire match fails +/// * Both matches were successful, and the output of `B` is returned as the final result. +pub struct ChainMatcher { + a: A, + b: B, +} + +impl ChainMatcher { + pub const fn new(a: A, b: B) -> Self { + Self { a, b } + } +} + +impl Matcher for ChainMatcher +where + A: Matcher, + B: Matcher, +{ + type Matched = >::Matched; + + #[inline] + fn matches(&self, entity: &T) -> Option { + self.a.matches(entity).and_then(|matched| self.b.matches(&matched)) + } +} + +/// Matches operations which implement some trait `Trait`, capturing the match as a trait object. +/// +/// NOTE: `Trait` must be an object-safe trait. +pub struct OpTraitMatcher { + _marker: core::marker::PhantomData, +} + +impl Default for OpTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl OpTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + /// Create a new [OpTraitMatcher] from the given matcher. + pub const fn new() -> Self { + Self { + _marker: core::marker::PhantomData, + } + } +} + +impl Matcher for OpTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + type Matched = UnsafeIntrusiveEntityRef; + + fn matches(&self, entity: &Operation) -> Option { + entity + .as_trait::() + .map(|op| unsafe { UnsafeIntrusiveEntityRef::from_raw(op) }) + } +} + +/// Matches operations which implement some trait `Trait`. +/// +/// Returns a type-erased operation ref, not a trait object like [OpTraitMatcher] +pub struct HasTraitMatcher { + _marker: core::marker::PhantomData, +} + +impl Default for HasTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl HasTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + /// Create a new [HasTraitMatcher] from the given matcher. + pub const fn new() -> Self { + Self { + _marker: core::marker::PhantomData, + } + } +} + +impl Matcher for HasTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + type Matched = OperationRef; + + fn matches(&self, entity: &Operation) -> Option { + if !entity.implements::() { + return None; + } + Some(entity.as_operation_ref()) + } +} + +/// Matches any operation with an attribute named `name`, the value of which matches a matcher of +/// type `M`. +pub struct OpAttrMatcher { + name: &'static str, + matcher: M, +} + +impl OpAttrMatcher +where + M: Matcher, +{ + /// Create a new [OpAttrMatcher] from the given attribute name and matcher. + pub const fn new(name: &'static str, matcher: M) -> Self { + Self { name, matcher } + } +} + +impl Matcher for OpAttrMatcher +where + M: Matcher, +{ + type Matched = >::Matched; + + fn matches(&self, entity: &Operation) -> Option { + entity.get_attribute(self.name).and_then(|value| self.matcher.matches(value)) + } +} + +/// Matches any operation with an attribute `name` and concrete value type of `A`. +/// +/// Binds the value as its concrete type `A`. +pub type TypedOpAttrMatcher = OpAttrMatcher>; + +/// Matches and binds any attribute value whose concrete type is `A`. +pub struct TypedAttrMatcher(core::marker::PhantomData); +impl Default for TypedAttrMatcher { + #[inline(always)] + fn default() -> Self { + Self(core::marker::PhantomData) + } +} +impl Matcher for TypedAttrMatcher { + type Matched = A; + + #[inline] + fn matches(&self, entity: &dyn AttributeValue) -> Option { + entity.downcast_ref::().cloned() + } +} + +/// A matcher for operations that always succeeds, binding the operation reference in the process. +struct AnyOpMatcher; +impl Matcher for AnyOpMatcher { + type Matched = OperationRef; + + #[inline(always)] + fn matches(&self, entity: &Operation) -> Option { + Some(entity.as_operation_ref()) + } +} + +/// A matcher for operations whose concrete type is `T`, binding the op with a strongly-typed +/// reference. +struct OneOpMatcher(core::marker::PhantomData); +impl OneOpMatcher { + pub const fn new() -> Self { + Self(core::marker::PhantomData) + } +} +impl Matcher for OneOpMatcher { + type Matched = UnsafeIntrusiveEntityRef; + + #[inline(always)] + fn matches(&self, entity: &Operation) -> Option { + entity + .downcast_ref::() + .map(|op| unsafe { UnsafeIntrusiveEntityRef::from_raw(op) }) + } +} + +/// A matcher for values that always succeeds, binding the value reference in the process. +struct AnyValueMatcher; +impl Matcher for AnyValueMatcher { + type Matched = ValueRef; + + #[inline(always)] + fn matches(&self, entity: &ValueRef) -> Option { + Some(entity.clone()) + } +} + +/// A matcher that only succeeds if it matches exactly the provided value. +struct ExactValueMatcher(ValueRef); +impl Matcher for ExactValueMatcher { + type Matched = ValueRef; + + #[inline(always)] + fn matches(&self, entity: &ValueRef) -> Option { + if ValueRef::ptr_eq(&self.0, entity) { + Some(entity.clone()) + } else { + None + } + } +} + +/// A matcher for operations that implement [crate::traits::ConstantLike] +type ConstantOpMatcher = HasTraitMatcher; + +/// Like [ConstantOpMatcher], this matcher matches constant operations, but rather than binding +/// the operation itself, it binds the constant value produced by the operation. +#[derive(Default)] +struct ConstantOpBinder; +impl Matcher for ConstantOpBinder { + type Matched = Box; + + fn matches(&self, entity: &Operation) -> Option { + use crate::traits::Foldable; + + if !entity.implements::() { + return None; + } + + let mut out = SmallVec::default(); + entity.fold(&mut out).expect("expected constant-like op to be foldable"); + let Some(OpFoldResult::Attribute(value)) = out.pop() else { + return None; + }; + + Some(value) + } +} + +/// An extension of [ConstantOpBinder] which only matches constant values of type `T` +struct TypedConstantOpBinder(core::marker::PhantomData); +impl TypedConstantOpBinder { + pub const fn new() -> Self { + Self(core::marker::PhantomData) + } +} +impl Matcher for TypedConstantOpBinder { + type Matched = T; + + fn matches(&self, entity: &Operation) -> Option { + ConstantOpBinder.matches(entity).and_then(|value| { + if !value.is::() { + None + } else { + Some(unsafe { + let raw = Box::into_raw(value); + *Box::from_raw(raw as *mut T) + }) + } + }) + } +} + +/// Matches operations which implement [crate::traits::UnaryOp] and binds the operand. +#[derive(Default)] +struct UnaryOpBinder; +impl Matcher for UnaryOpBinder { + type Matched = OpOperand; + + fn matches(&self, entity: &Operation) -> Option { + if !entity.implements::() { + return None; + } + + Some(entity.operands()[0].borrow().as_operand_ref()) + } +} + +/// Matches operations which implement [crate::traits::BinaryOp] and binds both operands. +#[derive(Default)] +struct BinaryOpBinder; +impl Matcher for BinaryOpBinder { + type Matched = [OpOperand; 2]; + + fn matches(&self, entity: &Operation) -> Option { + if !entity.implements::() { + return None; + } + + let operands = entity.operands(); + let lhs = operands[0].borrow().as_operand_ref(); + let rhs = operands[1].borrow().as_operand_ref(); + + Some([lhs, rhs]) + } +} + +/// Converts the output of [UnaryOpBinder] to an OpFoldResult, by checking if the operand definition +/// is a constant-like op, and either binding the constant value, or the SSA value used as the +/// operand. +/// +/// This can be used to set up for folding. +struct FoldResultBinder; +impl Matcher for FoldResultBinder { + type Matched = OpFoldResult; + + fn matches(&self, operand: &OpOperand) -> Option { + let operand = operand.borrow(); + let maybe_constant = operand + .value() + .get_defining_op() + .and_then(|defining_op| constant().matches(&defining_op.borrow())); + if let Some(const_operand) = maybe_constant { + Some(OpFoldResult::Attribute(const_operand)) + } else { + Some(OpFoldResult::Value(operand.as_value_ref())) + } + } +} + +/// Converts the output of [BinaryOpBinder] to a pair of OpFoldResults, by checking if the operand +/// definitions are constant, and either binding the constant values, or the SSA values used by each +/// operand. +/// +/// This can be used to set up for folding. +struct BinaryFoldResultBinder; +impl Matcher<[OpOperand; 2]> for BinaryFoldResultBinder { + type Matched = [OpFoldResult; 2]; + + fn matches(&self, operands: &[OpOperand; 2]) -> Option { + let binder = FoldResultBinder; + + let lhs = binder.matches(&operands[0]).unwrap(); + let rhs = binder.matches(&operands[1]).unwrap(); + + Some([lhs, rhs]) + } +} + +/// Matches the operand of a unary op to determine if it is a candidate for folding. +/// +/// A successful match binds the constant value of the operand for use by the [Foldable] impl. +struct FoldableOperandBinder; +impl Matcher for FoldableOperandBinder { + type Matched = Box; + + fn matches(&self, operand: &OpOperand) -> Option { + let operand = operand.borrow(); + let defining_op = operand.value().get_defining_op()?; + constant().matches(&defining_op.borrow()) + } +} + +struct TypedFoldableOperandBinder(core::marker::PhantomData); +impl Default for TypedFoldableOperandBinder { + fn default() -> Self { + Self(core::marker::PhantomData) + } +} +impl Matcher for TypedFoldableOperandBinder { + type Matched = Box; + + fn matches(&self, operand: &OpOperand) -> Option { + FoldableOperandBinder + .matches(operand) + .and_then(|value| value.downcast::().ok()) + } +} + +/// Matches the operands of a binary op to determine if it is a candidate for folding. +/// +/// A successful match binds the constant value of the operands for use by the [Foldable] impl. +/// +/// NOTE: Both operands must be constant for this to match. Use [BinaryFoldResultBinder] if you +/// wish to let the [Foldable] impl decide what to do in the presence of mixed constant and non- +/// constant operands. +struct FoldableBinaryOpBinder; +impl Matcher<[OpOperand; 2]> for FoldableBinaryOpBinder { + type Matched = [Box; 2]; + + fn matches(&self, operands: &[OpOperand; 2]) -> Option { + let binder = FoldableOperandBinder; + let lhs = binder.matches(&operands[0])?; + let rhs = binder.matches(&operands[1])?; + + Some([lhs, rhs]) + } +} + +// Match Combinators + +/// Matches both `a` and `b`, or fails +pub const fn match_both( + a: A, + b: B, +) -> impl Matcher>::Matched> +where + A: Matcher, + B: Matcher, +{ + AndMatcher::new(a, b) +} + +/// Matches `a` and if successful, matches `b` against the output of `a`, or fails. +pub const fn match_chain( + a: A, + b: B, +) -> impl Matcher>::Matched> +where + A: Matcher, + B: Matcher, +{ + ChainMatcher::new(a, b) +} + +// Operation Matchers + +/// Matches any operation, i.e. it always matches +/// +/// Returns a type-erased operation reference +pub const fn match_any() -> impl Matcher { + AnyOpMatcher +} + +/// Matches any operation whose concrete type is `T` +/// +/// Returns a strongly-typed op reference +pub const fn match_op() -> impl Matcher> { + OneOpMatcher::::new() +} + +/// Matches any operation that implements [crate::traits::ConstantLike]. +/// +/// These operations return a single result, and must be pure (no side effects) +pub const fn constant_like() -> impl Matcher { + ConstantOpMatcher::new() +} + +// Constant Value Binders + +/// Matches any operation that implements [crate::traits::ConstantLike], and binds the constant +/// value as the result of the match. +pub const fn constant() -> impl Matcher> { + ConstantOpBinder +} + +/// Like [constant], but only matches if the constant value has the concrete type `T`. +/// +/// Typically, constant values will be [crate::Immediate], but any attribute value can be matched. +pub const fn constant_of() -> impl Matcher { + TypedConstantOpBinder::new() +} + +// Value Binders + +/// Matches any unary operation (i.e. implements [crate::traits::UnaryOp]), and binds its operand. +pub const fn unary() -> impl Matcher { + UnaryOpBinder +} + +/// Matches any unary operation (i.e. implements [crate::traits::UnaryOp]), and binds its operand +/// as an [OpFoldResult]. +/// +/// This is done by examining the defining op of the operand to determine if it is a constant, and +/// if so, it binds the constant value, rather than the SSA value. +/// +/// This can be used to setup for folding. +pub const fn unary_fold_result() -> impl Matcher { + match_chain(UnaryOpBinder, FoldResultBinder) +} + +/// Matches any unary operation (i.e. implements [crate::traits::UnaryOp]) whose operand is a +/// materialized constant, and thus a prime candidate for folding. +/// +/// The constant value is bound by this matcher, so it can be used immediately for folding. +pub const fn unary_foldable() -> impl Matcher> { + match_chain(UnaryOpBinder, FoldableOperandBinder) +} + +/// Matches any binary operation (i.e. implements [crate::traits::BinaryOp]), and binds its operands. +pub const fn binary() -> impl Matcher { + BinaryOpBinder +} + +/// Matches any binary operation (i.e. implements [crate::traits::BinaryOp]), and binds its operands +/// as [OpFoldResult]s. +/// +/// This is done by examining the defining op of the operands to determine if they are constant, and +/// if so, binds the constant value, rather than the SSA value. +/// +/// This can be used to setup for folding. +pub const fn binary_fold_results() -> impl Matcher { + match_chain(BinaryOpBinder, BinaryFoldResultBinder) +} + +/// Matches any binary operation (i.e. implements [crate::traits::BinaryOp]) whose operands are +/// both materialized constants, and thus a prime candidate for folding. +/// +/// The constant values are bound by this matcher, so they can be used immediately for folding. +pub const fn binary_foldable() -> impl Matcher; 2]> { + match_chain(BinaryOpBinder, FoldableBinaryOpBinder) +} + +// Value Matchers + +/// Matches any value, i.e. it always matches +pub const fn match_any_value() -> impl Matcher { + AnyValueMatcher +} + +/// Matches any instance of `value`, i.e. it requires an exact match +pub const fn match_value(value: ValueRef) -> impl Matcher { + ExactValueMatcher(value) +} + +pub const fn foldable_operand() -> impl Matcher> { + FoldableOperandBinder +} + +pub const fn foldable_operand_of() -> impl Matcher> +where + T: AttributeValue + Clone, +{ + TypedFoldableOperandBinder(core::marker::PhantomData) +} + +#[cfg(test)] +mod tests { + use alloc::rc::Rc; + + use super::*; + use crate::{ + dialects::hir::{InstBuilder, *}, + *, + }; + + #[test] + fn matcher_match_any_value() { + let context = Rc::new(Context::default()); + + let (lhs, rhs, sum) = setup(context.clone()); + + // All three values should `match_any_value` + for value in [&lhs, &rhs, &sum] { + assert_eq!(match_any_value().matches(value).as_ref(), Some(value)); + } + } + + #[test] + fn matcher_match_value() { + let context = Rc::new(Context::default()); + + let (lhs, rhs, sum) = setup(context.clone()); + + // All three values should match themselves via `match_value` + for value in [&lhs, &rhs, &sum] { + assert_eq!(match_value(value.clone()).matches(value).as_ref(), Some(value)); + } + } + + #[test] + fn matcher_match_any() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, sum) = setup(context.clone()); + + // We should be able to match `lhs` and `sum` ops using `match_any` + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + for op in [&lhs_op, &sum_op] { + assert_eq!(match_any().matches(&op.borrow()).as_ref(), Some(op)); + } + } + + #[test] + fn matcher_match_op() { + let context = Rc::new(Context::default()); + + let (lhs, rhs, sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + assert!(rhs.borrow().get_defining_op().is_none()); + + // Both `lhs` and `sum` ops should be matched as their respective operation types, and not + // as a different operation type + assert!(match_op::().matches(&lhs_op.borrow()).is_some()); + assert!(match_op::().matches(&sum_op.borrow()).is_none()); + assert!(match_op::().matches(&lhs_op.borrow()).is_none()); + assert!(match_op::().matches(&sum_op.borrow()).is_some()); + } + + #[test] + fn matcher_match_both() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, _sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + + // Ensure if the first matcher fails, then the whole match fails + assert!(match_both(match_op::(), constant_of::()) + .matches(&lhs_op.borrow()) + .is_none()); + // Ensure if the second matcher fails, then the whole match fails + assert!(match_both(constant_like(), constant_of::()) + .matches(&lhs_op.borrow()) + .is_none()); + // Ensure that if both matchers would succeed, then the whole match succeeds + assert!(match_both(constant_like(), constant_of::()) + .matches(&lhs_op.borrow()) + .is_some()); + } + + #[test] + fn matcher_match_chain() { + let context = Rc::new(Context::default()); + + let (_, rhs, sum) = setup(context.clone()); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + let [lhs_fr, rhs_fr] = binary_fold_results() + .matches(&sum_op.borrow()) + .expect("expected to bind both operands of 'add'"); + assert_eq!(lhs_fr, OpFoldResult::Attribute(Box::new(Immediate::U32(1)))); + assert_eq!(rhs_fr, OpFoldResult::Value(rhs)); + } + + #[test] + fn matcher_constant_like() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + // Only `lhs` should be matched by `constant_like` + assert!(constant_like().matches(&lhs_op.borrow()).is_some()); + assert!(constant_like().matches(&sum_op.borrow()).is_none()); + } + + #[test] + fn matcher_constant() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + // Only `lhs` should produce a matching constant value + assert!(constant().matches(&lhs_op.borrow()).is_some()); + assert!(constant().matches(&sum_op.borrow()).is_none()); + } + + #[test] + fn matcher_constant_of() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + // `lhs` should produce a matching constant value of the correct type and value + assert_eq!(constant_of::().matches(&lhs_op.borrow()), Some(Immediate::U32(1))); + assert!(constant_of::().matches(&sum_op.borrow()).is_none()); + } + + fn setup(context: Rc) -> (ValueRef, ValueRef, ValueRef) { + let mut builder = OpBuilder::new(Rc::clone(&context)); + + let mut function = { + let builder = builder.create::(SourceSpan::default()); + let id = Ident::new("test".into(), SourceSpan::default()); + let signature = Signature::new([AbiParam::new(Type::U32)], [AbiParam::new(Type::U32)]); + builder(id, signature).unwrap() + }; + + // Define function body + let mut func = function.borrow_mut(); + let mut builder = FunctionBuilder::new(&mut func); + let lhs = builder.ins().u32(1, SourceSpan::default()).unwrap(); + let block = builder.current_block(); + let rhs = block.borrow().arguments()[0].clone().upcast(); + let sum = builder.ins().add(lhs.clone(), rhs.clone(), SourceSpan::default()).unwrap(); + builder.ins().ret(Some(sum.clone()), SourceSpan::default()).unwrap(); + + (lhs, rhs, sum) + } +} From 72c4f12b05007d0afb742e9f8a44fdf27d904f27 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 20:16:27 -0400 Subject: [PATCH 20/31] wip: implement op folder --- hir2/src/folder.rs | 488 +++++++++++++++++++++++++++++++++++++++++++++ hir2/src/lib.rs | 13 +- 2 files changed, 500 insertions(+), 1 deletion(-) create mode 100644 hir2/src/folder.rs diff --git a/hir2/src/folder.rs b/hir2/src/folder.rs new file mode 100644 index 000000000..acbf39fee --- /dev/null +++ b/hir2/src/folder.rs @@ -0,0 +1,488 @@ +use alloc::{collections::BTreeMap, rc::Rc}; + +use rustc_hash::FxHashMap; +use smallvec::{smallvec, SmallVec}; + +use crate::{ + matchers::Matcher, + traits::{ConstantLike, Foldable, IsolatedFromAbove}, + AttributeValue, BlockRef, Builder, Context, Dialect, FoldResult, OpFoldResult, OperationRef, + RegionRef, Rewriter, RewriterImpl, RewriterListener, SourceSpan, Spanned, Type, Value, + ValueRef, +}; + +/// Represents a constant value uniqued by dialect, value, and type. +struct UniquedConstant { + dialect: Rc, + value: Box, + ty: Type, +} +impl Eq for UniquedConstant {} +impl PartialEq for UniquedConstant { + fn eq(&self, other: &Self) -> bool { + use core::hash::{Hash, Hasher}; + + let v1_hash = { + let mut hasher = rustc_hash::FxHasher::default(); + self.value.hash(&mut hasher); + hasher.finish() + }; + let v2_hash = { + let mut hasher = rustc_hash::FxHasher::default(); + other.value.hash(&mut hasher); + hasher.finish() + }; + + self.dialect.name() == other.dialect.name() && v1_hash == v2_hash && self.ty == other.ty + } +} +impl UniquedConstant { + pub fn new(op: &OperationRef, value: Box) -> Self { + let op = op.borrow(); + let dialect = op.dialect(); + let ty = op.results()[0].borrow().ty().clone(); + + Self { dialect, value, ty } + } +} +impl core::hash::Hash for UniquedConstant { + fn hash(&self, state: &mut H) { + self.dialect.name().hash(state); + self.value.hash(state); + self.ty.hash(state); + } +} + +/// A map of uniqued constants, to their defining operations +type ConstantMap = FxHashMap; + +/// The [OperationFolder] is responsible for orchestrating operation folding, and the effects that +/// folding an operation has on the containing region. +/// +/// It handles the following details related to operation folding: +/// +/// * Attempting to fold the operation itself +/// * Materializing constants +/// * Uniquing/de-duplicating materialized constants, including moving them up the CFG to ensure +/// that new uses of a uniqued constant are dominated by the constant definition. +/// * Removing folded operations (or cleaning up failed attempts), and notifying any listeners of +/// those actions. +pub struct OperationFolder { + rewriter: Box, + /// A mapping between an insertion region and the constants that have been created within it. + scopes: BTreeMap, + /// This map tracks all of the dialects that an operation is referenced by; given that multiple + /// dialects may generate the same constant. + referenced_dialects: BTreeMap; 1]>>, + /// The location to use for folder-owned constants + erased_folded_location: SourceSpan, +} +impl OperationFolder { + pub fn new(context: Rc, listener: L) -> Self + where + L: RewriterListener, + { + Self { + rewriter: Box::new(RewriterImpl::::new(context).with_listener(listener)), + scopes: Default::default(), + referenced_dialects: Default::default(), + erased_folded_location: SourceSpan::UNKNOWN, + } + } + + /// Tries to perform folding on `op`, including unifying de-duplicated constants. + /// + /// If successful, replaces uses of `op`'s results with the folded results, and returns + /// a [FoldResult]. + pub fn try_fold(&mut self, mut op: OperationRef) -> FoldResult { + // If this is a uniqued constant, return failure as we know that it has already been + // folded. + if self.is_folder_owned_constant(&op) { + // Check to see if we should rehoist, i.e. if a non-constant operation was inserted + // before this one. + let block = op.borrow().parent().unwrap(); + if block.borrow().front().unwrap() != op + && !self.is_folder_owned_constant(&op.prev().unwrap()) + { + let mut op = op.borrow_mut(); + op.move_before(crate::ProgramPoint::Block(block)); + op.set_span(self.erased_folded_location); + } + return FoldResult::Failed; + } + + // Try to fold the operation + let mut fold_results = SmallVec::default(); + match op.borrow_mut().fold(&mut fold_results) { + FoldResult::InPlace => { + // Folding API does not notify listeners, so we need to do so manually + self.rewriter.notify_operation_modified(op); + + FoldResult::InPlace + } + FoldResult::Ok(_) => { + assert!( + !fold_results.is_empty(), + "expected non-empty fold results from a successful fold" + ); + if let FoldResult::Ok(replacements) = + self.process_fold_results(op.clone(), &fold_results) + { + // Constant folding succeeded. Replace all of the result values and erase the + // operation. + self.notify_removal(op.clone()); + self.rewriter.replace_op_with_values(op, &replacements); + FoldResult::Ok(()) + } else { + FoldResult::Failed + } + } + failed @ FoldResult::Failed => failed, + } + } + + /// Try to process a set of fold results. + /// + /// Returns the folded values if successful. + fn process_fold_results( + &mut self, + op: OperationRef, + fold_results: &[OpFoldResult], + ) -> FoldResult> { + let borrowed_op = op.borrow(); + assert_eq!(fold_results.len(), borrowed_op.num_results()); + + // Create a builder to insert new operations into the entry block of the insertion region + let insert_region = get_insertion_region(borrowed_op.parent().unwrap()); + let entry = insert_region.borrow().entry_block_ref().unwrap(); + self.rewriter.set_insertion_point_to_start(entry); + + // Create the result constants and replace the results. + let dialect = borrowed_op.dialect(); + let mut out = SmallVec::default(); + for (op_result, fold_result) in borrowed_op.results().iter().zip(fold_results) { + match fold_result { + // Check if the result was an SSA value. + OpFoldResult::Value(value) => { + out.push(value.clone()); + continue; + } + // Check to see if there is a canonicalized version of this constant. + OpFoldResult::Attribute(attr_repl) => { + if let Some(mut const_op) = self.try_get_or_create_constant( + insert_region.clone(), + dialect.clone(), + attr_repl.clone_value(), + op_result.borrow().ty().clone(), + self.erased_folded_location, + ) { + // Ensure that this constant dominates the operation we are replacing. + // + // This may not automatically happen if the operation being folded was + // inserted before the constant within the insertion block. + let op_block = borrowed_op.parent().unwrap(); + if op_block == const_op.borrow().parent().unwrap() + && op_block.borrow().front().unwrap() != const_op + { + const_op.borrow_mut().move_before(crate::ProgramPoint::Block(op_block)); + } + out.push(const_op.borrow().get_result(0).borrow().as_value_ref()); + continue; + } + + // If materialization fails, clean up any operations generated for the previous + // results and return failure. + let inserted_before = self.rewriter.insertion_point().unwrap().op(); + if let Some(inserted_before) = inserted_before { + while let Some(inserted_op) = inserted_before.prev() { + self.notify_removal(inserted_op.clone()); + self.rewriter.erase_op(inserted_op); + } + } + + return FoldResult::Failed; + } + } + } + + FoldResult::Ok(out) + } + + /// Notifies that the given constant `op` should be removed from this folder's internal + /// bookkeeping. + /// + /// NOTE: This method must be called if a constant op is to be deleted externally to this + /// folder. `op` must be constant. + pub fn notify_removal(&mut self, op: OperationRef) { + // Check to see if this operation is uniqued within the folder. + let Some(referenced_dialects) = self.referenced_dialects.get_mut(&op) else { + return; + }; + + let borrowed_op = op.borrow(); + + // Get the constant value for this operation, this is the value that was used to unique + // the operation internally. + let value = crate::matchers::constant().matches(&borrowed_op).unwrap(); + + // Get the constant map that this operation was uniqued in. + let insert_region = get_insertion_region(borrowed_op.parent().unwrap()); + let uniqued_constants = self.scopes.get_mut(&insert_region).unwrap(); + + // Erase all of the references to this operation. + let ty = borrowed_op.results()[0].borrow().ty().clone(); + for dialect in referenced_dialects.drain(..) { + let uniqued_constant = UniquedConstant { + dialect, + value: value.clone_value(), + ty: ty.clone(), + }; + uniqued_constants.remove(&uniqued_constant); + } + } + + /// CLear out any constants cached inside the folder. + pub fn clear(&mut self) { + self.scopes.clear(); + self.referenced_dialects.clear(); + } + + /// Tries to fold a pre-existing constant operation. + /// + /// `value` represents the value of the constant, and can be optionally passed if the value is + /// already known (e.g. if the constant was discovered by a pattern match). This is purely an + /// optimization opportunity for callers that already know the value of the constant. + /// + /// Returns `false` if an existing constant for `op` already exists in the folder, in which case + /// `op` is replaced and erased. Otherwise, returns `true` and `op` is inserted into the folder + /// and hoisted if necessary. + pub fn insert_known_constant( + &mut self, + mut op: OperationRef, + value: Option>, + ) -> bool { + let block = op.borrow().parent().unwrap(); + + // If this is a constant we uniqued, we don't need to insert, but we can check to see if + // we should rehoist it. + if self.is_folder_owned_constant(&op) { + if block.borrow().front().unwrap() != op + && !self.is_folder_owned_constant(&op.prev().unwrap()) + { + let mut op = op.borrow_mut(); + op.move_before(crate::ProgramPoint::Block(block)); + op.set_span(self.erased_folded_location); + } + return true; + } + + // Get the constant value of the op if necessary. + let value = value.unwrap_or_else(|| { + crate::matchers::constant() + .matches(&op.borrow()) + .expect("expected `op` to be a constant") + }); + + // Check for an existing constant operation for the attribute value. + let insert_region = get_insertion_region(block.clone()); + let uniqued_constants = self.scopes.entry(insert_region.clone()).or_default(); + let uniqued_constant = UniquedConstant::new(&op, value); + let mut is_new = false; + let mut folder_const_op = uniqued_constants + .entry(uniqued_constant) + .or_insert_with(|| { + is_new = true; + op.clone() + }) + .clone(); + + // If there is an existing constant, replace `op` + if !is_new { + self.notify_removal(op.clone()); + self.rewriter.replace_op(op, folder_const_op.clone()); + folder_const_op.borrow_mut().set_span(self.erased_folded_location); + return false; + } + + // Otherwise, we insert `op`. If `op` is in the insertion block and is either already at the + // front of the block, or the previous operation is already a constant we uniqued (i.e. one + // we inserted), then we don't need to do anything. Otherwise, we move the constant to the + // insertion block. + let insert_block = insert_region.borrow().entry_block_ref().unwrap(); + if block != insert_block + || (insert_block.borrow().front().unwrap() != op + && !self.is_folder_owned_constant(&op.prev().unwrap())) + { + let mut op = op.borrow_mut(); + op.move_before(crate::ProgramPoint::Block(insert_block)); + op.set_span(self.erased_folded_location); + } + + let referenced_dialects = self.referenced_dialects.entry(op.clone()).or_default(); + let dialect = op.borrow().dialect(); + let dialect_name = dialect.name(); + if !referenced_dialects.iter().any(|d| d.name() == dialect_name) { + referenced_dialects.push(dialect); + } + + true + } + + /// Get or create a constant for use in the specified block. + /// + /// The constant may be created in a parent block. On success, this returns the result of the + /// constant operation, or `None` otherwise. + pub fn get_or_create_constant( + &mut self, + block: BlockRef, + dialect: Rc, + value: Box, + ty: Type, + ) -> Option { + // Find an insertion point for the constant. + let insert_region = get_insertion_region(block.clone()); + let entry = insert_region.borrow().entry_block_ref().unwrap(); + self.rewriter.set_insertion_point_to_start(entry); + + // Get the constant map for the insertion region of this operation. + // Use erased location since the op is being built at the front of the block. + let const_op = self.try_get_or_create_constant( + insert_region, + dialect, + value, + ty, + self.erased_folded_location, + )?; + Some(const_op.borrow().results()[0].borrow().as_value_ref()) + } + + /// Try to get or create a new constant entry. + /// + /// On success, this returns the constant operation, `None` otherwise + fn try_get_or_create_constant( + &mut self, + insert_region: RegionRef, + dialect: Rc, + value: Box, + ty: Type, + span: SourceSpan, + ) -> Option { + let uniqued_constants = self.scopes.entry(insert_region).or_default(); + let uniqued_constant = UniquedConstant { + dialect: dialect.clone(), + value, + ty, + }; + if let Some(mut const_op) = uniqued_constants.get(&uniqued_constant).cloned() { + { + let mut const_op = const_op.borrow_mut(); + if const_op.span() != span { + const_op.set_span(span); + } + } + return Some(const_op); + } + + // If one doesn't exist, try to materialize one. + let const_op = materialize_constant( + self.rewriter.as_mut(), + dialect.clone(), + uniqued_constant.value.clone_value(), + &uniqued_constant.ty, + span, + )?; + + // Check to see if the generated constant is in the expected dialect. + let new_dialect = const_op.borrow().dialect(); + if new_dialect.name() == dialect.name() { + self.referenced_dialects.entry(const_op.clone()).or_default().push(new_dialect); + return Some(const_op); + } + + // If it isn't, then we also need to make sure that the mapping for the new dialect is valid + let new_uniqued_constant = UniquedConstant { + dialect: new_dialect.clone(), + value: uniqued_constant.value.clone_value(), + ty: uniqued_constant.ty.clone(), + }; + let maybe_existing_op = uniqued_constants.get(&new_uniqued_constant).cloned(); + uniqued_constants.insert( + uniqued_constant, + maybe_existing_op.clone().unwrap_or_else(|| const_op.clone()), + ); + if let Some(mut existing_op) = maybe_existing_op { + self.notify_removal(const_op.clone()); + self.rewriter.erase_op(const_op); + self.referenced_dialects + .get_mut(&existing_op) + .unwrap() + .push(new_uniqued_constant.dialect.clone()); + let mut existing = existing_op.borrow_mut(); + if existing.span() != span { + existing.set_span(span); + } + Some(existing_op) + } else { + self.referenced_dialects + .insert(const_op.clone(), smallvec![dialect, new_dialect]); + uniqued_constants.insert(new_uniqued_constant, const_op.clone()); + Some(const_op) + } + } + + /// Returns true if the given operation is an already folded constant that is owned by this + /// folder. + #[inline(always)] + fn is_folder_owned_constant(&self, op: &OperationRef) -> bool { + self.referenced_dialects.contains_key(op) + } +} + +/// Materialize a constant for a given attribute and type. +/// +/// Returns a constant operation if successful, otherwise `None` +fn materialize_constant( + builder: &mut dyn Builder, + dialect: Rc, + value: Box, + ty: &Type, + span: SourceSpan, +) -> Option { + let ip = builder.insertion_point().cloned(); + + // Ask the dialect to materialize a constant operation for this value. + let const_op = dialect.materialize_constant(builder, value, ty, span)?; + assert_eq!(ip.as_ref(), builder.insertion_point()); + assert!(const_op.borrow().implements::()); + Some(const_op) +} + +/// Given the containing block of an operation, find the parent region that folded constants should +/// be inserted into. +fn get_insertion_region(insertion_block: BlockRef) -> RegionRef { + use crate::EntityWithId; + + let mut insertion_block = Some(insertion_block); + while let Some(block) = insertion_block.take() { + let parent_region = block.borrow().parent().unwrap_or_else(|| { + panic!("expected block {} to be attached to a region", block.borrow().id()) + }); + // Insert in this region for any of the following scenarios: + // + // * The parent is known to be isolated from above + // * The parent is a top-level operation + let parent_op = parent_region + .borrow() + .parent() + .expect("expected region to be attached to an operation"); + let parent = parent_op.borrow(); + let parent_block = parent.parent(); + if parent.implements::() || parent_block.is_none() { + return parent_region; + } + + insertion_block = parent_block; + } + + unreachable!("expected valid insertion region") +} diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index fa3acb1ea..6b77b8970 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -37,13 +37,24 @@ mod attributes; pub mod demangle; pub mod derive; pub mod dialects; +mod folder; pub mod formatter; mod hash; mod ir; pub mod matchers; mod patterns; -pub use self::{any::AsAny, attributes::*, ir::*, patterns::*}; +pub use self::{ + any::AsAny, + attributes::{ + markers::*, Attribute, AttributeSet, AttributeValue, CallConv, DictAttr, Overflow, SetAttr, + Visibility, + }, + folder::OperationFolder, + hash::{DynHash, DynHasher}, + ir::*, + patterns::*, +}; // TODO(pauls): The following is a rough list of what needs to be implemented for the IR // refactoring to be complete and usable in place of the old IR (some are optional): From b69c4ec187df18a4e14a16bd26c4b10e4496fbce Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 20:16:53 -0400 Subject: [PATCH 21/31] feat: implement some useful display helpers --- hir2/src/formatter.rs | 44 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/hir2/src/formatter.rs b/hir2/src/formatter.rs index 08eab2ae5..e2025a620 100644 --- a/hir2/src/formatter.rs +++ b/hir2/src/formatter.rs @@ -1,4 +1,4 @@ -use core::fmt; +use core::{cell::Cell, fmt}; pub use miden_core::{ prettier::*, @@ -15,3 +15,45 @@ impl fmt::Display for DisplayIndent { Ok(()) } } + +/// Render an iterator of `T`, comma-separated +pub struct DisplayValues(Cell>); +impl DisplayValues { + pub fn new(inner: T) -> Self { + Self(Cell::new(Some(inner))) + } +} +impl fmt::Display for DisplayValues +where + T: fmt::Display, + I: Iterator, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let iter = self.0.take().unwrap(); + for (i, item) in iter.enumerate() { + if i == 0 { + write!(f, "{}", item)?; + } else { + write!(f, ", {}", item)?; + } + } + Ok(()) + } +} + +/// Render an `Option` using the `Display` impl for `T` +pub struct DisplayOptional<'a, T>(pub Option<&'a T>); +impl<'a, T: fmt::Display> fmt::Display for DisplayOptional<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.0 { + None => f.write_str("None"), + Some(item) => write!(f, "Some({item})"), + } + } +} +impl<'a, T: fmt::Display> fmt::Debug for DisplayOptional<'a, T> { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} From 05e110058fafdb48aced38bcf70a0bf0f00e4f30 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 20:19:08 -0400 Subject: [PATCH 22/31] wip: implement hir dialect builder and op impls, particularly control ops --- hir2/src/dialects/hir.rs | 84 +++- hir2/src/dialects/hir/builders/function.rs | 364 +++++++++++++-- hir2/src/dialects/hir/ops.rs | 5 +- hir2/src/dialects/hir/ops/binary.rs | 486 +++++++++++---------- hir2/src/dialects/hir/ops/cast.rs | 56 +++ hir2/src/dialects/hir/ops/constants.rs | 41 ++ hir2/src/dialects/hir/ops/control.rs | 481 +++++++++++++++++++- hir2/src/dialects/hir/ops/function.rs | 42 +- hir2/src/dialects/hir/ops/invoke.rs | 50 ++- hir2/src/dialects/hir/ops/module.rs | 17 +- hir2/src/dialects/hir/ops/primop.rs | 14 + hir2/src/dialects/hir/ops/ternary.rs | 24 +- hir2/src/dialects/hir/ops/unary.rs | 49 +++ 13 files changed, 1394 insertions(+), 319 deletions(-) create mode 100644 hir2/src/dialects/hir/ops/constants.rs diff --git a/hir2/src/dialects/hir.rs b/hir2/src/dialects/hir.rs index fae91c19b..31cd846ec 100644 --- a/hir2/src/dialects/hir.rs +++ b/hir2/src/dialects/hir.rs @@ -5,10 +5,13 @@ use alloc::rc::Rc; use core::cell::{Cell, RefCell}; pub use self::{ - builders::{DefaultInstBuilder, FunctionBuilder}, + builders::{DefaultInstBuilder, FunctionBuilder, InstBuilder, InstBuilderBase}, ops::*, }; -use crate::{interner, Dialect, DialectName, DialectRegistration, OperationName}; +use crate::{ + interner, AttributeValue, Builder, BuilderExt, Dialect, DialectName, DialectRegistration, + Immediate, OperationName, OperationRef, SourceSpan, Type, +}; #[derive(Default)] pub struct HirDialect { @@ -66,6 +69,83 @@ impl Dialect for HirDialect { } } } + + fn materialize_constant( + &self, + builder: &mut dyn Builder, + attr: Box, + ty: &Type, + span: SourceSpan, + ) -> Option { + use crate::Op; + + // Save the current insertion point + let mut builder = crate::InsertionGuard::new(builder); + + // Only integer constants are supported for now + if !ty.is_integer() { + return None; + } + + // Currently, we expect folds to produce `Immediate`-valued attributes + if let Some(&imm) = attr.downcast_ref::() { + // If the immediate value is of the same type as the expected result type, we're ready + // to materialize the constant + let imm_ty = imm.ty(); + if &imm_ty == ty { + let op_builder = builder.create::(span); + return op_builder(imm) + .ok() + .map(|op| op.borrow().as_operation().as_operation_ref()); + } + + // The immediate value has a different type than expected, but we can coerce types, so + // long as the value fits in the target type + if imm_ty.size_in_bits() > ty.size_in_bits() { + return None; + } + + let imm = match ty { + Type::I8 => match imm { + Immediate::I1(value) => Immediate::I8(value as i8), + Immediate::U8(value) => Immediate::I8(i8::try_from(value).ok()?), + _ => return None, + }, + Type::U8 => match imm { + Immediate::I1(value) => Immediate::U8(value as u8), + Immediate::I8(value) => Immediate::U8(u8::try_from(value).ok()?), + _ => return None, + }, + Type::I16 => match imm { + Immediate::I1(value) => Immediate::I16(value as i16), + Immediate::I8(value) => Immediate::I16(value as i16), + Immediate::U8(value) => Immediate::I16(value.into()), + Immediate::U16(value) => Immediate::I16(i16::try_from(value).ok()?), + _ => return None, + }, + Type::U16 => match imm { + Immediate::I1(value) => Immediate::U16(value as u16), + Immediate::I8(value) => Immediate::U16(u16::try_from(value).ok()?), + Immediate::U8(value) => Immediate::U16(value as u16), + Immediate::I16(value) => Immediate::U16(u16::try_from(value).ok()?), + _ => return None, + }, + Type::I32 => Immediate::I32(imm.as_i32()?), + Type::U32 => Immediate::U32(imm.as_u32()?), + Type::I64 => Immediate::I64(imm.as_i64()?), + Type::U64 => Immediate::U64(imm.as_u64()?), + Type::I128 => Immediate::I128(imm.as_i128()?), + Type::U128 => Immediate::U128(imm.as_u128()?), + Type::Felt => Immediate::Felt(imm.as_felt()?), + ty => unimplemented!("unrecognized integral type '{ty}'"), + }; + + let op_builder = builder.create::(span); + return op_builder(imm).ok().map(|op| op.borrow().as_operation().as_operation_ref()); + } + + None + } } impl DialectRegistration for HirDialect { diff --git a/hir2/src/dialects/hir/builders/function.rs b/hir2/src/dialects/hir/builders/function.rs index b42b84b74..d5fefc69c 100644 --- a/hir2/src/dialects/hir/builders/function.rs +++ b/hir2/src/dialects/hir/builders/function.rs @@ -1,5 +1,5 @@ use crate::{ - dialects::hir::*, AsCallableSymbolRef, BlockRef, Builder, Immediate, InsertionPoint, Op, + dialects::hir::*, AsCallableSymbolRef, Block, BlockRef, Builder, Immediate, InsertionPoint, Op, OpBuilder, Region, RegionRef, Report, SourceSpan, Type, UnsafeIntrusiveEntityRef, Usable, ValueRef, }; @@ -10,9 +10,15 @@ pub struct FunctionBuilder<'f> { } impl<'f> FunctionBuilder<'f> { pub fn new(func: &'f mut Function) -> Self { + let current_block = if func.body().is_empty() { + func.create_entry_block() + } else { + func.last_block() + }; let context = func.as_operation().context_rc(); let mut builder = OpBuilder::new(context); - builder.set_insertion_point_to_end(func.last_block()); + + builder.set_insertion_point_to_end(current_block); Self { func, builder } } @@ -44,10 +50,12 @@ impl<'f> FunctionBuilder<'f> { } pub fn create_block(&mut self) -> BlockRef { - self.builder.create_block(self.body_region(), None, None) + self.builder.create_block(self.body_region(), None, &[]) } pub fn detach_block(&mut self, mut block: BlockRef) { + use crate::EntityWithParent; + assert_ne!( block, self.current_block(), @@ -63,6 +71,7 @@ impl<'f> FunctionBuilder<'f> { body.body_mut().cursor_mut_from_ptr(block.clone()).remove(); } block.borrow_mut().uses_mut().clear(); + Block::on_removed_from_parent(block, body.as_region_ref()); } pub fn append_block_param(&mut self, block: BlockRef, ty: Type, span: SourceSpan) -> ValueRef { @@ -173,6 +182,12 @@ pub trait InstBuilder<'f>: InstBuilderBase<'f> { op_builder(lhs, rhs) } + fn u32(mut self, value: u32, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let constant = op_builder(Immediate::U32(value))?; + Ok(constant.borrow().result().as_value_ref()) + } + //signed_integer_literal!(1, bool); //integer_literal!(8); //integer_literal!(16); @@ -577,35 +592,276 @@ pub trait InstBuilder<'f>: InstBuilderBase<'f> { Ok(op.borrow().result().as_value_ref()) } - /* - binary_int_op_with_overflow!(add, Opcode::Add); - binary_int_op_with_overflow!(sub, Opcode::Sub); - binary_int_op_with_overflow!(mul, Opcode::Mul); - checked_binary_int_op!(div, Opcode::Div); - binary_int_op!(min, Opcode::Min); - binary_int_op!(max, Opcode::Max); - checked_binary_int_op!(r#mod, Opcode::Mod); - checked_binary_int_op!(divmod, Opcode::DivMod); - binary_int_op!(exp, Opcode::Exp); - binary_boolean_op!(and, Opcode::And); - binary_int_op!(band, Opcode::Band); - binary_boolean_op!(or, Opcode::Or); - binary_int_op!(bor, Opcode::Bor); - binary_boolean_op!(xor, Opcode::Xor); - binary_int_op!(bxor, Opcode::Bxor); - unary_int_op!(neg, Opcode::Neg); - unary_int_op!(inv, Opcode::Inv); - unary_int_op_with_overflow!(incr, Opcode::Incr); - unary_int_op!(ilog2, Opcode::Ilog2); - unary_int_op!(pow2, Opcode::Pow2); - unary_boolean_op!(not, Opcode::Not); - unary_int_op!(bnot, Opcode::Bnot); - unary_int_op!(popcnt, Opcode::Popcnt); - unary_int_op!(clz, Opcode::Clz); - unary_int_op!(ctz, Opcode::Ctz); - unary_int_op!(clo, Opcode::Clo); - unary_int_op!(cto, Opcode::Cto); - */ + /// Two's complement addition which traps on overflow + fn add(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Checked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Unchecked two's complement addition. Behavior is undefined if the result overflows. + fn add_unchecked( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Unchecked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement addition which wraps around on overflow, e.g. `wrapping_add` + fn add_wrapping( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Wrapping)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement addition which wraps around on overflow, but returns a boolean flag that + /// indicates whether or not the operation overflowed, followed by the wrapped result, e.g. + /// `overflowing_add` (but with the result types inverted compared to Rust's version). + fn add_overflowing( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result<(ValueRef, ValueRef), Report> { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + let op = op.borrow(); + let overflowed = op.overflowed().as_value_ref(); + let result = op.result().as_value_ref(); + Ok((overflowed, result)) + } + + /// Two's complement subtraction which traps on under/overflow + fn sub(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Checked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Unchecked two's complement subtraction. Behavior is undefined if the result under/overflows. + fn sub_unchecked( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Unchecked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement subtraction which wraps around on under/overflow, e.g. `wrapping_sub` + fn sub_wrapping( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Wrapping)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement subtraction which wraps around on overflow, but returns a boolean flag that + /// indicates whether or not the operation under/overflowed, followed by the wrapped result, + /// e.g. `overflowing_sub` (but with the result types inverted compared to Rust's version). + fn sub_overflowing( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result<(ValueRef, ValueRef), Report> { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + let op = op.borrow(); + let overflowed = op.overflowed().as_value_ref(); + let result = op.result().as_value_ref(); + Ok((overflowed, result)) + } + + /// Two's complement multiplication which traps on overflow + fn mul(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Checked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Unchecked two's complement multiplication. Behavior is undefined if the result overflows. + fn mul_unchecked( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Unchecked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement multiplication which wraps around on overflow, e.g. `wrapping_mul` + fn mul_wrapping( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Wrapping)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement multiplication which wraps around on overflow, but returns a boolean flag + /// that indicates whether or not the operation overflowed, followed by the wrapped result, + /// e.g. `overflowing_mul` (but with the result types inverted compared to Rust's version). + fn mul_overflowing( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result<(ValueRef, ValueRef), Report> { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + let op = op.borrow(); + let overflowed = op.overflowed().as_value_ref(); + let result = op.result().as_value_ref(); + Ok((overflowed, result)) + } + + /// Integer division. Traps if `rhs` is zero. + fn div(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Integer Euclidean modulo. Traps if `rhs` is zero. + fn r#mod(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Combined integer Euclidean division and modulo. Traps if `rhs` is zero. + fn divmod( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result<(ValueRef, ValueRef), Report> { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + let op = op.borrow(); + let quotient = op.quotient().as_value_ref(); + let remainder = op.remainder().as_value_ref(); + Ok((quotient, remainder)) + } + + /// Exponentiation + fn exp(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Compute 2^n + fn pow2(mut self, n: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(n)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Compute ilog2(n) + fn ilog2(mut self, n: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(n)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Modular inverse + fn inv(mut self, n: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(n)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Unary negation + fn neg(mut self, n: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(n)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement unary increment by one which traps on overflow + fn incr(mut self, lhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Logical AND + fn and(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Logical OR + fn or(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Logical XOR + fn xor(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Logical NOT + fn not(mut self, lhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Bitwise AND + fn band(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Bitwise OR + fn bor(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Bitwise XOR + fn bxor(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Bitwise NOT + fn bnot(mut self, lhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs)?; + Ok(op.borrow().result().as_value_ref()) + } fn rotl(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { let op_builder = self.builder_mut().create::(span); @@ -631,6 +887,36 @@ pub trait InstBuilder<'f>: InstBuilderBase<'f> { Ok(op.borrow().result().as_value_ref()) } + fn popcnt(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn clz(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn ctz(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn clo(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn cto(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + fn eq(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { let op_builder = self.builder_mut().create::(span); let op = op_builder(lhs, rhs)?; @@ -643,6 +929,20 @@ pub trait InstBuilder<'f>: InstBuilderBase<'f> { Ok(op.borrow().result().as_value_ref()) } + /// Compares two integers and returns the minimum value + fn min(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Compares two integers and returns the maximum value + fn max(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + fn gt(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { let op_builder = self.builder_mut().create::(span); let op = op_builder(lhs, rhs)?; diff --git a/hir2/src/dialects/hir/ops.rs b/hir2/src/dialects/hir/ops.rs index 5eb25fa3c..47701fd07 100644 --- a/hir2/src/dialects/hir/ops.rs +++ b/hir2/src/dialects/hir/ops.rs @@ -1,6 +1,7 @@ mod assertions; mod binary; mod cast; +mod constants; mod control; mod function; mod invoke; @@ -11,6 +12,6 @@ mod ternary; mod unary; pub use self::{ - assertions::*, binary::*, cast::*, control::*, function::*, invoke::*, mem::*, module::*, - primop::*, ternary::*, unary::*, + assertions::*, binary::*, cast::*, constants::*, control::*, function::*, invoke::*, mem::*, + module::*, primop::*, ternary::*, unary::*, }; diff --git a/hir2/src/dialects/hir/ops/binary.rs b/hir2/src/dialects/hir/ops/binary.rs index 8582197ca..4e24a0d0c 100644 --- a/hir2/src/dialects/hir/ops/binary.rs +++ b/hir2/src/dialects/hir/ops/binary.rs @@ -1,9 +1,57 @@ use crate::{derive::operation, dialects::hir::HirDialect, traits::*, *}; +// Implement `derive(InferTypeOpInterface)` with `#[infer]` helper attribute: +// +// * `#[infer]` on a result field indicates its type should be inferred from the type of the first +// operand field +// * `#[infer(from = field)]` on a result field indicates its type should be inferred from +// the given field. The field is expected to implement `AsRef` +// * `#[infer(type = I1)]` on a field indicates that the field should always be inferred to have the given type +// * `#[infer(with = path::to::function)]` on a field indicates that the given function should be called to +// compute the inferred type for that field +macro_rules! infer_return_ty_for_binary_op { + ($Op:ty) => { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty().clone(); + self.result_mut().set_type(lhs); + Ok(()) + } + } + }; + + + ($Op:ty as $manually_specified_ty:expr) => { + paste::paste! { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type($manually_specified_ty); + Ok(()) + } + } + } + }; + + ($Op:ty, $($manually_specified_field_name:ident : $manually_specified_field_ty:expr),+) => { + paste::paste! { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty().clone(); + self.result_mut().set_type(lhs); + $( + self.[<$manually_specified_field_name _mut>]().set_type($manually_specified_field_ty); + )* + Ok(()) + } + } + } + }; +} + /// Two's complement sum #[operation( dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands), + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), implements(InferTypeOpInterface) )] pub struct Add { @@ -17,31 +65,7 @@ pub struct Add { overflow: Overflow, } -impl InferTypeOpInterface for Add { - fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { - use midenc_session::diagnostics::Severity; - let span = self.span(); - let lhs = self.lhs().ty().clone(); - { - let rhs = self.rhs(); - if lhs != rhs.ty() { - return Err(context - .session - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid operand types") - .with_primary_label(span, "operands of this operation are not compatible") - .with_secondary_label( - rhs.span(), - format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), - ) - .into_report()); - } - } - self.result_mut().set_type(lhs); - Ok(()) - } -} +infer_return_ty_for_binary_op!(Add); /// Two's complement sum with overflow bit #[operation( @@ -60,31 +84,7 @@ pub struct AddOverflowing { result: AnyInteger, } -impl InferTypeOpInterface for AddOverflowing { - fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { - use midenc_session::diagnostics::Severity; - let span = self.span(); - let lhs = self.lhs().ty().clone(); - { - let rhs = self.rhs(); - if lhs != rhs.ty() { - return Err(context - .session - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid operand types") - .with_primary_label(span, "operands of this operation are not compatible") - .with_secondary_label( - rhs.span(), - format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), - ) - .into_report()); - } - } - self.result_mut().set_type(lhs); - Ok(()) - } -} +infer_return_ty_for_binary_op!(AddOverflowing, overflowed: Type::I1); /// Two's complement difference (subtraction) #[operation( @@ -103,31 +103,7 @@ pub struct Sub { overflow: Overflow, } -impl InferTypeOpInterface for Sub { - fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { - use midenc_session::diagnostics::Severity; - let span = self.span(); - let lhs = self.lhs().ty().clone(); - { - let rhs = self.rhs(); - if lhs != rhs.ty() { - return Err(context - .session - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid operand types") - .with_primary_label(span, "operands of this operation are not compatible") - .with_secondary_label( - rhs.span(), - format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), - ) - .into_report()); - } - } - self.result_mut().set_type(lhs); - Ok(()) - } -} +infer_return_ty_for_binary_op!(Sub); /// Two's complement difference (subtraction) with underflow bit #[operation( @@ -146,31 +122,7 @@ pub struct SubOverflowing { result: AnyInteger, } -impl InferTypeOpInterface for SubOverflowing { - fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { - use midenc_session::diagnostics::Severity; - let span = self.span(); - let lhs = self.lhs().ty().clone(); - { - let rhs = self.rhs(); - if lhs != rhs.ty() { - return Err(context - .session - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid operand types") - .with_primary_label(span, "operands of this operation are not compatible") - .with_secondary_label( - rhs.span(), - format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), - ) - .into_report()); - } - } - self.result_mut().set_type(lhs); - Ok(()) - } -} +infer_return_ty_for_binary_op!(SubOverflowing, overflowed: Type::I1); /// Two's complement product #[operation( @@ -189,38 +141,14 @@ pub struct Mul { overflow: Overflow, } -impl InferTypeOpInterface for Mul { - fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { - use midenc_session::diagnostics::Severity; - let span = self.span(); - let lhs = self.lhs().ty().clone(); - { - let rhs = self.rhs(); - if lhs != rhs.ty() { - return Err(context - .session - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid operand types") - .with_primary_label(span, "operands of this operation are not compatible") - .with_secondary_label( - rhs.span(), - format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), - ) - .into_report()); - } - } - self.result_mut().set_type(lhs); - Ok(()) - } -} +infer_return_ty_for_binary_op!(Mul); /// Two's complement product with overflow bit #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands), - implements(InferTypeOpInterface) - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] pub struct MulOverflowing { #[operand] lhs: AnyInteger, @@ -232,37 +160,14 @@ pub struct MulOverflowing { result: AnyInteger, } -impl InferTypeOpInterface for MulOverflowing { - fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { - use midenc_session::diagnostics::Severity; - let span = self.span(); - let lhs = self.lhs().ty().clone(); - { - let rhs = self.rhs(); - if lhs != rhs.ty() { - return Err(context - .session - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid operand types") - .with_primary_label(span, "operands of this operation are not compatible") - .with_secondary_label( - rhs.span(), - format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), - ) - .into_report()); - } - } - self.result_mut().set_type(lhs); - Ok(()) - } -} +infer_return_ty_for_binary_op!(MulOverflowing, overflowed: Type::I1); /// Exponentiation for field elements #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Exp { #[operand] lhs: IntFelt, @@ -272,11 +177,14 @@ pub struct Exp { result: IntFelt, } +infer_return_ty_for_binary_op!(Exp); + /// Unsigned integer division, traps on division by zero #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Div { #[operand] lhs: AnyInteger, @@ -286,11 +194,14 @@ pub struct Div { result: AnyInteger, } +infer_return_ty_for_binary_op!(Div); + /// Signed integer division, traps on division by zero or dividing the minimum signed value by -1 #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Sdiv { #[operand] lhs: AnyInteger, @@ -300,11 +211,14 @@ pub struct Sdiv { result: AnyInteger, } +infer_return_ty_for_binary_op!(Sdiv); + /// Unsigned integer Euclidean modulo, traps on division by zero #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Mod { #[operand] lhs: AnyInteger, @@ -314,13 +228,16 @@ pub struct Mod { result: AnyInteger, } +infer_return_ty_for_binary_op!(Mod); + /// Signed integer Euclidean modulo, traps on division by zero /// /// The result has the same sign as the dividend (lhs) #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Smod { #[operand] lhs: AnyInteger, @@ -330,13 +247,16 @@ pub struct Smod { result: AnyInteger, } +infer_return_ty_for_binary_op!(Smod); + /// Combined unsigned integer Euclidean division and remainder (modulo). /// /// Traps on division by zero. #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Divmod { #[operand] lhs: AnyInteger, @@ -348,15 +268,25 @@ pub struct Divmod { quotient: AnyInteger, } +impl InferTypeOpInterface for Divmod { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty().clone(); + self.remainder_mut().set_type(lhs.clone()); + self.quotient_mut().set_type(lhs); + Ok(()) + } +} + /// Combined signed integer Euclidean division and remainder (modulo). /// /// Traps on division by zero. /// /// The remainder has the same sign as the dividend (lhs) #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Sdivmod { #[operand] lhs: AnyInteger, @@ -368,13 +298,23 @@ pub struct Sdivmod { quotient: AnyInteger, } +impl InferTypeOpInterface for Sdivmod { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty().clone(); + self.remainder_mut().set_type(lhs.clone()); + self.quotient_mut().set_type(lhs); + Ok(()) + } +} + /// Logical AND /// /// Operands must be boolean. #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct And { #[operand] lhs: Bool, @@ -384,13 +324,16 @@ pub struct And { result: Bool, } +infer_return_ty_for_binary_op!(And); + /// Logical OR /// /// Operands must be boolean. #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Or { #[operand] lhs: Bool, @@ -400,13 +343,16 @@ pub struct Or { result: Bool, } +infer_return_ty_for_binary_op!(Or); + /// Logical XOR /// /// Operands must be boolean. #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Xor { #[operand] lhs: Bool, @@ -416,11 +362,14 @@ pub struct Xor { result: Bool, } +infer_return_ty_for_binary_op!(Xor); + /// Bitwise AND #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Band { #[operand] lhs: AnyInteger, @@ -430,11 +379,14 @@ pub struct Band { result: AnyInteger, } +infer_return_ty_for_binary_op!(Band); + /// Bitwise OR #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Bor { #[operand] lhs: AnyInteger, @@ -444,13 +396,16 @@ pub struct Bor { result: AnyInteger, } +infer_return_ty_for_binary_op!(Bor); + /// Bitwise XOR /// /// Operands must be boolean. #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Bxor { #[operand] lhs: AnyInteger, @@ -460,13 +415,16 @@ pub struct Bxor { result: AnyInteger, } +infer_return_ty_for_binary_op!(Bxor); + /// Bitwise shift-left /// /// Shifts larger than the bitwidth of the value will be wrapped to zero. #[operation( - dialect = HirDialect, - traits(BinaryOp), - )] + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] pub struct Shl { #[operand] lhs: AnyInteger, @@ -476,13 +434,40 @@ pub struct Shl { result: AnyInteger, } +infer_return_ty_for_binary_op!(Shl); + +/// Bitwise shift-left by immediate +/// +/// Shifts larger than the bitwidth of the value will be wrapped to zero. +#[operation( + dialect = HirDialect, + implements(InferTypeOpInterface) +)] +pub struct ShlImm { + #[operand] + lhs: AnyInteger, + #[attr] + shift: u32, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for ShlImm { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.lhs().ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + /// Bitwise (logical) shift-right /// /// Shifts larger than the bitwidth of the value will effectively truncate the value to zero. #[operation( - dialect = HirDialect, - traits(BinaryOp), - )] + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] pub struct Shr { #[operand] lhs: AnyInteger, @@ -492,14 +477,17 @@ pub struct Shr { result: AnyInteger, } +infer_return_ty_for_binary_op!(Shr); + /// Arithmetic (signed) shift-right /// /// The result of shifts larger than the bitwidth of the value depend on the sign of the value; /// for positive values, it rounds to zero; for negative values, it rounds to MIN. #[operation( - dialect = HirDialect, - traits(BinaryOp), - )] + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] pub struct Ashr { #[operand] lhs: AnyInteger, @@ -509,13 +497,16 @@ pub struct Ashr { result: AnyInteger, } +infer_return_ty_for_binary_op!(Ashr); + /// Bitwise rotate-left /// /// The rotation count must be < the bitwidth of the value type. #[operation( - dialect = HirDialect, - traits(BinaryOp), - )] + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] pub struct Rotl { #[operand] lhs: AnyInteger, @@ -525,13 +516,16 @@ pub struct Rotl { result: AnyInteger, } +infer_return_ty_for_binary_op!(Rotl); + /// Bitwise rotate-right /// /// The rotation count must be < the bitwidth of the value type. #[operation( - dialect = HirDialect, - traits(BinaryOp), - )] + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] pub struct Rotr { #[operand] lhs: AnyInteger, @@ -541,11 +535,14 @@ pub struct Rotr { result: AnyInteger, } +infer_return_ty_for_binary_op!(Rotr); + /// Equality comparison #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] pub struct Eq { #[operand] lhs: AnyInteger, @@ -555,11 +552,14 @@ pub struct Eq { result: Bool, } +infer_return_ty_for_binary_op!(Eq as Type::I1); + /// Inequality comparison #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] pub struct Neq { #[operand] lhs: AnyInteger, @@ -569,11 +569,14 @@ pub struct Neq { result: Bool, } +infer_return_ty_for_binary_op!(Neq as Type::I1); + /// Greater-than comparison #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] pub struct Gt { #[operand] lhs: AnyInteger, @@ -583,11 +586,14 @@ pub struct Gt { result: Bool, } +infer_return_ty_for_binary_op!(Gt as Type::I1); + /// Greater-than-or-equal comparison #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] pub struct Gte { #[operand] lhs: AnyInteger, @@ -597,11 +603,14 @@ pub struct Gte { result: Bool, } +infer_return_ty_for_binary_op!(Gte as Type::I1); + /// Less-than comparison #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] pub struct Lt { #[operand] lhs: AnyInteger, @@ -611,11 +620,14 @@ pub struct Lt { result: Bool, } +infer_return_ty_for_binary_op!(Lt as Type::I1); + /// Less-than-or-equal comparison #[operation( - dialect = HirDialect, - traits(BinaryOp, SameTypeOperands), - )] + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] pub struct Lte { #[operand] lhs: AnyInteger, @@ -625,11 +637,14 @@ pub struct Lte { result: Bool, } +infer_return_ty_for_binary_op!(Lte as Type::I1); + /// Select minimum value #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Min { #[operand] lhs: AnyInteger, @@ -639,11 +654,14 @@ pub struct Min { result: AnyInteger, } +infer_return_ty_for_binary_op!(Min); + /// Select maximum value #[operation( - dialect = HirDialect, - traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), - )] + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] pub struct Max { #[operand] lhs: AnyInteger, @@ -652,3 +670,5 @@ pub struct Max { #[result] result: AnyInteger, } + +infer_return_ty_for_binary_op!(Max); diff --git a/hir2/src/dialects/hir/ops/cast.rs b/hir2/src/dialects/hir/ops/cast.rs index 06017b501..4a42720f5 100644 --- a/hir2/src/dialects/hir/ops/cast.rs +++ b/hir2/src/dialects/hir/ops/cast.rs @@ -44,6 +44,14 @@ pub struct PtrToInt { result: AnyInteger, } +impl InferTypeOpInterface for PtrToInt { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + #[operation( dialect = HirDialect, traits(UnaryOp) @@ -57,6 +65,14 @@ pub struct IntToPtr { result: AnyPointer, } +impl InferTypeOpInterface for IntToPtr { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + #[operation( dialect = HirDialect, traits(UnaryOp) @@ -70,6 +86,14 @@ pub struct Cast { result: AnyInteger, } +impl InferTypeOpInterface for Cast { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + #[operation( dialect = HirDialect, traits(UnaryOp) @@ -83,6 +107,14 @@ pub struct Bitcast { result: AnyPointerOrInteger, } +impl InferTypeOpInterface for Bitcast { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + #[operation( dialect = HirDialect, traits(UnaryOp) @@ -96,6 +128,14 @@ pub struct Trunc { result: AnyInteger, } +impl InferTypeOpInterface for Trunc { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + #[operation( dialect = HirDialect, traits(UnaryOp) @@ -109,6 +149,14 @@ pub struct Zext { result: AnyUnsignedInteger, } +impl InferTypeOpInterface for Zext { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + #[operation( dialect = HirDialect, traits(UnaryOp) @@ -121,3 +169,11 @@ pub struct Sext { #[result] result: AnySignedInteger, } + +impl InferTypeOpInterface for Sext { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} diff --git a/hir2/src/dialects/hir/ops/constants.rs b/hir2/src/dialects/hir/ops/constants.rs new file mode 100644 index 000000000..c3681e47b --- /dev/null +++ b/hir2/src/dialects/hir/ops/constants.rs @@ -0,0 +1,41 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +#[operation( + dialect = HirDialect, + traits(ConstantLike), + implements(InferTypeOpInterface, Foldable) +)] +pub struct Constant { + #[attr] + value: Immediate, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for Constant { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.value().ty(); + self.result_mut().set_type(ty); + + Ok(()) + } +} + +impl Foldable for Constant { + #[inline] + fn fold(&self, results: &mut smallvec::SmallVec<[OpFoldResult; 1]>) -> FoldResult { + results.push(OpFoldResult::Attribute(self.get_attribute("value").unwrap().clone_value())); + FoldResult::Ok(()) + } + + #[inline(always)] + fn fold_with( + &self, + _operands: &[Option>], + results: &mut smallvec::SmallVec<[OpFoldResult; 1]>, + ) -> FoldResult { + self.fold(results) + } +} diff --git a/hir2/src/dialects/hir/ops/control.rs b/hir2/src/dialects/hir/ops/control.rs index f41ceb73b..bd3f31be6 100644 --- a/hir2/src/dialects/hir/ops/control.rs +++ b/hir2/src/dialects/hir/ops/control.rs @@ -1,7 +1,9 @@ use midenc_hir_macros::operation; +use smallvec::{smallvec, SmallVec}; use crate::{dialects::hir::HirDialect, traits::*, *}; +/// Returns from the enclosing function with the provided operands as its results. #[operation( dialect = HirDialect, traits(Terminator, ReturnLike) @@ -11,6 +13,7 @@ pub struct Ret { values: AnyType, } +/// Returns from the enclosing function with the provided immediate value as its result. #[operation( dialect = HirDialect, traits(Terminator, ReturnLike) @@ -20,18 +23,33 @@ pub struct RetImm { value: Immediate, } +/// An unstructured control flow primitive representing an unconditional branch to `target` #[operation( dialect = HirDialect, - traits(Terminator) + traits(Terminator), + implements(BranchOpInterface) )] pub struct Br { #[successor] target: Successor, } +impl BranchOpInterface for Br { + #[inline] + fn get_successor_for_operands( + &self, + _operands: &[Box], + ) -> Option { + Some(self.target().dest.borrow().block.clone()) + } +} + +/// An unstructured control flow primitive representing a conditional branch to either `then_dest` +/// or `else_dest` depending on the value of `condition`, a boolean value. #[operation( dialect = HirDialect, - traits(Terminator) + traits(Terminator), + implements(BranchOpInterface) )] pub struct CondBr { #[operand] @@ -41,10 +59,37 @@ pub struct CondBr { #[successor] else_dest: Successor, } +impl BranchOpInterface for CondBr { + fn get_successor_for_operands(&self, operands: &[Box]) -> Option { + let value = &*operands[0]; + let cond = if let Some(imm) = value.downcast_ref::() { + imm.as_bool().expect("invalid boolean condition for 'hir.if'") + } else if let Some(yes) = value.downcast_ref::() { + *yes + } else { + panic!("expected boolean immediate for '{}' condition, got: {:?}", self.name(), value) + }; + Some(if cond { + self.then_dest().dest.borrow().block.clone() + } else { + self.else_dest().dest.borrow().block.clone() + }) + } +} + +/// An unstructured control flow primitive that represents a multi-way branch to one of multiple +/// branch targets, depending on the value of `selector`. +/// +/// If a specific selector value is matched by `cases`, the branch target corresponding to that +/// case is the one to which control is transferred. If no matching case is found for the selector, +/// then the `fallback` target is used instead. +/// +/// A `fallback` successor must always be provided. #[operation( dialect = HirDialect, - traits(Terminator) + traits(Terminator), + implements(BranchOpInterface) )] pub struct Switch { #[operand] @@ -55,7 +100,36 @@ pub struct Switch { fallback: Successor, } -// TODO(pauls): Implement `SuccessorInterface` for this type +impl BranchOpInterface for Switch { + #[inline] + fn get_successor_for_operands(&self, operands: &[Box]) -> Option { + let value = &*operands[0]; + let selector = if let Some(selector) = value.downcast_ref::() { + selector.as_u32().expect("invalid selector value for 'hir.switch'") + } else if let Some(selector) = value.downcast_ref::() { + *selector + } else if let Some(selector) = value.downcast_ref::() { + u32::try_from(*selector).expect("invalid selector value for 'hir.switch'") + } else if let Some(selector) = value.downcast_ref::() { + u32::try_from(*selector).expect("invalid selector value for 'hir.switch': out of range") + } else { + panic!("unsupported selector value type for '{}', got: {:?}", self.name(), value) + }; + + for switch_case in self.cases().iter() { + let key = switch_case.key().unwrap(); + if selector == key.value { + return Some(switch_case.block()); + } + } + + // If we reach here, no selector match was found, so use the fallback successor + Some(self.fallback().dest.borrow().block.clone()) + } +} + +/// Represents a single branch target by matching a specific selector value in a [Switch] +/// operation. #[derive(Debug, Clone)] pub struct SwitchCase { pub value: u32, @@ -63,12 +137,14 @@ pub struct SwitchCase { pub arguments: Vec, } +#[doc(hidden)] pub struct SwitchCaseRef<'a> { pub value: u32, pub successor: BlockOperandRef, pub arguments: OpOperandRange<'a>, } +#[doc(hidden)] pub struct SwitchCaseMut<'a> { pub value: u32, pub successor: BlockOperandRef, @@ -113,9 +189,21 @@ impl KeyedSuccessor for SwitchCase { } } +/// [If] is a structured control flow operation representing conditional execution. +/// +/// An [If] takes a single condition as an argument, which chooses between one of its two regions +/// based on the condition. If the condition is true, then the `then_body` region is executed, +/// otherwise `else_body`. +/// +/// Neither region allows any arguments, and both regions must be terminated with one of: +/// +/// * [Return] to return from the enclosing function directly +/// * [Unreachable] to abort execution +/// * [Yield] to return from the enclosing [If] #[operation( dialect = HirDialect, - traits(SingleBlock, NoRegionArguments) + traits(SingleBlock, NoRegionArguments), + implements(RegionBranchOpInterface) )] pub struct If { #[operand] @@ -126,6 +214,129 @@ pub struct If { else_body: Region, } +impl RegionBranchOpInterface for If { + fn get_entry_successor_regions( + &self, + operands: &[Option>], + ) -> RegionSuccessorIter<'_> { + match operands[0].as_deref() { + None => self.get_successor_regions(RegionBranchPoint::Parent), + Some(value) => { + let cond = if let Some(imm) = value.downcast_ref::() { + imm.as_bool().expect("invalid boolean condition for 'hir.if'") + } else if let Some(yes) = value.downcast_ref::() { + *yes + } else { + panic!( + "expected boolean immediate for '{}' condition, got: {:?}", + self.name(), + value + ) + }; + + if cond { + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.then_body().as_region_ref()), + key: None, + operand_group: 0, + }], + ) + } else { + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.else_body().as_region_ref()), + key: None, + operand_group: 0, + }], + ) + } + } + } + } + + fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_> { + match point { + RegionBranchPoint::Parent => { + // Either branch is reachable on entry + RegionSuccessorIter::new( + self.as_operation(), + [ + RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.then_body().as_region_ref()), + key: None, + operand_group: 0, + }, + RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.else_body().as_region_ref()), + key: None, + operand_group: 0, + }, + ], + ) + } + RegionBranchPoint::Child(_) => { + // Only the parent If is reachable from then_body/else_body + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + // TODO(pauls): Need to handle operand groups properly, as this group refers + // to the operand groups of If, but the results of the If come from the + // Yield contained in the If body + operand_group: 0, + }], + ) + } + } + } + + fn get_region_invocation_bounds( + &self, + operands: &[Option>], + ) -> SmallVec<[InvocationBounds; 1]> { + use smallvec::smallvec; + + match operands[0].as_deref() { + None => { + // Only one region is invoked, and no more than a single time + smallvec![InvocationBounds::NoMoreThan(1); 2] + } + Some(value) => { + let cond = if let Some(imm) = value.downcast_ref::() { + imm.as_bool().expect("invalid boolean condition for 'hir.if'") + } else if let Some(yes) = value.downcast_ref::() { + *yes + } else { + panic!( + "expected boolean immediate for '{}' condition, got: {:?}", + self.name(), + value + ) + }; + if cond { + smallvec![InvocationBounds::Exact(1), InvocationBounds::Never] + } else { + smallvec![InvocationBounds::Never, InvocationBounds::Exact(1)] + } + } + } + } + + #[inline(always)] + fn is_repetitive_region(&self, _index: usize) -> bool { + false + } + + #[inline(always)] + fn has_loop(&self) -> bool { + false + } +} + /// A while is a loop structure composed of two regions: a "before" region, and an "after" region. /// /// The "before" region's entry block parameters correspond to the operands expected by the @@ -140,7 +351,8 @@ pub struct If { /// continue the loop. #[operation( dialect = HirDialect, - traits(SingleBlock) + traits(SingleBlock), + implements(RegionBranchOpInterface) )] pub struct While { #[region] @@ -149,17 +361,266 @@ pub struct While { after: Region, } +impl RegionBranchOpInterface for While { + #[inline] + fn get_entry_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> { + // Operands being forwarded to the `before` region from outside the loop + SuccessorOperandRange::forward(self.operands().all()) + } + + fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_> { + match point { + RegionBranchPoint::Parent => { + // The only successor region when branching from outside the While op is the + // `before` region. + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.before().as_region_ref()), + key: None, + operand_group: 0, + }], + ) + } + RegionBranchPoint::Child(region) => { + // When branching from `before`, the only successor is `after` or the While itself, + // otherwise, when branching from `after` the only successor is `before`. + let after_region = self.after().as_region_ref(); + if region == after_region { + // TODO(pauls): We should handle operands properly here - the While op itself + // does not have any operands for this transfer of control, that comes from the + // Yield op + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.before().as_region_ref()), + key: None, + operand_group: 0, + }], + ) + } else { + // TODO(pauls): We should handle operands properly here - the While op itself + // does not have any operands for this transfer of control, that comes from the + // Condition op + assert!( + region == self.before().as_region_ref(), + "unexpected region branch point" + ); + RegionSuccessorIter::new( + self.as_operation(), + [ + RegionSuccessorInfo { + successor: RegionBranchPoint::Child(after_region), + key: None, + operand_group: 0, + }, + RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + operand_group: 0, + }, + ], + ) + } + } + } + } + + #[inline] + fn get_region_invocation_bounds( + &self, + _operands: &[Option>], + ) -> SmallVec<[InvocationBounds; 1]> { + smallvec![InvocationBounds::Unknown; self.num_regions()] + } + + #[inline(always)] + fn is_repetitive_region(&self, _index: usize) -> bool { + // Both regions are in the loop (`before` -> `after` -> `before` -> `after`) + true + } + + #[inline(always)] + fn has_loop(&self) -> bool { + true + } +} + +/// The [Condition] op is used in conjunction with [While] as the terminator of its `before` region. +/// +/// This op represents a choice between continuing the loop, or exiting the [While] loop and +/// continuing execution after the loop. +/// +/// NOTE: Attempting to use this op in any other context than the one described above is invalid, +/// and the implementation of various interfaces by this op will panic if that assumption is +/// violated. #[operation( dialect = HirDialect, - traits(Terminator, ReturnLike) + traits(Terminator, ReturnLike), + implements(RegionBranchTerminatorOpInterface) )] pub struct Condition { #[operand] - value: Bool, + condition: Bool, + #[operands] + forwarded: AnyType, +} + +impl RegionBranchTerminatorOpInterface for Condition { + #[inline] + fn get_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> { + SuccessorOperandRange::forward(self.forwarded()) + } + + #[inline] + fn get_mutable_successor_operands( + &mut self, + _point: RegionBranchPoint, + ) -> SuccessorOperandRangeMut<'_> { + SuccessorOperandRangeMut::forward(self.forwarded_mut()) + } + + fn get_successor_regions( + &self, + operands: &[Option>], + ) -> SmallVec<[RegionSuccessorInfo; 2]> { + // A [While] loop has two regions: `before` (containing this op), and `after`, which this + // op branches to when the condition is true. If the condition is false, control is + // transferred back to the parent [While] operation, with the forwarded operands of the + // condition used as the results of the [While] operation. + // + // We can return a single statically-known region if we were given a constant condition + // value, otherwise we must return both possible regions. + let cond = operands[0].as_deref(); + match cond { + None => { + let after_region = self + .parent_op() + .unwrap() + .borrow() + .downcast_ref::() + .expect("expected `Condition` op to be a child of a `While` op") + .after() + .as_region_ref(); + // We can't know the condition until runtime, so both the parent `while` op and + // the `after` region could be successors + let if_false = RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + // the `forwarded` operand group + operand_group: 1, + }; + let if_true = RegionSuccessorInfo { + successor: RegionBranchPoint::Child(after_region), + key: None, + // the `forwarded` operand group + operand_group: 1, + }; + smallvec![if_false, if_true] + } + Some(value) => { + // Extract the boolean value of the condition + let should_continue = if let Some(imm) = value.downcast_ref::() { + imm.as_bool().expect("invalid boolean immediate for 'hir.condition'") + } else if let Some(yes) = value.downcast_ref::() { + *yes + } else { + panic!("expected boolean immediate for 'hir.condition'") + }; + + // Choose the specific region successor implied by the condition + if should_continue { + // Proceed to the 'after' region + let after_region = self + .parent_op() + .unwrap() + .borrow() + .downcast_ref::() + .expect("expected `Condition` op to be a child of a `While` op") + .after() + .as_region_ref(); + smallvec![RegionSuccessorInfo { + successor: RegionBranchPoint::Child(after_region), + key: None, + // the `forwarded` operand group + operand_group: 1, + }] + } else { + // Break out to the parent 'while' operation + smallvec![RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + // the `forwarded` operand group + operand_group: 1, + }] + } + } + } + } } +/// The [Yield] op is used in conjunction with [If] and [While] ops as a return-like terminator. +/// +/// * With [If], its regions must be terminated with either a [Yield] or an [Unreachable] op. +/// * With [While], a [Yield] is only valid in the `after` region, and the yielded operands must +/// match the region arguments of the `before` region. Thus to return values from the body of a +/// loop, one must first yield them from the `after` region to the `before` region using [Yield], +/// and then yield them from the `before` region by passsing them as forwarded operands of the +/// [Condition] op. +/// +/// Any number of operands can be yielded at the same time. However, when [Yield] is used in +/// conjunction with [While], the arity and type of the operands must match the region arguments +/// of the `before` region. When used in conjunction with [If], both the `if_true` and `if_false` +/// regions must yield the same arity and types. #[operation( dialect = HirDialect, - traits(Terminator, ReturnLike) + traits(Terminator, ReturnLike), + implements(RegionBranchTerminatorOpInterface) )] -pub struct Yield {} +pub struct Yield { + #[operands] + yielded: AnyType, +} + +impl RegionBranchTerminatorOpInterface for Yield { + #[inline] + fn get_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> { + SuccessorOperandRange::forward(self.yielded()) + } + + fn get_mutable_successor_operands( + &mut self, + _point: RegionBranchPoint, + ) -> SuccessorOperandRangeMut<'_> { + SuccessorOperandRangeMut::forward(self.yielded_mut()) + } + + fn get_successor_regions( + &self, + _operands: &[Option>], + ) -> SmallVec<[RegionSuccessorInfo; 2]> { + // Depending on the type of operation containing this yield, the set of successor regions + // is always known. + // + // * [While] may only have a yield to its `before` region + // * [If] may only yield to its parent + let parent_op = self.parent_op().unwrap(); + let parent_op = parent_op.borrow(); + if parent_op.is::() { + smallvec![RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + operand_group: 0, + }] + } else if let Some(while_op) = parent_op.downcast_ref::() { + let before_region = while_op.before().as_region_ref(); + smallvec![RegionSuccessorInfo { + successor: RegionBranchPoint::Child(before_region), + key: None, + operand_group: 0, + }] + } else { + panic!("unsupported parent operation for '{}': '{}'", self.name(), parent_op.name()) + } + } +} diff --git a/hir2/src/dialects/hir/ops/function.rs b/hir2/src/dialects/hir/ops/function.rs index 2c92b039e..0c474a4c7 100644 --- a/hir2/src/dialects/hir/ops/function.rs +++ b/hir2/src/dialects/hir/ops/function.rs @@ -1,10 +1,10 @@ use crate::{ derive::operation, dialects::hir::HirDialect, - traits::{IsolatedFromAbove, RegionKind, RegionKindInterface, SingleRegion}, - BlockRef, CallableOpInterface, Ident, Operation, OperationRef, RegionRef, Report, Signature, - Symbol, SymbolName, SymbolNameAttr, SymbolRef, SymbolUse, SymbolUseList, SymbolUseRef, - SymbolUsesIter, Usable, Visibility, + traits::{IsolatedFromAbove, SingleRegion}, + Block, BlockRef, CallableOpInterface, Ident, Op, Operation, OperationRef, RegionKind, + RegionKindInterface, RegionRef, Report, Signature, Symbol, SymbolName, SymbolNameAttr, + SymbolRef, SymbolUse, SymbolUseList, SymbolUseRef, SymbolUsesIter, Usable, Visibility, }; trait UsableSymbol = Usable; @@ -20,16 +20,42 @@ trait UsableSymbol = Usable; ) )] pub struct Function { - #[region] - body: RegionRef, #[attr] name: Ident, #[attr] signature: Signature, + #[region] + body: RegionRef, /// The uses of this function as a symbol + #[default] uses: SymbolUseList, } +/// Builders +impl Function { + /// Conver this function from a declaration (no body) to a definition (has a body) by creating + /// the entry block based on the function signature. + /// + /// NOTE: The resulting function is _invalid_ until the block has a terminator inserted into it. + /// + /// This function will panic if an entry block has already been created + pub fn create_entry_block(&mut self) -> BlockRef { + use crate::EntityWithParent; + + assert!(self.body().is_empty(), "entry block already exists"); + let signature = self.signature(); + let block = self + .as_operation() + .context() + .create_block_with_params(signature.params().iter().map(|p| p.ty.clone())); + let mut body = self.body_mut(); + body.push_back(block.clone()); + Block::on_inserted_into_parent(block.clone(), body.as_region_ref()); + block + } +} + +/// Accessors impl Function { #[inline] pub fn entry_block(&self) -> BlockRef { @@ -97,7 +123,7 @@ impl Symbol for Function { fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter { SymbolUsesIter::from_iter(self.uses.iter().filter_map(|user| { if OperationRef::ptr_eq(&from, &user.owner) - || from.borrow().is_proper_ancestor_of(user.owner.clone()) + || from.borrow().is_proper_ancestor_of(&user.owner) { Some(unsafe { SymbolUseRef::from_raw(&*user) }) } else { @@ -120,7 +146,7 @@ impl Symbol for Function { // Unlink previously used symbol { let current_symbol = owner - .get_typed_attribute_mut::(&attr_name) + .get_typed_attribute_mut::(attr_name) .expect("stale symbol user"); unsafe { self.uses.cursor_mut_from_ptr(current_symbol.user.clone()).remove(); diff --git a/hir2/src/dialects/hir/ops/invoke.rs b/hir2/src/dialects/hir/ops/invoke.rs index cb3851513..192e14edf 100644 --- a/hir2/src/dialects/hir/ops/invoke.rs +++ b/hir2/src/dialects/hir/ops/invoke.rs @@ -2,9 +2,6 @@ use midenc_hir_macros::operation; use crate::{dialects::hir::HirDialect, traits::*, *}; -// TODO(pauls): Implement support for: -// -// * Inferring op constraints from callee signature #[operation( dialect = HirDialect, implements(CallOpInterface) @@ -16,6 +13,53 @@ pub struct Exec { arguments: AnyType, } +impl InferTypeOpInterface for Exec { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + + let span = self.span(); + let owner = self.as_operation().as_operation_ref(); + if let Some(symbol) = self.resolve() { + let symbol = symbol.borrow(); + if let Some(callable) = + symbol.as_symbol_operation().as_trait::() + { + let signature = callable.signature(); + for (i, result) in signature.results().iter().enumerate() { + let value = + context.make_result(span, result.ty.clone(), owner.clone(), i as u8); + self.op.results.push(value); + } + + Ok(()) + } else { + Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + span, + "invalid callee: does not implement CallableOpInterface", + ) + .with_secondary_label( + symbol.as_symbol_operation().span, + "symbol refers to this definition", + ) + .into_report()) + } + } else { + Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label(span, "invalid callee: symbol is undefined") + .into_report()) + } + } +} + /* #[operation( dialect = HirDialect, diff --git a/hir2/src/dialects/hir/ops/module.rs b/hir2/src/dialects/hir/ops/module.rs index 9cc4ee2a9..eddaf27d5 100644 --- a/hir2/src/dialects/hir/ops/module.rs +++ b/hir2/src/dialects/hir/ops/module.rs @@ -6,10 +6,11 @@ use crate::{ symbol_table::SymbolUsesIter, traits::{ GraphRegionNoTerminator, HasOnlyGraphRegion, IsolatedFromAbove, NoRegionArguments, - NoTerminator, RegionKind, RegionKindInterface, SingleBlock, SingleRegion, + NoTerminator, SingleBlock, SingleRegion, }, - Ident, InsertionPoint, Operation, OperationRef, Report, Symbol, SymbolName, SymbolNameAttr, - SymbolRef, SymbolTable, SymbolUseList, SymbolUseRef, Usable, Visibility, + Ident, InsertionPoint, Operation, OperationRef, RegionKind, RegionKindInterface, Report, + Symbol, SymbolName, SymbolNameAttr, SymbolRef, SymbolTable, SymbolUseList, SymbolUseRef, + Usable, Visibility, }; #[operation( @@ -90,9 +91,9 @@ impl Symbol for Module { fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter { SymbolUsesIter::from_iter(self.uses.iter().filter_map(|user| { - if OperationRef::ptr_eq(&from, &user.owner) { - Some(unsafe { SymbolUseRef::from_raw(&*user) }) - } else if from.borrow().is_proper_ancestor_of(user.owner.clone()) { + if OperationRef::ptr_eq(&from, &user.owner) + || from.borrow().is_proper_ancestor_of(&user.owner) + { Some(unsafe { SymbolUseRef::from_raw(&*user) }) } else { None @@ -114,7 +115,7 @@ impl Symbol for Module { // Unlink previously used symbol { let current_symbol = owner - .get_typed_attribute_mut::(&attr_name) + .get_typed_attribute_mut::(attr_name) .expect("stale symbol user"); unsafe { self.uses.cursor_mut_from_ptr(current_symbol.user.clone()).remove(); @@ -207,7 +208,7 @@ impl SymbolTable for Module { let mut next_use = next_use.borrow_mut(); let mut op = next_use.owner.borrow_mut(); let symbol_name = op - .get_typed_attribute_mut::(&next_use.symbol) + .get_typed_attribute_mut::(next_use.symbol) .expect("stale symbol user"); symbol_name.name = to; } diff --git a/hir2/src/dialects/hir/ops/primop.rs b/hir2/src/dialects/hir/ops/primop.rs index 859d89894..df0401d74 100644 --- a/hir2/src/dialects/hir/ops/primop.rs +++ b/hir2/src/dialects/hir/ops/primop.rs @@ -13,6 +13,13 @@ pub struct MemGrow { result: UInt32, } +impl InferTypeOpInterface for MemGrow { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type(Type::I32); + Ok(()) + } +} + #[operation( dialect = HirDialect, traits(HasSideEffects, MemoryRead) @@ -22,6 +29,13 @@ pub struct MemSize { result: UInt32, } +impl InferTypeOpInterface for MemSize { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type(Type::I32); + Ok(()) + } +} + #[operation( dialect = HirDialect, traits(HasSideEffects, MemoryWrite) diff --git a/hir2/src/dialects/hir/ops/ternary.rs b/hir2/src/dialects/hir/ops/ternary.rs index 1cac4ffc1..90e4bcd9a 100644 --- a/hir2/src/dialects/hir/ops/ternary.rs +++ b/hir2/src/dialects/hir/ops/ternary.rs @@ -19,27 +19,9 @@ pub struct Select { } impl InferTypeOpInterface for Select { - fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { - use midenc_session::diagnostics::Severity; - let span = self.span(); - let lhs = self.first().ty().clone(); - { - let rhs = self.second(); - if lhs != rhs.ty() { - return Err(context - .session - .diagnostics - .diagnostic(Severity::Error) - .with_message("invalid operand types") - .with_primary_label(span, "operands of this operation are not compatible") - .with_secondary_label( - rhs.span(), - format!("expected this value to have type '{lhs}', but got '{}'", rhs.ty()), - ) - .into_report()); - } - } - self.result_mut().set_type(lhs); + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.first().ty().clone(); + self.result_mut().set_type(ty); Ok(()) } } diff --git a/hir2/src/dialects/hir/ops/unary.rs b/hir2/src/dialects/hir/ops/unary.rs index 785bab848..c27165d28 100644 --- a/hir2/src/dialects/hir/ops/unary.rs +++ b/hir2/src/dialects/hir/ops/unary.rs @@ -1,5 +1,28 @@ use crate::{derive::operation, dialects::hir::HirDialect, traits::*, *}; +macro_rules! infer_return_ty_for_unary_op { + ($Op:ty) => { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.operand().ty().clone(); + self.result_mut().set_type(lhs); + Ok(()) + } + } + }; + + ($Op:ty as $manually_specified_ty:expr) => { + paste::paste! { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type($manually_specified_ty); + Ok(()) + } + } + } + }; +} + /// Increment #[operation ( dialect = HirDialect, @@ -12,6 +35,8 @@ pub struct Incr { result: AnyInteger, } +infer_return_ty_for_unary_op!(Incr); + /// Negation #[operation ( dialect = HirDialect, @@ -24,6 +49,8 @@ pub struct Neg { result: AnyInteger, } +infer_return_ty_for_unary_op!(Neg); + /// Modular inverse #[operation ( dialect = HirDialect, @@ -36,6 +63,8 @@ pub struct Inv { result: IntFelt, } +infer_return_ty_for_unary_op!(Inv); + /// log2(operand) #[operation ( dialect = HirDialect, @@ -48,6 +77,8 @@ pub struct Ilog2 { result: IntFelt, } +infer_return_ty_for_unary_op!(Ilog2); + /// pow2(operand) #[operation ( dialect = HirDialect, @@ -60,6 +91,8 @@ pub struct Pow2 { result: AnyInteger, } +infer_return_ty_for_unary_op!(Pow2); + /// Logical NOT #[operation ( dialect = HirDialect, @@ -72,6 +105,8 @@ pub struct Not { result: Bool, } +infer_return_ty_for_unary_op!(Not); + /// Bitwise NOT #[operation ( dialect = HirDialect, @@ -84,6 +119,8 @@ pub struct Bnot { result: AnyInteger, } +infer_return_ty_for_unary_op!(Bnot); + /// is_odd(operand) #[operation ( dialect = HirDialect, @@ -96,6 +133,8 @@ pub struct IsOdd { result: Bool, } +infer_return_ty_for_unary_op!(IsOdd as Type::I1); + /// Count of non-zero bits (population count) #[operation ( dialect = HirDialect, @@ -108,6 +147,8 @@ pub struct Popcnt { result: UInt32, } +infer_return_ty_for_unary_op!(Popcnt as Type::U32); + /// Count Leading Zeros #[operation ( dialect = HirDialect, @@ -120,6 +161,8 @@ pub struct Clz { result: UInt32, } +infer_return_ty_for_unary_op!(Clz as Type::U32); + /// Count Trailing Zeros #[operation ( dialect = HirDialect, @@ -132,6 +175,8 @@ pub struct Ctz { result: UInt32, } +infer_return_ty_for_unary_op!(Ctz as Type::U32); + /// Count Leading Ones #[operation ( dialect = HirDialect, @@ -144,6 +189,8 @@ pub struct Clo { result: UInt32, } +infer_return_ty_for_unary_op!(Clo as Type::U32); + /// Count Trailing Ones #[operation ( dialect = HirDialect, @@ -155,3 +202,5 @@ pub struct Cto { #[result] result: UInt32, } + +infer_return_ty_for_unary_op!(Cto as Type::U32); From f7aedcb4d3a2af99dcb0468bdefd87e05bcb898a Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Sat, 19 Oct 2024 20:20:02 -0400 Subject: [PATCH 23/31] feat: implement pattern rewriter infrastructure --- hir2/src/patterns.rs | 4 +- hir2/src/patterns/applicator.rs | 91 ++- hir2/src/patterns/driver.rs | 1011 ++++++++++++++++++++++++ hir2/src/patterns/pattern.rs | 270 +++++-- hir2/src/patterns/pattern_set.rs | 2 +- hir2/src/patterns/rewriter.rs | 1252 ++++++++++++++++++++++-------- 6 files changed, 2238 insertions(+), 392 deletions(-) create mode 100644 hir2/src/patterns/driver.rs diff --git a/hir2/src/patterns.rs b/hir2/src/patterns.rs index 7fa4d9c69..b70d11cc9 100644 --- a/hir2/src/patterns.rs +++ b/hir2/src/patterns.rs @@ -1,10 +1,12 @@ mod applicator; +mod driver; mod pattern; mod pattern_set; mod rewriter; pub use self::{ - applicator::PatternApplicator, + applicator::{PatternApplicationError, PatternApplicator}, + driver::*, pattern::*, pattern_set::{FrozenRewritePatternSet, RewritePatternSet}, rewriter::*, diff --git a/hir2/src/patterns/applicator.rs b/hir2/src/patterns/applicator.rs index 75fd4b517..dde6eb86a 100644 --- a/hir2/src/patterns/applicator.rs +++ b/hir2/src/patterns/applicator.rs @@ -2,8 +2,13 @@ use alloc::{collections::BTreeMap, rc::Rc}; use smallvec::SmallVec; -use super::{FrozenRewritePatternSet, PatternBenefit, PatternRewriter, RewritePattern}; -use crate::{Builder, OperationName, OperationRef, Report}; +use super::{FrozenRewritePatternSet, PatternBenefit, RewritePattern, Rewriter}; +use crate::{OperationName, OperationRef, Report}; + +pub enum PatternApplicationError { + NoMatchesFound, + Report(Report), +} /// This type manages the application of a group of rewrite patterns, with a user-provided cost model pub struct PatternApplicator { @@ -56,7 +61,8 @@ impl PatternApplicator { /// Apply the default cost model that solely uses the pattern's static benefit #[inline] pub fn apply_default_cost_model(&mut self) { - self.apply_cost_model(|pattern| pattern.benefit()); + log::debug!("applying default cost model"); + self.apply_cost_model(|pattern| *pattern.benefit()); } /// Walk all of the patterns within the applicator. @@ -74,18 +80,19 @@ impl PatternApplicator { } } - pub fn match_and_rewrite( + pub fn match_and_rewrite( &mut self, op: OperationRef, - rewriter: &mut PatternRewriter, - can_apply: Option, - mut on_failure: Option, - mut on_success: Option, - ) -> Result<(), Report> + rewriter: &mut R, + can_apply: A, + mut on_failure: F, + mut on_success: S, + ) -> Result<(), PatternApplicationError> where - A: Fn(&dyn RewritePattern) -> bool, - F: FnMut(&dyn RewritePattern), - S: FnMut(&dyn RewritePattern) -> Result<(), Report>, + A: for<'a> Fn(&'a dyn RewritePattern) -> bool, + F: for<'a> FnMut(&'a dyn RewritePattern), + S: for<'a> FnMut(&'a dyn RewritePattern) -> Result<(), Report>, + R: Rewriter, { // Check to see if there are patterns matching this specific operation type. let op_name = { @@ -94,9 +101,21 @@ impl PatternApplicator { }; let op_specific_patterns = self.patterns.get(&op_name).map(|p| p.as_slice()).unwrap_or(&[]); + if op_specific_patterns.is_empty() { + log::trace!("no op-specific patterns found for '{op_name}'"); + } else { + log::trace!( + "found {} op-specific patterns for '{op_name}'", + op_specific_patterns.len() + ); + } + + log::trace!("{} op-agnostic patterns available", self.match_any_patterns.len()); + // Process the op-specific patterns and op-agnostic patterns in an interleaved fashion let mut op_patterns = op_specific_patterns.iter().peekable(); let mut any_op_patterns = self.match_any_patterns.iter().peekable(); + let mut result = Err(PatternApplicationError::NoMatchesFound); loop { // Find the next pattern with the highest benefit // @@ -108,20 +127,39 @@ impl PatternApplicator { if let Some(next_any_pattern) = any_op_patterns .next_if(|p| best_pattern.is_none_or(|bp| bp.benefit() < p.benefit())) { + if let Some(best_pattern) = best_pattern { + log::trace!( + "selected op-agnostic pattern '{}' because its benefit is higher than the \ + next best op-specific pattern '{}'", + next_any_pattern.name(), + best_pattern.name() + ); + } else { + log::trace!( + "selected op-agnostic pattern '{}' because no op-specific pattern is \ + available", + next_any_pattern.name() + ); + } best_pattern.replace(next_any_pattern); } else { - // The op-specific pattern is best, so actually consume it from the iterator + // The op-specific pattern is best, if available, so actually consume it from the iterator + if let Some(best_pattern) = best_pattern { + log::trace!("selected op-specific pattern '{}'", best_pattern.name()); + } best_pattern = op_patterns.next(); } // Break if we have exhausted all patterns let Some(best_pattern) = best_pattern else { + log::trace!("all patterns have been exhausted"); break; }; // Can we apply this pattern? - let applicable = can_apply.as_ref().is_none_or(|can_apply| can_apply(&**best_pattern)); + let applicable = can_apply(&**best_pattern); if !applicable { + log::trace!("skipping pattern: can_apply returned false"); continue; } @@ -134,22 +172,27 @@ impl PatternApplicator { // messages/rendering, as the rewrite may invalidate `op` log::debug!("trying to match '{}'", best_pattern.name()); - if best_pattern.match_and_rewrite(op.clone(), rewriter)? { - log::debug!("successfully matched pattern '{}'", best_pattern.name()); - if let Some(on_success) = on_success.as_mut() { - on_success(&**best_pattern)?; + match best_pattern.match_and_rewrite(op.clone(), rewriter) { + Ok(matched) => { + if matched { + log::trace!("pattern matched successfully"); + result = + on_success(&**best_pattern).map_err(PatternApplicationError::Report); + break; + } else { + log::trace!("failed to match pattern"); + on_failure(&**best_pattern); + } } - break; - } else { - // Perform any necessary cleanup - log::debug!("failed to match pattern '{}'", best_pattern.name()); - if let Some(on_failure) = on_failure.as_mut() { + Err(err) => { + log::error!("error occurred during match_and_rewrite: {err}"); + result = Err(PatternApplicationError::Report(err)); on_failure(&**best_pattern); } } } - Ok(()) + result } } diff --git a/hir2/src/patterns/driver.rs b/hir2/src/patterns/driver.rs new file mode 100644 index 000000000..236608de8 --- /dev/null +++ b/hir2/src/patterns/driver.rs @@ -0,0 +1,1011 @@ +use alloc::{collections::BTreeSet, rc::Rc}; +use core::cell::RefCell; + +use smallvec::SmallVec; + +use super::{ + ForwardingListener, FrozenRewritePatternSet, PatternApplicator, PatternRewriter, Rewriter, + RewriterListener, +}; +use crate::{ + traits::{ConstantLike, Foldable, IsolatedFromAbove}, + BlockRef, Builder, Context, InsertionGuard, Listener, OpFoldResult, OperationFolder, + OperationRef, ProgramPoint, Region, RegionRef, Report, RewritePattern, SourceSpan, Spanned, + Value, ValueRef, WalkResult, Walkable, +}; + +/// Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the +/// highest benefit patterns in a greedy worklist driven manner until a fixpoint is reached. +/// +/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be +/// configured using [GreedyRewriteConfig]. +/// +/// This function also performs folding and simple dead-code elimination before attempting to match +/// any of the provided patterns. +/// +/// A region scope can be set using [GreedyRewriteConfig]. By default, the scope is set to the +/// specified region. Only in-scope ops are added to the worklist and only in-scope ops are allowed +/// to be modified by the patterns. +/// +/// Returns `Ok(changed)` if the iterative process converged (i.e., fixpoint was reached) and no +/// more patterns can be matched within the region. The `changed` flag is set to `true` if the IR +/// was modified at all. +/// +/// NOTE: This function does not apply patterns to the region's parent operation. +pub fn apply_patterns_and_fold_region_greedily( + region: RegionRef, + patterns: Rc, + mut config: GreedyRewriteConfig, +) -> Result { + // The top-level operation must be known to be isolated from above to prevent performing + // canonicalizations on operations defined at or above the region containing 'op'. + let context = { + let parent_op = region.borrow().parent().unwrap().borrow(); + assert!( + parent_op.implements::(), + "patterns can only be applied to operations which are isolated from above" + ); + parent_op.context_rc() + }; + + // Set scope if not specified + if config.scope.is_none() { + config.scope = Some(region.clone()); + } + + let mut driver = RegionPatternRewriteDriver::new(context, patterns, config, region); + let converged = driver.simplify(); + if converged.is_err() { + if let Some(max_iterations) = driver.driver.config.max_iterations { + log::trace!("pattern rewrite did not converge after scanning {max_iterations} times"); + } else { + log::trace!("pattern rewrite did not converge"); + } + } + converged +} + +/// Rewrite ops nested under the given operation, which must be isolated from above, by repeatedly +/// applying the highest benefit patterns in a greedy worklist driven manner until a fixpoint is +/// reached. +/// +/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be +/// configured using [GreedyRewriteConfig]. +/// +/// Also performs folding and simple dead-code elimination before attempting to match any of the +/// provided patterns. +/// +/// This overload runs a separate greedy rewrite for each region of the specified op. A region +/// scope can be set in the configuration parameter. By default, the scope is set to the region of +/// the current greedy rewrite. Only in-scope ops are added to the worklist and only in-scope ops +/// and the specified op itself are allowed to be modified by the patterns. +/// +/// NOTE: The specified op may be modified, but it may not be removed by the patterns. +/// +/// Returns `Ok(changed)` if the iterative process converged (i.e., fixpoint was reached) and no +/// more patterns can be matched within the region. The `changed` flag is set to `true` if the IR +/// was modified at all. +/// +/// NOTE: This function does not apply patterns to the given operation itself. +pub fn apply_patterns_and_fold_greedily( + op: OperationRef, + patterns: Rc, + config: GreedyRewriteConfig, +) -> Result { + let mut any_region_changed = false; + let mut failed = false; + let op = op.borrow(); + let mut cursor = op.regions().front(); + while let Some(region) = cursor.as_pointer() { + cursor.move_next(); + match apply_patterns_and_fold_region_greedily(region, patterns.clone(), config.clone()) { + Ok(region_changed) => { + any_region_changed |= region_changed; + } + Err(region_changed) => { + any_region_changed |= region_changed; + failed = true; + } + } + } + + if failed { + Err(any_region_changed) + } else { + Ok(any_region_changed) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum ApplyPatternsAndFoldEffect { + /// No effect, the IR remains unchanged + None, + /// The IR was modified + Changed, + /// The input IR was erased + Erased, +} + +pub type ApplyPatternsAndFoldResult = + Result; + +/// Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy +/// worklist driven manner until a fixpoint is reached. +/// +/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be +/// configured using [GreedyRewriteConfig]. +/// +/// This function also performs folding and simple dead-code elimination before attempting to match +/// any of the provided patterns. +/// +/// Newly created ops and other pre-existing ops that use results of rewritten ops or supply +/// operands to such ops are also processed, unless such ops are excluded via `config.restrict`. +/// Any other ops remain unmodified (i.e., regardless of restrictions). +/// +/// In addition to op restrictions, a region scope can be specified. Only ops within the scope are +/// simplified. This is similar to [apply_patterns_and_fold_greedily], where only ops within the +/// given region/op are simplified by default. If no scope is specified, it is assumed to be the +/// first common enclosing region of the given ops. +/// +/// Note that ops in `ops` could be erased as result of folding, becoming dead, or via pattern +/// rewrites. If more far reaching simplification is desired, [apply_patterns_and_fold_greedily] +/// should be used. +/// +/// Returns `Ok(effect)` if the iterative process converged (i.e., fixpoint was reached) and no more +/// patterns can be matched. `effect` is set to `Changed` if the IR was modified, but at least one +/// operation was not erased. It is set to `Erased` if all of the input ops were erased. +pub fn apply_patterns_and_fold( + ops: &[OperationRef], + patterns: Rc, + mut config: GreedyRewriteConfig, +) -> ApplyPatternsAndFoldResult { + if ops.is_empty() { + return Ok(ApplyPatternsAndFoldEffect::None); + } + + // Determine scope of rewrite + if let Some(scope) = config.scope.as_ref() { + // If a scope was provided, make sure that all ops are in scope. + let all_ops_in_scope = ops.iter().all(|op| scope.borrow().find_ancestor_op(op).is_some()); + assert!(all_ops_in_scope, "ops must be within the specified scope"); + } else { + // Compute scope if none was provided. The scope will remain `None` if there is a top-level + // op among `ops`. + config.scope = Region::find_common_ancestor(ops); + } + + // Start the pattern driver + let max_rewrites = config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX); + let context = ops[0].borrow().context_rc(); + let mut driver = MultiOpPatternRewriteDriver::new(context, patterns, config, ops); + let converged = driver.simplify(ops); + let changed = match converged.as_ref() { + Ok(changed) | Err(changed) => *changed, + }; + let erased = driver.inner.surviving_ops.borrow().is_empty(); + let effect = if erased { + ApplyPatternsAndFoldEffect::Erased + } else if changed { + ApplyPatternsAndFoldEffect::Changed + } else { + ApplyPatternsAndFoldEffect::None + }; + if converged.is_ok() { + Ok(effect) + } else { + log::trace!("pattern rewrite did not converge after {max_rewrites} rewrites"); + Err(effect) + } +} + +/// This enum indicates which ops are put on the worklist during a greedy pattern rewrite +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum GreedyRewriteStrictness { + /// No restrictions on which ops are processed. + #[default] + Any, + /// Only pre-existing and newly created ops are processed. + /// + /// Pre-existing ops are those that were on the worklist at the very beginning. + ExistingAndNew, + /// Only pre-existing ops are processed. + /// + /// Pre-existing ops are those that were on the worklist at the very beginning. + Existing, +} + +/// This enum indicates the level of simplification to be applied to regions during a greedy +/// pattern rewrite. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RegionSimplificationLevel { + /// Disable simplification. + None, + /// Perform basic simplifications (e.g. dead argument elimination) + #[default] + Normal, + /// Perform additional complex/expensive simplifications (e.g. block merging) + Aggressive, +} + +/// Configuration for [GreedyPatternRewriteDriver] +#[derive(Clone)] +pub struct GreedyRewriteConfig { + listener: Option>, + /// If set, only ops within the given region are added to the worklist. + /// + /// If no scope is specified, and no specific region is given when starting the greedy rewrite, + /// then the closest enclosing region of the initial list of operations is used. + scope: Option, + /// If set, specifies the maximum number of times the rewriter will iterate between applying + /// patterns and simplifying regions. + /// + /// NOTE: Only applicable when simplifying entire regions. + max_iterations: Option, + /// If set, specifies the maximum number of rewrites within an iteration. + max_rewrites: Option, + /// Perform control flow optimizations to the region tree after applying all patterns. + /// + /// NOTE: Only applicable when simplifying entire regions. + region_simplification: RegionSimplificationLevel, + /// The restrictions to apply, if any, to operations added to the worklist during the rewrite. + restrict: GreedyRewriteStrictness, + /// This flag specifies the order of initial traversal that populates the rewriter worklist. + /// + /// When true, operations are visited top-down, which is generally more efficient in terms of + /// compilation time. + /// + /// When false, the initial traversal of the region tree is bottom up on each block, which may + /// match larger patterns when given an ambiguous pattern set. + /// + /// NOTE: Only applicable when simplifying entire regions. + use_top_down_traversal: bool, +} +impl Default for GreedyRewriteConfig { + fn default() -> Self { + Self { + listener: None, + scope: None, + max_iterations: core::num::NonZeroU32::new(10), + max_rewrites: None, + region_simplification: Default::default(), + restrict: Default::default(), + use_top_down_traversal: false, + } + } +} +impl GreedyRewriteConfig { + pub fn new_with_listener(listener: impl RewriterListener) -> Self { + Self { + listener: Some(Rc::new(listener)), + ..Default::default() + } + } + + /// Scope rewrites to operations within `region` + pub fn with_scope(&mut self, region: RegionRef) -> &mut Self { + self.scope = Some(region); + self + } + + /// Set the maximum number of times the rewriter will iterate between applying patterns and + /// simplifying regions. + /// + /// If `0` is given, the number of iterations is unlimited. + /// + /// NOTE: Only applicable when simplifying entire regions. + pub fn with_max_iterations(&mut self, max: u32) -> &mut Self { + self.max_iterations = core::num::NonZeroU32::new(max); + self + } + + /// Set the maximum number of rewrites per iteration. + /// + /// If `0` is given, the number of rewrites is unlimited. + /// + /// NOTE: Only applicable when simplifying entire regions. + pub fn with_max_rewrites(&mut self, max: u32) -> &mut Self { + self.max_rewrites = core::num::NonZeroU32::new(max); + self + } + + /// Set the level of control flow optimizations to apply to the region tree. + /// + /// NOTE: Only applicable when simplifying entire regions. + pub fn with_region_simplification_level( + &mut self, + level: RegionSimplificationLevel, + ) -> &mut Self { + self.region_simplification = level; + self + } + + /// Set the level of restriction to apply to operations added to the worklist during the rewrite. + pub fn with_restrictions(&mut self, level: GreedyRewriteStrictness) -> &mut Self { + self.restrict = level; + self + } + + /// Specify whether or not to use a top-down traversal when initially adding operations to the + /// worklist. + pub fn with_top_down_traversal(&mut self, yes: bool) -> &mut Self { + self.use_top_down_traversal = yes; + self + } +} + +pub struct GreedyPatternRewriteDriver { + context: Rc, + worklist: RefCell, + config: GreedyRewriteConfig, + /// Not maintained when `config.restrict` is `GreedyRewriteStrictness::Any` + filtered_ops: RefCell>, + matcher: RefCell, +} + +impl GreedyPatternRewriteDriver { + pub fn new( + context: Rc, + patterns: Rc, + config: GreedyRewriteConfig, + ) -> Self { + // Apply a simple cost model based solely on pattern benefit + let mut matcher = PatternApplicator::new(patterns); + matcher.apply_default_cost_model(); + + Self { + context, + worklist: Default::default(), + config, + filtered_ops: Default::default(), + matcher: RefCell::new(matcher), + } + } +} + +/// Worklist Managment +impl GreedyPatternRewriteDriver { + /// Add the given operation to the worklist + pub fn add_single_op_to_worklist(&self, op: OperationRef) { + if matches!(self.config.restrict, GreedyRewriteStrictness::Any) + || self.filtered_ops.borrow().contains(&op) + { + log::trace!("adding single op '{}' to worklist", op.borrow().name()); + self.worklist.borrow_mut().push(op); + } else { + log::trace!( + "skipped adding single op '{}' to worklist due to strictness level", + op.borrow().name() + ); + } + } + + /// Add the given operation, and its ancestors, to the worklist + pub fn add_to_worklist(&self, op: OperationRef) { + // Gather potential ancestors while looking for a `scope` parent region + let mut ancestors = SmallVec::<[OperationRef; 8]>::default(); + let mut op = Some(op); + while let Some(ancestor_op) = op.take() { + let region = ancestor_op.borrow().parent_region(); + if self.config.scope.as_ref() == region.as_ref() { + ancestors.push(ancestor_op); + for op in ancestors { + self.add_single_op_to_worklist(op); + } + return; + } else { + log::trace!( + "gathering ancestors of '{}' for worklist", + ancestor_op.borrow().name() + ); + ancestors.push(ancestor_op); + } + if let Some(region) = region { + op = region.borrow().parent(); + } else { + log::trace!("reached top level op while searching for ancestors"); + } + } + } + + /// Process operations until the worklist is empty, or `config.max_rewrites` is reached. + /// + /// Returns true if the IR was changed. + pub fn process_worklist(self: Rc) -> bool { + log::debug!("starting processing of greedy pattern rewrite driver worklist"); + let mut rewriter = + PatternRewriter::new_with_listener(self.context.clone(), Rc::clone(&self)); + + let mut changed = false; + let mut num_rewrites = 0u32; + while self.config.max_rewrites.is_none_or(|max| num_rewrites < max.get()) { + let Some(op) = self.worklist.borrow_mut().pop() else { + // Worklist is empty, we've converged + log::debug!("processing worklist complete, rewrites have converged"); + return changed; + }; + + if self.process_worklist_item(&mut rewriter, op) { + changed = true; + num_rewrites += 1; + } + } + + log::debug!( + "processing worklist was canceled after {} rewrites without converging (reached max \ + rewrite limit)", + self.config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX) + ); + + changed + } + + /// Process a single operation from the worklist. + /// + /// Returns true if the IR was changed. + fn process_worklist_item( + &self, + rewriter: &mut PatternRewriter>, + mut op_ref: OperationRef, + ) -> bool { + let op = op_ref.borrow_mut(); + log::trace!("processing operation '{}'", op.name()); + + // If the operation is trivially dead - remove it. + if op.is_trivially_dead() { + drop(op); + rewriter.erase_op(op_ref); + log::trace!("processing complete: operation is trivially dead"); + return true; + } + + // Try to fold this op, unless it is a constant op, as that would lead to an infinite + // folding loop, since the folded result would be immediately materialized as a constant + // op, and then revisited. + if !op.implements::() { + let mut results = SmallVec::<[OpFoldResult; 1]>::default(); + log::trace!("attempting to fold operation.."); + if op.fold(&mut results).is_ok() { + if results.is_empty() { + // Op was modified in-place + self.notify_operation_modified(op_ref.clone()); + log::trace!("operation was succesfully folded/modified in-place"); + return true; + } else { + log::trace!( + "operation was succesfully folded away, to be replaced with: {}", + crate::formatter::DisplayValues::new(results.iter()) + ); + } + + // Op results can be replaced with `results` + assert_eq!( + results.len(), + op.num_results(), + "folder produced incorrect number of results" + ); + let mut rewriter = InsertionGuard::new(&mut **rewriter); + rewriter.set_insertion_point_before(ProgramPoint::Op(op_ref.clone())); + + log::trace!("replacing op with fold results.."); + let mut replacements = SmallVec::<[ValueRef; 2]>::default(); + let mut materialization_succeeded = true; + for (fold_result, result_ty) in results + .into_iter() + .zip(op.results().all().iter().map(|r| r.borrow().ty().clone())) + { + match fold_result { + OpFoldResult::Value(value) => { + assert_eq!( + value.borrow().ty(), + &result_ty, + "folder produced value of incorrect type" + ); + replacements.push(value); + } + OpFoldResult::Attribute(attr) => { + // Materialize attributes as SSA values using a constant op + let span = op.span(); + log::trace!( + "materializing constant for value '{}' and type '{result_ty}'", + attr.render() + ); + let constant_op = op.dialect().materialize_constant( + &mut *rewriter, + attr, + &result_ty, + span, + ); + match constant_op { + None => { + log::trace!( + "materialization failed: cleaning up any materialized ops \ + for {} previous results", + replacements.len() + ); + // If materialization fails, clean up any operations generated for the previous results + let mut replacement_ops = + SmallVec::<[OperationRef; 2]>::default(); + for replacement in replacements.iter() { + let replacement = replacement.borrow(); + assert!( + !replacement.is_used(), + "folder reused existing op for one result, but \ + constant materialization failed for another result" + ); + let replacement_op = replacement.get_defining_op().unwrap(); + if replacement_ops.contains(&replacement_op) { + continue; + } + replacement_ops.push(replacement_op); + } + for replacement_op in replacement_ops { + rewriter.erase_op(replacement_op); + } + materialization_succeeded = false; + break; + } + Some(constant_op) => { + let const_op = constant_op.borrow(); + assert!( + const_op.implements::(), + "materialize_constant produced op that does not implement \ + ConstantLike" + ); + let result: ValueRef = + const_op.results().all()[0].clone().upcast(); + assert_eq!( + result.borrow().ty(), + &result_ty, + "materialize_constant produced incorrect result type" + ); + log::trace!( + "successfully materialized constant as {}", + result.borrow().id() + ); + replacements.push(result); + } + } + } + } + } + + if materialization_succeeded { + log::trace!( + "materialization of fold results was successful, performing replacement.." + ); + drop(op); + rewriter.replace_op_with_values(op_ref, &replacements); + log::trace!( + "fold succeeded: operation was replaced with materialized constants" + ); + return true; + } else { + log::trace!( + "materialization of fold results failed, proceeding without folding" + ); + } + } + } else { + log::trace!("operation could not be folded"); + } + + // Try to match one of the patterns. + // + // The rewriter is automatically notified of any necessary changes, so there is nothing + // else to do here. + // TODO(pauls): if self.config.listener.is_some() { + // + // We need to trigger `notify_pattern_begin` in `can_apply`, and `notify_pattern_end` + // in `on_failure` and `on_success`, but we can't have multiple mutable aliases of + // the listener captured by these closures. + // + // This is another aspect of the listener infra that needs to be handled + log::trace!("attempting to match and rewrite one of the input patterns.."); + let result = if let Some(listener) = self.config.listener.as_deref() { + let op_name = op.name(); + let can_apply = |pattern: &dyn RewritePattern| { + log::trace!("applying pattern {} to op {}", pattern.name(), &op_name); + listener.notify_pattern_begin(pattern, op_ref.clone()); + true + }; + let on_failure = |pattern: &dyn RewritePattern| { + log::trace!("pattern failed to match"); + listener.notify_pattern_end(pattern, false); + }; + let on_success = |pattern: &dyn RewritePattern| { + log::trace!("pattern applied successfully"); + listener.notify_pattern_end(pattern, true); + Ok(()) + }; + drop(op); + self.matcher.borrow_mut().match_and_rewrite( + op_ref.clone(), + &mut **rewriter, + can_apply, + on_failure, + on_success, + ) + } else { + drop(op); + self.matcher.borrow_mut().match_and_rewrite( + op_ref.clone(), + &mut **rewriter, + |_| true, + |_| {}, + |_| Ok(()), + ) + }; + + match result { + Ok(_) => { + log::trace!("processing complete: pattern matched and operation was rewritten"); + true + } + Err(crate::PatternApplicationError::NoMatchesFound) => { + log::debug!("processing complete: exhausted all patterns without finding a match"); + false + } + Err(crate::PatternApplicationError::Report(report)) => { + log::debug!( + "processing complete: error occurred during match and rewrite: {report}" + ); + false + } + } + } + + /// Look over the operands of the provided op for any defining operations that should be re- + /// added to the worklist. This function sho9uld be called when an operation is modified or + /// removed, as it may trigger further simplifications. + fn add_operands_to_worklist(&self, op: OperationRef) { + let current_op = op.borrow(); + for operand in current_op.operands().all() { + // If this operand currently has at most 2 users, add its defining op to the worklist. + // After the op is deleted, then the operand will have at most 1 user left. If it has + // 0 users left, it can be deleted as well, and if it has 1 user left, there may be + // further canonicalization opportunities. + let operand = operand.borrow(); + let Some(def_op) = operand.value().get_defining_op() else { + continue; + }; + + let mut other_user = None; + let mut has_more_than_two_uses = false; + for user in operand.value().iter_uses() { + if user.owner == op || other_user.as_ref().is_some_and(|ou| ou == &user.owner) { + continue; + } + if other_user.is_none() { + other_user = Some(user.owner.clone()); + continue; + } + has_more_than_two_uses = true; + break; + } + if !has_more_than_two_uses { + self.add_to_worklist(def_op); + } + } + } +} + +/// Notifications +impl Listener for GreedyPatternRewriteDriver { + fn kind(&self) -> crate::ListenerType { + crate::ListenerType::Rewriter + } + + /// Notify the driver that the given block was inserted + fn notify_block_inserted( + &self, + block: crate::BlockRef, + prev: Option, + ip: Option, + ) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_block_inserted(block, prev, ip); + } + } + + /// Notify the driver that the specified operation was inserted. + /// + /// Update the worklist as needed: the operation is enqueued depending on scope and strictness + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_operation_inserted(op.clone(), prev.clone()); + } + if matches!(self.config.restrict, GreedyRewriteStrictness::ExistingAndNew) { + self.filtered_ops.borrow_mut().insert(op.clone()); + } + self.add_to_worklist(op); + } +} +impl RewriterListener for GreedyPatternRewriteDriver { + /// Notify the driver that the given block is about to be removed. + fn notify_block_erased(&self, block: BlockRef) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_block_erased(block); + } + } + + /// Notify the driver that the sepcified operation may have been modified in-place. The + /// operation is added to the worklist. + fn notify_operation_modified(&self, op: OperationRef) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_operation_modified(op.clone()); + } + self.add_to_worklist(op); + } + + /// Notify the driver that the specified operation was removed. + /// + /// Update the worklist as needed: the operation and its children are removed from the worklist + fn notify_operation_erased(&self, op: OperationRef) { + // Only ops that are within the configured scope are added to the worklist of the greedy + // pattern rewriter. + // + // A greedy pattern rewrite is not allowed to erase the parent op of the scope region, as + // that would break the worklist handling and some sanity checks. + if let Some(scope) = self.config.scope.as_ref() { + assert!( + scope.borrow().parent().is_some_and(|parent_op| parent_op != op), + "scope region must not be erased during greedy pattern rewrite" + ); + } + + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_operation_erased(op.clone()); + } + + self.add_operands_to_worklist(op.clone()); + self.worklist.borrow_mut().remove(&op); + + if self.config.restrict != GreedyRewriteStrictness::Any { + self.filtered_ops.borrow_mut().remove(&op); + } + } + + /// Notify the driver that the specified operation was replaced. + /// + /// Update the worklist as needed: new users are enqueued + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_operation_replaced_with_values(op, replacement); + } + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_match_failure(span, reason); + } + } +} + +pub struct RegionPatternRewriteDriver { + driver: Rc, + region: RegionRef, +} +impl RegionPatternRewriteDriver { + pub fn new( + context: Rc, + patterns: Rc, + config: GreedyRewriteConfig, + region: RegionRef, + ) -> Self { + use crate::Walkable; + let mut driver = GreedyPatternRewriteDriver::new(context, patterns, config); + // Populate strict mode ops, if applicable + if driver.config.restrict != GreedyRewriteStrictness::Any { + let filtered_ops = driver.filtered_ops.get_mut(); + region.borrow().postwalk(|op| { + filtered_ops.insert(op); + }); + } + Self { + driver: Rc::new(driver), + region, + } + } + + /// Simplify ops inside `self.region`, and simplify the region itself. + /// + /// Returns `Ok(changed)` if the transformation converged, with `changed` indicating whether or + /// not the IR was changed. Otherwise, `Err(changed)` is returned. + pub fn simplify(&mut self) -> Result { + use crate::matchers::Matcher; + + let mut continue_rewrites = false; + let mut iteration = 0; + + while self.driver.config.max_iterations.is_none_or(|max| iteration < max.get()) { + log::trace!("starting iteration {iteration} of region pattern rewrite driver"); + iteration += 1; + + // New iteration: start with an empty worklist + self.driver.worklist.borrow_mut().clear(); + + // `OperationFolder` CSE's constant ops (and may move them into parents regions to + // enable more aggressive CSE'ing). + let context = self.driver.context.clone(); + let mut folder = OperationFolder::new(context, Rc::clone(&self.driver)); + let mut insert_known_constant = |op: OperationRef| { + // Check for existing constants when populating the worklist. This avoids + // accidentally reversing the constant order during processing. + if let Some(const_value) = crate::matchers::constant().matches(&op.borrow()) { + if !folder.insert_known_constant(op, Some(const_value)) { + return true; + } + } + false + }; + + if !self.driver.config.use_top_down_traversal { + // Add operations to the worklist in postorder. + log::trace!("adding operations in postorder"); + self.region.borrow().postwalk(|op| { + if !insert_known_constant(op.clone()) { + self.driver.add_to_worklist(op); + } + }); + } else { + // Add all nested operations to the worklist in preorder. + log::trace!("adding operations in preorder"); + self.region + .borrow() + .prewalk_interruptible(|op| { + if !insert_known_constant(op.clone()) { + self.driver.add_to_worklist(op); + WalkResult::::Continue(()) + } else { + WalkResult::Skip + } + }) + .into_result() + .expect("unexpected error occurred while walking region"); + + // Reverse the list so our loop processes them in-order + self.driver.worklist.borrow_mut().reverse(); + } + + continue_rewrites = self.driver.clone().process_worklist(); + log::trace!( + "processing of worklist for this iteration has completed, \ + changed={continue_rewrites}" + ); + + // After applying patterns, make sure that the CFG of each of the regions is kept up to + // date. + if self.driver.config.region_simplification != RegionSimplificationLevel::None { + let mut rewriter = PatternRewriter::new_with_listener( + self.driver.context.clone(), + Rc::clone(&self.driver), + ); + continue_rewrites |= Region::simplify_all( + &[self.region.clone()], + &mut *rewriter, + self.driver.config.region_simplification, + ) + .is_ok(); + } else { + log::debug!("region simplification was disabled, skipping simplification rewrites"); + } + + if !continue_rewrites { + log::trace!("region pattern rewrites have converged"); + break; + } + } + + // If `continue_rewrites` is false, then the rewrite converged, i.e. the IR wasn't changed + // in the last iteration. + if !continue_rewrites { + Ok(iteration > 1) + } else { + Err(iteration > 1) + } + } +} + +pub struct MultiOpPatternRewriteDriver { + driver: Rc, + inner: Rc, +} + +struct MultiOpPatternRewriteDriverImpl { + surviving_ops: RefCell>, +} + +impl MultiOpPatternRewriteDriver { + pub fn new( + context: Rc, + patterns: Rc, + mut config: GreedyRewriteConfig, + ops: &[OperationRef], + ) -> Self { + let surviving_ops = BTreeSet::from_iter(ops.iter().cloned()); + let inner = Rc::new(MultiOpPatternRewriteDriverImpl { + surviving_ops: RefCell::new(surviving_ops), + }); + let listener = Rc::new(ForwardingListener::new(config.listener.take(), Rc::clone(&inner))); + config.listener = Some(listener); + + let mut driver = GreedyPatternRewriteDriver::new(context.clone(), patterns, config); + if driver.config.restrict != GreedyRewriteStrictness::Any { + driver.filtered_ops.get_mut().extend(ops.iter().cloned()); + } + + Self { + driver: Rc::new(driver), + inner, + } + } + + pub fn simplify(&mut self, ops: &[OperationRef]) -> Result { + // Populate the initial worklist + for op in ops { + self.driver.add_single_op_to_worklist(op.clone()); + } + + // Process ops on the worklist + let changed = self.driver.clone().process_worklist(); + if self.driver.worklist.borrow().is_empty() { + Ok(changed) + } else { + Err(changed) + } + } +} + +impl Listener for MultiOpPatternRewriteDriverImpl { + fn kind(&self) -> crate::ListenerType { + crate::ListenerType::Rewriter + } +} +impl RewriterListener for MultiOpPatternRewriteDriverImpl { + fn notify_operation_erased(&self, op: OperationRef) { + self.surviving_ops.borrow_mut().remove(&op); + } +} + +#[derive(Default)] +struct Worklist(Vec); +impl Worklist { + /// Clear all operations from the worklist + #[inline] + pub fn clear(&mut self) { + self.0.clear() + } + + /// Returns true if the worklist is empty + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Push an operation to the end of the worklist, unless it is already in the worklist. + pub fn push(&mut self, op: OperationRef) { + if self.0.contains(&op) { + return; + } + self.0.push(op); + } + + /// Pop the next operation from the worklist + #[inline] + pub fn pop(&mut self) -> Option { + self.0.pop() + } + + /// Remove `op` from the worklist + pub fn remove(&mut self, op: &OperationRef) { + if let Some(index) = self.0.iter().position(|o| o == op) { + self.0.remove(index); + } + } + + /// Reverse the worklist + pub fn reverse(&mut self) { + self.0.reverse(); + } +} diff --git a/hir2/src/patterns/pattern.rs b/hir2/src/patterns/pattern.rs index a7cc9b7d8..d767bcb0d 100644 --- a/hir2/src/patterns/pattern.rs +++ b/hir2/src/patterns/pattern.rs @@ -3,7 +3,7 @@ use core::{any::TypeId, fmt}; use smallvec::SmallVec; -use super::PatternRewriter; +use super::Rewriter; use crate::{interner, Context, OperationName, OperationRef, Report}; #[derive(Debug)] @@ -80,9 +80,60 @@ impl Ord for PatternBenefit { } } -/// A [Pattern] describes all of the data related to a pattern, but does not express any actual +pub trait Pattern { + fn info(&self) -> &PatternInfo; + /// A name used when printing diagnostics related to this pattern + #[inline(always)] + fn name(&self) -> &'static str { + self.info().name + } + /// The kind of value used to select candidate root operations for this pattern. + #[inline(always)] + fn kind(&self) -> &PatternKind { + &self.info().kind + } + /// Returns the benefit - the inverse of "cost" - of matching this pattern. + /// + /// The benefit of a [Pattern] is always static - rewrites that may have dynamic benefit can be + /// instantiated multiple times (different instances), for each benefit that they may return, + /// and be guarded by different match condition predicates. + #[inline(always)] + fn benefit(&self) -> &PatternBenefit { + &self.info().benefit + } + /// Returns true if this pattern is known to result in recursive application, i.e. this pattern + /// may generate IR that also matches this pattern, but is known to bound the recursion. This + /// signals to the rewrite driver that it is safe to apply this pattern recursively to the + /// generated IR. + #[inline(always)] + fn has_bounded_rewrite_recursion(&self) -> bool { + self.info().has_bounded_recursion + } + /// Return a list of operations that may be generated when rewriting an operation instance + /// with this pattern. + #[inline(always)] + fn generated_ops(&self) -> &[OperationName] { + &self.info().generated_ops + } + /// Return the root operation that this pattern matches. + /// + /// Patterns that can match multiple root types return `None` + #[inline(always)] + fn get_root_operation(&self) -> Option { + self.info().root_operation() + } + /// Return the trait id used to match the root operation of this pattern. + /// + /// If the pattern does not use a trait id for deciding the root match, this returns `None` + #[inline(always)] + fn get_root_trait(&self) -> Option { + self.info().get_root_trait() + } +} + +/// [PatternBase] describes all of the data related to a pattern, but does not express any actual /// pattern logic, i.e. it is solely used for metadata about a pattern. -pub struct Pattern { +pub struct PatternInfo { #[allow(unused)] context: Rc, name: &'static str, @@ -93,7 +144,8 @@ pub struct Pattern { has_bounded_recursion: bool, generated_ops: SmallVec<[OperationName; 0]>, } -impl Pattern { + +impl PatternInfo { /// Create a new [Pattern] from its component parts. pub fn new( context: Rc, @@ -112,39 +164,17 @@ impl Pattern { } } - /// A name used when printing diagnostics related to this pattern - #[inline(always)] - pub const fn name(&self) -> &'static str { - self.name - } - - /// The kind of value used to select candidate root operations for this pattern. - #[inline(always)] - pub const fn kind(&self) -> &PatternKind { - &self.kind - } - - /// Returns the benefit - the inverse of "cost" - of matching this pattern. - /// - /// The benefit of a [Pattern] is always static - rewrites that may have dynamic benefit can be - /// instantiated multiple times (different instances), for each benefit that they may return, - /// and be guarded by different match condition predicates. + /// Set whether or not this pattern has bounded rewrite recursion #[inline(always)] - pub const fn benefit(&self) -> PatternBenefit { - self.benefit - } - - /// Return a list of operations that may be generated when rewriting an operation instance - /// with this pattern. - #[inline] - pub fn generated_ops(&self) -> &[OperationName] { - &self.generated_ops + pub fn with_bounded_rewrite_recursion(&mut self, yes: bool) -> &mut Self { + self.has_bounded_recursion = yes; + self } /// Return the root operation that this pattern matches. /// /// Patterns that can match multiple root types return `None` - pub fn get_root_operation(&self) -> Option { + pub fn root_operation(&self) -> Option { match self.kind { PatternKind::Operation(ref name) => Some(name.clone()), _ => None, @@ -154,26 +184,17 @@ impl Pattern { /// Return the trait id used to match the root operation of this pattern. /// /// If the pattern does not use a trait id for deciding the root match, this returns `None` - pub fn get_root_trait(&self) -> Option { + pub fn root_trait(&self) -> Option { match self.kind { PatternKind::Trait(type_id) => Some(type_id), _ => None, } } +} - /// Returns true if this pattern is known to result in recursive application, i.e. this pattern - /// may generate IR that also matches this pattern, but is known to bound the recursion. This - /// signals to the rewrite driver that it is safe to apply this pattern recursively to the - /// generated IR. - #[inline(always)] - pub const fn has_bounded_rewrite_recursion(&self) -> bool { - self.has_bounded_recursion - } - - /// Set whether or not this pattern has bounded rewrite recursion +impl Pattern for PatternInfo { #[inline(always)] - pub fn with_bounded_rewrite_recursion(&mut self, yes: bool) -> &mut Self { - self.has_bounded_recursion = yes; + fn info(&self) -> &PatternInfo { self } } @@ -187,22 +208,12 @@ impl Pattern { /// /// Implementations must provide `matches` and `rewrite` implementations, from which the /// `match_and_rewrite` implementation is derived. -pub trait RewritePattern { - /// A name to use for this pattern in diagnostics - fn name(&self) -> &'static str { - core::any::type_name::() - } - /// The pattern used to match candidate root operations for this rewrite. - fn kind(&self) -> &PatternKind; - /// The estimated benefit of this pattern - fn benefit(&self) -> PatternBenefit; - /// Whether or not this rewrite pattern has bounded recursion - fn has_bounded_rewrite_recursion(&self) -> bool; +pub trait RewritePattern: Pattern { /// Rewrite the IR rooted at the specified operation with the result of this pattern, generating /// any new operations with the specified builder. If an unexpected error is encountered, i.e. /// an internal compiler error, it is emitted through the normal diagnostic system, and the IR /// is left in a valid state. - fn rewrite(&self, op: OperationRef, rewriter: &mut PatternRewriter); + fn rewrite(&self, op: OperationRef, rewriter: &mut dyn Rewriter); /// Attempt to match this pattern against the IR rooted at the specified operation, /// which is the same operation as [Pattern::kind]. @@ -213,7 +224,7 @@ pub trait RewritePattern { fn match_and_rewrite( &self, op: OperationRef, - rewriter: &mut PatternRewriter, + rewriter: &mut dyn Rewriter, ) -> Result { if self.matches(op.clone())? { self.rewrite(op, rewriter); @@ -224,3 +235,148 @@ pub trait RewritePattern { } } } + +#[cfg(test)] +mod tests { + use alloc::rc::Rc; + + use pretty_assertions::{assert_eq, assert_str_eq}; + + use super::*; + use crate::{dialects::hir::*, *}; + + /// In Miden, `n << 1` is vastly inferior to `n * 2` in cost, so reverse it + /// + /// NOTE: These two ops have slightly different semantics, a real implementation would have + /// to handle the edge cases. + struct ConvertShiftLeftBy1ToMultiply { + info: PatternInfo, + } + impl ConvertShiftLeftBy1ToMultiply { + pub fn new(context: Rc) -> Self { + let dialect = context.get_or_register_dialect::(); + let op_name = ::register_with(&*dialect); + let mut info = PatternInfo::new( + context, + "convert-shl1-to-mul2", + PatternKind::Operation(op_name), + PatternBenefit::new(1), + ); + info.with_bounded_rewrite_recursion(true); + Self { info } + } + } + impl Pattern for ConvertShiftLeftBy1ToMultiply { + fn info(&self) -> &PatternInfo { + &self.info + } + } + impl RewritePattern for ConvertShiftLeftBy1ToMultiply { + fn matches(&self, op: OperationRef) -> Result { + use crate::matchers::{self, match_chain, match_op, MatchWith, Matcher}; + + let binder = MatchWith(|op: &UnsafeIntrusiveEntityRef| { + log::trace!( + "found matching 'hir.shl' operation, checking if `shift` operand is foldable" + ); + let op = op.borrow(); + let shift = op.shift().as_operand_ref(); + let matched = matchers::foldable_operand_of::().matches(&shift); + matched.and_then(|imm| { + log::trace!("`shift` operand is an immediate: {imm}"); + let imm = imm.as_u64(); + if imm.is_none() { + log::trace!("`shift` operand is not a valid u64 value"); + } + if imm.is_some_and(|imm| imm == 1) { + Some(()) + } else { + None + } + }) + }); + log::trace!("attempting to match '{}'", self.name()); + let matched = match_chain(match_op::(), binder).matches(&op.borrow()).is_some(); + log::trace!("'{}' matched: {matched}", self.name()); + Ok(matched) + } + + fn rewrite(&self, op: OperationRef, rewriter: &mut dyn Rewriter) { + log::trace!("found match, rewriting '{}'", op.borrow().name()); + let (span, lhs) = { + let shl = op.borrow(); + let shl = shl.downcast_ref::().unwrap(); + let span = shl.span(); + let lhs = shl.lhs().as_value_ref(); + (span, lhs) + }; + let constant_builder = rewriter.create::(span); + let constant: UnsafeIntrusiveEntityRef = + constant_builder(Immediate::U32(2)).unwrap(); + let shift = constant.borrow().result().as_value_ref(); + let mul_builder = rewriter.create::(span); + let mul = mul_builder(lhs, shift, Overflow::Wrapping).unwrap(); + let mul = mul.borrow().as_operation().as_operation_ref(); + log::trace!("replacing shl with mul"); + rewriter.replace_op(op, mul); + } + } + + #[test] + fn rewrite_pattern_api_test() { + let mut builder = env_logger::Builder::from_env("MIDENC_TRACE"); + builder.init(); + + let context = Rc::new(Context::default()); + let pattern = ConvertShiftLeftBy1ToMultiply::new(Rc::clone(&context)); + + let mut builder = OpBuilder::new(Rc::clone(&context)); + let mut function = { + let builder = builder.create::(SourceSpan::default()); + let id = Ident::new("test".into(), SourceSpan::default()); + let signature = Signature::new([AbiParam::new(Type::U32)], [AbiParam::new(Type::U32)]); + builder(id, signature).unwrap() + }; + + // Define function body + { + let mut func = function.borrow_mut(); + let mut builder = FunctionBuilder::new(&mut func); + let shift = builder.ins().u32(1, SourceSpan::default()).unwrap(); + let block = builder.current_block(); + let lhs = block.borrow().arguments()[0].clone().upcast(); + let result = builder.ins().shl(lhs, shift, SourceSpan::default()).unwrap(); + builder.ins().ret(Some(result), SourceSpan::default()).unwrap(); + } + + // Construct pattern set + let mut rewrites = RewritePatternSet::new(builder.context_rc()); + rewrites.push(pattern); + let rewrites = Rc::new(FrozenRewritePatternSet::new(rewrites)); + + // Execute pattern driver + let mut config = GreedyRewriteConfig::default(); + config.with_region_simplification_level(RegionSimplificationLevel::None); + let result = crate::apply_patterns_and_fold_greedily( + function.borrow().as_operation().as_operation_ref(), + rewrites, + config, + ); + + // The rewrite should converge and modify the IR + assert_eq!(result, Ok(true)); + + // Confirm that the expected rewrite occurred + let func = function.borrow(); + let output = func.as_operation().to_string(); + let expected = "\ +hir.function public @test(v0: u32) -> u32 { +^block0(v0: u32): + v1 = hir.constant 1 : u32; + v3 = hir.constant 2 : u32; + v4 = hir.mul v0, v3 : u32 #[overflow = wrapping]; + hir.ret v4; +};"; + assert_str_eq!(output.as_str(), expected); + } +} diff --git a/hir2/src/patterns/pattern_set.rs b/hir2/src/patterns/pattern_set.rs index e8375964b..0d7fb0134 100644 --- a/hir2/src/patterns/pattern_set.rs +++ b/hir2/src/patterns/pattern_set.rs @@ -71,7 +71,7 @@ impl FrozenRewritePatternSet { PatternKind::Trait(ref trait_id) => { for dialect in this.context.registered_dialects().values() { for op in dialect.registered_ops().iter() { - if op.implements_trait_id(&trait_id) { + if op.implements_trait_id(trait_id) { this.op_specific_patterns .entry(op.clone()) .or_default() diff --git a/hir2/src/patterns/rewriter.rs b/hir2/src/patterns/rewriter.rs index 88514e317..fc6b23900 100644 --- a/hir2/src/patterns/rewriter.rs +++ b/hir2/src/patterns/rewriter.rs @@ -1,186 +1,176 @@ -#![allow(unused)] use alloc::rc::Rc; use core::ops::{Deref, DerefMut}; +use smallvec::SmallVec; + use crate::{ - BlockRef, Builder, Context, InsertionPoint, Listener, ListenerType, OpBuilder, OpOperand, - OpResultRef, OperationRef, Pattern, RegionRef, Report, SourceSpan, Type, ValueRef, + Block, BlockRef, Builder, Context, EntityWithParent, InsertionGuard, InsertionPoint, Listener, + ListenerType, OpBuilder, OpOperandImpl, Operation, OperationRef, Pattern, PostOrderBlockIter, + ProgramPoint, RegionRef, Report, SourceSpan, Usable, ValueRef, }; -/// A special type of `RewriterBase` that coordinates the application of a rewrite pattern on the -/// current IR being matched, providing a way to keep track of any mutations made. -/// -/// This type should be used to perform all necessary IR mutations within a rewrite pattern, as -/// the pattern driver may be tracking various state that would be invalidated when a mutation takes -/// place. -pub struct PatternRewriter { - rewriter: RewriterImpl, - recoverable: bool, -} -impl PatternRewriter { - pub fn new(builder: OpBuilder) -> Self { - Self { - rewriter: RewriterImpl::new(builder), - recoverable: false, - } - } - - #[inline] - pub const fn can_recover_from_rewrite_failure(&self) -> bool { - self.recoverable - } -} -impl Deref for PatternRewriter { - type Target = RewriterImpl; +/// A [Rewriter] is a [Builder] extended with additional functionality that is of primary use when +/// rewriting the IR after it is initially constructed. It is the basis on which the pattern +/// rewriter infrastructure is built. +pub trait Rewriter: Builder + RewriterListener { + /// Returns true if this rewriter has a listener attached. + /// + /// When no listener is present, fast paths can be taken when rewriting the IR, whereas a + /// listener requires breaking mutations up into individual actions so that the listener can + /// be made aware of all of them, in the order they occur. + fn has_listener(&self) -> bool; - #[inline(always)] - fn deref(&self) -> &Self::Target { - &self.rewriter - } -} -impl DerefMut for PatternRewriter { - #[inline(always)] - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.rewriter - } -} + /// Replace the results of the given operation with the specified list of values (replacements). + /// + /// The result types of the given op and the replacements must match. The original op is erased. + fn replace_op_with_values(&mut self, op: OperationRef, values: &[ValueRef]) { + assert_eq!(op.borrow().num_results(), values.len()); -pub struct RewriterImpl { - builder: OpBuilder, - listener: Option>, -} + // Replace all result uses, notifies listener of the modifications + self.replace_all_op_uses_with_values(op.clone(), values); -impl Listener for RewriterImpl { - fn kind(&self) -> ListenerType { - ListenerType::Rewriter + // Erase the op and notify the listener + self.erase_op(op); } - fn notify_block_inserted( - &mut self, - block: BlockRef, - prev: Option, - ip: Option, - ) { - if let Some(listener) = self.listener.as_deref_mut() { - listener.notify_block_inserted(block, prev, ip); - } else { - self.builder.notify_block_inserted(block, prev, ip); - } - } + /// Replace the results of the given operation with the specified replacement op. + /// + /// The result types of the two ops must match. The original op is erased. + fn replace_op(&mut self, op: OperationRef, new_op: OperationRef) { + assert_eq!(op.borrow().num_results(), new_op.borrow().num_results()); - fn notify_operation_inserted(&mut self, op: OperationRef, prev: Option) { - if let Some(listener) = self.listener.as_deref_mut() { - listener.notify_operation_inserted(op, prev); - } else { - self.builder.notify_operation_inserted(op, prev); - } - } -} + // Replace all result uses, notifies listener of the modifications + self.replace_all_op_uses_with(op.clone(), new_op); -impl Builder for RewriterImpl { - #[inline(always)] - fn context(&self) -> &Context { - self.builder.context() + // Erase the op and notify the listener + self.erase_op(op); } - #[inline(always)] - fn context_rc(&self) -> Rc { - self.builder.context_rc() - } + /// This method erases an operation that is known to have no uses. + fn erase_op(&mut self, mut op: OperationRef) { + assert!(!op.borrow().is_used(), "expected op to have no uses"); - #[inline(always)] - fn insertion_point(&self) -> Option<&InsertionPoint> { - self.builder.insertion_point() - } + // If no listener is attached, the op can be dropped all at once. + if self.has_listener() { + op.borrow_mut().erase(); + return; + } - #[inline(always)] - fn clear_insertion_point(&mut self) -> Option { - self.builder.clear_insertion_point() - } + // Helper function that erases a single operation + fn erase_single_op( + mut op: OperationRef, + rewrite_listener: &mut R, + ) { + let mut op_mut = op.borrow_mut(); + if cfg!(debug_assertions) { + // All nested ops should have been erased already + assert!(op_mut.regions().iter().all(|r| r.is_empty()), "expected empty regions"); + // All users should have been erased already if the op is in a region with SSA dominance + if op_mut.is_used() { + if let Some(region) = op_mut.parent_region() { + assert!( + region.borrow().may_be_graph_region(), + "expected that op has no uses" + ); + } + } + } - #[inline(always)] - fn restore_insertion_point(&mut self, ip: Option) { - self.builder.restore_insertion_point(ip); - } + rewrite_listener.notify_operation_erased(op); - #[inline(always)] - fn set_insertion_point(&mut self, ip: InsertionPoint) { - self.builder.set_insertion_point(ip); - } + // Explicitly drop all uses in case the op is in a graph region + op_mut.drop_all_uses(); + op_mut.erase(); + } - #[inline] - fn create_block

( - &mut self, - parent: RegionRef, - ip: Option, - args: P, - ) -> BlockRef - where - P: IntoIterator, - { - self.builder.create_block(parent, ip, args) - } + // Nested ops must be erased one-by-one, so that listeners have a consistent view of the + // IR every time a notification is triggered. Users must be erased before definitions, i.e. + // in post-order, reverse dominance. + fn erase_tree(op_ref: OperationRef, rewriter: &mut R) { + // Erase nested ops + let op = op_ref.borrow(); + for region in op.regions() { + // Erase all blocks in the right order. Successors should be erased before + // predecessors because successor blocks may use values defined in predecessor + // blocks. A post-order traversal of blocks within a region visits successors before + // predecessors. Repeat the traversal until the region is empty. (The block graph + // could be disconnected.) + let mut erased_blocks = SmallVec::<[BlockRef; 4]>::default(); + while !region.is_empty() { + erased_blocks.clear(); + for block_ref in PostOrderBlockIter::new(region.entry_block_ref().unwrap()) { + let block = block_ref.borrow(); + let mut cursor = block.body().front(); + while let Some(op) = cursor.as_pointer() { + erase_tree(op, rewriter); + cursor.move_next(); + } + erased_blocks.push(block_ref); + } + for mut block in erased_blocks.drain(..) { + // Explicitly drop all uses in case there is a cycle in the block + // graph. + for arg in block.borrow_mut().arguments_mut() { + arg.borrow_mut().uses_mut().clear(); + } + block.borrow_mut().drop_all_uses(); + rewriter.erase_block(block); + } + } + } + erase_single_op(op_ref, rewriter); + } - #[inline] - fn create_block_before

(&mut self, before: BlockRef, args: P) -> BlockRef - where - P: IntoIterator, - { - self.builder.create_block_before(before, args) + erase_tree(op, self); } - #[inline] - fn insert(&mut self, op: OperationRef) { - self.builder.insert(op); - } -} + /// This method erases all operations in a block. + fn erase_block(&mut self, mut block: BlockRef) { + assert!(!block.borrow().is_used(), "expected 'block' to be unused"); -impl RewriterImpl { - pub fn new(builder: OpBuilder) -> Self { - Self { - builder, - listener: None, + let mut blk = block.borrow_mut(); + let mut cursor = blk.body_mut().back_mut(); + while let Some(op) = cursor.remove() { + assert!(!op.borrow().is_used(), "expected 'op' to be unused"); + self.erase_op(op); } - } - pub fn with_listener(mut self, listener: impl RewriterListener) -> Self { - self.listener = Some(Box::new(listener)); - self + // Notify the listener that the block is about to be removed. + self.notify_block_erased(block.clone()); + + // Remove block from parent region + let mut region = blk.parent().expect("expected 'block' to have a parent region"); + let mut region_mut = region.borrow_mut(); + let mut cursor = unsafe { region_mut.body_mut().cursor_mut_from_ptr(block.clone()) }; + cursor.remove(); } /// Move the blocks that belong to `region` before the given insertion point in another region, /// `ip`. The two regions must be different. The caller is responsible for creating or /// updating the operation transferring flow of control to the region, and passing it the /// correct block arguments. - pub fn inline_region_before(&mut self, region: RegionRef, ip: InsertionPoint) { - todo!() - } - - /// Replace the results of the given operation with the specified list of values (replacements). - /// - /// The result types of the given op and the replacements must match. The original op is erased. - pub fn replace_op_with_values(&mut self, op: OperationRef, values: V) - where - V: IntoIterator, - { - todo!() - } - - /// Replace the results of the given operation with the specified replacement op. - /// - /// The result types of the two ops must match. The original op is erased. - pub fn replace_op(&mut self, op: OperationRef, new_op: OperationRef) { - todo!() - } - - /// This method erases an operation that is known to have no uses. - pub fn erase_op(&mut self, op: OperationRef) { - todo!() - } - - /// This method erases all operations in a block. - pub fn erase_block(&mut self, block: BlockRef) { - todo!() + fn inline_region_before(&mut self, mut region: RegionRef, ip: BlockRef) { + let mut parent = ip.borrow().parent().expect("invalid 'ip': must be attached to a region"); + assert!(!RegionRef::ptr_eq(®ion, &parent), "cannot inline a region into itself"); + let mut region_body = region.borrow_mut().body_mut().take(); + if !self.has_listener() { + { + let mut region_cursor = region_body.front_mut(); + while let Some(block) = region_cursor.as_pointer() { + Block::on_inserted_into_parent(block, parent.clone()); + region_cursor.move_next(); + } + } + let mut parent_region = parent.borrow_mut(); + let parent_body = parent_region.body_mut(); + let mut cursor = unsafe { parent_body.cursor_mut_from_ptr(ip.clone()) }; + cursor.splice_before(region_body); + } else { + // Move blocks from beginning of the region one-by-one + for block in region_body { + self.move_block_before(block, ip.clone()); + } + } } /// Inline the operations of block `src` before the given insertion point. @@ -191,13 +181,62 @@ impl RewriterImpl { /// successors. Similarly, if the source block is inserted somewhere in the middle (or /// beginning) of the dest block, the source block must have no successors. Otherwise, the /// resulting IR would have unreachable operations. - pub fn inline_block_before( - &mut self, - src: BlockRef, - ip: InsertionPoint, - args: Option<&[ValueRef]>, - ) { - todo!() + fn inline_block_before(&mut self, mut src: BlockRef, ip: OperationRef, args: &[ValueRef]) { + assert!( + args.len() == src.borrow().num_arguments(), + "incorrect # of argument replacement values" + ); + + // The source block will be deleted, so it should not have any users (i.e., there should be + // no predecessors). + assert!(!src.borrow().has_predecessors(), "expected 'src' to have no predecessors"); + + let mut dest = ip.borrow().parent().expect("expected 'ip' to belong to a block"); + let insert_at_block_end = + OperationRef::ptr_eq(&ip, &dest.borrow().body().back().as_pointer().unwrap()); + if insert_at_block_end { + // The source block will be inserted at the end of the dest block, so the + // dest block should have no successors. Otherwise, the inserted operations + // will be unreachable. + assert!(!dest.borrow().has_successors(), "expected 'dest' to have no successors"); + } else { + // The source block will be inserted in the middle of the dest block, so + // the source block should have no successors. Otherwise, the remainder of + // the dest block would be unreachable. + assert!(!src.borrow().has_successors(), "expected 'src' to have no successors"); + } + + // Replace all of the successor arguments with the provided values. + for (arg, replacement) in src.borrow().arguments().iter().zip(args) { + self.replace_all_uses_of_value_with(arg.clone().upcast(), replacement.clone()); + } + + // Move operations from the source block to the dest block and erase the source block. + if self.has_listener() { + let mut src_ops = src.borrow_mut().body_mut().take(); + let mut src_cursor = src_ops.front_mut(); + while let Some(op) = src_cursor.remove() { + self.move_op_before(op, ip.clone()); + } + } else { + // Fast path: If no listener is attached, move all operations at once. + let mut dest_block = dest.borrow_mut(); + let dest_body = dest_block.body_mut(); + let mut src_ops = src.borrow_mut().body_mut().take(); + { + let mut src_cursor = src_ops.front_mut(); + while let Some(op) = src_cursor.as_pointer() { + Operation::on_inserted_into_parent(op, dest.clone()); + src_cursor.move_next(); + } + } + let mut cursor = unsafe { dest_body.cursor_mut_from_ptr(ip) }; + cursor.splice_before(src_ops); + } + + // Erase the source block. + assert!(src.borrow().body().is_empty(), "expected 'src' to be empty"); + self.erase_block(src); } /// Inline the operations of block `src` into the end of block `dest`. The source block will be @@ -206,197 +245,637 @@ impl RewriterImpl { /// /// The dest block must have no successors. Otherwise, the resulting IR will have unreachable /// operations. - pub fn merge_blocks(&mut self, src: BlockRef, dest: BlockRef, args: Option<&[ValueRef]>) { - todo!() + fn merge_blocks(&mut self, src: BlockRef, dest: BlockRef, args: &[ValueRef]) { + let ip = dest.borrow().body().back().as_pointer().unwrap(); + self.inline_block_before(src, ip, args); } /// Split the operations starting at `ip` (inclusive) out of the given block into a new block, /// and return it. - pub fn split_block(&mut self, block: BlockRef, ip: InsertionPoint) -> BlockRef { - todo!() - } + fn split_block(&mut self, mut block: BlockRef, ip: OperationRef) -> BlockRef { + // Fast path: if no listener is attached, split the block directly + if !self.has_listener() { + return block.borrow_mut().split_block(ip); + } - /// Unlink this operation from its current block and insert it right before `ip`, which - /// may be in the same or another block in the same function. - pub fn move_op_before(&mut self, op: OperationRef, ip: InsertionPoint) { - todo!() - } + assert_eq!( + block, + ip.borrow().parent().expect("expected 'ip' to be attached to a block"), + "expected 'ip' to be in 'block'" + ); - /// Unlink this operation from its current block and insert it right after `ip`, which may be - /// in the same or another block in the same function. - pub fn move_op_after(&mut self, op: OperationRef, ip: InsertionPoint) { - todo!() + let region = block + .borrow() + .parent() + .expect("cannot split a block which is not attached to a region"); + + // `create_block` sets the insertion point to the start of the new block + let mut guard = InsertionGuard::new(self); + let new_block = guard.create_block(region, Some(block.clone()), &[]); + + // If `ip` points to the end of the block, no ops should be moved + if OperationRef::ptr_eq(&ip, &block.borrow().body().back().as_pointer().unwrap()) { + return new_block; + } + + // Move ops one-by-one from the end of `block` to the start of `new_block`. + // Stop when the operation pointed to by `ip` has been moved. + let mut block = block.borrow_mut(); + let mut cursor = block.body_mut().back_mut(); + while let Some(op) = cursor.remove() { + let is_last_move = OperationRef::ptr_eq(&op, &ip); + guard.move_op_before(op, new_block.borrow().body().front().as_pointer().unwrap()); + if is_last_move { + break; + } + } + + new_block } /// Unlink this block and insert it right before `ip`. - pub fn move_block_before(&mut self, block: BlockRef, ip: InsertionPoint) { - todo!() + fn move_block_before(&mut self, mut block: BlockRef, ip: BlockRef) { + let current_region = block.borrow().parent(); + block.borrow_mut().move_before(ip.clone()); + self.notify_block_inserted(block, current_region, Some(ip)); } - /// This method is used to notify the rewriter that an in-place operation modification is about - /// to happen. - /// - /// The returned guard can be used to access the rewriter, as well as finalize or cancel the - /// in-place modification. - pub fn start_in_place_modification( - &mut self, - op: OperationRef, - ) -> InPlaceModificationGuard<'_> { - InPlaceModificationGuard::new(self, op) + /// Unlink this operation from its current block and insert it right before `ip`, which + /// may be in the same or another block in the same function. + fn move_op_before(&mut self, mut op: OperationRef, ip: OperationRef) { + let current_block = op.borrow().parent(); + let current_ip = current_block.map(|block| { + let blk = block.borrow(); + let cursor = unsafe { blk.body().cursor_from_ptr(op.clone()) }; + if let Some(next_op) = cursor.peek_next().as_pointer() { + InsertionPoint::before(next_op) + } else if let Some(prev_op) = cursor.peek_prev().as_pointer() { + InsertionPoint::after(prev_op) + } else { + InsertionPoint::after(block) + } + }); + op.borrow_mut().move_before(ProgramPoint::Op(ip.clone())); + self.notify_operation_inserted(op, current_ip); } - /// Performs an in-place modification of `root` using `callback`, taking care of notifying the - /// rewriter of progress and outcome of the modification. - pub fn modify_op_in_place(&mut self, root: OperationRef, callback: F) - where - F: Fn(InPlaceModificationGuard<'_>), - { - let guard = self.start_in_place_modification(root); - callback(guard); + /// Unlink this operation from its current block and insert it right after `ip`, which may be + /// in the same or another block in the same function. + fn move_op_after(&mut self, mut op: OperationRef, ip: OperationRef) { + let current_block = op.borrow().parent(); + let current_ip = current_block.map(|block| { + let blk = block.borrow(); + let cursor = unsafe { blk.body().cursor_from_ptr(op.clone()) }; + if let Some(next_op) = cursor.peek_next().as_pointer() { + InsertionPoint::before(next_op) + } else if let Some(prev_op) = cursor.peek_prev().as_pointer() { + InsertionPoint::after(prev_op) + } else { + InsertionPoint::after(block) + } + }); + op.borrow_mut().move_after(ProgramPoint::Op(ip.clone())); + self.notify_operation_inserted(op, current_ip); } /// Find uses of `from` and replace them with `to`. /// /// Notifies the listener about every in-place op modification (for every use that was replaced). - pub fn replace_all_uses_of_value_with(&mut self, from: ValueRef, to: ValueRef) { - todo!() + fn replace_all_uses_of_value_with(&mut self, mut from: ValueRef, mut to: ValueRef) { + let mut from_val = from.borrow_mut(); + let from_uses = from_val.uses_mut(); + let mut cursor = from_uses.front_mut(); + while let Some(mut operand) = cursor.remove() { + let to = &mut to; + let op = operand.borrow().owner.clone(); + self.notify_operation_modification_started(&op); + operand.borrow_mut().value = to.clone(); + to.borrow_mut().insert_use(operand); + self.notify_operation_modified(op); + } } /// Find uses of `from` and replace them with `to`. /// /// Notifies the listener about every in-place op modification (for every use that was replaced). - pub fn replace_all_uses_of_block_with(&mut self, from: BlockRef, to: BlockRef) { - todo!() + fn replace_all_uses_of_block_with(&mut self, mut from: BlockRef, mut to: BlockRef) { + let mut from_block = from.borrow_mut(); + let from_uses = from_block.uses_mut(); + let mut cursor = from_uses.front_mut(); + while let Some(mut operand) = cursor.remove() { + let to = &mut to; + let op = operand.borrow().owner.clone(); + self.notify_operation_modification_started(&op); + operand.borrow_mut().block = to.clone(); + to.borrow_mut().insert_use(operand); + self.notify_operation_modified(op); + } } /// Find uses of `from` and replace them with `to`. /// /// Notifies the listener about every in-place op modification (for every use that was replaced). - pub fn replace_all_uses_with(&mut self, from: &[ValueRef], to: &[ValueRef]) { - todo!() + fn replace_all_uses_with(&mut self, from: &[ValueRef], to: &[ValueRef]) { + assert_eq!(from.len(), to.len(), "incorrect number of replacements"); + for (from, to) in from.iter().cloned().zip(to.iter().cloned()) { + self.replace_all_uses_of_value_with(from, to); + } } /// Find uses of `from` and replace them with `to`. /// /// Notifies the listener about every in-place modification (for every use that was replaced), /// and that the `from` operation is about to be replaced. - pub fn replace_all_op_uses_with_values(&mut self, from: OperationRef, to: &[ValueRef]) { - todo!() + fn replace_all_op_uses_with_values(&mut self, from: OperationRef, to: &[ValueRef]) { + self.notify_operation_replaced_with_values(from.clone(), to); + + let results = from + .borrow() + .results() + .all() + .iter() + .map(|result| result.borrow().as_value_ref()) + .collect::>(); + + self.replace_all_uses_with(&results, to); } /// Find uses of `from` and replace them with `to`. /// /// Notifies the listener about every in-place modification (for every use that was replaced), /// and that the `from` operation is about to be replaced. - pub fn replace_all_op_uses_with(&mut self, from: OperationRef, to: OperationRef) { - todo!() + fn replace_all_op_uses_with(&mut self, from: OperationRef, to: OperationRef) { + self.notify_operation_replaced(from.clone(), to.clone()); + + let from_results = from + .borrow() + .results() + .all() + .iter() + .map(|result| result.borrow().as_value_ref()) + .collect::>(); + + let to_results = to + .borrow() + .results() + .all() + .iter() + .map(|result| result.borrow().as_value_ref()) + .collect::>(); + + self.replace_all_uses_with(&from_results, &to_results); } - /// Find uses of `from` and replace them with `to`, if `predicate` returns true. + /// Find uses of `from` within `block` and replace them with `to`. /// /// Notifies the listener about every in-place op modification (for every use that was replaced). /// /// Returns true if all uses were replaced, otherwise false. - pub fn maybe_replace_uses_of_value_with

( + fn replace_op_uses_within_block( + &mut self, + from: OperationRef, + to: &[ValueRef], + block: BlockRef, + ) -> bool { + let parent_op = block.borrow().parent_op(); + self.maybe_replace_op_uses_with(from, to, |operand| { + !parent_op + .as_ref() + .is_some_and(|op| op.borrow().is_proper_ancestor_of(&operand.owner)) + }) + } + + /// Find uses of `from` and replace them with `to`, except if the user is in `exceptions`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + fn replace_all_uses_except( &mut self, from: ValueRef, to: ValueRef, - predicate: P, + exceptions: &[OperationRef], + ) { + self.maybe_replace_uses_of_value_with(from, to, |operand| { + !exceptions.contains(&operand.owner) + }); + } +} + +/// An extension trait for [Rewriter] implementations. +/// +/// This trait contains functionality that is not object safe, and would prevent using [Rewriter] as +/// a trait object. It is automatically implemented for all [Rewriter] impls. +pub trait RewriterExt: Rewriter { + /// Find uses of `from` and replace them with `to`, if `should_replace` returns true. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + fn maybe_replace_uses_of_value_with

( + &mut self, + mut from: ValueRef, + mut to: ValueRef, + should_replace: P, ) -> bool where - P: Fn(OpOperand) -> bool, + P: Fn(&OpOperandImpl) -> bool, { - todo!() + let mut all_replaced = true; + let mut from = from.borrow_mut(); + let from_uses = from.uses_mut(); + let mut cursor = from_uses.front_mut(); + while let Some(user) = cursor.as_pointer() { + if should_replace(&user.borrow()) { + let owner = user.borrow().owner.clone(); + self.notify_operation_modification_started(&owner); + let mut operand = cursor.remove().unwrap(); + { + operand.borrow_mut().value = to.clone(); + } + to.borrow_mut().insert_use(operand); + self.notify_operation_modified(owner); + } else { + all_replaced = false; + cursor.move_next(); + } + } + all_replaced } - /// Find uses of `from` and replace them with `to`, if `predicate` returns true. + /// Find uses of `from` and replace them with `to`, if `should_replace` returns true. /// /// Notifies the listener about every in-place op modification (for every use that was replaced). /// /// Returns true if all uses were replaced, otherwise false. - pub fn maybe_replace_uses_with

( + fn maybe_replace_uses_with

( &mut self, from: &[ValueRef], to: &[ValueRef], - predicate: P, + should_replace: P, ) -> bool where - P: Fn(OpOperand) -> bool, + P: Fn(&OpOperandImpl) -> bool, { - todo!() + assert_eq!(from.len(), to.len(), "incorrect number of replacements"); + let mut all_replaced = true; + for (from, to) in from.iter().cloned().zip(to.iter().cloned()) { + all_replaced &= self.maybe_replace_uses_of_value_with(from, to, &should_replace); + } + all_replaced } - /// Find uses of `from` and replace them with `to`, if `predicate` returns true. + /// Find uses of `from` and replace them with `to`, if `should_replace` returns true. /// /// Notifies the listener about every in-place op modification (for every use that was replaced). /// /// Returns true if all uses were replaced, otherwise false. - pub fn maybe_replace_op_uses_with

( + fn maybe_replace_op_uses_with

( &mut self, from: OperationRef, to: &[ValueRef], - predicate: P, + should_replace: P, ) -> bool where - P: Fn(OpOperand) -> bool, + P: Fn(&OpOperandImpl) -> bool, { - todo!() + let results = SmallVec::<[ValueRef; 2]>::from_iter( + from.borrow().results.all().iter().cloned().map(|result| result.upcast()), + ); + self.maybe_replace_uses_with(&results, to, should_replace) } +} - /// Find uses of `from` within `block` and replace them with `to`. +impl RewriterExt for R {} + +#[allow(unused_variables)] +pub trait RewriterListener: Listener { + /// Notify the listener that the specified block is about to be erased. /// - /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// At this point, the block has zero uses. + fn notify_block_erased(&self, block: BlockRef) {} + + /// Notify the listener that an in-place modification of the specified operation has started + fn notify_operation_modification_started(&self, op: &OperationRef) {} + + /// Notify the listener that an in-place modification of the specified operation was canceled + fn notify_operation_modification_canceled(&self, op: &OperationRef) {} + + /// Notify the listener that the specified operation was modified in-place. + fn notify_operation_modified(&self, op: OperationRef) {} + + /// Notify the listener that all uses of the specified operation's results are about to be + /// replaced with the results of another operation. This is called before the uses of the old + /// operation have been changed. /// - /// Returns true if all uses were replaced, otherwise false. - pub fn replace_op_uses_within_block( - &mut self, - from: OperationRef, - to: &[ValueRef], - block: BlockRef, - ) -> bool { - let parent_op = block.borrow().parent_op(); - self.maybe_replace_op_uses_with(from, to, |operand| { - let operand = operand.borrow(); - !parent_op - .as_ref() - .is_some_and(|op| op.borrow().is_proper_ancestor_of(operand.owner.clone())) - }) + /// By default, this function calls the "operation replaced with values" notification. + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + let replacement = replacement.borrow(); + let values = replacement + .results() + .all() + .iter() + .cloned() + .map(|result| result.upcast()) + .collect::>(); + self.notify_operation_replaced_with_values(op, &values); } - /// Find uses of `from` and replace them with `to`, except if the user is in `exceptions`. + /// Notify the listener that all uses of the specified operation's results are about to be + /// replaced with the given range of values, potentially produced by other operations. This is + /// called before the uses of the operation have been changed. + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) {} + + /// Notify the listener that the specified operation is about to be erased. At this point, the + /// operation has zero uses. /// - /// Notifies the listener about every in-place op modification (for every use that was replaced). - pub fn replace_all_uses_except( - &mut self, - from: ValueRef, - to: ValueRef, - exceptions: &[OperationRef], - ) { - self.maybe_replace_uses_of_value_with(from, to, |operand| { - let operand = operand.borrow(); - !exceptions.contains(&operand.owner) - }); + /// NOTE: This notification is not triggered when unlinking an operation. + fn notify_operation_erased(&self, op: OperationRef) {} + + /// Notify the listener that the specified pattern is about to be applied at the specified root + /// operation. + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {} + + /// Notify the listener that a pattern application finished with the specified status. + /// + /// `true` indicates that the pattern was applied successfully. `false` indicates that the + /// pattern could not be applied. The pattern may have communicated the reason for the failure + /// with `notify_match_failure` + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {} + + /// Notify the listener that the pattern failed to match, and provide a diagnostic explaining + /// the reason why the failure occurred. + fn notify_match_failure(&self, span: SourceSpan, reason: Report) {} +} + +impl RewriterListener for Option { + fn notify_block_erased(&self, block: BlockRef) { + if let Some(listener) = self.as_ref() { + listener.notify_block_erased(block); + } } - pub fn notify_match_failure(&mut self, span: SourceSpan, report: Report) { - if let Some(listener) = self.listener.as_mut() { - listener.notify_match_failure(span, report); + fn notify_operation_modification_started(&self, op: &OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_modification_started(op); + } + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_modification_canceled(op); + } + } + + fn notify_operation_modified(&self, op: OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_modified(op); + } + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_replaced(op, replacement); + } + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_replaced_with_values(op, replacement); + } + } + + fn notify_operation_erased(&self, op: OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_erased(op); + } + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_pattern_begin(pattern, op); + } + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + if let Some(listener) = self.as_ref() { + listener.notify_pattern_end(pattern, success); + } + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + if let Some(listener) = self.as_ref() { + listener.notify_match_failure(span, reason); } } } +impl RewriterListener for Box { + fn notify_block_erased(&self, block: BlockRef) { + (**self).notify_block_erased(block); + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + (**self).notify_operation_modification_started(op); + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + (**self).notify_operation_modification_canceled(op); + } + + fn notify_operation_modified(&self, op: OperationRef) { + (**self).notify_operation_modified(op); + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + (**self).notify_operation_replaced(op, replacement); + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + (**self).notify_operation_replaced_with_values(op, replacement); + } + + fn notify_operation_erased(&self, op: OperationRef) { + (**self).notify_operation_erased(op) + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + (**self).notify_pattern_begin(pattern, op); + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + (**self).notify_pattern_end(pattern, success); + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + (**self).notify_match_failure(span, reason); + } +} + +impl RewriterListener for Rc { + fn notify_block_erased(&self, block: BlockRef) { + (**self).notify_block_erased(block); + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + (**self).notify_operation_modification_started(op); + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + (**self).notify_operation_modification_canceled(op); + } + + fn notify_operation_modified(&self, op: OperationRef) { + (**self).notify_operation_modified(op); + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + (**self).notify_operation_replaced(op, replacement); + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + (**self).notify_operation_replaced_with_values(op, replacement); + } + + fn notify_operation_erased(&self, op: OperationRef) { + (**self).notify_operation_erased(op) + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + (**self).notify_pattern_begin(pattern, op); + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + (**self).notify_pattern_end(pattern, success); + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + (**self).notify_match_failure(span, reason); + } +} + +/// A listener of kind `Rewriter` that does nothing +pub struct NoopRewriterListener; +impl Listener for NoopRewriterListener { + #[inline] + fn kind(&self) -> ListenerType { + ListenerType::Rewriter + } + + #[inline(always)] + fn notify_operation_inserted(&self, _op: OperationRef, _prev: Option) {} + + #[inline(always)] + fn notify_block_inserted( + &self, + _block: BlockRef, + _prev: Option, + _ip: Option, + ) { + } +} +impl RewriterListener for NoopRewriterListener { + fn notify_operation_replaced(&self, _op: OperationRef, _replacement: OperationRef) {} +} + +pub struct ForwardingListener { + base: Base, + derived: Derived, +} +impl ForwardingListener { + pub fn new(base: Base, derived: Derived) -> Self { + Self { base, derived } + } +} +impl Listener for ForwardingListener { + fn kind(&self) -> ListenerType { + self.derived.kind() + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + self.base.notify_block_inserted(block.clone(), prev.clone(), ip.clone()); + self.derived.notify_block_inserted(block, prev, ip); + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + self.base.notify_operation_inserted(op.clone(), prev.clone()); + self.derived.notify_operation_inserted(op, prev); + } +} +impl RewriterListener + for ForwardingListener +{ + fn notify_block_erased(&self, block: BlockRef) { + self.base.notify_block_erased(block.clone()); + self.derived.notify_block_erased(block); + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + self.base.notify_operation_modification_started(op); + self.derived.notify_operation_modification_started(op); + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + self.base.notify_operation_modification_canceled(op); + self.derived.notify_operation_modification_canceled(op); + } + + fn notify_operation_modified(&self, op: OperationRef) { + self.base.notify_operation_modified(op.clone()); + self.derived.notify_operation_modified(op); + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + self.base.notify_operation_replaced(op.clone(), replacement.clone()); + self.derived.notify_operation_replaced(op, replacement); + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + self.base.notify_operation_replaced_with_values(op.clone(), replacement); + self.derived.notify_operation_replaced_with_values(op, replacement); + } + + fn notify_operation_erased(&self, op: OperationRef) { + self.base.notify_operation_erased(op.clone()); + self.derived.notify_operation_erased(op); + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + self.base.notify_pattern_begin(pattern, op.clone()); + self.derived.notify_pattern_begin(pattern, op); + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + self.base.notify_pattern_end(pattern, success); + self.derived.notify_pattern_end(pattern, success); + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + let err = Report::msg(format!("{reason}")); + self.base.notify_match_failure(span, reason); + self.derived.notify_match_failure(span, err); + } +} + /// Wraps an in-place modification of an [Operation] to ensure the rewriter is properly notified /// about the progress and outcome of the in-place notification. /// /// This is a minor efficiency win, as it avoids creating a new operation, and removing the old one, /// but also often allows simpler code in the client. -pub struct InPlaceModificationGuard<'a> { - rewriter: &'a mut RewriterImpl, +pub struct InPlaceModificationGuard<'a, R: ?Sized + Rewriter> { + rewriter: &'a mut R, op: OperationRef, canceled: bool, } -impl<'a> InPlaceModificationGuard<'a> { - fn new(rewriter: &'a mut RewriterImpl, op: OperationRef) -> Self { +impl<'a, R> InPlaceModificationGuard<'a, R> +where + R: ?Sized + Rewriter, +{ + pub fn new(rewriter: &'a mut R, op: OperationRef) -> Self { + rewriter.notify_operation_modification_started(&op); Self { rewriter, op, @@ -405,7 +884,7 @@ impl<'a> InPlaceModificationGuard<'a> { } #[inline] - pub fn rewriter(&mut self) -> &mut RewriterImpl { + pub fn rewriter(&mut self) -> &mut R { self.rewriter } @@ -422,102 +901,257 @@ impl<'a> InPlaceModificationGuard<'a> { /// Signals the end of an in-place modification of the current operation. pub fn finalize(self) {} } -impl core::ops::Deref for InPlaceModificationGuard<'_> { - type Target = RewriterImpl; +impl core::ops::Deref for InPlaceModificationGuard<'_, R> { + type Target = R; #[inline(always)] fn deref(&self) -> &Self::Target { self.rewriter } } -impl core::ops::DerefMut for InPlaceModificationGuard<'_> { +impl core::ops::DerefMut for InPlaceModificationGuard<'_, R> { #[inline(always)] fn deref_mut(&mut self) -> &mut Self::Target { self.rewriter } } -impl Drop for InPlaceModificationGuard<'_> { +impl Drop for InPlaceModificationGuard<'_, R> { fn drop(&mut self) { if self.canceled { - //self.rewriter.cancel_op_modification(self.op.clone()); - todo!("cancel op modification") + self.rewriter.notify_operation_modification_canceled(&self.op); } else { - //self.rewriter.finalize_op_modification(self.op.clone()); - todo!("finalize op modification") + self.rewriter.notify_operation_modified(self.op.clone()); } } } -pub trait RewriterListener: Listener { - /// Notify the listener that the specified block is about to be erased. - /// - /// At this point, the block has zero uses. - fn notify_block_erased(&mut self, block: BlockRef) {} +/// A special type of `RewriterBase` that coordinates the application of a rewrite pattern on the +/// current IR being matched, providing a way to keep track of any mutations made. +/// +/// This type should be used to perform all necessary IR mutations within a rewrite pattern, as +/// the pattern driver may be tracking various state that would be invalidated when a mutation takes +/// place. +pub struct PatternRewriter { + rewriter: RewriterImpl, + recoverable: bool, +} - /// Notify the listener that the specified operation was modified in-place. - fn notify_operation_modified(&mut self, op: OperationRef) {} +impl PatternRewriter { + pub fn new(context: Rc) -> Self { + let rewriter = RewriterImpl::new(context); + Self { + rewriter, + recoverable: false, + } + } - /// Notify the listener that all uses of the specified operation's results are about to be - /// replaced with the results of another operation. This is called before the uses of the old - /// operation have been changed. - /// - /// By default, this function calls the "operation replaced with values" notification. - fn notify_operation_replaced(&mut self, op: OperationRef, replacement: OperationRef) { - let replacement = replacement.borrow(); - self.notify_operation_replaced_with_values(op, replacement.results().all().as_slice()); + pub fn from_builder(builder: OpBuilder) -> Self { + let (context, _, ip) = builder.into_parts(); + let mut rewriter = RewriterImpl::new(context); + rewriter.restore_insertion_point(ip); + Self { + rewriter, + recoverable: false, + } } +} - /// Notify the listener that all uses of the specified operation's results are about to be - /// replaced with the given range of values, potentially produced by other operations. This is - /// called before the uses of the operation have been changed. - fn notify_operation_replaced_with_values( - &mut self, - op: OperationRef, - replacement: &[OpResultRef], - ) { +impl PatternRewriter { + pub fn new_with_listener(context: Rc, listener: L) -> Self { + let rewriter = RewriterImpl::::new(context).with_listener(listener); + Self { + rewriter, + recoverable: false, + } } - /// Notify the listener that the specified operation is about to be erased. At this point, the - /// operation has zero uses. - /// - /// NOTE: This notification is not triggered when unlinking an operation. - fn notify_operation_erased(&mut self, op: OperationRef) {} + #[inline] + pub const fn can_recover_from_rewrite_failure(&self) -> bool { + self.recoverable + } +} +impl Deref for PatternRewriter { + type Target = RewriterImpl; - /// Notify the listener that the specified pattern is about to be applied at the specified root - /// operation. - fn notify_pattern_begin(&mut self, pattern: &Pattern, op: OperationRef) {} + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.rewriter + } +} +impl DerefMut for PatternRewriter { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.rewriter + } +} - /// Notify the listener that a pattern application finished with the specified status. - /// - /// `Ok` indicates that the pattern was applied successfully. `Err` indicates that the pattern - /// could not be applied. The pattern may have communicated the reason for the failure with - /// `notify_match_failure` - fn notify_pattern_end(&mut self, pattern: &Pattern, status: Result<(), Report>) {} +pub struct RewriterImpl { + context: Rc, + listener: Option, + ip: Option, +} - /// Notify the listener that the pattern failed to match, and provide a diagnostic explaining - /// the reason why the failure occurred. - fn notify_match_failure(&mut self, span: SourceSpan, reason: Report) {} +impl RewriterImpl { + pub fn new(context: Rc) -> Self { + Self { + context, + listener: None, + ip: None, + } + } + + pub fn with_listener(self, listener: L2) -> RewriterImpl + where + L2: Listener, + { + RewriterImpl { + context: self.context, + listener: Some(listener), + ip: self.ip, + } + } } -struct RewriterListenerBase { - kind: ListenerType, +impl From> for RewriterImpl { + #[inline] + fn from(builder: OpBuilder) -> Self { + let (context, listener, ip) = builder.into_parts(); + Self { + context, + listener, + ip, + } + } } -impl Listener for RewriterListenerBase { + +impl Builder for RewriterImpl { #[inline(always)] + fn context(&self) -> &Context { + &self.context + } + + #[inline(always)] + fn context_rc(&self) -> Rc { + self.context.clone() + } + + #[inline(always)] + fn insertion_point(&self) -> Option<&InsertionPoint> { + self.ip.as_ref() + } + + #[inline(always)] + fn clear_insertion_point(&mut self) -> Option { + self.ip.take() + } + + #[inline(always)] + fn restore_insertion_point(&mut self, ip: Option) { + self.ip = ip; + } + + #[inline(always)] + fn set_insertion_point(&mut self, ip: InsertionPoint) { + self.ip = Some(ip); + } +} + +impl Rewriter for RewriterImpl { + #[inline(always)] + fn has_listener(&self) -> bool { + self.listener.is_some() + } +} + +impl Listener for RewriterImpl { fn kind(&self) -> ListenerType { ListenerType::Rewriter } + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_inserted(op, prev); + } + } + fn notify_block_inserted( - &mut self, + &self, block: BlockRef, prev: Option, - ip: Option, + ip: Option, ) { - todo!() + if let Some(listener) = self.listener.as_ref() { + listener.notify_block_inserted(block, prev, ip); + } + } +} + +impl RewriterListener for RewriterImpl { + fn notify_block_erased(&self, block: BlockRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_block_erased(block); + } + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_modification_started(op); + } + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_modification_canceled(op); + } } - fn notify_operation_inserted(&mut self, op: OperationRef, prev: Option) { - todo!() + fn notify_operation_modified(&self, op: OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_modified(op); + } + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + if self.listener.is_some() { + let replacement = replacement.borrow(); + let values = replacement + .results() + .all() + .iter() + .cloned() + .map(|result| result.upcast()) + .collect::>(); + self.notify_operation_replaced_with_values(op, &values); + } + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_replaced_with_values(op, replacement); + } + } + + fn notify_operation_erased(&self, op: OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_erased(op); + } + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_pattern_begin(pattern, op); + } + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_pattern_end(pattern, success); + } + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_match_failure(span, reason); + } } } From 05c30361f6e305dffa61eaff27bb4e61d1268468 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Wed, 23 Oct 2024 08:19:38 -0400 Subject: [PATCH 24/31] chore: normalize use of fxhash-based hash maps --- Cargo.lock | 26 ++++++++----------- Cargo.toml | 2 +- codegen/masm/Cargo.toml | 2 +- codegen/masm/src/emulator/breakpoints.rs | 5 ++-- codegen/masm/src/emulator/mod.rs | 8 +++--- frontend-wasm/Cargo.toml | 1 + frontend-wasm/src/component/dfg.rs | 6 +++-- frontend-wasm/src/component/inline.rs | 9 +++---- frontend-wasm/src/component/parser.rs | 13 +++++----- frontend-wasm/src/component/translator.rs | 5 ++-- frontend-wasm/src/component/types/mod.rs | 6 +++-- .../src/component/types/resources.rs | 2 +- frontend-wasm/src/miden_abi/mod.rs | 3 +-- frontend-wasm/src/module/mod.rs | 3 +-- .../src/module/module_translation_state.rs | 3 +-- frontend-wasm/src/translation_utils.rs | 3 --- hir-analysis/Cargo.toml | 2 +- hir-analysis/src/data.rs | 4 +-- hir-analysis/src/liveness.rs | 3 +-- hir-analysis/src/validation/block.rs | 1 - hir-symbol/Cargo.toml | 1 + hir-symbol/build.rs | 4 ++- hir-transform/Cargo.toml | 1 - hir-transform/src/adt/scoped_map.rs | 2 +- hir-transform/src/inline_blocks.rs | 1 - hir-transform/src/spill.rs | 1 - hir-transform/src/split_critical_edges.rs | 1 - hir/Cargo.toml | 1 + hir/src/asm/import.rs | 5 ++-- hir/src/dataflow.rs | 3 +-- hir/src/lib.rs | 4 +++ hir/src/module.rs | 1 - hir/src/parser/ast/convert.rs | 4 +-- hir/src/pass/analysis.rs | 17 ++++-------- hir2/src/folder.rs | 7 +++-- hir2/src/lib.rs | 4 +++ 36 files changed, 75 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 51ffbb835..82b385845 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3230,6 +3230,7 @@ dependencies = [ "bitcode", "cranelift-entity", "env_logger 0.11.5", + "hashbrown 0.14.5", "intrusive-collections", "inventory", "log", @@ -3245,7 +3246,6 @@ dependencies = [ "paste", "petgraph", "proptest", - "rustc-hash 1.1.0", "serde 1.0.210", "serde_bytes", "smallvec", @@ -3321,6 +3321,7 @@ dependencies = [ "derive_more", "expect-test", "gimli", + "hashbrown 0.14.5", "indexmap 2.5.0", "log", "miden-core", @@ -3328,7 +3329,7 @@ dependencies = [ "midenc-hir", "midenc-hir-type", "midenc-session", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", "wasmparser 0.214.0", "wat", @@ -3342,6 +3343,7 @@ dependencies = [ "cranelift-entity", "derive_more", "either", + "hashbrown 0.14.5", "indexmap 2.5.0", "intrusive-collections", "inventory", @@ -3361,7 +3363,7 @@ dependencies = [ "petgraph", "pretty_assertions", "rustc-demangle", - "rustc-hash 1.1.0", + "rustc-hash", "serde 1.0.210", "serde_bytes", "serde_repr", @@ -3377,13 +3379,13 @@ dependencies = [ "anyhow", "cranelift-bforest", "cranelift-entity", + "hashbrown 0.14.5", "intrusive-collections", "inventory", "miden-thiserror", "midenc-hir", "midenc-session", "pretty_assertions", - "rustc-hash 1.1.0", "smallvec", ] @@ -3404,10 +3406,11 @@ version = "0.0.6" dependencies = [ "Inflector", "compact_str", + "hashbrown 0.14.5", "lock_api", "miden-formatting", "parking_lot", - "rustc-hash 1.1.0", + "rustc-hash", "serde 1.0.210", "toml 0.8.19", ] @@ -3423,7 +3426,6 @@ dependencies = [ "midenc-hir-analysis", "midenc-session", "pretty_assertions", - "rustc-hash 1.1.0", "smallvec", ] @@ -3470,7 +3472,7 @@ dependencies = [ "petgraph", "pretty_assertions", "rustc-demangle", - "rustc-hash 1.1.0", + "rustc-hash", "serde 1.0.210", "serde_bytes", "serde_repr", @@ -4417,7 +4419,7 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.0.0", + "rustc-hash", "rustls", "socket2 0.5.7", "thiserror", @@ -4434,7 +4436,7 @@ dependencies = [ "bytes", "rand", "ring", - "rustc-hash 2.0.0", + "rustc-hash", "rustls", "slab", "thiserror", @@ -4751,12 +4753,6 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.0.0" diff --git a/Cargo.toml b/Cargo.toml index 15e62c400..5d3406143 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ parking_lot_core = "0.9" petgraph = "0.6" pretty_assertions = "1.0" proptest = "1.4" -rustc-hash = "1.1" +rustc-hash = { version = "2.0", default-features = false } serde = { version = "1.0.208", features = ["serde_derive", "alloc", "rc"] } serde_repr = "0.1.19" serde_bytes = "0.11.15" diff --git a/codegen/masm/Cargo.toml b/codegen/masm/Cargo.toml index 0a3b9cb96..cfae3541b 100644 --- a/codegen/masm/Cargo.toml +++ b/codegen/masm/Cargo.toml @@ -22,6 +22,7 @@ cranelift-entity.workspace = true intrusive-collections.workspace = true inventory.workspace = true log.workspace = true +hashbrown.workspace = true miden-assembly.workspace = true miden-core.workspace = true miden-processor.workspace = true @@ -32,7 +33,6 @@ midenc-hir-transform.workspace = true midenc-session = { workspace = true, features = ["serde"] } paste.workspace = true petgraph.workspace = true -rustc-hash.workspace = true serde.workspace = true serde_bytes.workspace = true smallvec.workspace = true diff --git a/codegen/masm/src/emulator/breakpoints.rs b/codegen/masm/src/emulator/breakpoints.rs index 19058ac4a..b04aee7bd 100644 --- a/codegen/masm/src/emulator/breakpoints.rs +++ b/codegen/masm/src/emulator/breakpoints.rs @@ -1,7 +1,6 @@ use std::collections::BTreeSet; -use midenc_hir::FunctionIdent; -use rustc_hash::{FxHashMap, FxHashSet}; +use midenc_hir::{FunctionIdent, FxHashMap, FxHashSet}; use super::{Addr, BreakpointEvent, EmulatorEvent, Instruction, InstructionPointer}; use crate::BlockId; @@ -143,7 +142,7 @@ impl BreakpointManager { /// Set the given breakpoint pub fn set(&mut self, bp: Breakpoint) { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; match bp { Breakpoint::All => { diff --git a/codegen/masm/src/emulator/mod.rs b/codegen/masm/src/emulator/mod.rs index 526ccb7a1..892abdfe9 100644 --- a/codegen/masm/src/emulator/mod.rs +++ b/codegen/masm/src/emulator/mod.rs @@ -8,8 +8,10 @@ use std::{cell::RefCell, cmp, rc::Rc, sync::Arc}; use memory::Memory; use miden_assembly::{ast::ProcedureName, LibraryNamespace}; -use midenc_hir::{assert_matches, Felt, FieldElement, FunctionIdent, Ident, OperandStack, Stack}; -use rustc_hash::{FxHashMap, FxHashSet}; +use midenc_hir::{ + assert_matches, Felt, FieldElement, FunctionIdent, FxHashMap, FxHashSet, Ident, OperandStack, + Stack, +}; use self::functions::{Activation, Stub}; pub use self::{ @@ -313,7 +315,7 @@ impl Emulator { /// /// An error is returned if a module with the same name is already loaded. pub fn load_module(&mut self, module: Arc) -> Result<(), EmulationError> { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; assert_matches!( self.status, diff --git a/frontend-wasm/Cargo.toml b/frontend-wasm/Cargo.toml index dfa63d375..55f45e443 100644 --- a/frontend-wasm/Cargo.toml +++ b/frontend-wasm/Cargo.toml @@ -21,6 +21,7 @@ gimli = { version = "0.31", default-features = false, features = [ ] } indexmap.workspace = true log.workspace = true +hashbrown.workspace = true miden-core.workspace = true midenc-hir.workspace = true midenc-hir-type.workspace = true diff --git a/frontend-wasm/src/component/dfg.rs b/frontend-wasm/src/component/dfg.rs index 8375b1552..2c798d413 100644 --- a/frontend-wasm/src/component/dfg.rs +++ b/frontend-wasm/src/component/dfg.rs @@ -35,8 +35,10 @@ use std::{hash::Hash, ops::Index}; use indexmap::IndexMap; -use midenc_hir::cranelift_entity::{EntityRef, PrimaryMap}; -use rustc_hash::FxHashMap; +use midenc_hir::{ + cranelift_entity::{EntityRef, PrimaryMap}, + FxHashMap, +}; use crate::{ component::*, diff --git a/frontend-wasm/src/component/inline.rs b/frontend-wasm/src/component/inline.rs index c89cbb498..f080b22fe 100644 --- a/frontend-wasm/src/component/inline.rs +++ b/frontend-wasm/src/component/inline.rs @@ -46,12 +46,11 @@ // Based on wasmtime v16.0 Wasm component translation -use std::{borrow::Cow, collections::HashMap}; +use std::borrow::Cow; use anyhow::{bail, Result}; use indexmap::IndexMap; -use midenc_hir::cranelift_entity::PrimaryMap; -use rustc_hash::FxHashMap; +use midenc_hir::{cranelift_entity::PrimaryMap, FxBuildHasher, FxHashMap}; use wasmparser::types::{ComponentAnyTypeId, ComponentEntityType, ComponentInstanceTypeId}; use super::{ @@ -61,7 +60,6 @@ use super::{ use crate::{ component::{dfg, LocalInitializer}, module::{module_env::ParsedModule, types::*, ModuleImport}, - translation_utils::BuildFxHasher, }; pub fn run( @@ -87,8 +85,7 @@ pub fn run( // // Note that this is represents the abstract state of a host import of an // item since we don't know the precise structure of the host import. - let mut args = - HashMap::with_capacity_and_hasher(root_component.exports.len(), BuildFxHasher::default()); + let mut args = FxHashMap::with_capacity_and_hasher(root_component.exports.len(), FxBuildHasher); let mut path = Vec::new(); types.resources_mut().set_current_instance(index); let types_ref = root_component.types_ref(); diff --git a/frontend-wasm/src/component/parser.rs b/frontend-wasm/src/component/parser.rs index a06002497..51bd871c1 100644 --- a/frontend-wasm/src/component/parser.rs +++ b/frontend-wasm/src/component/parser.rs @@ -3,15 +3,15 @@ // Based on wasmtime v16.0 Wasm component translation -use std::{collections::HashMap, mem}; +use std::mem; use indexmap::IndexMap; use midenc_hir::{ cranelift_entity::PrimaryMap, diagnostics::{IntoDiagnostic, Severity}, + FxBuildHasher, FxHashMap, }; use midenc_session::Session; -use rustc_hash::FxHashMap; use wasmparser::{ types::{ AliasableResourceId, ComponentEntityType, ComponentFuncTypeId, ComponentInstanceTypeId, @@ -30,7 +30,6 @@ use crate::{ TableIndex, WasmType, }, }, - translation_utils::BuildFxHasher, unsupported_diag, WasmTranslationConfig, }; @@ -707,7 +706,7 @@ impl<'a, 'data> ComponentParser<'a, 'data> { raw_args: &[wasmparser::ComponentInstantiationArg<'data>], ty: ComponentInstanceTypeId, ) -> WasmResult> { - let mut args = HashMap::with_capacity_and_hasher(raw_args.len(), BuildFxHasher::default()); + let mut args = FxHashMap::with_capacity_and_hasher(raw_args.len(), FxBuildHasher); for arg in raw_args { let idx = self.kind_to_item(arg.kind, arg.index)?; args.insert(arg.name, idx); @@ -722,7 +721,7 @@ impl<'a, 'data> ComponentParser<'a, 'data> { &mut self, exports: &[wasmparser::ComponentExport<'data>], ) -> WasmResult> { - let mut map = HashMap::with_capacity_and_hasher(exports.len(), BuildFxHasher::default()); + let mut map = FxHashMap::with_capacity_and_hasher(exports.len(), FxBuildHasher); for export in exports { let idx = self.kind_to_item(export.kind, export.index)?; map.insert(export.name.0, idx); @@ -825,7 +824,7 @@ fn instantiate_module<'data>( module: ModuleIndex, raw_args: &[wasmparser::InstantiationArg<'data>], ) -> LocalInitializer<'data> { - let mut args = HashMap::with_capacity_and_hasher(raw_args.len(), BuildFxHasher::default()); + let mut args = FxHashMap::with_capacity_and_hasher(raw_args.len(), FxBuildHasher); for arg in raw_args { match arg.kind { wasmparser::InstantiationArgKind::Instance => { @@ -842,7 +841,7 @@ fn instantiate_module<'data>( fn instantiate_module_from_exports<'data>( exports: &[wasmparser::Export<'data>], ) -> LocalInitializer<'data> { - let mut map = HashMap::with_capacity_and_hasher(exports.len(), BuildFxHasher::default()); + let mut map = FxHashMap::with_capacity_and_hasher(exports.len(), FxBuildHasher); for export in exports { let idx = match export.kind { wasmparser::ExternalKind::Func => { diff --git a/frontend-wasm/src/component/translator.rs b/frontend-wasm/src/component/translator.rs index e28ad0580..08e72f1bd 100644 --- a/frontend-wasm/src/component/translator.rs +++ b/frontend-wasm/src/component/translator.rs @@ -1,11 +1,10 @@ use midenc_hir::{ cranelift_entity::PrimaryMap, diagnostics::Severity, CanonAbiImport, ComponentBuilder, - ComponentExport, FunctionIdent, FunctionType, Ident, InterfaceFunctionIdent, InterfaceIdent, - Symbol, + ComponentExport, FunctionIdent, FunctionType, FxHashMap, Ident, InterfaceFunctionIdent, + InterfaceIdent, Symbol, }; use midenc_hir_type::Abi; use midenc_session::Session; -use rustc_hash::FxHashMap; use super::{ interface_type_to_ir, CanonicalOptions, ComponentTypes, CoreDef, CoreExport, Export, diff --git a/frontend-wasm/src/component/types/mod.rs b/frontend-wasm/src/component/types/mod.rs index 008608e5a..945534bf9 100644 --- a/frontend-wasm/src/component/types/mod.rs +++ b/frontend-wasm/src/component/types/mod.rs @@ -9,8 +9,10 @@ pub mod resources; use core::{hash::Hash, ops::Index}; use anyhow::{bail, Result}; -use midenc_hir::cranelift_entity::{EntityRef, PrimaryMap}; -use rustc_hash::FxHashMap; +use midenc_hir::{ + cranelift_entity::{EntityRef, PrimaryMap}, + FxHashMap, +}; use wasmparser::{collections::IndexSet, names::KebabString, types}; use self::resources::ResourcesBuilder; diff --git a/frontend-wasm/src/component/types/resources.rs b/frontend-wasm/src/component/types/resources.rs index 4d85f15e4..cd3faac21 100644 --- a/frontend-wasm/src/component/types/resources.rs +++ b/frontend-wasm/src/component/types/resources.rs @@ -66,7 +66,7 @@ // Based on wasmtime v16.0 Wasm component translation -use rustc_hash::FxHashMap; +use midenc_hir::FxHashMap; use wasmparser::types; use crate::component::{ diff --git a/frontend-wasm/src/miden_abi/mod.rs b/frontend-wasm/src/miden_abi/mod.rs index 3eafa8f0b..30d591587 100644 --- a/frontend-wasm/src/miden_abi/mod.rs +++ b/frontend-wasm/src/miden_abi/mod.rs @@ -3,8 +3,7 @@ pub(crate) mod transform; pub(crate) mod tx_kernel; use miden_core::crypto::hash::RpoDigest; -use midenc_hir::{FunctionType, Symbol}; -use rustc_hash::FxHashMap; +use midenc_hir::{FunctionType, FxHashMap, Symbol}; pub(crate) type FunctionTypeMap = FxHashMap<&'static str, FunctionType>; pub(crate) type ModuleFunctionTypeMap = FxHashMap<&'static str, FunctionTypeMap>; diff --git a/frontend-wasm/src/module/mod.rs b/frontend-wasm/src/module/mod.rs index 58952595c..912e48544 100644 --- a/frontend-wasm/src/module/mod.rs +++ b/frontend-wasm/src/module/mod.rs @@ -9,9 +9,8 @@ use indexmap::IndexMap; use midenc_hir::{ cranelift_entity::{packed_option::ReservedValue, EntityRef, PrimaryMap}, diagnostics::{DiagnosticsHandler, Severity}, - Ident, Symbol, + FxHashMap, Ident, Symbol, }; -use rustc_hash::FxHashMap; use self::types::*; use crate::{component::SignatureIndex, error::WasmResult, unsupported_diag}; diff --git a/frontend-wasm/src/module/module_translation_state.rs b/frontend-wasm/src/module/module_translation_state.rs index ac26f6854..66d8834db 100644 --- a/frontend-wasm/src/module/module_translation_state.rs +++ b/frontend-wasm/src/module/module_translation_state.rs @@ -1,9 +1,8 @@ use miden_core::crypto::hash::RpoDigest; use midenc_hir::{ diagnostics::{DiagnosticsHandler, Severity}, - AbiParam, CallConv, DataFlowGraph, FunctionIdent, Ident, Linkage, Signature, + AbiParam, CallConv, DataFlowGraph, FunctionIdent, FxHashMap, Ident, Linkage, Signature, }; -use rustc_hash::FxHashMap; use super::{instance::ModuleArgument, ir_func_type, EntityIndex, FuncIndex, Module, ModuleTypes}; use crate::{ diff --git a/frontend-wasm/src/translation_utils.rs b/frontend-wasm/src/translation_utils.rs index 067ee278d..284048aa4 100644 --- a/frontend-wasm/src/translation_utils.rs +++ b/frontend-wasm/src/translation_utils.rs @@ -5,14 +5,11 @@ use midenc_hir::{ AbiParam, CallConv, Felt, FieldElement, InstBuilder, Linkage, Signature, Value, }; use midenc_hir_type::{FunctionType, Type}; -use rustc_hash::FxHasher; use crate::{ error::WasmResult, module::function_builder_ext::FunctionBuilderExt, unsupported_diag, }; -pub type BuildFxHasher = std::hash::BuildHasherDefault; - /// Represents the possible sizes in bytes of the discriminant of a variant type in the component /// model #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] diff --git a/hir-analysis/Cargo.toml b/hir-analysis/Cargo.toml index 41aa1444c..4f303c598 100644 --- a/hir-analysis/Cargo.toml +++ b/hir-analysis/Cargo.toml @@ -17,9 +17,9 @@ cranelift-entity.workspace = true cranelift-bforest.workspace = true inventory.workspace = true intrusive-collections.workspace = true +hashbrown.workspace = true midenc-hir.workspace = true midenc-session.workspace = true -rustc-hash.workspace = true smallvec.workspace = true thiserror.workspace = true diff --git a/hir-analysis/src/data.rs b/hir-analysis/src/data.rs index ffa0ea8d2..60262c892 100644 --- a/hir-analysis/src/data.rs +++ b/hir-analysis/src/data.rs @@ -1,9 +1,9 @@ use midenc_hir::{ pass::{Analysis, AnalysisManager, AnalysisResult}, - Function, FunctionIdent, GlobalValue, GlobalValueData, GlobalVariableTable, Module, Program, + Function, FunctionIdent, FxHashMap, GlobalValue, GlobalValueData, GlobalVariableTable, Module, + Program, }; use midenc_session::Session; -use rustc_hash::FxHashMap; /// This analysis calculates the addresses/offsets of all global variables in a [Program] or /// [Module] diff --git a/hir-analysis/src/liveness.rs b/hir-analysis/src/liveness.rs index fcd69ed17..b6de57ed9 100644 --- a/hir-analysis/src/liveness.rs +++ b/hir-analysis/src/liveness.rs @@ -9,7 +9,6 @@ use midenc_hir::{ Block as BlockId, Inst as InstId, Value as ValueId, *, }; use midenc_session::Session; -use rustc_hash::FxHashMap; use super::{ControlFlowGraph, DominatorTree, LoopAnalysis}; @@ -417,7 +416,7 @@ fn compute_liveness( domtree: &DominatorTree, loops: &LoopAnalysis, ) { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; let mut worklist = VecDeque::from_iter(domtree.cfg_postorder().iter().copied()); diff --git a/hir-analysis/src/validation/block.rs b/hir-analysis/src/validation/block.rs index 1f0234abe..4772a6797 100644 --- a/hir-analysis/src/validation/block.rs +++ b/hir-analysis/src/validation/block.rs @@ -2,7 +2,6 @@ use midenc_hir::{ diagnostics::{DiagnosticsHandler, Report, Severity, Spanned}, *, }; -use rustc_hash::FxHashSet; use smallvec::SmallVec; use super::Rule; diff --git a/hir-symbol/Cargo.toml b/hir-symbol/Cargo.toml index d4d4f5a1e..7bd8ace77 100644 --- a/hir-symbol/Cargo.toml +++ b/hir-symbol/Cargo.toml @@ -26,5 +26,6 @@ serde = { workspace = true, optional = true } [build-dependencies] Inflector.workspace = true +hashbrown.workspace = true rustc-hash.workspace = true toml.workspace = true diff --git a/hir-symbol/build.rs b/hir-symbol/build.rs index 96d42798a..8144efcc3 100644 --- a/hir-symbol/build.rs +++ b/hir-symbol/build.rs @@ -1,3 +1,4 @@ +extern crate hashbrown; extern crate inflector; extern crate rustc_hash; extern crate toml; @@ -12,9 +13,10 @@ use std::{ }; use inflector::Inflector; -use rustc_hash::FxHashSet; use toml::{value::Table, Value}; +type FxHashSet = hashbrown::HashSet; + #[derive(Debug, Default, Clone)] struct Symbol { key: String, diff --git a/hir-transform/Cargo.toml b/hir-transform/Cargo.toml index 88b3f959d..f9b8a8b5e 100644 --- a/hir-transform/Cargo.toml +++ b/hir-transform/Cargo.toml @@ -18,7 +18,6 @@ log.workspace = true midenc-hir.workspace = true midenc-hir-analysis.workspace = true midenc-session.workspace = true -rustc-hash.workspace = true smallvec.workspace = true [dev-dependencies] diff --git a/hir-transform/src/adt/scoped_map.rs b/hir-transform/src/adt/scoped_map.rs index fe6562b11..f1b9dca7d 100644 --- a/hir-transform/src/adt/scoped_map.rs +++ b/hir-transform/src/adt/scoped_map.rs @@ -1,6 +1,6 @@ use std::{borrow::Borrow, fmt, hash::Hash, rc::Rc}; -use rustc_hash::FxHashMap; +use midenc_hir::FxHashMap; #[derive(Clone)] pub struct ScopedMap diff --git a/hir-transform/src/inline_blocks.rs b/hir-transform/src/inline_blocks.rs index aa89c9522..b18467f7a 100644 --- a/hir-transform/src/inline_blocks.rs +++ b/hir-transform/src/inline_blocks.rs @@ -7,7 +7,6 @@ use midenc_hir::{ }; use midenc_hir_analysis::ControlFlowGraph; use midenc_session::{diagnostics::IntoDiagnostic, Session}; -use rustc_hash::FxHashSet; use smallvec::SmallVec; use crate::adt::ScopedMap; diff --git a/hir-transform/src/spill.rs b/hir-transform/src/spill.rs index b9e69ae10..effbcd8a0 100644 --- a/hir-transform/src/spill.rs +++ b/hir-transform/src/spill.rs @@ -11,7 +11,6 @@ use midenc_hir_analysis::{ spill::Placement, ControlFlowGraph, DominanceFrontier, DominatorTree, SpillAnalysis, Use, User, }; use midenc_session::{diagnostics::IntoDiagnostic, Emit, Session}; -use rustc_hash::FxHashSet; /// This pass places spills of SSA values to temporaries to cap the depth of the operand stack. /// diff --git a/hir-transform/src/split_critical_edges.rs b/hir-transform/src/split_critical_edges.rs index 1cef2d1e9..80113c99f 100644 --- a/hir-transform/src/split_critical_edges.rs +++ b/hir-transform/src/split_critical_edges.rs @@ -7,7 +7,6 @@ use midenc_hir::{ }; use midenc_hir_analysis::ControlFlowGraph; use midenc_session::{diagnostics::IntoDiagnostic, Session}; -use rustc_hash::FxHashSet; use smallvec::SmallVec; /// This pass breaks any critical edges in the CFG of a function. diff --git a/hir/Cargo.toml b/hir/Cargo.toml index d486dc05e..cf82a9638 100644 --- a/hir/Cargo.toml +++ b/hir/Cargo.toml @@ -32,6 +32,7 @@ intrusive-collections.workspace = true inventory.workspace = true lalrpop-util = "0.20" log.workspace = true +hashbrown.workspace = true miden-core.workspace = true miden-assembly.workspace = true midenc-hir-symbol.workspace = true diff --git a/hir/src/asm/import.rs b/hir/src/asm/import.rs index 2a0a2f844..7fb2b8499 100644 --- a/hir/src/asm/import.rs +++ b/hir/src/asm/import.rs @@ -5,11 +5,10 @@ use core::{ }; use anyhow::bail; -use rustc_hash::{FxHashMap, FxHashSet}; use crate::{ diagnostics::{SourceSpan, Spanned}, - FunctionIdent, Ident, Symbol, + FunctionIdent, FxHashMap, FxHashSet, Ident, Symbol, }; #[derive(Default, Debug, Clone)] @@ -32,7 +31,7 @@ impl ModuleImportInfo { /// /// NOTE: It is assumed that the caller is adding imports using fully-qualified names. pub fn add(&mut self, id: FunctionIdent) { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; let module_id = id.module; match self.modules.entry(module_id) { diff --git a/hir/src/dataflow.rs b/hir/src/dataflow.rs index 1f5d5a4e2..f0b575cb4 100644 --- a/hir/src/dataflow.rs +++ b/hir/src/dataflow.rs @@ -1,7 +1,6 @@ use core::ops::{Deref, DerefMut, Index, IndexMut}; use cranelift_entity::{PrimaryMap, SecondaryMap}; -use rustc_hash::FxHashMap; use smallvec::SmallVec; use crate::{ @@ -114,7 +113,7 @@ impl DataFlowGraph { name: Ident, signature: Signature, ) -> Result { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; let id = FunctionIdent { module, diff --git a/hir/src/lib.rs b/hir/src/lib.rs index 20c1d622d..c9781baf0 100644 --- a/hir/src/lib.rs +++ b/hir/src/lib.rs @@ -27,6 +27,10 @@ pub use midenc_hir_type::{ }; pub use midenc_session::diagnostics::{self, SourceSpan}; +pub type FxHashMap = hashbrown::HashMap; +pub type FxHashSet = hashbrown::HashSet; +pub use rustc_hash::{FxBuildHasher, FxHasher}; + /// Represents a field element in Miden pub type Felt = miden_core::Felt; diff --git a/hir/src/module.rs b/hir/src/module.rs index ff90c3858..d695a1654 100644 --- a/hir/src/module.rs +++ b/hir/src/module.rs @@ -5,7 +5,6 @@ use intrusive_collections::{ linked_list::{Cursor, CursorMut}, LinkedList, LinkedListLink, RBTreeLink, }; -use rustc_hash::FxHashSet; use self::formatter::PrettyPrint; use crate::{ diff --git a/hir/src/parser/ast/convert.rs b/hir/src/parser/ast/convert.rs index 30a75c24f..fc35afea0 100644 --- a/hir/src/parser/ast/convert.rs +++ b/hir/src/parser/ast/convert.rs @@ -650,7 +650,7 @@ fn try_insert_inst( module: function.id.module, }; if let Some(sig) = functions_by_id.get(&local) { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; if let Entry::Vacant(entry) = function.dfg.imports.entry(callee) { entry.insert(ExternalFunction { id: callee, @@ -672,7 +672,7 @@ fn try_insert_inst( callee } Right(external) => { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; used_imports.insert(external); if let Entry::Vacant(entry) = function.dfg.imports.entry(external) { if let Some(ef) = imports_by_id.get(&external) { diff --git a/hir/src/pass/analysis.rs b/hir/src/pass/analysis.rs index ab5a062b2..8384f3b43 100644 --- a/hir/src/pass/analysis.rs +++ b/hir/src/pass/analysis.rs @@ -5,11 +5,9 @@ use core::{ }; use midenc_session::Session; -use rustc_hash::{FxHashMap, FxHashSet, FxHasher}; +use rustc_hash::FxBuildHasher; -use crate::diagnostics::Report; - -type BuildFxHasher = std::hash::BuildHasherDefault; +use crate::{diagnostics::Report, FxHashMap, FxHashSet}; /// A convenient type alias for `Result` pub type AnalysisResult = Result; @@ -115,7 +113,7 @@ impl CachedAnalysisKey { { use core::hash::Hasher; - let mut hasher = FxHasher::default(); + let mut hasher = rustc_hash::FxHasher::default(); let entity_ty = TypeId::of::(); entity_ty.hash(&mut hasher); key.hash(&mut hasher); @@ -204,11 +202,9 @@ impl PreservedAnalyses { } fn with_capacity(current_entity_key: u64, cap: usize) -> Self { - use std::collections::HashMap; - Self { current_entity_key, - preserved: HashMap::with_capacity_and_hasher(cap, BuildFxHasher::default()), + preserved: FxHashMap::with_capacity_and_hasher(cap, FxBuildHasher), } } } @@ -349,8 +345,6 @@ impl AnalysisManager { where T: AnalysisKey, { - use std::collections::HashMap; - let current_entity_key = CachedAnalysisKey::entity_key::(key); if self.preserve_none.remove(¤t_entity_key) { @@ -381,8 +375,7 @@ impl AnalysisManager { preserve.insert(self.preserve.take(&key).unwrap()); } - let mut cached = - HashMap::with_capacity_and_hasher(to_invalidate.len(), BuildFxHasher::default()); + let mut cached = FxHashMap::with_capacity_and_hasher(to_invalidate.len(), FxBuildHasher); for key in to_invalidate.into_iter() { let (key, value) = self.cached.remove_entry(&key).unwrap(); cached.insert(key, value); diff --git a/hir2/src/folder.rs b/hir2/src/folder.rs index acbf39fee..b95367e15 100644 --- a/hir2/src/folder.rs +++ b/hir2/src/folder.rs @@ -1,14 +1,13 @@ use alloc::{collections::BTreeMap, rc::Rc}; -use rustc_hash::FxHashMap; use smallvec::{smallvec, SmallVec}; use crate::{ matchers::Matcher, traits::{ConstantLike, Foldable, IsolatedFromAbove}, - AttributeValue, BlockRef, Builder, Context, Dialect, FoldResult, OpFoldResult, OperationRef, - RegionRef, Rewriter, RewriterImpl, RewriterListener, SourceSpan, Spanned, Type, Value, - ValueRef, + AttributeValue, BlockRef, Builder, Context, Dialect, FoldResult, FxHashMap, OpFoldResult, + OperationRef, RegionRef, Rewriter, RewriterImpl, RewriterListener, SourceSpan, Spanned, Type, + Value, ValueRef, }; /// Represents a constant value uniqued by dialect, value, and type. diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index 6b77b8970..5266a6a7f 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -32,6 +32,10 @@ pub use compact_str::{ CompactString as SmallStr, CompactStringExt as SmallStrExt, ToCompactString as ToSmallStr, }; +pub type FxHashMap = hashbrown::HashMap; +pub type FxHashSet = hashbrown::HashSet; +pub use rustc_hash::{FxBuildHasher, FxHasher}; + mod any; mod attributes; pub mod demangle; From 9c86f86089e2149fb1bf97d3537e77a1b484b6d2 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Fri, 25 Oct 2024 01:46:16 -0400 Subject: [PATCH 25/31] feat: implement pass infrastructure --- hir2/src/ir/entity.rs | 13 + hir2/src/ir/operation.rs | 6 +- hir2/src/ir/operation/name.rs | 10 +- hir2/src/lib.rs | 4 + hir2/src/pass.rs | 19 + hir2/src/pass/analysis.rs | 694 +++++++++++++++++++++++ hir2/src/pass/instrumentation.rs | 129 +++++ hir2/src/pass/manager.rs | 912 +++++++++++++++++++++++++++++++ hir2/src/pass/pass.rs | 371 +++++++++++++ hir2/src/pass/registry.rs | 397 ++++++++++++++ hir2/src/pass/specialization.rs | 142 +++++ hir2/src/pass/statistics.rs | 462 ++++++++++++++++ midenc-session/src/duration.rs | 36 ++ 13 files changed, 3187 insertions(+), 8 deletions(-) create mode 100644 hir2/src/pass.rs create mode 100644 hir2/src/pass/analysis.rs create mode 100644 hir2/src/pass/instrumentation.rs create mode 100644 hir2/src/pass/manager.rs create mode 100644 hir2/src/pass/pass.rs create mode 100644 hir2/src/pass/registry.rs create mode 100644 hir2/src/pass/specialization.rs create mode 100644 hir2/src/pass/statistics.rs diff --git a/hir2/src/ir/entity.rs b/hir2/src/ir/entity.rs index 33155316e..c46492a5d 100644 --- a/hir2/src/ir/entity.rs +++ b/hir2/src/ir/entity.rs @@ -701,6 +701,19 @@ pub struct EntityMut<'b, T: ?Sized> { _marker: core::marker::PhantomData<&'b mut T>, } impl<'b, T: ?Sized> EntityMut<'b, T> { + #[inline] + pub fn map(mut orig: Self, f: F) -> EntityMut<'b, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let value = NonNull::from(f(&mut *orig)); + EntityMut { + value, + borrow: orig.borrow, + _marker: core::marker::PhantomData, + } + } + /// Splits an `EntityMut` into multiple `EntityMut`s for different components of the borrowed /// data. /// diff --git a/hir2/src/ir/operation.rs b/hir2/src/ir/operation.rs index bd189a1ea..bae90cbac 100644 --- a/hir2/src/ir/operation.rs +++ b/hir2/src/ir/operation.rs @@ -284,7 +284,7 @@ impl Operation { /// Returns true if the concrete type of this operation is `T` #[inline] - pub fn is(&self) -> bool { + pub fn is(&self) -> bool { self.name.is::() } @@ -298,12 +298,12 @@ impl Operation { } /// Attempt to downcast to the concrete [Op] type of this operation - pub fn downcast_ref(&self) -> Option<&T> { + pub fn downcast_ref(&self) -> Option<&T> { self.name.downcast_ref::(self.container()) } /// Attempt to downcast to the concrete [Op] type of this operation - pub fn downcast_mut(&mut self) -> Option<&mut T> { + pub fn downcast_mut(&mut self) -> Option<&mut T> { self.name.downcast_mut::(self.container().cast_mut()) } diff --git a/hir2/src/ir/operation/name.rs b/hir2/src/ir/operation/name.rs index ab2802b88..d0317e12c 100644 --- a/hir2/src/ir/operation/name.rs +++ b/hir2/src/ir/operation/name.rs @@ -59,7 +59,7 @@ impl OperationName { } /// Returns true if `T` is the concrete type that implements this operation - pub fn is(&self) -> bool { + pub fn is(&self) -> bool { TypeId::of::() == self.0.type_id } @@ -79,7 +79,7 @@ impl OperationName { } #[inline] - pub(super) fn downcast_ref(&self, ptr: *const ()) -> Option<&T> { + pub(super) fn downcast_ref(&self, ptr: *const ()) -> Option<&T> { if self.is::() { Some(unsafe { self.downcast_ref_unchecked(ptr) }) } else { @@ -88,12 +88,12 @@ impl OperationName { } #[inline(always)] - unsafe fn downcast_ref_unchecked(&self, ptr: *const ()) -> &T { + unsafe fn downcast_ref_unchecked(&self, ptr: *const ()) -> &T { &*core::ptr::from_raw_parts(ptr.cast::(), ()) } #[inline] - pub(super) fn downcast_mut(&mut self, ptr: *mut ()) -> Option<&mut T> { + pub(super) fn downcast_mut(&mut self, ptr: *mut ()) -> Option<&mut T> { if self.is::() { Some(unsafe { self.downcast_mut_unchecked(ptr) }) } else { @@ -102,7 +102,7 @@ impl OperationName { } #[inline(always)] - unsafe fn downcast_mut_unchecked(&mut self, ptr: *mut ()) -> &mut T { + unsafe fn downcast_mut_unchecked(&mut self, ptr: *mut ()) -> &mut T { &mut *core::ptr::from_raw_parts_mut(ptr.cast::(), ()) } diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index 5266a6a7f..68b294097 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -18,6 +18,9 @@ #![feature(unboxed_closures)] #![feature(const_type_id)] #![feature(exact_size_is_empty)] +#![feature(generic_const_exprs)] +#![feature(new_uninit)] +#![feature(clone_to_uninit)] #![allow(incomplete_features)] #![allow(internal_features)] @@ -46,6 +49,7 @@ pub mod formatter; mod hash; mod ir; pub mod matchers; +pub mod pass; mod patterns; pub use self::{ diff --git a/hir2/src/pass.rs b/hir2/src/pass.rs new file mode 100644 index 000000000..c46d3306d --- /dev/null +++ b/hir2/src/pass.rs @@ -0,0 +1,19 @@ +mod analysis; +mod instrumentation; +mod manager; +#[allow(clippy::module_inception)] +mod pass; +pub mod registry; +mod specialization; +pub mod statistics; + +use self::pass::PassExecutionState; +pub use self::{ + analysis::{Analysis, AnalysisManager, OperationAnalysis, PreservedAnalyses}, + instrumentation::{PassInstrumentation, PassInstrumentor, PipelineParentInfo}, + manager::{Nesting, OpPassManager, PassDisplayMode, PassManager}, + pass::{OperationPass, Pass}, + registry::{PassInfo, PassPipelineInfo}, + specialization::PassTarget, + statistics::{PassStatistic, Statistic, StatisticValue}, +}; diff --git a/hir2/src/pass/analysis.rs b/hir2/src/pass/analysis.rs new file mode 100644 index 000000000..aa0bb6a50 --- /dev/null +++ b/hir2/src/pass/analysis.rs @@ -0,0 +1,694 @@ +use alloc::rc::Rc; +use core::{ + any::{Any, TypeId}, + cell::RefCell, +}; + +use smallvec::SmallVec; + +type FxHashMap = hashbrown::HashMap; + +use super::{PassInstrumentor, PassTarget}; +use crate::{Op, Operation, OperationRef}; + +/// The [Analysis] trait is used to define an analysis over some operation. +/// +/// Analyses must be default-constructible, and `Sized + 'static` to support downcasting. +/// +/// An analysis, when requested, is first constructed via its `Default` implementation, and then +/// [Analysis::analyze] is called on the target type in order to compute the analysis results. +/// The analysis type also acts as storage for the analysis results. +/// +/// When the IR is changed, analyses are invalidated by default, unless they are specifically +/// preserved via the [PreservedAnalyses] set. When an analysis is being asked if it should be +/// invalidated, via [Analysis::invalidate], it has the opportunity to identify if it actually +/// needs to be invalidated based on what analyses were preserved. If dependent analyses of this +/// analysis haven't been invalidated, then this analysis may be able preserve itself as well, +/// and avoid redundant recomputation. +pub trait Analysis: Default + Any { + /// The specific type on which this analysis is performed. + /// + /// The analysis will only be run when an operation is of this type. + type Target: ?Sized + PassTarget; + + /// The [TypeId] associated with the concrete underlying [Analysis] implementation + /// + /// This is automatically implemented for you, but in some cases, such as wrapping an + /// analysis in another type, you may want to implement this so that queries against the + /// type return the expected [TypeId] + #[inline] + fn analysis_id(&self) -> TypeId { + TypeId::of::() + } + + /// Get a `dyn Any` reference to the underlying [Analysis] implementation + /// + /// This is automatically implemented for you, but in some cases, such as wrapping an + /// analysis in another type, you may want to implement this so that queries against the + /// type return the expected [TypeId] + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + /// Same as [Analysis::as_any], but used specifically for getting a reference-counted handle, + /// rather than a raw reference. + #[inline(always)] + fn as_any_rc(self: Rc) -> Rc { + self as Rc + } + + /// Returns the display name for this analysis + /// + /// By default this simply returns the name of the concrete implementation type. + fn name(&self) -> &'static str { + core::any::type_name::() + } + + /// Analyze `op` using the provided [AnalysisManager]. + fn analyze(&mut self, op: &Self::Target, analysis_manager: AnalysisManager); + + /// Query this analysis for invalidation. + /// + /// Given a preserved analysis set, returns true if it should truly be invalidated. This allows + /// for more fine-tuned invalidation in cases where an analysis wasn't explicitly marked + /// preserved, but may be preserved(or invalidated) based upon other properties such as analyses + /// sets. + fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) -> bool; +} + +/// A type-erased [Analysis]. +/// +/// This is automatically derived for all [Analysis] implementations, and is the means by which +/// one can abstract over sets of analyses using dynamic dispatch. +/// +/// This essentially just delegates to the underlying [Analysis] implementation, but it also handles +/// converting a raw [OperationRef] to the appropriate target type expected by the underlying +/// [Analysis]. +pub trait OperationAnalysis { + /// The unique type id of this analysis + fn analysis_id(&self) -> TypeId; + + /// Used for dynamic casting to the underlying [Analysis] type + fn as_any(&self) -> &dyn Any; + + /// Used for dynamic casting to the underlying [Analysis] type + fn as_any_rc(self: Rc) -> Rc; + + /// The name of this analysis + fn name(&self) -> &'static str; + + /// Runs this analysis over `op`. + /// + /// NOTE: This is only ever called once per instantiation of the analysis, but in theory can + /// support multiple calls to re-analyze `op`. Each call should reset any internal state to + /// ensure that if an analysis is reused in this way, that each analysis gets a clean slate. + fn analyze(&mut self, op: &OperationRef, am: AnalysisManager); + + /// Query this analysis for invalidation. + /// + /// Given a preserved analysis set, returns true if it should truly be invalidated. This allows + /// for more fine-tuned invalidation in cases where an analysis wasn't explicitly marked + /// preserved, but may be preserved(or invalidated) based upon other properties such as analyses + /// sets. + /// + /// Invalidated analyses must be removed from `preserved_analyses`. + fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) -> bool; +} + +impl dyn OperationAnalysis { + /// Cast an reference-counted handle to this analysis to its concrete implementation type. + /// + /// Returns `None` if the underlying analysis is not of type `T` + #[inline] + pub fn downcast(self: Rc) -> Option> { + self.as_any_rc().downcast::().ok() + } +} + +impl OperationAnalysis for A +where + A: Analysis, +{ + #[inline] + fn analysis_id(&self) -> TypeId { + ::analysis_id(self) + } + + #[inline] + fn as_any(&self) -> &dyn Any { + ::as_any(self) + } + + #[inline] + fn as_any_rc(self: Rc) -> Rc { + ::as_any_rc(self) + } + + #[inline] + fn name(&self) -> &'static str { + ::name(self) + } + + #[inline] + fn analyze(&mut self, op: &OperationRef, am: AnalysisManager) { + let op = <::Target as PassTarget>::into_target(op); + ::analyze(self, &op, am) + } + + #[inline] + fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) -> bool { + ::invalidate(self, preserved_analyses) + } +} + +/// Represents a set of analyses that are known to be preserved after a rewrite has been applied. +#[derive(Default)] +pub struct PreservedAnalyses { + /// The set of preserved analysis type ids + preserved: SmallVec<[TypeId; 8]>, +} +impl PreservedAnalyses { + /// Mark all analyses as preserved. + /// + /// This is generally only useful when the IR is known not to have changed. + pub fn preserve_all(&mut self) { + self.insert(AllAnalyses::TYPE_ID); + } + + /// Mark the specified [Analysis] type as preserved. + pub fn preserve(&mut self) { + self.insert(TypeId::of::()); + } + + /// Mark a type as preserved using its raw [TypeId]. + /// + /// Typically it is best to use [Self::preserve] instead, but this can be useful in cases + /// where you can't express the type in Rust directly. + pub fn preserve_raw(&mut self, id: TypeId) { + self.insert(id); + } + + /// Returns true if the specified type is preserved. + pub fn is_preserved(&self) -> bool { + self.preserved.contains(&TypeId::of::()) + } + + /// Returns true if the specified [TypeId] is marked preserved. + pub fn is_preserved_raw(&self, ty: &TypeId) -> bool { + self.preserved.contains(ty) + } + + /// Mark a previously preserved type as invalidated. + pub fn unpreserve(&mut self) { + self.remove(&TypeId::of::()) + } + + /// Mark a previously preserved [TypeId] as invalidated. + pub fn unpreserve_raw(&mut self, ty: &TypeId) { + self.remove(ty) + } + + /// Returns true if all analyses are preserved + pub fn is_all(&self) -> bool { + self.preserved.contains(&AllAnalyses::TYPE_ID) + } + + /// Returns true if no analyses are being preserved + pub fn is_none(&self) -> bool { + self.preserved.is_empty() + } + + fn insert(&mut self, id: TypeId) { + match self.preserved.binary_search_by_key(&id, |probe| *probe) { + Ok(index) => self.preserved.insert(index, id), + Err(index) => self.preserved.insert(index, id), + } + } + + fn remove(&mut self, id: &TypeId) { + if let Ok(index) = self.preserved.binary_search_by_key(&id, |probe| probe) { + self.preserved.remove(index); + } + } +} + +/// A marker type that is used to represent all possible [Analysis] types +pub struct AllAnalyses; +impl AllAnalyses { + const TYPE_ID: TypeId = TypeId::of::(); +} + +/// This type wraps all analyses stored in an [AnalysisMap], and handles some of the boilerplate +/// details around invalidation by intercepting calls to [Analysis::invalidate] and wrapping it +/// with extra logic. Notably, ensuring that invalidated analyses are removed from the +/// [PreservedAnalyses] set is handled by this wrapper. +/// +/// It is a transparent wrapper around `A`, and otherwise acts as a simple proxy to `A`'s +/// implementation of the [Analysis] trait. +#[repr(transparent)] +struct AnalysisWrapper { + analysis: A, +} +impl AnalysisWrapper { + fn new(op: &::Target, am: AnalysisManager) -> Self { + let mut analysis = A::default(); + analysis.analyze(op, am); + + Self { analysis } + } +} +impl Default for AnalysisWrapper { + fn default() -> Self { + Self { + analysis: Default::default(), + } + } +} +impl Analysis for AnalysisWrapper { + type Target = ::Target; + + #[inline] + fn analysis_id(&self) -> TypeId { + self.analysis.analysis_id() + } + + #[inline] + fn as_any(&self) -> &dyn Any { + self.analysis.as_any() + } + + #[inline] + fn as_any_rc(self: Rc) -> Rc { + // SAFETY: This transmute is safe because AnalysisWrapper is a transparent wrapper + // around A, so a pointer to the former is a pointer to the latter + let ptr = Rc::into_raw(self); + unsafe { Rc::::from_raw(ptr.cast()) as Rc } + } + + #[inline] + fn name(&self) -> &'static str { + self.analysis.name() + } + + #[inline] + fn analyze(&mut self, op: &Self::Target, am: AnalysisManager) { + self.analysis.analyze(op, am); + } + + fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) -> bool { + let invalidated = self.analysis.invalidate(preserved_analyses); + if invalidated { + preserved_analyses.unpreserve::(); + } + invalidated + } +} + +/// An [AnalysisManager] is the primary entrypoint for performing analysis on a specific operation +/// instance that it is constructed for. +/// +/// It is used to manage and cache analyses for the operation, as well as those of child operations, +/// via nested [AnalysisManager] instances. +/// +/// This type is a thin wrapper around a pointer, and is meant to be passed by value. It can be +/// cheaply cloned. +#[derive(Clone)] +#[repr(transparent)] +pub struct AnalysisManager { + analyses: Rc, +} +impl AnalysisManager { + /// Create a new top-level [AnalysisManager] for `op` + pub fn new(op: OperationRef, instrumentor: Option>) -> Self { + Self { + analyses: Rc::new(NestedAnalysisMap::new(op, instrumentor)), + } + } + + /// Query for a cached analysis on the given parent operation. The analysis may not exist and if + /// it does it may be out-of-date. + pub fn get_cached_parent_analysis(&self, parent: &OperationRef) -> Option> + where + A: Analysis, + { + let mut current_parent = self.analyses.parent(); + while let Some(parent_am) = current_parent.take() { + if &parent_am.get_operation() == parent { + return parent_am.analyses().get_cached::(); + } + current_parent = parent_am.parent(); + } + None + } + + /// Query for the given analysis for the current operation. + pub fn get_analysis(&self) -> Rc + where + A: Analysis, + { + self.analyses.analyses.borrow_mut().get(self.pass_instrumentor(), self.clone()) + } + + /// Query for the given analysis for the current operation of a specific derived operation type. + /// + /// NOTE: This will panic if the current operation is not of type `O`. + pub fn get_analysis_for(&self) -> Rc + where + A: Analysis, + O: 'static, + { + self.analyses + .analyses + .borrow_mut() + .get_analysis_for::(self.pass_instrumentor(), self.clone()) + } + + /// Query for a cached entry of the given analysis on the current operation. + pub fn get_cached_analysis(&self) -> Option> + where + A: Analysis, + { + self.analyses.analyses().get_cached::() + } + + /// Query for an analysis of a child operation, constructing it if necessary. + pub fn get_child_analysis(&self, op: &OperationRef) -> Rc + where + A: Analysis, + { + self.clone().nest(op).get_analysis::() + } + + /// Query for an analysis of a child operation of a specific derived operation type, + /// constructing it if necessary. + /// + /// NOTE: This will panic if `op` is not of type `O`. + pub fn get_child_analysis_for(&self, op: &O) -> Rc + where + A: Analysis, + O: Op, + { + self.clone() + .nest(&op.as_operation().as_operation_ref()) + .get_analysis_for::() + } + + /// Query for a cached analysis of a child operation, or return `None`. + pub fn get_cached_child_analysis(&self, child: &OperationRef) -> Option> + where + A: Analysis, + { + assert!(child.borrow().parent_op().unwrap() == self.analyses.get_operation()); + let child_analyses = self.analyses.child_analyses.borrow(); + let child_analyses = child_analyses.get(child)?; + let child_analyses = child_analyses.analyses.borrow(); + child_analyses.get_cached::() + } + + /// Get an analysis manager for the given operation, which must be a proper descendant of the + /// current operation represented by this analysis manager. + pub fn nest(&self, op: &OperationRef) -> AnalysisManager { + let current_op = self.analyses.get_operation(); + assert!(current_op.borrow().is_proper_ancestor_of(op), "expected valid descendant op"); + + // Check for the base case where the provided operation is immediately nested + if current_op == op.borrow().parent_op().expect("expected `op` to have a parent") { + return self.nest_immediate(op.clone()); + } + + // Otherwise, we need to collect all ancestors up to the current operation + let mut ancestors = SmallVec::<[OperationRef; 4]>::default(); + let mut next_op = op.clone(); + while next_op != current_op { + ancestors.push(next_op.clone()); + next_op = next_op.borrow().parent_op().unwrap(); + } + + let mut manager = self.clone(); + while let Some(op) = ancestors.pop() { + manager = manager.nest_immediate(op); + } + manager + } + + fn nest_immediate(&self, op: OperationRef) -> AnalysisManager { + use hashbrown::hash_map::Entry; + + assert!( + Some(self.analyses.get_operation()) == op.borrow().parent_op(), + "expected immediate child operation" + ); + let parent = self.analyses.clone(); + let mut child_analyses = self.analyses.child_analyses.borrow_mut(); + match child_analyses.entry(op.clone()) { + Entry::Vacant(entry) => { + let analyses = entry.insert(Rc::new(parent.nest(op))); + AnalysisManager { + analyses: Rc::clone(analyses), + } + } + Entry::Occupied(entry) => AnalysisManager { + analyses: Rc::clone(entry.get()), + }, + } + } + + /// Invalidate any non preserved analyses. + #[inline] + pub fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) { + Rc::clone(&self.analyses).invalidate(preserved_analyses) + } + + /// Clear any held analyses. + #[inline] + pub fn clear(&mut self) { + self.analyses.clear(); + } + + /// Clear any held analyses when the returned guard is dropped. + #[inline] + pub fn defer_clear(&self) -> ResetAnalysesOnDrop { + ResetAnalysesOnDrop { + analyses: self.analyses.clone(), + } + } + + /// Returns a [PassInstrumentor] for the current operation, if one was installed. + #[inline] + pub fn pass_instrumentor(&self) -> Option> { + self.analyses.pass_instrumentor() + } +} + +#[must_use] +#[doc(hidden)] +pub struct ResetAnalysesOnDrop { + analyses: Rc, +} +impl Drop for ResetAnalysesOnDrop { + fn drop(&mut self) { + self.analyses.clear() + } +} + +/// An analysis map that contains a map for the current operation, and a set of maps for any child +/// operations. +struct NestedAnalysisMap { + parent: Option>, + instrumentor: Option>, + analyses: RefCell, + child_analyses: RefCell>>, +} +impl NestedAnalysisMap { + /// Create a new top-level [NestedAnalysisMap] for `op`, with the given optional pass + /// instrumentor. + pub fn new(op: OperationRef, instrumentor: Option>) -> Self { + Self { + parent: None, + instrumentor, + analyses: RefCell::new(AnalysisMap::new(op)), + child_analyses: Default::default(), + } + } + + /// Create a new [NestedAnalysisMap] for `op` nested under `self`. + pub fn nest(self: Rc, op: OperationRef) -> Self { + let instrumentor = self.instrumentor.clone(); + Self { + parent: Some(self), + instrumentor, + analyses: RefCell::new(AnalysisMap::new(op)), + child_analyses: Default::default(), + } + } + + /// Get the parent [NestedAnalysisMap], or `None` if this is a top-level map. + pub fn parent(&self) -> Option> { + self.parent.clone() + } + + /// Return a [PassInstrumentor] for the current operation, if one was installed. + pub fn pass_instrumentor(&self) -> Option> { + self.instrumentor.clone() + } + + /// Get the operation for this analysis map. + #[inline] + pub fn get_operation(&self) -> OperationRef { + self.analyses.borrow().get_operation() + } + + fn analyses(&self) -> core::cell::Ref<'_, AnalysisMap> { + self.analyses.borrow() + } + + /// Invalidate any non preserved analyses. + pub fn invalidate(self: Rc, preserved_analyses: &mut PreservedAnalyses) { + // If all analyses were preserved, then there is nothing to do + if preserved_analyses.is_all() { + return; + } + + // Invalidate the analyses for the current operation directly + self.analyses.borrow_mut().invalidate(preserved_analyses); + + // If no analyses were preserved, then just simply clear out the child analysis results + if preserved_analyses.is_none() { + self.child_analyses.borrow_mut().clear(); + } + + // Otherwise, invalidate each child analysis map + let mut to_invalidate = SmallVec::<[Rc; 8]>::from_iter([self]); + while let Some(map) = to_invalidate.pop() { + map.child_analyses.borrow_mut().retain(|_op, nested_analysis_map| { + Rc::clone(nested_analysis_map).invalidate(preserved_analyses); + if nested_analysis_map.child_analyses.borrow().is_empty() { + false + } else { + to_invalidate.push(Rc::clone(nested_analysis_map)); + true + } + }); + } + } + + pub fn clear(&self) { + self.child_analyses.borrow_mut().clear(); + self.analyses.borrow_mut().clear(); + } +} + +/// This class represents a cache of analyses for a single operation. +/// +/// All computation, caching, and invalidation of analyses takes place here. +struct AnalysisMap { + analyses: FxHashMap>, + ir: OperationRef, +} +impl AnalysisMap { + pub fn new(ir: OperationRef) -> Self { + Self { + analyses: Default::default(), + ir, + } + } + + /// Get an analysis for the current IR unit, computing it if necessary. + pub fn get(&mut self, pi: Option>, am: AnalysisManager) -> Rc + where + A: Analysis, + { + Self::get_analysis_impl::( + &mut self.analyses, + pi, + &self.ir.borrow(), + &self.ir, + am, + ) + } + + /// Get a cached analysis instance if one exists, otherwise return `None`. + pub fn get_cached(&self) -> Option> + where + A: Analysis, + { + self.analyses.get(&TypeId::of::()).cloned().and_then(|a| a.downcast::()) + } + + /// Get an analysis for the current IR unit, assuming it's of the specified type, computing it + /// if necessary. + /// + /// NOTE: This will panic if the current operation is not of type `O`. + pub fn get_analysis_for( + &mut self, + pi: Option>, + am: AnalysisManager, + ) -> Rc + where + A: Analysis, + O: 'static, + { + let ir = <::Target as PassTarget>::into_target(&self.ir); + Self::get_analysis_impl::(&mut self.analyses, pi, &*ir, &self.ir, am) + } + + fn get_analysis_impl( + analyses: &mut FxHashMap>, + pi: Option>, + ir: &O, + op: &OperationRef, + am: AnalysisManager, + ) -> Rc + where + A: Analysis, + { + use hashbrown::hash_map::Entry; + + let id = TypeId::of::(); + match analyses.entry(id) { + Entry::Vacant(entry) => { + // We don't have a cached analysis for the operation, compute it directly and + // add it to the cache. + if let Some(pi) = pi.as_deref() { + pi.run_before_analysis(core::any::type_name::(), &id, op); + } + + let analysis = entry.insert(Self::construct_analysis::(am, ir)); + + if let Some(pi) = pi.as_deref() { + pi.run_after_analysis(core::any::type_name::(), &id, op); + } + + Rc::clone(analysis).downcast::().unwrap() + } + Entry::Occupied(entry) => Rc::clone(entry.get()).downcast::().unwrap(), + } + } + + fn construct_analysis(am: AnalysisManager, op: &O) -> Rc + where + A: Analysis, + { + Rc::new(AnalysisWrapper::::new(op, am)) as Rc + } + + /// Returns the operation that this analysis map represents. + pub fn get_operation(&self) -> OperationRef { + self.ir.clone() + } + + /// Clear any held analyses. + pub fn clear(&mut self) { + self.analyses.clear(); + } + + /// Invalidate any cached analyses based upon the given set of preserved analyses. + pub fn invalidate(&mut self, preserved_analyses: &mut PreservedAnalyses) { + // Remove any analyses that were invalidated. + // + // Using `retain`, we preserve the original insertion order, and dependencies always go + // before users, so we need only a single pass through. + self.analyses.retain(|_, a| !a.invalidate(preserved_analyses)); + } +} diff --git a/hir2/src/pass/instrumentation.rs b/hir2/src/pass/instrumentation.rs new file mode 100644 index 000000000..d16265671 --- /dev/null +++ b/hir2/src/pass/instrumentation.rs @@ -0,0 +1,129 @@ +use core::{any::TypeId, cell::RefCell}; + +use compact_str::CompactString; +use smallvec::SmallVec; + +use super::OperationPass; +use crate::{OperationName, OperationRef}; + +#[allow(unused_variables)] +pub trait PassInstrumentation { + fn run_before_pipeline( + &mut self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + } + fn run_after_pipeline( + &mut self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + } + fn run_before_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} + fn run_after_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} + fn run_after_pass_failed(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} + fn run_before_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) {} + fn run_after_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) {} +} + +pub struct PipelineParentInfo { + /// The pass that spawned this pipeline, if any + pub pass: Option, +} + +impl PassInstrumentation for Box

{ + fn run_before_pipeline( + &mut self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + (**self).run_before_pipeline(name, parent_info); + } + + fn run_after_pipeline( + &mut self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + (**self).run_after_pipeline(name, parent_info); + } + + fn run_before_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) { + (**self).run_before_pass(pass, op); + } + + fn run_after_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) { + (**self).run_after_pass(pass, op); + } + + fn run_after_pass_failed(&mut self, pass: &dyn OperationPass, op: &OperationRef) { + (**self).run_after_pass_failed(pass, op); + } + + fn run_before_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) { + (**self).run_before_analysis(name, id, op); + } + + fn run_after_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) { + (**self).run_after_analysis(name, id, op); + } +} + +#[derive(Default)] +pub struct PassInstrumentor { + instrumentations: RefCell; 1]>>, +} + +impl PassInstrumentor { + pub fn run_before_pipeline( + &self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + self.instrument(|pi| pi.run_before_pipeline(name, parent_info)); + } + + pub fn run_after_pipeline( + &self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + self.instrument(|pi| pi.run_after_pipeline(name, parent_info)); + } + + pub fn run_before_pass(&self, pass: &dyn OperationPass, op: &OperationRef) { + self.instrument(|pi| pi.run_before_pass(pass, op)); + } + + pub fn run_after_pass(&self, pass: &dyn OperationPass, op: &OperationRef) { + self.instrument(|pi| pi.run_after_pass(pass, op)); + } + + pub fn run_after_pass_failed(&self, pass: &dyn OperationPass, op: &OperationRef) { + self.instrument(|pi| pi.run_after_pass_failed(pass, op)); + } + + pub fn run_before_analysis(&self, name: &str, id: &TypeId, op: &OperationRef) { + self.instrument(|pi| pi.run_before_analysis(name, id, op)); + } + + pub fn run_after_analysis(&self, name: &str, id: &TypeId, op: &OperationRef) { + self.instrument(|pi| pi.run_after_analysis(name, id, op)); + } + + pub fn add_instrumentation(&self, pi: Box) { + self.instrumentations.borrow_mut().push(pi); + } + + #[inline(always)] + fn instrument(&self, callback: F) + where + F: Fn(&mut dyn PassInstrumentation), + { + let mut instrumentations = self.instrumentations.borrow_mut(); + for pi in instrumentations.iter_mut() { + callback(pi); + } + } +} diff --git a/hir2/src/pass/manager.rs b/hir2/src/pass/manager.rs new file mode 100644 index 000000000..aaba3a5df --- /dev/null +++ b/hir2/src/pass/manager.rs @@ -0,0 +1,912 @@ +use alloc::{collections::BTreeMap, rc::Rc}; + +use compact_str::{CompactString, ToCompactString}; +use miden_assembly::diagnostics::Severity; +use smallvec::{smallvec, SmallVec}; + +use super::{ + AnalysisManager, OperationPass, Pass, PassExecutionState, PassInstrumentation, + PassInstrumentor, PipelineParentInfo, Statistic, +}; +use crate::{ + traits::IsolatedFromAbove, Context, EntityMut, OpPrintingFlags, OpRegistration, Operation, + OperationName, OperationRef, Report, +}; + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)] +pub enum Nesting { + Implicit, + #[default] + Explicit, +} + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)] +pub enum PassDisplayMode { + List, + #[default] + Pipeline, +} + +// TODO(pauls) +#[allow(unused)] +pub struct IRPrintingConfig { + print_module_scope: bool, + print_after_only_on_change: bool, + print_after_only_on_failure: bool, + flags: OpPrintingFlags, +} + +/// The main pass manager and pipeline builder +pub struct PassManager { + context: Rc, + /// The underlying pass manager + pm: OpPassManager, + /// A manager for pass instrumentation + instrumentor: Rc, + /// An optional crash reproducer generator, if this pass manager is setup to + /// generate reproducers. + ///crash_reproducer_generator: Rc, + /// Indicates whether to print pass statistics + statistics: Option, + /// Indicates whether or not pass timing is enabled + timing: bool, + /// Indicates whether or not to run verification between passes + verification: bool, +} + +impl PassManager { + /// Create a new pass manager under the given context with a specific nesting + /// style. The created pass manager can schedule operations that match + /// `operationName`. + pub fn new(context: Rc, name: impl AsRef, nesting: Nesting) -> Self { + let pm = OpPassManager::new(name.as_ref(), nesting, context.clone()); + Self { + context, + pm, + instrumentor: Default::default(), + statistics: None, + timing: false, + verification: true, + } + } + + /// Create a new pass manager under the given context with a specific nesting + /// style. The created pass manager can schedule operations that match + /// `OperationTy`. + pub fn on(context: Rc, nesting: Nesting) -> Self { + Self::new(context, ::name(), nesting) + } + + /// Run the passes within this manager on the provided operation. The + /// specified operation must have the same name as the one provided the pass + /// manager on construction. + pub fn run(&mut self, op: OperationRef) -> Result<(), Report> { + use crate::Spanned; + + let op_name = op.borrow().name(); + let anchor = self.pm.name(); + if let Some(anchor) = anchor { + if anchor != &op_name { + return Err(self + .context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("failed to construct pass manager") + .with_primary_label( + op.borrow().span(), + format!("can't run '{anchor}' pass manager on '{op_name}'"), + ) + .into_report()); + } + } + + // Register all dialects for the current pipeline. + /* + let dependent_dialects = self.get_dependent_dialects(); + self.context.append_dialect_registry(dependent_dialects); + for dialect_name in dependent_dialects.names() { + self.context.get_or_register_dialect(dialect_name); + } + */ + + // Before running, make sure to finalize the pipeline pass list. + self.pm.finalize_pass_list()?; + + // Construct a top level analysis manager for the pipeline. + let analysis_manager = AnalysisManager::new(op.clone(), Some(self.instrumentor.clone())); + + // If reproducer generation is enabled, run the pass manager with crash handling enabled. + /* + let result = if self.crash_reproducer_generator.is_some() { + self.run_with_crash_recovery(op, analysis_manager); + } else { + self.run_passes(op, analysis_manager); + } + */ + let result = self.run_passes(op, analysis_manager); + + // Dump all of the pass statistics if necessary. + if self.statistics.is_some() { + let mut output = String::new(); + self.dump_statistics(&mut output).map_err(Report::msg)?; + println!("{output}"); + } + + result + } + + fn run_passes( + &mut self, + op: OperationRef, + analysis_manager: AnalysisManager, + ) -> Result<(), Report> { + OpToOpPassAdaptor::run_pipeline( + &mut self.pm, + op, + analysis_manager, + self.verification, + Some(self.instrumentor.clone()), + None, + ) + } + + #[inline] + pub fn context(&self) -> Rc { + self.context.clone() + } + + /// Runs the verifier after each individual pass. + pub fn enable_verifier(&mut self, yes: bool) -> &mut Self { + self.verification = yes; + self + } + + pub fn add_instrumentation(&mut self, pi: Box) -> &mut Self { + self.instrumentor.add_instrumentation(pi); + self + } + + pub fn enable_ir_printing(&mut self, _config: IRPrintingConfig) { + todo!() + } + + pub fn enable_timing(&mut self, yes: bool) -> &mut Self { + self.timing = yes; + self + } + + pub fn enable_statistics(&mut self, mode: Option) -> &mut Self { + self.statistics = mode; + self + } + + fn dump_statistics(&mut self, out: &mut dyn core::fmt::Write) -> core::fmt::Result { + self.pm.print_statistics(out, self.statistics.unwrap_or_default()) + } +} + +/// This class represents a pass manager that runs passes on either a specific +/// operation type, or any isolated operation. This pass manager can not be run +/// on an operation directly, but must be run either as part of a top-level +/// `PassManager`(e.g. when constructed via `nest` calls), or dynamically within +/// a pass by using the `Pass::runPipeline` API. +pub struct OpPassManager { + /// The current context + context: Rc, + /// The name of the operation that passes of this pass manager operate on + name: Option, + /// The set of passes to run as part of this pass manager + passes: SmallVec<[Box; 8]>, + /// Control the implicit nesting of passes that mismatch the name set for this manager + nesting: Nesting, +} + +impl OpPassManager { + pub const ANY: &str = "any"; + + /// Construct a new op-agnostic ("any") pass manager with the given operation + /// type and nesting behavior. This is the same as invoking: + /// `OpPassManager(OpPassManager::ANY, nesting)`. + pub fn any(nesting: Nesting, context: Rc) -> Self { + Self { + context, + name: None, + passes: Default::default(), + nesting, + } + } + + pub fn new(name: &str, nesting: Nesting, context: Rc) -> Self { + if name == Self::ANY { + return Self::any(nesting, context); + } + + let (dialect_name, opcode) = name.split_once('.').expect( + "invalid operation name: expected format `.`, but missing `.`", + ); + let dialect_name = + crate::DialectName::from_symbol(crate::interner::Symbol::intern(dialect_name)); + let dialect = context.get_registered_dialect(&dialect_name); + let ops = dialect.registered_ops(); + let name = + ops.iter() + .find(|name| name.name().as_str() == opcode) + .cloned() + .unwrap_or_else(|| { + panic!( + "invalid operation name: found dialect '{dialect_name}', but no operation \ + called '{opcode}' is registered to that dialect" + ) + }); + Self { + context, + name: Some(name), + passes: Default::default(), + nesting, + } + } + + pub fn for_operation(name: OperationName, nesting: Nesting, context: Rc) -> Self { + Self { + context, + name: Some(name), + passes: Default::default(), + nesting, + } + } + + pub fn context(&self) -> Rc { + self.context.clone() + } + + pub fn passes(&self) -> &[Box] { + &self.passes + } + + pub fn passes_mut(&mut self) -> &mut [Box] { + &mut self.passes + } + + pub fn is_empty(&self) -> bool { + self.passes.is_empty() + } + + pub fn len(&self) -> usize { + self.passes.len() + } + + pub fn clear(&mut self) { + self.passes.clear(); + } + + /// Nest a new op-specific pass manager (for the op with the given name), under this pass manager. + pub fn nest_with_type(&mut self, nested_name: &str) -> NestedOpPassManager<'_> { + self.nest(Self::new(nested_name, self.nesting, self.context.clone())) + } + + /// Nest a new op-agnostic ("any") pass manager under this pass manager. + pub fn nest_any(&mut self) -> NestedOpPassManager<'_> { + self.nest(Self::any(self.nesting, self.context.clone())) + } + + fn nest_for(&mut self, nested_name: OperationName) -> NestedOpPassManager<'_> { + self.nest(Self::for_operation(nested_name, self.nesting, self.context.clone())) + } + + pub fn add_pass(&mut self, pass: Box) { + // If this pass runs on a different operation than this pass manager, then implicitly + // nest a pass manager for this operation if enabled. + let pass_op_name = pass.target_name(&self.context); + if let Some(pass_op_name) = pass_op_name { + if self.name.as_ref().is_some_and(|name| name != &pass_op_name) { + if matches!(self.nesting, Nesting::Implicit) { + let mut nested = self.nest_for(pass_op_name); + nested.add_pass(pass); + return; + } + panic!( + "cannot add pass '{}' restricted to '{pass_op_name}' to a pass manager \ + intended to run on '{}', did you intend to nest?", + pass.name(), + self.name().unwrap(), + ); + } + } + + self.passes.push(pass); + } + + pub fn add_nested_pass(&mut self, pass: Box) { + let name = ::name(); + let mut nested = self.nest(Self::new(name.as_str(), self.nesting, self.context.clone())); + nested.add_pass(pass); + } + + pub fn finalize_pass_list(&mut self) -> Result<(), Report> { + /* + auto finalizeAdaptor = [ctx](OpToOpPassAdaptor *adaptor) { + for (auto &pm : adaptor->getPassManagers()) + if (failed(pm.getImpl().finalizePassList(ctx))) + return failure(); + return success(); + }; + + // Walk the pass list and merge adjacent adaptors. + OpToOpPassAdaptor *lastAdaptor = nullptr; + for (auto &pass : passes) { + // Check to see if this pass is an adaptor. + if (auto *currentAdaptor = dyn_cast(pass.get())) { + // If it is the first adaptor in a possible chain, remember it and + // continue. + if (!lastAdaptor) { + lastAdaptor = currentAdaptor; + continue; + } + + // Otherwise, try to merge into the existing adaptor and delete the + // current one. If merging fails, just remember this as the last adaptor. + if (succeeded(currentAdaptor->tryMergeInto(ctx, *lastAdaptor))) + pass.reset(); + else + lastAdaptor = currentAdaptor; + } else if (lastAdaptor) { + // If this pass isn't an adaptor, finalize it and forget the last adaptor. + if (failed(finalizeAdaptor(lastAdaptor))) + return failure(); + lastAdaptor = nullptr; + } + } + + // If there was an adaptor at the end of the manager, finalize it as well. + if (lastAdaptor && failed(finalizeAdaptor(lastAdaptor))) + return failure(); + + // Now that the adaptors have been merged, erase any empty slots corresponding + // to the merged adaptors that were nulled-out in the loop above. + llvm::erase_if(passes, std::logical_not>()); + + // If this is a op-agnostic pass manager, there is nothing left to do. + std::optional rawOpName = getOpName(*ctx); + if (!rawOpName) + return success(); + + // Otherwise, verify that all of the passes are valid for the current + // operation anchor. + std::optional opName = + rawOpName->getRegisteredInfo(); + for (std::unique_ptr &pass : passes) { + if (opName && !pass->canScheduleOn(*opName)) { + return emitError(UnknownLoc::get(ctx)) + << "unable to schedule pass '" << pass->getName() + << "' on a PassManager intended to run on '" << getOpAnchorName() + << "'!"; + } + } + return success(); + */ + todo!() + } + + pub fn name(&self) -> Option<&OperationName> { + self.name.as_ref() + } + + pub fn set_nesting(&mut self, nesting: Nesting) { + self.nesting = nesting; + } + + pub fn nesting(&self) -> Nesting { + self.nesting + } + + /// Indicate if this pass manager can be scheduled on the given operation + pub fn can_schedule_on(&self, name: &OperationName) -> bool { + // If this pass manager is op-specific, we simply check if the provided operation name + // is the same as this one. + if let Some(op_name) = self.name() { + return op_name == name; + } + + // Otherwise, this is an op-agnostic pass manager. Check that the operation can be + // scheduled on all passes within the manager. + if !name.implements::() { + return false; + } + self.passes.iter().all(|pass| pass.can_schedule_on(name)) + } + + fn initialize(&mut self) -> Result<(), Report> { + for pass in self.passes.iter_mut() { + // If this pass isn't an adaptor, directly initialize it + if let Some(adaptor) = pass.as_any_mut().downcast_mut::() { + for pm in adaptor.pass_managers_mut() { + pm.initialize()?; + } + } else { + pass.initialize(self.context.clone())?; + } + } + + Ok(()) + } + + #[allow(unused)] + fn merge_into(&mut self, rhs: &mut Self) { + assert_eq!(self.name, rhs.name, "merging unrelated pass managers"); + for pass in self.passes.drain(..) { + rhs.passes.push(pass); + } + } + + pub fn nest(&mut self, nested: Self) -> NestedOpPassManager<'_> { + let adaptor = Box::new(OpToOpPassAdaptor::new(nested)); + NestedOpPassManager { + parent: self, + nested: Some(adaptor), + } + } + + pub fn print_statistics( + &self, + out: &mut dyn core::fmt::Write, + display_mode: PassDisplayMode, + ) -> core::fmt::Result { + const PASS_STATS_DESCRIPTION: &str = "... Pass statistics report ..."; + + // Print the stats header. + writeln!(out, "=={:-<73}==", "")?; + // Figure out how many spaces for the description name. + let padding = 80usize.saturating_sub(PASS_STATS_DESCRIPTION.len()); + writeln!(out, "{: <1$}", PASS_STATS_DESCRIPTION, padding)?; + writeln!(out, "=={:-<73}==", "")?; + + // Defer to a specialized printer for each display mode. + match display_mode { + PassDisplayMode::List => self.print_statistics_as_list(out), + PassDisplayMode::Pipeline => self.print_statistics_as_pipeline(out), + } + } + + fn add_stats( + pass: &dyn OperationPass, + merged_stats: &mut BTreeMap<&str, SmallVec<[Box; 4]>>, + ) { + use alloc::collections::btree_map::Entry; + + if let Some(adaptor) = pass.as_any().downcast_ref::() { + // Recursively add each of the children. + for pass_manager in adaptor.pass_managers() { + for pass in pass_manager.passes() { + Self::add_stats(&**pass, merged_stats); + } + } + } else { + // If this is not an adaptor, add the stats to the list if there are any. + if !pass.has_statistics() { + return; + } + let statistics = SmallVec::<[Box; 4]>::from_iter( + pass.statistics().iter().map(|stat| Statistic::clone(&**stat)), + ); + match merged_stats.entry(pass.name()) { + Entry::Vacant(entry) => { + entry.insert(statistics); + } + Entry::Occupied(mut entry) => { + let prev_stats = entry.get_mut(); + assert_eq!(prev_stats.len(), statistics.len()); + for (index, mut stat) in statistics.into_iter().enumerate() { + let _ = prev_stats[index].try_merge(&mut stat); + } + } + } + } + } + + /// Print the statistics results in a list form, where each pass is sorted by name. + fn print_statistics_as_list(&self, out: &mut dyn core::fmt::Write) -> core::fmt::Result { + let mut merged_stats = BTreeMap::<&str, SmallVec<[Box; 4]>>::default(); + for pass in self.passes.iter() { + Self::add_stats(&**pass, &mut merged_stats); + } + + // Print the timing information sequentially. + for (pass, stats) in merged_stats.iter() { + self.print_pass_entry(out, 2, pass, stats)?; + } + + Ok(()) + } + + fn print_statistics_as_pipeline(&self, _out: &mut dyn core::fmt::Write) -> core::fmt::Result { + todo!() + } + + fn print_pass_entry( + &self, + out: &mut dyn core::fmt::Write, + indent: usize, + pass: &str, + stats: &[Box], + ) -> core::fmt::Result { + use core::fmt::Write; + + writeln!(out, "{: <1$}", pass, indent)?; + if stats.is_empty() { + return Ok(()); + } + + // Collect the largest name and value length from each of the statistics. + + struct Rendered<'a> { + name: &'a str, + description: &'a str, + value: compact_str::CompactString, + } + + let mut largest_name = 0usize; + let mut largest_value = 0usize; + let mut rendered_stats = SmallVec::<[Rendered; 4]>::default(); + for stat in stats { + let mut value = compact_str::CompactString::default(); + let doc = stat.pretty_print(); + write!(&mut value, "{doc}")?; + let name = stat.name(); + largest_name = core::cmp::max(largest_name, name.len()); + largest_value = core::cmp::max(largest_value, value.len()); + rendered_stats.push(Rendered { + name, + description: stat.description(), + value, + }); + } + + // Sort the statistics by name. + rendered_stats.sort_by(|a, b| a.name.cmp(b.name)); + + // Print statistics + for stat in rendered_stats { + write!(out, "{: <1$} (S) ", "", indent)?; + write!(out, "{: <1$} ", &stat.value, largest_value)?; + write!(out, "{: <1$}", &stat.name, largest_name)?; + if stat.description.is_empty() { + out.write_char('\n')?; + } else { + writeln!(out, " - {}", &stat.description)?; + } + } + + Ok(()) + } +} + +pub struct NestedOpPassManager<'parent> { + parent: &'parent mut OpPassManager, + nested: Option>, +} + +impl<'parent> core::ops::Deref for NestedOpPassManager<'parent> { + type Target = OpPassManager; + + fn deref(&self) -> &Self::Target { + &self.nested.as_deref().unwrap().pass_managers()[0] + } +} +impl<'parent> core::ops::DerefMut for NestedOpPassManager<'parent> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.nested.as_deref_mut().unwrap().pass_managers_mut()[0] + } +} + +impl<'parent> Drop for NestedOpPassManager<'parent> { + fn drop(&mut self) { + self.parent.add_pass(self.nested.take().unwrap() as Box); + } +} + +pub struct OpToOpPassAdaptor { + pms: SmallVec<[OpPassManager; 1]>, +} +impl OpToOpPassAdaptor { + pub fn new(pm: OpPassManager) -> Self { + Self { pms: smallvec![pm] } + } + + pub fn name(&self) -> CompactString { + use core::fmt::Write; + + let mut name = CompactString::default(); + let names = + crate::formatter::DisplayValues::new(self.pms.iter().map(|pm| match pm.name() { + None => alloc::borrow::Cow::Borrowed(OpPassManager::ANY), + Some(name) => alloc::borrow::Cow::Owned(name.to_string()), + })); + write!(&mut name, "Pipeline Collection: [{names}]").unwrap(); + name + } + + /// Try to merge the current pass adaptor into 'rhs'. + /// + /// This will try to append the pass managers of this adaptor into those within `rhs`, or return + /// failure if merging isn't possible. The main situation in which merging is not possible is if + /// one of the adaptors has an `any` pipeline that is not compatible with a pass manager in the + /// other adaptor. For example, if this adaptor has a `hir.function` pipeline and `rhs` has an + /// `any` pipeline that operates on a FunctionOpInterface. In this situation the pipelines have + /// a conflict (they both want to run on the same operations), so we can't merge. + #[allow(unused)] + pub fn try_merge_into(&mut self, _rhs: &mut Self) -> bool { + todo!() + } + + pub fn pass_managers(&self) -> &[OpPassManager] { + &self.pms + } + + pub fn pass_managers_mut(&mut self) -> &mut [OpPassManager] { + &mut self.pms + } + + /// Run the given operation and analysis manager on a provided op pass manager. + fn run_pipeline( + pm: &mut OpPassManager, + op: OperationRef, + analysis_manager: AnalysisManager, + verify: bool, + instrumentor: Option>, + parent_info: Option<&PipelineParentInfo>, + ) -> Result<(), Report> { + assert!( + instrumentor.is_none() || parent_info.is_some(), + "expected parent info if instrumentor is provided" + ); + + // Clear out any computed operation analyses on exit. + // + // These analyses won't be used anymore in this pipeline, and this helps reduce the + // current working set of memory. If preserving these analyses becomes important in the + // future, we can re-evaluate. + let _clear = analysis_manager.defer_clear(); + + // Run the pipeline over the provided operation. + let mut op_name = None; + if let Some(instrumentor) = instrumentor.as_deref() { + op_name = pm.name().cloned(); + instrumentor.run_before_pipeline(op_name.as_ref(), parent_info.as_ref().unwrap()); + } + + for pass in pm.passes_mut() { + Self::run(&mut **pass, op.clone(), analysis_manager.clone(), verify)?; + } + + if let Some(instrumentor) = instrumentor.as_deref() { + instrumentor.run_after_pipeline(op_name.as_ref(), parent_info.as_ref().unwrap()); + } + + Ok(()) + } + + /// Run the given operation and analysis manager on a single pass. + fn run( + pass: &mut dyn OperationPass, + op: OperationRef, + analysis_manager: AnalysisManager, + verify: bool, + ) -> Result<(), Report> { + use crate::Spanned; + + let (op_name, span, context) = { + let op = op.borrow(); + (op.name(), op.span(), op.context_rc()) + }; + if !op_name.implements::() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("failed to execute pass") + .with_primary_label( + span, + "trying to schedule a pass on an operation which does not implement \ + `IsolatedFromAbove`", + ) + .into_report()); + } + if !pass.can_schedule_on(&op_name) { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("failed to execute pass") + .with_primary_label(span, "trying to schedule a pass on an unsupported operation") + .into_report()); + } + + // Initialize the pass state with a callback for the pass to dynamically execute a pipeline + // on the currently visited operation. + let pi = analysis_manager.pass_instrumentor(); + let parent_info = PipelineParentInfo { + pass: Some(pass.name().to_compact_string()), + }; + let callback_op = op.clone(); + let callback_analysis_manager = analysis_manager.clone(); + let pipeline_callback: Box = Box::new( + move |pipeline: &mut OpPassManager, root: OperationRef| -> Result<(), Report> { + let pi = callback_analysis_manager.pass_instrumentor(); + let context = callback_op.borrow().context_rc(); + let root_op = root.borrow(); + if !root_op.is_ancestor_of(&callback_op) { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("failed to execute pass") + .with_primary_label( + root_op.span(), + "trying to schedule a dynamic pass pipeline on an operation that \ + isn't nested under the current operation the pass is processing", + ) + .into_report()); + } + assert!(pipeline.can_schedule_on(&root_op.name())); + // Before running, finalize the passes held by the pipeline + pipeline.finalize_pass_list()?; + + // Initialize the user-provided pipeline and execute the pipeline + pipeline.initialize()?; + + let nested_am = if root == callback_op { + callback_analysis_manager.clone() + } else { + callback_analysis_manager.nest(&root) + }; + Self::run_pipeline(pipeline, root, nested_am, verify, pi, Some(&parent_info)) + }, + ); + + let mut execution_state = PassExecutionState::new( + op.clone(), + context.clone(), + analysis_manager.clone(), + Some(pipeline_callback), + ); + + // Instrument before the pass has run + if let Some(instrumentor) = pi.as_deref() { + instrumentor.run_before_pass(pass, &op); + } + + let mut result = + if let Some(adaptor) = pass.as_any_mut().downcast_mut::() { + adaptor.run_on_operation(op.clone(), &mut execution_state, verify) + } else { + pass.run_on_operation(op.clone(), &mut execution_state) + }; + + // Invalidate any non-preserved analyses + analysis_manager.invalidate(execution_state.preserved_analyses_mut()); + + // When `verify == true`, we run the verifier (unless the pass failed) + if result.is_ok() && verify { + // If the pass is an adaptor pass, we don't run the verifier recursively because the + // nested operations should have already been verified after nested passes had run + let run_verifier_recursively = !pass.as_any().is::(); + + // Reduce compile time by avoiding running the verifier if the pass didn't change the + // IR since the last time the verifier was run: + // + // * If the pass said that it preserved all analyses then it can't have permuted the IR + let run_verifier_now = !execution_state.preserved_analyses().is_all(); + if run_verifier_now { + result = Self::verify(&op, run_verifier_recursively); + } + } + + if let Some(instrumentor) = pi.as_deref() { + if result.is_err() { + instrumentor.run_after_pass_failed(pass, &op); + } else { + instrumentor.run_after_pass(pass, &op); + } + } + + // Return the pass result + result + } + + fn verify(_op: &OperationRef, _verify_recursively: bool) -> Result<(), Report> { + todo!() + } + + fn run_on_operation( + &mut self, + op: OperationRef, + state: &mut PassExecutionState, + verify: bool, + ) -> Result<(), Report> { + let analysis_manager = state.analysis_manager(); + let instrumentor = analysis_manager.pass_instrumentor(); + let parent_info = PipelineParentInfo { + pass: Some(self.name()), + }; + + // Collection region refs so we aren't holding borrows during pass execution + let mut next_region = op.borrow().regions().back().as_pointer(); + while let Some(region) = next_region.take() { + next_region = region.next(); + let mut next_block = region.borrow().body().front().as_pointer(); + while let Some(block) = next_block.take() { + next_block = block.next(); + let mut next_op = block.borrow().front(); + while let Some(op) = next_op.take() { + next_op = op.next(); + let op_name = op.borrow().name(); + if let Some(manager) = + self.pms.iter_mut().find(|pm| pm.can_schedule_on(&op_name)) + { + let am = analysis_manager.nest(&op); + Self::run_pipeline( + manager, + op, + am, + verify, + instrumentor.clone(), + Some(&parent_info), + )?; + } + } + } + } + + Ok(()) + } +} + +impl Pass for OpToOpPassAdaptor { + type Target = Operation; + + fn name(&self) -> &'static str { + crate::interner::Symbol::intern(self.name()).as_str() + } + + #[inline(always)] + fn target_name(&self, _context: &Context) -> Option { + None + } + + fn print_as_textual_pipeline(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "{}", &self.name()) + } + + #[inline(always)] + fn statistics(&self) -> &[Box] { + &[] + } + + #[inline(always)] + fn statistics_mut(&mut self) -> &mut [Box] { + &mut [] + } + + #[inline(always)] + fn can_schedule_on(&self, _name: &OperationName) -> bool { + true + } + + fn run_on_operation( + &mut self, + _op: EntityMut<'_, Operation>, + _state: &mut PassExecutionState, + ) -> Result<(), Report> { + unreachable!("unexpected call to `Pass::run_on_operation` for OpToOpPassAdaptor") + } + + fn run_pipeline( + &mut self, + _pipeline: &mut OpPassManager, + _op: OperationRef, + _state: &mut PassExecutionState, + ) -> Result<(), Report> { + todo!() + } +} diff --git a/hir2/src/pass/pass.rs b/hir2/src/pass/pass.rs new file mode 100644 index 000000000..8cf728b87 --- /dev/null +++ b/hir2/src/pass/pass.rs @@ -0,0 +1,371 @@ +use alloc::rc::Rc; +use core::{any::Any, fmt}; + +use super::*; +use crate::{Context, EntityMut, OperationName, OperationRef, Report}; + +/// A type-erased [Pass]. +/// +/// This is used to allow heterogenous passes to be operated on uniformly. +/// +/// Semantically, an [OperationPass] behaves like a `Pass`. +#[allow(unused_variables)] +pub trait OperationPass { + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + fn name(&self) -> &'static str; + fn argument(&self) -> &'static str { + // NOTE: Could we compute an argument string from the type name? + "" + } + fn description(&self) -> &'static str { + "" + } + fn info(&self) -> PassInfo { + PassInfo::lookup(self.argument()).expect("could not find pass information") + } + /// The name of the operation that this pass operates on, or `None` if this is a generic pass. + fn target_name(&self, context: &Context) -> Option; + fn initialize_options(&mut self, options: &str) -> Result<(), Report> { + Ok(()) + } + fn print_as_textual_pipeline(&self, f: &mut fmt::Formatter) -> fmt::Result; + fn has_statistics(&self) -> bool { + !self.statistics().is_empty() + } + fn statistics(&self) -> &[Box]; + fn statistics_mut(&mut self) -> &mut [Box]; + fn initialize(&mut self, context: Rc) -> Result<(), Report> { + Ok(()) + } + fn can_schedule_on(&self, name: &OperationName) -> bool; + fn run_on_operation( + &mut self, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report>; + fn run_pipeline( + &mut self, + pipeline: &mut OpPassManager, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report>; +} + +impl

OperationPass for P +where + P: Pass + 'static, +{ + fn as_any(&self) -> &dyn Any { +

::as_any(self) + } + + fn as_any_mut(&mut self) -> &mut dyn Any { +

::as_any_mut(self) + } + + fn name(&self) -> &'static str { +

::name(self) + } + + fn argument(&self) -> &'static str { +

::argument(self) + } + + fn description(&self) -> &'static str { +

::description(self) + } + + fn info(&self) -> PassInfo { +

::info(self) + } + + fn target_name(&self, context: &Context) -> Option { +

::target_name(self, context) + } + + fn initialize_options(&mut self, options: &str) -> Result<(), Report> { +

::initialize_options(self, options) + } + + fn print_as_textual_pipeline(&self, f: &mut fmt::Formatter) -> fmt::Result { +

::print_as_textual_pipeline(self, f) + } + + fn has_statistics(&self) -> bool { +

::has_statistics(self) + } + + fn statistics(&self) -> &[Box] { +

::statistics(self) + } + + fn statistics_mut(&mut self) -> &mut [Box] { +

::statistics_mut(self) + } + + fn initialize(&mut self, context: Rc) -> Result<(), Report> { +

::initialize(self, context) + } + + fn can_schedule_on(&self, name: &OperationName) -> bool { +

::can_schedule_on(self, name) + } + + fn run_on_operation( + &mut self, + mut op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report> { + let op = <

::Target as PassTarget>::into_target_mut(&mut op); +

::run_on_operation(self, op, state) + } + + fn run_pipeline( + &mut self, + pipeline: &mut OpPassManager, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report> { +

::run_pipeline(self, pipeline, op, state) + } +} + +/// A compiler pass which operates on an [Operation] of some kind. +#[allow(unused_variables)] +pub trait Pass: Sized + Any { + /// The concrete/trait type targeted by this pass. + /// + /// Calls to `get_operation` will return a reference of this type. + type Target: ?Sized + PassTarget; + + /// Used for downcasting + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + /// Used for downcasting + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self as &mut dyn Any + } + + /// The display name of this pass + fn name(&self) -> &'static str; + /// The command line option name used to control this pass + fn argument(&self) -> &'static str { + // NOTE: Could we compute an argument string from the type name? + "" + } + /// A description of what this pass does. + fn description(&self) -> &'static str { + "" + } + /// Obtain the underlying [PassInfo] object for this pass. + fn info(&self) -> PassInfo { + PassInfo::lookup(self.argument()).expect("pass is not currently registered") + } + /// The name of the operation that this pass operates on, or `None` if this is a generic pass. + fn target_name(&self, context: &Context) -> Option { + <::Target as PassTarget>::target_name(context) + } + /// If command-line options are provided for this pass, implementations must parse the raw + /// options here, returning `Err` if parsing fails for some reason. + /// + /// By default, this is a no-op. + fn initialize_options(&mut self, options: &str) -> Result<(), Report> { + Ok(()) + } + /// Print this pass as a textual pipeline. + fn print_as_textual_pipeline(&self, f: &mut fmt::Formatter) -> fmt::Result; + /// Returns true if this pass has associated statistics + fn has_statistics(&self) -> bool { + !self.statistics().is_empty() + } + /// Get pass statistics associated with this pass + fn statistics(&self) -> &[Box]; + /// Get mutable access to the pass statistics associated with this pass + fn statistics_mut(&mut self) -> &mut [Box]; + /// Initialize any complex state necessary for running this pass. + /// + /// This hook should not rely on any state accessible during the execution of a pass. For + /// example, `context`/`get_operation`/`get_analysis`/etc. should not be invoked within this + /// hook. + /// + /// This method is invoked after all dependent dialects for the pipeline are loaded, and is not + /// allowed to load any further dialects (override the `get_dependent_dialects()` hook for this + /// purpose instead). Returns `Err` with a diagnostic if initialization fails, in which case the + /// pass pipeline won't execute. + fn initialize(&mut self, context: Rc) -> Result<(), Report> { + Ok(()) + } + /// Query if this pass can be scheduled to run on the given operation type. + fn can_schedule_on(&self, name: &OperationName) -> bool; + /// Run this pass on the current operation + fn run_on_operation( + &mut self, + op: EntityMut<'_, Self::Target>, + state: &mut PassExecutionState, + ) -> Result<(), Report>; + /// Run this pass as part of `pipeline` on `op` + fn run_pipeline( + &mut self, + pipeline: &mut OpPassManager, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report>; +} + +impl

Pass for Box

+where + P: Pass, +{ + type Target =

::Target; + + fn as_any(&self) -> &dyn Any { + (**self).as_any() + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + (**self).as_any_mut() + } + + #[inline] + fn name(&self) -> &'static str { + (**self).name() + } + + #[inline] + fn argument(&self) -> &'static str { + (**self).argument() + } + + #[inline] + fn description(&self) -> &'static str { + (**self).description() + } + + #[inline] + fn info(&self) -> PassInfo { + (**self).info() + } + + #[inline] + fn target_name(&self, context: &Context) -> Option { + (**self).target_name(context) + } + + #[inline] + fn initialize_options(&mut self, options: &str) -> Result<(), Report> { + (**self).initialize_options(options) + } + + #[inline] + fn print_as_textual_pipeline(&self, f: &mut fmt::Formatter) -> fmt::Result { + (**self).print_as_textual_pipeline(f) + } + + #[inline] + fn has_statistics(&self) -> bool { + (**self).has_statistics() + } + + #[inline] + fn statistics(&self) -> &[Box] { + (**self).statistics() + } + + #[inline] + fn statistics_mut(&mut self) -> &mut [Box] { + (**self).statistics_mut() + } + + #[inline] + fn initialize(&mut self, context: Rc) -> Result<(), Report> { + (**self).initialize(context) + } + + #[inline] + fn can_schedule_on(&self, name: &OperationName) -> bool { + (**self).can_schedule_on(name) + } + + #[inline] + fn run_on_operation( + &mut self, + op: EntityMut<'_, Self::Target>, + state: &mut PassExecutionState, + ) -> Result<(), Report> { + (**self).run_on_operation(op, state) + } + + #[inline] + fn run_pipeline( + &mut self, + pipeline: &mut OpPassManager, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report> { + (**self).run_pipeline(pipeline, op, state) + } +} + +pub type DynamicPipelineExecutor = + dyn FnMut(&mut OpPassManager, OperationRef) -> Result<(), Report>; + +/// The state for a single execution of a pass. This provides a unified +/// interface for accessing and initializing necessary state for pass execution. +pub struct PassExecutionState { + /// The operation being transformed + op: OperationRef, + context: Rc, + analysis_manager: AnalysisManager, + /// The set of preserved analyses for the current execution + preserved_analyses: PreservedAnalyses, + // Callback in the pass manager that allows one to schedule dynamic pipelines that will be + // rooted at the provided operation. + #[allow(unused)] + pipeline_executor: Option>, +} +impl PassExecutionState { + pub fn new( + op: OperationRef, + context: Rc, + analysis_manager: AnalysisManager, + pipeline_executor: Option>, + ) -> Self { + Self { + op, + context, + analysis_manager, + preserved_analyses: Default::default(), + pipeline_executor, + } + } + + #[inline(always)] + pub fn context(&self) -> Rc { + self.context.clone() + } + + #[inline(always)] + pub const fn current_operation(&self) -> &OperationRef { + &self.op + } + + #[inline(always)] + pub const fn analysis_manager(&self) -> &AnalysisManager { + &self.analysis_manager + } + + #[inline(always)] + pub const fn preserved_analyses(&self) -> &PreservedAnalyses { + &self.preserved_analyses + } + + #[inline(always)] + pub fn preserved_analyses_mut(&mut self) -> &mut PreservedAnalyses { + &mut self.preserved_analyses + } +} diff --git a/hir2/src/pass/registry.rs b/hir2/src/pass/registry.rs new file mode 100644 index 000000000..88b68b87f --- /dev/null +++ b/hir2/src/pass/registry.rs @@ -0,0 +1,397 @@ +use alloc::{collections::BTreeMap, sync::Arc}; +use core::any::TypeId; + +use midenc_hir_symbol::sync::{LazyLock, RwLock}; +use midenc_session::diagnostics::DiagnosticsHandler; + +use super::*; +use crate::Report; + +static PASS_REGISTRY: LazyLock = LazyLock::new(PassRegistry::new); + +/// A global, thread-safe pass and pass pipeline registry +/// +/// You should generally _not_ need to work with this directly. +pub struct PassRegistry { + passes: RwLock>, + pipelines: RwLock>, +} +impl Default for PassRegistry { + fn default() -> Self { + Self::new() + } +} +impl PassRegistry { + /// Create a new [PassRegistry] instance. + pub fn new() -> Self { + let mut passes = BTreeMap::default(); + let mut pipelines = BTreeMap::default(); + for pass in inventory::iter::() { + passes.insert( + pass.0.arg, + PassRegistryEntry { + arg: pass.0.arg, + description: pass.0.description, + type_id: pass.0.type_id, + builder: Arc::clone(&pass.0.builder), + }, + ); + } + for pipeline in inventory::iter::() { + pipelines.insert( + pipeline.0.arg, + PassRegistryEntry { + arg: pipeline.0.arg, + description: pipeline.0.description, + type_id: pipeline.0.type_id, + builder: Arc::clone(&pipeline.0.builder), + }, + ); + } + + Self { + passes: RwLock::new(passes), + pipelines: RwLock::new(pipelines), + } + } + + /// Get the pass information for the pass whose argument name is `name` + pub fn get_pass(&self, name: &str) -> Option { + self.passes.read().get(name).cloned().map(PassInfo) + } + + /// Get the pass pipeline information for the pipeline whose argument name is `name` + pub fn get_pipeline(&self, name: &str) -> Option { + self.pipelines.read().get(name).cloned().map(PassPipelineInfo) + } + + /// Register the given pass + pub fn register_pass(&self, info: PassInfo) { + use alloc::collections::btree_map::Entry; + + let mut passes = self.passes.write(); + match passes.entry(info.argument()) { + Entry::Vacant(entry) => { + entry.insert(info.0); + } + Entry::Occupied(entry) => { + assert_eq!( + entry.get().type_id, + info.0.type_id, + "cannot register pass '{}': name already registered by a different type", + info.argument() + ); + } + } + } + + /// Register the given pass pipeline + pub fn register_pipeline(&self, info: PassPipelineInfo) { + use alloc::collections::btree_map::Entry; + + let mut pipelines = self.pipelines.write(); + match pipelines.entry(info.argument()) { + Entry::Vacant(entry) => { + entry.insert(info.0); + } + Entry::Occupied(entry) => { + assert_eq!( + entry.get().type_id, + info.0.type_id, + "cannot register pass pipeline '{}': name already registered by a different \ + type", + info.argument() + ); + assert!(Arc::ptr_eq(&entry.get().builder, &info.0.builder)) + } + } + } +} + +inventory::collect!(PassInfo); +inventory::collect!(PassPipelineInfo); + +/// A type alias for the closure type for registering a pass with a pass manager +pub type PassRegistryFunction = dyn Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + + Send + + Sync + + 'static; + +/// A type alias for the closure type used for type-erased pass constructors +pub type PassAllocatorFunction = dyn Fn() -> Box; + +/// A [RegistryEntry] is a registered pass or pass pipeline. +/// +/// This trait provides the common functionality shared by both passes and pipelines. +pub trait RegistryEntry { + /// Returns the command-line option that may be passed to `midenc` that will cause this pass + /// or pass pipeline to run. + fn argument(&self) -> &'static str; + /// Return a description for the pass or pass pipeline. + fn description(&self) -> &'static str; + /// Adds this entry to the given pass manager. + /// + /// Note: `options` is an opaque string that will be parsed by the builder. + /// + /// Returns `Err` if an error occurred parsing the given options. + fn add_to_pipeline( + &self, + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), Report>; +} + +/// Information about a pass or pass pipeline in the pass registry +#[derive(Clone)] +struct PassRegistryEntry { + /// The name of the compiler option for referencing on the command line + arg: &'static str, + /// A description of the pass or pass pipeline + description: &'static str, + /// The type id of the concrete pass type + type_id: Option, + /// Function that registers this entry with a pass manager pipeline + builder: Arc, +} +impl RegistryEntry for PassRegistryEntry { + #[inline] + fn add_to_pipeline( + &self, + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), Report> { + (self.builder)(pm, options, diagnostics) + } + + #[inline(always)] + fn argument(&self) -> &'static str { + self.arg + } + + #[inline(always)] + fn description(&self) -> &'static str { + self.description + } +} + +/// Information about a registered pass pipeline +pub struct PassPipelineInfo(PassRegistryEntry); +impl PassPipelineInfo { + pub fn new(arg: &'static str, description: &'static str, builder: B) -> Self + where + B: Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + + Send + + Sync + + 'static, + { + Self(PassRegistryEntry { + arg, + description, + type_id: None, + builder: Arc::new(builder), + }) + } + + /// Find the [PassInfo] for a registered pass pipeline named `name` + pub fn lookup(name: &str) -> Option { + PASS_REGISTRY.get_pipeline(name) + } +} +impl RegistryEntry for PassPipelineInfo { + fn argument(&self) -> &'static str { + self.0.argument() + } + + fn description(&self) -> &'static str { + self.0.description() + } + + fn add_to_pipeline( + &self, + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), Report> { + self.0.add_to_pipeline(pm, options, diagnostics) + } +} + +/// Information about a registered pass +pub struct PassInfo(PassRegistryEntry); +impl PassInfo { + /// Create a new [PassInfo] from the given argument name and description, for a default- + /// constructible pass type `P`. + pub fn new(arg: &'static str, description: &'static str) -> Self { + let type_id = TypeId::of::

(); + Self(PassRegistryEntry { + arg, + description, + type_id: Some(type_id), + builder: Arc::new(default_registration::

), + }) + } + + /// Find the [PassInfo] for a registered pass named `name` + pub fn lookup(name: &str) -> Option { + PASS_REGISTRY.get_pass(name) + } +} +impl RegistryEntry for PassInfo { + fn argument(&self) -> &'static str { + self.0.argument() + } + + fn description(&self) -> &'static str { + self.0.description() + } + + fn add_to_pipeline( + &self, + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), Report> { + self.0.add_to_pipeline(pm, options, diagnostics) + } +} + +/// Register a specific dialect pipeline registry function with the system. +/// +/// # Example +/// +/// If your pipeline implements the [Default] trait, you can just do: +/// +/// ```text,ignore +/// register_pass_pipeline( +/// "my-pipeline", +/// "A simple test pipeline", +/// default_registration::(), +/// ) +/// ``` +/// +/// Otherwise, you need to pass a factor function which will be used to construct fresh instances +/// of the pipeline: +/// +/// ```text,ignore +/// register_pass_pipeline( +/// "my-pipeline", +/// "A simple test pipeline", +/// default_dyn_registration(|| MyPipeline::new(MyPipelineOptions::default())), +/// ) +/// ``` +/// +/// NOTE: The functions/closures passed above are required to be `Send + Sync + 'static`, as they +/// are stored in the global registry for the lifetime of the program, and may be accessed from any +/// thread. +pub fn register_pass_pipeline(arg: &'static str, description: &'static str, builder: B) +where + B: Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + + Send + + Sync + + 'static, +{ + PASS_REGISTRY.register_pipeline(PassPipelineInfo(PassRegistryEntry { + arg, + description, + type_id: None, + builder: Arc::new(builder), + })); +} + +/// Register a specific dialect pass allocator function with the system. +/// +/// # Example +/// +/// ```text,ignore +/// register_pass(|| MyPass::default()) +/// ``` +/// +/// NOTE: The allocator function provided is required to be `Send + Sync + 'static`, as it is +/// stored in the global registry for the lifetime of the program, and may be accessed from any +/// thread. +pub fn register_pass(ctor: impl Fn() -> Box + Send + Sync + 'static) { + let pass = ctor(); + let type_id = pass.as_any().type_id(); + let arg = pass.argument(); + assert!( + !arg.is_empty(), + "attempted to register pass '{}' without specifying an argument name", + pass.name() + ); + let description = pass.description(); + PASS_REGISTRY.register_pass(PassInfo(PassRegistryEntry { + arg, + description, + type_id: Some(type_id), + builder: Arc::new(default_registration_factory(ctor)), + })); +} + +/// A default implementation of a pass pipeline registration function. +/// +/// It expects that `P` (the type of the pass or pass pipeline), implements `Default`, so that an +/// instance is default-constructible. It then initializes the pass with the provided options, +/// validates that the pass/pipeline is valid for the parent pipeline, and adds it if so. +pub fn default_registration( + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, +) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + + let mut pass = Box::

::default() as Box; + let result = pass.initialize_options(options); + let pm_op_name = pm.name(); + let pass_op_name = pass.target_name(&pm.context()); + let pass_op_name = pass_op_name.as_ref(); + if matches!(pm.nesting(), Nesting::Explicit) && pm_op_name != pass_op_name { + return Err(diagnostics + .diagnostic(Severity::Error) + .with_message(format!( + "registration error for pass '{}': can't add pass restricted to '{}' on a pass \ + manager intended to run on '{}', did you intend to nest?", + pass.name(), + crate::formatter::DisplayOptional(pass_op_name.as_ref()), + crate::formatter::DisplayOptional(pm_op_name), + )) + .into_report()); + } + pm.add_pass(pass); + result +} + +/// Like [default_registration], but takes an arbitrary constructor in the form of a zero-arity +/// closure, rather than relying on [Default]. Thus, this is actually a registration function +/// _factory_, rather than a registration function itself. +pub fn default_registration_factory Box + Send + Sync + 'static>( + builder: B, +) -> impl Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + Send + Sync + 'static +{ + use midenc_session::diagnostics::Severity; + move |pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler| + -> Result<(), Report> { + let mut pass = builder(); + let result = pass.initialize_options(options); + let pm_op_name = pm.name(); + let pass_op_name = pass.target_name(&pm.context()); + let pass_op_name = pass_op_name.as_ref(); + if matches!(pm.nesting(), Nesting::Explicit) && pm_op_name != pass_op_name { + return Err(diagnostics + .diagnostic(Severity::Error) + .with_message(format!( + "registration error for pass '{}': can't add pass restricted to '{}' on a \ + pass manager intended to run on '{}', did you intend to nest?", + pass.name(), + crate::formatter::DisplayOptional(pass_op_name.as_ref()), + crate::formatter::DisplayOptional(pm_op_name), + )) + .into_report()); + } + pm.add_pass(pass); + result + } +} diff --git a/hir2/src/pass/specialization.rs b/hir2/src/pass/specialization.rs new file mode 100644 index 000000000..4c10a378f --- /dev/null +++ b/hir2/src/pass/specialization.rs @@ -0,0 +1,142 @@ +use crate::{ + traits::BranchOpInterface, Context, EntityMut, EntityRef, Op, Operation, OperationName, + OperationRef, Symbol, SymbolTable, +}; + +pub trait PassTarget { + fn target_name(context: &Context) -> Option; + fn into_target(op: &OperationRef) -> EntityRef<'_, Self>; + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, Self>; +} + +impl PassTarget for T { + default fn target_name(_context: &Context) -> Option { + None + } + + #[inline] + default fn into_target(op: &OperationRef) -> EntityRef<'_, T> { + EntityRef::map(op.borrow(), |t| { + t.downcast_ref::().unwrap_or_else(|| expected_type::(op)) + }) + } + + #[inline] + default fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, T> { + EntityMut::map(op.borrow_mut(), |t| { + t.downcast_mut::().unwrap_or_else(|| expected_type::(op)) + }) + } +} +impl PassTarget for Operation { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + #[inline] + fn into_target(op: &OperationRef) -> EntityRef<'_, Operation> { + op.borrow() + } + + #[inline] + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, Operation> { + op.borrow_mut() + } +} +impl PassTarget for dyn Op { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + fn into_target(op: &OperationRef) -> EntityRef<'_, dyn Op> { + EntityRef::map(op.borrow(), |op| op.as_trait::().unwrap()) + } + + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, dyn Op> { + EntityMut::map(op.borrow_mut(), |op| op.as_trait_mut::().unwrap()) + } +} +impl PassTarget for dyn BranchOpInterface { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + fn into_target(op: &OperationRef) -> EntityRef<'_, dyn BranchOpInterface> { + EntityRef::map(op.borrow(), |t| { + t.as_trait::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } + + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, dyn BranchOpInterface> { + EntityMut::map(op.borrow_mut(), |t| { + t.as_trait_mut::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } +} +impl PassTarget for dyn Symbol { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + fn into_target(op: &OperationRef) -> EntityRef<'_, dyn Symbol> { + EntityRef::map(op.borrow(), |t| { + t.as_trait::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } + + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, dyn Symbol> { + EntityMut::map(op.borrow_mut(), |t| { + t.as_trait_mut::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } +} +impl PassTarget for dyn SymbolTable + 'static { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + fn into_target(op: &OperationRef) -> EntityRef<'_, dyn SymbolTable + 'static> { + EntityRef::map(op.borrow(), |t| { + t.as_trait::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } + + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, dyn SymbolTable + 'static> { + EntityMut::map(op.borrow_mut(), |t| { + t.as_trait_mut::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } +} + +#[cold] +#[inline(never)] +#[track_caller] +fn expected_type(op: &OperationRef) -> ! { + panic!( + "expected operation '{}' to be a `{}`", + op.borrow().name(), + core::any::type_name::(), + ) +} + +#[cold] +#[inline(never)] +#[track_caller] +fn expected_implementation(op: &OperationRef) -> ! { + panic!( + "expected '{}' to implement `{}`, but no vtable was found", + op.borrow().name(), + core::any::type_name::() + ) +} diff --git a/hir2/src/pass/statistics.rs b/hir2/src/pass/statistics.rs new file mode 100644 index 000000000..042d29c2b --- /dev/null +++ b/hir2/src/pass/statistics.rs @@ -0,0 +1,462 @@ +use core::{any::Any, fmt}; + +use compact_str::CompactString; + +use crate::Report; + +/// A [Statistic] represents some stateful datapoint collected by and across passes. +/// +/// Statistics are named, have a description, and have a value. The value can be pretty printed, +/// and multiple instances of the same statistic can be merged together. +#[derive(Clone)] +pub struct PassStatistic { + pub name: CompactString, + pub description: CompactString, + pub value: V, +} +impl PassStatistic +where + V: StatisticValue, +{ + pub fn new(name: CompactString, description: CompactString, value: V) -> Self { + Self { + name, + description, + value, + } + } +} +impl Eq for PassStatistic {} +impl PartialEq for PassStatistic { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } +} +impl PartialOrd for PassStatistic { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.name.cmp(&other.name)) + } +} +impl Ord for PassStatistic { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.name.cmp(&other.name) + } +} +impl Statistic for PassStatistic +where + V: Clone + StatisticValue + 'static, +{ + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + &self.description + } + + fn pretty_print(&self) -> crate::formatter::Document { + self.value.pretty_print() + } + + fn try_merge(&mut self, other: &mut dyn Any) -> Result<(), Report> { + let lhs = &mut self.value; + if let Some(rhs) = other.downcast_mut::<::Value>() { + lhs.merge(rhs); + Ok(()) + } else { + let name = &self.name; + let expected_ty = core::any::type_name::<::Value>(); + Err(Report::msg(format!( + "could not merge statistic '{name}': expected value of type '{expected_ty}', but \ + got a value of some other type" + ))) + } + } + + fn clone(&self) -> Box { + use core::clone::CloneToUninit; + let mut this = Box::new_uninit(); + unsafe { + self.clone_to_uninit(this.as_mut_ptr()); + this.assume_init() + } + } +} + +/// An abstraction over statistics that allows operating generically over statistics with different +/// types of values. +pub trait Statistic { + /// The display name of this statistic + fn name(&self) -> &str; + /// A description of what this statistic means and why it is significant + fn description(&self) -> &str; + /// Pretty prints this statistic as a value + fn pretty_print(&self) -> crate::formatter::Document; + /// Merges another instance of this statistic into this one, given a mutable reference to the + /// raw underlying value of the other instance. + /// + /// Returns `Err` if `other` is not a valid value type for this statistic + fn try_merge(&mut self, other: &mut dyn Any) -> Result<(), Report>; + /// Clones the underlying statistic + fn clone(&self) -> Box; +} + +pub trait StatisticValue { + type Value: Any + Clone; + + fn value(&self) -> &Self::Value; + fn value_mut(&mut self) -> &mut Self::Value; + fn value_as_any(&self) -> &dyn Any { + self.value() as &dyn Any + } + fn value_as_any_mut(&mut self) -> &mut dyn Any { + self.value_mut() as &mut dyn Any + } + fn expected_type(&self) -> &'static str { + core::any::type_name::<::Value>() + } + fn merge(&mut self, other: &mut Self::Value); + fn pretty_print(&self) -> crate::formatter::Document; +} + +impl dyn StatisticValue { + pub fn downcast_ref(&self) -> Option<&T> { + self.value_as_any().downcast_ref::() + } + + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.value_as_any_mut().downcast_mut::() + } +} + +/// Merges via OR +impl StatisticValue for bool { + type Value = bool; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + *self |= *other; + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(*self) + } +} + +/// A boolean flag which evalutates to true, only if all observed values are false. +/// +/// Defaults to false, and merges by OR. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub struct FlagNone(bool); +impl From for bool { + #[inline(always)] + fn from(flag: FlagNone) -> Self { + !flag.0 + } +} +impl StatisticValue for FlagNone { + type Value = FlagNone; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + if !self.0 && !other.0 { + self.0 = true; + } else { + self.0 ^= other.0 + } + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(bool::from(*self)) + } +} + +/// A boolean flag which evaluates to true, only if at least one true value was observed. +/// +/// Defaults to false, and merges by OR. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct FlagAny(bool); +impl From for bool { + #[inline(always)] + fn from(flag: FlagAny) -> Self { + flag.0 + } +} +impl StatisticValue for FlagAny { + type Value = FlagAny; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + self.0 |= other.0; + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(bool::from(*self)) + } +} + +/// A boolean flag which evaluates to true, only if all observed values were true. +/// +/// Defaults to true, and merges by AND. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct FlagAll(bool); +impl From for bool { + #[inline(always)] + fn from(flag: FlagAll) -> Self { + flag.0 + } +} +impl StatisticValue for FlagAll { + type Value = FlagAll; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + self.0 &= other.0 + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(bool::from(*self)) + } +} + +macro_rules! numeric_statistic { + (#[cfg $($args:tt)*] $int_ty:ty) => { + /// Adds two numbers by saturating addition + #[cfg $($args)*] + impl StatisticValue for $int_ty { + type Value = $int_ty; + fn value(&self) -> &Self::Value { self } + fn value_mut(&mut self) -> &mut Self::Value { self } + fn merge(&mut self, other: &mut Self::Value) { + *self = self.saturating_add(*other); + } + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(*self) + } + } + }; + + (#[cfg $($args:tt)*] $int_ty:ty as $wrapper_ty:ty) => { + /// Adds two numbers by saturating addition + #[cfg $($args)*] + impl StatisticValue for $int_ty { + type Value = $int_ty; + fn value(&self) -> &Self::Value { self } + fn value_mut(&mut self) -> &mut Self::Value { self } + fn merge(&mut self, other: &mut Self::Value) { + *self = self.saturating_add(*other); + } + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(<$wrapper_ty>::from(*self)) + } + } + }; + + ($int_ty:ty) => { + /// Adds two numbers by saturating addition + impl StatisticValue for $int_ty { + type Value = $int_ty; + fn value(&self) -> &Self::Value { self } + fn value_mut(&mut self) -> &mut Self::Value { self } + fn merge(&mut self, other: &mut Self::Value) { + *self = self.saturating_add(*other); + } + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(*self) + } + } + } +} + +numeric_statistic!(u8); +numeric_statistic!(i8); +numeric_statistic!(u16); +numeric_statistic!(i16); +numeric_statistic!(u32); +numeric_statistic!(i32); +numeric_statistic!(u64); +numeric_statistic!(i64); +numeric_statistic!(usize); +numeric_statistic!(isize); +numeric_statistic!( + #[cfg(feature = "std")] + std::time::Duration as midenc_session::HumanDuration +); +numeric_statistic!( + #[cfg(feature = "std")] + midenc_session::HumanDuration +); + +impl StatisticValue for f64 { + type Value = f64; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + *self += *other; + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(*self) + } +} + +/// Merges an array of statistic values element-wise +impl StatisticValue for [T; N] +where + T: Any + StatisticValue + Clone, +{ + type Value = [T; N]; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + for index in 0..N { + self[index].merge(other[index].value_mut()); + } + } + + fn pretty_print(&self) -> crate::formatter::Document { + use crate::formatter::const_text; + + let doc = const_text("["); + self.iter().enumerate().fold(doc, |mut doc, (i, item)| { + if i > 0 { + doc += const_text(", "); + } + doc + item.pretty_print() + }) + const_text("]") + } +} + +/// Merges two vectors of statistics by appending +impl StatisticValue for Vec +where + T: Any + StatisticValue + Clone, +{ + type Value = Vec; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + self.append(other); + } + + fn pretty_print(&self) -> crate::formatter::Document { + use crate::formatter::const_text; + + let doc = const_text("["); + self.iter().enumerate().fold(doc, |mut doc, (i, item)| { + if i > 0 { + doc += const_text(", "); + } + doc + item.pretty_print() + }) + const_text("]") + } +} + +/// Merges two maps of statistics by merging values of identical keys, and appending missing keys +impl StatisticValue for alloc::collections::BTreeMap +where + K: Ord + Clone + fmt::Display + 'static, + V: Any + StatisticValue + Clone, +{ + type Value = alloc::collections::BTreeMap; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + use alloc::collections::btree_map::Entry; + + while let Some((k, mut v)) = other.pop_first() { + match self.entry(k) { + Entry::Vacant(entry) => { + entry.insert(v); + } + Entry::Occupied(mut entry) => { + entry.get_mut().merge(v.value_mut()); + } + } + } + } + + fn pretty_print(&self) -> crate::formatter::Document { + use crate::formatter::{const_text, indent, nl, text, Document}; + if self.is_empty() { + const_text("{}") + } else { + let single_line = const_text("{") + + self.iter().enumerate().fold(Document::Empty, |mut doc, (i, (k, v))| { + if i > 0 { + doc += const_text(", "); + } + doc + text(format!("{k}: ")) + v.pretty_print() + }) + + const_text("}"); + let multi_line = const_text("{") + + indent( + 4, + self.iter().enumerate().fold(nl(), |mut doc, (i, (k, v))| { + if i > 0 { + doc += const_text(",") + nl(); + } + doc + text(format!("{k}: ")) + v.pretty_print() + }) + nl(), + ) + + const_text("}"); + single_line | multi_line + } + } +} diff --git a/midenc-session/src/duration.rs b/midenc-session/src/duration.rs index 34aae7d9a..99d053ddf 100644 --- a/midenc-session/src/duration.rs +++ b/midenc-session/src/duration.rs @@ -3,6 +3,7 @@ use std::{ time::{Duration, Instant}, }; +#[derive(Copy, Clone)] pub struct HumanDuration(Duration); impl HumanDuration { pub fn since(i: Instant) -> Self { @@ -14,6 +15,41 @@ impl HumanDuration { pub fn as_secs_f64(&self) -> f64 { self.0.as_secs_f64() } + + /// Adds two [HumanDuration], using saturating arithmetic + #[inline] + pub fn saturating_add(self, rhs: Self) -> Self { + Self(self.0.saturating_add(rhs.0)) + } +} +impl core::ops::Add for HumanDuration { + type Output = HumanDuration; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0) + } +} +impl core::ops::Add for HumanDuration { + type Output = HumanDuration; + + fn add(self, rhs: Duration) -> Self::Output { + Self(self.0 + rhs) + } +} +impl core::ops::AddAssign for HumanDuration { + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + } +} +impl core::ops::AddAssign for HumanDuration { + fn add_assign(&mut self, rhs: Duration) { + self.0 += rhs; + } +} +impl From for Duration { + fn from(d: HumanDuration) -> Self { + d.0 + } } impl From for HumanDuration { fn from(d: Duration) -> Self { From f7f89f444116c421ef92028acfc283a6900ad103 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Mon, 28 Oct 2024 02:39:35 -0400 Subject: [PATCH 26/31] feat: add small data structures to hir2, implement smalldeque --- Cargo.toml | 2 +- hir2/src/adt.rs | 21 + hir2/src/adt/smalldeque.rs | 3768 +++++++++++++++++++++++++++++++++++ hir2/src/adt/smallmap.rs | 491 +++++ hir2/src/adt/smallordset.rs | 189 ++ hir2/src/adt/smallprio.rs | 167 ++ hir2/src/adt/smallset.rs | 298 +++ hir2/src/adt/sparsemap.rs | 233 +++ hir2/src/lib.rs | 20 + 9 files changed, 5188 insertions(+), 1 deletion(-) create mode 100644 hir2/src/adt.rs create mode 100644 hir2/src/adt/smalldeque.rs create mode 100644 hir2/src/adt/smallmap.rs create mode 100644 hir2/src/adt/smallordset.rs create mode 100644 hir2/src/adt/smallprio.rs create mode 100644 hir2/src/adt/smallset.rs create mode 100644 hir2/src/adt/sparsemap.rs diff --git a/Cargo.toml b/Cargo.toml index 5d3406143..75a276703 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ exclude = [ [workspace.package] version = "0.0.6" -rust-version = "1.80" +rust-version = "1.81" authors = ["Miden contributors"] description = "An intermediate representation and compiler for Miden Assembly" repository = "https://github.com/0xPolygonMiden/compiler" diff --git a/hir2/src/adt.rs b/hir2/src/adt.rs new file mode 100644 index 000000000..2b8471fa9 --- /dev/null +++ b/hir2/src/adt.rs @@ -0,0 +1,21 @@ +pub mod smalldeque; +pub mod smallmap; +pub mod smallordset; +pub mod smallprio; +pub mod smallset; +pub mod sparsemap; + +pub use self::{ + smalldeque::SmallDeque, + smallmap::SmallMap, + smallordset::SmallOrdSet, + smallprio::SmallPriorityQueue, + smallset::SmallSet, + sparsemap::{SparseMap, SparseMapValue}, +}; + +#[doc(hidden)] +pub trait SizedTypeProperties: Sized { + const IS_ZST: bool = core::mem::size_of::() == 0; +} +impl SizedTypeProperties for T {} diff --git a/hir2/src/adt/smalldeque.rs b/hir2/src/adt/smalldeque.rs new file mode 100644 index 000000000..7995bc50a --- /dev/null +++ b/hir2/src/adt/smalldeque.rs @@ -0,0 +1,3768 @@ +use core::{ + cmp::Ordering, + fmt, + iter::{repeat_n, repeat_with, ByRefSized}, + ops::{Index, IndexMut, Range, RangeBounds}, + ptr::{self, NonNull}, +}; + +use smallvec::SmallVec; + +use super::SizedTypeProperties; + +/// [SmallDeque] is a [alloc::collections::VecDeque]-like structure that can store a specified +/// number of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// This data structure is designed to basically provide the functionality of `VecDeque` without +/// needing to allocate on the heap for small numbers of nodes. +/// +/// Internally, [SmallDeque] is implemented on top of [SmallVec]. +/// +/// Most of the implementation is ripped from the standard library `VecDeque` impl, but adapted +/// for `SmallVec` +pub struct SmallDeque { + /// `self[0]`, if it exists, is `buf[head]`. + /// `head < buf.capacity()`, unless `buf.capacity() == 0` when `head == 0`. + head: usize, + /// The number of initialized elements, starting from the one at `head` and potentially + /// wrapping around. + /// + /// If `len == 0`, the exact value of `head` is unimportant. + /// + /// If `T` is zero-sized, then `self.len <= usize::MAX`, otherwise + /// `self.len <= isize::MAX as usize` + len: usize, + buf: SmallVec<[T; N]>, +} +impl Clone for SmallDeque { + fn clone(&self) -> Self { + let mut deq = Self::with_capacity(self.len()); + deq.extend(self.iter().cloned()); + deq + } + + fn clone_from(&mut self, source: &Self) { + self.clear(); + self.extend(source.iter().cloned()); + } +} +impl Default for SmallDeque { + fn default() -> Self { + Self { + head: 0, + len: 0, + buf: Default::default(), + } + } +} +impl SmallDeque { + /// Returns a new, empty [SmallDeque] + #[inline] + #[must_use] + pub const fn new() -> Self { + Self { + head: 0, + len: 0, + buf: SmallVec::new_const(), + } + } + + /// Create an empty deque with pre-allocated space for `capacity` elements. + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + Self { + head: 0, + len: 0, + buf: SmallVec::with_capacity(capacity), + } + } + + /// Returns true if this map is empty + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the number of key/value pairs in this map + pub fn len(&self) -> usize { + self.len + } + + /// Return a front-to-back iterator. + pub fn iter(&self) -> Iter<'_, T> { + let (a, b) = self.as_slices(); + Iter::new(a.iter(), b.iter()) + } + + /// Return a front-to-back iterator that returns mutable references + pub fn iter_mut(&mut self) -> IterMut<'_, T> { + let (a, b) = self.as_mut_slices(); + IterMut::new(a.iter_mut(), b.iter_mut()) + } + + /// Returns a pair of slices which contain, in order, the contents of the + /// deque. + /// + /// If [`SmallDeque::make_contiguous`] was previously called, all elements of the + /// deque will be in the first slice and the second slice will be empty. + #[inline] + pub fn as_slices(&self) -> (&[T], &[T]) { + let (a_range, b_range) = self.slice_ranges(.., self.len); + // SAFETY: `slice_ranges` always returns valid ranges into the physical buffer. + unsafe { (&*self.buffer_range(a_range), &*self.buffer_range(b_range)) } + } + + /// Returns a pair of slices which contain, in order, the contents of the + /// deque. + /// + /// If [`SmallDeque::make_contiguous`] was previously called, all elements of the + /// deque will be in the first slice and the second slice will be empty. + #[inline] + pub fn as_mut_slices(&mut self) -> (&mut [T], &mut [T]) { + let (a_range, b_range) = self.slice_ranges(.., self.len); + // SAFETY: `slice_ranges` always returns valid ranges into the physical buffer. + unsafe { (&mut *self.buffer_range_mut(a_range), &mut *self.buffer_range_mut(b_range)) } + } + + /// Given a range into the logical buffer of the deque, this function + /// return two ranges into the physical buffer that correspond to + /// the given range. The `len` parameter should usually just be `self.len`; + /// the reason it's passed explicitly is that if the deque is wrapped in + /// a `Drain`, then `self.len` is not actually the length of the deque. + /// + /// # Safety + /// + /// This function is always safe to call. For the resulting ranges to be valid + /// ranges into the physical buffer, the caller must ensure that the result of + /// calling `slice::range(range, ..len)` represents a valid range into the + /// logical buffer, and that all elements in that range are initialized. + fn slice_ranges(&self, range: R, len: usize) -> (Range, Range) + where + R: RangeBounds, + { + let Range { start, end } = core::slice::range(range, ..len); + let len = end - start; + + if len == 0 { + (0..0, 0..0) + } else { + // `slice::range` guarantees that `start <= end <= len`. + // because `len != 0`, we know that `start < end`, so `start < len` + // and the indexing is valid. + let wrapped_start = self.to_physical_idx(start); + + // this subtraction can never overflow because `wrapped_start` is + // at most `self.capacity()` (and if `self.capacity != 0`, then `wrapped_start` is strictly less + // than `self.capacity`). + let head_len = self.capacity() - wrapped_start; + + if head_len >= len { + // we know that `len + wrapped_start <= self.capacity <= usize::MAX`, so this addition can't overflow + (wrapped_start..wrapped_start + len, 0..0) + } else { + // can't overflow because of the if condition + let tail_len = len - head_len; + (wrapped_start..self.capacity(), 0..tail_len) + } + } + } + + /// Creates an iterator that covers the specified range in the deque. + /// + /// # Panics + /// + /// Panics if the starting point is greater than the end point or if + /// the end point is greater than the length of the deque. + #[inline] + pub fn range(&self, range: R) -> Iter<'_, T> + where + R: RangeBounds, + { + let (a_range, b_range) = self.slice_ranges(range, self.len); + // SAFETY: The ranges returned by `slice_ranges` + // are valid ranges into the physical buffer, so + // it's ok to pass them to `buffer_range` and + // dereference the result. + let a = unsafe { &*self.buffer_range(a_range) }; + let b = unsafe { &*self.buffer_range(b_range) }; + Iter::new(a.iter(), b.iter()) + } + + /// Creates an iterator that covers the specified mutable range in the deque. + /// + /// # Panics + /// + /// Panics if the starting point is greater than the end point or if + /// the end point is greater than the length of the deque. + #[inline] + pub fn range_mut(&mut self, range: R) -> IterMut<'_, T> + where + R: RangeBounds, + { + let (a_range, b_range) = self.slice_ranges(range, self.len); + // SAFETY: The ranges returned by `slice_ranges` + // are valid ranges into the physical buffer, so + // it's ok to pass them to `buffer_range` and + // dereference the result. + let a = unsafe { &mut *self.buffer_range_mut(a_range) }; + let b = unsafe { &mut *self.buffer_range_mut(b_range) }; + IterMut::new(a.iter_mut(), b.iter_mut()) + } + + /// Get a reference to the element at the given index. + /// + /// Element at index 0 is the front of the queue. + pub fn get(&self, index: usize) -> Option<&T> { + if index < self.len { + let index = self.to_physical_idx(index); + unsafe { Some(&*self.ptr().add(index)) } + } else { + None + } + } + + /// Get a mutable reference to the element at the given index. + /// + /// Element at index 0 is the front of the queue. + pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { + if index < self.len { + let index = self.to_physical_idx(index); + unsafe { Some(&mut *self.ptr_mut().add(index)) } + } else { + None + } + } + + /// Swaps elements at indices `i` and `j` + /// + /// `i` and `j` may be equal. + /// + /// Element at index 0 is the front of the queue. + /// + /// # Panics + /// + /// Panics if either index is out of bounds. + pub fn swap(&mut self, i: usize, j: usize) { + assert!(i < self.len()); + assert!(j < self.len()); + let ri = self.to_physical_idx(i); + let rj = self.to_physical_idx(j); + unsafe { ptr::swap(self.ptr_mut().add(ri), self.ptr_mut().add(rj)) } + } + + /// Returns the number of elements the deque can hold without reallocating. + #[inline] + pub fn capacity(&self) -> usize { + if T::IS_ZST { + usize::MAX + } else { + self.buf.capacity() + } + } + + /// Reserves the minimum capacity for at least `additional` more elements to be inserted in the + /// given deque. Does nothing if the capacity is already sufficient. + /// + /// Note that the allocator may give the collection more space than it requests. Therefore + /// capacity can not be relied upon to be precisely minimal. Prefer [`reserve`] if future + /// insertions are expected. + /// + /// # Panics + /// + /// Panics if the new capacity overflows `usize`. + pub fn reserve_exact(&mut self, additional: usize) { + let new_cap = self.len.checked_add(additional).expect("capacity overflow"); + let old_cap = self.capacity(); + + if new_cap > old_cap { + self.buf.try_grow(new_cap).expect("capacity overflow"); + unsafe { + self.handle_capacity_increase(old_cap); + } + } + } + + /// Reserves capacity for at least `additional` more elements to be inserted in the given + /// deque. The collection may reserve more space to speculatively avoid frequent reallocations. + /// + /// # Panics + /// + /// Panics if the new capacity overflows `usize`. + pub fn reserve(&mut self, additional: usize) { + let new_cap = self.len.checked_add(additional).expect("capacity overflow"); + let old_cap = self.capacity(); + + if new_cap > old_cap { + // we don't need to reserve_exact(), as the size doesn't have + // to be a power of 2. + self.buf.try_grow(new_cap).expect("capacity overflow"); + unsafe { + self.handle_capacity_increase(old_cap); + } + } + } + + /// Shortens the deque, keeping the first `len` elements and dropping + /// the rest. + /// + /// If `len` is greater or equal to the deque's current length, this has + /// no effect. + pub fn truncate(&mut self, len: usize) { + /// Runs the destructor for all items in the slice when it gets dropped (normally or + /// during unwinding). + struct Dropper<'a, T>(&'a mut [T]); + + impl<'a, T> Drop for Dropper<'a, T> { + fn drop(&mut self) { + unsafe { + ptr::drop_in_place(self.0); + } + } + } + + // Safe because: + // + // * Any slice passed to `drop_in_place` is valid; the second case has + // `len <= front.len()` and returning on `len > self.len()` ensures + // `begin <= back.len()` in the first case + // * The head of the SmallDeque is moved before calling `drop_in_place`, + // so no value is dropped twice if `drop_in_place` panics + unsafe { + if len >= self.len { + return; + } + + let (front, back) = self.as_mut_slices(); + if len > front.len() { + let begin = len - front.len(); + let drop_back = back.get_unchecked_mut(begin..) as *mut _; + self.len = len; + ptr::drop_in_place(drop_back); + } else { + let drop_back = back as *mut _; + let drop_front = front.get_unchecked_mut(len..) as *mut _; + self.len = len; + + // Make sure the second half is dropped even when a destructor + // in the first one panics. + let _back_dropper = Dropper(&mut *drop_back); + ptr::drop_in_place(drop_front); + } + } + } + + /// Removes the specified range from the deque in bulk, returning all + /// removed elements as an iterator. If the iterator is dropped before + /// being fully consumed, it drops the remaining removed elements. + /// + /// The returned iterator keeps a mutable borrow on the queue to optimize + /// its implementation. + /// + /// + /// # Panics + /// + /// Panics if the starting point is greater than the end point or if + /// the end point is greater than the length of the deque. + /// + /// # Leaking + /// + /// If the returned iterator goes out of scope without being dropped (due to + /// [`mem::forget`], for example), the deque may have lost and leaked + /// elements arbitrarily, including elements outside the range. + #[inline] + pub fn drain(&mut self, range: R) -> Drain<'_, T, N> + where + R: RangeBounds, + { + // Memory safety + // + // When the Drain is first created, the source deque is shortened to + // make sure no uninitialized or moved-from elements are accessible at + // all if the Drain's destructor never gets to run. + // + // Drain will ptr::read out the values to remove. + // When finished, the remaining data will be copied back to cover the hole, + // and the head/tail values will be restored correctly. + // + let Range { start, end } = core::slice::range(range, ..self.len); + let drain_start = start; + let drain_len = end - start; + + // The deque's elements are parted into three segments: + // * 0 -> drain_start + // * drain_start -> drain_start+drain_len + // * drain_start+drain_len -> self.len + // + // H = self.head; T = self.head+self.len; t = drain_start+drain_len; h = drain_head + // + // We store drain_start as self.len, and drain_len and self.len as + // drain_len and orig_len respectively on the Drain. This also + // truncates the effective array such that if the Drain is leaked, we + // have forgotten about the potentially moved values after the start of + // the drain. + // + // H h t T + // [. . . o o x x o o . . .] + // + // "forget" about the values after the start of the drain until after + // the drain is complete and the Drain destructor is run. + + unsafe { Drain::new(self, drain_start, drain_len) } + } + + /// Clears the deque, removing all values. + #[inline] + pub fn clear(&mut self) { + self.truncate(0); + // Not strictly necessary, but leaves things in a more consistent/predictable state. + self.head = 0; + } + + /// Returns `true` if the deque contains an element equal to the + /// given value. + /// + /// This operation is *O*(*n*). + /// + /// Note that if you have a sorted `SmallDeque`, [`binary_search`] may be faster. + /// + /// [`binary_search`]: SmallDeque::binary_search + pub fn contains(&self, x: &T) -> bool + where + T: PartialEq, + { + let (a, b) = self.as_slices(); + a.contains(x) || b.contains(x) + } + + /// Provides a reference to the front element, or `None` if the deque is + /// empty. + pub fn front(&self) -> Option<&T> { + self.get(0) + } + + /// Provides a mutable reference to the front element, or `None` if the + /// deque is empty. + pub fn front_mut(&mut self) -> Option<&mut T> { + self.get_mut(0) + } + + /// Provides a reference to the back element, or `None` if the deque is + /// empty. + pub fn back(&self) -> Option<&T> { + self.get(self.len.wrapping_sub(1)) + } + + /// Provides a mutable reference to the back element, or `None` if the + /// deque is empty. + pub fn back_mut(&mut self) -> Option<&mut T> { + self.get_mut(self.len.wrapping_sub(1)) + } + + /// Removes the first element and returns it, or `None` if the deque is + /// empty. + pub fn pop_front(&mut self) -> Option { + if self.is_empty() { + None + } else { + let old_head = self.head; + self.head = self.to_physical_idx(1); + self.len -= 1; + unsafe { + core::hint::assert_unchecked(self.len < self.capacity()); + Some(self.buffer_read(old_head)) + } + } + } + + /// Removes the last element from the deque and returns it, or `None` if + /// it is empty. + pub fn pop_back(&mut self) -> Option { + if self.is_empty() { + None + } else { + self.len -= 1; + unsafe { + core::hint::assert_unchecked(self.len < self.capacity()); + Some(self.buffer_read(self.to_physical_idx(self.len))) + } + } + } + + /// Prepends an element to the deque. + pub fn push_front(&mut self, value: T) { + if self.is_full() { + self.grow(); + } + + self.head = self.wrap_sub(self.head, 1); + self.len += 1; + + unsafe { + self.buffer_write(self.head, value); + } + } + + /// Appends an element to the back of the deque. + pub fn push_back(&mut self, value: T) { + if self.is_full() { + self.grow(); + } + + unsafe { self.buffer_write(self.to_physical_idx(self.len), value) } + self.len += 1; + } + + #[inline] + fn is_contiguous(&self) -> bool { + // Do the calculation like this to avoid overflowing if len + head > usize::MAX + self.head <= self.capacity() - self.len + } + + /// Removes an element from anywhere in the deque and returns it, + /// replacing it with the first element. + /// + /// This does not preserve ordering, but is *O*(1). + /// + /// Returns `None` if `index` is out of bounds. + /// + /// Element at index 0 is the front of the queue. + pub fn swap_remove_front(&mut self, index: usize) -> Option { + let length = self.len; + if index < length && index != 0 { + self.swap(index, 0); + } else if index >= length { + return None; + } + self.pop_front() + } + + /// Removes an element from anywhere in the deque and returns it, + /// replacing it with the last element. + /// + /// This does not preserve ordering, but is *O*(1). + /// + /// Returns `None` if `index` is out of bounds. + /// + /// Element at index 0 is the front of the queue. + pub fn swap_remove_back(&mut self, index: usize) -> Option { + let length = self.len; + if length > 0 && index < length - 1 { + self.swap(index, length - 1); + } else if index >= length { + return None; + } + self.pop_back() + } + + /// Inserts an element at `index` within the deque, shifting all elements + /// with indices greater than or equal to `index` towards the back. + /// + /// Element at index 0 is the front of the queue. + /// + /// # Panics + /// + /// Panics if `index` is greater than deque's length + pub fn insert(&mut self, index: usize, value: T) { + assert!(index <= self.len(), "index out of bounds"); + if self.is_full() { + self.grow(); + } + + let k = self.len - index; + if k < index { + // `index + 1` can't overflow, because if index was usize::MAX, then either the + // assert would've failed, or the deque would've tried to grow past usize::MAX + // and panicked. + unsafe { + // see `remove()` for explanation why this wrap_copy() call is safe. + self.wrap_copy(self.to_physical_idx(index), self.to_physical_idx(index + 1), k); + self.buffer_write(self.to_physical_idx(index), value); + self.len += 1; + } + } else { + let old_head = self.head; + self.head = self.wrap_sub(self.head, 1); + unsafe { + self.wrap_copy(old_head, self.head, index); + self.buffer_write(self.to_physical_idx(index), value); + self.len += 1; + } + } + } + + /// Removes and returns the element at `index` from the deque. + /// Whichever end is closer to the removal point will be moved to make + /// room, and all the affected elements will be moved to new positions. + /// Returns `None` if `index` is out of bounds. + /// + /// Element at index 0 is the front of the queue. + pub fn remove(&mut self, index: usize) -> Option { + if self.len <= index { + return None; + } + + let wrapped_idx = self.to_physical_idx(index); + + let elem = unsafe { Some(self.buffer_read(wrapped_idx)) }; + + let k = self.len - index - 1; + // safety: due to the nature of the if-condition, whichever wrap_copy gets called, + // its length argument will be at most `self.len / 2`, so there can't be more than + // one overlapping area. + if k < index { + unsafe { self.wrap_copy(self.wrap_add(wrapped_idx, 1), wrapped_idx, k) }; + self.len -= 1; + } else { + let old_head = self.head; + self.head = self.to_physical_idx(1); + unsafe { self.wrap_copy(old_head, self.head, index) }; + self.len -= 1; + } + + elem + } + + /// Splits the deque into two at the given index. + /// + /// Returns a newly allocated `SmallDeque`. `self` contains elements `[0, at)`, + /// and the returned deque contains elements `[at, len)`. + /// + /// Note that the capacity of `self` does not change. + /// + /// Element at index 0 is the front of the queue. + /// + /// # Panics + /// + /// Panics if `at > len`. + #[inline] + #[must_use = "use `.truncate()` if you don't need the other half"] + pub fn split_off(&mut self, at: usize) -> Self { + let len = self.len; + assert!(at <= len, "`at` out of bounds"); + + let other_len = len - at; + let mut other = Self::with_capacity(other_len); + + unsafe { + let (first_half, second_half) = self.as_slices(); + + let first_len = first_half.len(); + let second_len = second_half.len(); + if at < first_len { + // `at` lies in the first half. + let amount_in_first = first_len - at; + + ptr::copy_nonoverlapping( + first_half.as_ptr().add(at), + other.ptr_mut(), + amount_in_first, + ); + + // just take all of the second half. + ptr::copy_nonoverlapping( + second_half.as_ptr(), + other.ptr_mut().add(amount_in_first), + second_len, + ); + } else { + // `at` lies in the second half, need to factor in the elements we skipped + // in the first half. + let offset = at - first_len; + let amount_in_second = second_len - offset; + ptr::copy_nonoverlapping( + second_half.as_ptr().add(offset), + other.ptr_mut(), + amount_in_second, + ); + } + } + + // Cleanup where the ends of the buffers are + self.len = at; + other.len = other_len; + + other + } + + /// Moves all the elements of `other` into `self`, leaving `other` empty. + /// + /// # Panics + /// + /// Panics if the new number of elements in self overflows a `usize`. + #[inline] + pub fn append(&mut self, other: &mut Self) { + if T::IS_ZST { + self.len = self.len.checked_add(other.len).expect("capacity overflow"); + other.len = 0; + other.head = 0; + return; + } + + self.reserve(other.len); + unsafe { + let (left, right) = other.as_slices(); + self.copy_slice(self.to_physical_idx(self.len), left); + // no overflow, because self.capacity() >= old_cap + left.len() >= self.len + left.len() + self.copy_slice(self.to_physical_idx(self.len + left.len()), right); + } + // SAFETY: Update pointers after copying to avoid leaving doppelganger + // in case of panics. + self.len += other.len; + // Now that we own its values, forget everything in `other`. + other.len = 0; + other.head = 0; + } + + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `f(&e)` returns false. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + pub fn retain(&mut self, mut f: F) + where + F: FnMut(&T) -> bool, + { + self.retain_mut(|elem| f(elem)); + } + + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `f(&e)` returns false. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + pub fn retain_mut(&mut self, mut f: F) + where + F: FnMut(&mut T) -> bool, + { + let len = self.len; + let mut idx = 0; + let mut cur = 0; + + // Stage 1: All values are retained. + while cur < len { + if !f(&mut self[cur]) { + cur += 1; + break; + } + cur += 1; + idx += 1; + } + // Stage 2: Swap retained value into current idx. + while cur < len { + if !f(&mut self[cur]) { + cur += 1; + continue; + } + + self.swap(idx, cur); + cur += 1; + idx += 1; + } + // Stage 3: Truncate all values after idx. + if cur != idx { + self.truncate(idx); + } + } + + // Double the buffer size. This method is inline(never), so we expect it to only + // be called in cold paths. + // This may panic or abort + #[inline(never)] + fn grow(&mut self) { + // Extend or possibly remove this assertion when valid use-cases for growing the + // buffer without it being full emerge + debug_assert!(self.is_full()); + let old_cap = self.capacity(); + self.buf.grow(old_cap + 1); + unsafe { + self.handle_capacity_increase(old_cap); + } + debug_assert!(!self.is_full()); + } + + /// Modifies the deque in-place so that `len()` is equal to `new_len`, + /// either by removing excess elements from the back or by appending + /// elements generated by calling `generator` to the back. + pub fn resize_with(&mut self, new_len: usize, generator: impl FnMut() -> T) { + let len = self.len; + + if new_len > len { + self.extend(repeat_with(generator).take(new_len - len)) + } else { + self.truncate(new_len); + } + } + + /// Rearranges the internal storage of this deque so it is one contiguous + /// slice, which is then returned. + /// + /// This method does not allocate and does not change the order of the + /// inserted elements. As it returns a mutable slice, this can be used to + /// sort a deque. + /// + /// Once the internal storage is contiguous, the [`as_slices`] and + /// [`as_mut_slices`] methods will return the entire contents of the + /// deque in a single slice. + /// + /// [`as_slices`]: SmallDeque::as_slices + /// [`as_mut_slices`]: SmallDeque::as_mut_slices + pub fn make_contiguous(&mut self) -> &mut [T] { + if T::IS_ZST { + self.head = 0; + } + + if self.is_contiguous() { + unsafe { + return core::slice::from_raw_parts_mut(self.ptr_mut().add(self.head), self.len); + } + } + + let &mut Self { head, len, .. } = self; + let ptr = self.ptr_mut(); + let cap = self.capacity(); + + let free = cap - len; + let head_len = cap - head; + let tail = len - head_len; + let tail_len = tail; + + if free >= head_len { + // there is enough free space to copy the head in one go, + // this means that we first shift the tail backwards, and then + // copy the head to the correct position. + // + // from: DEFGH....ABC + // to: ABCDEFGH.... + unsafe { + self.copy(0, head_len, tail_len); + // ...DEFGH.ABC + self.copy_nonoverlapping(head, 0, head_len); + // ABCDEFGH.... + } + + self.head = 0; + } else if free >= tail_len { + // there is enough free space to copy the tail in one go, + // this means that we first shift the head forwards, and then + // copy the tail to the correct position. + // + // from: FGH....ABCDE + // to: ...ABCDEFGH. + unsafe { + self.copy(head, tail, head_len); + // FGHABCDE.... + self.copy_nonoverlapping(0, tail + head_len, tail_len); + // ...ABCDEFGH. + } + + self.head = tail; + } else { + // `free` is smaller than both `head_len` and `tail_len`. + // the general algorithm for this first moves the slices + // right next to each other and then uses `slice::rotate` + // to rotate them into place: + // + // initially: HIJK..ABCDEFG + // step 1: ..HIJKABCDEFG + // step 2: ..ABCDEFGHIJK + // + // or: + // + // initially: FGHIJK..ABCDE + // step 1: FGHIJKABCDE.. + // step 2: ABCDEFGHIJK.. + + // pick the shorter of the 2 slices to reduce the amount + // of memory that needs to be moved around. + if head_len > tail_len { + // tail is shorter, so: + // 1. copy tail forwards + // 2. rotate used part of the buffer + // 3. update head to point to the new beginning (which is just `free`) + + unsafe { + // if there is no free space in the buffer, then the slices are already + // right next to each other and we don't need to move any memory. + if free != 0 { + // because we only move the tail forward as much as there's free space + // behind it, we don't overwrite any elements of the head slice, and + // the slices end up right next to each other. + self.copy(0, free, tail_len); + } + + // We just copied the tail right next to the head slice, + // so all of the elements in the range are initialized + let slice = &mut *self.buffer_range_mut(free..self.capacity()); + + // because the deque wasn't contiguous, we know that `tail_len < self.len == slice.len()`, + // so this will never panic. + slice.rotate_left(tail_len); + + // the used part of the buffer now is `free..self.capacity()`, so set + // `head` to the beginning of that range. + self.head = free; + } + } else { + // head is shorter so: + // 1. copy head backwards + // 2. rotate used part of the buffer + // 3. update head to point to the new beginning (which is the beginning of the buffer) + + unsafe { + // if there is no free space in the buffer, then the slices are already + // right next to each other and we don't need to move any memory. + if free != 0 { + // copy the head slice to lie right behind the tail slice. + self.copy(self.head, tail_len, head_len); + } + + // because we copied the head slice so that both slices lie right + // next to each other, all the elements in the range are initialized. + let slice = &mut *self.buffer_range_mut(0..self.len); + + // because the deque wasn't contiguous, we know that `head_len < self.len == slice.len()` + // so this will never panic. + slice.rotate_right(head_len); + + // the used part of the buffer now is `0..self.len`, so set + // `head` to the beginning of that range. + self.head = 0; + } + } + } + + unsafe { core::slice::from_raw_parts_mut(ptr.add(self.head), self.len) } + } + + /// Rotates the double-ended queue `n` places to the left. + /// + /// Equivalently, + /// - Rotates item `n` into the first position. + /// - Pops the first `n` items and pushes them to the end. + /// - Rotates `len() - n` places to the right. + /// + /// # Panics + /// + /// If `n` is greater than `len()`. Note that `n == len()` + /// does _not_ panic and is a no-op rotation. + /// + /// # Complexity + /// + /// Takes `*O*(min(n, len() - n))` time and no extra space. + pub fn rotate_left(&mut self, n: usize) { + assert!(n <= self.len()); + let k = self.len - n; + if n <= k { + unsafe { self.rotate_left_inner(n) } + } else { + unsafe { self.rotate_right_inner(k) } + } + } + + /// Rotates the double-ended queue `n` places to the right. + /// + /// Equivalently, + /// - Rotates the first item into position `n`. + /// - Pops the last `n` items and pushes them to the front. + /// - Rotates `len() - n` places to the left. + /// + /// # Panics + /// + /// If `n` is greater than `len()`. Note that `n == len()` + /// does _not_ panic and is a no-op rotation. + /// + /// # Complexity + /// + /// Takes `*O*(min(n, len() - n))` time and no extra space. + pub fn rotate_right(&mut self, n: usize) { + assert!(n <= self.len()); + let k = self.len - n; + if n <= k { + unsafe { self.rotate_right_inner(n) } + } else { + unsafe { self.rotate_left_inner(k) } + } + } + + // SAFETY: the following two methods require that the rotation amount + // be less than half the length of the deque. + // + // `wrap_copy` requires that `min(x, capacity() - x) + copy_len <= capacity()`, + // but then `min` is never more than half the capacity, regardless of x, + // so it's sound to call here because we're calling with something + // less than half the length, which is never above half the capacity. + unsafe fn rotate_left_inner(&mut self, mid: usize) { + debug_assert!(mid * 2 <= self.len()); + unsafe { + self.wrap_copy(self.head, self.to_physical_idx(self.len), mid); + } + self.head = self.to_physical_idx(mid); + } + + unsafe fn rotate_right_inner(&mut self, k: usize) { + debug_assert!(k * 2 <= self.len()); + self.head = self.wrap_sub(self.head, k); + unsafe { + self.wrap_copy(self.to_physical_idx(self.len), self.head, k); + } + } + + /// Binary searches this `SmallDeque` for a given element. + /// If the `SmallDeque` is not sorted, the returned result is unspecified and + /// meaningless. + /// + /// If the value is found then [`Result::Ok`] is returned, containing the + /// index of the matching element. If there are multiple matches, then any + /// one of the matches could be returned. If the value is not found then + /// [`Result::Err`] is returned, containing the index where a matching + /// element could be inserted while maintaining sorted order. + /// + /// See also [`binary_search_by`], [`binary_search_by_key`], and [`partition_point`]. + /// + /// [`binary_search_by`]: SmallDeque::binary_search_by + /// [`binary_search_by_key`]: SmallDeque::binary_search_by_key + /// [`partition_point`]: SmallDeque::partition_point + /// + #[inline] + pub fn binary_search(&self, x: &T) -> Result + where + T: Ord, + { + self.binary_search_by(|e| e.cmp(x)) + } + + /// Binary searches this `SmallDeque` with a comparator function. + /// + /// The comparator function should return an order code that indicates + /// whether its argument is `Less`, `Equal` or `Greater` the desired + /// target. + /// If the `SmallDeque` is not sorted or if the comparator function does not + /// implement an order consistent with the sort order of the underlying + /// `SmallDeque`, the returned result is unspecified and meaningless. + /// + /// If the value is found then [`Result::Ok`] is returned, containing the + /// index of the matching element. If there are multiple matches, then any + /// one of the matches could be returned. If the value is not found then + /// [`Result::Err`] is returned, containing the index where a matching + /// element could be inserted while maintaining sorted order. + /// + /// See also [`binary_search`], [`binary_search_by_key`], and [`partition_point`]. + /// + /// [`binary_search`]: SmallDeque::binary_search + /// [`binary_search_by_key`]: SmallDeque::binary_search_by_key + /// [`partition_point`]: SmallDeque::partition_point + pub fn binary_search_by<'a, F>(&'a self, mut f: F) -> Result + where + F: FnMut(&'a T) -> Ordering, + { + let (front, back) = self.as_slices(); + // clippy doesn't recognize that `f` would be moved if we followed it's recommendation + #[allow(clippy::redundant_closure)] + let cmp_back = back.first().map(|e| f(e)); + + if let Some(Ordering::Equal) = cmp_back { + Ok(front.len()) + } else if let Some(Ordering::Less) = cmp_back { + back.binary_search_by(f) + .map(|idx| idx + front.len()) + .map_err(|idx| idx + front.len()) + } else { + front.binary_search_by(f) + } + } + + /// Binary searches this `SmallDeque` with a key extraction function. + /// + /// Assumes that the deque is sorted by the key, for instance with + /// [`make_contiguous().sort_by_key()`] using the same key extraction function. + /// If the deque is not sorted by the key, the returned result is + /// unspecified and meaningless. + /// + /// If the value is found then [`Result::Ok`] is returned, containing the + /// index of the matching element. If there are multiple matches, then any + /// one of the matches could be returned. If the value is not found then + /// [`Result::Err`] is returned, containing the index where a matching + /// element could be inserted while maintaining sorted order. + /// + /// See also [`binary_search`], [`binary_search_by`], and [`partition_point`]. + /// + /// [`make_contiguous().sort_by_key()`]: SmallDeque::make_contiguous + /// [`binary_search`]: SmallDeque::binary_search + /// [`binary_search_by`]: SmallDeque::binary_search_by + /// [`partition_point`]: SmallDeque::partition_point + #[inline] + pub fn binary_search_by_key<'a, B, F>(&'a self, b: &B, mut f: F) -> Result + where + F: FnMut(&'a T) -> B, + B: Ord, + { + self.binary_search_by(|k| f(k).cmp(b)) + } + + /// Returns the index of the partition point according to the given predicate + /// (the index of the first element of the second partition). + /// + /// The deque is assumed to be partitioned according to the given predicate. + /// This means that all elements for which the predicate returns true are at the start of the deque + /// and all elements for which the predicate returns false are at the end. + /// For example, `[7, 15, 3, 5, 4, 12, 6]` is partitioned under the predicate `x % 2 != 0` + /// (all odd numbers are at the start, all even at the end). + /// + /// If the deque is not partitioned, the returned result is unspecified and meaningless, + /// as this method performs a kind of binary search. + /// + /// See also [`binary_search`], [`binary_search_by`], and [`binary_search_by_key`]. + /// + /// [`binary_search`]: SmallDeque::binary_search + /// [`binary_search_by`]: SmallDeque::binary_search_by + /// [`binary_search_by_key`]: SmallDeque::binary_search_by_key + pub fn partition_point

(&self, mut pred: P) -> usize + where + P: FnMut(&T) -> bool, + { + let (front, back) = self.as_slices(); + + #[allow(clippy::redundant_closure)] + if let Some(true) = back.first().map(|v| pred(v)) { + back.partition_point(pred) + front.len() + } else { + front.partition_point(pred) + } + } +} + +impl SmallDeque { + /// Modifies the deque in-place so that `len()` is equal to new_len, + /// either by removing excess elements from the back or by appending clones of `value` + /// to the back. + pub fn resize(&mut self, new_len: usize, value: T) { + if new_len > self.len() { + let extra = new_len - self.len(); + self.extend(repeat_n(value, extra)) + } else { + self.truncate(new_len); + } + } +} + +impl SmallDeque { + #[inline] + fn ptr(&self) -> *const T { + self.buf.as_ptr() + } + + #[inline] + fn ptr_mut(&mut self) -> *mut T { + self.buf.as_mut_ptr() + } + + /// Appends an element to the buffer. + /// + /// # Safety + /// + /// May only be called if `deque.len() < deque.capacity()` + #[inline] + unsafe fn push_unchecked(&mut self, element: T) { + // SAFETY: Because of the precondition, it's guaranteed that there is space in the logical + // array after the last element. + unsafe { self.buffer_write(self.to_physical_idx(self.len), element) }; + // This can't overflow because `deque.len() < deque.capacity() <= usize::MAX` + self.len += 1; + } + + /// Moves an element out of the buffer + #[inline] + unsafe fn buffer_read(&mut self, offset: usize) -> T { + unsafe { ptr::read(self.ptr().add(offset)) } + } + + /// Writes an element into the buffer, moving it. + #[inline] + unsafe fn buffer_write(&mut self, offset: usize, value: T) { + unsafe { + ptr::write(self.ptr_mut().add(offset), value); + } + } + + /// Returns a slice pointer into the buffer. + /// `range` must lie inside `0..self.capacity()`. + #[inline] + unsafe fn buffer_range(&self, range: core::ops::Range) -> *const [T] { + unsafe { ptr::slice_from_raw_parts(self.ptr().add(range.start), range.end - range.start) } + } + + /// Returns a slice pointer into the buffer. + /// `range` must lie inside `0..self.capacity()`. + #[inline] + unsafe fn buffer_range_mut(&mut self, range: core::ops::Range) -> *mut [T] { + unsafe { + ptr::slice_from_raw_parts_mut(self.ptr_mut().add(range.start), range.end - range.start) + } + } + + /// Returns `true` if the buffer is at full capacity. + #[inline] + fn is_full(&self) -> bool { + self.len == self.capacity() + } + + /// Returns the index in the underlying buffer for a given logical element index + addend. + #[inline] + fn wrap_add(&self, idx: usize, addend: usize) -> usize { + wrap_index(idx.wrapping_add(addend), self.capacity()) + } + + #[inline] + fn to_physical_idx(&self, idx: usize) -> usize { + self.wrap_add(self.head, idx) + } + + /// Returns the index in the underlying buffer for a given logical element index - subtrahend. + #[inline] + fn wrap_sub(&self, idx: usize, subtrahend: usize) -> usize { + wrap_index(idx.wrapping_sub(subtrahend).wrapping_add(self.capacity()), self.capacity()) + } + + /// Copies a contiguous block of memory len long from src to dst + #[inline] + unsafe fn copy(&mut self, src: usize, dst: usize, len: usize) { + debug_assert!( + dst + len <= self.capacity(), + "cpy dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + debug_assert!( + src + len <= self.capacity(), + "cpy dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + unsafe { + ptr::copy(self.ptr().add(src), self.ptr_mut().add(dst), len); + } + } + + /// Copies a contiguous block of memory len long from src to dst + #[inline] + unsafe fn copy_nonoverlapping(&mut self, src: usize, dst: usize, len: usize) { + debug_assert!( + dst + len <= self.capacity(), + "cno dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + debug_assert!( + src + len <= self.capacity(), + "cno dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + unsafe { + ptr::copy_nonoverlapping(self.ptr().add(src), self.ptr_mut().add(dst), len); + } + } + + /// Copies a potentially wrapping block of memory len long from src to dest. + /// (abs(dst - src) + len) must be no larger than capacity() (There must be at + /// most one continuous overlapping region between src and dest). + unsafe fn wrap_copy(&mut self, src: usize, dst: usize, len: usize) { + debug_assert!( + core::cmp::min(src.abs_diff(dst), self.capacity() - src.abs_diff(dst)) + len + <= self.capacity(), + "wrc dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + + // If T is a ZST, don't do any copying. + if T::IS_ZST || src == dst || len == 0 { + return; + } + + let dst_after_src = self.wrap_sub(dst, src) < len; + + let src_pre_wrap_len = self.capacity() - src; + let dst_pre_wrap_len = self.capacity() - dst; + let src_wraps = src_pre_wrap_len < len; + let dst_wraps = dst_pre_wrap_len < len; + + match (dst_after_src, src_wraps, dst_wraps) { + (_, false, false) => { + // src doesn't wrap, dst doesn't wrap + // + // S . . . + // 1 [_ _ A A B B C C _] + // 2 [_ _ A A A A B B _] + // D . . . + // + unsafe { + self.copy(src, dst, len); + } + } + (false, false, true) => { + // dst before src, src doesn't wrap, dst wraps + // + // S . . . + // 1 [A A B B _ _ _ C C] + // 2 [A A B B _ _ _ A A] + // 3 [B B B B _ _ _ A A] + // . . D . + // + unsafe { + self.copy(src, dst, dst_pre_wrap_len); + self.copy(src + dst_pre_wrap_len, 0, len - dst_pre_wrap_len); + } + } + (true, false, true) => { + // src before dst, src doesn't wrap, dst wraps + // + // S . . . + // 1 [C C _ _ _ A A B B] + // 2 [B B _ _ _ A A B B] + // 3 [B B _ _ _ A A A A] + // . . D . + // + unsafe { + self.copy(src + dst_pre_wrap_len, 0, len - dst_pre_wrap_len); + self.copy(src, dst, dst_pre_wrap_len); + } + } + (false, true, false) => { + // dst before src, src wraps, dst doesn't wrap + // + // . . S . + // 1 [C C _ _ _ A A B B] + // 2 [C C _ _ _ B B B B] + // 3 [C C _ _ _ B B C C] + // D . . . + // + unsafe { + self.copy(src, dst, src_pre_wrap_len); + self.copy(0, dst + src_pre_wrap_len, len - src_pre_wrap_len); + } + } + (true, true, false) => { + // src before dst, src wraps, dst doesn't wrap + // + // . . S . + // 1 [A A B B _ _ _ C C] + // 2 [A A A A _ _ _ C C] + // 3 [C C A A _ _ _ C C] + // D . . . + // + unsafe { + self.copy(0, dst + src_pre_wrap_len, len - src_pre_wrap_len); + self.copy(src, dst, src_pre_wrap_len); + } + } + (false, true, true) => { + // dst before src, src wraps, dst wraps + // + // . . . S . + // 1 [A B C D _ E F G H] + // 2 [A B C D _ E G H H] + // 3 [A B C D _ E G H A] + // 4 [B C C D _ E G H A] + // . . D . . + // + debug_assert!(dst_pre_wrap_len > src_pre_wrap_len); + let delta = dst_pre_wrap_len - src_pre_wrap_len; + unsafe { + self.copy(src, dst, src_pre_wrap_len); + self.copy(0, dst + src_pre_wrap_len, delta); + self.copy(delta, 0, len - dst_pre_wrap_len); + } + } + (true, true, true) => { + // src before dst, src wraps, dst wraps + // + // . . S . . + // 1 [A B C D _ E F G H] + // 2 [A A B D _ E F G H] + // 3 [H A B D _ E F G H] + // 4 [H A B D _ E F F G] + // . . . D . + // + debug_assert!(src_pre_wrap_len > dst_pre_wrap_len); + let delta = src_pre_wrap_len - dst_pre_wrap_len; + unsafe { + self.copy(0, delta, len - src_pre_wrap_len); + self.copy(self.capacity() - delta, 0, delta); + self.copy(src, dst, dst_pre_wrap_len); + } + } + } + } + + /// Copies all values from `src` to `dst`, wrapping around if needed. + /// Assumes capacity is sufficient. + #[inline] + unsafe fn copy_slice(&mut self, dst: usize, src: &[T]) { + debug_assert!(src.len() <= self.capacity()); + let head_room = self.capacity() - dst; + if src.len() <= head_room { + unsafe { + ptr::copy_nonoverlapping(src.as_ptr(), self.ptr_mut().add(dst), src.len()); + } + } else { + let (left, right) = src.split_at(head_room); + unsafe { + ptr::copy_nonoverlapping(left.as_ptr(), self.ptr_mut().add(dst), left.len()); + ptr::copy_nonoverlapping(right.as_ptr(), self.ptr_mut(), right.len()); + } + } + } + + /// Writes all values from `iter` to `dst`. + /// + /// # Safety + /// + /// Assumes no wrapping around happens. + /// Assumes capacity is sufficient. + #[inline] + unsafe fn write_iter( + &mut self, + dst: usize, + iter: impl Iterator, + written: &mut usize, + ) { + iter.enumerate().for_each(|(i, element)| unsafe { + self.buffer_write(dst + i, element); + *written += 1; + }); + } + + /// Writes all values from `iter` to `dst`, wrapping + /// at the end of the buffer and returns the number + /// of written values. + /// + /// # Safety + /// + /// Assumes that `iter` yields at most `len` items. + /// Assumes capacity is sufficient. + unsafe fn write_iter_wrapping( + &mut self, + dst: usize, + mut iter: impl Iterator, + len: usize, + ) -> usize { + struct Guard<'a, T, const N: usize> { + deque: &'a mut SmallDeque, + written: usize, + } + + impl<'a, T, const N: usize> Drop for Guard<'a, T, N> { + fn drop(&mut self) { + self.deque.len += self.written; + } + } + + let head_room = self.capacity() - dst; + + let mut guard = Guard { + deque: self, + written: 0, + }; + + if head_room >= len { + unsafe { guard.deque.write_iter(dst, iter, &mut guard.written) }; + } else { + unsafe { + guard.deque.write_iter( + dst, + ByRefSized(&mut iter).take(head_room), + &mut guard.written, + ); + guard.deque.write_iter(0, iter, &mut guard.written) + }; + } + + guard.written + } + + /// Frobs the head and tail sections around to handle the fact that we + /// just reallocated. Unsafe because it trusts old_capacity. + #[inline] + unsafe fn handle_capacity_increase(&mut self, old_capacity: usize) { + let new_capacity = self.capacity(); + debug_assert!(new_capacity >= old_capacity); + + // Move the shortest contiguous section of the ring buffer + // + // H := head + // L := last element (`self.to_physical_idx(self.len - 1)`) + // + // H L + // [o o o o o o o o ] + // H L + // A [o o o o o o o o . . . . . . . . ] + // L H + // [o o o o o o o o ] + // H L + // B [. . . o o o o o o o o . . . . . ] + // L H + // [o o o o o o o o ] + // L H + // C [o o o o o o . . . . . . . . o o ] + + // can't use is_contiguous() because the capacity is already updated. + if self.head <= old_capacity - self.len { + // A + // Nop + } else { + let head_len = old_capacity - self.head; + let tail_len = self.len - head_len; + if head_len > tail_len && new_capacity - old_capacity >= tail_len { + // B + unsafe { + self.copy_nonoverlapping(0, old_capacity, tail_len); + } + } else { + // C + let new_head = new_capacity - head_len; + unsafe { + // can't use copy_nonoverlapping here, because if e.g. head_len = 2 + // and new_capacity = old_capacity + 1, then the heads overlap. + self.copy(self.head, new_head, head_len); + } + self.head = new_head; + } + } + debug_assert!(self.head < self.capacity() || self.capacity() == 0); + } +} + +/// Returns the index in the underlying buffer for a given logical element index. +#[inline] +fn wrap_index(logical_index: usize, capacity: usize) -> usize { + debug_assert!( + (logical_index == 0 && capacity == 0) + || logical_index < capacity + || (logical_index - capacity) < capacity + ); + if logical_index >= capacity { + logical_index - capacity + } else { + logical_index + } +} + +impl PartialEq for SmallDeque { + fn eq(&self, other: &Self) -> bool { + if self.len != other.len() { + return false; + } + let (sa, sb) = self.as_slices(); + let (oa, ob) = other.as_slices(); + match sa.len().cmp(&oa.len()) { + Ordering::Equal => sa == oa && sb == ob, + Ordering::Less => { + // Always divisible in three sections, for example: + // self: [a b c|d e f] + // other: [0 1 2 3|4 5] + // front = 3, mid = 1, + // [a b c] == [0 1 2] && [d] == [3] && [e f] == [4 5] + let front = sa.len(); + let mid = oa.len() - front; + + let (oa_front, oa_mid) = oa.split_at(front); + let (sb_mid, sb_back) = sb.split_at(mid); + debug_assert_eq!(sa.len(), oa_front.len()); + debug_assert_eq!(sb_mid.len(), oa_mid.len()); + debug_assert_eq!(sb_back.len(), ob.len()); + sa == oa_front && sb_mid == oa_mid && sb_back == ob + } + Ordering::Greater => { + let front = oa.len(); + let mid = sa.len() - front; + + let (sa_front, sa_mid) = sa.split_at(front); + let (ob_mid, ob_back) = ob.split_at(mid); + debug_assert_eq!(sa_front.len(), oa.len()); + debug_assert_eq!(sa_mid.len(), ob_mid.len()); + debug_assert_eq!(sb.len(), ob_back.len()); + sa_front == oa && sa_mid == ob_mid && sb == ob_back + } + } + } +} + +impl Eq for SmallDeque {} + +macro_rules! __impl_slice_eq1 { + ([$($vars:tt)*] $lhs:ty, $rhs:ty, $($constraints:tt)*) => { + impl PartialEq<$rhs> for $lhs + where + T: PartialEq, + $($constraints)* + { + fn eq(&self, other: &$rhs) -> bool { + if self.len() != other.len() { + return false; + } + let (sa, sb) = self.as_slices(); + let (oa, ob) = other[..].split_at(sa.len()); + sa == oa && sb == ob + } + } + } +} + +__impl_slice_eq1! { [A: alloc::alloc::Allocator, const N: usize] SmallDeque, Vec, } +__impl_slice_eq1! { [const N: usize] SmallDeque, SmallVec<[U; N]>, } +__impl_slice_eq1! { [const N: usize] SmallDeque, &[U], } +__impl_slice_eq1! { [const N: usize] SmallDeque, &mut [U], } +__impl_slice_eq1! { [const N: usize, const M: usize] SmallDeque, [U; M], } +__impl_slice_eq1! { [const N: usize, const M: usize] SmallDeque, &[U; M], } +__impl_slice_eq1! { [const N: usize, const M: usize] SmallDeque, &mut [U; M], } + +impl PartialOrd for SmallDeque { + fn partial_cmp(&self, other: &Self) -> Option { + self.iter().partial_cmp(other.iter()) + } +} + +impl Ord for SmallDeque { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + self.iter().cmp(other.iter()) + } +} + +impl core::hash::Hash for SmallDeque { + fn hash(&self, state: &mut H) { + state.write_length_prefix(self.len); + // It's not possible to use Hash::hash_slice on slices + // returned by as_slices method as their length can vary + // in otherwise identical deques. + // + // Hasher only guarantees equivalence for the exact same + // set of calls to its methods. + self.iter().for_each(|elem| elem.hash(state)); + } +} + +impl Index for SmallDeque { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &T { + self.get(index).expect("Out of bounds access") + } +} + +impl IndexMut for SmallDeque { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut T { + self.get_mut(index).expect("Out of bounds access") + } +} + +impl FromIterator for SmallDeque { + fn from_iter>(iter: I) -> Self { + SpecFromIter::spec_from_iter(iter.into_iter()) + } +} + +impl IntoIterator for SmallDeque { + type IntoIter = IntoIter; + type Item = T; + + /// Consumes the deque into a front-to-back iterator yielding elements by + /// value. + fn into_iter(self) -> IntoIter { + IntoIter::new(self) + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a SmallDeque { + type IntoIter = Iter<'a, T>; + type Item = &'a T; + + fn into_iter(self) -> Iter<'a, T> { + self.iter() + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a mut SmallDeque { + type IntoIter = IterMut<'a, T>; + type Item = &'a mut T; + + fn into_iter(self) -> IterMut<'a, T> { + self.iter_mut() + } +} + +impl Extend for SmallDeque { + fn extend>(&mut self, iter: I) { + >::spec_extend(self, iter.into_iter()); + } + + #[inline] + fn extend_one(&mut self, elem: T) { + self.push_back(elem); + } + + #[inline] + fn extend_reserve(&mut self, additional: usize) { + self.reserve(additional); + } + + #[inline] + unsafe fn extend_one_unchecked(&mut self, item: T) { + // SAFETY: Our preconditions ensure the space has been reserved, and `extend_reserve` is implemented correctly. + unsafe { + self.push_unchecked(item); + } + } +} + +impl<'a, T: 'a + Copy, const N: usize> Extend<&'a T> for SmallDeque { + fn extend>(&mut self, iter: I) { + self.spec_extend(iter.into_iter()); + } + + #[inline] + fn extend_one(&mut self, &elem: &'a T) { + self.push_back(elem); + } + + #[inline] + fn extend_reserve(&mut self, additional: usize) { + self.reserve(additional); + } + + #[inline] + unsafe fn extend_one_unchecked(&mut self, &item: &'a T) { + // SAFETY: Our preconditions ensure the space has been reserved, and `extend_reserve` is implemented correctly. + unsafe { + self.push_unchecked(item); + } + } +} + +impl fmt::Debug for SmallDeque { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.iter()).finish() + } +} + +impl From> for SmallDeque { + /// Turn a [`SmallVec<[T; N]>`] into a [`SmallDeque`]. + /// + /// [`SmallVec<[T; N]>`]: smallvec::SmallVec + /// [`SmallDeque`]: crate::adt::SmallDeque + /// + /// This conversion is guaranteed to run in *O*(1) time + /// and to not re-allocate the `Vec`'s buffer or allocate + /// any additional memory. + #[inline] + fn from(buf: SmallVec<[T; N]>) -> Self { + let len = buf.len(); + Self { head: 0, len, buf } + } +} + +impl From> for SmallVec<[T; N]> { + /// Turn a [`SmallDeque`] into a [`SmallVec<[T; N]>`]. + /// + /// [`SmallVec<[T; N]>`]: smallvec::SmallVec + /// [`SmallDeque`]: crate::adt::SmallDeque + /// + /// This never needs to re-allocate, but does need to do *O*(*n*) data movement if + /// the circular buffer doesn't happen to be at the beginning of the allocation. + fn from(mut other: SmallDeque) -> Self { + use core::mem::ManuallyDrop; + + other.make_contiguous(); + + unsafe { + if other.buf.spilled() { + let mut other = ManuallyDrop::new(other); + let buf = other.buf.as_mut_ptr(); + let len = other.len(); + let cap = other.capacity(); + + if other.head != 0 { + ptr::copy(buf.add(other.head), buf, len); + } + SmallVec::from_raw_parts(buf, len, cap) + } else { + // `other` is entirely stack-allocated, so we need to produce a new copy that + // has all of the elements starting at index 0, if not already + if other.head == 0 { + // Steal the underlying vec, and make sure that the length is set + let mut buf = other.buf; + buf.set_len(other.len); + buf + } else { + let mut other = ManuallyDrop::new(other); + let ptr = other.buf.as_mut_ptr(); + let len = other.len(); + + // Construct an uninitialized array on the stack of the same size as the target + // SmallVec's inline size, "move" `len` items into it, and the construct the + // SmallVec from the raw buffer and len + let mut buf = core::mem::MaybeUninit::::uninit_array::(); + let buf_ptr = core::mem::MaybeUninit::slice_as_mut_ptr(&mut buf); + ptr::copy(ptr.add(other.head), buf_ptr, len); + // While we are technically potentially letting a subset of elements in the + // array that never got uninitialized, be assumed to have been initialized + // here - that fact is never material: no references are ever created to those + // items, and the array is never dropped, as it is immediately placed in a + // ManuallyDrop, and the vector length is set to `len` before any access can + // be made to the vector + SmallVec::from_buf_and_len(core::mem::MaybeUninit::array_assume_init(buf), len) + } + } + } + } +} + +impl From<[T; N]> for SmallDeque { + /// Converts a `[T; N]` into a `SmallDeque`. + fn from(arr: [T; N]) -> Self { + use core::mem::ManuallyDrop; + + let mut deq = SmallDeque::<_, N>::with_capacity(N); + let arr = ManuallyDrop::new(arr); + if !::IS_ZST { + // SAFETY: SmallDeque::with_capacity ensures that there is enough capacity. + unsafe { + ptr::copy_nonoverlapping(arr.as_ptr(), deq.ptr_mut(), N); + } + } + deq.head = 0; + deq.len = N; + deq + } +} + +/// An iterator over the elements of a `SmallDeque`. +/// +/// This `struct` is created by the [`iter`] method on [`SmallDeque`]. See its +/// documentation for more. +/// +/// [`iter`]: SmallDeque::iter +#[derive(Clone)] +pub struct Iter<'a, T: 'a> { + i1: core::slice::Iter<'a, T>, + i2: core::slice::Iter<'a, T>, +} + +impl<'a, T> Iter<'a, T> { + pub(super) fn new(i1: core::slice::Iter<'a, T>, i2: core::slice::Iter<'a, T>) -> Self { + Self { i1, i2 } + } +} + +impl fmt::Debug for Iter<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Iter") + .field(&self.i1.as_slice()) + .field(&self.i2.as_slice()) + .finish() + } +} + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = &'a T; + + #[inline] + fn next(&mut self) -> Option<&'a T> { + match self.i1.next() { + Some(val) => Some(val), + None => { + // most of the time, the iterator will either always + // call next(), or always call next_back(). By swapping + // the iterators once the first one is empty, we ensure + // that the first branch is taken as often as possible, + // without sacrificing correctness, as i1 is empty anyways + core::mem::swap(&mut self.i1, &mut self.i2); + self.i1.next() + } + } + } + + fn advance_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + let remaining = self.i1.advance_by(n); + match remaining { + Ok(()) => Ok(()), + Err(n) => { + core::mem::swap(&mut self.i1, &mut self.i2); + self.i1.advance_by(n.get()) + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } + + fn fold(self, accum: Acc, mut f: F) -> Acc + where + F: FnMut(Acc, Self::Item) -> Acc, + { + let accum = self.i1.fold(accum, &mut f); + self.i2.fold(accum, &mut f) + } + + fn try_fold(&mut self, init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + let acc = self.i1.try_fold(init, &mut f)?; + self.i2.try_fold(acc, &mut f) + } + + #[inline] + fn last(mut self) -> Option<&'a T> { + self.next_back() + } +} + +impl<'a, T> DoubleEndedIterator for Iter<'a, T> { + #[inline] + fn next_back(&mut self) -> Option<&'a T> { + match self.i2.next_back() { + Some(val) => Some(val), + None => { + // most of the time, the iterator will either always + // call next(), or always call next_back(). By swapping + // the iterators once the second one is empty, we ensure + // that the first branch is taken as often as possible, + // without sacrificing correctness, as i2 is empty anyways + core::mem::swap(&mut self.i1, &mut self.i2); + self.i2.next_back() + } + } + } + + fn advance_back_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + match self.i2.advance_back_by(n) { + Ok(()) => Ok(()), + Err(n) => { + core::mem::swap(&mut self.i1, &mut self.i2); + self.i2.advance_back_by(n.get()) + } + } + } + + fn rfold(self, accum: Acc, mut f: F) -> Acc + where + F: FnMut(Acc, Self::Item) -> Acc, + { + let accum = self.i2.rfold(accum, &mut f); + self.i1.rfold(accum, &mut f) + } + + fn try_rfold(&mut self, init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + let acc = self.i2.try_rfold(init, &mut f)?; + self.i1.try_rfold(acc, &mut f) + } +} + +impl ExactSizeIterator for Iter<'_, T> { + fn len(&self) -> usize { + self.i1.len() + self.i2.len() + } + + fn is_empty(&self) -> bool { + self.i1.is_empty() && self.i2.is_empty() + } +} + +impl core::iter::FusedIterator for Iter<'_, T> {} + +unsafe impl core::iter::TrustedLen for Iter<'_, T> {} + +/// A mutable iterator over the elements of a `SmallDeque`. +/// +/// This `struct` is created by the [`iter_mut`] method on [`SmallDeque`]. See its +/// documentation for more. +/// +/// [`iter_mut`]: SmallDeque::iter_mut +pub struct IterMut<'a, T: 'a> { + i1: core::slice::IterMut<'a, T>, + i2: core::slice::IterMut<'a, T>, +} + +impl<'a, T> IterMut<'a, T> { + pub(super) fn new(i1: core::slice::IterMut<'a, T>, i2: core::slice::IterMut<'a, T>) -> Self { + Self { i1, i2 } + } +} + +impl fmt::Debug for IterMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IterMut") + .field(&self.i1.as_slice()) + .field(&self.i2.as_slice()) + .finish() + } +} + +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + #[inline] + fn next(&mut self) -> Option<&'a mut T> { + match self.i1.next() { + Some(val) => Some(val), + None => { + // most of the time, the iterator will either always + // call next(), or always call next_back(). By swapping + // the iterators once the first one is empty, we ensure + // that the first branch is taken as often as possible, + // without sacrificing correctness, as i1 is empty anyways + core::mem::swap(&mut self.i1, &mut self.i2); + self.i1.next() + } + } + } + + fn advance_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + match self.i1.advance_by(n) { + Ok(()) => Ok(()), + Err(remaining) => { + core::mem::swap(&mut self.i1, &mut self.i2); + self.i1.advance_by(remaining.get()) + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } + + fn fold(self, accum: Acc, mut f: F) -> Acc + where + F: FnMut(Acc, Self::Item) -> Acc, + { + let accum = self.i1.fold(accum, &mut f); + self.i2.fold(accum, &mut f) + } + + fn try_fold(&mut self, init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + let acc = self.i1.try_fold(init, &mut f)?; + self.i2.try_fold(acc, &mut f) + } + + #[inline] + fn last(mut self) -> Option<&'a mut T> { + self.next_back() + } +} + +impl<'a, T> DoubleEndedIterator for IterMut<'a, T> { + #[inline] + fn next_back(&mut self) -> Option<&'a mut T> { + match self.i2.next_back() { + Some(val) => Some(val), + None => { + // most of the time, the iterator will either always + // call next(), or always call next_back(). By swapping + // the iterators once the first one is empty, we ensure + // that the first branch is taken as often as possible, + // without sacrificing correctness, as i2 is empty anyways + core::mem::swap(&mut self.i1, &mut self.i2); + self.i2.next_back() + } + } + } + + fn advance_back_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + match self.i2.advance_back_by(n) { + Ok(()) => Ok(()), + Err(remaining) => { + core::mem::swap(&mut self.i1, &mut self.i2); + self.i2.advance_back_by(remaining.get()) + } + } + } + + fn rfold(self, accum: Acc, mut f: F) -> Acc + where + F: FnMut(Acc, Self::Item) -> Acc, + { + let accum = self.i2.rfold(accum, &mut f); + self.i1.rfold(accum, &mut f) + } + + fn try_rfold(&mut self, init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + let acc = self.i2.try_rfold(init, &mut f)?; + self.i1.try_rfold(acc, &mut f) + } +} + +impl ExactSizeIterator for IterMut<'_, T> { + fn len(&self) -> usize { + self.i1.len() + self.i2.len() + } + + fn is_empty(&self) -> bool { + self.i1.is_empty() && self.i2.is_empty() + } +} + +impl core::iter::FusedIterator for IterMut<'_, T> {} + +unsafe impl core::iter::TrustedLen for IterMut<'_, T> {} + +/// An owning iterator over the elements of a `SmallDeque`. +/// +/// This `struct` is created by the [`into_iter`] method on [`SmallDeque`] +/// (provided by the [`IntoIterator`] trait). See its documentation for more. +/// +/// [`into_iter`]: SmallDeque::into_iter +#[derive(Clone)] +pub struct IntoIter { + inner: SmallDeque, +} + +impl IntoIter { + pub(super) fn new(inner: SmallDeque) -> Self { + IntoIter { inner } + } + + pub(super) fn into_smalldeque(self) -> SmallDeque { + self.inner + } +} + +impl fmt::Debug for IntoIter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IntoIter").field(&self.inner).finish() + } +} + +impl Iterator for IntoIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.inner.pop_front() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.inner.len(); + (len, Some(len)) + } + + #[inline] + fn advance_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + let len = self.inner.len; + let rem = if len < n { + self.inner.clear(); + n - len + } else { + self.inner.drain(..n); + 0 + }; + core::num::NonZero::new(rem).map_or(Ok(()), Err) + } + + #[inline] + fn count(self) -> usize { + self.inner.len + } + + fn try_fold(&mut self, mut init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + struct Guard<'a, T, const M: usize> { + deque: &'a mut SmallDeque, + // `consumed <= deque.len` always holds. + consumed: usize, + } + + impl<'a, T, const M: usize> Drop for Guard<'a, T, M> { + fn drop(&mut self) { + self.deque.len -= self.consumed; + self.deque.head = self.deque.to_physical_idx(self.consumed); + } + } + + let mut guard = Guard { + deque: &mut self.inner, + consumed: 0, + }; + + let (head, tail) = guard.deque.as_slices(); + + init = head + .iter() + .map(|elem| { + guard.consumed += 1; + // SAFETY: Because we incremented `guard.consumed`, the + // deque effectively forgot the element, so we can take + // ownership + unsafe { ptr::read(elem) } + }) + .try_fold(init, &mut f)?; + + tail.iter() + .map(|elem| { + guard.consumed += 1; + // SAFETY: Same as above. + unsafe { ptr::read(elem) } + }) + .try_fold(init, &mut f) + } + + #[inline] + fn fold(mut self, init: B, mut f: F) -> B + where + F: FnMut(B, Self::Item) -> B, + { + match self.try_fold(init, |b, item| Ok::(f(b, item))) { + Ok(b) => b, + Err(e) => match e {}, + } + } + + #[inline] + fn last(mut self) -> Option { + self.inner.pop_back() + } + + fn next_chunk( + &mut self, + ) -> Result<[Self::Item; N], core::array::IntoIter> { + let mut raw_arr = core::mem::MaybeUninit::uninit_array(); + let raw_arr_ptr = raw_arr.as_mut_ptr().cast(); + let (head, tail) = self.inner.as_slices(); + + if head.len() >= N { + // SAFETY: By manually adjusting the head and length of the deque, we effectively + // make it forget the first `N` elements, so taking ownership of them is safe. + unsafe { ptr::copy_nonoverlapping(head.as_ptr(), raw_arr_ptr, N) }; + self.inner.head = self.inner.to_physical_idx(N); + self.inner.len -= N; + // SAFETY: We initialized the entire array with items from `head` + return Ok(unsafe { raw_arr.transpose().assume_init() }); + } + + // SAFETY: Same argument as above. + unsafe { ptr::copy_nonoverlapping(head.as_ptr(), raw_arr_ptr, head.len()) }; + let remaining = N - head.len(); + + if tail.len() >= remaining { + // SAFETY: Same argument as above. + unsafe { + ptr::copy_nonoverlapping(tail.as_ptr(), raw_arr_ptr.add(head.len()), remaining) + }; + self.inner.head = self.inner.to_physical_idx(N); + self.inner.len -= N; + // SAFETY: We initialized the entire array with items from `head` and `tail` + Ok(unsafe { raw_arr.transpose().assume_init() }) + } else { + // SAFETY: Same argument as above. + unsafe { + ptr::copy_nonoverlapping(tail.as_ptr(), raw_arr_ptr.add(head.len()), tail.len()) + }; + let init = head.len() + tail.len(); + // We completely drained all the deques elements. + self.inner.head = 0; + self.inner.len = 0; + // SAFETY: We copied all elements from both slices to the beginning of the array, so + // the given range is initialized. + Err(unsafe { core::array::IntoIter::new_unchecked(raw_arr, 0..init) }) + } + } +} + +impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + self.inner.pop_back() + } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + let len = self.inner.len; + let rem = if len < n { + self.inner.clear(); + n - len + } else { + self.inner.truncate(len - n); + 0 + }; + core::num::NonZero::new(rem).map_or(Ok(()), Err) + } + + fn try_rfold(&mut self, mut init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + struct Guard<'a, T, const N: usize> { + deque: &'a mut SmallDeque, + // `consumed <= deque.len` always holds. + consumed: usize, + } + + impl<'a, T, const N: usize> Drop for Guard<'a, T, N> { + fn drop(&mut self) { + self.deque.len -= self.consumed; + } + } + + let mut guard = Guard { + deque: &mut self.inner, + consumed: 0, + }; + + let (head, tail) = guard.deque.as_slices(); + + init = tail + .iter() + .map(|elem| { + guard.consumed += 1; + // SAFETY: See `try_fold`'s safety comment. + unsafe { ptr::read(elem) } + }) + .try_rfold(init, &mut f)?; + + head.iter() + .map(|elem| { + guard.consumed += 1; + // SAFETY: Same as above. + unsafe { ptr::read(elem) } + }) + .try_rfold(init, &mut f) + } + + #[inline] + fn rfold(mut self, init: B, mut f: F) -> B + where + F: FnMut(B, Self::Item) -> B, + { + match self.try_rfold(init, |b, item| Ok::(f(b, item))) { + Ok(b) => b, + Err(e) => match e {}, + } + } +} + +impl ExactSizeIterator for IntoIter { + #[inline] + fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +impl core::iter::FusedIterator for IntoIter {} + +unsafe impl core::iter::TrustedLen for IntoIter {} + +/// A draining iterator over the elements of a `SmallDeque`. +/// +/// This `struct` is created by the [`drain`] method on [`SmallDeque`]. See its +/// documentation for more. +/// +/// [`drain`]: SmallDeque::drain +pub struct Drain<'a, T: 'a, const N: usize> { + // We can't just use a &mut SmallDeque, as that would make Drain invariant over T + // and we want it to be covariant instead + deque: NonNull>, + // drain_start is stored in deque.len + drain_len: usize, + // index into the logical array, not the physical one (always lies in [0..deque.len)) + idx: usize, + // number of elements remaining after dropping the drain + new_len: usize, + remaining: usize, + // Needed to make Drain covariant over T + _marker: core::marker::PhantomData<&'a T>, +} + +impl<'a, T, const N: usize> Drain<'a, T, N> { + pub(super) unsafe fn new( + deque: &'a mut SmallDeque, + drain_start: usize, + drain_len: usize, + ) -> Self { + let orig_len = core::mem::replace(&mut deque.len, drain_start); + let new_len = orig_len - drain_len; + Drain { + deque: NonNull::from(deque), + drain_len, + idx: drain_start, + new_len, + remaining: drain_len, + _marker: core::marker::PhantomData, + } + } + + // Only returns pointers to the slices, as that's all we need + // to drop them. May only be called if `self.remaining != 0`. + unsafe fn as_slices(&mut self) -> (*mut [T], *mut [T]) { + unsafe { + let deque = self.deque.as_mut(); + + // We know that `self.idx + self.remaining <= deque.len <= usize::MAX`, so this won't overflow. + let logical_remaining_range = self.idx..self.idx + self.remaining; + + // SAFETY: `logical_remaining_range` represents the + // range into the logical buffer of elements that + // haven't been drained yet, so they're all initialized, + // and `slice::range(start..end, end) == start..end`, + // so the preconditions for `slice_ranges` are met. + let (a_range, b_range) = + deque.slice_ranges(logical_remaining_range.clone(), logical_remaining_range.end); + (deque.buffer_range_mut(a_range), deque.buffer_range_mut(b_range)) + } + } +} + +impl fmt::Debug for Drain<'_, T, N> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Drain") + .field(&self.drain_len) + .field(&self.idx) + .field(&self.new_len) + .field(&self.remaining) + .finish() + } +} + +unsafe impl Sync for Drain<'_, T, N> {} +unsafe impl Send for Drain<'_, T, N> {} + +impl Drop for Drain<'_, T, N> { + fn drop(&mut self) { + struct DropGuard<'r, 'a, T, const N: usize>(&'r mut Drain<'a, T, N>); + + let guard = DropGuard(self); + + if core::mem::needs_drop::() && guard.0.remaining != 0 { + unsafe { + // SAFETY: We just checked that `self.remaining != 0`. + let (front, back) = guard.0.as_slices(); + // since idx is a logical index, we don't need to worry about wrapping. + guard.0.idx += front.len(); + guard.0.remaining -= front.len(); + ptr::drop_in_place(front); + guard.0.remaining = 0; + ptr::drop_in_place(back); + } + } + + // Dropping `guard` handles moving the remaining elements into place. + impl<'r, 'a, T, const N: usize> Drop for DropGuard<'r, 'a, T, N> { + #[inline] + fn drop(&mut self) { + if core::mem::needs_drop::() && self.0.remaining != 0 { + unsafe { + // SAFETY: We just checked that `self.remaining != 0`. + let (front, back) = self.0.as_slices(); + ptr::drop_in_place(front); + ptr::drop_in_place(back); + } + } + + let source_deque = unsafe { self.0.deque.as_mut() }; + + let drain_len = self.0.drain_len; + let new_len = self.0.new_len; + + if T::IS_ZST { + // no need to copy around any memory if T is a ZST + source_deque.len = new_len; + return; + } + + let head_len = source_deque.len; // #elements in front of the drain + let tail_len = new_len - head_len; // #elements behind the drain + + // Next, we will fill the hole left by the drain with as few writes as possible. + // The code below handles the following control flow and reduces the amount of + // branches under the assumption that `head_len == 0 || tail_len == 0`, i.e. + // draining at the front or at the back of the dequeue is especially common. + // + // H = "head index" = `deque.head` + // h = elements in front of the drain + // d = elements in the drain + // t = elements behind the drain + // + // Note that the buffer may wrap at any point and the wrapping is handled by + // `wrap_copy` and `to_physical_idx`. + // + // Case 1: if `head_len == 0 && tail_len == 0` + // Everything was drained, reset the head index back to 0. + // H + // [ . . . . . d d d d . . . . . ] + // H + // [ . . . . . . . . . . . . . . ] + // + // Case 2: else if `tail_len == 0` + // Don't move data or the head index. + // H + // [ . . . h h h h d d d d . . . ] + // H + // [ . . . h h h h . . . . . . . ] + // + // Case 3: else if `head_len == 0` + // Don't move data, but move the head index. + // H + // [ . . . d d d d t t t t . . . ] + // H + // [ . . . . . . . t t t t . . . ] + // + // Case 4: else if `tail_len <= head_len` + // Move data, but not the head index. + // H + // [ . . h h h h d d d d t t . . ] + // H + // [ . . h h h h t t . . . . . . ] + // + // Case 5: else + // Move data and the head index. + // H + // [ . . h h d d d d t t t t . . ] + // H + // [ . . . . . . h h t t t t . . ] + + // When draining at the front (`.drain(..n)`) or at the back (`.drain(n..)`), + // we don't need to copy any data. The number of elements copied would be 0. + if head_len != 0 && tail_len != 0 { + join_head_and_tail_wrapping(source_deque, drain_len, head_len, tail_len); + // Marking this function as cold helps LLVM to eliminate it entirely if + // this branch is never taken. + // We use `#[cold]` instead of `#[inline(never)]`, because inlining this + // function into the general case (`.drain(n..m)`) is fine. + // See `tests/codegen/vecdeque-drain.rs` for a test. + #[cold] + fn join_head_and_tail_wrapping( + source_deque: &mut SmallDeque, + drain_len: usize, + head_len: usize, + tail_len: usize, + ) { + // Pick whether to move the head or the tail here. + let (src, dst, len); + if head_len < tail_len { + src = source_deque.head; + dst = source_deque.to_physical_idx(drain_len); + len = head_len; + } else { + src = source_deque.to_physical_idx(head_len + drain_len); + dst = source_deque.to_physical_idx(head_len); + len = tail_len; + }; + + unsafe { + source_deque.wrap_copy(src, dst, len); + } + } + } + + if new_len == 0 { + // Special case: If the entire dequeue was drained, reset the head back to 0, + // like `.clear()` does. + source_deque.head = 0; + } else if head_len < tail_len { + // If we moved the head above, then we need to adjust the head index here. + source_deque.head = source_deque.to_physical_idx(drain_len); + } + source_deque.len = new_len; + } + } + } +} + +impl Iterator for Drain<'_, T, N> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + let wrapped_idx = unsafe { self.deque.as_ref().to_physical_idx(self.idx) }; + self.idx += 1; + self.remaining -= 1; + Some(unsafe { self.deque.as_mut().buffer_read(wrapped_idx) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining; + (len, Some(len)) + } +} + +impl DoubleEndedIterator for Drain<'_, T, N> { + #[inline] + fn next_back(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + self.remaining -= 1; + let wrapped_idx = unsafe { self.deque.as_ref().to_physical_idx(self.idx + self.remaining) }; + Some(unsafe { self.deque.as_mut().buffer_read(wrapped_idx) }) + } +} + +impl ExactSizeIterator for Drain<'_, T, N> {} + +impl core::iter::FusedIterator for Drain<'_, T, N> {} + +/// Specialization trait used for `SmallDeque::from_iter` +trait SpecFromIter { + fn spec_from_iter(iter: I) -> Self; +} + +impl SpecFromIter for SmallDeque +where + I: Iterator, +{ + default fn spec_from_iter(iterator: I) -> Self { + // Since converting is O(1) now, just re-use the `SmallVec` logic for + // anything where we can't do something extra-special for `SmallDeque`, + // especially as that could save us some monomorphization work + // if one uses the same iterators (like slice ones) with both. + SmallVec::from_iter(iterator).into() + } +} + +impl SpecFromIter> for SmallDeque { + #[inline] + fn spec_from_iter(iterator: IntoIter) -> Self { + iterator.into_smalldeque() + } +} + +// Specialization trait used for SmallDeque::extend +trait SpecExtend { + fn spec_extend(&mut self, iter: I); +} + +impl SpecExtend for SmallDeque +where + I: Iterator, +{ + default fn spec_extend(&mut self, mut iter: I) { + // This function should be the moral equivalent of: + // + // for item in iter { + // self.push_back(item); + // } + + // May only be called if `deque.len() < deque.capacity()` + unsafe fn push_unchecked(deque: &mut SmallDeque, element: T) { + // SAFETY: Because of the precondition, it's guaranteed that there is space + // in the logical array after the last element. + unsafe { deque.buffer_write(deque.to_physical_idx(deque.len), element) }; + // This can't overflow because `deque.len() < deque.capacity() <= usize::MAX`. + deque.len += 1; + } + + while let Some(element) = iter.next() { + let (lower, _) = iter.size_hint(); + self.reserve(lower.saturating_add(1)); + + // SAFETY: We just reserved space for at least one element. + unsafe { push_unchecked(self, element) }; + + // Inner loop to avoid repeatedly calling `reserve`. + while self.len < self.capacity() { + let Some(element) = iter.next() else { + return; + }; + // SAFETY: The loop condition guarantees that `self.len() < self.capacity()`. + unsafe { push_unchecked(self, element) }; + } + } + } +} + +impl SpecExtend for SmallDeque +where + I: core::iter::TrustedLen, +{ + default fn spec_extend(&mut self, iter: I) { + // This is the case for a TrustedLen iterator. + let (low, high) = iter.size_hint(); + if let Some(additional) = high { + debug_assert_eq!( + low, + additional, + "TrustedLen iterator's size hint is not exact: {:?}", + (low, high) + ); + self.reserve(additional); + + let written = unsafe { + self.write_iter_wrapping(self.to_physical_idx(self.len), iter, additional) + }; + + debug_assert_eq!( + additional, written, + "The number of items written to SmallDeque doesn't match the TrustedLen size hint" + ); + } else { + // Per TrustedLen contract a `None` upper bound means that the iterator length + // truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway. + // Since the other branch already panics eagerly (via `reserve()`) we do the same here. + // This avoids additional codegen for a fallback code path which would eventually + // panic anyway. + panic!("capacity overflow"); + } + } +} + +impl<'a, T: 'a, I, const N: usize> SpecExtend<&'a T, I> for SmallDeque +where + I: Iterator, + T: Copy, +{ + default fn spec_extend(&mut self, iterator: I) { + self.spec_extend(iterator.copied()) + } +} + +impl<'a, T: 'a, const N: usize> SpecExtend<&'a T, core::slice::Iter<'a, T>> for SmallDeque +where + T: Copy, +{ + fn spec_extend(&mut self, iterator: core::slice::Iter<'a, T>) { + let slice = iterator.as_slice(); + self.reserve(slice.len()); + + unsafe { + self.copy_slice(self.to_physical_idx(self.len), slice); + self.len += slice.len(); + } + } +} + +#[cfg(test)] +mod tests { + use core::iter::TrustedLen; + + use smallvec::SmallVec; + + use super::*; + + #[test] + fn test_swap_front_back_remove() { + fn test(back: bool) { + // This test checks that every single combination of tail position and length is tested. + // Capacity 15 should be large enough to cover every case. + let mut tester = SmallDeque::<_, 16>::with_capacity(15); + let usable_cap = tester.capacity(); + let final_len = usable_cap / 2; + + for len in 0..final_len { + let expected: SmallDeque<_, 16> = if back { + (0..len).collect() + } else { + (0..len).rev().collect() + }; + for head_pos in 0..usable_cap { + tester.head = head_pos; + tester.len = 0; + if back { + for i in 0..len * 2 { + tester.push_front(i); + } + for i in 0..len { + assert_eq!(tester.swap_remove_back(i), Some(len * 2 - 1 - i)); + } + } else { + for i in 0..len * 2 { + tester.push_back(i); + } + for i in 0..len { + let idx = tester.len() - 1 - i; + assert_eq!(tester.swap_remove_front(idx), Some(len * 2 - 1 - i)); + } + } + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + assert_eq!(tester, expected); + } + } + } + test(true); + test(false); + } + + #[test] + fn test_insert() { + // This test checks that every single combination of tail position, length, and + // insertion position is tested. Capacity 15 should be large enough to cover every case. + + let mut tester = SmallDeque::<_, 16>::with_capacity(15); + // can't guarantee we got 15, so have to get what we got. + // 15 would be great, but we will definitely get 2^k - 1, for k >= 4, or else + // this test isn't covering what it wants to + let cap = tester.capacity(); + + // len is the length *after* insertion + let minlen = if cfg!(miri) { cap - 1 } else { 1 }; // Miri is too slow + for len in minlen..cap { + // 0, 1, 2, .., len - 1 + let expected = (0..).take(len).collect::>(); + for head_pos in 0..cap { + for to_insert in 0..len { + tester.head = head_pos; + tester.len = 0; + for i in 0..len { + if i != to_insert { + tester.push_back(i); + } + } + tester.insert(to_insert, to_insert); + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + assert_eq!(tester, expected); + } + } + } + } + + #[test] + fn test_get() { + let mut tester = SmallDeque::<_, 16>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + + assert_eq!(tester.len(), 3); + + assert_eq!(tester.get(1), Some(&2)); + assert_eq!(tester.get(2), Some(&3)); + assert_eq!(tester.get(0), Some(&1)); + assert_eq!(tester.get(3), None); + + tester.remove(0); + + assert_eq!(tester.len(), 2); + assert_eq!(tester.get(0), Some(&2)); + assert_eq!(tester.get(1), Some(&3)); + assert_eq!(tester.get(2), None); + } + + #[test] + fn test_get_mut() { + let mut tester = SmallDeque::<_, 16>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + + assert_eq!(tester.len(), 3); + + if let Some(elem) = tester.get_mut(0) { + assert_eq!(*elem, 1); + *elem = 10; + } + + if let Some(elem) = tester.get_mut(2) { + assert_eq!(*elem, 3); + *elem = 30; + } + + assert_eq!(tester.get(0), Some(&10)); + assert_eq!(tester.get(2), Some(&30)); + assert_eq!(tester.get_mut(3), None); + + tester.remove(2); + + assert_eq!(tester.len(), 2); + assert_eq!(tester.get(0), Some(&10)); + assert_eq!(tester.get(1), Some(&2)); + assert_eq!(tester.get(2), None); + } + + #[test] + fn test_swap() { + let mut tester = SmallDeque::<_, 3>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + + assert_eq!(tester, [1, 2, 3]); + + tester.swap(0, 0); + assert_eq!(tester, [1, 2, 3]); + tester.swap(0, 1); + assert_eq!(tester, [2, 1, 3]); + tester.swap(2, 1); + assert_eq!(tester, [2, 3, 1]); + tester.swap(1, 2); + assert_eq!(tester, [2, 1, 3]); + tester.swap(0, 2); + assert_eq!(tester, [3, 1, 2]); + tester.swap(2, 2); + assert_eq!(tester, [3, 1, 2]); + } + + #[test] + #[should_panic = "assertion failed: j < self.len()"] + fn test_swap_panic() { + let mut tester = SmallDeque::<_>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + tester.swap(2, 3); + } + + #[test] + fn test_reserve_exact() { + let mut tester: SmallDeque = SmallDeque::with_capacity(1); + assert_eq!(tester.capacity(), 1); + tester.reserve_exact(50); + assert_eq!(tester.capacity(), 50); + tester.reserve_exact(40); + // reserving won't shrink the buffer + assert_eq!(tester.capacity(), 50); + tester.reserve_exact(200); + assert_eq!(tester.capacity(), 200); + } + + #[test] + #[should_panic = "capacity overflow"] + fn test_reserve_exact_panic() { + let mut tester: SmallDeque = SmallDeque::new(); + tester.reserve_exact(usize::MAX); + } + + #[test] + fn test_contains() { + let mut tester = SmallDeque::<_>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + + assert!(tester.contains(&1)); + assert!(tester.contains(&3)); + assert!(!tester.contains(&0)); + assert!(!tester.contains(&4)); + tester.remove(0); + assert!(!tester.contains(&1)); + assert!(tester.contains(&2)); + assert!(tester.contains(&3)); + } + + #[test] + fn test_rotate_left_right() { + let mut tester: SmallDeque<_> = (1..=10).collect(); + tester.reserve(1); + + assert_eq!(tester.len(), 10); + + tester.rotate_left(0); + assert_eq!(tester, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + tester.rotate_right(0); + assert_eq!(tester, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + tester.rotate_left(3); + assert_eq!(tester, [4, 5, 6, 7, 8, 9, 10, 1, 2, 3]); + + tester.rotate_right(5); + assert_eq!(tester, [9, 10, 1, 2, 3, 4, 5, 6, 7, 8]); + + tester.rotate_left(tester.len()); + assert_eq!(tester, [9, 10, 1, 2, 3, 4, 5, 6, 7, 8]); + + tester.rotate_right(tester.len()); + assert_eq!(tester, [9, 10, 1, 2, 3, 4, 5, 6, 7, 8]); + + tester.rotate_left(1); + assert_eq!(tester, [10, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + } + + #[test] + #[should_panic = "assertion failed: n <= self.len()"] + fn test_rotate_left_panic() { + let mut tester: SmallDeque<_> = (1..=10).collect(); + tester.rotate_left(tester.len() + 1); + } + + #[test] + #[should_panic = "assertion failed: n <= self.len()"] + fn test_rotate_right_panic() { + let mut tester: SmallDeque<_> = (1..=10).collect(); + tester.rotate_right(tester.len() + 1); + } + + #[test] + fn test_binary_search() { + // If the givin SmallDeque is not sorted, the returned result is unspecified and meaningless, + // as this method performs a binary search. + + let tester: SmallDeque<_, 11> = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55].into(); + + assert_eq!(tester.binary_search(&0), Ok(0)); + assert_eq!(tester.binary_search(&5), Ok(5)); + assert_eq!(tester.binary_search(&55), Ok(10)); + assert_eq!(tester.binary_search(&4), Err(5)); + assert_eq!(tester.binary_search(&-1), Err(0)); + assert!(matches!(tester.binary_search(&1), Ok(1..=2))); + + let tester: SmallDeque<_, 14> = [1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3].into(); + assert_eq!(tester.binary_search(&1), Ok(0)); + assert!(matches!(tester.binary_search(&2), Ok(1..=4))); + assert!(matches!(tester.binary_search(&3), Ok(5..=13))); + assert_eq!(tester.binary_search(&-2), Err(0)); + assert_eq!(tester.binary_search(&0), Err(0)); + assert_eq!(tester.binary_search(&4), Err(14)); + assert_eq!(tester.binary_search(&5), Err(14)); + } + + #[test] + fn test_binary_search_by() { + // If the givin SmallDeque is not sorted, the returned result is unspecified and meaningless, + // as this method performs a binary search. + + let tester: SmallDeque = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55].into(); + + assert_eq!(tester.binary_search_by(|x| x.cmp(&0)), Ok(0)); + assert_eq!(tester.binary_search_by(|x| x.cmp(&5)), Ok(5)); + assert_eq!(tester.binary_search_by(|x| x.cmp(&55)), Ok(10)); + assert_eq!(tester.binary_search_by(|x| x.cmp(&4)), Err(5)); + assert_eq!(tester.binary_search_by(|x| x.cmp(&-1)), Err(0)); + assert!(matches!(tester.binary_search_by(|x| x.cmp(&1)), Ok(1..=2))); + } + + #[test] + fn test_binary_search_key() { + // If the givin SmallDeque is not sorted, the returned result is unspecified and meaningless, + // as this method performs a binary search. + + let tester: SmallDeque<_, 13> = [ + (-1, 0), + (2, 10), + (6, 5), + (7, 1), + (8, 10), + (10, 2), + (20, 3), + (24, 5), + (25, 18), + (28, 13), + (31, 21), + (32, 4), + (54, 25), + ] + .into(); + + assert_eq!(tester.binary_search_by_key(&-1, |&(a, _b)| a), Ok(0)); + assert_eq!(tester.binary_search_by_key(&8, |&(a, _b)| a), Ok(4)); + assert_eq!(tester.binary_search_by_key(&25, |&(a, _b)| a), Ok(8)); + assert_eq!(tester.binary_search_by_key(&54, |&(a, _b)| a), Ok(12)); + assert_eq!(tester.binary_search_by_key(&-2, |&(a, _b)| a), Err(0)); + assert_eq!(tester.binary_search_by_key(&1, |&(a, _b)| a), Err(1)); + assert_eq!(tester.binary_search_by_key(&4, |&(a, _b)| a), Err(2)); + assert_eq!(tester.binary_search_by_key(&13, |&(a, _b)| a), Err(6)); + assert_eq!(tester.binary_search_by_key(&55, |&(a, _b)| a), Err(13)); + assert_eq!(tester.binary_search_by_key(&100, |&(a, _b)| a), Err(13)); + + let tester: SmallDeque<_, 13> = [ + (0, 0), + (2, 1), + (6, 1), + (5, 1), + (3, 1), + (1, 2), + (2, 3), + (4, 5), + (5, 8), + (8, 13), + (1, 21), + (2, 34), + (4, 55), + ] + .into(); + + assert_eq!(tester.binary_search_by_key(&0, |&(_a, b)| b), Ok(0)); + assert!(matches!(tester.binary_search_by_key(&1, |&(_a, b)| b), Ok(1..=4))); + assert_eq!(tester.binary_search_by_key(&8, |&(_a, b)| b), Ok(8)); + assert_eq!(tester.binary_search_by_key(&13, |&(_a, b)| b), Ok(9)); + assert_eq!(tester.binary_search_by_key(&55, |&(_a, b)| b), Ok(12)); + assert_eq!(tester.binary_search_by_key(&-1, |&(_a, b)| b), Err(0)); + assert_eq!(tester.binary_search_by_key(&4, |&(_a, b)| b), Err(7)); + assert_eq!(tester.binary_search_by_key(&56, |&(_a, b)| b), Err(13)); + assert_eq!(tester.binary_search_by_key(&100, |&(_a, b)| b), Err(13)); + } + + #[test] + fn make_contiguous_big_head() { + let mut tester = SmallDeque::<_>::with_capacity(15); + + for i in 0..3 { + tester.push_back(i); + } + + for i in 3..10 { + tester.push_front(i); + } + + // 012......9876543 + assert_eq!(tester.capacity(), 15); + assert_eq!((&[9, 8, 7, 6, 5, 4, 3] as &[_], &[0, 1, 2] as &[_]), tester.as_slices()); + + let expected_start = tester.as_slices().1.len(); + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!((&[9, 8, 7, 6, 5, 4, 3, 0, 1, 2] as &[_], &[] as &[_]), tester.as_slices()); + } + + #[test] + fn make_contiguous_big_tail() { + let mut tester = SmallDeque::<_>::with_capacity(15); + + for i in 0..8 { + tester.push_back(i); + } + + for i in 8..10 { + tester.push_front(i); + } + + // 01234567......98 + let expected_start = 0; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!((&[9, 8, 0, 1, 2, 3, 4, 5, 6, 7] as &[_], &[] as &[_]), tester.as_slices()); + } + + #[test] + fn make_contiguous_small_free() { + let mut tester = SmallDeque::<_>::with_capacity(16); + + for i in b'A'..b'I' { + tester.push_back(i as char); + } + + for i in b'I'..b'N' { + tester.push_front(i as char); + } + + assert_eq!(tester, ['M', 'L', 'K', 'J', 'I', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']); + + // ABCDEFGH...MLKJI + let expected_start = 0; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!( + ( + &['M', 'L', 'K', 'J', 'I', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] as &[_], + &[] as &[_] + ), + tester.as_slices() + ); + + tester.clear(); + for i in b'I'..b'N' { + tester.push_back(i as char); + } + + for i in b'A'..b'I' { + tester.push_front(i as char); + } + + // IJKLM...HGFEDCBA + let expected_start = 3; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!( + ( + &['H', 'G', 'F', 'E', 'D', 'C', 'B', 'A', 'I', 'J', 'K', 'L', 'M'] as &[_], + &[] as &[_] + ), + tester.as_slices() + ); + } + + #[test] + fn make_contiguous_head_to_end() { + let mut tester = SmallDeque::<_>::with_capacity(16); + + for i in b'A'..b'L' { + tester.push_back(i as char); + } + + for i in b'L'..b'Q' { + tester.push_front(i as char); + } + + assert_eq!( + tester, + ['P', 'O', 'N', 'M', 'L', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K'] + ); + + // ABCDEFGHIJKPONML + let expected_start = 0; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!( + ( + &['P', 'O', 'N', 'M', 'L', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K'] + as &[_], + &[] as &[_] + ), + tester.as_slices() + ); + + tester.clear(); + for i in b'L'..b'Q' { + tester.push_back(i as char); + } + + for i in b'A'..b'L' { + tester.push_front(i as char); + } + + // LMNOPKJIHGFEDCBA + let expected_start = 0; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!( + ( + &['K', 'J', 'I', 'H', 'G', 'F', 'E', 'D', 'C', 'B', 'A', 'L', 'M', 'N', 'O', 'P'] + as &[_], + &[] as &[_] + ), + tester.as_slices() + ); + } + + #[test] + fn make_contiguous_head_to_end_2() { + // Another test case for #79808, taken from #80293. + + let mut dq = SmallDeque::<_>::from_iter(0..6); + dq.pop_front(); + dq.pop_front(); + dq.push_back(6); + dq.push_back(7); + dq.push_back(8); + dq.make_contiguous(); + let collected: Vec<_> = dq.iter().copied().collect(); + assert_eq!(dq.as_slices(), (&collected[..], &[] as &[_])); + } + + #[test] + fn test_remove() { + // This test checks that every single combination of tail position, length, and + // removal position is tested. Capacity 15 should be large enough to cover every case. + + let mut tester = SmallDeque::<_>::with_capacity(15); + // can't guarantee we got 15, so have to get what we got. + // 15 would be great, but we will definitely get 2^k - 1, for k >= 4, or else + // this test isn't covering what it wants to + let cap = tester.capacity(); + + // len is the length *after* removal + let minlen = if cfg!(miri) { cap - 2 } else { 0 }; // Miri is too slow + for len in minlen..cap - 1 { + // 0, 1, 2, .., len - 1 + let expected = (0..).take(len).collect::>(); + for head_pos in 0..cap { + for to_remove in 0..=len { + tester.head = head_pos; + tester.len = 0; + for i in 0..len { + if i == to_remove { + tester.push_back(1234); + } + tester.push_back(i); + } + if to_remove == len { + tester.push_back(1234); + } + tester.remove(to_remove); + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + assert_eq!(tester, expected); + } + } + } + } + + #[test] + fn test_range() { + let mut tester: SmallDeque = SmallDeque::<_>::with_capacity(7); + + let cap = tester.capacity(); + let minlen = if cfg!(miri) { cap - 1 } else { 0 }; // Miri is too slow + for len in minlen..=cap { + for head in 0..=cap { + for start in 0..=len { + for end in start..=len { + tester.head = head; + tester.len = 0; + for i in 0..len { + tester.push_back(i); + } + + // Check that we iterate over the correct values + let range: SmallDeque<_> = tester.range(start..end).copied().collect(); + let expected: SmallDeque<_> = (start..end).collect(); + assert_eq!(range, expected); + } + } + } + } + } + + #[test] + fn test_range_mut() { + let mut tester: SmallDeque = SmallDeque::with_capacity(7); + + let cap = tester.capacity(); + for len in 0..=cap { + for head in 0..=cap { + for start in 0..=len { + for end in start..=len { + tester.head = head; + tester.len = 0; + for i in 0..len { + tester.push_back(i); + } + + let head_was = tester.head; + let len_was = tester.len; + + // Check that we iterate over the correct values + let range: SmallDeque<_> = + tester.range_mut(start..end).map(|v| *v).collect(); + let expected: SmallDeque<_> = (start..end).collect(); + assert_eq!(range, expected); + + // We shouldn't have changed the capacity or made the + // head or tail out of bounds + assert_eq!(tester.capacity(), cap); + assert_eq!(tester.head, head_was); + assert_eq!(tester.len, len_was); + } + } + } + } + } + + #[test] + fn test_drain() { + let mut tester: SmallDeque = SmallDeque::with_capacity(7); + + let cap = tester.capacity(); + for len in 0..=cap { + for head in 0..cap { + for drain_start in 0..=len { + for drain_end in drain_start..=len { + tester.head = head; + tester.len = 0; + for i in 0..len { + tester.push_back(i); + } + + // Check that we drain the correct values + let drained: SmallDeque<_> = tester.drain(drain_start..drain_end).collect(); + let drained_expected: SmallDeque<_> = (drain_start..drain_end).collect(); + assert_eq!(drained, drained_expected); + + // We shouldn't have changed the capacity or made the + // head or tail out of bounds + assert_eq!(tester.capacity(), cap); + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + + // We should see the correct values in the SmallDeque + let expected: SmallDeque<_> = + (0..drain_start).chain(drain_end..len).collect(); + assert_eq!(expected, tester); + } + } + } + } + } + + #[test] + fn test_split_off() { + // This test checks that every single combination of tail position, length, and + // split position is tested. Capacity 15 should be large enough to cover every case. + + let mut tester = SmallDeque::with_capacity(15); + // can't guarantee we got 15, so have to get what we got. + // 15 would be great, but we will definitely get 2^k - 1, for k >= 4, or else + // this test isn't covering what it wants to + let cap = tester.capacity(); + + // len is the length *before* splitting + let minlen = if cfg!(miri) { cap - 1 } else { 0 }; // Miri is too slow + for len in minlen..cap { + // index to split at + for at in 0..=len { + // 0, 1, 2, .., at - 1 (may be empty) + let expected_self = (0..).take(at).collect::>(); + // at, at + 1, .., len - 1 (may be empty) + let expected_other = (at..).take(len - at).collect::>(); + + for head_pos in 0..cap { + tester.head = head_pos; + tester.len = 0; + for i in 0..len { + tester.push_back(i); + } + let result = tester.split_off(at); + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + assert!(result.head <= result.capacity()); + assert!(result.len <= result.capacity()); + assert_eq!(tester, expected_self); + assert_eq!(result, expected_other); + } + } + } + } + + #[test] + fn test_from_smallvec() { + for cap in 0..35 { + for len in 0..=cap { + let mut vec = SmallVec::<[_; 16]>::with_capacity(cap); + vec.extend(0..len); + + let vd = SmallDeque::from(vec.clone()); + assert_eq!(vd.len(), vec.len()); + assert!(vd.into_iter().eq(vec)); + } + } + } + + #[test] + fn test_extend_basic() { + test_extend_impl(false); + } + + #[test] + fn test_extend_trusted_len() { + test_extend_impl(true); + } + + fn test_extend_impl(trusted_len: bool) { + struct SmallDequeTester { + test: SmallDeque, + expected: SmallDeque, + trusted_len: bool, + } + + impl SmallDequeTester { + fn new(trusted_len: bool) -> Self { + Self { + test: SmallDeque::new(), + expected: SmallDeque::new(), + trusted_len, + } + } + + fn test_extend(&mut self, iter: I) + where + I: Iterator + TrustedLen + Clone, + { + struct BasicIterator(I); + impl Iterator for BasicIterator + where + I: Iterator, + { + type Item = usize; + + fn next(&mut self) -> Option { + self.0.next() + } + } + + if self.trusted_len { + self.test.extend(iter.clone()); + } else { + self.test.extend(BasicIterator(iter.clone())); + } + + for item in iter { + self.expected.push_back(item) + } + + assert_eq!(self.test, self.expected); + } + + fn drain + Clone>(&mut self, range: R) { + self.test.drain(range.clone()); + self.expected.drain(range); + + assert_eq!(self.test, self.expected); + } + + fn clear(&mut self) { + self.test.clear(); + self.expected.clear(); + } + + fn remaining_capacity(&self) -> usize { + self.test.capacity() - self.test.len() + } + } + + let mut tester = SmallDequeTester::new(trusted_len); + + // Initial capacity + tester.test_extend(0..tester.remaining_capacity()); + + // Grow + tester.test_extend(1024..2048); + + // Wrap around + tester.drain(..128); + + tester.test_extend(0..tester.remaining_capacity()); + + // Continue + tester.drain(256..); + tester.test_extend(4096..8196); + + tester.clear(); + + // Start again + tester.test_extend(0..32); + } + + #[test] + fn test_from_array() { + fn test() { + let mut array: [usize; N] = [0; N]; + + for (i, v) in array.iter_mut().enumerate() { + *v = i; + } + + let deq: SmallDeque<_, N> = array.into(); + + for i in 0..N { + assert_eq!(deq[i], i); + } + + assert_eq!(deq.len(), N); + } + test::<0>(); + test::<1>(); + test::<2>(); + test::<32>(); + test::<35>(); + } + + #[test] + fn test_smallvec_from_smalldeque() { + fn create_vec_and_test_convert(capacity: usize, offset: usize, len: usize) { + let mut vd = SmallDeque::<_, 16>::with_capacity(capacity); + for _ in 0..offset { + vd.push_back(0); + vd.pop_front(); + } + vd.extend(0..len); + + let vec: SmallVec<_> = SmallVec::from(vd.clone()); + assert_eq!(vec.len(), vd.len()); + assert!(vec.into_iter().eq(vd)); + } + + // Miri is too slow + let max_pwr = if cfg!(miri) { 5 } else { 7 }; + + for cap_pwr in 0..max_pwr { + // Make capacity as a (2^x)-1, so that the ring size is 2^x + let cap = (2i32.pow(cap_pwr) - 1) as usize; + + // In these cases there is enough free space to solve it with copies + for len in 0..((cap + 1) / 2) { + // Test contiguous cases + for offset in 0..(cap - len) { + create_vec_and_test_convert(cap, offset, len) + } + + // Test cases where block at end of buffer is bigger than block at start + for offset in (cap - len)..(cap - (len / 2)) { + create_vec_and_test_convert(cap, offset, len) + } + + // Test cases where block at start of buffer is bigger than block at end + for offset in (cap - (len / 2))..cap { + create_vec_and_test_convert(cap, offset, len) + } + } + + // Now there's not (necessarily) space to straighten the ring with simple copies, + // the ring will use swapping when: + // (cap + 1 - offset) > (cap + 1 - len) && (len - (cap + 1 - offset)) > (cap + 1 - len)) + // right block size > free space && left block size > free space + for len in ((cap + 1) / 2)..cap { + // Test contiguous cases + for offset in 0..(cap - len) { + create_vec_and_test_convert(cap, offset, len) + } + + // Test cases where block at end of buffer is bigger than block at start + for offset in (cap - len)..(cap - (len / 2)) { + create_vec_and_test_convert(cap, offset, len) + } + + // Test cases where block at start of buffer is bigger than block at end + for offset in (cap - (len / 2))..cap { + create_vec_and_test_convert(cap, offset, len) + } + } + } + } + + #[test] + fn test_clone_from() { + use smallvec::smallvec; + + let m = smallvec![1; 8]; + let n = smallvec![2; 12]; + let limit = if cfg!(miri) { 4 } else { 8 }; // Miri is too slow + for pfv in 0..limit { + for pfu in 0..limit { + for longer in 0..2 { + let (vr, ur) = if longer == 0 { (&m, &n) } else { (&n, &m) }; + let mut v = SmallDeque::<_>::from(vr.clone()); + for _ in 0..pfv { + v.push_front(1); + } + let mut u = SmallDeque::<_>::from(ur.clone()); + for _ in 0..pfu { + u.push_front(2); + } + v.clone_from(&u); + assert_eq!(&v, &u); + } + } + } + } + + #[test] + fn test_vec_deque_truncate_drop() { + static mut DROPS: u32 = 0; + #[derive(Clone)] + struct Elem(#[allow(dead_code)] i32); + impl Drop for Elem { + fn drop(&mut self) { + unsafe { + DROPS += 1; + } + } + } + + let v = vec![Elem(1), Elem(2), Elem(3), Elem(4), Elem(5)]; + for push_front in 0..=v.len() { + let v = v.clone(); + let mut tester = SmallDeque::<_>::with_capacity(5); + for (index, elem) in v.into_iter().enumerate() { + if index < push_front { + tester.push_front(elem); + } else { + tester.push_back(elem); + } + } + assert_eq!(unsafe { DROPS }, 0); + tester.truncate(3); + assert_eq!(unsafe { DROPS }, 2); + tester.truncate(0); + assert_eq!(unsafe { DROPS }, 5); + unsafe { + DROPS = 0; + } + } + } + + #[test] + fn issue_53529() { + let mut dst = SmallDeque::<_>::new(); + dst.push_front(Box::new(1)); + dst.push_front(Box::new(2)); + assert_eq!(*dst.pop_back().unwrap(), 1); + + let mut src = SmallDeque::<_>::new(); + src.push_front(Box::new(2)); + dst.append(&mut src); + for a in dst { + assert_eq!(*a, 2); + } + } + + #[test] + fn issue_80303() { + use core::{ + hash::{Hash, Hasher}, + iter, + num::Wrapping, + }; + + // This is a valid, albeit rather bad hash function implementation. + struct SimpleHasher(Wrapping); + + impl Hasher for SimpleHasher { + fn finish(&self) -> u64 { + self.0 .0 + } + + fn write(&mut self, bytes: &[u8]) { + // This particular implementation hashes value 24 in addition to bytes. + // Such an implementation is valid as Hasher only guarantees equivalence + // for the exact same set of calls to its methods. + for &v in iter::once(&24).chain(bytes) { + self.0 = Wrapping(31) * self.0 + Wrapping(u64::from(v)); + } + } + } + + fn hash_code(value: impl Hash) -> u64 { + let mut hasher = SimpleHasher(Wrapping(1)); + value.hash(&mut hasher); + hasher.finish() + } + + // This creates two deques for which values returned by as_slices + // method differ. + let vda: SmallDeque = (0..10).collect(); + let mut vdb = SmallDeque::with_capacity(10); + vdb.extend(5..10); + (0..5).rev().for_each(|elem| vdb.push_front(elem)); + assert_ne!(vda.as_slices(), vdb.as_slices()); + assert_eq!(vda, vdb); + assert_eq!(hash_code(vda), hash_code(vdb)); + } +} diff --git a/hir2/src/adt/smallmap.rs b/hir2/src/adt/smallmap.rs new file mode 100644 index 000000000..c0b61ac0f --- /dev/null +++ b/hir2/src/adt/smallmap.rs @@ -0,0 +1,491 @@ +use core::{ + borrow::Borrow, + cmp::Ordering, + fmt, + ops::{Index, IndexMut}, +}; + +use smallvec::SmallVec; + +/// [SmallMap] is a [BTreeMap]-like structure that can store a specified number +/// of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// This data structure is designed with two goals in mind: +/// +/// * Support efficient key/value operations over a small set of keys +/// * Preserve the order of keys +/// * Avoid allocating data on the heap for the typical case +/// +/// Internally, [SmallMap] is implemented on top of [SmallVec], and uses binary search +/// to locate elements. This is quite efficient in general, and is particularly fast +/// when all of the data is stored inline, but may not be a good fit for all use cases. +/// +/// Due to its design constraints, it only supports keys which implement [Ord]. +pub struct SmallMap { + items: SmallVec<[KeyValuePair; N]>, +} +impl SmallMap +where + K: Ord, +{ + /// Returns a new, empty [SmallMap] + pub const fn new() -> Self { + Self { + items: SmallVec::new_const(), + } + } + + /// Returns a new, empty [SmallMap], with capacity for `capacity` nodes without reallocating. + pub fn with_capacity(capacity: usize) -> Self { + Self { + items: SmallVec::with_capacity(capacity), + } + } + + /// Returns true if this map is empty + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Returns the number of key/value pairs in this map + pub fn len(&self) -> usize { + self.items.len() + } + + /// Return an iterator over the key/value pairs in this map + pub fn iter(&self) -> impl DoubleEndedIterator { + self.items.iter().map(|pair| (&pair.key, &pair.value)) + } + + /// Return an iterator over mutable key/value pairs in this map + pub fn iter_mut(&mut self) -> impl DoubleEndedIterator { + self.items.iter_mut().map(|pair| (&pair.key, &mut pair.value)) + } + + /// Returns true if `key` has been inserted in this map + pub fn contains(&self, key: &Q) -> bool + where + K: Borrow, + Q: Ord + ?Sized, + { + self.find(key).is_ok() + } + + /// Returns the value under `key` in this map, if it exists + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: Ord + ?Sized, + { + match self.find(key) { + Ok(idx) => Some(&self.items[idx].value), + Err(_) => None, + } + } + + /// Returns a mutable reference to the value under `key` in this map, if it exists + pub fn get_mut(&mut self, key: &Q) -> Option<&mut V> + where + K: Borrow, + Q: Ord + ?Sized, + { + match self.find(key) { + Ok(idx) => Some(&mut self.items[idx].value), + Err(_) => None, + } + } + + /// Inserts a new entry in this map using `key` and `value`. + /// + /// Returns the previous value, if `key` was already present in the map. + pub fn insert(&mut self, key: K, value: V) -> Option { + match self.entry(key) { + Entry::Occupied(mut entry) => Some(core::mem::replace(entry.get_mut(), value)), + Entry::Vacant(entry) => { + entry.insert(value); + None + } + } + } + + /// Removes the value inserted under `key`, if it exists + pub fn remove(&mut self, key: &Q) -> Option + where + K: Borrow, + Q: Ord + ?Sized, + { + match self.find(key) { + Ok(idx) => Some(self.items.remove(idx).value), + Err(_) => None, + } + } + + /// Clear the content of the map + pub fn clear(&mut self) { + self.items.clear(); + } + + /// Returns an [Entry] which can be used to combine `contains`+`insert` type operations. + pub fn entry(&mut self, key: K) -> Entry<'_, K, V, N> { + match self.find(&key) { + Ok(idx) => Entry::occupied(self, idx), + Err(idx) => Entry::vacant(self, idx, key), + } + } + + #[inline] + fn find(&self, item: &Q) -> Result + where + K: Borrow, + Q: Ord + ?Sized, + { + self.items.binary_search_by(|probe| Ord::cmp(probe.key.borrow(), item)) + } +} +impl Default for SmallMap { + fn default() -> Self { + Self { + items: Default::default(), + } + } +} +impl Eq for SmallMap +where + K: Eq, + V: Eq, +{ +} +impl PartialEq for SmallMap +where + K: PartialEq, + V: PartialEq, +{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.items + .iter() + .map(|pair| (&pair.key, &pair.value)) + .eq(other.items.iter().map(|pair| (&pair.key, &pair.value))) + } +} +impl fmt::Debug for SmallMap +where + K: fmt::Debug + Ord, + V: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_map() + .entries(self.items.iter().map(|item| (&item.key, &item.value))) + .finish() + } +} +impl Clone for SmallMap +where + K: Clone, + V: Clone, +{ + #[inline] + fn clone(&self) -> Self { + Self { + items: self.items.clone(), + } + } +} +impl IntoIterator for SmallMap +where + K: Ord, +{ + type IntoIter = SmallMapIntoIter; + type Item = (K, V); + + #[inline] + fn into_iter(self) -> Self::IntoIter { + SmallMapIntoIter { + iter: self.items.into_iter(), + } + } +} +impl FromIterator<(K, V)> for SmallMap +where + K: Ord, +{ + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut map = Self::default(); + for (k, v) in iter { + map.insert(k, v); + } + map + } +} +impl Index<&Q> for SmallMap +where + K: Borrow + Ord, + Q: Ord + ?Sized, +{ + type Output = V; + + fn index(&self, key: &Q) -> &Self::Output { + self.get(key).unwrap() + } +} +impl IndexMut<&Q> for SmallMap +where + K: Borrow + Ord, + Q: Ord + ?Sized, +{ + fn index_mut(&mut self, key: &Q) -> &mut Self::Output { + self.get_mut(key).unwrap() + } +} + +#[doc(hidden)] +pub struct SmallMapIntoIter { + iter: smallvec::IntoIter<[KeyValuePair; N]>, +} +impl ExactSizeIterator for SmallMapIntoIter { + #[inline(always)] + fn len(&self) -> usize { + self.iter.len() + } +} +impl Iterator for SmallMapIntoIter { + type Item = (K, V); + + #[inline(always)] + fn next(&mut self) -> Option { + self.iter.next().map(|pair| (pair.key, pair.value)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } + + #[inline] + fn count(self) -> usize { + self.iter.count() + } + + #[inline] + fn last(self) -> Option<(K, V)> { + self.iter.last().map(|pair| (pair.key, pair.value)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + self.iter.nth(n).map(|pair| (pair.key, pair.value)) + } +} +impl DoubleEndedIterator for SmallMapIntoIter { + #[inline] + fn next_back(&mut self) -> Option { + self.iter.next_back().map(|pair| (pair.key, pair.value)) + } + + #[inline] + fn nth_back(&mut self, n: usize) -> Option { + self.iter.nth_back(n).map(|pair| (pair.key, pair.value)) + } +} + +/// Represents an key/value pair entry in a [SmallMap] +pub enum Entry<'a, K, V, const N: usize> { + Occupied(OccupiedEntry<'a, K, V, N>), + Vacant(VacantEntry<'a, K, V, N>), +} +impl<'a, K, V: Default, const N: usize> Entry<'a, K, V, N> { + pub fn or_default(self) -> &'a mut V { + match self { + Self::Occupied(entry) => entry.into_mut(), + Self::Vacant(entry) => entry.insert(Default::default()), + } + } +} +impl<'a, K, V, const N: usize> Entry<'a, K, V, N> { + fn occupied(map: &'a mut SmallMap, idx: usize) -> Self { + Self::Occupied(OccupiedEntry { map, idx }) + } + + fn vacant(map: &'a mut SmallMap, idx: usize, key: K) -> Self { + Self::Vacant(VacantEntry { map, idx, key }) + } + + pub fn or_insert(self, default: V) -> &'a mut V { + match self { + Self::Occupied(entry) => entry.into_mut(), + Self::Vacant(entry) => entry.insert(default), + } + } + + pub fn or_insert_with(self, default: F) -> &'a mut V + where + F: FnOnce() -> V, + { + match self { + Self::Occupied(entry) => entry.into_mut(), + Self::Vacant(entry) => entry.insert(default()), + } + } + + pub fn key(&self) -> &K { + match self { + Self::Occupied(entry) => entry.key(), + Self::Vacant(entry) => entry.key(), + } + } + + pub fn and_modify(self, f: F) -> Self + where + F: FnOnce(&mut V), + { + match self { + Self::Occupied(mut entry) => { + f(entry.get_mut()); + Self::Occupied(entry) + } + vacant @ Self::Vacant(_) => vacant, + } + } +} + +/// Represents an occupied entry in a [SmallMap] +pub struct OccupiedEntry<'a, K, V, const N: usize> { + map: &'a mut SmallMap, + idx: usize, +} +impl<'a, K, V, const N: usize> OccupiedEntry<'a, K, V, N> { + #[inline(always)] + fn get_entry(&self) -> &KeyValuePair { + &self.map.items[self.idx] + } + + pub fn remove_entry(self) -> V { + self.map.items.remove(self.idx).value + } + + pub fn key(&self) -> &K { + &self.get_entry().key + } + + pub fn get(&self) -> &V { + &self.get_entry().value + } + + pub fn get_mut(&mut self) -> &mut V { + &mut self.map.items[self.idx].value + } + + pub fn into_mut(self) -> &'a mut V { + &mut self.map.items[self.idx].value + } +} + +/// Represents a vacant entry in a [SmallMap] +pub struct VacantEntry<'a, K, V, const N: usize> { + map: &'a mut SmallMap, + idx: usize, + key: K, +} +impl<'a, K, V, const N: usize> VacantEntry<'a, K, V, N> { + pub fn key(&self) -> &K { + &self.key + } + + pub fn into_key(self) -> K { + self.key + } + + pub fn insert_with(self, f: F) -> &'a mut V + where + F: FnOnce() -> V, + { + self.map.items.insert( + self.idx, + KeyValuePair { + key: self.key, + value: f(), + }, + ); + &mut self.map.items[self.idx].value + } + + pub fn insert(self, value: V) -> &'a mut V { + self.map.items.insert( + self.idx, + KeyValuePair { + key: self.key, + value, + }, + ); + &mut self.map.items[self.idx].value + } +} + +struct KeyValuePair { + key: K, + value: V, +} +impl AsRef for KeyValuePair { + #[inline] + fn as_ref(&self) -> &V { + &self.value + } +} +impl AsMut for KeyValuePair { + #[inline] + fn as_mut(&mut self) -> &mut V { + &mut self.value + } +} +impl Clone for KeyValuePair +where + K: Clone, + V: Clone, +{ + fn clone(&self) -> Self { + Self { + key: self.key.clone(), + value: self.value.clone(), + } + } +} +impl Copy for KeyValuePair +where + K: Copy, + V: Copy, +{ +} + +impl Eq for KeyValuePair where K: Eq {} + +impl PartialEq for KeyValuePair +where + K: PartialEq, +{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.key.eq(&other.key) + } +} + +impl Ord for KeyValuePair +where + K: Ord, +{ + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + self.key.cmp(&other.key) + } +} +impl PartialOrd for KeyValuePair +where + K: PartialOrd, +{ + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + self.key.partial_cmp(&other.key) + } +} diff --git a/hir2/src/adt/smallordset.rs b/hir2/src/adt/smallordset.rs new file mode 100644 index 000000000..1ab27c013 --- /dev/null +++ b/hir2/src/adt/smallordset.rs @@ -0,0 +1,189 @@ +use core::{borrow::Borrow, fmt}; + +use smallvec::SmallVec; + +/// [SmallOrdSet] is a [BTreeSet]-like structure that can store a specified number +/// of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// This data structure is designed with two goals in mind: +/// +/// * Support efficient set operations over a small set of items +/// * Maintains the underlying set in order (according to the `Ord` impl of the element type) +/// * Avoid allocating data on the heap for the typical case +/// +/// Internally, [SmallOrdSet] is implemented on top of [SmallVec], and uses binary search +/// to locate elements. This is quite efficient in general, and is particularly fast +/// when all of the data is stored inline, but may not be a good fit for all use cases. +/// +/// Due to its design constraints, it only supports elements which implement [Ord]. +/// +/// NOTE: This type differs from [SmallSet] in that [SmallOrdSet] uses the [Ord] implementation +/// of the element type for ordering, while [SmallSet] preserves the insertion order of elements. +/// Beyond that, the two types are meant to be essentially equivalent. +pub struct SmallOrdSet { + items: SmallVec<[T; N]>, +} +impl Default for SmallOrdSet { + fn default() -> Self { + Self { + items: Default::default(), + } + } +} +impl Eq for SmallOrdSet where T: Eq {} +impl PartialEq for SmallOrdSet +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.items.eq(&other.items) + } +} +impl fmt::Debug for SmallOrdSet +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_set().entries(self.items.iter()).finish() + } +} +impl Clone for SmallOrdSet +where + T: Clone, +{ + #[inline] + fn clone(&self) -> Self { + Self { + items: self.items.clone(), + } + } +} +impl SmallOrdSet +where + T: Ord, +{ + pub fn from_vec(items: SmallVec<[T; N]>) -> Self { + let mut set = Self { items }; + set.sort_and_dedup(); + set + } + + #[inline] + pub fn from_buf(buf: [T; N]) -> Self { + Self::from_vec(buf.into()) + } +} +impl IntoIterator for SmallOrdSet { + type IntoIter = smallvec::IntoIter<[T; N]>; + type Item = T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} +impl FromIterator for SmallOrdSet +where + T: Ord, +{ + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut set = Self::default(); + for item in iter { + set.insert(item); + } + set + } +} +impl SmallOrdSet +where + T: Ord, +{ + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + pub fn len(&self) -> usize { + self.items.len() + } + + pub fn as_slice(&self) -> &[T] { + self.items.as_slice() + } + + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.items.iter() + } + + pub fn insert(&mut self, item: T) -> bool { + match self.find(&item) { + Ok(_) => false, + Err(idx) => { + self.items.insert(idx, item); + true + } + } + } + + pub fn remove(&mut self, item: &Q) -> Option + where + T: Borrow, + Q: Ord + ?Sized, + { + match self.find(item) { + Ok(idx) => Some(self.items.remove(idx)), + Err(_) => None, + } + } + + /// Clear the content of the set + pub fn clear(&mut self) { + self.items.clear(); + } + + pub fn contains(&self, item: &Q) -> bool + where + T: Borrow, + Q: Ord + ?Sized, + { + self.find(item).is_ok() + } + + pub fn get(&self, item: &Q) -> Option<&T> + where + T: Borrow, + Q: Ord + ?Sized, + { + match self.find(item) { + Ok(idx) => Some(&self.items[idx]), + Err(_) => None, + } + } + + pub fn get_mut(&mut self, item: &Q) -> Option<&mut T> + where + T: Borrow, + Q: Ord + ?Sized, + { + match self.find(item) { + Ok(idx) => Some(&mut self.items[idx]), + Err(_) => None, + } + } + + #[inline] + fn find(&self, item: &Q) -> Result + where + T: Borrow, + Q: Ord + ?Sized, + { + self.items.binary_search_by(|probe| Ord::cmp(probe.borrow(), item)) + } + + fn sort_and_dedup(&mut self) { + self.items.sort_unstable(); + self.items.dedup(); + } +} diff --git a/hir2/src/adt/smallprio.rs b/hir2/src/adt/smallprio.rs new file mode 100644 index 000000000..b5b5258e8 --- /dev/null +++ b/hir2/src/adt/smallprio.rs @@ -0,0 +1,167 @@ +use core::{cmp::Ordering, fmt}; + +use smallvec::SmallVec; + +use super::SmallDeque; + +/// [SmallPriorityQueue] is a priority queue structure that can store a specified number +/// of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// Elements in the queue are stored "largest priority first", as determined by the [Ord] +/// implementation of the element type. If you instead wish to have a "lowest priority first" +/// queue, you can use [core::cmp::Reverse] to invert the natural order of the type. +/// +/// It is an exercise for the reader to figure out how to wrap a type with a custom comparator +/// function. Since that isn't particularly needed yet, no built-in support for that is provided. +pub struct SmallPriorityQueue { + pq: SmallDeque, +} +impl fmt::Debug for SmallPriorityQueue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.iter()).finish() + } +} +impl Default for SmallPriorityQueue { + fn default() -> Self { + Self { + pq: Default::default(), + } + } +} +impl Clone for SmallPriorityQueue { + fn clone(&self) -> Self { + Self { + pq: self.pq.clone(), + } + } + + fn clone_from(&mut self, source: &Self) { + self.pq.clone_from(&source.pq); + } +} + +impl SmallPriorityQueue { + /// Returns true if this map is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.pq.is_empty() + } + + /// Returns the number of key/value pairs in this map + #[inline] + pub fn len(&self) -> usize { + self.pq.len() + } + + /// Pop the highest priority item from the queue + #[inline] + pub fn pop(&mut self) -> Option { + self.pq.pop_front() + } + + /// Pop the lowest priority item from the queue + #[inline] + pub fn pop_last(&mut self) -> Option { + self.pq.pop_back() + } + + /// Get a reference to the next highest priority item in the queue + #[inline] + pub fn front(&self) -> Option<&T> { + self.pq.front() + } + + /// Get a reference to the lowest highest priority item in the queue + #[inline] + pub fn back(&self) -> Option<&T> { + self.pq.back() + } + + /// Get a front-to-back iterator over the items in the queue + #[inline] + pub fn iter(&self) -> super::smalldeque::Iter<'_, T> { + self.pq.iter() + } +} + +impl SmallPriorityQueue +where + T: Ord, +{ + /// Returns a new, empty [SmallPriorityQueue] + pub const fn new() -> Self { + Self { + pq: SmallDeque::new(), + } + } + + /// Push an item on the queue. + /// + /// If the item's priority is equal to, or greater than any other item in the queue, the newly + /// pushed item will be placed at the front of the queue. Otherwise, the item is placed in the + /// queue at the next slot where it's priority is at least the same as the next value in the + /// queue at that slot. + pub fn push(&mut self, item: T) { + if let Some(head) = self.pq.front() { + match head.cmp(&item) { + Ordering::Greater => self.push_slow(item), + Ordering::Equal | Ordering::Less => { + // Push to the front for efficiency + self.pq.push_front(item); + } + } + } else { + self.pq.push_back(item); + } + } + + /// Push an item on the queue, by conducting a search for the most appropriate index at which + /// to insert the new item, based upon a comparator function that compares the priorities of + /// the items. + fn push_slow(&mut self, item: T) { + match self.pq.binary_search_by(|probe| probe.cmp(&item)) { + Ok(index) => { + self.pq.insert(index, item); + } + Err(index) => { + self.pq.insert(index, item); + } + } + } +} + +impl IntoIterator for SmallPriorityQueue { + type IntoIter = super::smalldeque::IntoIter; + type Item = T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.pq.into_iter() + } +} + +impl FromIterator for SmallPriorityQueue +where + T: Ord, +{ + #[inline] + fn from_iter>(iter: I) -> Self { + let mut pq = SmallDeque::from_iter(iter); + + let items = pq.make_contiguous(); + + items.sort(); + + Self { pq } + } +} + +impl From> for SmallPriorityQueue { + fn from(mut value: SmallVec<[T; N]>) -> Self { + value.sort(); + + Self { + pq: SmallDeque::from(value), + } + } +} diff --git a/hir2/src/adt/smallset.rs b/hir2/src/adt/smallset.rs new file mode 100644 index 000000000..73c4aabc5 --- /dev/null +++ b/hir2/src/adt/smallset.rs @@ -0,0 +1,298 @@ +use core::{borrow::Borrow, fmt}; + +use smallvec::SmallVec; + +/// [SmallSet] is a set data structure that can store a specified number +/// of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// This data structure is designed with two goals in mind: +/// +/// * Support efficient set operations over a small set of items +/// * Preserve the insertion order of those items +/// * Avoid allocating data on the heap for the typical case +/// +/// Internally, [SmallSet] is implemented on top of [SmallVec], and uses linear search +/// to locate elements. This is only reasonably efficient on small sets, for anything +/// larger you should reach for quite efficient in general, and is particularly fast +/// when all of the data is stored inline, but may not be a good fit for all use cases. +/// +/// Due to its design constraints, elements must implement [Eq]. +pub struct SmallSet { + items: SmallVec<[T; N]>, +} +impl Default for SmallSet { + fn default() -> Self { + Self { + items: Default::default(), + } + } +} +impl Eq for SmallSet where T: Eq {} +impl PartialEq for SmallSet +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.items.eq(&other.items) + } +} +impl fmt::Debug for SmallSet +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_set().entries(self.items.iter()).finish() + } +} +impl Clone for SmallSet +where + T: Clone, +{ + #[inline] + fn clone(&self) -> Self { + Self { + items: self.items.clone(), + } + } +} +impl IntoIterator for SmallSet { + type IntoIter = smallvec::IntoIter<[T; N]>; + type Item = T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} +impl From<[T; N]> for SmallSet +where + T: Eq, +{ + #[inline] + fn from(items: [T; N]) -> Self { + Self::from_iter(items) + } +} +impl FromIterator for SmallSet +where + T: Eq, +{ + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut set = Self::default(); + for item in iter { + set.insert(item); + } + set + } +} +impl SmallSet +where + T: Eq, +{ + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + pub fn len(&self) -> usize { + self.items.len() + } + + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.items.iter() + } + + #[inline] + pub fn as_slice(&self) -> &[T] { + self.items.as_slice() + } + + pub fn insert(&mut self, item: T) -> bool { + if self.contains(&item) { + return false; + } + self.items.push(item); + true + } + + pub fn remove(&mut self, item: &Q) -> Option + where + T: Borrow, + Q: Eq + ?Sized, + { + match self.find(item) { + Some(idx) => Some(self.items.remove(idx)), + None => None, + } + } + + /// Remove all items from the set for which `predicate` returns false. + pub fn retain(&mut self, predicate: F) + where + F: FnMut(&mut T) -> bool, + { + self.items.retain(predicate); + } + + /// Clear the content of the set + pub fn clear(&mut self) { + self.items.clear(); + } + + pub fn contains(&self, item: &Q) -> bool + where + T: Borrow, + Q: Eq + ?Sized, + { + self.find(item).is_some() + } + + pub fn get(&self, item: &Q) -> Option<&T> + where + T: Borrow, + Q: Eq + ?Sized, + { + match self.find(item) { + Some(idx) => Some(&self.items[idx]), + None => None, + } + } + + pub fn get_mut(&mut self, item: &Q) -> Option<&mut T> + where + T: Borrow, + Q: Eq + ?Sized, + { + match self.find(item) { + Some(idx) => Some(&mut self.items[idx]), + None => None, + } + } + + /// Convert this set into a `SmallVec` containing the items of the set + #[inline] + pub fn into_vec(self) -> SmallVec<[T; N]> { + self.items + } + + #[inline] + fn find(&self, item: &Q) -> Option + where + T: Borrow, + Q: Eq + ?Sized, + { + self.items.iter().position(|elem| elem.borrow() == item) + } +} + +impl SmallSet +where + T: Clone + Eq, +{ + /// Obtain a new [SmallSet] containing the unique elements of both `self` and `other` + pub fn union(&self, other: &Self) -> Self { + let mut result = self.clone(); + + for item in other.items.iter() { + if result.contains(item) { + continue; + } + result.items.push(item.clone()); + } + + result + } + + /// Obtain a new [SmallSet] containing the unique elements of both `self` and `other` + pub fn into_union(mut self, other: &Self) -> Self { + for item in other.items.iter() { + if self.contains(item) { + continue; + } + self.items.push(item.clone()); + } + + self + } + + /// Obtain a new [SmallSet] containing the elements in common between `self` and `other` + pub fn intersection(&self, other: &Self) -> Self { + let mut result = Self::default(); + + for item in self.items.iter() { + if other.contains(item) { + result.items.push(item.clone()); + } + } + + result + } + + /// Obtain a new [SmallSet] containing the elements in common between `self` and `other` + pub fn into_intersection(self, other: &Self) -> Self { + let mut result = Self::default(); + + for item in self.items.into_iter() { + if other.contains(&item) { + result.items.push(item); + } + } + + result + } + + /// Obtain a new [SmallSet] containing the elements in `self` but not in `other` + pub fn difference(&self, other: &Self) -> Self { + let mut result = Self::default(); + + for item in self.items.iter() { + if other.contains(item) { + continue; + } + result.items.push(item.clone()); + } + + result + } + + /// Obtain a new [SmallSet] containing the elements in `self` but not in `other` + pub fn into_difference(mut self, other: &Self) -> Self { + Self { + items: self.items.drain_filter(|item| !other.contains(item)).collect(), + } + } + + /// Obtain a new [SmallSet] containing the elements in `self` or `other`, but not in both + pub fn symmetric_difference(&self, other: &Self) -> Self { + let mut result = Self::default(); + + for item in self.items.iter() { + if other.contains(item) { + continue; + } + result.items.push(item.clone()); + } + + for item in other.items.iter() { + if self.contains(item) { + continue; + } + result.items.push(item.clone()); + } + + result + } +} + +impl core::iter::Extend for SmallSet +where + E: Eq, +{ + fn extend>(&mut self, iter: T) { + for item in iter { + self.insert(item); + } + } +} diff --git a/hir2/src/adt/sparsemap.rs b/hir2/src/adt/sparsemap.rs new file mode 100644 index 000000000..51a1ed290 --- /dev/null +++ b/hir2/src/adt/sparsemap.rs @@ -0,0 +1,233 @@ +//! This module is based on [cranelift_entity::SparseMap], but implemented in-tree +//! because the SparseMapValueTrait is not implemented for any standard library types +use cranelift_entity::{EntityRef, SecondaryMap}; + +pub trait SparseMapValue { + fn key(&self) -> K; +} +impl> SparseMapValue for Box { + fn key(&self) -> K { + (**self).key() + } +} +impl> SparseMapValue for alloc::rc::Rc { + fn key(&self) -> K { + (**self).key() + } +} +impl SparseMapValue for crate::ValueId { + fn key(&self) -> crate::ValueId { + *self + } +} +impl SparseMapValue for crate::BlockId { + fn key(&self) -> crate::BlockId { + *self + } +} + +/// A sparse mapping of entity references. +/// +/// A `SparseMap` map provides: +/// +/// - Memory usage equivalent to `SecondaryMap` + `Vec`, so much smaller than +/// `SecondaryMap` for sparse mappings of larger `V` types. +/// - Constant time lookup, slightly slower than `SecondaryMap`. +/// - A very fast, constant time `clear()` operation. +/// - Fast insert and erase operations. +/// - Stable iteration that is as fast as a `Vec`. +/// +/// # Compared to `SecondaryMap` +/// +/// When should we use a `SparseMap` instead of a secondary `SecondaryMap`? First of all, +/// `SparseMap` does not provide the functionality of a `PrimaryMap` which can allocate and assign +/// entity references to objects as they are pushed onto the map. It is only the secondary entity +/// maps that can be replaced with a `SparseMap`. +/// +/// - A secondary entity map assigns a default mapping to all keys. It doesn't distinguish between +/// an unmapped key and one that maps to the default value. `SparseMap` does not require `Default` +/// values, and it tracks accurately if a key has been mapped or not. +/// - Iterating over the contents of an `SecondaryMap` is linear in the size of the *key space*, +/// while iterating over a `SparseMap` is linear in the number of elements in the mapping. This is +/// an advantage precisely when the mapping is sparse. +/// - `SparseMap::clear()` is constant time and super-fast. `SecondaryMap::clear()` is linear in the +/// size of the key space. (Or, rather the required `resize()` call following the `clear()` is). +/// - `SparseMap` requires the values to implement `SparseMapValue` which means that they must +/// contain their own key. +pub struct SparseMap +where + K: EntityRef, + V: SparseMapValue, +{ + sparse: SecondaryMap, + dense: Vec, +} +impl Default for SparseMap +where + K: EntityRef, + V: SparseMapValue, +{ + fn default() -> Self { + Self { + sparse: SecondaryMap::new(), + dense: Vec::new(), + } + } +} +impl core::fmt::Debug for SparseMap +where + K: EntityRef + core::fmt::Debug, + V: SparseMapValue + core::fmt::Debug, +{ + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + f.debug_map().entries(self.values().map(|v| (v.key(), v))).finish() + } +} +impl SparseMap +where + K: EntityRef, + V: SparseMapValue, +{ + /// Create a new empty mapping. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Returns the number of elements in the map. + pub fn len(&self) -> usize { + self.dense.len() + } + + /// Returns true is the map contains no elements. + pub fn is_empty(&self) -> bool { + self.dense.is_empty() + } + + /// Remove all elements from the mapping. + pub fn clear(&mut self) { + self.dense.clear(); + } + + /// Returns a reference to the value corresponding to the key. + pub fn get(&self, key: K) -> Option<&V> { + if let Some(idx) = self.sparse.get(key).cloned() { + if let Some(entry) = self.dense.get(idx as usize) { + if entry.key() == key { + return Some(entry); + } + } + } + None + } + + /// Returns a mutable reference to the value corresponding to the key. + /// + /// Note that the returned value must not be mutated in a way that would change its key. This + /// would invalidate the sparse set data structure. + pub fn get_mut(&mut self, key: K) -> Option<&mut V> { + if let Some(idx) = self.sparse.get(key).cloned() { + if let Some(entry) = self.dense.get_mut(idx as usize) { + if entry.key() == key { + return Some(entry); + } + } + } + None + } + + /// Return the index into `dense` of the value corresponding to `key`. + fn index(&self, key: K) -> Option { + if let Some(idx) = self.sparse.get(key).cloned() { + let idx = idx as usize; + if let Some(entry) = self.dense.get(idx) { + if entry.key() == key { + return Some(idx); + } + } + } + None + } + + /// Return `true` if the map contains a value corresponding to `key`. + pub fn contains_key(&self, key: K) -> bool { + self.get(key).is_some() + } + + /// Insert a value into the map. + /// + /// If the map did not have this key present, `None` is returned. + /// + /// If the map did have this key present, the value is updated, and the old value is returned. + /// + /// It is not necessary to provide a key since the value knows its own key already. + pub fn insert(&mut self, value: V) -> Option { + let key = value.key(); + + // Replace the existing entry for `key` if there is one. + if let Some(entry) = self.get_mut(key) { + return Some(core::mem::replace(entry, value)); + } + + // There was no previous entry for `key`. Add it to the end of `dense`. + let idx = self.dense.len(); + debug_assert!(idx <= u32::MAX as usize, "SparseMap overflow"); + self.dense.push(value); + self.sparse[key] = idx as u32; + None + } + + /// Remove a value from the map and return it. + pub fn remove(&mut self, key: K) -> Option { + if let Some(idx) = self.index(key) { + let back = self.dense.pop().unwrap(); + + // Are we popping the back of `dense`? + if idx == self.dense.len() { + return Some(back); + } + + // We're removing an element from the middle of `dense`. + // Replace the element at `idx` with the back of `dense`. + // Repair `sparse` first. + self.sparse[back.key()] = idx as u32; + return Some(core::mem::replace(&mut self.dense[idx], back)); + } + + // Nothing to remove. + None + } + + /// Remove the last value from the map. + pub fn pop(&mut self) -> Option { + self.dense.pop() + } + + /// Get an iterator over the values in the map. + /// + /// The iteration order is entirely determined by the preceding sequence of `insert` and + /// `remove` operations. In particular, if no elements were removed, this is the insertion + /// order. + pub fn values(&self) -> core::slice::Iter { + self.dense.iter() + } + + /// Get the values as a slice. + pub fn as_slice(&self) -> &[V] { + self.dense.as_slice() + } +} + +/// Iterating over the elements of a set. +impl<'a, K, V> IntoIterator for &'a SparseMap +where + K: EntityRef, + V: SparseMapValue, +{ + type IntoIter = core::slice::Iter<'a, V>; + type Item = &'a V; + + fn into_iter(self) -> Self::IntoIter { + self.values() + } +} diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index 68b294097..12f0a3259 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -21,6 +21,25 @@ #![feature(generic_const_exprs)] #![feature(new_uninit)] #![feature(clone_to_uninit)] +// The following are used in impls of custom collection types based on SmallVec +#![feature(iter_repeat_n)] +#![feature(std_internals)] // for ByRefSized +#![feature(extend_one)] +#![feature(extend_one_unchecked)] +#![feature(iter_advance_by)] +#![feature(iter_next_chunk)] +#![feature(iter_collect_into)] +#![feature(trusted_len)] +#![feature(never_type)] +#![feature(maybe_uninit_slice)] +#![feature(maybe_uninit_array_assume_init)] +#![feature(maybe_uninit_uninit_array)] +#![feature(maybe_uninit_uninit_array_transpose)] +#![feature(array_into_iter_constructors)] +#![feature(slice_range)] +#![feature(slice_swap_unchecked)] +#![feature(hasher_prefixfree_extras)] +// Some of the above features require us to disable these warnings #![allow(incomplete_features)] #![allow(internal_features)] @@ -39,6 +58,7 @@ pub type FxHashMap = hashbrown::HashMap; pub type FxHashSet = hashbrown::HashSet; pub use rustc_hash::{FxBuildHasher, FxHasher}; +pub mod adt; mod any; mod attributes; pub mod demangle; From 0aa10e5989c376bb8e0f44965a20444ee6726d14 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Mon, 28 Oct 2024 02:40:10 -0400 Subject: [PATCH 27/31] feat: add iterator extension trait for use in hir2 --- hir2/src/itertools.rs | 17 +++++++++++++++++ hir2/src/lib.rs | 2 ++ 2 files changed, 19 insertions(+) create mode 100644 hir2/src/itertools.rs diff --git a/hir2/src/itertools.rs b/hir2/src/itertools.rs new file mode 100644 index 000000000..10e985f07 --- /dev/null +++ b/hir2/src/itertools.rs @@ -0,0 +1,17 @@ +pub trait IteratorExt { + /// Returns true if the given iterator consists of exactly one element + fn has_single_element(&mut self) -> bool; +} + +impl IteratorExt for I { + default fn has_single_element(&mut self) -> bool { + self.next().is_some_and(|_| self.next().is_none()) + } +} + +impl IteratorExt for I { + #[inline] + fn has_single_element(&mut self) -> bool { + self.len() == 1 + } +} diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs index 12f0a3259..cabedb9e1 100644 --- a/hir2/src/lib.rs +++ b/hir2/src/lib.rs @@ -68,6 +68,7 @@ mod folder; pub mod formatter; mod hash; mod ir; +pub mod itertools; pub mod matchers; pub mod pass; mod patterns; @@ -81,6 +82,7 @@ pub use self::{ folder::OperationFolder, hash::{DynHash, DynHasher}, ir::*, + itertools::IteratorExt, patterns::*, }; From 8c32fc6f4f98c115d537e3780ed87fb3e35c29c7 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Mon, 28 Oct 2024 02:45:09 -0400 Subject: [PATCH 28/31] feat: be more conservative with how unsafe refs interact with debug printing --- hir2/src/ir.rs | 3 ++- hir2/src/ir/entity.rs | 14 ++++++++++---- hir2/src/ir/entity/list.rs | 24 ++++++++++++++++++++++++ hir2/src/ir/operation.rs | 13 ++++++++++++- hir2/src/ir/operation/name.rs | 2 +- hir2/src/ir/region.rs | 1 + hir2/src/ir/region/successor.rs | 16 ++++++++++++++-- hir2/src/ir/successor.rs | 27 ++++++++++++++++++++++++--- hir2/src/ir/symbol_table.rs | 11 ++++++++++- 9 files changed, 98 insertions(+), 13 deletions(-) diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index d2af639cb..f54dea8c0 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -38,7 +38,8 @@ pub use self::{ entity::{ Entity, EntityCursor, EntityCursorMut, EntityGroup, EntityId, EntityIter, EntityList, EntityMut, EntityRange, EntityRangeMut, EntityRef, EntityStorage, EntityWithId, - EntityWithParent, RawEntityRef, StorableEntity, UnsafeEntityRef, UnsafeIntrusiveEntityRef, + EntityWithParent, MaybeDefaultEntityIter, RawEntityRef, StorableEntity, UnsafeEntityRef, + UnsafeIntrusiveEntityRef, }, ident::{FunctionIdent, Ident}, immediates::{Felt, FieldElement, Immediate, StarkField}, diff --git a/hir2/src/ir/entity.rs b/hir2/src/ir/entity.rs index c46492a5d..a6d1aa740 100644 --- a/hir2/src/ir/entity.rs +++ b/hir2/src/ir/entity.rs @@ -15,7 +15,7 @@ use core::{ pub use self::{ group::EntityGroup, - list::{EntityCursor, EntityCursorMut, EntityIter, EntityList}, + list::{EntityCursor, EntityCursorMut, EntityIter, EntityList, MaybeDefaultEntityIter}, storage::{EntityRange, EntityRangeMut, EntityStorage}, }; use crate::any::*; @@ -89,7 +89,7 @@ pub trait StorableEntity { } /// A trait that must be implemented by the unique identifier for an [Entity] -pub trait EntityId: Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash { +pub trait EntityId: Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash + fmt::Display { fn as_usize(&self) -> usize; } @@ -565,11 +565,17 @@ impl fmt::Pointer for RawEntityRef { fmt::Pointer::fmt(&Self::as_ptr(self), f) } } +impl fmt::Display for RawEntityRef { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.borrow().id()) + } +} -impl fmt::Debug for RawEntityRef { +impl fmt::Debug for RawEntityRef { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(&self.borrow(), f) + write!(f, "{}", self.borrow().id()) } } impl crate::formatter::PrettyPrint diff --git a/hir2/src/ir/entity/list.rs b/hir2/src/ir/entity/list.rs index f3a8540fe..707a73024 100644 --- a/hir2/src/ir/entity/list.rs +++ b/hir2/src/ir/entity/list.rs @@ -572,6 +572,30 @@ impl<'a, T> DoubleEndedIterator for EntityIter<'a, T> { } } +pub struct MaybeDefaultEntityIter<'a, T> { + iter: Option>, +} +impl<'a, T> Default for MaybeDefaultEntityIter<'a, T> { + fn default() -> Self { + Self { iter: None } + } +} +impl<'a, T> core::iter::FusedIterator for MaybeDefaultEntityIter<'a, T> {} +impl<'a, T> Iterator for MaybeDefaultEntityIter<'a, T> { + type Item = EntityRef<'a, T>; + + #[inline] + fn next(&mut self) -> Option { + self.iter.as_mut().and_then(|iter| iter.next()) + } +} +impl<'a, T> DoubleEndedIterator for MaybeDefaultEntityIter<'a, T> { + #[inline] + fn next_back(&mut self) -> Option { + self.iter.as_mut().and_then(|iter| iter.next_back()) + } +} + pub type IntrusiveLink = intrusive_collections::LinkedListLink; impl RawEntityRef { diff --git a/hir2/src/ir/operation.rs b/hir2/src/ir/operation.rs index bae90cbac..0a2436dfd 100644 --- a/hir2/src/ir/operation.rs +++ b/hir2/src/ir/operation.rs @@ -118,12 +118,23 @@ impl fmt::Debug for Operation { .field("order", &self.order) .field("attrs", &self.attrs) .field("block", &self.block.as_ref().map(|b| b.borrow().id())) - .field("operands", &self.operands) + .field_with("operands", |f| { + let mut list = f.debug_list(); + for operand in self.operands().all() { + list.entry(&operand.borrow()); + } + list.finish() + }) .field("results", &self.results) .field("successors", &self.successors) .finish_non_exhaustive() } } +impl fmt::Debug for OperationRef { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + fmt::Debug::fmt(&self.borrow(), f) + } +} impl AsRef for Operation { fn as_ref(&self) -> &dyn Op { diff --git a/hir2/src/ir/operation/name.rs b/hir2/src/ir/operation/name.rs index d0317e12c..be9a298c7 100644 --- a/hir2/src/ir/operation/name.rs +++ b/hir2/src/ir/operation/name.rs @@ -5,7 +5,7 @@ use core::{ ptr::{DynMetadata, Pointee}, }; -use crate::{interner, traits::TraitInfo, DialectName, Op}; +use crate::{interner, traits::TraitInfo, DialectName}; /// The operation name, or mnemonic, that uniquely identifies an operation. /// diff --git a/hir2/src/ir/region.rs b/hir2/src/ir/region.rs index 86ed702e0..243fa8e85 100644 --- a/hir2/src/ir/region.rs +++ b/hir2/src/ir/region.rs @@ -55,6 +55,7 @@ pub struct Region { } impl Entity for Region {} + impl EntityWithParent for Region { type Parent = Operation; diff --git a/hir2/src/ir/region/successor.rs b/hir2/src/ir/region/successor.rs index 3d69b2f0a..a55f238f1 100644 --- a/hir2/src/ir/region/successor.rs +++ b/hir2/src/ir/region/successor.rs @@ -52,7 +52,13 @@ impl fmt::Debug for RegionSuccessor<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RegionSuccessor") .field("dest", &self.dest) - .field_with("arguments", |f| f.debug_list().entries(self.arguments.iter()).finish()) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for operand in self.arguments.iter() { + list.entry(&operand.borrow()); + } + list.finish() + }) .finish() } } @@ -90,7 +96,13 @@ impl fmt::Debug for RegionSuccessorMut<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RegionSuccessorMut") .field("dest", &self.dest) - .field_with("arguments", |f| f.debug_list().entries(self.arguments.iter()).finish()) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for operand in self.arguments.iter() { + list.entry(&operand.borrow()); + } + list.finish() + }) .finish() } } diff --git a/hir2/src/ir/successor.rs b/hir2/src/ir/successor.rs index 68668595d..b4f37aca3 100644 --- a/hir2/src/ir/successor.rs +++ b/hir2/src/ir/successor.rs @@ -254,7 +254,7 @@ pub trait KeyedSuccessor { } /// This struct tracks successor metadata needed by [crate::Operation] -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SuccessorInfo { pub block: BlockOperandRef, pub(crate) key: Option>, @@ -276,6 +276,15 @@ impl crate::StorableEntity for SuccessorInfo { self.block.unlink(); } } +impl fmt::Debug for SuccessorInfo { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SuccessorInfo") + .field("block", &self.block.borrow()) + .field("key", &self.key) + .field("operand_group", &self.operand_group) + .finish() + } +} /// An [OpSuccessor] is a BlockOperand + OpOperandRange for that block pub struct OpSuccessor<'a> { @@ -286,7 +295,13 @@ impl fmt::Debug for OpSuccessor<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("OpSuccessor") .field("block", &self.dest.borrow().block_id()) - .field_with("arguments", |f| f.debug_list().entries(self.arguments.iter()).finish()) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for operand in self.arguments.iter() { + list.entry(&operand.borrow()); + } + list.finish() + }) .finish() } } @@ -300,7 +315,13 @@ impl fmt::Debug for OpSuccessorMut<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("OpSuccessorMut") .field("block", &self.dest.borrow().block_id()) - .field_with("arguments", |f| f.debug_list().entries(self.arguments.iter()).finish()) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for operand in self.arguments.iter() { + list.entry(&operand.borrow()); + } + list.finish() + }) .finish() } } diff --git a/hir2/src/ir/symbol_table.rs b/hir2/src/ir/symbol_table.rs index 8ee2bf030..f6424aa5e 100644 --- a/hir2/src/ir/symbol_table.rs +++ b/hir2/src/ir/symbol_table.rs @@ -44,7 +44,7 @@ pub enum InvalidSymbolRefError { }, } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SymbolNameAttr { pub user: SymbolUseRef, /// The path through the abstract symbol space to the containing symbol table @@ -100,6 +100,15 @@ impl fmt::Display for SymbolNameAttr { } } } +impl fmt::Debug for SymbolNameAttr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SymbolNameAttr") + .field("user", &self.user.borrow()) + .field("path", &self.path) + .field("name", &self.name) + .finish() + } +} impl crate::formatter::PrettyPrint for SymbolNameAttr { fn render(&self) -> crate::formatter::Document { use crate::formatter::*; From 3ffd2c4befc85dcf4ff69e7a98b937eca6065ea4 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Mon, 28 Oct 2024 02:47:08 -0400 Subject: [PATCH 29/31] feat: implement cfg abstractions and utilities This commit adds the following useful tools (used in later commits): * The `Graph` trait for abstracting over flow graphs * The `GraphVisitor` and `DfsVisitor` primitives for performing and hooking into depth-first traversals of a `Graph` in either pre-order or post-order, or both. * Type aliases for pre- and post-order dfs visitors for blocks * The `GraphDiff` trait and `CfgDiff` data structure for representing pending insertions/deletions in a flow graph. This enables incremental updating of flow graph dependent analyses such as the dominator tree. --- hir2/src/ir.rs | 7 +- hir2/src/ir/block.rs | 366 ++++++++++++++++++++++++++++++++++++ hir2/src/ir/cfg.rs | 245 ++++++++++++++++++++++++ hir2/src/ir/cfg/diff.rs | 301 +++++++++++++++++++++++++++++ hir2/src/ir/cfg/visit.rs | 313 ++++++++++++++++++++++++++++++ hir2/src/ir/region.rs | 49 +++++ hir2/src/ir/visit.rs | 2 - hir2/src/ir/visit/blocks.rs | 102 ---------- 8 files changed, 1278 insertions(+), 107 deletions(-) create mode 100644 hir2/src/ir/cfg.rs create mode 100644 hir2/src/ir/cfg/diff.rs create mode 100644 hir2/src/ir/cfg/visit.rs delete mode 100644 hir2/src/ir/visit/blocks.rs diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index f54dea8c0..de3f8ceb5 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -1,6 +1,7 @@ mod block; mod builder; mod callable; +pub mod cfg; mod component; mod context; mod dialect; @@ -29,7 +30,7 @@ pub use midenc_session::diagnostics::{Report, SourceSpan, Span, Spanned}; pub use self::{ block::{ Block, BlockCursor, BlockCursorMut, BlockId, BlockList, BlockOperand, BlockOperandRef, - BlockRef, + BlockRef, PostOrderBlockIter, PreOrderBlockIter, }, builder::{Builder, BuilderExt, InsertionGuard, Listener, ListenerType, OpBuilder}, callable::*, @@ -80,7 +81,7 @@ pub use self::{ }, verifier::{OpVerifier, Verify}, visit::{ - BlockIter, OpVisitor, OperationVisitor, PostOrderBlockIter, Searcher, SymbolVisitor, - Visitor, WalkOrder, WalkResult, WalkStage, Walkable, + OpVisitor, OperationVisitor, Searcher, SymbolVisitor, Visitor, WalkOrder, WalkResult, + WalkStage, Walkable, }, }; diff --git a/hir2/src/ir/block.rs b/hir2/src/ir/block.rs index 1be0e2a6e..1a5d31e60 100644 --- a/hir2/src/ir/block.rs +++ b/hir2/src/ir/block.rs @@ -1,5 +1,7 @@ use core::fmt; +use smallvec::SmallVec; + use super::*; /// A pointer to a [Block] @@ -10,6 +12,10 @@ pub type BlockList = EntityList; pub type BlockCursor<'a> = EntityCursor<'a, Block>; /// A mutable cursor into a [BlockList] pub type BlockCursorMut<'a> = EntityCursorMut<'a, Block>; +/// An iterator over blocks produced by a depth-first, pre-order visit of the CFG +pub type PreOrderBlockIter = cfg::PreOrderIter; +/// An iterator over blocks produced by a depth-first, post-order visit of the CFG +pub type PostOrderBlockIter = cfg::PostOrderIter; /// The unique identifier for a [Block] #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -90,7 +96,9 @@ impl fmt::Debug for Block { .finish_non_exhaustive() } } + impl Entity for Block {} + impl EntityWithId for Block { type Id = BlockId; @@ -98,6 +106,7 @@ impl EntityWithId for Block { self.id } } + impl EntityWithParent for Block { type Parent = Region; @@ -125,6 +134,7 @@ impl EntityWithParent for Block { } } } + impl Usable for Block { type Use = BlockOperand; @@ -138,6 +148,357 @@ impl Usable for Block { &mut self.uses } } + +impl cfg::Graph for Block { + type ChildEdgeIter = BlockSuccessorEdgesIter; + type ChildIter = BlockSuccessorIter; + type Edge = BlockOperandRef; + type Node = BlockRef; + + fn size(&self) -> usize { + if let Some(term) = self.terminator() { + term.borrow().num_successors() + } else { + 0 + } + } + + fn entry_node(&self) -> Self::Node { + self.as_block_ref() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + BlockSuccessorIter::new(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + BlockSuccessorEdgesIter::new(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + edge.borrow().block.clone() + } +} + +impl<'a> cfg::InvertibleGraph for &'a Block { + type Inverse = cfg::Inverse<&'a Block>; + type InvertibleChildEdgeIter = BlockPredecessorEdgesIter; + type InvertibleChildIter = BlockPredecessorIter; + + fn inverse(self) -> Self::Inverse { + cfg::Inverse::new(self) + } + + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter { + BlockPredecessorIter::new(parent) + } + + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter { + BlockPredecessorEdgesIter::new(parent) + } +} + +impl cfg::Graph for BlockRef { + type ChildEdgeIter = BlockSuccessorEdgesIter; + type ChildIter = BlockSuccessorIter; + type Edge = BlockOperandRef; + type Node = BlockRef; + + fn size(&self) -> usize { + if let Some(term) = self.borrow().terminator() { + term.borrow().num_successors() + } else { + 0 + } + } + + fn entry_node(&self) -> Self::Node { + self.clone() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + BlockSuccessorIter::new(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + BlockSuccessorEdgesIter::new(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + edge.borrow().block.clone() + } +} + +impl cfg::InvertibleGraph for BlockRef { + type Inverse = cfg::Inverse; + type InvertibleChildEdgeIter = BlockPredecessorEdgesIter; + type InvertibleChildIter = BlockPredecessorIter; + + fn inverse(self) -> Self::Inverse { + cfg::Inverse::new(self) + } + + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter { + BlockPredecessorIter::new(parent) + } + + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter { + BlockPredecessorEdgesIter::new(parent) + } +} + +#[doc(hidden)] +pub struct BlockSuccessorIter { + iter: BlockSuccessorEdgesIter, +} +impl BlockSuccessorIter { + pub fn new(parent: BlockRef) -> Self { + Self { + iter: BlockSuccessorEdgesIter::new(parent), + } + } +} +impl ExactSizeIterator for BlockSuccessorIter { + #[inline] + fn len(&self) -> usize { + self.iter.len() + } + + #[inline] + fn is_empty(&self) -> bool { + self.iter.is_empty() + } +} +impl Iterator for BlockSuccessorIter { + type Item = BlockRef; + + fn next(&mut self) -> Option { + self.iter.next().map(|bo| bo.borrow().block.clone()) + } + + #[inline] + fn collect>(self) -> B + where + Self: Sized, + { + let Some(terminator) = self.iter.terminator.as_ref() else { + return B::from_iter([]); + }; + let terminator = terminator.borrow(); + let successors = terminator.successors(); + B::from_iter( + successors.all().as_slice()[self.iter.index..self.iter.num_successors] + .iter() + .map(|succ| succ.block.borrow().block.clone()), + ) + } + + fn collect_into>(self, collection: &mut E) -> &mut E + where + Self: Sized, + { + let Some(terminator) = self.iter.terminator.as_ref() else { + return collection; + }; + let terminator = terminator.borrow(); + let successors = terminator.successors(); + collection.extend( + successors.all().as_slice()[self.iter.index..self.iter.num_successors] + .iter() + .map(|succ| succ.block.borrow().block.clone()), + ); + collection + } +} + +#[doc(hidden)] +pub struct BlockSuccessorEdgesIter { + terminator: Option, + num_successors: usize, + index: usize, +} +impl BlockSuccessorEdgesIter { + pub fn new(parent: BlockRef) -> Self { + let terminator = parent.borrow().terminator(); + let num_successors = terminator.as_ref().map(|t| t.borrow().num_successors()).unwrap_or(0); + Self { + terminator, + num_successors, + index: 0, + } + } +} +impl ExactSizeIterator for BlockSuccessorEdgesIter { + #[inline] + fn len(&self) -> usize { + self.num_successors.saturating_sub(self.index) + } + + #[inline] + fn is_empty(&self) -> bool { + self.index >= self.num_successors + } +} +impl Iterator for BlockSuccessorEdgesIter { + type Item = BlockOperandRef; + + fn next(&mut self) -> Option { + if self.index >= self.num_successors { + return None; + } + + // SAFETY: We'll never have a none terminator if we have non-zero number of successors + let terminator = unsafe { self.terminator.as_ref().unwrap_unchecked() }; + let index = self.index; + self.index += 1; + Some(terminator.borrow().successor(index).dest.clone()) + } + + fn collect>(self) -> B + where + Self: Sized, + { + let Some(terminator) = self.terminator.as_ref() else { + return B::from_iter([]); + }; + let terminator = terminator.borrow(); + let successors = terminator.successors(); + B::from_iter( + successors.all().as_slice()[self.index..self.num_successors] + .iter() + .map(|succ| succ.block.clone()), + ) + } + + fn collect_into>(self, collection: &mut E) -> &mut E + where + Self: Sized, + { + let Some(terminator) = self.terminator.as_ref() else { + return collection; + }; + let terminator = terminator.borrow(); + let successors = terminator.successors(); + collection.extend( + successors.all().as_slice()[self.index..self.num_successors] + .iter() + .map(|succ| succ.block.clone()), + ); + collection + } +} + +#[doc(hidden)] +pub struct BlockPredecessorIter { + preds: SmallVec<[BlockRef; 4]>, + index: usize, +} +impl BlockPredecessorIter { + pub fn new(child: BlockRef) -> Self { + let preds = child.borrow().predecessors().map(|bo| bo.block.clone()).collect(); + Self { preds, index: 0 } + } + + #[inline(always)] + pub fn into_inner(self) -> SmallVec<[BlockRef; 4]> { + self.preds + } +} +impl ExactSizeIterator for BlockPredecessorIter { + #[inline] + fn len(&self) -> usize { + self.preds.len().saturating_sub(self.index) + } + + #[inline] + fn is_empty(&self) -> bool { + self.index >= self.preds.len() + } +} +impl Iterator for BlockPredecessorIter { + type Item = BlockRef; + + #[inline] + fn next(&mut self) -> Option { + if self.is_empty() { + return None; + } + let index = self.index; + self.index += 1; + Some(self.preds[index].clone()) + } + + fn collect>(self) -> B + where + Self: Sized, + { + B::from_iter(self.preds) + } + + fn collect_into>(self, collection: &mut E) -> &mut E + where + Self: Sized, + { + collection.extend(self.preds); + collection + } +} + +#[doc(hidden)] +pub struct BlockPredecessorEdgesIter { + preds: SmallVec<[BlockOperandRef; 4]>, + index: usize, +} +impl BlockPredecessorEdgesIter { + pub fn new(child: BlockRef) -> Self { + let preds = child + .borrow() + .predecessors() + .map(|bo| unsafe { BlockOperandRef::from_raw(&*bo) }) + .collect(); + Self { preds, index: 0 } + } +} +impl ExactSizeIterator for BlockPredecessorEdgesIter { + #[inline] + fn len(&self) -> usize { + self.preds.len().saturating_sub(self.index) + } + + #[inline] + fn is_empty(&self) -> bool { + self.index >= self.preds.len() + } +} +impl Iterator for BlockPredecessorEdgesIter { + type Item = BlockOperandRef; + + #[inline] + fn next(&mut self) -> Option { + if self.is_empty() { + return None; + } + let index = self.index; + self.index += 1; + Some(self.preds[index].clone()) + } + + fn collect>(self) -> B + where + Self: Sized, + { + B::from_iter(self.preds) + } + + fn collect_into>(self, collection: &mut E) -> &mut E + where + Self: Sized, + { + collection.extend(self.preds); + collection + } +} + impl Block { pub fn new(id: BlockId) -> Self { Self { @@ -165,6 +526,11 @@ impl Block { self.region.as_ref().and_then(|region| region.borrow().parent()) } + /// Get a handle to the ancestor [Block] of this block, if one is present + pub fn parent_block(&self) -> Option { + self.parent_op().and_then(|op| op.borrow().parent()) + } + /// Returns true if this block is the entry block for its containing region pub fn is_entry_block(&self) -> bool { if let Some(parent) = self.region.as_ref().map(|r| r.borrow()) { diff --git a/hir2/src/ir/cfg.rs b/hir2/src/ir/cfg.rs new file mode 100644 index 000000000..12bd4e764 --- /dev/null +++ b/hir2/src/ir/cfg.rs @@ -0,0 +1,245 @@ +mod diff; +mod visit; + +pub use self::{ + diff::{CfgDiff, CfgUpdate, CfgUpdateKind, GraphDiff}, + visit::{DefaultGraphVisitor, GraphVisitor, LazyDfsVisitor, PostOrderIter, PreOrderIter}, +}; + +/// This is an abstraction over graph-like structures used in the IR: +/// +/// * The CFG of a region, i.e. graph of blocks +/// * The CFG reachable from a single block, i.e. graph of blocks +/// * The dominator graph of a region, i.e. graph of dominator nodes +/// * The call graph of a program +/// * etc... +/// +/// It isn't strictly necessary, but it provides some uniformity, and is useful particularly +/// for implementation of various analyses. +pub trait Graph { + /// The type of node represented in the graph. + /// + /// Typically this should be a pointer-like reference type, cheap to copy/clone. + type Node: Clone; + /// Type used to iterate over children of a node in the graph. + type ChildIter: ExactSizeIterator; + /// The type used to represent an edge in the graph. + /// + /// This should be cheap to copy/clone. + type Edge; + /// Type used to iterate over child edges of a node in the graph. + type ChildEdgeIter: ExactSizeIterator; + + /// An empty graph has no nodes. + #[inline] + fn is_empty(&self) -> bool { + self.size() == 0 + } + /// Get the number of nodes in this graph + fn size(&self) -> usize; + /// Get the entry node of the graph. + /// + /// It is expected that a graph always has an entry. As such, this function will panic if + /// called on an "empty" graph. You should check whether the graph is empty _first_, if you + /// are working with a possibly-empty graph. + fn entry_node(&self) -> Self::Node; + /// Get an iterator over the children of `parent` + fn children(parent: Self::Node) -> Self::ChildIter; + /// Get an iterator over the children edges of `parent` + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter; + /// Return the destination node of an edge. + fn edge_dest(edge: Self::Edge) -> Self::Node; +} + +impl<'a, G: Graph> Graph for &'a G { + type ChildEdgeIter = ::ChildEdgeIter; + type ChildIter = ::ChildIter; + type Edge = ::Edge; + type Node = ::Node; + + fn is_empty(&self) -> bool { + (**self).is_empty() + } + + fn size(&self) -> usize { + (**self).size() + } + + fn entry_node(&self) -> Self::Node { + (**self).entry_node() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + ::children(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + ::children_edges(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + ::edge_dest(edge) + } +} + +/// An [InvertibleGraph] is a [Graph] which can be "inverted", i.e. edges are reversed. +/// +/// Technically, any graph is invertible, however we are primarily interested in supporting graphs +/// for which an inversion of itself has some semantic value. For example, visiting a CFG in +/// reverse is useful in various contexts, such as constructing dominator trees. +/// +/// This is primarily consumed via [Inverse]. +pub trait InvertibleGraph: Graph { + /// The type of this graph's inversion + /// + /// This is primarily useful in cases where you inverse the inverse of a graph - by allowing + /// the types to differ, you can recover the original graph, rather than having to emulate + /// both uninverted graphs using the inverse type. + /// + /// See [Inverse] for an example of how this is used. + type Inverse: Graph; + /// The type of iterator used to visit "inverted" children of a node in this graph, i.e. + /// the predecessors. + type InvertibleChildIter: ExactSizeIterator; + /// The type of iterator used to obtain the set of "inverted" children edges of a node in this + /// graph, i.e. the predecessor edges. + type InvertibleChildEdgeIter: ExactSizeIterator; + + /// Get an iterator over the predecessors of `parent`. + /// + /// NOTE: `parent` in this case will actually be a child of the nodes in the iterator, but we + /// preserve the naming so as to make it apparent we are working with an inversion of the + /// original graph. + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter; + /// Get an iterator over the predecessor edges of `parent`. + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter; + /// Obtain the inversion of this graph + fn inverse(self) -> Self::Inverse; +} + +/// This is a wrapper type for [Graph] implementations, used to indicate that iterating a +/// graph should be iterated in "inverse" order, the semantics of which depend on the graph. +/// +/// If used with an [InvertibleGraph], it uses the graph impls inverse iterators. If used with a +/// graph that is _not_ invertible, it uses the graph impls normal iterators. Effectively, this is +/// a specialization marker type. +pub struct Inverse { + graph: G, +} + +impl Inverse { + /// Construct an inversion over `graph` + #[inline] + pub fn new(graph: G) -> Self { + Self { graph } + } +} + +impl Graph for Inverse { + type ChildEdgeIter = InverseChildEdgeIter<::InvertibleChildEdgeIter>; + type ChildIter = InverseChildIter<::InvertibleChildIter>; + type Edge = ::Edge; + type Node = ::Node; + + fn is_empty(&self) -> bool { + self.graph.is_empty() + } + + fn size(&self) -> usize { + self.graph.size() + } + + fn entry_node(&self) -> Self::Node { + self.graph.entry_node() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + InverseChildIter::new(::inverse_children(parent)) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + InverseChildEdgeIter::new(::inverse_children_edges(parent)) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + ::edge_dest(edge) + } +} + +impl InvertibleGraph for Inverse { + type Inverse = G; + type InvertibleChildEdgeIter = ::ChildEdgeIter; + type InvertibleChildIter = ::ChildIter; + + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter { + ::children(parent) + } + + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter { + ::children_edges(parent) + } + + fn inverse(self) -> Self::Inverse { + self.graph + } +} + +/// An iterator returned by `children` that iterates over `inverse_children` of the underlying graph +#[doc(hidden)] +pub struct InverseChildIter { + iter: I, +} + +impl InverseChildIter { + pub fn new(iter: I) -> Self { + Self { iter } + } +} +impl ExactSizeIterator for InverseChildIter { + #[inline] + fn len(&self) -> usize { + self.iter.len() + } + + #[inline] + fn is_empty(&self) -> bool { + self.iter.is_empty() + } +} +impl Iterator for InverseChildIter { + type Item = ::Item; + + default fn next(&mut self) -> Option { + self.iter.next() + } +} + +/// An iterator returned by `children_edges` that iterates over `inverse_children_edges` of the +/// underlying graph. +#[doc(hidden)] +pub struct InverseChildEdgeIter { + iter: I, +} +impl InverseChildEdgeIter { + pub fn new(iter: I) -> Self { + Self { iter } + } +} +impl ExactSizeIterator for InverseChildEdgeIter { + #[inline] + fn len(&self) -> usize { + self.iter.len() + } + + #[inline] + fn is_empty(&self) -> bool { + self.iter.is_empty() + } +} +impl Iterator for InverseChildEdgeIter { + type Item = ::Item; + + default fn next(&mut self) -> Option { + self.iter.next() + } +} diff --git a/hir2/src/ir/cfg/diff.rs b/hir2/src/ir/cfg/diff.rs new file mode 100644 index 000000000..876a2de66 --- /dev/null +++ b/hir2/src/ir/cfg/diff.rs @@ -0,0 +1,301 @@ +use core::fmt; + +use smallvec::SmallVec; + +use crate::{adt::SmallMap, BlockRef}; + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum CfgUpdateKind { + Insert, + Delete, +} + +#[derive(Clone, PartialEq, Eq)] +pub struct CfgUpdate { + kind: CfgUpdateKind, + from: BlockRef, + to: BlockRef, +} +impl CfgUpdate { + #[inline(always)] + pub const fn kind(&self) -> CfgUpdateKind { + self.kind + } + + #[inline(always)] + pub const fn from(&self) -> &BlockRef { + &self.from + } + + #[inline(always)] + pub const fn to(&self) -> &BlockRef { + &self.to + } +} +impl fmt::Debug for CfgUpdate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(match self.kind { + CfgUpdateKind::Insert => "Insert", + CfgUpdateKind::Delete => "Delete", + }) + .field("from", &self.from) + .field("to", &self.to) + .finish() + } +} + +#[derive(Default, Clone)] +struct DeletesInserts { + deletes: SmallVec<[BlockRef; 2]>, + inserts: SmallVec<[BlockRef; 2]>, +} +impl DeletesInserts { + pub fn di(&self, is_insert: bool) -> &SmallVec<[BlockRef; 2]> { + if is_insert { + &self.inserts + } else { + &self.deletes + } + } + + pub fn di_mut(&mut self, is_insert: bool) -> &mut SmallVec<[BlockRef; 2]> { + if is_insert { + &mut self.inserts + } else { + &mut self.deletes + } + } +} + +pub trait GraphDiff { + fn is_empty(&self) -> bool; + fn legalized_updates(&self) -> &[CfgUpdate]; + fn num_legalized_updates(&self) -> usize { + self.legalized_updates().len() + } + fn pop_update_for_incremental_updates(&mut self) -> CfgUpdate; + fn get_children(&self, node: &BlockRef) -> SmallVec<[BlockRef; 8]>; +} + +/// GraphDiff defines a CFG snapshot: given a set of Update, provides +/// a getChildren method to get a Node's children based on the additional updates +/// in the snapshot. The current diff treats the CFG as a graph rather than a +/// multigraph. Added edges are pruned to be unique, and deleted edges will +/// remove all existing edges between two blocks. +/// +/// Two booleans are used to define orders in graphs: +/// InverseGraph defines when we need to reverse the whole graph and is as such +/// also equivalent to applying updates in reverse. +/// InverseEdge defines whether we want to change the edges direction. E.g., for +/// a non-inversed graph, the children are naturally the successors when +/// InverseEdge is false and the predecessors when InverseEdge is true. +#[derive(Clone)] +pub struct CfgDiff { + succ: SmallMap, + pred: SmallMap, + /// By default, it is assumed that, given a CFG and a set of updates, we wish + /// to apply these updates as given. If UpdatedAreReverseApplied is set, the + /// updates will be applied in reverse: deleted edges are considered re-added + /// and inserted edges are considered deleted when returning children. + updated_are_reverse_applied: bool, + /// Keep the list of legalized updates for a deterministic order of updates + /// when using a GraphDiff for incremental updates in the DominatorTree. + /// The list is kept in reverse to allow popping from end. + legalized_updates: SmallVec<[CfgUpdate; 4]>, +} + +impl Default for CfgDiff { + fn default() -> Self { + Self { + succ: Default::default(), + pred: Default::default(), + updated_are_reverse_applied: false, + legalized_updates: Default::default(), + } + } +} + +impl CfgDiff { + pub fn new(updates: I, reverse_apply_updates: bool) -> Self + where + I: ExactSizeIterator, + { + let mut this = Self { + legalized_updates: legalize_updates(updates, INVERSE_GRAPH, false), + ..Default::default() + }; + for update in this.legalized_updates.iter() { + let is_insert = matches!(update.kind(), CfgUpdateKind::Insert) || reverse_apply_updates; + this.succ + .entry(update.from.clone()) + .or_default() + .di_mut(is_insert) + .push(update.to.clone()); + this.pred + .entry(update.to.clone()) + .or_default() + .di_mut(is_insert) + .push(update.from.clone()); + } + this.updated_are_reverse_applied = reverse_apply_updates; + this + } +} + +impl GraphDiff for CfgDiff { + fn is_empty(&self) -> bool { + self.succ.is_empty() && self.pred.is_empty() && self.legalized_updates.is_empty() + } + + #[inline(always)] + fn legalized_updates(&self) -> &[CfgUpdate] { + &self.legalized_updates + } + + fn pop_update_for_incremental_updates(&mut self) -> CfgUpdate { + assert!(!self.legalized_updates.is_empty(), "no updates to apply"); + let update = self.legalized_updates.pop().unwrap(); + let is_insert = + matches!(update.kind(), CfgUpdateKind::Insert) || self.updated_are_reverse_applied; + let succ_di_list = &mut self.succ[update.from()]; + let is_empty = { + let succ_list = succ_di_list.di_mut(is_insert); + assert_eq!(succ_list.last(), Some(update.to())); + succ_list.pop(); + succ_list.is_empty() + }; + if is_empty && succ_di_list.di(!is_insert).is_empty() { + self.succ.remove(update.from()); + } + + let pred_di_list = &mut self.pred[update.to()]; + let pred_list = pred_di_list.di_mut(is_insert); + assert_eq!(pred_list.last(), Some(update.from())); + pred_list.pop(); + if pred_list.is_empty() && pred_di_list.di(!is_insert).is_empty() { + self.pred.remove(update.to()); + } + update + } + + fn get_children(&self, node: &BlockRef) -> SmallVec<[BlockRef; 8]> { + let mut r = crate::dominance::nca::get_children::(node); + if !INVERSE_EDGE { + r.reverse(); + } + + let children = if INVERSE_EDGE != INVERSE_GRAPH { + &self.pred + } else { + &self.succ + }; + let Some(found) = children.get(node) else { + return r; + }; + + // Remove children present in the CFG but not in the snapshot. + for child in found.di(false) { + r.retain(|c| c != child); + } + + // Add children present in the snapshot for not in the real CFG. + r.extend(found.di(true).iter().cloned()); + + r + } +} + +/// `legalize_updates` simplifies updates assuming a graph structure. +/// +/// This function serves double purpose: +/// +/// 1. It removes redundant updates, which makes it easier to reverse-apply them when traversing +/// CFG. +/// 2. It optimizes away updates that cancel each other out, as the end result is the same. +fn legalize_updates( + all_updates: I, + inverse_graph: bool, + reverse_result_order: bool, +) -> SmallVec<[CfgUpdate; 4]> +where + I: ExactSizeIterator, +{ + #[derive(Default, Copy, Clone)] + struct UpdateOp { + num_insertions: i32, + index: u32, + } + + // Count the total number of inserions of each edge. + // Each insertion adds 1 and deletion subtracts 1. The end number should be one of: + // + // * `-1` (deletion) + // * `0` (NOP), + // * `1` (insertion). + // + // Otherwise, the sequence of updates contains multiple updates of the same kind and we assert + // for that case. + let mut operations = + SmallMap::<(BlockRef, BlockRef), UpdateOp, 4>::with_capacity(all_updates.len()); + + for ( + i, + CfgUpdate { + kind, + mut from, + mut to, + }, + ) in all_updates.enumerate() + { + if inverse_graph { + // Reverse edge for post-dominators + core::mem::swap(&mut from, &mut to); + } + + operations + .entry((from, to)) + .or_insert_with(|| UpdateOp { + num_insertions: 0, + index: i as u32, + }) + .num_insertions += match kind { + CfgUpdateKind::Insert => 1, + CfgUpdateKind::Delete => -1, + }; + } + + let mut result = SmallVec::<[CfgUpdate; 4]>::with_capacity(operations.len()); + for ((from, to), update_op) in operations.iter() { + assert!(update_op.num_insertions.abs() <= 1, "unbalanced operations!"); + if update_op.num_insertions == 0 { + continue; + } + let kind = if update_op.num_insertions > 0 { + CfgUpdateKind::Insert + } else { + CfgUpdateKind::Delete + }; + result.push(CfgUpdate { + kind, + from: from.clone(), + to: to.clone(), + }); + } + + // Make the order consistent by not relying on pointer values within the set. Reuse the old + // operations map. + // + // In the future, we should sort by something else to minimize the amount of work needed to + // perform the series of updates. + result.sort_by(|a, b| { + let op_a = &operations[&(a.from.clone(), a.to.clone())]; + let op_b = &operations[&(b.from.clone(), b.to.clone())]; + if reverse_result_order { + op_a.index.cmp(&op_b.index) + } else { + op_a.index.cmp(&op_b.index).reverse() + } + }); + + result +} diff --git a/hir2/src/ir/cfg/visit.rs b/hir2/src/ir/cfg/visit.rs new file mode 100644 index 000000000..69c4ade14 --- /dev/null +++ b/hir2/src/ir/cfg/visit.rs @@ -0,0 +1,313 @@ +use core::ops::ControlFlow; + +use smallvec::SmallVec; + +use super::Graph; +use crate::adt::SmallSet; + +/// By implementing this trait, you can refine the traversal performed by [DfsVisitor], as well as +/// hook in custom behavior to be executed upon reaching a node in both pre-order and post-order +/// visits. +/// +/// There are two callbacks, both with default implementations that align with the default +/// semantics of a depth-first traversal. +/// +/// If you wish to prune the search, the best place to do so is [GraphVisitor::on_node_reached], +/// as it provides the opportunity to control whether or not the visitor will visit any of the +/// node's successors as well as emit the node during iteration. +#[allow(unused_variables)] +pub trait GraphVisitor { + type Node; + + /// Called when a node is first reached during a depth-first traversal, i.e. pre-order + /// + /// If this function returns `ControlFlow::Break`, none of `node`'s successors will be visited, + /// and `node` will not be emitted by the visitor. This can be used to prune the traversal, + /// e.g. confining a visit to a specific loop in a CFG. + fn on_node_reached(&mut self, from: Option<&Self::Node>, node: &Self::Node) -> ControlFlow<()> { + ControlFlow::Continue(()) + } + + /// Called when all successors of a node have been visited by the depth-first traversal, i.e. + /// post-order. + fn on_block_visited(&mut self, node: &Self::Node) {} +} + +/// A useful no-op visitor for when you want the default behavior. +pub struct DefaultGraphVisitor(core::marker::PhantomData); +impl Default for DefaultGraphVisitor { + fn default() -> Self { + Self(core::marker::PhantomData) + } +} +impl GraphVisitor for DefaultGraphVisitor { + type Node = T; +} + +/// A basic iterator over a depth-first traversal of nodes in a graph, producing them in pre-order. +#[repr(transparent)] +pub struct PreOrderIter(LazyDfsVisitor::Node>>) +where + G: Graph; +impl PreOrderIter +where + G: Graph, + ::Node: Eq, +{ + /// Visit all nodes reachable from `root` in pre-order + pub fn new(root: ::Node) -> Self { + Self(LazyDfsVisitor::new(root, DefaultGraphVisitor::default())) + } + + /// Visit all nodes reachable from `root` in pre-order, treating the nodes in `visited` as + /// already visited, skipping them (and their successors) during the traversal. + pub fn new_with_visited( + root: ::Node, + visited: impl IntoIterator::Node>, + ) -> Self { + Self(LazyDfsVisitor::new_with_visited(root, DefaultGraphVisitor::default(), visited)) + } +} +impl core::iter::FusedIterator for PreOrderIter +where + G: Graph, + ::Node: Eq, +{ +} +impl Iterator for PreOrderIter +where + G: Graph, + ::Node: Eq, +{ + type Item = ::Node; + + #[inline] + fn next(&mut self) -> Option { + self.0.next::() + } +} + +/// A basic iterator over a depth-first traversal of nodes in a graph, producing them in post-order. +#[repr(transparent)] +pub struct PostOrderIter(LazyDfsVisitor::Node>>) +where + G: Graph; +impl PostOrderIter +where + G: Graph, + ::Node: Eq, +{ + /// Visit all nodes reachable from `root` in post-order + #[inline] + pub fn new(root: ::Node) -> Self { + Self(LazyDfsVisitor::new(root, DefaultGraphVisitor::default())) + } + + /// Visit all nodes reachable from `root` in post-order, treating the nodes in `visited` as + /// already visited, skipping them (and their successors) during the traversal. + pub fn new_with_visited( + root: ::Node, + visited: impl IntoIterator::Node>, + ) -> Self { + Self(LazyDfsVisitor::new_with_visited(root, DefaultGraphVisitor::default(), visited)) + } +} +impl core::iter::FusedIterator for PostOrderIter +where + G: Graph, + ::Node: Eq, +{ +} +impl Iterator for PostOrderIter +where + G: Graph, + ::Node: Eq, +{ + type Item = ::Node; + + #[inline] + fn next(&mut self) -> Option { + self.0.next::() + } +} + +/// This type is an iterator over a depth-first traversal of a graph, with customization hooks +/// provided via the [GraphVisitor] trait. +/// +/// The order in which nodes are produced by the iterator depends on how you invoke the `next` +/// method - it must be instantiated with a constant boolean that indicates whether or not the +/// iteration is to produce nodes in post-order. +/// +/// As a result, this type does not implement `Iterator` itself - it is meant to be consumed as +/// an internal detail of higher-level iterator types. Two such types are provided in this module +/// for common pre- and post-order iterations: +/// +/// * [PreOrderIter], for iterating in pre-order +/// * [PostOrderIter], for iterating in post-order +/// +pub struct LazyDfsVisitor { + /// The nodes we have already visited, or wish to consider visited + visited: SmallSet<::Node, 32>, + /// The stack of discovered nodes currently being visited + stack: SmallVec<[VisitNode<::Node>; 8]>, + /// A [GraphVisitor] implementation used to hook into the traversal machinery + visitor: V, +} + +/// Represents a node in the graph which has been reached during traversal, and is in the process of +/// being visited. +struct VisitNode { + /// The parent node in the graph from which this node was drived + parent: Option, + /// The node in the underlying graph being visited + node: T, + /// The successors of this node + successors: SmallVec<[T; 2]>, + /// Set to `true` once this node has been handled by [GraphVisitor::on_node_reached] + reached: bool, +} +impl VisitNode +where + T: Clone, +{ + #[inline] + pub fn node(&self) -> T { + self.node.clone() + } + + /// Returns true if no successors remain to be visited under this node + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.successors.is_empty() + } + + /// Get the next successor of this node, and advance the `next_child` index + /// + /// It is expected that the caller has already checked [is_empty]. Failing to do so + /// will cause this function to panic once all successors have been visited. + pub fn pop_successor(&mut self) -> T { + self.successors.pop().unwrap() + } +} + +impl LazyDfsVisitor +where + G: Graph, + ::Node: Eq, + V: GraphVisitor::Node>, +{ + /// Visit the graph rooted under `from`, using the provided visitor for customization hooks. + pub fn new(from: ::Node, visitor: V) -> Self { + Self::new_with_visited(from, visitor, None::<::Node>) + } + + /// Visit the graph rooted under `from`, using the provided visitor for customization hooks. + /// + /// The initial set of "visited" nodes is seeded with `visited`. Any node in this set (and their + /// children) will be skipped during iteration and by the traversal itself. If `from` is in this + /// set, then the resulting iterator will be empty (i.e. produce no nodes, and perform no + /// traversal). + pub fn new_with_visited( + from: ::Node, + visitor: V, + visited: impl IntoIterator::Node>, + ) -> Self { + use smallvec::smallvec; + + let visited = visited.into_iter().collect::>(); + if visited.contains(&from) { + // The root node itself is being ignored, return an empty iterator + return Self { + visited, + stack: smallvec![], + visitor, + }; + } + + let successors = SmallVec::from_iter(G::children(from.clone())); + Self { + visited, + stack: smallvec![VisitNode { + parent: None, + node: from, + successors, + reached: false, + }], + visitor, + } + } + + /// Step the visitor forward one step. + /// + /// The semantics of a step depend on the value of `POSTORDER`: + /// + /// * If `POSTORDER == true`, then we resume traversal of the graph until the next node that + /// has had all of its successors visited is on top of the visit stack. + /// * If `POSTORDER == false`, then we resume traversal of the graph until the next unvisited + /// node is reached for the first time. + /// + /// In both cases, the node we find by the search is what is returned. If no more nodes remain + /// to be visited, this returns `None`. + /// + /// This function invokes the associated [GraphVisitor] callbacks during the traversal, at the + /// appropriate time. + #[allow(clippy::should_implement_trait)] + pub fn next(&mut self) -> Option<::Node> { + loop { + let Some(node) = self.stack.last_mut() else { + break None; + }; + + if !node.reached { + node.reached = true; + let unvisited = self.visited.insert(node.node()); + if !unvisited { + let _ = unsafe { self.stack.pop().unwrap_unchecked() }; + continue; + } + + // Handle pre-order visit + let should_visit = + self.visitor.on_node_reached(node.parent.as_ref(), &node.node).is_continue(); + if !should_visit { + // It was indicated we shouldn't visit this node, so move to the next + let _ = unsafe { self.stack.pop().unwrap_unchecked() }; + continue; + } + + if POSTORDER { + // We need to visit this node's successors first + continue; + } else { + // We're going to visit this node's successors on the next call + break Some(node.node.clone()); + } + } + + // Otherwise, we're visiting a successor of this node. + // + // If this node has no successors, we're done visiting it + // If we've visited all successors of this node, we've got our next item + if node.is_empty() { + let node = unsafe { self.stack.pop().unwrap_unchecked() }; + self.visitor.on_block_visited(&node.node); + if POSTORDER { + break Some(node.node); + } else { + continue; + } + } + + // Otherwise, continue visiting successors + let parent = node.node(); + let successor = node.pop_successor(); + let successors = SmallVec::from_iter(G::children(successor.clone())); + self.stack.push(VisitNode { + parent: Some(parent), + node: successor, + successors, + reached: false, + }); + } + } +} diff --git a/hir2/src/ir/region.rs b/hir2/src/ir/region.rs index 243fa8e85..289600d38 100644 --- a/hir2/src/ir/region.rs +++ b/hir2/src/ir/region.rs @@ -84,6 +84,55 @@ impl EntityWithParent for Region { } } +impl cfg::Graph for Region { + type ChildEdgeIter = block::BlockSuccessorEdgesIter; + type ChildIter = block::BlockSuccessorIter; + type Edge = BlockOperandRef; + type Node = BlockRef; + + fn is_empty(&self) -> bool { + self.body.is_empty() + } + + fn size(&self) -> usize { + self.body.len() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + block::BlockSuccessorIter::new(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + block::BlockSuccessorEdgesIter::new(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + edge.borrow().block.clone() + } + + fn entry_node(&self) -> Self::Node { + self.body.front().as_pointer().expect("empty region") + } +} + +impl<'a> cfg::InvertibleGraph for &'a Region { + type Inverse = cfg::Inverse<&'a Region>; + type InvertibleChildEdgeIter = block::BlockPredecessorEdgesIter; + type InvertibleChildIter = block::BlockPredecessorIter; + + fn inverse(self) -> Self::Inverse { + cfg::Inverse::new(self) + } + + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter { + block::BlockPredecessorIter::new(parent) + } + + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter { + block::BlockPredecessorEdgesIter::new(parent) + } +} + /// Blocks impl Region { /// Returns true if this region is empty (has no blocks) diff --git a/hir2/src/ir/visit.rs b/hir2/src/ir/visit.rs index 64fe09412..9f36b1c7b 100644 --- a/hir2/src/ir/visit.rs +++ b/hir2/src/ir/visit.rs @@ -1,4 +1,3 @@ -mod blocks; mod searcher; mod visitor; mod walkable; @@ -6,7 +5,6 @@ mod walkable; pub use core::ops::ControlFlow; pub use self::{ - blocks::{BlockIter, PostOrderBlockIter}, searcher::Searcher, visitor::{OpVisitor, OperationVisitor, SymbolVisitor, Visitor}, walkable::{WalkOrder, WalkStage, Walkable}, diff --git a/hir2/src/ir/visit/blocks.rs b/hir2/src/ir/visit/blocks.rs deleted file mode 100644 index 54b7e69c2..000000000 --- a/hir2/src/ir/visit/blocks.rs +++ /dev/null @@ -1,102 +0,0 @@ -use alloc::collections::BTreeSet; - -use crate::BlockRef; - -#[allow(unused_variables)] -pub trait BlockVisitor { - /// Called when a block is first reached during a depth-first traversal, i.e. called in preorder - /// - /// If this function returns `false`, none of `block`'s children will be visited. This can be - /// used to prune the traversal, e.g. confining a visit to a specific loop in the CFG. - fn on_block_reached(&mut self, from: Option<&BlockRef>, block: &BlockRef) -> bool { - true - } - - /// Called when all children of a block have been visited by the depth-first traversal, i.e. - /// called in postorder. - fn on_block_visited(&mut self, block: &BlockRef) {} -} - -impl BlockVisitor for () {} - -#[repr(transparent)] -pub struct PostOrderBlockIter(BlockIter<()>); -impl PostOrderBlockIter { - #[inline] - pub fn new(root: BlockRef) -> Self { - Self(BlockIter::new(root, ())) - } -} -impl core::iter::FusedIterator for PostOrderBlockIter {} -impl Iterator for PostOrderBlockIter { - type Item = BlockRef; - - #[inline(always)] - fn next(&mut self) -> Option { - self.0.next() - } -} - -pub struct BlockIter { - visited: BTreeSet, - // First element is the basic block, second is the index of the next child to visit, third is the number of children - stack: Vec<(BlockRef, usize, usize)>, - visitor: V, -} - -impl BlockIter { - pub fn new(from: BlockRef, visitor: V) -> Self { - let mut this = Self { - visited: Default::default(), - stack: Default::default(), - visitor, - }; - this.insert_edge(None, from.clone()); - let num_successors = from.borrow().num_successors(); - this.stack.push((from, 0, num_successors)); - this.traverse_child(); - this - } - - /// Returns true if the target of the given edge should be visited. - /// - /// Called with `None` for `from` when adding the root node. - fn insert_edge(&mut self, from: Option, to: BlockRef) -> bool { - let should_visit = self.visitor.on_block_reached(from.as_ref(), &to); - let unvisited = self.visited.insert(to); - unvisited && should_visit - } - - fn traverse_child(&mut self) { - loop { - let Some((entry, index, max)) = self.stack.last_mut() else { - break; - }; - if index == max { - break; - } - let successor = entry.borrow().get_successor(*index); - *index += 1; - let entry = entry.clone(); - if self.insert_edge(Some(entry), successor.clone()) { - // If the block is not visited.. - let num_successors = successor.borrow().num_successors(); - self.stack.push((successor, 0, num_successors)); - } - } - } -} - -impl core::iter::FusedIterator for BlockIter {} -impl Iterator for BlockIter { - type Item = BlockRef; - - fn next(&mut self) -> Option { - let (next, ..) = self.stack.pop()?; - self.visitor.on_block_visited(&next); - if !self.stack.is_empty() { - self.traverse_child(); - } - Some(next) - } -} From 64a0be33a5f9e44d2191ed01fe122b5e37dd2e29 Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Mon, 28 Oct 2024 02:51:36 -0400 Subject: [PATCH 30/31] feat: implement dominator and post-dominator analyses This introduces the incrementally updateable dominator and post-dominator analyses from LLVM, based on the Semi-NCA algorithm. This enables us to keep the (post-)dominator tree analysis up to date during rewriting of the IR, without needing to recompute it from scratch each time. This is an initial port of the original C++ code, modified to be more Rust-like, while still remaining quite close to the original. This could be easily cleaned up/rewritten in the future to be more idiomatic Rust, but for now it should suffice. --- hir2/src/ir.rs | 1 + hir2/src/ir/block.rs | 137 +++ hir2/src/ir/dominance.rs | 17 + hir2/src/ir/dominance/info.rs | 419 +++++++++ hir2/src/ir/dominance/nca.rs | 1531 +++++++++++++++++++++++++++++++ hir2/src/ir/dominance/traits.rs | 193 ++++ hir2/src/ir/dominance/tree.rs | 911 ++++++++++++++++++ 7 files changed, 3209 insertions(+) create mode 100644 hir2/src/ir/dominance.rs create mode 100644 hir2/src/ir/dominance/info.rs create mode 100644 hir2/src/ir/dominance/nca.rs create mode 100644 hir2/src/ir/dominance/traits.rs create mode 100644 hir2/src/ir/dominance/tree.rs diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index de3f8ceb5..9f0da7ebf 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -5,6 +5,7 @@ pub mod cfg; mod component; mod context; mod dialect; +pub mod dominance; mod entity; mod ident; mod immediates; diff --git a/hir2/src/ir/block.rs b/hir2/src/ir/block.rs index 1a5d31e60..4b3622928 100644 --- a/hir2/src/ir/block.rs +++ b/hir2/src/ir/block.rs @@ -747,6 +747,143 @@ impl Block { } } +/// Ancestors +impl Block { + pub fn is_legal_to_hoist_into(&self) -> bool { + use crate::traits::{HasSideEffects, ReturnLike}; + + // No terminator means the block is under construction, and thus legal to hoist into + let Some(terminator) = self.terminator() else { + return true; + }; + + // If the block has no successors, it can never be legal to hoist into it, there is nothing + // to hoist! + if self.num_successors() == 0 { + return false; + } + + // Instructions should not be hoisted across effectful or return-like terminators. This is + // typically only exception handling intrinsics, which HIR doesn't really have, but which + // we may nevertheless want to represent in the future. + // + // NOTE: Most return-like terminators would have no successors, but in LLVM, for example, + // there are instructions like `catch_ret`, which semantically are return-like, but which + // have a successor block (the landing pad). + let terminator = terminator.borrow(); + terminator.implements::() || terminator.implements::() + } + + pub fn has_ssa_dominance(&self) -> bool { + self.parent_op() + .and_then(|op| { + op.borrow() + .as_trait::() + .map(|rki| rki.has_ssa_dominance()) + }) + .unwrap_or(true) + } + + /// Walk up the ancestor blocks of `block`, until `f` returns `true` for a block. + /// + /// NOTE: `block` is visited before any of its ancestors. + pub fn traverse_ancestors(block: BlockRef, mut f: F) -> Option + where + F: FnMut(BlockRef) -> bool, + { + let mut block = Some(block); + while let Some(current) = block.take() { + if f(current.clone()) { + return Some(current); + } + block = current.borrow().parent_block(); + } + + None + } + + /// Try to get a pair of blocks, starting with the given pair, which live in the same region, + /// by exploring the relationships of both blocks with respect to their regions. + /// + /// The returned block pair will either be the same input blocks, or some combination of those + /// blocks or their ancestors. + pub fn get_blocks_in_same_region(a: &BlockRef, b: &BlockRef) -> Option<(BlockRef, BlockRef)> { + // If both blocks do not live in the same region, we will have to check their parent + // operations. + let a_region = a.borrow().parent().unwrap(); + let b_region = b.borrow().parent().unwrap(); + if a_region == b_region { + return Some((a.clone(), b.clone())); + } + + // Iterate over all ancestors of `a`, counting the depth of `a`. + // + // If one of `a`'s ancestors are in the same region as `b`, then we stop early because we + // found our nearest common ancestor. + let mut a_depth = 0; + let result = Self::traverse_ancestors(a.clone(), |block| { + a_depth += 1; + block.borrow().parent().is_some_and(|r| r == b_region) + }); + if let Some(a) = result { + return Some((a, b.clone())); + } + + // Iterate over all ancestors of `b`, counting the depth of `b`. + // + // If one of `b`'s ancestors are in the same region as `a`, then we stop early because we + // found our nearest common ancestor. + let mut b_depth = 0; + let result = Self::traverse_ancestors(b.clone(), |block| { + b_depth += 1; + block.borrow().parent().is_some_and(|r| r == a_region) + }); + if let Some(b) = result { + return Some((a.clone(), b)); + } + + // Otherwise, we found two blocks that are siblings at some level. Walk the deepest one + // up until we reach the top or find a nearest common ancestor. + let mut a = Some(a.clone()); + let mut b = Some(b.clone()); + loop { + use core::cmp::Ordering; + + match a_depth.cmp(&b_depth) { + Ordering::Greater => { + a = a.and_then(|a| a.borrow().parent_block()); + a_depth -= 1; + } + Ordering::Less => { + b = b.and_then(|b| b.borrow().parent_block()); + b_depth -= 1; + } + Ordering::Equal => break, + } + } + + // If we found something with the same level, then we can march both up at the same time + // from here on out. + while let Some(next_a) = a.take() { + // If they are at the same level, and have the same parent region, then we succeeded. + let next_a_parent = next_a.borrow().parent(); + let b_parent = b.as_ref().and_then(|b| b.borrow().parent()); + if next_a_parent == b_parent { + return Some((next_a, b.unwrap())); + } + + a = next_a_parent + .and_then(|r| r.borrow().parent()) + .and_then(|op| op.borrow().parent()); + b = b_parent.and_then(|r| r.borrow().parent()).and_then(|op| op.borrow().parent()); + } + + // They don't share a nearest common ancestor, perhaps they are in different modules or + // something. + None + } +} + /// Predecessors and Successors impl Block { /// Returns true if this block has predecessors diff --git a/hir2/src/ir/dominance.rs b/hir2/src/ir/dominance.rs new file mode 100644 index 000000000..e8186b315 --- /dev/null +++ b/hir2/src/ir/dominance.rs @@ -0,0 +1,17 @@ +mod info; +pub mod nca; +mod traits; +mod tree; + +pub use self::{ + info::{DominanceInfo, PostDominanceInfo, RegionDominanceInfo}, + traits::{Dominates, PostDominates}, + tree::{ + DomTreeError, DomTreeVerificationLevel, DominanceTree, PostDominanceTree, + PostOrderDomTreeIter, PreOrderDomTreeIter, + }, +}; +use self::{ + nca::{BatchUpdateInfo, SemiNCA}, + tree::{DomTreeBase, DomTreeNode, DomTreeRoots}, +}; diff --git a/hir2/src/ir/dominance/info.rs b/hir2/src/ir/dominance/info.rs new file mode 100644 index 000000000..f1a1bd185 --- /dev/null +++ b/hir2/src/ir/dominance/info.rs @@ -0,0 +1,419 @@ +use alloc::{collections::BTreeMap, rc::Rc}; +use core::cell::{LazyCell, Ref, RefCell}; + +use super::*; +use crate::{Block, BlockRef, OperationRef, RegionKindInterface, RegionRef}; + +/// [DominanceInfo] provides a high-level API for querying dominance information. +/// +/// Note that this type is aware of the different types of regions, and returns a region-kind +/// specific notion of dominance. See [RegionKindInterface] for details. +pub struct DominanceInfo { + info: DominanceInfoBase, +} + +impl DominanceInfo { + #[doc(hidden)] + #[inline(always)] + pub(crate) fn info(&self) -> &DominanceInfoBase { + &self.info + } + + /// Returns true if `a` dominates `b`. + /// + /// Note that if `a == b`, this returns true, if you want strict dominance, see + /// [Self::properly_dominates] instead. + /// + /// The specific details of how dominance is computed is specific to the types involved. See + /// the implementations of the [Dominates] trait for that information. + pub fn dominates(&self, a: &A, b: &B) -> bool + where + A: Dominates, + { + a.dominates(b, self) + } + + /// Returns true if `a` properly dominates `b`. + /// + /// This always returns false if `a == b`. + /// + /// The specific details of how dominance is computed is specific to the types involved. See + /// the implementations of the [Dominates] trait for that information. + pub fn properly_dominates(&self, a: &A, b: &B) -> bool + where + A: Dominates, + { + a.properly_dominates(b, self) + } + + /// An implementation of `properly_dominates` for operations, where we sometimes wish to treat + /// `a` as dominating `b`, if `b` is enclosed by a region of `a`. This behavior is controlled + /// by the `enclosing_op_ok` flag. + pub fn properly_dominates_with_options( + &self, + a: &OperationRef, + b: &OperationRef, + enclosing_op_ok: bool, + ) -> bool { + let a_block = a.borrow().parent().expect("`a` must be in a block"); + let mut b_block = b.borrow().parent().expect("`b` must be in a block"); + + // An instruction dominates itself, but does not properly dominate itself, unless this is + // a graph region. + if a == b { + return !a_block.borrow().has_ssa_dominance(); + } + + // If these ops are in different regions, then normalize one into the other. + let a_region = a_block.borrow().parent().unwrap(); + let mut b = b.clone(); + if a_region != b_block.borrow().parent().unwrap() { + // Walk up `b`'s region tree until we find an operation in `a`'s region that encloses + // it. If this fails, then we know there is no post-dominance relation. + let Some(found) = a_region.borrow().find_ancestor_op(&b) else { + return false; + }; + b = found; + b_block = b.borrow().parent().expect("`b` must be in a block"); + assert!(b_block.borrow().parent().unwrap() == a_region); + + // If `a` encloses `b`, then we consider it to dominate. + if a == &b && enclosing_op_ok { + return true; + } + } + + // Ok, they are in the same region now. + if a_block == b_block { + // Dominance changes based on the region type. In a region with SSA dominance, uses + // insde the same block must follow defs. In other region kinds, uses and defs can + // come in any order inside a block. + return if a_block.borrow().has_ssa_dominance() { + // If the blocks are the same, then check if `b` is before `a` in the block. + a.borrow().is_before_in_block(&b) + } else { + true + }; + } + + // If the blocks are different, use the dominance tree to resolve the query + self.info + .dominance(&a_region) + .properly_dominates(Some(&a_block), Some(&b_block)) + } +} + +/// [PostDominanceInfo] provides a high-level API for querying post-dominance information. +/// +/// Note that this type is aware of the different types of regions, and returns a region-kind +/// specific notion of dominance. See [RegionKindInterface] for details. +pub struct PostDominanceInfo { + info: DominanceInfoBase, +} + +impl PostDominanceInfo { + #[doc(hidden)] + #[inline(always)] + pub(super) fn info(&self) -> &DominanceInfoBase { + &self.info + } + + /// Returns true if `a` post-dominates `b`. + /// + /// Note that if `a == b`, this returns true, if you want strict post-dominance, see + /// [Self::properly_post_dominates] instead. + /// + /// The specific details of how dominance is computed is specific to the types involved. See + /// the implementations of the [PostDominates] trait for that information. + pub fn post_dominates(&self, a: &A, b: &B) -> bool + where + A: PostDominates, + { + a.post_dominates(b, self) + } + + /// Returns true if `a` properly post-dominates `b`. + /// + /// This always returns false if `a == b`. + /// + /// The specific details of how dominance is computed is specific to the types involved. See + /// the implementations of the [PostDominates] trait for that information. + pub fn properly_post_dominates(&self, a: &A, b: &B) -> bool + where + A: PostDominates, + { + a.properly_post_dominates(b, self) + } +} + +/// This type carries the dominance information for a single region, lazily computed on demand. +pub struct RegionDominanceInfo { + /// The dominator tree for this region + domtree: LazyCell>>, RegionDomTreeCtor>, + /// A flag that indicates where blocks in this region have SSA dominance + has_ssa_dominance: bool, +} + +impl RegionDominanceInfo { + /// Construct a new [RegionDominanceInfo] for `region` + pub fn new(region: RegionRef) -> Self { + let r = region.borrow(); + let parent_op = r.parent().unwrap(); + // A region has SSA dominance if it tells us one way or the other, otherwise we must assume + // that it does. + let has_ssa_dominance = parent_op + .borrow() + .as_trait::() + .map(|rki| rki.has_ssa_dominance()) + .unwrap_or(true); + + Self::create(region, has_ssa_dominance, r.has_one_block()) + } + + fn create(region: RegionRef, has_ssa_dominance: bool, has_one_block: bool) -> Self { + // We only create a dominator tree for multi-block regions + if has_one_block { + Self { + domtree: LazyCell::new(RegionDomTreeCtor(None)), + has_ssa_dominance, + } + } else { + Self { + domtree: LazyCell::new(RegionDomTreeCtor(Some(region))), + has_ssa_dominance, + } + } + } + + /// Get the dominance tree for this region. + /// + /// Returns `None` if the region was empty or had only a single block. + pub fn dominance(&self) -> Option>> { + self.domtree.clone() + } +} + +/// This type provides shared functionality to both [DominanceInfo] and [PostDominanceInfo]. +#[derive(Default)] +pub(crate) struct DominanceInfoBase { + /// A mapping of regions to their dominator tree and a flag that indicates whether or not they + /// have SSA dominance. + /// + /// This map does not contain dominator trees for empty or single block regions, however we + /// still compute whether or not they have SSA dominance regardless. + dominance_infos: RefCell>>, +} + +#[allow(unused)] +impl DominanceInfoBase { + /// Compute dominance information for all of the regions in `op`. + pub fn new(op: OperationRef) -> Self { + let op = op.borrow(); + let has_ssa_dominance = op + .as_trait::() + .map(|rki| rki.has_ssa_dominance()) + .unwrap_or(true); + let dominance_infos = BTreeMap::from_iter(op.regions().iter().map(|r| { + let region = r.as_region_ref(); + let info = RegionDominanceInfo::::create( + region.clone(), + has_ssa_dominance, + r.has_one_block(), + ); + (region, info) + })); + + Self { + dominance_infos: RefCell::new(dominance_infos), + } + } + + /// Invalidate all dominance info. + /// + /// This can be used by clients that make major changes to the CFG and don't have a good way to + /// update it. + pub fn invalidate(&mut self) { + self.dominance_infos.get_mut().clear(); + } + + /// Invalidate dominance info for the given region. + /// + /// This can be used by clients that make major changes to the CFG and don't have a good way to + /// update it. + pub fn invalidate_region(&mut self, region: RegionRef) { + self.dominance_infos.get_mut().remove(®ion); + } + + /// Finds the nearest common dominator block for the two given blocks `a` and `b`. + /// + /// If no common dominator can be found, this function will return `None`. + pub fn find_nearest_common_dominator_of( + &self, + a: Option<&BlockRef>, + b: Option<&BlockRef>, + ) -> Option { + // If either `a` or `b` are `None`, then conservatively return `None` + let a = a?; + let b = b?; + + // If they are the same block, then we are done. + if a == b { + return Some(a.clone()); + } + + // Try to find blocks that are in the same region. + let (a, b) = Block::get_blocks_in_same_region(a, b)?; + + // If the common ancestor in a common region is the same block, then return it. + if a == b { + return Some(a); + } + + // Otherwise, there must be multiple blocks in the region, check the dominance tree + self.dominance(&a.borrow().parent().unwrap()) + .find_nearest_common_dominator(&a, &b) + } + + /// Finds the nearest common dominator block for the given range of blocks. + /// + /// If no common dominator can be found, this function will return `None`. + pub fn find_nearest_common_dominator_of_all<'a>( + &self, + mut blocks: impl ExactSizeIterator, + ) -> Option { + let mut dom = blocks.next().cloned(); + + for block in blocks { + dom = self.find_nearest_common_dominator_of(dom.as_ref(), Some(block)); + } + + dom + } + + /// Get the root dominance node of the given region. + /// + /// Panics if `region` is not a multi-block region. + pub fn root_node(&self, region: &RegionRef) -> Rc { + self.get_dominance_info(region) + .domtree + .as_deref() + .expect("`region` isn't multi-block") + .root_node() + .expect("expected region to have a root node") + } + + /// Return the dominance node for the region containing `block`. + /// + /// Panics if `block` is not a member of a multi-block region. + pub fn node(&self, block: &BlockRef) -> Option> { + self.get_dominance_info(&block.borrow().parent().expect("block isn't attached to region")) + .domtree + .as_deref() + .expect("`block` isn't in a multi-block region") + .get(Some(block)) + } + + /// Return true if the specified block is reachable from the entry block of its region. + pub fn is_reachable_from_entry(&self, block: &BlockRef) -> bool { + // If this is the first block in its region, then it is trivially reachable. + if block.borrow().is_entry_block() { + return true; + } + + let region = block.borrow().parent().expect("block isn't attached to region"); + self.dominance(®ion).is_reachable_from_entry(block) + } + + /// Return true if operations in the specified block are known to obey SSA dominance rules. + /// + /// Returns false if the block is a graph region or unknown. + pub fn block_has_ssa_dominance(&self, block: &BlockRef) -> bool { + let region = block.borrow().parent().expect("block isn't attached to region"); + self.get_dominance_info(®ion).has_ssa_dominance + } + + /// Return true if operations in the specified region are known to obey SSA dominance rules. + /// + /// Returns false if the region is a graph region or unknown. + pub fn region_has_ssa_dominance(&self, region: &RegionRef) -> bool { + self.get_dominance_info(region).has_ssa_dominance + } + + /// Returns the dominance tree for `region`. + /// + /// Panics if `region` is a single-block region. + pub fn dominance(&self, region: &RegionRef) -> Rc> { + self.get_dominance_info(region) + .dominance() + .expect("cannot get dominator tree for single block regions") + } + + /// Return the dominance information for `region`. + /// + /// NOTE: The dominance tree for single-block regions will be `None` + fn get_dominance_info(&self, region: &RegionRef) -> Ref<'_, RegionDominanceInfo> { + // Check to see if we already have this information. + self.dominance_infos + .borrow_mut() + .entry(region.clone()) + .or_insert_with(|| RegionDominanceInfo::new(region.clone())); + + Ref::map(self.dominance_infos.borrow(), |di| &di[region]) + } + + /// Return true if the specified block A properly dominates block B. + pub fn properly_dominates(&self, a: &BlockRef, b: &BlockRef) -> bool { + // A block dominates itself, but does not properly dominate itself. + if a == b { + return false; + } + + // If both blocks are not in the same region, `a` properly dominates `b` if `b` is defined + // in an operation region that (recursively) ends up being dominated by `a`. Walk up the + // ancestors of `b`. + let mut b = b.clone(); + let a_region = a.borrow().parent(); + if a_region != b.borrow().parent() { + // If we could not find a valid block `b` then it is not a dominator. + let Some(found) = a_region.as_ref().and_then(|r| r.borrow().find_ancestor_block(&b)) + else { + return false; + }; + + b = found; + + // Check to see if the ancestor of `b` is the same block as `a`. `a` properly dominates + // `b` if it contains an op that contains the `b` block + if a == &b { + return true; + } + } + + // Otherwise, they are two different blocks in the same region, use dominance tree + self.dominance(&a_region.unwrap()).properly_dominates(Some(a), Some(&b)) + } +} + +/// A faux-constructor for [RegionDominanceInfo] for use with [LazyCell] without boxing. +struct RegionDomTreeCtor(Option); +impl FnOnce<()> for RegionDomTreeCtor { + type Output = Option>>; + + extern "rust-call" fn call_once(self, _args: ()) -> Self::Output { + self.0.and_then(|region| DomTreeBase::new(region).ok().map(Rc::new)) + } +} +impl FnMut<()> for RegionDomTreeCtor { + extern "rust-call" fn call_mut(&mut self, _args: ()) -> Self::Output { + self.0 + .as_ref() + .and_then(|region| DomTreeBase::new(region.clone()).ok().map(Rc::new)) + } +} +impl Fn<()> for RegionDomTreeCtor { + extern "rust-call" fn call(&self, _args: ()) -> Self::Output { + self.0 + .as_ref() + .and_then(|region| DomTreeBase::new(region.clone()).ok().map(Rc::new)) + } +} diff --git a/hir2/src/ir/dominance/nca.rs b/hir2/src/ir/dominance/nca.rs new file mode 100644 index 000000000..025b634eb --- /dev/null +++ b/hir2/src/ir/dominance/nca.rs @@ -0,0 +1,1531 @@ +use alloc::{collections::BTreeMap, rc::Rc}; +use core::cell::{Cell, Ref, RefCell}; + +use smallvec::SmallVec; + +use super::{DomTreeBase, DomTreeNode, DomTreeRoots}; +use crate::{ + cfg::{self, Graph, GraphDiff, Inverse}, + formatter::{DisplayOptional, DisplayValues}, + BlockRef, EntityWithId, Region, +}; + +/// [SemiNCAInfo] provides functionality for constructing a dominator tree for a control-flow graph +/// based on the Semi-NCA algorithm described in the following dissertation: +/// +/// [1] Linear-Time Algorithms for Dominators and Related Problems +/// Loukas Georgiadis, Princeton University, November 2005, pp. 21-23: +/// ftp://ftp.cs.princeton.edu/reports/2005/737.pdf +/// +/// The Semi-NCA algorithm runs in O(n^2) worst-case time but usually slightly faster than Simple +/// Lengauer-Tarjan in practice. +/// +/// O(n^2) worst cases happen when the computation of nearest common ancestors requires O(n) average +/// time, which is very unlikely in real world. If this ever turns out to be an issue, consider +/// implementing a hybrid algorithm that uses SLT to perform full constructions and SemiNCA for +/// incremental updates. +/// +/// The file uses the Depth Based Search algorithm to perform incremental updates (insertion and +/// deletions). The implemented algorithm is based on this publication: +/// +/// [2] An Experimental Study of Dynamic Dominators +/// Loukas Georgiadis, et al., April 12 2016, pp. 5-7, 9-10: +/// https://arxiv.org/pdf/1604.02711.pdf +pub struct SemiNCA { + /// Number to node mapping is 1-based. + virtual_node: RefCell, + num_to_node: SmallVec<[BlockRef; 64]>, + /// Infos are mapped to nodes using block indices + node_infos: RefCell>, + batch_updates: Option>, +} + +/// Get the successors (or predecessors, if `INVERSED == true`) of `node`, incorporating insertions +/// and deletions from `bui` if available. +/// +/// The use of "children" here changes meaning depending on: +/// +/// * Whether or not the graph traversal is `INVERSED` +/// * Whether or not the graph is a post-dominator tree (i.e. `IS_POST_DOM`) +/// +/// If we're traversing a post-dominator tree, then the "children" of a node are actually +/// predecessors of the block in the CFG. However, if the traversal is _also_ `INVERSED`, then the +/// children actually are successors of the block in the CFG. +/// +/// For a forward-dominance tree, "children" do correspond to successors in the CFG, but again, if +/// the traversal is `INVERSED`, then the children are actually predecessors. +/// +/// This function (and others in this module) are written in such a way that we can abstract over +/// whether the underlying dominator tree is a forward- or post-dominance tree, as much of the +/// implementation is identical. +pub fn get_children_with_batch_updates( + node: &BlockRef, + bui: Option<&BatchUpdateInfo>, +) -> SmallVec<[BlockRef; 8]> { + use crate::cfg::GraphDiff; + + if let Some(bui) = bui { + bui.pre_cfg_view.get_children::(node) + } else { + get_children::(node) + } +} + +/// Get the successors (or predecessors, if `INVERSED == true`) of `node`. +pub fn get_children(node: &BlockRef) -> SmallVec<[BlockRef; 8]> { + if INVERSED { + Inverse::::children(node.clone()).collect() + } else { + let mut r = BlockRef::children(node.clone()).collect::>(); + r.reverse(); + r + } +} + +#[derive(Default)] +pub struct NodeInfo { + num: Cell, + parent: Cell, + semi: Cell, + label: Cell, + idom: Cell>, + reverse_children: SmallVec<[u32; 4]>, +} +impl NodeInfo { + pub fn idom(&self) -> Option { + unsafe { &*self.idom.as_ptr() }.clone() + } + + #[inline] + pub fn num(&self) -> u32 { + self.num.get() + } + + #[inline] + pub fn parent(&self) -> u32 { + self.parent.get() + } + + #[inline] + pub fn semi(&self) -> u32 { + self.semi.get() + } + + #[inline] + pub fn label(&self) -> u32 { + self.label.get() + } + + #[inline] + pub fn reverse_children(&self) -> &[u32] { + &self.reverse_children + } +} + +/// [BatchUpdateInfo] represents a batch of insertion/deletion operations that have been applied to +/// the CFG. This information is used to incrementally update the dominance tree as changes are +/// made to the CFG. +#[derive(Default, Clone)] +pub struct BatchUpdateInfo { + pub pre_cfg_view: cfg::CfgDiff, + pub post_cfg_view: cfg::CfgDiff, + pub num_legalized: usize, + // Remembers if the whole tree was recomputed at some point during the current batch update + pub is_recalculated: bool, +} + +impl BatchUpdateInfo { + pub fn new( + pre_cfg_view: cfg::CfgDiff, + post_cfg_view: Option>, + ) -> Self { + let num_legalized = pre_cfg_view.num_legalized_updates(); + Self { + pre_cfg_view, + post_cfg_view: post_cfg_view.unwrap_or_default(), + num_legalized, + is_recalculated: false, + } + } +} + +impl SemiNCA { + /// Obtain a fresh [SemiNCA] instance, using the provided set of [BatchUpdateInfo]. + pub fn new(batch_updates: Option>) -> Self { + Self { + virtual_node: Default::default(), + num_to_node: Default::default(), + node_infos: Default::default(), + batch_updates, + } + } + + /// Reset the [SemiNCA] state so it can be used to compute a dominator tree from scratch. + pub fn clear(&mut self) { + // Don't reset the pointer to BatchUpdateInfo here -- if there's an update in progress, + // we need this information to continue it. + self.virtual_node = Default::default(); + self.num_to_node.clear(); + self.node_infos.get_mut().clear(); + } + + /// Look up information about a block in the Semi-NCA state + pub fn node_info(&self, block: Option<&BlockRef>) -> Ref<'_, NodeInfo> { + match block { + None => self.virtual_node.borrow(), + Some(block) => { + let id = block.borrow().id(); + let index = id.as_u32() as usize; + + if index >= self.node_infos.borrow().len() { + self.node_infos.borrow_mut().resize_with(index + 1, NodeInfo::default); + } + + Ref::map(self.node_infos.borrow(), |ni| unsafe { ni.get_unchecked(index) }) + } + } + } + + /// Get a mutable reference to the stored informaton for `block` + pub fn node_info_mut(&mut self, block: Option<&BlockRef>) -> &mut NodeInfo { + match block { + None => self.virtual_node.get_mut(), + Some(block) => { + let id = block.borrow().id(); + let index = id.as_u32() as usize; + + let node_infos = self.node_infos.get_mut(); + if index >= node_infos.len() { + node_infos.resize_with(index + 1, NodeInfo::default); + } + + unsafe { node_infos.get_unchecked_mut(index) } + } + } + } + + /// Look up the immediate dominator for `block`, if it has one. + /// + /// A value of `None` for `block` is meaningless, as virtual nodes only are present in post- + /// dominance graphs, and always post-dominate all other nodes in the graph. However, it is + /// convenient to have many of the APIs in this module take a `Option` for uniformity. + pub fn idom(&self, block: Option<&BlockRef>) -> Option { + self.node_info(block).idom() + } + + /// Get or compute the dominance tree node information for `block`, in `tree`, using the current + /// Semi-NCA state. + pub fn node_for_block( + &self, + block: Option<&BlockRef>, + tree: &mut DomTreeBase, + ) -> Option> { + let node = tree.get(block); + if node.is_some() { + return node; + } + + // Haven't calculated this node yet? Get or calculate the node for the immediate dominator + let idom = self.idom(block); + let idom_node = match idom { + None => Some(tree.get(None).expect("expected idom or virtual node")), + Some(idom_block) => self.node_for_block(Some(&idom_block), tree), + }; + + // Add a new tree node for this node, and link it as a child of idom_node + Some(tree.create_node(block.cloned(), idom_node)) + } + + /// Custom DFS implementation which can skip nodes based on a provided predicate. + /// + /// It also collects reverse children so that we don't have to spend time getting predecessors + /// in SemiNCA. + /// + /// If `IsReverse` is set to true, the DFS walk will be performed backwards relative to IS_POST_DOM + /// -- using reverse edges for dominators and forward edges for post-dominators. + /// + /// If `succ_order` is specified then that is the order in which the DFS traverses the children, + /// otherwise the order is implied by the results of `get_children`. + pub fn run_dfs( + &mut self, + v: Option<&BlockRef>, + mut last_num: u32, + mut condition: C, + attach_to_num: u32, + succ_order: Option<&BTreeMap>, + ) -> u32 + where + C: FnMut(Option<&BlockRef>, Option<&BlockRef>) -> bool, + { + let v = v.expect("expected valid root node for search"); + + let mut worklist = + SmallVec::<[(BlockRef, u32); 64]>::from_iter([(v.clone(), attach_to_num)]); + + self.node_info_mut(Some(v)).parent.set(attach_to_num); + + while let Some((block, parent_num)) = worklist.pop() { + let block_info = self.node_info_mut(Some(&block)); + block_info.reverse_children.push(parent_num); + + // Visited nodes always have positive DFS numbers. + if block_info.num.get() != 0 { + continue; + } + + block_info.parent.set(parent_num); + last_num += 1; + block_info.num.set(last_num); + block_info.semi.set(last_num); + block_info.label.set(last_num); + self.num_to_node.push(block.clone()); + + let mut successors = if const { REVERSE != IS_POST_DOM } { + get_children_with_batch_updates::( + &block, + self.batch_updates.as_ref(), + ) + } else { + get_children_with_batch_updates::( + &block, + self.batch_updates.as_ref(), + ) + }; + if let Some(succ_order) = succ_order { + if successors.len() > 1 { + successors.sort_by(|a, b| succ_order[a].cmp(&succ_order[b])); + } + } + + for succ in successors.into_iter().filter(|succ| condition(Some(&block), Some(succ))) { + worklist.push((succ, last_num)); + } + } + + last_num + } + + // V is a predecessor of W. eval() returns V if V < W, otherwise the minimum + // of sdom(U), where U > W and there is a virtual forest path from U to V. The + // virtual forest consists of linked edges of processed vertices. + // + // We can follow Parent pointers (virtual forest edges) to determine the + // ancestor U with minimum sdom(U). But it is slow and thus we employ the path + // compression technique to speed up to O(m*log(n)). Theoretically the virtual + // forest can be organized as balanced trees to achieve almost linear + // O(m*alpha(m,n)) running time. But it requires two auxiliary arrays (Size + // and Child) and is unlikely to be faster than the simple implementation. + // + // For each vertex V, its Label points to the vertex with the minimal sdom(U) + // (Semi) in its path from V (included) to NodeToInfo[V].Parent (excluded). + fn eval<'a, 'b: 'a>( + v: u32, + last_linked: u32, + eval_stack: &mut SmallVec<[&'a NodeInfo; 32]>, + num_to_info: &'b [Ref<'b, NodeInfo>], + ) -> u32 { + let mut v_info = &*num_to_info[v as usize]; + if v_info.parent.get() < last_linked { + return v_info.label.get(); + } + + // Store ancestors except the last (root of a virtual tree) into a stack. + eval_stack.clear(); + loop { + let parent = &num_to_info[v_info.parent.get() as usize]; + eval_stack.push(v_info); + v_info = parent; + if v_info.parent.get() < last_linked { + break; + } + } + + // Path compression. Point each vertex's `parent` to the root and update its `label` if any + // of its ancestors `label` has a smaller `semi` + let mut p_info = v_info; + let mut p_label_info = &*num_to_info[p_info.label.get() as usize]; + while let Some(info) = eval_stack.pop() { + v_info = info; + v_info.parent.set(p_info.parent.get()); + let v_label_info = &*num_to_info[v_info.label.get() as usize]; + if p_label_info.semi.get() < v_label_info.semi.get() { + v_info.label.set(p_info.label.get()); + } else { + p_label_info = v_label_info; + } + p_info = v_info; + } + + v_info.label.get() + } + + /// This function requires DFS to be run before calling it. + pub fn run(&mut self) { + let next_num = self.num_to_node.len(); + let mut num_to_info = SmallVec::<[Ref<'_, NodeInfo>; 8]>::default(); + num_to_info.reserve(next_num); + + // Initialize idoms to spanning tree parents + for i in 0..next_num { + let v = &self.num_to_node[i]; + let v_info = self.node_info(Some(v)); + v_info.idom.set(Some(self.num_to_node[v_info.parent.get() as usize].clone())); + num_to_info.push(v_info); + } + + // Step 1: Calculate the semi-dominators of all vertices + let mut eval_stack = SmallVec::<[&NodeInfo; 32]>::default(); + for i in (2..(next_num - 1)).rev() { + let w_info = &num_to_info[i]; + + // Initialize the semi-dominator to point to the parent node. + w_info.semi.set(w_info.parent.get()); + for n in w_info.reverse_children.iter().copied() { + let semi_u = num_to_info + [Self::eval(n, i as u32 + 1, &mut eval_stack, &num_to_info) as usize] + .semi + .get(); + if semi_u < w_info.semi.get() { + w_info.semi.set(semi_u); + } + } + } + + // Step 2: Explicitly define the immediate dominator of each vertex. + // + // IDom[i] = NCA(SDom[i], SpanningTreeParent(i)) + // + // Note that the parents were stored in IDoms and later got invalidated during path + // compression in `eval` + for i in 2..next_num { + let w_info = &num_to_info[i]; + assert_ne!(w_info.semi.get(), 0); + let s_dom_num = num_to_info[w_info.semi.get() as usize].num.get(); + let mut w_idom_candidate = w_info.idom(); + loop { + let w_idom_candidate_info = self.node_info(w_idom_candidate.as_ref()); + if w_idom_candidate_info.num.get() <= s_dom_num { + break; + } + w_idom_candidate = w_idom_candidate_info.idom(); + } + + w_info.idom.set(w_idom_candidate); + } + } + + /// [PostDominatorTree] always has a virtual root that represents a virtual CFG node that serves + /// as a single exit from the region. + /// + /// All the other exits (CFG nodes with terminators and nodes in infinite loops) are logically + /// connected to this virtual CFG exit node. + /// + /// This function maps a null CFG node to the virtual root tree node. + fn add_virtual_root(&mut self) { + if const { IS_POST_DOM } { + assert!(self.num_to_node.is_empty(), "SemiNCAInfo must be freshly constructed"); + + let info = self.virtual_node.get_mut(); + info.num.set(1); + info.semi.set(1); + info.label.set(1); + + // num_to_node[1] = None + } + } + + /// For postdominators, nodes with no forward successors are trivial roots that + /// are always selected as tree roots. Roots with forward successors correspond + /// to CFG nodes within infinite loops. + fn has_forward_successors( + n: Option<&BlockRef>, + bui: Option<&BatchUpdateInfo>, + ) -> bool { + let n = n.expect("`n` must be a valid node"); + !get_children_with_batch_updates::(n, bui).is_empty() + } + + fn entry_node(tree: &DomTreeBase) -> BlockRef { + tree.parent() + .borrow() + .entry_block_ref() + .expect("expected region to have an entry block") + } + + pub fn find_roots( + tree: &DomTreeBase, + bui: Option<&BatchUpdateInfo>, + ) -> DomTreeRoots { + let mut roots = DomTreeRoots::default(); + + // For dominators, region entry CFG node is always a tree root node. + if !IS_POST_DOM { + roots.push(Some(Self::entry_node(tree))); + return roots; + } + + let mut snca = Self::new(bui.cloned()); + + // PostDominatorTree always has a virtual root. + snca.add_virtual_root(); + let mut num = 1u32; + + log::trace!("looking for trivial roots"); + + // Step 1: Find all the trivial roots that are going to definitely remain tree roots + let mut total = 0; + // It may happen that there are some new nodes in the CFG that are result of the ongoing + // batch update, but we cannot really pretend that they don't exist -- we won't see any + // outgoing or incoming edges to them, so it's fine to discover them here, as they would end + // up appearing in the CFG at some point anyway. + let region = tree.parent().borrow(); + let mut region_body = region.body().front(); + while let Some(n) = region_body.as_pointer() { + region_body.move_next(); + total += 1; + // If it has no successors, it is definitely a root + if !Self::has_forward_successors(Some(&n), bui) { + roots.push(Some(n.clone())); + // Run DFS not to walk this part of CFG later. + num = snca.run_dfs::(Some(&n), num, always_descend, 1, None); + log::trace!("found a new trivial root: {}", n.borrow().id()); + match snca.num_to_node.get(num as usize) { + None => log::trace!("last visited node: None"), + Some(last_visited) => { + log::trace!("last visited node: {}", last_visited.borrow().id()) + } + } + } + } + + log::trace!("looking for non-trivial roots"); + + // Step 2: Find all non-trivial root candidates. + // + // Those are CFG nodes that are reverse-unreachable were not visited by previous DFS walks + // (i.e. CFG nodes in infinite loops). + // + // Accounting for the virtual exit, see if we had any reverse-unreachable nodes. + let has_non_trivial_roots = total + 1 != num; + if has_non_trivial_roots { + // `succ_order` is the order of blocks in the region. It is needed to make the + // calculation of the `furthest_away` node and the whole PostDominanceTree immune to + // swapping successors (e.g. canonicalizing branch predicates). `succ_order` is + // initialized lazily only for successors of reverse unreachable nodes. + #[derive(Default)] + struct LazySuccOrder { + succ_order: BTreeMap, + initialized: bool, + } + impl LazySuccOrder { + pub fn get_or_init<'a, 'b: 'a, const IS_POST_DOM: bool>( + &'b mut self, + region: &Region, + bui: Option<&'a BatchUpdateInfo>, + snca: &SemiNCA, + ) -> &'a BTreeMap { + if !self.initialized { + let mut region_body = region.body().front(); + while let Some(n) = region_body.as_pointer() { + region_body.move_next(); + let n_num = snca.node_info(Some(&n)).num.get(); + if n_num == 0 { + for succ in + get_children_with_batch_updates::(&n, bui) + { + self.succ_order.insert(succ, 0); + } + } + } + + // Add mapping for all entries of succ_order + let mut node_num = 0; + let mut region_body = region.body().front(); + while let Some(n) = region_body.as_pointer() { + region_body.move_next(); + node_num += 1; + if let Some(order) = self.succ_order.get_mut(&n) { + assert_eq!(*order, 0); + *order = node_num; + } + } + self.initialized = true; + } + + &self.succ_order + } + } + + let mut succ_order = LazySuccOrder::default(); + + // Make another DFS pass over all other nodes to find the reverse-unreachable blocks, + // and find the furthest paths we'll be able to make. + // + // Note that this looks N^2, but it's really 2N worst case, if every node is unreachable. + // This is because we are still going to only visit each unreachable node once, we may + // just visit it in two directions, depending on how lucky we get. + let mut region_body = region.body().front(); + while let Some(n) = region_body.as_pointer() { + region_body.move_next(); + + if snca.node_info(Some(&n)).num.get() == 0 { + log::trace!("visiting node {n}"); + + // Find the furthest away we can get by following successors, then + // follow them in reverse. This gives us some reasonable answer about + // the post-dom tree inside any infinite loop. In particular, it + // guarantees we get to the farthest away point along *some* + // path. This also matches the GCC's behavior. + // If we really wanted a totally complete picture of dominance inside + // this infinite loop, we could do it with SCC-like algorithms to find + // the lowest and highest points in the infinite loop. In theory, it + // would be nice to give the canonical backedge for the loop, but it's + // expensive and does not always lead to a minimal set of roots. + log::trace!("running forward DFS.."); + + let succ_order = succ_order.get_or_init(®ion, bui, &snca); + let new_num = snca.run_dfs::( + Some(&n), + num, + always_descend, + num, + Some(succ_order), + ); + let furthest_away = snca.num_to_node[new_num as usize].clone(); + log::trace!( + "found a new furthest away node (non-trivial root): {furthest_away}", + ); + roots.push(Some(furthest_away.clone())); + log::trace!("previous `num`: {num}, new `num` {new_num}"); + log::trace!("removing DFS info.."); + for i in ((num + 1)..=new_num).rev() { + let n = snca.num_to_node[i as usize].clone(); + log::trace!("removing DFS info for {n}"); + *snca.node_info_mut(Some(&n)) = Default::default(); + snca.num_to_node.pop(); + } + let prev_num = num; + log::trace!("running reverse depth-first search"); + num = snca.run_dfs::( + Some(&furthest_away), + num, + always_descend, + 1, + None, + ); + for i in (prev_num + 1)..num { + match snca.num_to_node.get(i as usize) { + None => log::trace!("found virtual node"), + Some(n) => log::trace!("found node {n}"), + } + } + } + } + } + + log::trace!("total: {total}, num: {num}"); + log::trace!("discovered cfg nodes:"); + for i in 0..num { + log::trace!(" {i}: {}", &snca.num_to_node[i as usize]); + } + + assert_eq!(total + 1, num, "everything should have been visited"); + + // Step 3: If we found some non-trivial roots, make them non-redundant. + if has_non_trivial_roots { + Self::remove_redundant_roots(snca.batch_updates.as_ref(), &mut roots); + } + + log::trace!( + "found roots: {}", + DisplayValues::new(roots.iter().map(|v| DisplayOptional(v.as_ref()))) + ); + + roots + } + + // This function only makes sense for postdominators. + // + // We define roots to be some set of CFG nodes where (reverse) DFS walks have to start in order + // to visit all the CFG nodes (including the reverse-unreachable ones). + // + // When the search for non-trivial roots is done it may happen that some of the non-trivial + // roots are reverse-reachable from other non-trivial roots, which makes them redundant. This + // function removes them from the set of input roots. + fn remove_redundant_roots( + bui: Option<&BatchUpdateInfo>, + roots: &mut SmallVec<[Option; 4]>, + ) { + assert!(IS_POST_DOM, "this function is for post-dominators only"); + + log::trace!("removing redundant roots.."); + + let mut snca = Self::new(bui.cloned()); + + let mut root_index = 0; + 'roots: while root_index < roots.len() { + let root = &roots[root_index]; + + // Trivial roots are never redundant + if !Self::has_forward_successors(root.as_ref(), bui) { + continue; + } + + log::trace!("checking if {} remains a root", DisplayOptional(root.as_ref())); + snca.clear(); + + // Do a forward walk looking for the other roots. + let num = snca.run_dfs::(root.as_ref(), 0, always_descend, 0, None); + // Skip the start node and begin from the second one (note that DFS uses 1-based indexing) + for x in 2..(num as usize) { + let n = &snca.num_to_node[x]; + + // If we found another root in a (forward) DFS walk, remove the current root from + // the set of roots, as it is reverse-reachable from the other one. + if roots.iter().any(|r| r.as_ref().is_some_and(|root| root == n)) { + log::trace!("forward DFS walk found another root {n}"); + log::trace!("removing root {}", DisplayOptional(root.as_ref())); + roots.swap_remove(root_index); + + // Root at the back takes the current root's place, so revisit the same index on + // the next iteration + continue 'roots; + } + } + + root_index += 1; + } + } + + pub fn do_full_dfs_walk(&mut self, tree: &DomTreeBase, condition: C) + where + for<'a> C: Copy + Fn(Option<&BlockRef>, Option<&BlockRef>) -> bool + 'a, + { + if const { !IS_POST_DOM } { + assert_eq!(tree.num_roots(), 1, "dominators should have a single root"); + self.run_dfs::(tree.roots()[0].as_ref(), 0, condition, 0, None); + return; + } + + self.add_virtual_root(); + let mut num = 1; + for root in tree.roots() { + num = self.run_dfs::(root.as_ref(), num, condition, 1, None); + } + } + + pub fn attach_new_subtree( + &mut self, + tree: &mut DomTreeBase, + attach_to: Rc, + ) { + // Attach the first unreachable block to `attach_to` + self.node_info(Some(&self.num_to_node[0])).idom.set(attach_to.block().cloned()); + // Loop over all of the discovered blocks in the function... + for w in self.num_to_node.iter() { + if tree.get(Some(w)).is_some() { + // Already computed the node before + continue; + } + + let idom = self.idom(Some(w)); + + // Get or compute the node for the immediate dominator + let idom_node = self.node_for_block(idom.as_ref(), tree); + + // Add a new tree node for this basic block, and link it as a child of idom_node + tree.create_node(Some(w.clone()), idom_node); + } + } + + pub fn reattach_existing_subtree( + &mut self, + tree: &mut DomTreeBase, + attach_to: Rc, + ) { + self.node_info(Some(&self.num_to_node[0])).idom.set(attach_to.block().cloned()); + for n in self.num_to_node.iter() { + let node = tree.get(Some(n)).unwrap(); + let idom = tree.get(self.node_info(Some(n)).idom().as_ref()); + node.set_idom(idom); + } + } + + // Checks if a node has proper support, as defined on the page 3 and later + // explained on the page 7 of [2]. + pub fn has_proper_support( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + node: &DomTreeNode, + ) -> bool { + log::trace!("is reachable from idom {node}"); + + let Some(block) = node.block() else { + return false; + }; + + let preds = if IS_POST_DOM { + get_children_with_batch_updates::(block, bui) + } else { + get_children_with_batch_updates::(block, bui) + }; + + for pred in preds { + log::trace!("pred {pred}"); + if tree.get(Some(&pred)).is_none() { + continue; + } + + let support = tree.find_nearest_common_dominator(block, &pred); + log::trace!("support {}", DisplayOptional(support.as_ref())); + if support.as_ref() != Some(block) { + log::trace!( + "{node} is reachable from support {}", + DisplayOptional(support.as_ref()) + ); + return true; + } + } + + false + } +} + +#[derive(Eq, PartialEq)] +struct InsertionInfoItem { + node: Rc, +} +impl From> for InsertionInfoItem { + fn from(node: Rc) -> Self { + Self { node } + } +} +impl PartialOrd for InsertionInfoItem { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for InsertionInfoItem { + #[inline] + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.node.level().cmp(&other.node.level()) + } +} + +#[derive(Default)] +struct InsertionInfo { + bucket: crate::adt::SmallPriorityQueue, + visited: crate::adt::SmallSet, 8>, + affected: SmallVec<[Rc; 8]>, +} + +/// Insertion and Deletion +impl SemiNCA { + pub fn insert_edge( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Option, + to: Option, + ) { + assert!( + from.as_ref().is_some() || IS_POST_DOM, + "'from' has to be a valid cfg node or a virtual root" + ); + let to = to.expect("expected a valid `to` node"); + + log::trace!("inserting edge {from:?} -> {to}"); + + let from_node = tree.get(from.as_ref()); + let from_node = if let Some(from_node) = from_node { + from_node + } else { + // Ignore edges from unreachable nodes for (forward) dominators. + if !IS_POST_DOM { + return; + } + + // The unreachable node becomes a new root -- a tree node for it. + let virtual_root = tree.get(None); + let from_node = tree.create_node(from.clone(), virtual_root); + tree.roots_mut().push(from); + from_node + }; + + tree.mark_invalid(); + + let to_node = tree.get(Some(&to)); + match to_node { + None => Self::insert_unreachable(tree, bui, from_node, to), + Some(to_node) => Self::insert_reachable(tree, bui, from_node, to_node), + } + } + + fn insert_unreachable( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Rc, + to: BlockRef, + ) { + log::trace!("inserting {from} -> {to} (unreachable)"); + + // Collect discovered edges to already reachable nodes + // Discover and connect nodes that became reachable with the insertion. + let mut discovered_edges_to_reachable = SmallVec::default(); + Self::compute_unreachable_dominators( + tree, + bui, + &to, + from.clone(), + &mut discovered_edges_to_reachable, + ); + + log::trace!("inserted {from} -> {to} (prev unreachable)"); + + // Use the discovered edges and insert discovered connecting (incoming) edges + for (from_block_ref, to_node) in discovered_edges_to_reachable { + log::trace!("inserting discovered connecting edge {from_block_ref:?} -> {to_node}",); + let from_node = tree.get(from_block_ref.as_ref()).unwrap(); + Self::insert_reachable(tree, bui, from_node, to_node); + } + } + + fn insert_reachable( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Rc, + to: Rc, + ) { + log::trace!("reachable {from} -> {to}"); + + if const { IS_POST_DOM } { + let rebuilt = SemiNCA::::update_roots_before_insertion( + unsafe { + core::mem::transmute::<&mut DomTreeBase, &mut DomTreeBase>( + tree, + ) + }, + bui.map(|bui| unsafe { + core::mem::transmute::<&BatchUpdateInfo, &BatchUpdateInfo>( + bui, + ) + }), + to.clone(), + ); + if rebuilt { + return; + } + } + + // find_nearest_common_dominator expects both pointers to be valid. When `from` is a virtual + // root, then its CFG block pointer is `None`, so we have to "compute" the NCD manually + let ncd_block = if from.block().is_some() && to.block().is_some() { + tree.find_nearest_common_dominator(from.block().unwrap(), from.block().unwrap()) + } else { + None + }; + assert!(ncd_block.is_some() || tree.is_post_dominator()); + let ncd = tree.get(ncd_block.as_ref()).unwrap(); + + log::trace!("nearest common dominator == {ncd}"); + + // Based on Lemma 2.5 from [2], after insertion of (from, to), `v` is affected iff + // depth(ncd) + 1 < depth(v) && a path `P` from `to` to `v` exists where every `w` on `P` + // s.t. depth(v) <= depth(w) + // + // This reduces to a widest path problem (maximizing the depth of the minimum vertex in + // the path) which can be solved by a modified version of Dijkstra with a bucket queue + // (named depth-based search in [2]). + // + // `to` is in the path, so depth(ncd) + 1 < depth(v) <= depth(to). Nothing affected if + // this does not hold. + let ncd_level = ncd.level(); + if ncd_level + 1 >= to.level() { + return; + } + + let mut insertion_info = InsertionInfo::default(); + let mut unaffected_on_current_level = SmallVec::<[Rc; 8]>::default(); + insertion_info.bucket.push(to.clone().into()); + insertion_info.visited.insert(to); + + while let Some(InsertionInfoItem { mut node }) = insertion_info.bucket.pop() { + insertion_info.affected.push(node.clone()); + + let current_level = node.level(); + log::trace!("mark {node} as affected, current level: {current_level}"); + + assert!(node.block().is_some() && insertion_info.visited.contains(&node)); + + loop { + // Unlike regular Dijkstra, we have an inner loop to expand more + // vertices. The first iteration is for the (affected) vertex popped + // from II.Bucket and the rest are for vertices in + // UnaffectedOnCurrentLevel, which may eventually expand to affected + // vertices. + // + // Invariant: there is an optimal path from `To` to TN with the minimum + // depth being CurrentLevel. + for succ in get_children_with_batch_updates::( + node.block().unwrap(), + bui, + ) { + let succ_node = tree + .get(Some(&succ)) + .expect("unreachable successor found during reachable insertion"); + let succ_level = succ_node.level(); + log::trace!("successor {succ_node}, level = {succ_level}"); + + // There is an optimal path from `To` to Succ with the minimum depth + // being min(CurrentLevel, SuccLevel). + // + // If depth(NCD)+1 < depth(Succ) is not satisfied, Succ is unaffected + // and no affected vertex may be reached by a path passing through it. + // Stop here. Also, Succ may be visited by other predecessors but the + // first visit has the optimal path. Stop if Succ has been visited. + if succ_level <= ncd_level + 1 + || !insertion_info.visited.insert(succ_node.clone()) + { + continue; + } + + if succ_level > current_level { + // succ is unaffected, but it may (transitively) expand to affected vertices. + // Store it in unaffected_on_current_level + log::trace!("marking visiting not affected {succ}"); + unaffected_on_current_level.push(succ_node.clone()); + } else { + // The condition is satisfied (Succ is affected). Add Succ to the + // bucket queue. + log::trace!("add {succ} to a bucket"); + insertion_info.bucket.push(succ_node.clone().into()); + } + } + + if unaffected_on_current_level.is_empty() { + break; + } + + if let Some(n) = unaffected_on_current_level.pop() { + node = n; + } else { + break; + } + log::trace!("next: {node}"); + } + } + + // Finish by updating immediate dominators and levels. + Self::update_insertion(tree, bui, ncd, &insertion_info); + } + + pub fn delete_edge( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Option<&BlockRef>, + to: Option<&BlockRef>, + ) { + let from = from.expect("cannot disconnect virtual node"); + let to = to.expect("cannot disconnect virtual node"); + + log::trace!("deleting edge {from} -> {to}"); + + // Deletion in an unreachable subtree -- nothing to do. + let Some(from_node) = tree.get(Some(from)) else { + return; + }; + + let Some(to_node) = tree.get(Some(to)) else { + log::trace!("to {to} already unreachable -- there is no edge to delete",); + return; + }; + + let ncd_block = tree.find_nearest_common_dominator(from, to); + let ncd = tree.get(ncd_block.as_ref()); + + // If to dominates from -- nothing to do. + if Some(&to_node) != ncd.as_ref() { + tree.mark_invalid(); + + let to_idom = to_node.idom(); + log::trace!( + "ncd {}, to_idom {}", + DisplayOptional(ncd.as_ref()), + DisplayOptional(to_idom.as_ref()) + ); + + // To remains reachable after deletion (based on caption under figure 4, from [2]) + if (Some(&from_node) != to_idom.as_ref()) + || Self::has_proper_support(tree, bui, &to_node) + { + Self::delete_reachable(tree, bui, from_node, to_node) + } else { + Self::delete_unreachable(tree, bui, to_node) + } + + if const { IS_POST_DOM } { + SemiNCA::::update_roots_after_update( + unsafe { + core::mem::transmute::<&mut DomTreeBase, &mut DomTreeBase>( + tree, + ) + }, + bui.map(|bui| unsafe { + core::mem::transmute::<&BatchUpdateInfo, &BatchUpdateInfo>( + bui, + ) + }), + ); + } + } + } + + /// Handles deletions that leave destination nodes reachable. + fn delete_reachable( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Rc, + to: Rc, + ) { + log::trace!("deleting reachable {from} -> {to} - rebuilding subtree.."); + + // Find the top of the subtree that needs to be rebuilt (based on the lemma 2.6 from [2]) + let to_idom = + tree.find_nearest_common_dominator(from.block().unwrap(), to.block().unwrap()); + assert!(to_idom.is_some() || tree.is_post_dominator()); + let to_idom_node = tree.get(to_idom.as_ref()).unwrap(); + let prev_idom_subtree = to_idom_node.idom(); + // Top of the subtree to rebuild is the root node. Rebuild the tree from scratch. + let Some(prev_idom_subtree) = prev_idom_subtree else { + log::trace!("the entire tree needs to be rebuilt"); + Self::compute_from_scratch(tree, bui.cloned()); + return; + }; + + // Only visit nodes in the subtree starting at `to` + let level = to_idom_node.level(); + let descend_below = |_: Option<&BlockRef>, to: Option<&BlockRef>| -> bool { + tree.get(to).unwrap().level() > level + }; + + log::trace!("top of subtree {to_idom_node}"); + + let mut snca = Self::new(bui.cloned()); + snca.run_dfs::(to_idom.as_ref(), 0, descend_below, 0, None); + log::trace!("running Semi-NCA"); + snca.run(); + snca.reattach_existing_subtree(tree, prev_idom_subtree); + } + + /// Handle deletions that make destination node unreachable. + /// + /// (Based on the lemma 2.7 from the [2].) + fn delete_unreachable( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + to: Rc, + ) { + log::trace!("deleting unreachable subtree {to}"); + assert!(to.block().is_some()); + + if IS_POST_DOM { + // Deletion makes a region reverse-unreachable and creates a new root. + // + // Simulate that by inserting an edge from the virtual root to `to` and adding it as a new + // root. + log::trace!("deletion made a region reverse-unreachable, adding new root {to}"); + tree.roots_mut().push(to.block().cloned()); + Self::insert_reachable(tree, bui, tree.get(None).unwrap(), to); + return; + } + + let mut affected_queue = SmallVec::<[Option; 16]>::default(); + let level = to.level(); + + // Traverse destination node's descendants with greater level in the tree + // and collect visited nodes. + let descend_and_collect = |_: Option<&BlockRef>, to: Option<&BlockRef>| -> bool { + let node = tree.get(to).unwrap(); + if node.level() > level { + return true; + } + let to = to.cloned(); + if !affected_queue.contains(&to) { + affected_queue.push(to) + } + false + }; + + let mut snca = Self::new(bui.cloned()); + let last_dfs_num = snca.run_dfs::(to.block(), 0, descend_and_collect, 0, None); + + let mut min_node = to.clone(); + // Identify the top of the subtree to rebuild by finding the NCD of all the affected nodes. + for n in affected_queue { + let node = tree.get(n.as_ref()).unwrap(); + let ncd_block = + tree.find_nearest_common_dominator(node.block().unwrap(), to.block().unwrap()); + assert!(ncd_block.is_some() || tree.is_post_dominator()); + let ncd = tree.get(ncd_block.as_ref()).unwrap(); + log::trace!( + "processing affected node {node} with: nearest common dominator = {ncd}, min node \ + = {min_node}" + ); + if ncd != node && ncd.level() < min_node.level() { + min_node = ncd; + } + } + + // Root reached, rebuild the whole tree from scratch. + if min_node.idom().is_none() { + log::trace!("the entire tree needs to be rebuilt"); + Self::compute_from_scratch(tree, bui.cloned()); + return; + } + + // Erase the unreachable subtree in reverse preorder to process all children before deleting + // their parent. + for i in (1..=(last_dfs_num as usize)).rev() { + let n = &snca.num_to_node[i]; + log::trace!("erasing node {n}"); + tree.erase_node(n); + } + + // The affected subtree start at the `to` node -- there's no extra work to do. + if min_node == to { + return; + } + + log::trace!("delete_unreachable: running dfs with min_node = {min_node}"); + let min_level = min_node.level(); + let prev_idom = min_node.idom().unwrap(); + snca.clear(); + + // Identify nodes that remain in the affected subtree. + let descend_below = |_: Option<&BlockRef>, to: Option<&BlockRef>| -> bool { + let to_node = tree.get(to); + to_node.is_some_and(|to_node| to_node.level() > min_level) + }; + snca.run_dfs::(min_node.block(), 0, descend_below, 0, None); + + log::trace!("previous idom(min_node) = {prev_idom}"); + log::trace!("running Semi-NCA"); + + // Rebuild the remaining part of affected subtree. + snca.run(); + snca.reattach_existing_subtree(tree, prev_idom); + } + + pub fn apply_updates( + tree: &mut DomTreeBase, + mut pre_view_cfg: cfg::CfgDiff, + post_view_cfg: cfg::CfgDiff, + ) { + // Note: the `post_view_cfg` is only used when computing from scratch. It's data should + // already included in the `pre_view_cfg` for incremental updates. + let num_updates = pre_view_cfg.num_legalized_updates(); + match num_updates { + 0 => (), + 1 => { + // Take the fast path for a single update and avoid running the batch update machinery. + let update = pre_view_cfg.pop_update_for_incremental_updates(); + let bui = if post_view_cfg.is_empty() { + None + } else { + Some(BatchUpdateInfo::new(post_view_cfg.clone(), Some(post_view_cfg))) + }; + match update.kind() { + cfg::CfgUpdateKind::Insert => { + Self::insert_edge( + tree, + bui.as_ref(), + Some(update.from().clone()), + Some(update.to().clone()), + ); + } + cfg::CfgUpdateKind::Delete => { + Self::delete_edge( + tree, + bui.as_ref(), + Some(update.from()), + Some(update.to()), + ); + } + } + } + _ => { + let mut bui = BatchUpdateInfo::new(pre_view_cfg, Some(post_view_cfg)); + // Recalculate the DominatorTree when the number of updates exceeds a threshold, + // which usually makes direct updating slower than recalculation. We select this + // threshold proportional to the size of the DominatorTree. The constant is selected + // by choosing the one with an acceptable performance on some real-world inputs. + + // Make unittests of the incremental algorithm work + // TODO(pauls): review this + if tree.len() <= 100 { + if bui.num_legalized > tree.len() { + Self::compute_from_scratch(tree, Some(bui.clone())); + } + } else if bui.num_legalized > tree.len() / 40 { + Self::compute_from_scratch(tree, Some(bui.clone())); + } + + // If the DominatorTree was recalculated at some point, stop the batch updates. Full + // recalculations ignore batch updates and look at the actual CFG. + for _ in 0..bui.num_legalized { + if bui.is_recalculated { + break; + } + + Self::apply_next_update(tree, &mut bui); + } + } + } + } + + fn apply_next_update( + tree: &mut DomTreeBase, + bui: &mut BatchUpdateInfo, + ) { + // Popping the next update, will move the `pre_view_cfg` to the next snapshot. + let current_update = bui.pre_cfg_view.pop_update_for_incremental_updates(); + log::trace!("applying update: {current_update:?}"); + + match current_update.kind() { + cfg::CfgUpdateKind::Insert => { + Self::insert_edge( + tree, + Some(bui), + Some(current_update.from().clone()), + Some(current_update.to().clone()), + ); + } + cfg::CfgUpdateKind::Delete => { + Self::delete_edge( + tree, + Some(bui), + Some(current_update.from()), + Some(current_update.to()), + ); + } + } + } + + pub fn compute(tree: &mut DomTreeBase) { + Self::compute_from_scratch(tree, None); + } + + pub fn compute_from_scratch( + tree: &mut DomTreeBase, + mut bui: Option>, + ) { + use crate::cfg::GraphDiff; + + tree.reset(); + + // If the update is using the actual CFG, `bui` is `None`. If it's using a view, `bui` is + // `Some` and the `pre_cfg_view` is used. When calculating from scratch, make the + // `pre_cfg_view` equal to the `post_cfg_view`, so `post` is used. + let post_view_bui = bui.clone().and_then(|mut bui| { + if !bui.post_cfg_view.is_empty() { + bui.pre_cfg_view = bui.post_cfg_view.clone(); + Some(bui) + } else { + None + } + }); + + // This is rebuilding the whole tree, not incrementally, but `post_view_bui` is used in case + // the caller needs a dominator tree update with a cfg view + let mut snca = Self::new(post_view_bui); + + // Step 0: Number blocks in depth-first order, and initialize variables used in later stages + // of the algorithm. + let roots = Self::find_roots(tree, bui.as_ref()); + *tree.roots_mut() = roots; + snca.do_full_dfs_walk(tree, always_descend); + + snca.run(); + if let Some(bui) = bui.as_mut() { + bui.is_recalculated = true; + log::trace!("dominator tree recalculated, skipping future batch updates"); + } + + if tree.roots().is_empty() { + return; + } + + // Add a node for the root. If the tree is a post-dominator tree, it will be the virtual + // exit (denoted by a block ref of `None`), which post-dominates all real exits (including + // multiple exit blocks, infinite loops). + let root = if IS_POST_DOM { + None + } else { + tree.roots()[0].clone() + }; + + let new_root = tree.create_node(root, None); + tree.set_root(new_root); + let root_node = tree.root_node().expect("expected root node"); + snca.attach_new_subtree(tree, root_node); + } + + fn update_insertion( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + ncd: Rc, + insertion_info: &InsertionInfo, + ) { + log::trace!("updating nearest common dominator = {ncd}"); + + for to_node in insertion_info.affected.iter() { + log::trace!("idom({to_node}) = {ncd}"); + to_node.set_idom(Some(ncd.clone())); + } + + if IS_POST_DOM { + SemiNCA::::update_roots_after_update( + unsafe { + core::mem::transmute::<&mut DomTreeBase, &mut DomTreeBase>( + tree, + ) + }, + bui.map(|bui| unsafe { + core::mem::transmute::<&BatchUpdateInfo, &BatchUpdateInfo>( + bui, + ) + }), + ); + } + } + + /// Connects nodes that become reachable with an insertion + fn compute_unreachable_dominators( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + root: &BlockRef, + incoming: Rc, + discovered_connecting_edges: &mut SmallVec<[(Option, Rc); 8]>, + ) { + assert!(tree.get(Some(root)).is_none(), "root must not be reachable"); + + // Visit only previously unreachable nodes + let unreachable_descender = |from: Option<&BlockRef>, to: Option<&BlockRef>| -> bool { + let to_node = tree.get(to); + match to_node { + None => true, + Some(to_node) => { + discovered_connecting_edges.push((from.cloned(), to_node)); + false + } + } + }; + + let mut snca = Self::new(bui.cloned()); + snca.run_dfs::(Some(root), 0, unreachable_descender, 0, None); + snca.run(); + snca.attach_new_subtree(tree, incoming); + + log::trace!("after adding unreachable nodes"); + } +} + +/// Verification +impl SemiNCA { + pub fn verify_roots(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_reachability(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_levels(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_dfs_numbers(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_parent_property(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_sibling_property(&self, _tree: &DomTreeBase) -> bool { + true + } +} + +impl SemiNCA { + /// Determines if some existing root becomes reverse-reachable after the insertion. + /// + /// Rebuilds the whole tree if that situation happens. + fn update_roots_before_insertion( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + to: Rc, + ) -> bool { + // Destination node is not attached to the virtual root, so it cannot be a root + if !tree.is_virtual_root(&to.idom().unwrap()) { + return false; + } + + if !tree.roots().contains(&to.block().cloned()) { + // To is not a root, nothing to update + return false; + } + + log::trace!("after the insertion, {to} is no longer a root - rebuilding the tree.."); + + Self::compute_from_scratch(tree, bui.cloned()); + true + } + + /// Updates the set of roots after insertion or deletion. + /// + /// This ensures that roots are the same when after a series of updates and when the tree would + /// be built from scratch. + fn update_roots_after_update( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + ) { + // The tree has only trivial roots -- nothing to update. + if !tree.roots().iter().any(|n| Self::has_forward_successors(n.as_ref(), bui)) { + return; + } + + // Recalculate the set of roots + let roots = Self::find_roots(tree, bui); + if !is_permutation(tree.roots(), &roots) { + // The roots chosen in the CFG have changed. This is because the incremental algorithm + // does not really know or use the set of roots and can make a different (implicit) + // decision about which node within an infinite loop becomes a root. + log::trace!( + "roots are different in updated trees - the entire tree needs to be rebuilt" + ); + // It may be possible to update the tree without recalculating it, but we do not know + // yet how to do it, and it happens rarely in practice. + Self::compute_from_scratch(tree, bui.cloned()); + } + } +} + +fn is_permutation(a: &[Option], b: &[Option]) -> bool { + if a.len() != b.len() { + return false; + } + let set = crate::adt::SmallSet::<_, 4>::from_iter(a.iter().cloned()); + for n in b { + if !set.contains(n) { + return false; + } + } + true +} + +#[doc(hidden)] +#[inline(always)] +const fn always_descend(_: Option<&BlockRef>, _: Option<&BlockRef>) -> bool { + true +} diff --git a/hir2/src/ir/dominance/traits.rs b/hir2/src/ir/dominance/traits.rs new file mode 100644 index 000000000..4d6f83e34 --- /dev/null +++ b/hir2/src/ir/dominance/traits.rs @@ -0,0 +1,193 @@ +use super::{DominanceInfo, PostDominanceInfo}; +use crate::{Block, Operation, Value}; + +/// This trait is implemented on a type which has a dominance relationship with `Rhs`. +pub trait Dominates { + /// Returns true if `self` dominates `other`. + /// + /// In cases where `Rhs = Self`, implementations should return true when `self == other`. + /// + /// For a stricter form of dominance, use [Dominates::properly_dominates]. + fn dominates(&self, other: &Rhs, dom_info: &DominanceInfo) -> bool; + /// Returns true if `self` properly dominates `other`. + /// + /// In cases where `Rhs = Self`, implementations should return false when `self == other`. + fn properly_dominates(&self, other: &Rhs, dom_info: &DominanceInfo) -> bool; +} + +/// This trait is implemented on a type which has a post-dominance relationship with `Rhs`. +pub trait PostDominates { + /// Returns true if `self` post-dominates `other`. + /// + /// In cases where `Rhs = Self`, implementations should return true when `self == other`. + /// + /// For a stricter form of dominance, use [PostDominates::properly_dominates]. + fn post_dominates(&self, other: &Rhs, dom_info: &PostDominanceInfo) -> bool; + /// Returns true if `self` properly post-dominates `other`. + /// + /// In cases where `Rhs = Self`, implementations should return false when `self == other`. + fn properly_post_dominates(&self, other: &Rhs, dom_info: &PostDominanceInfo) -> bool; +} + +/// The dominance relationship between two blocks. +impl Dominates for Block { + /// Returns true if `a == b` or `a` properly dominates `b`. + fn dominates(&self, other: &Self, dom_info: &DominanceInfo) -> bool { + core::ptr::addr_eq(self, other) || self.properly_dominates(other, dom_info) + } + + /// Returns true if `a != b` and: + /// + /// * `a` is an ancestor of `b` + /// * The region containing `a` also contains `b` or some ancestor of `b`, and `a` dominates + /// that block in that kind of region. + /// * In SSA regions, `a` properly dominates `b` if all control flow paths from the entry + /// block to `b`, flow through `a`. + /// * In graph regions, all blocks dominate all other blocks. + fn properly_dominates(&self, other: &Self, dom_info: &DominanceInfo) -> bool { + dom_info.info().properly_dominates(&self.as_block_ref(), &other.as_block_ref()) + } +} + +/// The post-dominance relationship between two blocks. +impl PostDominates for Block { + fn post_dominates(&self, other: &Self, dom_info: &PostDominanceInfo) -> bool { + core::ptr::addr_eq(self, other) || self.properly_post_dominates(other, dom_info) + } + + /// Returns true if `a != b` and: + /// + /// * `a` is an ancestor of `b` + /// * The region containing `a` also contains `b` or some ancestor of `b`, and `a` dominates + /// that block in that kind of region. + /// * In SSA regions, `a` properly post-dominates `b` if all control flow paths from `b` to + /// an exit node, flow through `a`. + /// * In graph regions, all blocks post-dominate all other blocks. + fn properly_post_dominates(&self, other: &Self, dom_info: &PostDominanceInfo) -> bool { + dom_info.info().properly_dominates(&self.as_block_ref(), &other.as_block_ref()) + } +} + +/// The dominance relationship for operations +impl Dominates for Operation { + fn dominates(&self, other: &Self, dom_info: &DominanceInfo) -> bool { + core::ptr::addr_eq(self, other) || self.properly_dominates(other, dom_info) + } + + /// Returns true if `a != b`, and: + /// + /// * `a` and `b` are in the same block, and `a` properly dominates `b` within the block, or + /// * the block that contains `a` properly dominates the block that contains `b`. + /// * `b` is enclosed in a region of `a` + /// + /// In any SSA region, `a` dominates `b` in the same block if `a` precedes `b`. In a graph + /// region all operations in a block dominate all other operations in the same block. + fn properly_dominates(&self, other: &Self, dom_info: &DominanceInfo) -> bool { + let a = self.as_operation_ref(); + let b = other.as_operation_ref(); + dom_info.properly_dominates_with_options(&a, &b, /*enclosing_op_ok= */ true) + } +} + +/// The post-dominance relationship for operations +impl PostDominates for Operation { + fn post_dominates(&self, other: &Self, dom_info: &PostDominanceInfo) -> bool { + core::ptr::addr_eq(self, other) || self.properly_post_dominates(other, dom_info) + } + + /// Returns true if `a != b`, and: + /// + /// * `a` and `b` are in the same block, and `a` properly post-dominates `b` within the block + /// * the block that contains `a` properly post-dominates the block that contains `b`. + /// * `b` is enclosed in a region of `a` + /// + /// In any SSA region, `a` post-dominates `b` in the same block if `b` precedes `a`. In a graph + /// region all operations in a block post-dominate all other operations in the same block. + fn properly_post_dominates(&self, other: &Self, dom_info: &PostDominanceInfo) -> bool { + let a_block = self.parent().expect("`self` must be in a block"); + let mut b_block = other.parent().expect("`other` must be in a block"); + + // An instruction post dominates, but does not properly post-dominate itself unless this is + // a graph region. + if core::ptr::addr_eq(self, other) { + return !a_block.borrow().has_ssa_dominance(); + } + + // If these ops are in different regions, then normalize one into the other. + let a_region = a_block.borrow().parent(); + let b_region = b_block.borrow().parent(); + let a = self.as_operation_ref(); + let mut b = other.as_operation_ref(); + if a_region != b_region { + // Walk up `b`'s region tree until we find an operation in `a`'s region that encloses + // it. If this fails, then we know there is no post-dominance relation. + let Some(found) = a_region.as_ref().and_then(|r| r.borrow().find_ancestor_op(&b)) + else { + return false; + }; + b = found; + b_block = b.borrow().parent().unwrap(); + assert!(b_block.borrow().parent() == a_region); + + // If `a` encloses `b`, then we consider it to post-dominate. + if a == b { + return true; + } + } + + // Ok, they are in the same region. If they are in the same block, check if `b` is before + // `a` in the block. + if a_block == b_block { + // Dominance changes based on the region type + return if a_block.borrow().has_ssa_dominance() { + // If the blocks are the same, then check if `b` is before `a` in the block. + b.borrow().is_before_in_block(&a) + } else { + true + }; + } + + // If the blocks are different, check if `a`'s block post-dominates `b`'s + dom_info + .info() + .dominance(&a_region.unwrap()) + .properly_dominates(Some(&a_block), Some(&b_block)) + } +} + +/// The dominance relationship between a value and an operation, e.g. between a definition of a +/// value and a user of that same value. +impl Dominates for dyn Value { + /// Return true if the definition of `self` dominates a use by operation `other`. + fn dominates(&self, other: &Operation, dom_info: &DominanceInfo) -> bool { + self.get_defining_op().is_some_and(|op| op == other.as_operation_ref()) + || self.properly_dominates(other, dom_info) + } + + /// Returns true if the definition of `self` properly dominates `other`. + /// + /// This requires the value to either be a block argument, where the block containing `other` + /// is dominated by the block defining `self`, OR that the value is an operation result, and + /// the defining op of `self` properly dominates `other`. + /// + /// If the defining op of `self` encloses `b` in one of its regions, `a` does not dominate `b`. + fn properly_dominates(&self, other: &Operation, dom_info: &DominanceInfo) -> bool { + // Block arguments properly dominate all operations in their own block, so we use a + // dominates check here, not a properly_dominates check. + if let Some(block_arg) = self.downcast_ref::() { + return block_arg + .owner() + .borrow() + .dominates(&other.parent().unwrap().borrow(), dom_info); + } + + // `a` properly dominates `b` if the operation defining `a` properly dominates `b`, but `a` + // does not itself enclose `b` in one of its regions. + let defining_op = self.get_defining_op().unwrap(); + dom_info.properly_dominates_with_options( + &defining_op, + &other.as_operation_ref(), + /*enclosing_op_ok= */ false, + ) + } +} diff --git a/hir2/src/ir/dominance/tree.rs b/hir2/src/ir/dominance/tree.rs new file mode 100644 index 000000000..decab0cd4 --- /dev/null +++ b/hir2/src/ir/dominance/tree.rs @@ -0,0 +1,911 @@ +use alloc::rc::Rc; +use core::{ + cell::{Cell, RefCell}, + fmt, + num::NonZeroU32, +}; + +use smallvec::{smallvec, SmallVec}; + +use super::{BatchUpdateInfo, SemiNCA}; +use crate::{ + cfg::{self, Graph, Inverse, InvertibleGraph}, + formatter::DisplayOptional, + BlockRef, EntityWithId, RegionRef, +}; + +#[derive(Debug, thiserror::Error)] +pub enum DomTreeError { + /// Tried to compute a dominator tree for an empty region + #[error("unable to create dominance tree for empty region")] + EmptyRegion, +} + +/// The level of verification to use with [DominatorTreeBase::verify] +pub enum DomTreeVerificationLevel { + /// Checks basic tree structure and compares with a freshly constructed tree + /// + /// O(n^2) time worst case, but is faster in practice. + Fast, + /// Checks if the tree is correct, but compares it to a freshly constructed tree instead of + /// checking the sibling property. + /// + /// O(n^2) time. + Basic, + /// Verifies if the tree is correct by making sure all the properties, including the parent + /// and sibling property, hold. + /// + /// O(n^3) time. + Full, +} + +/// A forward dominance tree +pub type DominanceTree = DomTreeBase; + +/// A post (backward) dominance tree +pub type PostDominanceTree = DomTreeBase; + +pub type DomTreeRoots = SmallVec<[Option; 4]>; + +/// A dominator tree implementation that abstracts over the type of dominance it represents. +pub struct DomTreeBase { + /// The roots from which dominance is traced. + /// + /// For forward dominance trees, there is always a single root. For post-dominance trees, there + /// may be multiple, one for each exit from the region. + roots: DomTreeRoots, + /// The nodes represented in this dominance tree + nodes: SmallVec<[(Option, Rc); 64]>, + /// The root dominance tree node. + root: Option>, + /// The parent region for which this dominance tree was computed + parent: RegionRef, + /// Whether this dominance tree is valid (true), or outdated (false) + valid: Cell, + /// A counter for expensive queries that may cause us to perform some extra work in order to + /// speed up those queries after a certain point. + slow_queries: Cell, +} + +/// A node in a [DomTreeBase]. +pub struct DomTreeNode { + /// The block represented by this node + block: Option, + /// The immediate dominator of this node, if applicable + idom: Cell>>, + /// The children of this node in the tree + children: RefCell; 4]>>, + /// The depth of this node in the tree + level: Cell, + /// The DFS visitation order (forward) + num_in: Cell>, + /// The DFS visitation order (backward) + num_out: Cell>, +} + +impl fmt::Display for DomTreeNode { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", DisplayOptional(self.block.as_ref().map(|b| b.borrow().id()).as_ref())) + } +} + +impl fmt::Debug for DomTreeNode { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use crate::EntityWithId; + + f.debug_struct("DomTreeNode") + .field_with("block", |f| match self.block.as_ref() { + None => f.write_str("None"), + Some(block_ref) => write!(f, "{}", block_ref.borrow().id()), + }) + .field("idom", unsafe { &*self.idom.as_ptr() }) + .field_with("children", |f| { + f.debug_list().entries(self.children.borrow().iter()).finish() + }) + .field("level", &self.level.get()) + .field("num_in", &self.num_in.get()) + .field("num_out", &self.num_out.get()) + .finish() + } +} + +/// An iterator over nodes in a dominance tree produced by a depth-first, pre-order traversal +pub type PreOrderDomTreeIter = cfg::PreOrderIter>; +/// An iterator over nodes in a dominance tree produced by a depth-first, post-order traversal +pub type PostOrderDomTreeIter = cfg::PostOrderIter>; + +impl Graph for Rc { + type ChildEdgeIter = DomTreeSuccessorIter; + type ChildIter = DomTreeSuccessorIter; + type Edge = Rc; + type Node = Rc; + + fn size(&self) -> usize { + self.children.borrow().len() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + DomTreeSuccessorIter::new(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + DomTreeSuccessorIter::new(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + // The edge is the child node + edge + } + + fn entry_node(&self) -> Self::Node { + Rc::clone(self) + } +} + +pub struct DomTreeSuccessorIter { + node: Rc, + num_children: usize, + index: usize, +} +impl DomTreeSuccessorIter { + pub fn new(node: Rc) -> Self { + let num_children = node.num_children(); + Self { + node, + num_children, + index: 0, + } + } +} +impl core::iter::FusedIterator for DomTreeSuccessorIter {} +impl ExactSizeIterator for DomTreeSuccessorIter { + #[inline] + fn len(&self) -> usize { + self.num_children.saturating_sub(self.index) + } + + #[inline] + fn is_empty(&self) -> bool { + self.index >= self.num_children + } +} +impl Iterator for DomTreeSuccessorIter { + type Item = Rc; + + fn next(&mut self) -> Option { + if self.index >= self.num_children { + return None; + } + let index = self.index; + self.index += 1; + Some(self.node.children.borrow()[index].clone()) + } +} +impl DoubleEndedIterator for DomTreeSuccessorIter { + fn next_back(&mut self) -> Option { + if self.num_children == 0 { + return None; + } + let index = self.num_children; + self.num_children -= 1; + Some(self.node.children.borrow()[index].clone()) + } +} + +impl DomTreeNode { + /// Create a new node for `block`, with the specified immediate dominator. + /// + /// If `block` is `None`, this must be a node in a post-dominator tree, and the resulting node + /// is a virtual node that post-dominates all nodes in the tree + pub fn new(block: Option, idom: Option>) -> Self { + Self { + block, + idom: Cell::new(idom), + children: Default::default(), + level: Cell::new(0), + num_in: Cell::new(None), + num_out: Cell::new(None), + } + } + + /// Build this node with the specified immediate dominator. + pub fn with_idom(self, idom: Rc) -> Self { + self.level.set(idom.level.get() + 1); + self.idom.set(Some(idom)); + self + } + + pub fn block(&self) -> Option<&BlockRef> { + self.block.as_ref() + } + + pub fn idom(&self) -> Option> { + unsafe { &*self.idom.as_ptr() }.clone() + } + + pub(super) fn set_idom(&self, idom: Option>) { + self.idom.set(idom); + } + + #[inline(always)] + pub fn level(&self) -> u32 { + self.level.get() + } + + pub fn is_leaf(&self) -> bool { + self.children.borrow().is_empty() + } + + pub fn num_children(&self) -> usize { + self.children.borrow().len() + } + + pub fn add_child(&self, child: Rc) { + self.children.borrow_mut().push(child); + } + + pub fn clear_children(&self) { + self.children.borrow_mut().clear(); + } + + /// Returns true if `self` is dominated by `other` in the tree. + pub fn is_dominated_by(&self, other: &Self) -> bool { + assert!( + self.num_in.get().is_some() && other.num_in.get().is_some(), + "you forgot to call update_dfs_numbers" + ); + self.num_in.get().is_some_and(|a| other.num_in.get().is_some_and(|b| a >= b)) + && self.num_out.get().is_some_and(|a| other.num_in.get().is_some_and(|b| a <= b)) + } + + /// Recomputes this node's depth in the dominator tree + fn update_level(self: Rc) { + let idom_level = self.idom().expect("expected to have an immediate dominator").level(); + if self.level() == idom_level + 1 { + return; + } + + let mut stack = SmallVec::<[Rc; 64]>::from_iter([self.clone()]); + while let Some(current) = stack.pop() { + current.level.set(current.idom().unwrap().level() + 1); + for child in current.children.borrow().iter() { + assert!(child.idom().is_some()); + if child.level() != child.idom().unwrap().level() + 1 { + stack.push(Rc::clone(child)); + } + } + } + } +} + +impl Eq for DomTreeNode {} +impl PartialEq for DomTreeNode { + fn eq(&self, other: &Self) -> bool { + self.block == other.block + } +} + +impl DomTreeBase { + #[inline] + pub fn root(&self) -> &BlockRef { + self.roots[0].as_ref().unwrap() + } +} + +impl DomTreeBase { + /// Compute a dominator tree for `region` + pub fn new(region: RegionRef) -> Result { + let entry = region.borrow().entry_block_ref().ok_or(DomTreeError::EmptyRegion)?; + let root = Rc::new(DomTreeNode::new(Some(entry.clone()), None)); + let nodes = smallvec![(Some(entry.clone()), root.clone())]; + let roots = smallvec![Some(entry)]; + + let mut this = Self { + parent: region, + root: Some(root), + roots, + nodes, + valid: Cell::new(false), + slow_queries: Cell::new(0), + }; + + this.compute(); + + Ok(this) + } + + #[inline] + pub fn parent(&self) -> &RegionRef { + &self.parent + } + + pub fn len(&self) -> usize { + self.nodes.len() + } + + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + #[inline] + pub fn num_roots(&self) -> usize { + self.roots.len() + } + + #[inline] + pub fn roots(&self) -> &[Option] { + &self.roots + } + + #[inline] + pub fn roots_mut(&mut self) -> &mut DomTreeRoots { + &mut self.roots + } + + pub(super) fn set_root(&mut self, root: Rc) { + self.root = Some(root); + } + + /// Returns true if this tree is a post-dominance tree. + #[inline(always)] + pub const fn is_post_dominator(&self) -> bool { + IS_POST_DOM + } + + pub(super) fn mark_invalid(&self) { + self.valid.set(false); + } + + /// Get the node for `block`, if one exists in the tree. + /// + /// Use `None` to get the virtual node, if this is a post-dominator tree + pub fn get(&self, block: Option<&BlockRef>) -> Option> { + self.node_index(block) + .map(|index| unsafe { self.nodes.get_unchecked(index).1.clone() }) + } + + #[inline] + fn node_index(&self, block: Option<&BlockRef>) -> Option { + assert!( + block.is_none_or(|block| block + .borrow() + .parent() + .is_some_and(|parent| parent == self.parent)), + "cannot get dominance info of block with different parent" + ); + if let Some(block) = block { + self.nodes.iter().position(|(b, _)| b.as_ref().is_some_and(|b| b == block)) + } else { + self.nodes.iter().position(|(b, _)| b.is_none()) + } + } + + /// Returns the entry node for the CFG of the region. + /// + /// However, if this tree represents the post-dominance relations for a region, this root may be + /// a node with `block` set to `None`. This is the case when there are multiple exit nodes from + /// a particular function. Consumers of post-dominance information must be capable of dealing + /// with this possibility. + pub fn root_node(&self) -> Option> { + self.root.clone() + } + + /// Get all nodes dominated by `r`, including `r` itself + pub fn get_descendants(&self, r: &BlockRef) -> SmallVec<[BlockRef; 2]> { + let mut results = SmallVec::default(); + let Some(rn) = self.get(Some(r)) else { + return results; + }; + let mut worklist = SmallVec::<[Rc; 8]>::default(); + worklist.push(rn); + + while let Some(n) = worklist.pop() { + let Some(n_block) = n.block() else { + continue; + }; + results.push(n_block.clone()); + worklist.extend(n.children.borrow().iter().cloned()); + } + + results + } + + /// Return true if `a` is dominated by the entry block of the region containing it. + pub fn is_reachable_from_entry(&self, a: &BlockRef) -> bool { + assert!(!self.is_post_dominator(), "unimplemented for post dominator trees"); + + self.get(Some(a)).is_some() + } + + #[inline] + pub const fn is_reachable_from_entry_node(&self, a: Option<&Rc>) -> bool { + a.is_some() + } + + /// Returns true if and only if `a` dominates `b` and `a != b` + /// + /// Note that this is not a constant time operation. + pub fn properly_dominates(&self, a: Option<&BlockRef>, b: Option<&BlockRef>) -> bool { + if a == b { + return false; + } + let a = self.get(a); + let b = self.get(b); + if a.is_none() || b.is_none() { + return false; + } + self.properly_dominates_node(a, b) + } + + /// Returns true if and only if `a` dominates `b` and `a != b` + /// + /// Note that this is not a constant time operation. + pub fn properly_dominates_node( + &self, + a: Option>, + b: Option>, + ) -> bool { + a != b && self.dominates_node(a, b) + } + + /// Returns true iff `a` dominates `b`. + /// + /// Note that this is not a constant time operation + pub fn dominates(&self, a: Option<&BlockRef>, b: Option<&BlockRef>) -> bool { + if a == b { + return true; + } + let a = self.get(a); + let b = self.get(b); + self.dominates_node(a, b) + } + + /// Returns true iff `a` dominates `b`. + /// + /// Note that this is not a constant time operation + pub fn dominates_node(&self, a: Option>, b: Option>) -> bool { + // A trivially dominates itself + if a == b { + return true; + } + + // An unreachable node is dominated by anything + if b.is_none() { + return true; + } + + // And dominates nothing. + if a.is_none() { + return false; + } + + let a = a.unwrap(); + let b = b.unwrap(); + + if b.idom().is_some_and(|idom| idom == a) { + return true; + } + + if a.idom().is_some_and(|idom| idom == b) { + return false; + } + + // A can only dominate B if it is higher in the tree + if a.level() >= b.level() { + return false; + } + + if self.valid.get() { + return b.is_dominated_by(&a); + } + + // If we end up with too many slow queries, just update the DFS numbers on the assumption + // that we are going to keep querying + self.slow_queries.set(self.slow_queries.get() + 1); + if self.slow_queries.get() > 32 { + self.update_dfs_numbers(); + return b.is_dominated_by(&a); + } + + self.dominated_by_slow_tree_walk(a, b) + } + + /// Finds the nearest block which is a common dominator of both `a` and `b` + pub fn find_nearest_common_dominator(&self, a: &BlockRef, b: &BlockRef) -> Option { + assert!(a.borrow().parent() == b.borrow().parent(), "two blocks are not in same region"); + + // If either A or B is an entry block then it is nearest common dominator (for forward + // dominators). + if !self.is_post_dominator() { + let parent = a.borrow().parent().unwrap(); + let entry = parent.borrow().entry_block_ref().unwrap(); + if a == &entry || b == &entry { + return Some(entry); + } + } + + let mut a = self.get(Some(a)).expect("'a' must be in the tree"); + let mut b = self.get(Some(b)).expect("'b' must be in the tree"); + + // Use level information to go up the tree until the levels match. Then continue going up + // until we arrive at the same node. + while a != b { + if a.level() < b.level() { + core::mem::swap(&mut a, &mut b); + } + + a = a.idom().unwrap(); + } + + a.block().cloned() + } +} + +impl DomTreeBase { + pub fn insert_edge(&mut self, mut from: Option, mut to: Option) { + if self.is_post_dominator() { + core::mem::swap(&mut from, &mut to); + } + SemiNCA::::insert_edge(self, None, from, to) + } + + pub fn delete_edge(&mut self, mut from: Option, mut to: Option) { + if self.is_post_dominator() { + core::mem::swap(&mut from, &mut to); + } + SemiNCA::::delete_edge(self, None, from.as_ref(), to.as_ref()) + } + + pub fn apply_updates( + &mut self, + pre_view_cfg: cfg::CfgDiff, + post_view_cfg: cfg::CfgDiff, + ) { + SemiNCA::::apply_updates(self, pre_view_cfg, post_view_cfg); + } + + pub fn compute(&mut self) { + SemiNCA::::compute_from_scratch(self, None); + } + + pub fn compute_with_updates(&mut self, updates: impl ExactSizeIterator) { + // FIXME: Updated to use the PreViewCFG and behave the same as until now. + // This behavior is however incorrect; this actually needs the PostViewCFG. + let pre_view_cfg = cfg::CfgDiff::new(updates, true); + let bui = BatchUpdateInfo::new(pre_view_cfg, None); + SemiNCA::::compute_from_scratch(self, Some(bui)); + } + + pub fn verify(&self, level: DomTreeVerificationLevel) -> bool { + let snca = SemiNCA::new(None); + + // Simplest check is to compare against a new tree. This will also usefully print the old + // and ne3w trees, if they are different. + if !self.is_same_as_fresh_tree() { + return false; + } + + // Common checks to verify the properties of the tree. O(n log n) at worst. + if !snca.verify_roots(self) + || !snca.verify_reachability(self) + || !snca.verify_levels(self) + || !snca.verify_dfs_numbers(self) + { + return false; + } + + // Extra checks depending on verification level. Up to O(n^3) + match level { + DomTreeVerificationLevel::Basic => { + if !snca.verify_parent_property(self) { + return false; + } + } + DomTreeVerificationLevel::Full => { + if !snca.verify_parent_property(self) || !snca.verify_sibling_property(self) { + return false; + } + } + _ => (), + } + + true + } + + fn is_same_as_fresh_tree(&self) -> bool { + let fresh = Self::new(self.parent.clone()).unwrap(); + let is_same = self == &fresh; + if !is_same { + log::error!( + "{} is different than a freshly computed one!", + if IS_POST_DOM { + "post-dominator tree" + } else { + "dominator tree" + } + ); + log::error!("Current: {self}"); + log::error!("Fresh: {fresh}"); + } + + is_same + } + + pub fn is_virtual_root(&self, node: &DomTreeNode) -> bool { + self.is_post_dominator() && node.block.is_none() + } + + pub fn add_new_block(&mut self, block: BlockRef, idom: Option) -> Rc { + assert!(self.get(Some(&block)).is_none(), "block already in dominator tree"); + let idom = self.get(idom.as_ref()).expect("no immediate dominator specified for `idom`"); + self.mark_invalid(); + self.create_node(Some(block), Some(idom)) + } + + pub fn set_new_root(&mut self, block: BlockRef) -> Rc { + assert!(self.get(Some(&block)).is_none(), "block already in dominator tree"); + assert!(!self.is_post_dominator(), "cannot change root of post-dominator tree"); + + self.valid.set(false); + let node = self.create_node(Some(block.clone()), None); + if self.roots.is_empty() { + self.roots.push(Some(block)); + } else { + assert_eq!(self.roots.len(), 1); + let old_node = self.get(self.roots[0].as_ref()).unwrap(); + node.add_child(old_node.clone()); + old_node.idom.set(Some(node.clone())); + old_node.update_level(); + self.roots[0] = Some(block); + } + self.root = Some(node.clone()); + node + } + + pub fn change_immediate_dominator(&mut self, n: &BlockRef, idom: Option<&BlockRef>) { + let n = self.get(Some(n)).expect("expected `n` to be in tree"); + let idom = self.get(idom).expect("expected `idom` to be in tree"); + self.change_immediate_dominator_node(n, idom); + } + + pub fn change_immediate_dominator_node(&mut self, n: Rc, idom: Rc) { + self.valid.set(false); + n.idom.set(Some(idom)); + } + + /// Removes a node from the dominator tree. + /// + /// Block must not dominate any other blocks. + /// + /// Removes node from the children of its immediate dominator. Deletes dominator node associated + /// with `block`. + pub fn erase_node(&mut self, block: &BlockRef) { + let node_index = self.node_index(Some(block)).expect("removing node that isn't in tree"); + let node = unsafe { self.nodes.get_unchecked(node_index).1.clone() }; + assert!(node.is_leaf(), "node is not a leaf node"); + + self.valid.set(false); + + // Remove node from immediate dominator's children + if let Some(idom) = node.idom() { + idom.children.borrow_mut().retain(|child| child != &node); + } + + self.nodes.swap_remove(node_index); + + if !IS_POST_DOM { + return; + } + + // Remember to update PostDominatorTree roots + if let Some(root_index) = + self.roots.iter().position(|r| r.as_ref().is_some_and(|r| r == block)) + { + self.roots.swap_remove(root_index); + } + } + + /// Assign in and out numbers to the nodes while walking the dominator tree in DFS order. + pub fn update_dfs_numbers(&self) { + if self.valid.get() { + self.slow_queries.set(0); + return; + } + + let mut worklist = SmallVec::<[(Rc, usize); 32]>::default(); + let this_root = self.root_node().unwrap(); + + // Both dominators and postdominators have a single root node. In the case + // case of PostDominatorTree, this node is a virtual root. + this_root.num_in.set(NonZeroU32::new(1)); + worklist.push((this_root, 0)); + + let mut dfs_num = 2u32; + + while let Some((node, child_index)) = worklist.last_mut() { + // If we visited all of the children of this node, "recurse" back up the + // stack setting the DFOutNum. + if *child_index >= node.num_children() { + node.num_out.set(Some(unsafe { NonZeroU32::new_unchecked(dfs_num) })); + dfs_num += 1; + worklist.pop(); + } else { + // Otherwise, recursively visit this child. + let index = *child_index; + *child_index += 1; + let child = node.children.borrow()[index].clone(); + dfs_num += 1; + child.num_in.set(Some(unsafe { NonZeroU32::new_unchecked(dfs_num) })); + worklist.push((child, 0)); + } + } + + self.slow_queries.set(0); + self.valid.set(true); + } + + /// Reset the dominator tree state + pub fn reset(&mut self) { + self.nodes.clear(); + self.root.take(); + self.roots.clear(); + self.valid.set(false); + self.slow_queries.set(0); + } + + pub(super) fn create_node( + &mut self, + block: Option, + idom: Option>, + ) -> Rc { + let node = Rc::new(DomTreeNode::new(block.clone(), idom.clone())); + self.nodes.push((block, node.clone())); + if let Some(idom) = idom { + idom.add_child(node.clone()); + } + node + } + + /// `block` is split and now it has one successor. + /// + /// Update dominator tree to reflect the change. + pub fn split_block(&mut self, block: &BlockRef) { + if IS_POST_DOM { + self.split::>(block.clone()); + } else { + self.split::(block.clone()); + } + } + + // `block` is split and now it has one successor. Update dominator tree to reflect this change. + fn split(&mut self, block: ::Node) + where + G: InvertibleGraph, + { + let mut successors = G::children(block.clone()); + assert_eq!(successors.len(), 1, "`block` should have a single successor"); + + let succ = successors.next().unwrap(); + let predecessors = G::inverse_children(block.clone()).collect::>(); + + assert!(!predecessors.is_empty(), "expected at at least one predecessor"); + + let mut block_dominates_succ = true; + for pred in G::inverse_children(succ.clone()) { + if pred != block + && !self.dominates(Some(&succ), Some(&pred)) + && self.is_reachable_from_entry(&pred) + { + block_dominates_succ = false; + break; + } + } + + // Find `block`'s immediate dominator and create new dominator tree node for `block`. + let idom = predecessors.iter().find(|p| self.is_reachable_from_entry(p)).cloned(); + + // It's possible that none of the predecessors of `block` are reachable; + // in that case, `block` itself is unreachable, so nothing needs to be + // changed. + let Some(idom) = idom else { + return; + }; + + let idom = predecessors.iter().fold(idom, |idom, p| { + if self.is_reachable_from_entry(p) { + self.find_nearest_common_dominator(&idom, p).expect("expected idom") + } else { + idom + } + }); + + // Create the new dominator tree node... and set the idom of `block`. + let node = self.add_new_block(block.clone(), Some(idom)); + + // If NewBB strictly dominates other blocks, then it is now the immediate + // dominator of NewBBSucc. Update the dominator tree as appropriate. + if block_dominates_succ { + let succ_node = self.get(Some(&succ)).expect("expected 'succ' to be in dominator tree"); + self.change_immediate_dominator_node(succ_node, node); + } + } + + fn dominated_by_slow_tree_walk(&self, a: Rc, b: Rc) -> bool { + assert_ne!(a, b); + + let a_level = a.level(); + let mut b = b; + + // Don't walk nodes above A's subtree. When we reach A's level, we must + // either find A or be in some other subtree not dominated by A. + while let Some(b_idom) = b.idom() { + if b_idom.level() >= a_level { + // Walk up the tree + b = b_idom; + } + } + + b == a + } +} + +impl fmt::Display for DomTreeBase { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use core::fmt::Write; + + f.write_str("=============================--------------------------------\n")?; + if IS_POST_DOM { + f.write_str("Inorder PostDominator Tree: ")?; + } else { + f.write_str("Inorder Dominator Tree: ")?; + } + if !self.valid.get() { + write!(f, "DFS numbers invalid: {} slow queries.", self.slow_queries.get())?; + } + f.write_char('\n')?; + + // The postdom tree can have a `None` root if there are no returns. + if let Some(root_node) = self.root_node() { + print_dom_tree(root_node, 1, f)? + } + f.write_str("Roots: ")?; + for (i, block) in self.roots.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + if let Some(block) = block { + write!(f, "{block}")?; + } else { + f.write_str("")?; + } + } + f.write_char('\n') + } +} + +fn print_dom_tree( + node: Rc, + level: usize, + f: &mut core::fmt::Formatter<'_>, +) -> core::fmt::Result { + write!(f, "{: <1$}", "", level)?; + writeln!(f, "[{level}] {node}")?; + for child_node in node.children.borrow().iter().cloned() { + print_dom_tree(child_node, level + 1, f)?; + } + Ok(()) +} + +impl Eq for DomTreeBase {} +impl PartialEq for DomTreeBase { + fn eq(&self, other: &Self) -> bool { + self.parent == other.parent + && self.roots.len() == other.roots.len() + && self.roots.iter().all(|root| other.roots.contains(root)) + && self.nodes.len() == other.nodes.len() + && self.nodes.iter().all(|(_, node)| { + let block = node.block(); + other.get(block).is_some_and(|n| node == &n) + }) + } +} From 9e9dfe876d9c0f3f7fa749ca5ddd8c2d0bad1eca Mon Sep 17 00:00:00 2001 From: Paul Schoenfelder Date: Mon, 28 Oct 2024 02:55:45 -0400 Subject: [PATCH 31/31] feat: implement loop analysis This commit implements the generic loop analysis from LLVM, in Rust. It is a more sophisticated analysis than the one in HIR1, and provides us with some additional useful information that will come in handy during certain transformations to MASM. See the "Loop Terminlogy" document on llvm.org for details on how LLVM reasons about loops, which corresponds to the analysis above. --- hir2/src/ir.rs | 1 + hir2/src/ir/loops.rs | 1277 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1278 insertions(+) create mode 100644 hir2/src/ir/loops.rs diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs index 9f0da7ebf..fd5981e20 100644 --- a/hir2/src/ir.rs +++ b/hir2/src/ir.rs @@ -11,6 +11,7 @@ mod ident; mod immediates; mod insert; mod interface; +pub mod loops; mod op; mod operands; mod operation; diff --git a/hir2/src/ir/loops.rs b/hir2/src/ir/loops.rs new file mode 100644 index 000000000..58bc2647f --- /dev/null +++ b/hir2/src/ir/loops.rs @@ -0,0 +1,1277 @@ +use alloc::{collections::BTreeMap, rc::Rc}; +use core::{ + cell::{Cell, Ref, RefCell, RefMut}, + fmt, +}; + +use smallvec::SmallVec; + +use super::dominance::{DominanceTree, PostOrderDomTreeIter}; +use crate::{ + adt::{SmallMap, SmallSet}, + cfg::{Graph, Inverse, InvertibleGraph}, + BlockRef, EntityWithId, OperationRef, PostOrderBlockIter, Report, +}; + +/// [LoopForest] represents all of the top-level loop structures in a specified region. +/// +/// The [LoopForest] analysis is used to identify natural loops and determine the loop depth of +/// various nodes in a generic graph of blocks. A natural loop has exactly one entry-point, which +/// is called the header. Note that natural loops may actually be several loops that share the same +/// header node. +/// +/// This analysis calculates the nesting structure of loops in a function. For each natural loop +/// identified, this analysis identifies natural loops contained entirely within the loop and the +/// basic blocks that make up the loop. +/// +/// It can calculate on the fly various bits of information, for example: +/// +/// * Whether there is a preheader for the loop +/// * The number of back edges to the header +/// * Whether or not a particular block branches out of the loop +/// * The successor blocks of the loop +/// * The loop depth +/// * etc... +/// +/// Note that this analysis specifically identifies _loops_ not cycles or SCCs in the graph. There +/// can be strongly connected components in the graph which this analysis will not recognize and +/// that will not be represented by a loop instance. In particular, a loop might be inside such a +/// non-loop SCC, or a non-loop SCC might contain a sub-SCC which is a loop. +/// +/// For an overview of terminology used in this API (and thus all related loop analyses or +/// transforms), see [Loop Terminology](https://llvm.org/docs/LoopTerminology.html). +#[derive(Default)] +pub struct LoopForest { + /// The set of top-level loops in the forest + top_level_loops: SmallVec<[Rc; 4]>, + /// Mapping of basic blocks to the inner most loop they occur in + block_map: BTreeMap>, +} + +impl LoopForest { + /// Compute a new [LoopForest] from the given dominator tree + pub fn new(tree: &DominanceTree) -> Self { + let mut forest = Self::default(); + forest.analyze(tree); + forest + } + + /// Returns true if there are no loops in the forest + pub fn is_empty(&self) -> bool { + self.top_level_loops.is_empty() + } + + /// Returns the number of loops in the forest + pub fn len(&self) -> usize { + self.top_level_loops.len() + } + + /// Returns true if `block` is in this loop forest + #[inline] + pub fn contains_block(&self, block: &BlockRef) -> bool { + self.block_map.contains_key(block) + } + + /// Get the set of top-level/outermost loops in the forest + pub fn top_level_loops(&self) -> &[Rc] { + &self.top_level_loops + } + + /// Return all of the loops in the function in preorder across the loop nests, with siblings in + /// forward program order. + /// + /// Note that because loops form a forest of trees, preorder is equivalent to reverse postorder. + pub fn loops_in_preorder(&self) -> SmallVec<[Rc; 4]> { + // The outer-most loop actually goes into the result in the same relative order as we walk + // it. But LoopForest stores the top level loops in reverse program order so for here we + // reverse it to get forward program order. + // + // FIXME: If we change the order of LoopForest we will want to remove the reverse here. + let mut preorder_loops = SmallVec::<[Rc; 4]>::default(); + for l in self.top_level_loops.iter().cloned().rev() { + let mut loops_in_preorder = l.loops_in_preorder(); + preorder_loops.append(&mut loops_in_preorder); + } + preorder_loops + } + + /// Return all of the loops in the function in preorder across the loop nests, with siblings in + /// _reverse_ program order. + /// + /// Note that because loops form a forest of trees, preorder is equivalent to reverse postorder. + /// + /// Also note that this is _not_ a reverse preorder. Only the siblings are in reverse program + /// order. + pub fn loops_in_reverse_sibling_preorder(&self) -> SmallVec<[Rc; 4]> { + // The outer-most loop actually goes into the result in the same relative order as we walk + // it. LoopForest stores the top level loops in reverse program order so we walk in order + // here. + // + // FIXME: If we change the order of LoopInfo we will want to add a reverse here. + let mut preorder_loops = SmallVec::<[Rc; 4]>::default(); + let mut preorder_worklist = SmallVec::<[Rc; 4]>::default(); + for l in self.top_level_loops.iter().cloned() { + assert!(preorder_worklist.is_empty()); + preorder_worklist.push(l); + while let Some(l) = preorder_worklist.pop() { + // Sub-loops are stored in forward program order, but will process the worklist + // backwards so we can just append them in order. + preorder_worklist.extend(l.nested().iter().cloned()); + preorder_loops.push(l); + } + } + + preorder_loops + } + + /// Return the inner most loop that `block` lives in. + /// + /// If a basic block is in no loop (for example the entry node), `None` is returned. + pub fn loop_for(&self, block: &BlockRef) -> Option> { + self.block_map.get(block).cloned() + } + + /// Return the loop nesting level of the specified block. + /// + /// A depth of 0 means the block is not inside any loop. + pub fn loop_depth(&self, block: &BlockRef) -> usize { + self.loop_for(block).map(|l| l.depth()).unwrap_or(0) + } + + /// Returns true if the block is a loop header + pub fn is_loop_header(&self, block: &BlockRef) -> bool { + self.loop_for(block).map(|l| &l.header() == block).unwrap_or(false) + } + + /// This removes the specified top-level loop from this loop info object. + /// + /// The loop is not deleted, as it will presumably be inserted into another loop. + /// + /// # Panics + /// + /// This function will panic if the given loop is not a top-level loop + pub fn remove_loop(&mut self, l: &Loop) -> Option> { + assert!(l.is_outermost(), "`l` is not an outermost loop"); + let index = self.top_level_loops.iter().position(|tll| core::ptr::addr_eq(&**tll, l))?; + Some(self.top_level_loops.swap_remove(index)) + } + + /// Change the top-level loop that contains `block` to the specified loop. + /// + /// This should be used by transformations that restructure the loop hierarchy tree. + pub fn change_loop_for(&mut self, block: BlockRef, l: Option>) { + if let Some(l) = l { + self.block_map.insert(block, l); + } else { + self.block_map.remove(&block); + } + } + + /// Replace the specified loop in the top-level loops list with the indicated loop. + pub fn change_top_level_loop(&mut self, old: Rc, new: Rc) { + assert!( + new.parent_loop().is_none() && old.parent_loop().is_none(), + "loops already embedded into a subloop" + ); + let index = self + .top_level_loops + .iter() + .position(|tll| Rc::ptr_eq(tll, &old)) + .expect("`old` loop is not a top-level loop"); + self.top_level_loops[index] = new; + } + + /// This adds the specified loop to the collection of top-level loops. + pub fn add_top_level_loop(&mut self, l: Rc) { + assert!(l.is_outermost(), "loop already in subloop"); + self.top_level_loops.push(l); + } + + /// This method completely removes `block` from all data structures, including all of the loop + /// objects it is nested in and our mapping from basic blocks to loops. + pub fn remove_block(&mut self, block: &BlockRef) { + if let Some(l) = self.block_map.remove(block) { + let mut next_l = Some(l); + while let Some(l) = next_l.take() { + next_l = l.parent_loop(); + l.remove_block_from_loop(block); + } + } + } + + pub fn is_not_already_contained_in(sub_loop: Option<&Loop>, parent: Option<&Loop>) -> bool { + let Some(sub_loop) = sub_loop else { + return true; + }; + if parent.is_some_and(|parent| parent == sub_loop) { + return false; + } + return Self::is_not_already_contained_in(sub_loop.parent_loop().as_deref(), parent); + } + + /// Analyze the given dominance tree to discover loops. + /// + /// The analysis discovers loops during a post-order traversal of the given dominator tree, + /// interleaved with backward CFG traversals within each subloop + /// (see `discover_and_map_subloop`). The backward traversal skips inner subloops, so this part + /// of the algorithm is linear in the number of CFG edges. Subloop and block vectors are then + /// populated during a single forward CFG traversal. + /// + /// During the two CFG traversals each block is seen three times: + /// + /// 1. Discovered and mapped by a reverse CFG traversal. + /// 2. Visited during a forward DFS CFG traversal. + /// 3. Reverse-inserted in the loop in postorder following forward DFS. + /// + /// The block vectors are inclusive, so step 3 requires loop-depth number of insertions per + /// block. + pub fn analyze(&mut self, tree: &DominanceTree) { + // Postorder traversal of the dominator tree. + let root = tree.root_node().unwrap(); + for node in PostOrderDomTreeIter::new(root.clone()) { + let header = node.block().expect("expected header block").clone(); + let mut backedges = SmallVec::<[BlockRef; 4]>::default(); + + // Check each predecessor of the potential loop header. + for backedge in BlockRef::inverse_children(header.clone()) { + // If `header` dominates `pred`, this is a new loop. Collect the backedges. + let backedge_node = tree.get(Some(&backedge)); + if backedge_node.is_some() && tree.dominates_node(Some(node.clone()), backedge_node) + { + backedges.push(backedge); + } + } + + // Perform a backward CFG traversal to discover and map blocks in this loop. + if !backedges.is_empty() { + let l = Rc::new(Loop::new(header.clone())); + self.discover_and_map_sub_loop(l, backedges, tree); + } + } + + // Perform a single forward CFG traversal to populate blocks and subloops for all loops. + for block in PostOrderBlockIter::new(root.block().cloned().unwrap()) { + self.insert_into_loop(block); + } + } + + /// Discover a subloop with the specified backedges such that: + /// + /// * All blocks within this loop are mapped to this loop or a subloop. + /// * All subloops within this loop have their parent loop set to this loop or a subloop. + fn discover_and_map_sub_loop( + &mut self, + l: Rc, + backedges: SmallVec<[BlockRef; 4]>, + tree: &DominanceTree, + ) { + let mut num_blocks = 0usize; + let mut num_subloops = 0usize; + + // Perform a backward CFG traversal using a worklist. + let mut reverse_cfg_worklist = backedges; + while let Some(pred) = reverse_cfg_worklist.pop() { + match self.loop_for(&pred) { + None if !tree.is_reachable_from_entry(&pred) => continue, + None => { + // This is an undiscovered block. Map it to the current loop. + self.change_loop_for(pred.clone(), Some(l.clone())); + num_blocks += 1; + if pred == l.header() { + continue; + } + + // Push all block predecessors on the worklist + reverse_cfg_worklist.extend(Inverse::::children(pred.clone())); + } + Some(subloop) => { + // This is a discovered block. Find its outermost discovered loop. + let subloop = subloop.outermost_loop(); + + // If it is already discovered to be a subloop of this loop, continue. + if subloop == l { + continue; + } + + // Discover a subloop of this loop. + subloop.set_parent_loop(Some(l.clone())); + num_subloops += 1; + num_blocks += subloop.num_blocks(); + + // Continue traversal along predecessors that are not loop-back edges from + // within this subloop tree itself. Note that a predecessor may directly reach + // another subloop that is not yet discovered to be a subloop of this loop, + // which we must traverse. + for pred in BlockRef::inverse_children(subloop.header()) { + if self.loop_for(&pred).is_none_or(|l| l != subloop) { + reverse_cfg_worklist.push(pred); + } + } + } + } + } + + l.nested.borrow_mut().reserve(num_subloops); + l.reserve(num_blocks); + } + + /// Add a single block to its ancestor loops in post-order. + /// + /// If the block is a subloop header, add the subloop to its parent in post-order, then reverse + /// the block and subloop vectors of the now complete subloop to achieve RPO. + fn insert_into_loop(&mut self, block: BlockRef) { + let mut subloop = self.loop_for(&block); + if let Some(sl) = subloop.clone().filter(|sl| sl.header() == block) { + let parent = sl.parent_loop(); + // We reach this point once per subloop after processing all the blocks in the subloop. + if sl.is_outermost() { + self.add_top_level_loop(sl.clone()); + } else { + parent.as_ref().unwrap().nested.borrow_mut().push(sl.clone()); + } + + // For convenience, blocks and subloops are inserted in postorder. Reverse the lists, + // except for the loop header, which is always at the beginning. + sl.reverse_blocks(1); + sl.nested.borrow_mut().reverse(); + subloop = parent; + } + + while let Some(sl) = subloop.take() { + sl.add_block_entry(block.clone()); + subloop = sl.parent_loop(); + } + } + + /// Verify the loop forest structure using the provided [DominanceTree] + pub fn verify(&self, tree: &DominanceTree) -> Result<(), Report> { + let mut loops = SmallSet::, 2>::default(); + for l in self.top_level_loops.iter().cloned() { + if !l.is_outermost() { + return Err(Report::msg("top-level loop has a parent")); + } + l.verify_loop_nest(&mut loops)?; + } + + if cfg!(debug_assertions) { + // Verify that blocks are mapped to valid loops. + for (block, block_loop) in self.block_map.iter() { + if !loops.contains(block_loop) { + return Err(Report::msg("orphaned loop")); + } + if !block_loop.contains_block(block) { + return Err(Report::msg("orphaned block")); + } + for child_loop in block_loop.nested().iter() { + if child_loop.contains_block(block) { + return Err(Report::msg( + "expected block map to reflect the innermost loop containing `block`", + )); + } + } + } + + // Recompute forest to verify loops structure. + let other = LoopForest::new(tree); + + // Build a map we can use to move from our forest to the newly computed one. This allows + // us to ignore the particular order in any layer of the loop forest while still + // comparing the structure. + let mut other_headers = SmallMap::, 8>::default(); + + fn add_inner_loops_to_headers_map( + headers: &mut SmallMap, 8>, + l: &Rc, + ) { + let header = l.header(); + headers.insert(header, Rc::clone(l)); + for sl in l.nested().iter() { + add_inner_loops_to_headers_map(headers, sl); + } + } + + for l in other.top_level_loops() { + add_inner_loops_to_headers_map(&mut other_headers, l); + } + + // Walk the top level loops and ensure there is a corresponding top-level loop in the + // computed version and then recursively compare those loop nests. + for l in self.top_level_loops() { + let header = l.header(); + let other_l = other_headers.remove(&header); + match other_l { + None => { + return Err(Report::msg( + "top level loop is missing in computed loop forest", + )) + } + Some(other_l) => { + // Recursively compare the loops + Self::compare_loops(l.clone(), other_l, &mut other_headers)?; + } + } + } + + // Any remaining entries in the map are loops which were found when computing a fresh + // loop forest but not present in the current one. + if !other_headers.is_empty() { + for (_header, header_loop) in other_headers { + log::trace!("Found new loop {header_loop:?}"); + } + return Err(Report::msg("found new loops when recomputing loop forest")); + } + } + + Ok(()) + } + + #[cfg(debug_assertions)] + fn compare_loops( + l: Rc, + other_l: Rc, + other_loop_headers: &mut SmallMap, 8>, + ) -> Result<(), Report> { + let header = l.header(); + let other_header = other_l.header(); + if header != other_header { + return Err(Report::msg( + "mismatched headers even though found under the same map entry", + )); + } + + if l.depth() != other_l.depth() { + return Err(Report::msg("mismatched loop depth")); + } + + { + let mut parent_l = Some(l.clone()); + let mut other_parent_l = Some(other_l.clone()); + while let Some(pl) = parent_l.take() { + if let Some(opl) = other_parent_l.take() { + if pl.header() != opl.header() { + return Err(Report::msg("mismatched parent loop headers")); + } + parent_l = pl.parent_loop(); + other_parent_l = opl.parent_loop(); + } else { + return Err(Report::msg( + "`other_l` misreported its depth: expected a parent and got none", + )); + } + } + } + + for sl in l.nested().iter() { + let sl_header = sl.header(); + let other_sl = other_loop_headers.remove(&sl_header); + match other_sl { + None => return Err(Report::msg("inner loop is missing in computed loop forest")), + Some(other_sl) => { + Self::compare_loops(sl.clone(), other_sl, other_loop_headers)?; + } + } + } + + let mut blocks = l.blocks.borrow().clone(); + let mut other_blocks = other_l.blocks.borrow().clone(); + blocks.sort_by_key(|b| b.borrow().id()); + other_blocks.sort_by_key(|b| b.borrow().id()); + if blocks != other_blocks { + log::trace!("blocks: {}", crate::formatter::DisplayValues::new(blocks.iter())); + log::trace!( + "other_blocks: {}", + crate::formatter::DisplayValues::new(other_blocks.iter()) + ); + return Err(Report::msg("loops report mismatched blocks")); + } + + let block_set = l.block_set(); + let other_block_set = other_l.block_set(); + let diff = block_set.symmetric_difference(&other_block_set); + if block_set.len() != other_block_set.len() || !diff.is_empty() { + log::trace!( + "block_set: {}", + crate::formatter::DisplayValues::new(block_set.iter()) + ); + log::trace!( + "other_block_set: {}", + crate::formatter::DisplayValues::new(other_block_set.iter()) + ); + log::trace!("diff: {}", crate::formatter::DisplayValues::new(diff.iter())); + return Err(Report::msg("loops report mismatched block sets")); + } + + Ok(()) + } + + #[cfg(not(debug_assertions))] + fn compare_loops( + _l: Rc, + _other_l: Rc, + _other_loop_headers: &mut SmallMap, 8>, + ) -> Result<(), Report> { + Ok(()) + } +} + +impl fmt::Debug for LoopForest { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("LoopInfo") + .field("top_level_loops", &self.top_level_loops) + .field("block_map", &self.block_map) + .finish() + } +} + +/// Edge type. +pub type LoopEdge = (BlockRef, BlockRef); + +/// [Loop] is used to represent loops that are detected in the control-flow graph. +#[derive(Default)] +pub struct Loop { + /// If this loop is an outermost loop, this field is `None`. + /// + /// Otherwise, it holds a handle to the parent loop which transfers control to this loop. + parent_loop: Cell>>, + /// Loops contained entirely within this one. + /// + /// All of the loops in this set will have their `parent` set to this loop + nested: RefCell; 2]>>, + /// The list of blocks in this loop. + /// + /// The header block is always at index 0. + blocks: RefCell>, + /// The uniqued set of blocks present in this loop + block_set: RefCell>, +} + +impl Eq for Loop {} +impl PartialEq for Loop { + fn eq(&self, other: &Self) -> bool { + core::ptr::addr_eq(self, other) + } +} + +impl Loop { + /// Create a new [Loop] with `block` as its header. + pub fn new(block: BlockRef) -> Self { + let mut this = Self::default(); + this.blocks.get_mut().push(block.clone()); + this.block_set.get_mut().insert(block); + this + } + + /// Get the nesting level of this loop. + /// + /// An outer-most loop has depth 1, for consistency with loop depth values used for basic + /// blocks, where depth 0 is used for blocks not inside any loops. + pub fn depth(&self) -> usize { + let mut depth = 1; + let mut current_loop = self.parent_loop(); + while let Some(curr) = current_loop.take() { + depth += 1; + current_loop = curr.parent_loop(); + } + depth + } + + /// Get the header block of this loop + pub fn header(&self) -> BlockRef { + self.blocks.borrow()[0].clone() + } + + /// Return the parent loop of this loop, if it has one, or `None` if it is a top-level loop. + /// + /// A loop is either top-level in a function (that is, it is not contained in any other loop) or + /// it is entirely enclosed in some other loop. If a loop is top-level, it has no parent, + /// otherwise its parent is the innermost loop in which it is enclosed. + pub fn parent_loop(&self) -> Option> { + unsafe { (*self.parent_loop.as_ptr()).clone() } + } + + /// This is a low-level API for bypassing [add_child_loop]. + pub fn set_parent_loop(&self, parent: Option>) { + self.parent_loop.set(parent); + } + + /// Discover the outermost loop that contains `self` + pub fn outermost_loop(self: Rc) -> Rc { + let mut l = self; + while let Some(parent) = l.parent_loop() { + l = parent; + } + l + } + + /// Return true if the specified loop is contained within in this loop. + pub fn contains(&self, l: Rc) -> bool { + if core::ptr::addr_eq(self, &*l) { + return true; + } + + let Some(parent) = l.parent_loop() else { + return false; + }; + + self.contains(parent) + } + + /// Returns true if the specified basic block is in this loop + pub fn contains_block(&self, block: &BlockRef) -> bool { + self.block_set.borrow().contains(block) + } + + /// Returns true if the specified operation is in this loop + pub fn contains_op(&self, op: &OperationRef) -> bool { + let Some(block) = op.borrow().parent() else { + return false; + }; + self.contains_block(&block) + } + + /// Return the loops contained entirely within this loop. + pub fn nested(&self) -> Ref<'_, [Rc]> { + Ref::map(self.nested.borrow(), |nested| nested.as_slice()) + } + + /// Return true if the loop does not contain any (natural) loops. + /// + /// [Loop] does not detect irreducible control flow, just natural loops. That is, it is possible + /// that there is cyclic control flow within the innermost loop or around the outermost loop. + pub fn is_innermost(&self) -> bool { + self.nested.borrow().is_empty() + } + + /// Return true if the loop does not have a parent (natural) loop (i.e. it is outermost, which + /// is the same as top-level). + pub fn is_outermost(&self) -> bool { + unsafe { (*self.parent_loop.as_ptr()).is_none() } + } + + /// Get a list of the basic blocks which make up this loop. + pub fn blocks(&self) -> Ref<'_, [BlockRef]> { + Ref::map(self.blocks.borrow(), |blocks| blocks.as_slice()) + } + + /// Get a mutable reference to the basic blocks which make up this loop. + pub fn blocks_mut(&self) -> RefMut<'_, SmallVec<[BlockRef; 32]>> { + self.blocks.borrow_mut() + } + + /// Return the number of blocks contained in this loop + pub fn num_blocks(&self) -> usize { + self.blocks.borrow().len() + } + + /// Return a reference to the blocks set. + pub fn block_set(&self) -> Ref<'_, SmallSet> { + self.block_set.borrow() + } + + /// Return a mutable reference to the blocks set. + pub fn block_set_mut(&self) -> RefMut<'_, SmallSet> { + self.block_set.borrow_mut() + } + + /// Returns true if the terminator of `block` can branch to another block that is outside of the + /// current loop. + /// + /// # Panics + /// + /// This function will panic if `block` is not inside this loop. + pub fn is_loop_exiting(&self, block: &BlockRef) -> bool { + assert!(self.contains_block(block), "exiting block must be part of the loop"); + BlockRef::children(block.clone()).any(|succ| !self.contains_block(&succ)) + } + + /// Returns true if `block` is a loop-latch. + /// + /// A latch block is a block that contains a branch back to the header. + /// + /// This function is useful when there are multiple latches in a loop because `get_loop_latch` + /// will return `None` in that case. + pub fn is_loop_latch(&self, block: &BlockRef) -> bool { + assert!(self.contains_block(block), "block does not belong to the loop"); + BlockRef::inverse_children(self.header()).any(|pred| &pred == block) + } + + /// Calculate the number of back edges to the loop header + pub fn num_backedges(&self) -> usize { + BlockRef::inverse_children(self.header()) + .filter(|pred| self.contains_block(pred)) + .count() + } +} + +/// Loop Analysis +/// +/// Note that all of these methods can fail on general loops (ie, there may not be a preheader, +/// etc). For best success, the loop simplification and induction variable canonicalization pass +/// should be used to normalize loops for easy analysis. These methods assume canonical loops. +impl Loop { + /// Get all blocks inside the loop that have successors outside of the loop. + /// + /// These are the blocks _inside of the current loop_ which branch out. The returned list is + /// always unique. + pub fn exiting_blocks(&self) -> SmallVec<[BlockRef; 2]> { + let mut exiting_blocks = SmallVec::default(); + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + // A block must be an exit block if it is not contained in the current loop + if !self.contains_block(&succ) { + exiting_blocks.push(block.clone()); + break; + } + } + } + exiting_blocks + } + + /// If [Self::exiting_blocks] would return exactly one block, return it, otherwise `None`. + pub fn exiting_block(&self) -> Option { + let mut exiting_block = None; + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + if exiting_block.is_some() { + return None; + } else { + exiting_block = Some(block.clone()); + } + break; + } + } + } + exiting_block + } + + /// Get all of the successor blocks of this loop. + /// + /// These are the blocks _outside of the current loop_ which are branched to. + pub fn exit_blocks(&self) -> SmallVec<[BlockRef; 2]> { + let mut exit_blocks = SmallVec::default(); + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + exit_blocks.push(succ); + } + } + } + exit_blocks + } + + /// If [Self::exit_blocks] would return exactly one block, return it, otherwise `None`. + pub fn exit_block(&self) -> Option { + let mut exit_block = None; + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + if exit_block.is_some() { + return None; + } else { + exit_block = Some(succ); + } + } + } + } + exit_block + } + + /// Returns true if no exit block for the loop has a predecessor that is outside the loop. + pub fn has_dedicated_exits(&self) -> bool { + // Each predecessor of each exit block of a normal loop is contained within the loop. + for exit_block in self.unique_exit_blocks() { + for pred in BlockRef::inverse_children(exit_block) { + if !self.contains_block(&pred) { + return false; + } + } + } + + // All the requirements are met. + true + } + + /// Return all unique successor blocks of this loop. + /// + /// These are the blocks _outside of the current loop_ which are branched to. + pub fn unique_exit_blocks(&self) -> SmallVec<[BlockRef; 2]> { + let mut unique_exits = SmallVec::default(); + unique_exit_blocks_helper(self, &mut unique_exits, |_| true); + unique_exits + } + + /// Return all unique successor blocks of this loop, except successors from the latch block + /// which are not considered. If an exit that comes from the latch block, but also has a non- + /// latch predecessor in the loop, it will be included. + /// + /// These are the blocks _outside of the current loop_ which are branched to. + pub fn unique_non_latch_exit_blocks(&self) -> SmallVec<[BlockRef; 2]> { + let latch_block = self.loop_latch().expect("latch must exist"); + let mut unique_exits = SmallVec::default(); + unique_exit_blocks_helper(self, &mut unique_exits, |block| block != &latch_block); + unique_exits + } + + /// If [Self::unique_exit_blocks] would return exactly one block, return it, otherwise `None`. + #[inline] + pub fn unique_exit_block(&self) -> Option { + self.exit_block() + } + + /// Return true if this loop does not have any exit blocks. + pub fn has_no_exit_blocks(&self) -> bool { + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + return false; + } + } + } + true + } + + /// Return all pairs of (_inside_block_, _outside_block_). + pub fn exit_edges(&self) -> SmallVec<[LoopEdge; 2]> { + let mut exit_edges = SmallVec::default(); + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + exit_edges.push((block.clone(), succ)); + } + } + } + exit_edges + } + + /// Returns the pre-header for this loop, if there is one. + /// + /// A loop has a pre-header if there is only one edge to the header of the loop from outside of + /// the loop. If this is the case, the block branching to the header of the loop is the + /// pre-header node. + /// + /// This returns `None` if there is no pre-header for the loop. + pub fn preheader(&self) -> Option { + use crate::IteratorExt; + + // Keep track of nodes outside the loop branching to the header... + let out = self.loop_predecessor()?; + + // Make sure we are allowed to hoist instructions into the predecessor. + if !out.borrow().is_legal_to_hoist_into() { + return None; + } + + // Make sure there is only one exit out of the preheader. + if !BlockRef::children(out.clone()).has_single_element() { + // Multiple exits from the block, must not be a preheader. + return None; + } + + // The predecessor has exactly one successor, so it is a preheader. + Some(out) + } + + /// If the given loop's header has exactly one unique predecessor outside the loop, return it. + /// + /// This is less strict than the loop "preheader" concept, which requires the predecessor to + /// have exactly one successor. + pub fn loop_predecessor(&self) -> Option { + // Keep track of nodes outside the loop branching to the header... + let mut out = None; + // Loop over the predecessors of the header node... + let header = self.header(); + for pred in BlockRef::inverse_children(header) { + if !self.contains_block(&pred) { + if out.as_ref().is_some_and(|out| out != &pred) { + // Multiple predecessors outside the loop + return None; + } + out = Some(pred); + } + } + out + } + + /// If there is a single latch block for this loop, return it. + /// + /// A latch block is a block that contains a branch back to the header. + pub fn loop_latch(&self) -> Option { + let header = self.header(); + let mut latch_block = None; + for pred in BlockRef::inverse_children(header) { + if self.contains_block(&pred) { + if latch_block.is_some() { + return None; + } + latch_block = Some(pred); + } + } + latch_block + } + + /// Get all loop latch blocks of this loop. + /// + /// A latch block is a block that contains a branch back to the header. + pub fn loop_latches(&self) -> SmallVec<[BlockRef; 2]> { + BlockRef::inverse_children(self.header()) + .filter(|pred| self.contains_block(pred)) + .collect() + } + + /// Return all inner loops in the loop nest rooted by the loop in preorder, with siblings in + /// forward program order. + pub fn inner_loops_in_preorder(&self) -> SmallVec<[Rc; 2]> { + let mut worklist = SmallVec::<[Rc; 4]>::default(); + worklist.extend(self.nested().iter().rev().cloned()); + + let mut results = SmallVec::default(); + while let Some(l) = worklist.pop() { + // Sub-loops are stored in forward program order, but will process the + // worklist backwards so append them in reverse order. + worklist.extend(l.nested().iter().rev().cloned()); + results.push(l); + } + + results + } + + /// Return all loops in the loop nest rooted by the loop in preorder, with siblings in forward + /// program order. + pub fn loops_in_preorder(self: Rc) -> SmallVec<[Rc; 2]> { + let mut loops = self.inner_loops_in_preorder(); + loops.insert(0, self); + loops + } +} + +fn unique_exit_blocks_helper( + l: &Loop, + exit_blocks: &mut SmallVec<[BlockRef; 2]>, + mut predicate: F, +) where + F: FnMut(&BlockRef) -> bool, +{ + let mut visited = SmallSet::::default(); + for block in l.blocks.borrow().iter().filter(|b| predicate(b)).cloned() { + for succ in BlockRef::children(block) { + if !l.contains_block(&succ) && visited.insert(succ.clone()) { + exit_blocks.push(succ); + } + } + } +} + +/// Updates +impl Loop { + /// Add `block` to this loop, and as a member of all parent loops. + /// + /// It is not valid to replace the loop header using this function. + /// + /// This is intended for use by analyses which need to update loop information. + pub fn add_block_to_loop(self: Rc, block: BlockRef, forest: &mut LoopForest) { + assert!(!forest.contains_block(&block), "`block` is already in this loop"); + + // Add the loop mapping to the LoopForest object... + forest.block_map.insert(block.clone(), self.clone()); + + // Add the basic block to this loop and all parent loops... + let mut next_l = Some(self); + while let Some(l) = next_l.take() { + l.add_block_entry(block.clone()); + next_l = l.parent_loop(); + } + } + + /// Replace `prev` with `new` in the set of children of this loop, updating the parent pointer + /// of `prev` to `None`, and of `new` to `self`. + /// + /// This also updates the loop depth of the new child. + /// + /// This is intended for use when splitting loops up. + pub fn replace_child_loop_with(self: Rc, prev: Rc, new: Rc) { + assert_eq!(prev.parent_loop().as_ref(), Some(&self), "this loop is already broken"); + assert!(new.parent_loop().is_none(), "`new` already has a parent"); + + // Set the parent of `new` to `self` + new.set_parent_loop(Some(self.clone())); + // Replace `prev` in `self.nested` with `new` + let mut nested = self.nested.borrow_mut(); + let entry = nested.iter_mut().find(|l| Rc::ptr_eq(l, &prev)).expect("`prev` not in loop"); + let _ = core::mem::replace(entry, new); + // Set the parent of `prev` to `None` + prev.set_parent_loop(None); + } + + /// Add the specified loop to be a child of this loop. + /// + /// This updates the loop depth of the new child. + pub fn add_child_loop(self: Rc, child: Rc) { + assert!(child.parent_loop().is_none(), "child already has a parent"); + child.set_parent_loop(Some(self.clone())); + self.nested.borrow_mut().push(child); + } + + /// This removes subloops of this loop based on the provided predicate, and returns them in a + /// vector. + /// + /// The loops are not deleted, as they will presumably be inserted into another loop. + pub fn take_child_loops(&self, should_remove: F) -> SmallVec<[Rc; 2]> + where + F: Fn(&Loop) -> bool, + { + let mut taken = SmallVec::default(); + self.nested.borrow_mut().retain(|l| { + if should_remove(l) { + l.set_parent_loop(None); + taken.push(Rc::clone(l)); + false + } else { + true + } + }); + taken + } + + /// This removes the specified child from being a subloop of this loop. + /// + /// The loop is not deleted, as it will presumably be inserted into another loop. + pub fn take_child_loop(&self, child: &Loop) -> Option> { + let mut nested = self.nested.borrow_mut(); + let index = nested.iter().position(|l| core::ptr::addr_eq(&**l, child))?; + Some(nested.swap_remove(index)) + } + + /// This adds a basic block directly to the basic block list. + /// + /// This should only be used by transformations that create new loops. Other transformations + /// should use [add_block_to_loop]. + pub fn add_block_entry(&self, block: BlockRef) { + self.blocks.borrow_mut().push(block.clone()); + self.block_set.borrow_mut().insert(block); + } + + /// Reverse the order of blocks in this loop starting from `index` to the end. + pub fn reverse_blocks(&self, index: usize) { + self.blocks.borrow_mut()[index..].reverse(); + } + + /// Reserve capacity for `capacity` blocks + pub fn reserve(&self, capacity: usize) { + self.blocks.borrow_mut().reserve(capacity); + } + + /// This method is used to move `block` (which must be part of this loop) to be the loop header + /// of the loop (the block that dominates all others). + pub fn move_to_header(&self, block: BlockRef) { + let mut blocks = self.blocks.borrow_mut(); + let index = blocks.iter().position(|b| b == &block).expect("loop does not contain `block`"); + if index == 0 { + return; + } + unsafe { + blocks.swap_unchecked(0, index); + } + } + + /// This removes the specified basic block from the current loop, updating the `self.blocks` as + /// appropriate. This does not update the mapping in the corresponding [LoopInfo]. + pub fn remove_block_from_loop(&self, block: &BlockRef) { + let mut blocks = self.blocks.borrow_mut(); + let index = blocks.iter().position(|b| b == block).expect("loop does not contain `block`"); + blocks.swap_remove(index); + self.block_set.borrow_mut().remove(block); + } + + /// Verify loop structure + #[cfg(debug_assertions)] + pub fn verify_loop(&self) -> Result<(), Report> { + use crate::PreOrderBlockIter; + + if self.blocks.borrow().is_empty() { + return Err(Report::msg("loop header is missing")); + } + + // Setup for using a depth-first iterator to visit every block in the loop. + let exit_blocks = self.exit_blocks(); + let mut visit_set = SmallSet::::default(); + visit_set.extend(exit_blocks.iter().cloned()); + + // Keep track of the BBs visited. + let mut visited_blocks = SmallSet::::default(); + + // Check the individual blocks. + let header = self.header(); + for block in + PreOrderBlockIter::new_with_visited(header.clone(), exit_blocks.iter().cloned()) + { + let has_in_loop_successors = + BlockRef::children(block.clone()).any(|b| self.contains_block(&b)); + if !has_in_loop_successors { + return Err(Report::msg("loop block has no in-loop successors")); + } + + let has_in_loop_predecessors = + BlockRef::inverse_children(block.clone()).any(|b| self.contains_block(&b)); + if !has_in_loop_predecessors { + return Err(Report::msg("loop block has no in-loop predecessors")); + } + + let outside_loop_preds = BlockRef::inverse_children(block.clone()) + .filter(|b| !self.contains_block(b)) + .collect::>(); + + if block == header && outside_loop_preds.is_empty() { + return Err(Report::msg("loop is unreachable")); + } else if !outside_loop_preds.is_empty() { + // A non-header loop shouldn't be reachable from outside the loop, though it is + // permitted if the predecessor is not itself actually reachable. + let entry = block.borrow().parent().unwrap().borrow().entry_block_ref().unwrap(); + for child_block in PreOrderBlockIter::new(entry) { + if outside_loop_preds.iter().any(|pred| &child_block == pred) { + return Err(Report::msg("loop has multiple entry points")); + } + } + } + if block != header.borrow().parent().unwrap().borrow().entry_block_ref().unwrap() { + return Err(Report::msg("loop contains region entry block")); + } + visited_blocks.insert(block); + } + + if visited_blocks.len() != self.num_blocks() { + log::trace!("The following blocks are unreachable in the loop: "); + for block in self.blocks().iter() { + if !visited_blocks.contains(block) { + log::trace!("{block}"); + } + } + return Err(Report::msg("unreachable block in loop")); + } + + // Check the subloops + for subloop in self.nested().iter() { + // Each block in each subloop should be contained within this loop. + for block in subloop.blocks().iter() { + if !self.contains_block(block) { + return Err(Report::msg( + "loop does not contain all the blocks of its subloops", + )); + } + } + } + + // Check the parent loop pointer. + if let Some(parent) = self.parent_loop() { + if !parent.nested().contains(&parent) { + return Err(Report::msg("loop is not a subloop of its parent")); + } + } + + Ok(()) + } + + #[cfg(not(debug_assertions))] + pub fn verify_loop(&self) {} + + /// Verify loop structure of this loop and all nested loops. + pub fn verify_loop_nest( + self: Rc, + loops: &mut SmallSet, 2>, + ) -> Result<(), Report> { + loops.insert(self.clone()); + + // Verify this loop. + self.verify_loop()?; + + // Verify the subloops. + for l in self.nested.borrow().iter().cloned() { + l.verify_loop_nest(loops)?; + } + + Ok(()) + } + + /// Print loop with all the blocks inside it. + pub fn print(&self, verbose: bool) -> impl fmt::Display + '_ { + PrintLoop { + loop_info: self, + nested: true, + verbose, + } + } +} + +struct PrintLoop<'a> { + loop_info: &'a Loop, + nested: bool, + verbose: bool, +} + +impl<'a> crate::formatter::PrettyPrint for PrintLoop<'a> { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + + let mut doc = const_text("loop containing: "); + let header = self.loop_info.header(); + for (i, block) in self.loop_info.blocks().iter().enumerate() { + if !self.verbose { + if i > 0 { + doc += const_text(", "); + } + doc += display(block.clone()); + } else { + doc += nl(); + } + + if block == &header { + doc += const_text("

"); + } else if self.loop_info.is_loop_latch(block) { + doc += const_text(""); + } else if self.loop_info.is_loop_exiting(block) { + doc += const_text(""); + } + + if self.verbose { + doc += text(format!("{:?}", &block.borrow())); + } + } + + if self.nested { + let nested = self.loop_info.nested().iter().fold(Document::Empty, |acc, l| { + let printer = PrintLoop { + loop_info: l, + nested: true, + verbose: self.verbose, + }; + acc + nl() + printer.render() + }); + doc + indent(2, nested) + } else { + doc + } + } +} + +impl<'a> fmt::Display for PrintLoop<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use crate::formatter::PrettyPrint; + self.pretty_print(f) + } +} + +impl fmt::Display for Loop { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.print(false)) + } +} +impl fmt::Debug for Loop { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Loop") + .field("parent_loop", &self.parent_loop()) + .field("nested", &self.nested()) + .field("blocks", &self.blocks()) + .field("block_set", &self.block_set()) + .finish() + } +}