Skip to content

Commit

Permalink
Merge pull request #1160 from JuliaSymbolics/ChrisRackauckas-patch-2
Browse files Browse the repository at this point in the history
Fix prewalk_if dropping metadata
  • Loading branch information
ChrisRackauckas authored Jun 5, 2024
2 parents da10c34 + 8600ee7 commit 5cfa0db
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,14 +588,6 @@ function replace_by_scalarizing(ex, dict)
rule = @rule(getindex(~x, ~~i) =>
scalarize(~x, (map(j->substitute(j, dict), ~~i)...,)))

simterm = (x, f, args; kws...) -> begin
if metadata(x) !== nothing
maketerm(typeof(x), f, args, symtype(x), metadata(x))
else
f(args...)
end
end

function rewrite_operation(x)
if iscall(x) && iscall(operation(x))
f = operation(x)
Expand All @@ -612,20 +604,23 @@ function replace_by_scalarizing(ex, dict)

prewalk_if(x->!(x isa ArrayOp || x isa ArrayMaker),
Rewriters.PassThrough(Chain([rewrite_operation, rule])),
ex, simterm)
ex)
end

function prewalk_if(cond, f, t, maketerm)
function prewalk_if(cond, f, t)
t′ = cond(t) ? f(t) : return t
if iscall(t′)
return maketerm(typeof(t′), TermInterface.head(t′),
map(x->prewalk_if(cond, f, x, maketerm), children(t′)))
if metadata(t′) !== nothing
return maketerm(typeof(t′), TermInterface.head(t′),
map(x->prewalk_if(cond, f, x), children(t′)), symtype(t′), metadata(t′))
else
TermInterface.head(t′)(map(x->prewalk_if(cond, f, x), children(t′))...)
end
else
return t′
end
end


function scalarize(arr::AbstractArray, idx)
arr[idx...]
end
Expand Down

0 comments on commit 5cfa0db

Please sign in to comment.