diff --git a/Cargo.lock b/Cargo.lock index fa2b23a23..82b385845 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,9 +158,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" [[package]] name = "anymap2" @@ -170,9 +170,9 @@ checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" [[package]] name = "arrayref" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" [[package]] name = "arrayvec" @@ -392,9 +392,9 @@ dependencies = [ [[package]] name = "auth-git2" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51bd0e4592409df8631ca807716dc1e5caafae5d01ce0157c966c71c7e49c3c" +checksum = "3810b5af212b013fe7302b12d86616c6c39a48e18f2e4b812a5a9e5710213791" dependencies = [ "dirs", "git2", @@ -524,6 +524,18 @@ dependencies = [ "typenum", ] +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "blake3" version = "1.5.4" @@ -537,6 +549,15 @@ dependencies = [ "constant_time_eq", ] +[[package]] +name = "blink-alloc" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e669f146bb8b2327006ed94c69cf78c8ec81c100192564654230a40b4f091d82" +dependencies = [ + "allocator-api2", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -826,9 +847,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.18" +version = "1.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" +checksum = "2d74707dde2ba56f86ae90effb3b43ddd369504387e718014de010cec7959800" dependencies = [ "jobserver", "libc", @@ -1649,6 +1670,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "futures" version = "0.3.30" @@ -2318,9 +2345,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -3022,9 +3049,9 @@ dependencies = [ [[package]] name = "miden-crypto" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6fad06fc3af260ed3c4235821daa2132813d993f96d446856036ae97e9606dd" +checksum = "8a69f8362ca496a79c88cf8e5b9b349bf9c6ed49fa867d0548e670afc1f3fca5" dependencies = [ "blake3", "cc", @@ -3130,9 +3157,9 @@ dependencies = [ [[package]] name = "miden-processor" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01e7b212b152b69373e89b069a18cb01742ef2c3f9c328e7b24c44e44f022e52" +checksum = "04a128e20400086c9a985f4e5702e438ba781338fb0bdf9acff16d996c640087" dependencies = [ "miden-air", "miden-core", @@ -3203,6 +3230,7 @@ dependencies = [ "bitcode", "cranelift-entity", "env_logger 0.11.5", + "hashbrown 0.14.5", "intrusive-collections", "inventory", "log", @@ -3218,7 +3246,6 @@ dependencies = [ "paste", "petgraph", "proptest", - "rustc-hash 1.1.0", "serde 1.0.210", "serde_bytes", "smallvec", @@ -3294,6 +3321,7 @@ dependencies = [ "derive_more", "expect-test", "gimli", + "hashbrown 0.14.5", "indexmap 2.5.0", "log", "miden-core", @@ -3301,7 +3329,7 @@ dependencies = [ "midenc-hir", "midenc-hir-type", "midenc-session", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", "wasmparser 0.214.0", "wat", @@ -3315,6 +3343,7 @@ dependencies = [ "cranelift-entity", "derive_more", "either", + "hashbrown 0.14.5", "indexmap 2.5.0", "intrusive-collections", "inventory", @@ -3334,7 +3363,7 @@ dependencies = [ "petgraph", "pretty_assertions", "rustc-demangle", - "rustc-hash 1.1.0", + "rustc-hash", "serde 1.0.210", "serde_bytes", "serde_repr", @@ -3350,13 +3379,13 @@ dependencies = [ "anyhow", "cranelift-bforest", "cranelift-entity", + "hashbrown 0.14.5", "intrusive-collections", "inventory", "miden-thiserror", "midenc-hir", "midenc-session", "pretty_assertions", - "rustc-hash 1.1.0", "smallvec", ] @@ -3365,6 +3394,7 @@ name = "midenc-hir-macros" version = "0.0.6" dependencies = [ "Inflector", + "darling", "proc-macro2", "quote", "syn 2.0.77", @@ -3375,7 +3405,12 @@ name = "midenc-hir-symbol" version = "0.0.6" dependencies = [ "Inflector", - "rustc-hash 1.1.0", + "compact_str", + "hashbrown 0.14.5", + "lock_api", + "miden-formatting", + "parking_lot", + "rustc-hash", "serde 1.0.210", "toml 0.8.19", ] @@ -3391,7 +3426,6 @@ dependencies = [ "midenc-hir-analysis", "midenc-session", "pretty_assertions", - "rustc-hash 1.1.0", "smallvec", ] @@ -3399,9 +3433,52 @@ dependencies = [ name = "midenc-hir-type" version = "0.0.6" dependencies = [ + "miden-formatting", + "serde 1.0.210", + "serde_repr", + "smallvec", +] + +[[package]] +name = "midenc-hir2" +version = "0.0.6" +dependencies = [ + "anyhow", + "bitflags 2.6.0", + "bitvec", + "blink-alloc", + "compact_str", + "cranelift-entity", + "derive_more", + "either", + "env_logger 0.11.5", + "hashbrown 0.14.5", + "indexmap 2.5.0", + "intrusive-collections", + "inventory", + "lalrpop", + "lalrpop-util", + "log", + "miden-assembly", + "miden-core", + "miden-thiserror", + "midenc-hir-macros", + "midenc-hir-symbol", + "midenc-hir-type", + "midenc-session", + "num-bigint", + "num-traits 0.2.19", + "paste", + "petgraph", + "pretty_assertions", + "rustc-demangle", + "rustc-hash", "serde 1.0.210", + "serde_bytes", "serde_repr", "smallvec", + "typed-arena", + "unicode-width", ] [[package]] @@ -4135,9 +4212,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "pretty_assertions" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af7cee1a6c8a5b9208b3cb1061f10c0cb689087b3d8ce85fb9d2dd7a29b6ba66" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" dependencies = [ "diff", "yansi", @@ -4342,7 +4419,7 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.0.0", + "rustc-hash", "rustls", "socket2 0.5.7", "thiserror", @@ -4359,7 +4436,7 @@ dependencies = [ "bytes", "rand", "ring", - "rustc-hash 2.0.0", + "rustc-hash", "rustls", "slab", "thiserror", @@ -4389,6 +4466,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.8.5" @@ -4670,12 +4753,6 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.0.0" @@ -5439,6 +5516,12 @@ dependencies = [ "libc", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "tempfile" version = "3.10.1" @@ -5931,9 +6014,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-truncate" @@ -7026,6 +7109,15 @@ dependencies = [ "wasmparser 0.216.0", ] +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "xdg-home" version = "1.3.0" @@ -7047,9 +7139,9 @@ dependencies = [ [[package]] name = "yansi" -version = "0.5.1" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "zbus" diff --git a/Cargo.toml b/Cargo.toml index d1c632847..75a276703 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "codegen/*", "frontend-wasm", "hir", + "hir2", "hir-analysis", "hir-macros", "hir-symbol", @@ -27,7 +28,7 @@ exclude = [ [workspace.package] version = "0.0.6" -rust-version = "1.80" +rust-version = "1.81" authors = ["Miden contributors"] description = "An intermediate representation and compiler for Miden Assembly" repository = "https://github.com/0xPolygonMiden/compiler" @@ -52,9 +53,11 @@ clap = { version = "4.1", default-features = false, features = [ ] } cranelift-entity = "0.108" cranelift-bforest = "0.108" +compact_str = { version = "0.8", default-features = false } env_logger = "0.11" either = { version = "1.10", default-features = false } expect-test = "1.4.1" +hashbrown = { version = "0.14", features = ["nightly"] } Inflector = "0.11" intrusive-collections = "0.9" inventory = "0.3" @@ -67,7 +70,7 @@ parking_lot_core = "0.9" petgraph = "0.6" pretty_assertions = "1.0" proptest = "1.4" -rustc-hash = "1.1" +rustc-hash = { version = "2.0", default-features = false } serde = { version = "1.0.208", features = ["serde_derive", "alloc", "rc"] } serde_repr = "0.1.19" serde_bytes = "0.11.15" @@ -85,6 +88,7 @@ derive_more = "0.99" indexmap = "2.2" miden-assembly = { version = "0.10.3" } miden-core = { version = "0.10.3" } +miden-formatting = { version = "0.1", default-features = false } miden-parsing = "0.1" miden-processor = { version = "0.10.3" } miden-stdlib = { version = "0.10.3", features = ["with-debug-info"] } diff --git a/codegen/masm/Cargo.toml b/codegen/masm/Cargo.toml index 0a3b9cb96..cfae3541b 100644 --- a/codegen/masm/Cargo.toml +++ b/codegen/masm/Cargo.toml @@ -22,6 +22,7 @@ cranelift-entity.workspace = true intrusive-collections.workspace = true inventory.workspace = true log.workspace = true +hashbrown.workspace = true miden-assembly.workspace = true miden-core.workspace = true miden-processor.workspace = true @@ -32,7 +33,6 @@ midenc-hir-transform.workspace = true midenc-session = { workspace = true, features = ["serde"] } paste.workspace = true petgraph.workspace = true -rustc-hash.workspace = true serde.workspace = true serde_bytes.workspace = true smallvec.workspace = true diff --git a/codegen/masm/src/emulator/breakpoints.rs b/codegen/masm/src/emulator/breakpoints.rs index 19058ac4a..b04aee7bd 100644 --- a/codegen/masm/src/emulator/breakpoints.rs +++ b/codegen/masm/src/emulator/breakpoints.rs @@ -1,7 +1,6 @@ use std::collections::BTreeSet; -use midenc_hir::FunctionIdent; -use rustc_hash::{FxHashMap, FxHashSet}; +use midenc_hir::{FunctionIdent, FxHashMap, FxHashSet}; use super::{Addr, BreakpointEvent, EmulatorEvent, Instruction, InstructionPointer}; use crate::BlockId; @@ -143,7 +142,7 @@ impl BreakpointManager { /// Set the given breakpoint pub fn set(&mut self, bp: Breakpoint) { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; match bp { Breakpoint::All => { diff --git a/codegen/masm/src/emulator/mod.rs b/codegen/masm/src/emulator/mod.rs index 526ccb7a1..892abdfe9 100644 --- a/codegen/masm/src/emulator/mod.rs +++ b/codegen/masm/src/emulator/mod.rs @@ -8,8 +8,10 @@ use std::{cell::RefCell, cmp, rc::Rc, sync::Arc}; use memory::Memory; use miden_assembly::{ast::ProcedureName, LibraryNamespace}; -use midenc_hir::{assert_matches, Felt, FieldElement, FunctionIdent, Ident, OperandStack, Stack}; -use rustc_hash::{FxHashMap, FxHashSet}; +use midenc_hir::{ + assert_matches, Felt, FieldElement, FunctionIdent, FxHashMap, FxHashSet, Ident, OperandStack, + Stack, +}; use self::functions::{Activation, Stub}; pub use self::{ @@ -313,7 +315,7 @@ impl Emulator { /// /// An error is returned if a module with the same name is already loaded. pub fn load_module(&mut self, module: Arc) -> Result<(), EmulationError> { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; assert_matches!( self.status, diff --git a/docs/design/ir.md b/docs/design/ir.md new file mode 100644 index 000000000..6740b8109 --- /dev/null +++ b/docs/design/ir.md @@ -0,0 +1,839 @@ +# High-Level Intermediate Representation (HIR) + +This document describes the concepts, usage, and overall structure of the intermediate +representation used by `midenc`. + +## Introduction + +TODO + +## Concepts + +### Components + +A _component_ is a named entity that encapsulates one or more [_interfaces_](#interfaces), and comes +in two forms: + +* An _executable_ component, which has a statically-defined entrypoint, a function which initializes +and executes a program encapsulated by the component. +* A _library_ component, which exports one or more interfaces, and can be used as a dependency of +other components. + +We also commonly refer to executable components as _programs_, and library components as _libraries_, +which also correspond to the equivalent concepts in Miden Assembly. However, components are a more +general abstraction over programs and libraries, where the distinction is mostly one of intended +use and/or convention. + +Components can have zero or more dependencies, which are expressed in the form of interfaces that +they require instances of at runtime. Thus any component that provides the interface can be used +to satisfy the dependency. + +A component _instance_ refers to a component that has had all of its dependencies resolved +concretely, and is thus fully-defined. + +A component _definition_ specifies four things: + +1. The name of the component +2. The interfaces it imports +3. The interfaces it exports +4. The [_modules_](#modules) which implement the exported interfaces concretely + +### Interfaces + +An _interface_ is a named entity that describes one or more [_functions_](#functions) that it +exports. Conceptually, an _interface_ loosely corresponds to the notion of a module, in that both +a module and an interface define a namespace, in which one or more functions are exported. + +However, an _interface_, unlike a module, is abstract, and does not have any internal structure. +It is more like a trait, in that it abstractly represents a set of named behaviors implemented by +some [component](#components). + +### Modules + +A module is primarily two things: + +1. A container for one or more functions belonging to a common namespace +2. A concrete implementation of one or more [interfaces](#interfaces) + +Functions within a module may be exported, and so a module always has an implicit interface +consisting of all of its exported functions. Functions which are _not_ exported, are only visible +within the module, and do not form a part of the implicit interface of a module. + +Module names are used to name the implicit interface of the module. Thus, within a component, both +imported interfaces, and the implicit interfaces of all modules it defines, can be used to resolve +function references in those modules. + +A module defines a symbol table, whose entries are the functions defined in that module. + +### Functions + +A function is a special type of [_operation_](#operations). It is special in the following ways: + +* A function has a [_symbol_](#symbols-and-symbol-tables), and thus declares an entry in the nearest +containing [_symbol table_](#symbols-and-symbol-tables). +* A function is _isolated from above_, i.e. the contents of the function cannot escape the function, +nor reference things outside the function, except via symbol table references. Thus entities such as +[_values_](#values) and [_blocks_](#blocks) are function-scoped, if not more narrowly scoped, in the +case of operations with nested [_regions_](#regions). + +A function has an arbitrary set of parameters and results, corresponding to its type signature. A +function also has the notion of an _application binary interface_ (ABI), which drives how code is +generated for both caller and callee. For example, a function may have a specific calling convention +as a whole, and specific parameters/results may have type-specific semantics declared, such as +whether to zero- or sign-extend the value if the input is of a smaller range. + +A function always consists of a single [_region_](#regions), called the _body_, with at least one +[_block_](#blocks), which is called the _entry block_. The block parameters of the entry block +always correspond to the function parameters, i.e. the arity and type of the block parameters must +match the function signature. + +Additionally, a function has an additional constraint on its body, which is that all blocks in the +region must end with one of a restricted set of _terminator_ operations: any branch operation, which +transfers control between blocks of the region; the `unreachable` operation, which will result in +aborting the program if executed; or the `return` operation, which must return the same arity and +type of values declared in the function signature. + +### Global Variables + +A global variable is a second special type of [_operation_](#operations), after +[_functions_](#functions): + +* A global variable has a [_symbol_](#symbols-and-symbol-tables), and declares an entry in the +nearest containing [_symbol table_](#symbols-and-symbol-tables). +* The _initializer_ of a global variable is, like function bodies, _isolated from above_. + +A global variable may have an _initializer_, a single region/single block body which is implicitly +executed to initialize the value of the global variable. The initializer must be statically +evaluatable, i.e. a "constant" expression. In most cases, this will simply be a constant value, but +some limited forms of constant expressions are permitted. + +### Symbols and Symbol Tables + +A _symbol_ is simply a named entity, e.g. a function `foo` is a symbol whose value is `foo`. +On their own, symbols aren't particularly useful. This is where the concept of a _symbol table_ +becomes important. + +A _symbol table_ is a collection of uniqued symbols belonging to the same namespace, i.e. every +symbol has a single entry in the symbol table, regardless of entity type. Thus, it is not permitted +to have both a function and a global variable with the same name, in the same symbol table. If such +a thing needed to be allowed, perhaps because the namespace for functions and global variables are +separate, then you would use a per-entity symbol table. + +For our purposes, a _module_ defines a symbol table, and both functions and global variables share +that table. We do not currently use symbol tables for anything else. + +### Operations + +An _operation_ is the most important entity in HIR, and the most abstract. In the _Regionalized +Value State Dependence Graph_ paper, the entire representation described there consists of various +types of operations. In HIR, we do not go quite that abstract, however we do take a fair amount of +inspiration from that paper, as well as from MLIR. + +Operations consist of the following pieces: + +* Zero or more [_regions_](#regions) and their constituent [_blocks_](#blocks) +* Zero or more [_operands_](#operands), i.e. arguments or inputs +* Zero or more _results_, or outputs, one of the two ways that [_values_](#values) can be introduced +* Zero or more [_successors_](#successors), in the case of operations which transfer control to +another block in the same region. +* Zero or more [_attributes_](#attributes), the semantics of which depend on the operation. +* Zero or more [_traits_](#traits), implemented by the operation. + +An operation always belongs to a _block_ when in use. + +As you can see, this is a highly flexible concept. It is capable of representing modules and +functions, as well as primitive instructions. It can represent both structured and unstructured +control-flow. There is very little in terms of an IR that _can't_ be represented using operations. + +However, in our case, we use operations for five specific concepts: + +* Functions (the first of two special ops) +* Global Variables (the second of two special ops) +* Structured Control Flow (if/then, do/while and for loops) +* Unstructured Control Flow (br, cond_br, switch, ret) +* Primitive Instructions (i.e. things which correspond to the target ISA, e.g. `add`, `call`, etc.) + +For the most part, the fact that functions and global variables are implemented using operations +is not particularly important. Instead, most operations you will interact with are of the other +three varieties. While we've broken them up into three categories, for the most part, they aren't +actually significantly different. The primary difference is that the unstructured control-flow ops +are valid _terminators_ for blocks, in multi-block regions, while the structured control-flow ops +are not, and only a few special cases of primitive ops are also valid terminators (namely the +`ret` and `unreachable` ops). For most primitive and structured control-flow ops, their behavior +appears very similar: they take some operands, perform some action, and possibly return some +results. + +### Regions + +A _region_ encapsulates a control-flow graph (CFG) of one or more [_basic blocks_](#blocks). In HIR, +the contents of a region are in _single-static assignment_ (SSA) form, meaning that values may only +be defined once, definitions must [_dominate_](#dominance-relation) uses, and operations in the CFG +described by the region are executed one-by-one, from the entry block of the region, until control +exits the region (e.g. via `ret` or some other terminator instruction). + +The order of operations in the region closely corresponds to their scheduling order, though the +code generator may reschedule operations when it is safe - and more efficient - to do so. + +Operations in a region may introduce nested regions. For example, the body of a function consists +of a single region, and it might contain an `if` operation that defines two nested regions, one for +the true branch, and one for the false branch. Nested regions may access any [_values_](#values) in +an ancestor region, so long as those values dominate the operation that introduced the nested region. +The exception to this are operations which are _isolated from above_. The regions of such an +operation are not permitted to reference anything defined in an outer scope, except via +[_symbols_](#symbols-and-symbol-tables). For example, [_functions_](#functions) are an operation +which is isolated from above. + +The purpose of regions, is to allow for hierarchical/structured control flow operations. Without +them, representing structured control flow in the IR is difficult and error-prone, due to the +semantics of SSA CFGs, particularly with regards to analyses like dominance and loops. It is also +an important part of what makes [_operations_](#operations) such a powerful abstraction, as it +provides a way to generically represent the concept of something like a function body, without +needing to special-case them. + +A region must always consist of at least one block (the entry block), but not all regions allow +multiple blocks. When multiple blocks are present, it implies the presence of unstructured control +flow, as the only way to transfer control between blocks is by using unstructured control flow +operations, such as `br`, `cond_br`, or `switch`. Structured control flow operations such as `if`, +introduce nested regions consisting of only a single block, as all control flow within a structured +control flow op, must itself be structured. The specific rules for a region depend on the semantics +of the containing operation. + +### Blocks + +A _block_, or _basic block_, is a set of one or more [_operations_](#operations) in which there is +no control flow, except via the block _terminator_, i.e. the last operation in the block, which is +responsible for transferring control to another block, exiting the current region (e.g. returning +from a function body), or terminating program execution in some way (e.g. `unreachable`). + +A block may declare _block parameters_, the only other way to introduce [_values_](#values) into +the IR, aside from operation results. Predecessors of a block must ensure that they provide +arguments for all block parameters when transfering control to the block. + +Blocks always belong to a [_region_](#regions). The first block in a region is called the _entry +block_, and is special in that its block parameters (if any) correspond to whatever arguments +the region accepts. For example, the body of a function is a region, and the entry block in that +region must have a parameter list that exactly matches the arity and type of the parameters +declared in the function signature. In this way, the function parameters are materialized as +SSA values in the IR. + +### Values + +A _value_ represents terms in a program, temporaries created to store data as it flows through the +program. In HIR, which is in SSA form, values are immutable - once created they cannot be changed +nor destroyed. This property of values allows them to be reused, rather than recomputed, when the +operation that produced them contains no side-effects, i.e. invoking the operation with the same +inputs must produce the same outputs. This forms the basis of one of the ways in which SSA IRs can +optimize programs. + +!!! note + +> One way in which you can form an intuition for values in an SSA IR, is by thinking of them as +> registers in a virtual machine with no limit to the number of machine registers. This corresponds +> well to the fact that most values in an IR, are of a type which corresponds to something that can +> fit in a typical machine register (e.g. 32-bit or 64-bit values, sometimes larger). +> +> Values which cannot be held in actual machine registers, are usually managed in the form of heap +> or stack-allocated memory, with various operations used to allocate, copy/move, or extract smaller +> values from them. While not strictly required by the SSA representation, this is almost always +> effectively enforced by the instruction set, which will only consist of instructions whose +> operands and results are of a type that can be held in machine registers. + +Value _definitions_ (aka "defs") can be introduced in two ways: + +1. Block parameters. Most notably, the entry block for function bodies materializes the function +parameters as values via block parameters. Block parameters are also used at places in the CFG +where two definitions for a single value are joined together. For example, if the value assigned to +a variable in the source language is assigned conditionally, then in the IR, there will be a block +with a parameter corresponding to the value of that variable after it is assigned. All uses after +that point, would refer to that block parameter, rather than the value from a specific branch. +Similarly, loop-carried variables, such as an iteration count, are typically manifested as block +parameters of the block corresponding to the loop header. +2. Operation results. The most common way in which values are introduced. + +Values have _uses_ corresponding to operands or successor arguments (special operands which are used +to satisfy successor block parameters). As a result, values also have _users_, corresponding to the +specific operation and operand forming a _use. + +All _uses_ of a value must be [_dominated_](#dominance-relation) by its _definition_. The IR is +invalid if this rule is ever violated. + +### Operands + +An _operand_ is a [_value_](#values) or [_immediate_](#immediates) used as an argument to an +operation. + +Beyond the semantics of any given operation, operand ordering is only significant in so far as it +is used as the order in which those items are expected to appear on the operand stack once lowered +to Miden Assembly. The earlier an operand appears in the list of operands for an operation, the +closer to the top of the operand stack it will appear. + +Similarly, the ordering of operand results also correlates to the operand stack order after +lowering. Specifically, the earlier a result appears in the result list, the closer to the top of +the operand stack it will appear after the operation executes. + +### Immediates + +An _immediate_ is a literal value, typically of integral type, used as an operand. Not all +operations support immediates, but those that do, will typically use them to attempt to perform +optimizations only possible when there is static information available about the operands. For +example, multiplying any number by 2, will always produce an even number, so a sequence such as +`mul.2 is_odd` can be folded to `false` at compile-time, allowing further optimizations to occur. + +Immediates are separate from _constants_, in that immediates _are_ constants, but specifically +constants which are valid operand values. + +### Attributes + +An _attribute_ is (typically optional) metadata attached to an IR entity. In HIR, attributes can be +attached to functions, global variables, and operations. + +Attributes are stored as a set of arbitrary key-value data, where values can be one of four types: + +* `unit`, Attributes of this value type are usually "marker" attributes, i.e. they convey their +information simply by being present. +* `bool`, Attributes of this value type are somewhat similar to those of `unit` type, but by +carrying a boolean value, they can be used to convey both positive and negative meaning. For +example, you might want to support explicit inlining with `#[inline(true)]`, and prevent any form +of inlining with `#[inline(false)]`. Here, `unit` would be insufficient to describe both options +under a single attribute. +* `int`, Attributes of this value type are used to convey numeric metadata. For example, inliner +thresholds, or some other kind of per-operation limits. +* `string`, Attributes of this value type are useed to convey arbitrary values. Most commonly +you might see this type with things that are enum-like, e.g. `#[cc(fast)]` to specify a particular +calling convention for a function. + +Some attributes are "first-class", in that they are defined as part of an operation. For example, +the calling convention of a function is an intrinsic attribute of a function, and feels like a +native part of the `Function` API - rather than having to look up the attribute, and cast the value +to a more natural Rust type, you can simply call `function.calling_convention()`. + +Attributes are not heavily used at this time, but are expected to serve more purposes in the future +as we increase the amount of information frontends need to convey to the compiler backend. + +### Traits + +A _trait_ defines some behavior that can be implemented by an operation. This allows operations +to operated over generically in an analysis or rewrite, rather than having to handle every possible +concrete operation type. This makes passes less fragile to changes in the IR in general, and allows +the IR to be extended without having to update every single place where operations are handled. + +An operation can be cast to a specific trait that it implements, and trait instances can be +downcast to the concrete operation type if known. + +There are a handful of built-in traits, used to convey certain semantic information about the +operations they are attached to, and in particular, are used to validate those operations, for +example: + +* `IsolatedFromAbove`, a marker trait that indicates that regions of the operation it is attached to +cannot reference items from any parents, except via [_symbols_](#symbols-and-symbol-tables). +* `Terminator`, a marker trait for operations which are valid block terminators +* `ReturnLike`, a trait that describes behavior shared by instructions that exit from an enclosing +region, "returning" the results of executing that region. The most notable of these is `ret`, but +`yield` used by the structured control flow ops is also return-like in nature. +* `BranchOp`, a trait that describes behavior shared by all unstructured control-flow branch +instructions, e.g. `br`, `cond_br`, and `switch`. +* `ConstantLike`, a marker trait for operations that produce a constant value +* `Commutative`, a marker trait for binary operations that exhibit commutativity, i.e. the order of +the operands can be swapped without changing semantics. + +There are others as well, responsible for aiding in type checking, decorating operations with the +types of side effects they do (or do not) exhibit, and more. + +### Successors and Predecessors + +The concept of _predecessor_ and _successor_ corresponds to a parent/child relationship in a +control-flow graph (CFG), where edges in the graph are directed, and describe the order in which +control flows through the program. If a node $A$ transfers control to a node $B$ after it is +finished executing, then $A$ is a _predecessor_ of $B$, and $B$ is a _successor_ of $A$. + +Successors and predecessors can be looked at from two similar, but slightly different, perspectives: + +1. In terms of operations. In an SSA CFG, operations in a basic block are executed in order, and +thus the successor of an operation in the block, is the next operation to be executed in that block, +with the predecessor being the inverse of that relationship. At basic block boundaries, the +successor(s) of the _terminator_ operation, are the set of operations to which control can be +transferred. Likewise, the predecessor(s) of the first operation in a block, are the set of +terminators which can transfer control to the containing block. This is the most precise, but is not +quite as intuitive as the alternative. +2. In terms of blocks. The successor(s) of a basic block, are the set of blocks to which control may +be transferred when exiting the block. Likewise, the precessor(s) of a block, are the set of blocks +which can transfer control to it. We are most frequently dealing with the concept of successors and +predecessors in terms of blocks, as it allows us to focus on the interesting parts of the CFG. For +example, the dominator tree and loop analyses, are constructed in terms of a block-oriented CFG, +since we can trivially derive dominance and loop information for individual ops from their +containing blocks. + +Typically, you will see successors as a pair of `(block_id, &[value_id])`, i.e. the block to which +control is transferred, and the set of values being passed as block arguments. On the other hand, +predecessors are most often a pair of `(block_id, terminator_op_id)`, i.e. the block from which +control originates, and the specific operation responsible. + +### Dominance Relation + +In an SSA IR, the concept of _dominance_ is of critical importance. Dominance is a property of the +relationship between two or more entities and their respective program points. For example, between +the use of a value as an operand for an operation, and the definition of that value; or between a +basic block and its successors. The dominance property is anti-symmetric, i.e. if $A$ dominates $B$, +then $B$ cannot dominate $A$, unless $A = B$. Put simply: + +> Given a control-flow graph $G$, and a node $A \in G$, then $\forall B \in G$, $A dom B$ if all +> paths to $B$ from the root of $G$, pass through $A$. +> +> Furthermore, $A$ _strictly_ dominates $B$, if $A \neq B$. + +An example of why dominance is an important property of a program, can be seen when considering the +meaning of a program like so (written in pseudocode): + +``` +if (...) { + var a = 1; +} + +foo(a) +``` + +Here, the definition of `a` does not dominate its usage in the call to `foo`. If the conditional +branch is ever false, `a` is never defined, nor initialized - so what should happen when we reach +the call to `foo`? + +In practice, of course, such a program is rarely possible to expresss in a high-level language, +however in a low-level CFG, it is possible to reference values which are defined somewhere in the +graph, but in such a way that is not _legal_ according to the "definitions must dominate uses" +rule of SSA CFGs. The dominance property is what we use to validate the correctness of the IR, as +well as evaluate the range of valid transformations that can be applied to the IR. For example, we +might determine that it is valid to move an expression into a specific `if/then` branch, because +it is only used in that branch - the dominance property is how we determine that there are paths +through the program in which the result of the expression is unused, as well as what program points +represent the nearest point to one of its uses that still dominates _all_ of the uses. + +There is another useful notion of dominance, called _post-dominance_, which can be described much +like the regular notion of dominance, except in terms of paths to the exit of the CFG, rather than +paths from the entry: + +> Given a control-flow graph $G$, and a node $A \in $G$, then $\forall B \in G$, $A pdom B$ if all +> paths through $B$ that exit the CFG, must flow through $A$ first. +> +> Furthermore, $A$ _strictly_ post-dominates $B$ if $A \neq B$. + +The notion of post-dominance is important in determining the applicability of certain transformations, +in particular with loops. + + +## Structure + +The hierarchy of HIR looks like so: + + Component <- Imports <- Interface + | + v + Exports + | + v + Interface + | + v + Module --------- + | | + v v + Function/Op Global Variable + | + v + -- Region + | | + | v + | Block -> Value <-- + | | | | + | | v | + | | Operand | + | v | | + > Operation <- | + | | + v | + Result ----------- + +In short: + +* A _component_ imports dependencies in the form of _interfaces_ +* A _component_ exports at least one _interface_. +* An _interface_ is concretely implemented with a _module_. +* A _module_ contains _function_ and _global variable_ definitions, and imports them from the set +of interfaces available to the component. +* A _function_, as a type of _operation_, consists of a body _region_. +* A _region_ consists of at least one _block_, that may define _values_ in the form of block +parameters. +* A _block_ consists of _operations_, whose results introduce new _values_, and whose operands are +_values_ introduced either as block parameters or the results of previous operations. +* An _operation_ may contain nested _regions_, with associated blocks and operations. + +### Passes + +A _pass_ is effectively a fancy function with a specific signature and some constraints on its +semantics. There are three primary types of passes: [_analysis_](#analyses), +[_rewrite_](#rewrites), and [_conversion_](#conversions). These three pass types have different +signatures and semantics, but play symbiotic roles in the compilation pipeline. + +There are two abstractions over all passes, used and extended by the three types described above: + +* The `Pass` trait, which provides an abstraction suitable for describing any type of compiler pass +used by `midenc`. It primarily exists to allow writing pass-generic helpers. +* The `PassInfo` trait, which exists to provide a common interface for pass metadata, such as the +name, description, and command-line flag prefix. All passes must implement this trait. + +#### Analyses + +Analysis of the IR is expressed in the form of a specialized pass, an `Analysis`, and an +`AnalysisManager`, which is responsible for computing analyses on demand, caching them, and +invalidating the relevant parts of the cache when the IR changes. Analyses are expressed in terms +of a specific entity, such as `Function`, and are cached based on the unique identity of that +entity. + +An _analysis_ is responsible for computing some fact about the given IR entity it is given. Facts +typically include things such as: computing dominance, identifying loops and their various component +parts, reachability, liveness, identifying unused (i.e. dead) code, and much more. + +To do this, analyses are given an immutable reference to the IR in question; a reference to the +current `AnalysisManager`, so that the results of other analyses can be consulted; and a reference +to the current compilation `Session` for access to configuration relevant to the analysis. + +Analysis results are computed as an instance of `Self`. This provides structure to the analysis +results, and provides a place to implement helpful functionality for querying the results. + +A well-written analysis should be based purely off the inputs given to `Analysis::analyze`, and +ideally be based on some formalism, so that properties of the analysis can be verified. Most of +the analyses in HIR today, are based on the formalisms underpinning _dataflow analysis_, e.g. +semi-join lattices. + +#### Rewrites + +_Rewrites_ are a type of pass which mutate the IR entity to which they are applied. They can be +chained (via the `chain` method on the `RewritePass` trait) to form rewrite pipelines, called a +`RewriteSet`. A `RewriteSet` manages executing each rewrite in the set, and coordinating with the +`AnalysisManager` between rewrites. + +Rewrites are given a mutable reference to the IR they should apply to; the current `AnalysisManager` +so that analyes can be consulted (or computed) to facilitate the rewrite; and the current +compilation `Session`, for configuration. + +Rewrites must leave the IR in a valid state. Rewrites are also responsible for indicating, via the +`AnalysisManager`, which analyses can be preserved after applying the rewrite. A rewrite that makes +no changes, should mark all analyses preserved, to avoid recomputing the analysis results the next +time they are requested. If an analysis is not explicitly preserved by a rewrite, it will be +invalidated by the containing `RewriteSet`. + +Rewrite passes written for a `Function`, can be adapted for application to a `Module`, using the +`ModuleRewriteAdapter`. This makes writing rewrites for use in the main compiler pipeline, as simple +as defining it for a `Function`, and then using the `ModuleRewriteAdapter` to add the rewrite to the +pipeline. + +A well-written rewrite pass should only use data available in the IR itself, or analyses in the +provided `AnalysisManager`, to drive application of the rewrite. Additionally, rewrites should +focus on a single transformation, and rely on chaining rewrites to orchestrate more complex +transformations composed of multiple stages. Lastly, the rewrite should ideally have a logical +proof of its safety, or failing that, a basis in some formalism that can be suitably analyzed and/or +tested. If a rewrite cannot be described in such a way, it will be very difficult to provide +guarantees about the code produced by the transformation. This makes it hard to be confident +in the rewrite (and by extension, the compiler), and impossible to verify. + +#### Conversions + +_Conversions_ are a type of pass which converts between intermediate representations, or more +abstractly, between _dialects_. + +The concept of a dialect is focused primarily on semantics, and less so on concrete representations. +For example, source and target dialects might share the same underlying IR, with the target dialect +having a more restricted set of _legal_ operations, possibly with stricter semantics. + +This brings us to the concept of _legalization_, i.e. converting all _illegal_ operations in the IR +into _legal_ equivalents. Each dialect defines the set of operations which it considers legal. The +concept of legality is mostly important when the same underlying IR is used across multiple dialects, +as is the case with HIR. + +There are two types of conversion passes in HIR: _dialect conversion_, and _translation_: + +* _Dialect conversion_ is used at various points in the compilation pipeline to simplify the IR +for later passes. For example, Miden Assembly has no way to represent multi-way branches, such +as implemented by `switch`. At a certain point, we switch to a dialect where `switch` is illegal, +so that further passes can be written as if `switch` doesn't exist. Yet another dialect later in +the pipeline, makes all unstructured control flow illegal, in preparation for translation to +Miden Assembly, which has no unstructured control flow operators. +* _Translation_ refers to a conversion from one IR to a completely different one. Currently, the +only translation we have, is the one responsible for translating from HIR to Miden Assembly. In +the future, we hope to also implement frontends as translations _to_ HIR, but that is not currently +the case. Translations are currently implemented as simple passes that take in some IR as input, +and produce _whatever_ as output. + +Dialect conversions are implemented in the form of generic _conversion infrastructure_. A dialect +conversion is described as a set of _conversion patterns_, which define what to do when a specific +operation is seen in the input IR; and the set of operations which are legal in the target dialect. +The conversion driver is responsible for visiting the IR, repetitively applying any matched +conversion patterns until a fixpoint is reached, i.e. no more patterns are matched, or a conversion +pattern fails to apply successfully. If any illegal operations are found which do not have +corresponding conversion patterns, then a legalization error is raised, and the conversion overall +fails. + +## Usage + +Let's get into the gritty details of how the compiler works, particularly in relation to HIR. + +The entry point for compilation, generally speaking, is the _driver_. The driver is what is +responsible for collecting compiler inputs, including configuration, from the user, instantiating a +`Session`, and then invoking the compiler frontend with those items. + +The first stage of compilation is _parsing_, in which each input is converted to HIR by an +appropriate frontend. For example, Wasm modules are loaded and translated to HIR using the Wasm +frontend. The purpose of this stage is to get all inputs into HIR form, for subsequent stages. The +only exception to this are MASM sources, which are assembled to MAST directly, and then set aside +until later in the pipeline when we go to assemble the final artifact. + +The second stage of compilation is _semantic analysis_. This is the point where we validate the +HIR we have so far, and ensure that there are no obvious issues that will cause compilation to +fail unexpectedly later on. In some cases, this stage is skipped, as we have already validated the +IR in the frontend. + +The third stage of compilation is _linking_. This is where we gather together all of the inputs, +as well as any compiler options that tell us what libraries to link against, and where to search +for them, and then ensure that there are no missing inputs, undefined symbols, or incompatible type +signatures. The output of this stage is a well-formed [_component_](#components). + +The fourth stage of compilation is _rewriting_, in which all of the rewrite passes we wish to +apply to the IR, are applied. You could also think of this stage as a combination of optimization +and preparing for codegen. + +The fifth stage of compilation is _codegen_, where HIR is translated to Miden Assembly. + +The final stage of compilation is _assembly_, where the Miden Assembly we produced, along with +any other Miden Assembly libraries, are assembled to MAST, and then packaged in the Miden package +format. + +The `Session` object contains all of the important compiler configuration and exists for the +duration of a compiler invocation. In addition to this, there is a `Context` object, which is used +to allocate IR entities contained within a single `Module`. The context of each `Module` is further +subdivided by `Function`, in the form of `FunctionContext`, from which all function-local IR +entities are allocated. This ensures that each `Function` gets its own set of values, blocks, and +regions, while sharing `Module`-wide entities, such as constants and symbols. + +Operations are allocated from the nearest `Context`, and use that context to allocate and access +IR entities used by the operation. + +The `Context` object is pervasive, it is needed just about anywhere that IR entities are used, to +allow accessing data associated with those entities. Most entity references are integer ids, which +uniquely identify the entity, but provide no access to them without going through the `Context`. + +This is a bit awkward, but is easier to work with in Rust. The alternative is to rely on +dynamically-checked interior mutability inside reference-counted allocations, e.g. `Rc>`. +Not only is this similarly pervasive in terms of how APIs are structured, but it loses some of +the performance benefits of allocating IR objects close together on the heap. + +The following is some example code in Rust, demonstrating how it might look to create a component +in HIR and work with it. + +```rust +use midenc_hir::*; + +/// Defining Ops + +/// `Add` is a binary integral arithmetic operator, i.e. `+` +/// +/// It consists of two operands, which must be of the same type, and produces a result that +/// is also of that type. +/// +/// It supports various types of overflow semantics, see `Overflow`. +pub struct Add { + op: Operation, +} +impl Add { + pub fn build(overflow: Overflow, lhs: Operand, rhs: Operand, span: SourceSpan, context: &mut Context) -> OpId { + let ctrl_ty = context.value_type(lhs); + let mut builder = OpBuilder::::new(context); + // Set the source span for this op + builder.with_span(span); + // We specify the concrete operand values + builder.with_operands([lhs, rhs]); + // We specify results in terms of types, and the builder will materialize values + // for the instruction results based on those types. + builder.with_results([ctrl_ty]); + // We must also specify operation attributes, e.g. overflow behavior + // + // Attribute values must implement `AttributeValue`, a trait which represents the encoding + // and decoding of a type as an attribute value. + builder.with_attribute("overflow", overflow); + // Add op traits that this type implements + builder.with_trait::(); + builder.with_trait::(); + builder.with_trait::(); + // Instantiate the op + // + // NOTE: In order to use `OpBuilder`, `Self` must implement `Default` if it has any other + // fields than the underlying `Operation`. This is because the `OpBuilder` will construct + // a default instance of the type when allocating it, and according to Rust rules, any + // subsequent reference to the type requires that all fields were properly initialized. + builder.build() + } + + /// The `OpOperand` type abstracts over information about a given operand, such as whether it + /// is a value or immediate, and its type. It also contains the link for the operand in the + /// original `Value`'s use list. + pub fn lhs(&self) -> &OpOperand { + &self.op.operands[0] + } + + pub fn rhs(&self) -> &OpOperand { + &self.op.operands[1] + } + + /// The `OpResult` type abstracts over information about a given result, such as whether it + /// is a value or constant, and its type. + pub fn result(&self) -> &OpResult { + &self.op.results[0] + } + + /// Attributes of the op can be reified from the `AttributeSet` of the underlying `Operation`, + /// using `get_attribute`, which will use the `AttributeValue` implementation for the type to + /// reify it from the raw attribute data. + pub fn overflow(&self) -> Overflow { + self.op.get_attribute("overflow").unwrap_or_default() + } +} +/// All ops must implement this trait, but most implementations will look like this +impl Op for Add { + type Id = OpId; + + fn id(&self) -> Self::Id { self.op.key } + fn name(&self) -> &'static str { "add" } + fn as_operation(&self) -> &Operation { &self.op} + fn as_operation_mut(&mut self) -> &mut Operation { &mut self.op} +} +/// Marker trait used to indicate an op exhibits the commutativity property +impl Commutative for Add {} +/// Marker trait used to indicate that operands of this op should all be the same type +impl SameTypeOperands for Add {} +/// Canonicalization is optional, but encouraged when an operation has a canonical form. +/// +/// This is applied after transformations which introduce or modify an `Add` op, and ensures +/// that it is in canonical form. +/// +/// Canonicalizations ensure that pattern-based rewrites can be expressed in terms of the +/// canonical form, rather than needing to account for all possible variations. +impl Canonicalize for Add { + fn canonicalize(&mut self) { + // If `add` is given an immediate operand, always place it on the right-hand side + if self.lhs().is_immediate() { + if self.rhs().is_immediate() { + return; + } + self.op.operands.swap(0, 1); + } + } +} +/// Ops can optionally implement [PrettyPrint] and [PrettyParser], to allow for a less verbose +/// textual representation. If not implemented, the op will be printed using the generic format +/// driven by the underlying `Operation`. +impl formatter::PrettyPrint for Add { + fn render(&self) -> formatter::Document { + use formatter::*; + + let opcode = match self.overflow() { + Overflow::Unchecked => const_text("add"), + Overflow::Checked => const_text("add.checked"), + Overflow::Wrapping => const_text("add.wrapping"), + Overflow::Overflowing => const_text("add.overflowing"), + }; + opcode + const_text(" ") + display(self.lhs()) + const_text(", ") + display(self.rhs()) + } +} +impl parsing::PrettyParser for Add { + fn parse(s: &str, context: &mut Context) -> Result { + todo!("not shown here") + } +} + +/// Constructing Components + +// An interface can consist of functions, global variables, and potentially types in the future +let mut std = Interface::new("std"); +std.insert("math::u64/add_checked", FunctionType::new([Type::U64, Type::U64], [Type::U64])); + +let test = Interface::new("test"); +test.insert("run", FunctionType::new([], [Type::U32])); + +// A component is instantiated empty +let mut component = Component::new("test"); + +// You must then declare the interfaces that it imports and exports +component.import(std.clone()); +component.export(test.clone()); + +// Then, to actually define a component instance, you must define modules which implement +// the interfaces exported by the component +let mut module = Module::new("test"); +let mut run = module.define("run", FunctionType::new([], [Type::U32])); +//... build 'run' function + +// And add them to the component, like shown here. Here, `module` is added to the component as an +// implementation of the `test` interface +component.implement(test, module); + +// Modules can also be added to a component without using them to implement an interface, +// in which case they are only accessible from other modules in the same component, e.g.: +let foo = Module::new("foo"); +// Here, the 'foo' module will only be accessible from the 'test' module. +component.add(foo); + +// Lastly, during compilation, imports are resolved by linking against the components which +// implement them. The linker step identifies the concrete paths of each item that provides +// an imported symbol, and rewrites the generic interface path with the concrete path of the +// item it was resolved to. + +/// Visiting IR + +// Visiting the CFG + +let entry = function.body().entry(); +let mut worklist = VecDeque::from_iter([entry]); + +while let Some(next) = worklist.pop_front() { + // Ignore blocks we've already visited + if !visited.insert(next) { + continue; + } + let terminator = context.block(next).last().unwrap(); + // Visit all successors after this block, if the terminator branches to another block + if let Some(branch) = terminator.downcast_ref::() { + worklist.extend(branch.successors().iter().map(|succ| succ.block)); + } + // Visit the operations in the block bottom-up + let mut current_op = terminator; + while let Some(op) = current_op.prev() { + current_op = op; + } +} + +// Visiting uses of a value + +let op = function.body().entry().first().unwrap(); +let result = op.results().first().unwrap(); +assert!(result.is_used()); +for user in result.uses() { + dbg!(user); +} + +// Applying a rewrite pattern to every match in a function body + +struct FoldAdd; +impl RewritePattern for FoldAdd { + fn matches(&self, op: &dyn Op) -> bool { + op.is::() + } + + fn apply(&mut self, op: &mut dyn Op) -> RewritePatternResult { + let add_op = op.downcast_mut::().unwrap(); + if let Some(lhs) = add_op.lhs().as_immediate() { + if let Some(rhs) = add_op.rhs().as_immediate() { + let result = lhs + rhs; + return Ok(RewriteAction::ReplaceAllUsesWith(add_op.result().value, result)); + } + } + Ok(RewriteAction::None) + } +} +``` diff --git a/frontend-wasm/Cargo.toml b/frontend-wasm/Cargo.toml index dfa63d375..55f45e443 100644 --- a/frontend-wasm/Cargo.toml +++ b/frontend-wasm/Cargo.toml @@ -21,6 +21,7 @@ gimli = { version = "0.31", default-features = false, features = [ ] } indexmap.workspace = true log.workspace = true +hashbrown.workspace = true miden-core.workspace = true midenc-hir.workspace = true midenc-hir-type.workspace = true diff --git a/frontend-wasm/src/component/dfg.rs b/frontend-wasm/src/component/dfg.rs index 8375b1552..2c798d413 100644 --- a/frontend-wasm/src/component/dfg.rs +++ b/frontend-wasm/src/component/dfg.rs @@ -35,8 +35,10 @@ use std::{hash::Hash, ops::Index}; use indexmap::IndexMap; -use midenc_hir::cranelift_entity::{EntityRef, PrimaryMap}; -use rustc_hash::FxHashMap; +use midenc_hir::{ + cranelift_entity::{EntityRef, PrimaryMap}, + FxHashMap, +}; use crate::{ component::*, diff --git a/frontend-wasm/src/component/inline.rs b/frontend-wasm/src/component/inline.rs index c89cbb498..f080b22fe 100644 --- a/frontend-wasm/src/component/inline.rs +++ b/frontend-wasm/src/component/inline.rs @@ -46,12 +46,11 @@ // Based on wasmtime v16.0 Wasm component translation -use std::{borrow::Cow, collections::HashMap}; +use std::borrow::Cow; use anyhow::{bail, Result}; use indexmap::IndexMap; -use midenc_hir::cranelift_entity::PrimaryMap; -use rustc_hash::FxHashMap; +use midenc_hir::{cranelift_entity::PrimaryMap, FxBuildHasher, FxHashMap}; use wasmparser::types::{ComponentAnyTypeId, ComponentEntityType, ComponentInstanceTypeId}; use super::{ @@ -61,7 +60,6 @@ use super::{ use crate::{ component::{dfg, LocalInitializer}, module::{module_env::ParsedModule, types::*, ModuleImport}, - translation_utils::BuildFxHasher, }; pub fn run( @@ -87,8 +85,7 @@ pub fn run( // // Note that this is represents the abstract state of a host import of an // item since we don't know the precise structure of the host import. - let mut args = - HashMap::with_capacity_and_hasher(root_component.exports.len(), BuildFxHasher::default()); + let mut args = FxHashMap::with_capacity_and_hasher(root_component.exports.len(), FxBuildHasher); let mut path = Vec::new(); types.resources_mut().set_current_instance(index); let types_ref = root_component.types_ref(); diff --git a/frontend-wasm/src/component/parser.rs b/frontend-wasm/src/component/parser.rs index a06002497..51bd871c1 100644 --- a/frontend-wasm/src/component/parser.rs +++ b/frontend-wasm/src/component/parser.rs @@ -3,15 +3,15 @@ // Based on wasmtime v16.0 Wasm component translation -use std::{collections::HashMap, mem}; +use std::mem; use indexmap::IndexMap; use midenc_hir::{ cranelift_entity::PrimaryMap, diagnostics::{IntoDiagnostic, Severity}, + FxBuildHasher, FxHashMap, }; use midenc_session::Session; -use rustc_hash::FxHashMap; use wasmparser::{ types::{ AliasableResourceId, ComponentEntityType, ComponentFuncTypeId, ComponentInstanceTypeId, @@ -30,7 +30,6 @@ use crate::{ TableIndex, WasmType, }, }, - translation_utils::BuildFxHasher, unsupported_diag, WasmTranslationConfig, }; @@ -707,7 +706,7 @@ impl<'a, 'data> ComponentParser<'a, 'data> { raw_args: &[wasmparser::ComponentInstantiationArg<'data>], ty: ComponentInstanceTypeId, ) -> WasmResult> { - let mut args = HashMap::with_capacity_and_hasher(raw_args.len(), BuildFxHasher::default()); + let mut args = FxHashMap::with_capacity_and_hasher(raw_args.len(), FxBuildHasher); for arg in raw_args { let idx = self.kind_to_item(arg.kind, arg.index)?; args.insert(arg.name, idx); @@ -722,7 +721,7 @@ impl<'a, 'data> ComponentParser<'a, 'data> { &mut self, exports: &[wasmparser::ComponentExport<'data>], ) -> WasmResult> { - let mut map = HashMap::with_capacity_and_hasher(exports.len(), BuildFxHasher::default()); + let mut map = FxHashMap::with_capacity_and_hasher(exports.len(), FxBuildHasher); for export in exports { let idx = self.kind_to_item(export.kind, export.index)?; map.insert(export.name.0, idx); @@ -825,7 +824,7 @@ fn instantiate_module<'data>( module: ModuleIndex, raw_args: &[wasmparser::InstantiationArg<'data>], ) -> LocalInitializer<'data> { - let mut args = HashMap::with_capacity_and_hasher(raw_args.len(), BuildFxHasher::default()); + let mut args = FxHashMap::with_capacity_and_hasher(raw_args.len(), FxBuildHasher); for arg in raw_args { match arg.kind { wasmparser::InstantiationArgKind::Instance => { @@ -842,7 +841,7 @@ fn instantiate_module<'data>( fn instantiate_module_from_exports<'data>( exports: &[wasmparser::Export<'data>], ) -> LocalInitializer<'data> { - let mut map = HashMap::with_capacity_and_hasher(exports.len(), BuildFxHasher::default()); + let mut map = FxHashMap::with_capacity_and_hasher(exports.len(), FxBuildHasher); for export in exports { let idx = match export.kind { wasmparser::ExternalKind::Func => { diff --git a/frontend-wasm/src/component/translator.rs b/frontend-wasm/src/component/translator.rs index e28ad0580..08e72f1bd 100644 --- a/frontend-wasm/src/component/translator.rs +++ b/frontend-wasm/src/component/translator.rs @@ -1,11 +1,10 @@ use midenc_hir::{ cranelift_entity::PrimaryMap, diagnostics::Severity, CanonAbiImport, ComponentBuilder, - ComponentExport, FunctionIdent, FunctionType, Ident, InterfaceFunctionIdent, InterfaceIdent, - Symbol, + ComponentExport, FunctionIdent, FunctionType, FxHashMap, Ident, InterfaceFunctionIdent, + InterfaceIdent, Symbol, }; use midenc_hir_type::Abi; use midenc_session::Session; -use rustc_hash::FxHashMap; use super::{ interface_type_to_ir, CanonicalOptions, ComponentTypes, CoreDef, CoreExport, Export, diff --git a/frontend-wasm/src/component/types/mod.rs b/frontend-wasm/src/component/types/mod.rs index 008608e5a..945534bf9 100644 --- a/frontend-wasm/src/component/types/mod.rs +++ b/frontend-wasm/src/component/types/mod.rs @@ -9,8 +9,10 @@ pub mod resources; use core::{hash::Hash, ops::Index}; use anyhow::{bail, Result}; -use midenc_hir::cranelift_entity::{EntityRef, PrimaryMap}; -use rustc_hash::FxHashMap; +use midenc_hir::{ + cranelift_entity::{EntityRef, PrimaryMap}, + FxHashMap, +}; use wasmparser::{collections::IndexSet, names::KebabString, types}; use self::resources::ResourcesBuilder; diff --git a/frontend-wasm/src/component/types/resources.rs b/frontend-wasm/src/component/types/resources.rs index 4d85f15e4..cd3faac21 100644 --- a/frontend-wasm/src/component/types/resources.rs +++ b/frontend-wasm/src/component/types/resources.rs @@ -66,7 +66,7 @@ // Based on wasmtime v16.0 Wasm component translation -use rustc_hash::FxHashMap; +use midenc_hir::FxHashMap; use wasmparser::types; use crate::component::{ diff --git a/frontend-wasm/src/miden_abi/mod.rs b/frontend-wasm/src/miden_abi/mod.rs index 3eafa8f0b..30d591587 100644 --- a/frontend-wasm/src/miden_abi/mod.rs +++ b/frontend-wasm/src/miden_abi/mod.rs @@ -3,8 +3,7 @@ pub(crate) mod transform; pub(crate) mod tx_kernel; use miden_core::crypto::hash::RpoDigest; -use midenc_hir::{FunctionType, Symbol}; -use rustc_hash::FxHashMap; +use midenc_hir::{FunctionType, FxHashMap, Symbol}; pub(crate) type FunctionTypeMap = FxHashMap<&'static str, FunctionType>; pub(crate) type ModuleFunctionTypeMap = FxHashMap<&'static str, FunctionTypeMap>; diff --git a/frontend-wasm/src/module/mod.rs b/frontend-wasm/src/module/mod.rs index 58952595c..912e48544 100644 --- a/frontend-wasm/src/module/mod.rs +++ b/frontend-wasm/src/module/mod.rs @@ -9,9 +9,8 @@ use indexmap::IndexMap; use midenc_hir::{ cranelift_entity::{packed_option::ReservedValue, EntityRef, PrimaryMap}, diagnostics::{DiagnosticsHandler, Severity}, - Ident, Symbol, + FxHashMap, Ident, Symbol, }; -use rustc_hash::FxHashMap; use self::types::*; use crate::{component::SignatureIndex, error::WasmResult, unsupported_diag}; diff --git a/frontend-wasm/src/module/module_translation_state.rs b/frontend-wasm/src/module/module_translation_state.rs index ac26f6854..66d8834db 100644 --- a/frontend-wasm/src/module/module_translation_state.rs +++ b/frontend-wasm/src/module/module_translation_state.rs @@ -1,9 +1,8 @@ use miden_core::crypto::hash::RpoDigest; use midenc_hir::{ diagnostics::{DiagnosticsHandler, Severity}, - AbiParam, CallConv, DataFlowGraph, FunctionIdent, Ident, Linkage, Signature, + AbiParam, CallConv, DataFlowGraph, FunctionIdent, FxHashMap, Ident, Linkage, Signature, }; -use rustc_hash::FxHashMap; use super::{instance::ModuleArgument, ir_func_type, EntityIndex, FuncIndex, Module, ModuleTypes}; use crate::{ diff --git a/frontend-wasm/src/translation_utils.rs b/frontend-wasm/src/translation_utils.rs index 067ee278d..284048aa4 100644 --- a/frontend-wasm/src/translation_utils.rs +++ b/frontend-wasm/src/translation_utils.rs @@ -5,14 +5,11 @@ use midenc_hir::{ AbiParam, CallConv, Felt, FieldElement, InstBuilder, Linkage, Signature, Value, }; use midenc_hir_type::{FunctionType, Type}; -use rustc_hash::FxHasher; use crate::{ error::WasmResult, module::function_builder_ext::FunctionBuilderExt, unsupported_diag, }; -pub type BuildFxHasher = std::hash::BuildHasherDefault; - /// Represents the possible sizes in bytes of the discriminant of a variant type in the component /// model #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] diff --git a/hir-analysis/Cargo.toml b/hir-analysis/Cargo.toml index 41aa1444c..4f303c598 100644 --- a/hir-analysis/Cargo.toml +++ b/hir-analysis/Cargo.toml @@ -17,9 +17,9 @@ cranelift-entity.workspace = true cranelift-bforest.workspace = true inventory.workspace = true intrusive-collections.workspace = true +hashbrown.workspace = true midenc-hir.workspace = true midenc-session.workspace = true -rustc-hash.workspace = true smallvec.workspace = true thiserror.workspace = true diff --git a/hir-analysis/src/data.rs b/hir-analysis/src/data.rs index ffa0ea8d2..60262c892 100644 --- a/hir-analysis/src/data.rs +++ b/hir-analysis/src/data.rs @@ -1,9 +1,9 @@ use midenc_hir::{ pass::{Analysis, AnalysisManager, AnalysisResult}, - Function, FunctionIdent, GlobalValue, GlobalValueData, GlobalVariableTable, Module, Program, + Function, FunctionIdent, FxHashMap, GlobalValue, GlobalValueData, GlobalVariableTable, Module, + Program, }; use midenc_session::Session; -use rustc_hash::FxHashMap; /// This analysis calculates the addresses/offsets of all global variables in a [Program] or /// [Module] diff --git a/hir-analysis/src/liveness.rs b/hir-analysis/src/liveness.rs index fcd69ed17..b6de57ed9 100644 --- a/hir-analysis/src/liveness.rs +++ b/hir-analysis/src/liveness.rs @@ -9,7 +9,6 @@ use midenc_hir::{ Block as BlockId, Inst as InstId, Value as ValueId, *, }; use midenc_session::Session; -use rustc_hash::FxHashMap; use super::{ControlFlowGraph, DominatorTree, LoopAnalysis}; @@ -417,7 +416,7 @@ fn compute_liveness( domtree: &DominatorTree, loops: &LoopAnalysis, ) { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; let mut worklist = VecDeque::from_iter(domtree.cfg_postorder().iter().copied()); diff --git a/hir-analysis/src/validation/block.rs b/hir-analysis/src/validation/block.rs index 1f0234abe..4772a6797 100644 --- a/hir-analysis/src/validation/block.rs +++ b/hir-analysis/src/validation/block.rs @@ -2,7 +2,6 @@ use midenc_hir::{ diagnostics::{DiagnosticsHandler, Report, Severity, Spanned}, *, }; -use rustc_hash::FxHashSet; use smallvec::SmallVec; use super::Rule; diff --git a/hir-macros/Cargo.toml b/hir-macros/Cargo.toml index c65f18312..c9311110c 100644 --- a/hir-macros/Cargo.toml +++ b/hir-macros/Cargo.toml @@ -17,6 +17,7 @@ proc-macro = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +darling = { version = "0.20", features = ["diagnostics"] } Inflector.workspace = true proc-macro2 = "1.0" quote = "1.0" diff --git a/hir-macros/src/lib.rs b/hir-macros/src/lib.rs index f1f1491fc..f6ca222c4 100644 --- a/hir-macros/src/lib.rs +++ b/hir-macros/src/lib.rs @@ -1,9 +1,10 @@ extern crate proc_macro; +mod operation; mod spanned; use inflector::cases::kebabcase::to_kebab_case; -use quote::quote; +use quote::{format_ident, quote}; use syn::{parse_macro_input, spanned::Spanned, Data, DeriveInput, Error, Ident, Token}; #[proc_macro_derive(Spanned, attributes(span))] @@ -25,6 +26,71 @@ pub fn derive_spanned(input: proc_macro::TokenStream) -> proc_macro::TokenStream } } +/// #[operation( +/// dialect = HirDialect, +/// traits(Terminator), +/// implements(BranchOpInterface), +/// )] +/// pub struct Switch { +/// #[operand] +/// selector: UInt32, +/// #[successors(keyed)] +/// cases: SwitchArm, +/// #[successor] +/// fallback: Successor, +/// } +/// +/// pub struct Call { +/// #[attr] +/// callee: Symbol, +/// #[operands] +/// arguments: Vec, +/// #[results] +/// results: Vec, +/// } +/// +/// #[operation] +/// pub struct If { +/// #[operand] +/// condition: Bool, +/// #[region] +/// then_region: RegionRef, +/// #[region] +/// else_region: RegionRef, +/// } +#[proc_macro_attribute] +pub fn operation( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let attr = proc_macro2::TokenStream::from(attr); + let mut input = syn::parse_macro_input!(item as syn::ItemStruct); + let span = input.span(); + + // Reconstruct the input so we can treat this like a derive macro + // + // We can't _actually_ use derive, because we need to modify the item itself. + input.attrs.push(syn::Attribute { + pound_token: syn::token::Pound(span), + style: syn::AttrStyle::Outer, + bracket_token: syn::token::Bracket(span), + meta: syn::Meta::List(syn::MetaList { + path: syn::parse_str("operation").unwrap(), + delimiter: syn::MacroDelimiter::Paren(syn::token::Paren(span)), + tokens: attr, + }), + }); + + let input = syn::parse_quote! { + #input + }; + + match operation::derive_operation(input) { + Ok(token_stream) => proc_macro::TokenStream::from(token_stream), + Err(err) => err.write_errors().into(), + } +} + #[proc_macro_derive(PassInfo)] pub fn derive_pass_info(item: proc_macro::TokenStream) -> proc_macro::TokenStream { let derive_input = parse_macro_input!(item as DeriveInput); @@ -36,7 +102,7 @@ pub fn derive_pass_info(item: proc_macro::TokenStream) -> proc_macro::TokenStrea let pass_name = to_kebab_case(&name); let pass_name_lit = syn::Lit::Str(syn::LitStr::new(&pass_name, id.span())); - let doc_ident = syn::Ident::new("doc", derive_span); + let doc_ident = format_ident!("doc", span = derive_span); let docs = derive_input .attrs .iter() diff --git a/hir-macros/src/op.rs b/hir-macros/src/op.rs new file mode 100644 index 000000000..baf0f473c --- /dev/null +++ b/hir-macros/src/op.rs @@ -0,0 +1,268 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{punctuated::Punctuated, DeriveInput, Token}; + +pub struct Op { + generics: syn::Generics, + ident: syn::Ident, + fields: syn::Fields, + args: OpArgs, +} + +#[derive(Default)] +pub struct OpArgs { + pub code: Option, + pub severity: Option, + pub help: Option, + pub labels: Option, + pub source_code: Option, + pub url: Option, + pub forward: Option, + pub related: Option, + pub diagnostic_source: Option, +} + +pub enum OpArg { + Transparent, + Code(Code), + Severity(Severity), + Help(Help), + Url(Url), + Forward(Forward), +} + +impl Parse for OpArg { + fn parse(input: ParseStream) -> syn::Result { + let ident = input.fork().parse::()?; + if ident == "transparent" { + // consume the token + let _: syn::Ident = input.parse()?; + Ok(OpArg::Transparent) + } else if ident == "forward" { + Ok(OpArg::Forward(input.parse()?)) + } else if ident == "code" { + Ok(OpArg::Code(input.parse()?)) + } else if ident == "severity" { + Ok(OpArg::Severity(input.parse()?)) + } else if ident == "help" { + Ok(OpArg::Help(input.parse()?)) + } else if ident == "url" { + Ok(OpArg::Url(input.parse()?)) + } else { + Err(syn::Error::new(ident.span(), "Unrecognized diagnostic option")) + } + } +} + +impl OpArgs { + pub(crate) fn forward_or_override_enum( + &self, + variant: &syn::Ident, + which_fn: WhichFn, + mut f: impl FnMut(&ConcreteOpArgs) -> Option, + ) -> Option { + match self { + Self::Transparent(forward) => Some(forward.gen_enum_match_arm(variant, which_fn)), + Self::Concrete(concrete) => f(concrete).or_else(|| { + concrete + .forward + .as_ref() + .map(|forward| forward.gen_enum_match_arm(variant, which_fn)) + }), + } + } +} + +impl OpArgs { + fn parse( + _ident: &syn::Ident, + fields: &syn::Fields, + attrs: &[&syn::Attribute], + allow_transparent: bool, + ) -> syn::Result { + let mut errors = Vec::new(); + + let mut concrete = OpArgs::for_fields(fields)?; + for attr in attrs { + let args = attr.parse_args_with(Punctuated::::parse_terminated); + let args = match args { + Ok(args) => args, + Err(error) => { + errors.push(error); + continue; + } + }; + + concrete.add_args(attr, args, &mut errors); + } + + let combined_error = errors.into_iter().reduce(|mut lhs, rhs| { + lhs.combine(rhs); + lhs + }); + if let Some(error) = combined_error { + Err(error) + } else { + Ok(concrete) + } + } + + fn for_fields(fields: &syn::Fields) -> Result { + let labels = Labels::from_fields(fields)?; + let source_code = SourceCode::from_fields(fields)?; + let related = Related::from_fields(fields)?; + let help = Help::from_fields(fields)?; + let diagnostic_source = DiagnosticSource::from_fields(fields)?; + Ok(Self { + code: None, + help, + related, + severity: None, + labels, + url: None, + forward: None, + source_code, + diagnostic_source, + }) + } + + fn add_args( + &mut self, + attr: &syn::Attribute, + args: impl Iterator, + errors: &mut Vec, + ) { + for arg in args { + match arg { + OpArg::Transparent => { + errors.push(syn::Error::new_spanned(attr, "transparent not allowed")); + } + OpArg::Forward(to_field) => { + if self.forward.is_some() { + errors.push(syn::Error::new_spanned( + attr, + "forward has already been specified", + )); + } + self.forward = Some(to_field); + } + OpArg::Code(new_code) => { + if self.code.is_some() { + errors + .push(syn::Error::new_spanned(attr, "code has already been specified")); + } + self.code = Some(new_code); + } + OpArg::Severity(sev) => { + if self.severity.is_some() { + errors.push(syn::Error::new_spanned( + attr, + "severity has already been specified", + )); + } + self.severity = Some(sev); + } + OpArg::Help(hl) => { + if self.help.is_some() { + errors + .push(syn::Error::new_spanned(attr, "help has already been specified")); + } + self.help = Some(hl); + } + OpArg::Url(u) => { + if self.url.is_some() { + errors + .push(syn::Error::new_spanned(attr, "url has already been specified")); + } + self.url = Some(u); + } + } + } + } +} + +impl Op { + pub fn from_derive_input(input: DeriveInput) -> Result { + let input_attrs = input + .attrs + .iter() + .filter(|x| x.path().is_ident("operation")) + .collect::>(); + Ok(match input.data { + syn::Data::Struct(data_struct) => { + let args = OpArgs::parse(&input.ident, &data_struct.fields, &input_attrs, true)?; + + Op { + fields: data_struct.fields, + ident: input.ident, + generics: input.generics, + args, + } + } + syn::Data::Enum(_) | syn::Data::Union(_) => { + return Err(syn::Error::new( + input.ident.span(), + "Can't derive Op for enums or unions", + )) + } + }) + } + + pub fn gen(&self) -> TokenStream { + let (impl_generics, ty_generics, where_clause) = &self.generics.split_for_impl(); + let concrete = &self.args; + let forward = |which| concrete.forward.as_ref().map(|fwd| fwd.gen_struct_method(which)); + let code_body = concrete + .code + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::Code)); + let help_body = concrete + .help + .as_ref() + .and_then(|x| x.gen_struct(fields)) + .or_else(|| forward(WhichFn::Help)); + let sev_body = concrete + .severity + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::Severity)); + let rel_body = concrete + .related + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::Related)); + let url_body = concrete + .url + .as_ref() + .and_then(|x| x.gen_struct(ident, fields)) + .or_else(|| forward(WhichFn::Url)); + let labels_body = concrete + .labels + .as_ref() + .and_then(|x| x.gen_struct(fields)) + .or_else(|| forward(WhichFn::Labels)); + let src_body = concrete + .source_code + .as_ref() + .and_then(|x| x.gen_struct(fields)) + .or_else(|| forward(WhichFn::SourceCode)); + let diagnostic_source = concrete + .diagnostic_source + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::DiagnosticSource)); + quote! { + impl #impl_generics miette::Diagnostic for #ident #ty_generics #where_clause { + #code_body + #help_body + #sev_body + #rel_body + #url_body + #labels_body + #src_body + #diagnostic_source + } + } + } +} diff --git a/hir-macros/src/operation.rs b/hir-macros/src/operation.rs new file mode 100644 index 000000000..62094fe85 --- /dev/null +++ b/hir-macros/src/operation.rs @@ -0,0 +1,2672 @@ +use std::rc::Rc; + +use darling::{ + util::{Flag, SpannedValue}, + Error, FromDeriveInput, FromField, FromMeta, +}; +use inflector::Inflector; +use quote::{format_ident, quote, ToTokens}; +use syn::{spanned::Spanned, Ident, Token}; + +pub fn derive_operation(input: syn::DeriveInput) -> darling::Result { + let op = OpDefinition::from_derive_input(&input)?; + + Ok(op.into_token_stream()) +} + +/// This struct represents the fully parsed and prepared definition of an operation, along with all +/// of its associated items, trait impls, etc. +pub struct OpDefinition { + /// The span of the original item decorated with `#[operation]` + span: proc_macro2::Span, + /// The name of the dialect type corresponding to the dialect this op belongs to + dialect: Ident, + /// The type name of the concrete `Op` implementation, i.e. the item with `#[operation]` on it + name: Ident, + /// The name of the operation in the textual form of the IR, e.g. `Add` would be `add`. + opcode: Ident, + /// The set of paths corresponding to the op traits we need to generate impls for + traits: darling::util::PathList, + /// The set of paths corresponding to the op traits manually implemented by this op + implements: darling::util::PathList, + /// The named regions declared for this op + regions: Vec, + /// The named attributes declared for this op + attrs: Vec, + /// The named operands, and operand groups, declared for this op + /// + /// Sequential individually named operands are collected into an "unnamed" operand group, i.e. + /// the group is not named, only the individual operands. Conversely, each "named" operand group + /// can refer to the group by name, but not the individual operands. + operands: Vec, + /// The named results of this operation + /// + /// An operation can have no results, one or more individually named results, or a single named + /// result group, but not a combination. + results: Option, + /// The named successors, and successor groups, declared for this op. + /// + /// This is represented almost identically to `operands`, except we also support successor + /// groups with "keyed" items represented by an implementation of the `KeyedSuccessor` trait. + /// Keyed successor groups are handled a bit differently than "normal" successor groups in terms + /// of the types expected by the op builder for this type. + successors: Vec, + symbols: Vec, + /// The struct definition + op: syn::ItemStruct, + /// The implementation of `{Op}Builder` for this op. + op_builder_impl: OpBuilderImpl, + /// The implementation of `OpVerifier` for this op. + op_verifier_impl: OpVerifierImpl, +} +impl OpDefinition { + /// Initialize an [OpDefinition] from the parsed [Operation] received as input + fn from_operation(span: proc_macro2::Span, op: &mut Operation) -> darling::Result { + let dialect = op.dialect.clone(); + let name = op.ident.clone(); + let opcode = op.name.clone().unwrap_or_else(|| { + let name = name.to_string().to_snake_case(); + let name = name.strip_suffix("Op").unwrap_or(name.as_str()); + format_ident!("{name}", span = name.span()) + }); + let traits = core::mem::take(&mut op.traits); + let implements = core::mem::take(&mut op.implements); + + let fields = core::mem::replace( + &mut op.data, + darling::ast::Data::Struct(darling::ast::Fields::new( + darling::ast::Style::Struct, + vec![], + )), + ) + .take_struct() + .unwrap(); + + let mut named_fields = syn::punctuated::Punctuated::::new(); + // Add the `op` field (which holds the underlying Operation) + named_fields.push(syn::Field { + attrs: vec![], + vis: syn::Visibility::Inherited, + mutability: syn::FieldMutability::None, + ident: Some(format_ident!("op")), + colon_token: Some(syn::token::Colon(span)), + ty: make_type("::midenc_hir2::Operation"), + }); + + let op = syn::ItemStruct { + attrs: core::mem::take(&mut op.attrs), + vis: op.vis.clone(), + struct_token: syn::token::Struct(span), + ident: name.clone(), + generics: core::mem::take(&mut op.generics), + fields: syn::Fields::Named(syn::FieldsNamed { + brace_token: syn::token::Brace(span), + named: named_fields, + }), + semi_token: None, + }; + + let op_builder_impl = OpBuilderImpl::empty(name.clone()); + let op_verifier_impl = + OpVerifierImpl::new(name.clone(), traits.clone(), implements.clone()); + + let mut this = Self { + span, + dialect, + name, + opcode, + traits, + implements, + regions: vec![], + attrs: vec![], + operands: vec![], + results: None, + successors: vec![], + symbols: vec![], + op, + op_builder_impl, + op_verifier_impl, + }; + + this.hydrate(fields)?; + + Ok(this) + } + + fn hydrate(&mut self, fields: darling::ast::Fields) -> darling::Result<()> { + let named_fields = match &mut self.op.fields { + syn::Fields::Named(syn::FieldsNamed { ref mut named, .. }) => named, + _ => unreachable!(), + }; + let mut create_params = vec![]; + let (_, mut fields) = fields.split(); + + // Compute the absolute ordering of op parameters as follows: + // + // * By default, the ordering is implied by the order of field declarations in the struct + // * A field can be decorated with #[order(N)], where `N` is an absolute index + // * If all fields have an explicit order, then the sort following that order is used + // * If a mix of fields have explicit ordering, so as to acheive a particular struct layout, + // then the implicit order given to a field ensures that it appears after the highest + // ordered field which comes before it in the struct. For example, if I have the following + // pseudo-struct definition: `{ #[order(2)] a, b, #[order(1)] c, d }`, then the actual + // order of the parameters corresponding to those fields will be `c`, `a`, `b`, `d`. This + // is due to the fact that a.) `b` is assigned an index of `3` because it is the next + // available index following `2`, which was assigned to `a` before it in the struct, and + // 2.) `d` is assigned an index of `4`, as it is the next highest available index after + // `2`, which is the highest explicitly ordered field that is defined before it in the + // struct. + let mut assigned_highwater = 0; + let mut highwater = 0; + let mut claimed_indices = fields.iter().filter_map(|f| f.attrs.order).collect::>(); + claimed_indices.sort(); + claimed_indices.dedup(); + for field in fields.iter_mut() { + match field.attrs.order { + // If this order precedes a previous #[order] field, skip it + Some(order) if highwater > order => continue, + Some(order) => { + // Move high water mark to `order` + highwater = order; + } + None => { + // Find the next unused index > `highwater` && `assigned_highwater` + assigned_highwater = core::cmp::max(assigned_highwater, highwater); + let mut next_index = assigned_highwater + 1; + while claimed_indices.contains(&next_index) { + next_index += 1; + } + assigned_highwater = next_index; + field.attrs.order = Some(next_index); + } + } + } + fields.sort_by_key(|field| field.attrs.order); + + for field in fields { + let field_name = field.ident.clone().unwrap(); + let field_span = field_name.span(); + let field_ty = field.ty.clone(); + + let op_field_ty = field.attrs.pseudo_type(); + match op_field_ty.as_deref() { + // Forwarded field + None => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::CustomField(field_name.clone(), field_ty), + r#default: field.attrs.default.is_present(), + }); + named_fields.push(syn::Field { + attrs: field.attrs.forwarded, + vis: field.vis, + mutability: syn::FieldMutability::None, + ident: Some(field_name), + colon_token: Some(syn::token::Colon(field_span)), + ty: field.ty, + }); + continue; + } + Some(OperationFieldType::Attr) => { + let attr = OpAttribute { + name: field_name, + ty: field_ty, + }; + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Attr(attr.clone()), + r#default: field.attrs.default.is_present(), + }); + self.attrs.push(attr); + } + Some(OperationFieldType::Operand) => { + let operand = Operand { + name: field_name.clone(), + constraint: field_ty, + }; + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Operand(operand.clone()), + r#default: field.attrs.default.is_present(), + }); + match self.operands.last_mut() { + None => { + self.operands.push(OpOperandGroup::Unnamed(vec![operand])); + } + Some(OpOperandGroup::Unnamed(ref mut operands)) => { + operands.push(operand); + } + Some(OpOperandGroup::Named(..)) => { + // Start a new group + self.operands.push(OpOperandGroup::Unnamed(vec![operand])); + } + } + } + Some(OperationFieldType::Operands) => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::OperandGroup( + field_name.clone(), + field_ty.clone(), + ), + r#default: field.attrs.default.is_present(), + }); + self.operands.push(OpOperandGroup::Named(field_name, field_ty)); + } + Some(OperationFieldType::Result) => { + let result = OpResult { + name: field_name.clone(), + constraint: field_ty, + }; + match self.results.as_mut() { + None => { + self.results = Some(OpResultGroup::Unnamed(vec![result])); + } + Some(OpResultGroup::Unnamed(ref mut results)) => { + results.push(result); + } + Some(OpResultGroup::Named(..)) => { + return Err(Error::custom("#[result] and #[results] cannot be mixed") + .with_span(&field_name)); + } + } + } + Some(OperationFieldType::Results) => match self.results.as_mut() { + None => { + self.results = Some(OpResultGroup::Named(field_name, field_ty)); + } + Some(OpResultGroup::Unnamed(_)) => { + return Err(Error::custom("#[result] and #[results] cannot be mixed") + .with_span(&field_name)); + } + Some(OpResultGroup::Named(..)) => { + return Err(Error::custom("#[results] may only appear on a single field") + .with_span(&field_name)); + } + }, + Some(OperationFieldType::Region) => { + self.regions.push(field_name); + } + Some(OperationFieldType::Successor) => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Successor(field_name.clone()), + r#default: field.attrs.default.is_present(), + }); + match self.successors.last_mut() { + None => { + self.successors.push(SuccessorGroup::Unnamed(vec![field_name])); + } + Some(SuccessorGroup::Unnamed(ref mut ids)) => { + ids.push(field_name); + } + Some(SuccessorGroup::Named(_) | SuccessorGroup::Keyed(..)) => { + // Start a new group + self.successors.push(SuccessorGroup::Unnamed(vec![field_name])); + } + } + } + Some(OperationFieldType::Successors(SuccessorsType::Default)) => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::SuccessorGroupNamed(field_name.clone()), + r#default: field.attrs.default.is_present(), + }); + self.successors.push(SuccessorGroup::Named(field_name)); + } + Some(OperationFieldType::Successors(SuccessorsType::Keyed)) => { + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::SuccessorGroupKeyed( + field_name.clone(), + field_ty.clone(), + ), + r#default: field.attrs.default.is_present(), + }); + self.successors.push(SuccessorGroup::Keyed(field_name, field_ty)); + } + Some(OperationFieldType::Symbol(None)) => { + let symbol = Symbol { + name: field_name, + ty: SymbolType::Concrete(field_ty), + }; + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Symbol(symbol.clone()), + r#default: field.attrs.default.is_present(), + }); + self.symbols.push(symbol); + } + Some(OperationFieldType::Symbol(Some(ty))) => { + let symbol = Symbol { + name: field_name, + ty: ty.clone(), + }; + create_params.push(OpCreateParam { + param_ty: OpCreateParamType::Symbol(symbol.clone()), + r#default: field.attrs.default.is_present(), + }); + self.symbols.push(symbol); + } + } + } + + self.op_builder_impl.set_create_params(&self.op.generics, create_params); + + Ok(()) + } +} +impl FromDeriveInput for OpDefinition { + fn from_derive_input(input: &syn::DeriveInput) -> darling::Result { + let span = input.span(); + let mut operation = Operation::from_derive_input(input)?; + Self::from_operation(span, &mut operation) + } +} + +struct OpCreateFn<'a> { + op: &'a OpDefinition, + generics: syn::Generics, +} +impl<'a> OpCreateFn<'a> { + pub fn new(op: &'a OpDefinition) -> Self { + // Op::create generic parameters + let generics = syn::Generics { + lt_token: Some(syn::token::Lt(op.span)), + params: syn::punctuated::Punctuated::from_iter( + [syn::parse_str("B: ?Sized + ::midenc_hir2::Builder").unwrap()] + .into_iter() + .chain(op.op_builder_impl.buildable_op_impl.generics.params.iter().cloned()), + ), + gt_token: Some(syn::token::Gt(op.span)), + where_clause: op.op_builder_impl.buildable_op_impl.generics.where_clause.clone(), + }; + + Self { op, generics } + } +} + +struct WithAttrs<'a>(&'a OpDefinition); +impl quote::ToTokens for WithAttrs<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for param in self.0.op_builder_impl.create_params.iter() { + if let OpCreateParamType::Attr(OpAttribute { name, .. }) = ¶m.param_ty { + let field_name = syn::Lit::Str(syn::LitStr::new(&format!("{name}"), name.span())); + tokens.extend(quote! { + op_builder.with_attr(#field_name, #name); + }); + } + } + } +} + +struct WithSymbols<'a>(&'a OpDefinition); +impl quote::ToTokens for WithSymbols<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for param in self.0.op_builder_impl.create_params.iter() { + if let OpCreateParamType::Symbol(Symbol { name, ty }) = ¶m.param_ty { + let field_name = syn::Lit::Str(syn::LitStr::new(&format!("{name}"), name.span())); + match ty { + SymbolType::Any | SymbolType::Concrete(_) | SymbolType::Trait(_) => { + tokens.extend(quote! { + op_builder.with_symbol(#field_name, #name); + }); + } + SymbolType::Callable => { + tokens.extend(quote! { + op_builder.with_callable_symbol(#field_name, #name); + }); + } + } + } + } + } +} + +struct WithOperands<'a>(&'a OpDefinition); +impl quote::ToTokens for WithOperands<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for (group_index, group) in self.0.operands.iter().enumerate() { + match group { + OpOperandGroup::Unnamed(operands) => { + let group_index = syn::Lit::Int(syn::LitInt::new( + &format!("{group_index}usize"), + operands[0].name.span(), + )); + let operand_name = operands.iter().map(|o| &o.name).collect::>(); + let operand_constraint = operands.iter().map(|o| &o.constraint); + let constraint_violation = operands.iter().map(|o| { + syn::Lit::Str(syn::LitStr::new( + &format!("type constraint violation for '{}'", &o.name), + o.name.span(), + )) + }); + tokens.extend(quote! { + #( + { + let value = #operand_name.borrow(); + let value_ty = value.ty(); + if !<#operand_constraint as ::midenc_hir2::traits::TypeConstraint>::matches(value_ty) { + let expected = <#operand_constraint as ::midenc_hir2::traits::TypeConstraint>::description(); + return Err(builder.context() + .session + .diagnostics + .diagnostic(::midenc_session::diagnostics::Severity::Error) + .with_message("invalid operand") + .with_primary_label(span, #constraint_violation) + .with_secondary_label(value.span(), format!("this value has type '{value_ty}', but expected '{expected}'")) + .into_report()); + } + } + )* + op_builder.with_operands_in_group(#group_index, [#(#operand_name),*]); + }); + } + OpOperandGroup::Named(group_name, group_constraint) => { + let group_index = syn::Lit::Int(syn::LitInt::new( + &format!("{group_index}usize"), + group_name.span(), + )); + let constraint_violation = syn::Lit::Str(syn::LitStr::new( + &format!("type constraint violation for operand in '{group_name}'"), + group_name.span(), + )); + tokens.extend(quote! { + let #group_name = #group_name.into_iter().collect::<::alloc::vec::Vec<_>>(); + for operand in #group_name.iter() { + let value = operand.borrow(); + let value_ty = value.ty(); + if !<#group_constraint as ::midenc_hir2::traits::TypeConstraint>::matches(value_ty) { + let expected = <#group_constraint as ::midenc_hir2::traits::TypeConstraint>::description(); + return Err(builder.context() + .session + .diagnostics + .diagnostic(::midenc_session::diagnostics::Severity::Error) + .with_message("invalid operand") + .with_primary_label(span, #constraint_violation) + .with_secondary_label(value.span(), format!("this value has type '{value_ty}', but expected '{expected}'")) + .into_report()); + } + } + op_builder.with_operands_in_group(#group_index, #group_name); + }); + } + } + } + } +} + +struct InitializeCustomFields<'a>(&'a OpDefinition); +impl quote::ToTokens for InitializeCustomFields<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for param in self.0.op_builder_impl.create_params.iter() { + if let OpCreateParamType::CustomField(id, ..) = ¶m.param_ty { + tokens.extend(quote! { + core::ptr::addr_of_mut!((*__ptr).#id).write(#id); + }); + } + } + } +} + +struct WithResults<'a>(&'a OpDefinition); +impl quote::ToTokens for WithResults<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + match self.0.results.as_ref() { + None => (), + Some(OpResultGroup::Unnamed(results)) => { + let num_results = syn::Lit::Int(syn::LitInt::new( + &format!("{}usize", results.len()), + results[0].name.span(), + )); + tokens.extend(quote! { + op_builder.with_results(#num_results); + }); + } + // Named result groups can have an arbitrary number of results + Some(OpResultGroup::Named(..)) => (), + } + } +} + +struct WithSuccessors<'a>(&'a OpDefinition); +impl quote::ToTokens for WithSuccessors<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for group in self.0.successors.iter() { + match group { + SuccessorGroup::Unnamed(successors) => { + let successor_args = successors.iter().map(|s| format_ident!("{s}_args")); + tokens.extend(quote! { + op_builder.with_successors([ + #(( + #successors, + #successor_args.into_iter().collect::<::alloc::vec::Vec<_>>(), + ),)* + ]); + }); + } + SuccessorGroup::Named(name) => { + tokens.extend(quote! { + op_builder.with_successors(#name); + }); + } + SuccessorGroup::Keyed(name, _) => { + tokens.extend(quote! { + op_builder.with_keyed_successors(#name); + }); + } + } + } + } +} + +struct BuildOp<'a>(&'a OpDefinition); +impl quote::ToTokens for BuildOp<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + match self.0.results.as_ref() { + None => { + tokens.extend(quote! { + op_builder.build() + }); + } + Some(group) => { + let verify_result_constraints = match group { + OpResultGroup::Unnamed(results) => { + let verify_result = results.iter().map(|result| { + let result_name = &result.name; + let result_constraint = &result.constraint; + let constraint_violation = syn::Lit::Str(syn::LitStr::new(&format!("type constraint violation for result '{result_name}'"), result_name.span())); + quote! { + { + let op_result = op.#result_name(); + let value_ty = op_result.ty(); + if !<#result_constraint as ::midenc_hir2::traits::TypeConstraint>::matches(value_ty) { + let expected = <#result_constraint as ::midenc_hir2::traits::TypeConstraint>::description(); + return Err(builder.context() + .session + .diagnostics + .diagnostic(::midenc_session::diagnostics::Severity::Error) + .with_message("invalid operation") + .with_primary_label(span, #constraint_violation) + .with_secondary_label(op_result.span(), format!("this value has type '{value_ty}', but expected '{expected}'")) + .into_report()); + } + } + } + }); + quote! { + #( + #verify_result + )* + } + } + OpResultGroup::Named(name, constraint) => { + let constraint_violation = syn::Lit::Str(syn::LitStr::new( + &format!("type constraint violation for result in '{name}'"), + name.span(), + )); + quote! { + { + let results = op.#name(); + for result in results.iter() { + let value = result.borrow(); + let value_ty = value.ty(); + if !<#constraint as ::midenc_hir2::traits::TypeConstraint>::matches(value_ty) { + let expected = <#constraint as ::midenc_hir2::traits::TypeConstraint>::description(); + return Err(builder.context() + .session + .diagnostics + .diagnostic(::midenc_session::diagnostics::Severity::Error) + .with_message("invalid operation") + .with_primary_label(span, #constraint_violation) + .with_secondary_label(value.span(), format!("this value has type '{value_ty}', but expected '{expected}'")) + .into_report()); + } + } + } + } + } + }; + + tokens.extend(quote! { + let op = op_builder.build()?; + + { + let op = op.borrow(); + #verify_result_constraints + } + + Ok(op) + }) + } + } + } +} + +impl quote::ToTokens for OpCreateFn<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let dialect = &self.op.dialect; + let (impl_generics, _, where_clause) = self.generics.split_for_impl(); + let param_names = + self.op.op_builder_impl.create_params.iter().flat_map(OpCreateParam::bindings); + let param_types = self + .op + .op_builder_impl + .create_params + .iter() + .flat_map(OpCreateParam::binding_types); + let initialize_custom_fields = InitializeCustomFields(self.op); + let with_symbols = WithSymbols(self.op); + let with_attrs = WithAttrs(self.op); + let with_operands = WithOperands(self.op); + let with_results = WithResults(self.op); + let with_regions = self.op.regions.iter().map(|_| { + quote! { + op_builder.create_region(); + } + }); + let with_successors = WithSuccessors(self.op); + let build_op = BuildOp(self.op); + + tokens.extend(quote! { + /// Manually construct a new [#op_ident] + /// + /// It is generally preferable to use [`::midenc_hir2::Builder::create`] instead. + pub fn create #impl_generics( + builder: &mut B, + span: ::midenc_session::diagnostics::SourceSpan, + #( + #param_names: #param_types, + )* + ) -> Result<::midenc_hir2::UnsafeIntrusiveEntityRef, ::midenc_session::diagnostics::Report> + #where_clause + { + use ::midenc_hir2::{Builder, Op}; + let mut __this = { + let __operation_name = { + let context = builder.context(); + let dialect = context.get_or_register_dialect::<#dialect>(); + ::register_with(&*dialect) + }; + let __context = builder.context_rc(); + let mut __op = __context.alloc_uninit_tracked::(); + unsafe { + { + let mut __uninit = __op.borrow_mut(); + let __ptr = (*__uninit).as_mut_ptr(); + let __offset = core::mem::offset_of!(Self, op); + let __op_ptr = core::ptr::addr_of_mut!((*__ptr).op); + __op_ptr.write(::midenc_hir2::Operation::uninit::(__context, __operation_name, __offset)); + #initialize_custom_fields + } + let mut __this = ::midenc_hir2::UnsafeIntrusiveEntityRef::assume_init(__op); + __this.borrow_mut().set_span(span); + __this + } + }; + + let mut op_builder = ::midenc_hir2::OperationBuilder::new(builder, __this); + #with_attrs + #with_symbols + #with_operands + #( + #with_regions + )* + #with_successors + #with_results + + // Finalize construction of this op, verifying it + #build_op + } + }); + } +} + +impl quote::ToTokens for OpDefinition { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let op_ident = &self.name; + let (impl_generics, ty_generics, where_clause) = self.op.generics.split_for_impl(); + + // struct $Op + self.op.to_tokens(tokens); + + // impl Spanned + tokens.extend(quote! { + impl #impl_generics ::midenc_session::diagnostics::Spanned for #op_ident #ty_generics #where_clause { + fn span(&self) -> ::midenc_session::diagnostics::SourceSpan { + self.op.span() + } + } + }); + + // impl AsRef/AsMut + tokens.extend(quote! { + impl #impl_generics AsRef<::midenc_hir2::Operation> for #op_ident #ty_generics #where_clause { + #[inline(always)] + fn as_ref(&self) -> &::midenc_hir2::Operation { + &self.op + } + } + + impl #impl_generics AsMut<::midenc_hir2::Operation> for #op_ident #ty_generics #where_clause { + #[inline(always)] + fn as_mut(&mut self) -> &mut ::midenc_hir2::Operation { + &mut self.op + } + } + }); + + // impl Op + // impl OpRegistration + let opcode = &self.opcode; + let opcode_str = syn::Lit::Str(syn::LitStr::new(&opcode.to_string(), opcode.span())); + let traits = &self.traits; + let implements = &self.implements; + tokens.extend(quote! { + impl #impl_generics ::midenc_hir2::Op for #op_ident #ty_generics #where_clause { + #[inline] + fn name(&self) -> ::midenc_hir2::OperationName { + self.op.name() + } + + #[inline(always)] + fn as_operation(&self) -> &::midenc_hir2::Operation { + &self.op + } + + #[inline(always)] + fn as_operation_mut(&mut self) -> &mut ::midenc_hir2::Operation { + &mut self.op + } + } + + impl #impl_generics ::midenc_hir2::OpRegistration for #op_ident #ty_generics #where_clause { + fn name() -> ::midenc_hir_symbol::Symbol { + ::midenc_hir_symbol::Symbol::intern(#opcode_str) + } + + fn register_with(dialect: &dyn ::midenc_hir2::Dialect) -> ::midenc_hir2::OperationName { + let opcode = ::name(); + dialect.get_or_register_op( + opcode, + |dialect_name, opcode| { + ::midenc_hir2::OperationName::new::( + dialect_name, + opcode, + [ + ::midenc_hir2::traits::TraitInfo::new::(), + ::midenc_hir2::traits::TraitInfo::new::(), + #( + ::midenc_hir2::traits::TraitInfo::new::(), + )* + #( + ::midenc_hir2::traits::TraitInfo::new::(), + )* + ] + ) + } + ) + } + } + }); + + // impl $OpBuilder + // impl BuildableOp + self.op_builder_impl.to_tokens(tokens); + + // impl $Op + { + let create_fn = OpCreateFn::new(self); + let custom_field_fns = OpCustomFieldFns(self); + let attr_fns = OpAttrFns(self); + let symbol_fns = OpSymbolFns(self); + let operand_fns = OpOperandFns(self); + let result_fns = OpResultFns(self); + let region_fns = OpRegionFns(self); + let successor_fns = OpSuccessorFns(self); + tokens.extend(quote! { + /// Construction + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #create_fn + } + + /// User-defined Fields + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #custom_field_fns + } + + /// Attributes + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #attr_fns + } + + /// Symbols + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #symbol_fns + } + + /// Operands + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #operand_fns + } + + /// Results + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #result_fns + } + + /// Regions + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #region_fns + } + + /// Successors + #[allow(unused)] + impl #impl_generics #op_ident #ty_generics #where_clause { + #successor_fns + } + }); + } + + // impl $DerivedTrait + for derived_trait in self.traits.iter() { + tokens.extend(quote! { + impl #impl_generics #derived_trait for #op_ident #ty_generics #where_clause {} + }); + } + + // impl OpVerifier + self.op_verifier_impl.to_tokens(tokens); + } +} + +struct OpCustomFieldFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpCustomFieldFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // User-defined fields + for field in self.0.op.fields.iter() { + let field_name = field.ident.as_ref().unwrap(); + // Do not generate field functions for custom fields with private visibility + if matches!(field.vis, syn::Visibility::Inherited) { + continue; + } + let field_name_mut = format_ident!("{field_name}_mut"); + let set_field_name = format_ident!("set_{field_name}"); + let field_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the value of `{field_name}`"), + field_name.span(), + )); + let field_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the value of `{field_name}`"), + field_name.span(), + )); + let set_field_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Set the value of `{field_name}`"), + field_name.span(), + )); + let field_ty = &field.ty; + tokens.extend(quote! { + #[doc = #field_doc] + #[inline] + pub fn #field_name(&self) -> &#field_ty { + &self.#field_name + } + + #[doc = #field_mut_doc] + #[inline] + pub fn #field_name_mut(&mut self) -> &mut #field_ty { + &mut self.#field_name + } + + #[doc = #set_field_doc] + #[inline] + pub fn #set_field_name(&mut self, #field_name: #field_ty) { + self.#field_name = #field_name; + } + }); + } + } +} + +struct OpSymbolFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpSymbolFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // Symbols + for Symbol { + name: ref symbol, + ty: ref symbol_kind, + } in self.0.symbols.iter() + { + let span = symbol.span(); + let symbol_str = syn::Lit::Str(syn::LitStr::new(&symbol.to_string(), span)); + let symbol_mut = format_ident!("{symbol}_mut"); + let set_symbol = format_ident!("set_{symbol}"); + let set_symbol_unchecked = format_ident!("set_{symbol}_unchecked"); + let symbol_symbol = format_ident!("{symbol}_symbol"); + let symbol_symbol_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get the symbol under which the `{symbol}` attribute is stored"), + span, + )); + let symbol_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the value of the `{symbol}` attribute."), + span, + )); + let symbol_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the value of the `{symbol}` attribute."), + span, + )); + let set_symbol_doc_lines = [ + syn::Lit::Str(syn::LitStr::new( + &format!(" Set the value of the `{symbol}` symbol."), + span, + )), + syn::Lit::Str(syn::LitStr::new("", span)), + syn::Lit::Str(syn::LitStr::new( + " Returns `Err` if the symbol cannot be resolved in the nearest symbol table.", + span, + )), + ]; + let set_symbol_unchecked_doc_lines = [ + syn::Lit::Str(syn::LitStr::new( + &format!( + " Set the value of the `{symbol}` symbol without attempting to resolve it." + ), + span, + )), + syn::Lit::Str(syn::LitStr::new("", span)), + syn::Lit::Str(syn::LitStr::new( + " Because this does not resolve the given symbol, the caller is responsible \ + for updating the symbol use list.", + span, + )), + ]; + + tokens.extend(quote! { + #[doc = #symbol_symbol_doc] + #[inline(always)] + pub fn #symbol_symbol() -> ::midenc_hir_symbol::Symbol { + ::midenc_hir_symbol::Symbol::intern(#symbol_str) + } + + #[doc = #symbol_doc] + pub fn #symbol(&self) -> &::midenc_hir2::SymbolNameAttr { + self.op.get_typed_attribute(Self::#symbol_symbol()).unwrap() + } + + #[doc = #symbol_mut_doc] + pub fn #symbol_mut(&mut self) -> &mut ::midenc_hir2::SymbolNameAttr { + self.op.get_typed_attribute_mut(Self::#symbol_symbol()).unwrap() + } + + #( + #[doc = #set_symbol_unchecked_doc_lines] + )* + pub fn #set_symbol_unchecked(&mut self, value: ::midenc_hir2::SymbolNameAttr) { + self.op.set_attribute(Self::#symbol_symbol(), Some(value)); + } + }); + + let is_concrete_ty = match symbol_kind { + SymbolType::Concrete(ref ty) => [quote! { + // The way we check the type depends on whether `symbol` is a reference to `self` + let (data_ptr, _) = ::midenc_hir2::SymbolRef::as_ptr(&symbol).to_raw_parts(); + if core::ptr::addr_eq(data_ptr, (self as *const Self as *const ())) { + if !self.op.is::<#ty>() { + return Err(::midenc_hir2::InvalidSymbolRefError::InvalidType { + symbol: span, + expected: stringify!(#ty), + got: self.op.name(), + }); + } + } else { + if !symbol.borrow().is::<#ty>() { + return Err(::midenc_hir2::InvalidSymbolRefError::InvalidType { + symbol: span, + expected: stringify!(#ty), + got: symbol.as_symbol_operation().name(), + }); + } + } + }], + _ => [quote! {}], + }; + + match symbol_kind { + SymbolType::Any | SymbolType::Trait(_) | SymbolType::Concrete(_) => { + tokens.extend(quote! { + #( + #[doc = #set_symbol_doc_lines] + )* + pub fn #set_symbol(&mut self, symbol: impl ::midenc_hir2::AsSymbolRef) -> Result<(), ::midenc_hir2::InvalidSymbolRefError> { + let symbol = symbol.as_symbol_ref(); + #(#is_concrete_ty)* + self.op.set_symbol_attribute(Self::#symbol_symbol(), symbol); + + Ok(()) + } + }); + } + SymbolType::Callable => { + tokens.extend(quote! { + #( + #[doc = #set_symbol_doc_lines] + )* + pub fn #set_symbol(&mut self, symbol: impl ::midenc_hir2::AsCallableSymbolRef) -> Result<(), ::midenc_hir2::InvalidSymbolRefError> { + let symbol = symbol.as_callable_symbol_ref(); + let (data_ptr, _) = ::midenc_hir2::SymbolRef::as_ptr(&symbol).to_raw_parts(); + if core::ptr::addr_eq(data_ptr, (self as *const Self as *const ())) { + if !self.op.implements::() { + return Err(::midenc_hir2::InvalidSymbolRefError::NotCallable { + symbol: self.span(), + }); + } + } else { + let symbol = symbol.borrow(); + let symbol_op = symbol.as_symbol_operation(); + if !symbol_op.implements::() { + return Err(::midenc_hir2::InvalidSymbolRefError::NotCallable { + symbol: symbol_op.span(), + }); + } + } + self.op.set_symbol_attribute(Self::#symbol_symbol(), symbol.clone()); + + Ok(()) + } + }); + } + } + } + } +} + +struct OpAttrFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpAttrFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // Attributes + for OpAttribute { + name: ref attr, + ty: ref attr_ty, + } in self.0.attrs.iter() + { + let attr_str = syn::Lit::Str(syn::LitStr::new(&attr.to_string(), attr.span())); + let attr_mut = format_ident!("{attr}_mut"); + let set_attr = format_ident!("set_{attr}"); + let attr_symbol = format_ident!("{attr}_symbol"); + let attr_symbol_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get the symbol under which the `{attr}` attribute is stored"), + attr.span(), + )); + let attr_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the value of the `{attr}` attribute."), + attr.span(), + )); + let attr_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the value of the `{attr}` attribute."), + attr.span(), + )); + let set_attr_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Set the value of the `{attr}` attribute."), + attr.span(), + )); + tokens.extend(quote! { + #[doc = #attr_symbol_doc] + #[inline(always)] + pub fn #attr_symbol() -> ::midenc_hir_symbol::Symbol { + ::midenc_hir_symbol::Symbol::intern(#attr_str) + } + + #[doc = #attr_doc] + pub fn #attr(&self) -> &#attr_ty { + self.op.get_typed_attribute::<#attr_ty>(Self::#attr_symbol()).unwrap() + } + + #[doc = #attr_mut_doc] + pub fn #attr_mut(&mut self) -> &mut #attr_ty { + self.op.get_typed_attribute_mut::<#attr_ty>(Self::#attr_symbol()).unwrap() + } + + #[doc = #set_attr_doc] + pub fn #set_attr(&mut self, value: impl Into<#attr_ty>) { + self.op.set_attribute(Self::#attr_symbol(), Some(value.into())); + } + }); + } + } +} + +struct OpOperandFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpOperandFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for (group_index, operand_group) in self.0.operands.iter().enumerate() { + let group_index = syn::Lit::Int(syn::LitInt::new( + &format!("{group_index}usize"), + proc_macro2::Span::call_site(), + )); + match operand_group { + // Operands + OpOperandGroup::Unnamed(operands) => { + for ( + operand_index, + Operand { + name: ref operand, .. + }, + ) in operands.iter().enumerate() + { + let operand_index = syn::Lit::Int(syn::LitInt::new( + &format!("{operand_index}usize"), + proc_macro2::Span::call_site(), + )); + let operand_mut = format_ident!("{operand}_mut"); + let operand_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{operand}` operand."), + operand.span(), + )); + let operand_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{operand}` operand."), + operand.span(), + )); + tokens.extend(quote!{ + #[doc = #operand_doc] + #[inline] + pub fn #operand(&self) -> ::midenc_hir2::EntityRef<'_, ::midenc_hir2::OpOperandImpl> { + self.op.operands().group(#group_index)[#operand_index].borrow() + } + + #[doc = #operand_mut_doc] + #[inline] + pub fn #operand_mut(&mut self) -> ::midenc_hir2::EntityMut<'_, ::midenc_hir2::OpOperandImpl> { + self.op.operands_mut().group_mut(#group_index)[#operand_index].borrow_mut() + } + }); + } + } + // User-defined operand groups + OpOperandGroup::Named(group_name, _) => { + let group_name_mut = format_ident!("{group_name}_mut"); + let group_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{group_name}` operand group."), + group_name.span(), + )); + let group_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{group_name}` operand group."), + group_name.span(), + )); + tokens.extend(quote! { + #[doc = #group_doc] + #[inline] + pub fn #group_name(&self) -> ::midenc_hir2::OpOperandRange<'_> { + self.op.operands().group(#group_index) + } + + #[doc = #group_mut_doc] + #[inline] + pub fn #group_name_mut(&mut self) -> ::midenc_hir2::OpOperandRangeMut<'_> { + self.op.operands_mut().group_mut(#group_index) + } + }); + } + } + } + } +} + +struct OpResultFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpResultFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + if let Some(group) = self.0.results.as_ref() { + match group { + OpResultGroup::Unnamed(results) => { + for ( + index, + OpResult { + name: ref result, .. + }, + ) in results.iter().enumerate() + { + let index = syn::Lit::Int(syn::LitInt::new( + &format!("{index}usize"), + result.span(), + )); + let result_mut = format_ident!("{result}_mut"); + let result_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{result}` result."), + result.span(), + )); + let result_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{result}` result."), + result.span(), + )); + tokens.extend(quote!{ + #[doc = #result_doc] + #[inline] + pub fn #result(&self) -> ::midenc_hir2::EntityRef<'_, ::midenc_hir2::OpResult> { + self.op.results()[#index].borrow() + } + + #[doc = #result_mut_doc] + #[inline] + pub fn #result_mut(&mut self) -> ::midenc_hir2::EntityMut<'_, ::midenc_hir2::OpResult> { + self.op.results_mut()[#index].borrow_mut() + } + }); + } + } + OpResultGroup::Named(group, _) => { + let group_mut = format_ident!("{group}_mut"); + let group_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{group}` result group."), + group.span(), + )); + let group_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{group}` result group."), + group.span(), + )); + tokens.extend(quote! { + #[doc = #group_doc] + #[inline] + pub fn #group(&self) -> ::midenc_hir::OpResultRange<'_> { + self.results().group(0) + } + + #[doc = #group_mut_doc] + #[inline] + pub fn #group_mut(&mut self) -> ::midenc_hir::OpResultRangeMut<'_> { + self.op.results_mut().group_mut(0) + } + }); + } + } + } + } +} + +struct OpRegionFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpRegionFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // Regions + for (index, region) in self.0.regions.iter().enumerate() { + let index = syn::Lit::Int(syn::LitInt::new(&format!("{index}usize"), region.span())); + let region_mut = format_ident!("{region}_mut"); + let region_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{region}` region."), + region.span(), + )); + let region_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{region}` region."), + region.span(), + )); + tokens.extend(quote! { + #[doc = #region_doc] + #[inline] + pub fn #region(&self) -> ::midenc_hir2::EntityRef<'_, ::midenc_hir2::Region> { + self.op.region(#index) + } + + #[doc = #region_mut_doc] + #[inline] + pub fn #region_mut(&mut self) -> ::midenc_hir2::EntityMut<'_, ::midenc_hir2::Region> { + self.op.region_mut(#index) + } + }); + } + } +} + +struct OpSuccessorFns<'a>(&'a OpDefinition); +impl quote::ToTokens for OpSuccessorFns<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + for (group_index, group) in self.0.successors.iter().enumerate() { + let group_index = syn::Lit::Int(syn::LitInt::new( + &format!("{group_index}usize"), + proc_macro2::Span::call_site(), + )); + match group { + // Successors + SuccessorGroup::Unnamed(successors) => { + for (index, successor) in successors.iter().enumerate() { + let index = syn::Lit::Int(syn::LitInt::new( + &format!("{index}usize"), + proc_macro2::Span::call_site(), + )); + let successor_mut = format_ident!("{successor}_mut"); + let successor_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{successor}` successor."), + successor.span(), + )); + let successor_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{successor}` successor."), + successor.span(), + )); + tokens.extend(quote! { + #[doc = #successor_doc] + #[inline] + pub fn #successor(&self) -> ::midenc_hir2::OpSuccessor<'_> { + self.op.successor_in_group(#group_index, #index) + } + + #[doc = #successor_mut_doc] + #[inline] + pub fn #successor_mut(&mut self) -> ::midenc_hir2::OpSuccessorMut<'_> { + self.op.successor_in_group_mut(#group_index, #index) + } + }); + } + } + // Variadic successor groups + SuccessorGroup::Named(group) => { + let group_mut = format_ident!("{group}_mut"); + let group_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{group}` successor group."), + group.span(), + )); + let group_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{group}` successor group."), + group.span(), + )); + tokens.extend(quote! { + #[doc = #group_doc] + #[inline] + pub fn #group(&self) -> ::midenc_hir2::OpSuccessorRange<'_> { + self.op.successor_group(#group_index) + } + + #[doc = #group_mut_doc] + #[inline] + pub fn #group_mut(&mut self) -> ::midenc_hir2::OpSuccessorRangeMut<'_> { + self.op.successor_group(#group_index) + } + }); + } + // User-defined successor groups + SuccessorGroup::Keyed(group, group_ty) => { + let group_mut = format_ident!("{group}_mut"); + let group_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a reference to the `{group}` successor group."), + group.span(), + )); + let group_mut_doc = syn::Lit::Str(syn::LitStr::new( + &format!(" Get a mutable reference to the `{group}` successor group."), + group.span(), + )); + tokens.extend(quote! { + #[doc = #group_doc] + #[inline] + pub fn #group(&self) -> ::midenc_hir2::KeyedSuccessorRange<'_, #group_ty> { + self.op.keyed_successor_group::<#group_ty>(#group_index) + } + + #[doc = #group_mut_doc] + #[inline] + pub fn #group_mut(&mut self) -> ::midenc_hir2::KeyedSuccessorRangeMut<'_, #group_ty> { + self.op.keyed_successor_group_mut::<#group_ty>(#group_index) + } + }); + } + } + } + } +} + +/// Represents a field decorated with `#[attr]` +/// +/// The type associated with an `#[attr]` field represents the concrete value type of the attribute, +/// and thus must implement the `AttributeValue` trait. +#[derive(Debug, Clone)] +pub struct OpAttribute { + /// The attribute name + pub name: Ident, + /// The value type of the attribute + pub ty: syn::Type, +} + +/// An abstraction over named vs unnamed groups of some IR entity +pub enum EntityGroup { + /// An unnamed group consisting of individual named items + Unnamed(Vec), + /// A named group consisting of unnamed items + Named(Ident, syn::Type), +} + +/// A type representing a type constraint applied to a `Value` impl +pub type Constraint = syn::Type; + +#[derive(Debug, Clone)] +pub struct Operand { + pub name: Ident, + pub constraint: Constraint, +} + +pub type OpOperandGroup = EntityGroup; + +#[derive(Debug, Clone)] +pub struct OpResult { + pub name: Ident, + pub constraint: Constraint, +} + +pub type OpResultGroup = EntityGroup; + +#[derive(Debug)] +pub enum SuccessorGroup { + /// An unnamed group consisting of individual named successors + Unnamed(Vec), + /// A named group consisting of unnamed successors + Named(Ident), + /// A named group consisting of unnamed successors with an associated key + Keyed(Ident, syn::Type), +} + +/// Represents the generated `$OpBuilder` type used to create instances of `$Op` +/// +/// The implementation of the type requires us to know the type signature specific to this op, +/// so that we can emit an implementation matching that signature. +pub struct OpBuilderImpl { + /// The `$Op` we're building + op: Ident, + /// The `$OpBuilder` type name + name: Ident, + /// The doc string for `$OpBuilder` + doc: DocString, + /// The doc string for `$OpBuilder::new` + new_doc: DocString, + /// The set of parameters expected by `$Op::create` + /// + /// The order of these parameters is determined by: + /// + /// 1. The `order = N` property of the corresponding attribute type, e.g. `#[attr(order = 1)]` + /// 2. The default "kind" ordering of: symbols, required user-defined fields, operands, successors, attributes + /// 3. The order of appearance of the fields in the struct + create_params: Rc<[OpCreateParam]>, + /// The implementation of the `BuildableOp` trait for `$Op` via `$OpBuilder` + buildable_op_impl: BuildableOpImpl, + /// The implementation of the `FnOnce` trait for `$OpBuilder` + fn_once_impl: OpBuilderFnOnceImpl, +} +impl OpBuilderImpl { + pub fn empty(op: Ident) -> Self { + let name = format_ident!("{}Builder", &op); + let doc = DocString::new( + op.span(), + format!( + " A specialized builder for [{op}], which is used by calling it like a function." + ), + ); + let new_doc = DocString::new( + op.span(), + format!( + " Get a new [{name}] from the provided [::midenc_hir2::Builder] impl and span." + ), + ); + let create_params = Rc::<[OpCreateParam]>::from([]); + let buildable_op_impl = BuildableOpImpl { + op: op.clone(), + op_builder: name.clone(), + op_generics: Default::default(), + generics: Default::default(), + required_generics: None, + params: Rc::clone(&create_params), + }; + let fn_once_impl = OpBuilderFnOnceImpl { + op: op.clone(), + op_builder: name.clone(), + generics: Default::default(), + required_generics: None, + params: Rc::clone(&create_params), + }; + Self { + op, + name, + doc, + new_doc, + create_params, + buildable_op_impl, + fn_once_impl, + } + } + + pub fn set_create_params(&mut self, op_generics: &syn::Generics, params: Vec) { + let span = self.op.span(); + + let create_params = Rc::from(params.into_boxed_slice()); + self.create_params = Rc::clone(&create_params); + + let has_required_variant = self.create_params.iter().any(|param| param.default); + + // BuildableOp generic parameters + self.buildable_op_impl.params = Rc::clone(&create_params); + self.buildable_op_impl.op_generics = op_generics.clone(); + self.buildable_op_impl.required_generics = if has_required_variant { + Some(syn::Generics { + lt_token: Some(syn::token::Lt(span)), + params: syn::punctuated::Punctuated::from_iter( + op_generics.params.iter().cloned().chain( + self.create_params.iter().flat_map(OpCreateParam::required_generic_types), + ), + ), + gt_token: Some(syn::token::Gt(span)), + where_clause: op_generics.where_clause.clone(), + }) + } else { + None + }; + self.buildable_op_impl.generics = syn::Generics { + lt_token: Some(syn::token::Lt(span)), + params: syn::punctuated::Punctuated::from_iter( + op_generics + .params + .iter() + .cloned() + .chain(self.create_params.iter().flat_map(OpCreateParam::generic_types)), + ), + gt_token: Some(syn::token::Gt(span)), + where_clause: op_generics.where_clause.clone(), + }; + + // FnOnce generic parameters + self.fn_once_impl.params = create_params; + self.fn_once_impl.required_generics = + self.buildable_op_impl.required_generics.as_ref().map( + |buildable_op_impl_required_generics| syn::Generics { + lt_token: Some(syn::token::Lt(span)), + params: syn::punctuated::Punctuated::from_iter( + [ + syn::GenericParam::Lifetime(syn::LifetimeParam { + attrs: vec![], + lifetime: syn::Lifetime::new("'a", proc_macro2::Span::call_site()), + colon_token: None, + bounds: Default::default(), + }), + syn::parse_str("B: ?Sized + ::midenc_hir2::Builder").unwrap(), + ] + .into_iter() + .chain(buildable_op_impl_required_generics.params.iter().cloned()), + ), + gt_token: Some(syn::token::Gt(span)), + where_clause: buildable_op_impl_required_generics.where_clause.clone(), + }, + ); + self.fn_once_impl.generics = syn::Generics { + lt_token: Some(syn::token::Lt(span)), + params: syn::punctuated::Punctuated::from_iter( + [ + syn::GenericParam::Lifetime(syn::LifetimeParam { + attrs: vec![], + lifetime: syn::Lifetime::new("'a", proc_macro2::Span::call_site()), + colon_token: None, + bounds: Default::default(), + }), + syn::parse_str("B: ?Sized + ::midenc_hir2::Builder").unwrap(), + ] + .into_iter() + .chain(self.buildable_op_impl.generics.params.iter().cloned()), + ), + gt_token: Some(syn::token::Gt(span)), + where_clause: self.buildable_op_impl.generics.where_clause.clone(), + }; + } +} +impl quote::ToTokens for OpBuilderImpl { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + // Emit `$OpBuilder` + tokens.extend({ + let op_builder = &self.name; + let op_builder_doc = &self.doc; + let op_builder_new_doc = &self.new_doc; + quote! { + #op_builder_doc + pub struct #op_builder <'a, B: ?Sized> { + builder: &'a mut B, + span: ::midenc_session::diagnostics::SourceSpan, + } + + impl<'a, B> #op_builder <'a, B> + where + B: ?Sized + ::midenc_hir2::Builder, + { + #op_builder_new_doc + #[inline(always)] + pub fn new(builder: &'a mut B, span: ::midenc_session::diagnostics::SourceSpan) -> Self { + Self { + builder, + span, + } + } + } + } + }); + + // Emit `impl BuildableOp for $OpBuilder` + self.buildable_op_impl.to_tokens(tokens); + + // Emit `impl FnOnce for $OpBuilder` + self.fn_once_impl.to_tokens(tokens); + } +} + +pub struct BuildableOpImpl { + op: Ident, + op_builder: Ident, + op_generics: syn::Generics, + generics: syn::Generics, + required_generics: Option, + params: Rc<[OpCreateParam]>, +} +impl quote::ToTokens for BuildableOpImpl { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let op = &self.op; + let op_builder = &self.op_builder; + + // Minimal builder (specify only required parameters) + // + // NOTE: This is only emitted if there are `default` parameters + if let Some(required_generics) = self.required_generics.as_ref() { + let required_params = + self.params.iter().flat_map(OpCreateParam::required_binding_types); + let (_, required_ty_generics, _) = self.op_generics.split_for_impl(); + let (required_impl_generics, _, required_where_clause) = + required_generics.split_for_impl(); + let required_params_ty = syn::TypeTuple { + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(required_params), + }; + let quoted = quote! { + impl #required_impl_generics ::midenc_hir2::BuildableOp<#required_params_ty> for #op #required_ty_generics #required_where_clause { + type Builder<'a, T: ?Sized + ::midenc_hir2::Builder + 'a> = #op_builder <'a, T>; + + #[inline(always)] + fn builder<'b, B>(builder: &'b mut B, span: ::midenc_session::diagnostics::SourceSpan) -> Self::Builder<'b, B> + where + B: ?Sized + ::midenc_hir2::Builder + 'b, + { + #op_builder { + builder, + span, + } + } + } + }; + tokens.extend(quoted); + } + + // Maximal builder (specify all parameters) + let params = self.params.iter().flat_map(OpCreateParam::binding_types); + let (_, ty_generics, _) = self.op_generics.split_for_impl(); + let (impl_generics, _, where_clause) = self.generics.split_for_impl(); + let params_ty = syn::TypeTuple { + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(params), + }; + let quoted = quote! { + impl #impl_generics ::midenc_hir2::BuildableOp<#params_ty> for #op #ty_generics #where_clause { + type Builder<'a, T: ?Sized + ::midenc_hir2::Builder + 'a> = #op_builder <'a, T>; + + #[inline(always)] + fn builder<'b, B>(builder: &'b mut B, span: ::midenc_session::diagnostics::SourceSpan) -> Self::Builder<'b, B> + where + B: ?Sized + ::midenc_hir2::Builder + 'b, + { + #op_builder { + builder, + span, + } + } + } + }; + tokens.extend(quoted); + } +} + +pub struct OpBuilderFnOnceImpl { + op: Ident, + op_builder: Ident, + generics: syn::Generics, + required_generics: Option, + params: Rc<[OpCreateParam]>, +} +impl quote::ToTokens for OpBuilderFnOnceImpl { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let op = &self.op; + let op_builder = &self.op_builder; + + let create_param_names = + self.params.iter().flat_map(OpCreateParam::bindings).collect::>(); + + // Minimal builder (specify only required parameters) + // + // NOTE: This is only emitted if there are `default` parameters + if let Some(required_generics) = self.required_generics.as_ref() { + let required_param_names = + self.params.iter().flat_map(OpCreateParam::required_bindings); + let defaulted_param_names = self.params.iter().flat_map(|param| { + if param.default { + param.bindings() + } else { + vec![] + } + }); + let required_param_types = + self.params.iter().flat_map(OpCreateParam::required_binding_types); + let (required_impl_generics, _required_ty_generics, required_where_clause) = + required_generics.split_for_impl(); + let required_params_ty = syn::TypeTuple { + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(required_param_types), + }; + let required_params_bound = syn::PatTuple { + attrs: Default::default(), + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter( + required_param_names.into_iter().map(|id| { + syn::Pat::Ident(syn::PatIdent { + attrs: Default::default(), + by_ref: None, + mutability: None, + ident: id, + subpat: None, + }) + }), + ), + }; + tokens.extend(quote! { + impl #required_impl_generics ::core::ops::FnOnce<#required_params_ty> for #op_builder<'a, B> #required_where_clause { + type Output = Result<::midenc_hir2::UnsafeIntrusiveEntityRef<#op>, ::midenc_session::diagnostics::Report>; + + #[inline] + extern "rust-call" fn call_once(self, args: #required_params_ty) -> Self::Output { + let #required_params_bound = args; + #( + let #defaulted_param_names = Default::default(); + )* + <#op>::create(self.builder, self.span, #(#create_param_names),*) + } + } + }); + } + + // Maximal builder (specify all parameters) + let param_types = self.params.iter().flat_map(OpCreateParam::binding_types); + let (impl_generics, _ty_generics, where_clause) = self.generics.split_for_impl(); + let params_ty = syn::TypeTuple { + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(param_types), + }; + let params_bound = syn::PatTuple { + attrs: Default::default(), + paren_token: syn::token::Paren(op.span()), + elems: syn::punctuated::Punctuated::from_iter(create_param_names.iter().map(|id| { + syn::Pat::Ident(syn::PatIdent { + attrs: Default::default(), + by_ref: None, + mutability: None, + ident: id.clone(), + subpat: None, + }) + })), + }; + tokens.extend(quote! { + impl #impl_generics ::core::ops::FnOnce<#params_ty> for #op_builder<'a, B> #where_clause { + type Output = Result<::midenc_hir2::UnsafeIntrusiveEntityRef<#op>, ::midenc_session::diagnostics::Report>; + + #[inline] + extern "rust-call" fn call_once(self, args: #params_ty) -> Self::Output { + let #params_bound = args; + <#op>::create(self.builder, self.span, #(#create_param_names),*) + } + } + }); + } +} + +pub struct OpVerifierImpl { + op: Ident, + traits: darling::util::PathList, + implements: darling::util::PathList, +} +impl OpVerifierImpl { + pub fn new( + op: Ident, + traits: darling::util::PathList, + implements: darling::util::PathList, + ) -> Self { + Self { + op, + traits, + implements, + } + } +} +impl quote::ToTokens for OpVerifierImpl { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let op = &self.op; + if self.traits.is_empty() && self.implements.is_empty() { + tokens.extend(quote! { + /// No-op verifier implementation generated via `#[operation]` derive + /// + /// This implementation was chosen as no op traits were indicated as being derived _or_ + /// manually implemented by this type. + impl ::midenc_hir2::OpVerifier for #op { + #[inline(always)] + fn verify(&self, _context: &::midenc_hir2::Context) -> Result<(), ::midenc_session::diagnostics::Report> { + Ok(()) + } + } + }); + return; + } + + let op_verifier_doc_lines = { + let span = self.op.span(); + let mut lines = vec![ + syn::Lit::Str(syn::LitStr::new( + " Generated verifier implementation via `#[operation]` attribute", + span, + )), + syn::Lit::Str(syn::LitStr::new("", span)), + syn::Lit::Str(syn::LitStr::new(" Traits verified by this implementation:", span)), + syn::Lit::Str(syn::LitStr::new("", span)), + ]; + for derived_trait in self.traits.iter() { + lines.push(syn::Lit::Str(syn::LitStr::new( + &format!(" * [{}]", derived_trait.get_ident().unwrap()), + span, + ))); + } + for implemented_trait in self.implements.iter() { + lines.push(syn::Lit::Str(syn::LitStr::new( + &format!(" * [{}]", implemented_trait.get_ident().unwrap()), + span, + ))); + } + lines.push(syn::Lit::Str(syn::LitStr::new("", span))); + lines.push(syn::Lit::Str(syn::LitStr::new( + " Use `cargo-expand` to view the generated code if you suspect verification is \ + broken.", + span, + ))); + lines + }; + + let derived_traits = &self.traits; + let implemented_traits = &self.implements; + tokens.extend(quote! { + #( + #[doc = #op_verifier_doc_lines] + )* + impl ::midenc_hir2::OpVerifier for #op { + fn verify(&self, context: &::midenc_hir2::Context) -> Result<(), ::midenc_session::diagnostics::Report> { + /// Type alias for the generated concrete verifier type + #[allow(unused_parens)] + type OpVerifierImpl<'a, T> = ::midenc_hir2::derive::DeriveVerifier<'a, T, (#(&'a dyn #derived_traits,)* #(&'a dyn #implemented_traits),*)>; + + #[allow(unused_parens)] + impl<'a> ::midenc_hir2::OpVerifier for OpVerifierImpl<'a, #op> + where + #( + #op: ::midenc_hir2::verifier::Verifier, + )* + #( + #op: ::midenc_hir2::verifier::Verifier, + )* + { + #[inline] + fn verify(&self, context: &::midenc_hir2::Context) -> Result<(), ::midenc_session::diagnostics::Report> { + let op = self.downcast_ref::<#op>().unwrap(); + #( + if const { !<#op as ::midenc_hir2::verifier::Verifier>::VACUOUS } { + <#op as ::midenc_hir2::verifier::Verifier>::maybe_verify(op, context)?; + } + )* + #( + if const { !<#op as ::midenc_hir2::verifier::Verifier>::VACUOUS } { + <#op as ::midenc_hir2::verifier::Verifier>::maybe_verify(op, context)?; + } + )* + + Ok(()) + } + } + + let verifier = OpVerifierImpl::<#op>::new(&self.op); + verifier.verify(context) + } + } + }); + } +} + +/// Represents the parsed struct definition for the operation we wish to define +/// +/// Only named structs are allowed at this time. +#[derive(Debug, FromDeriveInput)] +#[darling( + attributes(operation), + supports(struct_named), + forward_attrs(doc, cfg, allow, derive) +)] +pub struct Operation { + ident: Ident, + vis: syn::Visibility, + generics: syn::Generics, + attrs: Vec, + data: darling::ast::Data<(), OperationField>, + dialect: Ident, + #[darling(default)] + name: Option, + #[darling(default)] + traits: darling::util::PathList, + #[darling(default)] + implements: darling::util::PathList, +} + +/// Represents a field in the input struct +#[derive(Debug, FromField)] +#[darling(forward_attrs( + doc, cfg, allow, attr, operand, operands, region, successor, successors, result, results, + default, order, symbol +))] +pub struct OperationField { + /// The name of this field. + /// + /// This will always be `Some`, as we do not support any types other than structs + ident: Option, + /// The visibility assigned to this field + vis: syn::Visibility, + /// The type assigned to this field + ty: syn::Type, + /// The processed attributes of this field + #[darling(with = OperationFieldAttrs::new)] + attrs: OperationFieldAttrs, +} + +#[derive(Default, Debug)] +pub struct OperationFieldAttrs { + /// Attributes we don't care about, and are forwarding along untouched + forwarded: Vec, + /// Whether or not to create instances of this op using the `Default` impl for this field + r#default: Flag, + /// Whether or not to assign an explicit order to this field. + /// + /// Once an explicit order has been assigned to a field, all subsequent fields must either have + /// an explicit order, or they will be assigned the next largest unallocated index in the order. + order: Option, + /// Was this an `#[attr]` field? + attr: Flag, + /// Was this an `#[operand]` field? + operand: Flag, + /// Was this an `#[operands]` field? + operands: Flag, + /// Was this a `#[result]` field? + result: Flag, + /// Was this a `#[results]` field? + results: Flag, + /// Was this a `#[region]` field? + region: Flag, + /// Was this a `#[successor]` field? + successor: Flag, + /// Was this a `#[successors]` field? + successors: Option>, + /// Was this a `#[symbol]` field? + symbol: Option>>, +} + +impl OperationFieldAttrs { + pub fn new(attrs: Vec) -> darling::Result { + let mut result = Self::default(); + let mut prev_decorator = None; + for attr in attrs { + if let Some(name) = attr.path().get_ident().map(|id| id.to_string()) { + match name.as_str() { + "attr" => { + if let Some(prev) = prev_decorator.replace("attr") { + return Err(Error::custom(format!( + "#[attr] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.attr = Flag::from_meta(&attr.meta).unwrap(); + } + "operand" => { + if let Some(prev) = prev_decorator.replace("operand") { + return Err(Error::custom(format!( + "#[operand] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.operand = Flag::from_meta(&attr.meta).unwrap(); + } + "operands" => { + if let Some(prev) = prev_decorator.replace("operands") { + return Err(Error::custom(format!( + "#[operands] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.operands = Flag::from_meta(&attr.meta).unwrap(); + } + "result" => { + if let Some(prev) = prev_decorator.replace("result") { + return Err(Error::custom(format!( + "#[result] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.result = Flag::from_meta(&attr.meta).unwrap(); + } + "results" => { + if let Some(prev) = prev_decorator.replace("results") { + return Err(Error::custom(format!( + "#[results] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.results = Flag::from_meta(&attr.meta).unwrap(); + } + "region" => { + if let Some(prev) = prev_decorator.replace("region") { + return Err(Error::custom(format!( + "#[region] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.region = Flag::from_meta(&attr.meta).unwrap(); + } + "successor" => { + if let Some(prev) = prev_decorator.replace("successor") { + return Err(Error::custom(format!( + "#[successor] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + result.successor = Flag::from_meta(&attr.meta).unwrap(); + } + "successors" => { + if let Some(prev) = prev_decorator.replace("successors") { + return Err(Error::custom(format!( + "#[successors] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + let span = attr.span(); + let mut succ_ty = SuccessorsType::Default; + match attr.parse_nested_meta(|meta| { + if meta.path.is_ident("keyed") { + succ_ty = SuccessorsType::Keyed; + Ok(()) + } else { + Err(meta.error(format!( + "invalid #[successors] decorator: unrecognized key '{}'", + meta.path.get_ident().unwrap() + ))) + } + }) { + Ok(_) => { + result.successors = Some(SpannedValue::new(succ_ty, span)); + } + Err(err) => { + return Err(Error::from(err)); + } + } + } + "symbol" => { + if let Some(prev) = prev_decorator.replace("symbol") { + return Err(Error::custom(format!( + "#[symbol] conflicts with a previous #[{prev}] decorator" + )) + .with_span(&attr)); + } + let span = attr.span(); + let mut symbol_ty = None; + match &attr.meta { + // A bare #[symbol], nothing to do + syn::Meta::Path(_) => (), + syn::Meta::List(ref list) => { + list.parse_nested_meta(|meta| { + if meta.path.is_ident("callable") { + symbol_ty = Some(SymbolType::Callable); + Ok(()) + } else if meta.path.is_ident("any") { + symbol_ty = Some(SymbolType::Any); + Ok(()) + } else if meta.path.is_ident("bounds") { + let symbol_bound = meta + .input + .parse::() + .map_err(Error::from)?; + symbol_ty = Some(symbol_bound.into()); + Ok(()) + } else { + Err(meta.error(format!( + "invalid #[symbol] decorator: unrecognized key '{}'", + meta.path.get_ident().unwrap() + ))) + } + }) + .map_err(Error::from)?; + } + meta @ syn::Meta::NameValue(_) => { + return Err(Error::custom( + "invalid #[symbol] decorator: invalid format, expected either \ + bare 'symbol' or a meta list", + ) + .with_span(meta)); + } + } + result.symbol = Some(SpannedValue::new(symbol_ty, span)); + } + "default" => { + result.default = Flag::present(); + } + "order" => { + result.order = Some( + attr.parse_args::() + .map_err(Error::from) + .and_then(|n| n.base10_parse::().map_err(Error::from))?, + ); + } + _ => { + result.forwarded.push(attr); + } + } + } else { + result.forwarded.push(attr); + } + } + + Ok(result) + } +} + +impl OperationFieldAttrs { + pub fn pseudo_type(&self) -> Option> { + use darling::util::SpannedValue; + if self.attr.is_present() { + Some(SpannedValue::new(OperationFieldType::Attr, self.attr.span())) + } else if self.operand.is_present() { + Some(SpannedValue::new(OperationFieldType::Operand, self.operand.span())) + } else if self.operands.is_present() { + Some(SpannedValue::new(OperationFieldType::Operands, self.operands.span())) + } else if self.result.is_present() { + Some(SpannedValue::new(OperationFieldType::Result, self.result.span())) + } else if self.results.is_present() { + Some(SpannedValue::new(OperationFieldType::Results, self.results.span())) + } else if self.region.is_present() { + Some(SpannedValue::new(OperationFieldType::Region, self.region.span())) + } else if self.successor.is_present() { + Some(SpannedValue::new(OperationFieldType::Successor, self.successor.span())) + } else if self.successors.is_some() { + self.successors.map(|succ| succ.map_ref(|s| OperationFieldType::Successors(*s))) + } else if self.symbol.is_some() { + self.symbol + .as_ref() + .map(|sym| sym.map_ref(|sym| OperationFieldType::Symbol(sym.clone()))) + } else { + None + } + } +} + +#[derive(Debug, Clone)] +pub enum OperationFieldType { + /// An operation attribute + Attr, + /// A named operand + Operand, + /// A named variadic operand group (zero or more operands) + Operands, + /// A named result + Result, + /// A named variadic result group (zero or more results) + Results, + /// A named region + Region, + /// A named successor + Successor, + /// A named variadic successor group (zero or more successors) + Successors(SuccessorsType), + /// A symbol operand + /// + /// Symbols are handled differently than regular operands, as they are not SSA values, and + /// are tracked using a different use/def graph than normal values. + /// + /// If the symbol type is `None`, it implies we should use the concrete field type as the + /// expected symbol type. Otherwise, use the provided symbol type to derive bounds for that + /// field. + Symbol(Option), +} +impl core::fmt::Display for OperationFieldType { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Attr => f.write_str("attr"), + Self::Operand => f.write_str("operand"), + Self::Operands => f.write_str("operands"), + Self::Result => f.write_str("result"), + Self::Results => f.write_str("results"), + Self::Region => f.write_str("region"), + Self::Successor => f.write_str("successor"), + Self::Successors(SuccessorsType::Default) => f.write_str("successors"), + Self::Successors(SuccessorsType::Keyed) => f.write_str("successors(keyed)"), + Self::Symbol(None) => f.write_str("symbol"), + Self::Symbol(Some(SymbolType::Any)) => f.write_str("symbol(any)"), + Self::Symbol(Some(SymbolType::Callable)) => f.write_str("symbol(callable)"), + Self::Symbol(Some(SymbolType::Concrete(_))) => write!(f, "symbol(concrete)"), + Self::Symbol(Some(SymbolType::Trait(_))) => write!(f, "symbol(trait)"), + } + } +} + +/// The type of successor group +#[derive(Default, Debug, darling::FromMeta, Copy, Clone)] +#[darling(default)] +pub enum SuccessorsType { + /// The default successor type consists of a `BlockRef` and an iterable of `ValueRef` + #[default] + Default, + /// A keyed successor is a custom type that implements the `KeyedSuccessor` trait + Keyed, +} + +/// Represents parameter information for `$Op::create` and the associated builder infrastructure. +#[derive(Debug)] +pub struct OpCreateParam { + /// The actual parameter type and payload + param_ty: OpCreateParamType, + /// Is this value initialized using `Default::default` when `Op::create` is called? + r#default: bool, +} + +#[derive(Debug)] +pub enum OpCreateParamType { + Attr(OpAttribute), + Operand(Operand), + #[allow(dead_code)] + OperandGroup(Ident, syn::Type), + CustomField(Ident, syn::Type), + Successor(Ident), + SuccessorGroupNamed(Ident), + SuccessorGroupKeyed(Ident, syn::Type), + Symbol(Symbol), +} +impl OpCreateParam { + /// Returns the names of all bindings implied by this parameter. + pub fn bindings(&self) -> Vec { + match &self.param_ty { + OpCreateParamType::Attr(OpAttribute { name, .. }) + | OpCreateParamType::CustomField(name, _) + | OpCreateParamType::Operand(Operand { name, .. }) + | OpCreateParamType::OperandGroup(name, _) + | OpCreateParamType::SuccessorGroupNamed(name) + | OpCreateParamType::SuccessorGroupKeyed(name, _) + | OpCreateParamType::Symbol(Symbol { name, .. }) => vec![name.clone()], + OpCreateParamType::Successor(name) => { + vec![name.clone(), format_ident!("{}_args", name)] + } + } + } + + /// Returns the names of all required (i.e. non-defaulted) bindings implied by this parameter. + pub fn required_bindings(&self) -> Vec { + if self.default { + return vec![]; + } + self.bindings() + } + + /// Returns the types assigned to the bindings returned by [Self::bindings] + pub fn binding_types(&self) -> Vec { + match &self.param_ty { + OpCreateParamType::Attr(OpAttribute { ty, .. }) + | OpCreateParamType::CustomField(_, ty) => { + vec![ty.clone()] + } + OpCreateParamType::Operand(_) => vec![make_type("::midenc_hir2::ValueRef")], + OpCreateParamType::OperandGroup(group_name, _) + | OpCreateParamType::SuccessorGroupNamed(group_name) + | OpCreateParamType::SuccessorGroupKeyed(group_name, _) => { + vec![make_type(format!("T{}", group_name.to_string().to_pascal_case()))] + } + OpCreateParamType::Successor(name) => vec![ + make_type("::midenc_hir2::BlockRef"), + make_type(format!("T{}Args", name.to_string().to_pascal_case())), + ], + OpCreateParamType::Symbol(Symbol { name, ty }) => match ty { + SymbolType::Any | SymbolType::Callable | SymbolType::Trait(_) => { + vec![make_type(format!("T{}", name.to_string().to_pascal_case()))] + } + SymbolType::Concrete(ty) => vec![ty.clone()], + }, + } + } + + /// Returns the types assigned to the bindings returned by [Self::required_bindings] + pub fn required_binding_types(&self) -> Vec { + if self.default { + return vec![]; + } + self.binding_types() + } + + /// Returns the generic type parameters bound for use by the types in [Self::binding_typess] + pub fn generic_types(&self) -> Vec { + match &self.param_ty { + OpCreateParamType::OperandGroup(name, _) => { + let value_iter_bound: syn::TypeParamBound = + syn::parse_str("IntoIterator").unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!( + "T{}", + &name.to_string().to_pascal_case(), + span = name.span() + ), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([value_iter_bound]), + eq_token: None, + r#default: None, + })] + } + OpCreateParamType::Successor(name) => { + let value_iter_bound: syn::TypeParamBound = + syn::parse_str("IntoIterator").unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!( + "T{}Args", + &name.to_string().to_pascal_case(), + span = name.span() + ), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([value_iter_bound]), + eq_token: None, + r#default: None, + })] + } + OpCreateParamType::SuccessorGroupNamed(name) => { + let value_iter_bound: syn::TypeParamBound = syn::parse_str( + "IntoIterator)>", + ) + .unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!( + "T{}", + &name.to_string().to_pascal_case(), + span = name.span() + ), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([value_iter_bound]), + eq_token: None, + r#default: None, + })] + } + OpCreateParamType::SuccessorGroupKeyed(name, ty) => { + let item_name = name.to_string().to_pascal_case(); + let iterator_ty = format_ident!("T{item_name}", span = name.span()); + vec![syn::parse_quote! { + #iterator_ty: IntoIterator + }] + } + OpCreateParamType::Symbol(Symbol { name, ty }) => match ty { + SymbolType::Any => { + let as_symbol_ref_bound = + syn::parse_str::("::midenc_hir2::AsSymbolRef") + .unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!("T{}", name.to_string().to_pascal_case()), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([as_symbol_ref_bound]), + eq_token: None, + r#default: None, + })] + } + SymbolType::Callable => { + let as_callable_symbol_ref_bound = + syn::parse_str::("::midenc_hir2::AsCallableSymbolRef") + .unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!("T{}", name.to_string().to_pascal_case()), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter([ + as_callable_symbol_ref_bound, + ]), + eq_token: None, + r#default: None, + })] + } + SymbolType::Concrete(_) => vec![], + SymbolType::Trait(bounds) => { + let as_symbol_ref_bound = syn::parse_str("::midenc_hir2::AsSymbolRef").unwrap(); + vec![syn::GenericParam::Type(syn::TypeParam { + attrs: vec![], + ident: format_ident!("T{}", name.to_string().to_pascal_case()), + colon_token: Some(syn::token::Colon(name.span())), + bounds: syn::punctuated::Punctuated::from_iter( + [as_symbol_ref_bound].into_iter().chain(bounds.iter().cloned()), + ), + eq_token: None, + r#default: None, + })] + } + }, + _ => vec![], + } + } + + /// Returns the generic type parameters bound for use by the types in [Self::required_binding_typess] + pub fn required_generic_types(&self) -> Vec { + if self.default { + return vec![]; + } + self.generic_types() + } +} + +/// A symbol value +#[derive(Debug, Clone)] +pub struct Symbol { + pub name: Ident, + pub ty: SymbolType, +} + +/// Represents the type of a symbol +#[derive(Debug, Clone)] +pub enum SymbolType { + /// Any `Symbol` implementation can be used + Any, + /// Any `Symbol + CallableOpInterface` implementation can be used + Callable, + /// Only the specific concrete type can be used, it must implement the `Symbol` trait + Concrete(syn::Type), + /// Any implementation of the provided trait can be used. + /// + /// The given trait type _must_ have `Symbol` as a supertrait. + Trait(syn::punctuated::Punctuated), +} + +struct SymbolTraitBound { + _eq_token: Token![=], + bounds: syn::punctuated::Punctuated, +} +impl syn::parse::Parse for SymbolTraitBound { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + if !lookahead.peek(Token![=]) { + return Err(lookahead.error()); + } + + let _eq_token = input.parse::()?; + let bounds = syn::punctuated::Punctuated::parse_separated_nonempty(input)?; + + Ok(Self { _eq_token, bounds }) + } +} +impl From for SymbolType { + #[inline] + fn from(value: SymbolTraitBound) -> Self { + SymbolType::Trait(value.bounds) + } +} + +pub struct DocString { + span: proc_macro2::Span, + doc: String, +} +impl DocString { + pub fn new(span: proc_macro2::Span, doc: String) -> Self { + Self { span, doc } + } +} +impl quote::ToTokens for DocString { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let attr = syn::Attribute { + pound_token: syn::token::Pound(self.span), + style: syn::AttrStyle::Outer, + bracket_token: syn::token::Bracket(self.span), + meta: syn::Meta::NameValue(syn::MetaNameValue { + path: attr_path("doc"), + eq_token: syn::token::Eq(self.span), + value: syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::Str(syn::LitStr::new(&self.doc, self.span)), + }), + }), + }; + + attr.to_tokens(tokens); + } +} + +#[derive(Copy, Clone)] +enum PathStyle { + Default, + Absolute, +} + +fn make_type(s: impl AsRef) -> syn::Type { + let s = s.as_ref(); + let path = type_path(s); + syn::Type::Path(syn::TypePath { qself: None, path }) +} + +fn type_path(s: impl AsRef) -> syn::Path { + let s = s.as_ref(); + let (s, style) = if let Some(s) = s.strip_prefix("::") { + (s, PathStyle::Absolute) + } else { + (s, PathStyle::Default) + }; + let parts = s.split("::"); + make_path(parts, style) +} + +fn attr_path(s: impl AsRef) -> syn::Path { + make_path([s.as_ref()], PathStyle::Default) +} + +fn make_path<'a>(parts: impl IntoIterator, style: PathStyle) -> syn::Path { + use proc_macro2::Span; + + syn::Path { + leading_colon: match style { + PathStyle::Default => None, + PathStyle::Absolute => Some(syn::token::PathSep(Span::call_site())), + }, + segments: syn::punctuated::Punctuated::from_iter(parts.into_iter().map(|part| { + syn::PathSegment { + ident: format_ident!("{}", part), + arguments: syn::PathArguments::None, + } + })), + } +} + +#[cfg(test)] +mod tests { + #![allow(dead_code)] + + #[test] + fn operation_impl_test() { + let item_input: syn::DeriveInput = syn::parse_quote! { + /// Two's complement sum + #[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface), + )] + pub struct Add { + /// The left-hand operand + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, + #[attr] + overflow: Overflow, + } + }; + + let output = super::derive_operation(item_input); + match output { + Ok(output) => { + let formatted = format_output(&output.to_string()); + println!("{formatted}"); + } + Err(err) => { + panic!("{err}"); + } + } + } + + fn format_output(input: &str) -> String { + use std::{ + io::{Read, Write}, + process::{Command, Stdio}, + }; + + let mut child = Command::new("rustfmt") + .args(["+nightly", "--edition", "2024"]) + .args([ + "--config", + "unstable_features=true,normalize_doc_attributes=true,\ + use_field_init_shorthand=true,condense_wildcard_suffixes=true,\ + format_strings=true,group_imports=StdExternalCrate,imports_granularity=Crate", + ]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("failed to spawn 'rustfmt'"); + + { + let mut stdin = child.stdin.take().unwrap(); + stdin.write_all(input.as_bytes()).expect("failed to write input to 'rustfmt'"); + } + let mut buf = String::new(); + let mut stdout = child.stdout.take().unwrap(); + stdout.read_to_string(&mut buf).expect("failed to read output from 'rustfmt'"); + match child.wait() { + Ok(status) => { + if status.success() { + buf + } else { + let mut stderr = child.stderr.take().unwrap(); + let mut err_buf = String::new(); + let _ = stderr.read_to_string(&mut err_buf).ok(); + panic!( + "command 'rustfmt' failed with status {:?}\n\nReason: {}", + status.code(), + if err_buf.is_empty() { + "" + } else { + err_buf.as_str() + }, + ); + } + } + Err(err) => panic!("command 'rustfmt' failed with {err}"), + } + } +} diff --git a/hir-symbol/Cargo.toml b/hir-symbol/Cargo.toml index 6274ede9d..7bd8ace77 100644 --- a/hir-symbol/Cargo.toml +++ b/hir-symbol/Cargo.toml @@ -12,13 +12,20 @@ readme.workspace = true edition.workspace = true [features] -default = [] +default = ["std"] +std = ["dep:parking_lot"] serde = ["dep:serde"] +compact_str = ["dep:compact_str"] [dependencies] +compact_str = { workspace = true, optional = true } +lock_api = "0.4" +miden-formatting.workspace = true +parking_lot = { version = "0.12", optional = true } serde = { workspace = true, optional = true } [build-dependencies] Inflector.workspace = true +hashbrown.workspace = true rustc-hash.workspace = true toml.workspace = true diff --git a/hir-symbol/build.rs b/hir-symbol/build.rs index 96d42798a..8144efcc3 100644 --- a/hir-symbol/build.rs +++ b/hir-symbol/build.rs @@ -1,3 +1,4 @@ +extern crate hashbrown; extern crate inflector; extern crate rustc_hash; extern crate toml; @@ -12,9 +13,10 @@ use std::{ }; use inflector::Inflector; -use rustc_hash::FxHashSet; use toml::{value::Table, Value}; +type FxHashSet = hashbrown::HashSet; + #[derive(Debug, Default, Clone)] struct Symbol { key: String, diff --git a/hir-symbol/src/lib.rs b/hir-symbol/src/lib.rs index 7d24dc6f6..a1bc3cfb1 100644 --- a/hir-symbol/src/lib.rs +++ b/hir-symbol/src/lib.rs @@ -1,26 +1,31 @@ -use core::{fmt, mem, ops::Deref, str}; -use std::{ +#![no_std] + +extern crate alloc; +#[cfg(feature = "std")] +extern crate std; + +pub mod sync; + +use alloc::{ + boxed::Box, collections::BTreeMap, - sync::{OnceLock, RwLock}, + string::{String, ToString}, + vec::Vec, }; +use core::{fmt, mem, ops::Deref, str}; -static SYMBOL_TABLE: OnceLock = OnceLock::new(); +use miden_formatting::prettier::PrettyPrint; pub mod symbols { include!(env!("SYMBOLS_RS")); } +static SYMBOL_TABLE: sync::LazyLock = sync::LazyLock::new(SymbolTable::default); + +#[derive(Default)] struct SymbolTable { - interner: RwLock, -} -impl SymbolTable { - pub fn new() -> Self { - Self { - interner: RwLock::new(Interner::new()), - } - } + interner: sync::RwLock, } -unsafe impl Sync for SymbolTable {} /// A symbol is an interned string. #[derive(Clone, Copy, PartialEq, Eq, Hash)] @@ -68,8 +73,8 @@ impl Symbol { } /// Maps a string to its interned representation. - pub fn intern>(string: S) -> Self { - let string = string.into(); + pub fn intern(string: impl ToString) -> Self { + let string = string.to_string(); with_interner(|interner| interner.intern(string)) } @@ -107,6 +112,12 @@ impl fmt::Display for Symbol { fmt::Display::fmt(&self.as_str(), f) } } +impl PrettyPrint for Symbol { + fn render(&self) -> miden_formatting::prettier::Document { + use miden_formatting::prettier::*; + const_text(self.as_str()) + } +} impl PartialOrd for Symbol { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) @@ -117,11 +128,53 @@ impl Ord for Symbol { self.as_str().cmp(other.as_str()) } } +impl AsRef for Symbol { + #[inline(always)] + fn as_ref(&self) -> &str { + self.as_str() + } +} +impl core::borrow::Borrow for Symbol { + #[inline(always)] + fn borrow(&self) -> &str { + self.as_str() + } +} impl> PartialEq for Symbol { fn eq(&self, other: &T) -> bool { self.as_str() == other.deref() } } +impl From<&'static str> for Symbol { + fn from(s: &'static str) -> Self { + with_interner(|interner| interner.insert(s)) + } +} +impl From for Symbol { + fn from(s: String) -> Self { + Self::intern(s) + } +} +impl From> for Symbol { + fn from(s: Box) -> Self { + Self::intern(s) + } +} +impl From> for Symbol { + fn from(s: alloc::borrow::Cow<'static, str>) -> Self { + use alloc::borrow::Cow; + match s { + Cow::Borrowed(s) => s.into(), + Cow::Owned(s) => Self::intern(s), + } + } +} +#[cfg(feature = "compact_str")] +impl From for Symbol { + fn from(s: compact_str::CompactString) -> Self { + Self::intern(s.into_string()) + } +} #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] struct SymbolIndex(u32); @@ -159,22 +212,26 @@ impl From for usize { } } -#[derive(Default)] struct Interner { pub names: BTreeMap<&'static str, Symbol>, pub strings: Vec<&'static str>, } -impl Interner { - pub fn new() -> Self { - let mut this = Interner::default(); +impl Default for Interner { + fn default() -> Self { + let mut this = Self { + names: BTreeMap::default(), + strings: Vec::with_capacity(symbols::__SYMBOLS.len()), + }; for (sym, s) in symbols::__SYMBOLS { this.names.insert(s, *sym); this.strings.push(s); } this } +} +impl Interner { pub fn intern(&mut self, string: String) -> Symbol { if let Some(&name) = self.names.get(string.as_str()) { return name; @@ -189,7 +246,17 @@ impl Interner { name } - pub fn get(&self, symbol: Symbol) -> &str { + pub fn insert(&mut self, s: &'static str) -> Symbol { + if let Some(&name) = self.names.get(s) { + return name; + } + let name = Symbol::new(self.strings.len() as u32); + self.strings.push(s); + self.names.insert(s, name); + name + } + + pub fn get(&self, symbol: Symbol) -> &'static str { self.strings[symbol.0.as_usize()] } } @@ -197,14 +264,12 @@ impl Interner { // If an interner exists, return it. Otherwise, prepare a fresh one. #[inline] fn with_interner T>(f: F) -> T { - let table = SYMBOL_TABLE.get_or_init(SymbolTable::new); - let mut r = table.interner.write().unwrap(); - f(&mut r) + let mut table = SYMBOL_TABLE.interner.write(); + f(&mut table) } #[inline] fn with_read_only_interner T>(f: F) -> T { - let table = SYMBOL_TABLE.get_or_init(SymbolTable::new); - let r = table.interner.read().unwrap(); - f(&r) + let table = SYMBOL_TABLE.interner.read(); + f(&table) } diff --git a/hir-symbol/src/symbols.toml b/hir-symbol/src/symbols.toml index bababf11f..e6b40f87a 100644 --- a/hir-symbol/src/symbols.toml +++ b/hir-symbol/src/symbols.toml @@ -88,3 +88,6 @@ treeify = { value = "treeify" } [attributes] entrypoint = {} + +[dialects] +hir = {} diff --git a/hir-symbol/src/sync.rs b/hir-symbol/src/sync.rs new file mode 100644 index 000000000..ffcab049a --- /dev/null +++ b/hir-symbol/src/sync.rs @@ -0,0 +1,20 @@ +#[cfg(all(not(feature = "std"), target_family = "wasm"))] +mod lazy_lock; +#[cfg(all(not(feature = "std"), target_family = "wasm"))] +mod rw_lock; + +#[cfg(feature = "std")] +pub use std::sync::LazyLock; + +#[cfg(feature = "std")] +#[allow(unused)] +pub use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +#[cfg(all(not(feature = "std"), target_family = "wasm"))] +pub use self::lazy_lock::RacyLock as LazyLock; +#[cfg(all(not(feature = "std"), target_family = "wasm"))] +#[allow(unused)] +pub use self::rw_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +#[cfg(all(not(feature = "std"), not(target_family = "wasm")))] +compile_error!("no_std builds of this crate are only supported on wasm targets"); diff --git a/hir-symbol/src/sync/lazy_lock.rs b/hir-symbol/src/sync/lazy_lock.rs new file mode 100644 index 000000000..157cd41ed --- /dev/null +++ b/hir-symbol/src/sync/lazy_lock.rs @@ -0,0 +1,187 @@ +use alloc::boxed::Box; +use core::{ + fmt, + ops::Deref, + ptr, + sync::atomic::{AtomicPtr, Ordering}, +}; + +/// Thread-safe, non-blocking, lazily evaluated lock with the same interface +/// as [`std::sync::LazyLock`]. +/// +/// Concurrent threads will race to set the value atomically, and memory allocated by losing threads +/// will be dropped immediately after they fail to set the pointer. +/// +/// The underlying implementation is based on `once_cell::race::OnceBox` which relies on +/// [`core::sync::atomic::AtomicPtr`] to ensure that the data race results in a single successful +/// write to the relevant pointer, namely the first write. +/// See . +/// +/// Performs lazy evaluation and can be used for statics. +pub struct RacyLock T> +where + F: Fn() -> T, +{ + inner: AtomicPtr, + f: F, +} + +impl RacyLock +where + F: Fn() -> T, +{ + /// Creates a new lazy, racy value with the given initializing function. + pub const fn new(f: F) -> Self { + Self { + inner: AtomicPtr::new(ptr::null_mut()), + f, + } + } + + /// Forces the evaluation of the locked value and returns a reference to + /// the result. This is equivalent to the [`Self::deref`]. + /// + /// There is no blocking involved in this operation. Instead, concurrent + /// threads will race to set the underlying pointer. Memory allocated by + /// losing threads will be dropped immediately after they fail to set the pointer. + /// + /// This function's interface is designed around [`std::sync::LazyLock::force`] but + /// the implementation is derived from `once_cell::race::OnceBox::get_or_try_init`. + pub fn force(this: &RacyLock) -> &T { + let mut ptr = this.inner.load(Ordering::Acquire); + + // Pointer is not yet set, attempt to set it ourselves. + if ptr.is_null() { + // Execute the initialization function and allocate. + let val = (this.f)(); + ptr = Box::into_raw(Box::new(val)); + + // Attempt atomic store. + let exchange = this.inner.compare_exchange( + ptr::null_mut(), + ptr, + Ordering::AcqRel, + Ordering::Acquire, + ); + + // Pointer already set, load. + if let Err(old) = exchange { + drop(unsafe { Box::from_raw(ptr) }); + ptr = old; + } + } + + unsafe { &*ptr } + } +} + +impl Default for RacyLock { + /// Creates a new lock that will evaluate the underlying value based on `T::default`. + #[inline] + fn default() -> RacyLock { + RacyLock::new(T::default) + } +} + +impl fmt::Debug for RacyLock +where + T: fmt::Debug, + F: Fn() -> T, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "RacyLock({:?})", self.inner.load(Ordering::Relaxed)) + } +} + +impl Deref for RacyLock +where + F: Fn() -> T, +{ + type Target = T; + + /// Either sets or retrieves the value, and dereferences it. + /// + /// See [`Self::force`] for more details. + #[inline] + fn deref(&self) -> &T { + RacyLock::force(self) + } +} + +impl Drop for RacyLock +where + F: Fn() -> T, +{ + /// Drops the underlying pointer. + fn drop(&mut self) { + let ptr = *self.inner.get_mut(); + if !ptr.is_null() { + // SAFETY: for any given value of `ptr`, we are guaranteed to have at most a single + // instance of `RacyLock` holding that value. Hence, synchronizing threads + // in `drop()` is not necessary, and we are guaranteed never to double-free. + // In short, since `RacyLock` doesn't implement `Clone`, the only scenario + // where there can be multiple instances of `RacyLock` across multiple threads + // referring to the same `ptr` value is when `RacyLock` is used in a static variable. + drop(unsafe { Box::from_raw(ptr) }); + } + } +} + +#[cfg(test)] +mod tests { + use alloc::vec::Vec; + + use super::*; + + #[test] + fn deref_default() { + // Lock a copy type and validate default value. + let lock: RacyLock = RacyLock::default(); + assert_eq!(*lock, 0); + } + + #[test] + fn deref_copy() { + // Lock a copy type and validate value. + let lock = RacyLock::new(|| 42); + assert_eq!(*lock, 42); + } + + #[test] + fn deref_clone() { + // Lock a no copy type. + let lock = RacyLock::new(|| Vec::from([1, 2, 3])); + + // Use the value so that the compiler forces us to clone. + let mut v = lock.clone(); + v.push(4); + + // Validate the value. + assert_eq!(v, Vec::from([1, 2, 3, 4])); + } + + #[test] + fn deref_static() { + // Create a static lock. + static VEC: RacyLock> = RacyLock::new(|| Vec::from([1, 2, 3])); + + // Validate that the address of the value does not change. + let addr = &*VEC as *const Vec; + for _ in 0..5 { + assert_eq!(*VEC, [1, 2, 3]); + assert_eq!(addr, &(*VEC) as *const Vec) + } + } + + #[test] + fn type_inference() { + // Check that we can infer `T` from closure's type. + let _ = RacyLock::new(|| ()); + } + + #[test] + fn is_sync_send() { + fn assert_traits() {} + assert_traits::>>(); + } +} diff --git a/hir-symbol/src/sync/rw_lock.rs b/hir-symbol/src/sync/rw_lock.rs new file mode 100644 index 000000000..66b61bf0a --- /dev/null +++ b/hir-symbol/src/sync/rw_lock.rs @@ -0,0 +1,220 @@ +use core::{ + hint, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use lock_api::RawRwLock; + +/// An implementation of a reader-writer lock, based on a spinlock primitive, no-std compatible +/// +/// See [lock_api::RwLock] for usage. +pub type RwLock = lock_api::RwLock; + +/// See [lock_api::RwLockReadGuard] for usage. +pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, Spinlock, T>; + +/// See [lock_api::RwLockWriteGuard] for usage. +pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, Spinlock, T>; + +/// The underlying raw reader-writer primitive that implements [lock_api::RawRwLock] +/// +/// This is fundamentally a spinlock, in that blocking operations on the lock will spin until +/// they succeed in acquiring/releasing the lock. +/// +/// To acheive the ability to share the underlying data with multiple readers, or hold +/// exclusive access for one writer, the lock state is based on a "locked" count, where shared +/// access increments the count by an even number, and acquiring exclusive access relies on the +/// use of the lowest order bit to stop further shared acquisition, and indicate that the lock +/// is exclusively held (the difference between the two is irrelevant from the perspective of +/// a thread attempting to acquire the lock, but internally the state uses `usize::MAX` as the +/// "exlusively locked" sentinel). +/// +/// This mechanism gets us the following: +/// +/// * Whether the lock has been acquired (shared or exclusive) +/// * Whether the lock is being exclusively acquired +/// * How many times the lock has been acquired +/// * Whether the acquisition(s) are exclusive or shared +/// +/// Further implementation details, such as how we manage draining readers once an attempt to +/// exclusively acquire the lock occurs, are described below. +/// +/// NOTE: This is a simple implementation, meant for use in no-std environments; there are much +/// more robust/performant implementations available when OS primitives can be used. +pub struct Spinlock { + /// The state of the lock, primarily representing the acquisition count, but relying on + /// the distinction between even and odd values to indicate whether or not exclusive access + /// is being acquired. + state: AtomicUsize, + /// A counter used to wake a parked writer once the last shared lock is released during + /// acquisition of an exclusive lock. The actual count is not acutally important, and + /// simply wraps around on overflow, but what is important is that when the value changes, + /// the writer will wake and resume attempting to acquire the exclusive lock. + writer_wake_counter: AtomicUsize, +} + +impl Default for Spinlock { + #[inline(always)] + fn default() -> Self { + Self::new() + } +} + +impl Spinlock { + pub const fn new() -> Self { + Self { + state: AtomicUsize::new(0), + writer_wake_counter: AtomicUsize::new(0), + } + } +} + +unsafe impl RawRwLock for Spinlock { + type GuardMarker = lock_api::GuardSend; + + // This is intentional on the part of the [RawRwLock] API, basically a hack to provide + // initial values as static items. + #[allow(clippy::declare_interior_mutable_const)] + const INIT: Spinlock = Spinlock::new(); + + /// The operation invoked when calling `RwLock::read`, blocks the caller until acquired + fn lock_shared(&self) { + let mut s = self.state.load(Ordering::Relaxed); + loop { + // If the exclusive bit is unset, attempt to acquire a read lock + if s & 1 == 0 { + match self.state.compare_exchange_weak( + s, + s + 2, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => return, + // Someone else beat us to the punch and acquired a lock + Err(e) => s = e, + } + } + // If an exclusive lock is held/being acquired, loop until the lock state changes + // at which point, try to acquire the lock again + if s & 1 == 1 { + loop { + let next = self.state.load(Ordering::Relaxed); + if s == next { + hint::spin_loop(); + continue; + } else { + s = next; + break; + } + } + } + } + } + + /// The operation invoked when calling `RwLock::try_read`, returns whether or not the + /// lock was acquired + fn try_lock_shared(&self) -> bool { + let s = self.state.load(Ordering::Relaxed); + if s & 1 == 0 { + self.state + .compare_exchange_weak(s, s + 2, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } else { + false + } + } + + /// The operation invoked when dropping a `RwLockReadGuard` + unsafe fn unlock_shared(&self) { + if self.state.fetch_sub(2, Ordering::Release) == 3 { + // The lock is being exclusively acquired, and we're the last shared acquisition + // to be released, so wake the writer by incrementing the wake counter + self.writer_wake_counter.fetch_add(1, Ordering::Release); + } + } + + /// The operation invoked when calling `RwLock::write`, blocks the caller until acquired + fn lock_exclusive(&self) { + let mut s = self.state.load(Ordering::Relaxed); + loop { + // Attempt to acquire the lock immediately, or complete acquistion of the lock + // if we're continuing the loop after acquiring the exclusive bit. If another + // thread acquired it first, we race to be the first thread to acquire it once + // released, by busy looping here. + if s <= 1 { + match self.state.compare_exchange( + s, + usize::MAX, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => return, + Err(e) => { + s = e; + hint::spin_loop(); + continue; + } + } + } + + // Only shared locks have been acquired, attempt to acquire the exclusive bit, + // which will prevent further shared locks from being acquired. It does not + // in and of itself grant us exclusive access however. + if s & 1 == 0 { + if let Err(e) = + self.state.compare_exchange(s, s + 1, Ordering::Relaxed, Ordering::Relaxed) + { + // The lock state has changed before we could acquire the exclusive bit, + // update our view of the lock state and try again + s = e; + continue; + } + } + + // We've acquired the exclusive bit, now we need to busy wait until all shared + // acquisitions are released. + let w = self.writer_wake_counter.load(Ordering::Acquire); + s = self.state.load(Ordering::Relaxed); + + // "Park" the thread here (by busy looping), until the release of the last shared + // lock, which is communicated to us by it incrementing the wake counter. + if s >= 2 { + while self.writer_wake_counter.load(Ordering::Acquire) == w { + hint::spin_loop(); + } + s = self.state.load(Ordering::Relaxed); + } + + // All shared locks have been released, go back to the top and try to complete + // acquisition of exclusive access. + } + } + + /// The operation invoked when calling `RwLock::try_write`, returns whether or not the + /// lock was acquired + fn try_lock_exclusive(&self) -> bool { + let s = self.state.load(Ordering::Relaxed); + if s <= 1 { + self.state + .compare_exchange(s, usize::MAX, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } else { + false + } + } + + /// The operation invoked when dropping a `RwLockWriteGuard` + unsafe fn unlock_exclusive(&self) { + // Infallible, as we hold an exclusive lock + // + // Note the use of `Release` ordering here, which ensures any loads of the lock state + // by other threads, are ordered after this store. + self.state.store(0, Ordering::Release); + // This fetch_add isn't important for signaling purposes, however it serves a key + // purpose, in that it imposes a memory ordering on any loads of this field that + // have an `Acquire` ordering, i.e. they will read the value stored here. Without + // a `Release` store, loads/stores of this field could be reordered relative to + // each other. + self.writer_wake_counter.fetch_add(1, Ordering::Release); + } +} diff --git a/hir-transform/Cargo.toml b/hir-transform/Cargo.toml index 88b3f959d..f9b8a8b5e 100644 --- a/hir-transform/Cargo.toml +++ b/hir-transform/Cargo.toml @@ -18,7 +18,6 @@ log.workspace = true midenc-hir.workspace = true midenc-hir-analysis.workspace = true midenc-session.workspace = true -rustc-hash.workspace = true smallvec.workspace = true [dev-dependencies] diff --git a/hir-transform/src/adt/scoped_map.rs b/hir-transform/src/adt/scoped_map.rs index fe6562b11..f1b9dca7d 100644 --- a/hir-transform/src/adt/scoped_map.rs +++ b/hir-transform/src/adt/scoped_map.rs @@ -1,6 +1,6 @@ use std::{borrow::Borrow, fmt, hash::Hash, rc::Rc}; -use rustc_hash::FxHashMap; +use midenc_hir::FxHashMap; #[derive(Clone)] pub struct ScopedMap diff --git a/hir-transform/src/inline_blocks.rs b/hir-transform/src/inline_blocks.rs index aa89c9522..b18467f7a 100644 --- a/hir-transform/src/inline_blocks.rs +++ b/hir-transform/src/inline_blocks.rs @@ -7,7 +7,6 @@ use midenc_hir::{ }; use midenc_hir_analysis::ControlFlowGraph; use midenc_session::{diagnostics::IntoDiagnostic, Session}; -use rustc_hash::FxHashSet; use smallvec::SmallVec; use crate::adt::ScopedMap; diff --git a/hir-transform/src/spill.rs b/hir-transform/src/spill.rs index b9e69ae10..effbcd8a0 100644 --- a/hir-transform/src/spill.rs +++ b/hir-transform/src/spill.rs @@ -11,7 +11,6 @@ use midenc_hir_analysis::{ spill::Placement, ControlFlowGraph, DominanceFrontier, DominatorTree, SpillAnalysis, Use, User, }; use midenc_session::{diagnostics::IntoDiagnostic, Emit, Session}; -use rustc_hash::FxHashSet; /// This pass places spills of SSA values to temporaries to cap the depth of the operand stack. /// diff --git a/hir-transform/src/split_critical_edges.rs b/hir-transform/src/split_critical_edges.rs index 1cef2d1e9..80113c99f 100644 --- a/hir-transform/src/split_critical_edges.rs +++ b/hir-transform/src/split_critical_edges.rs @@ -7,7 +7,6 @@ use midenc_hir::{ }; use midenc_hir_analysis::ControlFlowGraph; use midenc_session::{diagnostics::IntoDiagnostic, Session}; -use rustc_hash::FxHashSet; use smallvec::SmallVec; /// This pass breaks any critical edges in the CFG of a function. diff --git a/hir-type/Cargo.toml b/hir-type/Cargo.toml index bee94a9d2..7706f8b3c 100644 --- a/hir-type/Cargo.toml +++ b/hir-type/Cargo.toml @@ -16,6 +16,7 @@ default = ["serde"] serde = ["dep:serde", "dep:serde_repr"] [dependencies] +miden-formatting.workspace = true smallvec.workspace = true serde = { workspace = true, optional = true } serde_repr = { workspace = true, optional = true } diff --git a/hir-type/src/lib.rs b/hir-type/src/lib.rs index c19c93c2d..7fc802c08 100644 --- a/hir-type/src/lib.rs +++ b/hir-type/src/lib.rs @@ -7,6 +7,8 @@ mod layout; use alloc::{boxed::Box, vec::Vec}; use core::{fmt, num::NonZeroU16, str::FromStr}; +use miden_formatting::prettier::PrettyPrint; + pub use self::layout::Alignable; /// Represents the type of a value @@ -217,6 +219,11 @@ impl Type { matches!(self, Self::Array(_, _)) } + #[inline] + pub fn is_list(&self) -> bool { + matches!(self, Self::List(_)) + } + /// Returns true if `self` and `other` are compatible operand types for a binary operator, e.g. /// `add` /// @@ -328,6 +335,13 @@ impl fmt::Display for Type { } } } +impl PrettyPrint for Type { + fn render(&self) -> miden_formatting::prettier::Document { + use miden_formatting::prettier::*; + + display(self) + } +} /// This represents metadata about how a structured type will be represented in memory #[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)] diff --git a/hir/Cargo.toml b/hir/Cargo.toml index d486dc05e..cf82a9638 100644 --- a/hir/Cargo.toml +++ b/hir/Cargo.toml @@ -32,6 +32,7 @@ intrusive-collections.workspace = true inventory.workspace = true lalrpop-util = "0.20" log.workspace = true +hashbrown.workspace = true miden-core.workspace = true miden-assembly.workspace = true midenc-hir-symbol.workspace = true diff --git a/hir/src/asm/import.rs b/hir/src/asm/import.rs index 2a0a2f844..7fb2b8499 100644 --- a/hir/src/asm/import.rs +++ b/hir/src/asm/import.rs @@ -5,11 +5,10 @@ use core::{ }; use anyhow::bail; -use rustc_hash::{FxHashMap, FxHashSet}; use crate::{ diagnostics::{SourceSpan, Spanned}, - FunctionIdent, Ident, Symbol, + FunctionIdent, FxHashMap, FxHashSet, Ident, Symbol, }; #[derive(Default, Debug, Clone)] @@ -32,7 +31,7 @@ impl ModuleImportInfo { /// /// NOTE: It is assumed that the caller is adding imports using fully-qualified names. pub fn add(&mut self, id: FunctionIdent) { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; let module_id = id.module; match self.modules.entry(module_id) { diff --git a/hir/src/dataflow.rs b/hir/src/dataflow.rs index 1f5d5a4e2..f0b575cb4 100644 --- a/hir/src/dataflow.rs +++ b/hir/src/dataflow.rs @@ -1,7 +1,6 @@ use core::ops::{Deref, DerefMut, Index, IndexMut}; use cranelift_entity::{PrimaryMap, SecondaryMap}; -use rustc_hash::FxHashMap; use smallvec::SmallVec; use crate::{ @@ -114,7 +113,7 @@ impl DataFlowGraph { name: Ident, signature: Signature, ) -> Result { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; let id = FunctionIdent { module, diff --git a/hir/src/lib.rs b/hir/src/lib.rs index 20c1d622d..c9781baf0 100644 --- a/hir/src/lib.rs +++ b/hir/src/lib.rs @@ -27,6 +27,10 @@ pub use midenc_hir_type::{ }; pub use midenc_session::diagnostics::{self, SourceSpan}; +pub type FxHashMap = hashbrown::HashMap; +pub type FxHashSet = hashbrown::HashSet; +pub use rustc_hash::{FxBuildHasher, FxHasher}; + /// Represents a field element in Miden pub type Felt = miden_core::Felt; diff --git a/hir/src/module.rs b/hir/src/module.rs index ff90c3858..d695a1654 100644 --- a/hir/src/module.rs +++ b/hir/src/module.rs @@ -5,7 +5,6 @@ use intrusive_collections::{ linked_list::{Cursor, CursorMut}, LinkedList, LinkedListLink, RBTreeLink, }; -use rustc_hash::FxHashSet; use self::formatter::PrettyPrint; use crate::{ diff --git a/hir/src/parser/ast/convert.rs b/hir/src/parser/ast/convert.rs index 30a75c24f..fc35afea0 100644 --- a/hir/src/parser/ast/convert.rs +++ b/hir/src/parser/ast/convert.rs @@ -650,7 +650,7 @@ fn try_insert_inst( module: function.id.module, }; if let Some(sig) = functions_by_id.get(&local) { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; if let Entry::Vacant(entry) = function.dfg.imports.entry(callee) { entry.insert(ExternalFunction { id: callee, @@ -672,7 +672,7 @@ fn try_insert_inst( callee } Right(external) => { - use std::collections::hash_map::Entry; + use hashbrown::hash_map::Entry; used_imports.insert(external); if let Entry::Vacant(entry) = function.dfg.imports.entry(external) { if let Some(ef) = imports_by_id.get(&external) { diff --git a/hir/src/pass/analysis.rs b/hir/src/pass/analysis.rs index ab5a062b2..8384f3b43 100644 --- a/hir/src/pass/analysis.rs +++ b/hir/src/pass/analysis.rs @@ -5,11 +5,9 @@ use core::{ }; use midenc_session::Session; -use rustc_hash::{FxHashMap, FxHashSet, FxHasher}; +use rustc_hash::FxBuildHasher; -use crate::diagnostics::Report; - -type BuildFxHasher = std::hash::BuildHasherDefault; +use crate::{diagnostics::Report, FxHashMap, FxHashSet}; /// A convenient type alias for `Result` pub type AnalysisResult = Result; @@ -115,7 +113,7 @@ impl CachedAnalysisKey { { use core::hash::Hasher; - let mut hasher = FxHasher::default(); + let mut hasher = rustc_hash::FxHasher::default(); let entity_ty = TypeId::of::(); entity_ty.hash(&mut hasher); key.hash(&mut hasher); @@ -204,11 +202,9 @@ impl PreservedAnalyses { } fn with_capacity(current_entity_key: u64, cap: usize) -> Self { - use std::collections::HashMap; - Self { current_entity_key, - preserved: HashMap::with_capacity_and_hasher(cap, BuildFxHasher::default()), + preserved: FxHashMap::with_capacity_and_hasher(cap, FxBuildHasher), } } } @@ -349,8 +345,6 @@ impl AnalysisManager { where T: AnalysisKey, { - use std::collections::HashMap; - let current_entity_key = CachedAnalysisKey::entity_key::(key); if self.preserve_none.remove(¤t_entity_key) { @@ -381,8 +375,7 @@ impl AnalysisManager { preserve.insert(self.preserve.take(&key).unwrap()); } - let mut cached = - HashMap::with_capacity_and_hasher(to_invalidate.len(), BuildFxHasher::default()); + let mut cached = FxHashMap::with_capacity_and_hasher(to_invalidate.len(), FxBuildHasher); for key in to_invalidate.into_iter() { let (key, value) = self.cached.remove_entry(&key).unwrap(); cached.insert(key, value); diff --git a/hir2/Cargo.toml b/hir2/Cargo.toml new file mode 100644 index 000000000..8f680aacf --- /dev/null +++ b/hir2/Cargo.toml @@ -0,0 +1,68 @@ +[package] +name = "midenc-hir2" +description = "High-level Intermediate Representation for Miden Assembly" +version.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true +categories.workspace = true +keywords.workspace = true +license.workspace = true +readme.workspace = true +edition.workspace = true + +[features] +default = ["std"] +std = ["rustc-demangle/std", "compact_str/std"] +serde = [ + "dep:serde", + "dep:serde_repr", + "dep:serde_bytes", + "midenc-hir-symbol/serde", +] +debug_refcell = [] + +[build-dependencies] +lalrpop = { version = "0.20", default-features = false } + +[dependencies] +anyhow.workspace = true +blink-alloc = { version = "0.3", default-features = false, features = [ + "alloc", + "nightly", +] } +bitvec = { version = "1.0", default-features = false, features = ["alloc"] } +bitflags.workspace = true +either.workspace = true +cranelift-entity.workspace = true +compact_str.workspace = true +hashbrown.workspace = true +intrusive-collections.workspace = true +inventory.workspace = true +lalrpop-util = "0.20" +log.workspace = true +miden-core.workspace = true +miden-assembly.workspace = true +midenc-hir-symbol.workspace = true +midenc-hir-type.workspace = true +midenc-hir-macros.workspace = true +midenc-session.workspace = true +num-bigint = "0.4" +num-traits = "0.2" +petgraph.workspace = true +paste.workspace = true +rustc-hash.workspace = true +rustc-demangle = "0.1.19" +serde = { workspace = true, optional = true } +serde_repr = { workspace = true, optional = true } +serde_bytes = { workspace = true, optional = true } +smallvec.workspace = true +thiserror.workspace = true +typed-arena = "2.0" +unicode-width = { version = "0.1", features = ["no_std"] } +derive_more.workspace = true +indexmap.workspace = true + +[dev-dependencies] +pretty_assertions = "1.0" +env_logger.workspace = true diff --git a/hir2/src/adt.rs b/hir2/src/adt.rs new file mode 100644 index 000000000..2b8471fa9 --- /dev/null +++ b/hir2/src/adt.rs @@ -0,0 +1,21 @@ +pub mod smalldeque; +pub mod smallmap; +pub mod smallordset; +pub mod smallprio; +pub mod smallset; +pub mod sparsemap; + +pub use self::{ + smalldeque::SmallDeque, + smallmap::SmallMap, + smallordset::SmallOrdSet, + smallprio::SmallPriorityQueue, + smallset::SmallSet, + sparsemap::{SparseMap, SparseMapValue}, +}; + +#[doc(hidden)] +pub trait SizedTypeProperties: Sized { + const IS_ZST: bool = core::mem::size_of::() == 0; +} +impl SizedTypeProperties for T {} diff --git a/hir2/src/adt/smalldeque.rs b/hir2/src/adt/smalldeque.rs new file mode 100644 index 000000000..7995bc50a --- /dev/null +++ b/hir2/src/adt/smalldeque.rs @@ -0,0 +1,3768 @@ +use core::{ + cmp::Ordering, + fmt, + iter::{repeat_n, repeat_with, ByRefSized}, + ops::{Index, IndexMut, Range, RangeBounds}, + ptr::{self, NonNull}, +}; + +use smallvec::SmallVec; + +use super::SizedTypeProperties; + +/// [SmallDeque] is a [alloc::collections::VecDeque]-like structure that can store a specified +/// number of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// This data structure is designed to basically provide the functionality of `VecDeque` without +/// needing to allocate on the heap for small numbers of nodes. +/// +/// Internally, [SmallDeque] is implemented on top of [SmallVec]. +/// +/// Most of the implementation is ripped from the standard library `VecDeque` impl, but adapted +/// for `SmallVec` +pub struct SmallDeque { + /// `self[0]`, if it exists, is `buf[head]`. + /// `head < buf.capacity()`, unless `buf.capacity() == 0` when `head == 0`. + head: usize, + /// The number of initialized elements, starting from the one at `head` and potentially + /// wrapping around. + /// + /// If `len == 0`, the exact value of `head` is unimportant. + /// + /// If `T` is zero-sized, then `self.len <= usize::MAX`, otherwise + /// `self.len <= isize::MAX as usize` + len: usize, + buf: SmallVec<[T; N]>, +} +impl Clone for SmallDeque { + fn clone(&self) -> Self { + let mut deq = Self::with_capacity(self.len()); + deq.extend(self.iter().cloned()); + deq + } + + fn clone_from(&mut self, source: &Self) { + self.clear(); + self.extend(source.iter().cloned()); + } +} +impl Default for SmallDeque { + fn default() -> Self { + Self { + head: 0, + len: 0, + buf: Default::default(), + } + } +} +impl SmallDeque { + /// Returns a new, empty [SmallDeque] + #[inline] + #[must_use] + pub const fn new() -> Self { + Self { + head: 0, + len: 0, + buf: SmallVec::new_const(), + } + } + + /// Create an empty deque with pre-allocated space for `capacity` elements. + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + Self { + head: 0, + len: 0, + buf: SmallVec::with_capacity(capacity), + } + } + + /// Returns true if this map is empty + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the number of key/value pairs in this map + pub fn len(&self) -> usize { + self.len + } + + /// Return a front-to-back iterator. + pub fn iter(&self) -> Iter<'_, T> { + let (a, b) = self.as_slices(); + Iter::new(a.iter(), b.iter()) + } + + /// Return a front-to-back iterator that returns mutable references + pub fn iter_mut(&mut self) -> IterMut<'_, T> { + let (a, b) = self.as_mut_slices(); + IterMut::new(a.iter_mut(), b.iter_mut()) + } + + /// Returns a pair of slices which contain, in order, the contents of the + /// deque. + /// + /// If [`SmallDeque::make_contiguous`] was previously called, all elements of the + /// deque will be in the first slice and the second slice will be empty. + #[inline] + pub fn as_slices(&self) -> (&[T], &[T]) { + let (a_range, b_range) = self.slice_ranges(.., self.len); + // SAFETY: `slice_ranges` always returns valid ranges into the physical buffer. + unsafe { (&*self.buffer_range(a_range), &*self.buffer_range(b_range)) } + } + + /// Returns a pair of slices which contain, in order, the contents of the + /// deque. + /// + /// If [`SmallDeque::make_contiguous`] was previously called, all elements of the + /// deque will be in the first slice and the second slice will be empty. + #[inline] + pub fn as_mut_slices(&mut self) -> (&mut [T], &mut [T]) { + let (a_range, b_range) = self.slice_ranges(.., self.len); + // SAFETY: `slice_ranges` always returns valid ranges into the physical buffer. + unsafe { (&mut *self.buffer_range_mut(a_range), &mut *self.buffer_range_mut(b_range)) } + } + + /// Given a range into the logical buffer of the deque, this function + /// return two ranges into the physical buffer that correspond to + /// the given range. The `len` parameter should usually just be `self.len`; + /// the reason it's passed explicitly is that if the deque is wrapped in + /// a `Drain`, then `self.len` is not actually the length of the deque. + /// + /// # Safety + /// + /// This function is always safe to call. For the resulting ranges to be valid + /// ranges into the physical buffer, the caller must ensure that the result of + /// calling `slice::range(range, ..len)` represents a valid range into the + /// logical buffer, and that all elements in that range are initialized. + fn slice_ranges(&self, range: R, len: usize) -> (Range, Range) + where + R: RangeBounds, + { + let Range { start, end } = core::slice::range(range, ..len); + let len = end - start; + + if len == 0 { + (0..0, 0..0) + } else { + // `slice::range` guarantees that `start <= end <= len`. + // because `len != 0`, we know that `start < end`, so `start < len` + // and the indexing is valid. + let wrapped_start = self.to_physical_idx(start); + + // this subtraction can never overflow because `wrapped_start` is + // at most `self.capacity()` (and if `self.capacity != 0`, then `wrapped_start` is strictly less + // than `self.capacity`). + let head_len = self.capacity() - wrapped_start; + + if head_len >= len { + // we know that `len + wrapped_start <= self.capacity <= usize::MAX`, so this addition can't overflow + (wrapped_start..wrapped_start + len, 0..0) + } else { + // can't overflow because of the if condition + let tail_len = len - head_len; + (wrapped_start..self.capacity(), 0..tail_len) + } + } + } + + /// Creates an iterator that covers the specified range in the deque. + /// + /// # Panics + /// + /// Panics if the starting point is greater than the end point or if + /// the end point is greater than the length of the deque. + #[inline] + pub fn range(&self, range: R) -> Iter<'_, T> + where + R: RangeBounds, + { + let (a_range, b_range) = self.slice_ranges(range, self.len); + // SAFETY: The ranges returned by `slice_ranges` + // are valid ranges into the physical buffer, so + // it's ok to pass them to `buffer_range` and + // dereference the result. + let a = unsafe { &*self.buffer_range(a_range) }; + let b = unsafe { &*self.buffer_range(b_range) }; + Iter::new(a.iter(), b.iter()) + } + + /// Creates an iterator that covers the specified mutable range in the deque. + /// + /// # Panics + /// + /// Panics if the starting point is greater than the end point or if + /// the end point is greater than the length of the deque. + #[inline] + pub fn range_mut(&mut self, range: R) -> IterMut<'_, T> + where + R: RangeBounds, + { + let (a_range, b_range) = self.slice_ranges(range, self.len); + // SAFETY: The ranges returned by `slice_ranges` + // are valid ranges into the physical buffer, so + // it's ok to pass them to `buffer_range` and + // dereference the result. + let a = unsafe { &mut *self.buffer_range_mut(a_range) }; + let b = unsafe { &mut *self.buffer_range_mut(b_range) }; + IterMut::new(a.iter_mut(), b.iter_mut()) + } + + /// Get a reference to the element at the given index. + /// + /// Element at index 0 is the front of the queue. + pub fn get(&self, index: usize) -> Option<&T> { + if index < self.len { + let index = self.to_physical_idx(index); + unsafe { Some(&*self.ptr().add(index)) } + } else { + None + } + } + + /// Get a mutable reference to the element at the given index. + /// + /// Element at index 0 is the front of the queue. + pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { + if index < self.len { + let index = self.to_physical_idx(index); + unsafe { Some(&mut *self.ptr_mut().add(index)) } + } else { + None + } + } + + /// Swaps elements at indices `i` and `j` + /// + /// `i` and `j` may be equal. + /// + /// Element at index 0 is the front of the queue. + /// + /// # Panics + /// + /// Panics if either index is out of bounds. + pub fn swap(&mut self, i: usize, j: usize) { + assert!(i < self.len()); + assert!(j < self.len()); + let ri = self.to_physical_idx(i); + let rj = self.to_physical_idx(j); + unsafe { ptr::swap(self.ptr_mut().add(ri), self.ptr_mut().add(rj)) } + } + + /// Returns the number of elements the deque can hold without reallocating. + #[inline] + pub fn capacity(&self) -> usize { + if T::IS_ZST { + usize::MAX + } else { + self.buf.capacity() + } + } + + /// Reserves the minimum capacity for at least `additional` more elements to be inserted in the + /// given deque. Does nothing if the capacity is already sufficient. + /// + /// Note that the allocator may give the collection more space than it requests. Therefore + /// capacity can not be relied upon to be precisely minimal. Prefer [`reserve`] if future + /// insertions are expected. + /// + /// # Panics + /// + /// Panics if the new capacity overflows `usize`. + pub fn reserve_exact(&mut self, additional: usize) { + let new_cap = self.len.checked_add(additional).expect("capacity overflow"); + let old_cap = self.capacity(); + + if new_cap > old_cap { + self.buf.try_grow(new_cap).expect("capacity overflow"); + unsafe { + self.handle_capacity_increase(old_cap); + } + } + } + + /// Reserves capacity for at least `additional` more elements to be inserted in the given + /// deque. The collection may reserve more space to speculatively avoid frequent reallocations. + /// + /// # Panics + /// + /// Panics if the new capacity overflows `usize`. + pub fn reserve(&mut self, additional: usize) { + let new_cap = self.len.checked_add(additional).expect("capacity overflow"); + let old_cap = self.capacity(); + + if new_cap > old_cap { + // we don't need to reserve_exact(), as the size doesn't have + // to be a power of 2. + self.buf.try_grow(new_cap).expect("capacity overflow"); + unsafe { + self.handle_capacity_increase(old_cap); + } + } + } + + /// Shortens the deque, keeping the first `len` elements and dropping + /// the rest. + /// + /// If `len` is greater or equal to the deque's current length, this has + /// no effect. + pub fn truncate(&mut self, len: usize) { + /// Runs the destructor for all items in the slice when it gets dropped (normally or + /// during unwinding). + struct Dropper<'a, T>(&'a mut [T]); + + impl<'a, T> Drop for Dropper<'a, T> { + fn drop(&mut self) { + unsafe { + ptr::drop_in_place(self.0); + } + } + } + + // Safe because: + // + // * Any slice passed to `drop_in_place` is valid; the second case has + // `len <= front.len()` and returning on `len > self.len()` ensures + // `begin <= back.len()` in the first case + // * The head of the SmallDeque is moved before calling `drop_in_place`, + // so no value is dropped twice if `drop_in_place` panics + unsafe { + if len >= self.len { + return; + } + + let (front, back) = self.as_mut_slices(); + if len > front.len() { + let begin = len - front.len(); + let drop_back = back.get_unchecked_mut(begin..) as *mut _; + self.len = len; + ptr::drop_in_place(drop_back); + } else { + let drop_back = back as *mut _; + let drop_front = front.get_unchecked_mut(len..) as *mut _; + self.len = len; + + // Make sure the second half is dropped even when a destructor + // in the first one panics. + let _back_dropper = Dropper(&mut *drop_back); + ptr::drop_in_place(drop_front); + } + } + } + + /// Removes the specified range from the deque in bulk, returning all + /// removed elements as an iterator. If the iterator is dropped before + /// being fully consumed, it drops the remaining removed elements. + /// + /// The returned iterator keeps a mutable borrow on the queue to optimize + /// its implementation. + /// + /// + /// # Panics + /// + /// Panics if the starting point is greater than the end point or if + /// the end point is greater than the length of the deque. + /// + /// # Leaking + /// + /// If the returned iterator goes out of scope without being dropped (due to + /// [`mem::forget`], for example), the deque may have lost and leaked + /// elements arbitrarily, including elements outside the range. + #[inline] + pub fn drain(&mut self, range: R) -> Drain<'_, T, N> + where + R: RangeBounds, + { + // Memory safety + // + // When the Drain is first created, the source deque is shortened to + // make sure no uninitialized or moved-from elements are accessible at + // all if the Drain's destructor never gets to run. + // + // Drain will ptr::read out the values to remove. + // When finished, the remaining data will be copied back to cover the hole, + // and the head/tail values will be restored correctly. + // + let Range { start, end } = core::slice::range(range, ..self.len); + let drain_start = start; + let drain_len = end - start; + + // The deque's elements are parted into three segments: + // * 0 -> drain_start + // * drain_start -> drain_start+drain_len + // * drain_start+drain_len -> self.len + // + // H = self.head; T = self.head+self.len; t = drain_start+drain_len; h = drain_head + // + // We store drain_start as self.len, and drain_len and self.len as + // drain_len and orig_len respectively on the Drain. This also + // truncates the effective array such that if the Drain is leaked, we + // have forgotten about the potentially moved values after the start of + // the drain. + // + // H h t T + // [. . . o o x x o o . . .] + // + // "forget" about the values after the start of the drain until after + // the drain is complete and the Drain destructor is run. + + unsafe { Drain::new(self, drain_start, drain_len) } + } + + /// Clears the deque, removing all values. + #[inline] + pub fn clear(&mut self) { + self.truncate(0); + // Not strictly necessary, but leaves things in a more consistent/predictable state. + self.head = 0; + } + + /// Returns `true` if the deque contains an element equal to the + /// given value. + /// + /// This operation is *O*(*n*). + /// + /// Note that if you have a sorted `SmallDeque`, [`binary_search`] may be faster. + /// + /// [`binary_search`]: SmallDeque::binary_search + pub fn contains(&self, x: &T) -> bool + where + T: PartialEq, + { + let (a, b) = self.as_slices(); + a.contains(x) || b.contains(x) + } + + /// Provides a reference to the front element, or `None` if the deque is + /// empty. + pub fn front(&self) -> Option<&T> { + self.get(0) + } + + /// Provides a mutable reference to the front element, or `None` if the + /// deque is empty. + pub fn front_mut(&mut self) -> Option<&mut T> { + self.get_mut(0) + } + + /// Provides a reference to the back element, or `None` if the deque is + /// empty. + pub fn back(&self) -> Option<&T> { + self.get(self.len.wrapping_sub(1)) + } + + /// Provides a mutable reference to the back element, or `None` if the + /// deque is empty. + pub fn back_mut(&mut self) -> Option<&mut T> { + self.get_mut(self.len.wrapping_sub(1)) + } + + /// Removes the first element and returns it, or `None` if the deque is + /// empty. + pub fn pop_front(&mut self) -> Option { + if self.is_empty() { + None + } else { + let old_head = self.head; + self.head = self.to_physical_idx(1); + self.len -= 1; + unsafe { + core::hint::assert_unchecked(self.len < self.capacity()); + Some(self.buffer_read(old_head)) + } + } + } + + /// Removes the last element from the deque and returns it, or `None` if + /// it is empty. + pub fn pop_back(&mut self) -> Option { + if self.is_empty() { + None + } else { + self.len -= 1; + unsafe { + core::hint::assert_unchecked(self.len < self.capacity()); + Some(self.buffer_read(self.to_physical_idx(self.len))) + } + } + } + + /// Prepends an element to the deque. + pub fn push_front(&mut self, value: T) { + if self.is_full() { + self.grow(); + } + + self.head = self.wrap_sub(self.head, 1); + self.len += 1; + + unsafe { + self.buffer_write(self.head, value); + } + } + + /// Appends an element to the back of the deque. + pub fn push_back(&mut self, value: T) { + if self.is_full() { + self.grow(); + } + + unsafe { self.buffer_write(self.to_physical_idx(self.len), value) } + self.len += 1; + } + + #[inline] + fn is_contiguous(&self) -> bool { + // Do the calculation like this to avoid overflowing if len + head > usize::MAX + self.head <= self.capacity() - self.len + } + + /// Removes an element from anywhere in the deque and returns it, + /// replacing it with the first element. + /// + /// This does not preserve ordering, but is *O*(1). + /// + /// Returns `None` if `index` is out of bounds. + /// + /// Element at index 0 is the front of the queue. + pub fn swap_remove_front(&mut self, index: usize) -> Option { + let length = self.len; + if index < length && index != 0 { + self.swap(index, 0); + } else if index >= length { + return None; + } + self.pop_front() + } + + /// Removes an element from anywhere in the deque and returns it, + /// replacing it with the last element. + /// + /// This does not preserve ordering, but is *O*(1). + /// + /// Returns `None` if `index` is out of bounds. + /// + /// Element at index 0 is the front of the queue. + pub fn swap_remove_back(&mut self, index: usize) -> Option { + let length = self.len; + if length > 0 && index < length - 1 { + self.swap(index, length - 1); + } else if index >= length { + return None; + } + self.pop_back() + } + + /// Inserts an element at `index` within the deque, shifting all elements + /// with indices greater than or equal to `index` towards the back. + /// + /// Element at index 0 is the front of the queue. + /// + /// # Panics + /// + /// Panics if `index` is greater than deque's length + pub fn insert(&mut self, index: usize, value: T) { + assert!(index <= self.len(), "index out of bounds"); + if self.is_full() { + self.grow(); + } + + let k = self.len - index; + if k < index { + // `index + 1` can't overflow, because if index was usize::MAX, then either the + // assert would've failed, or the deque would've tried to grow past usize::MAX + // and panicked. + unsafe { + // see `remove()` for explanation why this wrap_copy() call is safe. + self.wrap_copy(self.to_physical_idx(index), self.to_physical_idx(index + 1), k); + self.buffer_write(self.to_physical_idx(index), value); + self.len += 1; + } + } else { + let old_head = self.head; + self.head = self.wrap_sub(self.head, 1); + unsafe { + self.wrap_copy(old_head, self.head, index); + self.buffer_write(self.to_physical_idx(index), value); + self.len += 1; + } + } + } + + /// Removes and returns the element at `index` from the deque. + /// Whichever end is closer to the removal point will be moved to make + /// room, and all the affected elements will be moved to new positions. + /// Returns `None` if `index` is out of bounds. + /// + /// Element at index 0 is the front of the queue. + pub fn remove(&mut self, index: usize) -> Option { + if self.len <= index { + return None; + } + + let wrapped_idx = self.to_physical_idx(index); + + let elem = unsafe { Some(self.buffer_read(wrapped_idx)) }; + + let k = self.len - index - 1; + // safety: due to the nature of the if-condition, whichever wrap_copy gets called, + // its length argument will be at most `self.len / 2`, so there can't be more than + // one overlapping area. + if k < index { + unsafe { self.wrap_copy(self.wrap_add(wrapped_idx, 1), wrapped_idx, k) }; + self.len -= 1; + } else { + let old_head = self.head; + self.head = self.to_physical_idx(1); + unsafe { self.wrap_copy(old_head, self.head, index) }; + self.len -= 1; + } + + elem + } + + /// Splits the deque into two at the given index. + /// + /// Returns a newly allocated `SmallDeque`. `self` contains elements `[0, at)`, + /// and the returned deque contains elements `[at, len)`. + /// + /// Note that the capacity of `self` does not change. + /// + /// Element at index 0 is the front of the queue. + /// + /// # Panics + /// + /// Panics if `at > len`. + #[inline] + #[must_use = "use `.truncate()` if you don't need the other half"] + pub fn split_off(&mut self, at: usize) -> Self { + let len = self.len; + assert!(at <= len, "`at` out of bounds"); + + let other_len = len - at; + let mut other = Self::with_capacity(other_len); + + unsafe { + let (first_half, second_half) = self.as_slices(); + + let first_len = first_half.len(); + let second_len = second_half.len(); + if at < first_len { + // `at` lies in the first half. + let amount_in_first = first_len - at; + + ptr::copy_nonoverlapping( + first_half.as_ptr().add(at), + other.ptr_mut(), + amount_in_first, + ); + + // just take all of the second half. + ptr::copy_nonoverlapping( + second_half.as_ptr(), + other.ptr_mut().add(amount_in_first), + second_len, + ); + } else { + // `at` lies in the second half, need to factor in the elements we skipped + // in the first half. + let offset = at - first_len; + let amount_in_second = second_len - offset; + ptr::copy_nonoverlapping( + second_half.as_ptr().add(offset), + other.ptr_mut(), + amount_in_second, + ); + } + } + + // Cleanup where the ends of the buffers are + self.len = at; + other.len = other_len; + + other + } + + /// Moves all the elements of `other` into `self`, leaving `other` empty. + /// + /// # Panics + /// + /// Panics if the new number of elements in self overflows a `usize`. + #[inline] + pub fn append(&mut self, other: &mut Self) { + if T::IS_ZST { + self.len = self.len.checked_add(other.len).expect("capacity overflow"); + other.len = 0; + other.head = 0; + return; + } + + self.reserve(other.len); + unsafe { + let (left, right) = other.as_slices(); + self.copy_slice(self.to_physical_idx(self.len), left); + // no overflow, because self.capacity() >= old_cap + left.len() >= self.len + left.len() + self.copy_slice(self.to_physical_idx(self.len + left.len()), right); + } + // SAFETY: Update pointers after copying to avoid leaving doppelganger + // in case of panics. + self.len += other.len; + // Now that we own its values, forget everything in `other`. + other.len = 0; + other.head = 0; + } + + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `f(&e)` returns false. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + pub fn retain(&mut self, mut f: F) + where + F: FnMut(&T) -> bool, + { + self.retain_mut(|elem| f(elem)); + } + + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `f(&e)` returns false. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + pub fn retain_mut(&mut self, mut f: F) + where + F: FnMut(&mut T) -> bool, + { + let len = self.len; + let mut idx = 0; + let mut cur = 0; + + // Stage 1: All values are retained. + while cur < len { + if !f(&mut self[cur]) { + cur += 1; + break; + } + cur += 1; + idx += 1; + } + // Stage 2: Swap retained value into current idx. + while cur < len { + if !f(&mut self[cur]) { + cur += 1; + continue; + } + + self.swap(idx, cur); + cur += 1; + idx += 1; + } + // Stage 3: Truncate all values after idx. + if cur != idx { + self.truncate(idx); + } + } + + // Double the buffer size. This method is inline(never), so we expect it to only + // be called in cold paths. + // This may panic or abort + #[inline(never)] + fn grow(&mut self) { + // Extend or possibly remove this assertion when valid use-cases for growing the + // buffer without it being full emerge + debug_assert!(self.is_full()); + let old_cap = self.capacity(); + self.buf.grow(old_cap + 1); + unsafe { + self.handle_capacity_increase(old_cap); + } + debug_assert!(!self.is_full()); + } + + /// Modifies the deque in-place so that `len()` is equal to `new_len`, + /// either by removing excess elements from the back or by appending + /// elements generated by calling `generator` to the back. + pub fn resize_with(&mut self, new_len: usize, generator: impl FnMut() -> T) { + let len = self.len; + + if new_len > len { + self.extend(repeat_with(generator).take(new_len - len)) + } else { + self.truncate(new_len); + } + } + + /// Rearranges the internal storage of this deque so it is one contiguous + /// slice, which is then returned. + /// + /// This method does not allocate and does not change the order of the + /// inserted elements. As it returns a mutable slice, this can be used to + /// sort a deque. + /// + /// Once the internal storage is contiguous, the [`as_slices`] and + /// [`as_mut_slices`] methods will return the entire contents of the + /// deque in a single slice. + /// + /// [`as_slices`]: SmallDeque::as_slices + /// [`as_mut_slices`]: SmallDeque::as_mut_slices + pub fn make_contiguous(&mut self) -> &mut [T] { + if T::IS_ZST { + self.head = 0; + } + + if self.is_contiguous() { + unsafe { + return core::slice::from_raw_parts_mut(self.ptr_mut().add(self.head), self.len); + } + } + + let &mut Self { head, len, .. } = self; + let ptr = self.ptr_mut(); + let cap = self.capacity(); + + let free = cap - len; + let head_len = cap - head; + let tail = len - head_len; + let tail_len = tail; + + if free >= head_len { + // there is enough free space to copy the head in one go, + // this means that we first shift the tail backwards, and then + // copy the head to the correct position. + // + // from: DEFGH....ABC + // to: ABCDEFGH.... + unsafe { + self.copy(0, head_len, tail_len); + // ...DEFGH.ABC + self.copy_nonoverlapping(head, 0, head_len); + // ABCDEFGH.... + } + + self.head = 0; + } else if free >= tail_len { + // there is enough free space to copy the tail in one go, + // this means that we first shift the head forwards, and then + // copy the tail to the correct position. + // + // from: FGH....ABCDE + // to: ...ABCDEFGH. + unsafe { + self.copy(head, tail, head_len); + // FGHABCDE.... + self.copy_nonoverlapping(0, tail + head_len, tail_len); + // ...ABCDEFGH. + } + + self.head = tail; + } else { + // `free` is smaller than both `head_len` and `tail_len`. + // the general algorithm for this first moves the slices + // right next to each other and then uses `slice::rotate` + // to rotate them into place: + // + // initially: HIJK..ABCDEFG + // step 1: ..HIJKABCDEFG + // step 2: ..ABCDEFGHIJK + // + // or: + // + // initially: FGHIJK..ABCDE + // step 1: FGHIJKABCDE.. + // step 2: ABCDEFGHIJK.. + + // pick the shorter of the 2 slices to reduce the amount + // of memory that needs to be moved around. + if head_len > tail_len { + // tail is shorter, so: + // 1. copy tail forwards + // 2. rotate used part of the buffer + // 3. update head to point to the new beginning (which is just `free`) + + unsafe { + // if there is no free space in the buffer, then the slices are already + // right next to each other and we don't need to move any memory. + if free != 0 { + // because we only move the tail forward as much as there's free space + // behind it, we don't overwrite any elements of the head slice, and + // the slices end up right next to each other. + self.copy(0, free, tail_len); + } + + // We just copied the tail right next to the head slice, + // so all of the elements in the range are initialized + let slice = &mut *self.buffer_range_mut(free..self.capacity()); + + // because the deque wasn't contiguous, we know that `tail_len < self.len == slice.len()`, + // so this will never panic. + slice.rotate_left(tail_len); + + // the used part of the buffer now is `free..self.capacity()`, so set + // `head` to the beginning of that range. + self.head = free; + } + } else { + // head is shorter so: + // 1. copy head backwards + // 2. rotate used part of the buffer + // 3. update head to point to the new beginning (which is the beginning of the buffer) + + unsafe { + // if there is no free space in the buffer, then the slices are already + // right next to each other and we don't need to move any memory. + if free != 0 { + // copy the head slice to lie right behind the tail slice. + self.copy(self.head, tail_len, head_len); + } + + // because we copied the head slice so that both slices lie right + // next to each other, all the elements in the range are initialized. + let slice = &mut *self.buffer_range_mut(0..self.len); + + // because the deque wasn't contiguous, we know that `head_len < self.len == slice.len()` + // so this will never panic. + slice.rotate_right(head_len); + + // the used part of the buffer now is `0..self.len`, so set + // `head` to the beginning of that range. + self.head = 0; + } + } + } + + unsafe { core::slice::from_raw_parts_mut(ptr.add(self.head), self.len) } + } + + /// Rotates the double-ended queue `n` places to the left. + /// + /// Equivalently, + /// - Rotates item `n` into the first position. + /// - Pops the first `n` items and pushes them to the end. + /// - Rotates `len() - n` places to the right. + /// + /// # Panics + /// + /// If `n` is greater than `len()`. Note that `n == len()` + /// does _not_ panic and is a no-op rotation. + /// + /// # Complexity + /// + /// Takes `*O*(min(n, len() - n))` time and no extra space. + pub fn rotate_left(&mut self, n: usize) { + assert!(n <= self.len()); + let k = self.len - n; + if n <= k { + unsafe { self.rotate_left_inner(n) } + } else { + unsafe { self.rotate_right_inner(k) } + } + } + + /// Rotates the double-ended queue `n` places to the right. + /// + /// Equivalently, + /// - Rotates the first item into position `n`. + /// - Pops the last `n` items and pushes them to the front. + /// - Rotates `len() - n` places to the left. + /// + /// # Panics + /// + /// If `n` is greater than `len()`. Note that `n == len()` + /// does _not_ panic and is a no-op rotation. + /// + /// # Complexity + /// + /// Takes `*O*(min(n, len() - n))` time and no extra space. + pub fn rotate_right(&mut self, n: usize) { + assert!(n <= self.len()); + let k = self.len - n; + if n <= k { + unsafe { self.rotate_right_inner(n) } + } else { + unsafe { self.rotate_left_inner(k) } + } + } + + // SAFETY: the following two methods require that the rotation amount + // be less than half the length of the deque. + // + // `wrap_copy` requires that `min(x, capacity() - x) + copy_len <= capacity()`, + // but then `min` is never more than half the capacity, regardless of x, + // so it's sound to call here because we're calling with something + // less than half the length, which is never above half the capacity. + unsafe fn rotate_left_inner(&mut self, mid: usize) { + debug_assert!(mid * 2 <= self.len()); + unsafe { + self.wrap_copy(self.head, self.to_physical_idx(self.len), mid); + } + self.head = self.to_physical_idx(mid); + } + + unsafe fn rotate_right_inner(&mut self, k: usize) { + debug_assert!(k * 2 <= self.len()); + self.head = self.wrap_sub(self.head, k); + unsafe { + self.wrap_copy(self.to_physical_idx(self.len), self.head, k); + } + } + + /// Binary searches this `SmallDeque` for a given element. + /// If the `SmallDeque` is not sorted, the returned result is unspecified and + /// meaningless. + /// + /// If the value is found then [`Result::Ok`] is returned, containing the + /// index of the matching element. If there are multiple matches, then any + /// one of the matches could be returned. If the value is not found then + /// [`Result::Err`] is returned, containing the index where a matching + /// element could be inserted while maintaining sorted order. + /// + /// See also [`binary_search_by`], [`binary_search_by_key`], and [`partition_point`]. + /// + /// [`binary_search_by`]: SmallDeque::binary_search_by + /// [`binary_search_by_key`]: SmallDeque::binary_search_by_key + /// [`partition_point`]: SmallDeque::partition_point + /// + #[inline] + pub fn binary_search(&self, x: &T) -> Result + where + T: Ord, + { + self.binary_search_by(|e| e.cmp(x)) + } + + /// Binary searches this `SmallDeque` with a comparator function. + /// + /// The comparator function should return an order code that indicates + /// whether its argument is `Less`, `Equal` or `Greater` the desired + /// target. + /// If the `SmallDeque` is not sorted or if the comparator function does not + /// implement an order consistent with the sort order of the underlying + /// `SmallDeque`, the returned result is unspecified and meaningless. + /// + /// If the value is found then [`Result::Ok`] is returned, containing the + /// index of the matching element. If there are multiple matches, then any + /// one of the matches could be returned. If the value is not found then + /// [`Result::Err`] is returned, containing the index where a matching + /// element could be inserted while maintaining sorted order. + /// + /// See also [`binary_search`], [`binary_search_by_key`], and [`partition_point`]. + /// + /// [`binary_search`]: SmallDeque::binary_search + /// [`binary_search_by_key`]: SmallDeque::binary_search_by_key + /// [`partition_point`]: SmallDeque::partition_point + pub fn binary_search_by<'a, F>(&'a self, mut f: F) -> Result + where + F: FnMut(&'a T) -> Ordering, + { + let (front, back) = self.as_slices(); + // clippy doesn't recognize that `f` would be moved if we followed it's recommendation + #[allow(clippy::redundant_closure)] + let cmp_back = back.first().map(|e| f(e)); + + if let Some(Ordering::Equal) = cmp_back { + Ok(front.len()) + } else if let Some(Ordering::Less) = cmp_back { + back.binary_search_by(f) + .map(|idx| idx + front.len()) + .map_err(|idx| idx + front.len()) + } else { + front.binary_search_by(f) + } + } + + /// Binary searches this `SmallDeque` with a key extraction function. + /// + /// Assumes that the deque is sorted by the key, for instance with + /// [`make_contiguous().sort_by_key()`] using the same key extraction function. + /// If the deque is not sorted by the key, the returned result is + /// unspecified and meaningless. + /// + /// If the value is found then [`Result::Ok`] is returned, containing the + /// index of the matching element. If there are multiple matches, then any + /// one of the matches could be returned. If the value is not found then + /// [`Result::Err`] is returned, containing the index where a matching + /// element could be inserted while maintaining sorted order. + /// + /// See also [`binary_search`], [`binary_search_by`], and [`partition_point`]. + /// + /// [`make_contiguous().sort_by_key()`]: SmallDeque::make_contiguous + /// [`binary_search`]: SmallDeque::binary_search + /// [`binary_search_by`]: SmallDeque::binary_search_by + /// [`partition_point`]: SmallDeque::partition_point + #[inline] + pub fn binary_search_by_key<'a, B, F>(&'a self, b: &B, mut f: F) -> Result + where + F: FnMut(&'a T) -> B, + B: Ord, + { + self.binary_search_by(|k| f(k).cmp(b)) + } + + /// Returns the index of the partition point according to the given predicate + /// (the index of the first element of the second partition). + /// + /// The deque is assumed to be partitioned according to the given predicate. + /// This means that all elements for which the predicate returns true are at the start of the deque + /// and all elements for which the predicate returns false are at the end. + /// For example, `[7, 15, 3, 5, 4, 12, 6]` is partitioned under the predicate `x % 2 != 0` + /// (all odd numbers are at the start, all even at the end). + /// + /// If the deque is not partitioned, the returned result is unspecified and meaningless, + /// as this method performs a kind of binary search. + /// + /// See also [`binary_search`], [`binary_search_by`], and [`binary_search_by_key`]. + /// + /// [`binary_search`]: SmallDeque::binary_search + /// [`binary_search_by`]: SmallDeque::binary_search_by + /// [`binary_search_by_key`]: SmallDeque::binary_search_by_key + pub fn partition_point

(&self, mut pred: P) -> usize + where + P: FnMut(&T) -> bool, + { + let (front, back) = self.as_slices(); + + #[allow(clippy::redundant_closure)] + if let Some(true) = back.first().map(|v| pred(v)) { + back.partition_point(pred) + front.len() + } else { + front.partition_point(pred) + } + } +} + +impl SmallDeque { + /// Modifies the deque in-place so that `len()` is equal to new_len, + /// either by removing excess elements from the back or by appending clones of `value` + /// to the back. + pub fn resize(&mut self, new_len: usize, value: T) { + if new_len > self.len() { + let extra = new_len - self.len(); + self.extend(repeat_n(value, extra)) + } else { + self.truncate(new_len); + } + } +} + +impl SmallDeque { + #[inline] + fn ptr(&self) -> *const T { + self.buf.as_ptr() + } + + #[inline] + fn ptr_mut(&mut self) -> *mut T { + self.buf.as_mut_ptr() + } + + /// Appends an element to the buffer. + /// + /// # Safety + /// + /// May only be called if `deque.len() < deque.capacity()` + #[inline] + unsafe fn push_unchecked(&mut self, element: T) { + // SAFETY: Because of the precondition, it's guaranteed that there is space in the logical + // array after the last element. + unsafe { self.buffer_write(self.to_physical_idx(self.len), element) }; + // This can't overflow because `deque.len() < deque.capacity() <= usize::MAX` + self.len += 1; + } + + /// Moves an element out of the buffer + #[inline] + unsafe fn buffer_read(&mut self, offset: usize) -> T { + unsafe { ptr::read(self.ptr().add(offset)) } + } + + /// Writes an element into the buffer, moving it. + #[inline] + unsafe fn buffer_write(&mut self, offset: usize, value: T) { + unsafe { + ptr::write(self.ptr_mut().add(offset), value); + } + } + + /// Returns a slice pointer into the buffer. + /// `range` must lie inside `0..self.capacity()`. + #[inline] + unsafe fn buffer_range(&self, range: core::ops::Range) -> *const [T] { + unsafe { ptr::slice_from_raw_parts(self.ptr().add(range.start), range.end - range.start) } + } + + /// Returns a slice pointer into the buffer. + /// `range` must lie inside `0..self.capacity()`. + #[inline] + unsafe fn buffer_range_mut(&mut self, range: core::ops::Range) -> *mut [T] { + unsafe { + ptr::slice_from_raw_parts_mut(self.ptr_mut().add(range.start), range.end - range.start) + } + } + + /// Returns `true` if the buffer is at full capacity. + #[inline] + fn is_full(&self) -> bool { + self.len == self.capacity() + } + + /// Returns the index in the underlying buffer for a given logical element index + addend. + #[inline] + fn wrap_add(&self, idx: usize, addend: usize) -> usize { + wrap_index(idx.wrapping_add(addend), self.capacity()) + } + + #[inline] + fn to_physical_idx(&self, idx: usize) -> usize { + self.wrap_add(self.head, idx) + } + + /// Returns the index in the underlying buffer for a given logical element index - subtrahend. + #[inline] + fn wrap_sub(&self, idx: usize, subtrahend: usize) -> usize { + wrap_index(idx.wrapping_sub(subtrahend).wrapping_add(self.capacity()), self.capacity()) + } + + /// Copies a contiguous block of memory len long from src to dst + #[inline] + unsafe fn copy(&mut self, src: usize, dst: usize, len: usize) { + debug_assert!( + dst + len <= self.capacity(), + "cpy dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + debug_assert!( + src + len <= self.capacity(), + "cpy dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + unsafe { + ptr::copy(self.ptr().add(src), self.ptr_mut().add(dst), len); + } + } + + /// Copies a contiguous block of memory len long from src to dst + #[inline] + unsafe fn copy_nonoverlapping(&mut self, src: usize, dst: usize, len: usize) { + debug_assert!( + dst + len <= self.capacity(), + "cno dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + debug_assert!( + src + len <= self.capacity(), + "cno dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + unsafe { + ptr::copy_nonoverlapping(self.ptr().add(src), self.ptr_mut().add(dst), len); + } + } + + /// Copies a potentially wrapping block of memory len long from src to dest. + /// (abs(dst - src) + len) must be no larger than capacity() (There must be at + /// most one continuous overlapping region between src and dest). + unsafe fn wrap_copy(&mut self, src: usize, dst: usize, len: usize) { + debug_assert!( + core::cmp::min(src.abs_diff(dst), self.capacity() - src.abs_diff(dst)) + len + <= self.capacity(), + "wrc dst={} src={} len={} cap={}", + dst, + src, + len, + self.capacity() + ); + + // If T is a ZST, don't do any copying. + if T::IS_ZST || src == dst || len == 0 { + return; + } + + let dst_after_src = self.wrap_sub(dst, src) < len; + + let src_pre_wrap_len = self.capacity() - src; + let dst_pre_wrap_len = self.capacity() - dst; + let src_wraps = src_pre_wrap_len < len; + let dst_wraps = dst_pre_wrap_len < len; + + match (dst_after_src, src_wraps, dst_wraps) { + (_, false, false) => { + // src doesn't wrap, dst doesn't wrap + // + // S . . . + // 1 [_ _ A A B B C C _] + // 2 [_ _ A A A A B B _] + // D . . . + // + unsafe { + self.copy(src, dst, len); + } + } + (false, false, true) => { + // dst before src, src doesn't wrap, dst wraps + // + // S . . . + // 1 [A A B B _ _ _ C C] + // 2 [A A B B _ _ _ A A] + // 3 [B B B B _ _ _ A A] + // . . D . + // + unsafe { + self.copy(src, dst, dst_pre_wrap_len); + self.copy(src + dst_pre_wrap_len, 0, len - dst_pre_wrap_len); + } + } + (true, false, true) => { + // src before dst, src doesn't wrap, dst wraps + // + // S . . . + // 1 [C C _ _ _ A A B B] + // 2 [B B _ _ _ A A B B] + // 3 [B B _ _ _ A A A A] + // . . D . + // + unsafe { + self.copy(src + dst_pre_wrap_len, 0, len - dst_pre_wrap_len); + self.copy(src, dst, dst_pre_wrap_len); + } + } + (false, true, false) => { + // dst before src, src wraps, dst doesn't wrap + // + // . . S . + // 1 [C C _ _ _ A A B B] + // 2 [C C _ _ _ B B B B] + // 3 [C C _ _ _ B B C C] + // D . . . + // + unsafe { + self.copy(src, dst, src_pre_wrap_len); + self.copy(0, dst + src_pre_wrap_len, len - src_pre_wrap_len); + } + } + (true, true, false) => { + // src before dst, src wraps, dst doesn't wrap + // + // . . S . + // 1 [A A B B _ _ _ C C] + // 2 [A A A A _ _ _ C C] + // 3 [C C A A _ _ _ C C] + // D . . . + // + unsafe { + self.copy(0, dst + src_pre_wrap_len, len - src_pre_wrap_len); + self.copy(src, dst, src_pre_wrap_len); + } + } + (false, true, true) => { + // dst before src, src wraps, dst wraps + // + // . . . S . + // 1 [A B C D _ E F G H] + // 2 [A B C D _ E G H H] + // 3 [A B C D _ E G H A] + // 4 [B C C D _ E G H A] + // . . D . . + // + debug_assert!(dst_pre_wrap_len > src_pre_wrap_len); + let delta = dst_pre_wrap_len - src_pre_wrap_len; + unsafe { + self.copy(src, dst, src_pre_wrap_len); + self.copy(0, dst + src_pre_wrap_len, delta); + self.copy(delta, 0, len - dst_pre_wrap_len); + } + } + (true, true, true) => { + // src before dst, src wraps, dst wraps + // + // . . S . . + // 1 [A B C D _ E F G H] + // 2 [A A B D _ E F G H] + // 3 [H A B D _ E F G H] + // 4 [H A B D _ E F F G] + // . . . D . + // + debug_assert!(src_pre_wrap_len > dst_pre_wrap_len); + let delta = src_pre_wrap_len - dst_pre_wrap_len; + unsafe { + self.copy(0, delta, len - src_pre_wrap_len); + self.copy(self.capacity() - delta, 0, delta); + self.copy(src, dst, dst_pre_wrap_len); + } + } + } + } + + /// Copies all values from `src` to `dst`, wrapping around if needed. + /// Assumes capacity is sufficient. + #[inline] + unsafe fn copy_slice(&mut self, dst: usize, src: &[T]) { + debug_assert!(src.len() <= self.capacity()); + let head_room = self.capacity() - dst; + if src.len() <= head_room { + unsafe { + ptr::copy_nonoverlapping(src.as_ptr(), self.ptr_mut().add(dst), src.len()); + } + } else { + let (left, right) = src.split_at(head_room); + unsafe { + ptr::copy_nonoverlapping(left.as_ptr(), self.ptr_mut().add(dst), left.len()); + ptr::copy_nonoverlapping(right.as_ptr(), self.ptr_mut(), right.len()); + } + } + } + + /// Writes all values from `iter` to `dst`. + /// + /// # Safety + /// + /// Assumes no wrapping around happens. + /// Assumes capacity is sufficient. + #[inline] + unsafe fn write_iter( + &mut self, + dst: usize, + iter: impl Iterator, + written: &mut usize, + ) { + iter.enumerate().for_each(|(i, element)| unsafe { + self.buffer_write(dst + i, element); + *written += 1; + }); + } + + /// Writes all values from `iter` to `dst`, wrapping + /// at the end of the buffer and returns the number + /// of written values. + /// + /// # Safety + /// + /// Assumes that `iter` yields at most `len` items. + /// Assumes capacity is sufficient. + unsafe fn write_iter_wrapping( + &mut self, + dst: usize, + mut iter: impl Iterator, + len: usize, + ) -> usize { + struct Guard<'a, T, const N: usize> { + deque: &'a mut SmallDeque, + written: usize, + } + + impl<'a, T, const N: usize> Drop for Guard<'a, T, N> { + fn drop(&mut self) { + self.deque.len += self.written; + } + } + + let head_room = self.capacity() - dst; + + let mut guard = Guard { + deque: self, + written: 0, + }; + + if head_room >= len { + unsafe { guard.deque.write_iter(dst, iter, &mut guard.written) }; + } else { + unsafe { + guard.deque.write_iter( + dst, + ByRefSized(&mut iter).take(head_room), + &mut guard.written, + ); + guard.deque.write_iter(0, iter, &mut guard.written) + }; + } + + guard.written + } + + /// Frobs the head and tail sections around to handle the fact that we + /// just reallocated. Unsafe because it trusts old_capacity. + #[inline] + unsafe fn handle_capacity_increase(&mut self, old_capacity: usize) { + let new_capacity = self.capacity(); + debug_assert!(new_capacity >= old_capacity); + + // Move the shortest contiguous section of the ring buffer + // + // H := head + // L := last element (`self.to_physical_idx(self.len - 1)`) + // + // H L + // [o o o o o o o o ] + // H L + // A [o o o o o o o o . . . . . . . . ] + // L H + // [o o o o o o o o ] + // H L + // B [. . . o o o o o o o o . . . . . ] + // L H + // [o o o o o o o o ] + // L H + // C [o o o o o o . . . . . . . . o o ] + + // can't use is_contiguous() because the capacity is already updated. + if self.head <= old_capacity - self.len { + // A + // Nop + } else { + let head_len = old_capacity - self.head; + let tail_len = self.len - head_len; + if head_len > tail_len && new_capacity - old_capacity >= tail_len { + // B + unsafe { + self.copy_nonoverlapping(0, old_capacity, tail_len); + } + } else { + // C + let new_head = new_capacity - head_len; + unsafe { + // can't use copy_nonoverlapping here, because if e.g. head_len = 2 + // and new_capacity = old_capacity + 1, then the heads overlap. + self.copy(self.head, new_head, head_len); + } + self.head = new_head; + } + } + debug_assert!(self.head < self.capacity() || self.capacity() == 0); + } +} + +/// Returns the index in the underlying buffer for a given logical element index. +#[inline] +fn wrap_index(logical_index: usize, capacity: usize) -> usize { + debug_assert!( + (logical_index == 0 && capacity == 0) + || logical_index < capacity + || (logical_index - capacity) < capacity + ); + if logical_index >= capacity { + logical_index - capacity + } else { + logical_index + } +} + +impl PartialEq for SmallDeque { + fn eq(&self, other: &Self) -> bool { + if self.len != other.len() { + return false; + } + let (sa, sb) = self.as_slices(); + let (oa, ob) = other.as_slices(); + match sa.len().cmp(&oa.len()) { + Ordering::Equal => sa == oa && sb == ob, + Ordering::Less => { + // Always divisible in three sections, for example: + // self: [a b c|d e f] + // other: [0 1 2 3|4 5] + // front = 3, mid = 1, + // [a b c] == [0 1 2] && [d] == [3] && [e f] == [4 5] + let front = sa.len(); + let mid = oa.len() - front; + + let (oa_front, oa_mid) = oa.split_at(front); + let (sb_mid, sb_back) = sb.split_at(mid); + debug_assert_eq!(sa.len(), oa_front.len()); + debug_assert_eq!(sb_mid.len(), oa_mid.len()); + debug_assert_eq!(sb_back.len(), ob.len()); + sa == oa_front && sb_mid == oa_mid && sb_back == ob + } + Ordering::Greater => { + let front = oa.len(); + let mid = sa.len() - front; + + let (sa_front, sa_mid) = sa.split_at(front); + let (ob_mid, ob_back) = ob.split_at(mid); + debug_assert_eq!(sa_front.len(), oa.len()); + debug_assert_eq!(sa_mid.len(), ob_mid.len()); + debug_assert_eq!(sb.len(), ob_back.len()); + sa_front == oa && sa_mid == ob_mid && sb == ob_back + } + } + } +} + +impl Eq for SmallDeque {} + +macro_rules! __impl_slice_eq1 { + ([$($vars:tt)*] $lhs:ty, $rhs:ty, $($constraints:tt)*) => { + impl PartialEq<$rhs> for $lhs + where + T: PartialEq, + $($constraints)* + { + fn eq(&self, other: &$rhs) -> bool { + if self.len() != other.len() { + return false; + } + let (sa, sb) = self.as_slices(); + let (oa, ob) = other[..].split_at(sa.len()); + sa == oa && sb == ob + } + } + } +} + +__impl_slice_eq1! { [A: alloc::alloc::Allocator, const N: usize] SmallDeque, Vec, } +__impl_slice_eq1! { [const N: usize] SmallDeque, SmallVec<[U; N]>, } +__impl_slice_eq1! { [const N: usize] SmallDeque, &[U], } +__impl_slice_eq1! { [const N: usize] SmallDeque, &mut [U], } +__impl_slice_eq1! { [const N: usize, const M: usize] SmallDeque, [U; M], } +__impl_slice_eq1! { [const N: usize, const M: usize] SmallDeque, &[U; M], } +__impl_slice_eq1! { [const N: usize, const M: usize] SmallDeque, &mut [U; M], } + +impl PartialOrd for SmallDeque { + fn partial_cmp(&self, other: &Self) -> Option { + self.iter().partial_cmp(other.iter()) + } +} + +impl Ord for SmallDeque { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + self.iter().cmp(other.iter()) + } +} + +impl core::hash::Hash for SmallDeque { + fn hash(&self, state: &mut H) { + state.write_length_prefix(self.len); + // It's not possible to use Hash::hash_slice on slices + // returned by as_slices method as their length can vary + // in otherwise identical deques. + // + // Hasher only guarantees equivalence for the exact same + // set of calls to its methods. + self.iter().for_each(|elem| elem.hash(state)); + } +} + +impl Index for SmallDeque { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &T { + self.get(index).expect("Out of bounds access") + } +} + +impl IndexMut for SmallDeque { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut T { + self.get_mut(index).expect("Out of bounds access") + } +} + +impl FromIterator for SmallDeque { + fn from_iter>(iter: I) -> Self { + SpecFromIter::spec_from_iter(iter.into_iter()) + } +} + +impl IntoIterator for SmallDeque { + type IntoIter = IntoIter; + type Item = T; + + /// Consumes the deque into a front-to-back iterator yielding elements by + /// value. + fn into_iter(self) -> IntoIter { + IntoIter::new(self) + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a SmallDeque { + type IntoIter = Iter<'a, T>; + type Item = &'a T; + + fn into_iter(self) -> Iter<'a, T> { + self.iter() + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a mut SmallDeque { + type IntoIter = IterMut<'a, T>; + type Item = &'a mut T; + + fn into_iter(self) -> IterMut<'a, T> { + self.iter_mut() + } +} + +impl Extend for SmallDeque { + fn extend>(&mut self, iter: I) { + >::spec_extend(self, iter.into_iter()); + } + + #[inline] + fn extend_one(&mut self, elem: T) { + self.push_back(elem); + } + + #[inline] + fn extend_reserve(&mut self, additional: usize) { + self.reserve(additional); + } + + #[inline] + unsafe fn extend_one_unchecked(&mut self, item: T) { + // SAFETY: Our preconditions ensure the space has been reserved, and `extend_reserve` is implemented correctly. + unsafe { + self.push_unchecked(item); + } + } +} + +impl<'a, T: 'a + Copy, const N: usize> Extend<&'a T> for SmallDeque { + fn extend>(&mut self, iter: I) { + self.spec_extend(iter.into_iter()); + } + + #[inline] + fn extend_one(&mut self, &elem: &'a T) { + self.push_back(elem); + } + + #[inline] + fn extend_reserve(&mut self, additional: usize) { + self.reserve(additional); + } + + #[inline] + unsafe fn extend_one_unchecked(&mut self, &item: &'a T) { + // SAFETY: Our preconditions ensure the space has been reserved, and `extend_reserve` is implemented correctly. + unsafe { + self.push_unchecked(item); + } + } +} + +impl fmt::Debug for SmallDeque { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.iter()).finish() + } +} + +impl From> for SmallDeque { + /// Turn a [`SmallVec<[T; N]>`] into a [`SmallDeque`]. + /// + /// [`SmallVec<[T; N]>`]: smallvec::SmallVec + /// [`SmallDeque`]: crate::adt::SmallDeque + /// + /// This conversion is guaranteed to run in *O*(1) time + /// and to not re-allocate the `Vec`'s buffer or allocate + /// any additional memory. + #[inline] + fn from(buf: SmallVec<[T; N]>) -> Self { + let len = buf.len(); + Self { head: 0, len, buf } + } +} + +impl From> for SmallVec<[T; N]> { + /// Turn a [`SmallDeque`] into a [`SmallVec<[T; N]>`]. + /// + /// [`SmallVec<[T; N]>`]: smallvec::SmallVec + /// [`SmallDeque`]: crate::adt::SmallDeque + /// + /// This never needs to re-allocate, but does need to do *O*(*n*) data movement if + /// the circular buffer doesn't happen to be at the beginning of the allocation. + fn from(mut other: SmallDeque) -> Self { + use core::mem::ManuallyDrop; + + other.make_contiguous(); + + unsafe { + if other.buf.spilled() { + let mut other = ManuallyDrop::new(other); + let buf = other.buf.as_mut_ptr(); + let len = other.len(); + let cap = other.capacity(); + + if other.head != 0 { + ptr::copy(buf.add(other.head), buf, len); + } + SmallVec::from_raw_parts(buf, len, cap) + } else { + // `other` is entirely stack-allocated, so we need to produce a new copy that + // has all of the elements starting at index 0, if not already + if other.head == 0 { + // Steal the underlying vec, and make sure that the length is set + let mut buf = other.buf; + buf.set_len(other.len); + buf + } else { + let mut other = ManuallyDrop::new(other); + let ptr = other.buf.as_mut_ptr(); + let len = other.len(); + + // Construct an uninitialized array on the stack of the same size as the target + // SmallVec's inline size, "move" `len` items into it, and the construct the + // SmallVec from the raw buffer and len + let mut buf = core::mem::MaybeUninit::::uninit_array::(); + let buf_ptr = core::mem::MaybeUninit::slice_as_mut_ptr(&mut buf); + ptr::copy(ptr.add(other.head), buf_ptr, len); + // While we are technically potentially letting a subset of elements in the + // array that never got uninitialized, be assumed to have been initialized + // here - that fact is never material: no references are ever created to those + // items, and the array is never dropped, as it is immediately placed in a + // ManuallyDrop, and the vector length is set to `len` before any access can + // be made to the vector + SmallVec::from_buf_and_len(core::mem::MaybeUninit::array_assume_init(buf), len) + } + } + } + } +} + +impl From<[T; N]> for SmallDeque { + /// Converts a `[T; N]` into a `SmallDeque`. + fn from(arr: [T; N]) -> Self { + use core::mem::ManuallyDrop; + + let mut deq = SmallDeque::<_, N>::with_capacity(N); + let arr = ManuallyDrop::new(arr); + if !::IS_ZST { + // SAFETY: SmallDeque::with_capacity ensures that there is enough capacity. + unsafe { + ptr::copy_nonoverlapping(arr.as_ptr(), deq.ptr_mut(), N); + } + } + deq.head = 0; + deq.len = N; + deq + } +} + +/// An iterator over the elements of a `SmallDeque`. +/// +/// This `struct` is created by the [`iter`] method on [`SmallDeque`]. See its +/// documentation for more. +/// +/// [`iter`]: SmallDeque::iter +#[derive(Clone)] +pub struct Iter<'a, T: 'a> { + i1: core::slice::Iter<'a, T>, + i2: core::slice::Iter<'a, T>, +} + +impl<'a, T> Iter<'a, T> { + pub(super) fn new(i1: core::slice::Iter<'a, T>, i2: core::slice::Iter<'a, T>) -> Self { + Self { i1, i2 } + } +} + +impl fmt::Debug for Iter<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Iter") + .field(&self.i1.as_slice()) + .field(&self.i2.as_slice()) + .finish() + } +} + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = &'a T; + + #[inline] + fn next(&mut self) -> Option<&'a T> { + match self.i1.next() { + Some(val) => Some(val), + None => { + // most of the time, the iterator will either always + // call next(), or always call next_back(). By swapping + // the iterators once the first one is empty, we ensure + // that the first branch is taken as often as possible, + // without sacrificing correctness, as i1 is empty anyways + core::mem::swap(&mut self.i1, &mut self.i2); + self.i1.next() + } + } + } + + fn advance_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + let remaining = self.i1.advance_by(n); + match remaining { + Ok(()) => Ok(()), + Err(n) => { + core::mem::swap(&mut self.i1, &mut self.i2); + self.i1.advance_by(n.get()) + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } + + fn fold(self, accum: Acc, mut f: F) -> Acc + where + F: FnMut(Acc, Self::Item) -> Acc, + { + let accum = self.i1.fold(accum, &mut f); + self.i2.fold(accum, &mut f) + } + + fn try_fold(&mut self, init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + let acc = self.i1.try_fold(init, &mut f)?; + self.i2.try_fold(acc, &mut f) + } + + #[inline] + fn last(mut self) -> Option<&'a T> { + self.next_back() + } +} + +impl<'a, T> DoubleEndedIterator for Iter<'a, T> { + #[inline] + fn next_back(&mut self) -> Option<&'a T> { + match self.i2.next_back() { + Some(val) => Some(val), + None => { + // most of the time, the iterator will either always + // call next(), or always call next_back(). By swapping + // the iterators once the second one is empty, we ensure + // that the first branch is taken as often as possible, + // without sacrificing correctness, as i2 is empty anyways + core::mem::swap(&mut self.i1, &mut self.i2); + self.i2.next_back() + } + } + } + + fn advance_back_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + match self.i2.advance_back_by(n) { + Ok(()) => Ok(()), + Err(n) => { + core::mem::swap(&mut self.i1, &mut self.i2); + self.i2.advance_back_by(n.get()) + } + } + } + + fn rfold(self, accum: Acc, mut f: F) -> Acc + where + F: FnMut(Acc, Self::Item) -> Acc, + { + let accum = self.i2.rfold(accum, &mut f); + self.i1.rfold(accum, &mut f) + } + + fn try_rfold(&mut self, init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + let acc = self.i2.try_rfold(init, &mut f)?; + self.i1.try_rfold(acc, &mut f) + } +} + +impl ExactSizeIterator for Iter<'_, T> { + fn len(&self) -> usize { + self.i1.len() + self.i2.len() + } + + fn is_empty(&self) -> bool { + self.i1.is_empty() && self.i2.is_empty() + } +} + +impl core::iter::FusedIterator for Iter<'_, T> {} + +unsafe impl core::iter::TrustedLen for Iter<'_, T> {} + +/// A mutable iterator over the elements of a `SmallDeque`. +/// +/// This `struct` is created by the [`iter_mut`] method on [`SmallDeque`]. See its +/// documentation for more. +/// +/// [`iter_mut`]: SmallDeque::iter_mut +pub struct IterMut<'a, T: 'a> { + i1: core::slice::IterMut<'a, T>, + i2: core::slice::IterMut<'a, T>, +} + +impl<'a, T> IterMut<'a, T> { + pub(super) fn new(i1: core::slice::IterMut<'a, T>, i2: core::slice::IterMut<'a, T>) -> Self { + Self { i1, i2 } + } +} + +impl fmt::Debug for IterMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IterMut") + .field(&self.i1.as_slice()) + .field(&self.i2.as_slice()) + .finish() + } +} + +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + #[inline] + fn next(&mut self) -> Option<&'a mut T> { + match self.i1.next() { + Some(val) => Some(val), + None => { + // most of the time, the iterator will either always + // call next(), or always call next_back(). By swapping + // the iterators once the first one is empty, we ensure + // that the first branch is taken as often as possible, + // without sacrificing correctness, as i1 is empty anyways + core::mem::swap(&mut self.i1, &mut self.i2); + self.i1.next() + } + } + } + + fn advance_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + match self.i1.advance_by(n) { + Ok(()) => Ok(()), + Err(remaining) => { + core::mem::swap(&mut self.i1, &mut self.i2); + self.i1.advance_by(remaining.get()) + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } + + fn fold(self, accum: Acc, mut f: F) -> Acc + where + F: FnMut(Acc, Self::Item) -> Acc, + { + let accum = self.i1.fold(accum, &mut f); + self.i2.fold(accum, &mut f) + } + + fn try_fold(&mut self, init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + let acc = self.i1.try_fold(init, &mut f)?; + self.i2.try_fold(acc, &mut f) + } + + #[inline] + fn last(mut self) -> Option<&'a mut T> { + self.next_back() + } +} + +impl<'a, T> DoubleEndedIterator for IterMut<'a, T> { + #[inline] + fn next_back(&mut self) -> Option<&'a mut T> { + match self.i2.next_back() { + Some(val) => Some(val), + None => { + // most of the time, the iterator will either always + // call next(), or always call next_back(). By swapping + // the iterators once the first one is empty, we ensure + // that the first branch is taken as often as possible, + // without sacrificing correctness, as i2 is empty anyways + core::mem::swap(&mut self.i1, &mut self.i2); + self.i2.next_back() + } + } + } + + fn advance_back_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + match self.i2.advance_back_by(n) { + Ok(()) => Ok(()), + Err(remaining) => { + core::mem::swap(&mut self.i1, &mut self.i2); + self.i2.advance_back_by(remaining.get()) + } + } + } + + fn rfold(self, accum: Acc, mut f: F) -> Acc + where + F: FnMut(Acc, Self::Item) -> Acc, + { + let accum = self.i2.rfold(accum, &mut f); + self.i1.rfold(accum, &mut f) + } + + fn try_rfold(&mut self, init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + let acc = self.i2.try_rfold(init, &mut f)?; + self.i1.try_rfold(acc, &mut f) + } +} + +impl ExactSizeIterator for IterMut<'_, T> { + fn len(&self) -> usize { + self.i1.len() + self.i2.len() + } + + fn is_empty(&self) -> bool { + self.i1.is_empty() && self.i2.is_empty() + } +} + +impl core::iter::FusedIterator for IterMut<'_, T> {} + +unsafe impl core::iter::TrustedLen for IterMut<'_, T> {} + +/// An owning iterator over the elements of a `SmallDeque`. +/// +/// This `struct` is created by the [`into_iter`] method on [`SmallDeque`] +/// (provided by the [`IntoIterator`] trait). See its documentation for more. +/// +/// [`into_iter`]: SmallDeque::into_iter +#[derive(Clone)] +pub struct IntoIter { + inner: SmallDeque, +} + +impl IntoIter { + pub(super) fn new(inner: SmallDeque) -> Self { + IntoIter { inner } + } + + pub(super) fn into_smalldeque(self) -> SmallDeque { + self.inner + } +} + +impl fmt::Debug for IntoIter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IntoIter").field(&self.inner).finish() + } +} + +impl Iterator for IntoIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.inner.pop_front() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.inner.len(); + (len, Some(len)) + } + + #[inline] + fn advance_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + let len = self.inner.len; + let rem = if len < n { + self.inner.clear(); + n - len + } else { + self.inner.drain(..n); + 0 + }; + core::num::NonZero::new(rem).map_or(Ok(()), Err) + } + + #[inline] + fn count(self) -> usize { + self.inner.len + } + + fn try_fold(&mut self, mut init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + struct Guard<'a, T, const M: usize> { + deque: &'a mut SmallDeque, + // `consumed <= deque.len` always holds. + consumed: usize, + } + + impl<'a, T, const M: usize> Drop for Guard<'a, T, M> { + fn drop(&mut self) { + self.deque.len -= self.consumed; + self.deque.head = self.deque.to_physical_idx(self.consumed); + } + } + + let mut guard = Guard { + deque: &mut self.inner, + consumed: 0, + }; + + let (head, tail) = guard.deque.as_slices(); + + init = head + .iter() + .map(|elem| { + guard.consumed += 1; + // SAFETY: Because we incremented `guard.consumed`, the + // deque effectively forgot the element, so we can take + // ownership + unsafe { ptr::read(elem) } + }) + .try_fold(init, &mut f)?; + + tail.iter() + .map(|elem| { + guard.consumed += 1; + // SAFETY: Same as above. + unsafe { ptr::read(elem) } + }) + .try_fold(init, &mut f) + } + + #[inline] + fn fold(mut self, init: B, mut f: F) -> B + where + F: FnMut(B, Self::Item) -> B, + { + match self.try_fold(init, |b, item| Ok::(f(b, item))) { + Ok(b) => b, + Err(e) => match e {}, + } + } + + #[inline] + fn last(mut self) -> Option { + self.inner.pop_back() + } + + fn next_chunk( + &mut self, + ) -> Result<[Self::Item; N], core::array::IntoIter> { + let mut raw_arr = core::mem::MaybeUninit::uninit_array(); + let raw_arr_ptr = raw_arr.as_mut_ptr().cast(); + let (head, tail) = self.inner.as_slices(); + + if head.len() >= N { + // SAFETY: By manually adjusting the head and length of the deque, we effectively + // make it forget the first `N` elements, so taking ownership of them is safe. + unsafe { ptr::copy_nonoverlapping(head.as_ptr(), raw_arr_ptr, N) }; + self.inner.head = self.inner.to_physical_idx(N); + self.inner.len -= N; + // SAFETY: We initialized the entire array with items from `head` + return Ok(unsafe { raw_arr.transpose().assume_init() }); + } + + // SAFETY: Same argument as above. + unsafe { ptr::copy_nonoverlapping(head.as_ptr(), raw_arr_ptr, head.len()) }; + let remaining = N - head.len(); + + if tail.len() >= remaining { + // SAFETY: Same argument as above. + unsafe { + ptr::copy_nonoverlapping(tail.as_ptr(), raw_arr_ptr.add(head.len()), remaining) + }; + self.inner.head = self.inner.to_physical_idx(N); + self.inner.len -= N; + // SAFETY: We initialized the entire array with items from `head` and `tail` + Ok(unsafe { raw_arr.transpose().assume_init() }) + } else { + // SAFETY: Same argument as above. + unsafe { + ptr::copy_nonoverlapping(tail.as_ptr(), raw_arr_ptr.add(head.len()), tail.len()) + }; + let init = head.len() + tail.len(); + // We completely drained all the deques elements. + self.inner.head = 0; + self.inner.len = 0; + // SAFETY: We copied all elements from both slices to the beginning of the array, so + // the given range is initialized. + Err(unsafe { core::array::IntoIter::new_unchecked(raw_arr, 0..init) }) + } + } +} + +impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + self.inner.pop_back() + } + + #[inline] + fn advance_back_by(&mut self, n: usize) -> Result<(), core::num::NonZero> { + let len = self.inner.len; + let rem = if len < n { + self.inner.clear(); + n - len + } else { + self.inner.truncate(len - n); + 0 + }; + core::num::NonZero::new(rem).map_or(Ok(()), Err) + } + + fn try_rfold(&mut self, mut init: B, mut f: F) -> R + where + F: FnMut(B, Self::Item) -> R, + R: core::ops::Try, + { + struct Guard<'a, T, const N: usize> { + deque: &'a mut SmallDeque, + // `consumed <= deque.len` always holds. + consumed: usize, + } + + impl<'a, T, const N: usize> Drop for Guard<'a, T, N> { + fn drop(&mut self) { + self.deque.len -= self.consumed; + } + } + + let mut guard = Guard { + deque: &mut self.inner, + consumed: 0, + }; + + let (head, tail) = guard.deque.as_slices(); + + init = tail + .iter() + .map(|elem| { + guard.consumed += 1; + // SAFETY: See `try_fold`'s safety comment. + unsafe { ptr::read(elem) } + }) + .try_rfold(init, &mut f)?; + + head.iter() + .map(|elem| { + guard.consumed += 1; + // SAFETY: Same as above. + unsafe { ptr::read(elem) } + }) + .try_rfold(init, &mut f) + } + + #[inline] + fn rfold(mut self, init: B, mut f: F) -> B + where + F: FnMut(B, Self::Item) -> B, + { + match self.try_rfold(init, |b, item| Ok::(f(b, item))) { + Ok(b) => b, + Err(e) => match e {}, + } + } +} + +impl ExactSizeIterator for IntoIter { + #[inline] + fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +impl core::iter::FusedIterator for IntoIter {} + +unsafe impl core::iter::TrustedLen for IntoIter {} + +/// A draining iterator over the elements of a `SmallDeque`. +/// +/// This `struct` is created by the [`drain`] method on [`SmallDeque`]. See its +/// documentation for more. +/// +/// [`drain`]: SmallDeque::drain +pub struct Drain<'a, T: 'a, const N: usize> { + // We can't just use a &mut SmallDeque, as that would make Drain invariant over T + // and we want it to be covariant instead + deque: NonNull>, + // drain_start is stored in deque.len + drain_len: usize, + // index into the logical array, not the physical one (always lies in [0..deque.len)) + idx: usize, + // number of elements remaining after dropping the drain + new_len: usize, + remaining: usize, + // Needed to make Drain covariant over T + _marker: core::marker::PhantomData<&'a T>, +} + +impl<'a, T, const N: usize> Drain<'a, T, N> { + pub(super) unsafe fn new( + deque: &'a mut SmallDeque, + drain_start: usize, + drain_len: usize, + ) -> Self { + let orig_len = core::mem::replace(&mut deque.len, drain_start); + let new_len = orig_len - drain_len; + Drain { + deque: NonNull::from(deque), + drain_len, + idx: drain_start, + new_len, + remaining: drain_len, + _marker: core::marker::PhantomData, + } + } + + // Only returns pointers to the slices, as that's all we need + // to drop them. May only be called if `self.remaining != 0`. + unsafe fn as_slices(&mut self) -> (*mut [T], *mut [T]) { + unsafe { + let deque = self.deque.as_mut(); + + // We know that `self.idx + self.remaining <= deque.len <= usize::MAX`, so this won't overflow. + let logical_remaining_range = self.idx..self.idx + self.remaining; + + // SAFETY: `logical_remaining_range` represents the + // range into the logical buffer of elements that + // haven't been drained yet, so they're all initialized, + // and `slice::range(start..end, end) == start..end`, + // so the preconditions for `slice_ranges` are met. + let (a_range, b_range) = + deque.slice_ranges(logical_remaining_range.clone(), logical_remaining_range.end); + (deque.buffer_range_mut(a_range), deque.buffer_range_mut(b_range)) + } + } +} + +impl fmt::Debug for Drain<'_, T, N> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Drain") + .field(&self.drain_len) + .field(&self.idx) + .field(&self.new_len) + .field(&self.remaining) + .finish() + } +} + +unsafe impl Sync for Drain<'_, T, N> {} +unsafe impl Send for Drain<'_, T, N> {} + +impl Drop for Drain<'_, T, N> { + fn drop(&mut self) { + struct DropGuard<'r, 'a, T, const N: usize>(&'r mut Drain<'a, T, N>); + + let guard = DropGuard(self); + + if core::mem::needs_drop::() && guard.0.remaining != 0 { + unsafe { + // SAFETY: We just checked that `self.remaining != 0`. + let (front, back) = guard.0.as_slices(); + // since idx is a logical index, we don't need to worry about wrapping. + guard.0.idx += front.len(); + guard.0.remaining -= front.len(); + ptr::drop_in_place(front); + guard.0.remaining = 0; + ptr::drop_in_place(back); + } + } + + // Dropping `guard` handles moving the remaining elements into place. + impl<'r, 'a, T, const N: usize> Drop for DropGuard<'r, 'a, T, N> { + #[inline] + fn drop(&mut self) { + if core::mem::needs_drop::() && self.0.remaining != 0 { + unsafe { + // SAFETY: We just checked that `self.remaining != 0`. + let (front, back) = self.0.as_slices(); + ptr::drop_in_place(front); + ptr::drop_in_place(back); + } + } + + let source_deque = unsafe { self.0.deque.as_mut() }; + + let drain_len = self.0.drain_len; + let new_len = self.0.new_len; + + if T::IS_ZST { + // no need to copy around any memory if T is a ZST + source_deque.len = new_len; + return; + } + + let head_len = source_deque.len; // #elements in front of the drain + let tail_len = new_len - head_len; // #elements behind the drain + + // Next, we will fill the hole left by the drain with as few writes as possible. + // The code below handles the following control flow and reduces the amount of + // branches under the assumption that `head_len == 0 || tail_len == 0`, i.e. + // draining at the front or at the back of the dequeue is especially common. + // + // H = "head index" = `deque.head` + // h = elements in front of the drain + // d = elements in the drain + // t = elements behind the drain + // + // Note that the buffer may wrap at any point and the wrapping is handled by + // `wrap_copy` and `to_physical_idx`. + // + // Case 1: if `head_len == 0 && tail_len == 0` + // Everything was drained, reset the head index back to 0. + // H + // [ . . . . . d d d d . . . . . ] + // H + // [ . . . . . . . . . . . . . . ] + // + // Case 2: else if `tail_len == 0` + // Don't move data or the head index. + // H + // [ . . . h h h h d d d d . . . ] + // H + // [ . . . h h h h . . . . . . . ] + // + // Case 3: else if `head_len == 0` + // Don't move data, but move the head index. + // H + // [ . . . d d d d t t t t . . . ] + // H + // [ . . . . . . . t t t t . . . ] + // + // Case 4: else if `tail_len <= head_len` + // Move data, but not the head index. + // H + // [ . . h h h h d d d d t t . . ] + // H + // [ . . h h h h t t . . . . . . ] + // + // Case 5: else + // Move data and the head index. + // H + // [ . . h h d d d d t t t t . . ] + // H + // [ . . . . . . h h t t t t . . ] + + // When draining at the front (`.drain(..n)`) or at the back (`.drain(n..)`), + // we don't need to copy any data. The number of elements copied would be 0. + if head_len != 0 && tail_len != 0 { + join_head_and_tail_wrapping(source_deque, drain_len, head_len, tail_len); + // Marking this function as cold helps LLVM to eliminate it entirely if + // this branch is never taken. + // We use `#[cold]` instead of `#[inline(never)]`, because inlining this + // function into the general case (`.drain(n..m)`) is fine. + // See `tests/codegen/vecdeque-drain.rs` for a test. + #[cold] + fn join_head_and_tail_wrapping( + source_deque: &mut SmallDeque, + drain_len: usize, + head_len: usize, + tail_len: usize, + ) { + // Pick whether to move the head or the tail here. + let (src, dst, len); + if head_len < tail_len { + src = source_deque.head; + dst = source_deque.to_physical_idx(drain_len); + len = head_len; + } else { + src = source_deque.to_physical_idx(head_len + drain_len); + dst = source_deque.to_physical_idx(head_len); + len = tail_len; + }; + + unsafe { + source_deque.wrap_copy(src, dst, len); + } + } + } + + if new_len == 0 { + // Special case: If the entire dequeue was drained, reset the head back to 0, + // like `.clear()` does. + source_deque.head = 0; + } else if head_len < tail_len { + // If we moved the head above, then we need to adjust the head index here. + source_deque.head = source_deque.to_physical_idx(drain_len); + } + source_deque.len = new_len; + } + } + } +} + +impl Iterator for Drain<'_, T, N> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + let wrapped_idx = unsafe { self.deque.as_ref().to_physical_idx(self.idx) }; + self.idx += 1; + self.remaining -= 1; + Some(unsafe { self.deque.as_mut().buffer_read(wrapped_idx) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining; + (len, Some(len)) + } +} + +impl DoubleEndedIterator for Drain<'_, T, N> { + #[inline] + fn next_back(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + self.remaining -= 1; + let wrapped_idx = unsafe { self.deque.as_ref().to_physical_idx(self.idx + self.remaining) }; + Some(unsafe { self.deque.as_mut().buffer_read(wrapped_idx) }) + } +} + +impl ExactSizeIterator for Drain<'_, T, N> {} + +impl core::iter::FusedIterator for Drain<'_, T, N> {} + +/// Specialization trait used for `SmallDeque::from_iter` +trait SpecFromIter { + fn spec_from_iter(iter: I) -> Self; +} + +impl SpecFromIter for SmallDeque +where + I: Iterator, +{ + default fn spec_from_iter(iterator: I) -> Self { + // Since converting is O(1) now, just re-use the `SmallVec` logic for + // anything where we can't do something extra-special for `SmallDeque`, + // especially as that could save us some monomorphization work + // if one uses the same iterators (like slice ones) with both. + SmallVec::from_iter(iterator).into() + } +} + +impl SpecFromIter> for SmallDeque { + #[inline] + fn spec_from_iter(iterator: IntoIter) -> Self { + iterator.into_smalldeque() + } +} + +// Specialization trait used for SmallDeque::extend +trait SpecExtend { + fn spec_extend(&mut self, iter: I); +} + +impl SpecExtend for SmallDeque +where + I: Iterator, +{ + default fn spec_extend(&mut self, mut iter: I) { + // This function should be the moral equivalent of: + // + // for item in iter { + // self.push_back(item); + // } + + // May only be called if `deque.len() < deque.capacity()` + unsafe fn push_unchecked(deque: &mut SmallDeque, element: T) { + // SAFETY: Because of the precondition, it's guaranteed that there is space + // in the logical array after the last element. + unsafe { deque.buffer_write(deque.to_physical_idx(deque.len), element) }; + // This can't overflow because `deque.len() < deque.capacity() <= usize::MAX`. + deque.len += 1; + } + + while let Some(element) = iter.next() { + let (lower, _) = iter.size_hint(); + self.reserve(lower.saturating_add(1)); + + // SAFETY: We just reserved space for at least one element. + unsafe { push_unchecked(self, element) }; + + // Inner loop to avoid repeatedly calling `reserve`. + while self.len < self.capacity() { + let Some(element) = iter.next() else { + return; + }; + // SAFETY: The loop condition guarantees that `self.len() < self.capacity()`. + unsafe { push_unchecked(self, element) }; + } + } + } +} + +impl SpecExtend for SmallDeque +where + I: core::iter::TrustedLen, +{ + default fn spec_extend(&mut self, iter: I) { + // This is the case for a TrustedLen iterator. + let (low, high) = iter.size_hint(); + if let Some(additional) = high { + debug_assert_eq!( + low, + additional, + "TrustedLen iterator's size hint is not exact: {:?}", + (low, high) + ); + self.reserve(additional); + + let written = unsafe { + self.write_iter_wrapping(self.to_physical_idx(self.len), iter, additional) + }; + + debug_assert_eq!( + additional, written, + "The number of items written to SmallDeque doesn't match the TrustedLen size hint" + ); + } else { + // Per TrustedLen contract a `None` upper bound means that the iterator length + // truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway. + // Since the other branch already panics eagerly (via `reserve()`) we do the same here. + // This avoids additional codegen for a fallback code path which would eventually + // panic anyway. + panic!("capacity overflow"); + } + } +} + +impl<'a, T: 'a, I, const N: usize> SpecExtend<&'a T, I> for SmallDeque +where + I: Iterator, + T: Copy, +{ + default fn spec_extend(&mut self, iterator: I) { + self.spec_extend(iterator.copied()) + } +} + +impl<'a, T: 'a, const N: usize> SpecExtend<&'a T, core::slice::Iter<'a, T>> for SmallDeque +where + T: Copy, +{ + fn spec_extend(&mut self, iterator: core::slice::Iter<'a, T>) { + let slice = iterator.as_slice(); + self.reserve(slice.len()); + + unsafe { + self.copy_slice(self.to_physical_idx(self.len), slice); + self.len += slice.len(); + } + } +} + +#[cfg(test)] +mod tests { + use core::iter::TrustedLen; + + use smallvec::SmallVec; + + use super::*; + + #[test] + fn test_swap_front_back_remove() { + fn test(back: bool) { + // This test checks that every single combination of tail position and length is tested. + // Capacity 15 should be large enough to cover every case. + let mut tester = SmallDeque::<_, 16>::with_capacity(15); + let usable_cap = tester.capacity(); + let final_len = usable_cap / 2; + + for len in 0..final_len { + let expected: SmallDeque<_, 16> = if back { + (0..len).collect() + } else { + (0..len).rev().collect() + }; + for head_pos in 0..usable_cap { + tester.head = head_pos; + tester.len = 0; + if back { + for i in 0..len * 2 { + tester.push_front(i); + } + for i in 0..len { + assert_eq!(tester.swap_remove_back(i), Some(len * 2 - 1 - i)); + } + } else { + for i in 0..len * 2 { + tester.push_back(i); + } + for i in 0..len { + let idx = tester.len() - 1 - i; + assert_eq!(tester.swap_remove_front(idx), Some(len * 2 - 1 - i)); + } + } + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + assert_eq!(tester, expected); + } + } + } + test(true); + test(false); + } + + #[test] + fn test_insert() { + // This test checks that every single combination of tail position, length, and + // insertion position is tested. Capacity 15 should be large enough to cover every case. + + let mut tester = SmallDeque::<_, 16>::with_capacity(15); + // can't guarantee we got 15, so have to get what we got. + // 15 would be great, but we will definitely get 2^k - 1, for k >= 4, or else + // this test isn't covering what it wants to + let cap = tester.capacity(); + + // len is the length *after* insertion + let minlen = if cfg!(miri) { cap - 1 } else { 1 }; // Miri is too slow + for len in minlen..cap { + // 0, 1, 2, .., len - 1 + let expected = (0..).take(len).collect::>(); + for head_pos in 0..cap { + for to_insert in 0..len { + tester.head = head_pos; + tester.len = 0; + for i in 0..len { + if i != to_insert { + tester.push_back(i); + } + } + tester.insert(to_insert, to_insert); + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + assert_eq!(tester, expected); + } + } + } + } + + #[test] + fn test_get() { + let mut tester = SmallDeque::<_, 16>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + + assert_eq!(tester.len(), 3); + + assert_eq!(tester.get(1), Some(&2)); + assert_eq!(tester.get(2), Some(&3)); + assert_eq!(tester.get(0), Some(&1)); + assert_eq!(tester.get(3), None); + + tester.remove(0); + + assert_eq!(tester.len(), 2); + assert_eq!(tester.get(0), Some(&2)); + assert_eq!(tester.get(1), Some(&3)); + assert_eq!(tester.get(2), None); + } + + #[test] + fn test_get_mut() { + let mut tester = SmallDeque::<_, 16>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + + assert_eq!(tester.len(), 3); + + if let Some(elem) = tester.get_mut(0) { + assert_eq!(*elem, 1); + *elem = 10; + } + + if let Some(elem) = tester.get_mut(2) { + assert_eq!(*elem, 3); + *elem = 30; + } + + assert_eq!(tester.get(0), Some(&10)); + assert_eq!(tester.get(2), Some(&30)); + assert_eq!(tester.get_mut(3), None); + + tester.remove(2); + + assert_eq!(tester.len(), 2); + assert_eq!(tester.get(0), Some(&10)); + assert_eq!(tester.get(1), Some(&2)); + assert_eq!(tester.get(2), None); + } + + #[test] + fn test_swap() { + let mut tester = SmallDeque::<_, 3>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + + assert_eq!(tester, [1, 2, 3]); + + tester.swap(0, 0); + assert_eq!(tester, [1, 2, 3]); + tester.swap(0, 1); + assert_eq!(tester, [2, 1, 3]); + tester.swap(2, 1); + assert_eq!(tester, [2, 3, 1]); + tester.swap(1, 2); + assert_eq!(tester, [2, 1, 3]); + tester.swap(0, 2); + assert_eq!(tester, [3, 1, 2]); + tester.swap(2, 2); + assert_eq!(tester, [3, 1, 2]); + } + + #[test] + #[should_panic = "assertion failed: j < self.len()"] + fn test_swap_panic() { + let mut tester = SmallDeque::<_>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + tester.swap(2, 3); + } + + #[test] + fn test_reserve_exact() { + let mut tester: SmallDeque = SmallDeque::with_capacity(1); + assert_eq!(tester.capacity(), 1); + tester.reserve_exact(50); + assert_eq!(tester.capacity(), 50); + tester.reserve_exact(40); + // reserving won't shrink the buffer + assert_eq!(tester.capacity(), 50); + tester.reserve_exact(200); + assert_eq!(tester.capacity(), 200); + } + + #[test] + #[should_panic = "capacity overflow"] + fn test_reserve_exact_panic() { + let mut tester: SmallDeque = SmallDeque::new(); + tester.reserve_exact(usize::MAX); + } + + #[test] + fn test_contains() { + let mut tester = SmallDeque::<_>::new(); + tester.push_back(1); + tester.push_back(2); + tester.push_back(3); + + assert!(tester.contains(&1)); + assert!(tester.contains(&3)); + assert!(!tester.contains(&0)); + assert!(!tester.contains(&4)); + tester.remove(0); + assert!(!tester.contains(&1)); + assert!(tester.contains(&2)); + assert!(tester.contains(&3)); + } + + #[test] + fn test_rotate_left_right() { + let mut tester: SmallDeque<_> = (1..=10).collect(); + tester.reserve(1); + + assert_eq!(tester.len(), 10); + + tester.rotate_left(0); + assert_eq!(tester, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + tester.rotate_right(0); + assert_eq!(tester, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + tester.rotate_left(3); + assert_eq!(tester, [4, 5, 6, 7, 8, 9, 10, 1, 2, 3]); + + tester.rotate_right(5); + assert_eq!(tester, [9, 10, 1, 2, 3, 4, 5, 6, 7, 8]); + + tester.rotate_left(tester.len()); + assert_eq!(tester, [9, 10, 1, 2, 3, 4, 5, 6, 7, 8]); + + tester.rotate_right(tester.len()); + assert_eq!(tester, [9, 10, 1, 2, 3, 4, 5, 6, 7, 8]); + + tester.rotate_left(1); + assert_eq!(tester, [10, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + } + + #[test] + #[should_panic = "assertion failed: n <= self.len()"] + fn test_rotate_left_panic() { + let mut tester: SmallDeque<_> = (1..=10).collect(); + tester.rotate_left(tester.len() + 1); + } + + #[test] + #[should_panic = "assertion failed: n <= self.len()"] + fn test_rotate_right_panic() { + let mut tester: SmallDeque<_> = (1..=10).collect(); + tester.rotate_right(tester.len() + 1); + } + + #[test] + fn test_binary_search() { + // If the givin SmallDeque is not sorted, the returned result is unspecified and meaningless, + // as this method performs a binary search. + + let tester: SmallDeque<_, 11> = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55].into(); + + assert_eq!(tester.binary_search(&0), Ok(0)); + assert_eq!(tester.binary_search(&5), Ok(5)); + assert_eq!(tester.binary_search(&55), Ok(10)); + assert_eq!(tester.binary_search(&4), Err(5)); + assert_eq!(tester.binary_search(&-1), Err(0)); + assert!(matches!(tester.binary_search(&1), Ok(1..=2))); + + let tester: SmallDeque<_, 14> = [1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3].into(); + assert_eq!(tester.binary_search(&1), Ok(0)); + assert!(matches!(tester.binary_search(&2), Ok(1..=4))); + assert!(matches!(tester.binary_search(&3), Ok(5..=13))); + assert_eq!(tester.binary_search(&-2), Err(0)); + assert_eq!(tester.binary_search(&0), Err(0)); + assert_eq!(tester.binary_search(&4), Err(14)); + assert_eq!(tester.binary_search(&5), Err(14)); + } + + #[test] + fn test_binary_search_by() { + // If the givin SmallDeque is not sorted, the returned result is unspecified and meaningless, + // as this method performs a binary search. + + let tester: SmallDeque = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55].into(); + + assert_eq!(tester.binary_search_by(|x| x.cmp(&0)), Ok(0)); + assert_eq!(tester.binary_search_by(|x| x.cmp(&5)), Ok(5)); + assert_eq!(tester.binary_search_by(|x| x.cmp(&55)), Ok(10)); + assert_eq!(tester.binary_search_by(|x| x.cmp(&4)), Err(5)); + assert_eq!(tester.binary_search_by(|x| x.cmp(&-1)), Err(0)); + assert!(matches!(tester.binary_search_by(|x| x.cmp(&1)), Ok(1..=2))); + } + + #[test] + fn test_binary_search_key() { + // If the givin SmallDeque is not sorted, the returned result is unspecified and meaningless, + // as this method performs a binary search. + + let tester: SmallDeque<_, 13> = [ + (-1, 0), + (2, 10), + (6, 5), + (7, 1), + (8, 10), + (10, 2), + (20, 3), + (24, 5), + (25, 18), + (28, 13), + (31, 21), + (32, 4), + (54, 25), + ] + .into(); + + assert_eq!(tester.binary_search_by_key(&-1, |&(a, _b)| a), Ok(0)); + assert_eq!(tester.binary_search_by_key(&8, |&(a, _b)| a), Ok(4)); + assert_eq!(tester.binary_search_by_key(&25, |&(a, _b)| a), Ok(8)); + assert_eq!(tester.binary_search_by_key(&54, |&(a, _b)| a), Ok(12)); + assert_eq!(tester.binary_search_by_key(&-2, |&(a, _b)| a), Err(0)); + assert_eq!(tester.binary_search_by_key(&1, |&(a, _b)| a), Err(1)); + assert_eq!(tester.binary_search_by_key(&4, |&(a, _b)| a), Err(2)); + assert_eq!(tester.binary_search_by_key(&13, |&(a, _b)| a), Err(6)); + assert_eq!(tester.binary_search_by_key(&55, |&(a, _b)| a), Err(13)); + assert_eq!(tester.binary_search_by_key(&100, |&(a, _b)| a), Err(13)); + + let tester: SmallDeque<_, 13> = [ + (0, 0), + (2, 1), + (6, 1), + (5, 1), + (3, 1), + (1, 2), + (2, 3), + (4, 5), + (5, 8), + (8, 13), + (1, 21), + (2, 34), + (4, 55), + ] + .into(); + + assert_eq!(tester.binary_search_by_key(&0, |&(_a, b)| b), Ok(0)); + assert!(matches!(tester.binary_search_by_key(&1, |&(_a, b)| b), Ok(1..=4))); + assert_eq!(tester.binary_search_by_key(&8, |&(_a, b)| b), Ok(8)); + assert_eq!(tester.binary_search_by_key(&13, |&(_a, b)| b), Ok(9)); + assert_eq!(tester.binary_search_by_key(&55, |&(_a, b)| b), Ok(12)); + assert_eq!(tester.binary_search_by_key(&-1, |&(_a, b)| b), Err(0)); + assert_eq!(tester.binary_search_by_key(&4, |&(_a, b)| b), Err(7)); + assert_eq!(tester.binary_search_by_key(&56, |&(_a, b)| b), Err(13)); + assert_eq!(tester.binary_search_by_key(&100, |&(_a, b)| b), Err(13)); + } + + #[test] + fn make_contiguous_big_head() { + let mut tester = SmallDeque::<_>::with_capacity(15); + + for i in 0..3 { + tester.push_back(i); + } + + for i in 3..10 { + tester.push_front(i); + } + + // 012......9876543 + assert_eq!(tester.capacity(), 15); + assert_eq!((&[9, 8, 7, 6, 5, 4, 3] as &[_], &[0, 1, 2] as &[_]), tester.as_slices()); + + let expected_start = tester.as_slices().1.len(); + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!((&[9, 8, 7, 6, 5, 4, 3, 0, 1, 2] as &[_], &[] as &[_]), tester.as_slices()); + } + + #[test] + fn make_contiguous_big_tail() { + let mut tester = SmallDeque::<_>::with_capacity(15); + + for i in 0..8 { + tester.push_back(i); + } + + for i in 8..10 { + tester.push_front(i); + } + + // 01234567......98 + let expected_start = 0; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!((&[9, 8, 0, 1, 2, 3, 4, 5, 6, 7] as &[_], &[] as &[_]), tester.as_slices()); + } + + #[test] + fn make_contiguous_small_free() { + let mut tester = SmallDeque::<_>::with_capacity(16); + + for i in b'A'..b'I' { + tester.push_back(i as char); + } + + for i in b'I'..b'N' { + tester.push_front(i as char); + } + + assert_eq!(tester, ['M', 'L', 'K', 'J', 'I', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']); + + // ABCDEFGH...MLKJI + let expected_start = 0; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!( + ( + &['M', 'L', 'K', 'J', 'I', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] as &[_], + &[] as &[_] + ), + tester.as_slices() + ); + + tester.clear(); + for i in b'I'..b'N' { + tester.push_back(i as char); + } + + for i in b'A'..b'I' { + tester.push_front(i as char); + } + + // IJKLM...HGFEDCBA + let expected_start = 3; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!( + ( + &['H', 'G', 'F', 'E', 'D', 'C', 'B', 'A', 'I', 'J', 'K', 'L', 'M'] as &[_], + &[] as &[_] + ), + tester.as_slices() + ); + } + + #[test] + fn make_contiguous_head_to_end() { + let mut tester = SmallDeque::<_>::with_capacity(16); + + for i in b'A'..b'L' { + tester.push_back(i as char); + } + + for i in b'L'..b'Q' { + tester.push_front(i as char); + } + + assert_eq!( + tester, + ['P', 'O', 'N', 'M', 'L', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K'] + ); + + // ABCDEFGHIJKPONML + let expected_start = 0; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!( + ( + &['P', 'O', 'N', 'M', 'L', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K'] + as &[_], + &[] as &[_] + ), + tester.as_slices() + ); + + tester.clear(); + for i in b'L'..b'Q' { + tester.push_back(i as char); + } + + for i in b'A'..b'L' { + tester.push_front(i as char); + } + + // LMNOPKJIHGFEDCBA + let expected_start = 0; + tester.make_contiguous(); + assert_eq!(tester.head, expected_start); + assert_eq!( + ( + &['K', 'J', 'I', 'H', 'G', 'F', 'E', 'D', 'C', 'B', 'A', 'L', 'M', 'N', 'O', 'P'] + as &[_], + &[] as &[_] + ), + tester.as_slices() + ); + } + + #[test] + fn make_contiguous_head_to_end_2() { + // Another test case for #79808, taken from #80293. + + let mut dq = SmallDeque::<_>::from_iter(0..6); + dq.pop_front(); + dq.pop_front(); + dq.push_back(6); + dq.push_back(7); + dq.push_back(8); + dq.make_contiguous(); + let collected: Vec<_> = dq.iter().copied().collect(); + assert_eq!(dq.as_slices(), (&collected[..], &[] as &[_])); + } + + #[test] + fn test_remove() { + // This test checks that every single combination of tail position, length, and + // removal position is tested. Capacity 15 should be large enough to cover every case. + + let mut tester = SmallDeque::<_>::with_capacity(15); + // can't guarantee we got 15, so have to get what we got. + // 15 would be great, but we will definitely get 2^k - 1, for k >= 4, or else + // this test isn't covering what it wants to + let cap = tester.capacity(); + + // len is the length *after* removal + let minlen = if cfg!(miri) { cap - 2 } else { 0 }; // Miri is too slow + for len in minlen..cap - 1 { + // 0, 1, 2, .., len - 1 + let expected = (0..).take(len).collect::>(); + for head_pos in 0..cap { + for to_remove in 0..=len { + tester.head = head_pos; + tester.len = 0; + for i in 0..len { + if i == to_remove { + tester.push_back(1234); + } + tester.push_back(i); + } + if to_remove == len { + tester.push_back(1234); + } + tester.remove(to_remove); + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + assert_eq!(tester, expected); + } + } + } + } + + #[test] + fn test_range() { + let mut tester: SmallDeque = SmallDeque::<_>::with_capacity(7); + + let cap = tester.capacity(); + let minlen = if cfg!(miri) { cap - 1 } else { 0 }; // Miri is too slow + for len in minlen..=cap { + for head in 0..=cap { + for start in 0..=len { + for end in start..=len { + tester.head = head; + tester.len = 0; + for i in 0..len { + tester.push_back(i); + } + + // Check that we iterate over the correct values + let range: SmallDeque<_> = tester.range(start..end).copied().collect(); + let expected: SmallDeque<_> = (start..end).collect(); + assert_eq!(range, expected); + } + } + } + } + } + + #[test] + fn test_range_mut() { + let mut tester: SmallDeque = SmallDeque::with_capacity(7); + + let cap = tester.capacity(); + for len in 0..=cap { + for head in 0..=cap { + for start in 0..=len { + for end in start..=len { + tester.head = head; + tester.len = 0; + for i in 0..len { + tester.push_back(i); + } + + let head_was = tester.head; + let len_was = tester.len; + + // Check that we iterate over the correct values + let range: SmallDeque<_> = + tester.range_mut(start..end).map(|v| *v).collect(); + let expected: SmallDeque<_> = (start..end).collect(); + assert_eq!(range, expected); + + // We shouldn't have changed the capacity or made the + // head or tail out of bounds + assert_eq!(tester.capacity(), cap); + assert_eq!(tester.head, head_was); + assert_eq!(tester.len, len_was); + } + } + } + } + } + + #[test] + fn test_drain() { + let mut tester: SmallDeque = SmallDeque::with_capacity(7); + + let cap = tester.capacity(); + for len in 0..=cap { + for head in 0..cap { + for drain_start in 0..=len { + for drain_end in drain_start..=len { + tester.head = head; + tester.len = 0; + for i in 0..len { + tester.push_back(i); + } + + // Check that we drain the correct values + let drained: SmallDeque<_> = tester.drain(drain_start..drain_end).collect(); + let drained_expected: SmallDeque<_> = (drain_start..drain_end).collect(); + assert_eq!(drained, drained_expected); + + // We shouldn't have changed the capacity or made the + // head or tail out of bounds + assert_eq!(tester.capacity(), cap); + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + + // We should see the correct values in the SmallDeque + let expected: SmallDeque<_> = + (0..drain_start).chain(drain_end..len).collect(); + assert_eq!(expected, tester); + } + } + } + } + } + + #[test] + fn test_split_off() { + // This test checks that every single combination of tail position, length, and + // split position is tested. Capacity 15 should be large enough to cover every case. + + let mut tester = SmallDeque::with_capacity(15); + // can't guarantee we got 15, so have to get what we got. + // 15 would be great, but we will definitely get 2^k - 1, for k >= 4, or else + // this test isn't covering what it wants to + let cap = tester.capacity(); + + // len is the length *before* splitting + let minlen = if cfg!(miri) { cap - 1 } else { 0 }; // Miri is too slow + for len in minlen..cap { + // index to split at + for at in 0..=len { + // 0, 1, 2, .., at - 1 (may be empty) + let expected_self = (0..).take(at).collect::>(); + // at, at + 1, .., len - 1 (may be empty) + let expected_other = (at..).take(len - at).collect::>(); + + for head_pos in 0..cap { + tester.head = head_pos; + tester.len = 0; + for i in 0..len { + tester.push_back(i); + } + let result = tester.split_off(at); + assert!(tester.head <= tester.capacity()); + assert!(tester.len <= tester.capacity()); + assert!(result.head <= result.capacity()); + assert!(result.len <= result.capacity()); + assert_eq!(tester, expected_self); + assert_eq!(result, expected_other); + } + } + } + } + + #[test] + fn test_from_smallvec() { + for cap in 0..35 { + for len in 0..=cap { + let mut vec = SmallVec::<[_; 16]>::with_capacity(cap); + vec.extend(0..len); + + let vd = SmallDeque::from(vec.clone()); + assert_eq!(vd.len(), vec.len()); + assert!(vd.into_iter().eq(vec)); + } + } + } + + #[test] + fn test_extend_basic() { + test_extend_impl(false); + } + + #[test] + fn test_extend_trusted_len() { + test_extend_impl(true); + } + + fn test_extend_impl(trusted_len: bool) { + struct SmallDequeTester { + test: SmallDeque, + expected: SmallDeque, + trusted_len: bool, + } + + impl SmallDequeTester { + fn new(trusted_len: bool) -> Self { + Self { + test: SmallDeque::new(), + expected: SmallDeque::new(), + trusted_len, + } + } + + fn test_extend(&mut self, iter: I) + where + I: Iterator + TrustedLen + Clone, + { + struct BasicIterator(I); + impl Iterator for BasicIterator + where + I: Iterator, + { + type Item = usize; + + fn next(&mut self) -> Option { + self.0.next() + } + } + + if self.trusted_len { + self.test.extend(iter.clone()); + } else { + self.test.extend(BasicIterator(iter.clone())); + } + + for item in iter { + self.expected.push_back(item) + } + + assert_eq!(self.test, self.expected); + } + + fn drain + Clone>(&mut self, range: R) { + self.test.drain(range.clone()); + self.expected.drain(range); + + assert_eq!(self.test, self.expected); + } + + fn clear(&mut self) { + self.test.clear(); + self.expected.clear(); + } + + fn remaining_capacity(&self) -> usize { + self.test.capacity() - self.test.len() + } + } + + let mut tester = SmallDequeTester::new(trusted_len); + + // Initial capacity + tester.test_extend(0..tester.remaining_capacity()); + + // Grow + tester.test_extend(1024..2048); + + // Wrap around + tester.drain(..128); + + tester.test_extend(0..tester.remaining_capacity()); + + // Continue + tester.drain(256..); + tester.test_extend(4096..8196); + + tester.clear(); + + // Start again + tester.test_extend(0..32); + } + + #[test] + fn test_from_array() { + fn test() { + let mut array: [usize; N] = [0; N]; + + for (i, v) in array.iter_mut().enumerate() { + *v = i; + } + + let deq: SmallDeque<_, N> = array.into(); + + for i in 0..N { + assert_eq!(deq[i], i); + } + + assert_eq!(deq.len(), N); + } + test::<0>(); + test::<1>(); + test::<2>(); + test::<32>(); + test::<35>(); + } + + #[test] + fn test_smallvec_from_smalldeque() { + fn create_vec_and_test_convert(capacity: usize, offset: usize, len: usize) { + let mut vd = SmallDeque::<_, 16>::with_capacity(capacity); + for _ in 0..offset { + vd.push_back(0); + vd.pop_front(); + } + vd.extend(0..len); + + let vec: SmallVec<_> = SmallVec::from(vd.clone()); + assert_eq!(vec.len(), vd.len()); + assert!(vec.into_iter().eq(vd)); + } + + // Miri is too slow + let max_pwr = if cfg!(miri) { 5 } else { 7 }; + + for cap_pwr in 0..max_pwr { + // Make capacity as a (2^x)-1, so that the ring size is 2^x + let cap = (2i32.pow(cap_pwr) - 1) as usize; + + // In these cases there is enough free space to solve it with copies + for len in 0..((cap + 1) / 2) { + // Test contiguous cases + for offset in 0..(cap - len) { + create_vec_and_test_convert(cap, offset, len) + } + + // Test cases where block at end of buffer is bigger than block at start + for offset in (cap - len)..(cap - (len / 2)) { + create_vec_and_test_convert(cap, offset, len) + } + + // Test cases where block at start of buffer is bigger than block at end + for offset in (cap - (len / 2))..cap { + create_vec_and_test_convert(cap, offset, len) + } + } + + // Now there's not (necessarily) space to straighten the ring with simple copies, + // the ring will use swapping when: + // (cap + 1 - offset) > (cap + 1 - len) && (len - (cap + 1 - offset)) > (cap + 1 - len)) + // right block size > free space && left block size > free space + for len in ((cap + 1) / 2)..cap { + // Test contiguous cases + for offset in 0..(cap - len) { + create_vec_and_test_convert(cap, offset, len) + } + + // Test cases where block at end of buffer is bigger than block at start + for offset in (cap - len)..(cap - (len / 2)) { + create_vec_and_test_convert(cap, offset, len) + } + + // Test cases where block at start of buffer is bigger than block at end + for offset in (cap - (len / 2))..cap { + create_vec_and_test_convert(cap, offset, len) + } + } + } + } + + #[test] + fn test_clone_from() { + use smallvec::smallvec; + + let m = smallvec![1; 8]; + let n = smallvec![2; 12]; + let limit = if cfg!(miri) { 4 } else { 8 }; // Miri is too slow + for pfv in 0..limit { + for pfu in 0..limit { + for longer in 0..2 { + let (vr, ur) = if longer == 0 { (&m, &n) } else { (&n, &m) }; + let mut v = SmallDeque::<_>::from(vr.clone()); + for _ in 0..pfv { + v.push_front(1); + } + let mut u = SmallDeque::<_>::from(ur.clone()); + for _ in 0..pfu { + u.push_front(2); + } + v.clone_from(&u); + assert_eq!(&v, &u); + } + } + } + } + + #[test] + fn test_vec_deque_truncate_drop() { + static mut DROPS: u32 = 0; + #[derive(Clone)] + struct Elem(#[allow(dead_code)] i32); + impl Drop for Elem { + fn drop(&mut self) { + unsafe { + DROPS += 1; + } + } + } + + let v = vec![Elem(1), Elem(2), Elem(3), Elem(4), Elem(5)]; + for push_front in 0..=v.len() { + let v = v.clone(); + let mut tester = SmallDeque::<_>::with_capacity(5); + for (index, elem) in v.into_iter().enumerate() { + if index < push_front { + tester.push_front(elem); + } else { + tester.push_back(elem); + } + } + assert_eq!(unsafe { DROPS }, 0); + tester.truncate(3); + assert_eq!(unsafe { DROPS }, 2); + tester.truncate(0); + assert_eq!(unsafe { DROPS }, 5); + unsafe { + DROPS = 0; + } + } + } + + #[test] + fn issue_53529() { + let mut dst = SmallDeque::<_>::new(); + dst.push_front(Box::new(1)); + dst.push_front(Box::new(2)); + assert_eq!(*dst.pop_back().unwrap(), 1); + + let mut src = SmallDeque::<_>::new(); + src.push_front(Box::new(2)); + dst.append(&mut src); + for a in dst { + assert_eq!(*a, 2); + } + } + + #[test] + fn issue_80303() { + use core::{ + hash::{Hash, Hasher}, + iter, + num::Wrapping, + }; + + // This is a valid, albeit rather bad hash function implementation. + struct SimpleHasher(Wrapping); + + impl Hasher for SimpleHasher { + fn finish(&self) -> u64 { + self.0 .0 + } + + fn write(&mut self, bytes: &[u8]) { + // This particular implementation hashes value 24 in addition to bytes. + // Such an implementation is valid as Hasher only guarantees equivalence + // for the exact same set of calls to its methods. + for &v in iter::once(&24).chain(bytes) { + self.0 = Wrapping(31) * self.0 + Wrapping(u64::from(v)); + } + } + } + + fn hash_code(value: impl Hash) -> u64 { + let mut hasher = SimpleHasher(Wrapping(1)); + value.hash(&mut hasher); + hasher.finish() + } + + // This creates two deques for which values returned by as_slices + // method differ. + let vda: SmallDeque = (0..10).collect(); + let mut vdb = SmallDeque::with_capacity(10); + vdb.extend(5..10); + (0..5).rev().for_each(|elem| vdb.push_front(elem)); + assert_ne!(vda.as_slices(), vdb.as_slices()); + assert_eq!(vda, vdb); + assert_eq!(hash_code(vda), hash_code(vdb)); + } +} diff --git a/hir2/src/adt/smallmap.rs b/hir2/src/adt/smallmap.rs new file mode 100644 index 000000000..c0b61ac0f --- /dev/null +++ b/hir2/src/adt/smallmap.rs @@ -0,0 +1,491 @@ +use core::{ + borrow::Borrow, + cmp::Ordering, + fmt, + ops::{Index, IndexMut}, +}; + +use smallvec::SmallVec; + +/// [SmallMap] is a [BTreeMap]-like structure that can store a specified number +/// of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// This data structure is designed with two goals in mind: +/// +/// * Support efficient key/value operations over a small set of keys +/// * Preserve the order of keys +/// * Avoid allocating data on the heap for the typical case +/// +/// Internally, [SmallMap] is implemented on top of [SmallVec], and uses binary search +/// to locate elements. This is quite efficient in general, and is particularly fast +/// when all of the data is stored inline, but may not be a good fit for all use cases. +/// +/// Due to its design constraints, it only supports keys which implement [Ord]. +pub struct SmallMap { + items: SmallVec<[KeyValuePair; N]>, +} +impl SmallMap +where + K: Ord, +{ + /// Returns a new, empty [SmallMap] + pub const fn new() -> Self { + Self { + items: SmallVec::new_const(), + } + } + + /// Returns a new, empty [SmallMap], with capacity for `capacity` nodes without reallocating. + pub fn with_capacity(capacity: usize) -> Self { + Self { + items: SmallVec::with_capacity(capacity), + } + } + + /// Returns true if this map is empty + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Returns the number of key/value pairs in this map + pub fn len(&self) -> usize { + self.items.len() + } + + /// Return an iterator over the key/value pairs in this map + pub fn iter(&self) -> impl DoubleEndedIterator { + self.items.iter().map(|pair| (&pair.key, &pair.value)) + } + + /// Return an iterator over mutable key/value pairs in this map + pub fn iter_mut(&mut self) -> impl DoubleEndedIterator { + self.items.iter_mut().map(|pair| (&pair.key, &mut pair.value)) + } + + /// Returns true if `key` has been inserted in this map + pub fn contains(&self, key: &Q) -> bool + where + K: Borrow, + Q: Ord + ?Sized, + { + self.find(key).is_ok() + } + + /// Returns the value under `key` in this map, if it exists + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: Ord + ?Sized, + { + match self.find(key) { + Ok(idx) => Some(&self.items[idx].value), + Err(_) => None, + } + } + + /// Returns a mutable reference to the value under `key` in this map, if it exists + pub fn get_mut(&mut self, key: &Q) -> Option<&mut V> + where + K: Borrow, + Q: Ord + ?Sized, + { + match self.find(key) { + Ok(idx) => Some(&mut self.items[idx].value), + Err(_) => None, + } + } + + /// Inserts a new entry in this map using `key` and `value`. + /// + /// Returns the previous value, if `key` was already present in the map. + pub fn insert(&mut self, key: K, value: V) -> Option { + match self.entry(key) { + Entry::Occupied(mut entry) => Some(core::mem::replace(entry.get_mut(), value)), + Entry::Vacant(entry) => { + entry.insert(value); + None + } + } + } + + /// Removes the value inserted under `key`, if it exists + pub fn remove(&mut self, key: &Q) -> Option + where + K: Borrow, + Q: Ord + ?Sized, + { + match self.find(key) { + Ok(idx) => Some(self.items.remove(idx).value), + Err(_) => None, + } + } + + /// Clear the content of the map + pub fn clear(&mut self) { + self.items.clear(); + } + + /// Returns an [Entry] which can be used to combine `contains`+`insert` type operations. + pub fn entry(&mut self, key: K) -> Entry<'_, K, V, N> { + match self.find(&key) { + Ok(idx) => Entry::occupied(self, idx), + Err(idx) => Entry::vacant(self, idx, key), + } + } + + #[inline] + fn find(&self, item: &Q) -> Result + where + K: Borrow, + Q: Ord + ?Sized, + { + self.items.binary_search_by(|probe| Ord::cmp(probe.key.borrow(), item)) + } +} +impl Default for SmallMap { + fn default() -> Self { + Self { + items: Default::default(), + } + } +} +impl Eq for SmallMap +where + K: Eq, + V: Eq, +{ +} +impl PartialEq for SmallMap +where + K: PartialEq, + V: PartialEq, +{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.items + .iter() + .map(|pair| (&pair.key, &pair.value)) + .eq(other.items.iter().map(|pair| (&pair.key, &pair.value))) + } +} +impl fmt::Debug for SmallMap +where + K: fmt::Debug + Ord, + V: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_map() + .entries(self.items.iter().map(|item| (&item.key, &item.value))) + .finish() + } +} +impl Clone for SmallMap +where + K: Clone, + V: Clone, +{ + #[inline] + fn clone(&self) -> Self { + Self { + items: self.items.clone(), + } + } +} +impl IntoIterator for SmallMap +where + K: Ord, +{ + type IntoIter = SmallMapIntoIter; + type Item = (K, V); + + #[inline] + fn into_iter(self) -> Self::IntoIter { + SmallMapIntoIter { + iter: self.items.into_iter(), + } + } +} +impl FromIterator<(K, V)> for SmallMap +where + K: Ord, +{ + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut map = Self::default(); + for (k, v) in iter { + map.insert(k, v); + } + map + } +} +impl Index<&Q> for SmallMap +where + K: Borrow + Ord, + Q: Ord + ?Sized, +{ + type Output = V; + + fn index(&self, key: &Q) -> &Self::Output { + self.get(key).unwrap() + } +} +impl IndexMut<&Q> for SmallMap +where + K: Borrow + Ord, + Q: Ord + ?Sized, +{ + fn index_mut(&mut self, key: &Q) -> &mut Self::Output { + self.get_mut(key).unwrap() + } +} + +#[doc(hidden)] +pub struct SmallMapIntoIter { + iter: smallvec::IntoIter<[KeyValuePair; N]>, +} +impl ExactSizeIterator for SmallMapIntoIter { + #[inline(always)] + fn len(&self) -> usize { + self.iter.len() + } +} +impl Iterator for SmallMapIntoIter { + type Item = (K, V); + + #[inline(always)] + fn next(&mut self) -> Option { + self.iter.next().map(|pair| (pair.key, pair.value)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } + + #[inline] + fn count(self) -> usize { + self.iter.count() + } + + #[inline] + fn last(self) -> Option<(K, V)> { + self.iter.last().map(|pair| (pair.key, pair.value)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + self.iter.nth(n).map(|pair| (pair.key, pair.value)) + } +} +impl DoubleEndedIterator for SmallMapIntoIter { + #[inline] + fn next_back(&mut self) -> Option { + self.iter.next_back().map(|pair| (pair.key, pair.value)) + } + + #[inline] + fn nth_back(&mut self, n: usize) -> Option { + self.iter.nth_back(n).map(|pair| (pair.key, pair.value)) + } +} + +/// Represents an key/value pair entry in a [SmallMap] +pub enum Entry<'a, K, V, const N: usize> { + Occupied(OccupiedEntry<'a, K, V, N>), + Vacant(VacantEntry<'a, K, V, N>), +} +impl<'a, K, V: Default, const N: usize> Entry<'a, K, V, N> { + pub fn or_default(self) -> &'a mut V { + match self { + Self::Occupied(entry) => entry.into_mut(), + Self::Vacant(entry) => entry.insert(Default::default()), + } + } +} +impl<'a, K, V, const N: usize> Entry<'a, K, V, N> { + fn occupied(map: &'a mut SmallMap, idx: usize) -> Self { + Self::Occupied(OccupiedEntry { map, idx }) + } + + fn vacant(map: &'a mut SmallMap, idx: usize, key: K) -> Self { + Self::Vacant(VacantEntry { map, idx, key }) + } + + pub fn or_insert(self, default: V) -> &'a mut V { + match self { + Self::Occupied(entry) => entry.into_mut(), + Self::Vacant(entry) => entry.insert(default), + } + } + + pub fn or_insert_with(self, default: F) -> &'a mut V + where + F: FnOnce() -> V, + { + match self { + Self::Occupied(entry) => entry.into_mut(), + Self::Vacant(entry) => entry.insert(default()), + } + } + + pub fn key(&self) -> &K { + match self { + Self::Occupied(entry) => entry.key(), + Self::Vacant(entry) => entry.key(), + } + } + + pub fn and_modify(self, f: F) -> Self + where + F: FnOnce(&mut V), + { + match self { + Self::Occupied(mut entry) => { + f(entry.get_mut()); + Self::Occupied(entry) + } + vacant @ Self::Vacant(_) => vacant, + } + } +} + +/// Represents an occupied entry in a [SmallMap] +pub struct OccupiedEntry<'a, K, V, const N: usize> { + map: &'a mut SmallMap, + idx: usize, +} +impl<'a, K, V, const N: usize> OccupiedEntry<'a, K, V, N> { + #[inline(always)] + fn get_entry(&self) -> &KeyValuePair { + &self.map.items[self.idx] + } + + pub fn remove_entry(self) -> V { + self.map.items.remove(self.idx).value + } + + pub fn key(&self) -> &K { + &self.get_entry().key + } + + pub fn get(&self) -> &V { + &self.get_entry().value + } + + pub fn get_mut(&mut self) -> &mut V { + &mut self.map.items[self.idx].value + } + + pub fn into_mut(self) -> &'a mut V { + &mut self.map.items[self.idx].value + } +} + +/// Represents a vacant entry in a [SmallMap] +pub struct VacantEntry<'a, K, V, const N: usize> { + map: &'a mut SmallMap, + idx: usize, + key: K, +} +impl<'a, K, V, const N: usize> VacantEntry<'a, K, V, N> { + pub fn key(&self) -> &K { + &self.key + } + + pub fn into_key(self) -> K { + self.key + } + + pub fn insert_with(self, f: F) -> &'a mut V + where + F: FnOnce() -> V, + { + self.map.items.insert( + self.idx, + KeyValuePair { + key: self.key, + value: f(), + }, + ); + &mut self.map.items[self.idx].value + } + + pub fn insert(self, value: V) -> &'a mut V { + self.map.items.insert( + self.idx, + KeyValuePair { + key: self.key, + value, + }, + ); + &mut self.map.items[self.idx].value + } +} + +struct KeyValuePair { + key: K, + value: V, +} +impl AsRef for KeyValuePair { + #[inline] + fn as_ref(&self) -> &V { + &self.value + } +} +impl AsMut for KeyValuePair { + #[inline] + fn as_mut(&mut self) -> &mut V { + &mut self.value + } +} +impl Clone for KeyValuePair +where + K: Clone, + V: Clone, +{ + fn clone(&self) -> Self { + Self { + key: self.key.clone(), + value: self.value.clone(), + } + } +} +impl Copy for KeyValuePair +where + K: Copy, + V: Copy, +{ +} + +impl Eq for KeyValuePair where K: Eq {} + +impl PartialEq for KeyValuePair +where + K: PartialEq, +{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.key.eq(&other.key) + } +} + +impl Ord for KeyValuePair +where + K: Ord, +{ + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + self.key.cmp(&other.key) + } +} +impl PartialOrd for KeyValuePair +where + K: PartialOrd, +{ + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + self.key.partial_cmp(&other.key) + } +} diff --git a/hir2/src/adt/smallordset.rs b/hir2/src/adt/smallordset.rs new file mode 100644 index 000000000..1ab27c013 --- /dev/null +++ b/hir2/src/adt/smallordset.rs @@ -0,0 +1,189 @@ +use core::{borrow::Borrow, fmt}; + +use smallvec::SmallVec; + +/// [SmallOrdSet] is a [BTreeSet]-like structure that can store a specified number +/// of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// This data structure is designed with two goals in mind: +/// +/// * Support efficient set operations over a small set of items +/// * Maintains the underlying set in order (according to the `Ord` impl of the element type) +/// * Avoid allocating data on the heap for the typical case +/// +/// Internally, [SmallOrdSet] is implemented on top of [SmallVec], and uses binary search +/// to locate elements. This is quite efficient in general, and is particularly fast +/// when all of the data is stored inline, but may not be a good fit for all use cases. +/// +/// Due to its design constraints, it only supports elements which implement [Ord]. +/// +/// NOTE: This type differs from [SmallSet] in that [SmallOrdSet] uses the [Ord] implementation +/// of the element type for ordering, while [SmallSet] preserves the insertion order of elements. +/// Beyond that, the two types are meant to be essentially equivalent. +pub struct SmallOrdSet { + items: SmallVec<[T; N]>, +} +impl Default for SmallOrdSet { + fn default() -> Self { + Self { + items: Default::default(), + } + } +} +impl Eq for SmallOrdSet where T: Eq {} +impl PartialEq for SmallOrdSet +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.items.eq(&other.items) + } +} +impl fmt::Debug for SmallOrdSet +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_set().entries(self.items.iter()).finish() + } +} +impl Clone for SmallOrdSet +where + T: Clone, +{ + #[inline] + fn clone(&self) -> Self { + Self { + items: self.items.clone(), + } + } +} +impl SmallOrdSet +where + T: Ord, +{ + pub fn from_vec(items: SmallVec<[T; N]>) -> Self { + let mut set = Self { items }; + set.sort_and_dedup(); + set + } + + #[inline] + pub fn from_buf(buf: [T; N]) -> Self { + Self::from_vec(buf.into()) + } +} +impl IntoIterator for SmallOrdSet { + type IntoIter = smallvec::IntoIter<[T; N]>; + type Item = T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} +impl FromIterator for SmallOrdSet +where + T: Ord, +{ + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut set = Self::default(); + for item in iter { + set.insert(item); + } + set + } +} +impl SmallOrdSet +where + T: Ord, +{ + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + pub fn len(&self) -> usize { + self.items.len() + } + + pub fn as_slice(&self) -> &[T] { + self.items.as_slice() + } + + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.items.iter() + } + + pub fn insert(&mut self, item: T) -> bool { + match self.find(&item) { + Ok(_) => false, + Err(idx) => { + self.items.insert(idx, item); + true + } + } + } + + pub fn remove(&mut self, item: &Q) -> Option + where + T: Borrow, + Q: Ord + ?Sized, + { + match self.find(item) { + Ok(idx) => Some(self.items.remove(idx)), + Err(_) => None, + } + } + + /// Clear the content of the set + pub fn clear(&mut self) { + self.items.clear(); + } + + pub fn contains(&self, item: &Q) -> bool + where + T: Borrow, + Q: Ord + ?Sized, + { + self.find(item).is_ok() + } + + pub fn get(&self, item: &Q) -> Option<&T> + where + T: Borrow, + Q: Ord + ?Sized, + { + match self.find(item) { + Ok(idx) => Some(&self.items[idx]), + Err(_) => None, + } + } + + pub fn get_mut(&mut self, item: &Q) -> Option<&mut T> + where + T: Borrow, + Q: Ord + ?Sized, + { + match self.find(item) { + Ok(idx) => Some(&mut self.items[idx]), + Err(_) => None, + } + } + + #[inline] + fn find(&self, item: &Q) -> Result + where + T: Borrow, + Q: Ord + ?Sized, + { + self.items.binary_search_by(|probe| Ord::cmp(probe.borrow(), item)) + } + + fn sort_and_dedup(&mut self) { + self.items.sort_unstable(); + self.items.dedup(); + } +} diff --git a/hir2/src/adt/smallprio.rs b/hir2/src/adt/smallprio.rs new file mode 100644 index 000000000..b5b5258e8 --- /dev/null +++ b/hir2/src/adt/smallprio.rs @@ -0,0 +1,167 @@ +use core::{cmp::Ordering, fmt}; + +use smallvec::SmallVec; + +use super::SmallDeque; + +/// [SmallPriorityQueue] is a priority queue structure that can store a specified number +/// of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// Elements in the queue are stored "largest priority first", as determined by the [Ord] +/// implementation of the element type. If you instead wish to have a "lowest priority first" +/// queue, you can use [core::cmp::Reverse] to invert the natural order of the type. +/// +/// It is an exercise for the reader to figure out how to wrap a type with a custom comparator +/// function. Since that isn't particularly needed yet, no built-in support for that is provided. +pub struct SmallPriorityQueue { + pq: SmallDeque, +} +impl fmt::Debug for SmallPriorityQueue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.iter()).finish() + } +} +impl Default for SmallPriorityQueue { + fn default() -> Self { + Self { + pq: Default::default(), + } + } +} +impl Clone for SmallPriorityQueue { + fn clone(&self) -> Self { + Self { + pq: self.pq.clone(), + } + } + + fn clone_from(&mut self, source: &Self) { + self.pq.clone_from(&source.pq); + } +} + +impl SmallPriorityQueue { + /// Returns true if this map is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.pq.is_empty() + } + + /// Returns the number of key/value pairs in this map + #[inline] + pub fn len(&self) -> usize { + self.pq.len() + } + + /// Pop the highest priority item from the queue + #[inline] + pub fn pop(&mut self) -> Option { + self.pq.pop_front() + } + + /// Pop the lowest priority item from the queue + #[inline] + pub fn pop_last(&mut self) -> Option { + self.pq.pop_back() + } + + /// Get a reference to the next highest priority item in the queue + #[inline] + pub fn front(&self) -> Option<&T> { + self.pq.front() + } + + /// Get a reference to the lowest highest priority item in the queue + #[inline] + pub fn back(&self) -> Option<&T> { + self.pq.back() + } + + /// Get a front-to-back iterator over the items in the queue + #[inline] + pub fn iter(&self) -> super::smalldeque::Iter<'_, T> { + self.pq.iter() + } +} + +impl SmallPriorityQueue +where + T: Ord, +{ + /// Returns a new, empty [SmallPriorityQueue] + pub const fn new() -> Self { + Self { + pq: SmallDeque::new(), + } + } + + /// Push an item on the queue. + /// + /// If the item's priority is equal to, or greater than any other item in the queue, the newly + /// pushed item will be placed at the front of the queue. Otherwise, the item is placed in the + /// queue at the next slot where it's priority is at least the same as the next value in the + /// queue at that slot. + pub fn push(&mut self, item: T) { + if let Some(head) = self.pq.front() { + match head.cmp(&item) { + Ordering::Greater => self.push_slow(item), + Ordering::Equal | Ordering::Less => { + // Push to the front for efficiency + self.pq.push_front(item); + } + } + } else { + self.pq.push_back(item); + } + } + + /// Push an item on the queue, by conducting a search for the most appropriate index at which + /// to insert the new item, based upon a comparator function that compares the priorities of + /// the items. + fn push_slow(&mut self, item: T) { + match self.pq.binary_search_by(|probe| probe.cmp(&item)) { + Ok(index) => { + self.pq.insert(index, item); + } + Err(index) => { + self.pq.insert(index, item); + } + } + } +} + +impl IntoIterator for SmallPriorityQueue { + type IntoIter = super::smalldeque::IntoIter; + type Item = T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.pq.into_iter() + } +} + +impl FromIterator for SmallPriorityQueue +where + T: Ord, +{ + #[inline] + fn from_iter>(iter: I) -> Self { + let mut pq = SmallDeque::from_iter(iter); + + let items = pq.make_contiguous(); + + items.sort(); + + Self { pq } + } +} + +impl From> for SmallPriorityQueue { + fn from(mut value: SmallVec<[T; N]>) -> Self { + value.sort(); + + Self { + pq: SmallDeque::from(value), + } + } +} diff --git a/hir2/src/adt/smallset.rs b/hir2/src/adt/smallset.rs new file mode 100644 index 000000000..73c4aabc5 --- /dev/null +++ b/hir2/src/adt/smallset.rs @@ -0,0 +1,298 @@ +use core::{borrow::Borrow, fmt}; + +use smallvec::SmallVec; + +/// [SmallSet] is a set data structure that can store a specified number +/// of elements inline (i.e. on the stack) without allocating memory from the heap. +/// +/// This data structure is designed with two goals in mind: +/// +/// * Support efficient set operations over a small set of items +/// * Preserve the insertion order of those items +/// * Avoid allocating data on the heap for the typical case +/// +/// Internally, [SmallSet] is implemented on top of [SmallVec], and uses linear search +/// to locate elements. This is only reasonably efficient on small sets, for anything +/// larger you should reach for quite efficient in general, and is particularly fast +/// when all of the data is stored inline, but may not be a good fit for all use cases. +/// +/// Due to its design constraints, elements must implement [Eq]. +pub struct SmallSet { + items: SmallVec<[T; N]>, +} +impl Default for SmallSet { + fn default() -> Self { + Self { + items: Default::default(), + } + } +} +impl Eq for SmallSet where T: Eq {} +impl PartialEq for SmallSet +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.items.eq(&other.items) + } +} +impl fmt::Debug for SmallSet +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_set().entries(self.items.iter()).finish() + } +} +impl Clone for SmallSet +where + T: Clone, +{ + #[inline] + fn clone(&self) -> Self { + Self { + items: self.items.clone(), + } + } +} +impl IntoIterator for SmallSet { + type IntoIter = smallvec::IntoIter<[T; N]>; + type Item = T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} +impl From<[T; N]> for SmallSet +where + T: Eq, +{ + #[inline] + fn from(items: [T; N]) -> Self { + Self::from_iter(items) + } +} +impl FromIterator for SmallSet +where + T: Eq, +{ + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut set = Self::default(); + for item in iter { + set.insert(item); + } + set + } +} +impl SmallSet +where + T: Eq, +{ + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + pub fn len(&self) -> usize { + self.items.len() + } + + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.items.iter() + } + + #[inline] + pub fn as_slice(&self) -> &[T] { + self.items.as_slice() + } + + pub fn insert(&mut self, item: T) -> bool { + if self.contains(&item) { + return false; + } + self.items.push(item); + true + } + + pub fn remove(&mut self, item: &Q) -> Option + where + T: Borrow, + Q: Eq + ?Sized, + { + match self.find(item) { + Some(idx) => Some(self.items.remove(idx)), + None => None, + } + } + + /// Remove all items from the set for which `predicate` returns false. + pub fn retain(&mut self, predicate: F) + where + F: FnMut(&mut T) -> bool, + { + self.items.retain(predicate); + } + + /// Clear the content of the set + pub fn clear(&mut self) { + self.items.clear(); + } + + pub fn contains(&self, item: &Q) -> bool + where + T: Borrow, + Q: Eq + ?Sized, + { + self.find(item).is_some() + } + + pub fn get(&self, item: &Q) -> Option<&T> + where + T: Borrow, + Q: Eq + ?Sized, + { + match self.find(item) { + Some(idx) => Some(&self.items[idx]), + None => None, + } + } + + pub fn get_mut(&mut self, item: &Q) -> Option<&mut T> + where + T: Borrow, + Q: Eq + ?Sized, + { + match self.find(item) { + Some(idx) => Some(&mut self.items[idx]), + None => None, + } + } + + /// Convert this set into a `SmallVec` containing the items of the set + #[inline] + pub fn into_vec(self) -> SmallVec<[T; N]> { + self.items + } + + #[inline] + fn find(&self, item: &Q) -> Option + where + T: Borrow, + Q: Eq + ?Sized, + { + self.items.iter().position(|elem| elem.borrow() == item) + } +} + +impl SmallSet +where + T: Clone + Eq, +{ + /// Obtain a new [SmallSet] containing the unique elements of both `self` and `other` + pub fn union(&self, other: &Self) -> Self { + let mut result = self.clone(); + + for item in other.items.iter() { + if result.contains(item) { + continue; + } + result.items.push(item.clone()); + } + + result + } + + /// Obtain a new [SmallSet] containing the unique elements of both `self` and `other` + pub fn into_union(mut self, other: &Self) -> Self { + for item in other.items.iter() { + if self.contains(item) { + continue; + } + self.items.push(item.clone()); + } + + self + } + + /// Obtain a new [SmallSet] containing the elements in common between `self` and `other` + pub fn intersection(&self, other: &Self) -> Self { + let mut result = Self::default(); + + for item in self.items.iter() { + if other.contains(item) { + result.items.push(item.clone()); + } + } + + result + } + + /// Obtain a new [SmallSet] containing the elements in common between `self` and `other` + pub fn into_intersection(self, other: &Self) -> Self { + let mut result = Self::default(); + + for item in self.items.into_iter() { + if other.contains(&item) { + result.items.push(item); + } + } + + result + } + + /// Obtain a new [SmallSet] containing the elements in `self` but not in `other` + pub fn difference(&self, other: &Self) -> Self { + let mut result = Self::default(); + + for item in self.items.iter() { + if other.contains(item) { + continue; + } + result.items.push(item.clone()); + } + + result + } + + /// Obtain a new [SmallSet] containing the elements in `self` but not in `other` + pub fn into_difference(mut self, other: &Self) -> Self { + Self { + items: self.items.drain_filter(|item| !other.contains(item)).collect(), + } + } + + /// Obtain a new [SmallSet] containing the elements in `self` or `other`, but not in both + pub fn symmetric_difference(&self, other: &Self) -> Self { + let mut result = Self::default(); + + for item in self.items.iter() { + if other.contains(item) { + continue; + } + result.items.push(item.clone()); + } + + for item in other.items.iter() { + if self.contains(item) { + continue; + } + result.items.push(item.clone()); + } + + result + } +} + +impl core::iter::Extend for SmallSet +where + E: Eq, +{ + fn extend>(&mut self, iter: T) { + for item in iter { + self.insert(item); + } + } +} diff --git a/hir2/src/adt/sparsemap.rs b/hir2/src/adt/sparsemap.rs new file mode 100644 index 000000000..51a1ed290 --- /dev/null +++ b/hir2/src/adt/sparsemap.rs @@ -0,0 +1,233 @@ +//! This module is based on [cranelift_entity::SparseMap], but implemented in-tree +//! because the SparseMapValueTrait is not implemented for any standard library types +use cranelift_entity::{EntityRef, SecondaryMap}; + +pub trait SparseMapValue { + fn key(&self) -> K; +} +impl> SparseMapValue for Box { + fn key(&self) -> K { + (**self).key() + } +} +impl> SparseMapValue for alloc::rc::Rc { + fn key(&self) -> K { + (**self).key() + } +} +impl SparseMapValue for crate::ValueId { + fn key(&self) -> crate::ValueId { + *self + } +} +impl SparseMapValue for crate::BlockId { + fn key(&self) -> crate::BlockId { + *self + } +} + +/// A sparse mapping of entity references. +/// +/// A `SparseMap` map provides: +/// +/// - Memory usage equivalent to `SecondaryMap` + `Vec`, so much smaller than +/// `SecondaryMap` for sparse mappings of larger `V` types. +/// - Constant time lookup, slightly slower than `SecondaryMap`. +/// - A very fast, constant time `clear()` operation. +/// - Fast insert and erase operations. +/// - Stable iteration that is as fast as a `Vec`. +/// +/// # Compared to `SecondaryMap` +/// +/// When should we use a `SparseMap` instead of a secondary `SecondaryMap`? First of all, +/// `SparseMap` does not provide the functionality of a `PrimaryMap` which can allocate and assign +/// entity references to objects as they are pushed onto the map. It is only the secondary entity +/// maps that can be replaced with a `SparseMap`. +/// +/// - A secondary entity map assigns a default mapping to all keys. It doesn't distinguish between +/// an unmapped key and one that maps to the default value. `SparseMap` does not require `Default` +/// values, and it tracks accurately if a key has been mapped or not. +/// - Iterating over the contents of an `SecondaryMap` is linear in the size of the *key space*, +/// while iterating over a `SparseMap` is linear in the number of elements in the mapping. This is +/// an advantage precisely when the mapping is sparse. +/// - `SparseMap::clear()` is constant time and super-fast. `SecondaryMap::clear()` is linear in the +/// size of the key space. (Or, rather the required `resize()` call following the `clear()` is). +/// - `SparseMap` requires the values to implement `SparseMapValue` which means that they must +/// contain their own key. +pub struct SparseMap +where + K: EntityRef, + V: SparseMapValue, +{ + sparse: SecondaryMap, + dense: Vec, +} +impl Default for SparseMap +where + K: EntityRef, + V: SparseMapValue, +{ + fn default() -> Self { + Self { + sparse: SecondaryMap::new(), + dense: Vec::new(), + } + } +} +impl core::fmt::Debug for SparseMap +where + K: EntityRef + core::fmt::Debug, + V: SparseMapValue + core::fmt::Debug, +{ + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + f.debug_map().entries(self.values().map(|v| (v.key(), v))).finish() + } +} +impl SparseMap +where + K: EntityRef, + V: SparseMapValue, +{ + /// Create a new empty mapping. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Returns the number of elements in the map. + pub fn len(&self) -> usize { + self.dense.len() + } + + /// Returns true is the map contains no elements. + pub fn is_empty(&self) -> bool { + self.dense.is_empty() + } + + /// Remove all elements from the mapping. + pub fn clear(&mut self) { + self.dense.clear(); + } + + /// Returns a reference to the value corresponding to the key. + pub fn get(&self, key: K) -> Option<&V> { + if let Some(idx) = self.sparse.get(key).cloned() { + if let Some(entry) = self.dense.get(idx as usize) { + if entry.key() == key { + return Some(entry); + } + } + } + None + } + + /// Returns a mutable reference to the value corresponding to the key. + /// + /// Note that the returned value must not be mutated in a way that would change its key. This + /// would invalidate the sparse set data structure. + pub fn get_mut(&mut self, key: K) -> Option<&mut V> { + if let Some(idx) = self.sparse.get(key).cloned() { + if let Some(entry) = self.dense.get_mut(idx as usize) { + if entry.key() == key { + return Some(entry); + } + } + } + None + } + + /// Return the index into `dense` of the value corresponding to `key`. + fn index(&self, key: K) -> Option { + if let Some(idx) = self.sparse.get(key).cloned() { + let idx = idx as usize; + if let Some(entry) = self.dense.get(idx) { + if entry.key() == key { + return Some(idx); + } + } + } + None + } + + /// Return `true` if the map contains a value corresponding to `key`. + pub fn contains_key(&self, key: K) -> bool { + self.get(key).is_some() + } + + /// Insert a value into the map. + /// + /// If the map did not have this key present, `None` is returned. + /// + /// If the map did have this key present, the value is updated, and the old value is returned. + /// + /// It is not necessary to provide a key since the value knows its own key already. + pub fn insert(&mut self, value: V) -> Option { + let key = value.key(); + + // Replace the existing entry for `key` if there is one. + if let Some(entry) = self.get_mut(key) { + return Some(core::mem::replace(entry, value)); + } + + // There was no previous entry for `key`. Add it to the end of `dense`. + let idx = self.dense.len(); + debug_assert!(idx <= u32::MAX as usize, "SparseMap overflow"); + self.dense.push(value); + self.sparse[key] = idx as u32; + None + } + + /// Remove a value from the map and return it. + pub fn remove(&mut self, key: K) -> Option { + if let Some(idx) = self.index(key) { + let back = self.dense.pop().unwrap(); + + // Are we popping the back of `dense`? + if idx == self.dense.len() { + return Some(back); + } + + // We're removing an element from the middle of `dense`. + // Replace the element at `idx` with the back of `dense`. + // Repair `sparse` first. + self.sparse[back.key()] = idx as u32; + return Some(core::mem::replace(&mut self.dense[idx], back)); + } + + // Nothing to remove. + None + } + + /// Remove the last value from the map. + pub fn pop(&mut self) -> Option { + self.dense.pop() + } + + /// Get an iterator over the values in the map. + /// + /// The iteration order is entirely determined by the preceding sequence of `insert` and + /// `remove` operations. In particular, if no elements were removed, this is the insertion + /// order. + pub fn values(&self) -> core::slice::Iter { + self.dense.iter() + } + + /// Get the values as a slice. + pub fn as_slice(&self) -> &[V] { + self.dense.as_slice() + } +} + +/// Iterating over the elements of a set. +impl<'a, K, V> IntoIterator for &'a SparseMap +where + K: EntityRef, + V: SparseMapValue, +{ + type IntoIter = core::slice::Iter<'a, V>; + type Item = &'a V; + + fn into_iter(self) -> Self::IntoIter { + self.values() + } +} diff --git a/hir2/src/any.rs b/hir2/src/any.rs new file mode 100644 index 000000000..3a3fd4296 --- /dev/null +++ b/hir2/src/any.rs @@ -0,0 +1,269 @@ +use core::{any::Any, marker::Unsize}; + +pub trait AsAny: Any + Unsize { + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self + } + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + #[inline(always)] + fn into_any(self: Box) -> Box { + self + } +} + +impl> AsAny for T {} + +/// # Safety +/// +/// This trait is not safe (or possible) to implement manually. +/// +/// It is automatically derived for all `T: Unsize` +pub unsafe trait Is: Unsize {} +unsafe impl Is for T where T: ?Sized + Unsize {} + +pub trait IsObjOf {} +impl IsObjOf for Trait +where + T: ?Sized + Is, + Trait: ?Sized, +{ +} + +#[allow(unused)] +pub trait DowncastRef: IsObjOf { + fn downcast_ref(&self) -> Option<&To>; + fn downcast_mut(&mut self) -> Option<&mut To>; +} +impl DowncastRef for From +where + From: ?Sized, + To: ?Sized + DowncastFromRef, +{ + #[inline(always)] + fn downcast_ref(&self) -> Option<&To> { + To::downcast_from_ref(self) + } + + #[inline(always)] + fn downcast_mut(&mut self) -> Option<&mut To> { + To::downcast_from_mut(self) + } +} + +pub trait DowncastFromRef: Is { + fn downcast_from_ref(from: &From) -> Option<&Self>; + fn downcast_from_mut(from: &mut From) -> Option<&mut Self>; +} +impl DowncastFromRef for To +where + From: ?Sized + AsAny, + To: Is + 'static, +{ + #[inline] + fn downcast_from_ref(from: &From) -> Option<&Self> { + from.as_any().downcast_ref() + } + + #[inline] + fn downcast_from_mut(from: &mut From) -> Option<&mut Self> { + from.as_any_mut().downcast_mut() + } +} + +#[allow(unused)] +pub trait Downcast: DowncastRef + Is { + fn downcast(self: Box) -> Result, Box>; +} +impl Downcast for From +where + From: ?Sized + DowncastRef + Is, + To: ?Sized + DowncastFrom, + Obj: ?Sized, +{ + #[inline] + fn downcast(self: Box) -> Result, Box> { + To::downcast_from(self) + } +} + +pub trait DowncastFrom: DowncastFromRef +where + From: ?Sized + Is, + Obj: ?Sized, +{ + fn downcast_from(from: Box) -> Result, Box>; +} +impl DowncastFrom for To +where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, +{ + fn downcast_from(from: Box) -> Result, Box> { + if !from.as_any().is::() { + Ok(from.into_any().downcast().unwrap()) + } else { + Err(from) + } + } +} + +pub trait Upcast: TryUpcastRef { + fn upcast_ref(&self) -> &To; + fn upcast_mut(&mut self) -> &mut To; + fn upcast(self: Box) -> Box; +} +impl Upcast for From +where + From: ?Sized + Is, + To: ?Sized, +{ + #[inline(always)] + fn upcast_ref(&self) -> &To { + self + } + + #[inline(always)] + fn upcast_mut(&mut self) -> &mut To { + self + } + + #[inline(always)] + fn upcast(self: Box) -> Box { + self + } +} + +pub trait TryUpcastRef: Is +where + To: ?Sized, +{ + #[allow(unused)] + fn is_of(&self) -> bool; + fn try_upcast_ref(&self) -> Option<&To>; + fn try_upcast_mut(&mut self) -> Option<&mut To>; +} +impl TryUpcastRef for From +where + From: Upcast + ?Sized, + To: ?Sized, +{ + #[inline(always)] + fn is_of(&self) -> bool { + true + } + + #[inline(always)] + fn try_upcast_ref(&self) -> Option<&To> { + Some(self.upcast_ref()) + } + + #[inline(always)] + fn try_upcast_mut(&mut self) -> Option<&mut To> { + Some(self.upcast_mut()) + } +} + +pub trait TryUpcast: Is + TryUpcastRef +where + To: ?Sized, + Obj: ?Sized, +{ + #[inline(always)] + fn try_upcast(self: Box) -> Result, Box> { + Err(self) + } +} +impl TryUpcast for From +where + From: Is + Upcast + ?Sized, + To: ?Sized, + Obj: ?Sized, +{ + #[inline] + fn try_upcast(self: Box) -> Result, Box> { + Ok(self.upcast()) + } +} + +/// Upcasts a type into a trait object +#[allow(unused)] +pub trait UpcastFrom +where + From: ?Sized, +{ + fn upcast_from_ref(from: &From) -> &Self; + fn upcast_from_mut(from: &mut From) -> &mut Self; + fn upcast_from(from: Box) -> Box; +} + +impl UpcastFrom for To +where + From: Upcast + ?Sized, + To: ?Sized, +{ + #[inline] + fn upcast_from_ref(from: &From) -> &Self { + from.upcast_ref() + } + + #[inline] + fn upcast_from_mut(from: &mut From) -> &mut Self { + from.upcast_mut() + } + + #[inline] + fn upcast_from(from: Box) -> Box { + from.upcast() + } +} + +#[allow(unused)] +pub trait TryUpcastFromRef: IsObjOf +where + From: ?Sized, +{ + fn try_upcast_from_ref(from: &From) -> Option<&Self>; + fn try_upcast_from_mut(from: &mut From) -> Option<&mut Self>; +} + +impl TryUpcastFromRef for To +where + From: TryUpcastRef + ?Sized, + To: ?Sized, +{ + #[inline] + fn try_upcast_from_ref(from: &From) -> Option<&Self> { + from.try_upcast_ref() + } + + #[inline] + fn try_upcast_from_mut(from: &mut From) -> Option<&mut Self> { + from.try_upcast_mut() + } +} + +#[allow(unused)] +pub trait TryUpcastFrom: TryUpcastFromRef +where + From: Is + ?Sized, + Obj: ?Sized, +{ + fn try_upcast_from(from: Box) -> Result, Box>; +} + +impl TryUpcastFrom for To +where + From: TryUpcast + ?Sized, + To: ?Sized, + Obj: ?Sized, +{ + #[inline] + fn try_upcast_from(from: Box) -> Result, Box> { + from.try_upcast() + } +} diff --git a/hir2/src/attributes.rs b/hir2/src/attributes.rs new file mode 100644 index 000000000..269c36c63 --- /dev/null +++ b/hir2/src/attributes.rs @@ -0,0 +1,486 @@ +mod call_conv; +mod overflow; +mod visibility; + +use alloc::collections::BTreeMap; +use core::{any::Any, borrow::Borrow, fmt}; + +pub use self::{call_conv::CallConv, overflow::Overflow, visibility::Visibility}; +use crate::interner::Symbol; + +pub mod markers { + use midenc_hir_symbol::symbols; + + use super::*; + + /// This attribute indicates that the decorated function is the entrypoint + /// for its containing program, regardless of what module it is defined in. + pub const ENTRYPOINT: Attribute = Attribute { + name: symbols::Entrypoint, + value: None, + }; +} + +/// An [AttributeSet] is a uniqued collection of attributes associated with some IR entity +#[derive(Debug, Default, Hash)] +pub struct AttributeSet(Vec); +impl FromIterator for AttributeSet { + fn from_iter(attrs: T) -> Self + where + T: IntoIterator, + { + let mut map = BTreeMap::default(); + for attr in attrs.into_iter() { + map.insert(attr.name, attr.value); + } + Self(map.into_iter().map(|(name, value)| Attribute { name, value }).collect()) + } +} +impl FromIterator<(Symbol, Option>)> for AttributeSet { + fn from_iter(attrs: T) -> Self + where + T: IntoIterator>)>, + { + let mut map = BTreeMap::default(); + for (name, value) in attrs.into_iter() { + map.insert(name, value); + } + Self(map.into_iter().map(|(name, value)| Attribute { name, value }).collect()) + } +} +impl AttributeSet { + /// Get a new, empty [AttributeSet] + pub fn new() -> Self { + Self::default() + } + + /// Insert a new [Attribute] in this set by `name` and `value` + pub fn insert(&mut self, name: impl Into, value: Option) { + self.set(Attribute { + name: name.into(), + value: value.map(|v| Box::new(v) as Box), + }); + } + + /// Adds `attr` to this set + pub fn set(&mut self, attr: Attribute) { + match self.0.binary_search_by_key(&attr.name, |attr| attr.name) { + Ok(index) => { + self.0[index].value = attr.value; + } + Err(index) => { + if index == self.0.len() { + self.0.push(attr); + } else { + self.0.insert(index, attr); + } + } + } + } + + /// Remove an [Attribute] by name from this set + pub fn remove(&mut self, name: impl Into) { + let name = name.into(); + match self.0.binary_search_by_key(&name, |attr| attr.name) { + Ok(index) if index + 1 == self.0.len() => { + self.0.pop(); + } + Ok(index) => { + self.0.remove(index); + } + Err(_) => (), + } + } + + /// Determine if the named [Attribute] is present in this set + pub fn has(&self, key: impl Into) -> bool { + let key = key.into(); + self.0.binary_search_by_key(&key, |attr| attr.name).is_ok() + } + + /// Get the [AttributeValue] associated with the named [Attribute] + pub fn get_any(&self, key: impl Into) -> Option<&dyn AttributeValue> { + let key = key.into(); + match self.0.binary_search_by_key(&key, |attr| attr.name) { + Ok(index) => self.0[index].value.as_deref(), + Err(_) => None, + } + } + + /// Get the [AttributeValue] associated with the named [Attribute] + pub fn get_any_mut(&mut self, key: impl Into) -> Option<&mut dyn AttributeValue> { + let key = key.into(); + match self.0.binary_search_by_key(&key, |attr| attr.name) { + Ok(index) => self.0[index].value.as_deref_mut(), + Err(_) => None, + } + } + + /// Get the value associated with the named [Attribute] as a value of type `V`, or `None`. + pub fn get(&self, key: impl Into) -> Option<&V> + where + V: AttributeValue, + { + self.get_any(key).and_then(|v| v.downcast_ref::()) + } + + /// Get the value associated with the named [Attribute] as a mutable value of type `V`, or + /// `None`. + pub fn get_mut(&mut self, key: impl Into) -> Option<&mut V> + where + V: AttributeValue, + { + self.get_any_mut(key).and_then(|v| v.downcast_mut::()) + } + + /// Iterate over each [Attribute] in this set + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter() + } +} + +/// An [Attribute] associates some data with a well-known identifier (name). +/// +/// Attributes are used for representing metadata that helps guide compilation, +/// but which is not part of the code itself. For example, `cfg` flags in Rust +/// are an example of something which you could represent using an [Attribute]. +/// They can also be used to store documentation, source locations, and more. +#[derive(Debug, Hash)] +pub struct Attribute { + /// The name of this attribute + pub name: Symbol, + /// The value associated with this attribute + pub value: Option>, +} +impl Attribute { + pub fn new(name: impl Into, value: Option) -> Self { + Self { + name: name.into(), + value: value.map(|v| Box::new(v) as Box), + } + } + + pub fn value(&self) -> Option<&dyn AttributeValue> { + self.value.as_deref() + } + + pub fn value_as(&self) -> Option<&V> + where + V: AttributeValue, + { + match self.value.as_deref() { + Some(value) => value.downcast_ref::(), + None => None, + } + } +} +impl fmt::Display for Attribute { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.value.as_deref().map(|v| v.render()) { + None => write!(f, "#[{}]", self.name.as_str()), + Some(value) => write!(f, "#[{}({value})]", &self.name), + } + } +} + +pub trait AttributeValue: + Any + fmt::Debug + crate::formatter::PrettyPrint + crate::DynHash + 'static +{ + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + fn clone_value(&self) -> Box; +} + +impl dyn AttributeValue { + pub fn is(&self) -> bool { + self.as_any().is::() + } + + pub fn downcast(self: Box) -> Result, Box> { + if self.is::() { + let ptr = Box::into_raw(self); + Ok(unsafe { Box::from_raw(ptr.cast()) }) + } else { + Err(self) + } + } + + pub fn downcast_ref(&self) -> Option<&T> { + self.as_any().downcast_ref::() + } + + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.as_any_mut().downcast_mut::() + } +} + +impl core::hash::Hash for dyn AttributeValue { + fn hash(&self, state: &mut H) { + use crate::DynHash; + + let hashable = self as &dyn DynHash; + hashable.dyn_hash(state); + } +} + +#[derive(Clone)] +pub struct SetAttr { + values: Vec, +} +impl Default for SetAttr { + fn default() -> Self { + Self { + values: Default::default(), + } + } +} +impl SetAttr +where + K: Ord + Clone, +{ + pub fn insert(&mut self, key: K) -> bool { + match self.values.binary_search_by(|k| key.cmp(k)) { + Ok(index) => { + self.values[index] = key; + false + } + Err(index) => { + self.values.insert(index, key); + true + } + } + } + + pub fn contains(&self, key: &K) -> bool { + self.values.binary_search_by(|k| key.cmp(k)).is_ok() + } + + pub fn iter(&self) -> core::slice::Iter<'_, K> { + self.values.iter() + } + + pub fn remove(&mut self, key: &Q) -> Option + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|k| key.cmp(k.borrow())) { + Ok(index) => Some(self.values.remove(index)), + Err(_) => None, + } + } +} +impl Eq for SetAttr where K: Eq {} +impl PartialEq for SetAttr +where + K: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.values == other.values + } +} +impl fmt::Debug for SetAttr +where + K: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_set().entries(self.values.iter()).finish() + } +} +impl crate::formatter::PrettyPrint for SetAttr +where + K: crate::formatter::PrettyPrint, +{ + fn render(&self) -> crate::formatter::Document { + todo!() + } +} +impl core::hash::Hash for SetAttr +where + K: core::hash::Hash, +{ + fn hash(&self, state: &mut H) { + as core::hash::Hash>::hash(&self.values, state); + } +} +impl AttributeValue for SetAttr +where + K: fmt::Debug + crate::formatter::PrettyPrint + Clone + core::hash::Hash + 'static, +{ + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self as &mut dyn Any + } + + #[inline] + fn clone_value(&self) -> Box { + Box::new(self.clone()) + } +} + +#[derive(Clone)] +pub struct DictAttr { + values: Vec<(K, V)>, +} +impl Default for DictAttr { + fn default() -> Self { + Self { values: vec![] } + } +} +impl DictAttr +where + K: Ord, + V: Clone, +{ + pub fn insert(&mut self, key: K, value: V) { + match self.values.binary_search_by(|(k, _)| key.cmp(k)) { + Ok(index) => { + self.values[index].1 = value; + } + Err(index) => { + self.values.insert(index, (key, value)); + } + } + } + + pub fn contains_key(&self, key: &Q) -> bool + where + K: Borrow, + Q: ?Sized + Ord, + { + self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())).is_ok() + } + + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())) { + Ok(index) => Some(&self.values[index].1), + Err(_) => None, + } + } + + pub fn remove(&mut self, key: &Q) -> Option + where + K: Borrow, + Q: ?Sized + Ord, + { + match self.values.binary_search_by(|(k, _)| key.cmp(k.borrow())) { + Ok(index) => Some(self.values.remove(index).1), + Err(_) => None, + } + } + + pub fn iter(&self) -> core::slice::Iter<'_, (K, V)> { + self.values.iter() + } +} +impl Eq for DictAttr +where + K: Eq, + V: Eq, +{ +} +impl PartialEq for DictAttr +where + K: PartialEq, + V: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.values == other.values + } +} +impl fmt::Debug for DictAttr +where + K: fmt::Debug, + V: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.values.iter().map(|entry| (&entry.0, &entry.1))) + .finish() + } +} +impl crate::formatter::PrettyPrint for DictAttr +where + K: crate::formatter::PrettyPrint, + V: crate::formatter::PrettyPrint, +{ + fn render(&self) -> crate::formatter::Document { + todo!() + } +} +impl core::hash::Hash for DictAttr +where + K: core::hash::Hash, + V: core::hash::Hash, +{ + fn hash(&self, state: &mut H) { + as core::hash::Hash>::hash(&self.values, state); + } +} +impl AttributeValue for DictAttr +where + K: fmt::Debug + crate::formatter::PrettyPrint + Clone + core::hash::Hash + 'static, + V: fmt::Debug + crate::formatter::PrettyPrint + Clone + core::hash::Hash + 'static, +{ + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self as &mut dyn Any + } + + #[inline] + fn clone_value(&self) -> Box { + Box::new(self.clone()) + } +} + +#[macro_export] +macro_rules! define_attr_type { + ($T:ty) => { + impl $crate::AttributeValue for $T { + #[inline(always)] + fn as_any(&self) -> &dyn core::any::Any { + self as &dyn core::any::Any + } + + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn core::any::Any { + self as &mut dyn core::any::Any + } + + #[inline] + fn clone_value(&self) -> Box { + Box::new(self.clone()) + } + } + }; +} + +define_attr_type!(bool); +define_attr_type!(u8); +define_attr_type!(i8); +define_attr_type!(u16); +define_attr_type!(i16); +define_attr_type!(u32); +define_attr_type!(core::num::NonZeroU32); +define_attr_type!(i32); +define_attr_type!(u64); +define_attr_type!(i64); +define_attr_type!(usize); +define_attr_type!(isize); +define_attr_type!(Symbol); +define_attr_type!(super::Immediate); +define_attr_type!(super::Type); diff --git a/hir2/src/attributes/call_conv.rs b/hir2/src/attributes/call_conv.rs new file mode 100644 index 000000000..abededc0e --- /dev/null +++ b/hir2/src/attributes/call_conv.rs @@ -0,0 +1,57 @@ +use core::fmt; + +/// Represents the calling convention of a function. +/// +/// Calling conventions are part of a program's ABI (Application Binary Interface), and +/// they define things such how arguments are passed to a function, how results are returned, +/// etc. In essence, the contract between caller and callee is described by the calling convention +/// of a function. +/// +/// Importantly, it is perfectly normal to mix calling conventions. For example, the public +/// API for a C library will use whatever calling convention is used by C on the target +/// platform (for Miden, that would be `SystemV`). However, internally that library may use +/// the `Fast` calling convention to allow the compiler to optimize more effectively calls +/// from the public API to private functions. In short, choose a calling convention that is +/// well-suited for a given function, to the extent that other constraints don't impose a choice +/// on you. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)] +#[cfg_attr( + feature = "serde", + derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) +)] +#[repr(u8)] +pub enum CallConv { + /// This calling convention is what I like to call "chef's choice" - the + /// compiler chooses it's own convention that optimizes for call performance. + /// + /// As a result of this, it is not permitted to use this convention in externally + /// linked functions, as the convention is unstable, and the compiler can't ensure + /// that the caller in another translation unit will use the correct convention. + Fast, + /// The standard calling convention used for C on most platforms + #[default] + SystemV, + /// A function which is using the WebAssembly Component Model "Canonical ABI". + Wasm, + /// A function with this calling convention must be called using + /// the `syscall` instruction. Attempts to call it with any other + /// call instruction will cause a validation error. The one exception + /// to this rule is when calling another function with the `Kernel` + /// convention that is defined in the same module, which can use the + /// standard `call` instruction. + /// + /// Kernel functions may only be defined in a kernel [Module]. + /// + /// In all other respects, this calling convention is the same as `SystemV` + Kernel, +} +impl fmt::Display for CallConv { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Fast => f.write_str("fast"), + Self::SystemV => f.write_str("C"), + Self::Wasm => f.write_str("wasm"), + Self::Kernel => f.write_str("kernel"), + } + } +} diff --git a/hir2/src/attributes/overflow.rs b/hir2/src/attributes/overflow.rs new file mode 100644 index 000000000..8bc7eae00 --- /dev/null +++ b/hir2/src/attributes/overflow.rs @@ -0,0 +1,66 @@ +use core::fmt; + +use crate::define_attr_type; + +/// This enumeration represents the various ways in which arithmetic operations +/// can be configured to behave when either the operands or results over/underflow +/// the range of the integral type. +/// +/// Always check the documentation of the specific instruction involved to see if there +/// are any specific differences in how this enum is interpreted compared to the default +/// meaning of each variant. +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] +pub enum Overflow { + /// Typically, this means the operation is performed using the equivalent field element + /// operation, rather than a dedicated operation for the given type. Because of this, the + /// result of the operation may exceed that of the integral type expected, but this will + /// not be caught right away. + /// + /// It is the callers responsibility to ensure that resulting value is in range. + #[default] + Unchecked, + /// The operation will trap if the operands, or the result, is not valid for the range of the + /// integral type involved, e.g. u32. + Checked, + /// The operation will wrap around, depending on the range of the integral type. For example, + /// given a u32 value, this is done by applying `mod 2^32` to the result. + Wrapping, + /// The result of the operation will be computed as in [Wrapping], however in addition to the + /// result, this variant also pushes a value on the stack which represents whether or not the + /// operation over/underflowed; either 1 if over/underflow occurred, or 0 otherwise. + Overflowing, +} +impl Overflow { + /// Returns true if overflow is unchecked + pub fn is_unchecked(&self) -> bool { + matches!(self, Self::Unchecked) + } + + /// Returns true if overflow will cause a trap + pub fn is_checked(&self) -> bool { + matches!(self, Self::Checked) + } + + /// Returns true if overflow will add an extra boolean on top of the stack + pub fn is_overflowing(&self) -> bool { + matches!(self, Self::Overflowing) + } +} +impl fmt::Display for Overflow { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Unchecked => f.write_str("unchecked"), + Self::Checked => f.write_str("checked"), + Self::Wrapping => f.write_str("wrapping"), + Self::Overflowing => f.write_str("overflow"), + } + } +} +impl crate::formatter::PrettyPrint for Overflow { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + display(self) + } +} + +define_attr_type!(Overflow); diff --git a/hir2/src/attributes/visibility.rs b/hir2/src/attributes/visibility.rs new file mode 100644 index 000000000..1c23a8f8f --- /dev/null +++ b/hir2/src/attributes/visibility.rs @@ -0,0 +1,77 @@ +use core::{fmt, str::FromStr}; + +use crate::define_attr_type; + +/// The types of visibility that a [Symbol] may have +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Visibility { + /// The symbol is public and may be referenced anywhere internal or external to the visible + /// references in the IR. + /// + /// Public visibility implies that we cannot remove the symbol even if we are unaware of any + /// references, and no other constraints apply, as we must assume that the symbol has references + /// we don't know about. + #[default] + Public, + /// The symbol is private and may only be referenced by ops local to operations within the + /// current symbol table. + /// + /// Private visibility implies that we know all uses of the symbol, and that those uses must + /// all exist within the current symbol table. + Private, + /// The symbol is public, but may only be referenced by symbol tables in the current compilation + /// graph, thus retaining the ability to observe all uses, and optimize based on that + /// information. + /// + /// Internal visibility implies that we know all uses of the symbol, but that there may be uses + /// in other symbol tables in addition to the current one. + Internal, +} +define_attr_type!(Visibility); +impl Visibility { + #[inline] + pub fn is_public(&self) -> bool { + matches!(self, Self::Public) + } + + #[inline] + pub fn is_private(&self) -> bool { + matches!(self, Self::Private) + } + + #[inline] + pub fn is_internal(&self) -> bool { + matches!(self, Self::Internal) + } +} +impl crate::formatter::PrettyPrint for Visibility { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + match self { + Self::Public => const_text("public"), + Self::Private => const_text("private"), + Self::Internal => const_text("internal"), + } + } +} +impl fmt::Display for Visibility { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Public => f.write_str("public"), + Self::Private => f.write_str("private"), + Self::Internal => f.write_str("internal"), + } + } +} +impl FromStr for Visibility { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "public" => Ok(Self::Public), + "private" => Ok(Self::Private), + "internal" => Ok(Self::Internal), + _ => Err(()), + } + } +} diff --git a/hir2/src/demangle.rs b/hir2/src/demangle.rs new file mode 100644 index 000000000..98493de1c --- /dev/null +++ b/hir2/src/demangle.rs @@ -0,0 +1,13 @@ +/// Demangle `name`, where `name` was mangled using Rust's mangling scheme +#[inline] +pub fn demangle>(name: S) -> String { + demangle_impl(name.as_ref()) +} + +fn demangle_impl(name: &str) -> String { + let mut input = name.as_bytes(); + let mut demangled = Vec::with_capacity(input.len() * 2); + rustc_demangle::demangle_stream(&mut input, &mut demangled, /* include_hash= */ false) + .expect("failed to write demangled identifier"); + String::from_utf8(demangled).expect("demangled identifier contains invalid utf-8") +} diff --git a/hir2/src/derive.rs b/hir2/src/derive.rs new file mode 100644 index 000000000..8aeba6d68 --- /dev/null +++ b/hir2/src/derive.rs @@ -0,0 +1,300 @@ +pub use midenc_hir_macros::operation; + +use crate::Operation; + +/// This macro is used to generate the boilerplate for operation trait implementations. +#[macro_export] +macro_rules! derive { + ( + $(#[$outer:meta])* + $vis:vis trait $OpTrait:ident { + $( + $OpTraitItem:item + )* + } + + verify { + $( + fn $verify_fn:ident($op:ident: &$OperationPath:path, $ctx:ident: &$ContextPath:path) -> $VerifyResult:ty $verify:block + )+ + } + + $($t:tt)* + ) => { + $crate::__derive_op_trait! { + $(#[$outer])* + $vis trait $OpTrait { + $( + $OpTraitItem:item + )* + } + + verify { + $( + fn $verify_fn($op: &$OperationPath, $ctx: &$ContextPath) -> $VerifyResult $verify + )* + } + } + + $($t)* + }; + + ( + $(#[$outer:meta])* + $vis:vis trait $OpTrait:ident { + $( + $OpTraitItem:item + )* + } + + $($t:tt)* + ) => { + $crate::__derive_op_trait! { + $(#[$outer])* + $vis trait $OpTrait { + $( + $OpTraitItem:item + )* + } + } + + $($t)* + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __derive_op_trait { + ( + $(#[$outer:meta])* + $vis:vis trait $OpTrait:ident { + $( + $OpTraitItem:item + )* + } + + verify { + $( + fn $verify_fn:ident($op:ident: &$OperationPath:path, $ctx:ident: &$ContextPath:path) -> $VerifyResult:ty $verify:block + )+ + } + ) => { + $(#[$outer])* + $vis trait $OpTrait { + $( + $OpTraitItem + )* + } + + impl $crate::Verify for T { + #[inline] + fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { + <$crate::Operation as $crate::Verify>::verify(self.as_operation(), context) + } + } + + impl $crate::Verify for $crate::Operation { + fn should_verify(&self, _context: &$crate::Context) -> bool { + self.implements::() + } + + fn verify(&self, context: &$crate::Context) -> Result<(), $crate::Report> { + $( + #[inline] + fn $verify_fn($op: &$OperationPath, $ctx: &$ContextPath) -> $VerifyResult $verify + )* + + $( + $verify_fn(self, context)?; + )* + + Ok(()) + } + } + }; + + ( + $(#[$outer:meta])* + $vis:vis trait $OpTrait:ident { + $( + $OpTraitItem:item + )* + } + ) => { + $(#[$outer])* + $vis trait $OpTrait { + $( + $OpTraitItem + )* + } + }; +} + +/// This type represents the concrete set of derived traits for some op `T`, paired with a +/// type-erased [Operation] reference for an instance of that op. +/// +/// This is used for two purposes: +/// +/// 1. To generate a specialized [OpVerifier] for `T` which contains all of the type and +/// trait-specific validation logic for that `T`. +/// 2. To apply the specialized verifier for `T` using the wrapped [Operation] reference. +#[doc(hidden)] +pub struct DeriveVerifier<'a, T, Derived: ?Sized> { + op: &'a Operation, + _t: core::marker::PhantomData, + _derived: core::marker::PhantomData, +} +impl<'a, T, Derived: ?Sized> DeriveVerifier<'a, T, Derived> { + #[doc(hidden)] + pub const fn new(op: &'a Operation) -> Self { + Self { + op, + _t: core::marker::PhantomData, + _derived: core::marker::PhantomData, + } + } +} +impl<'a, T, Derived: ?Sized> core::ops::Deref for DeriveVerifier<'a, T, Derived> { + type Target = Operation; + + fn deref(&self) -> &Self::Target { + self.op + } +} + +#[cfg(test)] +mod tests { + use alloc::rc::Rc; + use core::fmt; + + use super::operation; + use crate::{ + define_attr_type, dialects::hir::HirDialect, formatter, traits::*, Builder, BuilderExt, + Context, Op, Operation, Report, Spanned, Value, + }; + + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] + enum Overflow { + #[allow(unused)] + None, + Wrapping, + #[allow(unused)] + Overflowing, + } + impl fmt::Display for Overflow { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(self, f) + } + } + impl formatter::PrettyPrint for Overflow { + fn render(&self) -> formatter::Document { + use formatter::*; + display(self) + } + } + define_attr_type!(Overflow); + + /// An example op implementation to make sure all of the type machinery works + #[operation( + dialect = HirDialect, + traits(ArithmeticOp, BinaryOp, Commutative, SingleBlock, SameTypeOperands), + implements(InferTypeOpInterface) + )] + struct AddOp { + #[attr] + overflow: Overflow, + #[operand] + #[order(0)] + lhs: AnyInteger, + #[operand] + #[order(1)] + rhs: AnyInteger, + #[result] + result: AnyInteger, + } + + impl InferTypeOpInterface for AddOp { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty(); + { + let rhs = self.rhs(); + let rhs = rhs.value(); + let rhs_ty = rhs.ty(); + if &lhs != rhs_ty { + return Err(Report::msg(format!( + "lhs and rhs types do not match: expected '{lhs}', got '{rhs_ty}'" + ))); + } + } + self.result_mut().set_type(lhs); + Ok(()) + } + } + + derive! { + /// A marker trait for arithmetic ops + trait ArithmeticOp {} + + verify { + fn is_binary_op(op: &Operation, ctx: &Context) -> Result<(), Report> { + if op.num_operands() == 2 { + Ok(()) + } else { + Err( + ctx.session.diagnostics + .diagnostic(miden_assembly::diagnostics::Severity::Error) + .with_message("invalid operation") + .with_primary_label(op.span(), format!("incorrect number of operands, expected 2, got {}", op.num_operands())) + .with_help("this operator implements 'ArithmeticOp' which requires ops to be binary") + .into_report() + ) + } + } + } + } + + #[test] + fn derived_op_builder_test() { + use crate::{SourceSpan, Type}; + + let context = Rc::new(Context::default()); + let block = context.create_block_with_params([Type::U32, Type::U32]); + let (lhs, rhs) = { + let block = block.borrow(); + let lhs = block.get_argument(0).upcast::(); + let rhs = block.get_argument(1).upcast::(); + (lhs, rhs) + }; + let mut builder = context.builder(); + builder.set_insertion_point_to_end(block); + let op_builder = builder.create::(SourceSpan::default()); + let op = op_builder(lhs, rhs, Overflow::Wrapping); + let op = op.expect("failed to create AddOp"); + let op = op.borrow(); + assert!(op.as_operation().implements::()); + assert!(core::hint::black_box( + !>::VACUOUS + )); + } + + #[test] + #[should_panic = "lhs and rhs types do not match: expected 'u32', got 'i64'"] + fn derived_op_verifier_test() { + use crate::{SourceSpan, Type}; + + let context = Rc::new(Context::default()); + let block = context.create_block_with_params([Type::U32, Type::I64]); + let (lhs, invalid_rhs) = { + let block = block.borrow(); + let lhs = block.get_argument(0).upcast::(); + let rhs = block.get_argument(1).upcast::(); + (lhs, rhs) + }; + let mut builder = context.builder(); + builder.set_insertion_point_to_end(block); + // Try to create instance of AddOp with mismatched operand types + let op_builder = builder.create::(SourceSpan::default()); + let op = op_builder(lhs, invalid_rhs, Overflow::Wrapping); + let _op = op.unwrap(); + } +} diff --git a/hir2/src/dialects.rs b/hir2/src/dialects.rs new file mode 100644 index 000000000..85fe52210 --- /dev/null +++ b/hir2/src/dialects.rs @@ -0,0 +1 @@ +pub mod hir; diff --git a/hir2/src/dialects/hir.rs b/hir2/src/dialects/hir.rs new file mode 100644 index 000000000..31cd846ec --- /dev/null +++ b/hir2/src/dialects/hir.rs @@ -0,0 +1,158 @@ +mod builders; +mod ops; + +use alloc::rc::Rc; +use core::cell::{Cell, RefCell}; + +pub use self::{ + builders::{DefaultInstBuilder, FunctionBuilder, InstBuilder, InstBuilderBase}, + ops::*, +}; +use crate::{ + interner, AttributeValue, Builder, BuilderExt, Dialect, DialectName, DialectRegistration, + Immediate, OperationName, OperationRef, SourceSpan, Type, +}; + +#[derive(Default)] +pub struct HirDialect { + registered_ops: RefCell>, + registered_op_cache: Cell>>, +} + +impl core::fmt::Debug for HirDialect { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("HirDialect") + .field_with("registered_ops", |f| { + f.debug_set().entries(self.registered_ops.borrow().iter()).finish() + }) + .finish_non_exhaustive() + } +} + +impl HirDialect { + #[inline] + pub fn num_registered(&self) -> usize { + self.registered_ops.borrow().len() + } +} + +impl Dialect for HirDialect { + #[inline] + fn name(&self) -> DialectName { + DialectName::from_symbol(interner::symbols::Hir) + } + + fn registered_ops(&self) -> Rc<[OperationName]> { + let registered = unsafe { (*self.registered_op_cache.as_ptr()).clone() }; + if registered.as_ref().is_some_and(|ops| self.num_registered() == ops.len()) { + registered.unwrap() + } else { + let registered = self.registered_ops.borrow(); + let ops = Rc::from(registered.clone().into_boxed_slice()); + self.registered_op_cache.set(Some(Rc::clone(&ops))); + ops + } + } + + fn get_or_register_op( + &self, + opcode: midenc_hir_symbol::Symbol, + register: fn(DialectName, midenc_hir_symbol::Symbol) -> crate::OperationName, + ) -> crate::OperationName { + let mut registered = self.registered_ops.borrow_mut(); + match registered.binary_search_by_key(&opcode, |op| op.name()) { + Ok(index) => registered[index].clone(), + Err(index) => { + let name = register(self.name(), opcode); + registered.insert(index, name.clone()); + name + } + } + } + + fn materialize_constant( + &self, + builder: &mut dyn Builder, + attr: Box, + ty: &Type, + span: SourceSpan, + ) -> Option { + use crate::Op; + + // Save the current insertion point + let mut builder = crate::InsertionGuard::new(builder); + + // Only integer constants are supported for now + if !ty.is_integer() { + return None; + } + + // Currently, we expect folds to produce `Immediate`-valued attributes + if let Some(&imm) = attr.downcast_ref::() { + // If the immediate value is of the same type as the expected result type, we're ready + // to materialize the constant + let imm_ty = imm.ty(); + if &imm_ty == ty { + let op_builder = builder.create::(span); + return op_builder(imm) + .ok() + .map(|op| op.borrow().as_operation().as_operation_ref()); + } + + // The immediate value has a different type than expected, but we can coerce types, so + // long as the value fits in the target type + if imm_ty.size_in_bits() > ty.size_in_bits() { + return None; + } + + let imm = match ty { + Type::I8 => match imm { + Immediate::I1(value) => Immediate::I8(value as i8), + Immediate::U8(value) => Immediate::I8(i8::try_from(value).ok()?), + _ => return None, + }, + Type::U8 => match imm { + Immediate::I1(value) => Immediate::U8(value as u8), + Immediate::I8(value) => Immediate::U8(u8::try_from(value).ok()?), + _ => return None, + }, + Type::I16 => match imm { + Immediate::I1(value) => Immediate::I16(value as i16), + Immediate::I8(value) => Immediate::I16(value as i16), + Immediate::U8(value) => Immediate::I16(value.into()), + Immediate::U16(value) => Immediate::I16(i16::try_from(value).ok()?), + _ => return None, + }, + Type::U16 => match imm { + Immediate::I1(value) => Immediate::U16(value as u16), + Immediate::I8(value) => Immediate::U16(u16::try_from(value).ok()?), + Immediate::U8(value) => Immediate::U16(value as u16), + Immediate::I16(value) => Immediate::U16(u16::try_from(value).ok()?), + _ => return None, + }, + Type::I32 => Immediate::I32(imm.as_i32()?), + Type::U32 => Immediate::U32(imm.as_u32()?), + Type::I64 => Immediate::I64(imm.as_i64()?), + Type::U64 => Immediate::U64(imm.as_u64()?), + Type::I128 => Immediate::I128(imm.as_i128()?), + Type::U128 => Immediate::U128(imm.as_u128()?), + Type::Felt => Immediate::Felt(imm.as_felt()?), + ty => unimplemented!("unrecognized integral type '{ty}'"), + }; + + let op_builder = builder.create::(span); + return op_builder(imm).ok().map(|op| op.borrow().as_operation().as_operation_ref()); + } + + None + } +} + +impl DialectRegistration for HirDialect { + const NAMESPACE: &'static str = "hir"; + + #[inline] + fn init() -> Self { + Self::default() + } +} diff --git a/hir2/src/dialects/hir/builders.rs b/hir2/src/dialects/hir/builders.rs new file mode 100644 index 000000000..697d0e2d7 --- /dev/null +++ b/hir2/src/dialects/hir/builders.rs @@ -0,0 +1,3 @@ +mod function; + +pub use self::function::*; diff --git a/hir2/src/dialects/hir/builders/function.rs b/hir2/src/dialects/hir/builders/function.rs new file mode 100644 index 000000000..d5fefc69c --- /dev/null +++ b/hir2/src/dialects/hir/builders/function.rs @@ -0,0 +1,1161 @@ +use crate::{ + dialects::hir::*, AsCallableSymbolRef, Block, BlockRef, Builder, Immediate, InsertionPoint, Op, + OpBuilder, Region, RegionRef, Report, SourceSpan, Type, UnsafeIntrusiveEntityRef, Usable, + ValueRef, +}; + +pub struct FunctionBuilder<'f> { + pub func: &'f mut Function, + builder: OpBuilder, +} +impl<'f> FunctionBuilder<'f> { + pub fn new(func: &'f mut Function) -> Self { + let current_block = if func.body().is_empty() { + func.create_entry_block() + } else { + func.last_block() + }; + let context = func.as_operation().context_rc(); + let mut builder = OpBuilder::new(context); + + builder.set_insertion_point_to_end(current_block); + + Self { func, builder } + } + + pub fn at(func: &'f mut Function, ip: InsertionPoint) -> Self { + let context = func.as_operation().context_rc(); + let mut builder = OpBuilder::new(context); + builder.set_insertion_point(ip); + + Self { func, builder } + } + + pub fn body_region(&self) -> RegionRef { + unsafe { RegionRef::from_raw(&*self.func.body()) } + } + + pub fn entry_block(&self) -> BlockRef { + self.func.entry_block() + } + + #[inline] + pub fn current_block(&self) -> BlockRef { + self.builder.insertion_block().expect("builder has no insertion point set") + } + + #[inline] + pub fn switch_to_block(&mut self, block: BlockRef) { + self.builder.set_insertion_point_to_end(block); + } + + pub fn create_block(&mut self) -> BlockRef { + self.builder.create_block(self.body_region(), None, &[]) + } + + pub fn detach_block(&mut self, mut block: BlockRef) { + use crate::EntityWithParent; + + assert_ne!( + block, + self.current_block(), + "cannot remove block the builder is currently inserting in" + ); + assert_eq!( + block.borrow().parent().map(|p| RegionRef::as_ptr(&p)), + Some(&*self.func.body() as *const Region), + "cannot detach a block that does not belong to this function" + ); + let mut body = self.func.body_mut(); + unsafe { + body.body_mut().cursor_mut_from_ptr(block.clone()).remove(); + } + block.borrow_mut().uses_mut().clear(); + Block::on_removed_from_parent(block, body.as_region_ref()); + } + + pub fn append_block_param(&mut self, block: BlockRef, ty: Type, span: SourceSpan) -> ValueRef { + self.builder.context().append_block_argument(block, ty, span) + } + + pub fn ins<'a, 'b: 'a>(&'b mut self) -> DefaultInstBuilder<'a> { + DefaultInstBuilder::new(self.func, &mut self.builder) + } +} + +pub struct DefaultInstBuilder<'f> { + func: &'f mut Function, + builder: &'f mut OpBuilder, +} +impl<'f> DefaultInstBuilder<'f> { + pub(crate) fn new(func: &'f mut Function, builder: &'f mut OpBuilder) -> Self { + Self { func, builder } + } +} +impl<'f> InstBuilderBase<'f> for DefaultInstBuilder<'f> { + fn builder_parts(&mut self) -> (&mut Function, &mut OpBuilder) { + (self.func, self.builder) + } + + fn builder(&self) -> &OpBuilder { + self.builder + } + + fn builder_mut(&mut self) -> &mut OpBuilder { + self.builder + } +} + +pub trait InstBuilderBase<'f>: Sized { + fn builder(&self) -> &OpBuilder; + fn builder_mut(&mut self) -> &mut OpBuilder; + fn builder_parts(&mut self) -> (&mut Function, &mut OpBuilder); + /// Get a default instruction builder using the dataflow graph and insertion point of the + /// current builder + fn ins<'a, 'b: 'a>(&'b mut self) -> DefaultInstBuilder<'a> { + let (func, builder) = self.builder_parts(); + DefaultInstBuilder::new(func, builder) + } +} + +pub trait InstBuilder<'f>: InstBuilderBase<'f> { + fn assert( + mut self, + value: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = + self.builder_mut().create::(span); + op_builder(value) + } + + fn assert_with_error( + mut self, + value: ValueRef, + code: u32, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = + self.builder_mut().create::(span); + op_builder(value, code) + } + + fn assertz( + mut self, + value: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = + self.builder_mut().create::(span); + op_builder(value) + } + + fn assertz_with_error( + mut self, + value: ValueRef, + code: u32, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self + .builder_mut() + .create::(span); + op_builder(value, code) + } + + fn assert_eq( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(lhs, rhs) + } + + fn assert_eq_imm( + mut self, + lhs: ValueRef, + rhs: Immediate, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(lhs, rhs) + } + + fn u32(mut self, value: u32, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let constant = op_builder(Immediate::U32(value))?; + Ok(constant.borrow().result().as_value_ref()) + } + + //signed_integer_literal!(1, bool); + //integer_literal!(8); + //integer_literal!(16); + //integer_literal!(32); + //integer_literal!(64); + //integer_literal!(128); + + /* + fn felt(self, i: Felt, span: SourceSpan) -> Value { + into_first_result!(self.UnaryImm(Opcode::ImmFelt, Type::Felt, Immediate::Felt(i), span)) + } + + fn f64(self, f: f64, span: SourceSpan) -> Value { + into_first_result!(self.UnaryImm(Opcode::ImmF64, Type::F64, Immediate::F64(f), span)) + } + + fn character(self, c: char, span: SourceSpan) -> Value { + self.i32((c as u32) as i32, span) + } + */ + + /// Grow the global heap by `num_pages` pages, in 64kb units. + /// + /// Returns the previous size (in pages) of the heap, or -1 if the heap could not be grown. + fn mem_grow(mut self, num_pages: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(num_pages)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Return the size of the global heap in pages, where each page is 64kb. + fn mem_size(mut self, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder()?; + Ok(op.borrow().result().as_value_ref()) + } + + /* + /// Get a [GlobalValue] which represents the address of a global variable whose symbol is `name` + /// + /// On it's own, this does nothing, you must use the resulting [GlobalValue] with a builder + /// that expects one as an argument, or use `global_value` to obtain a [Value] from it. + fn symbol>(self, name: S, span: SourceSpan) -> GlobalValue { + self.symbol_relative(name, 0, span) + } + + /// Same semantics as `symbol`, but applies a constant offset to the address of the given + /// symbol. + /// + /// If the offset is zero, this is equivalent to `symbol` + fn symbol_relative>( + mut self, + name: S, + offset: i32, + span: SourceSpan, + ) -> GlobalValue { + self.data_flow_graph_mut().create_global_value(GlobalValueData::Symbol { + name: Ident::new(Symbol::intern(name.as_ref()), span), + offset, + }) + } + + /// Get the address of a global variable whose symbol is `name` + /// + /// The type of the pointer produced is given as `ty`. It is up to the caller + /// to ensure that loading memory from that pointer is valid for the provided + /// type. + fn symbol_addr>(self, name: S, ty: Type, span: SourceSpan) -> Value { + self.symbol_relative_addr(name, 0, ty, span) + } + + /// Same semantics as `symbol_addr`, but applies a constant offset to the address of the given + /// symbol. + /// + /// If the offset is zero, this is equivalent to `symbol_addr` + fn symbol_relative_addr>( + mut self, + name: S, + offset: i32, + ty: Type, + span: SourceSpan, + ) -> Value { + assert!(ty.is_pointer(), "expected pointer type, got '{}'", &ty); + let gv = self.data_flow_graph_mut().create_global_value(GlobalValueData::Symbol { + name: Ident::new(Symbol::intern(name.as_ref()), span), + offset, + }); + into_first_result!(self.Global(gv, ty, span)) + } + + /// Loads a value of type `ty` from the global variable whose symbol is `name`. + /// + /// NOTE: There is no requirement that the memory contents at the given symbol + /// contain a valid value of type `ty`. That is left entirely up the caller to + /// guarantee at a higher level. + fn load_symbol>(self, name: S, ty: Type, span: SourceSpan) -> Value { + self.load_symbol_relative(name, ty, 0, span) + } + + /// Same semantics as `load_symbol`, but a constant offset is applied to the address before + /// issuing the load. + fn load_symbol_relative>( + mut self, + name: S, + ty: Type, + offset: i32, + span: SourceSpan, + ) -> Value { + let base = self.data_flow_graph_mut().create_global_value(GlobalValueData::Symbol { + name: Ident::new(Symbol::intern(name.as_ref()), span), + offset: 0, + }); + self.load_global_relative(base, ty, offset, span) + } + + /// Loads a value of type `ty` from the address represented by `addr` + /// + /// NOTE: There is no requirement that the memory contents at the given symbol + /// contain a valid value of type `ty`. That is left entirely up the caller to + /// guarantee at a higher level. + fn load_global(self, addr: GlobalValue, ty: Type, span: SourceSpan) -> Value { + self.load_global_relative(addr, ty, 0, span) + } + + /// Same semantics as `load_global_relative`, but a constant offset is applied to the address + /// before issuing the load. + fn load_global_relative( + mut self, + base: GlobalValue, + ty: Type, + offset: i32, + span: SourceSpan, + ) -> Value { + if let GlobalValueData::Load { + ty: ref base_ty, .. + } = self.data_flow_graph().global_value(base) + { + // If the base global is a load, the target address cannot be computed until runtime, + // so expand this to the appropriate sequence of instructions to do so in that case + assert!(base_ty.is_pointer(), "expected global value to have pointer type"); + let base_ty = base_ty.clone(); + let base = self.ins().load_global(base, base_ty.clone(), span); + let addr = self.ins().ptrtoint(base, Type::U32, span); + let offset_addr = if offset >= 0 { + self.ins().add_imm_checked(addr, Immediate::U32(offset as u32), span) + } else { + self.ins().sub_imm_checked(addr, Immediate::U32(offset.unsigned_abs()), span) + }; + let ptr = self.ins().inttoptr(offset_addr, base_ty, span); + self.load(ptr, span) + } else { + // The global address can be computed statically + let gv = self.data_flow_graph_mut().create_global_value(GlobalValueData::Load { + base, + offset, + ty: ty.clone(), + }); + into_first_result!(self.Global(gv, ty, span)) + } + } + + /// Computes an address relative to the pointer produced by `base`, by applying an offset + /// given by multiplying `offset` * the size in bytes of `unit_ty`. + /// + /// The type of the pointer produced is the same as the type of the pointer given by `base` + /// + /// This is useful in some scenarios where `load_global_relative` is not, namely when computing + /// the effective address of an element of an array stored in a global variable. + fn global_addr_offset( + mut self, + base: GlobalValue, + offset: i32, + unit_ty: Type, + span: SourceSpan, + ) -> Value { + if let GlobalValueData::Load { + ty: ref base_ty, .. + } = self.data_flow_graph().global_value(base) + { + // If the base global is a load, the target address cannot be computed until runtime, + // so expand this to the appropriate sequence of instructions to do so in that case + assert!(base_ty.is_pointer(), "expected global value to have pointer type"); + let base_ty = base_ty.clone(); + let base = self.ins().load_global(base, base_ty.clone(), span); + let addr = self.ins().ptrtoint(base, Type::U32, span); + let unit_size: i32 = unit_ty + .size_in_bytes() + .try_into() + .expect("invalid type: size is larger than 2^32"); + let computed_offset = unit_size * offset; + let offset_addr = if computed_offset >= 0 { + self.ins().add_imm_checked(addr, Immediate::U32(offset as u32), span) + } else { + self.ins().sub_imm_checked(addr, Immediate::U32(offset.unsigned_abs()), span) + }; + let ptr = self.ins().inttoptr(offset_addr, base_ty, span); + self.load(ptr, span) + } else { + // The global address can be computed statically + let gv = self.data_flow_graph_mut().create_global_value(GlobalValueData::IAddImm { + base, + offset, + ty: unit_ty.clone(), + }); + let ty = self.data_flow_graph().global_type(gv); + into_first_result!(self.Global(gv, ty, span)) + } + } + + /// Loads a value of the type pointed to by the given pointer, on to the stack + /// + /// NOTE: This function will panic if `ptr` is not a pointer typed value + fn load(self, addr: Value, span: SourceSpan) -> Value { + let ty = require_pointee!(self, addr).clone(); + let data = Instruction::Load(LoadOp { + op: Opcode::Load, + addr, + ty: ty.clone(), + }); + into_first_result!(self.build(data, Type::Ptr(Box::new(ty)), span)) + } + + /// Loads a value from the given temporary (local variable), of the type associated with that + /// local. + fn load_local(self, local: LocalId, span: SourceSpan) -> Value { + let data = Instruction::LocalVar(LocalVarOp { + op: Opcode::Load, + local, + args: ValueList::default(), + }); + let ty = self.data_flow_graph().local_type(local).clone(); + into_first_result!(self.build(data, Type::Ptr(Box::new(ty)), span)) + } + + /// Stores `value` to the address given by `ptr` + /// + /// NOTE: This function will panic if the pointer and pointee types do not match + fn store(mut self, ptr: Value, value: Value, span: SourceSpan) -> Inst { + let pointee_ty = require_pointee!(self, ptr); + let value_ty = self.data_flow_graph().value_type(value); + assert_eq!(pointee_ty, value_ty, "expected value to be a {}, got {}", pointee_ty, value_ty); + let mut vlist = ValueList::default(); + { + let dfg = self.data_flow_graph_mut(); + vlist.extend([ptr, value], &mut dfg.value_lists); + } + self.PrimOp(Opcode::Store, Type::Unit, vlist, span).0 + } + + /// Stores `value` to the given temporary (local variable). + /// + /// NOTE: This function will panic if the type of `value` does not match the type of the local + /// variable. + fn store_local(mut self, local: LocalId, value: Value, span: SourceSpan) -> Inst { + let mut vlist = ValueList::default(); + { + let dfg = self.data_flow_graph_mut(); + let local_ty = dfg.local_type(local); + let value_ty = dfg.value_type(value); + assert_eq!(local_ty, value_ty, "expected value to be a {}, got {}", local_ty, value_ty); + vlist.push(value, &mut dfg.value_lists); + } + let data = Instruction::LocalVar(LocalVarOp { + op: Opcode::Store, + local, + args: vlist, + }); + self.build(data, Type::Unit, span).0 + } + */ + + /// Writes `count` copies of `value` to memory starting at address `dst`. + /// + /// Each copy of `value` will be written to memory starting at the next aligned address from + /// the previous copy. This instruction will trap if the input address does not meet the + /// minimum alignment requirements of the type. + fn memset( + mut self, + dst: ValueRef, + count: ValueRef, + value: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(dst, count, value) + } + + /// Copies `count` values from the memory at address `src`, to the memory at address `dst`. + /// + /// The unit size for `count` is determined by the `src` pointer type, i.e. a pointer to u8 + /// will copy one `count` bytes, a pointer to u16 will copy `count * 2` bytes, and so on. + /// + /// NOTE: The source and destination pointer types must match, or this function will panic. + fn memcpy( + mut self, + src: ValueRef, + dst: ValueRef, + count: ValueRef, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(src, dst, count) + } + + /// This is a cast operation that permits performing arithmetic on pointer values + /// by casting a pointer to a specified integral type. + fn ptrtoint(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// This is the inverse of `ptrtoint`, used to recover a pointer that was + /// previously cast to an integer type. It may also be used to cast arbitrary + /// integer values to pointers. + /// + /// In both cases, use of the resulting pointer must not violate the semantics + /// of the higher level language being represented in Miden IR. + fn inttoptr(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /* + /// This is an intrinsic which derives a new pointer from an existing pointer to an aggregate. + /// + /// In short, this represents the common need to calculate a new pointer from an existing + /// pointer, but without losing provenance of the original pointer. It is specifically + /// intended for use in obtaining a pointer to an element/field of an array/struct, of the + /// correct type, given a well typed pointer to the aggregate. + /// + /// This function will panic if the pointer is not to an aggregate type + /// + /// The new pointer is derived by statically navigating the structure of the pointee type, using + /// `offsets` to guide the traversal. Initially, the first offset is relative to the original + /// pointer, where `0` refers to the base/first field of the object. The second offset is then + /// relative to the base of the object selected by the first offset, and so on. Offsets must + /// remain in bounds, any attempt to index outside a type's boundaries will result in a + /// panic. + fn getelementptr(mut self, ptr: ValueRef, mut indices: &[usize], span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + op_builder(arg, ty) + } */ + + /// Cast `arg` to a value of type `ty` + /// + /// NOTE: This is only supported for integral types currently, and the types must be of the same + /// size in bytes, i.e. i32 -> u32 or vice versa. + /// + /// The intention of bitcasts is to reinterpret a value with different semantics, with no + /// validation that is typically implied by casting from one type to another. + fn bitcast(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Cast `arg` to a value of type `ty` + /// + /// NOTE: This is only valid for numeric to numeric, or pointer to pointer casts. + /// For numeric to pointer, or pointer to numeric casts, use `inttoptr` and `ptrtoint` + /// respectively. + fn cast(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Truncates an integral value as necessary to fit in `ty`. + /// + /// NOTE: Truncating a value into a larger type has undefined behavior, it is + /// equivalent to extending a value without doing anything with the new high-order + /// bits of the resulting value. + fn trunc(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Extends an integer into a larger integeral type, by zero-extending the value, + /// i.e. the new high-order bits of the resulting value will be all zero. + /// + /// NOTE: This function will panic if `ty` is smaller than `arg`. + /// + /// If `arg` is the same type as `ty`, `arg` is returned as-is + fn zext(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Extends an integer into a larger integeral type, by sign-extending the value, + /// i.e. the new high-order bits of the resulting value will all match the sign bit. + /// + /// NOTE: This function will panic if `ty` is smaller than `arg`. + /// + /// If `arg` is the same type as `ty`, `arg` is returned as-is + fn sext(mut self, arg: ValueRef, ty: Type, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(arg, ty)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement addition which traps on overflow + fn add(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Checked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Unchecked two's complement addition. Behavior is undefined if the result overflows. + fn add_unchecked( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Unchecked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement addition which wraps around on overflow, e.g. `wrapping_add` + fn add_wrapping( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Wrapping)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement addition which wraps around on overflow, but returns a boolean flag that + /// indicates whether or not the operation overflowed, followed by the wrapped result, e.g. + /// `overflowing_add` (but with the result types inverted compared to Rust's version). + fn add_overflowing( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result<(ValueRef, ValueRef), Report> { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + let op = op.borrow(); + let overflowed = op.overflowed().as_value_ref(); + let result = op.result().as_value_ref(); + Ok((overflowed, result)) + } + + /// Two's complement subtraction which traps on under/overflow + fn sub(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Checked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Unchecked two's complement subtraction. Behavior is undefined if the result under/overflows. + fn sub_unchecked( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Unchecked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement subtraction which wraps around on under/overflow, e.g. `wrapping_sub` + fn sub_wrapping( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Wrapping)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement subtraction which wraps around on overflow, but returns a boolean flag that + /// indicates whether or not the operation under/overflowed, followed by the wrapped result, + /// e.g. `overflowing_sub` (but with the result types inverted compared to Rust's version). + fn sub_overflowing( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result<(ValueRef, ValueRef), Report> { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + let op = op.borrow(); + let overflowed = op.overflowed().as_value_ref(); + let result = op.result().as_value_ref(); + Ok((overflowed, result)) + } + + /// Two's complement multiplication which traps on overflow + fn mul(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Checked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Unchecked two's complement multiplication. Behavior is undefined if the result overflows. + fn mul_unchecked( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Unchecked)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement multiplication which wraps around on overflow, e.g. `wrapping_mul` + fn mul_wrapping( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs, crate::Overflow::Wrapping)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement multiplication which wraps around on overflow, but returns a boolean flag + /// that indicates whether or not the operation overflowed, followed by the wrapped result, + /// e.g. `overflowing_mul` (but with the result types inverted compared to Rust's version). + fn mul_overflowing( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result<(ValueRef, ValueRef), Report> { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + let op = op.borrow(); + let overflowed = op.overflowed().as_value_ref(); + let result = op.result().as_value_ref(); + Ok((overflowed, result)) + } + + /// Integer division. Traps if `rhs` is zero. + fn div(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Integer Euclidean modulo. Traps if `rhs` is zero. + fn r#mod(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Combined integer Euclidean division and modulo. Traps if `rhs` is zero. + fn divmod( + mut self, + lhs: ValueRef, + rhs: ValueRef, + span: SourceSpan, + ) -> Result<(ValueRef, ValueRef), Report> { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + let op = op.borrow(); + let quotient = op.quotient().as_value_ref(); + let remainder = op.remainder().as_value_ref(); + Ok((quotient, remainder)) + } + + /// Exponentiation + fn exp(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Compute 2^n + fn pow2(mut self, n: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(n)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Compute ilog2(n) + fn ilog2(mut self, n: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(n)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Modular inverse + fn inv(mut self, n: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(n)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Unary negation + fn neg(mut self, n: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(n)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Two's complement unary increment by one which traps on overflow + fn incr(mut self, lhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Logical AND + fn and(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Logical OR + fn or(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Logical XOR + fn xor(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Logical NOT + fn not(mut self, lhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Bitwise AND + fn band(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Bitwise OR + fn bor(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Bitwise XOR + fn bxor(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Bitwise NOT + fn bnot(mut self, lhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn rotl(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn rotr(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn shl(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn shr(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn popcnt(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn clz(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn ctz(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn clo(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn cto(mut self, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn eq(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn neq(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Compares two integers and returns the minimum value + fn min(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + /// Compares two integers and returns the maximum value + fn max(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn gt(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn gte(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn lt(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn lte(mut self, lhs: ValueRef, rhs: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(lhs, rhs)?; + Ok(op.borrow().result().as_value_ref()) + } + + #[allow(clippy::wrong_self_convention)] + fn is_odd(mut self, value: ValueRef, span: SourceSpan) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(value)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn exec( + mut self, + callee: C, + args: A, + span: SourceSpan, + ) -> Result, Report> + where + C: AsCallableSymbolRef, + A: IntoIterator, + { + let op_builder = self.builder_mut().create::(span); + op_builder(callee, args) + } + + /* + fn call(mut self, callee: FunctionIdent, args: &[Value], span: SourceSpan) -> Inst { + let mut vlist = ValueList::default(); + { + let dfg = self.data_flow_graph_mut(); + assert!( + dfg.get_import(&callee).is_some(), + "must import callee ({}) before calling it", + &callee + ); + vlist.extend(args.iter().copied(), &mut dfg.value_lists); + } + self.Call(Opcode::Call, callee, vlist, span).0 + } + */ + + fn select( + mut self, + cond: ValueRef, + a: ValueRef, + b: ValueRef, + span: SourceSpan, + ) -> Result { + let op_builder = self.builder_mut().create::(span); + let op = op_builder(cond, a, b)?; + Ok(op.borrow().result().as_value_ref()) + } + + fn br( + mut self, + block: BlockRef, + args: A, + span: SourceSpan, + ) -> Result, Report> + where + A: IntoIterator, + { + let op_builder = self.builder_mut().create::(span); + op_builder(block, args) + } + + fn cond_br( + mut self, + cond: ValueRef, + then_dest: BlockRef, + then_args: T, + else_dest: BlockRef, + else_args: F, + span: SourceSpan, + ) -> Result, Report> + where + T: IntoIterator, + F: IntoIterator, + { + let op_builder = + self.builder_mut().create::(span); + op_builder(cond, then_dest, then_args, else_dest, else_args) + } + + /* + fn switch(self, arg: ValueRef, span: SourceSpan) -> SwitchBuilder<'f, Self> { + require_integer!(self, arg, Type::U32); + SwitchBuilder::new(self, arg, span) + } + */ + + fn ret( + mut self, + returning: Option, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self + .builder_mut() + .create:: as IntoIterator>::IntoIter,)>( + span, + ); + op_builder(returning) + } + + fn ret_imm( + mut self, + arg: Immediate, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder(arg) + } + + fn unreachable( + mut self, + span: SourceSpan, + ) -> Result, Report> { + let op_builder = self.builder_mut().create::(span); + op_builder() + } + + /* + fn inline_asm( + self, + args: &[Value], + results: impl IntoIterator, + span: SourceSpan, + ) -> MasmBuilder { + MasmBuilder::new(self, args, results.into_iter().collect(), span) + } + */ +} + +impl<'f, T: InstBuilderBase<'f>> InstBuilder<'f> for T {} + +/* +/// An instruction builder for `switch`, to ensure it is validated during construction +pub struct SwitchBuilder<'f, T: InstBuilder<'f>> { + builder: T, + arg: ValueRef, + span: SourceSpan, + arms: Vec, + _marker: core::marker::PhantomData<&'f Function>, +} +impl<'f, T: InstBuilder<'f>> SwitchBuilder<'f, T> { + fn new(builder: T, arg: ValueRef, span: SourceSpan) -> Self { + Self { + builder, + arg, + span, + arms: Default::default(), + _marker: core::marker::PhantomData, + } + } + + /// Specify to what block a specific discriminant value should be dispatched + pub fn case(mut self, discriminant: u32, target: Block, args: &[Value]) -> Self { + assert_eq!( + self.arms + .iter() + .find(|arm| arm.value == discriminant) + .map(|arm| arm.successor.destination), + None, + "duplicate switch case value '{discriminant}': already matched" + ); + let mut vlist = ValueList::default(); + { + let pool = &mut self.builder.data_flow_graph_mut().value_lists; + vlist.extend(args.iter().copied(), pool); + } + let arm = SwitchArm { + value: discriminant, + successor: Successor { + destination: target, + args: vlist, + }, + }; + self.arms.push(arm); + self + } + + /// Build the `switch` by specifying the fallback destination if none of the arms match + pub fn or_else(mut self, target: Block, args: &[Value]) -> Inst { + let mut vlist = ValueList::default(); + { + let pool = &mut self.builder.data_flow_graph_mut().value_lists; + vlist.extend(args.iter().copied(), pool); + } + let fallback = Successor { + destination: target, + args: vlist, + }; + self.builder.Switch(self.arg, self.arms, fallback, self.span).0 + } +} + */ diff --git a/hir2/src/dialects/hir/ops.rs b/hir2/src/dialects/hir/ops.rs new file mode 100644 index 000000000..47701fd07 --- /dev/null +++ b/hir2/src/dialects/hir/ops.rs @@ -0,0 +1,17 @@ +mod assertions; +mod binary; +mod cast; +mod constants; +mod control; +mod function; +mod invoke; +mod mem; +mod module; +mod primop; +mod ternary; +mod unary; + +pub use self::{ + assertions::*, binary::*, cast::*, constants::*, control::*, function::*, invoke::*, mem::*, + module::*, primop::*, ternary::*, unary::*, +}; diff --git a/hir2/src/dialects/hir/ops/assertions.rs b/hir2/src/dialects/hir/ops/assertions.rs new file mode 100644 index 000000000..0c704349f --- /dev/null +++ b/hir2/src/dialects/hir/ops/assertions.rs @@ -0,0 +1,55 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +#[operation( + dialect = HirDialect, + traits(HasSideEffects) +)] +pub struct Assert { + #[operand] + value: Bool, + #[attr] + #[default] + code: u32, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects) +)] +pub struct Assertz { + #[operand] + value: Bool, + #[attr] + #[default] + code: u32, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, Commutative, SameTypeOperands) +)] +pub struct AssertEq { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects) +)] +pub struct AssertEqImm { + #[operand] + lhs: AnyInteger, + #[attr] + rhs: Immediate, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, Terminator) +)] +pub struct Unreachable {} diff --git a/hir2/src/dialects/hir/ops/binary.rs b/hir2/src/dialects/hir/ops/binary.rs new file mode 100644 index 000000000..4e24a0d0c --- /dev/null +++ b/hir2/src/dialects/hir/ops/binary.rs @@ -0,0 +1,674 @@ +use crate::{derive::operation, dialects::hir::HirDialect, traits::*, *}; + +// Implement `derive(InferTypeOpInterface)` with `#[infer]` helper attribute: +// +// * `#[infer]` on a result field indicates its type should be inferred from the type of the first +// operand field +// * `#[infer(from = field)]` on a result field indicates its type should be inferred from +// the given field. The field is expected to implement `AsRef` +// * `#[infer(type = I1)]` on a field indicates that the field should always be inferred to have the given type +// * `#[infer(with = path::to::function)]` on a field indicates that the given function should be called to +// compute the inferred type for that field +macro_rules! infer_return_ty_for_binary_op { + ($Op:ty) => { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty().clone(); + self.result_mut().set_type(lhs); + Ok(()) + } + } + }; + + + ($Op:ty as $manually_specified_ty:expr) => { + paste::paste! { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type($manually_specified_ty); + Ok(()) + } + } + } + }; + + ($Op:ty, $($manually_specified_field_name:ident : $manually_specified_field_ty:expr),+) => { + paste::paste! { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty().clone(); + self.result_mut().set_type(lhs); + $( + self.[<$manually_specified_field_name _mut>]().set_type($manually_specified_field_ty); + )* + Ok(()) + } + } + } + }; +} + +/// Two's complement sum +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Add { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, + #[attr] + overflow: Overflow, +} + +infer_return_ty_for_binary_op!(Add); + +/// Two's complement sum with overflow bit +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct AddOverflowing { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + overflowed: Bool, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(AddOverflowing, overflowed: Type::I1); + +/// Two's complement difference (subtraction) +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Sub { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, + #[attr] + overflow: Overflow, +} + +infer_return_ty_for_binary_op!(Sub); + +/// Two's complement difference (subtraction) with underflow bit +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct SubOverflowing { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + overflowed: Bool, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(SubOverflowing, overflowed: Type::I1); + +/// Two's complement product +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Mul { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, + #[attr] + overflow: Overflow, +} + +infer_return_ty_for_binary_op!(Mul); + +/// Two's complement product with overflow bit +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct MulOverflowing { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + overflowed: Bool, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(MulOverflowing, overflowed: Type::I1); + +/// Exponentiation for field elements +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Exp { + #[operand] + lhs: IntFelt, + #[operand] + rhs: IntFelt, + #[result] + result: IntFelt, +} + +infer_return_ty_for_binary_op!(Exp); + +/// Unsigned integer division, traps on division by zero +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Div { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Div); + +/// Signed integer division, traps on division by zero or dividing the minimum signed value by -1 +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Sdiv { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Sdiv); + +/// Unsigned integer Euclidean modulo, traps on division by zero +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Mod { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Mod); + +/// Signed integer Euclidean modulo, traps on division by zero +/// +/// The result has the same sign as the dividend (lhs) +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Smod { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Smod); + +/// Combined unsigned integer Euclidean division and remainder (modulo). +/// +/// Traps on division by zero. +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Divmod { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + remainder: AnyInteger, + #[result] + quotient: AnyInteger, +} + +impl InferTypeOpInterface for Divmod { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty().clone(); + self.remainder_mut().set_type(lhs.clone()); + self.quotient_mut().set_type(lhs); + Ok(()) + } +} + +/// Combined signed integer Euclidean division and remainder (modulo). +/// +/// Traps on division by zero. +/// +/// The remainder has the same sign as the dividend (lhs) +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Sdivmod { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + remainder: AnyInteger, + #[result] + quotient: AnyInteger, +} + +impl InferTypeOpInterface for Sdivmod { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.lhs().ty().clone(); + self.remainder_mut().set_type(lhs.clone()); + self.quotient_mut().set_type(lhs); + Ok(()) + } +} + +/// Logical AND +/// +/// Operands must be boolean. +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct And { + #[operand] + lhs: Bool, + #[operand] + rhs: Bool, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(And); + +/// Logical OR +/// +/// Operands must be boolean. +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Or { + #[operand] + lhs: Bool, + #[operand] + rhs: Bool, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(Or); + +/// Logical XOR +/// +/// Operands must be boolean. +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Xor { + #[operand] + lhs: Bool, + #[operand] + rhs: Bool, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(Xor); + +/// Bitwise AND +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Band { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Band); + +/// Bitwise OR +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Bor { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Bor); + +/// Bitwise XOR +/// +/// Operands must be boolean. +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Bxor { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Bxor); + +/// Bitwise shift-left +/// +/// Shifts larger than the bitwidth of the value will be wrapped to zero. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] +pub struct Shl { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Shl); + +/// Bitwise shift-left by immediate +/// +/// Shifts larger than the bitwidth of the value will be wrapped to zero. +#[operation( + dialect = HirDialect, + implements(InferTypeOpInterface) +)] +pub struct ShlImm { + #[operand] + lhs: AnyInteger, + #[attr] + shift: u32, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for ShlImm { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.lhs().ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + +/// Bitwise (logical) shift-right +/// +/// Shifts larger than the bitwidth of the value will effectively truncate the value to zero. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] +pub struct Shr { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Shr); + +/// Arithmetic (signed) shift-right +/// +/// The result of shifts larger than the bitwidth of the value depend on the sign of the value; +/// for positive values, it rounds to zero; for negative values, it rounds to MIN. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] +pub struct Ashr { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Ashr); + +/// Bitwise rotate-left +/// +/// The rotation count must be < the bitwidth of the value type. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] +pub struct Rotl { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Rotl); + +/// Bitwise rotate-right +/// +/// The rotation count must be < the bitwidth of the value type. +#[operation( + dialect = HirDialect, + traits(BinaryOp), + implements(InferTypeOpInterface) +)] +pub struct Rotr { + #[operand] + lhs: AnyInteger, + #[operand] + shift: UInt32, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Rotr); + +/// Equality comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Eq { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(Eq as Type::I1); + +/// Inequality comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Neq { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(Neq as Type::I1); + +/// Greater-than comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Gt { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(Gt as Type::I1); + +/// Greater-than-or-equal comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Gte { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(Gte as Type::I1); + +/// Less-than comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Lt { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(Lt as Type::I1); + +/// Less-than-or-equal comparison +#[operation( + dialect = HirDialect, + traits(BinaryOp, SameTypeOperands), + implements(InferTypeOpInterface) +)] +pub struct Lte { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: Bool, +} + +infer_return_ty_for_binary_op!(Lte as Type::I1); + +/// Select minimum value +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Min { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Min); + +/// Select maximum value +#[operation( + dialect = HirDialect, + traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType), + implements(InferTypeOpInterface) +)] +pub struct Max { + #[operand] + lhs: AnyInteger, + #[operand] + rhs: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_binary_op!(Max); diff --git a/hir2/src/dialects/hir/ops/cast.rs b/hir2/src/dialects/hir/ops/cast.rs new file mode 100644 index 000000000..4a42720f5 --- /dev/null +++ b/hir2/src/dialects/hir/ops/cast.rs @@ -0,0 +1,179 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +/* +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum CastKind { + /// Reinterpret the bits of the operand as the target type, without any consideration for + /// the original meaning of those bits. + /// + /// For example, transmuting `u32::MAX` to `i32`, produces a value of `-1`, because the input + /// value overflows when interpreted as a signed integer. + Transmute, + /// Like `Transmute`, but the input operand is checked to verify that it is a valid value + /// of both the source and target types. + /// + /// For example, a checked cast of `u32::MAX` to `i32` would assert, because the input value + /// cannot be represented as an `i32` due to overflow. + Checked, + /// Convert the input value to the target type, by zero-extending the value to the target + /// bitwidth. A cast of this type must be a widening cast, i.e. from a smaller bitwidth to + /// a larger one. + Zext, + /// Convert the input value to the target type, by sign-extending the value to the target + /// bitwidth. A cast of this type must be a widening cast, i.e. from a smaller bitwidth to + /// a larger one. + Sext, + /// Convert the input value to the target type, by truncating the excess bits. A cast of this + /// type must be a narrowing cast, i.e. from a larger bitwidth to a smaller one. + Trunc, +} + */ + +#[operation( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct PtrToInt { + #[operand] + operand: AnyPointer, + #[attr] + ty: Type, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for PtrToInt { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct IntToPtr { + #[operand] + operand: AnyInteger, + #[attr] + ty: Type, + #[result] + result: AnyPointer, +} + +impl InferTypeOpInterface for IntToPtr { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Cast { + #[operand] + operand: AnyInteger, + #[attr] + ty: Type, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for Cast { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Bitcast { + #[operand] + operand: AnyPointerOrInteger, + #[attr] + ty: Type, + #[result] + result: AnyPointerOrInteger, +} + +impl InferTypeOpInterface for Bitcast { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Trunc { + #[operand] + operand: AnyInteger, + #[attr] + ty: Type, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for Trunc { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Zext { + #[operand] + operand: AnyUnsignedInteger, + #[attr] + ty: Type, + #[result] + result: AnyUnsignedInteger, +} + +impl InferTypeOpInterface for Zext { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} + +#[operation( + dialect = HirDialect, + traits(UnaryOp) +)] +pub struct Sext { + #[operand] + operand: AnySignedInteger, + #[attr] + ty: Type, + #[result] + result: AnySignedInteger, +} + +impl InferTypeOpInterface for Sext { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} diff --git a/hir2/src/dialects/hir/ops/constants.rs b/hir2/src/dialects/hir/ops/constants.rs new file mode 100644 index 000000000..c3681e47b --- /dev/null +++ b/hir2/src/dialects/hir/ops/constants.rs @@ -0,0 +1,41 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +#[operation( + dialect = HirDialect, + traits(ConstantLike), + implements(InferTypeOpInterface, Foldable) +)] +pub struct Constant { + #[attr] + value: Immediate, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for Constant { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.value().ty(); + self.result_mut().set_type(ty); + + Ok(()) + } +} + +impl Foldable for Constant { + #[inline] + fn fold(&self, results: &mut smallvec::SmallVec<[OpFoldResult; 1]>) -> FoldResult { + results.push(OpFoldResult::Attribute(self.get_attribute("value").unwrap().clone_value())); + FoldResult::Ok(()) + } + + #[inline(always)] + fn fold_with( + &self, + _operands: &[Option>], + results: &mut smallvec::SmallVec<[OpFoldResult; 1]>, + ) -> FoldResult { + self.fold(results) + } +} diff --git a/hir2/src/dialects/hir/ops/control.rs b/hir2/src/dialects/hir/ops/control.rs new file mode 100644 index 000000000..bd3f31be6 --- /dev/null +++ b/hir2/src/dialects/hir/ops/control.rs @@ -0,0 +1,626 @@ +use midenc_hir_macros::operation; +use smallvec::{smallvec, SmallVec}; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +/// Returns from the enclosing function with the provided operands as its results. +#[operation( + dialect = HirDialect, + traits(Terminator, ReturnLike) +)] +pub struct Ret { + #[operands] + values: AnyType, +} + +/// Returns from the enclosing function with the provided immediate value as its result. +#[operation( + dialect = HirDialect, + traits(Terminator, ReturnLike) +)] +pub struct RetImm { + #[attr] + value: Immediate, +} + +/// An unstructured control flow primitive representing an unconditional branch to `target` +#[operation( + dialect = HirDialect, + traits(Terminator), + implements(BranchOpInterface) +)] +pub struct Br { + #[successor] + target: Successor, +} + +impl BranchOpInterface for Br { + #[inline] + fn get_successor_for_operands( + &self, + _operands: &[Box], + ) -> Option { + Some(self.target().dest.borrow().block.clone()) + } +} + +/// An unstructured control flow primitive representing a conditional branch to either `then_dest` +/// or `else_dest` depending on the value of `condition`, a boolean value. +#[operation( + dialect = HirDialect, + traits(Terminator), + implements(BranchOpInterface) +)] +pub struct CondBr { + #[operand] + condition: Bool, + #[successor] + then_dest: Successor, + #[successor] + else_dest: Successor, +} +impl BranchOpInterface for CondBr { + fn get_successor_for_operands(&self, operands: &[Box]) -> Option { + let value = &*operands[0]; + let cond = if let Some(imm) = value.downcast_ref::() { + imm.as_bool().expect("invalid boolean condition for 'hir.if'") + } else if let Some(yes) = value.downcast_ref::() { + *yes + } else { + panic!("expected boolean immediate for '{}' condition, got: {:?}", self.name(), value) + }; + + Some(if cond { + self.then_dest().dest.borrow().block.clone() + } else { + self.else_dest().dest.borrow().block.clone() + }) + } +} + +/// An unstructured control flow primitive that represents a multi-way branch to one of multiple +/// branch targets, depending on the value of `selector`. +/// +/// If a specific selector value is matched by `cases`, the branch target corresponding to that +/// case is the one to which control is transferred. If no matching case is found for the selector, +/// then the `fallback` target is used instead. +/// +/// A `fallback` successor must always be provided. +#[operation( + dialect = HirDialect, + traits(Terminator), + implements(BranchOpInterface) +)] +pub struct Switch { + #[operand] + selector: UInt32, + #[successors(keyed)] + cases: SwitchCase, + #[successor] + fallback: Successor, +} + +impl BranchOpInterface for Switch { + #[inline] + fn get_successor_for_operands(&self, operands: &[Box]) -> Option { + let value = &*operands[0]; + let selector = if let Some(selector) = value.downcast_ref::() { + selector.as_u32().expect("invalid selector value for 'hir.switch'") + } else if let Some(selector) = value.downcast_ref::() { + *selector + } else if let Some(selector) = value.downcast_ref::() { + u32::try_from(*selector).expect("invalid selector value for 'hir.switch'") + } else if let Some(selector) = value.downcast_ref::() { + u32::try_from(*selector).expect("invalid selector value for 'hir.switch': out of range") + } else { + panic!("unsupported selector value type for '{}', got: {:?}", self.name(), value) + }; + + for switch_case in self.cases().iter() { + let key = switch_case.key().unwrap(); + if selector == key.value { + return Some(switch_case.block()); + } + } + + // If we reach here, no selector match was found, so use the fallback successor + Some(self.fallback().dest.borrow().block.clone()) + } +} + +/// Represents a single branch target by matching a specific selector value in a [Switch] +/// operation. +#[derive(Debug, Clone)] +pub struct SwitchCase { + pub value: u32, + pub successor: BlockRef, + pub arguments: Vec, +} + +#[doc(hidden)] +pub struct SwitchCaseRef<'a> { + pub value: u32, + pub successor: BlockOperandRef, + pub arguments: OpOperandRange<'a>, +} + +#[doc(hidden)] +pub struct SwitchCaseMut<'a> { + pub value: u32, + pub successor: BlockOperandRef, + pub arguments: OpOperandRangeMut<'a>, +} + +impl KeyedSuccessor for SwitchCase { + type Key = u32; + type Repr<'a> = SwitchCaseRef<'a>; + type ReprMut<'a> = SwitchCaseMut<'a>; + + fn key(&self) -> &Self::Key { + &self.value + } + + fn into_parts(self) -> (Self::Key, BlockRef, Vec) { + (self.value, self.successor, self.arguments) + } + + fn into_repr( + key: Self::Key, + block: BlockOperandRef, + operands: OpOperandRange<'_>, + ) -> Self::Repr<'_> { + SwitchCaseRef { + value: key, + successor: block, + arguments: operands, + } + } + + fn into_repr_mut( + key: Self::Key, + block: BlockOperandRef, + operands: OpOperandRangeMut<'_>, + ) -> Self::ReprMut<'_> { + SwitchCaseMut { + value: key, + successor: block, + arguments: operands, + } + } +} + +/// [If] is a structured control flow operation representing conditional execution. +/// +/// An [If] takes a single condition as an argument, which chooses between one of its two regions +/// based on the condition. If the condition is true, then the `then_body` region is executed, +/// otherwise `else_body`. +/// +/// Neither region allows any arguments, and both regions must be terminated with one of: +/// +/// * [Return] to return from the enclosing function directly +/// * [Unreachable] to abort execution +/// * [Yield] to return from the enclosing [If] +#[operation( + dialect = HirDialect, + traits(SingleBlock, NoRegionArguments), + implements(RegionBranchOpInterface) +)] +pub struct If { + #[operand] + condition: Bool, + #[region] + then_body: Region, + #[region] + else_body: Region, +} + +impl RegionBranchOpInterface for If { + fn get_entry_successor_regions( + &self, + operands: &[Option>], + ) -> RegionSuccessorIter<'_> { + match operands[0].as_deref() { + None => self.get_successor_regions(RegionBranchPoint::Parent), + Some(value) => { + let cond = if let Some(imm) = value.downcast_ref::() { + imm.as_bool().expect("invalid boolean condition for 'hir.if'") + } else if let Some(yes) = value.downcast_ref::() { + *yes + } else { + panic!( + "expected boolean immediate for '{}' condition, got: {:?}", + self.name(), + value + ) + }; + + if cond { + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.then_body().as_region_ref()), + key: None, + operand_group: 0, + }], + ) + } else { + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.else_body().as_region_ref()), + key: None, + operand_group: 0, + }], + ) + } + } + } + } + + fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_> { + match point { + RegionBranchPoint::Parent => { + // Either branch is reachable on entry + RegionSuccessorIter::new( + self.as_operation(), + [ + RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.then_body().as_region_ref()), + key: None, + operand_group: 0, + }, + RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.else_body().as_region_ref()), + key: None, + operand_group: 0, + }, + ], + ) + } + RegionBranchPoint::Child(_) => { + // Only the parent If is reachable from then_body/else_body + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + // TODO(pauls): Need to handle operand groups properly, as this group refers + // to the operand groups of If, but the results of the If come from the + // Yield contained in the If body + operand_group: 0, + }], + ) + } + } + } + + fn get_region_invocation_bounds( + &self, + operands: &[Option>], + ) -> SmallVec<[InvocationBounds; 1]> { + use smallvec::smallvec; + + match operands[0].as_deref() { + None => { + // Only one region is invoked, and no more than a single time + smallvec![InvocationBounds::NoMoreThan(1); 2] + } + Some(value) => { + let cond = if let Some(imm) = value.downcast_ref::() { + imm.as_bool().expect("invalid boolean condition for 'hir.if'") + } else if let Some(yes) = value.downcast_ref::() { + *yes + } else { + panic!( + "expected boolean immediate for '{}' condition, got: {:?}", + self.name(), + value + ) + }; + if cond { + smallvec![InvocationBounds::Exact(1), InvocationBounds::Never] + } else { + smallvec![InvocationBounds::Never, InvocationBounds::Exact(1)] + } + } + } + } + + #[inline(always)] + fn is_repetitive_region(&self, _index: usize) -> bool { + false + } + + #[inline(always)] + fn has_loop(&self) -> bool { + false + } +} + +/// A while is a loop structure composed of two regions: a "before" region, and an "after" region. +/// +/// The "before" region's entry block parameters correspond to the operands expected by the +/// operation, and can be used to compute the condition that determines whether the "after" body +/// is executed or not, or simply forwarded to the "after" region. The "before" region must +/// terminate with a [Condition] operation, which will be evaluated to determine whether or not +/// to continue the loop. +/// +/// The "after" region corresponds to the loop body, and must terminate with a [Yield] operation, +/// whose operands must be of the same arity and type as the "before" region's argument list. In +/// this way, the "after" body can feed back input to the "before" body to determine whether to +/// continue the loop. +#[operation( + dialect = HirDialect, + traits(SingleBlock), + implements(RegionBranchOpInterface) +)] +pub struct While { + #[region] + before: Region, + #[region] + after: Region, +} + +impl RegionBranchOpInterface for While { + #[inline] + fn get_entry_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> { + // Operands being forwarded to the `before` region from outside the loop + SuccessorOperandRange::forward(self.operands().all()) + } + + fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_> { + match point { + RegionBranchPoint::Parent => { + // The only successor region when branching from outside the While op is the + // `before` region. + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.before().as_region_ref()), + key: None, + operand_group: 0, + }], + ) + } + RegionBranchPoint::Child(region) => { + // When branching from `before`, the only successor is `after` or the While itself, + // otherwise, when branching from `after` the only successor is `before`. + let after_region = self.after().as_region_ref(); + if region == after_region { + // TODO(pauls): We should handle operands properly here - the While op itself + // does not have any operands for this transfer of control, that comes from the + // Yield op + RegionSuccessorIter::new( + self.as_operation(), + [RegionSuccessorInfo { + successor: RegionBranchPoint::Child(self.before().as_region_ref()), + key: None, + operand_group: 0, + }], + ) + } else { + // TODO(pauls): We should handle operands properly here - the While op itself + // does not have any operands for this transfer of control, that comes from the + // Condition op + assert!( + region == self.before().as_region_ref(), + "unexpected region branch point" + ); + RegionSuccessorIter::new( + self.as_operation(), + [ + RegionSuccessorInfo { + successor: RegionBranchPoint::Child(after_region), + key: None, + operand_group: 0, + }, + RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + operand_group: 0, + }, + ], + ) + } + } + } + } + + #[inline] + fn get_region_invocation_bounds( + &self, + _operands: &[Option>], + ) -> SmallVec<[InvocationBounds; 1]> { + smallvec![InvocationBounds::Unknown; self.num_regions()] + } + + #[inline(always)] + fn is_repetitive_region(&self, _index: usize) -> bool { + // Both regions are in the loop (`before` -> `after` -> `before` -> `after`) + true + } + + #[inline(always)] + fn has_loop(&self) -> bool { + true + } +} + +/// The [Condition] op is used in conjunction with [While] as the terminator of its `before` region. +/// +/// This op represents a choice between continuing the loop, or exiting the [While] loop and +/// continuing execution after the loop. +/// +/// NOTE: Attempting to use this op in any other context than the one described above is invalid, +/// and the implementation of various interfaces by this op will panic if that assumption is +/// violated. +#[operation( + dialect = HirDialect, + traits(Terminator, ReturnLike), + implements(RegionBranchTerminatorOpInterface) +)] +pub struct Condition { + #[operand] + condition: Bool, + #[operands] + forwarded: AnyType, +} + +impl RegionBranchTerminatorOpInterface for Condition { + #[inline] + fn get_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> { + SuccessorOperandRange::forward(self.forwarded()) + } + + #[inline] + fn get_mutable_successor_operands( + &mut self, + _point: RegionBranchPoint, + ) -> SuccessorOperandRangeMut<'_> { + SuccessorOperandRangeMut::forward(self.forwarded_mut()) + } + + fn get_successor_regions( + &self, + operands: &[Option>], + ) -> SmallVec<[RegionSuccessorInfo; 2]> { + // A [While] loop has two regions: `before` (containing this op), and `after`, which this + // op branches to when the condition is true. If the condition is false, control is + // transferred back to the parent [While] operation, with the forwarded operands of the + // condition used as the results of the [While] operation. + // + // We can return a single statically-known region if we were given a constant condition + // value, otherwise we must return both possible regions. + let cond = operands[0].as_deref(); + match cond { + None => { + let after_region = self + .parent_op() + .unwrap() + .borrow() + .downcast_ref::() + .expect("expected `Condition` op to be a child of a `While` op") + .after() + .as_region_ref(); + // We can't know the condition until runtime, so both the parent `while` op and + // the `after` region could be successors + let if_false = RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + // the `forwarded` operand group + operand_group: 1, + }; + let if_true = RegionSuccessorInfo { + successor: RegionBranchPoint::Child(after_region), + key: None, + // the `forwarded` operand group + operand_group: 1, + }; + smallvec![if_false, if_true] + } + Some(value) => { + // Extract the boolean value of the condition + let should_continue = if let Some(imm) = value.downcast_ref::() { + imm.as_bool().expect("invalid boolean immediate for 'hir.condition'") + } else if let Some(yes) = value.downcast_ref::() { + *yes + } else { + panic!("expected boolean immediate for 'hir.condition'") + }; + + // Choose the specific region successor implied by the condition + if should_continue { + // Proceed to the 'after' region + let after_region = self + .parent_op() + .unwrap() + .borrow() + .downcast_ref::() + .expect("expected `Condition` op to be a child of a `While` op") + .after() + .as_region_ref(); + smallvec![RegionSuccessorInfo { + successor: RegionBranchPoint::Child(after_region), + key: None, + // the `forwarded` operand group + operand_group: 1, + }] + } else { + // Break out to the parent 'while' operation + smallvec![RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + // the `forwarded` operand group + operand_group: 1, + }] + } + } + } + } +} + +/// The [Yield] op is used in conjunction with [If] and [While] ops as a return-like terminator. +/// +/// * With [If], its regions must be terminated with either a [Yield] or an [Unreachable] op. +/// * With [While], a [Yield] is only valid in the `after` region, and the yielded operands must +/// match the region arguments of the `before` region. Thus to return values from the body of a +/// loop, one must first yield them from the `after` region to the `before` region using [Yield], +/// and then yield them from the `before` region by passsing them as forwarded operands of the +/// [Condition] op. +/// +/// Any number of operands can be yielded at the same time. However, when [Yield] is used in +/// conjunction with [While], the arity and type of the operands must match the region arguments +/// of the `before` region. When used in conjunction with [If], both the `if_true` and `if_false` +/// regions must yield the same arity and types. +#[operation( + dialect = HirDialect, + traits(Terminator, ReturnLike), + implements(RegionBranchTerminatorOpInterface) +)] +pub struct Yield { + #[operands] + yielded: AnyType, +} + +impl RegionBranchTerminatorOpInterface for Yield { + #[inline] + fn get_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> { + SuccessorOperandRange::forward(self.yielded()) + } + + fn get_mutable_successor_operands( + &mut self, + _point: RegionBranchPoint, + ) -> SuccessorOperandRangeMut<'_> { + SuccessorOperandRangeMut::forward(self.yielded_mut()) + } + + fn get_successor_regions( + &self, + _operands: &[Option>], + ) -> SmallVec<[RegionSuccessorInfo; 2]> { + // Depending on the type of operation containing this yield, the set of successor regions + // is always known. + // + // * [While] may only have a yield to its `before` region + // * [If] may only yield to its parent + let parent_op = self.parent_op().unwrap(); + let parent_op = parent_op.borrow(); + if parent_op.is::() { + smallvec![RegionSuccessorInfo { + successor: RegionBranchPoint::Parent, + key: None, + operand_group: 0, + }] + } else if let Some(while_op) = parent_op.downcast_ref::() { + let before_region = while_op.before().as_region_ref(); + smallvec![RegionSuccessorInfo { + successor: RegionBranchPoint::Child(before_region), + key: None, + operand_group: 0, + }] + } else { + panic!("unsupported parent operation for '{}': '{}'", self.name(), parent_op.name()) + } + } +} diff --git a/hir2/src/dialects/hir/ops/function.rs b/hir2/src/dialects/hir/ops/function.rs new file mode 100644 index 000000000..0c474a4c7 --- /dev/null +++ b/hir2/src/dialects/hir/ops/function.rs @@ -0,0 +1,184 @@ +use crate::{ + derive::operation, + dialects::hir::HirDialect, + traits::{IsolatedFromAbove, SingleRegion}, + Block, BlockRef, CallableOpInterface, Ident, Op, Operation, OperationRef, RegionKind, + RegionKindInterface, RegionRef, Report, Signature, Symbol, SymbolName, SymbolNameAttr, + SymbolRef, SymbolUse, SymbolUseList, SymbolUseRef, SymbolUsesIter, Usable, Visibility, +}; + +trait UsableSymbol = Usable; + +#[operation( + dialect = HirDialect, + traits(SingleRegion, IsolatedFromAbove), + implements( + UsableSymbol, + Symbol, + CallableOpInterface, + RegionKindInterface + ) +)] +pub struct Function { + #[attr] + name: Ident, + #[attr] + signature: Signature, + #[region] + body: RegionRef, + /// The uses of this function as a symbol + #[default] + uses: SymbolUseList, +} + +/// Builders +impl Function { + /// Conver this function from a declaration (no body) to a definition (has a body) by creating + /// the entry block based on the function signature. + /// + /// NOTE: The resulting function is _invalid_ until the block has a terminator inserted into it. + /// + /// This function will panic if an entry block has already been created + pub fn create_entry_block(&mut self) -> BlockRef { + use crate::EntityWithParent; + + assert!(self.body().is_empty(), "entry block already exists"); + let signature = self.signature(); + let block = self + .as_operation() + .context() + .create_block_with_params(signature.params().iter().map(|p| p.ty.clone())); + let mut body = self.body_mut(); + body.push_back(block.clone()); + Block::on_inserted_into_parent(block.clone(), body.as_region_ref()); + block + } +} + +/// Accessors +impl Function { + #[inline] + pub fn entry_block(&self) -> BlockRef { + unsafe { BlockRef::from_raw(&*self.body().entry()) } + } + + pub fn last_block(&self) -> BlockRef { + self.body() + .body() + .back() + .as_pointer() + .expect("cannot access blocks of a function declaration") + } +} + +impl RegionKindInterface for Function { + #[inline(always)] + fn kind(&self) -> RegionKind { + RegionKind::SSA + } +} + +impl Usable for Function { + type Use = SymbolUse; + + #[inline(always)] + fn uses(&self) -> &SymbolUseList { + &self.uses + } + + #[inline(always)] + fn uses_mut(&mut self) -> &mut SymbolUseList { + &mut self.uses + } +} + +impl Symbol for Function { + #[inline(always)] + fn as_symbol_operation(&self) -> &Operation { + &self.op + } + + #[inline(always)] + fn as_symbol_operation_mut(&mut self) -> &mut Operation { + &mut self.op + } + + fn name(&self) -> SymbolName { + Self::name(self).as_symbol() + } + + fn set_name(&mut self, name: SymbolName) { + let id = self.name_mut(); + id.name = name; + } + + fn visibility(&self) -> Visibility { + self.signature().visibility + } + + fn set_visibility(&mut self, visibility: Visibility) { + self.signature_mut().visibility = visibility; + } + + fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter { + SymbolUsesIter::from_iter(self.uses.iter().filter_map(|user| { + if OperationRef::ptr_eq(&from, &user.owner) + || from.borrow().is_proper_ancestor_of(&user.owner) + { + Some(unsafe { SymbolUseRef::from_raw(&*user) }) + } else { + None + } + })) + } + + fn replace_all_uses( + &mut self, + replacement: SymbolRef, + from: OperationRef, + ) -> Result<(), Report> { + for symbol_use in self.symbol_uses(from) { + let (mut owner, attr_name) = { + let user = symbol_use.borrow(); + (user.owner.clone(), user.symbol) + }; + let mut owner = owner.borrow_mut(); + // Unlink previously used symbol + { + let current_symbol = owner + .get_typed_attribute_mut::(attr_name) + .expect("stale symbol user"); + unsafe { + self.uses.cursor_mut_from_ptr(current_symbol.user.clone()).remove(); + } + } + // Link replacement symbol + owner.set_symbol_attribute(attr_name, replacement.clone()); + } + + Ok(()) + } + + /// Returns true if this operation is a declaration, rather than a definition, of a symbol + /// + /// The default implementation assumes that all operations are definitions + #[inline] + fn is_declaration(&self) -> bool { + self.body().is_empty() + } +} + +impl CallableOpInterface for Function { + fn get_callable_region(&self) -> Option { + if self.is_declaration() { + None + } else { + self.op.regions().front().as_pointer() + } + } + + #[inline] + fn signature(&self) -> &Signature { + Function::signature(self) + } +} diff --git a/hir2/src/dialects/hir/ops/invoke.rs b/hir2/src/dialects/hir/ops/invoke.rs new file mode 100644 index 000000000..192e14edf --- /dev/null +++ b/hir2/src/dialects/hir/ops/invoke.rs @@ -0,0 +1,112 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +#[operation( + dialect = HirDialect, + implements(CallOpInterface) +)] +pub struct Exec { + #[symbol(callable)] + callee: SymbolNameAttr, + #[operands] + arguments: AnyType, +} + +impl InferTypeOpInterface for Exec { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + + let span = self.span(); + let owner = self.as_operation().as_operation_ref(); + if let Some(symbol) = self.resolve() { + let symbol = symbol.borrow(); + if let Some(callable) = + symbol.as_symbol_operation().as_trait::() + { + let signature = callable.signature(); + for (i, result) in signature.results().iter().enumerate() { + let value = + context.make_result(span, result.ty.clone(), owner.clone(), i as u8); + self.op.results.push(value); + } + + Ok(()) + } else { + Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + span, + "invalid callee: does not implement CallableOpInterface", + ) + .with_secondary_label( + symbol.as_symbol_operation().span, + "symbol refers to this definition", + ) + .into_report()) + } + } else { + Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label(span, "invalid callee: symbol is undefined") + .into_report()) + } + } +} + +/* +#[operation( + dialect = HirDialect, + implements(CallOpInterface) +)] +pub struct ExecIndirect { + #[attr] + signature: Signature, + /// TODO(pauls): Change this to FunctionType + #[operand] + callee: AnyType, +} + */ +impl CallOpInterface for Exec { + #[inline(always)] + fn callable_for_callee(&self) -> Callable { + self.callee().into() + } + + fn set_callee(&mut self, callable: Callable) { + let callee = callable.unwrap_symbol_name(); + *self.callee_mut() = callee; + } + + #[inline(always)] + fn arguments(&self) -> OpOperandRange<'_> { + self.operands().group(0) + } + + #[inline(always)] + fn arguments_mut(&mut self) -> OpOperandRangeMut<'_> { + self.operands_mut().group_mut(0) + } + + fn resolve(&self) -> Option { + let callee = self.callee(); + if callee.has_parent() { + todo!() + } + let module = self.as_operation().nearest_symbol_table()?; + let module = module.borrow(); + let symbol_table = module.as_trait::().unwrap(); + symbol_table.get(callee.name) + } + + fn resolve_in_symbol_table(&self, symbols: &dyn crate::SymbolTable) -> Option { + let callee = self.callee(); + symbols.get(callee.name) + } +} diff --git a/hir2/src/dialects/hir/ops/mem.rs b/hir2/src/dialects/hir/ops/mem.rs new file mode 100644 index 000000000..1650e8aa4 --- /dev/null +++ b/hir2/src/dialects/hir/ops/mem.rs @@ -0,0 +1,62 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryWrite) +)] +pub struct Store { + #[operand] + addr: AnyPointer, + #[operand] + value: AnyType, +} + +// TODO(pauls): StoreLocal + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryRead), + implements(InferTypeOpInterface) +)] +pub struct Load { + #[operand] + addr: AnyPointer, + #[result] + result: AnyType, +} + +impl InferTypeOpInterface for Load { + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report> { + let span = self.span(); + let pointee = { + let addr = self.addr(); + let addr_value = addr.value(); + addr_value.ty().pointee().cloned() + }; + match pointee { + Some(pointee) => { + self.result_mut().set_type(pointee); + Ok(()) + } + None => { + let addr = self.addr(); + let addr_value = addr.value(); + let addr_ty = addr_value.ty(); + Err(context + .session + .diagnostics + .diagnostic(miden_assembly::diagnostics::Severity::Error) + .with_message("invalid operand for 'load'") + .with_primary_label( + span, + format!("invalid 'addr' operand, expected pointer, got '{addr_ty}'"), + ) + .into_report()) + } + } + } +} + +// TODO(pauls): LoadLocal diff --git a/hir2/src/dialects/hir/ops/module.rs b/hir2/src/dialects/hir/ops/module.rs new file mode 100644 index 000000000..eddaf27d5 --- /dev/null +++ b/hir2/src/dialects/hir/ops/module.rs @@ -0,0 +1,226 @@ +use alloc::collections::BTreeMap; + +use crate::{ + derive::operation, + dialects::hir::HirDialect, + symbol_table::SymbolUsesIter, + traits::{ + GraphRegionNoTerminator, HasOnlyGraphRegion, IsolatedFromAbove, NoRegionArguments, + NoTerminator, SingleBlock, SingleRegion, + }, + Ident, InsertionPoint, Operation, OperationRef, RegionKind, RegionKindInterface, Report, + Symbol, SymbolName, SymbolNameAttr, SymbolRef, SymbolTable, SymbolUseList, SymbolUseRef, + Usable, Visibility, +}; + +#[operation( + dialect = HirDialect, + traits( + SingleRegion, + SingleBlock, + NoRegionArguments, + NoTerminator, + HasOnlyGraphRegion, + GraphRegionNoTerminator, + IsolatedFromAbove, + ), + implements(RegionKindInterface, SymbolTable, Symbol) +)] +pub struct Module { + #[attr] + name: Ident, + #[attr] + #[default] + visibility: Visibility, + #[region] + body: RegionRef, + #[default] + registry: BTreeMap, + #[default] + uses: SymbolUseList, +} + +impl RegionKindInterface for Module { + #[inline(always)] + fn kind(&self) -> RegionKind { + RegionKind::Graph + } +} + +impl Usable for Module { + type Use = crate::SymbolUse; + + #[inline(always)] + fn uses(&self) -> &SymbolUseList { + &self.uses + } + + #[inline(always)] + fn uses_mut(&mut self) -> &mut SymbolUseList { + &mut self.uses + } +} + +impl Symbol for Module { + #[inline(always)] + fn as_symbol_operation(&self) -> &Operation { + &self.op + } + + #[inline(always)] + fn as_symbol_operation_mut(&mut self) -> &mut Operation { + &mut self.op + } + + fn name(&self) -> SymbolName { + Module::name(self).as_symbol() + } + + fn set_name(&mut self, name: SymbolName) { + let id = self.name_mut(); + id.name = name; + } + + fn visibility(&self) -> Visibility { + *Module::visibility(self) + } + + fn set_visibility(&mut self, visibility: Visibility) { + *self.visibility_mut() = visibility; + } + + fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter { + SymbolUsesIter::from_iter(self.uses.iter().filter_map(|user| { + if OperationRef::ptr_eq(&from, &user.owner) + || from.borrow().is_proper_ancestor_of(&user.owner) + { + Some(unsafe { SymbolUseRef::from_raw(&*user) }) + } else { + None + } + })) + } + + fn replace_all_uses( + &mut self, + replacement: SymbolRef, + from: OperationRef, + ) -> Result<(), Report> { + for symbol_use in self.symbol_uses(from) { + let (mut owner, attr_name) = { + let user = symbol_use.borrow(); + (user.owner.clone(), user.symbol) + }; + let mut owner = owner.borrow_mut(); + // Unlink previously used symbol + { + let current_symbol = owner + .get_typed_attribute_mut::(attr_name) + .expect("stale symbol user"); + unsafe { + self.uses.cursor_mut_from_ptr(current_symbol.user.clone()).remove(); + } + } + // Link replacement symbol + owner.set_symbol_attribute(attr_name, replacement.clone()); + } + + Ok(()) + } +} + +impl SymbolTable for Module { + #[inline(always)] + fn as_symbol_table_operation(&self) -> &Operation { + &self.op + } + + #[inline(always)] + fn as_symbol_table_operation_mut(&mut self) -> &mut Operation { + &mut self.op + } + + fn get(&self, name: SymbolName) -> Option { + self.registry.get(&name).cloned() + } + + fn insert_new(&mut self, entry: SymbolRef, ip: Option) -> bool { + use crate::{BlockRef, Builder, OpBuilder}; + let op = { + let symbol = entry.borrow(); + let name = symbol.name(); + if self.registry.contains_key(&name) { + return false; + } + let op = symbol.as_operation_ref(); + drop(symbol); + self.registry.insert(name, entry.clone()); + op + }; + let mut builder = OpBuilder::new(self.op.context_rc()); + if let Some(ip) = ip { + builder.set_insertion_point(ip); + } else { + builder.set_insertion_point_to_end(unsafe { BlockRef::from_raw(&*self.body().entry()) }) + } + builder.insert(op); + true + } + + fn insert(&mut self, mut entry: SymbolRef, ip: Option) -> SymbolName { + let name = { + let mut symbol = entry.borrow_mut(); + let mut name = symbol.name(); + if self.registry.contains_key(&name) { + // Unique the symbol name + let mut counter = 0; + name = crate::symbol_table::generate_symbol_name(name, &mut counter, |name| { + self.registry.contains_key(name) + }); + symbol.set_name(name); + } + name + }; + self.insert_new(entry, ip); + name + } + + fn remove(&mut self, name: SymbolName) -> Option { + if let Some(ptr) = self.registry.remove(&name) { + let op = ptr.borrow().as_operation_ref(); + let mut body = self.body_mut(); + let mut entry = body.entry_mut(); + let mut cursor = unsafe { entry.body_mut().cursor_mut_from_ptr(op) }; + cursor.remove().map(|_| ptr) + } else { + None + } + } + + fn rename(&mut self, from: SymbolName, to: SymbolName) -> Result<(), Report> { + if let Some(symbol) = self.registry.get_mut(&from) { + let mut sym = symbol.borrow_mut(); + sym.set_name(to); + let uses = sym.uses_mut(); + let mut cursor = uses.front_mut(); + while let Some(mut next_use) = cursor.as_pointer() { + { + let mut next_use = next_use.borrow_mut(); + let mut op = next_use.owner.borrow_mut(); + let symbol_name = op + .get_typed_attribute_mut::(next_use.symbol) + .expect("stale symbol user"); + symbol_name.name = to; + } + cursor.move_next(); + } + + Ok(()) + } else { + Err(Report::msg(format!( + "unable to rename '{from}': no such symbol in '{}'", + self.name().as_str() + ))) + } + } +} diff --git a/hir2/src/dialects/hir/ops/primop.rs b/hir2/src/dialects/hir/ops/primop.rs new file mode 100644 index 000000000..df0401d74 --- /dev/null +++ b/hir2/src/dialects/hir/ops/primop.rs @@ -0,0 +1,63 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryRead, MemoryWrite, SameOperandsAndResultType) +)] +pub struct MemGrow { + #[operand] + pages: UInt32, + #[result] + result: UInt32, +} + +impl InferTypeOpInterface for MemGrow { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type(Type::I32); + Ok(()) + } +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryRead) +)] +pub struct MemSize { + #[result] + result: UInt32, +} + +impl InferTypeOpInterface for MemSize { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type(Type::I32); + Ok(()) + } +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryWrite) +)] +pub struct MemSet { + #[operand] + addr: AnyPointer, + #[operand] + count: UInt32, + #[operand] + value: AnyType, +} + +#[operation( + dialect = HirDialect, + traits(HasSideEffects, MemoryRead, MemoryWrite) +)] +pub struct MemCpy { + #[operand] + source: AnyPointer, + #[operand] + destination: AnyPointer, + #[operand] + count: UInt32, +} diff --git a/hir2/src/dialects/hir/ops/ternary.rs b/hir2/src/dialects/hir/ops/ternary.rs new file mode 100644 index 000000000..90e4bcd9a --- /dev/null +++ b/hir2/src/dialects/hir/ops/ternary.rs @@ -0,0 +1,27 @@ +use midenc_hir_macros::operation; + +use crate::{dialects::hir::HirDialect, traits::*, *}; + +/// Choose a value based on a boolean condition +#[operation( + dialect = HirDialect, + implements(InferTypeOpInterface) +)] +pub struct Select { + #[operand] + cond: Bool, + #[operand] + first: AnyInteger, + #[operand] + second: AnyInteger, + #[result] + result: AnyInteger, +} + +impl InferTypeOpInterface for Select { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let ty = self.first().ty().clone(); + self.result_mut().set_type(ty); + Ok(()) + } +} diff --git a/hir2/src/dialects/hir/ops/unary.rs b/hir2/src/dialects/hir/ops/unary.rs new file mode 100644 index 000000000..c27165d28 --- /dev/null +++ b/hir2/src/dialects/hir/ops/unary.rs @@ -0,0 +1,206 @@ +use crate::{derive::operation, dialects::hir::HirDialect, traits::*, *}; + +macro_rules! infer_return_ty_for_unary_op { + ($Op:ty) => { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + let lhs = self.operand().ty().clone(); + self.result_mut().set_type(lhs); + Ok(()) + } + } + }; + + ($Op:ty as $manually_specified_ty:expr) => { + paste::paste! { + impl InferTypeOpInterface for $Op { + fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> { + self.result_mut().set_type($manually_specified_ty); + Ok(()) + } + } + } + }; +} + +/// Increment +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Incr { + #[operand] + operand: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_unary_op!(Incr); + +/// Negation +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Neg { + #[operand] + operand: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_unary_op!(Neg); + +/// Modular inverse +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Inv { + #[operand] + operand: IntFelt, + #[result] + result: IntFelt, +} + +infer_return_ty_for_unary_op!(Inv); + +/// log2(operand) +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Ilog2 { + #[operand] + operand: IntFelt, + #[result] + result: IntFelt, +} + +infer_return_ty_for_unary_op!(Ilog2); + +/// pow2(operand) +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Pow2 { + #[operand] + operand: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_unary_op!(Pow2); + +/// Logical NOT +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Not { + #[operand] + operand: Bool, + #[result] + result: Bool, +} + +infer_return_ty_for_unary_op!(Not); + +/// Bitwise NOT +#[operation ( + dialect = HirDialect, + traits(UnaryOp, SameOperandsAndResultType) + )] +pub struct Bnot { + #[operand] + operand: AnyInteger, + #[result] + result: AnyInteger, +} + +infer_return_ty_for_unary_op!(Bnot); + +/// is_odd(operand) +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct IsOdd { + #[operand] + operand: AnyInteger, + #[result] + result: Bool, +} + +infer_return_ty_for_unary_op!(IsOdd as Type::I1); + +/// Count of non-zero bits (population count) +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Popcnt { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} + +infer_return_ty_for_unary_op!(Popcnt as Type::U32); + +/// Count Leading Zeros +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Clz { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} + +infer_return_ty_for_unary_op!(Clz as Type::U32); + +/// Count Trailing Zeros +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Ctz { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} + +infer_return_ty_for_unary_op!(Ctz as Type::U32); + +/// Count Leading Ones +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Clo { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} + +infer_return_ty_for_unary_op!(Clo as Type::U32); + +/// Count Trailing Ones +#[operation ( + dialect = HirDialect, + traits(UnaryOp) + )] +pub struct Cto { + #[operand] + operand: AnyInteger, + #[result] + result: UInt32, +} + +infer_return_ty_for_unary_op!(Cto as Type::U32); diff --git a/hir2/src/folder.rs b/hir2/src/folder.rs new file mode 100644 index 000000000..b95367e15 --- /dev/null +++ b/hir2/src/folder.rs @@ -0,0 +1,487 @@ +use alloc::{collections::BTreeMap, rc::Rc}; + +use smallvec::{smallvec, SmallVec}; + +use crate::{ + matchers::Matcher, + traits::{ConstantLike, Foldable, IsolatedFromAbove}, + AttributeValue, BlockRef, Builder, Context, Dialect, FoldResult, FxHashMap, OpFoldResult, + OperationRef, RegionRef, Rewriter, RewriterImpl, RewriterListener, SourceSpan, Spanned, Type, + Value, ValueRef, +}; + +/// Represents a constant value uniqued by dialect, value, and type. +struct UniquedConstant { + dialect: Rc, + value: Box, + ty: Type, +} +impl Eq for UniquedConstant {} +impl PartialEq for UniquedConstant { + fn eq(&self, other: &Self) -> bool { + use core::hash::{Hash, Hasher}; + + let v1_hash = { + let mut hasher = rustc_hash::FxHasher::default(); + self.value.hash(&mut hasher); + hasher.finish() + }; + let v2_hash = { + let mut hasher = rustc_hash::FxHasher::default(); + other.value.hash(&mut hasher); + hasher.finish() + }; + + self.dialect.name() == other.dialect.name() && v1_hash == v2_hash && self.ty == other.ty + } +} +impl UniquedConstant { + pub fn new(op: &OperationRef, value: Box) -> Self { + let op = op.borrow(); + let dialect = op.dialect(); + let ty = op.results()[0].borrow().ty().clone(); + + Self { dialect, value, ty } + } +} +impl core::hash::Hash for UniquedConstant { + fn hash(&self, state: &mut H) { + self.dialect.name().hash(state); + self.value.hash(state); + self.ty.hash(state); + } +} + +/// A map of uniqued constants, to their defining operations +type ConstantMap = FxHashMap; + +/// The [OperationFolder] is responsible for orchestrating operation folding, and the effects that +/// folding an operation has on the containing region. +/// +/// It handles the following details related to operation folding: +/// +/// * Attempting to fold the operation itself +/// * Materializing constants +/// * Uniquing/de-duplicating materialized constants, including moving them up the CFG to ensure +/// that new uses of a uniqued constant are dominated by the constant definition. +/// * Removing folded operations (or cleaning up failed attempts), and notifying any listeners of +/// those actions. +pub struct OperationFolder { + rewriter: Box, + /// A mapping between an insertion region and the constants that have been created within it. + scopes: BTreeMap, + /// This map tracks all of the dialects that an operation is referenced by; given that multiple + /// dialects may generate the same constant. + referenced_dialects: BTreeMap; 1]>>, + /// The location to use for folder-owned constants + erased_folded_location: SourceSpan, +} +impl OperationFolder { + pub fn new(context: Rc, listener: L) -> Self + where + L: RewriterListener, + { + Self { + rewriter: Box::new(RewriterImpl::::new(context).with_listener(listener)), + scopes: Default::default(), + referenced_dialects: Default::default(), + erased_folded_location: SourceSpan::UNKNOWN, + } + } + + /// Tries to perform folding on `op`, including unifying de-duplicated constants. + /// + /// If successful, replaces uses of `op`'s results with the folded results, and returns + /// a [FoldResult]. + pub fn try_fold(&mut self, mut op: OperationRef) -> FoldResult { + // If this is a uniqued constant, return failure as we know that it has already been + // folded. + if self.is_folder_owned_constant(&op) { + // Check to see if we should rehoist, i.e. if a non-constant operation was inserted + // before this one. + let block = op.borrow().parent().unwrap(); + if block.borrow().front().unwrap() != op + && !self.is_folder_owned_constant(&op.prev().unwrap()) + { + let mut op = op.borrow_mut(); + op.move_before(crate::ProgramPoint::Block(block)); + op.set_span(self.erased_folded_location); + } + return FoldResult::Failed; + } + + // Try to fold the operation + let mut fold_results = SmallVec::default(); + match op.borrow_mut().fold(&mut fold_results) { + FoldResult::InPlace => { + // Folding API does not notify listeners, so we need to do so manually + self.rewriter.notify_operation_modified(op); + + FoldResult::InPlace + } + FoldResult::Ok(_) => { + assert!( + !fold_results.is_empty(), + "expected non-empty fold results from a successful fold" + ); + if let FoldResult::Ok(replacements) = + self.process_fold_results(op.clone(), &fold_results) + { + // Constant folding succeeded. Replace all of the result values and erase the + // operation. + self.notify_removal(op.clone()); + self.rewriter.replace_op_with_values(op, &replacements); + FoldResult::Ok(()) + } else { + FoldResult::Failed + } + } + failed @ FoldResult::Failed => failed, + } + } + + /// Try to process a set of fold results. + /// + /// Returns the folded values if successful. + fn process_fold_results( + &mut self, + op: OperationRef, + fold_results: &[OpFoldResult], + ) -> FoldResult> { + let borrowed_op = op.borrow(); + assert_eq!(fold_results.len(), borrowed_op.num_results()); + + // Create a builder to insert new operations into the entry block of the insertion region + let insert_region = get_insertion_region(borrowed_op.parent().unwrap()); + let entry = insert_region.borrow().entry_block_ref().unwrap(); + self.rewriter.set_insertion_point_to_start(entry); + + // Create the result constants and replace the results. + let dialect = borrowed_op.dialect(); + let mut out = SmallVec::default(); + for (op_result, fold_result) in borrowed_op.results().iter().zip(fold_results) { + match fold_result { + // Check if the result was an SSA value. + OpFoldResult::Value(value) => { + out.push(value.clone()); + continue; + } + // Check to see if there is a canonicalized version of this constant. + OpFoldResult::Attribute(attr_repl) => { + if let Some(mut const_op) = self.try_get_or_create_constant( + insert_region.clone(), + dialect.clone(), + attr_repl.clone_value(), + op_result.borrow().ty().clone(), + self.erased_folded_location, + ) { + // Ensure that this constant dominates the operation we are replacing. + // + // This may not automatically happen if the operation being folded was + // inserted before the constant within the insertion block. + let op_block = borrowed_op.parent().unwrap(); + if op_block == const_op.borrow().parent().unwrap() + && op_block.borrow().front().unwrap() != const_op + { + const_op.borrow_mut().move_before(crate::ProgramPoint::Block(op_block)); + } + out.push(const_op.borrow().get_result(0).borrow().as_value_ref()); + continue; + } + + // If materialization fails, clean up any operations generated for the previous + // results and return failure. + let inserted_before = self.rewriter.insertion_point().unwrap().op(); + if let Some(inserted_before) = inserted_before { + while let Some(inserted_op) = inserted_before.prev() { + self.notify_removal(inserted_op.clone()); + self.rewriter.erase_op(inserted_op); + } + } + + return FoldResult::Failed; + } + } + } + + FoldResult::Ok(out) + } + + /// Notifies that the given constant `op` should be removed from this folder's internal + /// bookkeeping. + /// + /// NOTE: This method must be called if a constant op is to be deleted externally to this + /// folder. `op` must be constant. + pub fn notify_removal(&mut self, op: OperationRef) { + // Check to see if this operation is uniqued within the folder. + let Some(referenced_dialects) = self.referenced_dialects.get_mut(&op) else { + return; + }; + + let borrowed_op = op.borrow(); + + // Get the constant value for this operation, this is the value that was used to unique + // the operation internally. + let value = crate::matchers::constant().matches(&borrowed_op).unwrap(); + + // Get the constant map that this operation was uniqued in. + let insert_region = get_insertion_region(borrowed_op.parent().unwrap()); + let uniqued_constants = self.scopes.get_mut(&insert_region).unwrap(); + + // Erase all of the references to this operation. + let ty = borrowed_op.results()[0].borrow().ty().clone(); + for dialect in referenced_dialects.drain(..) { + let uniqued_constant = UniquedConstant { + dialect, + value: value.clone_value(), + ty: ty.clone(), + }; + uniqued_constants.remove(&uniqued_constant); + } + } + + /// CLear out any constants cached inside the folder. + pub fn clear(&mut self) { + self.scopes.clear(); + self.referenced_dialects.clear(); + } + + /// Tries to fold a pre-existing constant operation. + /// + /// `value` represents the value of the constant, and can be optionally passed if the value is + /// already known (e.g. if the constant was discovered by a pattern match). This is purely an + /// optimization opportunity for callers that already know the value of the constant. + /// + /// Returns `false` if an existing constant for `op` already exists in the folder, in which case + /// `op` is replaced and erased. Otherwise, returns `true` and `op` is inserted into the folder + /// and hoisted if necessary. + pub fn insert_known_constant( + &mut self, + mut op: OperationRef, + value: Option>, + ) -> bool { + let block = op.borrow().parent().unwrap(); + + // If this is a constant we uniqued, we don't need to insert, but we can check to see if + // we should rehoist it. + if self.is_folder_owned_constant(&op) { + if block.borrow().front().unwrap() != op + && !self.is_folder_owned_constant(&op.prev().unwrap()) + { + let mut op = op.borrow_mut(); + op.move_before(crate::ProgramPoint::Block(block)); + op.set_span(self.erased_folded_location); + } + return true; + } + + // Get the constant value of the op if necessary. + let value = value.unwrap_or_else(|| { + crate::matchers::constant() + .matches(&op.borrow()) + .expect("expected `op` to be a constant") + }); + + // Check for an existing constant operation for the attribute value. + let insert_region = get_insertion_region(block.clone()); + let uniqued_constants = self.scopes.entry(insert_region.clone()).or_default(); + let uniqued_constant = UniquedConstant::new(&op, value); + let mut is_new = false; + let mut folder_const_op = uniqued_constants + .entry(uniqued_constant) + .or_insert_with(|| { + is_new = true; + op.clone() + }) + .clone(); + + // If there is an existing constant, replace `op` + if !is_new { + self.notify_removal(op.clone()); + self.rewriter.replace_op(op, folder_const_op.clone()); + folder_const_op.borrow_mut().set_span(self.erased_folded_location); + return false; + } + + // Otherwise, we insert `op`. If `op` is in the insertion block and is either already at the + // front of the block, or the previous operation is already a constant we uniqued (i.e. one + // we inserted), then we don't need to do anything. Otherwise, we move the constant to the + // insertion block. + let insert_block = insert_region.borrow().entry_block_ref().unwrap(); + if block != insert_block + || (insert_block.borrow().front().unwrap() != op + && !self.is_folder_owned_constant(&op.prev().unwrap())) + { + let mut op = op.borrow_mut(); + op.move_before(crate::ProgramPoint::Block(insert_block)); + op.set_span(self.erased_folded_location); + } + + let referenced_dialects = self.referenced_dialects.entry(op.clone()).or_default(); + let dialect = op.borrow().dialect(); + let dialect_name = dialect.name(); + if !referenced_dialects.iter().any(|d| d.name() == dialect_name) { + referenced_dialects.push(dialect); + } + + true + } + + /// Get or create a constant for use in the specified block. + /// + /// The constant may be created in a parent block. On success, this returns the result of the + /// constant operation, or `None` otherwise. + pub fn get_or_create_constant( + &mut self, + block: BlockRef, + dialect: Rc, + value: Box, + ty: Type, + ) -> Option { + // Find an insertion point for the constant. + let insert_region = get_insertion_region(block.clone()); + let entry = insert_region.borrow().entry_block_ref().unwrap(); + self.rewriter.set_insertion_point_to_start(entry); + + // Get the constant map for the insertion region of this operation. + // Use erased location since the op is being built at the front of the block. + let const_op = self.try_get_or_create_constant( + insert_region, + dialect, + value, + ty, + self.erased_folded_location, + )?; + Some(const_op.borrow().results()[0].borrow().as_value_ref()) + } + + /// Try to get or create a new constant entry. + /// + /// On success, this returns the constant operation, `None` otherwise + fn try_get_or_create_constant( + &mut self, + insert_region: RegionRef, + dialect: Rc, + value: Box, + ty: Type, + span: SourceSpan, + ) -> Option { + let uniqued_constants = self.scopes.entry(insert_region).or_default(); + let uniqued_constant = UniquedConstant { + dialect: dialect.clone(), + value, + ty, + }; + if let Some(mut const_op) = uniqued_constants.get(&uniqued_constant).cloned() { + { + let mut const_op = const_op.borrow_mut(); + if const_op.span() != span { + const_op.set_span(span); + } + } + return Some(const_op); + } + + // If one doesn't exist, try to materialize one. + let const_op = materialize_constant( + self.rewriter.as_mut(), + dialect.clone(), + uniqued_constant.value.clone_value(), + &uniqued_constant.ty, + span, + )?; + + // Check to see if the generated constant is in the expected dialect. + let new_dialect = const_op.borrow().dialect(); + if new_dialect.name() == dialect.name() { + self.referenced_dialects.entry(const_op.clone()).or_default().push(new_dialect); + return Some(const_op); + } + + // If it isn't, then we also need to make sure that the mapping for the new dialect is valid + let new_uniqued_constant = UniquedConstant { + dialect: new_dialect.clone(), + value: uniqued_constant.value.clone_value(), + ty: uniqued_constant.ty.clone(), + }; + let maybe_existing_op = uniqued_constants.get(&new_uniqued_constant).cloned(); + uniqued_constants.insert( + uniqued_constant, + maybe_existing_op.clone().unwrap_or_else(|| const_op.clone()), + ); + if let Some(mut existing_op) = maybe_existing_op { + self.notify_removal(const_op.clone()); + self.rewriter.erase_op(const_op); + self.referenced_dialects + .get_mut(&existing_op) + .unwrap() + .push(new_uniqued_constant.dialect.clone()); + let mut existing = existing_op.borrow_mut(); + if existing.span() != span { + existing.set_span(span); + } + Some(existing_op) + } else { + self.referenced_dialects + .insert(const_op.clone(), smallvec![dialect, new_dialect]); + uniqued_constants.insert(new_uniqued_constant, const_op.clone()); + Some(const_op) + } + } + + /// Returns true if the given operation is an already folded constant that is owned by this + /// folder. + #[inline(always)] + fn is_folder_owned_constant(&self, op: &OperationRef) -> bool { + self.referenced_dialects.contains_key(op) + } +} + +/// Materialize a constant for a given attribute and type. +/// +/// Returns a constant operation if successful, otherwise `None` +fn materialize_constant( + builder: &mut dyn Builder, + dialect: Rc, + value: Box, + ty: &Type, + span: SourceSpan, +) -> Option { + let ip = builder.insertion_point().cloned(); + + // Ask the dialect to materialize a constant operation for this value. + let const_op = dialect.materialize_constant(builder, value, ty, span)?; + assert_eq!(ip.as_ref(), builder.insertion_point()); + assert!(const_op.borrow().implements::()); + Some(const_op) +} + +/// Given the containing block of an operation, find the parent region that folded constants should +/// be inserted into. +fn get_insertion_region(insertion_block: BlockRef) -> RegionRef { + use crate::EntityWithId; + + let mut insertion_block = Some(insertion_block); + while let Some(block) = insertion_block.take() { + let parent_region = block.borrow().parent().unwrap_or_else(|| { + panic!("expected block {} to be attached to a region", block.borrow().id()) + }); + // Insert in this region for any of the following scenarios: + // + // * The parent is known to be isolated from above + // * The parent is a top-level operation + let parent_op = parent_region + .borrow() + .parent() + .expect("expected region to be attached to an operation"); + let parent = parent_op.borrow(); + let parent_block = parent.parent(); + if parent.implements::() || parent_block.is_none() { + return parent_region; + } + + insertion_block = parent_block; + } + + unreachable!("expected valid insertion region") +} diff --git a/hir2/src/formatter.rs b/hir2/src/formatter.rs new file mode 100644 index 000000000..e2025a620 --- /dev/null +++ b/hir2/src/formatter.rs @@ -0,0 +1,59 @@ +use core::{cell::Cell, fmt}; + +pub use miden_core::{ + prettier::*, + utils::{DisplayHex, ToHex}, +}; + +pub struct DisplayIndent(pub usize); +impl fmt::Display for DisplayIndent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + const INDENT: &str = " "; + for _ in 0..self.0 { + f.write_str(INDENT)?; + } + Ok(()) + } +} + +/// Render an iterator of `T`, comma-separated +pub struct DisplayValues(Cell>); +impl DisplayValues { + pub fn new(inner: T) -> Self { + Self(Cell::new(Some(inner))) + } +} +impl fmt::Display for DisplayValues +where + T: fmt::Display, + I: Iterator, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let iter = self.0.take().unwrap(); + for (i, item) in iter.enumerate() { + if i == 0 { + write!(f, "{}", item)?; + } else { + write!(f, ", {}", item)?; + } + } + Ok(()) + } +} + +/// Render an `Option` using the `Display` impl for `T` +pub struct DisplayOptional<'a, T>(pub Option<&'a T>); +impl<'a, T: fmt::Display> fmt::Display for DisplayOptional<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.0 { + None => f.write_str("None"), + Some(item) => write!(f, "Some({item})"), + } + } +} +impl<'a, T: fmt::Display> fmt::Debug for DisplayOptional<'a, T> { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} diff --git a/hir2/src/hash.rs b/hir2/src/hash.rs new file mode 100644 index 000000000..d2c0f1959 --- /dev/null +++ b/hir2/src/hash.rs @@ -0,0 +1,97 @@ +use core::hash::{Hash, Hasher}; + +/// A type-erased version of [core::hash::Hash] +pub trait DynHash { + fn dyn_hash(&self, hasher: &mut dyn Hasher); +} + +impl DynHash for H { + #[inline] + fn dyn_hash(&self, hasher: &mut dyn Hasher) { + let mut hasher = DynHasher(hasher); + ::hash(self, &mut hasher) + } +} + +pub struct DynHasher<'a>(&'a mut dyn Hasher); + +impl<'a> DynHasher<'a> { + pub fn new(hasher: &'a mut H) -> Self + where + H: Hasher, + { + Self(hasher) + } +} + +impl<'a> Hasher for DynHasher<'a> { + #[inline] + fn finish(&self) -> u64 { + self.0.finish() + } + + #[inline] + fn write(&mut self, bytes: &[u8]) { + self.0.write(bytes) + } + + #[inline] + fn write_u8(&mut self, i: u8) { + self.0.write_u8(i); + } + + #[inline] + fn write_i8(&mut self, i: i8) { + self.0.write_i8(i); + } + + #[inline] + fn write_u16(&mut self, i: u16) { + self.0.write_u16(i); + } + + #[inline] + fn write_i16(&mut self, i: i16) { + self.0.write_i16(i); + } + + #[inline] + fn write_u32(&mut self, i: u32) { + self.0.write_u32(i); + } + + #[inline] + fn write_i32(&mut self, i: i32) { + self.0.write_i32(i); + } + + #[inline] + fn write_u64(&mut self, i: u64) { + self.0.write_u64(i); + } + + #[inline] + fn write_i64(&mut self, i: i64) { + self.0.write_i64(i); + } + + #[inline] + fn write_u128(&mut self, i: u128) { + self.0.write_u128(i); + } + + #[inline] + fn write_i128(&mut self, i: i128) { + self.0.write_i128(i); + } + + #[inline] + fn write_usize(&mut self, i: usize) { + self.0.write_usize(i); + } + + #[inline] + fn write_isize(&mut self, i: isize) { + self.0.write_isize(i); + } +} diff --git a/hir2/src/ir.rs b/hir2/src/ir.rs new file mode 100644 index 000000000..fd5981e20 --- /dev/null +++ b/hir2/src/ir.rs @@ -0,0 +1,89 @@ +mod block; +mod builder; +mod callable; +pub mod cfg; +mod component; +mod context; +mod dialect; +pub mod dominance; +mod entity; +mod ident; +mod immediates; +mod insert; +mod interface; +pub mod loops; +mod op; +mod operands; +mod operation; +mod print; +mod region; +mod successor; +pub(crate) mod symbol_table; +pub mod traits; +mod types; +mod usable; +mod value; +pub mod verifier; +mod visit; + +pub use midenc_hir_symbol as interner; +pub use midenc_session::diagnostics::{Report, SourceSpan, Span, Spanned}; + +pub use self::{ + block::{ + Block, BlockCursor, BlockCursorMut, BlockId, BlockList, BlockOperand, BlockOperandRef, + BlockRef, PostOrderBlockIter, PreOrderBlockIter, + }, + builder::{Builder, BuilderExt, InsertionGuard, Listener, ListenerType, OpBuilder}, + callable::*, + context::Context, + dialect::{Dialect, DialectName, DialectRegistration}, + entity::{ + Entity, EntityCursor, EntityCursorMut, EntityGroup, EntityId, EntityIter, EntityList, + EntityMut, EntityRange, EntityRangeMut, EntityRef, EntityStorage, EntityWithId, + EntityWithParent, MaybeDefaultEntityIter, RawEntityRef, StorableEntity, UnsafeEntityRef, + UnsafeIntrusiveEntityRef, + }, + ident::{FunctionIdent, Ident}, + immediates::{Felt, FieldElement, Immediate, StarkField}, + insert::{Insert, InsertionPoint, ProgramPoint}, + op::{BuildableOp, Op, OpExt, OpRegistration}, + operands::{ + OpOperand, OpOperandImpl, OpOperandList, OpOperandRange, OpOperandRangeMut, + OpOperandStorage, + }, + operation::{ + OpCursor, OpCursorMut, OpList, Operation, OperationBuilder, OperationName, OperationRef, + }, + print::{OpPrinter, OpPrintingFlags}, + region::{ + InvocationBounds, Region, RegionBranchOpInterface, RegionBranchPoint, + RegionBranchTerminatorOpInterface, RegionCursor, RegionCursorMut, RegionKind, + RegionKindInterface, RegionList, RegionRef, RegionSuccessor, RegionSuccessorInfo, + RegionSuccessorIter, RegionSuccessorMut, RegionTransformFailed, + }, + successor::{ + KeyedSuccessor, KeyedSuccessorRange, KeyedSuccessorRangeMut, OpSuccessor, OpSuccessorMut, + OpSuccessorRange, OpSuccessorRangeMut, OpSuccessorStorage, SuccessorInfo, SuccessorOperand, + SuccessorOperandRange, SuccessorOperandRangeMut, SuccessorOperands, SuccessorWithKey, + SuccessorWithKeyMut, + }, + symbol_table::{ + AsSymbolRef, InvalidSymbolRefError, Symbol, SymbolName, SymbolNameAttr, + SymbolNameComponent, SymbolNameComponents, SymbolRef, SymbolTable, SymbolUse, + SymbolUseCursor, SymbolUseCursorMut, SymbolUseIter, SymbolUseList, SymbolUseRef, + SymbolUsesIter, + }, + traits::{FoldResult, OpFoldResult}, + types::*, + usable::Usable, + value::{ + BlockArgument, BlockArgumentRef, OpResult, OpResultRange, OpResultRangeMut, OpResultRef, + OpResultStorage, Value, ValueId, ValueRef, + }, + verifier::{OpVerifier, Verify}, + visit::{ + OpVisitor, OperationVisitor, Searcher, SymbolVisitor, Visitor, WalkOrder, WalkResult, + WalkStage, Walkable, + }, +}; diff --git a/hir2/src/ir/block.rs b/hir2/src/ir/block.rs new file mode 100644 index 000000000..4b3622928 --- /dev/null +++ b/hir2/src/ir/block.rs @@ -0,0 +1,1118 @@ +use core::fmt; + +use smallvec::SmallVec; + +use super::*; + +/// A pointer to a [Block] +pub type BlockRef = UnsafeIntrusiveEntityRef; +/// An intrusive, doubly-linked list of [Block] +pub type BlockList = EntityList; +/// A cursor into a [BlockList] +pub type BlockCursor<'a> = EntityCursor<'a, Block>; +/// A mutable cursor into a [BlockList] +pub type BlockCursorMut<'a> = EntityCursorMut<'a, Block>; +/// An iterator over blocks produced by a depth-first, pre-order visit of the CFG +pub type PreOrderBlockIter = cfg::PreOrderIter; +/// An iterator over blocks produced by a depth-first, post-order visit of the CFG +pub type PostOrderBlockIter = cfg::PostOrderIter; + +/// The unique identifier for a [Block] +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct BlockId(u32); +impl BlockId { + pub const fn from_u32(id: u32) -> Self { + Self(id) + } + + pub const fn as_u32(&self) -> u32 { + self.0 + } +} +impl EntityId for BlockId { + #[inline(always)] + fn as_usize(&self) -> usize { + self.0 as usize + } +} +impl fmt::Debug for BlockId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "block{}", &self.0) + } +} +impl fmt::Display for BlockId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "block{}", &self.0) + } +} + +/// Represents a basic block in the IR. +/// +/// Basic blocks are used in SSA regions to provide the structure of the control-flow graph. +/// Operations within a basic block appear in the order they will be executed. +/// +/// A block must have a [traits::Terminator], an operation which transfers control to another block +/// in the same region, or out of the containing operation (e.g. returning from a function). +/// +/// Blocks have _predecessors_ and _successors_, representing the inbound and outbound edges +/// (respectively) formed by operations that transfer control between blocks. A block can have +/// zero or more predecessors and/or successors. A well-formed region will generally only have a +/// single block (the entry block) with no predecessors (i.e. no unreachable blocks), and no blocks +/// with both multiple predecessors _and_ multiple successors (i.e. no critical edges). It is valid +/// to have both unreachable blocks and critical edges in the IR, but they must be removed during +/// the course of compilation. +pub struct Block { + /// The unique id of this block + id: BlockId, + /// Flag that indicates whether the ops in this block have a valid ordering + valid_op_ordering: bool, + /// The set of uses of this block + uses: BlockOperandList, + /// The region this block is attached to. + /// + /// This will always be set if this block is attached to a region + region: Option, + /// The list of [Operation]s that comprise this block + body: OpList, + /// The parameter list for this block + arguments: Vec, +} +impl fmt::Debug for Block { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Block") + .field("id", &self.id) + .field_with("region", |f| match self.region.as_ref() { + None => f.write_str("None"), + Some(r) => write!(f, "Some({r:p})"), + }) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for arg in self.arguments.iter() { + list.entry_with(|f| f.write_fmt(format_args!("{}", &arg.borrow()))); + } + list.finish() + }) + .finish_non_exhaustive() + } +} + +impl Entity for Block {} + +impl EntityWithId for Block { + type Id = BlockId; + + fn id(&self) -> Self::Id { + self.id + } +} + +impl EntityWithParent for Block { + type Parent = Region; + + fn on_inserted_into_parent( + mut this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().region = Some(parent); + } + + fn on_removed_from_parent( + mut this: UnsafeIntrusiveEntityRef, + _parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().region = None; + } + + fn on_transfered_to_new_parent( + _from: UnsafeIntrusiveEntityRef, + to: UnsafeIntrusiveEntityRef, + transferred: impl IntoIterator>, + ) { + for mut transferred_block in transferred { + transferred_block.borrow_mut().region = Some(to.clone()); + } + } +} + +impl Usable for Block { + type Use = BlockOperand; + + #[inline(always)] + fn uses(&self) -> &BlockOperandList { + &self.uses + } + + #[inline(always)] + fn uses_mut(&mut self) -> &mut BlockOperandList { + &mut self.uses + } +} + +impl cfg::Graph for Block { + type ChildEdgeIter = BlockSuccessorEdgesIter; + type ChildIter = BlockSuccessorIter; + type Edge = BlockOperandRef; + type Node = BlockRef; + + fn size(&self) -> usize { + if let Some(term) = self.terminator() { + term.borrow().num_successors() + } else { + 0 + } + } + + fn entry_node(&self) -> Self::Node { + self.as_block_ref() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + BlockSuccessorIter::new(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + BlockSuccessorEdgesIter::new(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + edge.borrow().block.clone() + } +} + +impl<'a> cfg::InvertibleGraph for &'a Block { + type Inverse = cfg::Inverse<&'a Block>; + type InvertibleChildEdgeIter = BlockPredecessorEdgesIter; + type InvertibleChildIter = BlockPredecessorIter; + + fn inverse(self) -> Self::Inverse { + cfg::Inverse::new(self) + } + + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter { + BlockPredecessorIter::new(parent) + } + + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter { + BlockPredecessorEdgesIter::new(parent) + } +} + +impl cfg::Graph for BlockRef { + type ChildEdgeIter = BlockSuccessorEdgesIter; + type ChildIter = BlockSuccessorIter; + type Edge = BlockOperandRef; + type Node = BlockRef; + + fn size(&self) -> usize { + if let Some(term) = self.borrow().terminator() { + term.borrow().num_successors() + } else { + 0 + } + } + + fn entry_node(&self) -> Self::Node { + self.clone() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + BlockSuccessorIter::new(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + BlockSuccessorEdgesIter::new(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + edge.borrow().block.clone() + } +} + +impl cfg::InvertibleGraph for BlockRef { + type Inverse = cfg::Inverse; + type InvertibleChildEdgeIter = BlockPredecessorEdgesIter; + type InvertibleChildIter = BlockPredecessorIter; + + fn inverse(self) -> Self::Inverse { + cfg::Inverse::new(self) + } + + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter { + BlockPredecessorIter::new(parent) + } + + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter { + BlockPredecessorEdgesIter::new(parent) + } +} + +#[doc(hidden)] +pub struct BlockSuccessorIter { + iter: BlockSuccessorEdgesIter, +} +impl BlockSuccessorIter { + pub fn new(parent: BlockRef) -> Self { + Self { + iter: BlockSuccessorEdgesIter::new(parent), + } + } +} +impl ExactSizeIterator for BlockSuccessorIter { + #[inline] + fn len(&self) -> usize { + self.iter.len() + } + + #[inline] + fn is_empty(&self) -> bool { + self.iter.is_empty() + } +} +impl Iterator for BlockSuccessorIter { + type Item = BlockRef; + + fn next(&mut self) -> Option { + self.iter.next().map(|bo| bo.borrow().block.clone()) + } + + #[inline] + fn collect>(self) -> B + where + Self: Sized, + { + let Some(terminator) = self.iter.terminator.as_ref() else { + return B::from_iter([]); + }; + let terminator = terminator.borrow(); + let successors = terminator.successors(); + B::from_iter( + successors.all().as_slice()[self.iter.index..self.iter.num_successors] + .iter() + .map(|succ| succ.block.borrow().block.clone()), + ) + } + + fn collect_into>(self, collection: &mut E) -> &mut E + where + Self: Sized, + { + let Some(terminator) = self.iter.terminator.as_ref() else { + return collection; + }; + let terminator = terminator.borrow(); + let successors = terminator.successors(); + collection.extend( + successors.all().as_slice()[self.iter.index..self.iter.num_successors] + .iter() + .map(|succ| succ.block.borrow().block.clone()), + ); + collection + } +} + +#[doc(hidden)] +pub struct BlockSuccessorEdgesIter { + terminator: Option, + num_successors: usize, + index: usize, +} +impl BlockSuccessorEdgesIter { + pub fn new(parent: BlockRef) -> Self { + let terminator = parent.borrow().terminator(); + let num_successors = terminator.as_ref().map(|t| t.borrow().num_successors()).unwrap_or(0); + Self { + terminator, + num_successors, + index: 0, + } + } +} +impl ExactSizeIterator for BlockSuccessorEdgesIter { + #[inline] + fn len(&self) -> usize { + self.num_successors.saturating_sub(self.index) + } + + #[inline] + fn is_empty(&self) -> bool { + self.index >= self.num_successors + } +} +impl Iterator for BlockSuccessorEdgesIter { + type Item = BlockOperandRef; + + fn next(&mut self) -> Option { + if self.index >= self.num_successors { + return None; + } + + // SAFETY: We'll never have a none terminator if we have non-zero number of successors + let terminator = unsafe { self.terminator.as_ref().unwrap_unchecked() }; + let index = self.index; + self.index += 1; + Some(terminator.borrow().successor(index).dest.clone()) + } + + fn collect>(self) -> B + where + Self: Sized, + { + let Some(terminator) = self.terminator.as_ref() else { + return B::from_iter([]); + }; + let terminator = terminator.borrow(); + let successors = terminator.successors(); + B::from_iter( + successors.all().as_slice()[self.index..self.num_successors] + .iter() + .map(|succ| succ.block.clone()), + ) + } + + fn collect_into>(self, collection: &mut E) -> &mut E + where + Self: Sized, + { + let Some(terminator) = self.terminator.as_ref() else { + return collection; + }; + let terminator = terminator.borrow(); + let successors = terminator.successors(); + collection.extend( + successors.all().as_slice()[self.index..self.num_successors] + .iter() + .map(|succ| succ.block.clone()), + ); + collection + } +} + +#[doc(hidden)] +pub struct BlockPredecessorIter { + preds: SmallVec<[BlockRef; 4]>, + index: usize, +} +impl BlockPredecessorIter { + pub fn new(child: BlockRef) -> Self { + let preds = child.borrow().predecessors().map(|bo| bo.block.clone()).collect(); + Self { preds, index: 0 } + } + + #[inline(always)] + pub fn into_inner(self) -> SmallVec<[BlockRef; 4]> { + self.preds + } +} +impl ExactSizeIterator for BlockPredecessorIter { + #[inline] + fn len(&self) -> usize { + self.preds.len().saturating_sub(self.index) + } + + #[inline] + fn is_empty(&self) -> bool { + self.index >= self.preds.len() + } +} +impl Iterator for BlockPredecessorIter { + type Item = BlockRef; + + #[inline] + fn next(&mut self) -> Option { + if self.is_empty() { + return None; + } + let index = self.index; + self.index += 1; + Some(self.preds[index].clone()) + } + + fn collect>(self) -> B + where + Self: Sized, + { + B::from_iter(self.preds) + } + + fn collect_into>(self, collection: &mut E) -> &mut E + where + Self: Sized, + { + collection.extend(self.preds); + collection + } +} + +#[doc(hidden)] +pub struct BlockPredecessorEdgesIter { + preds: SmallVec<[BlockOperandRef; 4]>, + index: usize, +} +impl BlockPredecessorEdgesIter { + pub fn new(child: BlockRef) -> Self { + let preds = child + .borrow() + .predecessors() + .map(|bo| unsafe { BlockOperandRef::from_raw(&*bo) }) + .collect(); + Self { preds, index: 0 } + } +} +impl ExactSizeIterator for BlockPredecessorEdgesIter { + #[inline] + fn len(&self) -> usize { + self.preds.len().saturating_sub(self.index) + } + + #[inline] + fn is_empty(&self) -> bool { + self.index >= self.preds.len() + } +} +impl Iterator for BlockPredecessorEdgesIter { + type Item = BlockOperandRef; + + #[inline] + fn next(&mut self) -> Option { + if self.is_empty() { + return None; + } + let index = self.index; + self.index += 1; + Some(self.preds[index].clone()) + } + + fn collect>(self) -> B + where + Self: Sized, + { + B::from_iter(self.preds) + } + + fn collect_into>(self, collection: &mut E) -> &mut E + where + Self: Sized, + { + collection.extend(self.preds); + collection + } +} + +impl Block { + pub fn new(id: BlockId) -> Self { + Self { + id, + valid_op_ordering: true, + uses: Default::default(), + region: None, + body: Default::default(), + arguments: Default::default(), + } + } + + #[inline] + pub fn as_block_ref(&self) -> BlockRef { + unsafe { BlockRef::from_raw(self) } + } + + /// Get a handle to the containing [Region] of this block, if it is attached to one + pub fn parent(&self) -> Option { + self.region.clone() + } + + /// Get a handle to the containing [Operation] of this block, if it is attached to one + pub fn parent_op(&self) -> Option { + self.region.as_ref().and_then(|region| region.borrow().parent()) + } + + /// Get a handle to the ancestor [Block] of this block, if one is present + pub fn parent_block(&self) -> Option { + self.parent_op().and_then(|op| op.borrow().parent()) + } + + /// Returns true if this block is the entry block for its containing region + pub fn is_entry_block(&self) -> bool { + if let Some(parent) = self.region.as_ref().map(|r| r.borrow()) { + core::ptr::addr_eq(&*parent.entry(), self) + } else { + false + } + } + + /// Get the first operation in the body of this block + #[inline] + pub fn front(&self) -> Option { + self.body.front().as_pointer() + } + + /// Get the last operation in the body of this block + #[inline] + pub fn back(&self) -> Option { + self.body.back().as_pointer() + } + + /// Get the list of [Operation] comprising the body of this block + #[inline(always)] + pub fn body(&self) -> &OpList { + &self.body + } + + /// Get a mutable reference to the list of [Operation] comprising the body of this block + #[inline(always)] + pub fn body_mut(&mut self) -> &mut OpList { + &mut self.body + } +} + +/// Arguments +impl Block { + #[inline] + pub fn has_arguments(&self) -> bool { + !self.arguments.is_empty() + } + + #[inline] + pub fn num_arguments(&self) -> usize { + self.arguments.len() + } + + #[inline(always)] + pub fn arguments(&self) -> &[BlockArgumentRef] { + self.arguments.as_slice() + } + + #[inline(always)] + pub fn arguments_mut(&mut self) -> &mut Vec { + &mut self.arguments + } + + #[inline] + pub fn get_argument(&self, index: usize) -> BlockArgumentRef { + self.arguments[index].clone() + } + + /// Erase the block argument at `index` + /// + /// Panics if the argument still has uses. + pub fn erase_argument(&mut self, index: usize) { + assert!( + !self.arguments[index].borrow().is_used(), + "cannot erase block arguments with uses" + ); + self.arguments.remove(index); + } + + /// Erase every parameter of this block for which `should_erase` returns true. + /// + /// Panics if any argument to be erased still has uses. + pub fn erase_arguments(&mut self, should_erase: F) + where + F: Fn(&BlockArgument) -> bool, + { + self.arguments.retain(|arg| { + let arg = arg.borrow(); + let keep = !should_erase(&arg); + assert!(keep || !arg.is_used(), "cannot erase block arguments with uses"); + keep + }); + } +} + +/// Placement +impl Block { + /// Insert this block after `after` in its containing region. + /// + /// Panics if this block is already attached to a region, or if `after` is not attached. + pub fn insert_after(&mut self, after: BlockRef) { + assert!( + self.region.is_none(), + "cannot insert block that is already attached to another region" + ); + let mut region = + after.borrow().parent().expect("'after' block is not attached to a region"); + { + let mut region = region.borrow_mut(); + let region_body = region.body_mut(); + let mut cursor = unsafe { region_body.cursor_mut_from_ptr(after) }; + cursor.insert_after(self.as_block_ref()); + } + self.region = Some(region); + } + + /// Insert this block before `before` in its containing region. + /// + /// Panics if this block is already attached to a region, or if `before` is not attached. + pub fn insert_before(&mut self, before: BlockRef) { + assert!( + self.region.is_none(), + "cannot insert block that is already attached to another region" + ); + let mut region = + before.borrow().parent().expect("'before' block is not attached to a region"); + { + let mut region = region.borrow_mut(); + let region_body = region.body_mut(); + let mut cursor = unsafe { region_body.cursor_mut_from_ptr(before) }; + cursor.insert_before(self.as_block_ref()); + } + self.region = Some(region); + } + + /// Insert this block at the end of `region`. + /// + /// Panics if this block is already attached to a region. + pub fn insert_at_end(&mut self, mut region: RegionRef) { + assert!( + self.region.is_none(), + "cannot insert block that is already attached to another region" + ); + { + let mut region = region.borrow_mut(); + region.body_mut().push_back(self.as_block_ref()); + } + self.region = Some(region); + } + + /// Unlink this block from its current region and insert it right before `before` + pub fn move_before(&mut self, before: BlockRef) { + self.unlink(); + self.insert_before(before); + } + + /// Remove this block from its containing region + fn unlink(&mut self) { + if let Some(mut region) = self.region.take() { + let mut region = region.borrow_mut(); + unsafe { + let mut cursor = region.body_mut().cursor_mut_from_ptr(self.as_block_ref()); + cursor.remove(); + } + } + } + + /// Split this block into two blocks before the specified operation + /// + /// Note that all operations in the block prior to `before` stay as part of the original block, + /// and the rest are moved to the new block, including the old terminator. The original block is + /// thus left without a terminator. + /// + /// Returns the newly created block. + pub fn split_block(&mut self, before: OperationRef) -> BlockRef { + let this = self.as_block_ref(); + assert!( + BlockRef::ptr_eq( + &this, + &before.borrow().parent().expect("'before' op is not attached to a block") + ), + "cannot split block using an operation that does not belong to the block being split" + ); + + // We need the parent op so we can get access to the current Context, but this also tells us + // that this block is attached to a region and operation. + let parent = self.parent_op().expect("block is not attached to an operation"); + // Create a new empty block + let mut new_block = parent.borrow().context().create_block(); + // Insert the block in the same region as `self`, immediately after `self` + let region = self.region.as_mut().unwrap(); + { + let mut region_mut = region.borrow_mut(); + let blocks = region_mut.body_mut(); + let mut cursor = unsafe { blocks.cursor_mut_from_ptr(this.clone()) }; + cursor.insert_after(new_block.clone()); + } + // Split the body of `self` at `before`, and splice everything after `before`, including + // `before` itself, into the new block we created. + let mut ops = { + let mut cursor = unsafe { self.body.cursor_mut_from_ptr(before) }; + cursor.split_before() + }; + // The split_before method returns the list containing all of the ops before the cursor, but + // we want the inverse, so we just swap the two lists. + core::mem::swap(&mut self.body, &mut ops); + // Visit all of the ops and notify them of the move + for op in ops.iter() { + Operation::on_inserted_into_parent(op.as_operation_ref(), new_block.clone()); + } + new_block.borrow_mut().body = ops; + new_block + } + + pub fn clear(&mut self) { + // Drop all references from within this block + self.drop_all_references(); + + // Drop all operations within this block + self.body_mut().clear(); + } +} + +/// Ancestors +impl Block { + pub fn is_legal_to_hoist_into(&self) -> bool { + use crate::traits::{HasSideEffects, ReturnLike}; + + // No terminator means the block is under construction, and thus legal to hoist into + let Some(terminator) = self.terminator() else { + return true; + }; + + // If the block has no successors, it can never be legal to hoist into it, there is nothing + // to hoist! + if self.num_successors() == 0 { + return false; + } + + // Instructions should not be hoisted across effectful or return-like terminators. This is + // typically only exception handling intrinsics, which HIR doesn't really have, but which + // we may nevertheless want to represent in the future. + // + // NOTE: Most return-like terminators would have no successors, but in LLVM, for example, + // there are instructions like `catch_ret`, which semantically are return-like, but which + // have a successor block (the landing pad). + let terminator = terminator.borrow(); + terminator.implements::() || terminator.implements::() + } + + pub fn has_ssa_dominance(&self) -> bool { + self.parent_op() + .and_then(|op| { + op.borrow() + .as_trait::() + .map(|rki| rki.has_ssa_dominance()) + }) + .unwrap_or(true) + } + + /// Walk up the ancestor blocks of `block`, until `f` returns `true` for a block. + /// + /// NOTE: `block` is visited before any of its ancestors. + pub fn traverse_ancestors(block: BlockRef, mut f: F) -> Option + where + F: FnMut(BlockRef) -> bool, + { + let mut block = Some(block); + while let Some(current) = block.take() { + if f(current.clone()) { + return Some(current); + } + block = current.borrow().parent_block(); + } + + None + } + + /// Try to get a pair of blocks, starting with the given pair, which live in the same region, + /// by exploring the relationships of both blocks with respect to their regions. + /// + /// The returned block pair will either be the same input blocks, or some combination of those + /// blocks or their ancestors. + pub fn get_blocks_in_same_region(a: &BlockRef, b: &BlockRef) -> Option<(BlockRef, BlockRef)> { + // If both blocks do not live in the same region, we will have to check their parent + // operations. + let a_region = a.borrow().parent().unwrap(); + let b_region = b.borrow().parent().unwrap(); + if a_region == b_region { + return Some((a.clone(), b.clone())); + } + + // Iterate over all ancestors of `a`, counting the depth of `a`. + // + // If one of `a`'s ancestors are in the same region as `b`, then we stop early because we + // found our nearest common ancestor. + let mut a_depth = 0; + let result = Self::traverse_ancestors(a.clone(), |block| { + a_depth += 1; + block.borrow().parent().is_some_and(|r| r == b_region) + }); + if let Some(a) = result { + return Some((a, b.clone())); + } + + // Iterate over all ancestors of `b`, counting the depth of `b`. + // + // If one of `b`'s ancestors are in the same region as `a`, then we stop early because we + // found our nearest common ancestor. + let mut b_depth = 0; + let result = Self::traverse_ancestors(b.clone(), |block| { + b_depth += 1; + block.borrow().parent().is_some_and(|r| r == a_region) + }); + if let Some(b) = result { + return Some((a.clone(), b)); + } + + // Otherwise, we found two blocks that are siblings at some level. Walk the deepest one + // up until we reach the top or find a nearest common ancestor. + let mut a = Some(a.clone()); + let mut b = Some(b.clone()); + loop { + use core::cmp::Ordering; + + match a_depth.cmp(&b_depth) { + Ordering::Greater => { + a = a.and_then(|a| a.borrow().parent_block()); + a_depth -= 1; + } + Ordering::Less => { + b = b.and_then(|b| b.borrow().parent_block()); + b_depth -= 1; + } + Ordering::Equal => break, + } + } + + // If we found something with the same level, then we can march both up at the same time + // from here on out. + while let Some(next_a) = a.take() { + // If they are at the same level, and have the same parent region, then we succeeded. + let next_a_parent = next_a.borrow().parent(); + let b_parent = b.as_ref().and_then(|b| b.borrow().parent()); + if next_a_parent == b_parent { + return Some((next_a, b.unwrap())); + } + + a = next_a_parent + .and_then(|r| r.borrow().parent()) + .and_then(|op| op.borrow().parent()); + b = b_parent.and_then(|r| r.borrow().parent()).and_then(|op| op.borrow().parent()); + } + + // They don't share a nearest common ancestor, perhaps they are in different modules or + // something. + None + } +} + +/// Predecessors and Successors +impl Block { + /// Returns true if this block has predecessors + #[inline(always)] + pub fn has_predecessors(&self) -> bool { + self.is_used() + } + + /// Get an iterator over the predecessors of this block + #[inline(always)] + pub fn predecessors(&self) -> BlockOperandIter<'_> { + self.iter_uses() + } + + /// If this block has exactly one predecessor, return it, otherwise `None` + /// + /// NOTE: A predecessor block with multiple edges, e.g. a conditional branch that has this block + /// as the destination for both true/false branches is _not_ considered a single predecessor by + /// this function. + pub fn get_single_predecessor(&self) -> Option { + let front = self.uses.front(); + if front.is_null() { + return None; + } + let front = front.as_pointer().unwrap(); + let back = self.uses.back().as_pointer().unwrap(); + if BlockOperandRef::ptr_eq(&front, &back) { + Some(front.borrow().block.clone()) + } else { + None + } + } + + /// If this block has a unique predecessor, i.e. all incoming edges originate from one block, + /// return it, otherwise `None` + pub fn get_unique_predecessor(&self) -> Option { + let mut front = self.uses.front(); + let block_operand = front.get()?; + let block = block_operand.block.clone(); + loop { + front.move_next(); + if let Some(bo) = front.get() { + if !BlockRef::ptr_eq(&block, &bo.block) { + break None; + } + } else { + break Some(block); + } + } + } + + /// Returns true if this block has any successors + #[inline] + pub fn has_successors(&self) -> bool { + self.num_successors() > 0 + } + + /// Get the number of successors of this block in the CFG + pub fn num_successors(&self) -> usize { + self.terminator().map(|op| op.borrow().num_successors()).unwrap_or(0) + } + + /// Get the `index`th successor of this block's terminator operation + pub fn get_successor(&self, index: usize) -> BlockRef { + let op = self.terminator().expect("this block has no terminator"); + op.borrow().successor(index).dest.borrow().block.clone() + } + + /// This drops all operand uses from operations within this block, which is an essential step in + /// breaking cyclic dependences between references when they are to be deleted. + pub fn drop_all_references(&mut self) { + let mut cursor = self.body.front_mut(); + while let Some(mut op) = cursor.as_pointer() { + op.borrow_mut().drop_all_references(); + cursor.move_next(); + } + } + + /// This drops all uses of values defined in this block or in the blocks of nested regions + /// wherever the uses are located. + pub fn drop_all_defined_value_uses(&mut self) { + for arg in self.arguments.iter_mut() { + let mut arg = arg.borrow_mut(); + arg.uses_mut().clear(); + } + let mut cursor = self.body.front_mut(); + while let Some(mut op) = cursor.as_pointer() { + op.borrow_mut().drop_all_defined_value_uses(); + cursor.move_next(); + } + self.drop_all_uses(); + } + + /// Drop all uses of this block via [BlockOperand] + #[inline] + pub fn drop_all_uses(&mut self) { + self.uses_mut().clear(); + } + + #[inline(always)] + pub(super) const fn is_op_order_valid(&self) -> bool { + self.valid_op_ordering + } + + #[inline(always)] + pub(super) fn mark_op_order_valid(&mut self) { + self.valid_op_ordering = true; + } + + pub(super) fn invalidate_op_order(&mut self) { + // Validate the current ordering + assert!(self.verify_op_order()); + self.valid_op_ordering = false; + } + + /// Returns true if the current operation ordering in this block is valid + pub(super) fn verify_op_order(&self) -> bool { + // The order is already known to be invalid + if !self.valid_op_ordering { + return false; + } + + // The order is valid if there are less than 2 operations + if self.body.is_empty() + || OperationRef::ptr_eq( + &self.body.front().as_pointer().unwrap(), + &self.body.back().as_pointer().unwrap(), + ) + { + return true; + } + + let mut cursor = self.body.front(); + let mut prev = None; + while let Some(op) = cursor.as_pointer() { + cursor.move_next(); + if prev.is_none() { + prev = Some(op); + continue; + } + + // The previous operation must have a smaller order index than the next + let prev_order = prev.take().unwrap().borrow().order(); + let current_order = op.borrow().order().unwrap_or(u32::MAX); + if prev_order.is_some_and(|o| o >= current_order) { + return false; + } + prev = Some(op); + } + + true + } + + /// Get the terminator operation of this block, or `None` if the block does not have one. + pub fn terminator(&self) -> Option { + if !self.has_terminator() { + None + } else { + self.body.back().as_pointer() + } + } + + /// Returns true if this block has a terminator + pub fn has_terminator(&self) -> bool { + use crate::traits::Terminator; + !self.body.is_empty() + && self.body.back().get().is_some_and(|op| op.implements::()) + } +} + +pub type BlockOperandRef = UnsafeIntrusiveEntityRef; +/// An intrusive, doubly-linked list of [BlockOperand] +pub type BlockOperandList = EntityList; +#[allow(unused)] +pub type BlockOperandCursor<'a> = EntityCursor<'a, BlockOperand>; +#[allow(unused)] +pub type BlockOperandCursorMut<'a> = EntityCursorMut<'a, BlockOperand>; +pub type BlockOperandIter<'a> = EntityIter<'a, BlockOperand>; + +/// A [BlockOperand] represents a use of a [Block] by an [Operation] +pub struct BlockOperand { + /// The block value + pub block: BlockRef, + /// The owner of this operand, i.e. the operation it is an operand of + pub owner: OperationRef, + /// The index of this operand in the set of block operands of the operation + pub index: u8, +} +impl BlockOperand { + #[inline] + pub fn new(block: BlockRef, owner: OperationRef, index: u8) -> Self { + Self { + block, + owner, + index, + } + } + + pub fn block_id(&self) -> BlockId { + self.block.borrow().id + } +} +impl fmt::Debug for BlockOperand { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BlockOperand") + .field("block", &self.block.borrow().id()) + .field_with("owner", |f| write!(f, "{:p}", &self.owner)) + .field("index", &self.index) + .finish() + } +} +impl StorableEntity for BlockOperand { + #[inline(always)] + fn index(&self) -> usize { + self.index as usize + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index.try_into().expect("too many successors"); + } + + /// Remove this use of `block` + fn unlink(&mut self) { + let owner = unsafe { BlockOperandRef::from_raw(self) }; + let mut block = self.block.borrow_mut(); + let uses = block.uses_mut(); + unsafe { + let mut cursor = uses.cursor_mut_from_ptr(owner); + cursor.remove(); + } + } +} diff --git a/hir2/src/ir/builder.rs b/hir2/src/ir/builder.rs new file mode 100644 index 000000000..e8d1404c6 --- /dev/null +++ b/hir2/src/ir/builder.rs @@ -0,0 +1,441 @@ +use alloc::rc::Rc; + +use crate::{ + BlockArgument, BlockRef, BuildableOp, Context, InsertionPoint, OperationRef, ProgramPoint, + RegionRef, SourceSpan, Type, Value, +}; + +/// The [Builder] trait encompasses all of the functionality needed to construct and insert blocks +/// and operations into the IR. +pub trait Builder: Listener { + fn context(&self) -> &Context; + fn context_rc(&self) -> Rc; + /// Returns the current insertion point of the builder + fn insertion_point(&self) -> Option<&InsertionPoint>; + /// Clears the current insertion point + fn clear_insertion_point(&mut self) -> Option; + /// Restores the current insertion point to `ip` + fn restore_insertion_point(&mut self, ip: Option); + /// Sets the current insertion point to `ip` + fn set_insertion_point(&mut self, ip: InsertionPoint); + + /// Sets the insertion point to the specified program point, causing subsequent insertions to + /// be placed before it. + #[inline] + fn set_insertion_point_before(&mut self, pp: ProgramPoint) { + self.set_insertion_point(InsertionPoint::before(pp)); + } + + /// Sets the insertion point to the specified program point, causing subsequent insertions to + /// be placed after it. + #[inline] + fn set_insertion_point_after(&mut self, pp: ProgramPoint) { + self.set_insertion_point(InsertionPoint::after(pp)); + } + + /// Sets the insertion point to the node after the specified value is defined. + /// + /// If value has a defining operation, this sets the insertion point after that operation, so + /// that all insertions are placed following the definition. + /// + /// Otherwise, a value must be a block argument, so the insertion point is placed at the start + /// of the block, causing insertions to be placed starting at the front of the block. + fn set_insertion_point_after_value(&mut self, value: &dyn Value) { + let pp = if let Some(op) = value.get_defining_op() { + ProgramPoint::Op(op) + } else { + let block_argument = value.downcast_ref::().unwrap(); + ProgramPoint::Block(block_argument.owner()) + }; + self.set_insertion_point_after(pp); + } + + /// Sets the current insertion point to the start of `block`. + /// + /// Operations inserted will be placed starting at the beginning of the block. + #[inline] + fn set_insertion_point_to_start(&mut self, block: BlockRef) { + self.set_insertion_point_before(block.into()); + } + + /// Sets the current insertion point to the end of `block`. + /// + /// Operations inserted will be placed starting at the end of the block. + #[inline] + fn set_insertion_point_to_end(&mut self, block: BlockRef) { + self.set_insertion_point_after(block.into()); + } + + /// Return the block the current insertion point belongs to. + /// + /// NOTE: The insertion point is not necessarily at the end of the block. + /// + /// Returns `None` if the insertion point is unset, or is pointing at an operation which is + /// detached from a block. + fn insertion_block(&self) -> Option { + self.insertion_point().and_then(|ip| ip.at.block()) + } + + /// Add a new block with `args` arguments, and set the insertion point to the end of it. + /// + /// The block is inserted after the provided insertion point `ip`, or at the end of `parent` if + /// not. + /// + /// Panics if `ip` is in a different region than `parent`, or if the position it refers to is no + /// longer valid. + fn create_block(&mut self, parent: RegionRef, ip: Option, args: &[Type]) -> BlockRef { + let mut block = self.context().create_block_with_params(args.iter().cloned()); + if let Some(at) = ip { + let region = at.borrow().parent().unwrap(); + assert!( + RegionRef::ptr_eq(&parent, ®ion), + "insertion point region differs from 'parent'" + ); + + block.borrow_mut().insert_after(at); + } else { + block.borrow_mut().insert_at_end(parent); + } + + self.notify_block_inserted(block.clone(), None, None); + + self.set_insertion_point_to_end(block.clone()); + + block + } + + /// Add a new block with `args` arguments, and set the insertion point to the end of it. + /// + /// The block is inserted before `before`. + fn create_block_before(&mut self, before: BlockRef, args: &[Type]) -> BlockRef { + let mut block = self.context().create_block_with_params(args.iter().cloned()); + block.borrow_mut().insert_before(before); + self.notify_block_inserted(block.clone(), None, None); + self.set_insertion_point_to_end(block.clone()); + block + } + + /// Insert `op` at the current insertion point. + /// + /// If the insertion point is inserting after the current operation, then after calling this + /// function, the insertion point will have been moved to the newly inserted operation. This + /// ensures that subsequent calls to `insert` will place operations in the block in the same + /// sequence as they were inserted. The other insertion point placements already have more or + /// less intuitive behavior, e.g. inserting _before_ the current operation multiple times will + /// result in operations being placed in the same sequence they were inserted, just before the + /// current op. + /// + /// This function will panic if no insertion point is set. + fn insert(&mut self, mut op: OperationRef) { + let ip = self.insertion_point().expect("insertion point is unset").clone(); + match ip.at { + ProgramPoint::Block(block) => match ip.placement { + crate::Insert::Before => op.borrow_mut().insert_at_start(block), + crate::Insert::After => op.borrow_mut().insert_at_end(block), + }, + ProgramPoint::Op(other_op) => match ip.placement { + crate::Insert::Before => op.borrow_mut().insert_before(other_op), + crate::Insert::After => { + op.borrow_mut().insert_after(other_op.clone()); + self.set_insertion_point_after(ProgramPoint::Op(other_op)); + } + }, + } + self.notify_operation_inserted(op, None); + } +} + +pub trait BuilderExt: Builder { + /// Returns a specialized builder for a concrete [Op], `T`, which can be called like a closure + /// with the arguments required to create an instance of the specified operation. + /// + /// # How it works + /// + /// The set of arguments which are valid for the specialized builder returned by `create`, are + /// determined by what implementations of the [BuildableOp] trait exist for `T`. The specific + /// impl that is chosen will depend on the types of the arguments given to it. Typically, there + /// should only be one implementation, or if there are multiple, they should not overlap in + /// ways that may confuse type inference, or you will be forced to specify the full type of the + /// argument pack. + /// + /// This mechanism for constructing ops using arbitrary arguments is essentially a workaround + /// for the lack of variadic generics in Rust, and isn't quite as nice as what you can acheive + /// in C++ with varidadic templates and `std::forward` and such, but is close enough so that + /// the ergonomics are still a significant improvement over the alternative approaches. + /// + /// The nice thing about this is that we can generate all of the boilerplate, and hide all of + /// the sensitive/unsafe parts of initializing operations. Alternative approaches require + /// exposing more unsafe APIs for use by builders, whereas this approach can conceal those + /// details within this crate. + /// + /// ## Example + /// + /// ```text,ignore + /// // Get an OpBuilder + /// let builder = context.builder(); + /// // Obtain a builder for AddOp + /// let add_builder = builder.create::(span); + /// // Consume the builder by creating the op with the given arguments + /// let add = add_builder(lhs, rhs, Overflow::Wrapping).expect("invalid add op"); + /// ``` + /// + /// Or, simplified/collapsed: + /// + /// ```text,ignore + /// let builder = context.builder(); + /// let add = builder.create::(span)(lhs, rhs, Overflow::Wrapping) + /// .expect("invalid add op"); + /// ``` + #[inline(always)] + fn create(&mut self, span: SourceSpan) -> >::Builder<'_, Self> + where + Args: core::marker::Tuple, + T: BuildableOp, + { + >::builder(self, span) + } +} + +impl BuilderExt for B {} + +pub struct OpBuilder { + context: Rc, + listener: Option, + ip: Option, +} + +impl OpBuilder { + pub fn new(context: Rc) -> Self { + Self { + context, + listener: None, + ip: None, + } + } +} + +impl OpBuilder { + /// Sets the listener of this builder to `listener` + pub fn with_listener(self, listener: L2) -> OpBuilder + where + L2: Listener, + { + OpBuilder { + context: self.context, + listener: Some(listener), + ip: self.ip, + } + } + + #[inline] + pub fn into_parts(self) -> (Rc, Option, Option) { + (self.context, self.listener, self.ip) + } +} + +impl Listener for OpBuilder { + fn kind(&self) -> ListenerType { + self.listener.as_ref().map(|l| l.kind()).unwrap_or(ListenerType::Builder) + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_inserted(op, prev); + } + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_block_inserted(block, prev, ip); + } + } +} + +impl Builder for OpBuilder { + #[inline(always)] + fn context(&self) -> &Context { + self.context.as_ref() + } + + #[inline(always)] + fn context_rc(&self) -> Rc { + self.context.clone() + } + + #[inline(always)] + fn insertion_point(&self) -> Option<&InsertionPoint> { + self.ip.as_ref() + } + + #[inline] + fn clear_insertion_point(&mut self) -> Option { + self.ip.take() + } + + #[inline] + fn restore_insertion_point(&mut self, ip: Option) { + self.ip = ip; + } + + #[inline(always)] + fn set_insertion_point(&mut self, ip: InsertionPoint) { + self.ip = Some(ip); + } +} + +#[derive(Debug, Copy, Clone)] +pub enum ListenerType { + Builder, + Rewriter, +} + +#[allow(unused_variables)] +pub trait Listener: 'static { + fn kind(&self) -> ListenerType; + /// Notify the listener that the specified operation was inserted. + /// + /// * If the operation was moved, then `prev` is the previous location of the op + /// * If the operation was unlinked before it was inserted, then `prev` is `None` + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) {} + /// Notify the listener that the specified block was inserted. + /// + /// * If the block was moved, then `prev` and `ip` represent the previous location of the block. + /// * If the block was unlinked before it was inserted, then `prev` and `ip` are `None` + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + } +} + +impl Listener for Option { + fn kind(&self) -> ListenerType { + ListenerType::Builder + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + if let Some(listener) = self.as_ref() { + listener.notify_block_inserted(block, prev, ip); + } + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_inserted(op, prev); + } + } +} + +impl Listener for Box { + #[inline] + fn kind(&self) -> ListenerType { + (**self).kind() + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + (**self).notify_operation_inserted(op, prev) + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + (**self).notify_block_inserted(block, prev, ip) + } +} + +impl Listener for Rc { + #[inline] + fn kind(&self) -> ListenerType { + (**self).kind() + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + (**self).notify_operation_inserted(op, prev) + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + (**self).notify_block_inserted(block, prev, ip) + } +} + +/// A listener of kind `Builder` that does nothing +pub struct NoopBuilderListener; +impl Listener for NoopBuilderListener { + #[inline] + fn kind(&self) -> ListenerType { + ListenerType::Builder + } +} + +/// This is used to allow [InsertionGuard] to be agnostic about the type of builder/rewriter it +/// wraps, while still performing the necessary insertion point restoration on drop. Without this, +/// we would be required to specify a `B: Builder` bound on the definition of [InsertionGuard]. +#[doc(hidden)] +#[allow(unused_variables)] +trait RestoreInsertionPointOnDrop { + fn restore_insertion_point_on_drop(&mut self, ip: Option); +} +impl RestoreInsertionPointOnDrop for InsertionGuard<'_, B> { + #[inline(always)] + default fn restore_insertion_point_on_drop(&mut self, _ip: Option) {} +} +impl RestoreInsertionPointOnDrop for InsertionGuard<'_, B> { + fn restore_insertion_point_on_drop(&mut self, ip: Option) { + self.builder.restore_insertion_point(ip); + } +} + +pub struct InsertionGuard<'a, B: ?Sized> { + builder: &'a mut B, + ip: Option, +} +impl<'a, B> InsertionGuard<'a, B> +where + B: ?Sized + Builder, +{ + #[allow(unused)] + pub fn new(builder: &'a mut B) -> Self { + let ip = builder.insertion_point().cloned(); + Self { builder, ip } + } +} +impl core::ops::Deref for InsertionGuard<'_, B> { + type Target = B; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.builder + } +} +impl core::ops::DerefMut for InsertionGuard<'_, B> { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.builder + } +} +impl Drop for InsertionGuard<'_, B> { + fn drop(&mut self) { + let ip = self.ip.take(); + self.restore_insertion_point_on_drop(ip); + } +} diff --git a/hir2/src/ir/callable.rs b/hir2/src/ir/callable.rs new file mode 100644 index 000000000..b2e6c3f33 --- /dev/null +++ b/hir2/src/ir/callable.rs @@ -0,0 +1,397 @@ +use core::fmt; + +use crate::{ + formatter, CallConv, EntityRef, OpOperandRange, OpOperandRangeMut, RegionRef, Symbol, + SymbolNameAttr, SymbolRef, Type, UnsafeIntrusiveEntityRef, Value, ValueRef, Visibility, +}; + +/// A call-like operation is one that transfers control from one function to another. +/// +/// These operations may be traditional static calls, e.g. `call @foo`, or indirect calls, e.g. +/// `call_indirect v1`. An operation that uses this interface cannot _also_ implement the +/// `CallableOpInterface`. +pub trait CallOpInterface { + /// Get the callee of this operation. + /// + /// A callee is either a symbol, or a reference to an SSA value. + fn callable_for_callee(&self) -> Callable; + /// Sets the callee for this operation. + fn set_callee(&mut self, callable: Callable); + /// Get the operands of this operation that are used as arguments for the callee + fn arguments(&self) -> OpOperandRange<'_>; + /// Get a mutable reference to the operands of this operation that are used as arguments for the + /// callee + fn arguments_mut(&mut self) -> OpOperandRangeMut<'_>; + /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None` + /// if a valid callable was not resolved, using the provided symbol table. + /// + /// This method is used to perform callee resolution using a cached symbol table, rather than + /// traversing the operation hierarchy looking for symbol tables to try resolving with. + fn resolve_in_symbol_table(&self, symbols: &dyn crate::SymbolTable) -> Option; + /// Resolve the callable operation for the current callee to a `CallableOpInterface`, or `None` + /// if a valid callable was not resolved. + fn resolve(&self) -> Option; +} + +/// A callable operation is one who represents a potential function, and may be a target for a call- +/// like operation (i.e. implementations of `CallOpInterface`). These operations may be traditional +/// function ops (i.e. `Function`), as well as function reference-producing operations, such as an +/// op that creates closures, or captures a function by reference. +/// +/// These operations may only contain a single region. +pub trait CallableOpInterface { + /// Returns the region on the current operation that is callable. + /// + /// This may return `None` in the case of an external callable object, e.g. an externally- + /// defined function reference. + fn get_callable_region(&self) -> Option; + /// Returns the signature of the callable + fn signature(&self) -> &Signature; +} + +#[doc(hidden)] +pub trait AsCallableSymbolRef { + fn as_callable_symbol_ref(&self) -> SymbolRef; +} +impl AsCallableSymbolRef for T { + #[inline(always)] + fn as_callable_symbol_ref(&self) -> SymbolRef { + unsafe { SymbolRef::from_raw(self as &dyn Symbol) } + } +} +impl AsCallableSymbolRef for UnsafeIntrusiveEntityRef { + #[inline(always)] + fn as_callable_symbol_ref(&self) -> SymbolRef { + let t_ptr = Self::as_ptr(self); + unsafe { SymbolRef::from_raw(t_ptr as *const dyn Symbol) } + } +} + +/// A [Callable] represents a symbol or a value which can be used as a valid _callee_ for a +/// [CallOpInterface] implementation. +/// +/// Symbols are not SSA values, but there are situations where we want to treat them as one, such +/// as indirect calls. Abstracting over whether the callable is a symbol or an SSA value allows us +/// to focus on the call semantics, rather than the difference between the type types of value. +#[derive(Debug, Clone)] +pub enum Callable { + Symbol(SymbolNameAttr), + Value(ValueRef), +} +impl From<&SymbolNameAttr> for Callable { + fn from(value: &SymbolNameAttr) -> Self { + Self::Symbol(value.clone()) + } +} +impl From for Callable { + fn from(value: SymbolNameAttr) -> Self { + Self::Symbol(value) + } +} +impl From for Callable { + fn from(value: ValueRef) -> Self { + Self::Value(value) + } +} +impl Callable { + #[inline(always)] + pub fn new(callable: impl Into) -> Self { + callable.into() + } + + pub fn is_symbol(&self) -> bool { + matches!(self, Self::Symbol(_)) + } + + pub fn is_value(&self) -> bool { + matches!(self, Self::Value(_)) + } + + pub fn as_symbol_name(&self) -> Option<&SymbolNameAttr> { + match self { + Self::Symbol(ref name) => Some(name), + _ => None, + } + } + + pub fn as_value(&self) -> Option> { + match self { + Self::Value(ref value_ref) => Some(value_ref.borrow()), + _ => None, + } + } + + pub fn unwrap_symbol_name(self) -> SymbolNameAttr { + match self { + Self::Symbol(name) => name, + Self::Value(value_ref) => panic!("expected symbol, got {}", value_ref.borrow().id()), + } + } + + pub fn unwrap_value_ref(self) -> ValueRef { + match self { + Self::Value(value) => value, + Self::Symbol(ref name) => panic!("expected value, got {name}"), + } + } +} + +/// Represents whether an argument or return value has a special purpose in +/// the calling convention of a function. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)] +#[cfg_attr( + feature = "serde", + derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) +)] +#[repr(u8)] +pub enum ArgumentPurpose { + /// No special purpose, the argument is passed/returned by value + #[default] + Default, + /// Used for platforms where the calling convention expects return values of + /// a certain size to be written to a pointer passed in by the caller. + StructReturn, +} +impl fmt::Display for ArgumentPurpose { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Default => f.write_str("default"), + Self::StructReturn => f.write_str("sret"), + } + } +} + +/// Represents how to extend a small integer value to native machine integer width. +/// +/// For Miden, native integrals are unsigned 64-bit field elements, but it is typically +/// going to be the case that we are targeting the subset of Miden Assembly where integrals +/// are unsigned 32-bit integers with a standard twos-complement binary representation. +/// +/// It is for the latter scenario that argument extension is really relevant. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)] +#[cfg_attr( + feature = "serde", + derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr) +)] +#[repr(u8)] +pub enum ArgumentExtension { + /// Do not perform any extension, high bits have undefined contents + #[default] + None, + /// Zero-extend the value + Zext, + /// Sign-extend the value + Sext, +} +impl fmt::Display for ArgumentExtension { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::None => f.write_str("none"), + Self::Zext => f.write_str("zext"), + Self::Sext => f.write_str("sext"), + } + } +} + +/// Describes a function parameter or result. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct AbiParam { + /// The type associated with this value + pub ty: Type, + /// The special purpose, if any, of this parameter or result + pub purpose: ArgumentPurpose, + /// The desired approach to extending the size of this value to + /// a larger bit width, if applicable. + pub extension: ArgumentExtension, +} +impl AbiParam { + pub fn new(ty: Type) -> Self { + Self { + ty, + purpose: ArgumentPurpose::default(), + extension: ArgumentExtension::default(), + } + } + + pub fn sret(ty: Type) -> Self { + assert!(ty.is_pointer(), "sret parameters must be pointers"); + Self { + ty, + purpose: ArgumentPurpose::StructReturn, + extension: ArgumentExtension::default(), + } + } +} +impl formatter::PrettyPrint for AbiParam { + fn render(&self) -> formatter::Document { + use crate::formatter::*; + + let mut doc = const_text("(") + const_text("param") + const_text(" "); + if !matches!(self.purpose, ArgumentPurpose::Default) { + doc += const_text("(") + display(self.purpose) + const_text(")") + const_text(" "); + } + if !matches!(self.extension, ArgumentExtension::None) { + doc += const_text("(") + display(self.extension) + const_text(")") + const_text(" "); + } + doc + text(format!("{}", &self.ty)) + const_text(")") + } +} + +impl fmt::Display for AbiParam { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_map(); + builder.entry(&"ty", &format_args!("{}", &self.ty)); + if !matches!(self.purpose, ArgumentPurpose::Default) { + builder.entry(&"purpose", &format_args!("{}", &self.purpose)); + } + if !matches!(self.extension, ArgumentExtension::None) { + builder.entry(&"extension", &format_args!("{}", &self.extension)); + } + builder.finish() + } +} + +/// A [Signature] represents the type, ABI, and linkage of a function. +/// +/// A function signature provides us with all of the necessary detail to correctly +/// validate and emit code for a function, whether from the perspective of a caller, +/// or the callee. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Signature { + /// The arguments expected by this function + pub params: Vec, + /// The results returned by this function + pub results: Vec, + /// The calling convention that applies to this function + pub cc: CallConv, + /// The linkage/visibility that should be used for this function + pub visibility: Visibility, +} + +crate::define_attr_type!(Signature); + +impl fmt::Display for Signature { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .key(&"params") + .value_with(|f| { + let mut builder = f.debug_list(); + for param in self.params.iter() { + builder.entry(&format_args!("{param}")); + } + builder.finish() + }) + .key(&"results") + .value_with(|f| { + let mut builder = f.debug_list(); + for param in self.params.iter() { + builder.entry(&format_args!("{param}")); + } + builder.finish() + }) + .entry(&"cc", &format_args!("{}", &self.cc)) + .entry(&"visibility", &format_args!("{}", &self.visibility)) + .finish() + } +} + +impl Signature { + /// Create a new signature with the given parameter and result types, + /// for a public function using the `SystemV` calling convention + pub fn new, R: IntoIterator>( + params: P, + results: R, + ) -> Self { + Self { + params: params.into_iter().collect(), + results: results.into_iter().collect(), + cc: CallConv::SystemV, + visibility: Visibility::Public, + } + } + + /// Returns true if this function is externally visible + pub fn is_public(&self) -> bool { + matches!(self.visibility, Visibility::Public) + } + + /// Returns true if this function is only visible within it's containing module + pub fn is_private(&self) -> bool { + matches!(self.visibility, Visibility::Public) + } + + /// Returns true if this function is a kernel function + pub fn is_kernel(&self) -> bool { + matches!(self.cc, CallConv::Kernel) + } + + /// Returns the number of arguments expected by this function + pub fn arity(&self) -> usize { + self.params().len() + } + + /// Returns a slice containing the parameters for this function + pub fn params(&self) -> &[AbiParam] { + self.params.as_slice() + } + + /// Returns the parameter at `index`, if present + #[inline] + pub fn param(&self, index: usize) -> Option<&AbiParam> { + self.params.get(index) + } + + /// Returns a slice containing the results of this function + pub fn results(&self) -> &[AbiParam] { + match self.results.as_slice() { + [AbiParam { ty: Type::Unit, .. }] => &[], + [AbiParam { + ty: Type::Never, .. + }] => &[], + results => results, + } + } +} +impl formatter::PrettyPrint for Signature { + fn render(&self) -> formatter::Document { + use crate::formatter::*; + + let cc = if matches!(self.cc, CallConv::SystemV) { + None + } else { + Some( + const_text("(") + + const_text("cc") + + const_text(" ") + + display(self.cc) + + const_text(")"), + ) + }; + + let params = self.params.iter().fold(cc.unwrap_or(Document::Empty), |acc, param| { + if acc.is_empty() { + param.render() + } else { + acc + const_text(" ") + param.render() + } + }); + + if self.results.is_empty() { + params + } else { + let open = const_text("(") + const_text("result"); + let results = self + .results + .iter() + .fold(open, |acc, e| acc + const_text(" ") + text(format!("{}", &e.ty))) + + const_text(")"); + if matches!(params, Document::Empty) { + results + } else { + params + const_text(" ") + results + } + } + } +} diff --git a/hir2/src/ir/cfg.rs b/hir2/src/ir/cfg.rs new file mode 100644 index 000000000..12bd4e764 --- /dev/null +++ b/hir2/src/ir/cfg.rs @@ -0,0 +1,245 @@ +mod diff; +mod visit; + +pub use self::{ + diff::{CfgDiff, CfgUpdate, CfgUpdateKind, GraphDiff}, + visit::{DefaultGraphVisitor, GraphVisitor, LazyDfsVisitor, PostOrderIter, PreOrderIter}, +}; + +/// This is an abstraction over graph-like structures used in the IR: +/// +/// * The CFG of a region, i.e. graph of blocks +/// * The CFG reachable from a single block, i.e. graph of blocks +/// * The dominator graph of a region, i.e. graph of dominator nodes +/// * The call graph of a program +/// * etc... +/// +/// It isn't strictly necessary, but it provides some uniformity, and is useful particularly +/// for implementation of various analyses. +pub trait Graph { + /// The type of node represented in the graph. + /// + /// Typically this should be a pointer-like reference type, cheap to copy/clone. + type Node: Clone; + /// Type used to iterate over children of a node in the graph. + type ChildIter: ExactSizeIterator; + /// The type used to represent an edge in the graph. + /// + /// This should be cheap to copy/clone. + type Edge; + /// Type used to iterate over child edges of a node in the graph. + type ChildEdgeIter: ExactSizeIterator; + + /// An empty graph has no nodes. + #[inline] + fn is_empty(&self) -> bool { + self.size() == 0 + } + /// Get the number of nodes in this graph + fn size(&self) -> usize; + /// Get the entry node of the graph. + /// + /// It is expected that a graph always has an entry. As such, this function will panic if + /// called on an "empty" graph. You should check whether the graph is empty _first_, if you + /// are working with a possibly-empty graph. + fn entry_node(&self) -> Self::Node; + /// Get an iterator over the children of `parent` + fn children(parent: Self::Node) -> Self::ChildIter; + /// Get an iterator over the children edges of `parent` + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter; + /// Return the destination node of an edge. + fn edge_dest(edge: Self::Edge) -> Self::Node; +} + +impl<'a, G: Graph> Graph for &'a G { + type ChildEdgeIter = ::ChildEdgeIter; + type ChildIter = ::ChildIter; + type Edge = ::Edge; + type Node = ::Node; + + fn is_empty(&self) -> bool { + (**self).is_empty() + } + + fn size(&self) -> usize { + (**self).size() + } + + fn entry_node(&self) -> Self::Node { + (**self).entry_node() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + ::children(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + ::children_edges(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + ::edge_dest(edge) + } +} + +/// An [InvertibleGraph] is a [Graph] which can be "inverted", i.e. edges are reversed. +/// +/// Technically, any graph is invertible, however we are primarily interested in supporting graphs +/// for which an inversion of itself has some semantic value. For example, visiting a CFG in +/// reverse is useful in various contexts, such as constructing dominator trees. +/// +/// This is primarily consumed via [Inverse]. +pub trait InvertibleGraph: Graph { + /// The type of this graph's inversion + /// + /// This is primarily useful in cases where you inverse the inverse of a graph - by allowing + /// the types to differ, you can recover the original graph, rather than having to emulate + /// both uninverted graphs using the inverse type. + /// + /// See [Inverse] for an example of how this is used. + type Inverse: Graph; + /// The type of iterator used to visit "inverted" children of a node in this graph, i.e. + /// the predecessors. + type InvertibleChildIter: ExactSizeIterator; + /// The type of iterator used to obtain the set of "inverted" children edges of a node in this + /// graph, i.e. the predecessor edges. + type InvertibleChildEdgeIter: ExactSizeIterator; + + /// Get an iterator over the predecessors of `parent`. + /// + /// NOTE: `parent` in this case will actually be a child of the nodes in the iterator, but we + /// preserve the naming so as to make it apparent we are working with an inversion of the + /// original graph. + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter; + /// Get an iterator over the predecessor edges of `parent`. + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter; + /// Obtain the inversion of this graph + fn inverse(self) -> Self::Inverse; +} + +/// This is a wrapper type for [Graph] implementations, used to indicate that iterating a +/// graph should be iterated in "inverse" order, the semantics of which depend on the graph. +/// +/// If used with an [InvertibleGraph], it uses the graph impls inverse iterators. If used with a +/// graph that is _not_ invertible, it uses the graph impls normal iterators. Effectively, this is +/// a specialization marker type. +pub struct Inverse { + graph: G, +} + +impl Inverse { + /// Construct an inversion over `graph` + #[inline] + pub fn new(graph: G) -> Self { + Self { graph } + } +} + +impl Graph for Inverse { + type ChildEdgeIter = InverseChildEdgeIter<::InvertibleChildEdgeIter>; + type ChildIter = InverseChildIter<::InvertibleChildIter>; + type Edge = ::Edge; + type Node = ::Node; + + fn is_empty(&self) -> bool { + self.graph.is_empty() + } + + fn size(&self) -> usize { + self.graph.size() + } + + fn entry_node(&self) -> Self::Node { + self.graph.entry_node() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + InverseChildIter::new(::inverse_children(parent)) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + InverseChildEdgeIter::new(::inverse_children_edges(parent)) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + ::edge_dest(edge) + } +} + +impl InvertibleGraph for Inverse { + type Inverse = G; + type InvertibleChildEdgeIter = ::ChildEdgeIter; + type InvertibleChildIter = ::ChildIter; + + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter { + ::children(parent) + } + + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter { + ::children_edges(parent) + } + + fn inverse(self) -> Self::Inverse { + self.graph + } +} + +/// An iterator returned by `children` that iterates over `inverse_children` of the underlying graph +#[doc(hidden)] +pub struct InverseChildIter { + iter: I, +} + +impl InverseChildIter { + pub fn new(iter: I) -> Self { + Self { iter } + } +} +impl ExactSizeIterator for InverseChildIter { + #[inline] + fn len(&self) -> usize { + self.iter.len() + } + + #[inline] + fn is_empty(&self) -> bool { + self.iter.is_empty() + } +} +impl Iterator for InverseChildIter { + type Item = ::Item; + + default fn next(&mut self) -> Option { + self.iter.next() + } +} + +/// An iterator returned by `children_edges` that iterates over `inverse_children_edges` of the +/// underlying graph. +#[doc(hidden)] +pub struct InverseChildEdgeIter { + iter: I, +} +impl InverseChildEdgeIter { + pub fn new(iter: I) -> Self { + Self { iter } + } +} +impl ExactSizeIterator for InverseChildEdgeIter { + #[inline] + fn len(&self) -> usize { + self.iter.len() + } + + #[inline] + fn is_empty(&self) -> bool { + self.iter.is_empty() + } +} +impl Iterator for InverseChildEdgeIter { + type Item = ::Item; + + default fn next(&mut self) -> Option { + self.iter.next() + } +} diff --git a/hir2/src/ir/cfg/diff.rs b/hir2/src/ir/cfg/diff.rs new file mode 100644 index 000000000..876a2de66 --- /dev/null +++ b/hir2/src/ir/cfg/diff.rs @@ -0,0 +1,301 @@ +use core::fmt; + +use smallvec::SmallVec; + +use crate::{adt::SmallMap, BlockRef}; + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum CfgUpdateKind { + Insert, + Delete, +} + +#[derive(Clone, PartialEq, Eq)] +pub struct CfgUpdate { + kind: CfgUpdateKind, + from: BlockRef, + to: BlockRef, +} +impl CfgUpdate { + #[inline(always)] + pub const fn kind(&self) -> CfgUpdateKind { + self.kind + } + + #[inline(always)] + pub const fn from(&self) -> &BlockRef { + &self.from + } + + #[inline(always)] + pub const fn to(&self) -> &BlockRef { + &self.to + } +} +impl fmt::Debug for CfgUpdate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(match self.kind { + CfgUpdateKind::Insert => "Insert", + CfgUpdateKind::Delete => "Delete", + }) + .field("from", &self.from) + .field("to", &self.to) + .finish() + } +} + +#[derive(Default, Clone)] +struct DeletesInserts { + deletes: SmallVec<[BlockRef; 2]>, + inserts: SmallVec<[BlockRef; 2]>, +} +impl DeletesInserts { + pub fn di(&self, is_insert: bool) -> &SmallVec<[BlockRef; 2]> { + if is_insert { + &self.inserts + } else { + &self.deletes + } + } + + pub fn di_mut(&mut self, is_insert: bool) -> &mut SmallVec<[BlockRef; 2]> { + if is_insert { + &mut self.inserts + } else { + &mut self.deletes + } + } +} + +pub trait GraphDiff { + fn is_empty(&self) -> bool; + fn legalized_updates(&self) -> &[CfgUpdate]; + fn num_legalized_updates(&self) -> usize { + self.legalized_updates().len() + } + fn pop_update_for_incremental_updates(&mut self) -> CfgUpdate; + fn get_children(&self, node: &BlockRef) -> SmallVec<[BlockRef; 8]>; +} + +/// GraphDiff defines a CFG snapshot: given a set of Update, provides +/// a getChildren method to get a Node's children based on the additional updates +/// in the snapshot. The current diff treats the CFG as a graph rather than a +/// multigraph. Added edges are pruned to be unique, and deleted edges will +/// remove all existing edges between two blocks. +/// +/// Two booleans are used to define orders in graphs: +/// InverseGraph defines when we need to reverse the whole graph and is as such +/// also equivalent to applying updates in reverse. +/// InverseEdge defines whether we want to change the edges direction. E.g., for +/// a non-inversed graph, the children are naturally the successors when +/// InverseEdge is false and the predecessors when InverseEdge is true. +#[derive(Clone)] +pub struct CfgDiff { + succ: SmallMap, + pred: SmallMap, + /// By default, it is assumed that, given a CFG and a set of updates, we wish + /// to apply these updates as given. If UpdatedAreReverseApplied is set, the + /// updates will be applied in reverse: deleted edges are considered re-added + /// and inserted edges are considered deleted when returning children. + updated_are_reverse_applied: bool, + /// Keep the list of legalized updates for a deterministic order of updates + /// when using a GraphDiff for incremental updates in the DominatorTree. + /// The list is kept in reverse to allow popping from end. + legalized_updates: SmallVec<[CfgUpdate; 4]>, +} + +impl Default for CfgDiff { + fn default() -> Self { + Self { + succ: Default::default(), + pred: Default::default(), + updated_are_reverse_applied: false, + legalized_updates: Default::default(), + } + } +} + +impl CfgDiff { + pub fn new(updates: I, reverse_apply_updates: bool) -> Self + where + I: ExactSizeIterator, + { + let mut this = Self { + legalized_updates: legalize_updates(updates, INVERSE_GRAPH, false), + ..Default::default() + }; + for update in this.legalized_updates.iter() { + let is_insert = matches!(update.kind(), CfgUpdateKind::Insert) || reverse_apply_updates; + this.succ + .entry(update.from.clone()) + .or_default() + .di_mut(is_insert) + .push(update.to.clone()); + this.pred + .entry(update.to.clone()) + .or_default() + .di_mut(is_insert) + .push(update.from.clone()); + } + this.updated_are_reverse_applied = reverse_apply_updates; + this + } +} + +impl GraphDiff for CfgDiff { + fn is_empty(&self) -> bool { + self.succ.is_empty() && self.pred.is_empty() && self.legalized_updates.is_empty() + } + + #[inline(always)] + fn legalized_updates(&self) -> &[CfgUpdate] { + &self.legalized_updates + } + + fn pop_update_for_incremental_updates(&mut self) -> CfgUpdate { + assert!(!self.legalized_updates.is_empty(), "no updates to apply"); + let update = self.legalized_updates.pop().unwrap(); + let is_insert = + matches!(update.kind(), CfgUpdateKind::Insert) || self.updated_are_reverse_applied; + let succ_di_list = &mut self.succ[update.from()]; + let is_empty = { + let succ_list = succ_di_list.di_mut(is_insert); + assert_eq!(succ_list.last(), Some(update.to())); + succ_list.pop(); + succ_list.is_empty() + }; + if is_empty && succ_di_list.di(!is_insert).is_empty() { + self.succ.remove(update.from()); + } + + let pred_di_list = &mut self.pred[update.to()]; + let pred_list = pred_di_list.di_mut(is_insert); + assert_eq!(pred_list.last(), Some(update.from())); + pred_list.pop(); + if pred_list.is_empty() && pred_di_list.di(!is_insert).is_empty() { + self.pred.remove(update.to()); + } + update + } + + fn get_children(&self, node: &BlockRef) -> SmallVec<[BlockRef; 8]> { + let mut r = crate::dominance::nca::get_children::(node); + if !INVERSE_EDGE { + r.reverse(); + } + + let children = if INVERSE_EDGE != INVERSE_GRAPH { + &self.pred + } else { + &self.succ + }; + let Some(found) = children.get(node) else { + return r; + }; + + // Remove children present in the CFG but not in the snapshot. + for child in found.di(false) { + r.retain(|c| c != child); + } + + // Add children present in the snapshot for not in the real CFG. + r.extend(found.di(true).iter().cloned()); + + r + } +} + +/// `legalize_updates` simplifies updates assuming a graph structure. +/// +/// This function serves double purpose: +/// +/// 1. It removes redundant updates, which makes it easier to reverse-apply them when traversing +/// CFG. +/// 2. It optimizes away updates that cancel each other out, as the end result is the same. +fn legalize_updates( + all_updates: I, + inverse_graph: bool, + reverse_result_order: bool, +) -> SmallVec<[CfgUpdate; 4]> +where + I: ExactSizeIterator, +{ + #[derive(Default, Copy, Clone)] + struct UpdateOp { + num_insertions: i32, + index: u32, + } + + // Count the total number of inserions of each edge. + // Each insertion adds 1 and deletion subtracts 1. The end number should be one of: + // + // * `-1` (deletion) + // * `0` (NOP), + // * `1` (insertion). + // + // Otherwise, the sequence of updates contains multiple updates of the same kind and we assert + // for that case. + let mut operations = + SmallMap::<(BlockRef, BlockRef), UpdateOp, 4>::with_capacity(all_updates.len()); + + for ( + i, + CfgUpdate { + kind, + mut from, + mut to, + }, + ) in all_updates.enumerate() + { + if inverse_graph { + // Reverse edge for post-dominators + core::mem::swap(&mut from, &mut to); + } + + operations + .entry((from, to)) + .or_insert_with(|| UpdateOp { + num_insertions: 0, + index: i as u32, + }) + .num_insertions += match kind { + CfgUpdateKind::Insert => 1, + CfgUpdateKind::Delete => -1, + }; + } + + let mut result = SmallVec::<[CfgUpdate; 4]>::with_capacity(operations.len()); + for ((from, to), update_op) in operations.iter() { + assert!(update_op.num_insertions.abs() <= 1, "unbalanced operations!"); + if update_op.num_insertions == 0 { + continue; + } + let kind = if update_op.num_insertions > 0 { + CfgUpdateKind::Insert + } else { + CfgUpdateKind::Delete + }; + result.push(CfgUpdate { + kind, + from: from.clone(), + to: to.clone(), + }); + } + + // Make the order consistent by not relying on pointer values within the set. Reuse the old + // operations map. + // + // In the future, we should sort by something else to minimize the amount of work needed to + // perform the series of updates. + result.sort_by(|a, b| { + let op_a = &operations[&(a.from.clone(), a.to.clone())]; + let op_b = &operations[&(b.from.clone(), b.to.clone())]; + if reverse_result_order { + op_a.index.cmp(&op_b.index) + } else { + op_a.index.cmp(&op_b.index).reverse() + } + }); + + result +} diff --git a/hir2/src/ir/cfg/visit.rs b/hir2/src/ir/cfg/visit.rs new file mode 100644 index 000000000..69c4ade14 --- /dev/null +++ b/hir2/src/ir/cfg/visit.rs @@ -0,0 +1,313 @@ +use core::ops::ControlFlow; + +use smallvec::SmallVec; + +use super::Graph; +use crate::adt::SmallSet; + +/// By implementing this trait, you can refine the traversal performed by [DfsVisitor], as well as +/// hook in custom behavior to be executed upon reaching a node in both pre-order and post-order +/// visits. +/// +/// There are two callbacks, both with default implementations that align with the default +/// semantics of a depth-first traversal. +/// +/// If you wish to prune the search, the best place to do so is [GraphVisitor::on_node_reached], +/// as it provides the opportunity to control whether or not the visitor will visit any of the +/// node's successors as well as emit the node during iteration. +#[allow(unused_variables)] +pub trait GraphVisitor { + type Node; + + /// Called when a node is first reached during a depth-first traversal, i.e. pre-order + /// + /// If this function returns `ControlFlow::Break`, none of `node`'s successors will be visited, + /// and `node` will not be emitted by the visitor. This can be used to prune the traversal, + /// e.g. confining a visit to a specific loop in a CFG. + fn on_node_reached(&mut self, from: Option<&Self::Node>, node: &Self::Node) -> ControlFlow<()> { + ControlFlow::Continue(()) + } + + /// Called when all successors of a node have been visited by the depth-first traversal, i.e. + /// post-order. + fn on_block_visited(&mut self, node: &Self::Node) {} +} + +/// A useful no-op visitor for when you want the default behavior. +pub struct DefaultGraphVisitor(core::marker::PhantomData); +impl Default for DefaultGraphVisitor { + fn default() -> Self { + Self(core::marker::PhantomData) + } +} +impl GraphVisitor for DefaultGraphVisitor { + type Node = T; +} + +/// A basic iterator over a depth-first traversal of nodes in a graph, producing them in pre-order. +#[repr(transparent)] +pub struct PreOrderIter(LazyDfsVisitor::Node>>) +where + G: Graph; +impl PreOrderIter +where + G: Graph, + ::Node: Eq, +{ + /// Visit all nodes reachable from `root` in pre-order + pub fn new(root: ::Node) -> Self { + Self(LazyDfsVisitor::new(root, DefaultGraphVisitor::default())) + } + + /// Visit all nodes reachable from `root` in pre-order, treating the nodes in `visited` as + /// already visited, skipping them (and their successors) during the traversal. + pub fn new_with_visited( + root: ::Node, + visited: impl IntoIterator::Node>, + ) -> Self { + Self(LazyDfsVisitor::new_with_visited(root, DefaultGraphVisitor::default(), visited)) + } +} +impl core::iter::FusedIterator for PreOrderIter +where + G: Graph, + ::Node: Eq, +{ +} +impl Iterator for PreOrderIter +where + G: Graph, + ::Node: Eq, +{ + type Item = ::Node; + + #[inline] + fn next(&mut self) -> Option { + self.0.next::() + } +} + +/// A basic iterator over a depth-first traversal of nodes in a graph, producing them in post-order. +#[repr(transparent)] +pub struct PostOrderIter(LazyDfsVisitor::Node>>) +where + G: Graph; +impl PostOrderIter +where + G: Graph, + ::Node: Eq, +{ + /// Visit all nodes reachable from `root` in post-order + #[inline] + pub fn new(root: ::Node) -> Self { + Self(LazyDfsVisitor::new(root, DefaultGraphVisitor::default())) + } + + /// Visit all nodes reachable from `root` in post-order, treating the nodes in `visited` as + /// already visited, skipping them (and their successors) during the traversal. + pub fn new_with_visited( + root: ::Node, + visited: impl IntoIterator::Node>, + ) -> Self { + Self(LazyDfsVisitor::new_with_visited(root, DefaultGraphVisitor::default(), visited)) + } +} +impl core::iter::FusedIterator for PostOrderIter +where + G: Graph, + ::Node: Eq, +{ +} +impl Iterator for PostOrderIter +where + G: Graph, + ::Node: Eq, +{ + type Item = ::Node; + + #[inline] + fn next(&mut self) -> Option { + self.0.next::() + } +} + +/// This type is an iterator over a depth-first traversal of a graph, with customization hooks +/// provided via the [GraphVisitor] trait. +/// +/// The order in which nodes are produced by the iterator depends on how you invoke the `next` +/// method - it must be instantiated with a constant boolean that indicates whether or not the +/// iteration is to produce nodes in post-order. +/// +/// As a result, this type does not implement `Iterator` itself - it is meant to be consumed as +/// an internal detail of higher-level iterator types. Two such types are provided in this module +/// for common pre- and post-order iterations: +/// +/// * [PreOrderIter], for iterating in pre-order +/// * [PostOrderIter], for iterating in post-order +/// +pub struct LazyDfsVisitor { + /// The nodes we have already visited, or wish to consider visited + visited: SmallSet<::Node, 32>, + /// The stack of discovered nodes currently being visited + stack: SmallVec<[VisitNode<::Node>; 8]>, + /// A [GraphVisitor] implementation used to hook into the traversal machinery + visitor: V, +} + +/// Represents a node in the graph which has been reached during traversal, and is in the process of +/// being visited. +struct VisitNode { + /// The parent node in the graph from which this node was drived + parent: Option, + /// The node in the underlying graph being visited + node: T, + /// The successors of this node + successors: SmallVec<[T; 2]>, + /// Set to `true` once this node has been handled by [GraphVisitor::on_node_reached] + reached: bool, +} +impl VisitNode +where + T: Clone, +{ + #[inline] + pub fn node(&self) -> T { + self.node.clone() + } + + /// Returns true if no successors remain to be visited under this node + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.successors.is_empty() + } + + /// Get the next successor of this node, and advance the `next_child` index + /// + /// It is expected that the caller has already checked [is_empty]. Failing to do so + /// will cause this function to panic once all successors have been visited. + pub fn pop_successor(&mut self) -> T { + self.successors.pop().unwrap() + } +} + +impl LazyDfsVisitor +where + G: Graph, + ::Node: Eq, + V: GraphVisitor::Node>, +{ + /// Visit the graph rooted under `from`, using the provided visitor for customization hooks. + pub fn new(from: ::Node, visitor: V) -> Self { + Self::new_with_visited(from, visitor, None::<::Node>) + } + + /// Visit the graph rooted under `from`, using the provided visitor for customization hooks. + /// + /// The initial set of "visited" nodes is seeded with `visited`. Any node in this set (and their + /// children) will be skipped during iteration and by the traversal itself. If `from` is in this + /// set, then the resulting iterator will be empty (i.e. produce no nodes, and perform no + /// traversal). + pub fn new_with_visited( + from: ::Node, + visitor: V, + visited: impl IntoIterator::Node>, + ) -> Self { + use smallvec::smallvec; + + let visited = visited.into_iter().collect::>(); + if visited.contains(&from) { + // The root node itself is being ignored, return an empty iterator + return Self { + visited, + stack: smallvec![], + visitor, + }; + } + + let successors = SmallVec::from_iter(G::children(from.clone())); + Self { + visited, + stack: smallvec![VisitNode { + parent: None, + node: from, + successors, + reached: false, + }], + visitor, + } + } + + /// Step the visitor forward one step. + /// + /// The semantics of a step depend on the value of `POSTORDER`: + /// + /// * If `POSTORDER == true`, then we resume traversal of the graph until the next node that + /// has had all of its successors visited is on top of the visit stack. + /// * If `POSTORDER == false`, then we resume traversal of the graph until the next unvisited + /// node is reached for the first time. + /// + /// In both cases, the node we find by the search is what is returned. If no more nodes remain + /// to be visited, this returns `None`. + /// + /// This function invokes the associated [GraphVisitor] callbacks during the traversal, at the + /// appropriate time. + #[allow(clippy::should_implement_trait)] + pub fn next(&mut self) -> Option<::Node> { + loop { + let Some(node) = self.stack.last_mut() else { + break None; + }; + + if !node.reached { + node.reached = true; + let unvisited = self.visited.insert(node.node()); + if !unvisited { + let _ = unsafe { self.stack.pop().unwrap_unchecked() }; + continue; + } + + // Handle pre-order visit + let should_visit = + self.visitor.on_node_reached(node.parent.as_ref(), &node.node).is_continue(); + if !should_visit { + // It was indicated we shouldn't visit this node, so move to the next + let _ = unsafe { self.stack.pop().unwrap_unchecked() }; + continue; + } + + if POSTORDER { + // We need to visit this node's successors first + continue; + } else { + // We're going to visit this node's successors on the next call + break Some(node.node.clone()); + } + } + + // Otherwise, we're visiting a successor of this node. + // + // If this node has no successors, we're done visiting it + // If we've visited all successors of this node, we've got our next item + if node.is_empty() { + let node = unsafe { self.stack.pop().unwrap_unchecked() }; + self.visitor.on_block_visited(&node.node); + if POSTORDER { + break Some(node.node); + } else { + continue; + } + } + + // Otherwise, continue visiting successors + let parent = node.node(); + let successor = node.pop_successor(); + let successors = SmallVec::from_iter(G::children(successor.clone())); + self.stack.push(VisitNode { + parent: Some(parent), + node: successor, + successors, + reached: false, + }); + } + } +} diff --git a/hir2/src/ir/component.rs b/hir2/src/ir/component.rs new file mode 100644 index 000000000..e69de29bb diff --git a/hir2/src/ir/context.rs b/hir2/src/ir/context.rs new file mode 100644 index 000000000..cd702775f --- /dev/null +++ b/hir2/src/ir/context.rs @@ -0,0 +1,237 @@ +use alloc::{collections::BTreeMap, rc::Rc}; +use core::{ + cell::{Cell, RefCell}, + mem::MaybeUninit, +}; + +use blink_alloc::Blink; +use midenc_session::Session; + +use super::*; + +/// Represents the shared state of the IR, used during a compilation session. +/// +/// The primary purpose(s) of the context are: +/// +/// * Provide storage/memory for all allocated IR entities for the lifetime of the session. +/// * Provide unique value and block identifiers for printing the IR +/// * Provide a uniqued constant pool +/// * Provide configuration used during compilation +/// +/// # Safety +/// +/// The [Context] _must_ live as long as any reference to an IR entity may be dereferenced. +pub struct Context { + pub session: Rc, + allocator: Rc, + registered_dialects: RefCell>>, + next_block_id: Cell, + next_value_id: Cell, + //pub constants: ConstantPool, +} + +impl Default for Context { + fn default() -> Self { + use alloc::sync::Arc; + + use midenc_session::diagnostics::DefaultSourceManager; + + let target_dir = std::env::current_dir().unwrap(); + let options = midenc_session::Options::default(); + let source_manager = Arc::new(DefaultSourceManager::default()); + let session = + Rc::new(Session::new([], None, None, target_dir, options, None, source_manager)); + Self::new(session) + } +} + +impl Context { + /// Create a new [Context] for the given [Session] + pub fn new(session: Rc) -> Self { + let allocator = Rc::new(Blink::new()); + Self { + session, + allocator, + registered_dialects: Default::default(), + next_block_id: Cell::new(0), + next_value_id: Cell::new(0), + //constants: Default::default(), + } + } + + pub fn registered_dialects( + &self, + ) -> core::cell::Ref<'_, BTreeMap>> { + self.registered_dialects.borrow() + } + + pub fn get_registered_dialect(&self, dialect: &DialectName) -> Rc { + self.registered_dialects.borrow()[dialect].clone() + } + + pub fn get_or_register_dialect(&self) -> Rc { + use alloc::collections::btree_map::Entry; + + let mut registered_dialects = self.registered_dialects.borrow_mut(); + let dialect_name = DialectName::new(T::NAMESPACE); + match registered_dialects.entry(dialect_name) { + Entry::Occupied(entry) => Rc::clone(entry.get()), + Entry::Vacant(entry) => { + let dialect = Rc::new(T::init()) as Rc; + entry.insert(Rc::clone(&dialect)); + dialect + } + } + } + + /// Get a new [OpBuilder] for this context + pub fn builder(self: Rc) -> OpBuilder { + OpBuilder::new(Rc::clone(&self)) + } + + /// Create a new, detached and empty [Block] with no parameters + pub fn create_block(&self) -> BlockRef { + let block = Block::new(self.alloc_block_id()); + self.alloc_tracked(block) + } + + /// Create a new, detached and empty [Block], with parameters corresponding to the given types + pub fn create_block_with_params(&self, tys: I) -> BlockRef + where + I: IntoIterator, + { + let block = Block::new(self.alloc_block_id()); + let mut block = self.alloc_tracked(block); + let owner = block.clone(); + let args = tys.into_iter().enumerate().map(|(index, ty)| { + let id = self.alloc_value_id(); + let arg = BlockArgument::new( + SourceSpan::default(), + id, + ty, + owner.clone(), + index.try_into().expect("too many block arguments"), + ); + self.alloc(arg) + }); + block.borrow_mut().arguments_mut().extend(args); + block + } + + /// Append a new [BlockArgument] to `block`, with the given type and source location + /// + /// Returns the block argument as a `dyn Value` reference + pub fn append_block_argument( + &self, + mut block: BlockRef, + ty: Type, + span: SourceSpan, + ) -> ValueRef { + let next_index = block.borrow().num_arguments(); + let id = self.alloc_value_id(); + let arg = BlockArgument::new( + span, + id, + ty, + block.clone(), + next_index.try_into().expect("too many block arguments"), + ); + let arg = self.alloc(arg); + block.borrow_mut().arguments_mut().push(arg.clone()); + arg.upcast() + } + + /// Create a new [OpOperand] with the given value, owner, and index. + /// + /// NOTE: This inserts the operand as a user of `value`, but does _not_ add the operand to + /// `owner`'s operand storage, the caller is expected to do that. This makes this function a + /// more useful primitive. + pub fn make_operand(&self, mut value: ValueRef, owner: OperationRef, index: u8) -> OpOperand { + let op_operand = self.alloc_tracked(OpOperandImpl::new(value.clone(), owner, index)); + let mut value = value.borrow_mut(); + value.insert_use(op_operand.clone()); + op_operand + } + + /// Create a new [BlockOperand] with the given block, owner, and index. + /// + /// NOTE: This inserts the block operand as a user of `block`, but does _not_ add the block + /// operand to `owner`'s successor storage, the caller is expected to do that. This makes this + /// function a more useful primitive. + pub fn make_block_operand( + &self, + mut block: BlockRef, + owner: OperationRef, + index: u8, + ) -> BlockOperandRef { + let block_operand = self.alloc_tracked(BlockOperand::new(block.clone(), owner, index)); + let mut block = block.borrow_mut(); + block.insert_use(block_operand.clone()); + block_operand + } + + /// Create a new [OpResult] with the given type, owner, and index + /// + /// NOTE: This does not attach the result to the operation, it is expected that the caller will + /// do so. + pub fn make_result( + &self, + span: SourceSpan, + ty: Type, + owner: OperationRef, + index: u8, + ) -> OpResultRef { + let id = self.alloc_value_id(); + self.alloc(OpResult::new(span, id, ty, owner, index)) + } + + /// Allocate a new uninitialized entity of type `T` + /// + /// In general, you can probably prefer [Context::alloc] instead, but for use cases where you + /// need to allocate the space for `T` first, and then perform initialization, this can be + /// used. + pub fn alloc_uninit(&self) -> UnsafeEntityRef> { + UnsafeEntityRef::new_uninit(&self.allocator) + } + + /// Allocate a new uninitialized entity of type `T`, which needs to be tracked in an intrusive + /// doubly-linked list. + /// + /// In general, you can probably prefer [Context::alloc_tracked] instead, but for use cases + /// where you need to allocate the space for `T` first, and then perform initialization, + /// this can be used. + pub fn alloc_uninit_tracked(&self) -> UnsafeIntrusiveEntityRef> { + UnsafeIntrusiveEntityRef::new_uninit(&self.allocator) + } + + /// Allocate a new `EntityHandle`. + /// + /// [EntityHandle] is a smart-pointer type for IR entities, which behaves like a ref-counted + /// pointer with dynamically-checked borrow checking rules. It is designed to play well with + /// entities allocated from a [Context], and with the somewhat cyclical nature of the IR. + pub fn alloc(&self, value: T) -> UnsafeEntityRef { + UnsafeEntityRef::new(value, &self.allocator) + } + + /// Allocate a new `TrackedEntityHandle`. + /// + /// [TrackedEntityHandle] is like [EntityHandle], except that it is specially designed for + /// entities which are meant to be tracked in intrusive linked lists. For example, the blocks + /// in a region, or the ops in a block. It does this without requiring the entity to know about + /// the link at all, while still making it possible to access the link from the entity. + pub fn alloc_tracked(&self, value: T) -> UnsafeIntrusiveEntityRef { + UnsafeIntrusiveEntityRef::new(value, &self.allocator) + } + + fn alloc_block_id(&self) -> BlockId { + let id = self.next_block_id.get(); + self.next_block_id.set(id + 1); + BlockId::from_u32(id) + } + + fn alloc_value_id(&self) -> ValueId { + let id = self.next_value_id.get(); + self.next_value_id.set(id + 1); + ValueId::from_u32(id) + } +} diff --git a/hir2/src/ir/dialect.rs b/hir2/src/ir/dialect.rs new file mode 100644 index 000000000..0cd062117 --- /dev/null +++ b/hir2/src/ir/dialect.rs @@ -0,0 +1,135 @@ +use alloc::rc::Rc; +use core::{borrow::Borrow, ops::Deref}; + +use crate::{AsAny, AttributeValue, Builder, OperationName, OperationRef, SourceSpan, Type}; + +/// A [Dialect] represents a collection of IR entities that are used in conjunction with one +/// another. Multiple dialects can co-exist _or_ be mutually exclusive. Converting between dialects +/// is the job of the conversion infrastructure, using a process called _legalization_. +pub trait Dialect { + /// Get the name(space) of this dialect + fn name(&self) -> DialectName; + /// Get the set of registered operations associated with this dialect + fn registered_ops(&self) -> Rc<[OperationName]>; + /// Get the registered [OperationName] for an op `opcode`, or register it with `register`. + /// + /// Registering an operation with the dialect allows various parts of the IR to introspect the + /// set of operations which belong to a given dialect namespace. + fn get_or_register_op( + &self, + opcode: ::midenc_hir_symbol::Symbol, + register: fn(DialectName, ::midenc_hir_symbol::Symbol) -> OperationName, + ) -> OperationName; + + /// A hook to materialize a single constant operation from a given attribute value and type. + /// + /// This method should use the provided builder to create the operation without changing the + /// insertion point. The generated operation is expected to be constant-like, i.e. single result + /// zero operands, no side effects, etc. + /// + /// Returns `None` if a constant cannot be materialized for the given attribute. + #[allow(unused_variables)] + #[inline] + fn materialize_constant( + &self, + builder: &mut dyn Builder, + attr: Box, + ty: &Type, + span: SourceSpan, + ) -> Option { + None + } +} + +/// A [DialectRegistration] must be implemented for any implementation of [Dialect], to allow the +/// dialect to be registered with a [crate::Context] and instantiated on demand when building ops +/// in the IR. +/// +/// This is not part of the [Dialect] trait itself, as that trait must be object safe, and this +/// trait is _not_ object safe. +pub trait DialectRegistration: AsAny + Dialect { + /// The namespace of the dialect to register + /// + /// A dialect namespace serves both as a way to namespace the operations of that dialect, as + /// well as a way to uniquely name/identify the dialect itself. Thus, no two dialects can have + /// the same namespace at the same time. + const NAMESPACE: &'static str; + + /// Initialize an instance of this dialect to be stored (uniqued) in the current + /// [crate::Context]. + /// + /// A dialect will only ever be initialized once per context. A dialect must use interior + /// mutability to satisfy the requirements of the [Dialect] trait, and to allow the context to + /// store the returned instance in a reference-counted smart pointer. + fn init() -> Self; +} + +/// A strongly-typed symbol representing the name of a [Dialect]. +/// +/// Dialect names should be in lowercase ASCII format, though this is not enforced. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct DialectName(::midenc_hir_symbol::Symbol); +impl DialectName { + pub fn new(name: S) -> Self + where + S: Into<::midenc_hir_symbol::Symbol>, + { + Self(name.into()) + } + + pub const fn from_symbol(name: ::midenc_hir_symbol::Symbol) -> Self { + Self(name) + } + + pub const fn as_symbol(&self) -> ::midenc_hir_symbol::Symbol { + self.0 + } +} +impl core::fmt::Debug for DialectName { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.0.as_str()) + } +} +impl core::fmt::Display for DialectName { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.0.as_str()) + } +} +impl From<::midenc_hir_symbol::Symbol> for DialectName { + #[inline(always)] + fn from(value: ::midenc_hir_symbol::Symbol) -> Self { + Self(value) + } +} +impl From for ::midenc_hir_symbol::Symbol { + #[inline(always)] + fn from(value: DialectName) -> Self { + value.0 + } +} +impl Deref for DialectName { + type Target = ::midenc_hir_symbol::Symbol; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl AsRef<::midenc_hir_symbol::Symbol> for DialectName { + #[inline(always)] + fn as_ref(&self) -> &::midenc_hir_symbol::Symbol { + &self.0 + } +} +impl Borrow<::midenc_hir_symbol::Symbol> for DialectName { + #[inline(always)] + fn borrow(&self) -> &::midenc_hir_symbol::Symbol { + &self.0 + } +} +impl Borrow for DialectName { + #[inline(always)] + fn borrow(&self) -> &str { + self.0.as_str() + } +} diff --git a/hir2/src/ir/dominance.rs b/hir2/src/ir/dominance.rs new file mode 100644 index 000000000..e8186b315 --- /dev/null +++ b/hir2/src/ir/dominance.rs @@ -0,0 +1,17 @@ +mod info; +pub mod nca; +mod traits; +mod tree; + +pub use self::{ + info::{DominanceInfo, PostDominanceInfo, RegionDominanceInfo}, + traits::{Dominates, PostDominates}, + tree::{ + DomTreeError, DomTreeVerificationLevel, DominanceTree, PostDominanceTree, + PostOrderDomTreeIter, PreOrderDomTreeIter, + }, +}; +use self::{ + nca::{BatchUpdateInfo, SemiNCA}, + tree::{DomTreeBase, DomTreeNode, DomTreeRoots}, +}; diff --git a/hir2/src/ir/dominance/info.rs b/hir2/src/ir/dominance/info.rs new file mode 100644 index 000000000..f1a1bd185 --- /dev/null +++ b/hir2/src/ir/dominance/info.rs @@ -0,0 +1,419 @@ +use alloc::{collections::BTreeMap, rc::Rc}; +use core::cell::{LazyCell, Ref, RefCell}; + +use super::*; +use crate::{Block, BlockRef, OperationRef, RegionKindInterface, RegionRef}; + +/// [DominanceInfo] provides a high-level API for querying dominance information. +/// +/// Note that this type is aware of the different types of regions, and returns a region-kind +/// specific notion of dominance. See [RegionKindInterface] for details. +pub struct DominanceInfo { + info: DominanceInfoBase, +} + +impl DominanceInfo { + #[doc(hidden)] + #[inline(always)] + pub(crate) fn info(&self) -> &DominanceInfoBase { + &self.info + } + + /// Returns true if `a` dominates `b`. + /// + /// Note that if `a == b`, this returns true, if you want strict dominance, see + /// [Self::properly_dominates] instead. + /// + /// The specific details of how dominance is computed is specific to the types involved. See + /// the implementations of the [Dominates] trait for that information. + pub fn dominates(&self, a: &A, b: &B) -> bool + where + A: Dominates, + { + a.dominates(b, self) + } + + /// Returns true if `a` properly dominates `b`. + /// + /// This always returns false if `a == b`. + /// + /// The specific details of how dominance is computed is specific to the types involved. See + /// the implementations of the [Dominates] trait for that information. + pub fn properly_dominates(&self, a: &A, b: &B) -> bool + where + A: Dominates, + { + a.properly_dominates(b, self) + } + + /// An implementation of `properly_dominates` for operations, where we sometimes wish to treat + /// `a` as dominating `b`, if `b` is enclosed by a region of `a`. This behavior is controlled + /// by the `enclosing_op_ok` flag. + pub fn properly_dominates_with_options( + &self, + a: &OperationRef, + b: &OperationRef, + enclosing_op_ok: bool, + ) -> bool { + let a_block = a.borrow().parent().expect("`a` must be in a block"); + let mut b_block = b.borrow().parent().expect("`b` must be in a block"); + + // An instruction dominates itself, but does not properly dominate itself, unless this is + // a graph region. + if a == b { + return !a_block.borrow().has_ssa_dominance(); + } + + // If these ops are in different regions, then normalize one into the other. + let a_region = a_block.borrow().parent().unwrap(); + let mut b = b.clone(); + if a_region != b_block.borrow().parent().unwrap() { + // Walk up `b`'s region tree until we find an operation in `a`'s region that encloses + // it. If this fails, then we know there is no post-dominance relation. + let Some(found) = a_region.borrow().find_ancestor_op(&b) else { + return false; + }; + b = found; + b_block = b.borrow().parent().expect("`b` must be in a block"); + assert!(b_block.borrow().parent().unwrap() == a_region); + + // If `a` encloses `b`, then we consider it to dominate. + if a == &b && enclosing_op_ok { + return true; + } + } + + // Ok, they are in the same region now. + if a_block == b_block { + // Dominance changes based on the region type. In a region with SSA dominance, uses + // insde the same block must follow defs. In other region kinds, uses and defs can + // come in any order inside a block. + return if a_block.borrow().has_ssa_dominance() { + // If the blocks are the same, then check if `b` is before `a` in the block. + a.borrow().is_before_in_block(&b) + } else { + true + }; + } + + // If the blocks are different, use the dominance tree to resolve the query + self.info + .dominance(&a_region) + .properly_dominates(Some(&a_block), Some(&b_block)) + } +} + +/// [PostDominanceInfo] provides a high-level API for querying post-dominance information. +/// +/// Note that this type is aware of the different types of regions, and returns a region-kind +/// specific notion of dominance. See [RegionKindInterface] for details. +pub struct PostDominanceInfo { + info: DominanceInfoBase, +} + +impl PostDominanceInfo { + #[doc(hidden)] + #[inline(always)] + pub(super) fn info(&self) -> &DominanceInfoBase { + &self.info + } + + /// Returns true if `a` post-dominates `b`. + /// + /// Note that if `a == b`, this returns true, if you want strict post-dominance, see + /// [Self::properly_post_dominates] instead. + /// + /// The specific details of how dominance is computed is specific to the types involved. See + /// the implementations of the [PostDominates] trait for that information. + pub fn post_dominates(&self, a: &A, b: &B) -> bool + where + A: PostDominates, + { + a.post_dominates(b, self) + } + + /// Returns true if `a` properly post-dominates `b`. + /// + /// This always returns false if `a == b`. + /// + /// The specific details of how dominance is computed is specific to the types involved. See + /// the implementations of the [PostDominates] trait for that information. + pub fn properly_post_dominates(&self, a: &A, b: &B) -> bool + where + A: PostDominates, + { + a.properly_post_dominates(b, self) + } +} + +/// This type carries the dominance information for a single region, lazily computed on demand. +pub struct RegionDominanceInfo { + /// The dominator tree for this region + domtree: LazyCell>>, RegionDomTreeCtor>, + /// A flag that indicates where blocks in this region have SSA dominance + has_ssa_dominance: bool, +} + +impl RegionDominanceInfo { + /// Construct a new [RegionDominanceInfo] for `region` + pub fn new(region: RegionRef) -> Self { + let r = region.borrow(); + let parent_op = r.parent().unwrap(); + // A region has SSA dominance if it tells us one way or the other, otherwise we must assume + // that it does. + let has_ssa_dominance = parent_op + .borrow() + .as_trait::() + .map(|rki| rki.has_ssa_dominance()) + .unwrap_or(true); + + Self::create(region, has_ssa_dominance, r.has_one_block()) + } + + fn create(region: RegionRef, has_ssa_dominance: bool, has_one_block: bool) -> Self { + // We only create a dominator tree for multi-block regions + if has_one_block { + Self { + domtree: LazyCell::new(RegionDomTreeCtor(None)), + has_ssa_dominance, + } + } else { + Self { + domtree: LazyCell::new(RegionDomTreeCtor(Some(region))), + has_ssa_dominance, + } + } + } + + /// Get the dominance tree for this region. + /// + /// Returns `None` if the region was empty or had only a single block. + pub fn dominance(&self) -> Option>> { + self.domtree.clone() + } +} + +/// This type provides shared functionality to both [DominanceInfo] and [PostDominanceInfo]. +#[derive(Default)] +pub(crate) struct DominanceInfoBase { + /// A mapping of regions to their dominator tree and a flag that indicates whether or not they + /// have SSA dominance. + /// + /// This map does not contain dominator trees for empty or single block regions, however we + /// still compute whether or not they have SSA dominance regardless. + dominance_infos: RefCell>>, +} + +#[allow(unused)] +impl DominanceInfoBase { + /// Compute dominance information for all of the regions in `op`. + pub fn new(op: OperationRef) -> Self { + let op = op.borrow(); + let has_ssa_dominance = op + .as_trait::() + .map(|rki| rki.has_ssa_dominance()) + .unwrap_or(true); + let dominance_infos = BTreeMap::from_iter(op.regions().iter().map(|r| { + let region = r.as_region_ref(); + let info = RegionDominanceInfo::::create( + region.clone(), + has_ssa_dominance, + r.has_one_block(), + ); + (region, info) + })); + + Self { + dominance_infos: RefCell::new(dominance_infos), + } + } + + /// Invalidate all dominance info. + /// + /// This can be used by clients that make major changes to the CFG and don't have a good way to + /// update it. + pub fn invalidate(&mut self) { + self.dominance_infos.get_mut().clear(); + } + + /// Invalidate dominance info for the given region. + /// + /// This can be used by clients that make major changes to the CFG and don't have a good way to + /// update it. + pub fn invalidate_region(&mut self, region: RegionRef) { + self.dominance_infos.get_mut().remove(®ion); + } + + /// Finds the nearest common dominator block for the two given blocks `a` and `b`. + /// + /// If no common dominator can be found, this function will return `None`. + pub fn find_nearest_common_dominator_of( + &self, + a: Option<&BlockRef>, + b: Option<&BlockRef>, + ) -> Option { + // If either `a` or `b` are `None`, then conservatively return `None` + let a = a?; + let b = b?; + + // If they are the same block, then we are done. + if a == b { + return Some(a.clone()); + } + + // Try to find blocks that are in the same region. + let (a, b) = Block::get_blocks_in_same_region(a, b)?; + + // If the common ancestor in a common region is the same block, then return it. + if a == b { + return Some(a); + } + + // Otherwise, there must be multiple blocks in the region, check the dominance tree + self.dominance(&a.borrow().parent().unwrap()) + .find_nearest_common_dominator(&a, &b) + } + + /// Finds the nearest common dominator block for the given range of blocks. + /// + /// If no common dominator can be found, this function will return `None`. + pub fn find_nearest_common_dominator_of_all<'a>( + &self, + mut blocks: impl ExactSizeIterator, + ) -> Option { + let mut dom = blocks.next().cloned(); + + for block in blocks { + dom = self.find_nearest_common_dominator_of(dom.as_ref(), Some(block)); + } + + dom + } + + /// Get the root dominance node of the given region. + /// + /// Panics if `region` is not a multi-block region. + pub fn root_node(&self, region: &RegionRef) -> Rc { + self.get_dominance_info(region) + .domtree + .as_deref() + .expect("`region` isn't multi-block") + .root_node() + .expect("expected region to have a root node") + } + + /// Return the dominance node for the region containing `block`. + /// + /// Panics if `block` is not a member of a multi-block region. + pub fn node(&self, block: &BlockRef) -> Option> { + self.get_dominance_info(&block.borrow().parent().expect("block isn't attached to region")) + .domtree + .as_deref() + .expect("`block` isn't in a multi-block region") + .get(Some(block)) + } + + /// Return true if the specified block is reachable from the entry block of its region. + pub fn is_reachable_from_entry(&self, block: &BlockRef) -> bool { + // If this is the first block in its region, then it is trivially reachable. + if block.borrow().is_entry_block() { + return true; + } + + let region = block.borrow().parent().expect("block isn't attached to region"); + self.dominance(®ion).is_reachable_from_entry(block) + } + + /// Return true if operations in the specified block are known to obey SSA dominance rules. + /// + /// Returns false if the block is a graph region or unknown. + pub fn block_has_ssa_dominance(&self, block: &BlockRef) -> bool { + let region = block.borrow().parent().expect("block isn't attached to region"); + self.get_dominance_info(®ion).has_ssa_dominance + } + + /// Return true if operations in the specified region are known to obey SSA dominance rules. + /// + /// Returns false if the region is a graph region or unknown. + pub fn region_has_ssa_dominance(&self, region: &RegionRef) -> bool { + self.get_dominance_info(region).has_ssa_dominance + } + + /// Returns the dominance tree for `region`. + /// + /// Panics if `region` is a single-block region. + pub fn dominance(&self, region: &RegionRef) -> Rc> { + self.get_dominance_info(region) + .dominance() + .expect("cannot get dominator tree for single block regions") + } + + /// Return the dominance information for `region`. + /// + /// NOTE: The dominance tree for single-block regions will be `None` + fn get_dominance_info(&self, region: &RegionRef) -> Ref<'_, RegionDominanceInfo> { + // Check to see if we already have this information. + self.dominance_infos + .borrow_mut() + .entry(region.clone()) + .or_insert_with(|| RegionDominanceInfo::new(region.clone())); + + Ref::map(self.dominance_infos.borrow(), |di| &di[region]) + } + + /// Return true if the specified block A properly dominates block B. + pub fn properly_dominates(&self, a: &BlockRef, b: &BlockRef) -> bool { + // A block dominates itself, but does not properly dominate itself. + if a == b { + return false; + } + + // If both blocks are not in the same region, `a` properly dominates `b` if `b` is defined + // in an operation region that (recursively) ends up being dominated by `a`. Walk up the + // ancestors of `b`. + let mut b = b.clone(); + let a_region = a.borrow().parent(); + if a_region != b.borrow().parent() { + // If we could not find a valid block `b` then it is not a dominator. + let Some(found) = a_region.as_ref().and_then(|r| r.borrow().find_ancestor_block(&b)) + else { + return false; + }; + + b = found; + + // Check to see if the ancestor of `b` is the same block as `a`. `a` properly dominates + // `b` if it contains an op that contains the `b` block + if a == &b { + return true; + } + } + + // Otherwise, they are two different blocks in the same region, use dominance tree + self.dominance(&a_region.unwrap()).properly_dominates(Some(a), Some(&b)) + } +} + +/// A faux-constructor for [RegionDominanceInfo] for use with [LazyCell] without boxing. +struct RegionDomTreeCtor(Option); +impl FnOnce<()> for RegionDomTreeCtor { + type Output = Option>>; + + extern "rust-call" fn call_once(self, _args: ()) -> Self::Output { + self.0.and_then(|region| DomTreeBase::new(region).ok().map(Rc::new)) + } +} +impl FnMut<()> for RegionDomTreeCtor { + extern "rust-call" fn call_mut(&mut self, _args: ()) -> Self::Output { + self.0 + .as_ref() + .and_then(|region| DomTreeBase::new(region.clone()).ok().map(Rc::new)) + } +} +impl Fn<()> for RegionDomTreeCtor { + extern "rust-call" fn call(&self, _args: ()) -> Self::Output { + self.0 + .as_ref() + .and_then(|region| DomTreeBase::new(region.clone()).ok().map(Rc::new)) + } +} diff --git a/hir2/src/ir/dominance/nca.rs b/hir2/src/ir/dominance/nca.rs new file mode 100644 index 000000000..025b634eb --- /dev/null +++ b/hir2/src/ir/dominance/nca.rs @@ -0,0 +1,1531 @@ +use alloc::{collections::BTreeMap, rc::Rc}; +use core::cell::{Cell, Ref, RefCell}; + +use smallvec::SmallVec; + +use super::{DomTreeBase, DomTreeNode, DomTreeRoots}; +use crate::{ + cfg::{self, Graph, GraphDiff, Inverse}, + formatter::{DisplayOptional, DisplayValues}, + BlockRef, EntityWithId, Region, +}; + +/// [SemiNCAInfo] provides functionality for constructing a dominator tree for a control-flow graph +/// based on the Semi-NCA algorithm described in the following dissertation: +/// +/// [1] Linear-Time Algorithms for Dominators and Related Problems +/// Loukas Georgiadis, Princeton University, November 2005, pp. 21-23: +/// ftp://ftp.cs.princeton.edu/reports/2005/737.pdf +/// +/// The Semi-NCA algorithm runs in O(n^2) worst-case time but usually slightly faster than Simple +/// Lengauer-Tarjan in practice. +/// +/// O(n^2) worst cases happen when the computation of nearest common ancestors requires O(n) average +/// time, which is very unlikely in real world. If this ever turns out to be an issue, consider +/// implementing a hybrid algorithm that uses SLT to perform full constructions and SemiNCA for +/// incremental updates. +/// +/// The file uses the Depth Based Search algorithm to perform incremental updates (insertion and +/// deletions). The implemented algorithm is based on this publication: +/// +/// [2] An Experimental Study of Dynamic Dominators +/// Loukas Georgiadis, et al., April 12 2016, pp. 5-7, 9-10: +/// https://arxiv.org/pdf/1604.02711.pdf +pub struct SemiNCA { + /// Number to node mapping is 1-based. + virtual_node: RefCell, + num_to_node: SmallVec<[BlockRef; 64]>, + /// Infos are mapped to nodes using block indices + node_infos: RefCell>, + batch_updates: Option>, +} + +/// Get the successors (or predecessors, if `INVERSED == true`) of `node`, incorporating insertions +/// and deletions from `bui` if available. +/// +/// The use of "children" here changes meaning depending on: +/// +/// * Whether or not the graph traversal is `INVERSED` +/// * Whether or not the graph is a post-dominator tree (i.e. `IS_POST_DOM`) +/// +/// If we're traversing a post-dominator tree, then the "children" of a node are actually +/// predecessors of the block in the CFG. However, if the traversal is _also_ `INVERSED`, then the +/// children actually are successors of the block in the CFG. +/// +/// For a forward-dominance tree, "children" do correspond to successors in the CFG, but again, if +/// the traversal is `INVERSED`, then the children are actually predecessors. +/// +/// This function (and others in this module) are written in such a way that we can abstract over +/// whether the underlying dominator tree is a forward- or post-dominance tree, as much of the +/// implementation is identical. +pub fn get_children_with_batch_updates( + node: &BlockRef, + bui: Option<&BatchUpdateInfo>, +) -> SmallVec<[BlockRef; 8]> { + use crate::cfg::GraphDiff; + + if let Some(bui) = bui { + bui.pre_cfg_view.get_children::(node) + } else { + get_children::(node) + } +} + +/// Get the successors (or predecessors, if `INVERSED == true`) of `node`. +pub fn get_children(node: &BlockRef) -> SmallVec<[BlockRef; 8]> { + if INVERSED { + Inverse::::children(node.clone()).collect() + } else { + let mut r = BlockRef::children(node.clone()).collect::>(); + r.reverse(); + r + } +} + +#[derive(Default)] +pub struct NodeInfo { + num: Cell, + parent: Cell, + semi: Cell, + label: Cell, + idom: Cell>, + reverse_children: SmallVec<[u32; 4]>, +} +impl NodeInfo { + pub fn idom(&self) -> Option { + unsafe { &*self.idom.as_ptr() }.clone() + } + + #[inline] + pub fn num(&self) -> u32 { + self.num.get() + } + + #[inline] + pub fn parent(&self) -> u32 { + self.parent.get() + } + + #[inline] + pub fn semi(&self) -> u32 { + self.semi.get() + } + + #[inline] + pub fn label(&self) -> u32 { + self.label.get() + } + + #[inline] + pub fn reverse_children(&self) -> &[u32] { + &self.reverse_children + } +} + +/// [BatchUpdateInfo] represents a batch of insertion/deletion operations that have been applied to +/// the CFG. This information is used to incrementally update the dominance tree as changes are +/// made to the CFG. +#[derive(Default, Clone)] +pub struct BatchUpdateInfo { + pub pre_cfg_view: cfg::CfgDiff, + pub post_cfg_view: cfg::CfgDiff, + pub num_legalized: usize, + // Remembers if the whole tree was recomputed at some point during the current batch update + pub is_recalculated: bool, +} + +impl BatchUpdateInfo { + pub fn new( + pre_cfg_view: cfg::CfgDiff, + post_cfg_view: Option>, + ) -> Self { + let num_legalized = pre_cfg_view.num_legalized_updates(); + Self { + pre_cfg_view, + post_cfg_view: post_cfg_view.unwrap_or_default(), + num_legalized, + is_recalculated: false, + } + } +} + +impl SemiNCA { + /// Obtain a fresh [SemiNCA] instance, using the provided set of [BatchUpdateInfo]. + pub fn new(batch_updates: Option>) -> Self { + Self { + virtual_node: Default::default(), + num_to_node: Default::default(), + node_infos: Default::default(), + batch_updates, + } + } + + /// Reset the [SemiNCA] state so it can be used to compute a dominator tree from scratch. + pub fn clear(&mut self) { + // Don't reset the pointer to BatchUpdateInfo here -- if there's an update in progress, + // we need this information to continue it. + self.virtual_node = Default::default(); + self.num_to_node.clear(); + self.node_infos.get_mut().clear(); + } + + /// Look up information about a block in the Semi-NCA state + pub fn node_info(&self, block: Option<&BlockRef>) -> Ref<'_, NodeInfo> { + match block { + None => self.virtual_node.borrow(), + Some(block) => { + let id = block.borrow().id(); + let index = id.as_u32() as usize; + + if index >= self.node_infos.borrow().len() { + self.node_infos.borrow_mut().resize_with(index + 1, NodeInfo::default); + } + + Ref::map(self.node_infos.borrow(), |ni| unsafe { ni.get_unchecked(index) }) + } + } + } + + /// Get a mutable reference to the stored informaton for `block` + pub fn node_info_mut(&mut self, block: Option<&BlockRef>) -> &mut NodeInfo { + match block { + None => self.virtual_node.get_mut(), + Some(block) => { + let id = block.borrow().id(); + let index = id.as_u32() as usize; + + let node_infos = self.node_infos.get_mut(); + if index >= node_infos.len() { + node_infos.resize_with(index + 1, NodeInfo::default); + } + + unsafe { node_infos.get_unchecked_mut(index) } + } + } + } + + /// Look up the immediate dominator for `block`, if it has one. + /// + /// A value of `None` for `block` is meaningless, as virtual nodes only are present in post- + /// dominance graphs, and always post-dominate all other nodes in the graph. However, it is + /// convenient to have many of the APIs in this module take a `Option` for uniformity. + pub fn idom(&self, block: Option<&BlockRef>) -> Option { + self.node_info(block).idom() + } + + /// Get or compute the dominance tree node information for `block`, in `tree`, using the current + /// Semi-NCA state. + pub fn node_for_block( + &self, + block: Option<&BlockRef>, + tree: &mut DomTreeBase, + ) -> Option> { + let node = tree.get(block); + if node.is_some() { + return node; + } + + // Haven't calculated this node yet? Get or calculate the node for the immediate dominator + let idom = self.idom(block); + let idom_node = match idom { + None => Some(tree.get(None).expect("expected idom or virtual node")), + Some(idom_block) => self.node_for_block(Some(&idom_block), tree), + }; + + // Add a new tree node for this node, and link it as a child of idom_node + Some(tree.create_node(block.cloned(), idom_node)) + } + + /// Custom DFS implementation which can skip nodes based on a provided predicate. + /// + /// It also collects reverse children so that we don't have to spend time getting predecessors + /// in SemiNCA. + /// + /// If `IsReverse` is set to true, the DFS walk will be performed backwards relative to IS_POST_DOM + /// -- using reverse edges for dominators and forward edges for post-dominators. + /// + /// If `succ_order` is specified then that is the order in which the DFS traverses the children, + /// otherwise the order is implied by the results of `get_children`. + pub fn run_dfs( + &mut self, + v: Option<&BlockRef>, + mut last_num: u32, + mut condition: C, + attach_to_num: u32, + succ_order: Option<&BTreeMap>, + ) -> u32 + where + C: FnMut(Option<&BlockRef>, Option<&BlockRef>) -> bool, + { + let v = v.expect("expected valid root node for search"); + + let mut worklist = + SmallVec::<[(BlockRef, u32); 64]>::from_iter([(v.clone(), attach_to_num)]); + + self.node_info_mut(Some(v)).parent.set(attach_to_num); + + while let Some((block, parent_num)) = worklist.pop() { + let block_info = self.node_info_mut(Some(&block)); + block_info.reverse_children.push(parent_num); + + // Visited nodes always have positive DFS numbers. + if block_info.num.get() != 0 { + continue; + } + + block_info.parent.set(parent_num); + last_num += 1; + block_info.num.set(last_num); + block_info.semi.set(last_num); + block_info.label.set(last_num); + self.num_to_node.push(block.clone()); + + let mut successors = if const { REVERSE != IS_POST_DOM } { + get_children_with_batch_updates::( + &block, + self.batch_updates.as_ref(), + ) + } else { + get_children_with_batch_updates::( + &block, + self.batch_updates.as_ref(), + ) + }; + if let Some(succ_order) = succ_order { + if successors.len() > 1 { + successors.sort_by(|a, b| succ_order[a].cmp(&succ_order[b])); + } + } + + for succ in successors.into_iter().filter(|succ| condition(Some(&block), Some(succ))) { + worklist.push((succ, last_num)); + } + } + + last_num + } + + // V is a predecessor of W. eval() returns V if V < W, otherwise the minimum + // of sdom(U), where U > W and there is a virtual forest path from U to V. The + // virtual forest consists of linked edges of processed vertices. + // + // We can follow Parent pointers (virtual forest edges) to determine the + // ancestor U with minimum sdom(U). But it is slow and thus we employ the path + // compression technique to speed up to O(m*log(n)). Theoretically the virtual + // forest can be organized as balanced trees to achieve almost linear + // O(m*alpha(m,n)) running time. But it requires two auxiliary arrays (Size + // and Child) and is unlikely to be faster than the simple implementation. + // + // For each vertex V, its Label points to the vertex with the minimal sdom(U) + // (Semi) in its path from V (included) to NodeToInfo[V].Parent (excluded). + fn eval<'a, 'b: 'a>( + v: u32, + last_linked: u32, + eval_stack: &mut SmallVec<[&'a NodeInfo; 32]>, + num_to_info: &'b [Ref<'b, NodeInfo>], + ) -> u32 { + let mut v_info = &*num_to_info[v as usize]; + if v_info.parent.get() < last_linked { + return v_info.label.get(); + } + + // Store ancestors except the last (root of a virtual tree) into a stack. + eval_stack.clear(); + loop { + let parent = &num_to_info[v_info.parent.get() as usize]; + eval_stack.push(v_info); + v_info = parent; + if v_info.parent.get() < last_linked { + break; + } + } + + // Path compression. Point each vertex's `parent` to the root and update its `label` if any + // of its ancestors `label` has a smaller `semi` + let mut p_info = v_info; + let mut p_label_info = &*num_to_info[p_info.label.get() as usize]; + while let Some(info) = eval_stack.pop() { + v_info = info; + v_info.parent.set(p_info.parent.get()); + let v_label_info = &*num_to_info[v_info.label.get() as usize]; + if p_label_info.semi.get() < v_label_info.semi.get() { + v_info.label.set(p_info.label.get()); + } else { + p_label_info = v_label_info; + } + p_info = v_info; + } + + v_info.label.get() + } + + /// This function requires DFS to be run before calling it. + pub fn run(&mut self) { + let next_num = self.num_to_node.len(); + let mut num_to_info = SmallVec::<[Ref<'_, NodeInfo>; 8]>::default(); + num_to_info.reserve(next_num); + + // Initialize idoms to spanning tree parents + for i in 0..next_num { + let v = &self.num_to_node[i]; + let v_info = self.node_info(Some(v)); + v_info.idom.set(Some(self.num_to_node[v_info.parent.get() as usize].clone())); + num_to_info.push(v_info); + } + + // Step 1: Calculate the semi-dominators of all vertices + let mut eval_stack = SmallVec::<[&NodeInfo; 32]>::default(); + for i in (2..(next_num - 1)).rev() { + let w_info = &num_to_info[i]; + + // Initialize the semi-dominator to point to the parent node. + w_info.semi.set(w_info.parent.get()); + for n in w_info.reverse_children.iter().copied() { + let semi_u = num_to_info + [Self::eval(n, i as u32 + 1, &mut eval_stack, &num_to_info) as usize] + .semi + .get(); + if semi_u < w_info.semi.get() { + w_info.semi.set(semi_u); + } + } + } + + // Step 2: Explicitly define the immediate dominator of each vertex. + // + // IDom[i] = NCA(SDom[i], SpanningTreeParent(i)) + // + // Note that the parents were stored in IDoms and later got invalidated during path + // compression in `eval` + for i in 2..next_num { + let w_info = &num_to_info[i]; + assert_ne!(w_info.semi.get(), 0); + let s_dom_num = num_to_info[w_info.semi.get() as usize].num.get(); + let mut w_idom_candidate = w_info.idom(); + loop { + let w_idom_candidate_info = self.node_info(w_idom_candidate.as_ref()); + if w_idom_candidate_info.num.get() <= s_dom_num { + break; + } + w_idom_candidate = w_idom_candidate_info.idom(); + } + + w_info.idom.set(w_idom_candidate); + } + } + + /// [PostDominatorTree] always has a virtual root that represents a virtual CFG node that serves + /// as a single exit from the region. + /// + /// All the other exits (CFG nodes with terminators and nodes in infinite loops) are logically + /// connected to this virtual CFG exit node. + /// + /// This function maps a null CFG node to the virtual root tree node. + fn add_virtual_root(&mut self) { + if const { IS_POST_DOM } { + assert!(self.num_to_node.is_empty(), "SemiNCAInfo must be freshly constructed"); + + let info = self.virtual_node.get_mut(); + info.num.set(1); + info.semi.set(1); + info.label.set(1); + + // num_to_node[1] = None + } + } + + /// For postdominators, nodes with no forward successors are trivial roots that + /// are always selected as tree roots. Roots with forward successors correspond + /// to CFG nodes within infinite loops. + fn has_forward_successors( + n: Option<&BlockRef>, + bui: Option<&BatchUpdateInfo>, + ) -> bool { + let n = n.expect("`n` must be a valid node"); + !get_children_with_batch_updates::(n, bui).is_empty() + } + + fn entry_node(tree: &DomTreeBase) -> BlockRef { + tree.parent() + .borrow() + .entry_block_ref() + .expect("expected region to have an entry block") + } + + pub fn find_roots( + tree: &DomTreeBase, + bui: Option<&BatchUpdateInfo>, + ) -> DomTreeRoots { + let mut roots = DomTreeRoots::default(); + + // For dominators, region entry CFG node is always a tree root node. + if !IS_POST_DOM { + roots.push(Some(Self::entry_node(tree))); + return roots; + } + + let mut snca = Self::new(bui.cloned()); + + // PostDominatorTree always has a virtual root. + snca.add_virtual_root(); + let mut num = 1u32; + + log::trace!("looking for trivial roots"); + + // Step 1: Find all the trivial roots that are going to definitely remain tree roots + let mut total = 0; + // It may happen that there are some new nodes in the CFG that are result of the ongoing + // batch update, but we cannot really pretend that they don't exist -- we won't see any + // outgoing or incoming edges to them, so it's fine to discover them here, as they would end + // up appearing in the CFG at some point anyway. + let region = tree.parent().borrow(); + let mut region_body = region.body().front(); + while let Some(n) = region_body.as_pointer() { + region_body.move_next(); + total += 1; + // If it has no successors, it is definitely a root + if !Self::has_forward_successors(Some(&n), bui) { + roots.push(Some(n.clone())); + // Run DFS not to walk this part of CFG later. + num = snca.run_dfs::(Some(&n), num, always_descend, 1, None); + log::trace!("found a new trivial root: {}", n.borrow().id()); + match snca.num_to_node.get(num as usize) { + None => log::trace!("last visited node: None"), + Some(last_visited) => { + log::trace!("last visited node: {}", last_visited.borrow().id()) + } + } + } + } + + log::trace!("looking for non-trivial roots"); + + // Step 2: Find all non-trivial root candidates. + // + // Those are CFG nodes that are reverse-unreachable were not visited by previous DFS walks + // (i.e. CFG nodes in infinite loops). + // + // Accounting for the virtual exit, see if we had any reverse-unreachable nodes. + let has_non_trivial_roots = total + 1 != num; + if has_non_trivial_roots { + // `succ_order` is the order of blocks in the region. It is needed to make the + // calculation of the `furthest_away` node and the whole PostDominanceTree immune to + // swapping successors (e.g. canonicalizing branch predicates). `succ_order` is + // initialized lazily only for successors of reverse unreachable nodes. + #[derive(Default)] + struct LazySuccOrder { + succ_order: BTreeMap, + initialized: bool, + } + impl LazySuccOrder { + pub fn get_or_init<'a, 'b: 'a, const IS_POST_DOM: bool>( + &'b mut self, + region: &Region, + bui: Option<&'a BatchUpdateInfo>, + snca: &SemiNCA, + ) -> &'a BTreeMap { + if !self.initialized { + let mut region_body = region.body().front(); + while let Some(n) = region_body.as_pointer() { + region_body.move_next(); + let n_num = snca.node_info(Some(&n)).num.get(); + if n_num == 0 { + for succ in + get_children_with_batch_updates::(&n, bui) + { + self.succ_order.insert(succ, 0); + } + } + } + + // Add mapping for all entries of succ_order + let mut node_num = 0; + let mut region_body = region.body().front(); + while let Some(n) = region_body.as_pointer() { + region_body.move_next(); + node_num += 1; + if let Some(order) = self.succ_order.get_mut(&n) { + assert_eq!(*order, 0); + *order = node_num; + } + } + self.initialized = true; + } + + &self.succ_order + } + } + + let mut succ_order = LazySuccOrder::default(); + + // Make another DFS pass over all other nodes to find the reverse-unreachable blocks, + // and find the furthest paths we'll be able to make. + // + // Note that this looks N^2, but it's really 2N worst case, if every node is unreachable. + // This is because we are still going to only visit each unreachable node once, we may + // just visit it in two directions, depending on how lucky we get. + let mut region_body = region.body().front(); + while let Some(n) = region_body.as_pointer() { + region_body.move_next(); + + if snca.node_info(Some(&n)).num.get() == 0 { + log::trace!("visiting node {n}"); + + // Find the furthest away we can get by following successors, then + // follow them in reverse. This gives us some reasonable answer about + // the post-dom tree inside any infinite loop. In particular, it + // guarantees we get to the farthest away point along *some* + // path. This also matches the GCC's behavior. + // If we really wanted a totally complete picture of dominance inside + // this infinite loop, we could do it with SCC-like algorithms to find + // the lowest and highest points in the infinite loop. In theory, it + // would be nice to give the canonical backedge for the loop, but it's + // expensive and does not always lead to a minimal set of roots. + log::trace!("running forward DFS.."); + + let succ_order = succ_order.get_or_init(®ion, bui, &snca); + let new_num = snca.run_dfs::( + Some(&n), + num, + always_descend, + num, + Some(succ_order), + ); + let furthest_away = snca.num_to_node[new_num as usize].clone(); + log::trace!( + "found a new furthest away node (non-trivial root): {furthest_away}", + ); + roots.push(Some(furthest_away.clone())); + log::trace!("previous `num`: {num}, new `num` {new_num}"); + log::trace!("removing DFS info.."); + for i in ((num + 1)..=new_num).rev() { + let n = snca.num_to_node[i as usize].clone(); + log::trace!("removing DFS info for {n}"); + *snca.node_info_mut(Some(&n)) = Default::default(); + snca.num_to_node.pop(); + } + let prev_num = num; + log::trace!("running reverse depth-first search"); + num = snca.run_dfs::( + Some(&furthest_away), + num, + always_descend, + 1, + None, + ); + for i in (prev_num + 1)..num { + match snca.num_to_node.get(i as usize) { + None => log::trace!("found virtual node"), + Some(n) => log::trace!("found node {n}"), + } + } + } + } + } + + log::trace!("total: {total}, num: {num}"); + log::trace!("discovered cfg nodes:"); + for i in 0..num { + log::trace!(" {i}: {}", &snca.num_to_node[i as usize]); + } + + assert_eq!(total + 1, num, "everything should have been visited"); + + // Step 3: If we found some non-trivial roots, make them non-redundant. + if has_non_trivial_roots { + Self::remove_redundant_roots(snca.batch_updates.as_ref(), &mut roots); + } + + log::trace!( + "found roots: {}", + DisplayValues::new(roots.iter().map(|v| DisplayOptional(v.as_ref()))) + ); + + roots + } + + // This function only makes sense for postdominators. + // + // We define roots to be some set of CFG nodes where (reverse) DFS walks have to start in order + // to visit all the CFG nodes (including the reverse-unreachable ones). + // + // When the search for non-trivial roots is done it may happen that some of the non-trivial + // roots are reverse-reachable from other non-trivial roots, which makes them redundant. This + // function removes them from the set of input roots. + fn remove_redundant_roots( + bui: Option<&BatchUpdateInfo>, + roots: &mut SmallVec<[Option; 4]>, + ) { + assert!(IS_POST_DOM, "this function is for post-dominators only"); + + log::trace!("removing redundant roots.."); + + let mut snca = Self::new(bui.cloned()); + + let mut root_index = 0; + 'roots: while root_index < roots.len() { + let root = &roots[root_index]; + + // Trivial roots are never redundant + if !Self::has_forward_successors(root.as_ref(), bui) { + continue; + } + + log::trace!("checking if {} remains a root", DisplayOptional(root.as_ref())); + snca.clear(); + + // Do a forward walk looking for the other roots. + let num = snca.run_dfs::(root.as_ref(), 0, always_descend, 0, None); + // Skip the start node and begin from the second one (note that DFS uses 1-based indexing) + for x in 2..(num as usize) { + let n = &snca.num_to_node[x]; + + // If we found another root in a (forward) DFS walk, remove the current root from + // the set of roots, as it is reverse-reachable from the other one. + if roots.iter().any(|r| r.as_ref().is_some_and(|root| root == n)) { + log::trace!("forward DFS walk found another root {n}"); + log::trace!("removing root {}", DisplayOptional(root.as_ref())); + roots.swap_remove(root_index); + + // Root at the back takes the current root's place, so revisit the same index on + // the next iteration + continue 'roots; + } + } + + root_index += 1; + } + } + + pub fn do_full_dfs_walk(&mut self, tree: &DomTreeBase, condition: C) + where + for<'a> C: Copy + Fn(Option<&BlockRef>, Option<&BlockRef>) -> bool + 'a, + { + if const { !IS_POST_DOM } { + assert_eq!(tree.num_roots(), 1, "dominators should have a single root"); + self.run_dfs::(tree.roots()[0].as_ref(), 0, condition, 0, None); + return; + } + + self.add_virtual_root(); + let mut num = 1; + for root in tree.roots() { + num = self.run_dfs::(root.as_ref(), num, condition, 1, None); + } + } + + pub fn attach_new_subtree( + &mut self, + tree: &mut DomTreeBase, + attach_to: Rc, + ) { + // Attach the first unreachable block to `attach_to` + self.node_info(Some(&self.num_to_node[0])).idom.set(attach_to.block().cloned()); + // Loop over all of the discovered blocks in the function... + for w in self.num_to_node.iter() { + if tree.get(Some(w)).is_some() { + // Already computed the node before + continue; + } + + let idom = self.idom(Some(w)); + + // Get or compute the node for the immediate dominator + let idom_node = self.node_for_block(idom.as_ref(), tree); + + // Add a new tree node for this basic block, and link it as a child of idom_node + tree.create_node(Some(w.clone()), idom_node); + } + } + + pub fn reattach_existing_subtree( + &mut self, + tree: &mut DomTreeBase, + attach_to: Rc, + ) { + self.node_info(Some(&self.num_to_node[0])).idom.set(attach_to.block().cloned()); + for n in self.num_to_node.iter() { + let node = tree.get(Some(n)).unwrap(); + let idom = tree.get(self.node_info(Some(n)).idom().as_ref()); + node.set_idom(idom); + } + } + + // Checks if a node has proper support, as defined on the page 3 and later + // explained on the page 7 of [2]. + pub fn has_proper_support( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + node: &DomTreeNode, + ) -> bool { + log::trace!("is reachable from idom {node}"); + + let Some(block) = node.block() else { + return false; + }; + + let preds = if IS_POST_DOM { + get_children_with_batch_updates::(block, bui) + } else { + get_children_with_batch_updates::(block, bui) + }; + + for pred in preds { + log::trace!("pred {pred}"); + if tree.get(Some(&pred)).is_none() { + continue; + } + + let support = tree.find_nearest_common_dominator(block, &pred); + log::trace!("support {}", DisplayOptional(support.as_ref())); + if support.as_ref() != Some(block) { + log::trace!( + "{node} is reachable from support {}", + DisplayOptional(support.as_ref()) + ); + return true; + } + } + + false + } +} + +#[derive(Eq, PartialEq)] +struct InsertionInfoItem { + node: Rc, +} +impl From> for InsertionInfoItem { + fn from(node: Rc) -> Self { + Self { node } + } +} +impl PartialOrd for InsertionInfoItem { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for InsertionInfoItem { + #[inline] + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.node.level().cmp(&other.node.level()) + } +} + +#[derive(Default)] +struct InsertionInfo { + bucket: crate::adt::SmallPriorityQueue, + visited: crate::adt::SmallSet, 8>, + affected: SmallVec<[Rc; 8]>, +} + +/// Insertion and Deletion +impl SemiNCA { + pub fn insert_edge( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Option, + to: Option, + ) { + assert!( + from.as_ref().is_some() || IS_POST_DOM, + "'from' has to be a valid cfg node or a virtual root" + ); + let to = to.expect("expected a valid `to` node"); + + log::trace!("inserting edge {from:?} -> {to}"); + + let from_node = tree.get(from.as_ref()); + let from_node = if let Some(from_node) = from_node { + from_node + } else { + // Ignore edges from unreachable nodes for (forward) dominators. + if !IS_POST_DOM { + return; + } + + // The unreachable node becomes a new root -- a tree node for it. + let virtual_root = tree.get(None); + let from_node = tree.create_node(from.clone(), virtual_root); + tree.roots_mut().push(from); + from_node + }; + + tree.mark_invalid(); + + let to_node = tree.get(Some(&to)); + match to_node { + None => Self::insert_unreachable(tree, bui, from_node, to), + Some(to_node) => Self::insert_reachable(tree, bui, from_node, to_node), + } + } + + fn insert_unreachable( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Rc, + to: BlockRef, + ) { + log::trace!("inserting {from} -> {to} (unreachable)"); + + // Collect discovered edges to already reachable nodes + // Discover and connect nodes that became reachable with the insertion. + let mut discovered_edges_to_reachable = SmallVec::default(); + Self::compute_unreachable_dominators( + tree, + bui, + &to, + from.clone(), + &mut discovered_edges_to_reachable, + ); + + log::trace!("inserted {from} -> {to} (prev unreachable)"); + + // Use the discovered edges and insert discovered connecting (incoming) edges + for (from_block_ref, to_node) in discovered_edges_to_reachable { + log::trace!("inserting discovered connecting edge {from_block_ref:?} -> {to_node}",); + let from_node = tree.get(from_block_ref.as_ref()).unwrap(); + Self::insert_reachable(tree, bui, from_node, to_node); + } + } + + fn insert_reachable( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Rc, + to: Rc, + ) { + log::trace!("reachable {from} -> {to}"); + + if const { IS_POST_DOM } { + let rebuilt = SemiNCA::::update_roots_before_insertion( + unsafe { + core::mem::transmute::<&mut DomTreeBase, &mut DomTreeBase>( + tree, + ) + }, + bui.map(|bui| unsafe { + core::mem::transmute::<&BatchUpdateInfo, &BatchUpdateInfo>( + bui, + ) + }), + to.clone(), + ); + if rebuilt { + return; + } + } + + // find_nearest_common_dominator expects both pointers to be valid. When `from` is a virtual + // root, then its CFG block pointer is `None`, so we have to "compute" the NCD manually + let ncd_block = if from.block().is_some() && to.block().is_some() { + tree.find_nearest_common_dominator(from.block().unwrap(), from.block().unwrap()) + } else { + None + }; + assert!(ncd_block.is_some() || tree.is_post_dominator()); + let ncd = tree.get(ncd_block.as_ref()).unwrap(); + + log::trace!("nearest common dominator == {ncd}"); + + // Based on Lemma 2.5 from [2], after insertion of (from, to), `v` is affected iff + // depth(ncd) + 1 < depth(v) && a path `P` from `to` to `v` exists where every `w` on `P` + // s.t. depth(v) <= depth(w) + // + // This reduces to a widest path problem (maximizing the depth of the minimum vertex in + // the path) which can be solved by a modified version of Dijkstra with a bucket queue + // (named depth-based search in [2]). + // + // `to` is in the path, so depth(ncd) + 1 < depth(v) <= depth(to). Nothing affected if + // this does not hold. + let ncd_level = ncd.level(); + if ncd_level + 1 >= to.level() { + return; + } + + let mut insertion_info = InsertionInfo::default(); + let mut unaffected_on_current_level = SmallVec::<[Rc; 8]>::default(); + insertion_info.bucket.push(to.clone().into()); + insertion_info.visited.insert(to); + + while let Some(InsertionInfoItem { mut node }) = insertion_info.bucket.pop() { + insertion_info.affected.push(node.clone()); + + let current_level = node.level(); + log::trace!("mark {node} as affected, current level: {current_level}"); + + assert!(node.block().is_some() && insertion_info.visited.contains(&node)); + + loop { + // Unlike regular Dijkstra, we have an inner loop to expand more + // vertices. The first iteration is for the (affected) vertex popped + // from II.Bucket and the rest are for vertices in + // UnaffectedOnCurrentLevel, which may eventually expand to affected + // vertices. + // + // Invariant: there is an optimal path from `To` to TN with the minimum + // depth being CurrentLevel. + for succ in get_children_with_batch_updates::( + node.block().unwrap(), + bui, + ) { + let succ_node = tree + .get(Some(&succ)) + .expect("unreachable successor found during reachable insertion"); + let succ_level = succ_node.level(); + log::trace!("successor {succ_node}, level = {succ_level}"); + + // There is an optimal path from `To` to Succ with the minimum depth + // being min(CurrentLevel, SuccLevel). + // + // If depth(NCD)+1 < depth(Succ) is not satisfied, Succ is unaffected + // and no affected vertex may be reached by a path passing through it. + // Stop here. Also, Succ may be visited by other predecessors but the + // first visit has the optimal path. Stop if Succ has been visited. + if succ_level <= ncd_level + 1 + || !insertion_info.visited.insert(succ_node.clone()) + { + continue; + } + + if succ_level > current_level { + // succ is unaffected, but it may (transitively) expand to affected vertices. + // Store it in unaffected_on_current_level + log::trace!("marking visiting not affected {succ}"); + unaffected_on_current_level.push(succ_node.clone()); + } else { + // The condition is satisfied (Succ is affected). Add Succ to the + // bucket queue. + log::trace!("add {succ} to a bucket"); + insertion_info.bucket.push(succ_node.clone().into()); + } + } + + if unaffected_on_current_level.is_empty() { + break; + } + + if let Some(n) = unaffected_on_current_level.pop() { + node = n; + } else { + break; + } + log::trace!("next: {node}"); + } + } + + // Finish by updating immediate dominators and levels. + Self::update_insertion(tree, bui, ncd, &insertion_info); + } + + pub fn delete_edge( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Option<&BlockRef>, + to: Option<&BlockRef>, + ) { + let from = from.expect("cannot disconnect virtual node"); + let to = to.expect("cannot disconnect virtual node"); + + log::trace!("deleting edge {from} -> {to}"); + + // Deletion in an unreachable subtree -- nothing to do. + let Some(from_node) = tree.get(Some(from)) else { + return; + }; + + let Some(to_node) = tree.get(Some(to)) else { + log::trace!("to {to} already unreachable -- there is no edge to delete",); + return; + }; + + let ncd_block = tree.find_nearest_common_dominator(from, to); + let ncd = tree.get(ncd_block.as_ref()); + + // If to dominates from -- nothing to do. + if Some(&to_node) != ncd.as_ref() { + tree.mark_invalid(); + + let to_idom = to_node.idom(); + log::trace!( + "ncd {}, to_idom {}", + DisplayOptional(ncd.as_ref()), + DisplayOptional(to_idom.as_ref()) + ); + + // To remains reachable after deletion (based on caption under figure 4, from [2]) + if (Some(&from_node) != to_idom.as_ref()) + || Self::has_proper_support(tree, bui, &to_node) + { + Self::delete_reachable(tree, bui, from_node, to_node) + } else { + Self::delete_unreachable(tree, bui, to_node) + } + + if const { IS_POST_DOM } { + SemiNCA::::update_roots_after_update( + unsafe { + core::mem::transmute::<&mut DomTreeBase, &mut DomTreeBase>( + tree, + ) + }, + bui.map(|bui| unsafe { + core::mem::transmute::<&BatchUpdateInfo, &BatchUpdateInfo>( + bui, + ) + }), + ); + } + } + } + + /// Handles deletions that leave destination nodes reachable. + fn delete_reachable( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + from: Rc, + to: Rc, + ) { + log::trace!("deleting reachable {from} -> {to} - rebuilding subtree.."); + + // Find the top of the subtree that needs to be rebuilt (based on the lemma 2.6 from [2]) + let to_idom = + tree.find_nearest_common_dominator(from.block().unwrap(), to.block().unwrap()); + assert!(to_idom.is_some() || tree.is_post_dominator()); + let to_idom_node = tree.get(to_idom.as_ref()).unwrap(); + let prev_idom_subtree = to_idom_node.idom(); + // Top of the subtree to rebuild is the root node. Rebuild the tree from scratch. + let Some(prev_idom_subtree) = prev_idom_subtree else { + log::trace!("the entire tree needs to be rebuilt"); + Self::compute_from_scratch(tree, bui.cloned()); + return; + }; + + // Only visit nodes in the subtree starting at `to` + let level = to_idom_node.level(); + let descend_below = |_: Option<&BlockRef>, to: Option<&BlockRef>| -> bool { + tree.get(to).unwrap().level() > level + }; + + log::trace!("top of subtree {to_idom_node}"); + + let mut snca = Self::new(bui.cloned()); + snca.run_dfs::(to_idom.as_ref(), 0, descend_below, 0, None); + log::trace!("running Semi-NCA"); + snca.run(); + snca.reattach_existing_subtree(tree, prev_idom_subtree); + } + + /// Handle deletions that make destination node unreachable. + /// + /// (Based on the lemma 2.7 from the [2].) + fn delete_unreachable( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + to: Rc, + ) { + log::trace!("deleting unreachable subtree {to}"); + assert!(to.block().is_some()); + + if IS_POST_DOM { + // Deletion makes a region reverse-unreachable and creates a new root. + // + // Simulate that by inserting an edge from the virtual root to `to` and adding it as a new + // root. + log::trace!("deletion made a region reverse-unreachable, adding new root {to}"); + tree.roots_mut().push(to.block().cloned()); + Self::insert_reachable(tree, bui, tree.get(None).unwrap(), to); + return; + } + + let mut affected_queue = SmallVec::<[Option; 16]>::default(); + let level = to.level(); + + // Traverse destination node's descendants with greater level in the tree + // and collect visited nodes. + let descend_and_collect = |_: Option<&BlockRef>, to: Option<&BlockRef>| -> bool { + let node = tree.get(to).unwrap(); + if node.level() > level { + return true; + } + let to = to.cloned(); + if !affected_queue.contains(&to) { + affected_queue.push(to) + } + false + }; + + let mut snca = Self::new(bui.cloned()); + let last_dfs_num = snca.run_dfs::(to.block(), 0, descend_and_collect, 0, None); + + let mut min_node = to.clone(); + // Identify the top of the subtree to rebuild by finding the NCD of all the affected nodes. + for n in affected_queue { + let node = tree.get(n.as_ref()).unwrap(); + let ncd_block = + tree.find_nearest_common_dominator(node.block().unwrap(), to.block().unwrap()); + assert!(ncd_block.is_some() || tree.is_post_dominator()); + let ncd = tree.get(ncd_block.as_ref()).unwrap(); + log::trace!( + "processing affected node {node} with: nearest common dominator = {ncd}, min node \ + = {min_node}" + ); + if ncd != node && ncd.level() < min_node.level() { + min_node = ncd; + } + } + + // Root reached, rebuild the whole tree from scratch. + if min_node.idom().is_none() { + log::trace!("the entire tree needs to be rebuilt"); + Self::compute_from_scratch(tree, bui.cloned()); + return; + } + + // Erase the unreachable subtree in reverse preorder to process all children before deleting + // their parent. + for i in (1..=(last_dfs_num as usize)).rev() { + let n = &snca.num_to_node[i]; + log::trace!("erasing node {n}"); + tree.erase_node(n); + } + + // The affected subtree start at the `to` node -- there's no extra work to do. + if min_node == to { + return; + } + + log::trace!("delete_unreachable: running dfs with min_node = {min_node}"); + let min_level = min_node.level(); + let prev_idom = min_node.idom().unwrap(); + snca.clear(); + + // Identify nodes that remain in the affected subtree. + let descend_below = |_: Option<&BlockRef>, to: Option<&BlockRef>| -> bool { + let to_node = tree.get(to); + to_node.is_some_and(|to_node| to_node.level() > min_level) + }; + snca.run_dfs::(min_node.block(), 0, descend_below, 0, None); + + log::trace!("previous idom(min_node) = {prev_idom}"); + log::trace!("running Semi-NCA"); + + // Rebuild the remaining part of affected subtree. + snca.run(); + snca.reattach_existing_subtree(tree, prev_idom); + } + + pub fn apply_updates( + tree: &mut DomTreeBase, + mut pre_view_cfg: cfg::CfgDiff, + post_view_cfg: cfg::CfgDiff, + ) { + // Note: the `post_view_cfg` is only used when computing from scratch. It's data should + // already included in the `pre_view_cfg` for incremental updates. + let num_updates = pre_view_cfg.num_legalized_updates(); + match num_updates { + 0 => (), + 1 => { + // Take the fast path for a single update and avoid running the batch update machinery. + let update = pre_view_cfg.pop_update_for_incremental_updates(); + let bui = if post_view_cfg.is_empty() { + None + } else { + Some(BatchUpdateInfo::new(post_view_cfg.clone(), Some(post_view_cfg))) + }; + match update.kind() { + cfg::CfgUpdateKind::Insert => { + Self::insert_edge( + tree, + bui.as_ref(), + Some(update.from().clone()), + Some(update.to().clone()), + ); + } + cfg::CfgUpdateKind::Delete => { + Self::delete_edge( + tree, + bui.as_ref(), + Some(update.from()), + Some(update.to()), + ); + } + } + } + _ => { + let mut bui = BatchUpdateInfo::new(pre_view_cfg, Some(post_view_cfg)); + // Recalculate the DominatorTree when the number of updates exceeds a threshold, + // which usually makes direct updating slower than recalculation. We select this + // threshold proportional to the size of the DominatorTree. The constant is selected + // by choosing the one with an acceptable performance on some real-world inputs. + + // Make unittests of the incremental algorithm work + // TODO(pauls): review this + if tree.len() <= 100 { + if bui.num_legalized > tree.len() { + Self::compute_from_scratch(tree, Some(bui.clone())); + } + } else if bui.num_legalized > tree.len() / 40 { + Self::compute_from_scratch(tree, Some(bui.clone())); + } + + // If the DominatorTree was recalculated at some point, stop the batch updates. Full + // recalculations ignore batch updates and look at the actual CFG. + for _ in 0..bui.num_legalized { + if bui.is_recalculated { + break; + } + + Self::apply_next_update(tree, &mut bui); + } + } + } + } + + fn apply_next_update( + tree: &mut DomTreeBase, + bui: &mut BatchUpdateInfo, + ) { + // Popping the next update, will move the `pre_view_cfg` to the next snapshot. + let current_update = bui.pre_cfg_view.pop_update_for_incremental_updates(); + log::trace!("applying update: {current_update:?}"); + + match current_update.kind() { + cfg::CfgUpdateKind::Insert => { + Self::insert_edge( + tree, + Some(bui), + Some(current_update.from().clone()), + Some(current_update.to().clone()), + ); + } + cfg::CfgUpdateKind::Delete => { + Self::delete_edge( + tree, + Some(bui), + Some(current_update.from()), + Some(current_update.to()), + ); + } + } + } + + pub fn compute(tree: &mut DomTreeBase) { + Self::compute_from_scratch(tree, None); + } + + pub fn compute_from_scratch( + tree: &mut DomTreeBase, + mut bui: Option>, + ) { + use crate::cfg::GraphDiff; + + tree.reset(); + + // If the update is using the actual CFG, `bui` is `None`. If it's using a view, `bui` is + // `Some` and the `pre_cfg_view` is used. When calculating from scratch, make the + // `pre_cfg_view` equal to the `post_cfg_view`, so `post` is used. + let post_view_bui = bui.clone().and_then(|mut bui| { + if !bui.post_cfg_view.is_empty() { + bui.pre_cfg_view = bui.post_cfg_view.clone(); + Some(bui) + } else { + None + } + }); + + // This is rebuilding the whole tree, not incrementally, but `post_view_bui` is used in case + // the caller needs a dominator tree update with a cfg view + let mut snca = Self::new(post_view_bui); + + // Step 0: Number blocks in depth-first order, and initialize variables used in later stages + // of the algorithm. + let roots = Self::find_roots(tree, bui.as_ref()); + *tree.roots_mut() = roots; + snca.do_full_dfs_walk(tree, always_descend); + + snca.run(); + if let Some(bui) = bui.as_mut() { + bui.is_recalculated = true; + log::trace!("dominator tree recalculated, skipping future batch updates"); + } + + if tree.roots().is_empty() { + return; + } + + // Add a node for the root. If the tree is a post-dominator tree, it will be the virtual + // exit (denoted by a block ref of `None`), which post-dominates all real exits (including + // multiple exit blocks, infinite loops). + let root = if IS_POST_DOM { + None + } else { + tree.roots()[0].clone() + }; + + let new_root = tree.create_node(root, None); + tree.set_root(new_root); + let root_node = tree.root_node().expect("expected root node"); + snca.attach_new_subtree(tree, root_node); + } + + fn update_insertion( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + ncd: Rc, + insertion_info: &InsertionInfo, + ) { + log::trace!("updating nearest common dominator = {ncd}"); + + for to_node in insertion_info.affected.iter() { + log::trace!("idom({to_node}) = {ncd}"); + to_node.set_idom(Some(ncd.clone())); + } + + if IS_POST_DOM { + SemiNCA::::update_roots_after_update( + unsafe { + core::mem::transmute::<&mut DomTreeBase, &mut DomTreeBase>( + tree, + ) + }, + bui.map(|bui| unsafe { + core::mem::transmute::<&BatchUpdateInfo, &BatchUpdateInfo>( + bui, + ) + }), + ); + } + } + + /// Connects nodes that become reachable with an insertion + fn compute_unreachable_dominators( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + root: &BlockRef, + incoming: Rc, + discovered_connecting_edges: &mut SmallVec<[(Option, Rc); 8]>, + ) { + assert!(tree.get(Some(root)).is_none(), "root must not be reachable"); + + // Visit only previously unreachable nodes + let unreachable_descender = |from: Option<&BlockRef>, to: Option<&BlockRef>| -> bool { + let to_node = tree.get(to); + match to_node { + None => true, + Some(to_node) => { + discovered_connecting_edges.push((from.cloned(), to_node)); + false + } + } + }; + + let mut snca = Self::new(bui.cloned()); + snca.run_dfs::(Some(root), 0, unreachable_descender, 0, None); + snca.run(); + snca.attach_new_subtree(tree, incoming); + + log::trace!("after adding unreachable nodes"); + } +} + +/// Verification +impl SemiNCA { + pub fn verify_roots(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_reachability(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_levels(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_dfs_numbers(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_parent_property(&self, _tree: &DomTreeBase) -> bool { + true + } + + pub fn verify_sibling_property(&self, _tree: &DomTreeBase) -> bool { + true + } +} + +impl SemiNCA { + /// Determines if some existing root becomes reverse-reachable after the insertion. + /// + /// Rebuilds the whole tree if that situation happens. + fn update_roots_before_insertion( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + to: Rc, + ) -> bool { + // Destination node is not attached to the virtual root, so it cannot be a root + if !tree.is_virtual_root(&to.idom().unwrap()) { + return false; + } + + if !tree.roots().contains(&to.block().cloned()) { + // To is not a root, nothing to update + return false; + } + + log::trace!("after the insertion, {to} is no longer a root - rebuilding the tree.."); + + Self::compute_from_scratch(tree, bui.cloned()); + true + } + + /// Updates the set of roots after insertion or deletion. + /// + /// This ensures that roots are the same when after a series of updates and when the tree would + /// be built from scratch. + fn update_roots_after_update( + tree: &mut DomTreeBase, + bui: Option<&BatchUpdateInfo>, + ) { + // The tree has only trivial roots -- nothing to update. + if !tree.roots().iter().any(|n| Self::has_forward_successors(n.as_ref(), bui)) { + return; + } + + // Recalculate the set of roots + let roots = Self::find_roots(tree, bui); + if !is_permutation(tree.roots(), &roots) { + // The roots chosen in the CFG have changed. This is because the incremental algorithm + // does not really know or use the set of roots and can make a different (implicit) + // decision about which node within an infinite loop becomes a root. + log::trace!( + "roots are different in updated trees - the entire tree needs to be rebuilt" + ); + // It may be possible to update the tree without recalculating it, but we do not know + // yet how to do it, and it happens rarely in practice. + Self::compute_from_scratch(tree, bui.cloned()); + } + } +} + +fn is_permutation(a: &[Option], b: &[Option]) -> bool { + if a.len() != b.len() { + return false; + } + let set = crate::adt::SmallSet::<_, 4>::from_iter(a.iter().cloned()); + for n in b { + if !set.contains(n) { + return false; + } + } + true +} + +#[doc(hidden)] +#[inline(always)] +const fn always_descend(_: Option<&BlockRef>, _: Option<&BlockRef>) -> bool { + true +} diff --git a/hir2/src/ir/dominance/traits.rs b/hir2/src/ir/dominance/traits.rs new file mode 100644 index 000000000..4d6f83e34 --- /dev/null +++ b/hir2/src/ir/dominance/traits.rs @@ -0,0 +1,193 @@ +use super::{DominanceInfo, PostDominanceInfo}; +use crate::{Block, Operation, Value}; + +/// This trait is implemented on a type which has a dominance relationship with `Rhs`. +pub trait Dominates { + /// Returns true if `self` dominates `other`. + /// + /// In cases where `Rhs = Self`, implementations should return true when `self == other`. + /// + /// For a stricter form of dominance, use [Dominates::properly_dominates]. + fn dominates(&self, other: &Rhs, dom_info: &DominanceInfo) -> bool; + /// Returns true if `self` properly dominates `other`. + /// + /// In cases where `Rhs = Self`, implementations should return false when `self == other`. + fn properly_dominates(&self, other: &Rhs, dom_info: &DominanceInfo) -> bool; +} + +/// This trait is implemented on a type which has a post-dominance relationship with `Rhs`. +pub trait PostDominates { + /// Returns true if `self` post-dominates `other`. + /// + /// In cases where `Rhs = Self`, implementations should return true when `self == other`. + /// + /// For a stricter form of dominance, use [PostDominates::properly_dominates]. + fn post_dominates(&self, other: &Rhs, dom_info: &PostDominanceInfo) -> bool; + /// Returns true if `self` properly post-dominates `other`. + /// + /// In cases where `Rhs = Self`, implementations should return false when `self == other`. + fn properly_post_dominates(&self, other: &Rhs, dom_info: &PostDominanceInfo) -> bool; +} + +/// The dominance relationship between two blocks. +impl Dominates for Block { + /// Returns true if `a == b` or `a` properly dominates `b`. + fn dominates(&self, other: &Self, dom_info: &DominanceInfo) -> bool { + core::ptr::addr_eq(self, other) || self.properly_dominates(other, dom_info) + } + + /// Returns true if `a != b` and: + /// + /// * `a` is an ancestor of `b` + /// * The region containing `a` also contains `b` or some ancestor of `b`, and `a` dominates + /// that block in that kind of region. + /// * In SSA regions, `a` properly dominates `b` if all control flow paths from the entry + /// block to `b`, flow through `a`. + /// * In graph regions, all blocks dominate all other blocks. + fn properly_dominates(&self, other: &Self, dom_info: &DominanceInfo) -> bool { + dom_info.info().properly_dominates(&self.as_block_ref(), &other.as_block_ref()) + } +} + +/// The post-dominance relationship between two blocks. +impl PostDominates for Block { + fn post_dominates(&self, other: &Self, dom_info: &PostDominanceInfo) -> bool { + core::ptr::addr_eq(self, other) || self.properly_post_dominates(other, dom_info) + } + + /// Returns true if `a != b` and: + /// + /// * `a` is an ancestor of `b` + /// * The region containing `a` also contains `b` or some ancestor of `b`, and `a` dominates + /// that block in that kind of region. + /// * In SSA regions, `a` properly post-dominates `b` if all control flow paths from `b` to + /// an exit node, flow through `a`. + /// * In graph regions, all blocks post-dominate all other blocks. + fn properly_post_dominates(&self, other: &Self, dom_info: &PostDominanceInfo) -> bool { + dom_info.info().properly_dominates(&self.as_block_ref(), &other.as_block_ref()) + } +} + +/// The dominance relationship for operations +impl Dominates for Operation { + fn dominates(&self, other: &Self, dom_info: &DominanceInfo) -> bool { + core::ptr::addr_eq(self, other) || self.properly_dominates(other, dom_info) + } + + /// Returns true if `a != b`, and: + /// + /// * `a` and `b` are in the same block, and `a` properly dominates `b` within the block, or + /// * the block that contains `a` properly dominates the block that contains `b`. + /// * `b` is enclosed in a region of `a` + /// + /// In any SSA region, `a` dominates `b` in the same block if `a` precedes `b`. In a graph + /// region all operations in a block dominate all other operations in the same block. + fn properly_dominates(&self, other: &Self, dom_info: &DominanceInfo) -> bool { + let a = self.as_operation_ref(); + let b = other.as_operation_ref(); + dom_info.properly_dominates_with_options(&a, &b, /*enclosing_op_ok= */ true) + } +} + +/// The post-dominance relationship for operations +impl PostDominates for Operation { + fn post_dominates(&self, other: &Self, dom_info: &PostDominanceInfo) -> bool { + core::ptr::addr_eq(self, other) || self.properly_post_dominates(other, dom_info) + } + + /// Returns true if `a != b`, and: + /// + /// * `a` and `b` are in the same block, and `a` properly post-dominates `b` within the block + /// * the block that contains `a` properly post-dominates the block that contains `b`. + /// * `b` is enclosed in a region of `a` + /// + /// In any SSA region, `a` post-dominates `b` in the same block if `b` precedes `a`. In a graph + /// region all operations in a block post-dominate all other operations in the same block. + fn properly_post_dominates(&self, other: &Self, dom_info: &PostDominanceInfo) -> bool { + let a_block = self.parent().expect("`self` must be in a block"); + let mut b_block = other.parent().expect("`other` must be in a block"); + + // An instruction post dominates, but does not properly post-dominate itself unless this is + // a graph region. + if core::ptr::addr_eq(self, other) { + return !a_block.borrow().has_ssa_dominance(); + } + + // If these ops are in different regions, then normalize one into the other. + let a_region = a_block.borrow().parent(); + let b_region = b_block.borrow().parent(); + let a = self.as_operation_ref(); + let mut b = other.as_operation_ref(); + if a_region != b_region { + // Walk up `b`'s region tree until we find an operation in `a`'s region that encloses + // it. If this fails, then we know there is no post-dominance relation. + let Some(found) = a_region.as_ref().and_then(|r| r.borrow().find_ancestor_op(&b)) + else { + return false; + }; + b = found; + b_block = b.borrow().parent().unwrap(); + assert!(b_block.borrow().parent() == a_region); + + // If `a` encloses `b`, then we consider it to post-dominate. + if a == b { + return true; + } + } + + // Ok, they are in the same region. If they are in the same block, check if `b` is before + // `a` in the block. + if a_block == b_block { + // Dominance changes based on the region type + return if a_block.borrow().has_ssa_dominance() { + // If the blocks are the same, then check if `b` is before `a` in the block. + b.borrow().is_before_in_block(&a) + } else { + true + }; + } + + // If the blocks are different, check if `a`'s block post-dominates `b`'s + dom_info + .info() + .dominance(&a_region.unwrap()) + .properly_dominates(Some(&a_block), Some(&b_block)) + } +} + +/// The dominance relationship between a value and an operation, e.g. between a definition of a +/// value and a user of that same value. +impl Dominates for dyn Value { + /// Return true if the definition of `self` dominates a use by operation `other`. + fn dominates(&self, other: &Operation, dom_info: &DominanceInfo) -> bool { + self.get_defining_op().is_some_and(|op| op == other.as_operation_ref()) + || self.properly_dominates(other, dom_info) + } + + /// Returns true if the definition of `self` properly dominates `other`. + /// + /// This requires the value to either be a block argument, where the block containing `other` + /// is dominated by the block defining `self`, OR that the value is an operation result, and + /// the defining op of `self` properly dominates `other`. + /// + /// If the defining op of `self` encloses `b` in one of its regions, `a` does not dominate `b`. + fn properly_dominates(&self, other: &Operation, dom_info: &DominanceInfo) -> bool { + // Block arguments properly dominate all operations in their own block, so we use a + // dominates check here, not a properly_dominates check. + if let Some(block_arg) = self.downcast_ref::() { + return block_arg + .owner() + .borrow() + .dominates(&other.parent().unwrap().borrow(), dom_info); + } + + // `a` properly dominates `b` if the operation defining `a` properly dominates `b`, but `a` + // does not itself enclose `b` in one of its regions. + let defining_op = self.get_defining_op().unwrap(); + dom_info.properly_dominates_with_options( + &defining_op, + &other.as_operation_ref(), + /*enclosing_op_ok= */ false, + ) + } +} diff --git a/hir2/src/ir/dominance/tree.rs b/hir2/src/ir/dominance/tree.rs new file mode 100644 index 000000000..decab0cd4 --- /dev/null +++ b/hir2/src/ir/dominance/tree.rs @@ -0,0 +1,911 @@ +use alloc::rc::Rc; +use core::{ + cell::{Cell, RefCell}, + fmt, + num::NonZeroU32, +}; + +use smallvec::{smallvec, SmallVec}; + +use super::{BatchUpdateInfo, SemiNCA}; +use crate::{ + cfg::{self, Graph, Inverse, InvertibleGraph}, + formatter::DisplayOptional, + BlockRef, EntityWithId, RegionRef, +}; + +#[derive(Debug, thiserror::Error)] +pub enum DomTreeError { + /// Tried to compute a dominator tree for an empty region + #[error("unable to create dominance tree for empty region")] + EmptyRegion, +} + +/// The level of verification to use with [DominatorTreeBase::verify] +pub enum DomTreeVerificationLevel { + /// Checks basic tree structure and compares with a freshly constructed tree + /// + /// O(n^2) time worst case, but is faster in practice. + Fast, + /// Checks if the tree is correct, but compares it to a freshly constructed tree instead of + /// checking the sibling property. + /// + /// O(n^2) time. + Basic, + /// Verifies if the tree is correct by making sure all the properties, including the parent + /// and sibling property, hold. + /// + /// O(n^3) time. + Full, +} + +/// A forward dominance tree +pub type DominanceTree = DomTreeBase; + +/// A post (backward) dominance tree +pub type PostDominanceTree = DomTreeBase; + +pub type DomTreeRoots = SmallVec<[Option; 4]>; + +/// A dominator tree implementation that abstracts over the type of dominance it represents. +pub struct DomTreeBase { + /// The roots from which dominance is traced. + /// + /// For forward dominance trees, there is always a single root. For post-dominance trees, there + /// may be multiple, one for each exit from the region. + roots: DomTreeRoots, + /// The nodes represented in this dominance tree + nodes: SmallVec<[(Option, Rc); 64]>, + /// The root dominance tree node. + root: Option>, + /// The parent region for which this dominance tree was computed + parent: RegionRef, + /// Whether this dominance tree is valid (true), or outdated (false) + valid: Cell, + /// A counter for expensive queries that may cause us to perform some extra work in order to + /// speed up those queries after a certain point. + slow_queries: Cell, +} + +/// A node in a [DomTreeBase]. +pub struct DomTreeNode { + /// The block represented by this node + block: Option, + /// The immediate dominator of this node, if applicable + idom: Cell>>, + /// The children of this node in the tree + children: RefCell; 4]>>, + /// The depth of this node in the tree + level: Cell, + /// The DFS visitation order (forward) + num_in: Cell>, + /// The DFS visitation order (backward) + num_out: Cell>, +} + +impl fmt::Display for DomTreeNode { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", DisplayOptional(self.block.as_ref().map(|b| b.borrow().id()).as_ref())) + } +} + +impl fmt::Debug for DomTreeNode { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use crate::EntityWithId; + + f.debug_struct("DomTreeNode") + .field_with("block", |f| match self.block.as_ref() { + None => f.write_str("None"), + Some(block_ref) => write!(f, "{}", block_ref.borrow().id()), + }) + .field("idom", unsafe { &*self.idom.as_ptr() }) + .field_with("children", |f| { + f.debug_list().entries(self.children.borrow().iter()).finish() + }) + .field("level", &self.level.get()) + .field("num_in", &self.num_in.get()) + .field("num_out", &self.num_out.get()) + .finish() + } +} + +/// An iterator over nodes in a dominance tree produced by a depth-first, pre-order traversal +pub type PreOrderDomTreeIter = cfg::PreOrderIter>; +/// An iterator over nodes in a dominance tree produced by a depth-first, post-order traversal +pub type PostOrderDomTreeIter = cfg::PostOrderIter>; + +impl Graph for Rc { + type ChildEdgeIter = DomTreeSuccessorIter; + type ChildIter = DomTreeSuccessorIter; + type Edge = Rc; + type Node = Rc; + + fn size(&self) -> usize { + self.children.borrow().len() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + DomTreeSuccessorIter::new(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + DomTreeSuccessorIter::new(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + // The edge is the child node + edge + } + + fn entry_node(&self) -> Self::Node { + Rc::clone(self) + } +} + +pub struct DomTreeSuccessorIter { + node: Rc, + num_children: usize, + index: usize, +} +impl DomTreeSuccessorIter { + pub fn new(node: Rc) -> Self { + let num_children = node.num_children(); + Self { + node, + num_children, + index: 0, + } + } +} +impl core::iter::FusedIterator for DomTreeSuccessorIter {} +impl ExactSizeIterator for DomTreeSuccessorIter { + #[inline] + fn len(&self) -> usize { + self.num_children.saturating_sub(self.index) + } + + #[inline] + fn is_empty(&self) -> bool { + self.index >= self.num_children + } +} +impl Iterator for DomTreeSuccessorIter { + type Item = Rc; + + fn next(&mut self) -> Option { + if self.index >= self.num_children { + return None; + } + let index = self.index; + self.index += 1; + Some(self.node.children.borrow()[index].clone()) + } +} +impl DoubleEndedIterator for DomTreeSuccessorIter { + fn next_back(&mut self) -> Option { + if self.num_children == 0 { + return None; + } + let index = self.num_children; + self.num_children -= 1; + Some(self.node.children.borrow()[index].clone()) + } +} + +impl DomTreeNode { + /// Create a new node for `block`, with the specified immediate dominator. + /// + /// If `block` is `None`, this must be a node in a post-dominator tree, and the resulting node + /// is a virtual node that post-dominates all nodes in the tree + pub fn new(block: Option, idom: Option>) -> Self { + Self { + block, + idom: Cell::new(idom), + children: Default::default(), + level: Cell::new(0), + num_in: Cell::new(None), + num_out: Cell::new(None), + } + } + + /// Build this node with the specified immediate dominator. + pub fn with_idom(self, idom: Rc) -> Self { + self.level.set(idom.level.get() + 1); + self.idom.set(Some(idom)); + self + } + + pub fn block(&self) -> Option<&BlockRef> { + self.block.as_ref() + } + + pub fn idom(&self) -> Option> { + unsafe { &*self.idom.as_ptr() }.clone() + } + + pub(super) fn set_idom(&self, idom: Option>) { + self.idom.set(idom); + } + + #[inline(always)] + pub fn level(&self) -> u32 { + self.level.get() + } + + pub fn is_leaf(&self) -> bool { + self.children.borrow().is_empty() + } + + pub fn num_children(&self) -> usize { + self.children.borrow().len() + } + + pub fn add_child(&self, child: Rc) { + self.children.borrow_mut().push(child); + } + + pub fn clear_children(&self) { + self.children.borrow_mut().clear(); + } + + /// Returns true if `self` is dominated by `other` in the tree. + pub fn is_dominated_by(&self, other: &Self) -> bool { + assert!( + self.num_in.get().is_some() && other.num_in.get().is_some(), + "you forgot to call update_dfs_numbers" + ); + self.num_in.get().is_some_and(|a| other.num_in.get().is_some_and(|b| a >= b)) + && self.num_out.get().is_some_and(|a| other.num_in.get().is_some_and(|b| a <= b)) + } + + /// Recomputes this node's depth in the dominator tree + fn update_level(self: Rc) { + let idom_level = self.idom().expect("expected to have an immediate dominator").level(); + if self.level() == idom_level + 1 { + return; + } + + let mut stack = SmallVec::<[Rc; 64]>::from_iter([self.clone()]); + while let Some(current) = stack.pop() { + current.level.set(current.idom().unwrap().level() + 1); + for child in current.children.borrow().iter() { + assert!(child.idom().is_some()); + if child.level() != child.idom().unwrap().level() + 1 { + stack.push(Rc::clone(child)); + } + } + } + } +} + +impl Eq for DomTreeNode {} +impl PartialEq for DomTreeNode { + fn eq(&self, other: &Self) -> bool { + self.block == other.block + } +} + +impl DomTreeBase { + #[inline] + pub fn root(&self) -> &BlockRef { + self.roots[0].as_ref().unwrap() + } +} + +impl DomTreeBase { + /// Compute a dominator tree for `region` + pub fn new(region: RegionRef) -> Result { + let entry = region.borrow().entry_block_ref().ok_or(DomTreeError::EmptyRegion)?; + let root = Rc::new(DomTreeNode::new(Some(entry.clone()), None)); + let nodes = smallvec![(Some(entry.clone()), root.clone())]; + let roots = smallvec![Some(entry)]; + + let mut this = Self { + parent: region, + root: Some(root), + roots, + nodes, + valid: Cell::new(false), + slow_queries: Cell::new(0), + }; + + this.compute(); + + Ok(this) + } + + #[inline] + pub fn parent(&self) -> &RegionRef { + &self.parent + } + + pub fn len(&self) -> usize { + self.nodes.len() + } + + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + #[inline] + pub fn num_roots(&self) -> usize { + self.roots.len() + } + + #[inline] + pub fn roots(&self) -> &[Option] { + &self.roots + } + + #[inline] + pub fn roots_mut(&mut self) -> &mut DomTreeRoots { + &mut self.roots + } + + pub(super) fn set_root(&mut self, root: Rc) { + self.root = Some(root); + } + + /// Returns true if this tree is a post-dominance tree. + #[inline(always)] + pub const fn is_post_dominator(&self) -> bool { + IS_POST_DOM + } + + pub(super) fn mark_invalid(&self) { + self.valid.set(false); + } + + /// Get the node for `block`, if one exists in the tree. + /// + /// Use `None` to get the virtual node, if this is a post-dominator tree + pub fn get(&self, block: Option<&BlockRef>) -> Option> { + self.node_index(block) + .map(|index| unsafe { self.nodes.get_unchecked(index).1.clone() }) + } + + #[inline] + fn node_index(&self, block: Option<&BlockRef>) -> Option { + assert!( + block.is_none_or(|block| block + .borrow() + .parent() + .is_some_and(|parent| parent == self.parent)), + "cannot get dominance info of block with different parent" + ); + if let Some(block) = block { + self.nodes.iter().position(|(b, _)| b.as_ref().is_some_and(|b| b == block)) + } else { + self.nodes.iter().position(|(b, _)| b.is_none()) + } + } + + /// Returns the entry node for the CFG of the region. + /// + /// However, if this tree represents the post-dominance relations for a region, this root may be + /// a node with `block` set to `None`. This is the case when there are multiple exit nodes from + /// a particular function. Consumers of post-dominance information must be capable of dealing + /// with this possibility. + pub fn root_node(&self) -> Option> { + self.root.clone() + } + + /// Get all nodes dominated by `r`, including `r` itself + pub fn get_descendants(&self, r: &BlockRef) -> SmallVec<[BlockRef; 2]> { + let mut results = SmallVec::default(); + let Some(rn) = self.get(Some(r)) else { + return results; + }; + let mut worklist = SmallVec::<[Rc; 8]>::default(); + worklist.push(rn); + + while let Some(n) = worklist.pop() { + let Some(n_block) = n.block() else { + continue; + }; + results.push(n_block.clone()); + worklist.extend(n.children.borrow().iter().cloned()); + } + + results + } + + /// Return true if `a` is dominated by the entry block of the region containing it. + pub fn is_reachable_from_entry(&self, a: &BlockRef) -> bool { + assert!(!self.is_post_dominator(), "unimplemented for post dominator trees"); + + self.get(Some(a)).is_some() + } + + #[inline] + pub const fn is_reachable_from_entry_node(&self, a: Option<&Rc>) -> bool { + a.is_some() + } + + /// Returns true if and only if `a` dominates `b` and `a != b` + /// + /// Note that this is not a constant time operation. + pub fn properly_dominates(&self, a: Option<&BlockRef>, b: Option<&BlockRef>) -> bool { + if a == b { + return false; + } + let a = self.get(a); + let b = self.get(b); + if a.is_none() || b.is_none() { + return false; + } + self.properly_dominates_node(a, b) + } + + /// Returns true if and only if `a` dominates `b` and `a != b` + /// + /// Note that this is not a constant time operation. + pub fn properly_dominates_node( + &self, + a: Option>, + b: Option>, + ) -> bool { + a != b && self.dominates_node(a, b) + } + + /// Returns true iff `a` dominates `b`. + /// + /// Note that this is not a constant time operation + pub fn dominates(&self, a: Option<&BlockRef>, b: Option<&BlockRef>) -> bool { + if a == b { + return true; + } + let a = self.get(a); + let b = self.get(b); + self.dominates_node(a, b) + } + + /// Returns true iff `a` dominates `b`. + /// + /// Note that this is not a constant time operation + pub fn dominates_node(&self, a: Option>, b: Option>) -> bool { + // A trivially dominates itself + if a == b { + return true; + } + + // An unreachable node is dominated by anything + if b.is_none() { + return true; + } + + // And dominates nothing. + if a.is_none() { + return false; + } + + let a = a.unwrap(); + let b = b.unwrap(); + + if b.idom().is_some_and(|idom| idom == a) { + return true; + } + + if a.idom().is_some_and(|idom| idom == b) { + return false; + } + + // A can only dominate B if it is higher in the tree + if a.level() >= b.level() { + return false; + } + + if self.valid.get() { + return b.is_dominated_by(&a); + } + + // If we end up with too many slow queries, just update the DFS numbers on the assumption + // that we are going to keep querying + self.slow_queries.set(self.slow_queries.get() + 1); + if self.slow_queries.get() > 32 { + self.update_dfs_numbers(); + return b.is_dominated_by(&a); + } + + self.dominated_by_slow_tree_walk(a, b) + } + + /// Finds the nearest block which is a common dominator of both `a` and `b` + pub fn find_nearest_common_dominator(&self, a: &BlockRef, b: &BlockRef) -> Option { + assert!(a.borrow().parent() == b.borrow().parent(), "two blocks are not in same region"); + + // If either A or B is an entry block then it is nearest common dominator (for forward + // dominators). + if !self.is_post_dominator() { + let parent = a.borrow().parent().unwrap(); + let entry = parent.borrow().entry_block_ref().unwrap(); + if a == &entry || b == &entry { + return Some(entry); + } + } + + let mut a = self.get(Some(a)).expect("'a' must be in the tree"); + let mut b = self.get(Some(b)).expect("'b' must be in the tree"); + + // Use level information to go up the tree until the levels match. Then continue going up + // until we arrive at the same node. + while a != b { + if a.level() < b.level() { + core::mem::swap(&mut a, &mut b); + } + + a = a.idom().unwrap(); + } + + a.block().cloned() + } +} + +impl DomTreeBase { + pub fn insert_edge(&mut self, mut from: Option, mut to: Option) { + if self.is_post_dominator() { + core::mem::swap(&mut from, &mut to); + } + SemiNCA::::insert_edge(self, None, from, to) + } + + pub fn delete_edge(&mut self, mut from: Option, mut to: Option) { + if self.is_post_dominator() { + core::mem::swap(&mut from, &mut to); + } + SemiNCA::::delete_edge(self, None, from.as_ref(), to.as_ref()) + } + + pub fn apply_updates( + &mut self, + pre_view_cfg: cfg::CfgDiff, + post_view_cfg: cfg::CfgDiff, + ) { + SemiNCA::::apply_updates(self, pre_view_cfg, post_view_cfg); + } + + pub fn compute(&mut self) { + SemiNCA::::compute_from_scratch(self, None); + } + + pub fn compute_with_updates(&mut self, updates: impl ExactSizeIterator) { + // FIXME: Updated to use the PreViewCFG and behave the same as until now. + // This behavior is however incorrect; this actually needs the PostViewCFG. + let pre_view_cfg = cfg::CfgDiff::new(updates, true); + let bui = BatchUpdateInfo::new(pre_view_cfg, None); + SemiNCA::::compute_from_scratch(self, Some(bui)); + } + + pub fn verify(&self, level: DomTreeVerificationLevel) -> bool { + let snca = SemiNCA::new(None); + + // Simplest check is to compare against a new tree. This will also usefully print the old + // and ne3w trees, if they are different. + if !self.is_same_as_fresh_tree() { + return false; + } + + // Common checks to verify the properties of the tree. O(n log n) at worst. + if !snca.verify_roots(self) + || !snca.verify_reachability(self) + || !snca.verify_levels(self) + || !snca.verify_dfs_numbers(self) + { + return false; + } + + // Extra checks depending on verification level. Up to O(n^3) + match level { + DomTreeVerificationLevel::Basic => { + if !snca.verify_parent_property(self) { + return false; + } + } + DomTreeVerificationLevel::Full => { + if !snca.verify_parent_property(self) || !snca.verify_sibling_property(self) { + return false; + } + } + _ => (), + } + + true + } + + fn is_same_as_fresh_tree(&self) -> bool { + let fresh = Self::new(self.parent.clone()).unwrap(); + let is_same = self == &fresh; + if !is_same { + log::error!( + "{} is different than a freshly computed one!", + if IS_POST_DOM { + "post-dominator tree" + } else { + "dominator tree" + } + ); + log::error!("Current: {self}"); + log::error!("Fresh: {fresh}"); + } + + is_same + } + + pub fn is_virtual_root(&self, node: &DomTreeNode) -> bool { + self.is_post_dominator() && node.block.is_none() + } + + pub fn add_new_block(&mut self, block: BlockRef, idom: Option) -> Rc { + assert!(self.get(Some(&block)).is_none(), "block already in dominator tree"); + let idom = self.get(idom.as_ref()).expect("no immediate dominator specified for `idom`"); + self.mark_invalid(); + self.create_node(Some(block), Some(idom)) + } + + pub fn set_new_root(&mut self, block: BlockRef) -> Rc { + assert!(self.get(Some(&block)).is_none(), "block already in dominator tree"); + assert!(!self.is_post_dominator(), "cannot change root of post-dominator tree"); + + self.valid.set(false); + let node = self.create_node(Some(block.clone()), None); + if self.roots.is_empty() { + self.roots.push(Some(block)); + } else { + assert_eq!(self.roots.len(), 1); + let old_node = self.get(self.roots[0].as_ref()).unwrap(); + node.add_child(old_node.clone()); + old_node.idom.set(Some(node.clone())); + old_node.update_level(); + self.roots[0] = Some(block); + } + self.root = Some(node.clone()); + node + } + + pub fn change_immediate_dominator(&mut self, n: &BlockRef, idom: Option<&BlockRef>) { + let n = self.get(Some(n)).expect("expected `n` to be in tree"); + let idom = self.get(idom).expect("expected `idom` to be in tree"); + self.change_immediate_dominator_node(n, idom); + } + + pub fn change_immediate_dominator_node(&mut self, n: Rc, idom: Rc) { + self.valid.set(false); + n.idom.set(Some(idom)); + } + + /// Removes a node from the dominator tree. + /// + /// Block must not dominate any other blocks. + /// + /// Removes node from the children of its immediate dominator. Deletes dominator node associated + /// with `block`. + pub fn erase_node(&mut self, block: &BlockRef) { + let node_index = self.node_index(Some(block)).expect("removing node that isn't in tree"); + let node = unsafe { self.nodes.get_unchecked(node_index).1.clone() }; + assert!(node.is_leaf(), "node is not a leaf node"); + + self.valid.set(false); + + // Remove node from immediate dominator's children + if let Some(idom) = node.idom() { + idom.children.borrow_mut().retain(|child| child != &node); + } + + self.nodes.swap_remove(node_index); + + if !IS_POST_DOM { + return; + } + + // Remember to update PostDominatorTree roots + if let Some(root_index) = + self.roots.iter().position(|r| r.as_ref().is_some_and(|r| r == block)) + { + self.roots.swap_remove(root_index); + } + } + + /// Assign in and out numbers to the nodes while walking the dominator tree in DFS order. + pub fn update_dfs_numbers(&self) { + if self.valid.get() { + self.slow_queries.set(0); + return; + } + + let mut worklist = SmallVec::<[(Rc, usize); 32]>::default(); + let this_root = self.root_node().unwrap(); + + // Both dominators and postdominators have a single root node. In the case + // case of PostDominatorTree, this node is a virtual root. + this_root.num_in.set(NonZeroU32::new(1)); + worklist.push((this_root, 0)); + + let mut dfs_num = 2u32; + + while let Some((node, child_index)) = worklist.last_mut() { + // If we visited all of the children of this node, "recurse" back up the + // stack setting the DFOutNum. + if *child_index >= node.num_children() { + node.num_out.set(Some(unsafe { NonZeroU32::new_unchecked(dfs_num) })); + dfs_num += 1; + worklist.pop(); + } else { + // Otherwise, recursively visit this child. + let index = *child_index; + *child_index += 1; + let child = node.children.borrow()[index].clone(); + dfs_num += 1; + child.num_in.set(Some(unsafe { NonZeroU32::new_unchecked(dfs_num) })); + worklist.push((child, 0)); + } + } + + self.slow_queries.set(0); + self.valid.set(true); + } + + /// Reset the dominator tree state + pub fn reset(&mut self) { + self.nodes.clear(); + self.root.take(); + self.roots.clear(); + self.valid.set(false); + self.slow_queries.set(0); + } + + pub(super) fn create_node( + &mut self, + block: Option, + idom: Option>, + ) -> Rc { + let node = Rc::new(DomTreeNode::new(block.clone(), idom.clone())); + self.nodes.push((block, node.clone())); + if let Some(idom) = idom { + idom.add_child(node.clone()); + } + node + } + + /// `block` is split and now it has one successor. + /// + /// Update dominator tree to reflect the change. + pub fn split_block(&mut self, block: &BlockRef) { + if IS_POST_DOM { + self.split::>(block.clone()); + } else { + self.split::(block.clone()); + } + } + + // `block` is split and now it has one successor. Update dominator tree to reflect this change. + fn split(&mut self, block: ::Node) + where + G: InvertibleGraph, + { + let mut successors = G::children(block.clone()); + assert_eq!(successors.len(), 1, "`block` should have a single successor"); + + let succ = successors.next().unwrap(); + let predecessors = G::inverse_children(block.clone()).collect::>(); + + assert!(!predecessors.is_empty(), "expected at at least one predecessor"); + + let mut block_dominates_succ = true; + for pred in G::inverse_children(succ.clone()) { + if pred != block + && !self.dominates(Some(&succ), Some(&pred)) + && self.is_reachable_from_entry(&pred) + { + block_dominates_succ = false; + break; + } + } + + // Find `block`'s immediate dominator and create new dominator tree node for `block`. + let idom = predecessors.iter().find(|p| self.is_reachable_from_entry(p)).cloned(); + + // It's possible that none of the predecessors of `block` are reachable; + // in that case, `block` itself is unreachable, so nothing needs to be + // changed. + let Some(idom) = idom else { + return; + }; + + let idom = predecessors.iter().fold(idom, |idom, p| { + if self.is_reachable_from_entry(p) { + self.find_nearest_common_dominator(&idom, p).expect("expected idom") + } else { + idom + } + }); + + // Create the new dominator tree node... and set the idom of `block`. + let node = self.add_new_block(block.clone(), Some(idom)); + + // If NewBB strictly dominates other blocks, then it is now the immediate + // dominator of NewBBSucc. Update the dominator tree as appropriate. + if block_dominates_succ { + let succ_node = self.get(Some(&succ)).expect("expected 'succ' to be in dominator tree"); + self.change_immediate_dominator_node(succ_node, node); + } + } + + fn dominated_by_slow_tree_walk(&self, a: Rc, b: Rc) -> bool { + assert_ne!(a, b); + + let a_level = a.level(); + let mut b = b; + + // Don't walk nodes above A's subtree. When we reach A's level, we must + // either find A or be in some other subtree not dominated by A. + while let Some(b_idom) = b.idom() { + if b_idom.level() >= a_level { + // Walk up the tree + b = b_idom; + } + } + + b == a + } +} + +impl fmt::Display for DomTreeBase { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use core::fmt::Write; + + f.write_str("=============================--------------------------------\n")?; + if IS_POST_DOM { + f.write_str("Inorder PostDominator Tree: ")?; + } else { + f.write_str("Inorder Dominator Tree: ")?; + } + if !self.valid.get() { + write!(f, "DFS numbers invalid: {} slow queries.", self.slow_queries.get())?; + } + f.write_char('\n')?; + + // The postdom tree can have a `None` root if there are no returns. + if let Some(root_node) = self.root_node() { + print_dom_tree(root_node, 1, f)? + } + f.write_str("Roots: ")?; + for (i, block) in self.roots.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + if let Some(block) = block { + write!(f, "{block}")?; + } else { + f.write_str("")?; + } + } + f.write_char('\n') + } +} + +fn print_dom_tree( + node: Rc, + level: usize, + f: &mut core::fmt::Formatter<'_>, +) -> core::fmt::Result { + write!(f, "{: <1$}", "", level)?; + writeln!(f, "[{level}] {node}")?; + for child_node in node.children.borrow().iter().cloned() { + print_dom_tree(child_node, level + 1, f)?; + } + Ok(()) +} + +impl Eq for DomTreeBase {} +impl PartialEq for DomTreeBase { + fn eq(&self, other: &Self) -> bool { + self.parent == other.parent + && self.roots.len() == other.roots.len() + && self.roots.iter().all(|root| other.roots.contains(root)) + && self.nodes.len() == other.nodes.len() + && self.nodes.iter().all(|(_, node)| { + let block = node.block(); + other.get(block).is_some_and(|n| node == &n) + }) + } +} diff --git a/hir2/src/ir/entity.rs b/hir2/src/ir/entity.rs new file mode 100644 index 000000000..a6d1aa740 --- /dev/null +++ b/hir2/src/ir/entity.rs @@ -0,0 +1,1168 @@ +mod group; +mod list; +mod storage; + +use alloc::alloc::{AllocError, Layout}; +use core::{ + any::Any, + cell::{Cell, UnsafeCell}, + fmt, + hash::Hash, + mem::MaybeUninit, + ops::{Deref, DerefMut}, + ptr::NonNull, +}; + +pub use self::{ + group::EntityGroup, + list::{EntityCursor, EntityCursorMut, EntityIter, EntityList, MaybeDefaultEntityIter}, + storage::{EntityRange, EntityRangeMut, EntityStorage}, +}; +use crate::any::*; + +/// A trait implemented by an IR entity +pub trait Entity: Any {} + +/// A trait implemented by an [Entity] that is a logical child of another entity type, and is stored +/// in the parent using an [EntityList]. +/// +/// This trait defines callbacks that are executed any time the entity is modified in relation to its +/// parent entity, i.e. inserted in a parent, removed from a parent, or moved from one to another. +/// +/// By default, these callbacks are no-ops. +pub trait EntityWithParent: Entity { + /// The parent entity that this entity logically belongs to. + type Parent: Entity; + + /// Invoked when this entity type is inserted into an intrusive list + #[allow(unused_variables)] + #[inline] + fn on_inserted_into_parent( + this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + } + /// Invoked when this entity type is removed from an intrusive list + #[allow(unused_variables)] + #[inline] + fn on_removed_from_parent( + this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + } + /// Invoked when a set of entities is moved from one intrusive list to another + #[allow(unused_variables)] + #[inline] + fn on_transfered_to_new_parent( + from: UnsafeIntrusiveEntityRef, + to: UnsafeIntrusiveEntityRef, + transferred: impl IntoIterator>, + ) { + } +} + +/// A trait implemented by an [Entity] that has a unique identifier +/// +/// Currently, this is used only for [Value]s and [Block]s. +pub trait EntityWithId: Entity { + type Id: EntityId; + + fn id(&self) -> Self::Id; +} + +/// A trait implemented by an IR entity that can be stored in [EntityStorage]. +pub trait StorableEntity { + /// Get the absolute index of this entity in its container. + fn index(&self) -> usize; + /// Set the absolute index of this entity in its container. + /// + /// # Safety + /// + /// This is intended to be called only by the [EntityStorage] implementation, as it is + /// responsible for maintaining indices of all items it is storing. However, entities commonly + /// want to know their own index in storage, so this trait allows them to conceptually own the + /// index, but delegate maintenance to [EntityStorage]. + unsafe fn set_index(&mut self, index: usize); + /// Called when this entity is removed from [EntityStorage] + #[inline(always)] + fn unlink(&mut self) {} +} + +/// A trait that must be implemented by the unique identifier for an [Entity] +pub trait EntityId: Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash + fmt::Display { + fn as_usize(&self) -> usize; +} + +/// An error raised when an aliasing violation is detected in the use of [EntityHandle] +#[non_exhaustive] +pub struct AliasingViolationError { + #[cfg(debug_assertions)] + location: &'static core::panic::Location<'static>, + kind: AliasingViolationKind, +} + +#[derive(Debug)] +enum AliasingViolationKind { + /// Attempted to create an immutable alias for an entity that was mutably borrowed + Immutable, + /// Attempted to create a mutable alias for an entity that was immutably borrowed + Mutable, +} + +impl fmt::Display for AliasingViolationKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Immutable => f.write_str("already mutably borrowed"), + Self::Mutable => f.write_str("already borrowed"), + } + } +} + +impl fmt::Debug for AliasingViolationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("AliasingViolationError"); + builder.field("kind", &self.kind); + #[cfg(debug_assertions)] + builder.field("location", &self.location); + builder.finish() + } +} +impl fmt::Display for AliasingViolationError { + #[cfg(debug_assertions)] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} in file '{}' at line {} and column {}", + &self.kind, + self.location.file(), + self.location.line(), + self.location.column() + ) + } + + #[cfg(not(debug_assertions))] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", &self.kind) + } +} + +/// A raw pointer to an IR entity that has no associated metadata +pub type UnsafeEntityRef = RawEntityRef; + +/// A raw pointer to an IR entity that has an intrusive linked-list link as its metadata +pub type UnsafeIntrusiveEntityRef = RawEntityRef; + +/// A [RawEntityRef] is an unsafe smart pointer type for IR entities allocated in a [Context]. +/// +/// Along with the type of entity referenced, it can be instantiated with extra metadata of any +/// type. For example, [UnsafeIntrusiveEntityRef] stores an intrusive link in the entity metadata, +/// so that the entity can be added to an intrusive linked list without the entity needing to +/// know about the link - and without violating aliasing rules when navigating the list. +/// +/// Unlike regular references, no reference to the underlying `T` is constructed until one is +/// needed, at which point the borrow (whether mutable or immutable) is dynamically checked to +/// ensure that it is valid according to Rust's aliasing rules. +/// +/// As a result, a [RawEntityRef] is not considered an alias, and it is possible to acquire a +/// mutable reference to the underlying data even while other copies of the handle exist. Any +/// attempt to construct invalid aliases (immutable reference while a mutable reference exists, or +/// vice versa), will result in a runtime panic. +/// +/// This is a tradeoff, as we do not get compile-time guarantees that such panics will not occur, +/// but in exchange we get a much more flexible and powerful IR structure. +/// +/// # SAFETY +/// +/// Unlike most smart-pointer types, e.g. `Rc`, [RAwEntityRef] does not provide any protection +/// against the underlying allocation being deallocated (i.e. the arena it points into is dropped). +/// This is by design, as the type is meant to be stored in objects inside the arena, and +/// _not_ dropped when the arena is dropped. This requires care when using it however, to ensure +/// that no [RawEntityRef] lives longer than the arena that allocated it. +/// +/// For a safe entity reference, see [EntityRef], which binds a [RawEntityRef] to the lifetime +/// of the arena. +pub struct RawEntityRef { + inner: NonNull>, +} +impl Clone for RawEntityRef { + fn clone(&self) -> Self { + Self { inner: self.inner } + } +} +impl RawEntityRef { + /// Create a new [EntityHandle] from a raw pointer to the underlying [EntityObj]. + /// + /// # SAFETY + /// + /// [EntityHandle] is designed to operate like an owned smart-pointer type, ala `Rc`. As a + /// result, it expects that the underlying data _never moves_ after it is allocated, for as + /// long as any outstanding [EntityHandle]s exist that might be used to access that data. + /// + /// Additionally, it is expected that all accesses to the underlying data flow through an + /// [EntityHandle], as it is the foundation on which the soundness of [EntityHandle] is built. + /// You must ensure that there no other references to the underlying data exist, or can be + /// created, _except_ via [EntityHandle]. + /// + /// You should generally not be using this API, as it is meant solely for constructing an + /// [EntityHandle] immediately after allocating the underlying [EntityObj]. + #[inline] + unsafe fn from_inner(inner: NonNull>) -> Self { + Self { inner } + } + + #[inline] + unsafe fn from_ptr(ptr: *mut RawEntityMetadata) -> Self { + debug_assert!(!ptr.is_null()); + Self::from_inner(NonNull::new_unchecked(ptr)) + } + + #[inline] + fn into_inner(this: Self) -> NonNull> { + this.inner + } +} + +impl RawEntityRef { + /// Create a new [RawEntityRef] by allocating `value` with `metadata` in the given arena + /// allocator. + /// + /// # SAFETY + /// + /// The resulting [RawEntityRef] must not outlive the arena. This is not enforced statically, + /// it is up to the caller to uphold the invariants of this type. + pub fn new_with_metadata(value: T, metadata: Metadata, arena: &blink_alloc::Blink) -> Self { + unsafe { + Self::from_inner(NonNull::new_unchecked( + arena.put(RawEntityMetadata::new(value, metadata)), + )) + } + } + + /// Create a [RawEntityRef] for an entity which may not be fully initialized, using the provided + /// arena. + /// + /// # SAFETY + /// + /// The safety rules are much the same as [RawEntityRef::new], with the main difference + /// being that the `T` does not have to be initialized yet. No references to the `T` will + /// be created directly until [RawEntityRef::assume_init] is called. + pub fn new_uninit_with_metadata( + metadata: Metadata, + arena: &blink_alloc::Blink, + ) -> RawEntityRef, Metadata> { + unsafe { + RawEntityRef::from_ptr(RawEntityRef::allocate_for_layout( + metadata, + Layout::new::(), + |layout| arena.allocator().allocate(layout), + <*mut u8>::cast, + )) + } + } +} + +impl RawEntityRef { + pub fn new(value: T, arena: &blink_alloc::Blink) -> Self { + RawEntityRef::new_with_metadata(value, (), arena) + } + + pub fn new_uninit(arena: &blink_alloc::Blink) -> RawEntityRef, ()> { + RawEntityRef::new_uninit_with_metadata((), arena) + } +} + +impl RawEntityRef, Metadata> { + /// Converts to `RawEntityRef`. + /// + /// # Safety + /// + /// Just like with [MaybeUninit::assume_init], it is up to the caller to guarantee that the + /// value really is in an initialized state. Calling this when the content is not yet fully + /// initialized causes immediate undefined behavior. + #[inline] + pub unsafe fn assume_init(self) -> RawEntityRef { + let ptr = Self::into_inner(self); + unsafe { RawEntityRef::from_inner(ptr.cast()) } + } +} + +impl RawEntityRef { + /// Convert this handle into a raw pointer to the underlying entity. + /// + /// This should only be used in situations where the returned pointer will not be used to + /// actually access the underlying entity. Use [get] or [get_mut] for that. [RawEntityRef] + /// ensures that Rust's aliasing rules are not violated when using it, but if you use the + /// returned pointer to do so, no such guarantee is provided, and undefined behavior can + /// result. + /// + /// # Safety + /// + /// The returned pointer _must_ not be used to create a reference to the underlying entity + /// unless you can guarantee that such a reference does not violate Rust's aliasing rules. + /// + /// Do not use the pointer to create a mutable reference if other references exist, and do + /// not use the pointer to create an immutable reference if a mutable reference exists or + /// might be created while the immutable reference lives. + pub fn into_raw(this: Self) -> *const T { + Self::as_ptr(&this) + } + + pub fn as_ptr(this: &Self) -> *const T { + let ptr: *mut RawEntityMetadata = NonNull::as_ptr(this.inner); + + // SAFETY: This cannot go through Deref::deref or RawEntityRef::inner because this + // is required to retain raw/mut provenance such that e.g. `get_mut` can write through + // the pointer after the RawEntityRef is recovered through `from_raw` + let ptr = unsafe { core::ptr::addr_of_mut!((*ptr).entity.cell) }; + UnsafeCell::raw_get(ptr).cast_const() + } + + /// Convert a pointer returned by [RawEntityRef::into_raw] back into a [RawEntityRef]. + /// + /// # Safety + /// + /// * It is _only_ valid to call this method on a pointer returned by [RawEntityRef::into_raw]. + /// * The pointer must be a valid pointer for `T` + pub unsafe fn from_raw(ptr: *const T) -> Self { + let offset = unsafe { RawEntityMetadata::::data_offset(ptr) }; + + // Reverse the offset to find the original EntityObj + let entity_ptr = unsafe { ptr.byte_sub(offset) as *mut RawEntityMetadata }; + + unsafe { Self::from_ptr(entity_ptr) } + } + + /// Get a dynamically-checked immutable reference to the underlying `T` + #[track_caller] + pub fn borrow<'a, 'b: 'a>(&'a self) -> EntityRef<'b, T> { + let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); + unsafe { (*core::ptr::addr_of!((*ptr).entity)).borrow() } + } + + /// Get a dynamically-checked mutable reference to the underlying `T` + #[track_caller] + pub fn borrow_mut<'a, 'b: 'a>(&'a mut self) -> EntityMut<'b, T> { + let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); + unsafe { (*core::ptr::addr_of!((*ptr).entity)).borrow_mut() } + } + + /// Try to get a dynamically-checked mutable reference to the underlying `T` + /// + /// Returns `None` if the entity is already borrowed + pub fn try_borrow_mut<'a, 'b: 'a>(&'a mut self) -> Option> { + let ptr: *mut RawEntityMetadata = NonNull::as_ptr(self.inner); + unsafe { (*core::ptr::addr_of!((*ptr).entity)).try_borrow_mut().ok() } + } + + pub fn ptr_eq(this: &Self, other: &Self) -> bool { + core::ptr::addr_eq(this.inner.as_ptr(), other.inner.as_ptr()) + } + + unsafe fn allocate_for_layout( + metadata: Metadata, + value_layout: Layout, + allocate: F, + mem_to_metadata: F2, + ) -> *mut RawEntityMetadata + where + F: FnOnce(Layout) -> Result, AllocError>, + F2: FnOnce(*mut u8) -> *mut RawEntityMetadata, + { + use alloc::alloc::handle_alloc_error; + + let layout = raw_entity_metadata_layout_for_value_layout::(value_layout); + unsafe { + RawEntityRef::try_allocate_for_layout(metadata, value_layout, allocate, mem_to_metadata) + .unwrap_or_else(|_| handle_alloc_error(layout)) + } + } + + #[inline] + unsafe fn try_allocate_for_layout( + metadata: Metadata, + value_layout: Layout, + allocate: F, + mem_to_metadata: F2, + ) -> Result<*mut RawEntityMetadata, AllocError> + where + F: FnOnce(Layout) -> Result, AllocError>, + F2: FnOnce(*mut u8) -> *mut RawEntityMetadata, + { + let layout = raw_entity_metadata_layout_for_value_layout::(value_layout); + let ptr = allocate(layout)?; + let inner = mem_to_metadata(ptr.as_non_null_ptr().as_ptr()); + unsafe { + debug_assert_eq!(Layout::for_value_raw(inner), layout); + + core::ptr::addr_of_mut!((*inner).metadata).write(metadata); + core::ptr::addr_of_mut!((*inner).entity.borrow).write(Cell::new(BorrowFlag::UNUSED)); + #[cfg(debug_assertions)] + core::ptr::addr_of_mut!((*inner).entity.borrowed_at).write(Cell::new(None)); + } + + Ok(inner) + } +} + +impl RawEntityRef { + /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. + /// + /// If the cast is not valid for this reference, `Err` is returned containing the original value. + #[inline] + pub fn try_downcast( + self, + ) -> Result, RawEntityRef> + where + To: DowncastFromRef + 'static, + From: Is + AsAny + 'static, + Obj: ?Sized, + { + RawEntityRef::::try_downcast_from(self) + } + + /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. + /// + /// If the cast is not valid for this reference, `Err` is returned containing the original value. + #[inline] + pub fn try_downcast_ref(&self) -> Option> + where + To: DowncastFromRef + 'static, + From: Is + AsAny + 'static, + Obj: ?Sized, + { + RawEntityRef::::try_downcast_from_ref(self) + } + + /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. + /// + /// Panics if the cast is not valid for this reference. + #[inline] + #[track_caller] + pub fn downcast(self) -> RawEntityRef + where + To: DowncastFromRef + 'static, + From: Is + AsAny + 'static, + Obj: ?Sized, + { + RawEntityRef::::downcast_from(self) + } + + /// Casts this reference to the concrete type `T`, if the underlying value is a `T`. + /// + /// Panics if the cast is not valid for this reference. + #[inline] + #[track_caller] + pub fn downcast_ref(&self) -> RawEntityRef + where + To: DowncastFromRef + 'static, + From: Is + AsAny + 'static, + Obj: ?Sized, + { + RawEntityRef::::downcast_from_ref(self) + } +} + +impl RawEntityRef { + pub fn try_downcast_from( + from: RawEntityRef, + ) -> Result> + where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, + { + let borrow = from.borrow(); + if let Some(to) = borrow.as_any().downcast_ref() { + Ok(unsafe { RawEntityRef::from_raw(to) }) + } else { + Err(from) + } + } + + pub fn try_downcast_from_ref(from: &RawEntityRef) -> Option + where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, + { + let borrow = from.borrow(); + borrow.as_any().downcast_ref().map(|to| unsafe { RawEntityRef::from_raw(to) }) + } + + #[track_caller] + pub fn downcast_from(from: RawEntityRef) -> Self + where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, + { + let borrow = from.borrow(); + unsafe { RawEntityRef::from_raw(borrow.as_any().downcast_ref().expect("invalid cast")) } + } + + #[track_caller] + pub fn downcast_from_ref(from: &RawEntityRef) -> Self + where + From: ?Sized + Is + AsAny + 'static, + To: DowncastFromRef + 'static, + Obj: ?Sized, + { + let borrow = from.borrow(); + unsafe { RawEntityRef::from_raw(borrow.as_any().downcast_ref().expect("invalid cast")) } + } +} + +impl RawEntityRef { + /// Casts this reference to the an unsized type `Trait`, if `From` implements `Trait` + /// + /// If the cast is not valid for this reference, `Err` is returned containing the original value. + #[inline] + pub fn upcast(self) -> RawEntityRef + where + To: ?Sized, + From: core::marker::Unsize + AsAny + 'static, + { + unsafe { RawEntityRef::::from_inner(self.inner) } + } +} + +impl core::ops::CoerceUnsized> + for RawEntityRef +where + T: ?Sized + core::marker::Unsize, + U: ?Sized, +{ +} +impl Eq for RawEntityRef {} +impl PartialEq for RawEntityRef { + #[inline] + fn eq(&self, other: &Self) -> bool { + Self::ptr_eq(self, other) + } +} +impl PartialOrd for RawEntityRef { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for RawEntityRef { + #[inline] + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + let a = self.inner.as_ptr() as *const () as usize; + let b = other.inner.as_ptr() as *const () as usize; + a.cmp(&b) + } +} +impl core::hash::Hash for RawEntityRef { + fn hash(&self, state: &mut H) { + self.inner.hash(state); + } +} +impl fmt::Pointer for RawEntityRef { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Pointer::fmt(&Self::as_ptr(self), f) + } +} +impl fmt::Display for RawEntityRef { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.borrow().id()) + } +} + +impl fmt::Debug for RawEntityRef { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.borrow().id()) + } +} +impl crate::formatter::PrettyPrint + for RawEntityRef +{ + #[inline] + fn render(&self) -> crate::formatter::Document { + self.borrow().render() + } +} +impl StorableEntity for RawEntityRef { + #[inline] + fn index(&self) -> usize { + self.borrow().index() + } + + #[inline] + unsafe fn set_index(&mut self, index: usize) { + unsafe { + self.borrow_mut().set_index(index); + } + } + + #[inline] + fn unlink(&mut self) { + self.borrow_mut().unlink() + } +} + +/// A guard that ensures a reference to an IR entity cannot be mutably aliased +pub struct EntityRef<'b, T: ?Sized + 'b> { + value: NonNull, + borrow: BorrowRef<'b>, +} +impl core::ops::Deref for EntityRef<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_ref() } + } +} +impl<'b, T: ?Sized> EntityRef<'b, T> { + #[inline] + pub fn map(orig: Self, f: F) -> EntityRef<'b, U> + where + F: FnOnce(&T) -> &U, + { + EntityRef { + value: NonNull::from(f(&*orig)), + borrow: orig.borrow, + } + } +} + +impl<'b, T, U> core::ops::CoerceUnsized> for EntityRef<'b, T> +where + T: ?Sized + core::marker::Unsize, + U: ?Sized, +{ +} + +impl fmt::Debug for EntityRef<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} +impl fmt::Display for EntityRef<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} +impl crate::formatter::PrettyPrint for EntityRef<'_, T> { + #[inline] + fn render(&self) -> crate::formatter::Document { + (**self).render() + } +} +impl Eq for EntityRef<'_, T> {} +impl PartialEq for EntityRef<'_, T> { + fn eq(&self, other: &Self) -> bool { + **self == **other + } +} +impl PartialOrd for EntityRef<'_, T> { + fn partial_cmp(&self, other: &Self) -> Option { + (**self).partial_cmp(&**other) + } + + fn ge(&self, other: &Self) -> bool { + **self >= **other + } + + fn gt(&self, other: &Self) -> bool { + **self > **other + } + + fn le(&self, other: &Self) -> bool { + **self <= **other + } + + fn lt(&self, other: &Self) -> bool { + **self < **other + } +} +impl Ord for EntityRef<'_, T> { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + (**self).cmp(&**other) + } +} +impl Hash for EntityRef<'_, T> { + fn hash(&self, state: &mut H) { + (**self).hash(state); + } +} + +/// A guard that provides exclusive access to an IR entity +pub struct EntityMut<'b, T: ?Sized> { + /// The raw pointer to the underlying data + /// + /// This is a pointer rather than a `&'b mut T` to avoid `noalias` violations, because a + /// `EntityMut` argument doesn't hold exclusivity for its whole scope, only until it drops. + value: NonNull, + /// This value provides the drop glue for tracking that the underlying allocation is + /// mutably borrowed, but it is otherwise not read. + #[allow(unused)] + borrow: BorrowRefMut<'b>, + /// `NonNull` is covariant over `T`, so we need to reintroduce invariance via phantom data + _marker: core::marker::PhantomData<&'b mut T>, +} +impl<'b, T: ?Sized> EntityMut<'b, T> { + #[inline] + pub fn map(mut orig: Self, f: F) -> EntityMut<'b, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let value = NonNull::from(f(&mut *orig)); + EntityMut { + value, + borrow: orig.borrow, + _marker: core::marker::PhantomData, + } + } + + /// Splits an `EntityMut` into multiple `EntityMut`s for different components of the borrowed + /// data. + /// + /// The underlying entity will remain mutably borrowed until both returned `EntityMut`s go out + /// of scope. + /// + /// The entity is already mutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as `EntityMut::map_split(...)`, so as + /// to avoid conflicting with any method of the same name accessible via the `Deref` impl. + /// + /// # Examples + /// + /// ```rust + /// use midenc_hir2::*; + /// use blink_alloc::Blink; + /// + /// let alloc = Blink::default(); + /// let mut entity = UnsafeEntityRef::new([1, 2, 3, 4], &alloc); + /// let borrow = entity.borrow_mut(); + /// let (mut begin, mut end) = EntityMut::map_split(borrow, |slice| slice.split_at_mut(2)); + /// assert_eq!(*begin, [1, 2]); + /// assert_eq!(*end, [3, 4]); + /// begin.copy_from_slice(&[4, 3]); + /// end.copy_from_slice(&[2, 1]); + /// ``` + #[inline] + pub fn map_split( + mut orig: Self, + f: F, + ) -> (EntityMut<'b, U>, EntityMut<'b, V>) + where + F: FnOnce(&mut T) -> (&mut U, &mut V), + { + let borrow = orig.borrow.clone(); + let (a, b) = f(&mut *orig); + ( + EntityMut { + value: NonNull::from(a), + borrow, + _marker: core::marker::PhantomData, + }, + EntityMut { + value: NonNull::from(b), + borrow: orig.borrow, + _marker: core::marker::PhantomData, + }, + ) + } +} +impl Deref for EntityMut<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_ref() } + } +} +impl DerefMut for EntityMut<'_, T> { + #[inline] + fn deref_mut(&mut self) -> &mut T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_mut() } + } +} + +impl<'b, T, U> core::ops::CoerceUnsized> for EntityMut<'b, T> +where + T: ?Sized + core::marker::Unsize, + U: ?Sized, +{ +} + +impl fmt::Debug for EntityMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} +impl fmt::Display for EntityMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} +impl crate::formatter::PrettyPrint for EntityMut<'_, T> { + #[inline] + fn render(&self) -> crate::formatter::Document { + (**self).render() + } +} +impl Eq for EntityMut<'_, T> {} +impl PartialEq for EntityMut<'_, T> { + fn eq(&self, other: &Self) -> bool { + **self == **other + } +} +impl PartialOrd for EntityMut<'_, T> { + fn partial_cmp(&self, other: &Self) -> Option { + (**self).partial_cmp(&**other) + } + + fn ge(&self, other: &Self) -> bool { + **self >= **other + } + + fn gt(&self, other: &Self) -> bool { + **self > **other + } + + fn le(&self, other: &Self) -> bool { + **self <= **other + } + + fn lt(&self, other: &Self) -> bool { + **self < **other + } +} +impl Ord for EntityMut<'_, T> { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + (**self).cmp(&**other) + } +} +impl Hash for EntityMut<'_, T> { + fn hash(&self, state: &mut H) { + (**self).hash(state); + } +} + +// This type wraps the entity data with extra metadata we want to associate with the entity, but +// separately from it, so that pointers to the metadata do not cause aliasing violations if the +// entity itself is borrowed. +// +// The kind of metadata stored here is unconstrained, but in practice should be limited to things +// that you _need_ to be able to access from a `RawEntityRef`, without aliasing the entity. For now +// the main reason we use this is for the intrusive link used to store entities in an intrusive +// linked list. We don't want traversing the intrusive list to require borrowing the entity, only +// the link, unless we explicitly want to borrow the entity, thus we use the metadata field here +// to hold the link. +// +// This has to be `pub` for implementing the traits required for the intrusive collections +// integration, but its internals are hidden outside this module, and we hide it from the generated +// docs as well. +#[repr(C)] +#[doc(hidden)] +pub struct RawEntityMetadata { + metadata: Metadata, + entity: RawEntity, +} +impl RawEntityMetadata { + pub(crate) fn new(value: T, metadata: Metadata) -> Self { + Self { + metadata, + entity: RawEntity::new(value), + } + } +} +impl RawEntityMetadata { + pub(self) fn borrow(&self) -> EntityRef<'_, T> { + let ptr = self as *const Self; + unsafe { (*core::ptr::addr_of!((*ptr).entity)).borrow() } + } + + pub(self) fn borrow_mut(&self) -> EntityMut<'_, T> { + let ptr = (self as *const Self).cast_mut(); + unsafe { (*core::ptr::addr_of_mut!((*ptr).entity)).borrow_mut() } + } + + #[inline] + const fn metadata_offset() -> usize { + core::mem::offset_of!(RawEntityMetadata<(), Metadata>, metadata) + } + + /// Get the offset within a `RawEntityMetadata` for the payload behind a pointer. + /// + /// # Safety + /// + /// The pointer must point to (and have valid metadata for) a previously valid instance of T, but + /// the T is allowed to be dropped. + unsafe fn data_offset(ptr: *const T) -> usize { + use core::mem::align_of_val_raw; + + // Align the unsized value to the end of the RawEntityMetadata. + // Because RawEntityMetadata/RawEntity is repr(C), it will always be the last field in memory. + // + // SAFETY: since the only unsized types possible are slices, trait objects, and extern types, + // the input safety requirement is currently enough to satisfy the requirements of + // align_of_val_raw; but this is an implementation detail of the language that is unstable + unsafe { RawEntityMetadata::<(), Metadata>::data_offset_align(align_of_val_raw(ptr)) } + } + + #[inline] + fn data_offset_align(align: usize) -> usize { + let layout = Layout::new::>(); + layout.size() + layout.padding_needed_for(align) + } +} + +fn raw_entity_metadata_layout_for_value_layout(layout: Layout) -> Layout { + Layout::new::>() + .extend(layout) + .unwrap() + .0 + .pad_to_align() +} + +/// A [RawEntity] wraps an entity to be allocated in a [Context], and provides dynamic borrow- +/// checking functionality for [UnsafeEntityRef], thereby protecting the entity by ensuring that +/// all accesses adhere to Rust's aliasing rules. +#[repr(C)] +struct RawEntity { + borrow: Cell, + #[cfg(debug_assertions)] + borrowed_at: Cell>>, + cell: UnsafeCell, +} + +impl RawEntity { + pub fn new(value: T) -> Self { + Self { + borrow: Cell::new(BorrowFlag::UNUSED), + #[cfg(debug_assertions)] + borrowed_at: Cell::new(None), + cell: UnsafeCell::new(value), + } + } +} + +impl RawEntity { + #[track_caller] + #[inline] + pub fn borrow(&self) -> EntityRef<'_, T> { + match self.try_borrow() { + Ok(b) => b, + Err(err) => panic_aliasing_violation(err), + } + } + + #[inline] + #[cfg_attr(debug_assertions, track_caller)] + pub fn try_borrow(&self) -> Result, AliasingViolationError> { + match BorrowRef::new(&self.borrow) { + Some(b) => { + #[cfg(debug_assertions)] + { + // `borrowed_at` is always the *first* active borrow + if b.borrow.get() == BorrowFlag(1) { + self.borrowed_at.set(Some(core::panic::Location::caller())); + } + } + + // SAFETY: `BorrowRef` ensures that there is only immutable access to the value + // while borrowed. + let value = unsafe { NonNull::new_unchecked(self.cell.get()) }; + Ok(EntityRef { value, borrow: b }) + } + None => Err(AliasingViolationError { + #[cfg(debug_assertions)] + location: self.borrowed_at.get().unwrap(), + kind: AliasingViolationKind::Immutable, + }), + } + } + + #[inline] + #[track_caller] + pub fn borrow_mut(&self) -> EntityMut<'_, T> { + match self.try_borrow_mut() { + Ok(b) => b, + Err(err) => panic_aliasing_violation(err), + } + } + + #[inline] + #[cfg_attr(feature = "debug_refcell", track_caller)] + pub fn try_borrow_mut(&self) -> Result, AliasingViolationError> { + match BorrowRefMut::new(&self.borrow) { + Some(b) => { + #[cfg(debug_assertions)] + { + self.borrowed_at.set(Some(core::panic::Location::caller())); + } + + // SAFETY: `BorrowRefMut` guarantees unique access. + let value = unsafe { NonNull::new_unchecked(self.cell.get()) }; + Ok(EntityMut { + value, + borrow: b, + _marker: core::marker::PhantomData, + }) + } + None => Err(AliasingViolationError { + // If a borrow occurred, then we must already have an outstanding borrow, + // so `borrowed_at` will be `Some` + #[cfg(debug_assertions)] + location: self.borrowed_at.get().unwrap(), + kind: AliasingViolationKind::Mutable, + }), + } + } +} + +struct BorrowRef<'b> { + borrow: &'b Cell, +} +impl<'b> BorrowRef<'b> { + #[inline] + fn new(borrow: &'b Cell) -> Option { + let b = borrow.get().wrapping_add(1); + if !b.is_reading() { + // Incrementing borrow can result in a non-reading value (<= 0) in these cases: + // 1. It was < 0, i.e. there are writing borrows, so we can't allow a read borrow due to + // Rust's reference aliasing rules + // 2. It was isize::MAX (the max amount of reading borrows) and it overflowed into + // isize::MIN (the max amount of writing borrows) so we can't allow an additional + // read borrow because isize can't represent so many read borrows (this can only + // happen if you mem::forget more than a small constant amount of `EntityRef`s, which + // is not good practice) + None + } else { + // Incrementing borrow can result in a reading value (> 0) in these cases: + // 1. It was = 0, i.e. it wasn't borrowed, and we are taking the first read borrow + // 2. It was > 0 and < isize::MAX, i.e. there were read borrows, and isize is large + // enough to represent having one more read borrow + borrow.set(b); + Some(Self { borrow }) + } + } +} +impl Drop for BorrowRef<'_> { + #[inline] + fn drop(&mut self) { + let borrow = self.borrow.get(); + debug_assert!(borrow.is_reading()); + self.borrow.set(borrow - 1); + } +} +impl Clone for BorrowRef<'_> { + #[inline] + fn clone(&self) -> Self { + // Since this Ref exists, we know the borrow flag + // is a reading borrow. + let borrow = self.borrow.get(); + debug_assert!(borrow.is_reading()); + // Prevent the borrow counter from overflowing into + // a writing borrow. + assert!(borrow != BorrowFlag::MAX); + self.borrow.set(borrow + 1); + BorrowRef { + borrow: self.borrow, + } + } +} + +struct BorrowRefMut<'b> { + borrow: &'b Cell, +} +impl Drop for BorrowRefMut<'_> { + #[inline] + fn drop(&mut self) { + let borrow = self.borrow.get(); + debug_assert!(borrow.is_writing()); + self.borrow.set(borrow + 1); + } +} +impl<'b> BorrowRefMut<'b> { + #[inline] + fn new(borrow: &'b Cell) -> Option { + // NOTE: Unlike BorrowRefMut::clone, new is called to create the initial + // mutable reference, and so there must currently be no existing + // references. Thus, while clone increments the mutable refcount, here + // we explicitly only allow going from UNUSED to UNUSED - 1. + match borrow.get() { + BorrowFlag::UNUSED => { + borrow.set(BorrowFlag::UNUSED - 1); + Some(Self { borrow }) + } + _ => None, + } + } + + // Clones a `BorrowRefMut`. + // + // This is only valid if each `BorrowRefMut` is used to track a mutable + // reference to a distinct, nonoverlapping range of the original object. + // This isn't in a Clone impl so that code doesn't call this implicitly. + #[inline] + fn clone(&self) -> Self { + let borrow = self.borrow.get(); + debug_assert!(borrow.is_writing()); + // Prevent the borrow counter from underflowing. + assert!(borrow != BorrowFlag::MIN); + self.borrow.set(borrow - 1); + Self { + borrow: self.borrow, + } + } +} + +/// Positive values represent the number of outstanding immutable borrows, while negative values +/// represent the number of outstanding mutable borrows. Multiple mutable borrows can only be +/// active simultaneously if they refer to distinct, non-overlapping components of an entity. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +#[repr(transparent)] +struct BorrowFlag(isize); +impl BorrowFlag { + const MAX: Self = Self(isize::MAX); + const MIN: Self = Self(isize::MIN); + const UNUSED: Self = Self(0); + + pub fn is_writing(&self) -> bool { + self.0 < Self::UNUSED.0 + } + + pub fn is_reading(&self) -> bool { + self.0 > Self::UNUSED.0 + } + + #[inline] + pub const fn wrapping_add(self, rhs: isize) -> Self { + Self(self.0.wrapping_add(rhs)) + } +} +impl core::ops::Add for BorrowFlag { + type Output = BorrowFlag; + + #[inline] + fn add(self, rhs: isize) -> Self::Output { + Self(self.0 + rhs) + } +} +impl core::ops::Sub for BorrowFlag { + type Output = BorrowFlag; + + #[inline] + fn sub(self, rhs: isize) -> Self::Output { + Self(self.0 - rhs) + } +} + +// This ensures the panicking code is outlined from `borrow` and `borrow_mut` for `EntityObj`. +#[cfg_attr(not(panic = "abort"), inline(never))] +#[track_caller] +#[cold] +fn panic_aliasing_violation(err: AliasingViolationError) -> ! { + panic!("{err:?}") +} diff --git a/hir2/src/ir/entity/group.rs b/hir2/src/ir/entity/group.rs new file mode 100644 index 000000000..f3e8bb69c --- /dev/null +++ b/hir2/src/ir/entity/group.rs @@ -0,0 +1,217 @@ +use core::fmt; + +/// Represents size and range information for a contiguous grouping of entities in a vector. +/// +/// This is used so that individual groups can be grown or shrunk, while maintaining stability +/// of references to items in other groups. +#[derive(Default, Copy, Clone)] +pub struct EntityGroup(u16); +impl fmt::Debug for EntityGroup { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EntityGroup") + .field("range", &self.as_range()) + .field("len", &self.len()) + .finish() + } +} +impl EntityGroup { + const START_MASK: u16 = u8::MAX as u16; + + /// Create a new group of size `len`, starting at index `start` + pub fn new(start: usize, len: usize) -> Self { + let start = u16::try_from(start).expect("too many items"); + let len = u16::try_from(len).expect("group too large"); + let group = start | (len << 8); + + Self(group) + } + + /// Get the start index in the containing vector + #[inline] + pub fn start(&self) -> usize { + (self.0 & Self::START_MASK) as usize + } + + /// Get the end index (exclusive) in the containing vector + #[inline] + pub fn end(&self) -> usize { + self.start() + self.len() + } + + /// Returns true if this group is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get the number of items in this group + #[inline] + pub fn len(&self) -> usize { + (self.0 >> 8) as usize + } + + /// Get the [core::ops::Range] equivalent of this group + pub fn as_range(&self) -> core::ops::Range { + let start = self.start(); + let len = self.len(); + start..(start + len) + } + + /// Increase the size of this group by `n` items + /// + /// Panics if `n` overflows `u16::MAX`, or if the resulting size overflows `u8::MAX` + pub fn grow(&mut self, n: usize) { + let n = u16::try_from(n).expect("group is too large"); + let start = self.0 & Self::START_MASK; + let len = (self.0 >> 8) + n; + assert!(len <= u8::MAX as u16, "group is too large"); + self.0 = start | (len << 8); + } + + /// Decrease the size of this group by `n` items + /// + /// Panics if `n` overflows `u16::MAX`, or if `n` is greater than the number of remaining items. + pub fn shrink(&mut self, n: usize) { + let n = u16::try_from(n).expect("cannot shrink by a size larger than the max group size"); + let start = self.0 & Self::START_MASK; + let len = (self.0 >> 8).saturating_sub(n); + self.0 = start | (len << 8); + } + + /// Shift the position of this group by `offset` + pub fn shift_start(&mut self, offset: isize) { + let offset = i16::try_from(offset).expect("offset too large"); + let mut start = self.0 & Self::START_MASK; + if offset >= 0 { + start += offset as u16; + } else { + start -= offset.unsigned_abs(); + } + assert!(start <= Self::START_MASK, "group offset cannot be larger than u8::MAX"); + self.0 &= !Self::START_MASK; + self.0 |= start; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn entity_group_empty() { + let group = EntityGroup::new(0, 0); + assert_eq!(group.start(), 0); + assert_eq!(group.end(), 0); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 0..0); + + let group = EntityGroup::new(101, 0); + assert_eq!(group.start(), 101); + assert_eq!(group.end(), 101); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 101..101); + } + + #[test] + fn entity_group_non_empty() { + let group = EntityGroup::new(0, 1); + assert_eq!(group.start(), 0); + assert_eq!(group.end(), 1); + assert_eq!(group.len(), 1); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 0..1); + + let group = EntityGroup::new(255, 255); + assert_eq!(group.start(), 255); + assert_eq!(group.end(), 510); + assert_eq!(group.len(), 255); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 255..510); + } + + #[test] + fn entity_group_grow() { + let mut group = EntityGroup::new(10, 0); + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 10); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 10..10); + + group.grow(1); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 11); + assert_eq!(group.len(), 1); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 10..11); + + group.grow(3); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 14); + assert_eq!(group.len(), 4); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 10..14); + } + + #[test] + fn entity_group_shrink() { + let mut group = EntityGroup::new(10, 4); + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 14); + assert_eq!(group.len(), 4); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 10..14); + + group.shrink(3); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 11); + assert_eq!(group.len(), 1); + assert!(!group.is_empty()); + assert_eq!(group.as_range(), 10..11); + + group.shrink(1); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 10); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 10..10); + + group.shrink(1); + + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 10); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 10..10); + } + + #[test] + fn entity_group_shift_start() { + let mut group = EntityGroup::new(10, 0); + assert_eq!(group.start(), 10); + assert_eq!(group.end(), 10); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 10..10); + + group.shift_start(10); + assert_eq!(group.start(), 20); + assert_eq!(group.end(), 20); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 20..20); + + group.shift_start(-5); + assert_eq!(group.start(), 15); + assert_eq!(group.end(), 15); + assert_eq!(group.len(), 0); + assert!(group.is_empty()); + assert_eq!(group.as_range(), 15..15); + } +} diff --git a/hir2/src/ir/entity/list.rs b/hir2/src/ir/entity/list.rs new file mode 100644 index 000000000..707a73024 --- /dev/null +++ b/hir2/src/ir/entity/list.rs @@ -0,0 +1,755 @@ +use core::{fmt, mem::MaybeUninit, ptr::NonNull}; + +use super::{EntityMut, EntityRef, RawEntityMetadata, RawEntityRef, UnsafeIntrusiveEntityRef}; + +pub struct EntityList { + list: intrusive_collections::linked_list::LinkedList>, +} +impl Default for EntityList { + fn default() -> Self { + Self { + list: Default::default(), + } + } +} +impl EntityList { + /// Construct a new, empty [EntityList] + pub fn new() -> Self { + Self::default() + } + + /// Returns true if this list is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.list.is_empty() + } + + /// Returns the number of entities in this list + pub fn len(&self) -> usize { + let mut cursor = self.list.front(); + let mut usize = 0; + while !cursor.is_null() { + usize += 1; + cursor.move_next(); + } + usize + } + + /// Prepend `entity` to this list + pub fn push_front(&mut self, entity: UnsafeIntrusiveEntityRef) { + self.list.push_front(entity); + } + + /// Append `entity` to this list + pub fn push_back(&mut self, entity: UnsafeIntrusiveEntityRef) { + self.list.push_back(entity); + } + + /// Remove the entity at the front of the list, returning its [TrackedEntityHandle] + /// + /// Returns `None` if the list is empty. + pub fn pop_front(&mut self) -> Option> { + self.list.pop_back() + } + + /// Remove the entity at the back of the list, returning its [TrackedEntityHandle] + /// + /// Returns `None` if the list is empty. + pub fn pop_back(&mut self) -> Option> { + self.list.pop_back() + } + + #[doc(hidden)] + pub fn cursor(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.list.cursor(), + } + } + + #[doc(hidden)] + pub fn cursor_mut(&mut self) -> EntityCursorMut<'_, T> { + EntityCursorMut { + cursor: self.list.cursor_mut(), + } + } + + /// Get an [EntityCursor] pointing to the first entity in the list, or the null object if + /// the list is empty + pub fn front(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.list.front(), + } + } + + /// Get an [EntityCursorMut] pointing to the first entity in the list, or the null object if + /// the list is empty + pub fn front_mut(&mut self) -> EntityCursorMut<'_, T> { + EntityCursorMut { + cursor: self.list.front_mut(), + } + } + + /// Get an [EntityCursor] pointing to the last entity in the list, or the null object if + /// the list is empty + pub fn back(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.list.back(), + } + } + + /// Get an [EntityCursorMut] pointing to the last entity in the list, or the null object if + /// the list is empty + pub fn back_mut(&mut self) -> EntityCursorMut<'_, T> { + EntityCursorMut { + cursor: self.list.back_mut(), + } + } + + /// Get an iterator over the entities in this list + /// + /// The iterator returned produces [EntityRef]s for each item in the list, with their lifetime + /// bound to the list itself, not the iterator. + pub fn iter(&self) -> EntityIter<'_, T> { + EntityIter { + cursor: self.cursor(), + started: false, + } + } + + /// Removes all items from this list. + /// + /// This will unlink all entities currently in the list, which requires iterating through all + /// elements in the list. If the entities may be used again, this ensures that their intrusive + /// link is properly unlinked. + pub fn clear(&mut self) { + self.list.clear(); + } + + /// Empties the list without properly unlinking the intrusive links of the items in the list. + /// + /// Since this does not unlink any objects, any attempts to link these objects into another + /// [EntityList] will fail but will not cause any memory unsafety. To unlink those objects + /// manually, you must call the `force_unlink` function on the link. + pub fn fast_clear(&mut self) { + self.list.fast_clear(); + } + + /// Takes all the elements out of the [EntityList], leaving it empty. + /// + /// The taken elements are returned as a new [EntityList]. + pub fn take(&mut self) -> Self { + Self { + list: self.list.take(), + } + } + + /// Get a cursor to the item pointed to by `ptr`. + /// + /// # Safety + /// + /// This function may only be called when it is known that `ptr` refers to an entity which is + /// linked into this list. This operation will panic if the entity is not linked into any list, + /// and may result in undefined behavior if the operation is linked into a different list. + pub unsafe fn cursor_from_ptr(&self, ptr: UnsafeIntrusiveEntityRef) -> EntityCursor<'_, T> { + unsafe { + let raw = UnsafeIntrusiveEntityRef::into_inner(ptr).as_ptr(); + EntityCursor { + cursor: self.list.cursor_from_ptr(raw), + } + } + } + + /// Get a mutable cursor to the item pointed to by `ptr`. + /// + /// # Safety + /// + /// This function may only be called when it is known that `ptr` refers to an entity which is + /// linked into this list. This operation will panic if the entity is not linked into any list, + /// and may result in undefined behavior if the operation is linked into a different list. + pub unsafe fn cursor_mut_from_ptr( + &mut self, + ptr: UnsafeIntrusiveEntityRef, + ) -> EntityCursorMut<'_, T> { + let raw = UnsafeIntrusiveEntityRef::into_inner(ptr).as_ptr(); + unsafe { + EntityCursorMut { + cursor: self.list.cursor_mut_from_ptr(raw), + } + } + } +} + +impl fmt::Debug for EntityList { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_list(); + for entity in self.iter() { + builder.entry(&entity); + } + builder.finish() + } +} + +impl FromIterator> for EntityList { + fn from_iter(iter: I) -> Self + where + I: IntoIterator>, + { + let mut list = EntityList::::default(); + for handle in iter { + list.push_back(handle); + } + list + } +} + +impl IntoIterator for EntityList { + type IntoIter = intrusive_collections::linked_list::IntoIter>; + type Item = UnsafeIntrusiveEntityRef; + + fn into_iter(self) -> Self::IntoIter { + self.list.into_iter() + } +} + +impl<'a, T> IntoIterator for &'a EntityList { + type IntoIter = EntityIter<'a, T>; + type Item = EntityRef<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// A cursor which provides read-only access to an [EntityList]. +pub struct EntityCursor<'a, T> { + cursor: intrusive_collections::linked_list::Cursor<'a, EntityAdapter>, +} +impl<'a, T> EntityCursor<'a, T> { + /// Returns true if this cursor is pointing to the null object + #[inline] + pub fn is_null(&self) -> bool { + self.cursor.is_null() + } + + /// Get a shared reference to the entity under the cursor. + /// + /// Returns `None` if the cursor is currently pointing to the null object. + /// + /// NOTE: This returns an [EntityRef] whose lifetime is bound to the underlying [EntityList], + /// _not_ the [EntityCursor], since the cursor cannot mutate the list. + pub fn get(&self) -> Option> { + self.cursor.get().map(|obj| obj.entity.borrow()) + } + + /// Get the [TrackedEntityHandle] corresponding to the entity under the cursor. + /// + /// Returns `None` if the cursor is pointing to the null object. + #[inline] + pub fn as_pointer(&self) -> Option> { + self.cursor.clone_pointer() + } + + /// Consume the cursor and convert it into a borrow of the current entity, or `None` if null. + #[inline] + pub fn into_borrow(self) -> Option> { + self.cursor.get().map(|item| item.borrow()) + } + + /// Moves the cursor to the next element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will move it to the front of the + /// [EntityList]. If it is pointing to the back of the [EntityList] then this will move it to + /// the null object. + #[inline] + pub fn move_next(&mut self) { + self.cursor.move_next(); + } + + /// Moves the cursor to the previous element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will move it to the back of the + /// [EntityList]. If it is pointing to the front of the [EntityList] then this will move it to + /// the null object. + #[inline] + pub fn move_prev(&mut self) { + self.cursor.move_prev(); + } + + /// Returns a cursor pointing to the next element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will return a cursor pointing to the + /// front of the [EntityList]. If it is pointing to the last entity of the [EntityList] then + /// this will return a null cursor. + #[inline] + pub fn peek_next(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.peek_next(), + } + } + + /// Returns a cursor pointing to the previous element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will return a cursor pointing to + /// the last entity in the [EntityList]. If it is pointing to the front of the [EntityList] then + /// this will return a null cursor. + #[inline] + pub fn peek_prev(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.peek_prev(), + } + } +} + +/// A cursor which provides mutable access to an [EntityList]. +pub struct EntityCursorMut<'a, T> { + cursor: intrusive_collections::linked_list::CursorMut<'a, EntityAdapter>, +} +impl<'a, T> EntityCursorMut<'a, T> { + /// Returns true if this cursor is pointing to the null object + #[inline] + pub fn is_null(&self) -> bool { + self.cursor.is_null() + } + + /// Get a shared reference to the entity under the cursor. + /// + /// Returns `None` if the cursor is currently pointing to the null object. + /// + /// NOTE: This binds the lifetime of the [EntityRef] to the cursor, to ensure that the cursor + /// is frozen while the entity is being borrowed. This ensures that only one reference at a + /// time is being handed out by this cursor. + pub fn get(&self) -> Option> { + self.cursor.get().map(|obj| obj.entity.borrow()) + } + + /// Get a mutable reference to the entity under the cursor. + /// + /// Returns `None` if the cursor is currently pointing to the null object. + /// + /// Not only does this mutably borrow the cursor, the lifetime of the [EntityMut] is bound to + /// that of the cursor, which means it cannot outlive the cursor, and also prevents the cursor + /// from being accessed in any way until the mutable reference is dropped. This makes it + /// impossible to try and alias the underlying entity using the cursor. + pub fn get_mut(&mut self) -> Option> { + self.cursor.get().map(|obj| obj.entity.borrow_mut()) + } + + /// Returns a read-only cursor pointing to the current element. + /// + /// The lifetime of the returned [EntityCursor] is bound to that of the [EntityCursorMut], which + /// means it cannot outlive the [EntityCursorMut] and that the [EntityCursorMut] is frozen for + /// the lifetime of the [EntityCursor]. + pub fn as_cursor(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.as_cursor(), + } + } + + /// Get the [TrackedEntityHandle] corresponding to the entity under the cursor. + /// + /// Returns `None` if the cursor is pointing to the null object. + #[inline] + pub fn as_pointer(&self) -> Option> { + self.cursor.as_cursor().clone_pointer() + } + + /// Consume the cursor and convert it into a borrow of the current entity, or `None` if null. + #[inline] + pub fn into_borrow(self) -> Option> { + self.cursor.into_ref().map(|item| item.borrow()) + } + + /// Consume the cursor and convert it into a mutable borrow of the current entity, or `None` if null. + #[inline] + pub fn into_borrow_mut(self) -> Option> { + self.cursor.into_ref().map(|item| item.borrow_mut()) + } + + /// Moves the cursor to the next element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will move it to the front of the + /// [EntityList]. If it is pointing to the back of the [EntityList] then this will move it to + /// the null object. + #[inline] + pub fn move_next(&mut self) { + self.cursor.move_next(); + } + + /// Moves the cursor to the previous element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will move it to the back of the + /// [EntityList]. If it is pointing to the front of the [EntityList] then this will move it to + /// the null object. + #[inline] + pub fn move_prev(&mut self) { + self.cursor.move_prev(); + } + + /// Returns a cursor pointing to the next element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will return a cursor pointing to the + /// front of the [EntityList]. If it is pointing to the last entity of the [EntityList] then + /// this will return a null cursor. + #[inline] + pub fn peek_next(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.peek_next(), + } + } + + /// Returns a cursor pointing to the previous element of the [EntityList]. + /// + /// If the cursor is pointing to the null object then this will return a cursor pointing to + /// the last entity in the [EntityList]. If it is pointing to the front of the [EntityList] then + /// this will return a null cursor. + #[inline] + pub fn peek_prev(&self) -> EntityCursor<'_, T> { + EntityCursor { + cursor: self.cursor.peek_prev(), + } + } + + /// Removes the current entity from the [EntityList]. + /// + /// A pointer to the element that was removed is returned, and the cursor is moved to point to + /// the next element in the [Entitylist]. + /// + /// If the cursor is currently pointing to the null object then nothing is removed and `None` is + /// returned. + #[inline] + pub fn remove(&mut self) -> Option> { + self.cursor.remove() + } + + /// Removes the current entity from the [EntityList] and inserts another one in its place. + /// + /// A pointer to the entity that was removed is returned, and the cursor is modified to point to + /// the newly added entity. + /// + /// If the cursor is currently pointing to the null object then `Err` is returned containing the + /// entity we failed to insert. + /// + /// # Panics + /// Panics if the new entity is already linked to a different intrusive collection. + #[inline] + pub fn replace_with( + &mut self, + value: UnsafeIntrusiveEntityRef, + ) -> Result, UnsafeIntrusiveEntityRef> { + self.cursor.replace_with(value) + } + + /// Inserts a new entity into the [EntityList], after the current cursor position. + /// + /// If the cursor is pointing at the null object then the entity is inserted at the start of the + /// underlying [EntityList]. + /// + /// # Panics + /// + /// Panics if the entity is already linked to a different [EntityList] + #[inline] + pub fn insert_after(&mut self, value: UnsafeIntrusiveEntityRef) { + self.cursor.insert_after(value) + } + + /// Inserts a new entity into the [EntityList], before the current cursor position. + /// + /// If the cursor is pointing at the null object then the entity is inserted at the end of the + /// underlying [EntityList]. + /// + /// # Panics + /// + /// Panics if the entity is already linked to a different [EntityList] + #[inline] + pub fn insert_before(&mut self, value: UnsafeIntrusiveEntityRef) { + self.cursor.insert_before(value) + } + + /// This splices `list` into the underlying list of `self` by inserting the elements of `list` + /// after the current cursor position. + /// + /// For example, let's say we have the following list and cursor position: + /// + /// ```text,ignore + /// [A, B, C] + /// ^-- cursor + /// ``` + /// + /// Splicing a new list, `[D, E, F]` after the cursor would result in: + /// + /// ```text,ignore + /// [A, B, D, E, F, C] + /// ^-- cursor + /// ``` + /// + /// If the cursor is pointing at the null object, then `list` is appended to the start of the + /// underlying [EntityList] for this cursor. + #[inline] + pub fn splice_after(&mut self, list: EntityList) { + self.cursor.splice_after(list.list) + } + + /// This splices `list` into the underlying list of `self` by inserting the elements of `list` + /// before the current cursor position. + /// + /// For example, let's say we have the following list and cursor position: + /// + /// ```text,ignore + /// [A, B, C] + /// ^-- cursor + /// ``` + /// + /// Splicing a new list, `[D, E, F]` before the cursor would result in: + /// + /// ```text,ignore + /// [A, D, E, F, B, C] + /// ^-- cursor + /// ``` + /// + /// If the cursor is pointing at the null object, then `list` is appended to the end of the + /// underlying [EntityList] for this cursor. + #[inline] + pub fn splice_before(&mut self, list: EntityList) { + self.cursor.splice_before(list.list) + } + + /// Splits the list into two after the current cursor position. + /// + /// This will return a new list consisting of everything after the cursor, with the original + /// list retaining everything before. + /// + /// If the cursor is pointing at the null object then the entire contents of the [EntityList] + /// are moved. + pub fn split_after(&mut self) -> EntityList { + let list = self.cursor.split_after(); + EntityList { list } + } + + /// Splits the list into two before the current cursor position. + /// + /// This will return a new list consisting of everything before the cursor, with the original + /// list retaining everything after. + /// + /// If the cursor is pointing at the null object then the entire contents of the [EntityList] + /// are moved. + pub fn split_before(&mut self) -> EntityList { + let list = self.cursor.split_before(); + EntityList { list } + } +} + +pub struct EntityIter<'a, T> { + cursor: EntityCursor<'a, T>, + started: bool, +} +impl<'a, T> core::iter::FusedIterator for EntityIter<'a, T> {} +impl<'a, T> Iterator for EntityIter<'a, T> { + type Item = EntityRef<'a, T>; + + fn next(&mut self) -> Option { + // If we haven't started iterating yet, then we're on the null cursor, so move to the + // front of the list now that we have started iterating. + if !self.started { + self.started = true; + self.cursor.move_next(); + } + let item = self.cursor.get()?; + self.cursor.move_next(); + Some(item) + } +} +impl<'a, T> DoubleEndedIterator for EntityIter<'a, T> { + fn next_back(&mut self) -> Option { + // If we haven't started iterating yet, then we're on the null cursor, so move to the + // back of the list now that we have started iterating. + if !self.started { + self.started = true; + self.cursor.move_prev(); + } + let item = self.cursor.get()?; + self.cursor.move_prev(); + Some(item) + } +} + +pub struct MaybeDefaultEntityIter<'a, T> { + iter: Option>, +} +impl<'a, T> Default for MaybeDefaultEntityIter<'a, T> { + fn default() -> Self { + Self { iter: None } + } +} +impl<'a, T> core::iter::FusedIterator for MaybeDefaultEntityIter<'a, T> {} +impl<'a, T> Iterator for MaybeDefaultEntityIter<'a, T> { + type Item = EntityRef<'a, T>; + + #[inline] + fn next(&mut self) -> Option { + self.iter.as_mut().and_then(|iter| iter.next()) + } +} +impl<'a, T> DoubleEndedIterator for MaybeDefaultEntityIter<'a, T> { + #[inline] + fn next_back(&mut self) -> Option { + self.iter.as_mut().and_then(|iter| iter.next_back()) + } +} + +pub type IntrusiveLink = intrusive_collections::LinkedListLink; + +impl RawEntityRef { + /// Create a new [UnsafeIntrusiveEntityRef] by allocating `value` in `arena` + /// + /// # SAFETY + /// + /// This function has the same requirements around safety as [RawEntityRef::new]. + pub fn new(value: T, arena: &blink_alloc::Blink) -> Self { + RawEntityRef::new_with_metadata(value, IntrusiveLink::new(), arena) + } + + pub fn new_uninit(arena: &blink_alloc::Blink) -> RawEntityRef, IntrusiveLink> { + RawEntityRef::new_uninit_with_metadata(IntrusiveLink::new(), arena) + } +} +impl RawEntityRef { + /// Returns true if this entity is linked into an intrusive list + pub fn is_linked(&self) -> bool { + unsafe { + let offset = core::mem::offset_of!(RawEntityMetadata, metadata); + let current = self.inner.byte_add(offset).cast::(); + current.as_ref().is_linked() + } + } + + /// Get the previous entity in the list of `T` containing the current entity + /// + /// For example, in a list of `Operation` in a `Block`, this would return the handle of the + /// previous operation in the block, or `None` if there are no other ops before this one. + pub fn prev(&self) -> Option { + use intrusive_collections::linked_list::{LinkOps, LinkedListOps}; + unsafe { + let offset = core::mem::offset_of!(RawEntityMetadata, metadata); + let current = self.inner.byte_add(offset).cast(); + LinkOps.prev(current).map(|link_ptr| Self::from_link_ptr(link_ptr)) + } + } + + /// Get the next entity in the list of `T` containing the current entity + /// + /// For example, in a list of `Operation` in a `Block`, this would return the handle of the + /// next operation in the block, or `None` if there are no other ops after this one. + pub fn next(&self) -> Option { + use intrusive_collections::linked_list::{LinkOps, LinkedListOps}; + unsafe { + let offset = core::mem::offset_of!(RawEntityMetadata, metadata); + let current = self.inner.byte_add(offset).cast(); + LinkOps.next(current).map(|link_ptr| Self::from_link_ptr(link_ptr)) + } + } + + #[inline] + unsafe fn from_link_ptr(link: NonNull) -> Self { + let offset = core::mem::offset_of!(RawEntityMetadata, metadata); + let ptr = link.byte_sub(offset).cast::>(); + Self { inner: ptr } + } +} + +#[doc(hidden)] +pub struct DefaultPointerOps(core::marker::PhantomData); +impl Copy for DefaultPointerOps {} +impl Clone for DefaultPointerOps { + fn clone(&self) -> Self { + *self + } +} +impl Default for DefaultPointerOps { + fn default() -> Self { + Self::new() + } +} +impl DefaultPointerOps { + const fn new() -> Self { + Self(core::marker::PhantomData) + } +} + +unsafe impl intrusive_collections::PointerOps + for DefaultPointerOps> +{ + type Pointer = UnsafeIntrusiveEntityRef; + type Value = RawEntityMetadata; + + #[inline] + unsafe fn from_raw(&self, value: *const Self::Value) -> Self::Pointer { + debug_assert!(!value.is_null() && value.is_aligned()); + UnsafeIntrusiveEntityRef::from_ptr(value.cast_mut()) + } + + #[inline] + fn into_raw(&self, ptr: Self::Pointer) -> *const Self::Value { + UnsafeIntrusiveEntityRef::into_inner(ptr).as_ptr().cast_const() + } +} + +/// An adapter for storing any `Entity` impl in a [intrusive_collections::LinkedList] +pub struct EntityAdapter { + link_ops: intrusive_collections::linked_list::LinkOps, + ptr_ops: DefaultPointerOps>, + marker: core::marker::PhantomData, +} +impl Copy for EntityAdapter {} +impl Clone for EntityAdapter { + fn clone(&self) -> Self { + *self + } +} +impl Default for EntityAdapter { + fn default() -> Self { + Self::new() + } +} +impl EntityAdapter { + pub const fn new() -> Self { + Self { + link_ops: intrusive_collections::linked_list::LinkOps, + ptr_ops: DefaultPointerOps::new(), + marker: core::marker::PhantomData, + } + } +} + +unsafe impl intrusive_collections::Adapter for EntityAdapter { + type LinkOps = intrusive_collections::linked_list::LinkOps; + type PointerOps = DefaultPointerOps>; + + unsafe fn get_value( + &self, + link: ::LinkPtr, + ) -> *const ::Value { + let raw_entity_ref = UnsafeIntrusiveEntityRef::::from_link_ptr(link); + raw_entity_ref.inner.as_ptr().cast_const() + } + + unsafe fn get_link( + &self, + value: *const ::Value, + ) -> ::LinkPtr { + let raw_entity_ref = UnsafeIntrusiveEntityRef::from_ptr(value.cast_mut()); + let offset = RawEntityMetadata::::metadata_offset(); + raw_entity_ref.inner.byte_add(offset).cast() + } + + fn link_ops(&self) -> &Self::LinkOps { + &self.link_ops + } + + fn link_ops_mut(&mut self) -> &mut Self::LinkOps { + &mut self.link_ops + } + + fn pointer_ops(&self) -> &Self::PointerOps { + &self.ptr_ops + } +} diff --git a/hir2/src/ir/entity/storage.rs b/hir2/src/ir/entity/storage.rs new file mode 100644 index 000000000..76a918cc6 --- /dev/null +++ b/hir2/src/ir/entity/storage.rs @@ -0,0 +1,742 @@ +use core::fmt; + +use smallvec::{smallvec, SmallVec}; + +use super::{EntityGroup, StorableEntity}; + +/// [EntityStorage] provides an abstraction over storing IR entities in an [crate::Operation]. +/// +/// Specifically, it provides an abstraction for storing IR entities in a flat vector, while +/// retaining the ability to semantically group the entities, access them by group or individually, +/// and grow or shrink the group or overall set. +/// +/// The implementation expects the types stored in it to implement the [StorableEntity] trait, which +/// provides it the ability to ensure the entity is kept up to date with its position in the +/// set. Additionally, it ensures that removing an entity will unlink that entity from any +/// dependents or dependencies that it needs to maintain links for. +/// +/// Users can control the number of entities stored inline via the `INLINE` const parameter. By +/// default, only a single entity is stored inline, but sometimes more may be desired if you know +/// that a particular entity always has a particular cardinality. +pub struct EntityStorage { + /// The items being stored + items: SmallVec<[T; INLINE]>, + /// The semantic grouping information for this instance. + /// + /// There is always at least one group, and more can be explicitly added/removed. + groups: SmallVec<[EntityGroup; 2]>, +} + +impl fmt::Debug for EntityStorage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(core::any::type_name::()) + .field_with("groups", |f| { + let mut builder = f.debug_list(); + for group in self.groups.iter() { + let range = group.as_range(); + let items = &self.items[range.clone()]; + builder.entry_with(|f| { + f.debug_map().entry(&"range", &range).entry(&"items", &items).finish() + }); + } + builder.finish() + }) + .finish() + } +} + +impl Default for EntityStorage { + fn default() -> Self { + Self { + items: Default::default(), + groups: smallvec![EntityGroup::default()], + } + } +} + +impl EntityStorage { + /// Returns true if there are no items in storage. + #[inline] + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Returns the total number of items in storage. + #[inline] + pub fn len(&self) -> usize { + self.items.len() + } + + /// Returns the number of groups with allocated storage. + #[inline] + pub fn num_groups(&self) -> usize { + self.groups.len() + } + + /// Get an iterator over all of the items in storage + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.items.iter() + } + + /// Get a mutable iterator over all of the items in storage + #[inline] + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, T> { + self.items.iter_mut() + } +} + +impl EntityStorage { + /// Push an item to the last group + pub fn push(&mut self, mut item: T) { + let index = self.items.len(); + unsafe { + item.set_index(index); + } + self.items.push(item); + let group = self.groups.last_mut().unwrap(); + group.grow(1); + } + + /// Extend the last group with `items` + pub fn extend(&mut self, items: I) + where + I: IntoIterator, + { + let mut group = self.group_mut(self.groups.len() - 1); + group.extend(items); + } + + /// Push `items` as a new group, and return the group index + #[inline] + pub fn push_group(&mut self, items: impl IntoIterator) -> usize { + let group = self.groups.len(); + self.extend_group(group, items); + group + } + + /// Push `item` to the specified group + pub fn push_to_group(&mut self, group: usize, item: T) { + if self.groups.len() <= group { + let next_offset = self.groups.last().map(|group| group.as_range().end).unwrap_or(0); + self.groups.resize(group + 1, EntityGroup::new(next_offset, 0)); + } + let mut group = self.group_mut(group); + group.push(item); + } + + /// Pushes `items` to the given group, creating it if necessary, and allocating any intervening + /// implied groups if they have not been created it. + pub fn extend_group(&mut self, group: usize, items: I) + where + I: IntoIterator, + { + if self.groups.len() <= group { + let next_offset = self.groups.last().map(|group| group.as_range().end).unwrap_or(0); + self.groups.resize(group + 1, EntityGroup::new(next_offset, 0)); + } + let mut group = self.group_mut(group); + group.extend(items); + } + + /// Clear all items in storage + pub fn clear(&mut self) { + for mut item in self.items.drain(..) { + item.unlink(); + } + self.groups.clear(); + self.groups.push(EntityGroup::default()); + } + + /// Get all the items in storage + pub fn all(&self) -> EntityRange<'_, T> { + EntityRange { + range: 0..self.items.len(), + items: self.items.as_slice(), + } + } + + /// Get an [EntityRange] covering items in the specified group + pub fn group(&self, group: usize) -> EntityRange<'_, T> { + EntityRange { + range: self.groups[group].as_range(), + items: self.items.as_slice(), + } + } + + /// Get an [EntityRangeMut] covering items in the specified group + pub fn group_mut(&mut self, group: usize) -> EntityRangeMut<'_, T, INLINE> { + let range = self.groups[group].as_range(); + EntityRangeMut { + group, + range, + groups: &mut self.groups, + items: &mut self.items, + } + } +} +impl core::ops::Index for EntityStorage { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.items[index] + } +} +impl core::ops::IndexMut for EntityStorage { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.items[index] + } +} + +/// A reference to a range of items in [EntityStorage] +#[derive(Clone)] +pub struct EntityRange<'a, T> { + range: core::ops::Range, + items: &'a [T], +} +impl<'a, T> EntityRange<'a, T> { + /// Get an empty range + pub fn empty() -> Self { + Self { + range: 0..0, + items: &[], + } + } + + /// Returns true if this range is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.as_slice().is_empty() + } + + /// Returns the size of this range + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + /// Get this range as a slice + #[inline] + pub fn as_slice(&self) -> &[T] { + if self.range.is_empty() { + &[] + } else { + &self.items[self.range.start..self.range.end] + } + } + + /// Get an iterator over the items in this range + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.as_slice().iter() + } + + /// Get an item at the specified index relative to this range, or `None` if the index is out of bounds. + #[inline] + pub fn get(&self, index: usize) -> Option<&T> { + self.as_slice().get(index) + } +} +impl<'a, T> core::ops::Index for EntityRange<'a, T> { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.as_slice()[index] + } +} +impl<'a, T> IntoIterator for EntityRange<'a, T> { + type IntoIter = core::slice::Iter<'a, T>; + type Item = &'a T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.iter() + } +} + +/// A mutable range of items in [EntityStorage] +/// +/// Items outside the range are not modified, however the range itself can have its size change, +/// which as a result will shift other items around. Any other groups in [EntityStorage] will +/// be updated to reflect such changes, so in general this should be transparent. +pub struct EntityRangeMut<'a, T, const INLINE: usize = 1> { + group: usize, + range: core::ops::Range, + groups: &'a mut [EntityGroup], + items: &'a mut SmallVec<[T; INLINE]>, +} +impl<'a, T, const INLINE: usize> EntityRangeMut<'a, T, INLINE> { + /// Returns true if this range is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.as_slice().is_empty() + } + + /// Get the number of items covered by this range + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + /// Temporarily borrow this range immutably + pub fn as_immutable(&self) -> EntityRange<'_, T> { + EntityRange { + range: self.range.clone(), + items: self.items.as_slice(), + } + } + + /// Get this range as a slice + #[inline] + pub fn as_slice(&self) -> &[T] { + if self.range.is_empty() { + &[] + } else { + &self.items[self.range.start..self.range.end] + } + } + + /// Get this range as a mutable slice + #[inline] + pub fn as_slice_mut(&mut self) -> &mut [T] { + if self.range.is_empty() { + &mut [] + } else { + &mut self.items[self.range.start..self.range.end] + } + } + + /// Get an iterator over the items covered by this range + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, T> { + self.as_slice().iter() + } + + /// Get a mutable iterator over the items covered by this range + #[inline] + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, T> { + self.as_slice_mut().iter_mut() + } + + /// Get a reference to the item at `index`, relative to the start of this range. + #[inline] + pub fn get(&self, index: usize) -> Option<&T> { + self.as_slice().get(index) + } + + /// Get a mutable reference to the item at `index`, relative to the start of this range. + #[inline] + pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { + self.as_slice_mut().get_mut(index) + } +} + +impl<'a, T: StorableEntity, const INLINE: usize> EntityRangeMut<'a, T, INLINE> { + /// Append `item` to storage at the end of this range + #[inline] + pub fn push(&mut self, item: T) { + self.extend([item]); + } + + /// Append `items` to storage at the end of this range + pub fn extend(&mut self, operands: I) + where + I: IntoIterator, + { + // Handle edge case where group is the last group + let is_last = self.groups.len() == self.group + 1; + if is_last { + self.extend_last(operands); + } else { + self.extend_within(operands); + } + } + + fn extend_last(&mut self, items: I) + where + I: IntoIterator, + { + let prev_len = self.items.len(); + self.items.extend(items.into_iter().enumerate().map(|(i, mut item)| { + unsafe { + item.set_index(prev_len + i); + } + item + })); + let num_inserted = self.items.len().abs_diff(prev_len); + if num_inserted == 0 { + return; + } + self.groups[self.group].grow(num_inserted); + self.range = self.groups[self.group].as_range(); + } + + fn extend_within(&mut self, items: I) + where + I: IntoIterator, + { + let prev_len = self.items.len(); + let start = self.range.end; + self.items.insert_many( + start, + items.into_iter().enumerate().map(|(i, mut item)| { + unsafe { + item.set_index(start + i); + } + item + }), + ); + let num_inserted = self.items.len().abs_diff(prev_len); + if num_inserted == 0 { + return; + } + self.groups[self.group].grow(num_inserted); + self.range = self.groups[self.group].as_range(); + + // Shift groups + for group in self.groups[(self.group + 1)..].iter_mut() { + group.shift_start(num_inserted as isize); + } + + // Shift item indices + let shifted = self.range.end; + for (offset, item) in self.items[shifted..].iter_mut().enumerate() { + unsafe { + item.set_index(shifted + offset); + } + } + } + + /// Remove the item at `index` in this group, and return it. + /// + /// NOTE: This will panic if `index` is out of bounds of the group. + pub fn erase(&mut self, index: usize) -> T { + assert!(self.range.len() > index, "index out of bounds"); + self.range.end -= 1; + self.groups[self.group].shrink(1); + let mut removed = self.items.remove(self.range.start + index); + { + removed.unlink(); + } + + // Shift groups + let next_group = self.group + 1; + if next_group < self.groups.len() { + for group in self.groups[next_group..].iter_mut() { + group.shift_start(-1); + } + } + + // Shift item indices + let next_item = index; + if next_item < self.items.len() { + for (offset, item) in self.items[next_item..].iter_mut().enumerate() { + unsafe { + item.set_index(next_item + offset); + } + } + } + + removed + } + + /// Remove the last item from this group, or `None` if empty + pub fn pop(&mut self) -> Option { + if self.range.is_empty() { + return None; + } + let index = self.range.end - 1; + self.range.end = index; + self.groups[self.group].shrink(1); + let mut removed = self.items.remove(index); + { + removed.unlink(); + } + + // Shift groups + let next_group = self.group + 1; + if next_group < self.groups.len() { + for group in self.groups[next_group..].iter_mut() { + group.shift_start(-1); + } + } + + // Shift item indices + let next_item = index; + if next_item < self.items.len() { + for (offset, item) in self.items[next_item..].iter_mut().enumerate() { + unsafe { + item.set_index(next_item + offset); + } + } + } + + Some(removed) + } +} +impl<'a, T, const INLINE: usize> core::ops::Index for EntityRangeMut<'a, T, INLINE> { + type Output = T; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.as_slice()[index] + } +} +impl<'a, T, const INLINE: usize> core::ops::IndexMut for EntityRangeMut<'a, T, INLINE> { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.as_slice_mut()[index] + } +} +impl<'a, T> IntoIterator for EntityRangeMut<'a, T> { + type IntoIter = core::slice::IterMut<'a, T>; + type Item = &'a mut T; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + struct Item { + index: usize, + value: usize, + } + impl Item { + pub fn new(value: usize) -> Self { + Self { index: 0, value } + } + } + impl StorableEntity for Item { + fn index(&self) -> usize { + self.index + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index; + } + + fn unlink(&mut self) {} + } + + type ItemStorage = EntityStorage; + #[allow(unused)] + type ItemRange<'a> = EntityRange<'a, Item>; + #[allow(unused)] + type ItemRangeMut<'a> = EntityRangeMut<'a, Item, 1>; + + #[test] + fn entity_storage_empty_operations() { + let mut storage = ItemStorage::default(); + + // No items, but always have at least one group + assert_eq!(storage.len(), 0); + assert!(storage.is_empty()); + assert_eq!(storage.num_groups(), 1); + + { + let range = storage.all(); + assert_eq!(range.len(), 0); + assert!(range.is_empty()); + assert_eq!(range.as_slice(), &[]); + assert_eq!(range.iter().next(), None); + } + + // No items, two groups + let group = storage.push_group(None); + assert_eq!(group, 1); + assert_eq!(storage.num_groups(), 2); + assert_eq!(storage.len(), 0); + assert!(storage.is_empty()); + + { + let range = storage.group(0); + assert_eq!(range.len(), 0); + assert!(range.is_empty()); + assert_eq!(range.as_slice(), &[]); + assert_eq!(range.iter().next(), None); + } + } + + #[test] + fn entity_storage_push_to_empty_group_entity_range() { + let mut storage = ItemStorage::default(); + + // Get group as mutable range + let mut group_range = storage.group_mut(0); + + // Verify handling of empty group in EntityRangeMut + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + + // Push items to range + group_range.push(Item::new(0)); + group_range.push(Item::new(1)); + + // Verify range reflects changes + assert_eq!(group_range.len(), 2); + assert!(!group_range.is_empty()); + assert_eq!( + group_range.as_slice(), + &[Item { index: 0, value: 0 }, Item { index: 1, value: 1 }] + ); + assert_eq!(group_range.iter().next(), Some(&Item { index: 0, value: 0 })); + } + + #[test] + fn entity_storage_pop_from_non_empty_group_entity_range() { + let mut storage = ItemStorage::default(); + + assert_eq!(storage.num_groups(), 1); + storage.push_to_group(0, Item::new(0)); + assert_eq!(storage.len(), 1); + assert!(!storage.is_empty()); + + // Get group as mutable range + let mut group_range = storage.group_mut(0); + assert_eq!(group_range.len(), 1); + assert!(!group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[Item { index: 0, value: 0 }]); + assert_eq!(group_range.iter().next(), Some(&Item { index: 0, value: 0 })); + + // Pop item from range + let item = group_range.pop(); + assert_eq!(item, Some(Item { index: 0, value: 0 })); + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + assert_eq!(group_range.range.clone(), 0..0); + + // Pop from empty range should have no effect + let item = group_range.pop(); + assert_eq!(item, None); + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + assert_eq!(group_range.range.clone(), 0..0); + } + + #[test] + fn entity_storage_push_to_empty_group_entity_range_before_other_groups() { + let mut storage = ItemStorage::default(); + + storage.extend_group(0, [Item::new(0), Item::new(1)]); + let group1 = storage.push_group(None); + let group2 = storage.push_group(None); + let group3 = storage.push_group([Item::new(4), Item::new(5)]); + + assert!(!storage.is_empty()); + assert_eq!(storage.len(), 4); + assert_eq!(storage.num_groups(), 4); + + assert_eq!(storage.group(0).range.clone(), 0..2); + assert_eq!(storage.group(1).range.clone(), 2..2); + assert_eq!(storage.group(2).range.clone(), 2..2); + assert_eq!(storage.group(3).range.clone(), 2..4); + + // Insert items into first non-empty group + { + let mut group_range = storage.group_mut(group1); + + // Verify handling of empty group in EntityRangeMut + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + + // Push items to range + group_range.push(Item::new(2)); + group_range.push(Item::new(3)); + + // Verify range reflects changes + assert_eq!(group_range.len(), 2); + assert!(!group_range.is_empty()); + assert_eq!( + group_range.as_slice(), + &[Item { index: 2, value: 2 }, Item { index: 3, value: 3 }] + ); + assert_eq!(group_range.iter().next(), Some(&Item { index: 2, value: 2 })); + } + + // The subsequent empty group should still be empty, but at a new offset + let group_range = storage.group(group2); + assert_eq!(group_range.range.clone(), 4..4); + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + + // The trailing non-empty group should have updated offsets + let group_range = storage.group(group3); + assert_eq!(group_range.range.clone(), 4..6); + assert_eq!(group_range.len(), 2); + assert!(!group_range.is_empty()); + assert_eq!( + group_range.as_slice(), + &[Item { index: 4, value: 4 }, Item { index: 5, value: 5 }] + ); + assert_eq!(group_range.iter().next(), Some(&Item { index: 4, value: 4 })); + } + + #[test] + fn entity_storage_pop_from_non_empty_group_entity_range_before_other_groups() { + let mut storage = ItemStorage::default(); + + storage.extend_group(0, [Item::new(0), Item::new(1)]); + let group1 = storage.push_group(None); + let group2 = storage.push_group(None); + let group3 = storage.push_group([Item::new(4), Item::new(5)]); + + assert!(!storage.is_empty()); + assert_eq!(storage.len(), 4); + assert_eq!(storage.num_groups(), 4); + + assert_eq!(storage.group(0).range.clone(), 0..2); + assert_eq!(storage.group(1).range.clone(), 2..2); + assert_eq!(storage.group(2).range.clone(), 2..2); + assert_eq!(storage.group(3).range.clone(), 2..4); + + // Pop from group0 + { + let mut group_range = storage.group_mut(0); + let item = group_range.pop(); + assert_eq!(item, Some(Item { index: 1, value: 1 })); + assert_eq!(group_range.len(), 1); + assert!(!group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[Item { index: 0, value: 0 }]); + } + + // The subsequent empty group(s) should still be empty, but at a new offset + for group_index in [group1, group2] { + let group_range = storage.group(group_index); + assert_eq!(group_range.range.clone(), 1..1); + assert_eq!(group_range.len(), 0); + assert!(group_range.is_empty()); + assert_eq!(group_range.as_slice(), &[]); + assert_eq!(group_range.iter().next(), None); + } + + // The trailing non-empty group should have updated offsets + let group_range = storage.group(group3); + assert_eq!(group_range.range.clone(), 1..3); + assert_eq!(group_range.len(), 2); + assert!(!group_range.is_empty()); + assert_eq!( + group_range.as_slice(), + &[Item { index: 1, value: 4 }, Item { index: 2, value: 5 }] + ); + assert_eq!(group_range.iter().next(), Some(&Item { index: 1, value: 4 })); + } +} diff --git a/hir2/src/ir/ident.rs b/hir2/src/ir/ident.rs new file mode 100644 index 000000000..2ec6c8125 --- /dev/null +++ b/hir2/src/ir/ident.rs @@ -0,0 +1,235 @@ +use core::{ + cmp::Ordering, + fmt, + hash::{Hash, Hasher}, + str::FromStr, +}; + +use anyhow::anyhow; + +use super::{ + interner::{symbols, Symbol}, + SourceSpan, Spanned, +}; +use crate::{ + define_attr_type, + formatter::{self, PrettyPrint}, +}; + +/// Represents a globally-unique module/function name pair, with corresponding source spans. +#[derive(Copy, Clone, PartialEq, Eq, Hash, Spanned)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct FunctionIdent { + pub module: Ident, + #[span] + pub function: Ident, +} +define_attr_type!(FunctionIdent); +impl FunctionIdent { + pub fn display(&self) -> impl fmt::Display + '_ { + use crate::formatter::*; + + flatten( + const_text(self.module.as_str()) + + const_text("::") + + const_text(self.function.as_str()), + ) + } +} +impl FromStr for FunctionIdent { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.rsplit_once("::") { + Some((ns, id)) => { + let module = Ident::with_empty_span(Symbol::intern(ns)); + let function = Ident::with_empty_span(Symbol::intern(id)); + Ok(Self { module, function }) + } + None => Err(anyhow!( + "invalid function name, expected fully-qualified identifier, e.g. \ + 'std::math::u64::checked_add'" + )), + } + } +} +impl fmt::Debug for FunctionIdent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("FunctionIdent") + .field("module", &self.module.name) + .field("function", &self.function.name) + .finish() + } +} +impl fmt::Display for FunctionIdent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.pretty_print(f) + } +} +impl PrettyPrint for FunctionIdent { + fn render(&self) -> formatter::Document { + use crate::formatter::*; + + flatten(const_text("(") + display(self.module) + const_text(" ") + display(self.function)) + } +} +impl PartialOrd for FunctionIdent { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for FunctionIdent { + fn cmp(&self, other: &Self) -> Ordering { + self.module.cmp(&other.module).then(self.function.cmp(&other.function)) + } +} + +/// Represents an identifier in the IR. +/// +/// An identifier is some string, along with an associated source span +#[derive(Copy, Clone, Eq, Spanned)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(into = "Symbol", from = "Symbol"))] +pub struct Ident { + pub name: Symbol, + #[span] + pub span: SourceSpan, +} +define_attr_type!(Ident); +impl Default for Ident { + fn default() -> Self { + Self { + name: symbols::Empty, + span: SourceSpan::UNKNOWN, + } + } +} +impl FromStr for Ident { + type Err = core::convert::Infallible; + + fn from_str(name: &str) -> Result { + Ok(Self::from(name)) + } +} +impl<'a> From<&'a str> for Ident { + fn from(name: &'a str) -> Self { + Self::with_empty_span(Symbol::intern(name)) + } +} +impl From for Ident { + #[inline] + fn from(sym: Symbol) -> Self { + Self::with_empty_span(sym) + } +} +impl From for Symbol { + #[inline] + fn from(id: Ident) -> Self { + id.as_symbol() + } +} +impl Ident { + #[inline] + pub const fn new(name: Symbol, span: SourceSpan) -> Ident { + Ident { name, span } + } + + #[inline] + pub const fn with_empty_span(name: Symbol) -> Ident { + Ident::new(name, SourceSpan::UNKNOWN) + } + + #[inline] + pub fn as_str(self) -> &'static str { + self.name.as_str() + } + + #[inline(always)] + pub fn as_symbol(self) -> Symbol { + self.name + } + + // An identifier can be unquoted if is composed of any sequence of printable + // ASCII characters, except whitespace, quotation marks, comma, semicolon, or brackets + pub fn requires_quoting(&self) -> bool { + self.as_str().contains(|c| match c { + c if c.is_ascii_control() => true, + ' ' | '\'' | '"' | ',' | ';' | '[' | ']' => true, + c if c.is_ascii_graphic() => false, + _ => true, + }) + } +} +impl AsRef for Ident { + #[inline(always)] + fn as_ref(&self) -> &str { + self.as_str() + } +} +impl alloc::borrow::Borrow for Ident { + #[inline] + fn borrow(&self) -> &Symbol { + &self.name + } +} +impl alloc::borrow::Borrow for Ident { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} +impl Ord for Ident { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + self.as_str().cmp(other.as_str()) + } +} +impl PartialOrd for Ident { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl PartialEq for Ident { + #[inline] + fn eq(&self, rhs: &Self) -> bool { + self.name == rhs.name + } +} +impl PartialEq for Ident { + #[inline] + fn eq(&self, rhs: &Symbol) -> bool { + self.name.eq(rhs) + } +} +impl PartialEq for Ident { + fn eq(&self, rhs: &str) -> bool { + self.name.as_str() == rhs + } +} +impl Hash for Ident { + fn hash(&self, state: &mut H) { + self.name.hash(state); + } +} +impl fmt::Debug for Ident { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.pretty_print(f) + } +} +impl fmt::Display for Ident { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.pretty_print(f) + } +} +impl PrettyPrint for Ident { + fn render(&self) -> formatter::Document { + use crate::formatter::*; + + if self.requires_quoting() { + text(format!("\"{}\"", self.as_str().escape_default())) + } else { + text(format!("#{}", self.as_str())) + } + } +} diff --git a/hir2/src/ir/immediates.rs b/hir2/src/ir/immediates.rs new file mode 100644 index 000000000..57d4795f5 --- /dev/null +++ b/hir2/src/ir/immediates.rs @@ -0,0 +1,798 @@ +use core::{ + fmt, + hash::{Hash, Hasher}, +}; + +pub use miden_core::{Felt, FieldElement, StarkField}; + +use crate::{formatter::PrettyPrint, Type}; + +#[derive(Debug, Copy, Clone)] +pub enum Immediate { + I1(bool), + U8(u8), + I8(i8), + U16(u16), + I16(i16), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + U128(u128), + I128(i128), + F64(f64), + Felt(Felt), +} +impl Immediate { + pub fn ty(&self) -> Type { + match self { + Self::I1(_) => Type::I1, + Self::U8(_) => Type::U8, + Self::I8(_) => Type::I8, + Self::U16(_) => Type::U16, + Self::I16(_) => Type::I16, + Self::U32(_) => Type::U32, + Self::I32(_) => Type::I32, + Self::U64(_) => Type::U64, + Self::I64(_) => Type::I64, + Self::U128(_) => Type::U128, + Self::I128(_) => Type::I128, + Self::F64(_) => Type::F64, + Self::Felt(_) => Type::Felt, + } + } + + /// Returns true if this immediate is a non-negative value + pub fn is_non_negative(&self) -> bool { + match self { + Self::I1(i) => *i, + Self::I8(i) => *i > 0, + Self::U8(i) => *i > 0, + Self::I16(i) => *i > 0, + Self::U16(i) => *i > 0, + Self::I32(i) => *i > 0, + Self::U32(i) => *i > 0, + Self::I64(i) => *i > 0, + Self::U64(i) => *i > 0, + Self::U128(i) => *i > 0, + Self::I128(i) => *i > 0, + Self::F64(f) => f.is_sign_positive(), + Self::Felt(_) => true, + } + } + + /// Returns true if this immediate can represent negative values + pub fn is_signed(&self) -> bool { + matches!( + self, + Self::I8(_) | Self::I16(_) | Self::I32(_) | Self::I64(_) | Self::I128(_) | Self::F64(_) + ) + } + + /// Returns true if this immediate can only represent non-negative values + pub fn is_unsigned(&self) -> bool { + matches!( + self, + Self::I1(_) + | Self::U8(_) + | Self::U16(_) + | Self::U32(_) + | Self::U64(_) + | Self::U128(_) + | Self::Felt(_) + ) + } + + /// Returns true if this immediate is an odd integer, otherwise false + /// + /// If the immediate is not an integer, returns `None` + pub fn is_odd(&self) -> Option { + match self { + Self::I1(b) => Some(*b), + Self::U8(i) => Some(*i % 2 == 0), + Self::I8(i) => Some(*i % 2 == 0), + Self::U16(i) => Some(*i % 2 == 0), + Self::I16(i) => Some(*i % 2 == 0), + Self::U32(i) => Some(*i % 2 == 0), + Self::I32(i) => Some(*i % 2 == 0), + Self::U64(i) => Some(*i % 2 == 0), + Self::I64(i) => Some(*i % 2 == 0), + Self::Felt(i) => Some(i.as_int() % 2 == 0), + Self::U128(i) => Some(*i % 2 == 0), + Self::I128(i) => Some(*i % 2 == 0), + Self::F64(_) => None, + } + } + + /// Returns true if this immediate is a non-zero integer, otherwise false + /// + /// If the immediate is not an integer, returns `None` + pub fn as_bool(self) -> Option { + match self { + Self::I1(b) => Some(b), + Self::U8(i) => Some(i != 0), + Self::I8(i) => Some(i != 0), + Self::U16(i) => Some(i != 0), + Self::I16(i) => Some(i != 0), + Self::U32(i) => Some(i != 0), + Self::I32(i) => Some(i != 0), + Self::U64(i) => Some(i != 0), + Self::I64(i) => Some(i != 0), + Self::Felt(i) => Some(i.as_int() != 0), + Self::U128(i) => Some(i != 0), + Self::I128(i) => Some(i != 0), + Self::F64(_) => None, + } + } + + /// Attempts to convert this value to a u32 + pub fn as_u32(self) -> Option { + match self { + Self::I1(b) => Some(b as u32), + Self::U8(b) => Some(b as u32), + Self::I8(b) if b >= 0 => Some(b as u32), + Self::I8(_) => None, + Self::U16(b) => Some(b as u32), + Self::I16(b) if b >= 0 => Some(b as u32), + Self::I16(_) => None, + Self::U32(b) => Some(b), + Self::I32(b) if b >= 0 => Some(b as u32), + Self::I32(_) => None, + Self::U64(b) => u32::try_from(b).ok(), + Self::I64(b) if b >= 0 => u32::try_from(b as u64).ok(), + Self::I64(_) => None, + Self::Felt(i) => u32::try_from(i.as_int()).ok(), + Self::U128(b) if b <= (u32::MAX as u64 as u128) => Some(b as u32), + Self::U128(_) => None, + Self::I128(b) if b >= 0 && b <= (u32::MAX as u64 as i128) => Some(b as u32), + Self::I128(_) => None, + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to i32 + pub fn as_i32(self) -> Option { + match self { + Self::I1(b) => Some(b as i32), + Self::U8(i) => Some(i as i32), + Self::I8(i) => Some(i as i32), + Self::U16(i) => Some(i as i32), + Self::I16(i) => Some(i as i32), + Self::U32(i) => i.try_into().ok(), + Self::I32(i) => Some(i), + Self::U64(i) => i.try_into().ok(), + Self::I64(i) => i.try_into().ok(), + Self::Felt(i) => i.as_int().try_into().ok(), + Self::U128(i) if i <= (i32::MAX as u32 as u128) => Some(i as u32 as i32), + Self::U128(_) => None, + Self::I128(i) if i >= (i32::MIN as i128) && i <= (i32::MAX as i128) => Some(i as i32), + Self::I128(_) => None, + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to a field element + pub fn as_felt(self) -> Option { + match self { + Self::I1(b) => Some(Felt::new(b as u64)), + Self::U8(b) => Some(Felt::new(b as u64)), + Self::I8(b) => u64::try_from(b).ok().map(Felt::new), + Self::U16(b) => Some(Felt::new(b as u64)), + Self::I16(b) => u64::try_from(b).ok().map(Felt::new), + Self::U32(b) => Some(Felt::new(b as u64)), + Self::I32(b) => u64::try_from(b).ok().map(Felt::new), + Self::U64(b) => Some(Felt::new(b)), + Self::I64(b) => u64::try_from(b).ok().map(Felt::new), + Self::Felt(i) => Some(i), + Self::U128(b) => u64::try_from(b).ok().map(Felt::new), + Self::I128(b) => u64::try_from(b).ok().map(Felt::new), + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to u64 + pub fn as_u64(self) -> Option { + match self { + Self::I1(b) => Some(b as u64), + Self::U8(i) => Some(i as u64), + Self::I8(i) if i >= 0 => Some(i as u64), + Self::I8(_) => None, + Self::U16(i) => Some(i as u64), + Self::I16(i) if i >= 0 => Some(i as u16 as u64), + Self::I16(_) => None, + Self::U32(i) => Some(i as u64), + Self::I32(i) if i >= 0 => Some(i as u32 as u64), + Self::I32(_) => None, + Self::U64(i) => Some(i), + Self::I64(i) if i >= 0 => Some(i as u64), + Self::I64(_) => None, + Self::Felt(i) => Some(i.as_int()), + Self::U128(i) => (i).try_into().ok(), + Self::I128(i) if i >= 0 => (i).try_into().ok(), + Self::I128(_) => None, + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to i64 + pub fn as_i64(self) -> Option { + match self { + Self::I1(b) => Some(b as i64), + Self::U8(i) => Some(i as i64), + Self::I8(i) => Some(i as i64), + Self::U16(i) => Some(i as i64), + Self::I16(i) => Some(i as i64), + Self::U32(i) => Some(i as i64), + Self::I32(i) => Some(i as i64), + Self::U64(i) => (i).try_into().ok(), + Self::I64(i) => Some(i), + Self::Felt(i) => i.as_int().try_into().ok(), + Self::U128(i) if i <= i64::MAX as u128 => Some(i as u64 as i64), + Self::U128(_) => None, + Self::I128(i) => (i).try_into().ok(), + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to u128 + pub fn as_u128(self) -> Option { + match self { + Self::I1(b) => Some(b as u128), + Self::U8(i) => Some(i as u128), + Self::I8(i) if i >= 0 => Some(i as u128), + Self::I8(_) => None, + Self::U16(i) => Some(i as u128), + Self::I16(i) if i >= 0 => Some(i as u16 as u128), + Self::I16(_) => None, + Self::U32(i) => Some(i as u128), + Self::I32(i) if i >= 0 => Some(i as u32 as u128), + Self::I32(_) => None, + Self::U64(i) => Some(i as u128), + Self::I64(i) if i >= 0 => Some(i as u128), + Self::I64(_) => None, + Self::Felt(i) => Some(i.as_int() as u128), + Self::U128(i) => Some(i), + Self::I128(i) if i >= 0 => (i).try_into().ok(), + Self::I128(_) => None, + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } + + /// Attempts to convert this value to i128 + pub fn as_i128(self) -> Option { + match self { + Self::I1(b) => Some(b as i128), + Self::U8(i) => Some(i as i128), + Self::I8(i) => Some(i as i128), + Self::U16(i) => Some(i as i128), + Self::I16(i) => Some(i as i128), + Self::U32(i) => Some(i as i128), + Self::I32(i) => Some(i as i128), + Self::U64(i) => Some(i as i128), + Self::I64(i) => Some(i as i128), + Self::Felt(i) => Some(i.as_int() as i128), + Self::U128(i) if i <= i128::MAX as u128 => Some(i as i128), + Self::U128(_) => None, + Self::I128(i) => Some(i), + Self::F64(f) => FloatToInt::::to_int(f).ok(), + } + } +} +impl fmt::Display for Immediate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::I1(i) => write!(f, "{}", i), + Self::U8(i) => write!(f, "{}", i), + Self::I8(i) => write!(f, "{}", i), + Self::U16(i) => write!(f, "{}", i), + Self::I16(i) => write!(f, "{}", i), + Self::U32(i) => write!(f, "{}", i), + Self::I32(i) => write!(f, "{}", i), + Self::U64(i) => write!(f, "{}", i), + Self::I64(i) => write!(f, "{}", i), + Self::U128(i) => write!(f, "{}", i), + Self::I128(i) => write!(f, "{}", i), + Self::F64(n) => write!(f, "{}", n), + Self::Felt(i) => write!(f, "{}", i), + } + } +} +impl PrettyPrint for Immediate { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + display(self) + } +} +impl Hash for Immediate { + fn hash(&self, state: &mut H) { + let d = std::mem::discriminant(self); + d.hash(state); + match self { + Self::I1(i) => i.hash(state), + Self::U8(i) => i.hash(state), + Self::I8(i) => i.hash(state), + Self::U16(i) => i.hash(state), + Self::I16(i) => i.hash(state), + Self::U32(i) => i.hash(state), + Self::I32(i) => i.hash(state), + Self::U64(i) => i.hash(state), + Self::I64(i) => i.hash(state), + Self::U128(i) => i.hash(state), + Self::I128(i) => i.hash(state), + Self::F64(f) => { + let bytes = f.to_be_bytes(); + bytes.hash(state) + } + Self::Felt(i) => i.as_int().hash(state), + } + } +} +impl Eq for Immediate {} +impl PartialEq for Immediate { + fn eq(&self, other: &Self) -> bool { + match (*self, *other) { + (Self::I8(x), Self::I8(y)) => x == y, + (Self::U16(x), Self::U16(y)) => x == y, + (Self::I16(x), Self::I16(y)) => x == y, + (Self::U32(x), Self::U32(y)) => x == y, + (Self::I32(x), Self::I32(y)) => x == y, + (Self::U64(x), Self::U64(y)) => x == y, + (Self::I64(x), Self::I64(y)) => x == y, + (Self::U128(x), Self::U128(y)) => x == y, + (Self::I128(x), Self::I128(y)) => x == y, + (Self::F64(x), Self::F64(y)) => x == y, + (Self::Felt(x), Self::Felt(y)) => x == y, + _ => false, + } + } +} +impl PartialEq for Immediate { + fn eq(&self, other: &isize) -> bool { + let y = *other; + match *self { + Self::I1(x) => x == (y == 1), + Self::U8(_) if y < 0 => false, + Self::U8(x) => x as isize == y, + Self::I8(x) => x as isize == y, + Self::U16(_) if y < 0 => false, + Self::U16(x) => x as isize == y, + Self::I16(x) => x as isize == y, + Self::U32(_) if y < 0 => false, + Self::U32(x) => x as isize == y, + Self::I32(x) => x as isize == y, + Self::U64(_) if y < 0 => false, + Self::U64(x) => x == y as i64 as u64, + Self::I64(x) => x == y as i64, + Self::U128(_) if y < 0 => false, + Self::U128(x) => x == y as i128 as u128, + Self::I128(x) => x == y as i128, + Self::F64(_) => false, + Self::Felt(_) if y < 0 => false, + Self::Felt(x) => x.as_int() == y as i64 as u64, + } + } +} +impl PartialOrd for Immediate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for Immediate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + use std::cmp::Ordering; + + match (self, other) { + // Floats require special treatment + (Self::F64(x), Self::F64(y)) => x.total_cmp(y), + // Here we're attempting to compare against any integer immediate, + // so we must attempt to convert the float to the largest possible + // integer representation, i128, and then promote the integer immediate + // to i128 for comparison + // + // If the float is not an integer value, truncate it and compare, then + // adjust the result to account for the truncation + (Self::F64(x), y) => { + let y = y + .as_i128() + .expect("expected rhs to be an integer capable of fitting in an i128"); + if let Ok(x) = FloatToInt::::to_int(*x) { + x.cmp(&y) + } else { + let is_positive = x.is_sign_positive(); + if let Ok(x) = FloatToInt::::to_int((*x).trunc()) { + // Edge case for equality: the float must be bigger due to truncation + match x.cmp(&y) { + Ordering::Equal if is_positive => Ordering::Greater, + Ordering::Equal => Ordering::Less, + o => o, + } + } else { + // The float is larger than i128 can represent, the sign tells us in what + // direction + if is_positive { + Ordering::Greater + } else { + Ordering::Less + } + } + } + } + (x, y @ Self::F64(_)) => y.cmp(x).reverse(), + // u128 immediates require separate treatment + (Self::U128(x), Self::U128(y)) => x.cmp(y), + (Self::U128(x), y) => { + let y = y.as_u128().expect("expected rhs to be an integer in the range of u128"); + x.cmp(&y) + } + (x, Self::U128(y)) => { + let x = x.as_u128().expect("expected lhs to be an integer in the range of u128"); + x.cmp(y) + } + // i128 immediates require separate treatment + (Self::I128(x), Self::I128(y)) => x.cmp(y), + // We're only comparing against values here which are u64, i64, or smaller than 64-bits + (Self::I128(x), y) => { + let y = y.as_i128().expect("expected rhs to be an integer smaller than i128"); + x.cmp(&y) + } + (x, Self::I128(y)) => { + let x = x.as_i128().expect("expected lhs to be an integer smaller than i128"); + x.cmp(y) + } + // u64 immediates may not fit in an i64 + (Self::U64(x), Self::U64(y)) => x.cmp(y), + // We're only comparing against values here which are i64, or smaller than 64-bits + (Self::U64(x), y) => { + let y = + y.as_i64().expect("expected rhs to be an integer capable of fitting in an i64") + as u64; + x.cmp(&y) + } + (x, Self::U64(y)) => { + let x = + x.as_i64().expect("expected lhs to be an integer capable of fitting in an i64") + as u64; + x.cmp(y) + } + // All immediates at this point are i64 or smaller + (x, y) => { + let x = + x.as_i64().expect("expected lhs to be an integer capable of fitting in an i64"); + let y = + y.as_i64().expect("expected rhs to be an integer capable of fitting in an i64"); + x.cmp(&y) + } + } + } +} +impl From for Type { + #[inline] + fn from(imm: Immediate) -> Self { + imm.ty() + } +} +impl From<&Immediate> for Type { + #[inline(always)] + fn from(imm: &Immediate) -> Self { + imm.ty() + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: bool) -> Self { + Self::I1(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i8) -> Self { + Self::I8(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u8) -> Self { + Self::U8(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i16) -> Self { + Self::I16(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u16) -> Self { + Self::U16(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i32) -> Self { + Self::I32(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u32) -> Self { + Self::U32(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i64) -> Self { + Self::I64(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u64) -> Self { + Self::U64(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: u128) -> Self { + Self::U128(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: i128) -> Self { + Self::I128(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: f64) -> Self { + Self::F64(value) + } +} +impl From for Immediate { + #[inline(always)] + fn from(value: char) -> Self { + Self::I32(value as u32 as i32) + } +} + +trait FloatToInt: Sized { + const ZERO: T; + + fn upper_bound() -> Self; + fn lower_bound() -> Self; + fn to_int(self) -> Result; + unsafe fn to_int_unchecked(self) -> T; +} +impl FloatToInt for f64 { + const ZERO: i8 = 0; + + fn upper_bound() -> Self { + f64::from(i8::MAX) + 1.0 + } + + fn lower_bound() -> Self { + f64::from(i8::MIN) - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i8 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: u8 = 0; + + fn upper_bound() -> Self { + f64::from(u8::MAX) + 1.0 + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u8 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: i16 = 0; + + fn upper_bound() -> Self { + f64::from(i16::MAX) + 1.0 + } + + fn lower_bound() -> Self { + f64::from(i16::MIN) - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i16 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: u16 = 0; + + fn upper_bound() -> Self { + f64::from(u16::MAX) + 1.0 + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u16 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: i32 = 0; + + fn upper_bound() -> Self { + f64::from(i32::MAX) + 1.0 + } + + fn lower_bound() -> Self { + f64::from(i32::MIN) - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i32 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: u32 = 0; + + fn upper_bound() -> Self { + f64::from(u32::MAX) + 1.0 + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u32 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: i64 = 0; + + fn upper_bound() -> Self { + 63.0f64.exp2() + } + + fn lower_bound() -> Self { + (63.0f64.exp2() * -1.0) - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i64 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: u64 = 0; + + fn upper_bound() -> Self { + 64.0f64.exp2() + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u64 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: Felt = Felt::ZERO; + + fn upper_bound() -> Self { + 64.0f64.exp2() - 32.0f64.exp2() + 1.0 + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self).map(Felt::new) + } + + unsafe fn to_int_unchecked(self) -> Felt { + Felt::new(f64::to_int_unchecked::(self)) + } +} +impl FloatToInt for f64 { + const ZERO: u128 = 0; + + fn upper_bound() -> Self { + 128.0f64.exp2() + } + + fn lower_bound() -> Self { + 0.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> u128 { + f64::to_int_unchecked(self) + } +} +impl FloatToInt for f64 { + const ZERO: i128 = 0; + + fn upper_bound() -> Self { + f64::from(i128::BITS - 1).exp2() + } + + fn lower_bound() -> Self { + (f64::from(i128::BITS - 1) * -1.0).exp2() - 1.0 + } + + fn to_int(self) -> Result { + float_to_int(self) + } + + unsafe fn to_int_unchecked(self) -> i128 { + f64::to_int_unchecked(self) + } +} + +fn float_to_int(f: f64) -> Result +where + I: Copy, + f64: FloatToInt, +{ + use std::num::FpCategory; + match f.classify() { + FpCategory::Nan | FpCategory::Infinite | FpCategory::Subnormal => Err(()), + FpCategory::Zero => Ok(>::ZERO), + FpCategory::Normal => { + if f == f.trunc() + && f > >::lower_bound() + && f < >::upper_bound() + { + // SAFETY: We know that x must be integral, and within the bounds of its type + Ok(unsafe { >::to_int_unchecked(f) }) + } else { + Err(()) + } + } + } +} diff --git a/hir2/src/ir/insert.rs b/hir2/src/ir/insert.rs new file mode 100644 index 000000000..72ec5361a --- /dev/null +++ b/hir2/src/ir/insert.rs @@ -0,0 +1,234 @@ +use core::fmt; + +use crate::{BlockRef, OperationRef}; + +/// Represents the placement of inserted items relative to a [ProgramPoint] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Insert { + /// New items will be inserted before the current program point + Before, + /// New items will be inserted after the current program point + After, +} + +/// Represents a cursor within a region where new operations will be inserted. +/// +/// The `placement` field determines how new operations will be inserted relative to `at`: +/// +/// * If `at` is a block: +/// * `Insert::Before` will always inserts new operations as the first operation in the block, +/// i.e. every insert pushes to the front of the list of operations. +/// * `Insert::After` will always insert new operations at the end of the block, i.e. every insert +/// pushes to the back of the list of operations. +/// * If `at` is an operation: +/// * `Insert::Before` will always insert new operations directly preceding the `at` operation. +/// * `Insert::After` will always insert new operations directly following the `at` operation. +/// +/// If a builder/rewriter wishes to insert new operations starting at some point in the middle of a +/// block, but then move the insertion point forward as new operations are inserted, the builder +/// must call [move_next] (or [move_prev]) after each insertion. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct InsertionPoint { + pub at: ProgramPoint, + pub placement: Insert, +} +impl InsertionPoint { + /// Create a new insertion point with the specified placement, at the given program point. + #[inline] + pub const fn new(at: ProgramPoint, placement: Insert) -> Self { + Self { at, placement } + } + + /// Create a new insertion point at the given program point, which will place new operations + /// "before" that point. + /// + /// See [Insert::Before] for what the semantics of "before" means with regards to the different + /// kinds of program point. + #[inline] + pub fn before(at: impl Into) -> Self { + Self { + at: at.into(), + placement: Insert::Before, + } + } + + /// Create a new insertion point at the given program point, which will place new operations + /// "after" that point. + /// + /// See [Insert::After] for what the semantics of "after" means with regards to the different + /// kinds of program point. + #[inline] + pub fn after(at: impl Into) -> Self { + Self { + at: at.into(), + placement: Insert::After, + } + } + + /// Moves the insertion point to the previous operation relative to the current point. + /// + /// If there is no operation before the current point, this has no effect. + /// + /// If the current point is a [ProgramPoint::Block], and `self.placement` is `Insert::After`, + /// then this moves the insertion point to the operation immediately preceding the last + /// operation in the block, _if_ there are at least two operations in the block. In other words, + /// `self.at` becomes a [ProgramPoint::Op]. + pub fn move_prev(&mut self) { + match &mut self.at { + ProgramPoint::Op(ref mut current) => { + let prev = current.prev(); + if let Some(prev) = prev { + *current = prev; + } + } + ProgramPoint::Block(ref block) => { + if matches!(self.placement, Insert::After) { + if let Some(prev) = + block.borrow().body().back().as_pointer().and_then(|current| current.prev()) + { + self.at = ProgramPoint::Op(prev); + } + } + } + } + } + + /// Moves the insertion point to the next operation relative to the current point. + /// + /// If there is no operation after the current point, this has no effect. + /// + /// If the current point is a [ProgramPoint::Block], and `self.placement` is `Insert::Before`, + /// then this moves the insertion point to the operation immediately following the first + /// operation in the block, _if_ there are at least two operations in the block. In other words, + /// `self.at` becomes a [ProgramPoint::Op]. + pub fn move_next(&mut self) { + match &mut self.at { + ProgramPoint::Op(ref mut current) => { + let next = current.next(); + if let Some(next) = next { + *current = next; + } + } + ProgramPoint::Block(ref block) => { + if matches!(self.placement, Insert::Before) { + if let Some(next) = block + .borrow() + .body() + .front() + .as_pointer() + .and_then(|current| current.next()) + { + self.at = ProgramPoint::Op(next); + } + } + } + } + } + + /// Get a pointer to the [crate::Operation] on which this insertion point is positioned. + /// + /// Returns `None` if the insertion point is positioned in an empty block. + pub fn op(&self) -> Option { + match self.at { + ProgramPoint::Op(ref op) => Some(op.clone()), + ProgramPoint::Block(ref block) => match self.placement { + Insert::Before => block.borrow().front(), + Insert::After => block.borrow().back(), + }, + } + } + + /// Get a pointer to the [crate::Block] in which this insertion point is positioned. + /// + /// Panics if the current program point is an operation detached from any block. + pub fn block(&self) -> BlockRef { + self.at + .block() + .expect("invalid insertion point: operation is detached from any block") + } + + /// Returns true if this insertion point is positioned at the end of the containing block + pub fn is_at_block_end(&self) -> bool { + let block = self.block().borrow(); + if block.body().is_empty() { + matches!(self.at, ProgramPoint::Block(_)) + } else if matches!(self.placement, Insert::Before) { + false + } else { + match &self.at { + ProgramPoint::Block(_) => true, + ProgramPoint::Op(ref op) => &block.back().unwrap() == op, + } + } + } +} + +/// A `ProgramPoint` represents a position in a region where the live range of an SSA value can +/// begin or end. It can be either: +/// +/// 1. An operation +/// 2. A block +/// +/// This corresponds more or less to the lines in the textual form of the IR. +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum ProgramPoint { + /// An operation + Op(OperationRef), + /// A block + Block(BlockRef), +} +impl ProgramPoint { + /// Unwrap this program point as an [OperationRef], or panic if this is not a [ProgramPoint::Op] + pub fn unwrap_op(self) -> OperationRef { + use crate::EntityWithId; + match self { + Self::Op(x) => x, + Self::Block(x) => panic!("expected operation, but got {}", x.borrow().id()), + } + } + + /// Get the closest operation associated with this program point + /// + /// If this program point refers to a block, this will return the last operation in the block, + /// or `None` if the block is empty. + pub fn op(&self) -> Option { + match self { + Self::Op(op) => Some(op.clone()), + Self::Block(block) => block.borrow().back(), + } + } + + /// Get the block associated with this program point + /// + /// Returns `None` if the program point is a detached operation. + pub fn block(&self) -> Option { + match self { + Self::Op(op) => op.borrow().parent(), + Self::Block(block) => Some(block.clone()), + } + } +} +impl From for ProgramPoint { + fn from(op: OperationRef) -> Self { + Self::Op(op) + } +} +impl From for ProgramPoint { + fn from(block: BlockRef) -> Self { + Self::Block(block) + } +} +impl fmt::Display for ProgramPoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use crate::EntityWithId; + match self { + Self::Op(x) => write!(f, "{}", x.borrow().name()), + Self::Block(x) => write!(f, "{}", x.borrow().id()), + } + } +} +impl fmt::Debug for ProgramPoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("ProgramPoint").field_with(|f| write!(f, "{}", self)).finish() + } +} diff --git a/hir2/src/ir/interface.rs b/hir2/src/ir/interface.rs new file mode 100644 index 000000000..e69de29bb diff --git a/hir2/src/ir/loops.rs b/hir2/src/ir/loops.rs new file mode 100644 index 000000000..58bc2647f --- /dev/null +++ b/hir2/src/ir/loops.rs @@ -0,0 +1,1277 @@ +use alloc::{collections::BTreeMap, rc::Rc}; +use core::{ + cell::{Cell, Ref, RefCell, RefMut}, + fmt, +}; + +use smallvec::SmallVec; + +use super::dominance::{DominanceTree, PostOrderDomTreeIter}; +use crate::{ + adt::{SmallMap, SmallSet}, + cfg::{Graph, Inverse, InvertibleGraph}, + BlockRef, EntityWithId, OperationRef, PostOrderBlockIter, Report, +}; + +/// [LoopForest] represents all of the top-level loop structures in a specified region. +/// +/// The [LoopForest] analysis is used to identify natural loops and determine the loop depth of +/// various nodes in a generic graph of blocks. A natural loop has exactly one entry-point, which +/// is called the header. Note that natural loops may actually be several loops that share the same +/// header node. +/// +/// This analysis calculates the nesting structure of loops in a function. For each natural loop +/// identified, this analysis identifies natural loops contained entirely within the loop and the +/// basic blocks that make up the loop. +/// +/// It can calculate on the fly various bits of information, for example: +/// +/// * Whether there is a preheader for the loop +/// * The number of back edges to the header +/// * Whether or not a particular block branches out of the loop +/// * The successor blocks of the loop +/// * The loop depth +/// * etc... +/// +/// Note that this analysis specifically identifies _loops_ not cycles or SCCs in the graph. There +/// can be strongly connected components in the graph which this analysis will not recognize and +/// that will not be represented by a loop instance. In particular, a loop might be inside such a +/// non-loop SCC, or a non-loop SCC might contain a sub-SCC which is a loop. +/// +/// For an overview of terminology used in this API (and thus all related loop analyses or +/// transforms), see [Loop Terminology](https://llvm.org/docs/LoopTerminology.html). +#[derive(Default)] +pub struct LoopForest { + /// The set of top-level loops in the forest + top_level_loops: SmallVec<[Rc; 4]>, + /// Mapping of basic blocks to the inner most loop they occur in + block_map: BTreeMap>, +} + +impl LoopForest { + /// Compute a new [LoopForest] from the given dominator tree + pub fn new(tree: &DominanceTree) -> Self { + let mut forest = Self::default(); + forest.analyze(tree); + forest + } + + /// Returns true if there are no loops in the forest + pub fn is_empty(&self) -> bool { + self.top_level_loops.is_empty() + } + + /// Returns the number of loops in the forest + pub fn len(&self) -> usize { + self.top_level_loops.len() + } + + /// Returns true if `block` is in this loop forest + #[inline] + pub fn contains_block(&self, block: &BlockRef) -> bool { + self.block_map.contains_key(block) + } + + /// Get the set of top-level/outermost loops in the forest + pub fn top_level_loops(&self) -> &[Rc] { + &self.top_level_loops + } + + /// Return all of the loops in the function in preorder across the loop nests, with siblings in + /// forward program order. + /// + /// Note that because loops form a forest of trees, preorder is equivalent to reverse postorder. + pub fn loops_in_preorder(&self) -> SmallVec<[Rc; 4]> { + // The outer-most loop actually goes into the result in the same relative order as we walk + // it. But LoopForest stores the top level loops in reverse program order so for here we + // reverse it to get forward program order. + // + // FIXME: If we change the order of LoopForest we will want to remove the reverse here. + let mut preorder_loops = SmallVec::<[Rc; 4]>::default(); + for l in self.top_level_loops.iter().cloned().rev() { + let mut loops_in_preorder = l.loops_in_preorder(); + preorder_loops.append(&mut loops_in_preorder); + } + preorder_loops + } + + /// Return all of the loops in the function in preorder across the loop nests, with siblings in + /// _reverse_ program order. + /// + /// Note that because loops form a forest of trees, preorder is equivalent to reverse postorder. + /// + /// Also note that this is _not_ a reverse preorder. Only the siblings are in reverse program + /// order. + pub fn loops_in_reverse_sibling_preorder(&self) -> SmallVec<[Rc; 4]> { + // The outer-most loop actually goes into the result in the same relative order as we walk + // it. LoopForest stores the top level loops in reverse program order so we walk in order + // here. + // + // FIXME: If we change the order of LoopInfo we will want to add a reverse here. + let mut preorder_loops = SmallVec::<[Rc; 4]>::default(); + let mut preorder_worklist = SmallVec::<[Rc; 4]>::default(); + for l in self.top_level_loops.iter().cloned() { + assert!(preorder_worklist.is_empty()); + preorder_worklist.push(l); + while let Some(l) = preorder_worklist.pop() { + // Sub-loops are stored in forward program order, but will process the worklist + // backwards so we can just append them in order. + preorder_worklist.extend(l.nested().iter().cloned()); + preorder_loops.push(l); + } + } + + preorder_loops + } + + /// Return the inner most loop that `block` lives in. + /// + /// If a basic block is in no loop (for example the entry node), `None` is returned. + pub fn loop_for(&self, block: &BlockRef) -> Option> { + self.block_map.get(block).cloned() + } + + /// Return the loop nesting level of the specified block. + /// + /// A depth of 0 means the block is not inside any loop. + pub fn loop_depth(&self, block: &BlockRef) -> usize { + self.loop_for(block).map(|l| l.depth()).unwrap_or(0) + } + + /// Returns true if the block is a loop header + pub fn is_loop_header(&self, block: &BlockRef) -> bool { + self.loop_for(block).map(|l| &l.header() == block).unwrap_or(false) + } + + /// This removes the specified top-level loop from this loop info object. + /// + /// The loop is not deleted, as it will presumably be inserted into another loop. + /// + /// # Panics + /// + /// This function will panic if the given loop is not a top-level loop + pub fn remove_loop(&mut self, l: &Loop) -> Option> { + assert!(l.is_outermost(), "`l` is not an outermost loop"); + let index = self.top_level_loops.iter().position(|tll| core::ptr::addr_eq(&**tll, l))?; + Some(self.top_level_loops.swap_remove(index)) + } + + /// Change the top-level loop that contains `block` to the specified loop. + /// + /// This should be used by transformations that restructure the loop hierarchy tree. + pub fn change_loop_for(&mut self, block: BlockRef, l: Option>) { + if let Some(l) = l { + self.block_map.insert(block, l); + } else { + self.block_map.remove(&block); + } + } + + /// Replace the specified loop in the top-level loops list with the indicated loop. + pub fn change_top_level_loop(&mut self, old: Rc, new: Rc) { + assert!( + new.parent_loop().is_none() && old.parent_loop().is_none(), + "loops already embedded into a subloop" + ); + let index = self + .top_level_loops + .iter() + .position(|tll| Rc::ptr_eq(tll, &old)) + .expect("`old` loop is not a top-level loop"); + self.top_level_loops[index] = new; + } + + /// This adds the specified loop to the collection of top-level loops. + pub fn add_top_level_loop(&mut self, l: Rc) { + assert!(l.is_outermost(), "loop already in subloop"); + self.top_level_loops.push(l); + } + + /// This method completely removes `block` from all data structures, including all of the loop + /// objects it is nested in and our mapping from basic blocks to loops. + pub fn remove_block(&mut self, block: &BlockRef) { + if let Some(l) = self.block_map.remove(block) { + let mut next_l = Some(l); + while let Some(l) = next_l.take() { + next_l = l.parent_loop(); + l.remove_block_from_loop(block); + } + } + } + + pub fn is_not_already_contained_in(sub_loop: Option<&Loop>, parent: Option<&Loop>) -> bool { + let Some(sub_loop) = sub_loop else { + return true; + }; + if parent.is_some_and(|parent| parent == sub_loop) { + return false; + } + return Self::is_not_already_contained_in(sub_loop.parent_loop().as_deref(), parent); + } + + /// Analyze the given dominance tree to discover loops. + /// + /// The analysis discovers loops during a post-order traversal of the given dominator tree, + /// interleaved with backward CFG traversals within each subloop + /// (see `discover_and_map_subloop`). The backward traversal skips inner subloops, so this part + /// of the algorithm is linear in the number of CFG edges. Subloop and block vectors are then + /// populated during a single forward CFG traversal. + /// + /// During the two CFG traversals each block is seen three times: + /// + /// 1. Discovered and mapped by a reverse CFG traversal. + /// 2. Visited during a forward DFS CFG traversal. + /// 3. Reverse-inserted in the loop in postorder following forward DFS. + /// + /// The block vectors are inclusive, so step 3 requires loop-depth number of insertions per + /// block. + pub fn analyze(&mut self, tree: &DominanceTree) { + // Postorder traversal of the dominator tree. + let root = tree.root_node().unwrap(); + for node in PostOrderDomTreeIter::new(root.clone()) { + let header = node.block().expect("expected header block").clone(); + let mut backedges = SmallVec::<[BlockRef; 4]>::default(); + + // Check each predecessor of the potential loop header. + for backedge in BlockRef::inverse_children(header.clone()) { + // If `header` dominates `pred`, this is a new loop. Collect the backedges. + let backedge_node = tree.get(Some(&backedge)); + if backedge_node.is_some() && tree.dominates_node(Some(node.clone()), backedge_node) + { + backedges.push(backedge); + } + } + + // Perform a backward CFG traversal to discover and map blocks in this loop. + if !backedges.is_empty() { + let l = Rc::new(Loop::new(header.clone())); + self.discover_and_map_sub_loop(l, backedges, tree); + } + } + + // Perform a single forward CFG traversal to populate blocks and subloops for all loops. + for block in PostOrderBlockIter::new(root.block().cloned().unwrap()) { + self.insert_into_loop(block); + } + } + + /// Discover a subloop with the specified backedges such that: + /// + /// * All blocks within this loop are mapped to this loop or a subloop. + /// * All subloops within this loop have their parent loop set to this loop or a subloop. + fn discover_and_map_sub_loop( + &mut self, + l: Rc, + backedges: SmallVec<[BlockRef; 4]>, + tree: &DominanceTree, + ) { + let mut num_blocks = 0usize; + let mut num_subloops = 0usize; + + // Perform a backward CFG traversal using a worklist. + let mut reverse_cfg_worklist = backedges; + while let Some(pred) = reverse_cfg_worklist.pop() { + match self.loop_for(&pred) { + None if !tree.is_reachable_from_entry(&pred) => continue, + None => { + // This is an undiscovered block. Map it to the current loop. + self.change_loop_for(pred.clone(), Some(l.clone())); + num_blocks += 1; + if pred == l.header() { + continue; + } + + // Push all block predecessors on the worklist + reverse_cfg_worklist.extend(Inverse::::children(pred.clone())); + } + Some(subloop) => { + // This is a discovered block. Find its outermost discovered loop. + let subloop = subloop.outermost_loop(); + + // If it is already discovered to be a subloop of this loop, continue. + if subloop == l { + continue; + } + + // Discover a subloop of this loop. + subloop.set_parent_loop(Some(l.clone())); + num_subloops += 1; + num_blocks += subloop.num_blocks(); + + // Continue traversal along predecessors that are not loop-back edges from + // within this subloop tree itself. Note that a predecessor may directly reach + // another subloop that is not yet discovered to be a subloop of this loop, + // which we must traverse. + for pred in BlockRef::inverse_children(subloop.header()) { + if self.loop_for(&pred).is_none_or(|l| l != subloop) { + reverse_cfg_worklist.push(pred); + } + } + } + } + } + + l.nested.borrow_mut().reserve(num_subloops); + l.reserve(num_blocks); + } + + /// Add a single block to its ancestor loops in post-order. + /// + /// If the block is a subloop header, add the subloop to its parent in post-order, then reverse + /// the block and subloop vectors of the now complete subloop to achieve RPO. + fn insert_into_loop(&mut self, block: BlockRef) { + let mut subloop = self.loop_for(&block); + if let Some(sl) = subloop.clone().filter(|sl| sl.header() == block) { + let parent = sl.parent_loop(); + // We reach this point once per subloop after processing all the blocks in the subloop. + if sl.is_outermost() { + self.add_top_level_loop(sl.clone()); + } else { + parent.as_ref().unwrap().nested.borrow_mut().push(sl.clone()); + } + + // For convenience, blocks and subloops are inserted in postorder. Reverse the lists, + // except for the loop header, which is always at the beginning. + sl.reverse_blocks(1); + sl.nested.borrow_mut().reverse(); + subloop = parent; + } + + while let Some(sl) = subloop.take() { + sl.add_block_entry(block.clone()); + subloop = sl.parent_loop(); + } + } + + /// Verify the loop forest structure using the provided [DominanceTree] + pub fn verify(&self, tree: &DominanceTree) -> Result<(), Report> { + let mut loops = SmallSet::, 2>::default(); + for l in self.top_level_loops.iter().cloned() { + if !l.is_outermost() { + return Err(Report::msg("top-level loop has a parent")); + } + l.verify_loop_nest(&mut loops)?; + } + + if cfg!(debug_assertions) { + // Verify that blocks are mapped to valid loops. + for (block, block_loop) in self.block_map.iter() { + if !loops.contains(block_loop) { + return Err(Report::msg("orphaned loop")); + } + if !block_loop.contains_block(block) { + return Err(Report::msg("orphaned block")); + } + for child_loop in block_loop.nested().iter() { + if child_loop.contains_block(block) { + return Err(Report::msg( + "expected block map to reflect the innermost loop containing `block`", + )); + } + } + } + + // Recompute forest to verify loops structure. + let other = LoopForest::new(tree); + + // Build a map we can use to move from our forest to the newly computed one. This allows + // us to ignore the particular order in any layer of the loop forest while still + // comparing the structure. + let mut other_headers = SmallMap::, 8>::default(); + + fn add_inner_loops_to_headers_map( + headers: &mut SmallMap, 8>, + l: &Rc, + ) { + let header = l.header(); + headers.insert(header, Rc::clone(l)); + for sl in l.nested().iter() { + add_inner_loops_to_headers_map(headers, sl); + } + } + + for l in other.top_level_loops() { + add_inner_loops_to_headers_map(&mut other_headers, l); + } + + // Walk the top level loops and ensure there is a corresponding top-level loop in the + // computed version and then recursively compare those loop nests. + for l in self.top_level_loops() { + let header = l.header(); + let other_l = other_headers.remove(&header); + match other_l { + None => { + return Err(Report::msg( + "top level loop is missing in computed loop forest", + )) + } + Some(other_l) => { + // Recursively compare the loops + Self::compare_loops(l.clone(), other_l, &mut other_headers)?; + } + } + } + + // Any remaining entries in the map are loops which were found when computing a fresh + // loop forest but not present in the current one. + if !other_headers.is_empty() { + for (_header, header_loop) in other_headers { + log::trace!("Found new loop {header_loop:?}"); + } + return Err(Report::msg("found new loops when recomputing loop forest")); + } + } + + Ok(()) + } + + #[cfg(debug_assertions)] + fn compare_loops( + l: Rc, + other_l: Rc, + other_loop_headers: &mut SmallMap, 8>, + ) -> Result<(), Report> { + let header = l.header(); + let other_header = other_l.header(); + if header != other_header { + return Err(Report::msg( + "mismatched headers even though found under the same map entry", + )); + } + + if l.depth() != other_l.depth() { + return Err(Report::msg("mismatched loop depth")); + } + + { + let mut parent_l = Some(l.clone()); + let mut other_parent_l = Some(other_l.clone()); + while let Some(pl) = parent_l.take() { + if let Some(opl) = other_parent_l.take() { + if pl.header() != opl.header() { + return Err(Report::msg("mismatched parent loop headers")); + } + parent_l = pl.parent_loop(); + other_parent_l = opl.parent_loop(); + } else { + return Err(Report::msg( + "`other_l` misreported its depth: expected a parent and got none", + )); + } + } + } + + for sl in l.nested().iter() { + let sl_header = sl.header(); + let other_sl = other_loop_headers.remove(&sl_header); + match other_sl { + None => return Err(Report::msg("inner loop is missing in computed loop forest")), + Some(other_sl) => { + Self::compare_loops(sl.clone(), other_sl, other_loop_headers)?; + } + } + } + + let mut blocks = l.blocks.borrow().clone(); + let mut other_blocks = other_l.blocks.borrow().clone(); + blocks.sort_by_key(|b| b.borrow().id()); + other_blocks.sort_by_key(|b| b.borrow().id()); + if blocks != other_blocks { + log::trace!("blocks: {}", crate::formatter::DisplayValues::new(blocks.iter())); + log::trace!( + "other_blocks: {}", + crate::formatter::DisplayValues::new(other_blocks.iter()) + ); + return Err(Report::msg("loops report mismatched blocks")); + } + + let block_set = l.block_set(); + let other_block_set = other_l.block_set(); + let diff = block_set.symmetric_difference(&other_block_set); + if block_set.len() != other_block_set.len() || !diff.is_empty() { + log::trace!( + "block_set: {}", + crate::formatter::DisplayValues::new(block_set.iter()) + ); + log::trace!( + "other_block_set: {}", + crate::formatter::DisplayValues::new(other_block_set.iter()) + ); + log::trace!("diff: {}", crate::formatter::DisplayValues::new(diff.iter())); + return Err(Report::msg("loops report mismatched block sets")); + } + + Ok(()) + } + + #[cfg(not(debug_assertions))] + fn compare_loops( + _l: Rc, + _other_l: Rc, + _other_loop_headers: &mut SmallMap, 8>, + ) -> Result<(), Report> { + Ok(()) + } +} + +impl fmt::Debug for LoopForest { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("LoopInfo") + .field("top_level_loops", &self.top_level_loops) + .field("block_map", &self.block_map) + .finish() + } +} + +/// Edge type. +pub type LoopEdge = (BlockRef, BlockRef); + +/// [Loop] is used to represent loops that are detected in the control-flow graph. +#[derive(Default)] +pub struct Loop { + /// If this loop is an outermost loop, this field is `None`. + /// + /// Otherwise, it holds a handle to the parent loop which transfers control to this loop. + parent_loop: Cell>>, + /// Loops contained entirely within this one. + /// + /// All of the loops in this set will have their `parent` set to this loop + nested: RefCell; 2]>>, + /// The list of blocks in this loop. + /// + /// The header block is always at index 0. + blocks: RefCell>, + /// The uniqued set of blocks present in this loop + block_set: RefCell>, +} + +impl Eq for Loop {} +impl PartialEq for Loop { + fn eq(&self, other: &Self) -> bool { + core::ptr::addr_eq(self, other) + } +} + +impl Loop { + /// Create a new [Loop] with `block` as its header. + pub fn new(block: BlockRef) -> Self { + let mut this = Self::default(); + this.blocks.get_mut().push(block.clone()); + this.block_set.get_mut().insert(block); + this + } + + /// Get the nesting level of this loop. + /// + /// An outer-most loop has depth 1, for consistency with loop depth values used for basic + /// blocks, where depth 0 is used for blocks not inside any loops. + pub fn depth(&self) -> usize { + let mut depth = 1; + let mut current_loop = self.parent_loop(); + while let Some(curr) = current_loop.take() { + depth += 1; + current_loop = curr.parent_loop(); + } + depth + } + + /// Get the header block of this loop + pub fn header(&self) -> BlockRef { + self.blocks.borrow()[0].clone() + } + + /// Return the parent loop of this loop, if it has one, or `None` if it is a top-level loop. + /// + /// A loop is either top-level in a function (that is, it is not contained in any other loop) or + /// it is entirely enclosed in some other loop. If a loop is top-level, it has no parent, + /// otherwise its parent is the innermost loop in which it is enclosed. + pub fn parent_loop(&self) -> Option> { + unsafe { (*self.parent_loop.as_ptr()).clone() } + } + + /// This is a low-level API for bypassing [add_child_loop]. + pub fn set_parent_loop(&self, parent: Option>) { + self.parent_loop.set(parent); + } + + /// Discover the outermost loop that contains `self` + pub fn outermost_loop(self: Rc) -> Rc { + let mut l = self; + while let Some(parent) = l.parent_loop() { + l = parent; + } + l + } + + /// Return true if the specified loop is contained within in this loop. + pub fn contains(&self, l: Rc) -> bool { + if core::ptr::addr_eq(self, &*l) { + return true; + } + + let Some(parent) = l.parent_loop() else { + return false; + }; + + self.contains(parent) + } + + /// Returns true if the specified basic block is in this loop + pub fn contains_block(&self, block: &BlockRef) -> bool { + self.block_set.borrow().contains(block) + } + + /// Returns true if the specified operation is in this loop + pub fn contains_op(&self, op: &OperationRef) -> bool { + let Some(block) = op.borrow().parent() else { + return false; + }; + self.contains_block(&block) + } + + /// Return the loops contained entirely within this loop. + pub fn nested(&self) -> Ref<'_, [Rc]> { + Ref::map(self.nested.borrow(), |nested| nested.as_slice()) + } + + /// Return true if the loop does not contain any (natural) loops. + /// + /// [Loop] does not detect irreducible control flow, just natural loops. That is, it is possible + /// that there is cyclic control flow within the innermost loop or around the outermost loop. + pub fn is_innermost(&self) -> bool { + self.nested.borrow().is_empty() + } + + /// Return true if the loop does not have a parent (natural) loop (i.e. it is outermost, which + /// is the same as top-level). + pub fn is_outermost(&self) -> bool { + unsafe { (*self.parent_loop.as_ptr()).is_none() } + } + + /// Get a list of the basic blocks which make up this loop. + pub fn blocks(&self) -> Ref<'_, [BlockRef]> { + Ref::map(self.blocks.borrow(), |blocks| blocks.as_slice()) + } + + /// Get a mutable reference to the basic blocks which make up this loop. + pub fn blocks_mut(&self) -> RefMut<'_, SmallVec<[BlockRef; 32]>> { + self.blocks.borrow_mut() + } + + /// Return the number of blocks contained in this loop + pub fn num_blocks(&self) -> usize { + self.blocks.borrow().len() + } + + /// Return a reference to the blocks set. + pub fn block_set(&self) -> Ref<'_, SmallSet> { + self.block_set.borrow() + } + + /// Return a mutable reference to the blocks set. + pub fn block_set_mut(&self) -> RefMut<'_, SmallSet> { + self.block_set.borrow_mut() + } + + /// Returns true if the terminator of `block` can branch to another block that is outside of the + /// current loop. + /// + /// # Panics + /// + /// This function will panic if `block` is not inside this loop. + pub fn is_loop_exiting(&self, block: &BlockRef) -> bool { + assert!(self.contains_block(block), "exiting block must be part of the loop"); + BlockRef::children(block.clone()).any(|succ| !self.contains_block(&succ)) + } + + /// Returns true if `block` is a loop-latch. + /// + /// A latch block is a block that contains a branch back to the header. + /// + /// This function is useful when there are multiple latches in a loop because `get_loop_latch` + /// will return `None` in that case. + pub fn is_loop_latch(&self, block: &BlockRef) -> bool { + assert!(self.contains_block(block), "block does not belong to the loop"); + BlockRef::inverse_children(self.header()).any(|pred| &pred == block) + } + + /// Calculate the number of back edges to the loop header + pub fn num_backedges(&self) -> usize { + BlockRef::inverse_children(self.header()) + .filter(|pred| self.contains_block(pred)) + .count() + } +} + +/// Loop Analysis +/// +/// Note that all of these methods can fail on general loops (ie, there may not be a preheader, +/// etc). For best success, the loop simplification and induction variable canonicalization pass +/// should be used to normalize loops for easy analysis. These methods assume canonical loops. +impl Loop { + /// Get all blocks inside the loop that have successors outside of the loop. + /// + /// These are the blocks _inside of the current loop_ which branch out. The returned list is + /// always unique. + pub fn exiting_blocks(&self) -> SmallVec<[BlockRef; 2]> { + let mut exiting_blocks = SmallVec::default(); + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + // A block must be an exit block if it is not contained in the current loop + if !self.contains_block(&succ) { + exiting_blocks.push(block.clone()); + break; + } + } + } + exiting_blocks + } + + /// If [Self::exiting_blocks] would return exactly one block, return it, otherwise `None`. + pub fn exiting_block(&self) -> Option { + let mut exiting_block = None; + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + if exiting_block.is_some() { + return None; + } else { + exiting_block = Some(block.clone()); + } + break; + } + } + } + exiting_block + } + + /// Get all of the successor blocks of this loop. + /// + /// These are the blocks _outside of the current loop_ which are branched to. + pub fn exit_blocks(&self) -> SmallVec<[BlockRef; 2]> { + let mut exit_blocks = SmallVec::default(); + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + exit_blocks.push(succ); + } + } + } + exit_blocks + } + + /// If [Self::exit_blocks] would return exactly one block, return it, otherwise `None`. + pub fn exit_block(&self) -> Option { + let mut exit_block = None; + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + if exit_block.is_some() { + return None; + } else { + exit_block = Some(succ); + } + } + } + } + exit_block + } + + /// Returns true if no exit block for the loop has a predecessor that is outside the loop. + pub fn has_dedicated_exits(&self) -> bool { + // Each predecessor of each exit block of a normal loop is contained within the loop. + for exit_block in self.unique_exit_blocks() { + for pred in BlockRef::inverse_children(exit_block) { + if !self.contains_block(&pred) { + return false; + } + } + } + + // All the requirements are met. + true + } + + /// Return all unique successor blocks of this loop. + /// + /// These are the blocks _outside of the current loop_ which are branched to. + pub fn unique_exit_blocks(&self) -> SmallVec<[BlockRef; 2]> { + let mut unique_exits = SmallVec::default(); + unique_exit_blocks_helper(self, &mut unique_exits, |_| true); + unique_exits + } + + /// Return all unique successor blocks of this loop, except successors from the latch block + /// which are not considered. If an exit that comes from the latch block, but also has a non- + /// latch predecessor in the loop, it will be included. + /// + /// These are the blocks _outside of the current loop_ which are branched to. + pub fn unique_non_latch_exit_blocks(&self) -> SmallVec<[BlockRef; 2]> { + let latch_block = self.loop_latch().expect("latch must exist"); + let mut unique_exits = SmallVec::default(); + unique_exit_blocks_helper(self, &mut unique_exits, |block| block != &latch_block); + unique_exits + } + + /// If [Self::unique_exit_blocks] would return exactly one block, return it, otherwise `None`. + #[inline] + pub fn unique_exit_block(&self) -> Option { + self.exit_block() + } + + /// Return true if this loop does not have any exit blocks. + pub fn has_no_exit_blocks(&self) -> bool { + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + return false; + } + } + } + true + } + + /// Return all pairs of (_inside_block_, _outside_block_). + pub fn exit_edges(&self) -> SmallVec<[LoopEdge; 2]> { + let mut exit_edges = SmallVec::default(); + for block in self.blocks.borrow().iter() { + for succ in BlockRef::children(block.clone()) { + if !self.contains_block(&succ) { + exit_edges.push((block.clone(), succ)); + } + } + } + exit_edges + } + + /// Returns the pre-header for this loop, if there is one. + /// + /// A loop has a pre-header if there is only one edge to the header of the loop from outside of + /// the loop. If this is the case, the block branching to the header of the loop is the + /// pre-header node. + /// + /// This returns `None` if there is no pre-header for the loop. + pub fn preheader(&self) -> Option { + use crate::IteratorExt; + + // Keep track of nodes outside the loop branching to the header... + let out = self.loop_predecessor()?; + + // Make sure we are allowed to hoist instructions into the predecessor. + if !out.borrow().is_legal_to_hoist_into() { + return None; + } + + // Make sure there is only one exit out of the preheader. + if !BlockRef::children(out.clone()).has_single_element() { + // Multiple exits from the block, must not be a preheader. + return None; + } + + // The predecessor has exactly one successor, so it is a preheader. + Some(out) + } + + /// If the given loop's header has exactly one unique predecessor outside the loop, return it. + /// + /// This is less strict than the loop "preheader" concept, which requires the predecessor to + /// have exactly one successor. + pub fn loop_predecessor(&self) -> Option { + // Keep track of nodes outside the loop branching to the header... + let mut out = None; + // Loop over the predecessors of the header node... + let header = self.header(); + for pred in BlockRef::inverse_children(header) { + if !self.contains_block(&pred) { + if out.as_ref().is_some_and(|out| out != &pred) { + // Multiple predecessors outside the loop + return None; + } + out = Some(pred); + } + } + out + } + + /// If there is a single latch block for this loop, return it. + /// + /// A latch block is a block that contains a branch back to the header. + pub fn loop_latch(&self) -> Option { + let header = self.header(); + let mut latch_block = None; + for pred in BlockRef::inverse_children(header) { + if self.contains_block(&pred) { + if latch_block.is_some() { + return None; + } + latch_block = Some(pred); + } + } + latch_block + } + + /// Get all loop latch blocks of this loop. + /// + /// A latch block is a block that contains a branch back to the header. + pub fn loop_latches(&self) -> SmallVec<[BlockRef; 2]> { + BlockRef::inverse_children(self.header()) + .filter(|pred| self.contains_block(pred)) + .collect() + } + + /// Return all inner loops in the loop nest rooted by the loop in preorder, with siblings in + /// forward program order. + pub fn inner_loops_in_preorder(&self) -> SmallVec<[Rc; 2]> { + let mut worklist = SmallVec::<[Rc; 4]>::default(); + worklist.extend(self.nested().iter().rev().cloned()); + + let mut results = SmallVec::default(); + while let Some(l) = worklist.pop() { + // Sub-loops are stored in forward program order, but will process the + // worklist backwards so append them in reverse order. + worklist.extend(l.nested().iter().rev().cloned()); + results.push(l); + } + + results + } + + /// Return all loops in the loop nest rooted by the loop in preorder, with siblings in forward + /// program order. + pub fn loops_in_preorder(self: Rc) -> SmallVec<[Rc; 2]> { + let mut loops = self.inner_loops_in_preorder(); + loops.insert(0, self); + loops + } +} + +fn unique_exit_blocks_helper( + l: &Loop, + exit_blocks: &mut SmallVec<[BlockRef; 2]>, + mut predicate: F, +) where + F: FnMut(&BlockRef) -> bool, +{ + let mut visited = SmallSet::::default(); + for block in l.blocks.borrow().iter().filter(|b| predicate(b)).cloned() { + for succ in BlockRef::children(block) { + if !l.contains_block(&succ) && visited.insert(succ.clone()) { + exit_blocks.push(succ); + } + } + } +} + +/// Updates +impl Loop { + /// Add `block` to this loop, and as a member of all parent loops. + /// + /// It is not valid to replace the loop header using this function. + /// + /// This is intended for use by analyses which need to update loop information. + pub fn add_block_to_loop(self: Rc, block: BlockRef, forest: &mut LoopForest) { + assert!(!forest.contains_block(&block), "`block` is already in this loop"); + + // Add the loop mapping to the LoopForest object... + forest.block_map.insert(block.clone(), self.clone()); + + // Add the basic block to this loop and all parent loops... + let mut next_l = Some(self); + while let Some(l) = next_l.take() { + l.add_block_entry(block.clone()); + next_l = l.parent_loop(); + } + } + + /// Replace `prev` with `new` in the set of children of this loop, updating the parent pointer + /// of `prev` to `None`, and of `new` to `self`. + /// + /// This also updates the loop depth of the new child. + /// + /// This is intended for use when splitting loops up. + pub fn replace_child_loop_with(self: Rc, prev: Rc, new: Rc) { + assert_eq!(prev.parent_loop().as_ref(), Some(&self), "this loop is already broken"); + assert!(new.parent_loop().is_none(), "`new` already has a parent"); + + // Set the parent of `new` to `self` + new.set_parent_loop(Some(self.clone())); + // Replace `prev` in `self.nested` with `new` + let mut nested = self.nested.borrow_mut(); + let entry = nested.iter_mut().find(|l| Rc::ptr_eq(l, &prev)).expect("`prev` not in loop"); + let _ = core::mem::replace(entry, new); + // Set the parent of `prev` to `None` + prev.set_parent_loop(None); + } + + /// Add the specified loop to be a child of this loop. + /// + /// This updates the loop depth of the new child. + pub fn add_child_loop(self: Rc, child: Rc) { + assert!(child.parent_loop().is_none(), "child already has a parent"); + child.set_parent_loop(Some(self.clone())); + self.nested.borrow_mut().push(child); + } + + /// This removes subloops of this loop based on the provided predicate, and returns them in a + /// vector. + /// + /// The loops are not deleted, as they will presumably be inserted into another loop. + pub fn take_child_loops(&self, should_remove: F) -> SmallVec<[Rc; 2]> + where + F: Fn(&Loop) -> bool, + { + let mut taken = SmallVec::default(); + self.nested.borrow_mut().retain(|l| { + if should_remove(l) { + l.set_parent_loop(None); + taken.push(Rc::clone(l)); + false + } else { + true + } + }); + taken + } + + /// This removes the specified child from being a subloop of this loop. + /// + /// The loop is not deleted, as it will presumably be inserted into another loop. + pub fn take_child_loop(&self, child: &Loop) -> Option> { + let mut nested = self.nested.borrow_mut(); + let index = nested.iter().position(|l| core::ptr::addr_eq(&**l, child))?; + Some(nested.swap_remove(index)) + } + + /// This adds a basic block directly to the basic block list. + /// + /// This should only be used by transformations that create new loops. Other transformations + /// should use [add_block_to_loop]. + pub fn add_block_entry(&self, block: BlockRef) { + self.blocks.borrow_mut().push(block.clone()); + self.block_set.borrow_mut().insert(block); + } + + /// Reverse the order of blocks in this loop starting from `index` to the end. + pub fn reverse_blocks(&self, index: usize) { + self.blocks.borrow_mut()[index..].reverse(); + } + + /// Reserve capacity for `capacity` blocks + pub fn reserve(&self, capacity: usize) { + self.blocks.borrow_mut().reserve(capacity); + } + + /// This method is used to move `block` (which must be part of this loop) to be the loop header + /// of the loop (the block that dominates all others). + pub fn move_to_header(&self, block: BlockRef) { + let mut blocks = self.blocks.borrow_mut(); + let index = blocks.iter().position(|b| b == &block).expect("loop does not contain `block`"); + if index == 0 { + return; + } + unsafe { + blocks.swap_unchecked(0, index); + } + } + + /// This removes the specified basic block from the current loop, updating the `self.blocks` as + /// appropriate. This does not update the mapping in the corresponding [LoopInfo]. + pub fn remove_block_from_loop(&self, block: &BlockRef) { + let mut blocks = self.blocks.borrow_mut(); + let index = blocks.iter().position(|b| b == block).expect("loop does not contain `block`"); + blocks.swap_remove(index); + self.block_set.borrow_mut().remove(block); + } + + /// Verify loop structure + #[cfg(debug_assertions)] + pub fn verify_loop(&self) -> Result<(), Report> { + use crate::PreOrderBlockIter; + + if self.blocks.borrow().is_empty() { + return Err(Report::msg("loop header is missing")); + } + + // Setup for using a depth-first iterator to visit every block in the loop. + let exit_blocks = self.exit_blocks(); + let mut visit_set = SmallSet::::default(); + visit_set.extend(exit_blocks.iter().cloned()); + + // Keep track of the BBs visited. + let mut visited_blocks = SmallSet::::default(); + + // Check the individual blocks. + let header = self.header(); + for block in + PreOrderBlockIter::new_with_visited(header.clone(), exit_blocks.iter().cloned()) + { + let has_in_loop_successors = + BlockRef::children(block.clone()).any(|b| self.contains_block(&b)); + if !has_in_loop_successors { + return Err(Report::msg("loop block has no in-loop successors")); + } + + let has_in_loop_predecessors = + BlockRef::inverse_children(block.clone()).any(|b| self.contains_block(&b)); + if !has_in_loop_predecessors { + return Err(Report::msg("loop block has no in-loop predecessors")); + } + + let outside_loop_preds = BlockRef::inverse_children(block.clone()) + .filter(|b| !self.contains_block(b)) + .collect::>(); + + if block == header && outside_loop_preds.is_empty() { + return Err(Report::msg("loop is unreachable")); + } else if !outside_loop_preds.is_empty() { + // A non-header loop shouldn't be reachable from outside the loop, though it is + // permitted if the predecessor is not itself actually reachable. + let entry = block.borrow().parent().unwrap().borrow().entry_block_ref().unwrap(); + for child_block in PreOrderBlockIter::new(entry) { + if outside_loop_preds.iter().any(|pred| &child_block == pred) { + return Err(Report::msg("loop has multiple entry points")); + } + } + } + if block != header.borrow().parent().unwrap().borrow().entry_block_ref().unwrap() { + return Err(Report::msg("loop contains region entry block")); + } + visited_blocks.insert(block); + } + + if visited_blocks.len() != self.num_blocks() { + log::trace!("The following blocks are unreachable in the loop: "); + for block in self.blocks().iter() { + if !visited_blocks.contains(block) { + log::trace!("{block}"); + } + } + return Err(Report::msg("unreachable block in loop")); + } + + // Check the subloops + for subloop in self.nested().iter() { + // Each block in each subloop should be contained within this loop. + for block in subloop.blocks().iter() { + if !self.contains_block(block) { + return Err(Report::msg( + "loop does not contain all the blocks of its subloops", + )); + } + } + } + + // Check the parent loop pointer. + if let Some(parent) = self.parent_loop() { + if !parent.nested().contains(&parent) { + return Err(Report::msg("loop is not a subloop of its parent")); + } + } + + Ok(()) + } + + #[cfg(not(debug_assertions))] + pub fn verify_loop(&self) {} + + /// Verify loop structure of this loop and all nested loops. + pub fn verify_loop_nest( + self: Rc, + loops: &mut SmallSet, 2>, + ) -> Result<(), Report> { + loops.insert(self.clone()); + + // Verify this loop. + self.verify_loop()?; + + // Verify the subloops. + for l in self.nested.borrow().iter().cloned() { + l.verify_loop_nest(loops)?; + } + + Ok(()) + } + + /// Print loop with all the blocks inside it. + pub fn print(&self, verbose: bool) -> impl fmt::Display + '_ { + PrintLoop { + loop_info: self, + nested: true, + verbose, + } + } +} + +struct PrintLoop<'a> { + loop_info: &'a Loop, + nested: bool, + verbose: bool, +} + +impl<'a> crate::formatter::PrettyPrint for PrintLoop<'a> { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + + let mut doc = const_text("loop containing: "); + let header = self.loop_info.header(); + for (i, block) in self.loop_info.blocks().iter().enumerate() { + if !self.verbose { + if i > 0 { + doc += const_text(", "); + } + doc += display(block.clone()); + } else { + doc += nl(); + } + + if block == &header { + doc += const_text("

"); + } else if self.loop_info.is_loop_latch(block) { + doc += const_text(""); + } else if self.loop_info.is_loop_exiting(block) { + doc += const_text(""); + } + + if self.verbose { + doc += text(format!("{:?}", &block.borrow())); + } + } + + if self.nested { + let nested = self.loop_info.nested().iter().fold(Document::Empty, |acc, l| { + let printer = PrintLoop { + loop_info: l, + nested: true, + verbose: self.verbose, + }; + acc + nl() + printer.render() + }); + doc + indent(2, nested) + } else { + doc + } + } +} + +impl<'a> fmt::Display for PrintLoop<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use crate::formatter::PrettyPrint; + self.pretty_print(f) + } +} + +impl fmt::Display for Loop { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.print(false)) + } +} +impl fmt::Debug for Loop { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Loop") + .field("parent_loop", &self.parent_loop()) + .field("nested", &self.nested()) + .field("blocks", &self.blocks()) + .field("block_set", &self.block_set()) + .finish() + } +} diff --git a/hir2/src/ir/op.rs b/hir2/src/ir/op.rs new file mode 100644 index 000000000..5406f7a12 --- /dev/null +++ b/hir2/src/ir/op.rs @@ -0,0 +1,143 @@ +use super::*; +use crate::{any::AsAny, AttributeValue}; + +pub trait OpRegistration: Op { + fn name() -> ::midenc_hir_symbol::Symbol; + fn register_with(dialect: &dyn Dialect) -> OperationName; +} + +pub trait BuildableOp: Op { + type Builder<'a, T>: FnOnce, crate::Report>> + + 'a + where + T: ?Sized + Builder + 'a; + fn builder<'b, B>(builder: &'b mut B, span: SourceSpan) -> Self::Builder<'b, B> + where + B: ?Sized + Builder + 'b; +} + +pub trait Op: AsAny + OpVerifier { + /// The name of this operation's opcode + /// + /// The opcode must be distinct from all other opcodes in the same dialect + fn name(&self) -> OperationName; + fn as_operation(&self) -> &Operation; + fn as_operation_mut(&mut self) -> &mut Operation; + + fn set_span(&mut self, span: SourceSpan) { + self.as_operation_mut().set_span(span); + } + fn parent(&self) -> Option { + self.as_operation().parent() + } + fn parent_region(&self) -> Option { + self.as_operation().parent_region() + } + fn parent_op(&self) -> Option { + self.as_operation().parent_op() + } + fn num_regions(&self) -> usize { + self.as_operation().num_regions() + } + fn regions(&self) -> &RegionList { + self.as_operation().regions() + } + fn regions_mut(&mut self) -> &mut RegionList { + self.as_operation_mut().regions_mut() + } + fn region(&self, index: usize) -> EntityRef<'_, Region> { + self.as_operation().region(index) + } + fn region_mut(&mut self, index: usize) -> EntityMut<'_, Region> { + self.as_operation_mut().region_mut(index) + } + fn has_successors(&self) -> bool { + self.as_operation().has_successors() + } + fn num_successors(&self) -> usize { + self.as_operation().num_successors() + } + fn has_operands(&self) -> bool { + self.as_operation().has_operands() + } + fn num_operands(&self) -> usize { + self.as_operation().num_operands() + } + fn operands(&self) -> &OpOperandStorage { + self.as_operation().operands() + } + fn operands_mut(&mut self) -> &mut OpOperandStorage { + self.as_operation_mut().operands_mut() + } + fn results(&self) -> &OpResultStorage { + self.as_operation().results() + } + fn results_mut(&mut self) -> &mut OpResultStorage { + self.as_operation_mut().results_mut() + } + fn successors(&self) -> &OpSuccessorStorage { + self.as_operation().successors() + } + fn successors_mut(&mut self) -> &mut OpSuccessorStorage { + self.as_operation_mut().successors_mut() + } +} + +impl Spanned for dyn Op { + fn span(&self) -> SourceSpan { + self.as_operation().span + } +} + +pub trait OpExt { + /// Return the value associated with attribute `name` for this function + fn get_attribute(&self, name: impl Into) -> Option<&dyn AttributeValue>; + + /// Return true if this function has an attributed named `name` + fn has_attribute(&self, name: impl Into) -> bool; + + /// Set the attribute `name` with `value` for this function. + fn set_attribute( + &mut self, + name: impl Into, + value: Option, + ); + + /// Remove any attribute with the given name from this function + fn remove_attribute(&mut self, name: impl Into); + + /// Returns a handle to the nearest containing [Operation] of type `T` for this operation, if it + /// is attached to one + fn nearest_parent_op(&self) -> Option>; +} + +impl OpExt for T { + #[inline] + fn get_attribute(&self, name: impl Into) -> Option<&dyn AttributeValue> { + self.as_operation().get_attribute(name) + } + + #[inline] + fn has_attribute(&self, name: impl Into) -> bool { + self.as_operation().has_attribute(name) + } + + #[inline] + fn set_attribute( + &mut self, + name: impl Into, + value: Option, + ) { + self.as_operation_mut().set_attribute(name, value); + } + + #[inline] + fn remove_attribute(&mut self, name: impl Into) { + self.as_operation_mut().remove_attribute(name); + } + + #[inline] + fn nearest_parent_op(&self) -> Option> { + self.as_operation().nearest_parent_op() + } +} diff --git a/hir2/src/ir/operands.rs b/hir2/src/ir/operands.rs new file mode 100644 index 000000000..e2db1b83b --- /dev/null +++ b/hir2/src/ir/operands.rs @@ -0,0 +1,102 @@ +use core::fmt; + +use crate::{EntityRef, OperationRef, Type, UnsafeIntrusiveEntityRef, Value, ValueId, ValueRef}; + +pub type OpOperand = UnsafeIntrusiveEntityRef; +pub type OpOperandList = crate::EntityList; +#[allow(unused)] +pub type OpOperandIter<'a> = crate::EntityIter<'a, OpOperandImpl>; +#[allow(unused)] +pub type OpOperandCursor<'a> = crate::EntityCursor<'a, OpOperandImpl>; +#[allow(unused)] +pub type OpOperandCursorMut<'a> = crate::EntityCursorMut<'a, OpOperandImpl>; + +/// An [OpOperand] represents a use of a [Value] by an [Operation] +pub struct OpOperandImpl { + /// The operand value + pub value: ValueRef, + /// The owner of this operand, i.e. the operation it is an operand of + pub owner: OperationRef, + /// The index of this operand in the operand list of an operation + pub index: u8, +} +impl OpOperandImpl { + #[inline] + pub fn new(value: ValueRef, owner: OperationRef, index: u8) -> Self { + Self { + value, + owner, + index, + } + } + + pub fn value(&self) -> EntityRef<'_, dyn Value> { + self.value.borrow() + } + + #[inline] + pub fn as_value_ref(&self) -> ValueRef { + self.value.clone() + } + + #[inline] + pub fn as_operand_ref(&self) -> OpOperand { + unsafe { OpOperand::from_raw(self) } + } + + pub fn owner(&self) -> EntityRef<'_, crate::Operation> { + self.owner.borrow() + } + + pub fn ty(&self) -> crate::Type { + self.value().ty().clone() + } +} +impl fmt::Debug for OpOperandImpl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + #[allow(unused)] + struct ValueInfo<'a> { + id: ValueId, + ty: &'a Type, + } + + let value = self.value.borrow(); + let id = value.id(); + let ty = value.ty(); + f.debug_struct("OpOperand") + .field("index", &self.index) + .field("value", &ValueInfo { id, ty }) + .finish_non_exhaustive() + } +} +impl crate::Spanned for OpOperandImpl { + fn span(&self) -> crate::SourceSpan { + self.value.borrow().span() + } +} +impl crate::Entity for OpOperandImpl {} +impl crate::StorableEntity for OpOperandImpl { + #[inline(always)] + fn index(&self) -> usize { + self.index as usize + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index.try_into().expect("too many operands"); + } + + fn unlink(&mut self) { + let ptr = self.as_operand_ref(); + let mut value = self.value.borrow_mut(); + let uses = value.uses_mut(); + unsafe { + let mut cursor = uses.cursor_mut_from_ptr(ptr); + cursor.remove(); + } + } +} + +pub type OpOperandStorage = crate::EntityStorage; +pub type OpOperandRange<'a> = crate::EntityRange<'a, OpOperand>; +pub type OpOperandRangeMut<'a> = crate::EntityRangeMut<'a, OpOperand, 1>; diff --git a/hir2/src/ir/operation.rs b/hir2/src/ir/operation.rs new file mode 100644 index 000000000..0a2436dfd --- /dev/null +++ b/hir2/src/ir/operation.rs @@ -0,0 +1,1234 @@ +mod builder; +mod name; + +use alloc::rc::Rc; +use core::{ + fmt, + ptr::{DynMetadata, NonNull, Pointee}, + sync::atomic::AtomicU32, +}; + +use smallvec::SmallVec; + +pub use self::{builder::OperationBuilder, name::OperationName}; +use super::*; +use crate::{AttributeSet, AttributeValue}; + +pub type OperationRef = UnsafeIntrusiveEntityRef; +pub type OpList = EntityList; +pub type OpCursor<'a> = EntityCursor<'a, Operation>; +pub type OpCursorMut<'a> = EntityCursorMut<'a, Operation>; + +/// The [Operation] struct provides the common foundation for all [Op] implementations. +/// +/// It provides: +/// +/// * Support for casting between the concrete operation type `T`, `dyn Op`, the underlying +/// `Operation`, and any of the operation traits that the op implements. Not only can the casts +/// be performed, but an [Operation] can be queried to see if it implements a specific trait at +/// runtime to conditionally perform some behavior. This makes working with operations in the IR +/// very flexible and allows for adding or modifying operations without needing to change most of +/// the compiler, which predominately works on operation traits rather than concrete ops. +/// * Storage for all IR entities attached to an operation, e.g. operands, results, nested regions, +/// attributes, etc. +/// * Navigation of the IR graph; navigate up to the containing block/region/op, down to nested +/// regions/blocks/ops, or next/previous sibling operations in the same block. Additionally, you +/// can navigate directly to the definitions of operands used, to users of results produced, and +/// to successor blocks. +/// * Many utility functions related to working with operations, many of which are also accessible +/// via the [Op] trait, so that working with an [Op] or an [Operation] are largely +/// indistinguishable. +/// +/// All [Op] implementations can be cast to the underlying [Operation], but most of the +/// fucntionality is re-exported via default implementations of methods on the [Op] trait. The main +/// benefit is avoiding any potential overhead of casting when going through the trait, rather than +/// calling the underlying [Operation] method directly. +/// +/// # Safety +/// +/// [Operation] is implemented as part of a larger structure that relies on assumptions which depend +/// on IR entities being allocated via [Context], i.e. the arena. Those allocations produce an +/// [UnsafeIntrusiveEntityRef] or [UnsafeEntityRef], which allocate the pointee type inside a struct +/// that provides metadata about the pointee that can be accessed without aliasing the pointee +/// itself - in particular, links for intrusive collections. This is important, because while these +/// pointer types are a bit like raw pointers in that they lack any lifetime information, and are +/// thus unsafe to dereference in general, they _do_ ensure that the pointee can be safely reified +/// as a reference without violating Rust's borrow checking rules, i.e. they are dynamically borrow- +/// checked. +/// +/// The reason why we are able to generally treat these "unsafe" references as safe, is because we +/// require that all IR entities be allocated via [Context]. This makes it essential to keep the +/// context around in order to work with the IR, and effectively guarantees that no [RawEntityRef] +/// will be dereferenced after the context is dropped. This is not a guarantee provided by the +/// compiler however, but one that is imposed in practice, as attempting to work with the IR in +/// any capacity without a [Context] is almost impossible. We must ensure however, that we work +/// within this set of rules to uphold the safety guarantees. +/// +/// This "fragility" is a tradeoff - we get the performance characteristics of an arena-allocated +/// IR, with the flexibility and power of using pointers rather than indexes as handles, while also +/// maintaining the safety guarantees of Rust's borrowing system. The downside is that we can't just +/// allocate IR entities wherever we want and use them the same way. +#[derive(Spanned)] +pub struct Operation { + /// The [Context] in which this [Operation] was allocated. + context: NonNull, + /// The dialect and opcode name for this operation, as well as trait implementation metadata + name: OperationName, + /// The offset of the field containing this struct inside the concrete [Op] it represents. + /// + /// This is required in order to be able to perform casts from [Operation]. An [Operation] + /// cannot be constructed without providing it to the `uninit` function, and callers of that + /// function are required to ensure that it is correct. + offset: usize, + /// The order of this operation in its containing block + /// + /// This is atomic to ensure that even if a mutable reference to this operation is held, loads + /// of this field cannot be elided, as the value can still be mutated at any time. In practice, + /// the only time this is ever written, is when all operations in a block have their orders + /// recomputed, or when a single operation is updating its own order. + order: AtomicU32, + #[span] + pub span: SourceSpan, + /// Attributes that apply to this operation + pub attrs: AttributeSet, + /// The containing block of this operation + /// + /// Is set to `None` if this operation is detached + pub block: Option, + /// The set of operands for this operation + /// + /// NOTE: If the op supports immediate operands, the storage for the immediates is handled + /// by the op, rather than here. Additionally, the semantics of the immediate operands are + /// determined by the op, e.g. whether the immediate operands are always applied first, or + /// what they are used for. + pub operands: OpOperandStorage, + /// The set of values produced by this operation. + pub results: OpResultStorage, + /// If this operation represents control flow, this field stores the set of successors, + /// and successor operands. + pub successors: OpSuccessorStorage, + /// The set of regions belonging to this operation, if any + pub regions: RegionList, +} +impl fmt::Debug for Operation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Operation") + .field_with("name", |f| write!(f, "{}", &self.name())) + .field("offset", &self.offset) + .field("order", &self.order) + .field("attrs", &self.attrs) + .field("block", &self.block.as_ref().map(|b| b.borrow().id())) + .field_with("operands", |f| { + let mut list = f.debug_list(); + for operand in self.operands().all() { + list.entry(&operand.borrow()); + } + list.finish() + }) + .field("results", &self.results) + .field("successors", &self.successors) + .finish_non_exhaustive() + } +} +impl fmt::Debug for OperationRef { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + fmt::Debug::fmt(&self.borrow(), f) + } +} + +impl AsRef for Operation { + fn as_ref(&self) -> &dyn Op { + self.name.upcast(self.container()).unwrap() + } +} + +impl AsMut for Operation { + fn as_mut(&mut self) -> &mut dyn Op { + self.name.upcast_mut(self.container().cast_mut()).unwrap() + } +} + +impl Entity for Operation {} +impl EntityWithParent for Operation { + type Parent = Block; + + fn on_inserted_into_parent( + mut this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + let mut op = this.borrow_mut(); + op.block = Some(parent); + op.order.store(Self::INVALID_ORDER, std::sync::atomic::Ordering::Release); + } + + fn on_removed_from_parent( + mut this: UnsafeIntrusiveEntityRef, + _parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().block = None; + } + + fn on_transfered_to_new_parent( + from: UnsafeIntrusiveEntityRef, + mut to: UnsafeIntrusiveEntityRef, + transferred: impl IntoIterator>, + ) { + // Invalidate the ordering of the new parent block + to.borrow_mut().invalidate_op_order(); + + // If we are transferring operations within the same block, the block pointer doesn't + // need to be updated + if BlockRef::ptr_eq(&from, &to) { + return; + } + + for mut transferred_op in transferred { + transferred_op.borrow_mut().block = Some(to.clone()); + } + } +} + +/// Construction +impl Operation { + #[doc(hidden)] + pub unsafe fn uninit(context: Rc, name: OperationName, offset: usize) -> Self { + assert!(name.is::()); + + Self { + context: unsafe { NonNull::new_unchecked(Rc::as_ptr(&context).cast_mut()) }, + name, + offset, + order: AtomicU32::new(0), + span: Default::default(), + attrs: Default::default(), + block: Default::default(), + operands: Default::default(), + results: Default::default(), + successors: Default::default(), + regions: Default::default(), + } + } +} + +/// Metadata +impl Operation { + /// Get the name of this operation + /// + /// An operation name consists of both its dialect, and its opcode. + pub fn name(&self) -> OperationName { + self.name.clone() + } + + /// Get the dialect associated with this operation + pub fn dialect(&self) -> Rc { + self.context().get_registered_dialect(self.name.dialect()) + } + + /// Set the source location associated with this operation + #[inline] + pub fn set_span(&mut self, span: SourceSpan) { + self.span = span; + } + + /// Get a borrowed reference to the owning [Context] of this operation + #[inline(always)] + pub fn context(&self) -> &Context { + // SAFETY: This is safe so long as this operation is allocated in a Context, since the + // Context by definition outlives the allocation. + unsafe { self.context.as_ref() } + } + + /// Get a owned reference to the owning [Context] of this operation + pub fn context_rc(&self) -> Rc { + // SAFETY: This is safe so long as this operation is allocated in a Context, since the + // Context by definition outlives the allocation. + // + // Additionally, constructing the Rc from a raw pointer is safe here, as the pointer was + // obtained using `Rc::as_ptr`, so the only requirement to call `Rc::from_raw` is to + // increment the strong count, as `as_ptr` does not preserve the count for the reference + // held by this operation. Incrementing the count first is required to manufacture new + // clones of the `Rc` safely. + unsafe { + let ptr = self.context.as_ptr().cast_const(); + Rc::increment_strong_count(ptr); + Rc::from_raw(ptr) + } + } +} + +/// Verification +impl Operation { + /// Run any verifiers for this operation + pub fn verify(&self, context: &Context) -> Result<(), Report> { + let dyn_op: &dyn Op = self.as_ref(); + dyn_op.verify(context) + } + + /// Run any verifiers for this operation, and all of its nested operations, recursively. + /// + /// The verification is performed in post-order, so that when the verifier(s) for `self` are + /// run, it is known that all of its children have successfully verified. + pub fn recursively_verify(&self, context: &Context) -> Result<(), Report> { + self.postwalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + op.verify(context).into() + }) + .into_result() + } +} + +/// Traits/Casts +impl Operation { + pub(super) const fn container(&self) -> *const () { + unsafe { + let ptr = self as *const Self; + ptr.byte_sub(self.offset).cast() + } + } + + #[inline(always)] + pub fn as_operation_ref(&self) -> OperationRef { + // SAFETY: This is safe under the assumption that we always allocate Operations using the + // arena, i.e. it is a child of a RawEntityMetadata structure. + unsafe { OperationRef::from_raw(self) } + } + + /// Returns true if the concrete type of this operation is `T` + #[inline] + pub fn is(&self) -> bool { + self.name.is::() + } + + /// Returns true if this operation implements `Trait` + #[inline] + pub fn implements(&self) -> bool + where + Trait: ?Sized + Pointee> + 'static, + { + self.name.implements::() + } + + /// Attempt to downcast to the concrete [Op] type of this operation + pub fn downcast_ref(&self) -> Option<&T> { + self.name.downcast_ref::(self.container()) + } + + /// Attempt to downcast to the concrete [Op] type of this operation + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.name.downcast_mut::(self.container().cast_mut()) + } + + /// Attempt to cast this operation reference to an implementation of `Trait` + pub fn as_trait(&self) -> Option<&Trait> + where + Trait: ?Sized + Pointee> + 'static, + { + self.name.upcast(self.container()) + } + + /// Attempt to cast this operation reference to an implementation of `Trait` + pub fn as_trait_mut(&mut self) -> Option<&mut Trait> + where + Trait: ?Sized + Pointee> + 'static, + { + self.name.upcast_mut(self.container().cast_mut()) + } +} + +/// Attributes +impl Operation { + /// Get the underlying attribute set for this operation + #[inline(always)] + pub fn attributes(&self) -> &AttributeSet { + &self.attrs + } + + /// Get a mutable reference to the underlying attribute set for this operation + #[inline(always)] + pub fn attributes_mut(&mut self) -> &mut AttributeSet { + &mut self.attrs + } + + /// Return the value associated with attribute `name` for this function + pub fn get_attribute(&self, name: impl Into) -> Option<&dyn AttributeValue> { + self.attrs.get_any(name.into()) + } + + /// Return the value associated with attribute `name` for this function + pub fn get_attribute_mut( + &mut self, + name: impl Into, + ) -> Option<&mut dyn AttributeValue> { + self.attrs.get_any_mut(name.into()) + } + + /// Return the value associated with attribute `name` for this function, as its concrete type + /// `T`, _if_ the attribute by that name, is of that type. + pub fn get_typed_attribute(&self, name: impl Into) -> Option<&T> + where + T: AttributeValue, + { + self.attrs.get(name.into()) + } + + /// Return the value associated with attribute `name` for this function, as its concrete type + /// `T`, _if_ the attribute by that name, is of that type. + pub fn get_typed_attribute_mut( + &mut self, + name: impl Into, + ) -> Option<&mut T> + where + T: AttributeValue, + { + self.attrs.get_mut(name.into()) + } + + /// Return true if this function has an attributed named `name` + pub fn has_attribute(&self, name: impl Into) -> bool { + self.attrs.has(name.into()) + } + + /// Set the attribute `name` with `value` for this function. + pub fn set_attribute( + &mut self, + name: impl Into, + value: Option, + ) { + self.attrs.insert(name, value); + } + + /// Remove any attribute with the given name from this function + pub fn remove_attribute(&mut self, name: impl Into) { + self.attrs.remove(name.into()); + } +} + +/// Symbol Attributes +impl Operation { + pub fn set_symbol_attribute( + &mut self, + name: impl Into, + symbol: impl AsSymbolRef, + ) { + let name = name.into(); + let mut symbol = symbol.as_symbol_ref(); + + // Store the underlying attribute value + let user = self.context().alloc_tracked(SymbolUse { + owner: self.as_operation_ref(), + symbol: name, + }); + if self.has_attribute(name) { + let attr = self.get_typed_attribute_mut::(name).unwrap(); + let symbol = symbol.borrow(); + assert!( + !attr.user.is_linked(), + "attempted to replace symbol use without unlinking the previously used symbol \ + first" + ); + attr.user = user.clone(); + attr.name = symbol.name(); + attr.path = symbol.components().into_path(true); + } else { + let attr = { + let symbol = symbol.borrow(); + let name = symbol.name(); + let path = symbol.components().into_path(true); + SymbolNameAttr { + name, + path, + user: user.clone(), + } + }; + self.set_attribute(name, Some(attr)); + } + + // Add `self` as a user of `symbol`, unless `self` is `symbol` + let (data_ptr, _) = SymbolRef::as_ptr(&symbol).to_raw_parts(); + if core::ptr::addr_eq(data_ptr, self.container()) { + return; + } + + let mut symbol = symbol.borrow_mut(); + let symbol_uses = symbol.uses_mut(); + symbol_uses.push_back(user); + } +} + +/// Navigation +impl Operation { + /// Returns a handle to the containing [Block] of this operation, if it is attached to one + pub fn parent(&self) -> Option { + self.block.clone() + } + + /// Returns a handle to the containing [Region] of this operation, if it is attached to one + pub fn parent_region(&self) -> Option { + self.block.as_ref().and_then(|block| block.borrow().parent()) + } + + /// Returns a handle to the nearest containing [Operation] of this operation, if it is attached + /// to one + pub fn parent_op(&self) -> Option { + self.block.as_ref().and_then(|block| block.borrow().parent_op()) + } + + /// Returns a handle to the nearest containing [Operation] of type `T` for this operation, if it + /// is attached to one + pub fn nearest_parent_op(&self) -> Option> { + let mut parent = self.parent_op(); + while let Some(op) = parent.take() { + let entity_ref = op.borrow(); + parent = entity_ref.parent_op(); + if let Some(t_ref) = entity_ref.downcast_ref::() { + return Some(unsafe { UnsafeIntrusiveEntityRef::from_raw(t_ref) }); + } + } + None + } +} + +/// Regions +impl Operation { + /// Returns true if this operation has any regions + #[inline] + pub fn has_regions(&self) -> bool { + !self.regions.is_empty() + } + + /// Returns the number of regions owned by this operation. + /// + /// NOTE: This does not include regions of nested operations, just those directly attached + /// to this operation. + #[inline] + pub fn num_regions(&self) -> usize { + self.regions.len() + } + + /// Get a reference to the region list for this operation + #[inline(always)] + pub fn regions(&self) -> &RegionList { + &self.regions + } + + /// Get a mutable reference to the region list for this operation + #[inline(always)] + pub fn regions_mut(&mut self) -> &mut RegionList { + &mut self.regions + } + + /// Get a reference to a specific region, given its index. + /// + /// This function will panic if the index is invalid. + pub fn region(&self, index: usize) -> EntityRef<'_, Region> { + let mut cursor = self.regions.front(); + let mut count = 0; + while !cursor.is_null() { + if index == count { + return cursor.into_borrow().unwrap(); + } + cursor.move_next(); + count += 1; + } + panic!("invalid region index {index}: out of bounds"); + } + + /// Get a mutable reference to a specific region, given its index. + /// + /// This function will panic if the index is invalid. + pub fn region_mut(&mut self, index: usize) -> EntityMut<'_, Region> { + let mut cursor = self.regions.front_mut(); + let mut count = 0; + while !cursor.is_null() { + if index == count { + return cursor.into_borrow_mut().unwrap(); + } + cursor.move_next(); + count += 1; + } + panic!("invalid region index {index}: out of bounds"); + } +} + +/// Successors +impl Operation { + /// Returns true if this operation has any successor blocks + #[inline] + pub fn has_successors(&self) -> bool { + !self.successors.is_empty() + } + + /// Returns the number of successor blocks this operation may transfer control to + #[inline] + pub fn num_successors(&self) -> usize { + self.successors.len() + } + + /// Get a reference to the successors of this operation + #[inline(always)] + pub fn successors(&self) -> &OpSuccessorStorage { + &self.successors + } + + /// Get a mutable reference to the successors of this operation + #[inline(always)] + pub fn successors_mut(&mut self) -> &mut OpSuccessorStorage { + &mut self.successors + } + + /// Get a reference to the successor group at `index` + #[inline] + pub fn successor_group(&self, index: usize) -> OpSuccessorRange<'_> { + self.successors.group(index) + } + + /// Get a mutable reference to the successor group at `index` + #[inline] + pub fn successor_group_mut(&mut self, index: usize) -> OpSuccessorRangeMut<'_> { + self.successors.group_mut(index) + } + + /// Get a reference to the keyed successor group at `index` + #[inline] + pub fn keyed_successor_group(&self, index: usize) -> KeyedSuccessorRange<'_, T> + where + T: KeyedSuccessor, + { + let range = self.successors.group(index); + KeyedSuccessorRange::new(range, &self.operands) + } + + /// Get a mutable reference to the keyed successor group at `index` + #[inline] + pub fn keyed_successor_group_mut(&mut self, index: usize) -> KeyedSuccessorRangeMut<'_, T> + where + T: KeyedSuccessor, + { + let range = self.successors.group_mut(index); + KeyedSuccessorRangeMut::new(range, &mut self.operands) + } + + /// Get a reference to the successor at `index` in the group at `group_index` + #[inline] + pub fn successor_in_group(&self, group_index: usize, index: usize) -> OpSuccessor<'_> { + let info = &self.successors.group(group_index)[index]; + OpSuccessor { + dest: info.block.clone(), + arguments: self.operands.group(info.operand_group as usize), + } + } + + /// Get a mutable reference to the successor at `index` in the group at `group_index` + #[inline] + pub fn successor_in_group_mut( + &mut self, + group_index: usize, + index: usize, + ) -> OpSuccessorMut<'_> { + let info = &self.successors.group(group_index)[index]; + OpSuccessorMut { + dest: info.block.clone(), + arguments: self.operands.group_mut(info.operand_group as usize), + } + } + + /// Get a reference to the successor at `index` + #[inline] + pub fn successor(&self, index: usize) -> OpSuccessor<'_> { + let info = &self.successors[index]; + OpSuccessor { + dest: info.block.clone(), + arguments: self.operands.group(info.operand_group as usize), + } + } + + /// Get a mutable reference to the successor at `index` + #[inline] + pub fn successor_mut(&mut self, index: usize) -> OpSuccessorMut<'_> { + let info = self.successors[index].clone(); + OpSuccessorMut { + dest: info.block, + arguments: self.operands.group_mut(info.operand_group as usize), + } + } + + /// Get an iterator over the successors of this operation + pub fn successor_iter(&self) -> impl DoubleEndedIterator> + '_ { + self.successors.iter().map(|info| OpSuccessor { + dest: info.block.clone(), + arguments: self.operands.group(info.operand_group as usize), + }) + } +} + +/// Operands +impl Operation { + /// Returns true if this operation has at least one operand + #[inline] + pub fn has_operands(&self) -> bool { + !self.operands.is_empty() + } + + /// Returns the number of operands given to this operation + #[inline] + pub fn num_operands(&self) -> usize { + self.operands.len() + } + + /// Get a reference to the operand storage for this operation + #[inline] + pub fn operands(&self) -> &OpOperandStorage { + &self.operands + } + + /// Get a mutable reference to the operand storage for this operation + #[inline] + pub fn operands_mut(&mut self) -> &mut OpOperandStorage { + &mut self.operands + } + + /// Replace the current operands of this operation with the ones provided in `operands`. + pub fn set_operands(&mut self, operands: impl Iterator) { + self.operands.clear(); + let context = self.context_rc(); + let owner = self.as_operation_ref(); + self.operands.extend( + operands + .into_iter() + .enumerate() + .map(|(index, value)| context.make_operand(value, owner.clone(), index as u8)), + ); + } + + /// Replace any uses of `from` with `to` within this operation + pub fn replaces_uses_of_with(&mut self, mut from: ValueRef, mut to: ValueRef) { + if ValueRef::ptr_eq(&from, &to) { + return; + } + + for operand in self.operands.iter_mut() { + debug_assert!(operand.is_linked()); + if ValueRef::ptr_eq(&from, &operand.borrow().value) { + // Remove use of `from` by `operand` + { + let mut from_mut = from.borrow_mut(); + let from_uses = from_mut.uses_mut(); + let mut cursor = unsafe { from_uses.cursor_mut_from_ptr(operand.clone()) }; + cursor.remove(); + } + // Add use of `to` by `operand` + operand.borrow_mut().value = to.clone(); + to.borrow_mut().insert_use(operand.clone()); + } + } + } + + /// Replace all uses of this operation's results with `values` + /// + /// The number of results and the number of values in `values` must be exactly the same, + /// otherwise this function will panic. + pub fn replace_all_uses_with(&mut self, values: impl ExactSizeIterator) { + assert_eq!(self.num_results(), values.len()); + for (result, replacement) in self.results.iter_mut().zip(values) { + if ValueRef::ptr_eq(&result.clone().upcast(), &replacement) { + continue; + } + result.borrow_mut().replace_all_uses_with(replacement); + } + } + + /// Replace uses of this operation's results with `values`, for each use which, when provided + /// to the given callback, returns true. + /// + /// The number of results and the number of values in `values` must be exactly the same, + /// otherwise this function will panic. + pub fn replace_uses_with_if(&mut self, values: V, should_replace: F) + where + V: ExactSizeIterator, + F: Fn(&OpOperandImpl) -> bool, + { + assert_eq!(self.num_results(), values.len()); + for (result, replacement) in self.results.iter_mut().zip(values) { + let mut result = result.clone().upcast(); + if ValueRef::ptr_eq(&result, &replacement) { + continue; + } + result.borrow_mut().replace_uses_with_if(replacement, &should_replace); + } + } +} + +/// Results +impl Operation { + /// Returns true if this operation produces any results + #[inline] + pub fn has_results(&self) -> bool { + !self.results.is_empty() + } + + /// Returns the number of results produced by this operation + #[inline] + pub fn num_results(&self) -> usize { + self.results.len() + } + + /// Get a reference to the result set of this operation + #[inline] + pub fn results(&self) -> &OpResultStorage { + &self.results + } + + /// Get a mutable reference to the result set of this operation + #[inline] + pub fn results_mut(&mut self) -> &mut OpResultStorage { + &mut self.results + } + + /// Get a reference to the result at `index` among all results of this operation + #[inline] + pub fn get_result(&self, index: usize) -> &OpResultRef { + &self.results[index] + } + + /// Returns true if the results of this operation are used + pub fn is_used(&self) -> bool { + self.results.iter().any(|result| result.borrow().is_used()) + } + + /// Returns true if the results of this operation have exactly one user + pub fn has_exactly_one_use(&self) -> bool { + let mut used_by = None; + for result in self.results.iter() { + let result = result.borrow(); + if !result.is_used() { + continue; + } + + for used in result.iter_uses() { + if used_by.as_ref().is_some_and(|user| !OperationRef::eq(user, &used.owner)) { + // We found more than one user + return false; + } else if used_by.is_none() { + used_by = Some(used.owner.clone()); + } + } + } + + // If we reach here, and we have a `used_by` set, we have exactly one user + used_by.is_some() + } + + /// Returns true if the results of this operation are used outside of the given block + pub fn is_used_outside_of_block(&self, block: &BlockRef) -> bool { + self.results + .iter() + .any(|result| result.borrow().is_used_outside_of_block(block)) + } + + /// Returns true if this operation is unused and has no side effects that prevent it being erased + pub fn is_trivially_dead(&self) -> bool { + !self.is_used() && self.would_be_trivially_dead() + } + + /// Returns true if this operation would be dead if unused, and has no side effects that would + /// prevent erasing it. This is equivalent to checking `is_trivially_dead` if `self` is unused. + /// + /// NOTE: Terminators and symbols are never considered to be trivially dead by this function. + pub fn would_be_trivially_dead(&self) -> bool { + if self.implements::() || self.implements::() { + false + } else { + self.would_be_trivially_dead_even_if_terminator() + } + } + + /// Implementation of `would_be_trivially_dead` that also considers terminator operations as + /// dead if they have no side effects. This allows for marking region operations as trivially + /// dead without always being conservative about terminators. + pub fn would_be_trivially_dead_even_if_terminator(&self) -> bool { + // The set of operations to consider when checking for side effects + let mut effecting_ops = SmallVec::<[OperationRef; 1]>::from_iter([self.as_operation_ref()]); + while let Some(op) = effecting_ops.pop() { + let op = op.borrow(); + // If the operation has recursive effects, push all of the nested operations on to the + // stack to consider. + let has_recursive_effects = + op.implements::(); + if has_recursive_effects { + for region in op.regions() { + for block in region.body() { + let mut cursor = block.body().front(); + while let Some(op) = cursor.as_pointer() { + effecting_ops.push(op); + cursor.move_next(); + } + } + } + } + + // If the op has memory effects, try to characterize them to see if the op is trivially + // dead here. + if op.implements::() + || op.implements::() + { + return false; + } + + // If there were no effect interfaces, we treat this op as conservatively having effects + if !op.implements::() + && !op.implements::() + { + return false; + } + } + + // If we get here, none of the operations had effects that prevented marking this operation + // as dead. + true + } +} + +/// Insertion +impl Operation { + pub fn insert_at_start(&mut self, mut block: BlockRef) { + assert!( + self.block.is_none(), + "cannot insert operation that is already attached to another block" + ); + { + let mut block = block.borrow_mut(); + block.body_mut().push_front(unsafe { OperationRef::from_raw(self) }); + } + self.block = Some(block); + } + + pub fn insert_at_end(&mut self, mut block: BlockRef) { + assert!( + self.block.is_none(), + "cannot insert operation that is already attached to another block" + ); + { + let mut block = block.borrow_mut(); + block.body_mut().push_back(unsafe { OperationRef::from_raw(self) }); + } + self.block = Some(block); + } + + pub fn insert_before(&mut self, before: OperationRef) { + assert!( + self.block.is_none(), + "cannot insert operation that is already attached to another block" + ); + let mut block = + before.borrow().parent().expect("'before' block is not attached to a block"); + { + let mut block = block.borrow_mut(); + let block_body = block.body_mut(); + let mut cursor = unsafe { block_body.cursor_mut_from_ptr(before) }; + cursor.insert_before(unsafe { OperationRef::from_raw(self) }); + } + self.block = Some(block); + } + + pub fn insert_after(&mut self, after: OperationRef) { + assert!( + self.block.is_none(), + "cannot insert operation that is already attached to another block" + ); + let mut block = after.borrow().parent().expect("'after' block is not attached to a block"); + { + let mut block = block.borrow_mut(); + let block_body = block.body_mut(); + let mut cursor = unsafe { block_body.cursor_mut_from_ptr(after) }; + cursor.insert_after(unsafe { OperationRef::from_raw(self) }); + } + self.block = Some(block); + } +} + +/// Movement +impl Operation { + /// Remove this operation (and its descendants) from its containing block, and delete them + #[inline] + pub fn erase(&mut self) { + // We don't delete entities currently, so for now this is just an alias for `remove` + self.remove() + } + + /// Remove the operation from its parent block, but don't delete it. + pub fn remove(&mut self) { + let Some(mut parent) = self.block.take() else { + return; + }; + let mut block = parent.borrow_mut(); + let body = block.body_mut(); + let mut cursor = unsafe { body.cursor_mut_from_ptr(OperationRef::from_raw(self)) }; + cursor.remove(); + } + + /// Unlink this operation from its current block and insert it right before `ip`, which may + /// be in the same or another block in the same function. + pub fn move_before(&mut self, ip: ProgramPoint) { + self.remove(); + match ip { + ProgramPoint::Op(other) => { + self.insert_before(other); + } + ProgramPoint::Block(block) => { + self.insert_at_start(block); + } + } + } + + /// Unlink this operation from its current block and insert it right after `ip`, which may + /// be in the same or another block in the same function. + pub fn move_after(&mut self, ip: ProgramPoint) { + self.remove(); + match ip { + ProgramPoint::Op(other) => { + self.insert_after(other); + } + ProgramPoint::Block(block) => { + self.insert_at_end(block); + } + } + } + + /// This drops all operand uses from this operation, which is used to break cyclic dependencies + /// between references when they are to be deleted + pub fn drop_all_references(&mut self) { + self.operands.clear(); + + { + let mut region_cursor = self.regions.front_mut(); + while let Some(mut region) = region_cursor.as_pointer() { + region.borrow_mut().drop_all_references(); + region_cursor.move_next(); + } + } + + self.successors.clear(); + } + + /// This drops all uses of any values defined by this operation or its nested regions, + /// wherever they are located. + pub fn drop_all_defined_value_uses(&mut self) { + for result in self.results.iter_mut() { + let mut res = result.borrow_mut(); + res.uses_mut().clear(); + } + + let mut regions = self.regions.front_mut(); + while let Some(mut region) = regions.as_pointer() { + let mut region = region.borrow_mut(); + let blocks = region.body_mut(); + let mut cursor = blocks.front_mut(); + while let Some(mut block) = cursor.as_pointer() { + block.borrow_mut().drop_all_defined_value_uses(); + cursor.move_next(); + } + regions.move_next(); + } + } + + /// Drop all uses of results of this operation + pub fn drop_all_uses(&mut self) { + for result in self.results.iter_mut() { + result.borrow_mut().uses_mut().clear(); + } + } +} + +/// Ordering +impl Operation { + /// This value represents an invalid index ordering for an operation within its containing block + const INVALID_ORDER: u32 = u32::MAX; + /// This value represents the stride to use when computing a new order for an operation + const ORDER_STRIDE: u32 = 5; + + /// Returns true if this operation is an ancestor of `other`. + /// + /// An operation is considered its own ancestor, use [Self::is_proper_ancestor_of] if you do not + /// want this behavior. + pub fn is_ancestor_of(&self, other: &OperationRef) -> bool { + let this = self.as_operation_ref(); + OperationRef::ptr_eq(&this, other) || Self::is_a_proper_ancestor_of_b(&this, other) + } + + /// Returns true if this operation is a proper ancestor of `other` + pub fn is_proper_ancestor_of(&self, other: &OperationRef) -> bool { + let this = self.as_operation_ref(); + Self::is_a_proper_ancestor_of_b(&this, other) + } + + /// Returns true if operation `a` is a proper ancestor of operation `b` + fn is_a_proper_ancestor_of_b(a: &OperationRef, b: &OperationRef) -> bool { + let mut next = b.borrow().parent_op(); + while let Some(b) = next.take() { + if OperationRef::ptr_eq(a, &b) { + return true; + } + } + false + } + + /// Given an operation `other` that is within the same parent block, return whether the current + /// operation is before it in the operation list. + /// + /// NOTE: This function has an average complexity of O(1), but worst case may take O(N) where + /// N is the number of operations within the parent block. + pub fn is_before_in_block(&self, other: &OperationRef) -> bool { + use core::sync::atomic::Ordering; + + let block = self.block.clone().expect("operations without parent blocks have no order"); + let other = other.borrow(); + assert!( + other + .block + .as_ref() + .is_some_and(|other_block| BlockRef::ptr_eq(&block, other_block)), + "expected both operations to have the same parent block" + ); + + // If the order of the block is already invalid, directly recompute the parent + if !block.borrow().is_op_order_valid() { + Self::recompute_block_order(block); + } else { + // Update the order of either operation if necessary. + self.update_order_if_necessary(); + other.update_order_if_necessary(); + } + + self.order.load(Ordering::Relaxed) < other.order.load(Ordering::Relaxed) + } + + /// Update the order index of this operation of this operation if necessary, + /// potentially recomputing the order of the parent block. + fn update_order_if_necessary(&self) { + use core::sync::atomic::Ordering; + + assert!(self.block.is_some(), "expected valid parent"); + + // If the order is valid for this operation there is nothing to do. + let block = self.block.clone().unwrap(); + if self.has_valid_order() || block.borrow().body().iter().count() == 1 { + return; + } + + let this = self.as_operation_ref(); + let prev = this.prev(); + let next = this.next(); + assert!(prev.is_some() || next.is_some(), "expected more than one operation in block"); + + // If the operation is at the end of the block. + if next.is_none() { + let prev = prev.unwrap(); + let prev = prev.borrow(); + let prev_order = prev.order.load(Ordering::Acquire); + if prev_order == Self::INVALID_ORDER { + return Self::recompute_block_order(block); + } + + // Add the stride to the previous operation. + self.order.store(prev_order + Self::ORDER_STRIDE, Ordering::Release); + return; + } + + // If this is the first operation try to use the next operation to compute the + // ordering. + if prev.is_none() { + let next = next.unwrap(); + let next = next.borrow(); + let next_order = next.order.load(Ordering::Acquire); + match next_order { + Self::INVALID_ORDER | 0 => { + return Self::recompute_block_order(block); + } + // If we can't use the stride, just take the middle value left. This is safe + // because we know there is at least one valid index to assign to. + order if order <= Self::ORDER_STRIDE => { + self.order.store(order / 2, Ordering::Release); + } + _ => { + self.order.store(Self::ORDER_STRIDE, Ordering::Release); + } + } + return; + } + + // Otherwise, this operation is between two others. Place this operation in + // the middle of the previous and next if possible. + let prev = prev.unwrap().borrow().order.load(Ordering::Acquire); + let next = next.unwrap().borrow().order.load(Ordering::Acquire); + if prev == Self::INVALID_ORDER || next == Self::INVALID_ORDER { + return Self::recompute_block_order(block); + } + + // Check to see if there is a valid order between the two. + if prev + 1 == next { + return Self::recompute_block_order(block); + } + self.order.store(prev + ((next - prev) / 2), Ordering::Release); + } + + fn recompute_block_order(mut block: BlockRef) { + use core::sync::atomic::Ordering; + + let mut block = block.borrow_mut(); + let mut cursor = block.body().front(); + let mut index = 0; + while let Some(op) = cursor.as_pointer() { + index += Self::ORDER_STRIDE; + cursor.move_next(); + let ptr = OperationRef::as_ptr(&op); + unsafe { + let order_addr = core::ptr::addr_of!((*ptr).order); + (*order_addr).store(index, Ordering::Release); + } + } + + block.mark_op_order_valid(); + } + + /// Returns `None` if this operation has invalid ordering + #[inline] + pub(super) fn order(&self) -> Option { + use core::sync::atomic::Ordering; + match self.order.load(Ordering::Acquire) { + Self::INVALID_ORDER => None, + order => Some(order), + } + } + + /// Returns true if this operation has a valid order + #[inline(always)] + pub(super) fn has_valid_order(&self) -> bool { + self.order().is_some() + } +} + +impl crate::traits::Foldable for Operation { + fn fold(&self, results: &mut smallvec::SmallVec<[OpFoldResult; 1]>) -> FoldResult { + use crate::traits::Foldable; + + if let Some(foldable) = self.as_trait::() { + foldable.fold(results) + } else { + FoldResult::Failed + } + } + + fn fold_with<'operands>( + &self, + operands: &[Option>], + results: &mut smallvec::SmallVec<[OpFoldResult; 1]>, + ) -> FoldResult { + use crate::traits::Foldable; + + if let Some(foldable) = self.as_trait::() { + foldable.fold_with(operands, results) + } else { + FoldResult::Failed + } + } +} diff --git a/hir2/src/ir/operation/builder.rs b/hir2/src/ir/operation/builder.rs new file mode 100644 index 000000000..8856b06a1 --- /dev/null +++ b/hir2/src/ir/operation/builder.rs @@ -0,0 +1,248 @@ +use crate::{ + traits::Terminator, AsCallableSymbolRef, AsSymbolRef, AttributeValue, BlockRef, Builder, + KeyedSuccessor, Op, OpBuilder, OperationRef, Region, Report, Spanned, SuccessorInfo, Type, + UnsafeIntrusiveEntityRef, ValueRef, +}; + +/// The [OperationBuilder] is a primitive for imperatively constructing an [Operation]. +/// +/// Currently, this is primarily used by our `#[operation]` macro infrastructure, to finalize +/// construction of the underlying [Operation] of an [Op] implementation, after both have been +/// allocated and initialized with only basic metadata. This builder is then used to add all of +/// the data under the op, e.g. operands, results, attributes, etc. Once complete, verification is +/// run on the constructed op. +/// +/// Using this directly is possible, see [OperationBuilder::new] for details. You may also find it +/// useful to examine the expansion of the `#[operation]` macro for existing ops to understand what goes +/// on behind the scenes for most ops. +pub struct OperationBuilder<'a, T, B: ?Sized = OpBuilder> { + builder: &'a mut B, + op: OperationRef, + _marker: core::marker::PhantomData, +} +impl<'a, T, B> OperationBuilder<'a, T, B> +where + T: Op, + B: ?Sized + Builder, +{ + /// Create a new [OperationBuilder] for `op` using the provided [Builder]. + /// + /// The [Operation] underlying `op` must have been initialized correctly: + /// + /// * Allocated via the same context as `builder` + /// * Initialized via [crate::Operation::uninit] + /// * All op traits implemented by `T` must have been registered with its [OperationName] + /// * All fields of `T` must have been initialized to actual or default values. This builder + /// will invoke verification at the end, and if `T` is not correctly initialized, it will + /// result in undefined behavior. + pub fn new(builder: &'a mut B, op: UnsafeIntrusiveEntityRef) -> Self { + let op = unsafe { UnsafeIntrusiveEntityRef::from_raw(op.borrow().as_operation()) }; + Self { + builder, + op, + _marker: core::marker::PhantomData, + } + } + + /// Set attribute `name` on this op to `value` + #[inline] + pub fn with_attr(&mut self, name: &'static str, value: A) + where + A: AttributeValue, + { + self.op.borrow_mut().set_attribute(name, Some(value)); + } + + /// Set symbol `attr_name` on this op to `symbol`. + /// + /// Symbol references are stored as attributes, and have similar semantics to operands, i.e. + /// they require tracking uses. + #[inline] + pub fn with_symbol(&mut self, attr_name: &'static str, symbol: impl AsSymbolRef) { + self.op.borrow_mut().set_symbol_attribute(attr_name, symbol); + } + + /// Like [with_symbol], but further constrains the range of valid input symbols to those which + /// are valid [CallableOpInterface] implementations. + #[inline] + pub fn with_callable_symbol( + &mut self, + attr_name: &'static str, + callable: impl AsCallableSymbolRef, + ) { + let callable = callable.as_callable_symbol_ref(); + self.op.borrow_mut().set_symbol_attribute(attr_name, callable); + } + + /// Add a new [Region] to this operation. + /// + /// NOTE: You must ensure this is called _after_ [Self::with_operands], and [Self::implements] + /// if the op implements the [traits::NoRegionArguments] trait. Otherwise, the inserted region + /// may not be valid for this op. + pub fn create_region(&mut self) { + let mut region = Region::default(); + unsafe { + region.set_owner(Some(self.op.clone())); + } + let region = self.builder.context().alloc_tracked(region); + let mut op = self.op.borrow_mut(); + op.regions.push_back(region); + } + + pub fn with_successor( + &mut self, + dest: BlockRef, + arguments: impl IntoIterator, + ) { + let owner = self.op.clone(); + // Insert operand group for this successor + let mut op = self.op.borrow_mut(); + let operand_group = + op.operands.push_group(arguments.into_iter().enumerate().map(|(index, arg)| { + self.builder.context().make_operand(arg, owner.clone(), index as u8) + })); + // Record SuccessorInfo for this successor in the op + let succ_index = u8::try_from(op.successors.len()).expect("too many successors"); + let successor = self.builder.context().make_block_operand(dest.clone(), owner, succ_index); + op.successors.push_group([SuccessorInfo { + block: successor, + key: None, + operand_group: operand_group.try_into().expect("too many operand groups"), + }]); + } + + pub fn with_successors(&mut self, succs: I) + where + I: IntoIterator)>, + { + let owner = self.op.clone(); + let mut op = self.op.borrow_mut(); + let mut group = vec![]; + for (i, (block, args)) in succs.into_iter().enumerate() { + let block = self.builder.context().make_block_operand(block, owner.clone(), i as u8); + let operands = args + .into_iter() + .map(|value_ref| self.builder.context().make_operand(value_ref, owner.clone(), 0)); + let operand_group = op.operands.push_group(operands); + group.push(SuccessorInfo { + block, + key: None, + operand_group: operand_group.try_into().expect("too many operand groups"), + }); + } + op.successors.push_group(group); + } + + pub fn with_keyed_successors(&mut self, succs: I) + where + S: KeyedSuccessor, + I: IntoIterator, + { + let owner = self.op.clone(); + let mut op = self.op.borrow_mut(); + let mut group = vec![]; + for (i, successor) in succs.into_iter().enumerate() { + let (key, block, args) = successor.into_parts(); + let block = self.builder.context().make_block_operand(block, owner.clone(), i as u8); + let operands = args + .into_iter() + .map(|value_ref| self.builder.context().make_operand(value_ref, owner.clone(), 0)); + let operand_group = op.operands.push_group(operands); + let key = Box::new(key); + let key = unsafe { core::ptr::NonNull::new_unchecked(Box::into_raw(key)) }; + group.push(SuccessorInfo { + block, + key: Some(key.cast()), + operand_group: operand_group.try_into().expect("too many operand groups"), + }); + } + op.successors.push_group(group); + } + + /// Append operands to the set of operands given to this op so far. + pub fn with_operands(&mut self, operands: I) + where + I: IntoIterator, + { + let owner = self.op.clone(); + let operands = operands.into_iter().enumerate().map(|(index, value)| { + self.builder.context().make_operand(value, owner.clone(), index as u8) + }); + let mut op = self.op.borrow_mut(); + op.operands.extend(operands); + } + + /// Append operands to the set of operands in operand group `group` + pub fn with_operands_in_group(&mut self, group: usize, operands: I) + where + I: IntoIterator, + { + let owner = self.op.clone(); + let operands = operands.into_iter().enumerate().map(|(index, value)| { + self.builder.context().make_operand(value, owner.clone(), index as u8) + }); + let mut op = self.op.borrow_mut(); + op.operands.extend_group(group, operands); + } + + /// Allocate `n` results for this op, of unknown type, to be filled in later + pub fn with_results(&mut self, n: usize) { + let span = self.op.borrow().span; + let owner = self.op.clone(); + let results = (0..n).map(|idx| { + self.builder + .context() + .make_result(span, Type::Unknown, owner.clone(), idx as u8) + }); + let mut op = self.op.borrow_mut(); + op.results.clear(); + op.results.extend(results); + } + + /// Consume this builder, verify the op, and return a handle to it, or an error if validation + /// failed. + pub fn build(mut self) -> Result, Report> { + let op = { + let mut op = self.op.borrow_mut(); + + // Infer result types and apply any associated validation + if let Some(interface) = op.as_trait_mut::() { + interface.infer_return_types(self.builder.context())?; + } + + // Verify things that would require negative trait impls + if !op.implements::() && op.has_successors() { + return Err(self + .builder + .context() + .session + .diagnostics + .diagnostic(miden_assembly::diagnostics::Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation has successors, but does not implement the 'Terminator' \ + trait", + ) + .with_help("operations with successors must implement the 'Terminator' trait") + .into_report()); + } + + unsafe { UnsafeIntrusiveEntityRef::from_raw(op.container().cast()) } + }; + + // Run op-specific verification + { + let op: super::EntityRef = op.borrow(); + //let op = op.borrow(); + op.verify(self.builder.context())?; + } + + // Insert op at current insertion point, if set + if self.builder.insertion_point().is_some() { + self.builder.insert(self.op); + } + + Ok(op) + } +} diff --git a/hir2/src/ir/operation/name.rs b/hir2/src/ir/operation/name.rs new file mode 100644 index 000000000..be9a298c7 --- /dev/null +++ b/hir2/src/ir/operation/name.rs @@ -0,0 +1,191 @@ +use alloc::rc::Rc; +use core::{ + any::TypeId, + fmt, + ptr::{DynMetadata, Pointee}, +}; + +use crate::{interner, traits::TraitInfo, DialectName}; + +/// The operation name, or mnemonic, that uniquely identifies an operation. +/// +/// The operation name consists of its dialect name, and the opcode name within the dialect. +/// +/// No two operation names can share the same fully-qualified operation name. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct OperationName(Rc); + +struct OperationInfo { + /// The dialect of this operation + dialect: DialectName, + /// The opcode name for this operation + name: interner::Symbol, + /// The type id of the concrete type that implements this operation + type_id: TypeId, + /// Details of the traits implemented by this operation, used to answer questions about what + /// traits are implemented, as well as reconstruct `&dyn Trait` references given a pointer to + /// the data of a specific operation instance. + traits: Box<[TraitInfo]>, +} + +impl OperationName { + pub fn new(dialect: DialectName, name: S, traits: T) -> Self + where + O: crate::Op, + S: Into, + T: IntoIterator, + { + let type_id = TypeId::of::(); + let mut traits = traits.into_iter().collect::>(); + traits.sort_by_key(|ti| *ti.type_id()); + let traits = traits.into_boxed_slice(); + let info = Rc::new(OperationInfo::new(dialect, name.into(), type_id, traits)); + Self(info) + } + + /// Returns the dialect name of this operation + pub fn dialect(&self) -> &DialectName { + &self.0.dialect + } + + /// Returns the namespace to which this operation name belongs (i.e. dialect name) + pub fn namespace(&self) -> interner::Symbol { + self.0.dialect.as_symbol() + } + + /// Returns the name/opcode of this operation + pub fn name(&self) -> interner::Symbol { + self.0.name + } + + /// Returns true if `T` is the concrete type that implements this operation + pub fn is(&self) -> bool { + TypeId::of::() == self.0.type_id + } + + /// Returns true if this operation implements `Trait` + pub fn implements(&self) -> bool + where + Trait: ?Sized + Pointee> + 'static, + { + let type_id = TypeId::of::(); + self.0.traits.binary_search_by(|ti| ti.type_id().cmp(&type_id)).is_ok() + } + + /// Returns true if this operation implements `trait`, where `trait` is the `TypeId` of a + /// `dyn Trait` type. + pub fn implements_trait_id(&self, trait_id: &TypeId) -> bool { + self.0.traits.binary_search_by(|ti| ti.type_id().cmp(trait_id)).is_ok() + } + + #[inline] + pub(super) fn downcast_ref(&self, ptr: *const ()) -> Option<&T> { + if self.is::() { + Some(unsafe { self.downcast_ref_unchecked(ptr) }) + } else { + None + } + } + + #[inline(always)] + unsafe fn downcast_ref_unchecked(&self, ptr: *const ()) -> &T { + &*core::ptr::from_raw_parts(ptr.cast::(), ()) + } + + #[inline] + pub(super) fn downcast_mut(&mut self, ptr: *mut ()) -> Option<&mut T> { + if self.is::() { + Some(unsafe { self.downcast_mut_unchecked(ptr) }) + } else { + None + } + } + + #[inline(always)] + unsafe fn downcast_mut_unchecked(&mut self, ptr: *mut ()) -> &mut T { + &mut *core::ptr::from_raw_parts_mut(ptr.cast::(), ()) + } + + pub(super) fn upcast(&self, ptr: *const ()) -> Option<&Trait> + where + Trait: ?Sized + Pointee> + 'static, + { + let metadata = self + .get::() + .map(|trait_impl| unsafe { trait_impl.metadata_unchecked::() })?; + Some(unsafe { &*core::ptr::from_raw_parts(ptr, metadata) }) + } + + pub(super) fn upcast_mut(&mut self, ptr: *mut ()) -> Option<&mut Trait> + where + Trait: ?Sized + Pointee> + 'static, + { + let metadata = self + .get::() + .map(|trait_impl| unsafe { trait_impl.metadata_unchecked::() })?; + Some(unsafe { &mut *core::ptr::from_raw_parts_mut(ptr, metadata) }) + } + + fn get(&self) -> Option<&TraitInfo> { + let type_id = TypeId::of::(); + self.0 + .traits + .binary_search_by(|ti| ti.type_id().cmp(&type_id)) + .ok() + .map(|index| &self.0.traits[index]) + } +} +impl fmt::Debug for OperationName { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} +impl fmt::Display for OperationName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}.{}", &self.namespace(), &self.name()) + } +} + +impl OperationInfo { + pub fn new( + dialect: DialectName, + name: interner::Symbol, + type_id: TypeId, + traits: Box<[TraitInfo]>, + ) -> Self { + Self { + dialect, + name, + type_id, + traits, + } + } +} + +impl Eq for OperationInfo {} +impl PartialEq for OperationInfo { + fn eq(&self, other: &Self) -> bool { + self.dialect == other.dialect && self.name == other.name && self.type_id == other.type_id + } +} +impl PartialOrd for OperationInfo { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for OperationInfo { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.dialect + .cmp(&other.dialect) + .then_with(|| self.name.cmp(&other.name)) + .then_with(|| self.type_id.cmp(&other.type_id)) + } +} +impl core::hash::Hash for OperationInfo { + fn hash(&self, state: &mut H) { + self.dialect.hash(state); + self.name.hash(state); + self.type_id.hash(state); + } +} diff --git a/hir2/src/ir/print.rs b/hir2/src/ir/print.rs new file mode 100644 index 000000000..c91e586b2 --- /dev/null +++ b/hir2/src/ir/print.rs @@ -0,0 +1,222 @@ +use core::fmt; + +use super::{Context, Operation}; +use crate::{ + formatter::PrettyPrint, + matchers::Matcher, + traits::{SingleBlock, SingleRegion}, + CallableOpInterface, EntityWithId, Value, +}; + +#[derive(Default)] +pub struct OpPrintingFlags; + +/// The `OpPrinter` trait is expected to be implemented by all [Op] impls as a prequisite. +/// +/// The actual implementation is typically generated as part of deriving [Op]. +pub trait OpPrinter { + fn print( + &self, + flags: &OpPrintingFlags, + context: &Context, + f: &mut fmt::Formatter, + ) -> fmt::Result; +} + +impl OpPrinter for Operation { + #[inline] + fn print( + &self, + _flags: &OpPrintingFlags, + _context: &Context, + f: &mut fmt::Formatter, + ) -> fmt::Result { + write!(f, "{}", self.render()) + } +} + +impl fmt::Display for Operation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.render()) + } +} + +/// The generic format for printed operations is: +/// +/// <%result..> = .(%operand : , ..) : #.. { +/// // Region +/// ^(<%block_argument...>): +/// // Block +/// }; +/// +/// Special handling is provided for SingleRegionSingleBlock and CallableOpInterface ops: +/// +/// * SingleRegionSingleBlock ops with no operands will have the block header elided +/// * CallableOpInterface ops with no operands will be printed differently, using their +/// symbol and signature, as shown below: +/// +/// . @() -> #.. { +/// ... +/// } +impl PrettyPrint for Operation { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + + let is_single_region_single_block = + self.implements::() && self.implements::(); + let is_callable_op = self.implements::(); + let is_symbol = self.is_symbol(); + let no_operands = self.operands().is_empty(); + + let results = self.results(); + let mut doc = if !results.is_empty() { + let results = results.iter().enumerate().fold(Document::Empty, |doc, (i, result)| { + if i > 0 { + doc + const_text(", ") + display(result.borrow().id()) + } else { + doc + display(result.borrow().id()) + } + }); + results + const_text(" = ") + } else { + Document::Empty + }; + doc += display(self.name()) + const_text(" "); + let doc = if is_callable_op && is_symbol && no_operands { + let name = self.as_symbol().unwrap().name(); + let callable = self.as_trait::().unwrap(); + let signature = callable.signature(); + let mut doc = doc + display(signature.visibility) + text(format!(" @{}", name)); + if let Some(body) = callable.get_callable_region() { + let body = body.borrow(); + let entry = body.entry(); + doc += entry.arguments().iter().enumerate().fold( + const_text("("), + |doc, (i, param)| { + let param = param.borrow(); + let doc = if i > 0 { doc + const_text(", ") } else { doc }; + doc + display(param.id()) + const_text(": ") + display(param.ty()) + }, + ) + const_text(")"); + if !signature.results.is_empty() { + doc += signature.results().iter().enumerate().fold( + const_text(" -> "), + |doc, (i, result)| { + if i > 0 { + doc + const_text(", ") + display(&result.ty) + } else { + doc + display(&result.ty) + } + }, + ); + } + } else { + doc += signature.render() + } + doc + } else { + let mut is_constant = false; + let doc = if let Some(value) = crate::matchers::constant().matches(self) { + is_constant = true; + doc + value.render() + } else { + let operands = self.operands(); + if !operands.is_empty() { + operands.iter().enumerate().fold(doc, |doc, (i, operand)| { + let operand = operand.borrow(); + let value = operand.value(); + if i > 0 { + doc + const_text(", ") + display(value.id()) + } else { + doc + display(value.id()) + } + }) + } else { + doc + } + }; + let doc = if !results.is_empty() { + let results = + results.iter().enumerate().fold(Document::Empty, |doc, (i, result)| { + if i > 0 { + doc + const_text(", ") + text(format!("{}", result.borrow().ty())) + } else { + doc + text(format!("{}", result.borrow().ty())) + } + }); + doc + const_text(" : ") + results + } else { + doc + }; + + if is_constant { + doc + } else { + self.attrs.iter().fold(doc, |doc, attr| { + let doc = doc + const_text(" "); + if let Some(value) = attr.value() { + doc + const_text("#[") + + display(attr.name) + + const_text(" = ") + + value.render() + + const_text("]") + } else { + doc + text(format!("#[{}]", &attr.name)) + } + }) + } + }; + + if self.has_regions() { + self.regions.iter().fold(doc, |doc, region| { + let blocks = region.body().iter().enumerate().fold( + Document::Empty, + |mut doc, (block_index, block)| { + if block_index > 0 { + doc += nl(); + } + let ops = block.body().iter().enumerate().fold( + Document::Empty, + |mut doc, (i, op)| { + if i > 0 { + doc += nl(); + } + doc + op.render() + }, + ); + if is_single_region_single_block && no_operands { + doc + indent(4, nl() + ops) + } else { + let block_args = block.arguments().iter().enumerate().fold( + Document::Empty, + |doc, (i, arg)| { + if i > 0 { + doc + const_text(", ") + arg.borrow().render() + } else { + doc + arg.borrow().render() + } + }, + ); + let block_args = if block_args.is_empty() { + block_args + } else { + const_text("(") + block_args + const_text(")") + }; + doc + indent( + 4, + text(format!("^{}", block.id())) + + block_args + + const_text(":") + + nl() + + ops, + ) + } + }, + ); + doc + const_text(" {") + nl() + blocks + nl() + const_text("}") + }) + const_text(";") + } else { + doc + const_text(";") + } + } +} diff --git a/hir2/src/ir/region.rs b/hir2/src/ir/region.rs new file mode 100644 index 000000000..289600d38 --- /dev/null +++ b/hir2/src/ir/region.rs @@ -0,0 +1,473 @@ +mod branch_point; +mod interfaces; +mod invocation_bounds; +mod kind; +mod successor; +mod transforms; + +use smallvec::SmallVec; + +pub use self::{ + branch_point::RegionBranchPoint, + interfaces::{ + RegionBranchOpInterface, RegionBranchTerminatorOpInterface, RegionKindInterface, + RegionSuccessorIter, + }, + invocation_bounds::InvocationBounds, + kind::RegionKind, + successor::{RegionSuccessor, RegionSuccessorInfo, RegionSuccessorMut}, + transforms::RegionTransformFailed, +}; +use super::*; +use crate::RegionSimplificationLevel; + +pub type RegionRef = UnsafeIntrusiveEntityRef; +/// An intrusive, doubly-linked list of [Region]s +pub type RegionList = EntityList; +/// A cursor in a [RegionList] +pub type RegionCursor<'a> = EntityCursor<'a, Region>; +/// A mutable cursor in a [RegionList] +pub type RegionCursorMut<'a> = EntityCursorMut<'a, Region>; + +/// A region is a container for [Block], in one of two forms: +/// +/// * Graph-like, in which the region consists of a single block, and the order of operations in +/// that block does not dictate any specific control flow semantics. It is up to the containing +/// operation to define. +/// * SSA-form, in which the region consists of one or more blocks that must obey the usual rules +/// of SSA dominance, and where operations in a block reflect the order in which those operations +/// are to be executed. Values defined by an operation must dominate any uses of those values in +/// the region. +/// +/// The first block in a region is the _entry_ block, and its argument list corresponds to the +/// arguments expected by the region itself. +/// +/// A region is only valid when it is attached to an [Operation], whereas the inverse is not true, +/// i.e. an operation without a parent region is a top-level operation, e.g. `Module`. +#[derive(Default)] +pub struct Region { + /// The operation this region is attached to. + /// + /// If `link.is_linked() == true`, this will always be set to a valid pointer + owner: Option, + /// The list of [Block]s that comprise this region + body: BlockList, +} + +impl Entity for Region {} + +impl EntityWithParent for Region { + type Parent = Operation; + + fn on_inserted_into_parent( + mut this: UnsafeIntrusiveEntityRef, + parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().owner = Some(parent); + } + + fn on_removed_from_parent( + mut this: UnsafeIntrusiveEntityRef, + _parent: UnsafeIntrusiveEntityRef, + ) { + this.borrow_mut().owner = None; + } + + fn on_transfered_to_new_parent( + _from: UnsafeIntrusiveEntityRef, + to: UnsafeIntrusiveEntityRef, + transferred: impl IntoIterator>, + ) { + for mut transferred_region in transferred { + transferred_region.borrow_mut().owner = Some(to.clone()); + } + } +} + +impl cfg::Graph for Region { + type ChildEdgeIter = block::BlockSuccessorEdgesIter; + type ChildIter = block::BlockSuccessorIter; + type Edge = BlockOperandRef; + type Node = BlockRef; + + fn is_empty(&self) -> bool { + self.body.is_empty() + } + + fn size(&self) -> usize { + self.body.len() + } + + fn children(parent: Self::Node) -> Self::ChildIter { + block::BlockSuccessorIter::new(parent) + } + + fn children_edges(parent: Self::Node) -> Self::ChildEdgeIter { + block::BlockSuccessorEdgesIter::new(parent) + } + + fn edge_dest(edge: Self::Edge) -> Self::Node { + edge.borrow().block.clone() + } + + fn entry_node(&self) -> Self::Node { + self.body.front().as_pointer().expect("empty region") + } +} + +impl<'a> cfg::InvertibleGraph for &'a Region { + type Inverse = cfg::Inverse<&'a Region>; + type InvertibleChildEdgeIter = block::BlockPredecessorEdgesIter; + type InvertibleChildIter = block::BlockPredecessorIter; + + fn inverse(self) -> Self::Inverse { + cfg::Inverse::new(self) + } + + fn inverse_children(parent: Self::Node) -> Self::InvertibleChildIter { + block::BlockPredecessorIter::new(parent) + } + + fn inverse_children_edges(parent: Self::Node) -> Self::InvertibleChildEdgeIter { + block::BlockPredecessorEdgesIter::new(parent) + } +} + +/// Blocks +impl Region { + /// Returns true if this region is empty (has no blocks) + pub fn is_empty(&self) -> bool { + self.body.is_empty() + } + + /// Get a handle to the entry block for this region + pub fn entry(&self) -> EntityRef<'_, Block> { + self.body.front().into_borrow().unwrap() + } + + /// Get a mutable handle to the entry block for this region + pub fn entry_mut(&mut self) -> EntityMut<'_, Block> { + self.body.front_mut().into_borrow_mut().unwrap() + } + + /// Get the [BlockRef] of the entry block of this region, if it has one + #[inline] + pub fn entry_block_ref(&self) -> Option { + self.body.front().as_pointer() + } + + /// Get the list of blocks comprising the body of this region + pub fn body(&self) -> &BlockList { + &self.body + } + + /// Get a mutable reference to the list of blocks comprising the body of this region + pub fn body_mut(&mut self) -> &mut BlockList { + &mut self.body + } +} + +/// Metadata +impl Region { + #[inline] + pub fn as_region_ref(&self) -> RegionRef { + unsafe { RegionRef::from_raw(self) } + } + + /// Returns true if this region is an ancestor of `other`, i.e. it contains it. + /// + /// NOTE: This returns true if `self == other`, see [Self::is_proper_ancestor] if you do not + /// want this behavior. + pub fn is_ancestor(&self, other: &RegionRef) -> bool { + let this = self.as_region_ref(); + &this == other || Self::is_proper_ancestor_of(&this, other) + } + + /// Returns true if this region is a proper ancestor of `other`, i.e. `other` is contained by it + /// + /// NOTE: This returns false if `self == other`, see [Self::is_ancestor] if you do not want this + /// behavior. + pub fn is_proper_ancestor(&self, other: &RegionRef) -> bool { + let this = self.as_region_ref(); + Self::is_proper_ancestor_of(&this, other) + } + + fn is_proper_ancestor_of(this: &RegionRef, other: &RegionRef) -> bool { + if this == other { + return false; + } + + let mut parent = other.borrow().parent_region(); + while let Some(parent_region) = parent.take() { + if this == &parent_region { + return true; + } + parent = parent_region.borrow().parent_region(); + } + + false + } + + /// Returns true if this region may be a graph region without SSA dominance + pub fn may_be_graph_region(&self) -> bool { + if let Some(owner) = self.owner.as_ref() { + owner + .borrow() + .as_trait::() + .is_some_and(|rki| rki.has_graph_regions()) + } else { + true + } + } + + /// Returns true if this region has only one block + pub fn has_one_block(&self) -> bool { + !self.body.is_empty() + && BlockRef::ptr_eq( + &self.body.front().as_pointer().unwrap(), + &self.body.back().as_pointer().unwrap(), + ) + } + + /// Get the defining [Operation] for this region, if the region is attached to one. + pub fn parent(&self) -> Option { + self.owner.clone() + } + + /// Get the region which contains the parent operation of this region, if there is one. + pub fn parent_region(&self) -> Option { + self.owner.as_ref().and_then(|op| op.borrow().parent_region()) + } + + /// Set the owner of this region. + /// + /// Returns the previous owner. + /// + /// # Safety + /// + /// It is dangerous to set this field unless doing so as part of allocating the [Region] or + /// moving the [Region] from one op to another. If it is set to a different entity than actually + /// owns the region, it will result in undefined behavior or panics when we attempt to access + /// the owner via the region. + /// + /// You must ensure that the owner given _actually_ owns the region. Similarly, if you are + /// unsetting the owner, you must ensure that no entity _thinks_ it owns this region. + pub unsafe fn set_owner(&mut self, owner: Option) -> Option { + match owner { + None => self.owner.take(), + Some(owner) => self.owner.replace(owner), + } + } +} + +/// Mutation +impl Region { + /// Push `block` to the start of this region + #[inline] + pub fn push_front(&mut self, block: BlockRef) { + self.body.push_front(block); + } + + /// Push `block` to the end of this region + #[inline] + pub fn push_back(&mut self, block: BlockRef) { + self.body.push_back(block); + } + + /// Drop any references to blocks in this region - this is used to break cycles when cleaning + /// up regions. + pub fn drop_all_references(&mut self) { + let mut cursor = self.body_mut().front_mut(); + while let Some(mut op) = cursor.as_pointer() { + op.borrow_mut().drop_all_references(); + cursor.move_next(); + } + } +} + +/// Values +impl Region { + /// Check if every value in `values` is defined above this region, i.e. they are defined in a + /// region which is a proper ancestor of `self`. + pub fn values_are_defined_above(&self, values: &[ValueRef]) -> bool { + let this = self.as_region_ref(); + for value in values { + if !value + .borrow() + .parent_region() + .is_some_and(|value_region| Self::is_proper_ancestor_of(&value_region, &this)) + { + return false; + } + } + true + } + + /// Replace all uses of `value` with `replacement`, within this region. + pub fn replace_all_uses_in_region_with(&mut self, _value: ValueRef, _replacement: ValueRef) { + todo!("RegionUtils.h") + } + + /// Visit each use of a value in this region (and its descendants), where that value was defined + /// in an ancestor of `limit`. + pub fn visit_used_values_defined_above(&self, _limit: &RegionRef, _callback: F) + where + F: FnMut(OpOperand), + { + todo!("RegionUtils.h") + } + + /// Visit each use of a value in any of the provided regions (or their descendants), where that + /// value was defined in an ancestor of that region. + pub fn visit_used_values_defined_above_any(_regions: &[RegionRef], _callback: F) + where + F: FnMut(OpOperand), + { + todo!("RegionUtils.h") + } + + /// Return a vector of values used in this region (and its descendants), and defined in an + /// ancestor of the `limit` region. + pub fn get_used_values_defined_above(&self, _limit: &RegionRef) -> SmallVec<[ValueRef; 1]> { + todo!("RegionUtils.h") + } + + /// Return a vector of values used in any of the provided regions, but defined in an ancestor. + pub fn get_used_values_defined_above_any(_regions: &[RegionRef]) -> SmallVec<[ValueRef; 1]> { + todo!("RegionUtils.h") + } + + /// Make this region isolated from above. + /// + /// * Capture the values that are defined above the region and used within it. + /// * Append block arguments to the entry block that represent each captured value. + /// * Replace all uses of the captured values within the region, with the new block arguments + /// * `clone_into_region` is called with the defining op of a captured value. If it returns + /// true, it indicates that the op needs to be cloned into the region. As a result, the + /// operands of that operation become part of the captured value set (unless the operations + /// that define the operand values themselves are to be cloned). The cloned operations are + /// added to the entry block of the region. + /// + /// Returns the set of captured values. + pub fn make_isolated_from_above( + &mut self, + _rewriter: &mut R, + _clone_into_region: F, + ) -> SmallVec<[ValueRef; 1]> + where + R: crate::Rewriter, + F: Fn(&Operation) -> bool, + { + todo!("RegionUtils.h") + } +} + +/// Queries +impl Region { + pub fn find_common_ancestor(ops: &[OperationRef]) -> Option { + use bitvec::prelude::*; + + match ops.len() { + 0 => None, + 1 => unsafe { ops.get_unchecked(0) }.borrow().parent_region(), + num_ops => { + let (first, rest) = unsafe { ops.split_first().unwrap_unchecked() }; + let mut region = first.borrow().parent_region(); + let mut remaining_ops = bitvec![1; num_ops - 1]; + while let Some(r) = region.take() { + while let Some(index) = remaining_ops.first_one() { + // Is this op contained in `region`? + if r.borrow().find_ancestor_op(&rest[index]).is_some() { + unsafe { + remaining_ops.set_unchecked(index, false); + } + } + } + if remaining_ops.not_any() { + break; + } + region = r.borrow().parent_region(); + } + region + } + } + } + + /// Returns `block` if `block` lies in this region, or otherwise finds the ancestor of `block` + /// that lies in this region. + /// + /// Returns `None` if the latter fails. + pub fn find_ancestor_block(&self, block: &BlockRef) -> Option { + let this = self.as_region_ref(); + let mut current = Some(block.clone()); + while let Some(current_block) = current.take() { + let parent = current_block.borrow().parent()?; + if parent == this { + return Some(current_block); + } + current = + parent.borrow().owner.as_ref().and_then(|parent_op| parent_op.borrow().parent()); + } + current + } + + /// Returns `op` if `op` lies in this region, or otherwise finds the ancestor of `op` that lies + /// in this region. + /// + /// Returns `None` if the latter fails. + pub fn find_ancestor_op(&self, op: &OperationRef) -> Option { + let this = self.as_region_ref(); + let mut current = Some(op.clone()); + while let Some(current_op) = current.take() { + let parent = current_op.borrow().parent_region()?; + if parent == this { + return Some(current_op); + } + current = parent.borrow().parent(); + } + current + } +} + +/// Transforms +impl Region { + /// Run a set of structural simplifications over the regions in `regions`. + /// + /// This includes transformations like unreachable block elimination, dead argument elimination, + /// as well as some other DCE. + /// + /// This function returns `Ok` if any of the regions were simplified, `Err` otherwise. + /// + /// The provided rewriter is used to notify callers of operation and block deletion. + /// + /// The provided [RegionSimplificationLevel] will be used to determine whether to apply more + /// aggressive simplifications, namely block merging. Note that when block merging is enabled, + /// this can lead to merged blocks with extra arguments. + pub fn simplify_all( + regions: &[RegionRef], + rewriter: &mut dyn crate::Rewriter, + simplification_level: RegionSimplificationLevel, + ) -> Result<(), RegionTransformFailed> { + let merge_blocks = matches!(simplification_level, RegionSimplificationLevel::Aggressive); + + let eliminated_blocks = Self::erase_unreachable_blocks(regions, rewriter).is_ok(); + let eliminated_ops_or_args = Self::dead_code_elimination(regions, rewriter).is_ok(); + + let mut merged_identical_blocks = false; + let mut dropped_redundant_arguments = false; + if merge_blocks { + merged_identical_blocks = Self::merge_identical_blocks(regions, rewriter).is_ok(); + dropped_redundant_arguments = Self::drop_redundant_arguments(regions, rewriter).is_ok(); + } + + if eliminated_blocks + || eliminated_ops_or_args + || merged_identical_blocks + || dropped_redundant_arguments + { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } +} diff --git a/hir2/src/ir/region/branch_point.rs b/hir2/src/ir/region/branch_point.rs new file mode 100644 index 000000000..f3419f4e4 --- /dev/null +++ b/hir2/src/ir/region/branch_point.rs @@ -0,0 +1,50 @@ +use core::fmt; + +use super::*; + +/// This type represents a point being branched from in the methods of `RegionBranchOpInterface`. +/// +/// One can branch from one of two different kinds of places: +/// +/// * The parent operation (i.e. the op implementing `RegionBranchOpInterface`). +/// * A region within the parent operation (where the parent implements `RegionBranchOpInterface`). +#[derive(Clone, PartialEq, Eq)] +pub enum RegionBranchPoint { + /// A branch from the current operation to one of its regions + Parent, + /// A branch from the given region, within a parent `RegionBranchOpInterface` op + Child(RegionRef), +} +impl fmt::Debug for RegionBranchPoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Parent => f.write_str("Parent"), + Self::Child(ref region) => { + f.debug_tuple("Child").field(&format_args!("{:p}", region)).finish() + } + } + } +} +impl RegionBranchPoint { + /// Returns true if branching from the parent op. + #[inline] + pub fn is_parent(&self) -> bool { + matches!(self, Self::Parent) + } + + /// Returns the region if branching from a region, otherwise `None`. + pub fn region(&self) -> Option { + match self { + Self::Child(ref region) => Some(region.clone()), + Self::Parent => None, + } + } +} +impl<'a> From> for RegionBranchPoint { + fn from(succ: RegionSuccessor<'a>) -> Self { + match succ.into_successor() { + None => Self::Parent, + Some(succ) => Self::Child(succ), + } + } +} diff --git a/hir2/src/ir/region/interfaces.rs b/hir2/src/ir/region/interfaces.rs new file mode 100644 index 000000000..009c3202b --- /dev/null +++ b/hir2/src/ir/region/interfaces.rs @@ -0,0 +1,247 @@ +use super::*; +use crate::{ + attributes::AttributeValue, traits::Terminator, Op, SuccessorOperandRange, + SuccessorOperandRangeMut, Type, +}; + +/// An op interface that indicates what types of regions it holds +pub trait RegionKindInterface { + /// Get the [RegionKind] for this operation + fn kind(&self) -> RegionKind; + /// Returns true if the kind of this operation's regions requires SSA dominance + #[inline] + fn has_ssa_dominance(&self) -> bool { + matches!(self.kind(), RegionKind::SSA) + } + #[inline] + fn has_graph_regions(&self) -> bool { + matches!(self.kind(), RegionKind::Graph) + } +} + +// TODO(pauls): Implement verifier +/// This interface provides information for region operations that exhibit branching behavior +/// between held regions. I.e., this interface allows for expressing control flow information for +/// region holding operations. +/// +/// This interface is meant to model well-defined cases of control-flow and value propagation, +/// where what occurs along control-flow edges is assumed to be side-effect free. +/// +/// A "region branch point" indicates a point from which a branch originates. It can indicate either +/// a region of this op or [RegionBranchPoint::Parent]. In the latter case, the branch originates +/// from outside of the op, i.e., when first executing this op. +/// +/// A "region successor" indicates the target of a branch. It can indicate either a region of this +/// op or this op. In the former case, the region successor is a region pointer and a range of block +/// arguments to which the "successor operands" are forwarded to. In the latter case, the control +/// flow leaves this op and the region successor is a range of results of this op to which the +/// successor operands are forwarded to. +/// +/// By default, successor operands and successor block arguments/successor results must have the +/// same type. `areTypesCompatible` can be implemented to allow non-equal types. +/// +/// ## Example +/// +/// ```hir,ignore +/// %r = scf.for %iv = %lb to %ub step %step iter_args(%a = %b) +/// -> tensor<5xf32> { +/// ... +/// scf.yield %c : tensor<5xf32> +/// } +/// ``` +/// +/// `scf.for` has one region. The region has two region successors: the region itself and the +/// `scf.for` op. `%b` is an entry successor operand. `%c` is a successor operand. `%a` is a +/// successor block argument. `%r` is a successor result. +pub trait RegionBranchOpInterface: Op { + /// Returns the operands of this operation that are forwarded to the region successor's block + /// arguments or this operation's results when branching to `point`. `point` is guaranteed to + /// be among the successors that are returned by `get_entry_succcessor_regions` or + /// `get_successor_regions(parent_op())`. + /// + /// ## Example + /// + /// In the example in the top-level docs of this trait, this function returns the operand `%b` + /// of the `scf.for` op, regardless of the value of `point`, i.e. this op always forwards the + /// same operands, regardless of whether the loop has 0 or more iterations. + #[inline] + #[allow(unused_variables)] + fn get_entry_successor_operands(&self, point: RegionBranchPoint) -> SuccessorOperandRange<'_> { + crate::SuccessorOperandRange::empty() + } + /// Returns the potential region successors when first executing the op. + /// + /// Unlike [get_successor_regions], this method also passes along the constant operands of this + /// op. Based on these, the implementation may filter out certain successors. By default, it + /// simply dispatches to `get_successor_regions`. `operands` contains an entry for every operand + /// of this op, with `None` representing if the operand is non-constant. + /// + /// NOTE: The control flow does not necessarily have to enter any region of this op. + /// + /// ## Example + /// + /// In the example in the top-level docs of this trait, this function may return two region + /// successors: the single region of the `scf.for` op and the `scf.for` operation (that + /// implements this interface). If `%lb`, `%ub`, `%step` are constants and it can be determined + /// the loop does not have any iterations, this function may choose to return only this + /// operation. Similarly, if it can be determined that the loop has at least one iteration, this + /// function may choose to return only the region of the loop. + #[inline] + #[allow(unused_variables)] + fn get_entry_successor_regions( + &self, + operands: &[Option>], + ) -> RegionSuccessorIter<'_> { + self.get_successor_regions(RegionBranchPoint::Parent) + } + /// Returns the potential region successors when branching from `point`. + /// + /// These are the regions that may be selected during the flow of control. + /// + /// When `point` is [RegionBranchPoint::Parent], this function returns the region successors + /// when entering the operation. Otherwise, this method returns the successor regions when + /// branching from the region indicated by `point`. + /// + /// ## Example + /// + /// In the example in the top-level docs of this trait, this function returns the region of the + /// `scf.for` and this operation for either region branch point (`parent` and the region of the + /// `scf.for`). An implementation may choose to filter out region successors when it is + /// statically known (e.g., by examining the operands of this op) that those successors are not + /// branched to. + fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_>; + /// Returns a set of invocation bounds, representing the minimum and maximum number of times + /// this operation will invoke each attached region (assuming the regions yield normally, i.e. + /// do not abort or invoke an infinite loop). The minimum number of invocations is at least 0. + /// If the maximum number of invocations cannot be statically determined, then it will be set to + /// [InvocationBounds::unknown]. + /// + /// This function also passes along the constant operands of this op. `operands` contains an + /// entry for every operand of this op, with `None` representing if the operand is non-constant. + /// + /// This function may be called speculatively on operations where the provided operands are not + /// necessarily the same as the operation's current operands. This may occur in analyses that + /// wish to determine "what would be the region invocations if these were the operands?" + #[inline] + #[allow(unused_variables)] + fn get_region_invocation_bounds( + &self, + operands: &[Option>], + ) -> SmallVec<[InvocationBounds; 1]> { + use smallvec::smallvec; + + smallvec![InvocationBounds::Unknown; self.num_regions()] + } + /// This function is called to compare types along control-flow edges. + /// + /// By default, the types are check for exact equality. + #[inline] + fn are_types_compatible(&self, lhs: &Type, rhs: &Type) -> bool { + lhs == rhs + } + /// Returns `true` if control flow originating from the region at `index` may eventually branch + /// back to the same region, either from itself, or after passing through other regions first. + fn is_repetitive_region(&self, index: usize) -> bool; + /// Returns `true` if there is a loop in the region branching graph. + /// + /// Only reachable regions (starting from the entry region) are considered. + fn has_loop(&self) -> bool; +} + +// TODO(pauls): Implement verifier (should have no results and no successors) +/// This interface provides information for branching terminator operations in the presence of a +/// parent [RegionBranchOpInterface] implementation. It specifies which operands are passed to which +/// successor region. +pub trait RegionBranchTerminatorOpInterface: Op + Terminator { + /// Get a range of operands corresponding to values that are semantically "returned" by passing + /// them to the region successor indicated by `point`. + fn get_successor_operands(&self, point: RegionBranchPoint) -> SuccessorOperandRange<'_>; + /// Get a mutable range of operands corresponding to values that are semantically "returned" by + /// passing them to the region successor indicated by `point`. + fn get_mutable_successor_operands( + &mut self, + point: RegionBranchPoint, + ) -> SuccessorOperandRangeMut<'_>; + /// Returns the potential region successors that are branched to after this terminator based on + /// the given constant operands. + /// + /// This method also passes along the constant operands of this op. `operands` contains an entry + /// for every operand of this op, with `None` representing non-constant values. + /// + /// The default implementation simply dispatches to the parent `RegionBranchOpInterface`'s + /// `get_successor_regions` implementation. + #[allow(unused_variables)] + fn get_successor_regions( + &self, + operands: &[Option>], + ) -> SmallVec<[RegionSuccessorInfo; 2]> { + let parent_region = + self.parent_region().expect("expected operation to have a parent region"); + let parent_op = + parent_region.borrow().parent().expect("expected operation to have a parent op"); + parent_op + .borrow() + .as_trait::() + .expect("invalid region terminator parent: must implement RegionBranchOpInterface") + .get_successor_regions(RegionBranchPoint::Child(parent_region)) + .into_successor_infos() + } +} + +pub struct RegionSuccessorIter<'a> { + op: &'a Operation, + successors: SmallVec<[RegionSuccessorInfo; 2]>, + index: usize, +} +impl<'a> RegionSuccessorIter<'a> { + pub fn new( + op: &'a Operation, + successors: impl IntoIterator, + ) -> Self { + Self { + op, + successors: SmallVec::from_iter(successors), + index: 0, + } + } + + pub fn empty(op: &'a Operation) -> Self { + Self { + op, + successors: Default::default(), + index: 0, + } + } + + pub fn get(&self, index: usize) -> Option> { + self.successors.get(index).map(|info| RegionSuccessor { + dest: info.successor.clone(), + arguments: self.op.operands().group(info.operand_group as usize), + }) + } + + pub fn into_successor_infos(self) -> SmallVec<[RegionSuccessorInfo; 2]> { + self.successors + } +} +impl core::iter::FusedIterator for RegionSuccessorIter<'_> {} +impl<'a> ExactSizeIterator for RegionSuccessorIter<'a> { + fn len(&self) -> usize { + self.successors.len() + } +} +impl<'a> Iterator for RegionSuccessorIter<'a> { + type Item = RegionSuccessor<'a>; + + fn next(&mut self) -> Option { + if self.index >= self.successors.len() { + return None; + } + + let info = &self.successors[self.index]; + Some(RegionSuccessor { + dest: info.successor.clone(), + arguments: self.op.operands().group(info.operand_group as usize), + }) + } +} diff --git a/hir2/src/ir/region/invocation_bounds.rs b/hir2/src/ir/region/invocation_bounds.rs new file mode 100644 index 000000000..574471747 --- /dev/null +++ b/hir2/src/ir/region/invocation_bounds.rs @@ -0,0 +1,132 @@ +use core::ops::{Bound, RangeBounds}; + +/// This type represents upper and lower bounds on the number of times a region of a +/// `RegionBranchOpInterface` op can be invoked. The lower bound is at least zero, but the upper +/// bound may not be known. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum InvocationBounds { + /// The region can be invoked an unknown number of times, possibly never. + #[default] + Unknown, + /// The region can never be invoked + Never, + /// The region can be invoked exactly N times + Exact(u32), + /// The region can be invoked any number of times in the given range + Variable { min: u32, max: u32 }, + /// The region can be invoked at least N times, but an unknown number of times beyond that. + AtLeastN(u32), + /// The region can be invoked any number of times up to N + NoMoreThan(u32), +} +impl InvocationBounds { + #[inline] + pub fn new(bounds: impl Into) -> Self { + bounds.into() + } + + #[inline] + pub fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } + + pub fn min(&self) -> Bound<&u32> { + self.start_bound() + } + + pub fn max(&self) -> Bound<&u32> { + self.end_bound() + } +} +impl From for InvocationBounds { + fn from(value: u32) -> Self { + if value == 0 { + Self::Never + } else { + Self::Exact(value) + } + } +} +impl From> for InvocationBounds { + fn from(range: core::ops::Range) -> Self { + if range.start == range.end { + Self::Never + } else if range.end == range.start + 1 { + Self::Exact(range.start) + } else { + assert!(range.start < range.end); + Self::Variable { + min: range.start, + max: range.end, + } + } + } +} +impl From> for InvocationBounds { + fn from(value: core::ops::RangeFrom) -> Self { + if value.start == 0 { + Self::Unknown + } else { + Self::AtLeastN(value.start) + } + } +} +impl From> for InvocationBounds { + fn from(value: core::ops::RangeTo) -> Self { + if value.end == 1 { + Self::Never + } else if value.end == u32::MAX { + Self::Unknown + } else { + Self::NoMoreThan(value.end - 1) + } + } +} +impl From for InvocationBounds { + fn from(_value: core::ops::RangeFull) -> Self { + Self::Unknown + } +} +impl From> for InvocationBounds { + fn from(range: core::ops::RangeInclusive) -> Self { + let (start, end) = range.into_inner(); + if start == 0 && end == 0 { + Self::Never + } else if start == end { + Self::Exact(start) + } else { + Self::Variable { + min: start, + max: end + 1, + } + } + } +} +impl From> for InvocationBounds { + fn from(range: core::ops::RangeToInclusive) -> Self { + if range.end == 0 { + Self::Never + } else { + Self::NoMoreThan(range.end) + } + } +} +impl RangeBounds for InvocationBounds { + fn start_bound(&self) -> Bound<&u32> { + match self { + Self::Unknown | Self::NoMoreThan(_) => Bound::Unbounded, + Self::Never => Bound::Excluded(&0), + Self::Exact(n) | Self::Variable { min: n, .. } => Bound::Included(n), + Self::AtLeastN(n) => Bound::Included(n), + } + } + + fn end_bound(&self) -> Bound<&u32> { + match self { + Self::Unknown | Self::AtLeastN(_) => Bound::Unbounded, + Self::Never => Bound::Excluded(&0), + Self::Exact(n) | Self::Variable { max: n, .. } => Bound::Excluded(n), + Self::NoMoreThan(n) => Bound::Excluded(n), + } + } +} diff --git a/hir2/src/ir/region/kind.rs b/hir2/src/ir/region/kind.rs new file mode 100644 index 000000000..83eb6da8e --- /dev/null +++ b/hir2/src/ir/region/kind.rs @@ -0,0 +1,22 @@ +/// Represents the types of regions that can be represented in the IR +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RegionKind { + /// A graph region is one without control-flow semantics, i.e. dataflow between operations is + /// the only thing that dictates order, and operations can be conceptually executed in parallel + /// if the runtime supports it. + /// + /// As there is no control-flow in these regions, graph regions may only contain a single block. + Graph, + /// An SSA region is one where the strict control-flow semantics and properties of SSA (static + /// single assignment) form must be upheld. + /// + /// SSA regions must adhere to: + /// + /// * Values can only be defined once + /// * Definitions must dominate uses + /// * Ordering of operations in a block corresponds to execution order, i.e. operations earlier + /// in a block dominate those later in the block. + /// * Blocks must end with a terminator. + #[default] + SSA, +} diff --git a/hir2/src/ir/region/successor.rs b/hir2/src/ir/region/successor.rs new file mode 100644 index 000000000..a55f238f1 --- /dev/null +++ b/hir2/src/ir/region/successor.rs @@ -0,0 +1,108 @@ +use core::fmt; + +use super::*; +use crate::{OpOperandRange, OpOperandRangeMut}; + +/// This struct represents owned region successor metadata +#[derive(Clone)] +pub struct RegionSuccessorInfo { + pub successor: RegionBranchPoint, + #[allow(unused)] + pub(crate) key: Option>, + pub(crate) operand_group: u8, +} + +/// A [RegionSuccessor] represents the successor of a region. +/// +/// +/// A region successor can either be another region, or the parent operation. If the successor is a +/// region, this class represents the destination region, as well as a set of arguments from that +/// region that will be populated when control flows into the region. If the successor is the parent +/// operation, this class represents an optional set of results that will be populated when control +/// returns to the parent operation. +/// +/// This interface assumes that the values from the current region that are used to populate the +/// successor inputs are the operands of the return-like terminator operations in the blocks within +/// this region. +pub struct RegionSuccessor<'a> { + pub dest: RegionBranchPoint, + pub arguments: OpOperandRange<'a>, +} +impl<'a> RegionSuccessor<'a> { + /// Returns true if the successor is the parent op + pub fn is_parent(&self) -> bool { + self.dest.is_parent() + } + + pub fn successor(&self) -> Option { + self.dest.region() + } + + pub fn into_successor(self) -> Option { + self.dest.region() + } + + /// Return the inputs to the successor that are remapped by the exit values of the current + /// region. + pub fn successor_inputs(&self) -> &OpOperandRange<'a> { + &self.arguments + } +} +impl fmt::Debug for RegionSuccessor<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RegionSuccessor") + .field("dest", &self.dest) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for operand in self.arguments.iter() { + list.entry(&operand.borrow()); + } + list.finish() + }) + .finish() + } +} + +/// The mutable version of [RegionSuccessor] +pub struct RegionSuccessorMut<'a> { + pub dest: RegionBranchPoint, + pub arguments: OpOperandRangeMut<'a>, +} +impl<'a> RegionSuccessorMut<'a> { + /// Returns true if the successor is the parent op + pub fn is_parent(&self) -> bool { + self.dest.is_parent() + } + + pub fn successor(&self) -> Option { + self.dest.region() + } + + pub fn into_successor(self) -> Option { + self.dest.region() + } + + /// Return the inputs to the successor that are remapped by the exit values of the current + /// region. + pub fn successor_inputs(&mut self) -> &mut OpOperandRangeMut<'a> { + &mut self.arguments + } + + pub fn into_successor_inputs(self) -> OpOperandRangeMut<'a> { + self.arguments + } +} +impl fmt::Debug for RegionSuccessorMut<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RegionSuccessorMut") + .field("dest", &self.dest) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for operand in self.arguments.iter() { + list.entry(&operand.borrow()); + } + list.finish() + }) + .finish() + } +} diff --git a/hir2/src/ir/region/transforms.rs b/hir2/src/ir/region/transforms.rs new file mode 100644 index 000000000..10f7d442d --- /dev/null +++ b/hir2/src/ir/region/transforms.rs @@ -0,0 +1,6 @@ +mod block_merging; +mod dce; +mod drop_redundant_args; + +#[derive(Debug, Copy, Clone)] +pub struct RegionTransformFailed; diff --git a/hir2/src/ir/region/transforms/block_merging.rs b/hir2/src/ir/region/transforms/block_merging.rs new file mode 100644 index 000000000..cf8bbc310 --- /dev/null +++ b/hir2/src/ir/region/transforms/block_merging.rs @@ -0,0 +1,263 @@ +#![allow(unused)] +use alloc::collections::BTreeMap; + +use super::RegionTransformFailed; +use crate::{ + BlockArgument, BlockRef, DynHash, OpResult, Operation, OperationRef, Region, RegionRef, + Rewriter, ValueRef, +}; + +bitflags::bitflags! { + struct EquivalenceFlags: u8 { + const IGNORE_LOCATIONS = 1; + } +} + +struct OpEquivalence { + flags: EquivalenceFlags, + operand_hasher: OperandHasher, + result_hasher: ResultHasher, +} + +type ValueHasher = Box; + +impl OpEquivalence { + pub fn new() -> Self { + Self { + flags: EquivalenceFlags::empty(), + operand_hasher: DefaultValueHasher, + result_hasher: DefaultValueHasher, + } + } +} +impl OpEquivalence { + #[inline] + pub fn with_flags(mut self, flags: EquivalenceFlags) -> Self { + self.flags.insert(flags); + self + } + + /// Ignore op operands when computing equivalence for operations + pub fn ignore_operands(self) -> OpEquivalence<(), ResultHasher> { + OpEquivalence { + flags: self.flags, + operand_hasher: (), + result_hasher: self.result_hasher, + } + } + + /// Ignore op results when computing equivalence for operations + pub fn ignore_results(self) -> OpEquivalence { + OpEquivalence { + flags: self.flags, + operand_hasher: self.operand_hasher, + result_hasher: (), + } + } + + /// Specify a custom hasher for op operands + pub fn with_operand_hasher( + self, + hasher: impl Fn(&ValueRef, &mut dyn core::hash::Hasher) + 'static, + ) -> OpEquivalence { + OpEquivalence { + flags: self.flags, + operand_hasher: Box::new(hasher), + result_hasher: self.result_hasher, + } + } + + /// Specify a custom hasher for op results + pub fn with_result_hasher( + self, + hasher: impl Fn(&ValueRef, &mut dyn core::hash::Hasher) + 'static, + ) -> OpEquivalence { + OpEquivalence { + flags: self.flags, + operand_hasher: self.operand_hasher, + result_hasher: Box::new(hasher), + } + } + + /// Compare if two operations are equivalent using the current equivalence configuration. + /// + /// This is equivalent to calling [compute_equivalence] with `are_values_equivalent` set to + /// `ValueRef::ptr_eq`, and `on_value_equivalence` to a no-op. + #[inline] + pub fn are_equivalent(&self, lhs: &OperationRef, rhs: &OperationRef) -> bool { + #[inline(always)] + fn noop(_: &ValueRef, _: &ValueRef) {} + + self.compute_equivalence(lhs, rhs, ValueRef::ptr_eq, noop) + } + + /// Compare if two operations (and their regions) are equivalent using the current equivalence + /// configuration. + /// + /// * `are_values_equivalent` is a callback used to check if two values are equivalent. For + /// two operations to be equivalent, their operands must be the same SSA value, or this + /// callback must return `true`. + /// * `on_value_equivalence` is a callback to inform the caller that the analysis determined + /// that two values are equivalent. + /// + /// NOTE: Additional information regarding value equivalence can be injected into the analysis + /// via `are_values_equivalent`. Typically, callers may want values that were recorded as + /// equivalent via `on_value_equivalence` to be reflected in `are_values_equivalent`, but it + /// depends on the exact semantics desired by the caller. + pub fn compute_equivalence( + &self, + lhs: &OperationRef, + rhs: &OperationRef, + are_values_equivalent: VE, + on_value_equivalence: OVE, + ) -> bool + where + VE: Fn(&ValueRef, &ValueRef) -> bool, + OVE: FnMut(&ValueRef, &ValueRef), + { + todo!() + } + + /// Compare if two regions are equivalent using the current equivalence configuration. + /// + /// See [compute_equivalence] for more details. + pub fn compute_region_equivalence( + &self, + lhs: &RegionRef, + rhs: &RegionRef, + are_values_equivalent: VE, + on_value_equivalence: OVE, + ) -> bool + where + VE: Fn(&ValueRef, &ValueRef) -> bool, + OVE: FnMut(&ValueRef, &ValueRef), + { + todo!() + } + + /// Hashes an operation based on: + /// + /// * OperationName + /// * Attributes + /// * Result types + fn hash_operation(&self, op: &Operation, hasher: &mut impl core::hash::Hasher) { + use core::hash::Hash; + + use crate::Value; + + op.name().hash(hasher); + for attr in op.attributes().iter() { + attr.hash(hasher); + } + for result in op.results().iter() { + result.borrow().ty().hash(hasher); + } + } +} + +#[inline(always)] +pub fn ignore_value_equivalence(_lhs: &ValueRef, _rhs: &ValueRef) -> bool { + true +} + +struct DefaultValueHasher; +impl FnOnce<(&ValueRef, &mut dyn core::hash::Hasher)> for DefaultValueHasher { + type Output = (); + + extern "rust-call" fn call_once( + self, + args: (&ValueRef, &mut dyn core::hash::Hasher), + ) -> Self::Output { + use core::hash::Hash; + + let (value, hasher) = args; + value.dyn_hash(hasher); + } +} +impl FnMut<(&ValueRef, &mut dyn core::hash::Hasher)> for DefaultValueHasher { + extern "rust-call" fn call_mut( + &mut self, + args: (&ValueRef, &mut dyn core::hash::Hasher), + ) -> Self::Output { + use core::hash::Hash; + + let (value, hasher) = args; + value.dyn_hash(hasher); + } +} +impl Fn<(&ValueRef, &mut dyn core::hash::Hasher)> for DefaultValueHasher { + extern "rust-call" fn call( + &self, + args: (&ValueRef, &mut dyn core::hash::Hasher), + ) -> Self::Output { + use core::hash::Hash; + + let (value, hasher) = args; + value.dyn_hash(hasher); + } +} + +struct BlockEquivalenceData { + /// The block this data refers to + block: BlockRef, + /// The hash for this block + hash: u64, + /// A map of result producing operations to their relative orders within this block. The order + /// of an operation is the number of defined values that are produced within the block before + /// this operation. + op_order_index: BTreeMap, +} +impl BlockEquivalenceData { + pub fn new(block: BlockRef) -> Self { + use core::hash::Hasher; + + let mut op_order_index = BTreeMap::default(); + + let b = block.borrow(); + let mut order = b.num_arguments() as u32; + let mut op_equivalence = OpEquivalence::new() + .with_flags(EquivalenceFlags::IGNORE_LOCATIONS) + .ignore_operands() + .ignore_results(); + + let mut hasher = rustc_hash::FxHasher::default(); + for op in b.body() { + let num_results = op.num_results() as u32; + if num_results > 0 { + op_order_index.insert(op.as_operation_ref(), order); + order += num_results; + } + op_equivalence.hash_operation(&op, &mut hasher); + } + + Self { + block, + hash: hasher.finish(), + op_order_index, + } + } + + fn get_order_of(&self, value: &ValueRef) -> usize { + let value = value.borrow(); + assert!(value.parent_block().unwrap() == self.block, "expected value of this block"); + + if let Some(block_arg) = value.downcast_ref::() { + return block_arg.index(); + } + + let result = value.downcast_ref::().unwrap(); + let order = + *self.op_order_index.get(&result.owner()).expect("expected op to have an order"); + result.index() + (order as usize) + } +} + +impl Region { + // TODO(pauls) + pub(in crate::ir::region) fn merge_identical_blocks( + _regions: &[RegionRef], + _rewriter: &mut dyn Rewriter, + ) -> Result<(), RegionTransformFailed> { + Err(RegionTransformFailed) + } +} diff --git a/hir2/src/ir/region/transforms/dce.rs b/hir2/src/ir/region/transforms/dce.rs new file mode 100644 index 000000000..84497cd77 --- /dev/null +++ b/hir2/src/ir/region/transforms/dce.rs @@ -0,0 +1,354 @@ +use alloc::collections::{BTreeSet, VecDeque}; + +use smallvec::SmallVec; + +use super::RegionTransformFailed; +use crate::{ + traits::{BranchOpInterface, Terminator}, + OpOperandImpl, OpResult, Operation, OperationRef, PostOrderBlockIter, Region, RegionRef, + Rewriter, SuccessorOperands, ValueRef, +}; + +/// Data structure used to track which values have already been proved live. +/// +/// Because operations can have multiple results, this data structure tracks liveness for both +/// values and operations to avoid having to look through all results when analyzing a use. +/// +/// This data structure essentially tracks the dataflow lattice. The set of values/ops proved live +/// increases monotonically to a fixed-point. +#[derive(Default)] +struct LiveMap { + values: BTreeSet, + ops: BTreeSet, + changed: bool, +} +impl LiveMap { + pub fn was_proven_live(&self, value: &ValueRef) -> bool { + // TODO(pauls): For results that are removable, e.g. for region based control flow, + // we could allow for these values to be tracked independently. + let val = value.borrow(); + if let Some(result) = val.downcast_ref::() { + self.ops.contains(&result.owner()) + } else { + self.values.contains(value) + } + } + + #[inline] + pub fn was_op_proven_live(&self, op: &OperationRef) -> bool { + self.ops.contains(op) + } + + pub fn set_proved_live(&mut self, value: ValueRef) { + // TODO(pauls): For results that are removable, e.g. for region based control flow, + // we could allow for these values to be tracked independently. + let val = value.borrow(); + if let Some(result) = val.downcast_ref::() { + self.changed |= self.ops.insert(result.owner()); + } else { + self.changed |= self.values.insert(value); + } + } + + pub fn set_op_proved_live(&mut self, op: OperationRef) { + self.changed |= self.ops.insert(op); + } + + #[inline(always)] + pub fn mark_unchanged(&mut self) { + self.changed = false; + } + + #[inline(always)] + pub const fn has_changed(&self) -> bool { + self.changed + } + + pub fn is_use_specially_known_dead(&self, user: &OpOperandImpl) -> bool { + // DCE generally treats all uses of an op as live if the op itself is considered live. + // However, for successor operands to terminators we need a finer-grained notion where we + // deduce liveness for operands individually. The reason for this is easiest to think about + // in terms of a classical phi node based SSA IR, where each successor operand is really an + // operand to a _separate_ phi node, rather than all operands to the branch itself as with + // the block argument representation that we use. + // + // And similarly, because each successor operand is really an operand to a phi node, rather + // than to the terminator op itself, a terminator op can't e.g. "print" the value of a + // successor operand. + let owner = &user.owner; + if owner.borrow().implements::() { + if let Some(branch_interface) = owner.borrow().as_trait::() { + if let Some(arg) = + branch_interface.get_successor_block_argument(user.index as usize) + { + return !self.was_proven_live(&arg.upcast()); + } + } + } + + false + } + + pub fn propagate_region_liveness(&mut self, region: &Region) { + if region.body().is_empty() { + return; + } + + for block in PostOrderBlockIter::new(region.body().front().as_pointer().unwrap()) { + // We process block arguments after the ops in the block, to promote faster convergence + // to a fixed point (we try to visit uses before defs). + let block = block.borrow(); + for op in block.body() { + self.propagate_liveness(&op); + } + + // We currently do not remove entry block arguments, so there is no need to track their + // liveness. + // + // TODO(pauls): We could track these and enable removing dead operands/arguments from + // region control flow operations in the future. + if block.is_entry_block() { + continue; + } + + for arg in block.arguments() { + let arg = arg.clone().upcast(); + if !self.was_proven_live(&arg) { + self.process_value(arg); + } + } + } + } + + pub fn propagate_liveness(&mut self, op: &Operation) { + // Recurse on any regions the op has + for region in op.regions() { + self.propagate_region_liveness(®ion); + } + + // We process terminator operations separately + if op.implements::() { + return self.propagate_terminator_liveness(op); + } + + // Don't reprocess live operations. + if self.was_op_proven_live(&op.as_operation_ref()) { + return; + } + + // Process this op + if !op.would_be_trivially_dead() { + self.set_op_proved_live(op.as_operation_ref()); + } + + // If the op isn't intrinsically alive, check it's results + for result in op.results().iter() { + self.process_value(result.clone().upcast()); + } + } + + fn propagate_terminator_liveness(&mut self, op: &Operation) { + // Terminators are always live + self.set_op_proved_live(op.as_operation_ref()); + + // Check to see if we can reason about the successor operands instead + // + // If we can't reason about the operand to a successor, conservatively mark it as live + if let Some(branch_op) = op.as_trait::() { + let num_successors = op.num_successors(); + for succ_index in 0..num_successors { + let succ_operands = branch_op.get_successor_operands(succ_index); + for arg_index in 0..succ_operands.num_produced() { + let succ = op.successors()[succ_index].block.borrow().block.clone(); + let succ_arg = succ.borrow().get_argument(arg_index).upcast(); + self.set_proved_live(succ_arg); + } + } + } else { + let num_successors = op.num_successors(); + for succ_index in 0..num_successors { + let succ = op.successor(succ_index); + for arg in succ.arguments.iter() { + let arg = arg.borrow().as_value_ref(); + self.set_proved_live(arg); + } + } + } + } + + fn process_value(&mut self, value: ValueRef) { + let proved_live = value.borrow().iter_uses().any(|user| { + if self.is_use_specially_known_dead(&user) { + return false; + } + self.was_op_proven_live(&user.owner) + }); + if proved_live { + self.set_proved_live(value); + } + } +} + +impl Region { + pub fn dead_code_elimination( + regions: &[RegionRef], + rewriter: &mut dyn Rewriter, + ) -> Result<(), RegionTransformFailed> { + let mut live_map = LiveMap::default(); + loop { + live_map.mark_unchanged(); + + for region in regions { + live_map.propagate_region_liveness(®ion.borrow()); + } + + if !live_map.has_changed() { + break; + } + } + + Self::cleanup_dead_code(regions, rewriter, &live_map) + } + + /// Erase the unreachable blocks within the regions in `regions`. + /// + /// Returns `Ok` if any blocks were erased, `Err` otherwise. + pub fn erase_unreachable_blocks( + regions: &[RegionRef], + rewriter: &mut dyn crate::Rewriter, + ) -> Result<(), RegionTransformFailed> { + let mut erased_dead_blocks = false; + let mut reachable = BTreeSet::default(); + let mut worklist = VecDeque::from_iter(regions.iter().cloned()); + while let Some(mut region) = worklist.pop_front() { + let mut current_region = region.borrow_mut(); + let blocks = current_region.body_mut(); + if blocks.is_empty() { + continue; + } + + // If this is a single block region, just collect nested regions. + if blocks.front().as_pointer() == blocks.back().as_pointer() { + for op in blocks.front().get().unwrap().body() { + worklist.extend(op.regions().iter().map(|r| r.as_region_ref())); + } + continue; + } + + // Mark all reachable blocks. + reachable.clear(); + let iter = PostOrderBlockIter::new(blocks.front().as_pointer().unwrap()); + reachable.extend(iter); + + // Collect all of the dead blocks and push the live regions on the worklist + let mut cursor = blocks.front_mut(); + cursor.move_next(); + while let Some(mut block) = cursor.as_pointer() { + if reachable.contains(&block) { + // Walk any regions within this block + for op in block.borrow().body() { + worklist.extend(op.regions().iter().map(|r| r.as_region_ref())); + } + continue; + } + + // The block is unreachable, erase it + block.borrow_mut().drop_all_defined_value_uses(); + rewriter.erase_block(block); + erased_dead_blocks = true; + } + } + + if erased_dead_blocks { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } + + fn cleanup_dead_code( + regions: &[RegionRef], + rewriter: &mut dyn Rewriter, + live_map: &LiveMap, + ) -> Result<(), RegionTransformFailed> { + let mut erased_anything = false; + for region in regions { + let current_region = region.borrow(); + if current_region.body().is_empty() { + continue; + } + + let has_single_block = current_region.has_one_block(); + + // Delete every operation that is not live. Graph regions may have cycles in the use-def + // graph, so we must explicitly drop all uses from each operation as we erase it. + // Visiting the operations in post-order guarantees that in SSA CFG regions, value uses + // are removed before defs, which makes `drop_all_uses` a no-op. + let iter = PostOrderBlockIter::new(current_region.entry_block_ref().unwrap()); + for block in iter { + if !has_single_block { + Self::erase_terminator_successor_operands( + block.borrow().terminator().expect("expected block to have terminator"), + live_map, + ); + } + let mut next_op = block.borrow().body().back().as_pointer(); + while let Some(mut child_op) = next_op.take() { + next_op = child_op.prev(); + if !live_map.was_op_proven_live(&child_op) { + erased_anything = true; + child_op.borrow_mut().drop_all_uses(); + rewriter.erase_op(child_op); + } else { + let child_regions = child_op + .borrow() + .regions() + .iter() + .map(|r| r.as_region_ref()) + .collect::>(); + erased_anything |= + Self::cleanup_dead_code(&child_regions, rewriter, live_map).is_ok(); + } + } + } + + // Delete block arguments. + // + // The entry block has an unknown contract with their enclosing block, so leave it alone. + let mut region = region.clone(); + let mut current_region = region.borrow_mut(); + let mut blocks = current_region.body_mut().front_mut(); + while let Some(mut block) = blocks.as_pointer() { + blocks.move_next(); + block + .borrow_mut() + .erase_arguments(|arg| !live_map.was_proven_live(&arg.as_value_ref())); + } + } + + if erased_anything { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } + + fn erase_terminator_successor_operands(mut terminator: OperationRef, live_map: &LiveMap) { + let mut op = terminator.borrow_mut(); + if !op.implements::() { + return; + } + + // Iterate successors in reverse to minimize the amount of operand shifting + for succ_index in (0..op.num_successors()).rev() { + let mut succ = op.successor_mut(succ_index); + // Iterate arguments in reverse so that erasing an argument does not shift the others + let num_arguments = succ.arguments.len(); + for arg_index in (0..num_arguments).rev() { + if !live_map.was_proven_live(&succ.arguments[arg_index].borrow().as_value_ref()) { + succ.arguments.erase(arg_index); + } + } + } + } +} diff --git a/hir2/src/ir/region/transforms/drop_redundant_args.rs b/hir2/src/ir/region/transforms/drop_redundant_args.rs new file mode 100644 index 000000000..f31086ecd --- /dev/null +++ b/hir2/src/ir/region/transforms/drop_redundant_args.rs @@ -0,0 +1,151 @@ +use smallvec::SmallVec; + +use super::RegionTransformFailed; +use crate::{ + traits::BranchOpInterface, BlockArgumentRef, BlockRef, Region, RegionRef, Rewriter, + SuccessorOperands, Usable, +}; + +impl Region { + /// This optimization drops redundant argument to blocks. I.e., if a given argument to a block + /// receives the same value from each of the block predecessors, we can remove the argument from + /// the block and use directly the original value. + /// + /// ## Example + /// + /// A simple example: + /// + /// ```hir,ignore + /// %cond = llvm.call @rand() : () -> i1 + /// %val0 = llvm.mlir.constant(1 : i64) : i64 + /// %val1 = llvm.mlir.constant(2 : i64) : i64 + /// %val2 = llvm.mlir.constant(3 : i64) : i64 + /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2 + /// : i64) + /// + /// ^bb1(%arg0 : i64, %arg1 : i64): + /// llvm.call @foo(%arg0, %arg1) + /// ``` + /// + /// That can be rewritten as: + /// + /// ```hir,ignore + /// %cond = llvm.call @rand() : () -> i1 + /// %val0 = llvm.mlir.constant(1 : i64) : i64 + /// %val1 = llvm.mlir.constant(2 : i64) : i64 + /// %val2 = llvm.mlir.constant(3 : i64) : i64 + /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64) + /// + /// ^bb1(%arg0 : i64): + /// llvm.call @foo(%val0, %arg0) + /// ``` + pub(in crate::ir::region) fn drop_redundant_arguments( + regions: &[RegionRef], + rewriter: &mut dyn Rewriter, + ) -> Result<(), RegionTransformFailed> { + let mut worklist = SmallVec::<[RegionRef; 1]>::from_iter(regions.iter().cloned()); + + let mut any_changed = false; + while let Some(region) = worklist.pop() { + // Add any nested regions to the worklist + let region = region.borrow(); + let mut blocks = region.body().front(); + while let Some(block) = blocks.as_pointer() { + blocks.move_next(); + + any_changed |= + Self::drop_redundant_block_arguments(block.clone(), rewriter).is_ok(); + + for op in block.borrow().body() { + let mut regions = op.regions().front(); + while let Some(region) = regions.as_pointer() { + worklist.push(region); + regions.move_next(); + } + } + } + } + + if any_changed { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } + + /// If a block's argument is always the same across different invocations, then + /// drop the argument and use the value directly inside the block + fn drop_redundant_block_arguments( + mut block: BlockRef, + rewriter: &mut dyn Rewriter, + ) -> Result<(), RegionTransformFailed> { + let mut args_to_erase = SmallVec::<[usize; 4]>::default(); + + // Go through the arguments of the block. + let mut block_mut = block.borrow_mut(); + let block_args = + SmallVec::<[BlockArgumentRef; 2]>::from_iter(block_mut.arguments().iter().cloned()); + for (arg_index, block_arg) in block_args.into_iter().enumerate() { + let mut same_arg = true; + let mut common_value = None; + + // Go through the block predecessor and flag if they pass to the block different values + // for the same argument. + for pred in block_mut.predecessors() { + let pred_op = pred.owner.borrow(); + if let Some(branch_op) = pred_op.as_trait::() { + let succ_index = pred.index as usize; + let succ_operands = branch_op.get_successor_operands(succ_index); + let branch_operands = succ_operands.forwarded(); + let arg = branch_operands[arg_index].borrow().as_value_ref(); + if common_value.is_none() { + common_value = Some(arg); + continue; + } + if common_value.as_ref().is_some_and(|cv| cv != &arg) { + same_arg = false; + break; + } + } else { + same_arg = false; + break; + } + } + + // If they are passing the same value, drop the argument. + if let Some(common_value) = common_value { + if same_arg { + args_to_erase.push(arg_index); + + // Remove the argument from the block. + rewriter.replace_all_uses_of_value_with(block_arg, common_value); + } + } + } + + // Remove the arguments. + for arg_index in args_to_erase.iter().copied() { + block_mut.erase_argument(arg_index); + + // Remove the argument from the branch ops. + let mut preds = block_mut.uses_mut().front_mut(); + while let Some(mut pred) = preds.as_pointer() { + preds.move_next(); + + let mut pred = pred.borrow_mut(); + let mut pred_op = pred.owner.borrow_mut(); + if let Some(branch_op) = pred_op.as_trait_mut::() { + let succ_index = pred.index as usize; + let mut succ_operands = branch_op.get_successor_operands_mut(succ_index); + succ_operands.forwarded_mut().erase(arg_index); + } + } + } + + if !args_to_erase.is_empty() { + Ok(()) + } else { + Err(RegionTransformFailed) + } + } +} diff --git a/hir2/src/ir/successor.rs b/hir2/src/ir/successor.rs new file mode 100644 index 000000000..b4f37aca3 --- /dev/null +++ b/hir2/src/ir/successor.rs @@ -0,0 +1,460 @@ +use core::fmt; + +use super::OpOperandStorage; +use crate::{AttributeValue, BlockOperandRef, BlockRef, OpOperandRange, OpOperandRangeMut}; + +pub type OpSuccessorStorage = crate::EntityStorage; +pub type OpSuccessorRange<'a> = crate::EntityRange<'a, SuccessorInfo>; +pub type OpSuccessorRangeMut<'a> = crate::EntityRangeMut<'a, SuccessorInfo, 0>; + +/// This trait represents common behavior shared by any range of successor operands. +pub trait SuccessorOperands { + /// Returns true if there are no operands in this set + fn is_empty(&self) -> bool { + self.num_produced() == 0 && self.len() == 0 + } + /// Returns the total number of operands in this set + fn len(&self) -> usize; + /// Returns the number of internally produced operands in this set + fn num_produced(&self) -> usize; + /// Returns true if the operand at `index` is internally produced + #[inline] + fn is_operand_produced(&self, index: usize) -> bool { + index < self.num_produced() + } + /// Get the range of forwarded operands + fn forwarded(&self) -> OpOperandRange<'_>; + /// Get a [SuccessorOperand] representing the operand at `index` + /// + /// Returns `None` if the index is out of range. + fn get(&self, index: usize) -> Option { + if self.is_operand_produced(index) { + Some(SuccessorOperand::Produced) + } else { + self.forwarded() + .get(index) + .map(|op_operand| SuccessorOperand::Forwarded(op_operand.borrow().as_value_ref())) + } + } + + /// Get a [SuccessorOperand] representing the operand at `index`. + /// + /// Panics if the index is out of range. + fn get_unchecked(&self, index: usize) -> SuccessorOperand { + if self.is_operand_produced(index) { + SuccessorOperand::Produced + } else { + SuccessorOperand::Forwarded(self.forwarded()[index].borrow().as_value_ref()) + } + } + + /// Gets the index of the forwarded operand which maps to the given successor block argument + /// index. + /// + /// Panics if the given block argument index does not correspond to a forwarded operand. + fn get_operand_index(&self, block_argument_index: usize) -> usize { + assert!( + self.is_operand_produced(block_argument_index), + "cannot map operands produced by the operation" + ); + let base_index = self.forwarded()[0].borrow().index as usize; + base_index + (block_argument_index - self.num_produced()) + } +} + +/// This type models how operands are forwarded to block arguments in control flow. It consists of a +/// number, denoting how many of the successor block arguments are produced by the operation, +/// followed by a range of operands that are forwarded. The produced operands are passed to the +/// first few block arguments of the successor, followed by the forwarded operands. It is +/// unsupported to pass them in a different order. +/// +/// An example operation with both of these concepts would be a branch-on-error operation, that +/// internally produces an error object on the error path: +/// +/// ```hir,ignore +/// invoke %function(%0) +/// label ^success ^error(%1 : i32) +/// +/// ^error(%e: !error, %arg0 : i32): +/// ... +/// ``` +/// +/// This operation would return an instance of [SuccessorOperands] with a produced operand count of +/// 1 (mapped to `%e` in the successor) and forwarded operands consisting of `%1` in the example +/// above (mapped to `%arg0` in the successor). +pub struct SuccessorOperandRange<'a> { + /// The explicit op operands which are to be passed along to the successor + forwarded: OpOperandRange<'a>, + /// The number of operands that are produced internally within the operation and which are to + /// be passed to the successor before any forwarded operands. + num_produced: usize, +} +impl<'a> SuccessorOperandRange<'a> { + /// Create an empty successor operand set + pub fn empty() -> Self { + Self { + forwarded: OpOperandRange::empty(), + num_produced: 0, + } + } + + /// Create a successor operand set consisting solely of forwarded op operands + #[inline] + pub const fn forward(forwarded: OpOperandRange<'a>) -> Self { + Self { + forwarded, + num_produced: 0, + } + } + + /// Create a successor operand set consisting solely of `num_produced` internally produced + /// results + pub fn produced(num_produced: usize) -> Self { + Self { + forwarded: OpOperandRange::empty(), + num_produced, + } + } + + /// Create a new successor operand set with the given number of internally produced results, + /// and forwarded op operands. + #[inline] + pub const fn new(num_produced: usize, forwarded: OpOperandRange<'a>) -> Self { + Self { + forwarded, + num_produced, + } + } +} +impl<'a> SuccessorOperands for SuccessorOperandRange<'a> { + #[inline] + fn len(&self) -> usize { + self.num_produced + self.forwarded.len() + } + + #[inline(always)] + fn num_produced(&self) -> usize { + self.num_produced + } + + fn forwarded(&self) -> OpOperandRange<'_> { + self.forwarded.clone() + } +} + +/// The mutable variant of [SuccessorOperandsRange]. +pub struct SuccessorOperandRangeMut<'a> { + /// The explicit op operands which are to be passed along to the successor + forwarded: OpOperandRangeMut<'a>, + /// The number of operands that are produced internally within the operation and which are to + /// be passed to the successor before any forwarded operands. + num_produced: usize, +} +impl<'a> SuccessorOperandRangeMut<'a> { + /// Create a successor operand set consisting solely of forwarded op operands + #[inline] + pub const fn forward(forwarded: OpOperandRangeMut<'a>) -> Self { + Self { + forwarded, + num_produced: 0, + } + } + + /// Create a new successor operand set with the given number of internally produced results, + /// and forwarded op operands. + #[inline] + pub const fn new(num_produced: usize, forwarded: OpOperandRangeMut<'a>) -> Self { + Self { + forwarded, + num_produced, + } + } + + #[inline(always)] + pub fn forwarded_mut(&mut self) -> &mut OpOperandRangeMut<'a> { + &mut self.forwarded + } +} +impl<'a> SuccessorOperands for SuccessorOperandRangeMut<'a> { + #[inline] + fn len(&self) -> usize { + self.num_produced + self.forwarded.len() + } + + #[inline(always)] + fn num_produced(&self) -> usize { + self.num_produced + } + + #[inline(always)] + fn forwarded(&self) -> OpOperandRange<'_> { + self.forwarded.as_immutable() + } +} + +/// Represents an operand in a [SuccessorOperands] set. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SuccessorOperand { + /// This operand is internally produced by the operation, and passed to the successor before + /// any forwarded op operands. + Produced, + /// This operand is a forwarded operand of the operation. + Forwarded(crate::ValueRef), +} + +/// This trait represents successor-like values for operations, with support for control-flow +/// predicated on a "key", a sentinel value that must match in order for the successor block to be +/// taken. +/// +/// The ability to associate a successor with a user-defined key, is intended for modeling things +/// such as [crate::dialects::hir::Switch], which has one or more successors which are guarded by +/// an integer value that is matched against the input, or selector, value. Most importantly, doing +/// so in a way that keeps everything in sync as the IR is modified. +/// +/// When used as a successor argument to an operation, each successor gets its own operand group, +/// and if it has an associated key, keyed successors are stored in a special attribute which tracks +/// each key and its associated successor index. This allows requesting the successor details and +/// getting back the correct key, destination, and operands. +pub trait KeyedSuccessor { + /// The type of key this successor + type Key: AttributeValue + Clone + Eq; + /// The type of value which will represent a reference to this successor. + /// + /// You should use [OpSuccessor] if this successor is not keyed. + type Repr<'a>: 'a; + /// The type of value which will represent a mutable reference to this successor. + /// + /// You should use [OpSuccessorMut] if this successor is not keyed. + type ReprMut<'a>: 'a; + + /// The (optional) value of the key for this successor. + /// + /// Keys must be valid attribute values, as they will be encoded in the operation attributes. + /// + /// If `None` is returned, this successor is to be treated like a regular successor argument, + /// i.e. a destination block and associated operands. If a key is returned, the key must be + /// unique across the set of keyed successors. + fn key(&self) -> &Self::Key; + /// Convert this value into the raw parts comprising the successor information: + /// + /// * The (optional) key under which this successor is selected + /// * The destination block + /// * The destination operands + fn into_parts(self) -> (Self::Key, BlockRef, Vec); + fn into_repr( + key: Self::Key, + block: BlockOperandRef, + operands: OpOperandRange<'_>, + ) -> Self::Repr<'_>; + fn into_repr_mut( + key: Self::Key, + block: BlockOperandRef, + operands: OpOperandRangeMut<'_>, + ) -> Self::ReprMut<'_>; +} + +/// This struct tracks successor metadata needed by [crate::Operation] +#[derive(Clone)] +pub struct SuccessorInfo { + pub block: BlockOperandRef, + pub(crate) key: Option>, + pub(crate) operand_group: u8, +} +impl crate::StorableEntity for SuccessorInfo { + #[inline(always)] + fn index(&self) -> usize { + self.block.index() + } + + #[inline(always)] + unsafe fn set_index(&mut self, index: usize) { + self.block.set_index(index); + } + + #[inline(always)] + fn unlink(&mut self) { + self.block.unlink(); + } +} +impl fmt::Debug for SuccessorInfo { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SuccessorInfo") + .field("block", &self.block.borrow()) + .field("key", &self.key) + .field("operand_group", &self.operand_group) + .finish() + } +} + +/// An [OpSuccessor] is a BlockOperand + OpOperandRange for that block +pub struct OpSuccessor<'a> { + pub dest: BlockOperandRef, + pub arguments: OpOperandRange<'a>, +} +impl fmt::Debug for OpSuccessor<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpSuccessor") + .field("block", &self.dest.borrow().block_id()) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for operand in self.arguments.iter() { + list.entry(&operand.borrow()); + } + list.finish() + }) + .finish() + } +} + +/// An [OpSuccessorMut] is a BlockOperand + OpOperandRangeMut for that block +pub struct OpSuccessorMut<'a> { + pub dest: BlockOperandRef, + pub arguments: OpOperandRangeMut<'a>, +} +impl fmt::Debug for OpSuccessorMut<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpSuccessorMut") + .field("block", &self.dest.borrow().block_id()) + .field_with("arguments", |f| { + let mut list = f.debug_list(); + for operand in self.arguments.iter() { + list.entry(&operand.borrow()); + } + list.finish() + }) + .finish() + } +} + +pub struct KeyedSuccessorRange<'a, T> { + range: OpSuccessorRange<'a>, + operands: &'a OpOperandStorage, + _marker: core::marker::PhantomData, +} +impl<'a, T> KeyedSuccessorRange<'a, T> { + pub fn new(range: OpSuccessorRange<'a>, operands: &'a OpOperandStorage) -> Self { + Self { + range, + operands, + _marker: core::marker::PhantomData, + } + } + + pub fn get(&self, index: usize) -> Option> { + self.range.get(index).map(|info| { + let operands = self.operands.group(info.operand_group as usize); + SuccessorWithKey { + info, + operands, + _marker: core::marker::PhantomData, + } + }) + } + + pub fn iter(&self) -> KeyedSuccessorRangeIter<'a, '_, T> { + KeyedSuccessorRangeIter { + range: self, + index: 0, + } + } +} + +pub struct KeyedSuccessorRangeMut<'a, T> { + range: OpSuccessorRangeMut<'a>, + operands: &'a mut OpOperandStorage, + _marker: core::marker::PhantomData, +} +impl<'a, T> KeyedSuccessorRangeMut<'a, T> { + pub fn new(range: OpSuccessorRangeMut<'a>, operands: &'a mut OpOperandStorage) -> Self { + Self { + range, + operands, + _marker: core::marker::PhantomData, + } + } + + pub fn get(&self, index: usize) -> Option> { + self.range.get(index).map(|info| { + let operands = self.operands.group(info.operand_group as usize); + SuccessorWithKey { + info, + operands, + _marker: core::marker::PhantomData, + } + }) + } + + pub fn get_mut(&mut self, index: usize) -> Option> { + self.range.get_mut(index).map(|info| { + let operands = self.operands.group_mut(info.operand_group as usize); + SuccessorWithKeyMut { + info, + operands, + _marker: core::marker::PhantomData, + } + }) + } +} + +pub struct KeyedSuccessorRangeIter<'a, 'b: 'a, T> { + range: &'b KeyedSuccessorRange<'a, T>, + index: usize, +} +impl<'a, 'b: 'a, T> Iterator for KeyedSuccessorRangeIter<'a, 'b, T> { + type Item = SuccessorWithKey<'b, T>; + + fn next(&mut self) -> Option { + if self.index >= self.range.range.len() { + return None; + } + + let idx = self.index; + self.index += 1; + self.range.get(idx) + } +} + +pub struct SuccessorWithKey<'a, T> { + info: &'a SuccessorInfo, + operands: OpOperandRange<'a>, + _marker: core::marker::PhantomData, +} +impl<'a, T> SuccessorWithKey<'a, T> { + pub fn key(&self) -> Option<&T> { + self.info.key.map(|ptr| unsafe { &*(ptr.as_ptr() as *mut T) }) + } + + pub fn block(&self) -> BlockRef { + self.info.block.borrow().block.clone() + } + + #[inline(always)] + pub fn arguments(&self) -> &OpOperandRange<'a> { + &self.operands + } +} + +pub struct SuccessorWithKeyMut<'a, T> { + info: &'a SuccessorInfo, + operands: OpOperandRangeMut<'a>, + _marker: core::marker::PhantomData, +} +impl<'a, T> SuccessorWithKeyMut<'a, T> { + pub fn key(&self) -> Option<&T> { + self.info.key.map(|ptr| unsafe { &*(ptr.as_ptr() as *mut T) }) + } + + pub fn block(&self) -> BlockRef { + self.info.block.borrow().block.clone() + } + + #[inline(always)] + pub fn arguments(&self) -> &OpOperandRangeMut<'a> { + &self.operands + } + + #[inline(always)] + pub fn arguments_mut(&mut self) -> &mut OpOperandRangeMut<'a> { + &mut self.operands + } +} diff --git a/hir2/src/ir/symbol_table.rs b/hir2/src/ir/symbol_table.rs new file mode 100644 index 000000000..f6424aa5e --- /dev/null +++ b/hir2/src/ir/symbol_table.rs @@ -0,0 +1,647 @@ +use alloc::collections::VecDeque; +use core::fmt; + +use midenc_session::diagnostics::{miette, Diagnostic}; + +use crate::{ + define_attr_type, interner, InsertionPoint, Op, Operation, OperationRef, Report, + UnsafeIntrusiveEntityRef, Usable, Visibility, Walkable, +}; + +/// Represents the name of a [Symbol] in its local [SymbolTable] +pub type SymbolName = interner::Symbol; + +#[derive(Debug, thiserror::Error, Diagnostic)] +pub enum InvalidSymbolRefError { + #[error("invalid symbol reference: no symbol table available")] + NoSymbolTable { + #[label("cannot resolve this symbol")] + symbol: crate::SourceSpan, + #[label( + "because this operation has no parent symbol table with which to resolve the reference" + )] + user: crate::SourceSpan, + }, + #[error("invalid symbol reference: undefined symbol")] + UnknownSymbol { + #[label("failed to resolve this symbol")] + symbol: crate::SourceSpan, + #[label("in the nearest symbol table from this operation")] + user: crate::SourceSpan, + }, + #[error("invalid symbol reference: undefined component '{component}' of symbol")] + UnknownSymbolComponent { + #[label("failed to resolve this symbol")] + symbol: crate::SourceSpan, + #[label("from the root symbol table of this operation")] + user: crate::SourceSpan, + component: &'static str, + }, + #[error("invalid symbol reference: expected callable")] + NotCallable { + #[label("expected this symbol to implement the CallableOpInterface")] + symbol: crate::SourceSpan, + }, +} + +#[derive(Clone)] +pub struct SymbolNameAttr { + pub user: SymbolUseRef, + /// The path through the abstract symbol space to the containing symbol table + /// + /// It is assumed that all symbol tables are also symbols themselves, and thus the path to + /// `name` is formed from the names of all parent symbol tables, in hierarchical order. + /// + /// For example, consider a program consisting of a single component `@test_component`, + /// containing a module `@foo`, which in turn contains a function `@a`. The `path` for `@a` + /// would be `@test_component::@foo`, and `name` would be `@a`. + /// + /// If set to `interner::symbols::Empty`, the symbol `name` is in the global namespace. + /// + /// If set to any other value, then we recover the components of the path by splitting the + /// value on `::`. If not present, the path represents a single namespace. If multiple parts + /// are present, then each part represents a nested namespace starting from the global one. + pub path: SymbolName, + /// The name of the symbol + pub name: SymbolName, +} +define_attr_type!(SymbolNameAttr); +impl SymbolNameAttr { + #[inline(always)] + pub const fn name(&self) -> SymbolName { + self.name + } + + #[inline(always)] + pub const fn path(&self) -> SymbolName { + self.path + } + + /// Returns true if this symbol name is fully-qualified + pub fn is_absolute(&self) -> bool { + self.path.as_str().starts_with("::") + } + + #[inline] + pub fn has_parent(&self) -> bool { + self.path != interner::symbols::Empty + } + + pub fn components(&self) -> SymbolNameComponents { + SymbolNameComponents::new(self.path, self.name) + } +} +impl fmt::Display for SymbolNameAttr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.has_parent() { + write!(f, "{}::{}", &self.path, &self.name) + } else { + f.write_str(self.name.as_str()) + } + } +} +impl fmt::Debug for SymbolNameAttr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SymbolNameAttr") + .field("user", &self.user.borrow()) + .field("path", &self.path) + .field("name", &self.name) + .finish() + } +} +impl crate::formatter::PrettyPrint for SymbolNameAttr { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + display(self) + } +} +impl Eq for SymbolNameAttr {} +impl PartialEq for SymbolNameAttr { + fn eq(&self, other: &Self) -> bool { + self.path == other.path && self.name == other.name + } +} +impl PartialOrd for SymbolNameAttr { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for SymbolNameAttr { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.path.cmp(&other.path).then_with(|| self.name.cmp(&other.name)) + } +} +impl core::hash::Hash for SymbolNameAttr { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.path.hash(state); + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum SymbolNameComponent { + /// A component that signals the path is relative to the root symbol table + Root, + /// A component of the symbol name path + Component(SymbolName), + /// The name of the symbol in its local symbol table + Leaf(SymbolName), +} + +pub struct SymbolNameComponents { + parts: VecDeque<&'static str>, + name: SymbolName, + done: bool, +} +impl SymbolNameComponents { + fn new(path: SymbolName, name: SymbolName) -> Self { + let mut parts = VecDeque::default(); + if path == interner::symbols::Empty { + return Self { + parts, + name, + done: false, + }; + } + let mut split = path.as_str().split("::"); + let start = split.next().unwrap(); + if start.is_empty() { + parts.push_back("::"); + } + + while let Some(part) = split.next() { + if part.is_empty() { + if let Some(part2) = split.next() { + if part2.is_empty() { + parts.push_back("::"); + } else { + parts.push_back(part2); + } + } else { + break; + } + } else { + parts.push_back(part); + } + } + + Self { + parts, + name, + done: false, + } + } + + /// Convert this iterator into a symbol name representing the path prefix of a [Symbol]. + /// + /// If `absolute` is set to true, then the resulting path will be prefixed with `::` + pub fn into_path(self, absolute: bool) -> SymbolName { + if self.parts.is_empty() { + return ::midenc_hir_symbol::symbols::Empty; + } + + let mut buf = + String::with_capacity(2usize + self.parts.iter().map(|p| p.len()).sum::()); + if absolute { + buf.push_str("::"); + } + for part in self.parts { + buf.push_str(part); + } + SymbolName::intern(buf) + } +} +impl core::iter::FusedIterator for SymbolNameComponents {} +impl Iterator for SymbolNameComponents { + type Item = SymbolNameComponent; + + fn next(&mut self) -> Option { + if self.done { + return None; + } + if let Some(part) = self.parts.pop_front() { + if part == "::" { + return Some(SymbolNameComponent::Root); + } + return Some(SymbolNameComponent::Component(part.into())); + } + self.done = true; + Some(SymbolNameComponent::Leaf(self.name)) + } +} + +/// A trait which allows multiple types to be coerced into a [SymbolRef]. +/// +/// This is primarily intended for use in operation builders. +pub trait AsSymbolRef { + fn as_symbol_ref(&self) -> SymbolRef; +} +impl AsSymbolRef for &T { + #[inline] + fn as_symbol_ref(&self) -> SymbolRef { + unsafe { SymbolRef::from_raw(*self as &dyn Symbol) } + } +} +impl AsSymbolRef for UnsafeIntrusiveEntityRef { + #[inline] + fn as_symbol_ref(&self) -> SymbolRef { + let t_ptr = Self::as_ptr(self); + unsafe { SymbolRef::from_raw(t_ptr as *const dyn Symbol) } + } +} +impl AsSymbolRef for SymbolRef { + #[inline(always)] + fn as_symbol_ref(&self) -> SymbolRef { + Self::clone(self) + } +} + +/// A [SymbolTable] is an IR entity which contains other IR entities, called _symbols_, each of +/// which has a name, aka symbol, that uniquely identifies it amongst all other entities in the +/// same [SymbolTable]. +/// +/// The symbols in a [SymbolTable] do not need to all refer to the same entity type, however the +/// concrete value type of the symbol itself, e.g. `String`, must be the same. This is enforced +/// in the way that the [SymbolTable] and [Symbol] traits interact. A [SymbolTable] has an +/// associated `Key` type, and a [Symbol] has an associated `Id` type - only types whose `Id` +/// type matches the `Key` type of the [SymbolTable], can be stored in that table. +pub trait SymbolTable { + /// Get a reference to the underlying [Operation] + fn as_symbol_table_operation(&self) -> &Operation; + + /// Get a mutable reference to the underlying [Operation] + fn as_symbol_table_operation_mut(&mut self) -> &mut Operation; + + /// Get the entry for `name` in this table + fn get(&self, name: SymbolName) -> Option; + + /// Insert `entry` in the symbol table, but only if no other symbol with the same name exists. + /// + /// If provided, the symbol will be inserted at the given insertion point in the body of the + /// symbol table operation. + /// + /// This function will panic if the symbol is attached to another symbol table. + /// + /// Returns `true` if successful, `false` if the symbol is already defined + fn insert_new(&mut self, entry: SymbolRef, ip: Option) -> bool; + + /// Like [SymbolTable::insert_new], except the symbol is renamed to avoid collisions. + /// + /// Returns the name of the symbol after insertion. + fn insert(&mut self, entry: SymbolRef, ip: Option) -> SymbolName; + + /// Remove the symbol `name`, and return the entry if one was present. + fn remove(&mut self, name: SymbolName) -> Option; + + /// Renames the symbol named `from`, as `to`, as well as all uses of that symbol. + /// + /// Returns `Err` if unable to update all uses. + fn rename(&mut self, from: SymbolName, to: SymbolName) -> Result<(), Report>; +} + +impl dyn SymbolTable { + /// Look up a symbol with the given name and concrete type, returning `None` if no such symbol + /// exists + pub fn find(&self, name: SymbolName) -> Option> { + let op = self.get(name)?; + let op = op.borrow(); + let op = op.as_symbol_operation().downcast_ref::()?; + Some(unsafe { UnsafeIntrusiveEntityRef::from_raw(op) }) + } +} + +/// A [Symbol] is an IR entity with an associated _symbol_, or name, which is expected to be unique +/// amongst all other symbols in the same namespace. +/// +/// For example, functions are named, and are expected to be unique within the same module, +/// otherwise it would not be possible to unambiguously refer to a function by name. Likewise +/// with modules in a program, etc. +pub trait Symbol: Usable + 'static { + fn as_symbol_operation(&self) -> &Operation; + fn as_symbol_operation_mut(&mut self) -> &mut Operation; + /// Get the name of this symbol + fn name(&self) -> SymbolName; + /// Get an iterator over the components of the fully-qualified path of this symbol. + fn components(&self) -> SymbolNameComponents { + let mut parts = VecDeque::default(); + if let Some(symbol_table) = self.root_symbol_table() { + let symbol_table = symbol_table.borrow(); + symbol_table.walk_symbol_tables(true, |symbol_table, _| { + if let Some(sym) = symbol_table.as_symbol_table_operation().as_symbol() { + parts.push_back(sym.name().as_str()); + } + }); + } + SymbolNameComponents { + parts, + name: self.name(), + done: false, + } + } + /// Set the name of this symbol + fn set_name(&mut self, name: SymbolName); + /// Get the visibility of this symbol + fn visibility(&self) -> Visibility; + /// Returns true if this symbol has private visibility + #[inline] + fn is_private(&self) -> bool { + self.visibility().is_private() + } + /// Returns true if this symbol has public visibility + #[inline] + fn is_public(&self) -> bool { + self.visibility().is_public() + } + /// Sets the visibility of this symbol + fn set_visibility(&mut self, visibility: Visibility); + /// Sets the visibility of this symbol to private + fn set_private(&mut self) { + self.set_visibility(Visibility::Private); + } + /// Sets the visibility of this symbol to internal + fn set_internal(&mut self) { + self.set_visibility(Visibility::Internal); + } + /// Sets the visibility of this symbol to public + fn set_public(&mut self) { + self.set_visibility(Visibility::Public); + } + /// Get all of the uses of this symbol that are nested within `from` + fn symbol_uses(&self, from: OperationRef) -> SymbolUsesIter; + /// Return true if there are no uses of this symbol nested within `from` + fn symbol_uses_known_empty(&self, from: OperationRef) -> bool { + self.symbol_uses(from).is_empty() + } + /// Attempt to replace all uses of this symbol nested within `from`, with the provided replacement + fn replace_all_uses( + &mut self, + replacement: SymbolRef, + from: OperationRef, + ) -> Result<(), Report>; + /// Returns true if this operation can be discarded if it has no remaining symbol uses + /// + /// By default, if the visibility is non-public, a symbol is considered discardable + fn can_discard_when_unused(&self) -> bool { + !self.is_public() + } + /// Returns true if this operation is a declaration, rather than a definition, of a symbol + /// + /// The default implementation assumes that all operations are definitions + fn is_declaration(&self) -> bool { + false + } + /// Return the root symbol table in which this symbol is contained, if one exists. + /// + /// The root symbol table does not necessarily know about this symbol, rather the symbol table + /// which "owns" this symbol may itself be a symbol that belongs to another symbol table. This + /// function traces this chain as far as it goes, and returns the highest ancestor in the tree. + fn root_symbol_table(&self) -> Option { + self.as_symbol_operation().root_symbol_table() + } +} + +impl dyn Symbol { + pub fn is(&self) -> bool { + let op = self.as_symbol_operation(); + op.is::() + } + + pub fn downcast_ref(&self) -> Option<&T> { + let op = self.as_symbol_operation(); + op.downcast_ref::() + } + + pub fn downcast_mut(&mut self) -> Option<&mut T> { + let op = self.as_symbol_operation_mut(); + op.downcast_mut::() + } + + /// Get an [OperationRef] for the operation underlying this symbol + /// + /// NOTE: This relies on the assumption that all ops are allocated via the arena, and that all + /// [Symbol] implementations are ops. + pub fn as_operation_ref(&self) -> OperationRef { + self.as_symbol_operation().as_operation_ref() + } +} + +impl Operation { + /// Returns true if this operation implements [Symbol] + #[inline] + pub fn is_symbol(&self) -> bool { + self.implements::() + } + + /// Get this operation as a [Symbol], if this operation implements the trait. + #[inline] + pub fn as_symbol(&self) -> Option<&dyn Symbol> { + self.as_trait::() + } + + /// Get this operation as a [SymbolTable], if this operation implements the trait. + #[inline] + pub fn as_symbol_table(&self) -> Option<&dyn SymbolTable> { + self.as_trait::() + } + + /// Return the root symbol table in which this symbol is contained, if one exists. + /// + /// The root symbol table does not necessarily know about this symbol, rather the symbol table + /// which "owns" this symbol may itself be a symbol that belongs to another symbol table. This + /// function traces this chain as far as it goes, and returns the highest ancestor in the tree. + pub fn root_symbol_table(&self) -> Option { + let mut parent = self.nearest_symbol_table(); + let mut found = None; + while let Some(nearest_symbol_table) = parent.take() { + found = Some(nearest_symbol_table.clone()); + parent = nearest_symbol_table.borrow().nearest_symbol_table(); + } + found + } + + /// Returns the nearest [SymbolTable] from this operation. + /// + /// Returns `None` if no parent of this operation is a valid symbol table. + pub fn nearest_symbol_table(&self) -> Option { + let mut parent = self.parent_op(); + while let Some(parent_op) = parent.take() { + let op = parent_op.borrow(); + if op.implements::() { + drop(op); + return Some(parent_op); + } + parent = op.parent_op(); + } + None + } + + /// Returns the operation registered with the given symbol name within the closest symbol table + /// including `self`. + /// + /// Returns `None` if the symbol is not found. + pub fn nearest_symbol(&self, symbol: SymbolName) -> Option { + if let Some(sym) = self.as_symbol() { + if sym.name() == symbol { + return Some(unsafe { UnsafeIntrusiveEntityRef::from_raw(sym) }); + } + } + let symbol_table_op = self.nearest_symbol_table()?; + let op = symbol_table_op.borrow(); + let symbol_table = op.as_trait::().unwrap(); + symbol_table.get(symbol) + } + + /// Walks all symbol table operations nested within this operation, including itself. + /// + /// For each symbol table operation, the provided callback is invoked with the op and a boolean + /// signifying if the symbols within that symbol table can be treated as if all uses within the + /// IR are visible to the caller. + pub fn walk_symbol_tables(&self, all_symbol_uses_visible: bool, mut callback: F) + where + F: FnMut(&dyn SymbolTable, bool), + { + self.prewalk(|op: OperationRef| { + let op = op.borrow(); + if let Some(sym) = op.as_symbol_table() { + callback(sym, all_symbol_uses_visible); + } + }); + } +} + +pub type SymbolRef = UnsafeIntrusiveEntityRef; + +impl crate::Verify for T +where + T: Op + Symbol, +{ + fn verify(&self, context: &super::Context) -> Result<(), Report> { + verify_symbol(self, context) + } +} + +impl crate::Verify for Operation { + fn should_verify(&self, _context: &super::Context) -> bool { + self.implements::() + } + + fn verify(&self, context: &super::Context) -> Result<(), Report> { + verify_symbol( + self.as_trait::() + .expect("this operation does not implement the `Symbol` trait"), + context, + ) + } +} + +fn verify_symbol(symbol: &dyn Symbol, context: &super::Context) -> Result<(), Report> { + use midenc_session::diagnostics::{Severity, Spanned}; + + // Symbols must either have no parent, or be an immediate child of a SymbolTable + let op = symbol.as_symbol_operation(); + let parent = op.parent_op(); + if !parent.is_none_or(|parent| parent.borrow().implements::()) { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label(op.span(), "expected parent of this operation to be a symbol table") + .with_help("required due to this operation implementing the 'Symbol' trait") + .into_report()); + } + Ok(()) +} + +pub type SymbolUseRef = UnsafeIntrusiveEntityRef; +pub type SymbolUseList = crate::EntityList; +pub type SymbolUseIter<'a> = crate::EntityIter<'a, SymbolUse>; +pub type SymbolUseCursor<'a> = crate::EntityCursor<'a, SymbolUse>; +pub type SymbolUseCursorMut<'a> = crate::EntityCursorMut<'a, SymbolUse>; + +pub struct SymbolUsesIter { + items: VecDeque, +} +impl ExactSizeIterator for SymbolUsesIter { + #[inline(always)] + fn len(&self) -> usize { + self.items.len() + } +} +impl From> for SymbolUsesIter { + fn from(items: VecDeque) -> Self { + Self { items } + } +} +impl FromIterator for SymbolUsesIter { + fn from_iter>(iter: T) -> Self { + Self { + items: iter.into_iter().collect(), + } + } +} +impl core::iter::FusedIterator for SymbolUsesIter {} +impl Iterator for SymbolUsesIter { + type Item = SymbolUseRef; + + #[inline] + fn next(&mut self) -> Option { + self.items.pop_front() + } +} + +/// An [OpOperand] represents a use of a [Value] by an [Operation] +pub struct SymbolUse { + /// The user of the symbol + pub owner: OperationRef, + /// The symbol attribute of the op that stores the symbol + pub symbol: crate::interner::Symbol, +} +impl SymbolUse { + #[inline] + pub fn new(owner: OperationRef, symbol: crate::interner::Symbol) -> Self { + Self { owner, symbol } + } +} +impl fmt::Debug for SymbolUse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let op = self.owner.borrow(); + let value = op.get_typed_attribute::(self.symbol); + f.debug_struct("SymbolUse") + .field("attr", &self.symbol) + .field("symbol", &value) + .finish_non_exhaustive() + } +} + +/// Generate a unique symbol name. +/// +/// Iteratively increase `counter` and use it as a suffix for symbol names until `is_unique` does +/// not detect any conflict. +pub fn generate_symbol_name(name: SymbolName, counter: &mut usize, is_unique: F) -> SymbolName +where + F: Fn(&str) -> bool, +{ + use core::fmt::Write; + + use crate::SmallStr; + + if is_unique(name.as_str()) { + return name; + } + + let base_len = name.as_str().len(); + let mut buf = SmallStr::with_capacity(base_len + 2); + buf.push_str(name.as_str()); + loop { + *counter += 1; + buf.truncate(base_len); + buf.push('_'); + write!(&mut buf, "{counter}").unwrap(); + + if is_unique(buf.as_str()) { + break SymbolName::intern(buf); + } + } +} diff --git a/hir2/src/ir/traits.rs b/hir2/src/ir/traits.rs new file mode 100644 index 000000000..18920e1d2 --- /dev/null +++ b/hir2/src/ir/traits.rs @@ -0,0 +1,303 @@ +mod foldable; +mod info; +mod types; + +use midenc_session::diagnostics::Severity; + +pub(crate) use self::info::TraitInfo; +pub use self::{ + foldable::{FoldResult, Foldable, OpFoldResult}, + types::*, +}; +use crate::{derive, AttributeValue, Context, Operation, Report, Spanned}; + +/// Marker trait for commutative ops, e.g. `X op Y == Y op X` +pub trait Commutative {} + +/// Marker trait for constant-like ops +pub trait ConstantLike {} + +/// Marker trait for ops with side effects +pub trait HasSideEffects {} + +/// Marker trait for ops with recursive memory effects, i.e. the effects of the operation includes +/// the effects of operations nested within its regions. If the operation does not implement any +/// effect markers, e.g. `MemoryWrite`, then it can be assumed to have no memory effects itself. +pub trait HasRecursiveMemoryEffects {} + +/// Marker trait for ops which allocate memory +pub trait MemoryAlloc {} + +/// Marker trait for ops which free memory +pub trait MemoryFree {} + +/// Marker trait for ops which read memory +pub trait MemoryRead {} + +/// Marker trait for ops which write memory +pub trait MemoryWrite {} + +/// Marker trait for return-like ops +pub trait ReturnLike {} + +/// Op is a terminator (i.e. it can be used to terminate a block) +pub trait Terminator {} + +/// Op's regions do not require blocks to end with a [Terminator] +pub trait NoTerminator {} + +/// Marker trait for idemptoent ops, i.e. `op op X == op X (unary) / X op X == X (binary)` +pub trait Idempotent {} + +/// Marker trait for ops that exhibit the property `op op X == X` +pub trait Involution {} + +/// Marker trait for ops which are not permitted to access values defined above them +pub trait IsolatedFromAbove {} + +/// Marker trait for ops which have only regions of [`RegionKind::Graph`] +pub trait HasOnlyGraphRegion {} + +/// Op's regions are all single-block graph regions, that not require a terminator +/// +/// This trait _cannot_ be derived via `derive!` +pub trait GraphRegionNoTerminator: + NoTerminator + SingleBlock + crate::RegionKindInterface + HasOnlyGraphRegion +{ +} + +// TODO(pauls): Implement verifier +/// This interface provides information for branching terminator operations, i.e. terminator +/// operations with successors. +/// +/// This interface is meant to model well-defined cases of control-flow of value propagation, where +/// what occurs along control-flow edges is assumed to be side-effect free. For example, +/// corresponding successor operands and successor block arguments may have different types. In such +/// cases, `are_types_compatible` can be implemented to compare types along control-flow edges. By +/// default, type equality is used. +pub trait BranchOpInterface: crate::Op { + /// Returns the operands that correspond to the arguments of the successor at `index`. + /// + /// It consists of a number of operands that are internally produced by the operation, followed + /// by a range of operands that are forwarded. An example operation making use of produced + /// operands would be: + /// + /// ```hir,ignore + /// invoke %function(%0) + /// label ^success ^error(%1 : i32) + /// + /// ^error(%e: !error, %arg0: i32): + /// ... + ///``` + /// + /// The operand that would map to the `^error`s `%e` operand is produced by the `invoke` + /// operation, while `%1` is a forwarded operand that maps to `%arg0` in the successor. + /// + /// Produced operands always map to the first few block arguments of the successor, followed by + /// the forwarded operands. Mapping them in any other order is not supported by the interface. + /// + /// By having the forwarded operands last allows users of the interface to append more forwarded + /// operands to the branch operation without interfering with other successor operands. + fn get_successor_operands(&self, index: usize) -> crate::SuccessorOperandRange<'_> { + let op = ::as_operation(self); + let operand_group = op.successors()[index].operand_group as usize; + crate::SuccessorOperandRange::forward(op.operands().group(operand_group)) + } + /// The mutable version of [Self::get_successor_operands]. + fn get_successor_operands_mut(&mut self, index: usize) -> crate::SuccessorOperandRangeMut<'_> { + let op = ::as_operation_mut(self); + let operand_group = op.successors()[index].operand_group as usize; + crate::SuccessorOperandRangeMut::forward(op.operands_mut().group_mut(operand_group)) + } + /// Returns the block argument of the successor corresponding to the operand at `operand_index`. + /// + /// Returns `None` if the specified operand is not a successor operand. + fn get_successor_block_argument( + &self, + operand_index: usize, + ) -> Option { + let op = ::as_operation(self); + let operand_groups = op.operands().num_groups(); + let mut next_index = 0usize; + for operand_group in 0..operand_groups { + let group_size = op.operands().group(operand_group).len(); + if (next_index..(next_index + group_size)).contains(&operand_index) { + let arg_index = operand_index - next_index; + // We found the operand group, now map that to a successor + let succ_info = + op.successors().iter().find(|s| operand_group == s.operand_group as usize)?; + return succ_info.block.borrow().block.borrow().arguments().get(arg_index).cloned(); + } + + next_index += group_size; + } + + None + } + /// Returns the successor that would be chosen with the given constant operands. + /// + /// Returns `None` if a single successor could not be chosen. + #[inline] + #[allow(unused_variables)] + fn get_successor_for_operands( + &self, + operands: &[Box], + ) -> Option { + None + } + /// This is called to compare types along control-flow edges. + /// + /// By default, types must be exactly equal to be compatible. + fn are_types_compatible(&self, lhs: &crate::Type, rhs: &crate::Type) -> bool { + lhs == rhs + } +} + +/// This interface provides information for select-like operations, i.e., operations that forward +/// specific operands to the output, depending on a binary condition. +/// +/// If the value of the condition is 1, then the `true` operand is returned, and the third operand +/// is ignored, even if it was poison. +/// +/// If the value of the condition is 0, then the `false` operand is returned, and the second operand +/// is ignored, even if it was poison. +/// +/// If the condition is poison, then poison is returned. +/// +/// Implementing operations can also accept shaped conditions, in which case the operation works +/// element-wise. +pub trait SelectLikeOpInterface { + /// Returns the operand that represents the boolean condition for this select-like op. + fn get_condition(&self) -> crate::ValueRef; + /// Returns the operand that would be chosen for a true condition. + fn get_true_value(&self) -> crate::ValueRef; + /// Returns the operand that would be chosen for a false condition. + fn get_false_value(&self) -> crate::ValueRef; +} + +derive! { + /// Marker trait for unary ops, i.e. those which take a single operand + pub trait UnaryOp {} + + verify { + fn is_unary_op(op: &Operation, context: &Context) -> Result<(), Report> { + if op.num_operands() == 1 { + Ok(()) + } else { + Err( + context.session.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label(op.span(), format!("incorrect number of operands, expected 1, got {}", op.num_operands())) + .with_help("this operator implements 'UnaryOp', which requires it to have exactly one operand") + .into_report() + ) + } + } + } +} + +derive! { + /// Marker trait for binary ops, i.e. those which take two operands + pub trait BinaryOp {} + + verify { + fn is_binary_op(op: &Operation, context: &Context) -> Result<(), Report> { + if op.num_operands() == 2 { + Ok(()) + } else { + Err( + context.session.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label(op.span(), format!("incorrect number of operands, expected 2, got {}", op.num_operands())) + .with_help("this operator implements 'BinaryOp', which requires it to have exactly two operands") + .into_report() + ) + } + } + } +} + +derive! { + /// Op's regions have no arguments + pub trait NoRegionArguments {} + + verify { + fn no_region_arguments(op: &Operation, context: &Context) -> Result<(), Report> { + for region in op.regions().iter() { + if region.entry().has_arguments() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation does not permit regions with arguments, but one was found" + ) + .into_report()); + } + } + + Ok(()) + } + } +} + +derive! { + /// Op's regions have a single block + pub trait SingleBlock {} + + verify { + fn has_only_single_block_regions(op: &Operation, context: &Context) -> Result<(), Report> { + for region in op.regions().iter() { + if region.body().iter().count() > 1 { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation requires single-block regions, but regions with multiple \ + blocks were found", + ) + .into_report()); + } + } + + Ok(()) + } + } +} + +// pub trait SingleBlockImplicitTerminator {} + +derive! { + /// Op has a single region + pub trait SingleRegion {} + + verify { + fn has_exactly_one_region(op: &Operation, context: &Context) -> Result<(), Report> { + let num_regions = op.num_regions(); + if num_regions != 1 { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + format!("this operation requires exactly one region, but got {num_regions}") + ) + .into_report()); + } + + Ok(()) + } + } +} + +// pub trait HasParent {} +// pub trait ParentOneOf<(T,...)> {} diff --git a/hir2/src/ir/traits/foldable.rs b/hir2/src/ir/traits/foldable.rs new file mode 100644 index 000000000..cef1a08b1 --- /dev/null +++ b/hir2/src/ir/traits/foldable.rs @@ -0,0 +1,126 @@ +use smallvec::SmallVec; + +use crate::{AttributeValue, ValueRef}; + +/// Represents the outcome of an attempt to fold an operation. +#[must_use] +pub enum FoldResult { + /// The operation was folded and erased, and the given fold results were returned + Ok(T), + /// The operation was modified in-place, but not erased. + InPlace, + /// The operation could not be folded + Failed, +} +impl FoldResult { + /// Returns true if folding was successful + #[inline] + pub fn is_ok(&self) -> bool { + matches!(self, Self::Ok(_) | Self::InPlace) + } + + /// Returns true if folding was unsuccessful + #[inline] + pub fn is_failed(&self) -> bool { + matches!(self, Self::Failed) + } + + /// Convert this result to an `Option` representing a successful outcome, where `None` indicates + /// an in-place fold, and `Some(T)` indicates that the operation was folded away. + /// + /// Panics with the given message if the fold attempt failed. + #[inline] + #[track_caller] + pub fn expect(self, message: &'static str) -> Option { + match self { + Self::Ok(out) => Some(out), + Self::InPlace => None, + Self::Failed => unwrap_failed_fold_result(message), + } + } +} + +#[cold] +#[track_caller] +#[inline(never)] +fn unwrap_failed_fold_result(message: &'static str) -> ! { + panic!("tried to unwrap failed fold result as successful: {message}") +} + +/// Represents a single result value of a folded operation. +#[derive(Debug)] +pub enum OpFoldResult { + /// The value is constant + Attribute(Box), + /// The value is a non-constant SSA value + Value(ValueRef), +} +impl OpFoldResult { + #[inline] + pub fn is_constant(&self) -> bool { + matches!(self, Self::Attribute(_)) + } +} +impl Eq for OpFoldResult {} +impl PartialEq for OpFoldResult { + fn eq(&self, other: &Self) -> bool { + use core::hash::{Hash, Hasher}; + + match (self, other) { + (Self::Attribute(lhs), Self::Attribute(rhs)) => { + if lhs.as_any().type_id() != rhs.as_any().type_id() { + return false; + } + let lhs_hash = { + let mut hasher = rustc_hash::FxHasher::default(); + lhs.hash(&mut hasher); + hasher.finish() + }; + let rhs_hash = { + let mut hasher = rustc_hash::FxHasher::default(); + rhs.hash(&mut hasher); + hasher.finish() + }; + lhs_hash == rhs_hash + } + (Self::Value(lhs), Self::Value(rhs)) => ValueRef::ptr_eq(lhs, rhs), + _ => false, + } + } +} +impl core::fmt::Display for OpFoldResult { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Attribute(ref attr) => attr.pretty_print(f), + Self::Value(ref value) => write!(f, "{}", value.borrow().id()), + } + } +} + +/// An operation that can be constant-folded must implement the folding logic via this trait. +/// +/// NOTE: Any `ConstantLike` operation must implement this trait as a no-op, i.e. returning the +/// value of the constant directly, as this is used by the pattern matching infrastructure to +/// extract the value of constant operations without knowing anything about the specific op. +pub trait Foldable { + /// Attempt to fold this operation using its current operand values. + /// + /// If folding was successful and the operation should be erased, `results` will contain the + /// folded results. See [FoldResult] for more details on what the various outcomes of folding + /// are. + fn fold(&self, results: &mut SmallVec<[OpFoldResult; 1]>) -> FoldResult; + + /// Attempt to fold this operation with the specified operand values. + /// + /// The elements in `operands` will correspond 1:1 with the operands of the operation, but will + /// be `None` if the value is non-constant. + /// + /// If folding was successful and the operation should be erased, `results` will contain the + /// folded results. See [FoldResult] for more details on what the various outcomes of folding + /// are. + fn fold_with( + &self, + operands: &[Option>], + results: &mut SmallVec<[OpFoldResult; 1]>, + ) -> FoldResult; +} diff --git a/hir2/src/ir/traits/info.rs b/hir2/src/ir/traits/info.rs new file mode 100644 index 000000000..1a4e0754f --- /dev/null +++ b/hir2/src/ir/traits/info.rs @@ -0,0 +1,83 @@ +use core::{ + any::{Any, TypeId}, + marker::Unsize, + ptr::{null, DynMetadata, Pointee}, +}; + +pub struct TraitInfo { + /// The [TypeId] of the trait type, used as a unique key for [TraitImpl]s + type_id: TypeId, + /// Type-erased dyn metadata containing the trait vtable pointer for the concrete type + /// + /// This is transmuted to the correct trait type when reifying a `&dyn Trait` reference, + /// which is safe as `DynMetadata` is always the same size for all types. + metadata: DynMetadata, +} +impl TraitInfo { + pub fn new() -> Self + where + T: Any + Unsize + crate::verifier::Verifier, + Trait: ?Sized + Pointee> + 'static, + { + let type_id = TypeId::of::(); + let ptr = null::() as *const Trait; + let (_, metadata) = ptr.to_raw_parts(); + Self { + type_id, + metadata: unsafe { + core::mem::transmute::, DynMetadata>(metadata) + }, + } + } + + #[inline(always)] + pub const fn type_id(&self) -> &TypeId { + &self.type_id + } + + /// Obtain the dyn metadata for `Trait` from this info. + /// + /// # Safety + /// + /// This is highly unsafe - you must guarantee that `Trait` is the same type as the one used + /// to create this `TraitInfo` instance. In debug mode, errors like this will be caught, but + /// in release builds, no checks are performed, and absolute havoc will result if you use this + /// incorrectly. + /// + /// It is intended _only_ for use by generated code which has all of the type information + /// available to it statically. It must be public so that operations can be defined in other + /// crates. + pub unsafe fn metadata_unchecked(&self) -> DynMetadata + where + Trait: ?Sized + Pointee> + 'static, + { + debug_assert!(self.type_id == TypeId::of::()); + core::mem::transmute(self.metadata) + } +} +impl Eq for TraitInfo {} +impl PartialEq for TraitInfo { + fn eq(&self, other: &Self) -> bool { + self.type_id == other.type_id + } +} +impl PartialEq for TraitInfo { + fn eq(&self, other: &TypeId) -> bool { + self.type_id.eq(other) + } +} +impl PartialOrd for TraitInfo { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.type_id.cmp(&other.type_id)) + } +} +impl PartialOrd for TraitInfo { + fn partial_cmp(&self, other: &TypeId) -> Option { + Some(self.type_id.cmp(other)) + } +} +impl Ord for TraitInfo { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.type_id.cmp(&other.type_id) + } +} diff --git a/hir2/src/ir/traits/types.rs b/hir2/src/ir/traits/types.rs new file mode 100644 index 000000000..493b0a71e --- /dev/null +++ b/hir2/src/ir/traits/types.rs @@ -0,0 +1,512 @@ +use core::fmt; + +use midenc_session::diagnostics::Severity; + +use crate::{derive, Context, Op, Operation, Report, Spanned, Type}; + +/// OpInterface to compute the return type(s) of an operation. +pub trait InferTypeOpInterface: Op { + /// Run type inference for this op's results, using the current state, and apply any changes. + /// + /// Returns an error if unable to infer types, or if some type constraint is violated. + fn infer_return_types(&mut self, context: &Context) -> Result<(), Report>; + + /// Return whether the set sets of types are compatible + fn are_compatible_return_types(&self, lhs: &[Type], rhs: &[Type]) -> bool { + lhs == rhs + } +} + +derive! { + /// Op expects all operands to be of the same type + pub trait SameTypeOperands {} + + verify { + fn operands_are_the_same_type(op: &Operation, context: &Context) -> Result<(), Report> { + let mut operands = op.operands().iter(); + if let Some(first_operand) = operands.next() { + let (expected_ty, set_by) = { + let operand = first_operand.borrow(); + let value = operand.value(); + (value.ty().clone(), value.span()) + }; + + for operand in operands { + let operand = operand.borrow(); + let value = operand.value(); + let value_ty = value.ty(); + if value_ty != &expected_ty { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operation") + .with_primary_label( + op.span(), + "this operation expects all operands to be of the same type" + ) + .with_secondary_label( + set_by, + "inferred the expected type from this value" + ) + .with_secondary_label( + value.span(), + "which differs from this value" + ) + .with_help(format!("expected '{expected_ty}', got '{value_ty}'")) + .into_report() + ); + } + } + } + + Ok(()) + } + } +} + +derive! { + /// Op expects all operands and results to be of the same type + /// + /// TODO(pauls): Implement verification for this. Ideally we could require `SameTypeOperands` + /// as a super trait, check the operands using its implementation, and then check the results + /// separately + pub trait SameOperandsAndResultType {} +} + +/// An operation trait that indicates it expects a variable number of operands, matching the given +/// type constraint, i.e. zero or more of the base type. +pub trait Variadic {} + +impl crate::Verify> for T +where + T: crate::Op + Variadic, +{ + fn verify(&self, context: &Context) -> Result<(), Report> { + self.as_operation().verify(context) + } +} +impl crate::Verify> for Operation { + fn should_verify(&self, _context: &Context) -> bool { + self.implements::>() + } + + fn verify(&self, context: &Context) -> Result<(), Report> { + for operand in self.operands().iter() { + let operand = operand.borrow(); + let value = operand.value(); + let ty = value.ty(); + if ::matches(ty) { + continue; + } else { + let description = ::description(); + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("invalid operand") + .with_primary_label( + value.span(), + format!("expected operand type to be {description}, but got {ty}"), + ) + .into_report()); + } + } + + Ok(()) + } +} + +pub trait TypeConstraint: 'static { + fn description() -> impl fmt::Display; + fn matches(ty: &crate::Type) -> bool; +} + +/// A type that can be constructed as a [crate::Type] +pub trait BuildableTypeConstraint: TypeConstraint { + fn build() -> crate::Type; +} + +macro_rules! type_constraint { + ($Constraint:ident, $description:literal, $matcher:literal) => { + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct $Constraint; + impl TypeConstraint for $Constraint { + #[inline(always)] + fn description() -> impl core::fmt::Display { + $description + } + + #[inline(always)] + fn matches(_ty: &$crate::Type) -> bool { + $matcher + } + } + }; + + ($Constraint:ident, $description:literal, $matcher:path) => { + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct $Constraint; + impl TypeConstraint for $Constraint { + #[inline(always)] + fn description() -> impl core::fmt::Display { + $description + } + + #[inline(always)] + fn matches(ty: &$crate::Type) -> bool { + $matcher(ty) + } + } + }; + + ($Constraint:ident, $description:literal, |$matcher_input:ident| $matcher:expr) => { + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct $Constraint; + impl TypeConstraint for $Constraint { + #[inline(always)] + fn description() -> impl core::fmt::Display { + $description + } + + #[inline(always)] + fn matches($matcher_input: &$crate::Type) -> bool { + $matcher + } + } + }; +} + +type_constraint!(AnyType, "any type", true); +// TODO(pauls): Extend Type with new Function variant, we'll use that to represent function handles +//type_constraint!(AnyFunction, "a function type", crate::Type::is_function); +type_constraint!(AnyList, "any list type", crate::Type::is_list); +type_constraint!(AnyArray, "any array type", crate::Type::is_array); +type_constraint!(AnyStruct, "any struct type", crate::Type::is_struct); +type_constraint!(AnyPointer, "a pointer type", crate::Type::is_pointer); +type_constraint!(AnyInteger, "an integral type", crate::Type::is_integer); +type_constraint!(AnyPointerOrInteger, "an integral or pointer type", |ty| ty.is_pointer() + || ty.is_integer()); +type_constraint!(AnySignedInteger, "a signed integral type", crate::Type::is_signed_integer); +type_constraint!( + AnyUnsignedInteger, + "an unsigned integral type", + crate::Type::is_unsigned_integer +); +type_constraint!(IntFelt, "a field element", crate::Type::is_felt); + +/// A boolean +pub type Bool = SizedInt<1>; + +/// A signless 8-bit integer +pub type Int8 = SizedInt<8>; +/// A signed 8-bit integer +pub type SInt8 = And>; +/// An unsigned 8-bit integer +pub type UInt8 = And>; + +/// A signless 16-bit integer +pub type Int16 = SizedInt<16>; +/// A signed 16-bit integer +pub type SInt16 = And>; +/// An unsigned 16-bit integer +pub type UInt16 = And>; + +/// A signless 32-bit integer +pub type Int32 = SizedInt<32>; +/// A signed 16-bit integer +pub type SInt32 = And>; +/// An unsigned 16-bit integer +pub type UInt32 = And>; + +/// A signless 64-bit integer +pub type Int64 = SizedInt<64>; +/// A signed 64-bit integer +pub type SInt64 = And>; +/// An unsigned 64-bit integer +pub type UInt64 = And>; + +impl BuildableTypeConstraint for IntFelt { + fn build() -> crate::Type { + crate::Type::Felt + } +} +impl BuildableTypeConstraint for Bool { + fn build() -> crate::Type { + crate::Type::I1 + } +} +impl BuildableTypeConstraint for UInt8 { + fn build() -> crate::Type { + crate::Type::U8 + } +} +impl BuildableTypeConstraint for SInt8 { + fn build() -> crate::Type { + crate::Type::I8 + } +} +impl BuildableTypeConstraint for UInt16 { + fn build() -> crate::Type { + crate::Type::U16 + } +} +impl BuildableTypeConstraint for SInt16 { + fn build() -> crate::Type { + crate::Type::I16 + } +} +impl BuildableTypeConstraint for UInt32 { + fn build() -> crate::Type { + crate::Type::U32 + } +} +impl BuildableTypeConstraint for SInt32 { + fn build() -> crate::Type { + crate::Type::I32 + } +} +impl BuildableTypeConstraint for UInt64 { + fn build() -> crate::Type { + crate::Type::U64 + } +} +impl BuildableTypeConstraint for SInt64 { + fn build() -> crate::Type { + crate::Type::I64 + } +} + +/// Represents a fixed-width integer of `N` bits +pub struct SizedInt(core::marker::PhantomData<[(); N]>); +impl Copy for SizedInt {} +impl Clone for SizedInt { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for SizedInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl fmt::Display for SizedInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{N}-bit integral type") + } +} +impl TypeConstraint for SizedInt { + fn description() -> impl fmt::Display { + Self(core::marker::PhantomData) + } + + fn matches(ty: &crate::Type) -> bool { + ty.is_integer() + } +} +impl BuildableTypeConstraint for SizedInt<8> { + fn build() -> crate::Type { + crate::Type::I8 + } +} +impl BuildableTypeConstraint for SizedInt<16> { + fn build() -> crate::Type { + crate::Type::I16 + } +} +impl BuildableTypeConstraint for SizedInt<32> { + fn build() -> crate::Type { + crate::Type::I32 + } +} +impl BuildableTypeConstraint for SizedInt<64> { + fn build() -> crate::Type { + crate::Type::I64 + } +} + +/// A type constraint for pointer values +pub struct PointerOf(core::marker::PhantomData); +impl Copy for PointerOf {} +impl Clone for PointerOf { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for PointerOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl fmt::Display for PointerOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let pointee = ::description(); + write!(f, "a pointer to {pointee}") + } +} +impl TypeConstraint for PointerOf { + #[inline(always)] + fn description() -> impl fmt::Display { + Self(core::marker::PhantomData) + } + + fn matches(ty: &crate::Type) -> bool { + ty.pointee().is_some_and(|pointee| ::matches(pointee)) + } +} +impl BuildableTypeConstraint for PointerOf { + fn build() -> crate::Type { + let pointee = Box::new(::build()); + crate::Type::Ptr(pointee) + } +} + +/// A type constraint for array values +pub struct AnyArrayOf(core::marker::PhantomData); +impl Copy for AnyArrayOf {} +impl Clone for AnyArrayOf { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for AnyArrayOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl fmt::Display for AnyArrayOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let element = ::description(); + write!(f, "an array of {element}") + } +} +impl TypeConstraint for AnyArrayOf { + #[inline(always)] + fn description() -> impl fmt::Display { + Self(core::marker::PhantomData) + } + + fn matches(ty: &crate::Type) -> bool { + match ty { + crate::Type::Array(ref elem, _) => ::matches(elem), + _ => false, + } + } +} + +/// A type constraint for array values +pub struct ArrayOf(core::marker::PhantomData<[T; N]>); +impl Copy for ArrayOf {} +impl Clone for ArrayOf { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for ArrayOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl fmt::Display for ArrayOf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let element = ::description(); + write!(f, "an array of {N} {element}") + } +} +impl TypeConstraint for ArrayOf { + #[inline(always)] + fn description() -> impl fmt::Display { + Self(core::marker::PhantomData) + } + + fn matches(ty: &crate::Type) -> bool { + match ty { + crate::Type::Array(ref elem, arity) if *arity == N => { + ::matches(elem) + } + _ => false, + } + } +} +impl BuildableTypeConstraint for ArrayOf { + fn build() -> crate::Type { + let element = Box::new(::build()); + crate::Type::Array(element, N) + } +} + +/// Represents a conjunction of two constraints as a concrete value +pub struct And { + _left: core::marker::PhantomData, + _right: core::marker::PhantomData, +} +impl Copy for And {} +impl Clone for And { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for And { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl TypeConstraint for And { + fn description() -> impl fmt::Display { + struct Both { + left: L, + right: R, + } + impl fmt::Display for Both { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "both {} and {}", &self.left, &self.right) + } + } + let left = ::description(); + let right = ::description(); + Both { left, right } + } + + #[inline] + fn matches(ty: &crate::Type) -> bool { + ::matches(ty) && ::matches(ty) + } +} + +/// Represents a disjunction of two constraints as a concrete value +pub struct Or { + _left: core::marker::PhantomData, + _right: core::marker::PhantomData, +} +impl Copy for Or {} +impl Clone for Or { + fn clone(&self) -> Self { + *self + } +} +impl fmt::Debug for Or { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(core::any::type_name::()) + } +} +impl TypeConstraint for Or { + fn description() -> impl fmt::Display { + struct Either { + left: L, + right: R, + } + impl fmt::Display for Either { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "either {} or {}", &self.left, &self.right) + } + } + let left = ::description(); + let right = ::description(); + Either { left, right } + } + + #[inline] + fn matches(ty: &crate::Type) -> bool { + ::matches(ty) || ::matches(ty) + } +} diff --git a/hir2/src/ir/types.rs b/hir2/src/ir/types.rs new file mode 100644 index 000000000..aba5bea0a --- /dev/null +++ b/hir2/src/ir/types.rs @@ -0,0 +1 @@ +pub use midenc_hir_type::*; diff --git a/hir2/src/ir/usable.rs b/hir2/src/ir/usable.rs new file mode 100644 index 000000000..315c32e10 --- /dev/null +++ b/hir2/src/ir/usable.rs @@ -0,0 +1,57 @@ +use super::{ + entity::EntityIter, EntityCursor, EntityCursorMut, EntityList, UnsafeIntrusiveEntityRef, +}; + +/// The [Usable] trait is implemented for IR entities which are _defined_ and _used_, and as a +/// result, require a data structure called the _use-def list_. +/// +/// A _definition_ of an IR entity, is a unique instantiation of that entity, the result of which +/// is different from all other definitions, even if the data associated with that definition is +/// the same as another definition. For example, SSA values are defined as either block arguments +/// or operation results, and a given value can only be defined once. +/// +/// A _use_ represents a unique reference to a _definition_ of some IR entity. Each use is unique, +/// and can be used to obtain not only the _user_ of the reference, but the location of that use +/// within the user. Uses are tracked in a _use list_, also called the _use-def list_, which +/// associates all uses to the definition, or _def_, that they reference. For example, operations +/// in HIR _use_ SSA values defined previously in the program. +/// +/// A _user_ does not have to be of the same IR type as the _definition_, and the type representing +/// the _use_ is typically different than both, and represents the type of relationship between the +/// two. For example, an `OpOperand` represents a single use of a `Value` by an `Op`. The entity +/// being defined is a `Value`, the entity using that definition is an `Op`, and the data associated +/// with each use is represented by `OpOperand`. +pub trait Usable { + /// The type associated with each unique use, e.g. `OpOperand` + type Use; + + /// Get a list of uses of this definition + fn uses(&self) -> &EntityList; + /// Get a mutable list of uses of this definition + fn uses_mut(&mut self) -> &mut EntityList; + + /// Returns true if this definition is used + #[inline] + fn is_used(&self) -> bool { + !self.uses().is_empty() + } + /// Get an iterator over the uses of this definition + #[inline] + fn iter_uses(&self) -> EntityIter<'_, Self::Use> { + self.uses().iter() + } + /// Get a cursor positioned on the first use of this definition, or the null cursor if unused. + fn first_use(&self) -> EntityCursor<'_, Self::Use> { + self.uses().front() + } + /// Get a mutable cursor positioned on the first use of this definition, or the null cursor if + /// unused. + #[inline] + fn first_use_mut(&mut self) -> EntityCursorMut<'_, Self::Use> { + self.uses_mut().front_mut() + } + /// Add `user` to the set of uses of this definition + fn insert_use(&mut self, user: UnsafeIntrusiveEntityRef) { + self.uses_mut().push_back(user); + } +} diff --git a/hir2/src/ir/value.rs b/hir2/src/ir/value.rs new file mode 100644 index 000000000..2596f9f42 --- /dev/null +++ b/hir2/src/ir/value.rs @@ -0,0 +1,406 @@ +use core::{any::Any, fmt}; + +use super::*; + +/// A unique identifier for a [Value] in the IR +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct ValueId(u32); +impl ValueId { + pub const fn from_u32(id: u32) -> Self { + Self(id) + } + + pub const fn as_u32(&self) -> u32 { + self.0 + } +} +impl EntityId for ValueId { + #[inline(always)] + fn as_usize(&self) -> usize { + self.0 as usize + } +} +impl fmt::Debug for ValueId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "v{}", &self.0) + } +} +impl fmt::Display for ValueId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "v{}", &self.0) + } +} + +/// Represents an SSA value in the IR. +/// +/// The data underlying a [Value] represents a _definition_, and thus implements [Usable]. The users +/// of a [Value] are operands (see [OpOperandImpl]). Operands are associated with an operation. Thus +/// the graph formed of the edges between values and operations via operands forms the data-flow +/// graph of the program. +pub trait Value: + EntityWithId + Spanned + Usable + fmt::Debug +{ + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + /// Set the source location of this value + fn set_span(&mut self, span: SourceSpan); + /// Get the type of this value + fn ty(&self) -> &Type; + /// Set the type of this value + fn set_type(&mut self, ty: Type); + /// Get the defining operation for this value, _if_ defined by an operation. + /// + /// Returns `None` if this value is defined by other means than an operation result. + fn get_defining_op(&self) -> Option; + /// Get the region which contains the definition of this value + fn parent_region(&self) -> Option { + self.parent_block().and_then(|block| block.borrow().parent()) + } + /// Get the block which contains the definition of this value + fn parent_block(&self) -> Option; + /// Returns true if this value is used outside of the given block + fn is_used_outside_of_block(&self, block: &BlockRef) -> bool { + self.iter_uses().any(|user| { + user.owner.borrow().parent().is_some_and(|blk| !BlockRef::ptr_eq(&blk, block)) + }) + } + /// Replace all uses of `self` with `replacement` + fn replace_all_uses_with(&mut self, mut replacement: ValueRef) { + let mut cursor = self.uses_mut().front_mut(); + while let Some(mut user) = cursor.as_pointer() { + // Rewrite use of `self` with `replacement` + { + let mut user = user.borrow_mut(); + user.value = replacement.clone(); + } + // Remove `user` from the use list of `self` + cursor.remove(); + // Add `user` to the use list of `replacement` + replacement.borrow_mut().insert_use(user); + } + } + /// Replace all uses of `self` with `replacement` unless the user is in `exceptions` + fn replace_all_uses_except(&mut self, mut replacement: ValueRef, exceptions: &[OperationRef]) { + let mut cursor = self.uses_mut().front_mut(); + while let Some(mut user) = cursor.as_pointer() { + // Rewrite use of `self` with `replacement` if user not in `exceptions` + { + let mut user = user.borrow_mut(); + if exceptions.contains(&user.owner) { + cursor.move_next(); + continue; + } + user.value = replacement.clone(); + } + // Remove `user` from the use list of `self` + cursor.remove(); + // Add `user` to the use list of `replacement` + replacement.borrow_mut().insert_use(user); + } + } +} + +impl dyn Value { + #[inline] + pub fn is(&self) -> bool { + self.as_any().is::() + } + + #[inline] + pub fn downcast_ref(&self) -> Option<&T> { + self.as_any().downcast_ref::() + } + + #[inline] + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.as_any_mut().downcast_mut::() + } + + /// Replace all uses of `self` with `replacement` if `should_replace` returns true + pub fn replace_uses_with_if(&mut self, mut replacement: ValueRef, should_replace: F) + where + F: Fn(&OpOperandImpl) -> bool, + { + let mut cursor = self.uses_mut().front_mut(); + while let Some(mut user) = cursor.as_pointer() { + // Rewrite use of `self` with `replacement` if `should_replace` returns true + { + let mut user = user.borrow_mut(); + if !should_replace(&user) { + cursor.move_next(); + continue; + } + user.value = replacement.clone(); + } + // Remove `user` from the use list of `self` + cursor.remove(); + // Add `user` to the use list of `replacement` + replacement.borrow_mut().insert_use(user); + } + } +} + +/// Generates the boilerplate for a concrete [Value] type. +macro_rules! value_impl { + ( + $(#[$outer:meta])* + $vis:vis struct $ValueKind:ident { + $(#[doc $($owner_doc_args:tt)*])* + owner: $OwnerTy:ty, + $(#[doc $($index_doc_args:tt)*])* + index: u8, + $( + $(#[$inner:ident $($args:tt)*])* + $Field:ident: $FieldTy:ty, + )* + } + + fn get_defining_op(&$GetDefiningOpSelf:ident) -> Option $GetDefiningOp:block + + fn parent_block(&$ParentBlockSelf:ident) -> Option $ParentBlock:block + + $($t:tt)* + ) => { + $(#[$outer])* + #[derive(Spanned)] + $vis struct $ValueKind { + id: ValueId, + #[span] + span: SourceSpan, + ty: Type, + uses: OpOperandList, + owner: $OwnerTy, + index: u8, + $( + $(#[$inner $($args)*])* + $Field: $FieldTy + ),* + } + + impl $ValueKind { + pub fn new( + span: SourceSpan, + id: ValueId, + ty: Type, + owner: $OwnerTy, + index: u8, + $( + $Field: $FieldTy + ),* + ) -> Self { + Self { + id, + ty, + span, + uses: Default::default(), + owner, + index, + $( + $Field + ),* + } + } + + $(#[doc $($owner_doc_args)*])* + pub fn owner(&self) -> $OwnerTy { + self.owner.clone() + } + + $(#[doc $($index_doc_args)*])* + pub fn index(&self) -> usize { + self.index as usize + } + } + + impl Value for $ValueKind { + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self + } + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + fn ty(&self) -> &Type { + &self.ty + } + + fn set_span(&mut self, span: SourceSpan) { + self.span = span; + } + + fn set_type(&mut self, ty: Type) { + self.ty = ty; + } + + fn get_defining_op(&$GetDefiningOpSelf) -> Option $GetDefiningOp + + fn parent_block(&$ParentBlockSelf) -> Option $ParentBlock + } + + impl Entity for $ValueKind {} + impl EntityWithId for $ValueKind { + type Id = ValueId; + + #[inline(always)] + fn id(&self) -> Self::Id { + self.id + } + } + + impl Usable for $ValueKind { + type Use = OpOperandImpl; + + #[inline(always)] + fn uses(&self) -> &OpOperandList { + &self.uses + } + + #[inline(always)] + fn uses_mut(&mut self) -> &mut OpOperandList { + &mut self.uses + } + } + + impl fmt::Display for $ValueKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use crate::formatter::PrettyPrint; + + self.pretty_print(f) + } + } + + impl fmt::Debug for $ValueKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut builder = f.debug_struct(stringify!($ValueKind)); + builder + .field("id", &self.id) + .field("ty", &self.ty) + .field("index", &self.index) + .field("is_used", &(!self.uses.is_empty())); + + $( + builder.field(stringify!($Field), &self.$Field); + )* + + builder.finish_non_exhaustive() + } + } + + $($t)* + } +} + +/// A pointer to a [Value] +pub type ValueRef = UnsafeEntityRef; +/// A pointer to a [BlockArgument] +pub type BlockArgumentRef = UnsafeEntityRef; +/// A pointer to a [OpResult] +pub type OpResultRef = UnsafeEntityRef; + +value_impl!( + /// A [BlockArgument] represents the definition of a [Value] by a block parameter + pub struct BlockArgument { + /// Get the [Block] to which this [BlockArgument] belongs + owner: BlockRef, + /// Get the index of this argument in the argument list of the owning [Block] + index: u8, + } + + fn get_defining_op(&self) -> Option { + None + } + + fn parent_block(&self) -> Option { + Some(self.owner.clone()) + } +); + +impl BlockArgument { + #[inline] + pub fn as_value_ref(&self) -> ValueRef { + self.as_block_argument_ref().upcast() + } + + #[inline] + pub fn as_block_argument_ref(&self) -> BlockArgumentRef { + unsafe { BlockArgumentRef::from_raw(self) } + } +} + +impl crate::formatter::PrettyPrint for BlockArgument { + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + + display(self.id) + const_text(": ") + self.ty.render() + } +} + +impl StorableEntity for BlockArgument { + #[inline(always)] + fn index(&self) -> usize { + self.index as usize + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index.try_into().expect("too many block arguments"); + } +} + +value_impl!( + /// An [OpResult] represents the definition of a [Value] by the result of an [Operation] + pub struct OpResult { + /// Get the [Operation] to which this [OpResult] belongs + owner: OperationRef, + /// Get the index of this result in the result list of the owning [Operation] + index: u8, + } + + fn get_defining_op(&self) -> Option { + Some(self.owner.clone()) + } + + fn parent_block(&self) -> Option { + self.owner.borrow().parent() + } +); + +impl OpResult { + #[inline] + pub fn as_value_ref(&self) -> ValueRef { + unsafe { ValueRef::from_raw(self as &dyn Value) } + } +} + +impl crate::formatter::PrettyPrint for OpResult { + #[inline] + fn render(&self) -> crate::formatter::Document { + use crate::formatter::*; + + display(self.id) + } +} + +impl StorableEntity for OpResult { + #[inline(always)] + fn index(&self) -> usize { + self.index as usize + } + + unsafe fn set_index(&mut self, index: usize) { + self.index = index.try_into().expect("too many op results"); + } + + /// Unlink all users of this result + /// + /// The users will still refer to this result, but the use list of this value will be empty + fn unlink(&mut self) { + let uses = self.uses_mut(); + uses.clear(); + } +} + +pub type OpResultStorage = crate::EntityStorage; +pub type OpResultRange<'a> = crate::EntityRange<'a, OpResultRef>; +pub type OpResultRangeMut<'a> = crate::EntityRangeMut<'a, OpResultRef, 1>; diff --git a/hir2/src/ir/verifier.rs b/hir2/src/ir/verifier.rs new file mode 100644 index 000000000..9b5c2bcfe --- /dev/null +++ b/hir2/src/ir/verifier.rs @@ -0,0 +1,262 @@ +use super::{Context, Report}; + +/// The `OpVerifier` trait is expected to be implemented by all [Op] impls as a prequisite. +/// +/// The actual implementation is typically generated as part of deriving [Op]. +pub trait OpVerifier { + fn verify(&self, context: &Context) -> Result<(), Report>; +} + +/// The `Verify` trait represents verification logic associated with implementations of some trait. +/// +/// This is specifically used for automatically deriving verification checks for [Op] impls that +/// implement traits that imply constraints on the representation or behavior of that op. +/// +/// For example, if some [Op] derives an op trait like `SingleBlock`, this information is recorded +/// in the underlying [Operation] metadata, so that we can recover a trait object reference for the +/// trait when needed. However, just deriving the trait is not sufficient to guarantee that the op +/// actually adheres to the implicit constraints and behavior of that trait. For example, +/// `SingleBlock` implies that the implementing op contains only regions that consist of a single +/// [Block]. This cannot be checked statically. The first step to addressing this though, is to +/// reify the implicit validation rules as explicit checks - hence this trait. +/// +/// So we've established that some op traits, such as `SingleBlock` mentioned above, have implicit +/// validation rules, and we can implement [Verify] to make the implicit validation rules of such +/// traits explicit - but how do we ensure that when an op derives an op trait, that the [Verify] +/// impl is also derived, _and_ that it is called when the op is verified? +/// +/// The answer lies in the use of some tricky type-level code to accomplish the following goals: +/// +/// * Do not emit useless checks for op traits that have no verification rules +/// * Do not require storing data in each instance of an [Op] just to verify a trait +/// * Do not require emitting a bunch of redundant type checks for information we know statically +/// * Be able to automatically derive all of the verification machinery along with the op traits +/// +/// The way this works is as follows: +/// +/// * We `impl Verify for T where T: Op` for every trait `Trait` with validation rules. +/// * A blanket impl of [HasVerifier] exists for all `T: Verify`. This is a marker trait used +/// in conjunction with specialization. See the trait docs for more details on its purpose. +/// * The [Verifier] trait provides a default vacuous impl for all `Trait` and `T` pairs. However, +/// we also provided a specialized [Verifier] impl for all `T: Verify` using the +/// `HasVerifier` marker. The specialized impl applies the underlying `Verify` impl. +/// * When deriving the op traits for an `Op` impl, we generate a hidden type that encodes all of +/// the op traits implemented by the op. We then generate an `OpVerifier` impl for the op, which +/// uses the hidden type we generated to reify the `Verifier` impl for each trait. The +/// `OpVerifier` implementation uses const eval to strip out all of the vacuous verifier impls, +/// leaving behind just the "real" verification rules specific to the traits implemented by that +/// op. +/// * The `OpVerifier` impl is object-safe, and is in fact a required super-trait of `Op` to ensure +/// that verification is part of defining an `Op`, but also to ensure that `verify` is a method +/// of `Op`, and that we can cast an `Operation` to `&dyn OpVerifier` and call `verify` on that. +/// +/// As a result of all this, we end up with highly-specialized verifiers for each op, with no +/// dynamic dispatch, and automatically maintained as part of the `Op` definition. When a new +/// op trait is derived, the verifier for the op is automatically updated to verify the new trait. +pub trait Verify { + /// In cases where verification may be disabled via runtime configuration, or based on + /// dynamic properties of the type, this method can be overridden and used to signal to + /// the verification driver that verification should be skipped on this item. + #[inline(always)] + #[allow(unused_variables)] + fn should_verify(&self, context: &Context) -> bool { + true + } + /// Apply this verifier, but only if [Verify::should_verify] returns true. + #[inline] + fn maybe_verify(&self, context: &Context) -> Result<(), Report> { + if self.should_verify(context) { + self.verify(context) + } else { + Ok(()) + } + } + /// Apply this verifier to the current item. + fn verify(&self, context: &Context) -> Result<(), Report>; +} + +/// A marker trait used for verifier specialization. +/// +/// # Safety +/// +/// In order for the `#[rustc_unsafe_specialization_marker]` attribute to be used safely and +/// correctly, the following rules must hold: +/// +/// * No associated items +/// * No impls with lifetime constraints, as specialization will ignore them +/// +/// For our use case, which is specializing verification for a given type and trait combination, +/// by optimizing out verification-related code for type combinations which have no verifier, these +/// are easy rules to uphold. +/// +/// However, we must ensure that we continue to uphold these rules moving forward. +#[rustc_unsafe_specialization_marker] +unsafe trait HasVerifier: Verify {} + +// While at first glance, it appears we would be using this to specialize on the fact that a type +// _has_ a verifier, which is strictly-speaking true, the actual goal we're aiming to acheive is +// to be able to identify the _absence_ of a verifier, so that we can eliminate the boilerplate for +// verifying that trait. See `Verifier` for more information. +unsafe impl HasVerifier for T where T: Verify {} + +/// The `Verifier` trait is used to derive a verifier for a given trait and concrete type. +/// +/// It does this by providing a default implementation for all combinations of `Trait` and `T`, +/// which always succeeds, and then specializing that implementation for `T: HasVerifier`. +/// +/// This has the effect of making all traits "verifiable", but only actually doing any verification +/// for types which implement `Verify`. +/// +/// We go a step further and actually set things up so that `rustc` can eliminate all of the dead +/// code when verification is vacuous. This is done by using const eval in the hidden type generated +/// for an [Op] impls [OpVerifier] implementation, which will wrap verification in a const-evaluated +/// check of the `VACUOUS` associated const. It can also be used directly, but the general idea +/// behind all of this is that we don't need to directly touch any of this, it's all generated. +/// +/// NOTE: Because this trait provides a default blanket impl for all `T`, you should avoid bringing +/// it into scope unless absolutely needed. It is virtually always preferred to explicitly invoke +/// this trait using turbofish syntax, so as to avoid conflict with the [Verify] trait, and to +/// avoid polluting the namespace for all types in scope. +pub trait Verifier { + /// An implementation of `Verifier` sets this flag to true when its implementation is vacuous, + /// i.e. it always succeeds and is not dependent on runtime context. + /// + /// The default implementation of this trait sets this to `true`, since without a verifier for + /// the type, verification always succeeds. However, we can specialize on the presence of + /// a verifier and set this to `false`, which will result in all of the verification logic + /// being applied. + /// + /// ## Example Usage + /// + /// Shown below is an example of how one can use const eval to eliminate dead code branches + /// in verifier selection, so that the resulting implementation is specialized and able to + /// have more optimizations applied as a result. + /// + /// ```rust + /// use midenc_hir2::{Context, Report, Verify, verifier::Verifier}; + /// + /// /// This trait is a marker for a type that should never validate, i.e. verifying it always + /// /// returns an error. + /// trait Nope {} + /// impl Verify for Any { + /// fn verify(&self, context: &Context) -> Result<(), Report> { + /// Err(Report::msg("nope")) + /// } + /// } + /// + /// /// We can't impl the `Verify` trait for all T outside of `midenc_hir2`, so we newtype all T + /// /// to do so, in effect, we can mimic the effect of implementing for all T using this type. + /// struct Any(core::marker::PhantomData); + /// impl Any { + /// fn new() -> Self { + /// Self(core::marker::PhantomData) + /// } + /// } + /// + /// /// This struct implements `Nope`, so it has an explicit verifier, which always fails + /// struct AlwaysRejected; + /// impl Nope for AlwaysRejected {} + /// + /// /// This struct doesn't implement `Nope`, so it gets a vacuous verifier, which always + /// /// succeeds. + /// struct AlwaysAccepted; + /// + /// /// Our vacuous verifier impl + /// #[inline(always)] + /// fn noop(_: &Any, _: &Context) -> Result<(), Report> { Ok(()) } + /// + /// /// This block uses const-eval to select the verifier for Any statically + /// let always_accepted = const { + /// if as Verifier>::VACUOUS { + /// noop + /// } else { + /// as Verifier>::maybe_verify + /// } + /// }; + /// + /// /// This block uses const-eval to select the verifier for Any statically + /// let always_rejected = const { + /// if as Verifier>::VACUOUS { + /// noop + /// } else { + /// as Verifier>::maybe_verify + /// } + /// }; + /// + /// /// Verify that we got the correct impls. We can't verify that all of the abstraction was + /// /// eliminated, but from reviewing the assembly output, it appears that this is precisely + /// /// what happens. + /// let context = Context::default(); + /// assert!(always_accepted(&Any::new(), &context).is_ok()); + /// assert!(always_rejected(&Any::new(), &context).is_err()); + /// ``` + const VACUOUS: bool; + + /// Checks if this verifier is applicable for the current item + fn should_verify(&self, context: &Context) -> bool; + /// Applies the verifier for this item, if [Verifier::should_verify] returns `true` + fn maybe_verify(&self, context: &Context) -> Result<(), Report>; + /// Applies the verifier for this item + fn verify(&self, context: &Context) -> Result<(), Report>; +} + +/// The default blanket impl for all types and traits +impl Verifier for T { + default const VACUOUS: bool = true; + + #[inline(always)] + default fn should_verify(&self, _context: &Context) -> bool { + false + } + + #[inline(always)] + default fn maybe_verify(&self, _context: &Context) -> Result<(), Report> { + Ok(()) + } + + #[inline(always)] + default fn verify(&self, _context: &Context) -> Result<(), Report> { + Ok(()) + } +} + +/// THe specialized impl for types which implement `Verify` +impl Verifier for T +where + T: HasVerifier, +{ + const VACUOUS: bool = false; + + #[inline] + fn should_verify(&self, context: &Context) -> bool { + >::should_verify(self, context) + } + + #[inline(always)] + fn maybe_verify(&self, context: &Context) -> Result<(), Report> { + >::maybe_verify(self, context) + } + + #[inline] + fn verify(&self, context: &Context) -> Result<(), Report> { + >::verify(self, context) + } +} + +#[cfg(test)] +mod tests { + use core::hint::black_box; + + use super::*; + use crate::{traits::SingleBlock, Operation}; + + struct Vacuous; + + /// In this test, we're validating that a type that trivially verifies specializes as vacuous, + /// and that a type we know has a "real" verifier, specializes as _not_ vacuous + #[test] + fn verifier_specialization_concrete() { + assert!(black_box(>::VACUOUS)); + assert!(black_box(!>::VACUOUS)); + } +} diff --git a/hir2/src/ir/visit.rs b/hir2/src/ir/visit.rs new file mode 100644 index 000000000..9f36b1c7b --- /dev/null +++ b/hir2/src/ir/visit.rs @@ -0,0 +1,102 @@ +mod searcher; +mod visitor; +mod walkable; + +pub use core::ops::ControlFlow; + +pub use self::{ + searcher::Searcher, + visitor::{OpVisitor, OperationVisitor, SymbolVisitor, Visitor}, + walkable::{WalkOrder, WalkStage, Walkable}, +}; +use crate::Report; + +/// A result-like type used to control traversals of a [Walkable] entity. +/// +/// It is comparable to [core::ops::ControlFlow], with an additional option to continue traversal, +/// but with a sibling, rather than visiting any further children of the current item. +/// +/// It is compatible with the `?` operator, however doing so will exit early on _both_ `Skip` and +/// `Break`, so you should be aware of that when using the try operator. +#[derive(Clone)] +#[must_use] +pub enum WalkResult { + /// Continue the traversal normally, optionally producing a value for the current item. + Continue(C), + /// Skip traversing the current item and any children that have not been visited yet, and + /// continue the traversal with the next sibling of the current item. + Skip, + /// Stop the traversal entirely, and optionally returning a value associated with the break. + // + /// This can be used to represent both errors, and the successful result of a search. + Break(B), +} +impl WalkResult { + /// Returns true if the walk should continue + pub fn should_continue(&self) -> bool { + matches!(self, Self::Continue(_)) + } + + /// Returns true if the walk was interrupted + pub fn was_interrupted(&self) -> bool { + matches!(self, Self::Break(_)) + } + + /// Returns true if the walk was skipped + pub fn was_skipped(&self) -> bool { + matches!(self, Self::Skip) + } +} +impl WalkResult { + /// Convert this [WalkResult] into an equivalent [Result] + #[inline] + pub fn into_result(self) -> Result<(), B> { + match self { + Self::Break(err) => Err(err), + Self::Skip | Self::Continue(_) => Ok(()), + } + } +} +impl From> for WalkResult { + fn from(value: Result<(), B>) -> Self { + match value { + Ok(_) => WalkResult::Continue(()), + Err(err) => WalkResult::Break(err), + } + } +} +impl From> for Result<(), B> { + #[inline(always)] + fn from(value: WalkResult) -> Self { + value.into_result() + } +} +impl core::ops::FromResidual for WalkResult { + fn from_residual(residual: ::Residual) -> Self { + match residual { + WalkResult::Break(b) => WalkResult::Break(b), + _ => unreachable!(), + } + } +} +impl core::ops::Residual for WalkResult { + type TryType = WalkResult; +} +impl core::ops::Try for WalkResult { + type Output = C; + type Residual = WalkResult; + + #[inline] + fn from_output(output: Self::Output) -> Self { + WalkResult::Continue(output) + } + + #[inline] + fn branch(self) -> ControlFlow { + match self { + WalkResult::Continue(c) => ControlFlow::Continue(c), + WalkResult::Skip => ControlFlow::Break(WalkResult::Skip), + WalkResult::Break(b) => ControlFlow::Break(WalkResult::Break(b)), + } + } +} diff --git a/hir2/src/ir/visit/searcher.rs b/hir2/src/ir/visit/searcher.rs new file mode 100644 index 000000000..138c369b6 --- /dev/null +++ b/hir2/src/ir/visit/searcher.rs @@ -0,0 +1,61 @@ +use super::{OpVisitor, OperationVisitor, SymbolVisitor, Visitor, WalkResult, Walkable}; +use crate::{Op, Operation, OperationRef, Symbol}; + +/// [Searcher] is a driver for [Visitor] impls as applied to some root [Operation]. +/// +/// The searcher traverses the object graph in depth-first preorder, from operations to regions to +/// blocks to operations, etc. All nested items of an entity are visited before its siblings, i.e. +/// a region is fully visited before the next region of the same containing operation. +/// +/// This is effectively control-flow order, from an abstract interpretation perspective, i.e. an +/// actual program might only execute one region of a multi-region op, but this traversal will visit +/// all of them unless otherwise directed by a `WalkResult`. +pub struct Searcher { + visitor: V, + root: OperationRef, + _marker: core::marker::PhantomData, +} +impl> Searcher { + pub fn new(root: OperationRef, visitor: V) -> Self { + Self { + visitor, + root, + _marker: core::marker::PhantomData, + } + } +} + +impl Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + self.visitor.visit(&op) + }) + } +} + +impl> Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + if let Some(op) = op.downcast_ref::() { + self.visitor.visit(op) + } else { + WalkResult::Continue(()) + } + }) + } +} + +impl Searcher { + pub fn visit(&mut self) -> WalkResult<>::Output> { + self.root.borrow().prewalk_interruptible(|op: OperationRef| { + let op = op.borrow(); + if let Some(sym) = op.as_symbol() { + self.visitor.visit(sym) + } else { + WalkResult::Continue(()) + } + }) + } +} diff --git a/hir2/src/ir/visit/visitor.rs b/hir2/src/ir/visit/visitor.rs new file mode 100644 index 000000000..4a061d734 --- /dev/null +++ b/hir2/src/ir/visit/visitor.rs @@ -0,0 +1,36 @@ +use super::WalkResult; +use crate::{Op, Operation, Symbol}; + +/// A generic trait that describes visitors for all kinds +pub trait Visitor { + /// The type of output produced by visiting an item. + type Output; + + /// The function which is applied to each `T` as it is visited. + fn visit(&mut self, current: &T) -> WalkResult; +} + +/// We can automatically convert any closure of appropriate type to a `Visitor` +impl Visitor for F +where + F: FnMut(&T) -> WalkResult, +{ + type Output = U; + + #[inline] + fn visit(&mut self, op: &T) -> WalkResult { + self(op) + } +} + +/// Represents a visitor over [Operation] +pub trait OperationVisitor: Visitor {} +impl OperationVisitor for V where V: Visitor {} + +/// Represents a visitor over [Op] of type `T` +pub trait OpVisitor: Visitor {} +impl OpVisitor for V where V: Visitor {} + +/// Represents a visitor over [Symbol] +pub trait SymbolVisitor: Visitor {} +impl SymbolVisitor for V where V: Visitor {} diff --git a/hir2/src/ir/visit/walkable.rs b/hir2/src/ir/visit/walkable.rs new file mode 100644 index 000000000..a523a4c17 --- /dev/null +++ b/hir2/src/ir/visit/walkable.rs @@ -0,0 +1,404 @@ +use super::WalkResult; +use crate::{ + Block, BlockRef, Operation, OperationRef, Region, RegionRef, UnsafeIntrusiveEntityRef, +}; + +/// The traversal order for a walk of a region, block, or operation +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum WalkOrder { + PreOrder, + PostOrder, +} + +/// Encodes the current walk stage for generic walkers. +/// +/// When walking an operation, we can either choose a pre- or post-traversal walker which invokes +/// the callback on an operation before/after all its attached regions have been visited, or choose +/// a generic walker where the callback is invoked on the operation N+1 times, where N is the number +/// of regions attached to that operation. [WalkStage] encodes the current stage of the walk, i.e. +/// which regions have already been visited, and the callback accepts an additional argument for +/// the current stage. Such generic walkers that accept stage-aware callbacks are only applicable +/// when the callback operations on an operation (i.e. doesn't apply to callbacks on blocks or +/// regions). +#[derive(Clone, PartialEq, Eq)] +pub struct WalkStage { + /// The number of regions in the operation + num_regions: usize, + /// The next region to visit in the operation + next_region: usize, +} +impl WalkStage { + pub fn new(op: OperationRef) -> Self { + let op = op.borrow(); + Self { + num_regions: op.num_regions(), + next_region: 0, + } + } + + /// Returns true if the parent operation is being visited before all regions. + #[inline] + pub fn is_before_all_regions(&self) -> bool { + self.next_region == 0 + } + + /// Returns true if the parent operation is being visited just before visiting `region` + #[inline] + pub fn is_before_region(&self, region: usize) -> bool { + self.next_region == region + } + + /// Returns true if the parent operation is being visited just after visiting `region` + #[inline] + pub fn is_after_region(&self, region: usize) -> bool { + self.next_region == region + 1 + } + + /// Returns true if the parent operation is being visited after all regions. + #[inline] + pub fn is_after_all_regions(&self) -> bool { + self.next_region == self.num_regions + } + + /// Advance the walk stage + #[inline] + pub fn advance(&mut self) { + self.next_region += 1; + } + + /// Returns the next region that will be visited + #[inline(always)] + pub const fn next_region(&self) -> usize { + self.next_region + } +} + +/// A [Walkable] is an entity which can be traversed depth-first in either pre- or post-order +/// +/// An implementation of this trait specifies a type, `T`, corresponding to the type of item being +/// walked, while `Self` is the root entity, possibly of the same type, which may contain `T`. Thus +/// traversing from the root to all of the leaves, we will visit all reachable `T` nested within +/// `Self`, possibly including itself. +pub trait Walkable { + /// Walk all `T` in `self` in a specific order, applying the given callback to each. + /// + /// This is very similar to [Walkable::walk_interruptible], except the callback has no control + /// over the traversal, and must be infallible. + #[inline] + fn walk(&self, order: WalkOrder, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.walk_interruptible(order, |t| { + callback(t); + + WalkResult::<()>::Continue(()) + }); + } + + /// Walk all `T` in `self` using a pre-order, depth-first traversal, applying the given callback + /// to each `T`. + #[inline] + fn prewalk(&self, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.prewalk_interruptible(|t| { + callback(t); + + WalkResult::<()>::Continue(()) + }); + } + + /// Walk all `T` in `self` using a post-order, depth-first traversal, applying the given callback + /// to each `T`. + #[inline] + fn postwalk(&self, mut callback: F) + where + F: FnMut(UnsafeIntrusiveEntityRef), + { + let _ = self.postwalk_interruptible(|t| { + callback(t); + + WalkResult::<()>::Continue(()) + }); + } + + /// Walk `self` in the given order, visiting each `T` and applying the given callback to them. + /// + /// The given callback can control the traversal using the [WalkResult] it returns: + /// + /// * `WalkResult::Skip` will skip the walk of the current item and its nested elements that + /// have not been visited already, continuing with the next item. + /// * `WalkResult::Break` will interrupt the walk, and no more items will be visited + /// * `WalkResult::Continue` will continue the walk + #[inline] + fn walk_interruptible(&self, order: WalkOrder, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult, + { + match order { + WalkOrder::PreOrder => self.prewalk_interruptible(callback), + WalkOrder::PostOrder => self.prewalk_interruptible(callback), + } + } + + /// Walk all `T` in `self` using a pre-order, depth-first traversal, applying the given callback + /// to each `T`, and determining how to proceed based on the returned [WalkResult]. + fn prewalk_interruptible(&self, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult; + + /// Walk all `T` in `self` using a post-order, depth-first traversal, applying the given callback + /// to each `T`, and determining how to proceed based on the returned [WalkResult]. + fn postwalk_interruptible(&self, callback: F) -> WalkResult + where + F: FnMut(UnsafeIntrusiveEntityRef) -> WalkResult; +} + +/// Walking operations nested within an [Operation], including itself +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + prewalk_operation_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + postwalk_operation_interruptible(self, &mut callback) + } +} + +fn prewalk_operation_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + let result = callback(op.as_operation_ref()); + if !result.should_continue() { + return result; + } + + for region in op.regions().iter() { + for block in region.body().iter() { + let mut ops = block.body().front(); + while let Some(op) = ops.as_pointer() { + ops.move_next(); + let op = op.borrow(); + prewalk_operation_interruptible(&op, callback)?; + } + } + } + + WalkResult::Continue(()) +} + +fn postwalk_operation_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + for region in op.regions().iter() { + for block in region.body().iter() { + let mut ops = block.body().front(); + while let Some(op) = ops.as_pointer() { + ops.move_next(); + let op = op.borrow(); + postwalk_operation_interruptible(&op, callback)?; + } + } + } + + callback(op.as_operation_ref()) +} + +/// Walking regions of an [Operation], and those of all nested operations +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(RegionRef) -> WalkResult, + { + prewalk_regions_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(RegionRef) -> WalkResult, + { + postwalk_regions_interruptible(self, &mut callback) + } +} + +fn prewalk_regions_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(RegionRef) -> WalkResult, +{ + let mut regions = op.regions().front(); + while let Some(region) = regions.as_pointer() { + regions.move_next(); + match callback(region.clone()) { + WalkResult::Continue(_) => { + let region = region.borrow(); + for block in region.body().iter() { + for op in block.body().iter() { + prewalk_regions_interruptible(&op, callback)?; + } + } + } + WalkResult::Skip => continue, + result @ WalkResult::Break(_) => return result, + } + } + + WalkResult::Continue(()) +} + +fn postwalk_regions_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(RegionRef) -> WalkResult, +{ + let mut regions = op.regions().front(); + while let Some(region) = regions.as_pointer() { + regions.move_next(); + { + let region = region.borrow(); + for block in region.body().iter() { + for op in block.body().iter() { + postwalk_regions_interruptible(&op, callback)?; + } + } + } + callback(region)?; + } + + WalkResult::Continue(()) +} + +/// Walking operations nested within a [Region] +impl Walkable for Region { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + prewalk_region_operations_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(OperationRef) -> WalkResult, + { + postwalk_region_operations_interruptible(self, &mut callback) + } +} + +fn prewalk_region_operations_interruptible(region: &Region, callback: &mut F) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + for block in region.body().iter() { + let mut cursor = block.body().front(); + while let Some(op) = cursor.as_pointer() { + cursor.move_next(); + match callback(op.clone()) { + WalkResult::Continue(_) => { + let op = op.borrow(); + for region in op.regions() { + prewalk_region_operations_interruptible(®ion, callback)?; + } + } + WalkResult::Skip => continue, + result @ WalkResult::Break(_) => return result, + } + } + } + + WalkResult::Continue(()) +} + +fn postwalk_region_operations_interruptible( + region: &Region, + callback: &mut F, +) -> WalkResult +where + F: FnMut(OperationRef) -> WalkResult, +{ + for block in region.body().iter() { + let mut cursor = block.body().front(); + while let Some(op) = cursor.as_pointer() { + cursor.move_next(); + { + let op = op.borrow(); + for region in op.regions() { + postwalk_region_operations_interruptible(®ion, callback)?; + } + } + callback(op)?; + } + } + + WalkResult::Continue(()) +} + +/// Walking blocks of an [Operation], and those of all nested operations +impl Walkable for Operation { + fn prewalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(BlockRef) -> WalkResult, + { + prewalk_blocks_interruptible(self, &mut callback) + } + + fn postwalk_interruptible(&self, mut callback: F) -> WalkResult + where + F: FnMut(BlockRef) -> WalkResult, + { + postwalk_blocks_interruptible(self, &mut callback) + } +} + +fn prewalk_blocks_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(BlockRef) -> WalkResult, +{ + for region in op.regions().iter() { + let mut blocks = region.body().front(); + while let Some(block) = blocks.as_pointer() { + blocks.move_next(); + match callback(block.clone()) { + WalkResult::Continue(_) => { + let block = block.borrow(); + for op in block.body().iter() { + prewalk_blocks_interruptible(&op, callback)?; + } + } + WalkResult::Skip => continue, + result @ WalkResult::Break(_) => return result, + } + } + } + + WalkResult::Continue(()) +} + +fn postwalk_blocks_interruptible(op: &Operation, callback: &mut F) -> WalkResult +where + F: FnMut(BlockRef) -> WalkResult, +{ + for region in op.regions().iter() { + let mut blocks = region.body().front(); + while let Some(block) = blocks.as_pointer() { + blocks.move_next(); + { + let block = block.borrow(); + for op in block.body().iter() { + postwalk_blocks_interruptible(&op, callback)?; + } + } + callback(block.clone())?; + } + } + + WalkResult::Continue(()) +} diff --git a/hir2/src/itertools.rs b/hir2/src/itertools.rs new file mode 100644 index 000000000..10e985f07 --- /dev/null +++ b/hir2/src/itertools.rs @@ -0,0 +1,17 @@ +pub trait IteratorExt { + /// Returns true if the given iterator consists of exactly one element + fn has_single_element(&mut self) -> bool; +} + +impl IteratorExt for I { + default fn has_single_element(&mut self) -> bool { + self.next().is_some_and(|_| self.next().is_none()) + } +} + +impl IteratorExt for I { + #[inline] + fn has_single_element(&mut self) -> bool { + self.len() == 1 + } +} diff --git a/hir2/src/lib.rs b/hir2/src/lib.rs new file mode 100644 index 000000000..cabedb9e1 --- /dev/null +++ b/hir2/src/lib.rs @@ -0,0 +1,105 @@ +#![feature(allocator_api)] +#![feature(alloc_layout_extra)] +#![feature(coerce_unsized)] +#![feature(unsize)] +#![feature(ptr_metadata)] +#![feature(layout_for_ptr)] +#![feature(slice_ptr_get)] +#![feature(specialization)] +#![feature(rustc_attrs)] +#![feature(debug_closure_helpers)] +#![feature(trait_alias)] +#![feature(trait_upcasting)] +#![feature(is_none_or)] +#![feature(try_trait_v2)] +#![feature(try_trait_v2_residual)] +#![feature(tuple_trait)] +#![feature(fn_traits)] +#![feature(unboxed_closures)] +#![feature(const_type_id)] +#![feature(exact_size_is_empty)] +#![feature(generic_const_exprs)] +#![feature(new_uninit)] +#![feature(clone_to_uninit)] +// The following are used in impls of custom collection types based on SmallVec +#![feature(iter_repeat_n)] +#![feature(std_internals)] // for ByRefSized +#![feature(extend_one)] +#![feature(extend_one_unchecked)] +#![feature(iter_advance_by)] +#![feature(iter_next_chunk)] +#![feature(iter_collect_into)] +#![feature(trusted_len)] +#![feature(never_type)] +#![feature(maybe_uninit_slice)] +#![feature(maybe_uninit_array_assume_init)] +#![feature(maybe_uninit_uninit_array)] +#![feature(maybe_uninit_uninit_array_transpose)] +#![feature(array_into_iter_constructors)] +#![feature(slice_range)] +#![feature(slice_swap_unchecked)] +#![feature(hasher_prefixfree_extras)] +// Some of the above features require us to disable these warnings +#![allow(incomplete_features)] +#![allow(internal_features)] + +extern crate alloc; + +#[cfg(feature = "std")] +extern crate std; + +extern crate self as midenc_hir2; + +pub use compact_str::{ + CompactString as SmallStr, CompactStringExt as SmallStrExt, ToCompactString as ToSmallStr, +}; + +pub type FxHashMap = hashbrown::HashMap; +pub type FxHashSet = hashbrown::HashSet; +pub use rustc_hash::{FxBuildHasher, FxHasher}; + +pub mod adt; +mod any; +mod attributes; +pub mod demangle; +pub mod derive; +pub mod dialects; +mod folder; +pub mod formatter; +mod hash; +mod ir; +pub mod itertools; +pub mod matchers; +pub mod pass; +mod patterns; + +pub use self::{ + any::AsAny, + attributes::{ + markers::*, Attribute, AttributeSet, AttributeValue, CallConv, DictAttr, Overflow, SetAttr, + Visibility, + }, + folder::OperationFolder, + hash::{DynHash, DynHasher}, + ir::*, + itertools::IteratorExt, + patterns::*, +}; + +// TODO(pauls): The following is a rough list of what needs to be implemented for the IR +// refactoring to be complete and usable in place of the old IR (some are optional): +// +// * constants and constant-like ops +// * global variables and global ops +// * Need to implement InferTypeOpInterface for all applicable ops +// * Builders (i.e. component builder, interface builder, module builder, function builder, last is most important) +// NOTE: The underlying builder infra is done, so layering on the high-level builders is pretty simple +// * canonicalization (optional) +// * pattern matching/rewrites (needed for legalization/conversion, mostly complete, see below) +// - Need to provide implementations of stubbed out rewriter methods +// - Need to implement the GreedyRewritePatternDriver +// - Need to implement matchers +// * dataflow analysis framework (required to replace old analyses) +// * linking/global symbol resolution (required to replace old linker, partially implemented via symbols/symbol tables already) +// * legalization/dialect conversion (required to convert between unstructured and structured control flow dialects at minimum) +// * lowering diff --git a/hir2/src/matchers.rs b/hir2/src/matchers.rs new file mode 100644 index 000000000..1c3a0dc73 --- /dev/null +++ b/hir2/src/matchers.rs @@ -0,0 +1,3 @@ +mod matcher; + +pub use self::matcher::*; diff --git a/hir2/src/matchers/matcher.rs b/hir2/src/matchers/matcher.rs new file mode 100644 index 000000000..818d697fd --- /dev/null +++ b/hir2/src/matchers/matcher.rs @@ -0,0 +1,770 @@ +use core::ptr::{DynMetadata, Pointee}; + +use smallvec::SmallVec; + +use crate::{ + AttributeValue, Op, OpFoldResult, OpOperand, Operation, OperationRef, UnsafeIntrusiveEntityRef, + ValueRef, +}; + +/// [Matcher] is a pattern matching abstraction with support for expressing both matching and +/// capturing semantics. +/// +/// This is used to implement low-level pattern matching primitives for the IR for use in: +/// +/// * Folding +/// * Canonicalization +/// * Regionalized transformations and analyses +pub trait Matcher { + /// The value type produced as a result of a successful match + /// + /// Use `()` if this matcher does not capture any value, and simply signals whether or not + /// the pattern was matched. + type Matched; + + /// Check if `entity` is matched by this matcher, returning `Self::Matched` if successful. + fn matches(&self, entity: &T) -> Option; +} + +#[repr(transparent)] +pub struct MatchWith(pub F); +impl Matcher for MatchWith +where + F: Fn(&T) -> Option, +{ + type Matched = U; + + #[inline(always)] + fn matches(&self, entity: &T) -> Option { + (self.0)(entity) + } +} + +/// A match combinator representing the logical AND of two sub-matchers. +/// +/// Both patterns must match on the same IR entity, but only the matched value of `B` is returned, +/// i.e. the captured result of `A` is discarded. +/// +/// Returns the result of matching `B` if successful, otherwise `None` +pub struct AndMatcher { + a: A, + b: B, +} + +impl AndMatcher { + pub const fn new(a: A, b: B) -> Self { + Self { a, b } + } +} + +impl Matcher for AndMatcher +where + A: Matcher, + B: Matcher, +{ + type Matched = >::Matched; + + #[inline] + fn matches(&self, entity: &T) -> Option { + self.a.matches(entity).and_then(|_| self.b.matches(entity)) + } +} + +/// A match combinator representing a monadic bind of two patterns. +/// +/// In other words, given two patterns `A` and `B`: +/// +/// * `A` is matched, and if it fails, the entire match fails. +/// * `B` is then matched against the output of `A`, and if it fails, the entire match fails +/// * Both matches were successful, and the output of `B` is returned as the final result. +pub struct ChainMatcher { + a: A, + b: B, +} + +impl ChainMatcher { + pub const fn new(a: A, b: B) -> Self { + Self { a, b } + } +} + +impl Matcher for ChainMatcher +where + A: Matcher, + B: Matcher, +{ + type Matched = >::Matched; + + #[inline] + fn matches(&self, entity: &T) -> Option { + self.a.matches(entity).and_then(|matched| self.b.matches(&matched)) + } +} + +/// Matches operations which implement some trait `Trait`, capturing the match as a trait object. +/// +/// NOTE: `Trait` must be an object-safe trait. +pub struct OpTraitMatcher { + _marker: core::marker::PhantomData, +} + +impl Default for OpTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl OpTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + /// Create a new [OpTraitMatcher] from the given matcher. + pub const fn new() -> Self { + Self { + _marker: core::marker::PhantomData, + } + } +} + +impl Matcher for OpTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + type Matched = UnsafeIntrusiveEntityRef; + + fn matches(&self, entity: &Operation) -> Option { + entity + .as_trait::() + .map(|op| unsafe { UnsafeIntrusiveEntityRef::from_raw(op) }) + } +} + +/// Matches operations which implement some trait `Trait`. +/// +/// Returns a type-erased operation ref, not a trait object like [OpTraitMatcher] +pub struct HasTraitMatcher { + _marker: core::marker::PhantomData, +} + +impl Default for HasTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl HasTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + /// Create a new [HasTraitMatcher] from the given matcher. + pub const fn new() -> Self { + Self { + _marker: core::marker::PhantomData, + } + } +} + +impl Matcher for HasTraitMatcher +where + Trait: ?Sized + Pointee> + 'static, +{ + type Matched = OperationRef; + + fn matches(&self, entity: &Operation) -> Option { + if !entity.implements::() { + return None; + } + Some(entity.as_operation_ref()) + } +} + +/// Matches any operation with an attribute named `name`, the value of which matches a matcher of +/// type `M`. +pub struct OpAttrMatcher { + name: &'static str, + matcher: M, +} + +impl OpAttrMatcher +where + M: Matcher, +{ + /// Create a new [OpAttrMatcher] from the given attribute name and matcher. + pub const fn new(name: &'static str, matcher: M) -> Self { + Self { name, matcher } + } +} + +impl Matcher for OpAttrMatcher +where + M: Matcher, +{ + type Matched = >::Matched; + + fn matches(&self, entity: &Operation) -> Option { + entity.get_attribute(self.name).and_then(|value| self.matcher.matches(value)) + } +} + +/// Matches any operation with an attribute `name` and concrete value type of `A`. +/// +/// Binds the value as its concrete type `A`. +pub type TypedOpAttrMatcher = OpAttrMatcher>; + +/// Matches and binds any attribute value whose concrete type is `A`. +pub struct TypedAttrMatcher(core::marker::PhantomData); +impl Default for TypedAttrMatcher { + #[inline(always)] + fn default() -> Self { + Self(core::marker::PhantomData) + } +} +impl Matcher for TypedAttrMatcher { + type Matched = A; + + #[inline] + fn matches(&self, entity: &dyn AttributeValue) -> Option { + entity.downcast_ref::().cloned() + } +} + +/// A matcher for operations that always succeeds, binding the operation reference in the process. +struct AnyOpMatcher; +impl Matcher for AnyOpMatcher { + type Matched = OperationRef; + + #[inline(always)] + fn matches(&self, entity: &Operation) -> Option { + Some(entity.as_operation_ref()) + } +} + +/// A matcher for operations whose concrete type is `T`, binding the op with a strongly-typed +/// reference. +struct OneOpMatcher(core::marker::PhantomData); +impl OneOpMatcher { + pub const fn new() -> Self { + Self(core::marker::PhantomData) + } +} +impl Matcher for OneOpMatcher { + type Matched = UnsafeIntrusiveEntityRef; + + #[inline(always)] + fn matches(&self, entity: &Operation) -> Option { + entity + .downcast_ref::() + .map(|op| unsafe { UnsafeIntrusiveEntityRef::from_raw(op) }) + } +} + +/// A matcher for values that always succeeds, binding the value reference in the process. +struct AnyValueMatcher; +impl Matcher for AnyValueMatcher { + type Matched = ValueRef; + + #[inline(always)] + fn matches(&self, entity: &ValueRef) -> Option { + Some(entity.clone()) + } +} + +/// A matcher that only succeeds if it matches exactly the provided value. +struct ExactValueMatcher(ValueRef); +impl Matcher for ExactValueMatcher { + type Matched = ValueRef; + + #[inline(always)] + fn matches(&self, entity: &ValueRef) -> Option { + if ValueRef::ptr_eq(&self.0, entity) { + Some(entity.clone()) + } else { + None + } + } +} + +/// A matcher for operations that implement [crate::traits::ConstantLike] +type ConstantOpMatcher = HasTraitMatcher; + +/// Like [ConstantOpMatcher], this matcher matches constant operations, but rather than binding +/// the operation itself, it binds the constant value produced by the operation. +#[derive(Default)] +struct ConstantOpBinder; +impl Matcher for ConstantOpBinder { + type Matched = Box; + + fn matches(&self, entity: &Operation) -> Option { + use crate::traits::Foldable; + + if !entity.implements::() { + return None; + } + + let mut out = SmallVec::default(); + entity.fold(&mut out).expect("expected constant-like op to be foldable"); + let Some(OpFoldResult::Attribute(value)) = out.pop() else { + return None; + }; + + Some(value) + } +} + +/// An extension of [ConstantOpBinder] which only matches constant values of type `T` +struct TypedConstantOpBinder(core::marker::PhantomData); +impl TypedConstantOpBinder { + pub const fn new() -> Self { + Self(core::marker::PhantomData) + } +} +impl Matcher for TypedConstantOpBinder { + type Matched = T; + + fn matches(&self, entity: &Operation) -> Option { + ConstantOpBinder.matches(entity).and_then(|value| { + if !value.is::() { + None + } else { + Some(unsafe { + let raw = Box::into_raw(value); + *Box::from_raw(raw as *mut T) + }) + } + }) + } +} + +/// Matches operations which implement [crate::traits::UnaryOp] and binds the operand. +#[derive(Default)] +struct UnaryOpBinder; +impl Matcher for UnaryOpBinder { + type Matched = OpOperand; + + fn matches(&self, entity: &Operation) -> Option { + if !entity.implements::() { + return None; + } + + Some(entity.operands()[0].borrow().as_operand_ref()) + } +} + +/// Matches operations which implement [crate::traits::BinaryOp] and binds both operands. +#[derive(Default)] +struct BinaryOpBinder; +impl Matcher for BinaryOpBinder { + type Matched = [OpOperand; 2]; + + fn matches(&self, entity: &Operation) -> Option { + if !entity.implements::() { + return None; + } + + let operands = entity.operands(); + let lhs = operands[0].borrow().as_operand_ref(); + let rhs = operands[1].borrow().as_operand_ref(); + + Some([lhs, rhs]) + } +} + +/// Converts the output of [UnaryOpBinder] to an OpFoldResult, by checking if the operand definition +/// is a constant-like op, and either binding the constant value, or the SSA value used as the +/// operand. +/// +/// This can be used to set up for folding. +struct FoldResultBinder; +impl Matcher for FoldResultBinder { + type Matched = OpFoldResult; + + fn matches(&self, operand: &OpOperand) -> Option { + let operand = operand.borrow(); + let maybe_constant = operand + .value() + .get_defining_op() + .and_then(|defining_op| constant().matches(&defining_op.borrow())); + if let Some(const_operand) = maybe_constant { + Some(OpFoldResult::Attribute(const_operand)) + } else { + Some(OpFoldResult::Value(operand.as_value_ref())) + } + } +} + +/// Converts the output of [BinaryOpBinder] to a pair of OpFoldResults, by checking if the operand +/// definitions are constant, and either binding the constant values, or the SSA values used by each +/// operand. +/// +/// This can be used to set up for folding. +struct BinaryFoldResultBinder; +impl Matcher<[OpOperand; 2]> for BinaryFoldResultBinder { + type Matched = [OpFoldResult; 2]; + + fn matches(&self, operands: &[OpOperand; 2]) -> Option { + let binder = FoldResultBinder; + + let lhs = binder.matches(&operands[0]).unwrap(); + let rhs = binder.matches(&operands[1]).unwrap(); + + Some([lhs, rhs]) + } +} + +/// Matches the operand of a unary op to determine if it is a candidate for folding. +/// +/// A successful match binds the constant value of the operand for use by the [Foldable] impl. +struct FoldableOperandBinder; +impl Matcher for FoldableOperandBinder { + type Matched = Box; + + fn matches(&self, operand: &OpOperand) -> Option { + let operand = operand.borrow(); + let defining_op = operand.value().get_defining_op()?; + constant().matches(&defining_op.borrow()) + } +} + +struct TypedFoldableOperandBinder(core::marker::PhantomData); +impl Default for TypedFoldableOperandBinder { + fn default() -> Self { + Self(core::marker::PhantomData) + } +} +impl Matcher for TypedFoldableOperandBinder { + type Matched = Box; + + fn matches(&self, operand: &OpOperand) -> Option { + FoldableOperandBinder + .matches(operand) + .and_then(|value| value.downcast::().ok()) + } +} + +/// Matches the operands of a binary op to determine if it is a candidate for folding. +/// +/// A successful match binds the constant value of the operands for use by the [Foldable] impl. +/// +/// NOTE: Both operands must be constant for this to match. Use [BinaryFoldResultBinder] if you +/// wish to let the [Foldable] impl decide what to do in the presence of mixed constant and non- +/// constant operands. +struct FoldableBinaryOpBinder; +impl Matcher<[OpOperand; 2]> for FoldableBinaryOpBinder { + type Matched = [Box; 2]; + + fn matches(&self, operands: &[OpOperand; 2]) -> Option { + let binder = FoldableOperandBinder; + let lhs = binder.matches(&operands[0])?; + let rhs = binder.matches(&operands[1])?; + + Some([lhs, rhs]) + } +} + +// Match Combinators + +/// Matches both `a` and `b`, or fails +pub const fn match_both( + a: A, + b: B, +) -> impl Matcher>::Matched> +where + A: Matcher, + B: Matcher, +{ + AndMatcher::new(a, b) +} + +/// Matches `a` and if successful, matches `b` against the output of `a`, or fails. +pub const fn match_chain( + a: A, + b: B, +) -> impl Matcher>::Matched> +where + A: Matcher, + B: Matcher, +{ + ChainMatcher::new(a, b) +} + +// Operation Matchers + +/// Matches any operation, i.e. it always matches +/// +/// Returns a type-erased operation reference +pub const fn match_any() -> impl Matcher { + AnyOpMatcher +} + +/// Matches any operation whose concrete type is `T` +/// +/// Returns a strongly-typed op reference +pub const fn match_op() -> impl Matcher> { + OneOpMatcher::::new() +} + +/// Matches any operation that implements [crate::traits::ConstantLike]. +/// +/// These operations return a single result, and must be pure (no side effects) +pub const fn constant_like() -> impl Matcher { + ConstantOpMatcher::new() +} + +// Constant Value Binders + +/// Matches any operation that implements [crate::traits::ConstantLike], and binds the constant +/// value as the result of the match. +pub const fn constant() -> impl Matcher> { + ConstantOpBinder +} + +/// Like [constant], but only matches if the constant value has the concrete type `T`. +/// +/// Typically, constant values will be [crate::Immediate], but any attribute value can be matched. +pub const fn constant_of() -> impl Matcher { + TypedConstantOpBinder::new() +} + +// Value Binders + +/// Matches any unary operation (i.e. implements [crate::traits::UnaryOp]), and binds its operand. +pub const fn unary() -> impl Matcher { + UnaryOpBinder +} + +/// Matches any unary operation (i.e. implements [crate::traits::UnaryOp]), and binds its operand +/// as an [OpFoldResult]. +/// +/// This is done by examining the defining op of the operand to determine if it is a constant, and +/// if so, it binds the constant value, rather than the SSA value. +/// +/// This can be used to setup for folding. +pub const fn unary_fold_result() -> impl Matcher { + match_chain(UnaryOpBinder, FoldResultBinder) +} + +/// Matches any unary operation (i.e. implements [crate::traits::UnaryOp]) whose operand is a +/// materialized constant, and thus a prime candidate for folding. +/// +/// The constant value is bound by this matcher, so it can be used immediately for folding. +pub const fn unary_foldable() -> impl Matcher> { + match_chain(UnaryOpBinder, FoldableOperandBinder) +} + +/// Matches any binary operation (i.e. implements [crate::traits::BinaryOp]), and binds its operands. +pub const fn binary() -> impl Matcher { + BinaryOpBinder +} + +/// Matches any binary operation (i.e. implements [crate::traits::BinaryOp]), and binds its operands +/// as [OpFoldResult]s. +/// +/// This is done by examining the defining op of the operands to determine if they are constant, and +/// if so, binds the constant value, rather than the SSA value. +/// +/// This can be used to setup for folding. +pub const fn binary_fold_results() -> impl Matcher { + match_chain(BinaryOpBinder, BinaryFoldResultBinder) +} + +/// Matches any binary operation (i.e. implements [crate::traits::BinaryOp]) whose operands are +/// both materialized constants, and thus a prime candidate for folding. +/// +/// The constant values are bound by this matcher, so they can be used immediately for folding. +pub const fn binary_foldable() -> impl Matcher; 2]> { + match_chain(BinaryOpBinder, FoldableBinaryOpBinder) +} + +// Value Matchers + +/// Matches any value, i.e. it always matches +pub const fn match_any_value() -> impl Matcher { + AnyValueMatcher +} + +/// Matches any instance of `value`, i.e. it requires an exact match +pub const fn match_value(value: ValueRef) -> impl Matcher { + ExactValueMatcher(value) +} + +pub const fn foldable_operand() -> impl Matcher> { + FoldableOperandBinder +} + +pub const fn foldable_operand_of() -> impl Matcher> +where + T: AttributeValue + Clone, +{ + TypedFoldableOperandBinder(core::marker::PhantomData) +} + +#[cfg(test)] +mod tests { + use alloc::rc::Rc; + + use super::*; + use crate::{ + dialects::hir::{InstBuilder, *}, + *, + }; + + #[test] + fn matcher_match_any_value() { + let context = Rc::new(Context::default()); + + let (lhs, rhs, sum) = setup(context.clone()); + + // All three values should `match_any_value` + for value in [&lhs, &rhs, &sum] { + assert_eq!(match_any_value().matches(value).as_ref(), Some(value)); + } + } + + #[test] + fn matcher_match_value() { + let context = Rc::new(Context::default()); + + let (lhs, rhs, sum) = setup(context.clone()); + + // All three values should match themselves via `match_value` + for value in [&lhs, &rhs, &sum] { + assert_eq!(match_value(value.clone()).matches(value).as_ref(), Some(value)); + } + } + + #[test] + fn matcher_match_any() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, sum) = setup(context.clone()); + + // We should be able to match `lhs` and `sum` ops using `match_any` + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + for op in [&lhs_op, &sum_op] { + assert_eq!(match_any().matches(&op.borrow()).as_ref(), Some(op)); + } + } + + #[test] + fn matcher_match_op() { + let context = Rc::new(Context::default()); + + let (lhs, rhs, sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + assert!(rhs.borrow().get_defining_op().is_none()); + + // Both `lhs` and `sum` ops should be matched as their respective operation types, and not + // as a different operation type + assert!(match_op::().matches(&lhs_op.borrow()).is_some()); + assert!(match_op::().matches(&sum_op.borrow()).is_none()); + assert!(match_op::().matches(&lhs_op.borrow()).is_none()); + assert!(match_op::().matches(&sum_op.borrow()).is_some()); + } + + #[test] + fn matcher_match_both() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, _sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + + // Ensure if the first matcher fails, then the whole match fails + assert!(match_both(match_op::(), constant_of::()) + .matches(&lhs_op.borrow()) + .is_none()); + // Ensure if the second matcher fails, then the whole match fails + assert!(match_both(constant_like(), constant_of::()) + .matches(&lhs_op.borrow()) + .is_none()); + // Ensure that if both matchers would succeed, then the whole match succeeds + assert!(match_both(constant_like(), constant_of::()) + .matches(&lhs_op.borrow()) + .is_some()); + } + + #[test] + fn matcher_match_chain() { + let context = Rc::new(Context::default()); + + let (_, rhs, sum) = setup(context.clone()); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + let [lhs_fr, rhs_fr] = binary_fold_results() + .matches(&sum_op.borrow()) + .expect("expected to bind both operands of 'add'"); + assert_eq!(lhs_fr, OpFoldResult::Attribute(Box::new(Immediate::U32(1)))); + assert_eq!(rhs_fr, OpFoldResult::Value(rhs)); + } + + #[test] + fn matcher_constant_like() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + // Only `lhs` should be matched by `constant_like` + assert!(constant_like().matches(&lhs_op.borrow()).is_some()); + assert!(constant_like().matches(&sum_op.borrow()).is_none()); + } + + #[test] + fn matcher_constant() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + // Only `lhs` should produce a matching constant value + assert!(constant().matches(&lhs_op.borrow()).is_some()); + assert!(constant().matches(&sum_op.borrow()).is_none()); + } + + #[test] + fn matcher_constant_of() { + let context = Rc::new(Context::default()); + + let (lhs, _rhs, sum) = setup(context.clone()); + let lhs_op = lhs.borrow().get_defining_op().unwrap(); + let sum_op = sum.borrow().get_defining_op().unwrap(); + + // `lhs` should produce a matching constant value of the correct type and value + assert_eq!(constant_of::().matches(&lhs_op.borrow()), Some(Immediate::U32(1))); + assert!(constant_of::().matches(&sum_op.borrow()).is_none()); + } + + fn setup(context: Rc) -> (ValueRef, ValueRef, ValueRef) { + let mut builder = OpBuilder::new(Rc::clone(&context)); + + let mut function = { + let builder = builder.create::(SourceSpan::default()); + let id = Ident::new("test".into(), SourceSpan::default()); + let signature = Signature::new([AbiParam::new(Type::U32)], [AbiParam::new(Type::U32)]); + builder(id, signature).unwrap() + }; + + // Define function body + let mut func = function.borrow_mut(); + let mut builder = FunctionBuilder::new(&mut func); + let lhs = builder.ins().u32(1, SourceSpan::default()).unwrap(); + let block = builder.current_block(); + let rhs = block.borrow().arguments()[0].clone().upcast(); + let sum = builder.ins().add(lhs.clone(), rhs.clone(), SourceSpan::default()).unwrap(); + builder.ins().ret(Some(sum.clone()), SourceSpan::default()).unwrap(); + + (lhs, rhs, sum) + } +} diff --git a/hir2/src/pass.rs b/hir2/src/pass.rs new file mode 100644 index 000000000..c46d3306d --- /dev/null +++ b/hir2/src/pass.rs @@ -0,0 +1,19 @@ +mod analysis; +mod instrumentation; +mod manager; +#[allow(clippy::module_inception)] +mod pass; +pub mod registry; +mod specialization; +pub mod statistics; + +use self::pass::PassExecutionState; +pub use self::{ + analysis::{Analysis, AnalysisManager, OperationAnalysis, PreservedAnalyses}, + instrumentation::{PassInstrumentation, PassInstrumentor, PipelineParentInfo}, + manager::{Nesting, OpPassManager, PassDisplayMode, PassManager}, + pass::{OperationPass, Pass}, + registry::{PassInfo, PassPipelineInfo}, + specialization::PassTarget, + statistics::{PassStatistic, Statistic, StatisticValue}, +}; diff --git a/hir2/src/pass/analysis.rs b/hir2/src/pass/analysis.rs new file mode 100644 index 000000000..aa0bb6a50 --- /dev/null +++ b/hir2/src/pass/analysis.rs @@ -0,0 +1,694 @@ +use alloc::rc::Rc; +use core::{ + any::{Any, TypeId}, + cell::RefCell, +}; + +use smallvec::SmallVec; + +type FxHashMap = hashbrown::HashMap; + +use super::{PassInstrumentor, PassTarget}; +use crate::{Op, Operation, OperationRef}; + +/// The [Analysis] trait is used to define an analysis over some operation. +/// +/// Analyses must be default-constructible, and `Sized + 'static` to support downcasting. +/// +/// An analysis, when requested, is first constructed via its `Default` implementation, and then +/// [Analysis::analyze] is called on the target type in order to compute the analysis results. +/// The analysis type also acts as storage for the analysis results. +/// +/// When the IR is changed, analyses are invalidated by default, unless they are specifically +/// preserved via the [PreservedAnalyses] set. When an analysis is being asked if it should be +/// invalidated, via [Analysis::invalidate], it has the opportunity to identify if it actually +/// needs to be invalidated based on what analyses were preserved. If dependent analyses of this +/// analysis haven't been invalidated, then this analysis may be able preserve itself as well, +/// and avoid redundant recomputation. +pub trait Analysis: Default + Any { + /// The specific type on which this analysis is performed. + /// + /// The analysis will only be run when an operation is of this type. + type Target: ?Sized + PassTarget; + + /// The [TypeId] associated with the concrete underlying [Analysis] implementation + /// + /// This is automatically implemented for you, but in some cases, such as wrapping an + /// analysis in another type, you may want to implement this so that queries against the + /// type return the expected [TypeId] + #[inline] + fn analysis_id(&self) -> TypeId { + TypeId::of::() + } + + /// Get a `dyn Any` reference to the underlying [Analysis] implementation + /// + /// This is automatically implemented for you, but in some cases, such as wrapping an + /// analysis in another type, you may want to implement this so that queries against the + /// type return the expected [TypeId] + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + /// Same as [Analysis::as_any], but used specifically for getting a reference-counted handle, + /// rather than a raw reference. + #[inline(always)] + fn as_any_rc(self: Rc) -> Rc { + self as Rc + } + + /// Returns the display name for this analysis + /// + /// By default this simply returns the name of the concrete implementation type. + fn name(&self) -> &'static str { + core::any::type_name::() + } + + /// Analyze `op` using the provided [AnalysisManager]. + fn analyze(&mut self, op: &Self::Target, analysis_manager: AnalysisManager); + + /// Query this analysis for invalidation. + /// + /// Given a preserved analysis set, returns true if it should truly be invalidated. This allows + /// for more fine-tuned invalidation in cases where an analysis wasn't explicitly marked + /// preserved, but may be preserved(or invalidated) based upon other properties such as analyses + /// sets. + fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) -> bool; +} + +/// A type-erased [Analysis]. +/// +/// This is automatically derived for all [Analysis] implementations, and is the means by which +/// one can abstract over sets of analyses using dynamic dispatch. +/// +/// This essentially just delegates to the underlying [Analysis] implementation, but it also handles +/// converting a raw [OperationRef] to the appropriate target type expected by the underlying +/// [Analysis]. +pub trait OperationAnalysis { + /// The unique type id of this analysis + fn analysis_id(&self) -> TypeId; + + /// Used for dynamic casting to the underlying [Analysis] type + fn as_any(&self) -> &dyn Any; + + /// Used for dynamic casting to the underlying [Analysis] type + fn as_any_rc(self: Rc) -> Rc; + + /// The name of this analysis + fn name(&self) -> &'static str; + + /// Runs this analysis over `op`. + /// + /// NOTE: This is only ever called once per instantiation of the analysis, but in theory can + /// support multiple calls to re-analyze `op`. Each call should reset any internal state to + /// ensure that if an analysis is reused in this way, that each analysis gets a clean slate. + fn analyze(&mut self, op: &OperationRef, am: AnalysisManager); + + /// Query this analysis for invalidation. + /// + /// Given a preserved analysis set, returns true if it should truly be invalidated. This allows + /// for more fine-tuned invalidation in cases where an analysis wasn't explicitly marked + /// preserved, but may be preserved(or invalidated) based upon other properties such as analyses + /// sets. + /// + /// Invalidated analyses must be removed from `preserved_analyses`. + fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) -> bool; +} + +impl dyn OperationAnalysis { + /// Cast an reference-counted handle to this analysis to its concrete implementation type. + /// + /// Returns `None` if the underlying analysis is not of type `T` + #[inline] + pub fn downcast(self: Rc) -> Option> { + self.as_any_rc().downcast::().ok() + } +} + +impl OperationAnalysis for A +where + A: Analysis, +{ + #[inline] + fn analysis_id(&self) -> TypeId { + ::analysis_id(self) + } + + #[inline] + fn as_any(&self) -> &dyn Any { + ::as_any(self) + } + + #[inline] + fn as_any_rc(self: Rc) -> Rc { + ::as_any_rc(self) + } + + #[inline] + fn name(&self) -> &'static str { + ::name(self) + } + + #[inline] + fn analyze(&mut self, op: &OperationRef, am: AnalysisManager) { + let op = <::Target as PassTarget>::into_target(op); + ::analyze(self, &op, am) + } + + #[inline] + fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) -> bool { + ::invalidate(self, preserved_analyses) + } +} + +/// Represents a set of analyses that are known to be preserved after a rewrite has been applied. +#[derive(Default)] +pub struct PreservedAnalyses { + /// The set of preserved analysis type ids + preserved: SmallVec<[TypeId; 8]>, +} +impl PreservedAnalyses { + /// Mark all analyses as preserved. + /// + /// This is generally only useful when the IR is known not to have changed. + pub fn preserve_all(&mut self) { + self.insert(AllAnalyses::TYPE_ID); + } + + /// Mark the specified [Analysis] type as preserved. + pub fn preserve(&mut self) { + self.insert(TypeId::of::()); + } + + /// Mark a type as preserved using its raw [TypeId]. + /// + /// Typically it is best to use [Self::preserve] instead, but this can be useful in cases + /// where you can't express the type in Rust directly. + pub fn preserve_raw(&mut self, id: TypeId) { + self.insert(id); + } + + /// Returns true if the specified type is preserved. + pub fn is_preserved(&self) -> bool { + self.preserved.contains(&TypeId::of::()) + } + + /// Returns true if the specified [TypeId] is marked preserved. + pub fn is_preserved_raw(&self, ty: &TypeId) -> bool { + self.preserved.contains(ty) + } + + /// Mark a previously preserved type as invalidated. + pub fn unpreserve(&mut self) { + self.remove(&TypeId::of::()) + } + + /// Mark a previously preserved [TypeId] as invalidated. + pub fn unpreserve_raw(&mut self, ty: &TypeId) { + self.remove(ty) + } + + /// Returns true if all analyses are preserved + pub fn is_all(&self) -> bool { + self.preserved.contains(&AllAnalyses::TYPE_ID) + } + + /// Returns true if no analyses are being preserved + pub fn is_none(&self) -> bool { + self.preserved.is_empty() + } + + fn insert(&mut self, id: TypeId) { + match self.preserved.binary_search_by_key(&id, |probe| *probe) { + Ok(index) => self.preserved.insert(index, id), + Err(index) => self.preserved.insert(index, id), + } + } + + fn remove(&mut self, id: &TypeId) { + if let Ok(index) = self.preserved.binary_search_by_key(&id, |probe| probe) { + self.preserved.remove(index); + } + } +} + +/// A marker type that is used to represent all possible [Analysis] types +pub struct AllAnalyses; +impl AllAnalyses { + const TYPE_ID: TypeId = TypeId::of::(); +} + +/// This type wraps all analyses stored in an [AnalysisMap], and handles some of the boilerplate +/// details around invalidation by intercepting calls to [Analysis::invalidate] and wrapping it +/// with extra logic. Notably, ensuring that invalidated analyses are removed from the +/// [PreservedAnalyses] set is handled by this wrapper. +/// +/// It is a transparent wrapper around `A`, and otherwise acts as a simple proxy to `A`'s +/// implementation of the [Analysis] trait. +#[repr(transparent)] +struct AnalysisWrapper { + analysis: A, +} +impl AnalysisWrapper { + fn new(op: &::Target, am: AnalysisManager) -> Self { + let mut analysis = A::default(); + analysis.analyze(op, am); + + Self { analysis } + } +} +impl Default for AnalysisWrapper { + fn default() -> Self { + Self { + analysis: Default::default(), + } + } +} +impl Analysis for AnalysisWrapper { + type Target = ::Target; + + #[inline] + fn analysis_id(&self) -> TypeId { + self.analysis.analysis_id() + } + + #[inline] + fn as_any(&self) -> &dyn Any { + self.analysis.as_any() + } + + #[inline] + fn as_any_rc(self: Rc) -> Rc { + // SAFETY: This transmute is safe because AnalysisWrapper is a transparent wrapper + // around A, so a pointer to the former is a pointer to the latter + let ptr = Rc::into_raw(self); + unsafe { Rc::::from_raw(ptr.cast()) as Rc } + } + + #[inline] + fn name(&self) -> &'static str { + self.analysis.name() + } + + #[inline] + fn analyze(&mut self, op: &Self::Target, am: AnalysisManager) { + self.analysis.analyze(op, am); + } + + fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) -> bool { + let invalidated = self.analysis.invalidate(preserved_analyses); + if invalidated { + preserved_analyses.unpreserve::(); + } + invalidated + } +} + +/// An [AnalysisManager] is the primary entrypoint for performing analysis on a specific operation +/// instance that it is constructed for. +/// +/// It is used to manage and cache analyses for the operation, as well as those of child operations, +/// via nested [AnalysisManager] instances. +/// +/// This type is a thin wrapper around a pointer, and is meant to be passed by value. It can be +/// cheaply cloned. +#[derive(Clone)] +#[repr(transparent)] +pub struct AnalysisManager { + analyses: Rc, +} +impl AnalysisManager { + /// Create a new top-level [AnalysisManager] for `op` + pub fn new(op: OperationRef, instrumentor: Option>) -> Self { + Self { + analyses: Rc::new(NestedAnalysisMap::new(op, instrumentor)), + } + } + + /// Query for a cached analysis on the given parent operation. The analysis may not exist and if + /// it does it may be out-of-date. + pub fn get_cached_parent_analysis(&self, parent: &OperationRef) -> Option> + where + A: Analysis, + { + let mut current_parent = self.analyses.parent(); + while let Some(parent_am) = current_parent.take() { + if &parent_am.get_operation() == parent { + return parent_am.analyses().get_cached::(); + } + current_parent = parent_am.parent(); + } + None + } + + /// Query for the given analysis for the current operation. + pub fn get_analysis(&self) -> Rc + where + A: Analysis, + { + self.analyses.analyses.borrow_mut().get(self.pass_instrumentor(), self.clone()) + } + + /// Query for the given analysis for the current operation of a specific derived operation type. + /// + /// NOTE: This will panic if the current operation is not of type `O`. + pub fn get_analysis_for(&self) -> Rc + where + A: Analysis, + O: 'static, + { + self.analyses + .analyses + .borrow_mut() + .get_analysis_for::(self.pass_instrumentor(), self.clone()) + } + + /// Query for a cached entry of the given analysis on the current operation. + pub fn get_cached_analysis(&self) -> Option> + where + A: Analysis, + { + self.analyses.analyses().get_cached::() + } + + /// Query for an analysis of a child operation, constructing it if necessary. + pub fn get_child_analysis(&self, op: &OperationRef) -> Rc + where + A: Analysis, + { + self.clone().nest(op).get_analysis::() + } + + /// Query for an analysis of a child operation of a specific derived operation type, + /// constructing it if necessary. + /// + /// NOTE: This will panic if `op` is not of type `O`. + pub fn get_child_analysis_for(&self, op: &O) -> Rc + where + A: Analysis, + O: Op, + { + self.clone() + .nest(&op.as_operation().as_operation_ref()) + .get_analysis_for::() + } + + /// Query for a cached analysis of a child operation, or return `None`. + pub fn get_cached_child_analysis(&self, child: &OperationRef) -> Option> + where + A: Analysis, + { + assert!(child.borrow().parent_op().unwrap() == self.analyses.get_operation()); + let child_analyses = self.analyses.child_analyses.borrow(); + let child_analyses = child_analyses.get(child)?; + let child_analyses = child_analyses.analyses.borrow(); + child_analyses.get_cached::() + } + + /// Get an analysis manager for the given operation, which must be a proper descendant of the + /// current operation represented by this analysis manager. + pub fn nest(&self, op: &OperationRef) -> AnalysisManager { + let current_op = self.analyses.get_operation(); + assert!(current_op.borrow().is_proper_ancestor_of(op), "expected valid descendant op"); + + // Check for the base case where the provided operation is immediately nested + if current_op == op.borrow().parent_op().expect("expected `op` to have a parent") { + return self.nest_immediate(op.clone()); + } + + // Otherwise, we need to collect all ancestors up to the current operation + let mut ancestors = SmallVec::<[OperationRef; 4]>::default(); + let mut next_op = op.clone(); + while next_op != current_op { + ancestors.push(next_op.clone()); + next_op = next_op.borrow().parent_op().unwrap(); + } + + let mut manager = self.clone(); + while let Some(op) = ancestors.pop() { + manager = manager.nest_immediate(op); + } + manager + } + + fn nest_immediate(&self, op: OperationRef) -> AnalysisManager { + use hashbrown::hash_map::Entry; + + assert!( + Some(self.analyses.get_operation()) == op.borrow().parent_op(), + "expected immediate child operation" + ); + let parent = self.analyses.clone(); + let mut child_analyses = self.analyses.child_analyses.borrow_mut(); + match child_analyses.entry(op.clone()) { + Entry::Vacant(entry) => { + let analyses = entry.insert(Rc::new(parent.nest(op))); + AnalysisManager { + analyses: Rc::clone(analyses), + } + } + Entry::Occupied(entry) => AnalysisManager { + analyses: Rc::clone(entry.get()), + }, + } + } + + /// Invalidate any non preserved analyses. + #[inline] + pub fn invalidate(&self, preserved_analyses: &mut PreservedAnalyses) { + Rc::clone(&self.analyses).invalidate(preserved_analyses) + } + + /// Clear any held analyses. + #[inline] + pub fn clear(&mut self) { + self.analyses.clear(); + } + + /// Clear any held analyses when the returned guard is dropped. + #[inline] + pub fn defer_clear(&self) -> ResetAnalysesOnDrop { + ResetAnalysesOnDrop { + analyses: self.analyses.clone(), + } + } + + /// Returns a [PassInstrumentor] for the current operation, if one was installed. + #[inline] + pub fn pass_instrumentor(&self) -> Option> { + self.analyses.pass_instrumentor() + } +} + +#[must_use] +#[doc(hidden)] +pub struct ResetAnalysesOnDrop { + analyses: Rc, +} +impl Drop for ResetAnalysesOnDrop { + fn drop(&mut self) { + self.analyses.clear() + } +} + +/// An analysis map that contains a map for the current operation, and a set of maps for any child +/// operations. +struct NestedAnalysisMap { + parent: Option>, + instrumentor: Option>, + analyses: RefCell, + child_analyses: RefCell>>, +} +impl NestedAnalysisMap { + /// Create a new top-level [NestedAnalysisMap] for `op`, with the given optional pass + /// instrumentor. + pub fn new(op: OperationRef, instrumentor: Option>) -> Self { + Self { + parent: None, + instrumentor, + analyses: RefCell::new(AnalysisMap::new(op)), + child_analyses: Default::default(), + } + } + + /// Create a new [NestedAnalysisMap] for `op` nested under `self`. + pub fn nest(self: Rc, op: OperationRef) -> Self { + let instrumentor = self.instrumentor.clone(); + Self { + parent: Some(self), + instrumentor, + analyses: RefCell::new(AnalysisMap::new(op)), + child_analyses: Default::default(), + } + } + + /// Get the parent [NestedAnalysisMap], or `None` if this is a top-level map. + pub fn parent(&self) -> Option> { + self.parent.clone() + } + + /// Return a [PassInstrumentor] for the current operation, if one was installed. + pub fn pass_instrumentor(&self) -> Option> { + self.instrumentor.clone() + } + + /// Get the operation for this analysis map. + #[inline] + pub fn get_operation(&self) -> OperationRef { + self.analyses.borrow().get_operation() + } + + fn analyses(&self) -> core::cell::Ref<'_, AnalysisMap> { + self.analyses.borrow() + } + + /// Invalidate any non preserved analyses. + pub fn invalidate(self: Rc, preserved_analyses: &mut PreservedAnalyses) { + // If all analyses were preserved, then there is nothing to do + if preserved_analyses.is_all() { + return; + } + + // Invalidate the analyses for the current operation directly + self.analyses.borrow_mut().invalidate(preserved_analyses); + + // If no analyses were preserved, then just simply clear out the child analysis results + if preserved_analyses.is_none() { + self.child_analyses.borrow_mut().clear(); + } + + // Otherwise, invalidate each child analysis map + let mut to_invalidate = SmallVec::<[Rc; 8]>::from_iter([self]); + while let Some(map) = to_invalidate.pop() { + map.child_analyses.borrow_mut().retain(|_op, nested_analysis_map| { + Rc::clone(nested_analysis_map).invalidate(preserved_analyses); + if nested_analysis_map.child_analyses.borrow().is_empty() { + false + } else { + to_invalidate.push(Rc::clone(nested_analysis_map)); + true + } + }); + } + } + + pub fn clear(&self) { + self.child_analyses.borrow_mut().clear(); + self.analyses.borrow_mut().clear(); + } +} + +/// This class represents a cache of analyses for a single operation. +/// +/// All computation, caching, and invalidation of analyses takes place here. +struct AnalysisMap { + analyses: FxHashMap>, + ir: OperationRef, +} +impl AnalysisMap { + pub fn new(ir: OperationRef) -> Self { + Self { + analyses: Default::default(), + ir, + } + } + + /// Get an analysis for the current IR unit, computing it if necessary. + pub fn get(&mut self, pi: Option>, am: AnalysisManager) -> Rc + where + A: Analysis, + { + Self::get_analysis_impl::( + &mut self.analyses, + pi, + &self.ir.borrow(), + &self.ir, + am, + ) + } + + /// Get a cached analysis instance if one exists, otherwise return `None`. + pub fn get_cached(&self) -> Option> + where + A: Analysis, + { + self.analyses.get(&TypeId::of::()).cloned().and_then(|a| a.downcast::()) + } + + /// Get an analysis for the current IR unit, assuming it's of the specified type, computing it + /// if necessary. + /// + /// NOTE: This will panic if the current operation is not of type `O`. + pub fn get_analysis_for( + &mut self, + pi: Option>, + am: AnalysisManager, + ) -> Rc + where + A: Analysis, + O: 'static, + { + let ir = <::Target as PassTarget>::into_target(&self.ir); + Self::get_analysis_impl::(&mut self.analyses, pi, &*ir, &self.ir, am) + } + + fn get_analysis_impl( + analyses: &mut FxHashMap>, + pi: Option>, + ir: &O, + op: &OperationRef, + am: AnalysisManager, + ) -> Rc + where + A: Analysis, + { + use hashbrown::hash_map::Entry; + + let id = TypeId::of::(); + match analyses.entry(id) { + Entry::Vacant(entry) => { + // We don't have a cached analysis for the operation, compute it directly and + // add it to the cache. + if let Some(pi) = pi.as_deref() { + pi.run_before_analysis(core::any::type_name::(), &id, op); + } + + let analysis = entry.insert(Self::construct_analysis::(am, ir)); + + if let Some(pi) = pi.as_deref() { + pi.run_after_analysis(core::any::type_name::(), &id, op); + } + + Rc::clone(analysis).downcast::().unwrap() + } + Entry::Occupied(entry) => Rc::clone(entry.get()).downcast::().unwrap(), + } + } + + fn construct_analysis(am: AnalysisManager, op: &O) -> Rc + where + A: Analysis, + { + Rc::new(AnalysisWrapper::::new(op, am)) as Rc + } + + /// Returns the operation that this analysis map represents. + pub fn get_operation(&self) -> OperationRef { + self.ir.clone() + } + + /// Clear any held analyses. + pub fn clear(&mut self) { + self.analyses.clear(); + } + + /// Invalidate any cached analyses based upon the given set of preserved analyses. + pub fn invalidate(&mut self, preserved_analyses: &mut PreservedAnalyses) { + // Remove any analyses that were invalidated. + // + // Using `retain`, we preserve the original insertion order, and dependencies always go + // before users, so we need only a single pass through. + self.analyses.retain(|_, a| !a.invalidate(preserved_analyses)); + } +} diff --git a/hir2/src/pass/instrumentation.rs b/hir2/src/pass/instrumentation.rs new file mode 100644 index 000000000..d16265671 --- /dev/null +++ b/hir2/src/pass/instrumentation.rs @@ -0,0 +1,129 @@ +use core::{any::TypeId, cell::RefCell}; + +use compact_str::CompactString; +use smallvec::SmallVec; + +use super::OperationPass; +use crate::{OperationName, OperationRef}; + +#[allow(unused_variables)] +pub trait PassInstrumentation { + fn run_before_pipeline( + &mut self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + } + fn run_after_pipeline( + &mut self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + } + fn run_before_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} + fn run_after_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} + fn run_after_pass_failed(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} + fn run_before_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) {} + fn run_after_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) {} +} + +pub struct PipelineParentInfo { + /// The pass that spawned this pipeline, if any + pub pass: Option, +} + +impl PassInstrumentation for Box

{ + fn run_before_pipeline( + &mut self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + (**self).run_before_pipeline(name, parent_info); + } + + fn run_after_pipeline( + &mut self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + (**self).run_after_pipeline(name, parent_info); + } + + fn run_before_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) { + (**self).run_before_pass(pass, op); + } + + fn run_after_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) { + (**self).run_after_pass(pass, op); + } + + fn run_after_pass_failed(&mut self, pass: &dyn OperationPass, op: &OperationRef) { + (**self).run_after_pass_failed(pass, op); + } + + fn run_before_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) { + (**self).run_before_analysis(name, id, op); + } + + fn run_after_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) { + (**self).run_after_analysis(name, id, op); + } +} + +#[derive(Default)] +pub struct PassInstrumentor { + instrumentations: RefCell; 1]>>, +} + +impl PassInstrumentor { + pub fn run_before_pipeline( + &self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + self.instrument(|pi| pi.run_before_pipeline(name, parent_info)); + } + + pub fn run_after_pipeline( + &self, + name: Option<&OperationName>, + parent_info: &PipelineParentInfo, + ) { + self.instrument(|pi| pi.run_after_pipeline(name, parent_info)); + } + + pub fn run_before_pass(&self, pass: &dyn OperationPass, op: &OperationRef) { + self.instrument(|pi| pi.run_before_pass(pass, op)); + } + + pub fn run_after_pass(&self, pass: &dyn OperationPass, op: &OperationRef) { + self.instrument(|pi| pi.run_after_pass(pass, op)); + } + + pub fn run_after_pass_failed(&self, pass: &dyn OperationPass, op: &OperationRef) { + self.instrument(|pi| pi.run_after_pass_failed(pass, op)); + } + + pub fn run_before_analysis(&self, name: &str, id: &TypeId, op: &OperationRef) { + self.instrument(|pi| pi.run_before_analysis(name, id, op)); + } + + pub fn run_after_analysis(&self, name: &str, id: &TypeId, op: &OperationRef) { + self.instrument(|pi| pi.run_after_analysis(name, id, op)); + } + + pub fn add_instrumentation(&self, pi: Box) { + self.instrumentations.borrow_mut().push(pi); + } + + #[inline(always)] + fn instrument(&self, callback: F) + where + F: Fn(&mut dyn PassInstrumentation), + { + let mut instrumentations = self.instrumentations.borrow_mut(); + for pi in instrumentations.iter_mut() { + callback(pi); + } + } +} diff --git a/hir2/src/pass/manager.rs b/hir2/src/pass/manager.rs new file mode 100644 index 000000000..aaba3a5df --- /dev/null +++ b/hir2/src/pass/manager.rs @@ -0,0 +1,912 @@ +use alloc::{collections::BTreeMap, rc::Rc}; + +use compact_str::{CompactString, ToCompactString}; +use miden_assembly::diagnostics::Severity; +use smallvec::{smallvec, SmallVec}; + +use super::{ + AnalysisManager, OperationPass, Pass, PassExecutionState, PassInstrumentation, + PassInstrumentor, PipelineParentInfo, Statistic, +}; +use crate::{ + traits::IsolatedFromAbove, Context, EntityMut, OpPrintingFlags, OpRegistration, Operation, + OperationName, OperationRef, Report, +}; + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)] +pub enum Nesting { + Implicit, + #[default] + Explicit, +} + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)] +pub enum PassDisplayMode { + List, + #[default] + Pipeline, +} + +// TODO(pauls) +#[allow(unused)] +pub struct IRPrintingConfig { + print_module_scope: bool, + print_after_only_on_change: bool, + print_after_only_on_failure: bool, + flags: OpPrintingFlags, +} + +/// The main pass manager and pipeline builder +pub struct PassManager { + context: Rc, + /// The underlying pass manager + pm: OpPassManager, + /// A manager for pass instrumentation + instrumentor: Rc, + /// An optional crash reproducer generator, if this pass manager is setup to + /// generate reproducers. + ///crash_reproducer_generator: Rc, + /// Indicates whether to print pass statistics + statistics: Option, + /// Indicates whether or not pass timing is enabled + timing: bool, + /// Indicates whether or not to run verification between passes + verification: bool, +} + +impl PassManager { + /// Create a new pass manager under the given context with a specific nesting + /// style. The created pass manager can schedule operations that match + /// `operationName`. + pub fn new(context: Rc, name: impl AsRef, nesting: Nesting) -> Self { + let pm = OpPassManager::new(name.as_ref(), nesting, context.clone()); + Self { + context, + pm, + instrumentor: Default::default(), + statistics: None, + timing: false, + verification: true, + } + } + + /// Create a new pass manager under the given context with a specific nesting + /// style. The created pass manager can schedule operations that match + /// `OperationTy`. + pub fn on(context: Rc, nesting: Nesting) -> Self { + Self::new(context, ::name(), nesting) + } + + /// Run the passes within this manager on the provided operation. The + /// specified operation must have the same name as the one provided the pass + /// manager on construction. + pub fn run(&mut self, op: OperationRef) -> Result<(), Report> { + use crate::Spanned; + + let op_name = op.borrow().name(); + let anchor = self.pm.name(); + if let Some(anchor) = anchor { + if anchor != &op_name { + return Err(self + .context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("failed to construct pass manager") + .with_primary_label( + op.borrow().span(), + format!("can't run '{anchor}' pass manager on '{op_name}'"), + ) + .into_report()); + } + } + + // Register all dialects for the current pipeline. + /* + let dependent_dialects = self.get_dependent_dialects(); + self.context.append_dialect_registry(dependent_dialects); + for dialect_name in dependent_dialects.names() { + self.context.get_or_register_dialect(dialect_name); + } + */ + + // Before running, make sure to finalize the pipeline pass list. + self.pm.finalize_pass_list()?; + + // Construct a top level analysis manager for the pipeline. + let analysis_manager = AnalysisManager::new(op.clone(), Some(self.instrumentor.clone())); + + // If reproducer generation is enabled, run the pass manager with crash handling enabled. + /* + let result = if self.crash_reproducer_generator.is_some() { + self.run_with_crash_recovery(op, analysis_manager); + } else { + self.run_passes(op, analysis_manager); + } + */ + let result = self.run_passes(op, analysis_manager); + + // Dump all of the pass statistics if necessary. + if self.statistics.is_some() { + let mut output = String::new(); + self.dump_statistics(&mut output).map_err(Report::msg)?; + println!("{output}"); + } + + result + } + + fn run_passes( + &mut self, + op: OperationRef, + analysis_manager: AnalysisManager, + ) -> Result<(), Report> { + OpToOpPassAdaptor::run_pipeline( + &mut self.pm, + op, + analysis_manager, + self.verification, + Some(self.instrumentor.clone()), + None, + ) + } + + #[inline] + pub fn context(&self) -> Rc { + self.context.clone() + } + + /// Runs the verifier after each individual pass. + pub fn enable_verifier(&mut self, yes: bool) -> &mut Self { + self.verification = yes; + self + } + + pub fn add_instrumentation(&mut self, pi: Box) -> &mut Self { + self.instrumentor.add_instrumentation(pi); + self + } + + pub fn enable_ir_printing(&mut self, _config: IRPrintingConfig) { + todo!() + } + + pub fn enable_timing(&mut self, yes: bool) -> &mut Self { + self.timing = yes; + self + } + + pub fn enable_statistics(&mut self, mode: Option) -> &mut Self { + self.statistics = mode; + self + } + + fn dump_statistics(&mut self, out: &mut dyn core::fmt::Write) -> core::fmt::Result { + self.pm.print_statistics(out, self.statistics.unwrap_or_default()) + } +} + +/// This class represents a pass manager that runs passes on either a specific +/// operation type, or any isolated operation. This pass manager can not be run +/// on an operation directly, but must be run either as part of a top-level +/// `PassManager`(e.g. when constructed via `nest` calls), or dynamically within +/// a pass by using the `Pass::runPipeline` API. +pub struct OpPassManager { + /// The current context + context: Rc, + /// The name of the operation that passes of this pass manager operate on + name: Option, + /// The set of passes to run as part of this pass manager + passes: SmallVec<[Box; 8]>, + /// Control the implicit nesting of passes that mismatch the name set for this manager + nesting: Nesting, +} + +impl OpPassManager { + pub const ANY: &str = "any"; + + /// Construct a new op-agnostic ("any") pass manager with the given operation + /// type and nesting behavior. This is the same as invoking: + /// `OpPassManager(OpPassManager::ANY, nesting)`. + pub fn any(nesting: Nesting, context: Rc) -> Self { + Self { + context, + name: None, + passes: Default::default(), + nesting, + } + } + + pub fn new(name: &str, nesting: Nesting, context: Rc) -> Self { + if name == Self::ANY { + return Self::any(nesting, context); + } + + let (dialect_name, opcode) = name.split_once('.').expect( + "invalid operation name: expected format `.`, but missing `.`", + ); + let dialect_name = + crate::DialectName::from_symbol(crate::interner::Symbol::intern(dialect_name)); + let dialect = context.get_registered_dialect(&dialect_name); + let ops = dialect.registered_ops(); + let name = + ops.iter() + .find(|name| name.name().as_str() == opcode) + .cloned() + .unwrap_or_else(|| { + panic!( + "invalid operation name: found dialect '{dialect_name}', but no operation \ + called '{opcode}' is registered to that dialect" + ) + }); + Self { + context, + name: Some(name), + passes: Default::default(), + nesting, + } + } + + pub fn for_operation(name: OperationName, nesting: Nesting, context: Rc) -> Self { + Self { + context, + name: Some(name), + passes: Default::default(), + nesting, + } + } + + pub fn context(&self) -> Rc { + self.context.clone() + } + + pub fn passes(&self) -> &[Box] { + &self.passes + } + + pub fn passes_mut(&mut self) -> &mut [Box] { + &mut self.passes + } + + pub fn is_empty(&self) -> bool { + self.passes.is_empty() + } + + pub fn len(&self) -> usize { + self.passes.len() + } + + pub fn clear(&mut self) { + self.passes.clear(); + } + + /// Nest a new op-specific pass manager (for the op with the given name), under this pass manager. + pub fn nest_with_type(&mut self, nested_name: &str) -> NestedOpPassManager<'_> { + self.nest(Self::new(nested_name, self.nesting, self.context.clone())) + } + + /// Nest a new op-agnostic ("any") pass manager under this pass manager. + pub fn nest_any(&mut self) -> NestedOpPassManager<'_> { + self.nest(Self::any(self.nesting, self.context.clone())) + } + + fn nest_for(&mut self, nested_name: OperationName) -> NestedOpPassManager<'_> { + self.nest(Self::for_operation(nested_name, self.nesting, self.context.clone())) + } + + pub fn add_pass(&mut self, pass: Box) { + // If this pass runs on a different operation than this pass manager, then implicitly + // nest a pass manager for this operation if enabled. + let pass_op_name = pass.target_name(&self.context); + if let Some(pass_op_name) = pass_op_name { + if self.name.as_ref().is_some_and(|name| name != &pass_op_name) { + if matches!(self.nesting, Nesting::Implicit) { + let mut nested = self.nest_for(pass_op_name); + nested.add_pass(pass); + return; + } + panic!( + "cannot add pass '{}' restricted to '{pass_op_name}' to a pass manager \ + intended to run on '{}', did you intend to nest?", + pass.name(), + self.name().unwrap(), + ); + } + } + + self.passes.push(pass); + } + + pub fn add_nested_pass(&mut self, pass: Box) { + let name = ::name(); + let mut nested = self.nest(Self::new(name.as_str(), self.nesting, self.context.clone())); + nested.add_pass(pass); + } + + pub fn finalize_pass_list(&mut self) -> Result<(), Report> { + /* + auto finalizeAdaptor = [ctx](OpToOpPassAdaptor *adaptor) { + for (auto &pm : adaptor->getPassManagers()) + if (failed(pm.getImpl().finalizePassList(ctx))) + return failure(); + return success(); + }; + + // Walk the pass list and merge adjacent adaptors. + OpToOpPassAdaptor *lastAdaptor = nullptr; + for (auto &pass : passes) { + // Check to see if this pass is an adaptor. + if (auto *currentAdaptor = dyn_cast(pass.get())) { + // If it is the first adaptor in a possible chain, remember it and + // continue. + if (!lastAdaptor) { + lastAdaptor = currentAdaptor; + continue; + } + + // Otherwise, try to merge into the existing adaptor and delete the + // current one. If merging fails, just remember this as the last adaptor. + if (succeeded(currentAdaptor->tryMergeInto(ctx, *lastAdaptor))) + pass.reset(); + else + lastAdaptor = currentAdaptor; + } else if (lastAdaptor) { + // If this pass isn't an adaptor, finalize it and forget the last adaptor. + if (failed(finalizeAdaptor(lastAdaptor))) + return failure(); + lastAdaptor = nullptr; + } + } + + // If there was an adaptor at the end of the manager, finalize it as well. + if (lastAdaptor && failed(finalizeAdaptor(lastAdaptor))) + return failure(); + + // Now that the adaptors have been merged, erase any empty slots corresponding + // to the merged adaptors that were nulled-out in the loop above. + llvm::erase_if(passes, std::logical_not>()); + + // If this is a op-agnostic pass manager, there is nothing left to do. + std::optional rawOpName = getOpName(*ctx); + if (!rawOpName) + return success(); + + // Otherwise, verify that all of the passes are valid for the current + // operation anchor. + std::optional opName = + rawOpName->getRegisteredInfo(); + for (std::unique_ptr &pass : passes) { + if (opName && !pass->canScheduleOn(*opName)) { + return emitError(UnknownLoc::get(ctx)) + << "unable to schedule pass '" << pass->getName() + << "' on a PassManager intended to run on '" << getOpAnchorName() + << "'!"; + } + } + return success(); + */ + todo!() + } + + pub fn name(&self) -> Option<&OperationName> { + self.name.as_ref() + } + + pub fn set_nesting(&mut self, nesting: Nesting) { + self.nesting = nesting; + } + + pub fn nesting(&self) -> Nesting { + self.nesting + } + + /// Indicate if this pass manager can be scheduled on the given operation + pub fn can_schedule_on(&self, name: &OperationName) -> bool { + // If this pass manager is op-specific, we simply check if the provided operation name + // is the same as this one. + if let Some(op_name) = self.name() { + return op_name == name; + } + + // Otherwise, this is an op-agnostic pass manager. Check that the operation can be + // scheduled on all passes within the manager. + if !name.implements::() { + return false; + } + self.passes.iter().all(|pass| pass.can_schedule_on(name)) + } + + fn initialize(&mut self) -> Result<(), Report> { + for pass in self.passes.iter_mut() { + // If this pass isn't an adaptor, directly initialize it + if let Some(adaptor) = pass.as_any_mut().downcast_mut::() { + for pm in adaptor.pass_managers_mut() { + pm.initialize()?; + } + } else { + pass.initialize(self.context.clone())?; + } + } + + Ok(()) + } + + #[allow(unused)] + fn merge_into(&mut self, rhs: &mut Self) { + assert_eq!(self.name, rhs.name, "merging unrelated pass managers"); + for pass in self.passes.drain(..) { + rhs.passes.push(pass); + } + } + + pub fn nest(&mut self, nested: Self) -> NestedOpPassManager<'_> { + let adaptor = Box::new(OpToOpPassAdaptor::new(nested)); + NestedOpPassManager { + parent: self, + nested: Some(adaptor), + } + } + + pub fn print_statistics( + &self, + out: &mut dyn core::fmt::Write, + display_mode: PassDisplayMode, + ) -> core::fmt::Result { + const PASS_STATS_DESCRIPTION: &str = "... Pass statistics report ..."; + + // Print the stats header. + writeln!(out, "=={:-<73}==", "")?; + // Figure out how many spaces for the description name. + let padding = 80usize.saturating_sub(PASS_STATS_DESCRIPTION.len()); + writeln!(out, "{: <1$}", PASS_STATS_DESCRIPTION, padding)?; + writeln!(out, "=={:-<73}==", "")?; + + // Defer to a specialized printer for each display mode. + match display_mode { + PassDisplayMode::List => self.print_statistics_as_list(out), + PassDisplayMode::Pipeline => self.print_statistics_as_pipeline(out), + } + } + + fn add_stats( + pass: &dyn OperationPass, + merged_stats: &mut BTreeMap<&str, SmallVec<[Box; 4]>>, + ) { + use alloc::collections::btree_map::Entry; + + if let Some(adaptor) = pass.as_any().downcast_ref::() { + // Recursively add each of the children. + for pass_manager in adaptor.pass_managers() { + for pass in pass_manager.passes() { + Self::add_stats(&**pass, merged_stats); + } + } + } else { + // If this is not an adaptor, add the stats to the list if there are any. + if !pass.has_statistics() { + return; + } + let statistics = SmallVec::<[Box; 4]>::from_iter( + pass.statistics().iter().map(|stat| Statistic::clone(&**stat)), + ); + match merged_stats.entry(pass.name()) { + Entry::Vacant(entry) => { + entry.insert(statistics); + } + Entry::Occupied(mut entry) => { + let prev_stats = entry.get_mut(); + assert_eq!(prev_stats.len(), statistics.len()); + for (index, mut stat) in statistics.into_iter().enumerate() { + let _ = prev_stats[index].try_merge(&mut stat); + } + } + } + } + } + + /// Print the statistics results in a list form, where each pass is sorted by name. + fn print_statistics_as_list(&self, out: &mut dyn core::fmt::Write) -> core::fmt::Result { + let mut merged_stats = BTreeMap::<&str, SmallVec<[Box; 4]>>::default(); + for pass in self.passes.iter() { + Self::add_stats(&**pass, &mut merged_stats); + } + + // Print the timing information sequentially. + for (pass, stats) in merged_stats.iter() { + self.print_pass_entry(out, 2, pass, stats)?; + } + + Ok(()) + } + + fn print_statistics_as_pipeline(&self, _out: &mut dyn core::fmt::Write) -> core::fmt::Result { + todo!() + } + + fn print_pass_entry( + &self, + out: &mut dyn core::fmt::Write, + indent: usize, + pass: &str, + stats: &[Box], + ) -> core::fmt::Result { + use core::fmt::Write; + + writeln!(out, "{: <1$}", pass, indent)?; + if stats.is_empty() { + return Ok(()); + } + + // Collect the largest name and value length from each of the statistics. + + struct Rendered<'a> { + name: &'a str, + description: &'a str, + value: compact_str::CompactString, + } + + let mut largest_name = 0usize; + let mut largest_value = 0usize; + let mut rendered_stats = SmallVec::<[Rendered; 4]>::default(); + for stat in stats { + let mut value = compact_str::CompactString::default(); + let doc = stat.pretty_print(); + write!(&mut value, "{doc}")?; + let name = stat.name(); + largest_name = core::cmp::max(largest_name, name.len()); + largest_value = core::cmp::max(largest_value, value.len()); + rendered_stats.push(Rendered { + name, + description: stat.description(), + value, + }); + } + + // Sort the statistics by name. + rendered_stats.sort_by(|a, b| a.name.cmp(b.name)); + + // Print statistics + for stat in rendered_stats { + write!(out, "{: <1$} (S) ", "", indent)?; + write!(out, "{: <1$} ", &stat.value, largest_value)?; + write!(out, "{: <1$}", &stat.name, largest_name)?; + if stat.description.is_empty() { + out.write_char('\n')?; + } else { + writeln!(out, " - {}", &stat.description)?; + } + } + + Ok(()) + } +} + +pub struct NestedOpPassManager<'parent> { + parent: &'parent mut OpPassManager, + nested: Option>, +} + +impl<'parent> core::ops::Deref for NestedOpPassManager<'parent> { + type Target = OpPassManager; + + fn deref(&self) -> &Self::Target { + &self.nested.as_deref().unwrap().pass_managers()[0] + } +} +impl<'parent> core::ops::DerefMut for NestedOpPassManager<'parent> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.nested.as_deref_mut().unwrap().pass_managers_mut()[0] + } +} + +impl<'parent> Drop for NestedOpPassManager<'parent> { + fn drop(&mut self) { + self.parent.add_pass(self.nested.take().unwrap() as Box); + } +} + +pub struct OpToOpPassAdaptor { + pms: SmallVec<[OpPassManager; 1]>, +} +impl OpToOpPassAdaptor { + pub fn new(pm: OpPassManager) -> Self { + Self { pms: smallvec![pm] } + } + + pub fn name(&self) -> CompactString { + use core::fmt::Write; + + let mut name = CompactString::default(); + let names = + crate::formatter::DisplayValues::new(self.pms.iter().map(|pm| match pm.name() { + None => alloc::borrow::Cow::Borrowed(OpPassManager::ANY), + Some(name) => alloc::borrow::Cow::Owned(name.to_string()), + })); + write!(&mut name, "Pipeline Collection: [{names}]").unwrap(); + name + } + + /// Try to merge the current pass adaptor into 'rhs'. + /// + /// This will try to append the pass managers of this adaptor into those within `rhs`, or return + /// failure if merging isn't possible. The main situation in which merging is not possible is if + /// one of the adaptors has an `any` pipeline that is not compatible with a pass manager in the + /// other adaptor. For example, if this adaptor has a `hir.function` pipeline and `rhs` has an + /// `any` pipeline that operates on a FunctionOpInterface. In this situation the pipelines have + /// a conflict (they both want to run on the same operations), so we can't merge. + #[allow(unused)] + pub fn try_merge_into(&mut self, _rhs: &mut Self) -> bool { + todo!() + } + + pub fn pass_managers(&self) -> &[OpPassManager] { + &self.pms + } + + pub fn pass_managers_mut(&mut self) -> &mut [OpPassManager] { + &mut self.pms + } + + /// Run the given operation and analysis manager on a provided op pass manager. + fn run_pipeline( + pm: &mut OpPassManager, + op: OperationRef, + analysis_manager: AnalysisManager, + verify: bool, + instrumentor: Option>, + parent_info: Option<&PipelineParentInfo>, + ) -> Result<(), Report> { + assert!( + instrumentor.is_none() || parent_info.is_some(), + "expected parent info if instrumentor is provided" + ); + + // Clear out any computed operation analyses on exit. + // + // These analyses won't be used anymore in this pipeline, and this helps reduce the + // current working set of memory. If preserving these analyses becomes important in the + // future, we can re-evaluate. + let _clear = analysis_manager.defer_clear(); + + // Run the pipeline over the provided operation. + let mut op_name = None; + if let Some(instrumentor) = instrumentor.as_deref() { + op_name = pm.name().cloned(); + instrumentor.run_before_pipeline(op_name.as_ref(), parent_info.as_ref().unwrap()); + } + + for pass in pm.passes_mut() { + Self::run(&mut **pass, op.clone(), analysis_manager.clone(), verify)?; + } + + if let Some(instrumentor) = instrumentor.as_deref() { + instrumentor.run_after_pipeline(op_name.as_ref(), parent_info.as_ref().unwrap()); + } + + Ok(()) + } + + /// Run the given operation and analysis manager on a single pass. + fn run( + pass: &mut dyn OperationPass, + op: OperationRef, + analysis_manager: AnalysisManager, + verify: bool, + ) -> Result<(), Report> { + use crate::Spanned; + + let (op_name, span, context) = { + let op = op.borrow(); + (op.name(), op.span(), op.context_rc()) + }; + if !op_name.implements::() { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("failed to execute pass") + .with_primary_label( + span, + "trying to schedule a pass on an operation which does not implement \ + `IsolatedFromAbove`", + ) + .into_report()); + } + if !pass.can_schedule_on(&op_name) { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("failed to execute pass") + .with_primary_label(span, "trying to schedule a pass on an unsupported operation") + .into_report()); + } + + // Initialize the pass state with a callback for the pass to dynamically execute a pipeline + // on the currently visited operation. + let pi = analysis_manager.pass_instrumentor(); + let parent_info = PipelineParentInfo { + pass: Some(pass.name().to_compact_string()), + }; + let callback_op = op.clone(); + let callback_analysis_manager = analysis_manager.clone(); + let pipeline_callback: Box = Box::new( + move |pipeline: &mut OpPassManager, root: OperationRef| -> Result<(), Report> { + let pi = callback_analysis_manager.pass_instrumentor(); + let context = callback_op.borrow().context_rc(); + let root_op = root.borrow(); + if !root_op.is_ancestor_of(&callback_op) { + return Err(context + .session + .diagnostics + .diagnostic(Severity::Error) + .with_message("failed to execute pass") + .with_primary_label( + root_op.span(), + "trying to schedule a dynamic pass pipeline on an operation that \ + isn't nested under the current operation the pass is processing", + ) + .into_report()); + } + assert!(pipeline.can_schedule_on(&root_op.name())); + // Before running, finalize the passes held by the pipeline + pipeline.finalize_pass_list()?; + + // Initialize the user-provided pipeline and execute the pipeline + pipeline.initialize()?; + + let nested_am = if root == callback_op { + callback_analysis_manager.clone() + } else { + callback_analysis_manager.nest(&root) + }; + Self::run_pipeline(pipeline, root, nested_am, verify, pi, Some(&parent_info)) + }, + ); + + let mut execution_state = PassExecutionState::new( + op.clone(), + context.clone(), + analysis_manager.clone(), + Some(pipeline_callback), + ); + + // Instrument before the pass has run + if let Some(instrumentor) = pi.as_deref() { + instrumentor.run_before_pass(pass, &op); + } + + let mut result = + if let Some(adaptor) = pass.as_any_mut().downcast_mut::() { + adaptor.run_on_operation(op.clone(), &mut execution_state, verify) + } else { + pass.run_on_operation(op.clone(), &mut execution_state) + }; + + // Invalidate any non-preserved analyses + analysis_manager.invalidate(execution_state.preserved_analyses_mut()); + + // When `verify == true`, we run the verifier (unless the pass failed) + if result.is_ok() && verify { + // If the pass is an adaptor pass, we don't run the verifier recursively because the + // nested operations should have already been verified after nested passes had run + let run_verifier_recursively = !pass.as_any().is::(); + + // Reduce compile time by avoiding running the verifier if the pass didn't change the + // IR since the last time the verifier was run: + // + // * If the pass said that it preserved all analyses then it can't have permuted the IR + let run_verifier_now = !execution_state.preserved_analyses().is_all(); + if run_verifier_now { + result = Self::verify(&op, run_verifier_recursively); + } + } + + if let Some(instrumentor) = pi.as_deref() { + if result.is_err() { + instrumentor.run_after_pass_failed(pass, &op); + } else { + instrumentor.run_after_pass(pass, &op); + } + } + + // Return the pass result + result + } + + fn verify(_op: &OperationRef, _verify_recursively: bool) -> Result<(), Report> { + todo!() + } + + fn run_on_operation( + &mut self, + op: OperationRef, + state: &mut PassExecutionState, + verify: bool, + ) -> Result<(), Report> { + let analysis_manager = state.analysis_manager(); + let instrumentor = analysis_manager.pass_instrumentor(); + let parent_info = PipelineParentInfo { + pass: Some(self.name()), + }; + + // Collection region refs so we aren't holding borrows during pass execution + let mut next_region = op.borrow().regions().back().as_pointer(); + while let Some(region) = next_region.take() { + next_region = region.next(); + let mut next_block = region.borrow().body().front().as_pointer(); + while let Some(block) = next_block.take() { + next_block = block.next(); + let mut next_op = block.borrow().front(); + while let Some(op) = next_op.take() { + next_op = op.next(); + let op_name = op.borrow().name(); + if let Some(manager) = + self.pms.iter_mut().find(|pm| pm.can_schedule_on(&op_name)) + { + let am = analysis_manager.nest(&op); + Self::run_pipeline( + manager, + op, + am, + verify, + instrumentor.clone(), + Some(&parent_info), + )?; + } + } + } + } + + Ok(()) + } +} + +impl Pass for OpToOpPassAdaptor { + type Target = Operation; + + fn name(&self) -> &'static str { + crate::interner::Symbol::intern(self.name()).as_str() + } + + #[inline(always)] + fn target_name(&self, _context: &Context) -> Option { + None + } + + fn print_as_textual_pipeline(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "{}", &self.name()) + } + + #[inline(always)] + fn statistics(&self) -> &[Box] { + &[] + } + + #[inline(always)] + fn statistics_mut(&mut self) -> &mut [Box] { + &mut [] + } + + #[inline(always)] + fn can_schedule_on(&self, _name: &OperationName) -> bool { + true + } + + fn run_on_operation( + &mut self, + _op: EntityMut<'_, Operation>, + _state: &mut PassExecutionState, + ) -> Result<(), Report> { + unreachable!("unexpected call to `Pass::run_on_operation` for OpToOpPassAdaptor") + } + + fn run_pipeline( + &mut self, + _pipeline: &mut OpPassManager, + _op: OperationRef, + _state: &mut PassExecutionState, + ) -> Result<(), Report> { + todo!() + } +} diff --git a/hir2/src/pass/pass.rs b/hir2/src/pass/pass.rs new file mode 100644 index 000000000..8cf728b87 --- /dev/null +++ b/hir2/src/pass/pass.rs @@ -0,0 +1,371 @@ +use alloc::rc::Rc; +use core::{any::Any, fmt}; + +use super::*; +use crate::{Context, EntityMut, OperationName, OperationRef, Report}; + +/// A type-erased [Pass]. +/// +/// This is used to allow heterogenous passes to be operated on uniformly. +/// +/// Semantically, an [OperationPass] behaves like a `Pass`. +#[allow(unused_variables)] +pub trait OperationPass { + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + fn name(&self) -> &'static str; + fn argument(&self) -> &'static str { + // NOTE: Could we compute an argument string from the type name? + "" + } + fn description(&self) -> &'static str { + "" + } + fn info(&self) -> PassInfo { + PassInfo::lookup(self.argument()).expect("could not find pass information") + } + /// The name of the operation that this pass operates on, or `None` if this is a generic pass. + fn target_name(&self, context: &Context) -> Option; + fn initialize_options(&mut self, options: &str) -> Result<(), Report> { + Ok(()) + } + fn print_as_textual_pipeline(&self, f: &mut fmt::Formatter) -> fmt::Result; + fn has_statistics(&self) -> bool { + !self.statistics().is_empty() + } + fn statistics(&self) -> &[Box]; + fn statistics_mut(&mut self) -> &mut [Box]; + fn initialize(&mut self, context: Rc) -> Result<(), Report> { + Ok(()) + } + fn can_schedule_on(&self, name: &OperationName) -> bool; + fn run_on_operation( + &mut self, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report>; + fn run_pipeline( + &mut self, + pipeline: &mut OpPassManager, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report>; +} + +impl

OperationPass for P +where + P: Pass + 'static, +{ + fn as_any(&self) -> &dyn Any { +

::as_any(self) + } + + fn as_any_mut(&mut self) -> &mut dyn Any { +

::as_any_mut(self) + } + + fn name(&self) -> &'static str { +

::name(self) + } + + fn argument(&self) -> &'static str { +

::argument(self) + } + + fn description(&self) -> &'static str { +

::description(self) + } + + fn info(&self) -> PassInfo { +

::info(self) + } + + fn target_name(&self, context: &Context) -> Option { +

::target_name(self, context) + } + + fn initialize_options(&mut self, options: &str) -> Result<(), Report> { +

::initialize_options(self, options) + } + + fn print_as_textual_pipeline(&self, f: &mut fmt::Formatter) -> fmt::Result { +

::print_as_textual_pipeline(self, f) + } + + fn has_statistics(&self) -> bool { +

::has_statistics(self) + } + + fn statistics(&self) -> &[Box] { +

::statistics(self) + } + + fn statistics_mut(&mut self) -> &mut [Box] { +

::statistics_mut(self) + } + + fn initialize(&mut self, context: Rc) -> Result<(), Report> { +

::initialize(self, context) + } + + fn can_schedule_on(&self, name: &OperationName) -> bool { +

::can_schedule_on(self, name) + } + + fn run_on_operation( + &mut self, + mut op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report> { + let op = <

::Target as PassTarget>::into_target_mut(&mut op); +

::run_on_operation(self, op, state) + } + + fn run_pipeline( + &mut self, + pipeline: &mut OpPassManager, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report> { +

::run_pipeline(self, pipeline, op, state) + } +} + +/// A compiler pass which operates on an [Operation] of some kind. +#[allow(unused_variables)] +pub trait Pass: Sized + Any { + /// The concrete/trait type targeted by this pass. + /// + /// Calls to `get_operation` will return a reference of this type. + type Target: ?Sized + PassTarget; + + /// Used for downcasting + #[inline(always)] + fn as_any(&self) -> &dyn Any { + self as &dyn Any + } + + /// Used for downcasting + #[inline(always)] + fn as_any_mut(&mut self) -> &mut dyn Any { + self as &mut dyn Any + } + + /// The display name of this pass + fn name(&self) -> &'static str; + /// The command line option name used to control this pass + fn argument(&self) -> &'static str { + // NOTE: Could we compute an argument string from the type name? + "" + } + /// A description of what this pass does. + fn description(&self) -> &'static str { + "" + } + /// Obtain the underlying [PassInfo] object for this pass. + fn info(&self) -> PassInfo { + PassInfo::lookup(self.argument()).expect("pass is not currently registered") + } + /// The name of the operation that this pass operates on, or `None` if this is a generic pass. + fn target_name(&self, context: &Context) -> Option { + <::Target as PassTarget>::target_name(context) + } + /// If command-line options are provided for this pass, implementations must parse the raw + /// options here, returning `Err` if parsing fails for some reason. + /// + /// By default, this is a no-op. + fn initialize_options(&mut self, options: &str) -> Result<(), Report> { + Ok(()) + } + /// Print this pass as a textual pipeline. + fn print_as_textual_pipeline(&self, f: &mut fmt::Formatter) -> fmt::Result; + /// Returns true if this pass has associated statistics + fn has_statistics(&self) -> bool { + !self.statistics().is_empty() + } + /// Get pass statistics associated with this pass + fn statistics(&self) -> &[Box]; + /// Get mutable access to the pass statistics associated with this pass + fn statistics_mut(&mut self) -> &mut [Box]; + /// Initialize any complex state necessary for running this pass. + /// + /// This hook should not rely on any state accessible during the execution of a pass. For + /// example, `context`/`get_operation`/`get_analysis`/etc. should not be invoked within this + /// hook. + /// + /// This method is invoked after all dependent dialects for the pipeline are loaded, and is not + /// allowed to load any further dialects (override the `get_dependent_dialects()` hook for this + /// purpose instead). Returns `Err` with a diagnostic if initialization fails, in which case the + /// pass pipeline won't execute. + fn initialize(&mut self, context: Rc) -> Result<(), Report> { + Ok(()) + } + /// Query if this pass can be scheduled to run on the given operation type. + fn can_schedule_on(&self, name: &OperationName) -> bool; + /// Run this pass on the current operation + fn run_on_operation( + &mut self, + op: EntityMut<'_, Self::Target>, + state: &mut PassExecutionState, + ) -> Result<(), Report>; + /// Run this pass as part of `pipeline` on `op` + fn run_pipeline( + &mut self, + pipeline: &mut OpPassManager, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report>; +} + +impl

Pass for Box

+where + P: Pass, +{ + type Target =

::Target; + + fn as_any(&self) -> &dyn Any { + (**self).as_any() + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + (**self).as_any_mut() + } + + #[inline] + fn name(&self) -> &'static str { + (**self).name() + } + + #[inline] + fn argument(&self) -> &'static str { + (**self).argument() + } + + #[inline] + fn description(&self) -> &'static str { + (**self).description() + } + + #[inline] + fn info(&self) -> PassInfo { + (**self).info() + } + + #[inline] + fn target_name(&self, context: &Context) -> Option { + (**self).target_name(context) + } + + #[inline] + fn initialize_options(&mut self, options: &str) -> Result<(), Report> { + (**self).initialize_options(options) + } + + #[inline] + fn print_as_textual_pipeline(&self, f: &mut fmt::Formatter) -> fmt::Result { + (**self).print_as_textual_pipeline(f) + } + + #[inline] + fn has_statistics(&self) -> bool { + (**self).has_statistics() + } + + #[inline] + fn statistics(&self) -> &[Box] { + (**self).statistics() + } + + #[inline] + fn statistics_mut(&mut self) -> &mut [Box] { + (**self).statistics_mut() + } + + #[inline] + fn initialize(&mut self, context: Rc) -> Result<(), Report> { + (**self).initialize(context) + } + + #[inline] + fn can_schedule_on(&self, name: &OperationName) -> bool { + (**self).can_schedule_on(name) + } + + #[inline] + fn run_on_operation( + &mut self, + op: EntityMut<'_, Self::Target>, + state: &mut PassExecutionState, + ) -> Result<(), Report> { + (**self).run_on_operation(op, state) + } + + #[inline] + fn run_pipeline( + &mut self, + pipeline: &mut OpPassManager, + op: OperationRef, + state: &mut PassExecutionState, + ) -> Result<(), Report> { + (**self).run_pipeline(pipeline, op, state) + } +} + +pub type DynamicPipelineExecutor = + dyn FnMut(&mut OpPassManager, OperationRef) -> Result<(), Report>; + +/// The state for a single execution of a pass. This provides a unified +/// interface for accessing and initializing necessary state for pass execution. +pub struct PassExecutionState { + /// The operation being transformed + op: OperationRef, + context: Rc, + analysis_manager: AnalysisManager, + /// The set of preserved analyses for the current execution + preserved_analyses: PreservedAnalyses, + // Callback in the pass manager that allows one to schedule dynamic pipelines that will be + // rooted at the provided operation. + #[allow(unused)] + pipeline_executor: Option>, +} +impl PassExecutionState { + pub fn new( + op: OperationRef, + context: Rc, + analysis_manager: AnalysisManager, + pipeline_executor: Option>, + ) -> Self { + Self { + op, + context, + analysis_manager, + preserved_analyses: Default::default(), + pipeline_executor, + } + } + + #[inline(always)] + pub fn context(&self) -> Rc { + self.context.clone() + } + + #[inline(always)] + pub const fn current_operation(&self) -> &OperationRef { + &self.op + } + + #[inline(always)] + pub const fn analysis_manager(&self) -> &AnalysisManager { + &self.analysis_manager + } + + #[inline(always)] + pub const fn preserved_analyses(&self) -> &PreservedAnalyses { + &self.preserved_analyses + } + + #[inline(always)] + pub fn preserved_analyses_mut(&mut self) -> &mut PreservedAnalyses { + &mut self.preserved_analyses + } +} diff --git a/hir2/src/pass/registry.rs b/hir2/src/pass/registry.rs new file mode 100644 index 000000000..88b68b87f --- /dev/null +++ b/hir2/src/pass/registry.rs @@ -0,0 +1,397 @@ +use alloc::{collections::BTreeMap, sync::Arc}; +use core::any::TypeId; + +use midenc_hir_symbol::sync::{LazyLock, RwLock}; +use midenc_session::diagnostics::DiagnosticsHandler; + +use super::*; +use crate::Report; + +static PASS_REGISTRY: LazyLock = LazyLock::new(PassRegistry::new); + +/// A global, thread-safe pass and pass pipeline registry +/// +/// You should generally _not_ need to work with this directly. +pub struct PassRegistry { + passes: RwLock>, + pipelines: RwLock>, +} +impl Default for PassRegistry { + fn default() -> Self { + Self::new() + } +} +impl PassRegistry { + /// Create a new [PassRegistry] instance. + pub fn new() -> Self { + let mut passes = BTreeMap::default(); + let mut pipelines = BTreeMap::default(); + for pass in inventory::iter::() { + passes.insert( + pass.0.arg, + PassRegistryEntry { + arg: pass.0.arg, + description: pass.0.description, + type_id: pass.0.type_id, + builder: Arc::clone(&pass.0.builder), + }, + ); + } + for pipeline in inventory::iter::() { + pipelines.insert( + pipeline.0.arg, + PassRegistryEntry { + arg: pipeline.0.arg, + description: pipeline.0.description, + type_id: pipeline.0.type_id, + builder: Arc::clone(&pipeline.0.builder), + }, + ); + } + + Self { + passes: RwLock::new(passes), + pipelines: RwLock::new(pipelines), + } + } + + /// Get the pass information for the pass whose argument name is `name` + pub fn get_pass(&self, name: &str) -> Option { + self.passes.read().get(name).cloned().map(PassInfo) + } + + /// Get the pass pipeline information for the pipeline whose argument name is `name` + pub fn get_pipeline(&self, name: &str) -> Option { + self.pipelines.read().get(name).cloned().map(PassPipelineInfo) + } + + /// Register the given pass + pub fn register_pass(&self, info: PassInfo) { + use alloc::collections::btree_map::Entry; + + let mut passes = self.passes.write(); + match passes.entry(info.argument()) { + Entry::Vacant(entry) => { + entry.insert(info.0); + } + Entry::Occupied(entry) => { + assert_eq!( + entry.get().type_id, + info.0.type_id, + "cannot register pass '{}': name already registered by a different type", + info.argument() + ); + } + } + } + + /// Register the given pass pipeline + pub fn register_pipeline(&self, info: PassPipelineInfo) { + use alloc::collections::btree_map::Entry; + + let mut pipelines = self.pipelines.write(); + match pipelines.entry(info.argument()) { + Entry::Vacant(entry) => { + entry.insert(info.0); + } + Entry::Occupied(entry) => { + assert_eq!( + entry.get().type_id, + info.0.type_id, + "cannot register pass pipeline '{}': name already registered by a different \ + type", + info.argument() + ); + assert!(Arc::ptr_eq(&entry.get().builder, &info.0.builder)) + } + } + } +} + +inventory::collect!(PassInfo); +inventory::collect!(PassPipelineInfo); + +/// A type alias for the closure type for registering a pass with a pass manager +pub type PassRegistryFunction = dyn Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + + Send + + Sync + + 'static; + +/// A type alias for the closure type used for type-erased pass constructors +pub type PassAllocatorFunction = dyn Fn() -> Box; + +/// A [RegistryEntry] is a registered pass or pass pipeline. +/// +/// This trait provides the common functionality shared by both passes and pipelines. +pub trait RegistryEntry { + /// Returns the command-line option that may be passed to `midenc` that will cause this pass + /// or pass pipeline to run. + fn argument(&self) -> &'static str; + /// Return a description for the pass or pass pipeline. + fn description(&self) -> &'static str; + /// Adds this entry to the given pass manager. + /// + /// Note: `options` is an opaque string that will be parsed by the builder. + /// + /// Returns `Err` if an error occurred parsing the given options. + fn add_to_pipeline( + &self, + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), Report>; +} + +/// Information about a pass or pass pipeline in the pass registry +#[derive(Clone)] +struct PassRegistryEntry { + /// The name of the compiler option for referencing on the command line + arg: &'static str, + /// A description of the pass or pass pipeline + description: &'static str, + /// The type id of the concrete pass type + type_id: Option, + /// Function that registers this entry with a pass manager pipeline + builder: Arc, +} +impl RegistryEntry for PassRegistryEntry { + #[inline] + fn add_to_pipeline( + &self, + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), Report> { + (self.builder)(pm, options, diagnostics) + } + + #[inline(always)] + fn argument(&self) -> &'static str { + self.arg + } + + #[inline(always)] + fn description(&self) -> &'static str { + self.description + } +} + +/// Information about a registered pass pipeline +pub struct PassPipelineInfo(PassRegistryEntry); +impl PassPipelineInfo { + pub fn new(arg: &'static str, description: &'static str, builder: B) -> Self + where + B: Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + + Send + + Sync + + 'static, + { + Self(PassRegistryEntry { + arg, + description, + type_id: None, + builder: Arc::new(builder), + }) + } + + /// Find the [PassInfo] for a registered pass pipeline named `name` + pub fn lookup(name: &str) -> Option { + PASS_REGISTRY.get_pipeline(name) + } +} +impl RegistryEntry for PassPipelineInfo { + fn argument(&self) -> &'static str { + self.0.argument() + } + + fn description(&self) -> &'static str { + self.0.description() + } + + fn add_to_pipeline( + &self, + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), Report> { + self.0.add_to_pipeline(pm, options, diagnostics) + } +} + +/// Information about a registered pass +pub struct PassInfo(PassRegistryEntry); +impl PassInfo { + /// Create a new [PassInfo] from the given argument name and description, for a default- + /// constructible pass type `P`. + pub fn new(arg: &'static str, description: &'static str) -> Self { + let type_id = TypeId::of::

(); + Self(PassRegistryEntry { + arg, + description, + type_id: Some(type_id), + builder: Arc::new(default_registration::

), + }) + } + + /// Find the [PassInfo] for a registered pass named `name` + pub fn lookup(name: &str) -> Option { + PASS_REGISTRY.get_pass(name) + } +} +impl RegistryEntry for PassInfo { + fn argument(&self) -> &'static str { + self.0.argument() + } + + fn description(&self) -> &'static str { + self.0.description() + } + + fn add_to_pipeline( + &self, + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), Report> { + self.0.add_to_pipeline(pm, options, diagnostics) + } +} + +/// Register a specific dialect pipeline registry function with the system. +/// +/// # Example +/// +/// If your pipeline implements the [Default] trait, you can just do: +/// +/// ```text,ignore +/// register_pass_pipeline( +/// "my-pipeline", +/// "A simple test pipeline", +/// default_registration::(), +/// ) +/// ``` +/// +/// Otherwise, you need to pass a factor function which will be used to construct fresh instances +/// of the pipeline: +/// +/// ```text,ignore +/// register_pass_pipeline( +/// "my-pipeline", +/// "A simple test pipeline", +/// default_dyn_registration(|| MyPipeline::new(MyPipelineOptions::default())), +/// ) +/// ``` +/// +/// NOTE: The functions/closures passed above are required to be `Send + Sync + 'static`, as they +/// are stored in the global registry for the lifetime of the program, and may be accessed from any +/// thread. +pub fn register_pass_pipeline(arg: &'static str, description: &'static str, builder: B) +where + B: Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + + Send + + Sync + + 'static, +{ + PASS_REGISTRY.register_pipeline(PassPipelineInfo(PassRegistryEntry { + arg, + description, + type_id: None, + builder: Arc::new(builder), + })); +} + +/// Register a specific dialect pass allocator function with the system. +/// +/// # Example +/// +/// ```text,ignore +/// register_pass(|| MyPass::default()) +/// ``` +/// +/// NOTE: The allocator function provided is required to be `Send + Sync + 'static`, as it is +/// stored in the global registry for the lifetime of the program, and may be accessed from any +/// thread. +pub fn register_pass(ctor: impl Fn() -> Box + Send + Sync + 'static) { + let pass = ctor(); + let type_id = pass.as_any().type_id(); + let arg = pass.argument(); + assert!( + !arg.is_empty(), + "attempted to register pass '{}' without specifying an argument name", + pass.name() + ); + let description = pass.description(); + PASS_REGISTRY.register_pass(PassInfo(PassRegistryEntry { + arg, + description, + type_id: Some(type_id), + builder: Arc::new(default_registration_factory(ctor)), + })); +} + +/// A default implementation of a pass pipeline registration function. +/// +/// It expects that `P` (the type of the pass or pass pipeline), implements `Default`, so that an +/// instance is default-constructible. It then initializes the pass with the provided options, +/// validates that the pass/pipeline is valid for the parent pipeline, and adds it if so. +pub fn default_registration( + pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler, +) -> Result<(), Report> { + use midenc_session::diagnostics::Severity; + + let mut pass = Box::

::default() as Box; + let result = pass.initialize_options(options); + let pm_op_name = pm.name(); + let pass_op_name = pass.target_name(&pm.context()); + let pass_op_name = pass_op_name.as_ref(); + if matches!(pm.nesting(), Nesting::Explicit) && pm_op_name != pass_op_name { + return Err(diagnostics + .diagnostic(Severity::Error) + .with_message(format!( + "registration error for pass '{}': can't add pass restricted to '{}' on a pass \ + manager intended to run on '{}', did you intend to nest?", + pass.name(), + crate::formatter::DisplayOptional(pass_op_name.as_ref()), + crate::formatter::DisplayOptional(pm_op_name), + )) + .into_report()); + } + pm.add_pass(pass); + result +} + +/// Like [default_registration], but takes an arbitrary constructor in the form of a zero-arity +/// closure, rather than relying on [Default]. Thus, this is actually a registration function +/// _factory_, rather than a registration function itself. +pub fn default_registration_factory Box + Send + Sync + 'static>( + builder: B, +) -> impl Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + Send + Sync + 'static +{ + use midenc_session::diagnostics::Severity; + move |pm: &mut OpPassManager, + options: &str, + diagnostics: &DiagnosticsHandler| + -> Result<(), Report> { + let mut pass = builder(); + let result = pass.initialize_options(options); + let pm_op_name = pm.name(); + let pass_op_name = pass.target_name(&pm.context()); + let pass_op_name = pass_op_name.as_ref(); + if matches!(pm.nesting(), Nesting::Explicit) && pm_op_name != pass_op_name { + return Err(diagnostics + .diagnostic(Severity::Error) + .with_message(format!( + "registration error for pass '{}': can't add pass restricted to '{}' on a \ + pass manager intended to run on '{}', did you intend to nest?", + pass.name(), + crate::formatter::DisplayOptional(pass_op_name.as_ref()), + crate::formatter::DisplayOptional(pm_op_name), + )) + .into_report()); + } + pm.add_pass(pass); + result + } +} diff --git a/hir2/src/pass/specialization.rs b/hir2/src/pass/specialization.rs new file mode 100644 index 000000000..4c10a378f --- /dev/null +++ b/hir2/src/pass/specialization.rs @@ -0,0 +1,142 @@ +use crate::{ + traits::BranchOpInterface, Context, EntityMut, EntityRef, Op, Operation, OperationName, + OperationRef, Symbol, SymbolTable, +}; + +pub trait PassTarget { + fn target_name(context: &Context) -> Option; + fn into_target(op: &OperationRef) -> EntityRef<'_, Self>; + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, Self>; +} + +impl PassTarget for T { + default fn target_name(_context: &Context) -> Option { + None + } + + #[inline] + default fn into_target(op: &OperationRef) -> EntityRef<'_, T> { + EntityRef::map(op.borrow(), |t| { + t.downcast_ref::().unwrap_or_else(|| expected_type::(op)) + }) + } + + #[inline] + default fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, T> { + EntityMut::map(op.borrow_mut(), |t| { + t.downcast_mut::().unwrap_or_else(|| expected_type::(op)) + }) + } +} +impl PassTarget for Operation { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + #[inline] + fn into_target(op: &OperationRef) -> EntityRef<'_, Operation> { + op.borrow() + } + + #[inline] + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, Operation> { + op.borrow_mut() + } +} +impl PassTarget for dyn Op { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + fn into_target(op: &OperationRef) -> EntityRef<'_, dyn Op> { + EntityRef::map(op.borrow(), |op| op.as_trait::().unwrap()) + } + + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, dyn Op> { + EntityMut::map(op.borrow_mut(), |op| op.as_trait_mut::().unwrap()) + } +} +impl PassTarget for dyn BranchOpInterface { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + fn into_target(op: &OperationRef) -> EntityRef<'_, dyn BranchOpInterface> { + EntityRef::map(op.borrow(), |t| { + t.as_trait::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } + + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, dyn BranchOpInterface> { + EntityMut::map(op.borrow_mut(), |t| { + t.as_trait_mut::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } +} +impl PassTarget for dyn Symbol { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + fn into_target(op: &OperationRef) -> EntityRef<'_, dyn Symbol> { + EntityRef::map(op.borrow(), |t| { + t.as_trait::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } + + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, dyn Symbol> { + EntityMut::map(op.borrow_mut(), |t| { + t.as_trait_mut::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } +} +impl PassTarget for dyn SymbolTable + 'static { + #[inline(always)] + fn target_name(_context: &Context) -> Option { + None + } + + fn into_target(op: &OperationRef) -> EntityRef<'_, dyn SymbolTable + 'static> { + EntityRef::map(op.borrow(), |t| { + t.as_trait::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } + + fn into_target_mut(op: &mut OperationRef) -> EntityMut<'_, dyn SymbolTable + 'static> { + EntityMut::map(op.borrow_mut(), |t| { + t.as_trait_mut::() + .unwrap_or_else(|| expected_implementation::(op)) + }) + } +} + +#[cold] +#[inline(never)] +#[track_caller] +fn expected_type(op: &OperationRef) -> ! { + panic!( + "expected operation '{}' to be a `{}`", + op.borrow().name(), + core::any::type_name::(), + ) +} + +#[cold] +#[inline(never)] +#[track_caller] +fn expected_implementation(op: &OperationRef) -> ! { + panic!( + "expected '{}' to implement `{}`, but no vtable was found", + op.borrow().name(), + core::any::type_name::() + ) +} diff --git a/hir2/src/pass/statistics.rs b/hir2/src/pass/statistics.rs new file mode 100644 index 000000000..042d29c2b --- /dev/null +++ b/hir2/src/pass/statistics.rs @@ -0,0 +1,462 @@ +use core::{any::Any, fmt}; + +use compact_str::CompactString; + +use crate::Report; + +/// A [Statistic] represents some stateful datapoint collected by and across passes. +/// +/// Statistics are named, have a description, and have a value. The value can be pretty printed, +/// and multiple instances of the same statistic can be merged together. +#[derive(Clone)] +pub struct PassStatistic { + pub name: CompactString, + pub description: CompactString, + pub value: V, +} +impl PassStatistic +where + V: StatisticValue, +{ + pub fn new(name: CompactString, description: CompactString, value: V) -> Self { + Self { + name, + description, + value, + } + } +} +impl Eq for PassStatistic {} +impl PartialEq for PassStatistic { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } +} +impl PartialOrd for PassStatistic { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.name.cmp(&other.name)) + } +} +impl Ord for PassStatistic { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.name.cmp(&other.name) + } +} +impl Statistic for PassStatistic +where + V: Clone + StatisticValue + 'static, +{ + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + &self.description + } + + fn pretty_print(&self) -> crate::formatter::Document { + self.value.pretty_print() + } + + fn try_merge(&mut self, other: &mut dyn Any) -> Result<(), Report> { + let lhs = &mut self.value; + if let Some(rhs) = other.downcast_mut::<::Value>() { + lhs.merge(rhs); + Ok(()) + } else { + let name = &self.name; + let expected_ty = core::any::type_name::<::Value>(); + Err(Report::msg(format!( + "could not merge statistic '{name}': expected value of type '{expected_ty}', but \ + got a value of some other type" + ))) + } + } + + fn clone(&self) -> Box { + use core::clone::CloneToUninit; + let mut this = Box::new_uninit(); + unsafe { + self.clone_to_uninit(this.as_mut_ptr()); + this.assume_init() + } + } +} + +/// An abstraction over statistics that allows operating generically over statistics with different +/// types of values. +pub trait Statistic { + /// The display name of this statistic + fn name(&self) -> &str; + /// A description of what this statistic means and why it is significant + fn description(&self) -> &str; + /// Pretty prints this statistic as a value + fn pretty_print(&self) -> crate::formatter::Document; + /// Merges another instance of this statistic into this one, given a mutable reference to the + /// raw underlying value of the other instance. + /// + /// Returns `Err` if `other` is not a valid value type for this statistic + fn try_merge(&mut self, other: &mut dyn Any) -> Result<(), Report>; + /// Clones the underlying statistic + fn clone(&self) -> Box; +} + +pub trait StatisticValue { + type Value: Any + Clone; + + fn value(&self) -> &Self::Value; + fn value_mut(&mut self) -> &mut Self::Value; + fn value_as_any(&self) -> &dyn Any { + self.value() as &dyn Any + } + fn value_as_any_mut(&mut self) -> &mut dyn Any { + self.value_mut() as &mut dyn Any + } + fn expected_type(&self) -> &'static str { + core::any::type_name::<::Value>() + } + fn merge(&mut self, other: &mut Self::Value); + fn pretty_print(&self) -> crate::formatter::Document; +} + +impl dyn StatisticValue { + pub fn downcast_ref(&self) -> Option<&T> { + self.value_as_any().downcast_ref::() + } + + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.value_as_any_mut().downcast_mut::() + } +} + +/// Merges via OR +impl StatisticValue for bool { + type Value = bool; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + *self |= *other; + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(*self) + } +} + +/// A boolean flag which evalutates to true, only if all observed values are false. +/// +/// Defaults to false, and merges by OR. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub struct FlagNone(bool); +impl From for bool { + #[inline(always)] + fn from(flag: FlagNone) -> Self { + !flag.0 + } +} +impl StatisticValue for FlagNone { + type Value = FlagNone; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + if !self.0 && !other.0 { + self.0 = true; + } else { + self.0 ^= other.0 + } + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(bool::from(*self)) + } +} + +/// A boolean flag which evaluates to true, only if at least one true value was observed. +/// +/// Defaults to false, and merges by OR. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct FlagAny(bool); +impl From for bool { + #[inline(always)] + fn from(flag: FlagAny) -> Self { + flag.0 + } +} +impl StatisticValue for FlagAny { + type Value = FlagAny; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + self.0 |= other.0; + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(bool::from(*self)) + } +} + +/// A boolean flag which evaluates to true, only if all observed values were true. +/// +/// Defaults to true, and merges by AND. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct FlagAll(bool); +impl From for bool { + #[inline(always)] + fn from(flag: FlagAll) -> Self { + flag.0 + } +} +impl StatisticValue for FlagAll { + type Value = FlagAll; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + self.0 &= other.0 + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(bool::from(*self)) + } +} + +macro_rules! numeric_statistic { + (#[cfg $($args:tt)*] $int_ty:ty) => { + /// Adds two numbers by saturating addition + #[cfg $($args)*] + impl StatisticValue for $int_ty { + type Value = $int_ty; + fn value(&self) -> &Self::Value { self } + fn value_mut(&mut self) -> &mut Self::Value { self } + fn merge(&mut self, other: &mut Self::Value) { + *self = self.saturating_add(*other); + } + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(*self) + } + } + }; + + (#[cfg $($args:tt)*] $int_ty:ty as $wrapper_ty:ty) => { + /// Adds two numbers by saturating addition + #[cfg $($args)*] + impl StatisticValue for $int_ty { + type Value = $int_ty; + fn value(&self) -> &Self::Value { self } + fn value_mut(&mut self) -> &mut Self::Value { self } + fn merge(&mut self, other: &mut Self::Value) { + *self = self.saturating_add(*other); + } + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(<$wrapper_ty>::from(*self)) + } + } + }; + + ($int_ty:ty) => { + /// Adds two numbers by saturating addition + impl StatisticValue for $int_ty { + type Value = $int_ty; + fn value(&self) -> &Self::Value { self } + fn value_mut(&mut self) -> &mut Self::Value { self } + fn merge(&mut self, other: &mut Self::Value) { + *self = self.saturating_add(*other); + } + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(*self) + } + } + } +} + +numeric_statistic!(u8); +numeric_statistic!(i8); +numeric_statistic!(u16); +numeric_statistic!(i16); +numeric_statistic!(u32); +numeric_statistic!(i32); +numeric_statistic!(u64); +numeric_statistic!(i64); +numeric_statistic!(usize); +numeric_statistic!(isize); +numeric_statistic!( + #[cfg(feature = "std")] + std::time::Duration as midenc_session::HumanDuration +); +numeric_statistic!( + #[cfg(feature = "std")] + midenc_session::HumanDuration +); + +impl StatisticValue for f64 { + type Value = f64; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + *self += *other; + } + + fn pretty_print(&self) -> crate::formatter::Document { + crate::formatter::display(*self) + } +} + +/// Merges an array of statistic values element-wise +impl StatisticValue for [T; N] +where + T: Any + StatisticValue + Clone, +{ + type Value = [T; N]; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + for index in 0..N { + self[index].merge(other[index].value_mut()); + } + } + + fn pretty_print(&self) -> crate::formatter::Document { + use crate::formatter::const_text; + + let doc = const_text("["); + self.iter().enumerate().fold(doc, |mut doc, (i, item)| { + if i > 0 { + doc += const_text(", "); + } + doc + item.pretty_print() + }) + const_text("]") + } +} + +/// Merges two vectors of statistics by appending +impl StatisticValue for Vec +where + T: Any + StatisticValue + Clone, +{ + type Value = Vec; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + self.append(other); + } + + fn pretty_print(&self) -> crate::formatter::Document { + use crate::formatter::const_text; + + let doc = const_text("["); + self.iter().enumerate().fold(doc, |mut doc, (i, item)| { + if i > 0 { + doc += const_text(", "); + } + doc + item.pretty_print() + }) + const_text("]") + } +} + +/// Merges two maps of statistics by merging values of identical keys, and appending missing keys +impl StatisticValue for alloc::collections::BTreeMap +where + K: Ord + Clone + fmt::Display + 'static, + V: Any + StatisticValue + Clone, +{ + type Value = alloc::collections::BTreeMap; + + fn value(&self) -> &Self::Value { + self + } + + fn value_mut(&mut self) -> &mut Self::Value { + self + } + + fn merge(&mut self, other: &mut Self::Value) { + use alloc::collections::btree_map::Entry; + + while let Some((k, mut v)) = other.pop_first() { + match self.entry(k) { + Entry::Vacant(entry) => { + entry.insert(v); + } + Entry::Occupied(mut entry) => { + entry.get_mut().merge(v.value_mut()); + } + } + } + } + + fn pretty_print(&self) -> crate::formatter::Document { + use crate::formatter::{const_text, indent, nl, text, Document}; + if self.is_empty() { + const_text("{}") + } else { + let single_line = const_text("{") + + self.iter().enumerate().fold(Document::Empty, |mut doc, (i, (k, v))| { + if i > 0 { + doc += const_text(", "); + } + doc + text(format!("{k}: ")) + v.pretty_print() + }) + + const_text("}"); + let multi_line = const_text("{") + + indent( + 4, + self.iter().enumerate().fold(nl(), |mut doc, (i, (k, v))| { + if i > 0 { + doc += const_text(",") + nl(); + } + doc + text(format!("{k}: ")) + v.pretty_print() + }) + nl(), + ) + + const_text("}"); + single_line | multi_line + } + } +} diff --git a/hir2/src/patterns.rs b/hir2/src/patterns.rs new file mode 100644 index 000000000..b70d11cc9 --- /dev/null +++ b/hir2/src/patterns.rs @@ -0,0 +1,13 @@ +mod applicator; +mod driver; +mod pattern; +mod pattern_set; +mod rewriter; + +pub use self::{ + applicator::{PatternApplicationError, PatternApplicator}, + driver::*, + pattern::*, + pattern_set::{FrozenRewritePatternSet, RewritePatternSet}, + rewriter::*, +}; diff --git a/hir2/src/patterns/applicator.rs b/hir2/src/patterns/applicator.rs new file mode 100644 index 000000000..dde6eb86a --- /dev/null +++ b/hir2/src/patterns/applicator.rs @@ -0,0 +1,222 @@ +use alloc::{collections::BTreeMap, rc::Rc}; + +use smallvec::SmallVec; + +use super::{FrozenRewritePatternSet, PatternBenefit, RewritePattern, Rewriter}; +use crate::{OperationName, OperationRef, Report}; + +pub enum PatternApplicationError { + NoMatchesFound, + Report(Report), +} + +/// This type manages the application of a group of rewrite patterns, with a user-provided cost model +pub struct PatternApplicator { + /// The list that owns the patterns used within this applicator + rewrite_patterns_set: Rc, + /// The set of patterns to match for each operation, stable sorted by benefit. + patterns: BTreeMap; 2]>>, + /// The set of patterns that may match against any operation type, stable sorted by benefit. + match_any_patterns: SmallVec<[Rc; 1]>, +} +impl PatternApplicator { + pub fn new(rewrite_patterns_set: Rc) -> Self { + Self { + rewrite_patterns_set, + patterns: Default::default(), + match_any_patterns: Default::default(), + } + } + + /// Apply a cost model to the patterns within this applicator. + pub fn apply_cost_model(&mut self, model: CostModel) + where + CostModel: Fn(&dyn RewritePattern) -> PatternBenefit, + { + // Clear the results computed by the previous cost model + self.match_any_patterns.clear(); + self.patterns.clear(); + + // Filter out op-specific patterns with no benefit, and order by highest benefit first + let mut benefits = Vec::default(); + for (op, op_patterns) in self.rewrite_patterns_set.op_specific_patterns().iter() { + benefits + .extend(op_patterns.iter().filter_map(|p| filter_map_pattern_benefit(p, &model))); + benefits.sort_by_key(|(_, benefit)| *benefit); + self.patterns + .insert(op.clone(), benefits.drain(..).map(|(pat, _)| pat).collect()); + } + + // Filter out "match any" patterns with no benefit, and order by highest benefit first + benefits.extend( + self.rewrite_patterns_set + .any_op_patterns() + .iter() + .filter_map(|p| filter_map_pattern_benefit(p, &model)), + ); + benefits.sort_by_key(|(_, benefit)| *benefit); + self.match_any_patterns.extend(benefits.into_iter().map(|(pat, _)| pat)); + } + + /// Apply the default cost model that solely uses the pattern's static benefit + #[inline] + pub fn apply_default_cost_model(&mut self) { + log::debug!("applying default cost model"); + self.apply_cost_model(|pattern| *pattern.benefit()); + } + + /// Walk all of the patterns within the applicator. + pub fn walk_all_patterns(&self, mut callback: F) + where + F: FnMut(Rc), + { + for patterns in self.rewrite_patterns_set.op_specific_patterns().values() { + for pattern in patterns { + callback(Rc::clone(pattern)); + } + } + for pattern in self.rewrite_patterns_set.any_op_patterns() { + callback(Rc::clone(pattern)); + } + } + + pub fn match_and_rewrite( + &mut self, + op: OperationRef, + rewriter: &mut R, + can_apply: A, + mut on_failure: F, + mut on_success: S, + ) -> Result<(), PatternApplicationError> + where + A: for<'a> Fn(&'a dyn RewritePattern) -> bool, + F: for<'a> FnMut(&'a dyn RewritePattern), + S: for<'a> FnMut(&'a dyn RewritePattern) -> Result<(), Report>, + R: Rewriter, + { + // Check to see if there are patterns matching this specific operation type. + let op_name = { + let op = op.borrow(); + op.name() + }; + let op_specific_patterns = self.patterns.get(&op_name).map(|p| p.as_slice()).unwrap_or(&[]); + + if op_specific_patterns.is_empty() { + log::trace!("no op-specific patterns found for '{op_name}'"); + } else { + log::trace!( + "found {} op-specific patterns for '{op_name}'", + op_specific_patterns.len() + ); + } + + log::trace!("{} op-agnostic patterns available", self.match_any_patterns.len()); + + // Process the op-specific patterns and op-agnostic patterns in an interleaved fashion + let mut op_patterns = op_specific_patterns.iter().peekable(); + let mut any_op_patterns = self.match_any_patterns.iter().peekable(); + let mut result = Err(PatternApplicationError::NoMatchesFound); + loop { + // Find the next pattern with the highest benefit + // + // 1. Start with the assumption that we'll use the next op-specific pattern + let mut best_pattern = op_patterns.peek().copied(); + // 2. But take the next op-agnostic pattern instead, IF: + // a. There are no more op-specific patterns + // b. The benefit of the op-agnostic pattern is higher than the op-specific pattern + if let Some(next_any_pattern) = any_op_patterns + .next_if(|p| best_pattern.is_none_or(|bp| bp.benefit() < p.benefit())) + { + if let Some(best_pattern) = best_pattern { + log::trace!( + "selected op-agnostic pattern '{}' because its benefit is higher than the \ + next best op-specific pattern '{}'", + next_any_pattern.name(), + best_pattern.name() + ); + } else { + log::trace!( + "selected op-agnostic pattern '{}' because no op-specific pattern is \ + available", + next_any_pattern.name() + ); + } + best_pattern.replace(next_any_pattern); + } else { + // The op-specific pattern is best, if available, so actually consume it from the iterator + if let Some(best_pattern) = best_pattern { + log::trace!("selected op-specific pattern '{}'", best_pattern.name()); + } + best_pattern = op_patterns.next(); + } + + // Break if we have exhausted all patterns + let Some(best_pattern) = best_pattern else { + log::trace!("all patterns have been exhausted"); + break; + }; + + // Can we apply this pattern? + let applicable = can_apply(&**best_pattern); + if !applicable { + log::trace!("skipping pattern: can_apply returned false"); + continue; + } + + // Try to match and rewrite this pattern. + // + // The patterns are sorted by benefit, so if we match we can immediately rewrite. + rewriter.set_insertion_point_before(crate::ProgramPoint::Op(op.clone())); + + // TODO: Save the nearest parent IsolatedFromAbove op of this op for use in debug + // messages/rendering, as the rewrite may invalidate `op` + log::debug!("trying to match '{}'", best_pattern.name()); + + match best_pattern.match_and_rewrite(op.clone(), rewriter) { + Ok(matched) => { + if matched { + log::trace!("pattern matched successfully"); + result = + on_success(&**best_pattern).map_err(PatternApplicationError::Report); + break; + } else { + log::trace!("failed to match pattern"); + on_failure(&**best_pattern); + } + } + Err(err) => { + log::error!("error occurred during match_and_rewrite: {err}"); + result = Err(PatternApplicationError::Report(err)); + on_failure(&**best_pattern); + } + } + } + + result + } +} + +fn filter_map_pattern_benefit( + pattern: &Rc, + cost_model: &CostModel, +) -> Option<(Rc, PatternBenefit)> +where + CostModel: Fn(&dyn RewritePattern) -> PatternBenefit, +{ + let benefit = if pattern.benefit().is_impossible_to_match() { + PatternBenefit::NONE + } else { + cost_model(&**pattern) + }; + if benefit.is_impossible_to_match() { + log::debug!( + "ignoring pattern '{}' ({}) because it is impossible to match or cannot lead to legal \ + IR (by cost model)", + pattern.name(), + pattern.kind(), + ); + None + } else { + Some((Rc::clone(pattern), benefit)) + } +} diff --git a/hir2/src/patterns/driver.rs b/hir2/src/patterns/driver.rs new file mode 100644 index 000000000..236608de8 --- /dev/null +++ b/hir2/src/patterns/driver.rs @@ -0,0 +1,1011 @@ +use alloc::{collections::BTreeSet, rc::Rc}; +use core::cell::RefCell; + +use smallvec::SmallVec; + +use super::{ + ForwardingListener, FrozenRewritePatternSet, PatternApplicator, PatternRewriter, Rewriter, + RewriterListener, +}; +use crate::{ + traits::{ConstantLike, Foldable, IsolatedFromAbove}, + BlockRef, Builder, Context, InsertionGuard, Listener, OpFoldResult, OperationFolder, + OperationRef, ProgramPoint, Region, RegionRef, Report, RewritePattern, SourceSpan, Spanned, + Value, ValueRef, WalkResult, Walkable, +}; + +/// Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the +/// highest benefit patterns in a greedy worklist driven manner until a fixpoint is reached. +/// +/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be +/// configured using [GreedyRewriteConfig]. +/// +/// This function also performs folding and simple dead-code elimination before attempting to match +/// any of the provided patterns. +/// +/// A region scope can be set using [GreedyRewriteConfig]. By default, the scope is set to the +/// specified region. Only in-scope ops are added to the worklist and only in-scope ops are allowed +/// to be modified by the patterns. +/// +/// Returns `Ok(changed)` if the iterative process converged (i.e., fixpoint was reached) and no +/// more patterns can be matched within the region. The `changed` flag is set to `true` if the IR +/// was modified at all. +/// +/// NOTE: This function does not apply patterns to the region's parent operation. +pub fn apply_patterns_and_fold_region_greedily( + region: RegionRef, + patterns: Rc, + mut config: GreedyRewriteConfig, +) -> Result { + // The top-level operation must be known to be isolated from above to prevent performing + // canonicalizations on operations defined at or above the region containing 'op'. + let context = { + let parent_op = region.borrow().parent().unwrap().borrow(); + assert!( + parent_op.implements::(), + "patterns can only be applied to operations which are isolated from above" + ); + parent_op.context_rc() + }; + + // Set scope if not specified + if config.scope.is_none() { + config.scope = Some(region.clone()); + } + + let mut driver = RegionPatternRewriteDriver::new(context, patterns, config, region); + let converged = driver.simplify(); + if converged.is_err() { + if let Some(max_iterations) = driver.driver.config.max_iterations { + log::trace!("pattern rewrite did not converge after scanning {max_iterations} times"); + } else { + log::trace!("pattern rewrite did not converge"); + } + } + converged +} + +/// Rewrite ops nested under the given operation, which must be isolated from above, by repeatedly +/// applying the highest benefit patterns in a greedy worklist driven manner until a fixpoint is +/// reached. +/// +/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be +/// configured using [GreedyRewriteConfig]. +/// +/// Also performs folding and simple dead-code elimination before attempting to match any of the +/// provided patterns. +/// +/// This overload runs a separate greedy rewrite for each region of the specified op. A region +/// scope can be set in the configuration parameter. By default, the scope is set to the region of +/// the current greedy rewrite. Only in-scope ops are added to the worklist and only in-scope ops +/// and the specified op itself are allowed to be modified by the patterns. +/// +/// NOTE: The specified op may be modified, but it may not be removed by the patterns. +/// +/// Returns `Ok(changed)` if the iterative process converged (i.e., fixpoint was reached) and no +/// more patterns can be matched within the region. The `changed` flag is set to `true` if the IR +/// was modified at all. +/// +/// NOTE: This function does not apply patterns to the given operation itself. +pub fn apply_patterns_and_fold_greedily( + op: OperationRef, + patterns: Rc, + config: GreedyRewriteConfig, +) -> Result { + let mut any_region_changed = false; + let mut failed = false; + let op = op.borrow(); + let mut cursor = op.regions().front(); + while let Some(region) = cursor.as_pointer() { + cursor.move_next(); + match apply_patterns_and_fold_region_greedily(region, patterns.clone(), config.clone()) { + Ok(region_changed) => { + any_region_changed |= region_changed; + } + Err(region_changed) => { + any_region_changed |= region_changed; + failed = true; + } + } + } + + if failed { + Err(any_region_changed) + } else { + Ok(any_region_changed) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum ApplyPatternsAndFoldEffect { + /// No effect, the IR remains unchanged + None, + /// The IR was modified + Changed, + /// The input IR was erased + Erased, +} + +pub type ApplyPatternsAndFoldResult = + Result; + +/// Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy +/// worklist driven manner until a fixpoint is reached. +/// +/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be +/// configured using [GreedyRewriteConfig]. +/// +/// This function also performs folding and simple dead-code elimination before attempting to match +/// any of the provided patterns. +/// +/// Newly created ops and other pre-existing ops that use results of rewritten ops or supply +/// operands to such ops are also processed, unless such ops are excluded via `config.restrict`. +/// Any other ops remain unmodified (i.e., regardless of restrictions). +/// +/// In addition to op restrictions, a region scope can be specified. Only ops within the scope are +/// simplified. This is similar to [apply_patterns_and_fold_greedily], where only ops within the +/// given region/op are simplified by default. If no scope is specified, it is assumed to be the +/// first common enclosing region of the given ops. +/// +/// Note that ops in `ops` could be erased as result of folding, becoming dead, or via pattern +/// rewrites. If more far reaching simplification is desired, [apply_patterns_and_fold_greedily] +/// should be used. +/// +/// Returns `Ok(effect)` if the iterative process converged (i.e., fixpoint was reached) and no more +/// patterns can be matched. `effect` is set to `Changed` if the IR was modified, but at least one +/// operation was not erased. It is set to `Erased` if all of the input ops were erased. +pub fn apply_patterns_and_fold( + ops: &[OperationRef], + patterns: Rc, + mut config: GreedyRewriteConfig, +) -> ApplyPatternsAndFoldResult { + if ops.is_empty() { + return Ok(ApplyPatternsAndFoldEffect::None); + } + + // Determine scope of rewrite + if let Some(scope) = config.scope.as_ref() { + // If a scope was provided, make sure that all ops are in scope. + let all_ops_in_scope = ops.iter().all(|op| scope.borrow().find_ancestor_op(op).is_some()); + assert!(all_ops_in_scope, "ops must be within the specified scope"); + } else { + // Compute scope if none was provided. The scope will remain `None` if there is a top-level + // op among `ops`. + config.scope = Region::find_common_ancestor(ops); + } + + // Start the pattern driver + let max_rewrites = config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX); + let context = ops[0].borrow().context_rc(); + let mut driver = MultiOpPatternRewriteDriver::new(context, patterns, config, ops); + let converged = driver.simplify(ops); + let changed = match converged.as_ref() { + Ok(changed) | Err(changed) => *changed, + }; + let erased = driver.inner.surviving_ops.borrow().is_empty(); + let effect = if erased { + ApplyPatternsAndFoldEffect::Erased + } else if changed { + ApplyPatternsAndFoldEffect::Changed + } else { + ApplyPatternsAndFoldEffect::None + }; + if converged.is_ok() { + Ok(effect) + } else { + log::trace!("pattern rewrite did not converge after {max_rewrites} rewrites"); + Err(effect) + } +} + +/// This enum indicates which ops are put on the worklist during a greedy pattern rewrite +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum GreedyRewriteStrictness { + /// No restrictions on which ops are processed. + #[default] + Any, + /// Only pre-existing and newly created ops are processed. + /// + /// Pre-existing ops are those that were on the worklist at the very beginning. + ExistingAndNew, + /// Only pre-existing ops are processed. + /// + /// Pre-existing ops are those that were on the worklist at the very beginning. + Existing, +} + +/// This enum indicates the level of simplification to be applied to regions during a greedy +/// pattern rewrite. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RegionSimplificationLevel { + /// Disable simplification. + None, + /// Perform basic simplifications (e.g. dead argument elimination) + #[default] + Normal, + /// Perform additional complex/expensive simplifications (e.g. block merging) + Aggressive, +} + +/// Configuration for [GreedyPatternRewriteDriver] +#[derive(Clone)] +pub struct GreedyRewriteConfig { + listener: Option>, + /// If set, only ops within the given region are added to the worklist. + /// + /// If no scope is specified, and no specific region is given when starting the greedy rewrite, + /// then the closest enclosing region of the initial list of operations is used. + scope: Option, + /// If set, specifies the maximum number of times the rewriter will iterate between applying + /// patterns and simplifying regions. + /// + /// NOTE: Only applicable when simplifying entire regions. + max_iterations: Option, + /// If set, specifies the maximum number of rewrites within an iteration. + max_rewrites: Option, + /// Perform control flow optimizations to the region tree after applying all patterns. + /// + /// NOTE: Only applicable when simplifying entire regions. + region_simplification: RegionSimplificationLevel, + /// The restrictions to apply, if any, to operations added to the worklist during the rewrite. + restrict: GreedyRewriteStrictness, + /// This flag specifies the order of initial traversal that populates the rewriter worklist. + /// + /// When true, operations are visited top-down, which is generally more efficient in terms of + /// compilation time. + /// + /// When false, the initial traversal of the region tree is bottom up on each block, which may + /// match larger patterns when given an ambiguous pattern set. + /// + /// NOTE: Only applicable when simplifying entire regions. + use_top_down_traversal: bool, +} +impl Default for GreedyRewriteConfig { + fn default() -> Self { + Self { + listener: None, + scope: None, + max_iterations: core::num::NonZeroU32::new(10), + max_rewrites: None, + region_simplification: Default::default(), + restrict: Default::default(), + use_top_down_traversal: false, + } + } +} +impl GreedyRewriteConfig { + pub fn new_with_listener(listener: impl RewriterListener) -> Self { + Self { + listener: Some(Rc::new(listener)), + ..Default::default() + } + } + + /// Scope rewrites to operations within `region` + pub fn with_scope(&mut self, region: RegionRef) -> &mut Self { + self.scope = Some(region); + self + } + + /// Set the maximum number of times the rewriter will iterate between applying patterns and + /// simplifying regions. + /// + /// If `0` is given, the number of iterations is unlimited. + /// + /// NOTE: Only applicable when simplifying entire regions. + pub fn with_max_iterations(&mut self, max: u32) -> &mut Self { + self.max_iterations = core::num::NonZeroU32::new(max); + self + } + + /// Set the maximum number of rewrites per iteration. + /// + /// If `0` is given, the number of rewrites is unlimited. + /// + /// NOTE: Only applicable when simplifying entire regions. + pub fn with_max_rewrites(&mut self, max: u32) -> &mut Self { + self.max_rewrites = core::num::NonZeroU32::new(max); + self + } + + /// Set the level of control flow optimizations to apply to the region tree. + /// + /// NOTE: Only applicable when simplifying entire regions. + pub fn with_region_simplification_level( + &mut self, + level: RegionSimplificationLevel, + ) -> &mut Self { + self.region_simplification = level; + self + } + + /// Set the level of restriction to apply to operations added to the worklist during the rewrite. + pub fn with_restrictions(&mut self, level: GreedyRewriteStrictness) -> &mut Self { + self.restrict = level; + self + } + + /// Specify whether or not to use a top-down traversal when initially adding operations to the + /// worklist. + pub fn with_top_down_traversal(&mut self, yes: bool) -> &mut Self { + self.use_top_down_traversal = yes; + self + } +} + +pub struct GreedyPatternRewriteDriver { + context: Rc, + worklist: RefCell, + config: GreedyRewriteConfig, + /// Not maintained when `config.restrict` is `GreedyRewriteStrictness::Any` + filtered_ops: RefCell>, + matcher: RefCell, +} + +impl GreedyPatternRewriteDriver { + pub fn new( + context: Rc, + patterns: Rc, + config: GreedyRewriteConfig, + ) -> Self { + // Apply a simple cost model based solely on pattern benefit + let mut matcher = PatternApplicator::new(patterns); + matcher.apply_default_cost_model(); + + Self { + context, + worklist: Default::default(), + config, + filtered_ops: Default::default(), + matcher: RefCell::new(matcher), + } + } +} + +/// Worklist Managment +impl GreedyPatternRewriteDriver { + /// Add the given operation to the worklist + pub fn add_single_op_to_worklist(&self, op: OperationRef) { + if matches!(self.config.restrict, GreedyRewriteStrictness::Any) + || self.filtered_ops.borrow().contains(&op) + { + log::trace!("adding single op '{}' to worklist", op.borrow().name()); + self.worklist.borrow_mut().push(op); + } else { + log::trace!( + "skipped adding single op '{}' to worklist due to strictness level", + op.borrow().name() + ); + } + } + + /// Add the given operation, and its ancestors, to the worklist + pub fn add_to_worklist(&self, op: OperationRef) { + // Gather potential ancestors while looking for a `scope` parent region + let mut ancestors = SmallVec::<[OperationRef; 8]>::default(); + let mut op = Some(op); + while let Some(ancestor_op) = op.take() { + let region = ancestor_op.borrow().parent_region(); + if self.config.scope.as_ref() == region.as_ref() { + ancestors.push(ancestor_op); + for op in ancestors { + self.add_single_op_to_worklist(op); + } + return; + } else { + log::trace!( + "gathering ancestors of '{}' for worklist", + ancestor_op.borrow().name() + ); + ancestors.push(ancestor_op); + } + if let Some(region) = region { + op = region.borrow().parent(); + } else { + log::trace!("reached top level op while searching for ancestors"); + } + } + } + + /// Process operations until the worklist is empty, or `config.max_rewrites` is reached. + /// + /// Returns true if the IR was changed. + pub fn process_worklist(self: Rc) -> bool { + log::debug!("starting processing of greedy pattern rewrite driver worklist"); + let mut rewriter = + PatternRewriter::new_with_listener(self.context.clone(), Rc::clone(&self)); + + let mut changed = false; + let mut num_rewrites = 0u32; + while self.config.max_rewrites.is_none_or(|max| num_rewrites < max.get()) { + let Some(op) = self.worklist.borrow_mut().pop() else { + // Worklist is empty, we've converged + log::debug!("processing worklist complete, rewrites have converged"); + return changed; + }; + + if self.process_worklist_item(&mut rewriter, op) { + changed = true; + num_rewrites += 1; + } + } + + log::debug!( + "processing worklist was canceled after {} rewrites without converging (reached max \ + rewrite limit)", + self.config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX) + ); + + changed + } + + /// Process a single operation from the worklist. + /// + /// Returns true if the IR was changed. + fn process_worklist_item( + &self, + rewriter: &mut PatternRewriter>, + mut op_ref: OperationRef, + ) -> bool { + let op = op_ref.borrow_mut(); + log::trace!("processing operation '{}'", op.name()); + + // If the operation is trivially dead - remove it. + if op.is_trivially_dead() { + drop(op); + rewriter.erase_op(op_ref); + log::trace!("processing complete: operation is trivially dead"); + return true; + } + + // Try to fold this op, unless it is a constant op, as that would lead to an infinite + // folding loop, since the folded result would be immediately materialized as a constant + // op, and then revisited. + if !op.implements::() { + let mut results = SmallVec::<[OpFoldResult; 1]>::default(); + log::trace!("attempting to fold operation.."); + if op.fold(&mut results).is_ok() { + if results.is_empty() { + // Op was modified in-place + self.notify_operation_modified(op_ref.clone()); + log::trace!("operation was succesfully folded/modified in-place"); + return true; + } else { + log::trace!( + "operation was succesfully folded away, to be replaced with: {}", + crate::formatter::DisplayValues::new(results.iter()) + ); + } + + // Op results can be replaced with `results` + assert_eq!( + results.len(), + op.num_results(), + "folder produced incorrect number of results" + ); + let mut rewriter = InsertionGuard::new(&mut **rewriter); + rewriter.set_insertion_point_before(ProgramPoint::Op(op_ref.clone())); + + log::trace!("replacing op with fold results.."); + let mut replacements = SmallVec::<[ValueRef; 2]>::default(); + let mut materialization_succeeded = true; + for (fold_result, result_ty) in results + .into_iter() + .zip(op.results().all().iter().map(|r| r.borrow().ty().clone())) + { + match fold_result { + OpFoldResult::Value(value) => { + assert_eq!( + value.borrow().ty(), + &result_ty, + "folder produced value of incorrect type" + ); + replacements.push(value); + } + OpFoldResult::Attribute(attr) => { + // Materialize attributes as SSA values using a constant op + let span = op.span(); + log::trace!( + "materializing constant for value '{}' and type '{result_ty}'", + attr.render() + ); + let constant_op = op.dialect().materialize_constant( + &mut *rewriter, + attr, + &result_ty, + span, + ); + match constant_op { + None => { + log::trace!( + "materialization failed: cleaning up any materialized ops \ + for {} previous results", + replacements.len() + ); + // If materialization fails, clean up any operations generated for the previous results + let mut replacement_ops = + SmallVec::<[OperationRef; 2]>::default(); + for replacement in replacements.iter() { + let replacement = replacement.borrow(); + assert!( + !replacement.is_used(), + "folder reused existing op for one result, but \ + constant materialization failed for another result" + ); + let replacement_op = replacement.get_defining_op().unwrap(); + if replacement_ops.contains(&replacement_op) { + continue; + } + replacement_ops.push(replacement_op); + } + for replacement_op in replacement_ops { + rewriter.erase_op(replacement_op); + } + materialization_succeeded = false; + break; + } + Some(constant_op) => { + let const_op = constant_op.borrow(); + assert!( + const_op.implements::(), + "materialize_constant produced op that does not implement \ + ConstantLike" + ); + let result: ValueRef = + const_op.results().all()[0].clone().upcast(); + assert_eq!( + result.borrow().ty(), + &result_ty, + "materialize_constant produced incorrect result type" + ); + log::trace!( + "successfully materialized constant as {}", + result.borrow().id() + ); + replacements.push(result); + } + } + } + } + } + + if materialization_succeeded { + log::trace!( + "materialization of fold results was successful, performing replacement.." + ); + drop(op); + rewriter.replace_op_with_values(op_ref, &replacements); + log::trace!( + "fold succeeded: operation was replaced with materialized constants" + ); + return true; + } else { + log::trace!( + "materialization of fold results failed, proceeding without folding" + ); + } + } + } else { + log::trace!("operation could not be folded"); + } + + // Try to match one of the patterns. + // + // The rewriter is automatically notified of any necessary changes, so there is nothing + // else to do here. + // TODO(pauls): if self.config.listener.is_some() { + // + // We need to trigger `notify_pattern_begin` in `can_apply`, and `notify_pattern_end` + // in `on_failure` and `on_success`, but we can't have multiple mutable aliases of + // the listener captured by these closures. + // + // This is another aspect of the listener infra that needs to be handled + log::trace!("attempting to match and rewrite one of the input patterns.."); + let result = if let Some(listener) = self.config.listener.as_deref() { + let op_name = op.name(); + let can_apply = |pattern: &dyn RewritePattern| { + log::trace!("applying pattern {} to op {}", pattern.name(), &op_name); + listener.notify_pattern_begin(pattern, op_ref.clone()); + true + }; + let on_failure = |pattern: &dyn RewritePattern| { + log::trace!("pattern failed to match"); + listener.notify_pattern_end(pattern, false); + }; + let on_success = |pattern: &dyn RewritePattern| { + log::trace!("pattern applied successfully"); + listener.notify_pattern_end(pattern, true); + Ok(()) + }; + drop(op); + self.matcher.borrow_mut().match_and_rewrite( + op_ref.clone(), + &mut **rewriter, + can_apply, + on_failure, + on_success, + ) + } else { + drop(op); + self.matcher.borrow_mut().match_and_rewrite( + op_ref.clone(), + &mut **rewriter, + |_| true, + |_| {}, + |_| Ok(()), + ) + }; + + match result { + Ok(_) => { + log::trace!("processing complete: pattern matched and operation was rewritten"); + true + } + Err(crate::PatternApplicationError::NoMatchesFound) => { + log::debug!("processing complete: exhausted all patterns without finding a match"); + false + } + Err(crate::PatternApplicationError::Report(report)) => { + log::debug!( + "processing complete: error occurred during match and rewrite: {report}" + ); + false + } + } + } + + /// Look over the operands of the provided op for any defining operations that should be re- + /// added to the worklist. This function sho9uld be called when an operation is modified or + /// removed, as it may trigger further simplifications. + fn add_operands_to_worklist(&self, op: OperationRef) { + let current_op = op.borrow(); + for operand in current_op.operands().all() { + // If this operand currently has at most 2 users, add its defining op to the worklist. + // After the op is deleted, then the operand will have at most 1 user left. If it has + // 0 users left, it can be deleted as well, and if it has 1 user left, there may be + // further canonicalization opportunities. + let operand = operand.borrow(); + let Some(def_op) = operand.value().get_defining_op() else { + continue; + }; + + let mut other_user = None; + let mut has_more_than_two_uses = false; + for user in operand.value().iter_uses() { + if user.owner == op || other_user.as_ref().is_some_and(|ou| ou == &user.owner) { + continue; + } + if other_user.is_none() { + other_user = Some(user.owner.clone()); + continue; + } + has_more_than_two_uses = true; + break; + } + if !has_more_than_two_uses { + self.add_to_worklist(def_op); + } + } + } +} + +/// Notifications +impl Listener for GreedyPatternRewriteDriver { + fn kind(&self) -> crate::ListenerType { + crate::ListenerType::Rewriter + } + + /// Notify the driver that the given block was inserted + fn notify_block_inserted( + &self, + block: crate::BlockRef, + prev: Option, + ip: Option, + ) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_block_inserted(block, prev, ip); + } + } + + /// Notify the driver that the specified operation was inserted. + /// + /// Update the worklist as needed: the operation is enqueued depending on scope and strictness + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_operation_inserted(op.clone(), prev.clone()); + } + if matches!(self.config.restrict, GreedyRewriteStrictness::ExistingAndNew) { + self.filtered_ops.borrow_mut().insert(op.clone()); + } + self.add_to_worklist(op); + } +} +impl RewriterListener for GreedyPatternRewriteDriver { + /// Notify the driver that the given block is about to be removed. + fn notify_block_erased(&self, block: BlockRef) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_block_erased(block); + } + } + + /// Notify the driver that the sepcified operation may have been modified in-place. The + /// operation is added to the worklist. + fn notify_operation_modified(&self, op: OperationRef) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_operation_modified(op.clone()); + } + self.add_to_worklist(op); + } + + /// Notify the driver that the specified operation was removed. + /// + /// Update the worklist as needed: the operation and its children are removed from the worklist + fn notify_operation_erased(&self, op: OperationRef) { + // Only ops that are within the configured scope are added to the worklist of the greedy + // pattern rewriter. + // + // A greedy pattern rewrite is not allowed to erase the parent op of the scope region, as + // that would break the worklist handling and some sanity checks. + if let Some(scope) = self.config.scope.as_ref() { + assert!( + scope.borrow().parent().is_some_and(|parent_op| parent_op != op), + "scope region must not be erased during greedy pattern rewrite" + ); + } + + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_operation_erased(op.clone()); + } + + self.add_operands_to_worklist(op.clone()); + self.worklist.borrow_mut().remove(&op); + + if self.config.restrict != GreedyRewriteStrictness::Any { + self.filtered_ops.borrow_mut().remove(&op); + } + } + + /// Notify the driver that the specified operation was replaced. + /// + /// Update the worklist as needed: new users are enqueued + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_operation_replaced_with_values(op, replacement); + } + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + if let Some(listener) = self.config.listener.as_deref() { + listener.notify_match_failure(span, reason); + } + } +} + +pub struct RegionPatternRewriteDriver { + driver: Rc, + region: RegionRef, +} +impl RegionPatternRewriteDriver { + pub fn new( + context: Rc, + patterns: Rc, + config: GreedyRewriteConfig, + region: RegionRef, + ) -> Self { + use crate::Walkable; + let mut driver = GreedyPatternRewriteDriver::new(context, patterns, config); + // Populate strict mode ops, if applicable + if driver.config.restrict != GreedyRewriteStrictness::Any { + let filtered_ops = driver.filtered_ops.get_mut(); + region.borrow().postwalk(|op| { + filtered_ops.insert(op); + }); + } + Self { + driver: Rc::new(driver), + region, + } + } + + /// Simplify ops inside `self.region`, and simplify the region itself. + /// + /// Returns `Ok(changed)` if the transformation converged, with `changed` indicating whether or + /// not the IR was changed. Otherwise, `Err(changed)` is returned. + pub fn simplify(&mut self) -> Result { + use crate::matchers::Matcher; + + let mut continue_rewrites = false; + let mut iteration = 0; + + while self.driver.config.max_iterations.is_none_or(|max| iteration < max.get()) { + log::trace!("starting iteration {iteration} of region pattern rewrite driver"); + iteration += 1; + + // New iteration: start with an empty worklist + self.driver.worklist.borrow_mut().clear(); + + // `OperationFolder` CSE's constant ops (and may move them into parents regions to + // enable more aggressive CSE'ing). + let context = self.driver.context.clone(); + let mut folder = OperationFolder::new(context, Rc::clone(&self.driver)); + let mut insert_known_constant = |op: OperationRef| { + // Check for existing constants when populating the worklist. This avoids + // accidentally reversing the constant order during processing. + if let Some(const_value) = crate::matchers::constant().matches(&op.borrow()) { + if !folder.insert_known_constant(op, Some(const_value)) { + return true; + } + } + false + }; + + if !self.driver.config.use_top_down_traversal { + // Add operations to the worklist in postorder. + log::trace!("adding operations in postorder"); + self.region.borrow().postwalk(|op| { + if !insert_known_constant(op.clone()) { + self.driver.add_to_worklist(op); + } + }); + } else { + // Add all nested operations to the worklist in preorder. + log::trace!("adding operations in preorder"); + self.region + .borrow() + .prewalk_interruptible(|op| { + if !insert_known_constant(op.clone()) { + self.driver.add_to_worklist(op); + WalkResult::::Continue(()) + } else { + WalkResult::Skip + } + }) + .into_result() + .expect("unexpected error occurred while walking region"); + + // Reverse the list so our loop processes them in-order + self.driver.worklist.borrow_mut().reverse(); + } + + continue_rewrites = self.driver.clone().process_worklist(); + log::trace!( + "processing of worklist for this iteration has completed, \ + changed={continue_rewrites}" + ); + + // After applying patterns, make sure that the CFG of each of the regions is kept up to + // date. + if self.driver.config.region_simplification != RegionSimplificationLevel::None { + let mut rewriter = PatternRewriter::new_with_listener( + self.driver.context.clone(), + Rc::clone(&self.driver), + ); + continue_rewrites |= Region::simplify_all( + &[self.region.clone()], + &mut *rewriter, + self.driver.config.region_simplification, + ) + .is_ok(); + } else { + log::debug!("region simplification was disabled, skipping simplification rewrites"); + } + + if !continue_rewrites { + log::trace!("region pattern rewrites have converged"); + break; + } + } + + // If `continue_rewrites` is false, then the rewrite converged, i.e. the IR wasn't changed + // in the last iteration. + if !continue_rewrites { + Ok(iteration > 1) + } else { + Err(iteration > 1) + } + } +} + +pub struct MultiOpPatternRewriteDriver { + driver: Rc, + inner: Rc, +} + +struct MultiOpPatternRewriteDriverImpl { + surviving_ops: RefCell>, +} + +impl MultiOpPatternRewriteDriver { + pub fn new( + context: Rc, + patterns: Rc, + mut config: GreedyRewriteConfig, + ops: &[OperationRef], + ) -> Self { + let surviving_ops = BTreeSet::from_iter(ops.iter().cloned()); + let inner = Rc::new(MultiOpPatternRewriteDriverImpl { + surviving_ops: RefCell::new(surviving_ops), + }); + let listener = Rc::new(ForwardingListener::new(config.listener.take(), Rc::clone(&inner))); + config.listener = Some(listener); + + let mut driver = GreedyPatternRewriteDriver::new(context.clone(), patterns, config); + if driver.config.restrict != GreedyRewriteStrictness::Any { + driver.filtered_ops.get_mut().extend(ops.iter().cloned()); + } + + Self { + driver: Rc::new(driver), + inner, + } + } + + pub fn simplify(&mut self, ops: &[OperationRef]) -> Result { + // Populate the initial worklist + for op in ops { + self.driver.add_single_op_to_worklist(op.clone()); + } + + // Process ops on the worklist + let changed = self.driver.clone().process_worklist(); + if self.driver.worklist.borrow().is_empty() { + Ok(changed) + } else { + Err(changed) + } + } +} + +impl Listener for MultiOpPatternRewriteDriverImpl { + fn kind(&self) -> crate::ListenerType { + crate::ListenerType::Rewriter + } +} +impl RewriterListener for MultiOpPatternRewriteDriverImpl { + fn notify_operation_erased(&self, op: OperationRef) { + self.surviving_ops.borrow_mut().remove(&op); + } +} + +#[derive(Default)] +struct Worklist(Vec); +impl Worklist { + /// Clear all operations from the worklist + #[inline] + pub fn clear(&mut self) { + self.0.clear() + } + + /// Returns true if the worklist is empty + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Push an operation to the end of the worklist, unless it is already in the worklist. + pub fn push(&mut self, op: OperationRef) { + if self.0.contains(&op) { + return; + } + self.0.push(op); + } + + /// Pop the next operation from the worklist + #[inline] + pub fn pop(&mut self) -> Option { + self.0.pop() + } + + /// Remove `op` from the worklist + pub fn remove(&mut self, op: &OperationRef) { + if let Some(index) = self.0.iter().position(|o| o == op) { + self.0.remove(index); + } + } + + /// Reverse the worklist + pub fn reverse(&mut self) { + self.0.reverse(); + } +} diff --git a/hir2/src/patterns/pattern.rs b/hir2/src/patterns/pattern.rs new file mode 100644 index 000000000..d767bcb0d --- /dev/null +++ b/hir2/src/patterns/pattern.rs @@ -0,0 +1,382 @@ +use alloc::rc::Rc; +use core::{any::TypeId, fmt}; + +use smallvec::SmallVec; + +use super::Rewriter; +use crate::{interner, Context, OperationName, OperationRef, Report}; + +#[derive(Debug)] +pub enum PatternKind { + /// The pattern root matches any operation + Any, + /// The pattern root matches a specific named operation + Operation(OperationName), + /// The pattern root matches a specific trait + Trait(TypeId), +} +impl fmt::Display for PatternKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Any => f.write_str("for any"), + Self::Operation(name) => write!(f, "for operation '{name}'"), + Self::Trait(_) => write!(f, "for trait"), + } + } +} + +/// Represents the benefit a pattern has. +/// +/// More beneficial patterns are preferred over those with lesser benefit, while patterns with no +/// benefit whatsoever can be discarded. +/// +/// This is used to evaluate which patterns to apply, and in what order. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +#[repr(transparent)] +pub struct PatternBenefit(Option); +impl PatternBenefit { + /// Represents a pattern which is the most beneficial + pub const MAX: Self = Self(Some(unsafe { core::num::NonZeroU16::new_unchecked(u16::MAX) })); + /// Represents a pattern which is the least beneficial + pub const MIN: Self = Self(Some(unsafe { core::num::NonZeroU16::new_unchecked(1) })); + /// Represents a pattern which can never match, and thus should be discarded + pub const NONE: Self = Self(None); + + /// Create a new [PatternBenefit] from a raw [u16] value. + /// + /// A value of `u16::MAX` is treated as impossible to match, while values from `0..=65534` range + /// from the least beneficial to the most beneficial. + pub fn new(benefit: u16) -> Self { + if benefit == u16::MAX { + Self(None) + } else { + Self(Some(unsafe { core::num::NonZeroU16::new_unchecked(benefit + 1) })) + } + } + + /// Returns true if the pattern benefit indicates it can never match + #[inline] + pub fn is_impossible_to_match(&self) -> bool { + self.0.is_none() + } +} + +impl PartialOrd for PatternBenefit { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for PatternBenefit { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + use core::cmp::Ordering; + match (self.0, other.0) { + (None, None) => Ordering::Equal, + // Impossible to match is always last + (None, Some(_)) => Ordering::Greater, + (Some(_), None) => Ordering::Less, + // Benefits are ordered in reverse of integer order (higher benefit appears earlier) + (Some(a), Some(b)) => a.get().cmp(&b.get()).reverse(), + } + } +} + +pub trait Pattern { + fn info(&self) -> &PatternInfo; + /// A name used when printing diagnostics related to this pattern + #[inline(always)] + fn name(&self) -> &'static str { + self.info().name + } + /// The kind of value used to select candidate root operations for this pattern. + #[inline(always)] + fn kind(&self) -> &PatternKind { + &self.info().kind + } + /// Returns the benefit - the inverse of "cost" - of matching this pattern. + /// + /// The benefit of a [Pattern] is always static - rewrites that may have dynamic benefit can be + /// instantiated multiple times (different instances), for each benefit that they may return, + /// and be guarded by different match condition predicates. + #[inline(always)] + fn benefit(&self) -> &PatternBenefit { + &self.info().benefit + } + /// Returns true if this pattern is known to result in recursive application, i.e. this pattern + /// may generate IR that also matches this pattern, but is known to bound the recursion. This + /// signals to the rewrite driver that it is safe to apply this pattern recursively to the + /// generated IR. + #[inline(always)] + fn has_bounded_rewrite_recursion(&self) -> bool { + self.info().has_bounded_recursion + } + /// Return a list of operations that may be generated when rewriting an operation instance + /// with this pattern. + #[inline(always)] + fn generated_ops(&self) -> &[OperationName] { + &self.info().generated_ops + } + /// Return the root operation that this pattern matches. + /// + /// Patterns that can match multiple root types return `None` + #[inline(always)] + fn get_root_operation(&self) -> Option { + self.info().root_operation() + } + /// Return the trait id used to match the root operation of this pattern. + /// + /// If the pattern does not use a trait id for deciding the root match, this returns `None` + #[inline(always)] + fn get_root_trait(&self) -> Option { + self.info().get_root_trait() + } +} + +/// [PatternBase] describes all of the data related to a pattern, but does not express any actual +/// pattern logic, i.e. it is solely used for metadata about a pattern. +pub struct PatternInfo { + #[allow(unused)] + context: Rc, + name: &'static str, + kind: PatternKind, + #[allow(unused)] + labels: SmallVec<[interner::Symbol; 1]>, + benefit: PatternBenefit, + has_bounded_recursion: bool, + generated_ops: SmallVec<[OperationName; 0]>, +} + +impl PatternInfo { + /// Create a new [Pattern] from its component parts. + pub fn new( + context: Rc, + name: &'static str, + kind: PatternKind, + benefit: PatternBenefit, + ) -> Self { + Self { + context, + name, + kind, + labels: SmallVec::default(), + benefit, + has_bounded_recursion: false, + generated_ops: SmallVec::default(), + } + } + + /// Set whether or not this pattern has bounded rewrite recursion + #[inline(always)] + pub fn with_bounded_rewrite_recursion(&mut self, yes: bool) -> &mut Self { + self.has_bounded_recursion = yes; + self + } + + /// Return the root operation that this pattern matches. + /// + /// Patterns that can match multiple root types return `None` + pub fn root_operation(&self) -> Option { + match self.kind { + PatternKind::Operation(ref name) => Some(name.clone()), + _ => None, + } + } + + /// Return the trait id used to match the root operation of this pattern. + /// + /// If the pattern does not use a trait id for deciding the root match, this returns `None` + pub fn root_trait(&self) -> Option { + match self.kind { + PatternKind::Trait(type_id) => Some(type_id), + _ => None, + } + } +} + +impl Pattern for PatternInfo { + #[inline(always)] + fn info(&self) -> &PatternInfo { + self + } +} + +/// A [RewritePattern] represents two things: +/// +/// * A pattern which matches some IR that we're interested in, typically to replace with something +/// else. +/// * A rewrite which replaces IR that maches the pattern, with new IR, i.e. a DAG-to-DAG +/// replacement +/// +/// Implementations must provide `matches` and `rewrite` implementations, from which the +/// `match_and_rewrite` implementation is derived. +pub trait RewritePattern: Pattern { + /// Rewrite the IR rooted at the specified operation with the result of this pattern, generating + /// any new operations with the specified builder. If an unexpected error is encountered, i.e. + /// an internal compiler error, it is emitted through the normal diagnostic system, and the IR + /// is left in a valid state. + fn rewrite(&self, op: OperationRef, rewriter: &mut dyn Rewriter); + + /// Attempt to match this pattern against the IR rooted at the specified operation, + /// which is the same operation as [Pattern::kind]. + fn matches(&self, op: OperationRef) -> Result; + + /// Attempt to match this pattern against the IR rooted at the specified operation. If + /// matching is successful, the rewrite is automatically applied. + fn match_and_rewrite( + &self, + op: OperationRef, + rewriter: &mut dyn Rewriter, + ) -> Result { + if self.matches(op.clone())? { + self.rewrite(op, rewriter); + + Ok(true) + } else { + Ok(false) + } + } +} + +#[cfg(test)] +mod tests { + use alloc::rc::Rc; + + use pretty_assertions::{assert_eq, assert_str_eq}; + + use super::*; + use crate::{dialects::hir::*, *}; + + /// In Miden, `n << 1` is vastly inferior to `n * 2` in cost, so reverse it + /// + /// NOTE: These two ops have slightly different semantics, a real implementation would have + /// to handle the edge cases. + struct ConvertShiftLeftBy1ToMultiply { + info: PatternInfo, + } + impl ConvertShiftLeftBy1ToMultiply { + pub fn new(context: Rc) -> Self { + let dialect = context.get_or_register_dialect::(); + let op_name = ::register_with(&*dialect); + let mut info = PatternInfo::new( + context, + "convert-shl1-to-mul2", + PatternKind::Operation(op_name), + PatternBenefit::new(1), + ); + info.with_bounded_rewrite_recursion(true); + Self { info } + } + } + impl Pattern for ConvertShiftLeftBy1ToMultiply { + fn info(&self) -> &PatternInfo { + &self.info + } + } + impl RewritePattern for ConvertShiftLeftBy1ToMultiply { + fn matches(&self, op: OperationRef) -> Result { + use crate::matchers::{self, match_chain, match_op, MatchWith, Matcher}; + + let binder = MatchWith(|op: &UnsafeIntrusiveEntityRef| { + log::trace!( + "found matching 'hir.shl' operation, checking if `shift` operand is foldable" + ); + let op = op.borrow(); + let shift = op.shift().as_operand_ref(); + let matched = matchers::foldable_operand_of::().matches(&shift); + matched.and_then(|imm| { + log::trace!("`shift` operand is an immediate: {imm}"); + let imm = imm.as_u64(); + if imm.is_none() { + log::trace!("`shift` operand is not a valid u64 value"); + } + if imm.is_some_and(|imm| imm == 1) { + Some(()) + } else { + None + } + }) + }); + log::trace!("attempting to match '{}'", self.name()); + let matched = match_chain(match_op::(), binder).matches(&op.borrow()).is_some(); + log::trace!("'{}' matched: {matched}", self.name()); + Ok(matched) + } + + fn rewrite(&self, op: OperationRef, rewriter: &mut dyn Rewriter) { + log::trace!("found match, rewriting '{}'", op.borrow().name()); + let (span, lhs) = { + let shl = op.borrow(); + let shl = shl.downcast_ref::().unwrap(); + let span = shl.span(); + let lhs = shl.lhs().as_value_ref(); + (span, lhs) + }; + let constant_builder = rewriter.create::(span); + let constant: UnsafeIntrusiveEntityRef = + constant_builder(Immediate::U32(2)).unwrap(); + let shift = constant.borrow().result().as_value_ref(); + let mul_builder = rewriter.create::(span); + let mul = mul_builder(lhs, shift, Overflow::Wrapping).unwrap(); + let mul = mul.borrow().as_operation().as_operation_ref(); + log::trace!("replacing shl with mul"); + rewriter.replace_op(op, mul); + } + } + + #[test] + fn rewrite_pattern_api_test() { + let mut builder = env_logger::Builder::from_env("MIDENC_TRACE"); + builder.init(); + + let context = Rc::new(Context::default()); + let pattern = ConvertShiftLeftBy1ToMultiply::new(Rc::clone(&context)); + + let mut builder = OpBuilder::new(Rc::clone(&context)); + let mut function = { + let builder = builder.create::(SourceSpan::default()); + let id = Ident::new("test".into(), SourceSpan::default()); + let signature = Signature::new([AbiParam::new(Type::U32)], [AbiParam::new(Type::U32)]); + builder(id, signature).unwrap() + }; + + // Define function body + { + let mut func = function.borrow_mut(); + let mut builder = FunctionBuilder::new(&mut func); + let shift = builder.ins().u32(1, SourceSpan::default()).unwrap(); + let block = builder.current_block(); + let lhs = block.borrow().arguments()[0].clone().upcast(); + let result = builder.ins().shl(lhs, shift, SourceSpan::default()).unwrap(); + builder.ins().ret(Some(result), SourceSpan::default()).unwrap(); + } + + // Construct pattern set + let mut rewrites = RewritePatternSet::new(builder.context_rc()); + rewrites.push(pattern); + let rewrites = Rc::new(FrozenRewritePatternSet::new(rewrites)); + + // Execute pattern driver + let mut config = GreedyRewriteConfig::default(); + config.with_region_simplification_level(RegionSimplificationLevel::None); + let result = crate::apply_patterns_and_fold_greedily( + function.borrow().as_operation().as_operation_ref(), + rewrites, + config, + ); + + // The rewrite should converge and modify the IR + assert_eq!(result, Ok(true)); + + // Confirm that the expected rewrite occurred + let func = function.borrow(); + let output = func.as_operation().to_string(); + let expected = "\ +hir.function public @test(v0: u32) -> u32 { +^block0(v0: u32): + v1 = hir.constant 1 : u32; + v3 = hir.constant 2 : u32; + v4 = hir.mul v0, v3 : u32 #[overflow = wrapping]; + hir.ret v4; +};"; + assert_str_eq!(output.as_str(), expected); + } +} diff --git a/hir2/src/patterns/pattern_set.rs b/hir2/src/patterns/pattern_set.rs new file mode 100644 index 000000000..0d7fb0134 --- /dev/null +++ b/hir2/src/patterns/pattern_set.rs @@ -0,0 +1,110 @@ +use alloc::{collections::BTreeMap, rc::Rc}; + +use smallvec::SmallVec; + +use super::*; +use crate::{Context, OperationName}; + +pub struct RewritePatternSet { + context: Rc, + patterns: Vec>, +} +impl RewritePatternSet { + pub fn new(context: Rc) -> Self { + Self { + context, + patterns: vec![], + } + } + + pub fn from_iter

(context: Rc, patterns: P) -> Self + where + P: IntoIterator>, + { + Self { + context, + patterns: patterns.into_iter().collect(), + } + } + + #[inline] + pub fn context(&self) -> Rc { + Rc::clone(&self.context) + } + + #[inline] + pub fn patterns(&self) -> &[Box] { + &self.patterns + } + + pub fn push(&mut self, pattern: impl RewritePattern + 'static) { + self.patterns.push(Box::new(pattern)); + } +} + +pub struct FrozenRewritePatternSet { + context: Rc, + patterns: Vec>, + op_specific_patterns: BTreeMap; 2]>>, + any_op_patterns: SmallVec<[Rc; 1]>, +} +impl FrozenRewritePatternSet { + pub fn new(patterns: RewritePatternSet) -> Self { + let RewritePatternSet { context, patterns } = patterns; + let mut this = Self { + context, + patterns: Default::default(), + op_specific_patterns: Default::default(), + any_op_patterns: Default::default(), + }; + + for pattern in patterns { + let pattern = Rc::::from(pattern); + match pattern.kind() { + PatternKind::Operation(name) => { + this.op_specific_patterns + .entry(name.clone()) + .or_default() + .push(Rc::clone(&pattern)); + this.patterns.push(pattern); + } + PatternKind::Trait(ref trait_id) => { + for dialect in this.context.registered_dialects().values() { + for op in dialect.registered_ops().iter() { + if op.implements_trait_id(trait_id) { + this.op_specific_patterns + .entry(op.clone()) + .or_default() + .push(Rc::clone(&pattern)); + } + } + } + this.patterns.push(pattern); + } + PatternKind::Any => { + this.any_op_patterns.push(Rc::clone(&pattern)); + this.patterns.push(pattern); + } + } + } + + this + } + + #[inline] + pub fn patterns(&self) -> &[Rc] { + &self.patterns + } + + #[inline] + pub fn op_specific_patterns( + &self, + ) -> &BTreeMap; 2]>> { + &self.op_specific_patterns + } + + #[inline] + pub fn any_op_patterns(&self) -> &[Rc] { + &self.any_op_patterns + } +} diff --git a/hir2/src/patterns/rewriter.rs b/hir2/src/patterns/rewriter.rs new file mode 100644 index 000000000..fc6b23900 --- /dev/null +++ b/hir2/src/patterns/rewriter.rs @@ -0,0 +1,1157 @@ +use alloc::rc::Rc; +use core::ops::{Deref, DerefMut}; + +use smallvec::SmallVec; + +use crate::{ + Block, BlockRef, Builder, Context, EntityWithParent, InsertionGuard, InsertionPoint, Listener, + ListenerType, OpBuilder, OpOperandImpl, Operation, OperationRef, Pattern, PostOrderBlockIter, + ProgramPoint, RegionRef, Report, SourceSpan, Usable, ValueRef, +}; + +/// A [Rewriter] is a [Builder] extended with additional functionality that is of primary use when +/// rewriting the IR after it is initially constructed. It is the basis on which the pattern +/// rewriter infrastructure is built. +pub trait Rewriter: Builder + RewriterListener { + /// Returns true if this rewriter has a listener attached. + /// + /// When no listener is present, fast paths can be taken when rewriting the IR, whereas a + /// listener requires breaking mutations up into individual actions so that the listener can + /// be made aware of all of them, in the order they occur. + fn has_listener(&self) -> bool; + + /// Replace the results of the given operation with the specified list of values (replacements). + /// + /// The result types of the given op and the replacements must match. The original op is erased. + fn replace_op_with_values(&mut self, op: OperationRef, values: &[ValueRef]) { + assert_eq!(op.borrow().num_results(), values.len()); + + // Replace all result uses, notifies listener of the modifications + self.replace_all_op_uses_with_values(op.clone(), values); + + // Erase the op and notify the listener + self.erase_op(op); + } + + /// Replace the results of the given operation with the specified replacement op. + /// + /// The result types of the two ops must match. The original op is erased. + fn replace_op(&mut self, op: OperationRef, new_op: OperationRef) { + assert_eq!(op.borrow().num_results(), new_op.borrow().num_results()); + + // Replace all result uses, notifies listener of the modifications + self.replace_all_op_uses_with(op.clone(), new_op); + + // Erase the op and notify the listener + self.erase_op(op); + } + + /// This method erases an operation that is known to have no uses. + fn erase_op(&mut self, mut op: OperationRef) { + assert!(!op.borrow().is_used(), "expected op to have no uses"); + + // If no listener is attached, the op can be dropped all at once. + if self.has_listener() { + op.borrow_mut().erase(); + return; + } + + // Helper function that erases a single operation + fn erase_single_op( + mut op: OperationRef, + rewrite_listener: &mut R, + ) { + let mut op_mut = op.borrow_mut(); + if cfg!(debug_assertions) { + // All nested ops should have been erased already + assert!(op_mut.regions().iter().all(|r| r.is_empty()), "expected empty regions"); + // All users should have been erased already if the op is in a region with SSA dominance + if op_mut.is_used() { + if let Some(region) = op_mut.parent_region() { + assert!( + region.borrow().may_be_graph_region(), + "expected that op has no uses" + ); + } + } + } + + rewrite_listener.notify_operation_erased(op); + + // Explicitly drop all uses in case the op is in a graph region + op_mut.drop_all_uses(); + op_mut.erase(); + } + + // Nested ops must be erased one-by-one, so that listeners have a consistent view of the + // IR every time a notification is triggered. Users must be erased before definitions, i.e. + // in post-order, reverse dominance. + fn erase_tree(op_ref: OperationRef, rewriter: &mut R) { + // Erase nested ops + let op = op_ref.borrow(); + for region in op.regions() { + // Erase all blocks in the right order. Successors should be erased before + // predecessors because successor blocks may use values defined in predecessor + // blocks. A post-order traversal of blocks within a region visits successors before + // predecessors. Repeat the traversal until the region is empty. (The block graph + // could be disconnected.) + let mut erased_blocks = SmallVec::<[BlockRef; 4]>::default(); + while !region.is_empty() { + erased_blocks.clear(); + for block_ref in PostOrderBlockIter::new(region.entry_block_ref().unwrap()) { + let block = block_ref.borrow(); + let mut cursor = block.body().front(); + while let Some(op) = cursor.as_pointer() { + erase_tree(op, rewriter); + cursor.move_next(); + } + erased_blocks.push(block_ref); + } + for mut block in erased_blocks.drain(..) { + // Explicitly drop all uses in case there is a cycle in the block + // graph. + for arg in block.borrow_mut().arguments_mut() { + arg.borrow_mut().uses_mut().clear(); + } + block.borrow_mut().drop_all_uses(); + rewriter.erase_block(block); + } + } + } + erase_single_op(op_ref, rewriter); + } + + erase_tree(op, self); + } + + /// This method erases all operations in a block. + fn erase_block(&mut self, mut block: BlockRef) { + assert!(!block.borrow().is_used(), "expected 'block' to be unused"); + + let mut blk = block.borrow_mut(); + let mut cursor = blk.body_mut().back_mut(); + while let Some(op) = cursor.remove() { + assert!(!op.borrow().is_used(), "expected 'op' to be unused"); + self.erase_op(op); + } + + // Notify the listener that the block is about to be removed. + self.notify_block_erased(block.clone()); + + // Remove block from parent region + let mut region = blk.parent().expect("expected 'block' to have a parent region"); + let mut region_mut = region.borrow_mut(); + let mut cursor = unsafe { region_mut.body_mut().cursor_mut_from_ptr(block.clone()) }; + cursor.remove(); + } + + /// Move the blocks that belong to `region` before the given insertion point in another region, + /// `ip`. The two regions must be different. The caller is responsible for creating or + /// updating the operation transferring flow of control to the region, and passing it the + /// correct block arguments. + fn inline_region_before(&mut self, mut region: RegionRef, ip: BlockRef) { + let mut parent = ip.borrow().parent().expect("invalid 'ip': must be attached to a region"); + assert!(!RegionRef::ptr_eq(®ion, &parent), "cannot inline a region into itself"); + let mut region_body = region.borrow_mut().body_mut().take(); + if !self.has_listener() { + { + let mut region_cursor = region_body.front_mut(); + while let Some(block) = region_cursor.as_pointer() { + Block::on_inserted_into_parent(block, parent.clone()); + region_cursor.move_next(); + } + } + let mut parent_region = parent.borrow_mut(); + let parent_body = parent_region.body_mut(); + let mut cursor = unsafe { parent_body.cursor_mut_from_ptr(ip.clone()) }; + cursor.splice_before(region_body); + } else { + // Move blocks from beginning of the region one-by-one + for block in region_body { + self.move_block_before(block, ip.clone()); + } + } + } + + /// Inline the operations of block `src` before the given insertion point. + /// The source block will be deleted and must have no uses. The `args` values, if provided, are + /// used to replace the block arguments of `src`. + /// + /// If the source block is inserted at the end of the dest block, the dest block must have no + /// successors. Similarly, if the source block is inserted somewhere in the middle (or + /// beginning) of the dest block, the source block must have no successors. Otherwise, the + /// resulting IR would have unreachable operations. + fn inline_block_before(&mut self, mut src: BlockRef, ip: OperationRef, args: &[ValueRef]) { + assert!( + args.len() == src.borrow().num_arguments(), + "incorrect # of argument replacement values" + ); + + // The source block will be deleted, so it should not have any users (i.e., there should be + // no predecessors). + assert!(!src.borrow().has_predecessors(), "expected 'src' to have no predecessors"); + + let mut dest = ip.borrow().parent().expect("expected 'ip' to belong to a block"); + let insert_at_block_end = + OperationRef::ptr_eq(&ip, &dest.borrow().body().back().as_pointer().unwrap()); + if insert_at_block_end { + // The source block will be inserted at the end of the dest block, so the + // dest block should have no successors. Otherwise, the inserted operations + // will be unreachable. + assert!(!dest.borrow().has_successors(), "expected 'dest' to have no successors"); + } else { + // The source block will be inserted in the middle of the dest block, so + // the source block should have no successors. Otherwise, the remainder of + // the dest block would be unreachable. + assert!(!src.borrow().has_successors(), "expected 'src' to have no successors"); + } + + // Replace all of the successor arguments with the provided values. + for (arg, replacement) in src.borrow().arguments().iter().zip(args) { + self.replace_all_uses_of_value_with(arg.clone().upcast(), replacement.clone()); + } + + // Move operations from the source block to the dest block and erase the source block. + if self.has_listener() { + let mut src_ops = src.borrow_mut().body_mut().take(); + let mut src_cursor = src_ops.front_mut(); + while let Some(op) = src_cursor.remove() { + self.move_op_before(op, ip.clone()); + } + } else { + // Fast path: If no listener is attached, move all operations at once. + let mut dest_block = dest.borrow_mut(); + let dest_body = dest_block.body_mut(); + let mut src_ops = src.borrow_mut().body_mut().take(); + { + let mut src_cursor = src_ops.front_mut(); + while let Some(op) = src_cursor.as_pointer() { + Operation::on_inserted_into_parent(op, dest.clone()); + src_cursor.move_next(); + } + } + let mut cursor = unsafe { dest_body.cursor_mut_from_ptr(ip) }; + cursor.splice_before(src_ops); + } + + // Erase the source block. + assert!(src.borrow().body().is_empty(), "expected 'src' to be empty"); + self.erase_block(src); + } + + /// Inline the operations of block `src` into the end of block `dest`. The source block will be + /// deleted and must have no uses. The `args` values, if present, are used to replace the block + /// arguments of `src`. + /// + /// The dest block must have no successors. Otherwise, the resulting IR will have unreachable + /// operations. + fn merge_blocks(&mut self, src: BlockRef, dest: BlockRef, args: &[ValueRef]) { + let ip = dest.borrow().body().back().as_pointer().unwrap(); + self.inline_block_before(src, ip, args); + } + + /// Split the operations starting at `ip` (inclusive) out of the given block into a new block, + /// and return it. + fn split_block(&mut self, mut block: BlockRef, ip: OperationRef) -> BlockRef { + // Fast path: if no listener is attached, split the block directly + if !self.has_listener() { + return block.borrow_mut().split_block(ip); + } + + assert_eq!( + block, + ip.borrow().parent().expect("expected 'ip' to be attached to a block"), + "expected 'ip' to be in 'block'" + ); + + let region = block + .borrow() + .parent() + .expect("cannot split a block which is not attached to a region"); + + // `create_block` sets the insertion point to the start of the new block + let mut guard = InsertionGuard::new(self); + let new_block = guard.create_block(region, Some(block.clone()), &[]); + + // If `ip` points to the end of the block, no ops should be moved + if OperationRef::ptr_eq(&ip, &block.borrow().body().back().as_pointer().unwrap()) { + return new_block; + } + + // Move ops one-by-one from the end of `block` to the start of `new_block`. + // Stop when the operation pointed to by `ip` has been moved. + let mut block = block.borrow_mut(); + let mut cursor = block.body_mut().back_mut(); + while let Some(op) = cursor.remove() { + let is_last_move = OperationRef::ptr_eq(&op, &ip); + guard.move_op_before(op, new_block.borrow().body().front().as_pointer().unwrap()); + if is_last_move { + break; + } + } + + new_block + } + + /// Unlink this block and insert it right before `ip`. + fn move_block_before(&mut self, mut block: BlockRef, ip: BlockRef) { + let current_region = block.borrow().parent(); + block.borrow_mut().move_before(ip.clone()); + self.notify_block_inserted(block, current_region, Some(ip)); + } + + /// Unlink this operation from its current block and insert it right before `ip`, which + /// may be in the same or another block in the same function. + fn move_op_before(&mut self, mut op: OperationRef, ip: OperationRef) { + let current_block = op.borrow().parent(); + let current_ip = current_block.map(|block| { + let blk = block.borrow(); + let cursor = unsafe { blk.body().cursor_from_ptr(op.clone()) }; + if let Some(next_op) = cursor.peek_next().as_pointer() { + InsertionPoint::before(next_op) + } else if let Some(prev_op) = cursor.peek_prev().as_pointer() { + InsertionPoint::after(prev_op) + } else { + InsertionPoint::after(block) + } + }); + op.borrow_mut().move_before(ProgramPoint::Op(ip.clone())); + self.notify_operation_inserted(op, current_ip); + } + + /// Unlink this operation from its current block and insert it right after `ip`, which may be + /// in the same or another block in the same function. + fn move_op_after(&mut self, mut op: OperationRef, ip: OperationRef) { + let current_block = op.borrow().parent(); + let current_ip = current_block.map(|block| { + let blk = block.borrow(); + let cursor = unsafe { blk.body().cursor_from_ptr(op.clone()) }; + if let Some(next_op) = cursor.peek_next().as_pointer() { + InsertionPoint::before(next_op) + } else if let Some(prev_op) = cursor.peek_prev().as_pointer() { + InsertionPoint::after(prev_op) + } else { + InsertionPoint::after(block) + } + }); + op.borrow_mut().move_after(ProgramPoint::Op(ip.clone())); + self.notify_operation_inserted(op, current_ip); + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + fn replace_all_uses_of_value_with(&mut self, mut from: ValueRef, mut to: ValueRef) { + let mut from_val = from.borrow_mut(); + let from_uses = from_val.uses_mut(); + let mut cursor = from_uses.front_mut(); + while let Some(mut operand) = cursor.remove() { + let to = &mut to; + let op = operand.borrow().owner.clone(); + self.notify_operation_modification_started(&op); + operand.borrow_mut().value = to.clone(); + to.borrow_mut().insert_use(operand); + self.notify_operation_modified(op); + } + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + fn replace_all_uses_of_block_with(&mut self, mut from: BlockRef, mut to: BlockRef) { + let mut from_block = from.borrow_mut(); + let from_uses = from_block.uses_mut(); + let mut cursor = from_uses.front_mut(); + while let Some(mut operand) = cursor.remove() { + let to = &mut to; + let op = operand.borrow().owner.clone(); + self.notify_operation_modification_started(&op); + operand.borrow_mut().block = to.clone(); + to.borrow_mut().insert_use(operand); + self.notify_operation_modified(op); + } + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + fn replace_all_uses_with(&mut self, from: &[ValueRef], to: &[ValueRef]) { + assert_eq!(from.len(), to.len(), "incorrect number of replacements"); + for (from, to) in from.iter().cloned().zip(to.iter().cloned()) { + self.replace_all_uses_of_value_with(from, to); + } + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place modification (for every use that was replaced), + /// and that the `from` operation is about to be replaced. + fn replace_all_op_uses_with_values(&mut self, from: OperationRef, to: &[ValueRef]) { + self.notify_operation_replaced_with_values(from.clone(), to); + + let results = from + .borrow() + .results() + .all() + .iter() + .map(|result| result.borrow().as_value_ref()) + .collect::>(); + + self.replace_all_uses_with(&results, to); + } + + /// Find uses of `from` and replace them with `to`. + /// + /// Notifies the listener about every in-place modification (for every use that was replaced), + /// and that the `from` operation is about to be replaced. + fn replace_all_op_uses_with(&mut self, from: OperationRef, to: OperationRef) { + self.notify_operation_replaced(from.clone(), to.clone()); + + let from_results = from + .borrow() + .results() + .all() + .iter() + .map(|result| result.borrow().as_value_ref()) + .collect::>(); + + let to_results = to + .borrow() + .results() + .all() + .iter() + .map(|result| result.borrow().as_value_ref()) + .collect::>(); + + self.replace_all_uses_with(&from_results, &to_results); + } + + /// Find uses of `from` within `block` and replace them with `to`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + fn replace_op_uses_within_block( + &mut self, + from: OperationRef, + to: &[ValueRef], + block: BlockRef, + ) -> bool { + let parent_op = block.borrow().parent_op(); + self.maybe_replace_op_uses_with(from, to, |operand| { + !parent_op + .as_ref() + .is_some_and(|op| op.borrow().is_proper_ancestor_of(&operand.owner)) + }) + } + + /// Find uses of `from` and replace them with `to`, except if the user is in `exceptions`. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + fn replace_all_uses_except( + &mut self, + from: ValueRef, + to: ValueRef, + exceptions: &[OperationRef], + ) { + self.maybe_replace_uses_of_value_with(from, to, |operand| { + !exceptions.contains(&operand.owner) + }); + } +} + +/// An extension trait for [Rewriter] implementations. +/// +/// This trait contains functionality that is not object safe, and would prevent using [Rewriter] as +/// a trait object. It is automatically implemented for all [Rewriter] impls. +pub trait RewriterExt: Rewriter { + /// Find uses of `from` and replace them with `to`, if `should_replace` returns true. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + fn maybe_replace_uses_of_value_with

( + &mut self, + mut from: ValueRef, + mut to: ValueRef, + should_replace: P, + ) -> bool + where + P: Fn(&OpOperandImpl) -> bool, + { + let mut all_replaced = true; + let mut from = from.borrow_mut(); + let from_uses = from.uses_mut(); + let mut cursor = from_uses.front_mut(); + while let Some(user) = cursor.as_pointer() { + if should_replace(&user.borrow()) { + let owner = user.borrow().owner.clone(); + self.notify_operation_modification_started(&owner); + let mut operand = cursor.remove().unwrap(); + { + operand.borrow_mut().value = to.clone(); + } + to.borrow_mut().insert_use(operand); + self.notify_operation_modified(owner); + } else { + all_replaced = false; + cursor.move_next(); + } + } + all_replaced + } + + /// Find uses of `from` and replace them with `to`, if `should_replace` returns true. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + fn maybe_replace_uses_with

( + &mut self, + from: &[ValueRef], + to: &[ValueRef], + should_replace: P, + ) -> bool + where + P: Fn(&OpOperandImpl) -> bool, + { + assert_eq!(from.len(), to.len(), "incorrect number of replacements"); + let mut all_replaced = true; + for (from, to) in from.iter().cloned().zip(to.iter().cloned()) { + all_replaced &= self.maybe_replace_uses_of_value_with(from, to, &should_replace); + } + all_replaced + } + + /// Find uses of `from` and replace them with `to`, if `should_replace` returns true. + /// + /// Notifies the listener about every in-place op modification (for every use that was replaced). + /// + /// Returns true if all uses were replaced, otherwise false. + fn maybe_replace_op_uses_with

( + &mut self, + from: OperationRef, + to: &[ValueRef], + should_replace: P, + ) -> bool + where + P: Fn(&OpOperandImpl) -> bool, + { + let results = SmallVec::<[ValueRef; 2]>::from_iter( + from.borrow().results.all().iter().cloned().map(|result| result.upcast()), + ); + self.maybe_replace_uses_with(&results, to, should_replace) + } +} + +impl RewriterExt for R {} + +#[allow(unused_variables)] +pub trait RewriterListener: Listener { + /// Notify the listener that the specified block is about to be erased. + /// + /// At this point, the block has zero uses. + fn notify_block_erased(&self, block: BlockRef) {} + + /// Notify the listener that an in-place modification of the specified operation has started + fn notify_operation_modification_started(&self, op: &OperationRef) {} + + /// Notify the listener that an in-place modification of the specified operation was canceled + fn notify_operation_modification_canceled(&self, op: &OperationRef) {} + + /// Notify the listener that the specified operation was modified in-place. + fn notify_operation_modified(&self, op: OperationRef) {} + + /// Notify the listener that all uses of the specified operation's results are about to be + /// replaced with the results of another operation. This is called before the uses of the old + /// operation have been changed. + /// + /// By default, this function calls the "operation replaced with values" notification. + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + let replacement = replacement.borrow(); + let values = replacement + .results() + .all() + .iter() + .cloned() + .map(|result| result.upcast()) + .collect::>(); + self.notify_operation_replaced_with_values(op, &values); + } + + /// Notify the listener that all uses of the specified operation's results are about to be + /// replaced with the given range of values, potentially produced by other operations. This is + /// called before the uses of the operation have been changed. + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) {} + + /// Notify the listener that the specified operation is about to be erased. At this point, the + /// operation has zero uses. + /// + /// NOTE: This notification is not triggered when unlinking an operation. + fn notify_operation_erased(&self, op: OperationRef) {} + + /// Notify the listener that the specified pattern is about to be applied at the specified root + /// operation. + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {} + + /// Notify the listener that a pattern application finished with the specified status. + /// + /// `true` indicates that the pattern was applied successfully. `false` indicates that the + /// pattern could not be applied. The pattern may have communicated the reason for the failure + /// with `notify_match_failure` + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {} + + /// Notify the listener that the pattern failed to match, and provide a diagnostic explaining + /// the reason why the failure occurred. + fn notify_match_failure(&self, span: SourceSpan, reason: Report) {} +} + +impl RewriterListener for Option { + fn notify_block_erased(&self, block: BlockRef) { + if let Some(listener) = self.as_ref() { + listener.notify_block_erased(block); + } + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_modification_started(op); + } + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_modification_canceled(op); + } + } + + fn notify_operation_modified(&self, op: OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_modified(op); + } + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_replaced(op, replacement); + } + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_replaced_with_values(op, replacement); + } + } + + fn notify_operation_erased(&self, op: OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_operation_erased(op); + } + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + if let Some(listener) = self.as_ref() { + listener.notify_pattern_begin(pattern, op); + } + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + if let Some(listener) = self.as_ref() { + listener.notify_pattern_end(pattern, success); + } + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + if let Some(listener) = self.as_ref() { + listener.notify_match_failure(span, reason); + } + } +} + +impl RewriterListener for Box { + fn notify_block_erased(&self, block: BlockRef) { + (**self).notify_block_erased(block); + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + (**self).notify_operation_modification_started(op); + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + (**self).notify_operation_modification_canceled(op); + } + + fn notify_operation_modified(&self, op: OperationRef) { + (**self).notify_operation_modified(op); + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + (**self).notify_operation_replaced(op, replacement); + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + (**self).notify_operation_replaced_with_values(op, replacement); + } + + fn notify_operation_erased(&self, op: OperationRef) { + (**self).notify_operation_erased(op) + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + (**self).notify_pattern_begin(pattern, op); + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + (**self).notify_pattern_end(pattern, success); + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + (**self).notify_match_failure(span, reason); + } +} + +impl RewriterListener for Rc { + fn notify_block_erased(&self, block: BlockRef) { + (**self).notify_block_erased(block); + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + (**self).notify_operation_modification_started(op); + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + (**self).notify_operation_modification_canceled(op); + } + + fn notify_operation_modified(&self, op: OperationRef) { + (**self).notify_operation_modified(op); + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + (**self).notify_operation_replaced(op, replacement); + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + (**self).notify_operation_replaced_with_values(op, replacement); + } + + fn notify_operation_erased(&self, op: OperationRef) { + (**self).notify_operation_erased(op) + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + (**self).notify_pattern_begin(pattern, op); + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + (**self).notify_pattern_end(pattern, success); + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + (**self).notify_match_failure(span, reason); + } +} + +/// A listener of kind `Rewriter` that does nothing +pub struct NoopRewriterListener; +impl Listener for NoopRewriterListener { + #[inline] + fn kind(&self) -> ListenerType { + ListenerType::Rewriter + } + + #[inline(always)] + fn notify_operation_inserted(&self, _op: OperationRef, _prev: Option) {} + + #[inline(always)] + fn notify_block_inserted( + &self, + _block: BlockRef, + _prev: Option, + _ip: Option, + ) { + } +} +impl RewriterListener for NoopRewriterListener { + fn notify_operation_replaced(&self, _op: OperationRef, _replacement: OperationRef) {} +} + +pub struct ForwardingListener { + base: Base, + derived: Derived, +} +impl ForwardingListener { + pub fn new(base: Base, derived: Derived) -> Self { + Self { base, derived } + } +} +impl Listener for ForwardingListener { + fn kind(&self) -> ListenerType { + self.derived.kind() + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + self.base.notify_block_inserted(block.clone(), prev.clone(), ip.clone()); + self.derived.notify_block_inserted(block, prev, ip); + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + self.base.notify_operation_inserted(op.clone(), prev.clone()); + self.derived.notify_operation_inserted(op, prev); + } +} +impl RewriterListener + for ForwardingListener +{ + fn notify_block_erased(&self, block: BlockRef) { + self.base.notify_block_erased(block.clone()); + self.derived.notify_block_erased(block); + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + self.base.notify_operation_modification_started(op); + self.derived.notify_operation_modification_started(op); + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + self.base.notify_operation_modification_canceled(op); + self.derived.notify_operation_modification_canceled(op); + } + + fn notify_operation_modified(&self, op: OperationRef) { + self.base.notify_operation_modified(op.clone()); + self.derived.notify_operation_modified(op); + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + self.base.notify_operation_replaced(op.clone(), replacement.clone()); + self.derived.notify_operation_replaced(op, replacement); + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + self.base.notify_operation_replaced_with_values(op.clone(), replacement); + self.derived.notify_operation_replaced_with_values(op, replacement); + } + + fn notify_operation_erased(&self, op: OperationRef) { + self.base.notify_operation_erased(op.clone()); + self.derived.notify_operation_erased(op); + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + self.base.notify_pattern_begin(pattern, op.clone()); + self.derived.notify_pattern_begin(pattern, op); + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + self.base.notify_pattern_end(pattern, success); + self.derived.notify_pattern_end(pattern, success); + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + let err = Report::msg(format!("{reason}")); + self.base.notify_match_failure(span, reason); + self.derived.notify_match_failure(span, err); + } +} + +/// Wraps an in-place modification of an [Operation] to ensure the rewriter is properly notified +/// about the progress and outcome of the in-place notification. +/// +/// This is a minor efficiency win, as it avoids creating a new operation, and removing the old one, +/// but also often allows simpler code in the client. +pub struct InPlaceModificationGuard<'a, R: ?Sized + Rewriter> { + rewriter: &'a mut R, + op: OperationRef, + canceled: bool, +} +impl<'a, R> InPlaceModificationGuard<'a, R> +where + R: ?Sized + Rewriter, +{ + pub fn new(rewriter: &'a mut R, op: OperationRef) -> Self { + rewriter.notify_operation_modification_started(&op); + Self { + rewriter, + op, + canceled: false, + } + } + + #[inline] + pub fn rewriter(&mut self) -> &mut R { + self.rewriter + } + + #[inline] + pub fn op(&self) -> &OperationRef { + &self.op + } + + /// Cancels the pending in-place modification. + pub fn cancel(mut self) { + self.canceled = true; + } + + /// Signals the end of an in-place modification of the current operation. + pub fn finalize(self) {} +} +impl core::ops::Deref for InPlaceModificationGuard<'_, R> { + type Target = R; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.rewriter + } +} +impl core::ops::DerefMut for InPlaceModificationGuard<'_, R> { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.rewriter + } +} +impl Drop for InPlaceModificationGuard<'_, R> { + fn drop(&mut self) { + if self.canceled { + self.rewriter.notify_operation_modification_canceled(&self.op); + } else { + self.rewriter.notify_operation_modified(self.op.clone()); + } + } +} + +/// A special type of `RewriterBase` that coordinates the application of a rewrite pattern on the +/// current IR being matched, providing a way to keep track of any mutations made. +/// +/// This type should be used to perform all necessary IR mutations within a rewrite pattern, as +/// the pattern driver may be tracking various state that would be invalidated when a mutation takes +/// place. +pub struct PatternRewriter { + rewriter: RewriterImpl, + recoverable: bool, +} + +impl PatternRewriter { + pub fn new(context: Rc) -> Self { + let rewriter = RewriterImpl::new(context); + Self { + rewriter, + recoverable: false, + } + } + + pub fn from_builder(builder: OpBuilder) -> Self { + let (context, _, ip) = builder.into_parts(); + let mut rewriter = RewriterImpl::new(context); + rewriter.restore_insertion_point(ip); + Self { + rewriter, + recoverable: false, + } + } +} + +impl PatternRewriter { + pub fn new_with_listener(context: Rc, listener: L) -> Self { + let rewriter = RewriterImpl::::new(context).with_listener(listener); + Self { + rewriter, + recoverable: false, + } + } + + #[inline] + pub const fn can_recover_from_rewrite_failure(&self) -> bool { + self.recoverable + } +} +impl Deref for PatternRewriter { + type Target = RewriterImpl; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + &self.rewriter + } +} +impl DerefMut for PatternRewriter { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.rewriter + } +} + +pub struct RewriterImpl { + context: Rc, + listener: Option, + ip: Option, +} + +impl RewriterImpl { + pub fn new(context: Rc) -> Self { + Self { + context, + listener: None, + ip: None, + } + } + + pub fn with_listener(self, listener: L2) -> RewriterImpl + where + L2: Listener, + { + RewriterImpl { + context: self.context, + listener: Some(listener), + ip: self.ip, + } + } +} + +impl From> for RewriterImpl { + #[inline] + fn from(builder: OpBuilder) -> Self { + let (context, listener, ip) = builder.into_parts(); + Self { + context, + listener, + ip, + } + } +} + +impl Builder for RewriterImpl { + #[inline(always)] + fn context(&self) -> &Context { + &self.context + } + + #[inline(always)] + fn context_rc(&self) -> Rc { + self.context.clone() + } + + #[inline(always)] + fn insertion_point(&self) -> Option<&InsertionPoint> { + self.ip.as_ref() + } + + #[inline(always)] + fn clear_insertion_point(&mut self) -> Option { + self.ip.take() + } + + #[inline(always)] + fn restore_insertion_point(&mut self, ip: Option) { + self.ip = ip; + } + + #[inline(always)] + fn set_insertion_point(&mut self, ip: InsertionPoint) { + self.ip = Some(ip); + } +} + +impl Rewriter for RewriterImpl { + #[inline(always)] + fn has_listener(&self) -> bool { + self.listener.is_some() + } +} + +impl Listener for RewriterImpl { + fn kind(&self) -> ListenerType { + ListenerType::Rewriter + } + + fn notify_operation_inserted(&self, op: OperationRef, prev: Option) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_inserted(op, prev); + } + } + + fn notify_block_inserted( + &self, + block: BlockRef, + prev: Option, + ip: Option, + ) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_block_inserted(block, prev, ip); + } + } +} + +impl RewriterListener for RewriterImpl { + fn notify_block_erased(&self, block: BlockRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_block_erased(block); + } + } + + fn notify_operation_modification_started(&self, op: &OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_modification_started(op); + } + } + + fn notify_operation_modification_canceled(&self, op: &OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_modification_canceled(op); + } + } + + fn notify_operation_modified(&self, op: OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_modified(op); + } + } + + fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) { + if self.listener.is_some() { + let replacement = replacement.borrow(); + let values = replacement + .results() + .all() + .iter() + .cloned() + .map(|result| result.upcast()) + .collect::>(); + self.notify_operation_replaced_with_values(op, &values); + } + } + + fn notify_operation_replaced_with_values(&self, op: OperationRef, replacement: &[ValueRef]) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_replaced_with_values(op, replacement); + } + } + + fn notify_operation_erased(&self, op: OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_operation_erased(op); + } + } + + fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_pattern_begin(pattern, op); + } + } + + fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_pattern_end(pattern, success); + } + } + + fn notify_match_failure(&self, span: SourceSpan, reason: Report) { + if let Some(listener) = self.listener.as_ref() { + listener.notify_match_failure(span, reason); + } + } +} diff --git a/midenc-session/src/duration.rs b/midenc-session/src/duration.rs index 34aae7d9a..99d053ddf 100644 --- a/midenc-session/src/duration.rs +++ b/midenc-session/src/duration.rs @@ -3,6 +3,7 @@ use std::{ time::{Duration, Instant}, }; +#[derive(Copy, Clone)] pub struct HumanDuration(Duration); impl HumanDuration { pub fn since(i: Instant) -> Self { @@ -14,6 +15,41 @@ impl HumanDuration { pub fn as_secs_f64(&self) -> f64 { self.0.as_secs_f64() } + + /// Adds two [HumanDuration], using saturating arithmetic + #[inline] + pub fn saturating_add(self, rhs: Self) -> Self { + Self(self.0.saturating_add(rhs.0)) + } +} +impl core::ops::Add for HumanDuration { + type Output = HumanDuration; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0) + } +} +impl core::ops::Add for HumanDuration { + type Output = HumanDuration; + + fn add(self, rhs: Duration) -> Self::Output { + Self(self.0 + rhs) + } +} +impl core::ops::AddAssign for HumanDuration { + fn add_assign(&mut self, rhs: Self) { + self.0 += rhs.0; + } +} +impl core::ops::AddAssign for HumanDuration { + fn add_assign(&mut self, rhs: Duration) { + self.0 += rhs; + } +} +impl From for Duration { + fn from(d: HumanDuration) -> Self { + d.0 + } } impl From for HumanDuration { fn from(d: Duration) -> Self {