diff --git a/Project.toml b/Project.toml index 9d67914cc..c51fa023f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.47" +version = "0.4.48" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/make.jl b/docs/make.jl index b457a7fa4..c24fbea32 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,6 +19,9 @@ makedocs( ), ) ), + size_threshold_ignore=[ + joinpath("developer_documentation", "internal_docstrings.md"), + ], ), modules=[Mooncake], checkdocs=:none, @@ -28,20 +31,21 @@ makedocs( pages = [ "Mooncake.jl" => "index.md", "Understanding Mooncake.jl" => [ - "Introduction" => "understanding_intro.md", - "Algorithmic Differentiation" => "algorithmic_differentiation.md", - "Mooncake.jl's Rule System" => "mathematical_interpretation.md", + joinpath("understanding_mooncake", "introduction.md"), + joinpath("understanding_mooncake", "algorithmic_differentiation.md"), + joinpath("understanding_mooncake", "rule_system.md"), ], "Utilities" => [ - "Tools for Rules" => "tools_for_rules.md", - "Debug Mode" => "debug_mode.md", - "Debugging and MWEs" => "debugging_and_mwes.md", + joinpath("utilities", "tools_for_rules.md"), + joinpath("utilities", "debug_mode.md"), + joinpath("utilities", "debugging_and_mwes.md"), ], "Developer Documentation" => [ - "Running Tests Locally" => "running_tests_locally.md", - "Developer Tools" => "developer_tools.md", + joinpath("developer_documentation", "running_tests_locally.md"), + joinpath("developer_documentation", "developer_tools.md"), + joinpath("developer_documentation", "internal_docstrings.md"), ], - "Known Limitations" => "known_limitations.md", + "known_limitations.md", ] ) diff --git a/docs/src/developer_tools.md b/docs/src/developer_documentation/developer_tools.md similarity index 100% rename from docs/src/developer_tools.md rename to docs/src/developer_documentation/developer_tools.md diff --git a/docs/src/developer_documentation/internal_docstrings.md b/docs/src/developer_documentation/internal_docstrings.md new file mode 100644 index 000000000..ce3fe608f --- /dev/null +++ b/docs/src/developer_documentation/internal_docstrings.md @@ -0,0 +1,11 @@ +# Internal Docstrings + +Docstrings listed here are _not_ part of the public Mooncake.jl interface. +Consequently, they can change between non-breaking changes to Mooncake.jl without warning. + +The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL. + +```@autodocs; canonical=false +Modules = [Mooncake] +Public = false +``` diff --git a/docs/src/running_tests_locally.md b/docs/src/developer_documentation/running_tests_locally.md similarity index 100% rename from docs/src/running_tests_locally.md rename to docs/src/developer_documentation/running_tests_locally.md diff --git a/docs/src/algorithmic_differentiation.md b/docs/src/understanding_mooncake/algorithmic_differentiation.md similarity index 100% rename from docs/src/algorithmic_differentiation.md rename to docs/src/understanding_mooncake/algorithmic_differentiation.md diff --git a/docs/src/understanding_intro.md b/docs/src/understanding_mooncake/introduction.md similarity index 98% rename from docs/src/understanding_intro.md rename to docs/src/understanding_mooncake/introduction.md index 5525cba4a..36f1f4bc8 100644 --- a/docs/src/understanding_intro.md +++ b/docs/src/understanding_mooncake/introduction.md @@ -1,4 +1,4 @@ -# Mooncake.jl and Reverse-Mode AD +# Introduction The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain _what_ precisely is meant by this, and _how_ it can be interpreted mathematically. diff --git a/docs/src/mathematical_interpretation.md b/docs/src/understanding_mooncake/rule_system.md similarity index 100% rename from docs/src/mathematical_interpretation.md rename to docs/src/understanding_mooncake/rule_system.md diff --git a/docs/src/debug_mode.md b/docs/src/utilities/debug_mode.md similarity index 100% rename from docs/src/debug_mode.md rename to docs/src/utilities/debug_mode.md diff --git a/docs/src/debugging_and_mwes.md b/docs/src/utilities/debugging_and_mwes.md similarity index 100% rename from docs/src/debugging_and_mwes.md rename to docs/src/utilities/debugging_and_mwes.md diff --git a/docs/src/tools_for_rules.md b/docs/src/utilities/tools_for_rules.md similarity index 100% rename from docs/src/tools_for_rules.md rename to docs/src/utilities/tools_for_rules.md diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 187f4453b..5ee8f9e50 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -68,8 +68,8 @@ include("stack.jl") include(joinpath("interpreter", "contexts.jl")) include(joinpath("interpreter", "abstract_interpretation.jl")) -include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_utils.jl")) +include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index 428f979f4..ade6a4725 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -901,9 +901,9 @@ Useful for implementation getfield-like rules for mutable structs, pointers, dic increment_rdata!!(t::T, r) where {T} = tangent(fdata(t), increment!!(rdata(t), r))::T """ - zero_tangent(p, ::NoFData) - + zero_tangent(primal, fdata) +Equivalent to `tangent(fdata, rdata(zero_tangent(primal)))`. """ zero_tangent(p, ::NoFData) = zero_tangent(p) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index eeb5d26b3..d8b15322e 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -61,7 +61,11 @@ MooncakeInterpreter() = MooncakeInterpreter(DefaultCtx) context_type(::MooncakeInterpreter{C}) where {C} = C -# Globally cached interpreter. Should only be accessed via `get_interpreter`. +""" + const GLOBAL_INTERPRETER + +Globally cached interpreter. Should only be accessed via `get_interpreter`. +""" const GLOBAL_INTERPRETER = Ref(MooncakeInterpreter()) """ diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl index 5cdbe8073..4fcd53cd3 100644 --- a/src/interpreter/bbcode.jl +++ b/src/interpreter/bbcode.jl @@ -45,6 +45,8 @@ struct IDPhiNode values::Vector{Any} end +Base.:(==)(x::IDPhiNode, y::IDPhiNode) = x.edges == y.edges && x.values == y.values + Base.copy(node::IDPhiNode) = IDPhiNode(copy(node.edges), copy(node.values)) """ @@ -107,15 +109,6 @@ A Union of the possible types of a terminator node. """ const Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode} -""" - const InstVector = Vector{NewInstruction} - -Note: the `CC.NewInstruction` type is used to represent instructions because it has the -correct fields. While it is only used to represent new instrucdtions in `Core.Compiler`, it -is used to represent all instructions in `BBCode`. -""" -const InstVector = Vector{NewInstruction} - """ BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector) @@ -283,9 +276,13 @@ function compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}} return _compute_all_successors(ir.blocks) end -# Internal method. Just requires that a Vector of BBlocks are passed. This method is easier -# to construct test cases for because there is no need to construct all the other stuff that -# goes into a `BBCode`. +""" + _compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} + +Internal method implementing [`compute_all_successors`](@ref). This method is easier to +construct test cases for because it only requires the collection of `BBlocks`, not all of +the other stuff that goes into a `BBCode`. +""" function _compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} succs = map(enumerate(blks)) do (n, blk) return successors(terminator(blk), n, blks, n == length(blks)) @@ -312,6 +309,13 @@ function compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} return _compute_all_predecessors(ir.blocks) end +""" + _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} + +Internal method implementing [`compute_all_predecessors`](@ref). This method is easier to +construct test cases for because it only requires the collection of `BBlocks`, not all of +the other stuff that goes into a `BBCode`. +""" function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} successor_map = _compute_all_successors(blks) @@ -363,8 +367,12 @@ Computes the `Core.Compiler.CFG` object associated to this `bb_code`. """ control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG = _control_flow_graph(bb_code.blocks) -# Internal function, used to implement control_flow_graph, but easier to write test cases -# for because there is no need to construct an ensure BBCode object. +""" + _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG + +Internal function, used to implement [`control_flow_graph`](@ref). Easier to write test +cases for because there is no need to construct an ensure BBCode object, just the `BBlock`s. +""" function _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG # Get IDs of predecessors and successors. @@ -418,7 +426,12 @@ function BBCode(ir::IRCode) return BBCode(ir, blocks) end -# Convert an InstructionStream into a list of `NewInstruction`s. + +""" + new_inst_vec(x::CC.InstructionStream) + +Convert an `Compiler.InstructionStream` into a list of `Compiler.NewInstruction`s. +""" function new_inst_vec(x::CC.InstructionStream) return map((v..., ) -> NewInstruction(v...), stmt(x), x.type, x.info, x.line, x.flag) end @@ -427,19 +440,27 @@ end const SSAToIdDict = Dict{SSAValue, ID} const BlockNumToIdDict = Dict{Integer, ID} -# Assigns an ID to each line in `stmts`, and replaces each instance of an `SSAValue` in each -# line with the corresponding `ID`. For example, a call statement of the form -# `Expr(:call, :f, %4)` is be replaced with `Expr(:call, :f, id_assigned_to_%4)`. +""" + _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector} + +Assigns an ID to each line in `stmts`, and replaces each instance of an `SSAValue` in each +line with the corresponding `ID`. For example, a call statement of the form +`Expr(:call, :f, %4)` is be replaced with `Expr(:call, :f, id_assigned_to_%4)`. +""" function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector} ids = map(_ -> ID(), insts) val_id_map = SSAToIdDict(zip(SSAValue.(eachindex(insts)), ids)) return ids, map(Base.Fix1(_ssa_to_ids, val_id_map), insts) end -# Produce a new instance of `x` in which all instances of `SSAValue`s are replaced with -# the `ID`s prescribed by `d`, all basic block numbers are replaced with the `ID`s -# prescribed by `d`, and `GotoIfNot`, `GotoNode`, and `PhiNode` instances are replaced with -# the corresponding `ID` versions. +""" + _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) + +Produce a new instance of `inst` in which all instances of `SSAValue`s are replaced with +the `ID`s prescribed by `d`, all basic block numbers are replaced with the `ID`s +prescribed by `d`, and `GotoIfNot`, `GotoNode`, and `PhiNode` instances are replaced with +the corresponding `ID` versions. +""" function _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) return NewInstruction(inst; stmt=_ssa_to_ids(d, inst.stmt)) end @@ -462,7 +483,12 @@ end _ssa_to_ids(d::SSAToIdDict, x::GotoNode) = x _ssa_to_ids(d::SSAToIdDict, x::GotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), x.dest) -# Replace all integers corresponding to references to blocks with IDs. +""" + _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector} + +Assign to each basic block in `cfg` an `ID`. Replace all integers referencing block numbers +in `insts` with the corresponding `ID`. Return the `ID`s and the updated instructions. +""" function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector} ids = map(_ -> ID(), cfg.blocks) block_num_id_map = BlockNumToIdDict(zip(eachindex(cfg.blocks), ids)) @@ -498,7 +524,7 @@ collection of `GotoIfNot` nodes. function CC.IRCode(bb_code::BBCode) bb_code = _lower_switch_statements(bb_code) bb_code = _remove_double_edges(bb_code) - insts = _ids_to_line_positions(bb_code) + insts = _ids_to_line_numbers(bb_code) cfg = control_flow_graph(bb_code) insts = _lines_to_blocks(insts, cfg) return IRCode( @@ -517,8 +543,12 @@ function CC.IRCode(bb_code::BBCode) ) end -# Converts all `Switch`s into a semantically-equivalent collection of `GotoIfNot`s. See the -# `Switch` docstring for an explanation of what is going on here. +""" + _lower_switch_statements(bb_code::BBCode) + +Converts all `Switch`s into a semantically-equivalent collection of `GotoIfNot`s. See the +`Switch` docstring for an explanation of what is going on here. +""" function _lower_switch_statements(bb_code::BBCode) new_blocks = Vector{BBlock}(undef, 0) for block in bb_code.blocks @@ -545,9 +575,13 @@ function _lower_switch_statements(bb_code::BBCode) return BBCode(bb_code, new_blocks) end -# Returns a `Vector{Any}` of statements in which each `ID` has been replaced by either an -# `SSAValue`, or an `Int64` / `Int32` which refers to an `SSAValue`. -function _ids_to_line_positions(bb_code::BBCode)::InstVector +""" + _ids_to_line_numbers(bb_code::BBCode)::InstVector + +For each statement in `bb_code`, returns a `NewInstruction` in which every `ID` is replaced +by either an `SSAValue`, or an `Int64` / `Int32` which refers to an `SSAValue`. +""" +function _ids_to_line_numbers(bb_code::BBCode)::InstVector # Construct map from `ID`s to `SSAValue`s. block_ids = [b.id for b in bb_code.blocks] @@ -561,7 +595,12 @@ function _ids_to_line_positions(bb_code::BBCode)::InstVector return [_to_ssas(id_to_ssa_map, stmt) for stmt in concatenate_stmts(bb_code)] end -# Like `_to_ids`, but converts IDs to SSAValues / (integers corresponding to ssas). +""" + _to_ssas(d::Dict, inst::NewInstruction) + +Like `_ssas_to_ids`, but in reverse. Converts IDs to SSAValues / (integers corresponding +to ssas). +""" _to_ssas(d::Dict, inst::NewInstruction) = NewInstruction(inst; stmt=_to_ssas(d, inst.stmt)) _to_ssas(d::Dict, x::ReturnNode) = isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x _to_ssas(d::Dict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) @@ -580,31 +619,14 @@ end _to_ssas(d::Dict, x::IDGotoNode) = GotoNode(d[x.label].id) _to_ssas(d::Dict, x::IDGotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), d[x.dest].id) -# Replaces references to blocks by line-number with references to block numbers. -function _lines_to_blocks(insts::InstVector, cfg::CC.CFG) - return map(inst -> __lines_to_blocks(cfg, inst), insts) -end - -function __lines_to_blocks(cfg::CC.CFG, inst::NewInstruction) - return NewInstruction(inst; stmt=__lines_to_blocks(cfg, inst.stmt)) -end -function __lines_to_blocks(cfg::CC.CFG, stmt::GotoNode) - return GotoNode(CC.block_for_inst(cfg, stmt.label)) -end -function __lines_to_blocks(cfg::CC.CFG, stmt::GotoIfNot) - return GotoIfNot(stmt.cond, CC.block_for_inst(cfg, stmt.dest)) -end -function __lines_to_blocks(cfg::CC.CFG, stmt::PhiNode) - return PhiNode(Int32[CC.block_for_inst(cfg, Int(e)) for e in stmt.edges], stmt.values) -end -function __lines_to_blocks(cfg::CC.CFG, stmt::Expr) - Meta.isexpr(stmt, :enter) && throw(error("Cannot handle enter yet")) - return stmt -end -__lines_to_blocks(::CC.CFG, stmt) = stmt +""" + _remove_double_edges(ir::BBCode)::BBCode -# If the `dest` field of a `GotoIfNot` node points towards the next block, replace it with -# a `GotoNode`. +If the `dest` field of an `IDGotoIfNot` node in block `n` of `ir` points towards the `n+1`th +block then we have two edges from block `n` to block `n+1`. This transformation replaces all +such `IDGotoIfNot` nodes with unconditional `IDGotoNode`s pointing towards the `n+1`th block +in `ir`. +""" function _remove_double_edges(ir::BBCode) new_blks = map(enumerate(ir.blocks)) do (n, blk) t = terminator(blk) @@ -652,7 +674,7 @@ function _distance_to_entry(blks::Vector{BBlock})::Vector{Int} return dijkstra_shortest_paths(g, id_to_int[blks[1].id]).dists end -#= +""" _sort_blocks!(ir::BBCode)::BBCode Ensure that blocks appear in order of distance-from-entry-point, where distance the @@ -666,14 +688,14 @@ there. WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in `ir` is valid. Notably, this does not hold if you have any `IDGotoIfNot` nodes in `ir`. -=# +""" function _sort_blocks!(ir::BBCode)::BBCode I = sortperm(_distance_to_entry(ir.blocks)) ir.blocks .= ir.blocks[I] return ir end -#= +""" characterise_unique_predecessor_blocks(blks::Vector{BBlock}) -> Tuple{Dict{ID, Bool}, Dict{ID, Bool}} @@ -700,7 +722,7 @@ working with cheap loops -- loops where the operations performed at each iterati are inexpensive -- for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all. -=# +""" function characterise_unique_predecessor_blocks( blks::Vector{BBlock} )::Tuple{Dict{ID, Bool}, Dict{ID, Bool}} @@ -767,7 +789,14 @@ function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool} return is_used end -# Helper function used in characterise_used_ids. +""" + _find_id_uses!(d::Dict{ID, Bool}, x) + +Helper function used in [`characterise_used_ids`](@ref). For all uses of `ID`s in `x`, set +the corresponding value of `d` to `true`. + +For example, if `x = ReturnNode(ID(5))`, then this function sets `d[ID(5)] = true`. +""" function _find_id_uses!(d::Dict{ID, Bool}, x::Expr) for arg in x.args in(arg, keys(d)) && setindex!(d, true, arg) @@ -854,3 +883,27 @@ function _remove_unreachable_blocks!(blks::Vector{BBlock}) return remaining_blks end + +""" + inc_args(stmt) + +Increment by `1` the `n` field of any `Argument`s present in `stmt`. +""" +inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) +inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x +inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) +inc_args(x::IDGotoNode) = x +function inc_args(x::IDPhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = __inc(x.values[n]) + end + end + return IDPhiNode(x.edges, new_values) +end +inc_args(::Nothing) = nothing +inc_args(x::GlobalRef) = x + +__inc(x::Argument) = Argument(x.n + 1) +__inc(x) = x diff --git a/src/interpreter/ir_normalisation.jl b/src/interpreter/ir_normalisation.jl index fb0f57f02..6c835f023 100644 --- a/src/interpreter/ir_normalisation.jl +++ b/src/interpreter/ir_normalisation.jl @@ -8,7 +8,8 @@ unchanged, but makes AD more straightforward. In particular, replace 3. `:splatnew` Expr`s with `:call`s to `Mooncake._splat_new_`, 4. `Core.IntrinsicFunction`s with counterparts from `Mooncake.IntrinsicWrappers`, 5. `getfield(x, 1)` with `lgetfield(x, Val(1))`, and related transformations, -6. `gc_preserve_begin` / `gc_preserve_end` exprs so that memory release is delayed. +6. `memoryrefget` calls to `lmemoryrefget` calls, and related transformations, +7. `gc_preserve_begin` / `gc_preserve_end` exprs so that memory release is delayed. `spnames` are the names associated to the static parameters of `ir`. These are needed when handling `:foreigncall` expressions, in which it is not necessarily the case that all @@ -218,7 +219,9 @@ if VERSION >= v"1.11-" """ lift_memoryrefget_and_memoryrefset_builtins(inst) -Replaces memoryrefget -> lmemoryrefget and memoryrefset! -> lmemoryrefset!. +Replaces memoryrefget -> lmemoryrefget and memoryrefset! -> lmemoryrefset! if their final +two arguments (`ordering` and `boundscheck`) are constants. See [`lmemoryrefget`] and +[`lmemoryrefset!`](@ref) for more context. """ function lift_memoryrefget_and_memoryrefset_builtins(inst) Meta.isexpr(inst, :call) || return inst diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index cfa1f3c16..9e736daf4 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -1,4 +1,18 @@ -# the field containing the instructions in `IRCode` changed name in 1.11 from inst to stmt. +""" + const InstVector = Vector{NewInstruction} + +Note: the `CC.NewInstruction` type is used to represent instructions because it has the +correct fields. While it is only used to represent new instrucdtions in `Core.Compiler`, it +is used to represent all instructions in `BBCode`. +""" +const InstVector = Vector{NewInstruction} + +""" + stmt(ir::CC.InstructionStream) + +Get the field containing the instructions in `ir`. This changed name in 1.11 from `inst` to +`stmt`. +""" stmt(ir::CC.InstructionStream) = @static VERSION < v"1.11.0-rc4" ? ir.inst : ir.stmt """ @@ -30,7 +44,9 @@ function ircode( return CC.IRCode(stmts, cfg, linetable, argtypes, meta, CC.VarState[]) end -#= +""" + __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) + Converts any edges in `GotoNode`s, `GotoIfNot`s, `PhiNode`s, and `:enter` expressions which refer to line numbers into references to block numbers. The `cfg` provides the information required to perform this conversion. @@ -39,7 +55,7 @@ For context, `CodeInfo` objects have references to line numbers, while `IRCode` block numbers. This code is copied over directly from the body of `Core.Compiler.inflate_ir!`. -=# +""" function __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) for i in eachindex(insts) stmt = insts[i] @@ -59,9 +75,21 @@ function __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) return insts end -#= +""" + _instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector + +Pulls out the instructions from `insts`, and calls `__line_numbers_to_block_numbers!`. +""" +function _lines_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector + stmts = __line_numbers_to_block_numbers!(Any[x.stmt for x in insts], cfg) + return map((inst, stmt) -> NewInstruction(inst; stmt), insts, stmts) +end + +""" + __insts_to_instruction_stream(insts::Vector{Any}) + Produces an instruction stream whose -- `inst` field is `insts`, +- `stmt` (v1.11 and up) / `inst` (v1.10) field is `insts`, - `type` field is all `Any`, - `info` field is all `Core.Compiler.NoCallInfo`, - `line` field is all `Int32(1)`, and @@ -69,7 +97,7 @@ Produces an instruction stream whose As such, if you wish to ensure that your `IRCode` prints nicely, you should ensure that its linetable field has at least one element. -=# +""" function __insts_to_instruction_stream(insts::Vector{Any}) return CC.InstructionStream( insts, @@ -271,30 +299,6 @@ Throw an `UnhandledLanguageFeatureException` with message `msg`. """ unhandled_feature(msg::String) = throw(UnhandledLanguageFeatureException(msg)) -""" - inc_args(stmt) - -Increment by `1` the `n` field of any `Argument`s present in `stmt`. -""" -inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) -inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x -inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) -inc_args(x::IDGotoNode) = x -function inc_args(x::IDPhiNode) - new_values = Vector{Any}(undef, length(x.values)) - for n in eachindex(x.values) - if isassigned(x.values, n) - new_values[n] = __inc(x.values[n]) - end - end - return IDPhiNode(x.edges, new_values) -end -inc_args(::Nothing) = nothing -inc_args(x::GlobalRef) = x - -__inc(x::Argument) = Argument(x.n + 1) -__inc(x) = x - """ new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index a25c20e12..0de5480ba 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -1,4 +1,4 @@ -#= +""" SharedDataPairs() A data structure used to manage the captured data in the `OpaqueClosures` which implement @@ -9,26 +9,26 @@ of the `pairs` field of this data structure means that `data` will be available This is achieved by storing all of the data in the `pairs` field in the captured tuple which is passed to an `OpaqueClosure`, and extracting this data into registers associated to the corresponding `ID`s. -=# +""" struct SharedDataPairs pairs::Vector{Tuple{ID, Any}} SharedDataPairs() = new(Tuple{ID, Any}[]) end -#= +""" add_data!(p::SharedDataPairs, data)::ID Puts `data` into `p`, and returns the `id` associated to it. This `id` should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this `id` is always `data`. -=# +""" function add_data!(p::SharedDataPairs, data)::ID id = ID() push!(p.pairs, (id, data)) return id end -#= +""" shared_data_tuple(p::SharedDataPairs)::Tuple Create the tuple that will constitute the captured variables in the forwards- and reverse- @@ -42,10 +42,10 @@ then the output of this function is ```julia (5.0, "hello") ``` -=# +""" shared_data_tuple(p::SharedDataPairs)::Tuple = tuple(map(last, p.pairs)...) -#= +""" shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair} Produce a sequence of id-statment pairs which will extract the data from @@ -62,7 +62,7 @@ IDInstPair[ (ID(3), new_inst(:(getfield(_1, 2)))), ] ``` -=# +""" function shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair} return map(enumerate(p.pairs)) do (n, p) return (p[1], new_inst(Expr(:call, get_shared_data_field, Argument(1), n))) @@ -71,16 +71,16 @@ end @inline get_shared_data_field(shared_data, n) = getfield(shared_data, n) -#= +""" The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than `typemax(Int32)` unique basic blocks in a given function, which ought to be reasonable. -=# +""" const BlockStack = Stack{Int32} -#= +""" ADInfo This data structure is used to hold "global" information associated to a particular call to @@ -119,7 +119,7 @@ codegen which produces the forwards- and reverse-passes. To achieve this, we construct a `LazyZeroRData` for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information. -=# +""" struct ADInfo interp::MooncakeInterpreter block_stack_id::ID @@ -176,20 +176,36 @@ function ADInfo(interp::MooncakeInterpreter, ir::BBCode, debug_mode::Bool) return ADInfo(interp, arg_types, ssa_insts, is_used_dict, debug_mode, zero_lazy_rdata_ref) end -# Shortcut for `add_data!(info.shared_data_pairs, data)`. +""" + add_data!(info::ADInfo, data)::ID + +Equivalent to `add_data!(info.shared_data_pairs, data)`. +""" add_data!(info::ADInfo, data)::ID = add_data!(info.shared_data_pairs, data) -# Returns `x` if it is a singleton, or the `ID` of the ssa which will contain it on the -# forwards- and reverse-passes. The reason for this is that if something is a singleton, it -# can be placed directly in the IR. +""" + add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x) + +Returns `x` if it is a singleton, or the `ID` of the ssa which will contain it on the +forwards- and reverse-passes. The reason for this is that if something is a singleton, it +can be inserted directly into the IR. +""" function add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x) return Base.issingletontype(_typeof(x)) ? x : add_data!(p, x) end -# Returns `true` if `id` is used by any of the lines in the ir, false otherwise. +""" + is_used(info::ADInfo, id::ID)::Bool + +Returns `true` if `id` is used by any of the lines in the ir, false otherwise. +""" is_used(info::ADInfo, id::ID)::Bool = info.is_used_dict[id] -# Returns the static / inferred type associated to `x`. +""" + get_primal_type(info::ADInfo, x) + +Returns the static / inferred type associated to `x`. +""" get_primal_type(info::ADInfo, x::Argument) = info.arg_types[x] get_primal_type(info::ADInfo, x::ID) = _type(info.ssa_insts[x].type) get_primal_type(::ADInfo, x::QuoteNode) = _typeof(x.value) @@ -198,23 +214,38 @@ function get_primal_type(::ADInfo, x::GlobalRef) return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty end -# Returns the `ID` associated to the line in the reverse pass which will contain the -# reverse data for `x`. If `x` is not an `Argument` or `ID`, then `nothing` is returned. +""" + get_rev_data_id(info::ADInfo, x) + +Returns the `ID` associated to the line in the reverse pass which will contain the +reverse data for `x`. If `x` is not an `Argument` or `ID`, then `nothing` is returned. +""" get_rev_data_id(info::ADInfo, x::Argument) = info.arg_rdata_ref_ids[x] get_rev_data_id(info::ADInfo, x::ID) = info.ssa_rdata_ref_ids[x] get_rev_data_id(::ADInfo, ::Any) = nothing -# Create the statements which initialise the reverse-data `Ref`s. +""" + reverse_data_ref_stmts(info::ADInfo) + +Create the statements which initialise the reverse-data `Ref`s. +""" function reverse_data_ref_stmts(info::ADInfo) - arg_stmts = [(id, __ref(_type(info.arg_types[k]))) for (k, id) in info.arg_rdata_ref_ids] - ssa_stmts = [(id, __ref(_type(info.ssa_insts[k].type))) for (k, id) in info.ssa_rdata_ref_ids] - return vcat(arg_stmts, ssa_stmts) + return vcat( + map(collect(info.arg_rdata_ref_ids)) do (k, id) + (id, new_inst(Expr(:call, __make_ref, _type(info.arg_types[k])))) + end, + map(collect(info.ssa_rdata_ref_ids)) do (k, id) + (id, new_inst(Expr(:call, __make_ref, _type(info.ssa_insts[k].type)))) + end, + ) end -# Helper for reverse_data_ref_stmts. -__ref(P) = new_inst(Expr(:call, __make_ref, P)) +""" + __make_ref(p::Type{P}) where {P} -# Helper for reverse_data_ref_stmts. +Helper for [`reverse_data_ref_stmts`](@ref). Constructs a `Ref` whose element type is the +[`zero_like_rdata_type`](@ref) for `P`, and whose element is the zero-like rdata for `P`. +""" @inline @generated function __make_ref(p::Type{P}) where {P} _P = @isdefined(P) ? P : _typeof(p) R = zero_like_rdata_type(_P) @@ -232,12 +263,16 @@ end # Returns the number of arguments that the primal function has. num_args(info::ADInfo) = length(info.arg_types) -# This struct is used to ensure that `ZeroRData`s, which are used as placeholder zero -# elements whenever an actual instance of a zero rdata for a particular primal type cannot -# be constructed without also having an instance of said type, never reach rules. -# On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures -# that if it is a `ZeroRData`, we instead get an actual zero of the correct type. If it is -# not a zero rdata, the computation _should_ be elided via inlining + constant prop. +""" + RRuleZeroWrapper(rule) + +This struct is used to ensure that `ZeroRData`s, which are used as placeholder zero +elements whenever an actual instance of a zero rdata for a particular primal type cannot +be constructed without also having an instance of said type, never reach rules. +On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures +that if it is a `ZeroRData`, we instead get an actual zero of the correct type. If it is +not a zero rdata, the computation _should_ be elided via inlining + constant prop. +""" struct RRuleZeroWrapper{Trule} rule::Trule end @@ -257,7 +292,7 @@ end return y::CoDual, (pb!! isa NoPullback ? pb!! : RRuleWrapperPb(pb!!, l)) end -#= +""" ADStmtInfo Data structure which contains the result of `make_ad_stmts!`. Fields are @@ -269,7 +304,7 @@ Data structure which contains the result of `make_ad_stmts!`. Fields are per-block basis. - `fwds`: the instructions which run the forwards-pass of AD - `rvs`: the instructions which run the reverse-pass of AD / the pullback -=# +""" struct ADStmtInfo line::ID comms_id::Union{ID, Nothing} @@ -277,8 +312,12 @@ struct ADStmtInfo rvs::Vector{IDInstPair} end -# Convenient constructor for `ADStmtInfo`. If either `fwds` or `rvs` is not a vector, -# `__vec` promotes it to a single-element `Vector`. +""" + ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs) + +Convenient constructor for `ADStmtInfo`. If either `fwds` or `rvs` is not a vector, +`__vec` promotes it to a single-element `Vector`. +""" function ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs) if !(comms_id === nothing || in(comms_id, map(first, __vec(line, fwds)))) throw(ArgumentError("comms_id not found in IDs of `fwds` instructions.")) @@ -291,14 +330,18 @@ __vec(line::ID, x::NewInstruction) = IDInstPair[(line, x)] __vec(line::ID, x::Vector{Tuple{ID, Any}}) = throw(error("boooo")) __vec(line::ID, x::Vector{IDInstPair}) = x -# Return the element of `fwds` whose `ID` is the communcation `ID`. Returns `Nothing` if -# `comms_id` is `nothing`. +""" + comms_channel(info::ADStmtInfo) + +Return the element of `fwds` whose `ID` is the communcation `ID`. Returns `Nothing` if +`comms_id` is `nothing`. +""" function comms_channel(info::ADStmtInfo) info.comms_id === nothing && return nothing return only(filter(x -> x[1] == info.comms_id, info.fwds)) end -#= +""" make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo Every line in the primal code is associated to one or more lines in the forwards-pass of AD, @@ -312,28 +355,36 @@ form of an `ADStmtInfo`. `info` is a data structure containing various bits of global information that certain types of nodes need access to. -=# +""" function make_ad_stmts! end -# `nothing` as a statement in Julia IR indicates the presence of a line which will later be -# removed. We emit a no-op on both the forwards- and reverse-passes. No shared data. +#= + make_ad_stmts!(::Nothing, line::ID, ::ADInfo) + +`nothing` as a statement in Julia IR indicates the presence of a line which will later be +removed. We emit a no-op on both the forwards- and reverse-passes. No shared data. +=# function make_ad_stmts!(::Nothing, line::ID, ::ADInfo) return ad_stmt_info(line, nothing, nothing, nothing) end -# `ReturnNode`s have a single field, `val`, for which there are three cases to consider: -# -# 1. `val` is undefined: this `ReturnNode` is unreachable. Consequently, we'll never hit the -# associated statements on the forwards-pass or pullback. We just return the original -# statement on the forwards-pass, and `nothing` on the reverse-pass. -# 2. `val isa Union{Argument, ID}`: this is an active piece of data. Consequently, we know -# that it will be an `CoDual` already, and can just return it. Therefore `stmt` -# is returned as the forwards-pass (with any `Argument`s incremented). On the reverse-pass -# the associated rdata ref should be incremented with the rdata passed to the pullback, -# which lives in argument 2. -# 3. `val` is defined, but not a `Union{Argument, ID}`: in this case we're returning a -# constant -- build a constant CoDual and return that. There is nothing to do on the -# reverse pass. +#= + make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) + +`ReturnNode`s have a single field, `val`, for which there are three cases to consider: + +1. `val` is undefined: this `ReturnNode` is unreachable. Consequently, we'll never hit the + associated statements on the forwards-pass or pullback. We just return the original + statement on the forwards-pass, and `nothing` on the reverse-pass. +2. `val isa Union{Argument, ID}`: this is an active piece of data. Consequently, we know + that it will be an `CoDual` already, and can just return it. Therefore `stmt` + is returned as the forwards-pass (with any `Argument`s incremented). On the reverse-pass + the associated rdata ref should be incremented with the rdata passed to the pullback, + which lives in argument 2. +3. `val` is defined, but not a `Union{Argument, ID}`: in this case we're returning a + constant -- build a constant CoDual and return that. There is nothing to do on the + reverse pass. +=# function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) if !is_reachable_return_node(stmt) return ad_stmt_info(line, nothing, inc_args(stmt), nothing) @@ -445,15 +496,23 @@ make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo) = const_ad_stmt(stmt, li # Literal constant. make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) -# `make_ad_stmts!` for constants. +""" + const_ad_stmt(stmt, line::ID, info::ADInfo) + +Implementation of `make_ad_stmts!` used for constants. +""" function const_ad_stmt(stmt, line::ID, info::ADInfo) x = const_codual(stmt, info) return ad_stmt_info(line, nothing, x isa ID ? Expr(:call, identity, x) : x, nothing) end -# Build a `CoDual` from `stmt`, with zero / uninitialised fdata. If the resulting CoDual is -# a bits type, then it is returned. If it is not, then the CoDual is put into shared data, -# and the ID associated to it in the forwards- and reverse-passes returned. +""" + const_codual(stmt, info::ADInfo) + +Build a `CoDual` from `stmt`, with zero / uninitialised fdata. If the resulting CoDual is +a bits type, then it is returned. If it is not, then the CoDual is put into shared data, +and the ID associated to it in the forwards- and reverse-passes returned. +""" function const_codual(stmt, info::ADInfo) v = get_const_primal_value(stmt) x = uninit_fcodual(v) @@ -464,7 +523,11 @@ safe_for_literal(v) = v isa String || v isa Type || isbitstype(_typeof(v)) inc_or_const(stmt, info::ADInfo) = is_active(stmt) ? __inc(stmt) : const_codual(stmt, info) -# Get the value associated to `x`. For `GlobalRef`s, verify that `x` is indeed a constant. +""" + get_const_primal_value(x::GlobalRef) + +Get the value associated to `x`. For `GlobalRef`s, verify that `x` is indeed a constant. +""" function get_const_primal_value(x::GlobalRef) isconst(x) || unhandled_feature("Non-constant GlobalRef not supported: $x") return getglobal(x.mod, x.name) @@ -663,7 +726,11 @@ end is_active(::Union{Argument, ID}) = true is_active(::Any) = false -# Get a bound on the pullback type, given a rule and associated primal types. +""" + pullback_type(Trule, arg_types) + +Get a bound on the pullback type, given a rule and associated primal types. +""" function pullback_type(Trule, arg_types) T = Core.Compiler.return_type(Tuple{Trule, map(fcodual_type, arg_types)...}) return T <: Tuple ? _pullback_type(T) : Any @@ -681,8 +748,16 @@ end __get_primal(x::CoDual) = primal(x) __get_primal(x) = x -# Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`. -@inline function __run_rvs_pass!(P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...) where {sig} +""" + __run_rvs_pass!( + P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs... + ) where {sig} + +Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`. +""" +@inline function __run_rvs_pass!( + P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs... +) where {sig} tuple_map(increment_if_ref!, arg_rev_data_refs, pb!!(ret_rev_data_ref[])) set_ret_ref_to_zero!!(P, ret_rev_data_ref) return nothing @@ -764,15 +839,23 @@ _copy(x) = copy(x) return fwds.fwds_oc.oc(uf_args...)::CoDual, fwds.pb end -# If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0). +""" + __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs} + +If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0). +""" function __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs} isva || return args last_el = isa(args[end], NoRData) ? ntuple(n -> NoRData(), nvargs) : args[end] return (args[1:end-1]..., last_el...) end -# If isva and nargs=2, then inputs `(CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))` -# are transformed into `(CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))`. +""" + __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs} + +If isva and nargs=2, then inputs `(CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))` +are transformed into `(CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))`. +""" function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs} isva || return args group_primal = map(primal, args[nargs:end]) @@ -793,9 +876,13 @@ _is_primitive(C::Type, sig::Type) = is_primitive(C, sig) const RuleMC{A, R} = MistyClosure{OpaqueClosure{A, R}} -# Compute the concrete type of the rule that will be returned from `build_rrule`. This is -# important for performance in dynamic dispatch, and to ensure that recursion works -# properly. +""" + rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C} + +Compute the concrete type of the rule that will be returned from `build_rrule`. This is +important for performance in dynamic dispatch, and to ensure that recursion works +properly. +""" function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C} if _is_primitive(C, sig_or_mi) @@ -953,7 +1040,12 @@ function build_rrule( end end -# Used by `build_rrule`, and the various debugging tools: primal_ir, fwds_ir, adjoint_ir. +""" + generate_ir( + interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true + ) +Used by `build_rrule`, and the various debugging tools: primal_ir, fwds_ir, adjoint_ir. +""" function generate_ir( interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true ) @@ -996,25 +1088,37 @@ function generate_ir( return DerivedRuleInfo(ir, opt_fwd_ir, opt_rvs_ir, shared_data, info, isva) end -# Given an `OpaqueClosure` `oc`, create a new `OpaqueClosure` of the same type, but with new -# captured variables. This is needed for efficiency reasons -- if `build_rrule` is called -# repeatedly with the same signature and intepreter, it is important to avoid recompiling -# the `OpaqueClosure`s that it produces multiple times, because it can be quite expensive to -# do so. -@eval function replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure} +""" + replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure} + +Given an `OpaqueClosure` `oc`, create a new `OpaqueClosure` of the same type, but with new +captured variables. This is needed for efficiency reasons -- if `build_rrule` is called +repeatedly with the same signature and intepreter, it is important to avoid recompiling +the `OpaqueClosure`s that it produces multiple times, because it can be quite expensive to +do so. +""" +function replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure} + return __replace_captures_internal(oc, new_captures) +end + +@eval function __replace_captures_internal(oc::Toc, new_captures) where {Toc<:OpaqueClosure} return $(Expr( :new, :(Toc), :new_captures, :(oc.world), :(oc.source), :(oc.invoke), :(oc.specptr) )) end -# Wrapper for `MistyClosure`s. +""" + replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure} + +Same as `replace_captures` for `Core.OpaqueClosure`s, but returns a new `MistyClosure`. +""" function replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure} return Tmc(replace_captures(mc.oc, new_captures), mc.ir) end const ADStmts = Vector{Tuple{ID, Vector{ADStmtInfo}}} -#= +""" create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo) This function produces code which can be inserted into the forwards-pass and reverse-pass at @@ -1036,7 +1140,7 @@ For each basic block represented in `ADStmts`: Returns two a `Tuple{Vector{IDInstPair}, Vector{IDInstPair}`. The nth element of each `Vector` corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in `ad_stmts_blocks`. -=# +""" function create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo) insts = map(ad_stmts_blocks) do (_, ad_stmts) @@ -1075,11 +1179,11 @@ function create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo) return map(first, insts), map(last, insts) end -#= +""" forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) Produce the IR associated to the `OpaqueClosure` which runs most of the forwards-pass. -=# +""" function forwards_pass_ir( ir::BBCode, ad_stmts_blocks::ADStmts, fwds_comms_insts, info::ADInfo, Tshared_data ) @@ -1145,11 +1249,18 @@ function forwards_pass_ir( return remove_unreachable_blocks!(ir) end -# Going via this function, rather than just calling push!, makes it very straightforward to -# figure out much time is spent pushing to the block stack when profiling AD. +""" + __push_blk_stack!(block_stack::BlockStack, id::Int32) + +Equivalent to `push!(block_stack, id)`. Going via this function, rather than just calling +push! directly, is helpful for debugging and performance analysis -- it makes it very +straightforward to figure out much time is spent pushing to the block stack when profiling. +""" @inline __push_blk_stack!(block_stack::BlockStack, id::Int32) = push!(block_stack, id) -@inline function __assemble_lazy_zero_rdata(r::Ref{T}, args::Vararg{CoDual, N}) where {T<:Tuple, N} +@inline function __assemble_lazy_zero_rdata( + r::Ref{T}, args::Vararg{CoDual, N} +) where {T<:Tuple, N} r[] = __make_tuples(T, args) return nothing end @@ -1161,11 +1272,11 @@ end return Expr(:call, tuple, lazy_exprs...) end -#= +""" pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) Produce the IR associated to the `OpaqueClosure` which runs most of the pullback. -=# +""" function pullback_ir( ir::BBCode, Tret, ad_stmts_blocks::ADStmts, pb_comms_insts, info::ADInfo, Tshared_data ) @@ -1307,14 +1418,14 @@ function pullback_ir( return remove_unreachable_blocks!(_sort_blocks!(pb_ir)) end -#= +""" conclude_rvs_block( blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo ) Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to. -=# +""" function conclude_rvs_block( blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo ) @@ -1345,21 +1456,29 @@ function conclude_rvs_block( return vcat(deref_stmts, switch), new_blocks end -# Helper functionality for conclude_rvs_block. +""" + __get_value(edge::ID, x::IDPhiNode) + +Helper functionality for conclude_rvs_block. +""" function __get_value(edge::ID, x::IDPhiNode) edge in x.edges || return nothing n = only(findall(==(edge), x.edges)) return isassigned(x.values, n) ? x.values[n] : nothing end -# Helper, used in conclude_rvs_block. +""" + __deref_and_zero(::Type{P}, x::Ref) where {P} + +Helper, used in conclude_rvs_block. +""" @inline function __deref_and_zero(::Type{P}, x::Ref) where {P} t = x[] x[] = Mooncake.zero_like_rdata_from_type(P) return t end -#= +""" rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo) Produces a `BBlock` which runs the reverse-pass for the edge associated to `pred_id` in a @@ -1386,7 +1505,7 @@ on. The same ideas apply if `pred_id` were `#3`. The block would end with `#3`, and there would be two `increment_ref!` calls because both `%5` and `_2` are not constants. -=# +""" function rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo) @assert length(rdata_ids) == length(values) inc_stmts = map(rdata_ids, values) do id, val @@ -1397,7 +1516,7 @@ function rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, return BBlock(ID(), vcat(inc_stmts, goto_stmt)) end -#= +""" make_switch_stmts( pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo ) @@ -1419,7 +1538,7 @@ switch( In words: `make_switch_stmts` emits code which jumps to whichever block preceded the current block during the forwards-pass. -=# +""" function make_switch_stmts( pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo ) @@ -1454,15 +1573,23 @@ end function make_switch_stmts(pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo) return make_switch_stmts(pred_ids, pred_ids, pred_is_unique_pred, info) end +""" + __pop_blk_stack!(block_stack::BlockStack) -# Going via this function, rather than just calling pop! directly, makes it easy to figure -# out how much time is spent popping the block stack when profiling performance. +Equivalent to `pop!(block_stack)`. Going via this function, rather than just calling `pop!` +directly, makes it easy to figure out how much time is spent popping the block stack when +profiling performance, and to know that this function was hit when debugging. +""" @inline __pop_blk_stack!(block_stack::BlockStack) = pop!(block_stack) -# Helper function emitted by `make_switch_stmts`. +""" + __switch_case(id::Int32, predecessor_id::Int32) + +Helper function emitted by `make_switch_stmts`. +""" __switch_case(id::Int32, predecessor_id::Int32) = !(id === predecessor_id) -#= +""" DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool) For internal use only. @@ -1471,7 +1598,7 @@ A callable data structure which, when invoked, calls an rrule specific to the dy of its arguments. Stores rules in an internal cache to avoid re-deriving. This is used to implement dynamic dispatch. -=# +""" struct DynamicDerivedRule{V} cache::V debug_mode::Bool @@ -1491,7 +1618,7 @@ function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N} return rule(args...) end -#= +""" LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool) For internal use only. @@ -1508,7 +1635,7 @@ Note: the signature of the primal for which this is a rule is stored in the type reason to keep this around is for debugging -- it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit. -=# +""" mutable struct LazyDerivedRule{primal_sig, Trule} debug_mode::Bool mi::Core.MethodInstance diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 912c72d80..6303013c5 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -160,20 +160,23 @@ end @inactive_intrinsic bswap_int @inactive_intrinsic ceil_llvm -#= +""" + __cglobal(::Val{s}, x::Vararg{Any, N}) where {s, N} + Replacement for `Core.Intrinsics.cglobal`. `cglobal` is different from the other intrinsics -in that the name `cglobal` is reversed by the language (try creating a variable called +in that the name `cglobal` is reserved by the language (try creating a variable called `cglobal` -- Julia will not let you). Additionally, it requires that its first argument, the specification of the name of the C cglobal variable that this intrinsic returns a pointer to, is known statically. In this regard it is like foreigncalls. As a consequence, it requires special handling. The name is converted into a `Val` so that it is available statically, and the function into which `cglobal` calls are converted is -named `Mooncake.IntrinsicsWrappers.__cglobal`, rather than `Mooncake.IntrinsicsWrappers.cglobal`. +named `Mooncake.IntrinsicsWrappers.__cglobal`, rather than +`Mooncake.IntrinsicsWrappers.cglobal`. If you examine the code associated with `Mooncake.intrinsic_to_function`, you will see that special handling of `cglobal` is used. -=# +""" __cglobal(::Val{s}, x::Vararg{Any, N}) where {s, N} = cglobal(s, x...) translate(::Val{Intrinsics.cglobal}) = __cglobal diff --git a/src/utils.jl b/src/utils.jl index 3628e614e..d52f6f9f3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -76,7 +76,7 @@ end end -#= +""" _map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P} For all `n`, if `x[n]` is assigned, then writes the value returned by `f(x[n])` to `y[n]`, @@ -85,7 +85,7 @@ otherwise leaves `y[n]` unchanged. Equivalent to `map!(f, y, x)` if `P` is a bits type as element will always be assigned. Requires that `y` and `x` have the same size. -=# +""" function _map_if_assigned!(f::F, y::DenseArray, x::DenseArray{P}) where {F, P} @assert size(y) == size(x) @inbounds for n in eachindex(y) @@ -96,14 +96,14 @@ function _map_if_assigned!(f::F, y::DenseArray, x::DenseArray{P}) where {F, P} return y end -#= +""" _map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray) Similar to the other method of `_map_if_assigned!` -- for all `n`, if `x1[n]` is assigned, writes `f(x1[n], x2[n])` to `y[n]`, otherwise leaves `y[n]` unchanged. Requires that `y`, `x1`, and `x2` have the same size. -=# +""" function _map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray) where {F, P} @assert size(y) == size(x1) @assert size(y) == size(x2) @@ -115,30 +115,30 @@ function _map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArra return y end -#= +""" _map(f, x...) Same as `map` but requires all elements of `x` to have equal length. The usual function `map` doesn't enforce this for `Array`s. -=# +""" @inline function _map(f::F, x::Vararg{Any, N}) where {F, N} @assert allequal(map(length, x)) return map(f, x...) end -#= +""" is_vararg_and_sparam_names(m::Method) Returns a 2-tuple. The first element is true if `m` is a vararg method, and false if not. The second element contains the names of the static parameters associated to `m`. -=# +""" is_vararg_and_sparam_names(m::Method) = m.isva, sparam_names(m) -#= +""" is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} Finds the method associated to `sig`, and calls `is_vararg_and_sparam_names` on it. -=# +""" function is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} world = Base.get_world_counter() min = Base.RefValue{UInt}(typemin(UInt)) @@ -147,16 +147,20 @@ function is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} return is_vararg_and_sparam_names(only(ms).method) end -#= +""" is_vararg_and_sparam_names(mi::Core.MethodInstance) Calls `is_vararg_and_sparam_names` on `mi.def::Method`. -=# +""" function is_vararg_and_sparam_names(mi::Core.MethodInstance)::Tuple{Bool, Vector{Symbol}} return is_vararg_and_sparam_names(mi.def) end -# Returns the names of all of the static parameters in `m`. +""" + sparam_names(m::Core.Method)::Vector{Symbol} + +Returns the names of all of the static parameters in `m`. +""" function sparam_names(m::Core.Method)::Vector{Symbol} whereparams = ExprTools.where_parameters(m.sig) whereparams === nothing && return Symbol[] diff --git a/test/integration_testing/gp/gp.jl b/test/integration_testing/gp/gp.jl index 0933161a7..b7b70bf6e 100644 --- a/test/integration_testing/gp/gp.jl +++ b/test/integration_testing/gp/gp.jl @@ -6,6 +6,7 @@ using AbstractGPs, KernelFunctions, LinearAlgebra, Mooncake, StableRNGs, Test using Mooncake.TestUtils: test_rule @testset "gp" begin + rng = StableRNG(123456) ks = Any[ ZeroKernel(), ConstantKernel(; c=1.0), @@ -17,22 +18,25 @@ using Mooncake.TestUtils: test_rule PolynomialKernel(; degree=2, c=0.5), ] xs = Any[ - (randn(10), randn(10)), - (randn(1), randn(1)), - (ColVecs(randn(2, 11)), ColVecs(randn(2, 11))), - (RowVecs(randn(3, 4)), RowVecs(randn(3, 4))), + (randn(rng, 10), randn(rng, 10)), + (randn(rng, 1), randn(rng, 1)), + (ColVecs(randn(rng, 2, 11)), ColVecs(randn(rng, 2, 11))), + (RowVecs(randn(rng, 3, 4)), RowVecs(randn(rng, 3, 4))), ] d_2_xs = Any[ - (ColVecs(randn(2, 11)), ColVecs(randn(2, 11))), - (RowVecs(randn(9, 2)), RowVecs(randn(9, 2))), + (ColVecs(randn(rng, 2, 11)), ColVecs(randn(rng, 2, 11))), + (RowVecs(randn(rng, 9, 2)), RowVecs(randn(rng, 9, 2))), ] @testset "$k, $(typeof(x1))" for (k, x1, x2) in vcat( Any[(k, x1, x2) for k in ks for (x1, x2) in xs], Any[(with_lengthscale(k, 1.1), x1, x2) for k in ks for (x1, x2) in xs], - Any[(with_lengthscale(k, rand(2)), x1, x2) for k in ks for (x1, x2) in d_2_xs], - Any[(k ∘ LinearTransform(randn(2, 2)), x1, x2) for k in ks for (x1, x2) in d_2_xs], + Any[(with_lengthscale(k, rand(rng, 2)), x1, x2) for k in ks for (x1, x2) in d_2_xs], Any[ - (k ∘ LinearTransform(Diagonal(randn(2))), x1, x2) for + (k ∘ LinearTransform(randn(rng, 2, 2)), x1, x2) for + k in ks for (x1, x2) in d_2_xs + ], + Any[ + (k ∘ LinearTransform(Diagonal(randn(rng, 2))), x1, x2) for k in ks for (x1, x2) in d_2_xs ], ) @@ -43,10 +47,10 @@ using Mooncake.TestUtils: test_rule (kernelmatrix, k, x1), (kernelmatrix_diag, k, x1), (fx -> rand(StableRNG(123456), fx), fx), - (logpdf, fx, rand(fx)), + (logpdf, fx, rand(rng, fx)), ] @info typeof(args) - test_rule(StableRNG(123456), args...; is_primitive=false, unsafe_perturb=true) + test_rule(rng, args...; is_primitive=false, unsafe_perturb=true) end end end diff --git a/test/interpreter/bbcode.jl b/test/interpreter/bbcode.jl index e77a57fed..e844cc57b 100644 --- a/test/interpreter/bbcode.jl +++ b/test/interpreter/bbcode.jl @@ -267,4 +267,18 @@ end ] @test Mooncake.stmt(new_ir.stmts) == expected_stmts end + @testset "inc_args" begin + @test Mooncake.inc_args(Expr(:call, sin, Argument(4))) == Expr(:call, sin, Argument(5)) + @test Mooncake.inc_args(ReturnNode(Argument(2))) == ReturnNode(Argument(3)) + id = ID() + @test Mooncake.inc_args(IDGotoIfNot(Argument(1), id)) == IDGotoIfNot(Argument(2), id) + @test Mooncake.inc_args(IDGotoNode(id)) == IDGotoNode(id) + ids = [id, ID()] + @test ==( + Mooncake.inc_args(IDPhiNode(ids, Any[Argument(1), 4])), + IDPhiNode(ids, Any[Argument(2), 4]), + ) + @test Mooncake.inc_args(nothing) === nothing + @test Mooncake.inc_args(GlobalRef(Base, :sin)) == GlobalRef(Base, :sin) + end end diff --git a/test/interpreter/ir_utils.jl b/test/interpreter/ir_utils.jl index f5e0f1b36..ea420ce8d 100644 --- a/test/interpreter/ir_utils.jl +++ b/test/interpreter/ir_utils.jl @@ -54,13 +54,6 @@ end Mooncake.UnhandledLanguageFeatureException, Mooncake.unhandled_feature("foo") ) end - @testset "inc_args" begin - @test Mooncake.inc_args(Expr(:call, sin, Argument(4))) == Expr(:call, sin, Argument(5)) - @test Mooncake.inc_args(ReturnNode(Argument(2))) == ReturnNode(Argument(3)) - id = ID() - @test Mooncake.inc_args(IDGotoIfNot(Argument(1), id)) == IDGotoIfNot(Argument(2), id) - @test Mooncake.inc_args(IDGotoNode(id)) == IDGotoNode(id) - end @testset "replace_uses_with!" begin stmt = Expr(:call, sin, SSAValue(1)) Mooncake.replace_uses_with!(stmt, SSAValue(1), 5.0) diff --git a/test/runtests.jl b/test/runtests.jl index 79a402b17..56cf74d9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,8 +13,8 @@ include("front_matter.jl") @testset "interpreter" begin include(joinpath("interpreter", "contexts.jl")) include(joinpath("interpreter", "abstract_interpretation.jl")) - include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_utils.jl")) + include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))