From 61699f6bdb2699709aa6118590bc66b3d4b42fb5 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 31 Jul 2023 17:10:47 -0400 Subject: [PATCH] AbstractInterpreter: add a hook to customize bestguess calculation Currently, the code that updates `bestguess` using `ReturnNode` information includes hardcodes that relate to `Conditional` and `LimitedAccuracy`. These behaviors are actually lattice-dependent and therefore should be overloadable by `AbstractInterpreter`. Additionally, particularly in Diffractor, a clever strategy is required to update return types in a way that it takes into account information from both the original method and its rule method (xref: JuliaDiff/Diffractor.jl#202). This also requires such an overload to exist. In response to these needs, this commit introduces an implementation of a hook named `update_bestguess!`. --- base/compiler/abstractinterpretation.jl | 67 ++++++++++++++----------- base/compiler/typeinfer.jl | 39 +++++++------- 2 files changed, 58 insertions(+), 48 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 0366a6353473c..6e552423cc25e 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2887,17 +2887,49 @@ function init_vartable!(vartable::VarTable, frame::InferenceState) return vartable end +function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState, + currstate::VarTable, @nospecialize(rt)) + bestguess = frame.bestguess + nargs = narguments(frame, #=include_va=#false) + slottypes = frame.slottypes + rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate)) + # narrow representation of bestguess slightly to prepare for tmerge with rt + if rt isa InterConditional && bestguess isa Const + slot_id = rt.slot + old_id_type = slottypes[slot_id] + if bestguess.val === true && rt.elsetype !== Bottom + bestguess = InterConditional(slot_id, old_id_type, Bottom) + elseif bestguess.val === false && rt.thentype !== Bottom + bestguess = InterConditional(slot_id, Bottom, old_id_type) + end + end + # copy limitations to return value + if !isempty(frame.pclimitations) + union!(frame.limitations, frame.pclimitations) + empty!(frame.pclimitations) + end + if !isempty(frame.limitations) + rt = LimitedAccuracy(rt, copy(frame.limitations)) + end + ๐•ƒโ‚š = ipo_lattice(interp) + if !โŠ‘(๐•ƒโ‚š, rt, bestguess) + # TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end + frame.bestguess = tmerge(๐•ƒโ‚š, bestguess, rt) # new (wider) return type for frame + return true + else + return false + end +end + # make as much progress on `frame` as possible (without handling cycles) function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) @assert !is_inferred(frame) frame.dont_work_on_me = true # mark that this function is currently on the stack W = frame.ip - nargs = narguments(frame, #=include_va=#false) - slottypes = frame.slottypes ssavaluetypes = frame.ssavaluetypes bbs = frame.cfg.blocks nbbs = length(bbs) - ๐•ƒโ‚š, ๐•ƒแตข = ipo_lattice(interp), typeinf_lattice(interp) + ๐•ƒแตข = typeinf_lattice(interp) currbb = frame.currbb if currbb != 1 @@ -2998,35 +3030,10 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end end elseif isa(stmt, ReturnNode) - bestguess = frame.bestguess rt = abstract_eval_value(interp, stmt.val, currstate, frame) - rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate)) - # narrow representation of bestguess slightly to prepare for tmerge with rt - if rt isa InterConditional && bestguess isa Const - let slot_id = rt.slot - old_id_type = slottypes[slot_id] - if bestguess.val === true && rt.elsetype !== Bottom - bestguess = InterConditional(slot_id, old_id_type, Bottom) - elseif bestguess.val === false && rt.thentype !== Bottom - bestguess = InterConditional(slot_id, Bottom, old_id_type) - end - end - end - # copy limitations to return value - if !isempty(frame.pclimitations) - union!(frame.limitations, frame.pclimitations) - empty!(frame.pclimitations) - end - if !isempty(frame.limitations) - rt = LimitedAccuracy(rt, copy(frame.limitations)) - end - if !โŠ‘(๐•ƒโ‚š, rt, bestguess) - # new (wider) return type for frame - bestguess = tmerge(๐•ƒโ‚š, bestguess, rt) - # TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end - frame.bestguess = bestguess + if update_bestguess!(interp, frame, currstate, rt) for (caller, caller_pc) in frame.cycle_backedges - if !(caller.ssavaluetypes[caller_pc] === Any) + if caller.ssavaluetypes[caller_pc] !== Any # no reason to revisit if that call-site doesn't affect the final result push!(caller.ip, block_for_inst(caller.cfg, caller_pc)) end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index afbf57cc04df8..e867f5e9ad9dc 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -870,26 +870,10 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize # since the inliner will request to use it later cache = :local else + rt = cached_return_type(code) effects = ipo_effects(code) update_valid_age!(caller, WorldRange(min_world(code), max_world(code))) - rettype = code.rettype - if isdefined(code, :rettype_const) - rettype_const = code.rettype_const - # the second subtyping/egal conditions are necessary to distinguish usual cases - # from rare cases when `Const` wrapped those extended lattice type objects - if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype) - rettype = PartialStruct(rettype, rettype_const) - elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure - rettype = rettype_const - elseif isa(rettype_const, InterConditional) && rettype !== InterConditional - rettype = rettype_const - elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias - rettype = rettype_const - else - rettype = Const(rettype_const) - end - end - return EdgeCallResult(rettype, mi, effects) + return EdgeCallResult(rt, mi, effects) end else cache = :global # cache edge targets by default @@ -933,6 +917,25 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize return EdgeCallResult(frame.bestguess, nothing, adjust_effects(frame)) end +function cached_return_type(code::CodeInstance) + rettype = code.rettype + isdefined(code, :rettype_const) || return rettype + rettype_const = code.rettype_const + # the second subtyping/egal conditions are necessary to distinguish usual cases + # from rare cases when `Const` wrapped those extended lattice type objects + if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype) + return PartialStruct(rettype, rettype_const) + elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure + return rettype_const + elseif isa(rettype_const, InterConditional) && rettype !== InterConditional + return rettype_const + elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias + return rettype_const + else + return Const(rettype_const) + end +end + #### entry points for inferring a MethodInstance given a type signature #### # compute an inferred AST and return type