Skip to content

Commit

Permalink
Piracy: Add kwarg treat_as_own (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgoettgens authored Jun 24, 2023
1 parent dd1d392 commit 175e78f
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 39 deletions.
73 changes: 44 additions & 29 deletions src/piracy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ end
##################################
# Generic fallback for type parameters that are instances, like the 1 in
# Array{T, 1}
is_foreign(@nospecialize(x), pkg::Base.PkgId) = is_foreign(typeof(x), pkg)
is_foreign(@nospecialize(x), pkg::Base.PkgId; treat_as_own) =
is_foreign(typeof(x), pkg; treat_as_own = treat_as_own)

# Symbols can be used as type params - we assume these are unique and not
# piracy. This implies that we have
Expand Down Expand Up @@ -77,87 +78,96 @@ is_foreign(@nospecialize(x), pkg::Base.PkgId) = is_foreign(typeof(x), pkg)
# a crazy API). The symbol name may also come from `gensym`. Since the aim of
# `Aqua.test_piracy` is to detect only "obvious" piracy, let us play on the
# safe side.
is_foreign(x::Symbol, pkg::Base.PkgId) = false
is_foreign(x::Symbol, pkg::Base.PkgId; treat_as_own) = false

is_foreign_module(mod::Module, pkg::Base.PkgId) = Base.PkgId(mod) != pkg

function is_foreign(@nospecialize(T::DataType), pkg::Base.PkgId)
function is_foreign(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as_own)
params = T.parameters
# For Type{Foo}, we consider it to originate from the same as Foo
C = getfield(parentmodule(T), nameof(T))
if C === Type
@assert length(params) == 1
return is_foreign(first(params), pkg)
return is_foreign(first(params), pkg; treat_as_own = treat_as_own)
else
# Both the type itself and all of its parameters must be foreign
return is_foreign_module(parentmodule(T), pkg) && all(params) do param
is_foreign(param, pkg)
end
return !(C in treat_as_own) &&
is_foreign_module(parentmodule(T), pkg) &&
all(param -> is_foreign(param, pkg; treat_as_own = treat_as_own), params)
end
end

function is_foreign(@nospecialize(U::UnionAll), pkg::Base.PkgId)
function is_foreign(@nospecialize(U::UnionAll), pkg::Base.PkgId; treat_as_own)
# We do not consider extending Set{T} to be piracy, if T is not foreign.
# Extending it goes against Julia style, but it's not piracy IIUC.
is_foreign(U.body, pkg) && is_foreign(U.var, pkg)
is_foreign(U.body, pkg; treat_as_own = treat_as_own) &&
is_foreign(U.var, pkg; treat_as_own = treat_as_own)
end

is_foreign(@nospecialize(T::TypeVar), pkg::Base.PkgId) = is_foreign(T.ub, pkg)
is_foreign(@nospecialize(T::TypeVar), pkg::Base.PkgId; treat_as_own) =
is_foreign(T.ub, pkg; treat_as_own = treat_as_own)

# Before 1.7, Vararg was a UnionAll, so the UnionAll method will work
@static if VERSION >= v"1.7"
is_foreign(@nospecialize(T::Core.TypeofVararg), pkg::Base.PkgId) = is_foreign(T.T, pkg)
is_foreign(@nospecialize(T::Core.TypeofVararg), pkg::Base.PkgId; treat_as_own) =
is_foreign(T.T, pkg; treat_as_own = treat_as_own)
end

function is_foreign(@nospecialize(U::Union), pkg::Base.PkgId)
function is_foreign(@nospecialize(U::Union), pkg::Base.PkgId; treat_as_own)
# Even if Foo is local, overloading f(::Union{Foo, Int}) with foreign f
# is piracy.
any(T -> is_foreign(T, pkg), Base.uniontypes(U))
any(T -> is_foreign(T, pkg; treat_as_own = treat_as_own), Base.uniontypes(U))
end

function is_foreign_method(@nospecialize(U::Union), pkg::Base.PkgId)
function is_foreign_method(@nospecialize(U::Union), pkg::Base.PkgId; treat_as_own)
# When installing a method for a union type, then we only consider it as
# foreign if *all* parameters of the union are foreign, i.e. overloading
# Union{Foo, Int}() is not piracy.
all(T -> is_foreign(T, pkg), Base.uniontypes(U))
all(T -> is_foreign(T, pkg; treat_as_own = treat_as_own), Base.uniontypes(U))
end

function is_foreign_method(@nospecialize(x::Any), pkg::Base.PkgId)
is_foreign(x, pkg)
function is_foreign_method(@nospecialize(x::Any), pkg::Base.PkgId; treat_as_own)
is_foreign(x, pkg; treat_as_own = treat_as_own)
end

function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId)
function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as_own)
params = T.parameters
# For Type{Foo}, we consider it to originate from the same as Foo
C = getfield(parentmodule(T), nameof(T))
if C === Type
@assert length(params) == 1
U = first(params)
return is_foreign_method(first(params), pkg)
return is_foreign_method(first(params), pkg; treat_as_own = treat_as_own)
end

# fallback to general code
return is_foreign(T, pkg)
return !(T in treat_as_own) &&
!(T <: Function && T.instance in treat_as_own) &&
is_foreign(T, pkg; treat_as_own = treat_as_own)
end


function is_pirate(meth::Method)
function is_pirate(meth::Method; treat_as_own = Union{Function,Type}[])
method_pkg = Base.PkgId(meth.module)

signature = Base.unwrap_unionall(meth.sig)

# the first parameter in the signature is the function type, and it
# follows slightly other rules if it happens to be a Union type
is_foreign_method(signature.parameters[1], method_pkg) || return false
is_foreign_method(signature.parameters[1], method_pkg; treat_as_own = treat_as_own) ||
return false

all(param -> is_foreign(param, method_pkg), signature.parameters[2:end])
all(
param -> is_foreign(param, method_pkg; treat_as_own = treat_as_own),
signature.parameters[2:end],
)
end

hunt(mod::Module; from::Module = mod) = hunt(Base.PkgId(mod); from = from)
hunt(mod::Module; from::Module = mod, kwargs...) =
hunt(Base.PkgId(mod); from = from, kwargs...)

function hunt(pkg::Base.PkgId; from::Module)
function hunt(pkg::Base.PkgId; from::Module, kwargs...)
filter(all_methods(from)) do method
is_pirate(method) && Base.PkgId(method.module) === pkg
Base.PkgId(method.module) === pkg && is_pirate(method, kwargs...)
end
end

Expand All @@ -172,9 +182,14 @@ See [Julia documentation](https://docs.julialang.org/en/v1/manual/style-guide/#A
# Keyword Arguments
- `broken::Bool = false`: If true, it uses `@test_broken` instead of
`@test`.
- `treat_as_own = Union{Function, Type}[]`: The types in this container
are considered to be "owned" by the module `m`. This is useful for
testing packages that deliberately commit some type piracy, e.g. modules
adding higher-level functionality to a lightweight C-wrapper, or packages
that are extending `StatsAPI.jl`.
"""
function test_piracy(m::Module; broken::Bool = false)
v = Piracy.hunt(m)
function test_piracy(m::Module; broken::Bool = false, kwargs...)
v = Piracy.hunt(m; kwargs...)
if !isempty(v)
printstyled(
stderr,
Expand Down
2 changes: 2 additions & 0 deletions test/pkgs/PiracyForeignProject/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
name = "PiracyForeignProject"
uuid = "f592ac8b-a2e8-4dd0-be7a-e4053dab5b76"
6 changes: 6 additions & 0 deletions test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module PiracyForeignProject

struct ForeignType end
struct ForeignParameterizedType{T} end

end
82 changes: 72 additions & 10 deletions test/test_piracy.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
push!(LOAD_PATH, joinpath(@__DIR__, "pkgs", "PiracyForeignProject"))

baremodule PiracyModule

using PiracyForeignProject: ForeignType, ForeignParameterizedType

using Base:
Base,
Set,
Expand Down Expand Up @@ -28,6 +32,7 @@ Base.findlast(::Set{Foo}, x::Int) = x + 1
Base.findlast(::Type{Val{Foo}}, x::Int) = x + 1
Base.findlast(::Tuple{Vararg{Bar{Set{Int}}}}, x::Int) = x + 1
Base.findlast(::Val{:foo}, x::Int) = x + 1
Base.findlast(::ForeignParameterizedType{Foo}, x::Int) = x + 1

# Not piracy
const MyUnion = Union{Int,Foo}
Expand All @@ -40,12 +45,25 @@ Base.findfirst(::Set{Vector{Char}}, ::Int) = 1
Base.findfirst(::Union{Foo,Bar{Set{Unsigned}},UInt}, ::Tuple{Vararg{String}}) = 1
Base.findfirst(::AbstractChar, ::Set{T}) where {Int <: T <: Integer} = 1

# Piracy, but not for `ForeignType in treat_as_own`
Base.findmax(::ForeignType, x::Int) = x + 1
Base.findmax(::Set{Vector{ForeignType}}, x::Int) = x + 1
Base.findmax(::Union{Foo,ForeignType}, x::Int) = x + 1

# Piracy, but not for `ForeignParameterizedType in treat_as_own`
Base.findmin(::ForeignParameterizedType{Int}, x::Int) = x + 1
Base.findmin(::Set{Vector{ForeignParameterizedType{Int}}}, x::Int) = x + 1
Base.findmin(::Union{Foo,ForeignParameterizedType{Int}}, x::Int) = x + 1

# Assign them names in this module so they can be found by all_methods
x = Base.findfirst
y = Base.findlast
a = Base.findfirst
b = Base.findlast
c = Base.findmax
d = Base.findmin
end # PiracyModule

using Aqua: Piracy
using PiracyForeignProject: ForeignType, ForeignParameterizedType

# Get all methods - test length
meths = filter(Piracy.all_methods(PiracyModule)) do m
Expand All @@ -55,23 +73,67 @@ end
# 2 Foo constructors
# 2 from f
# 1 from MyUnion
# 5 from findlast
# 6 from findlast
# 3 from findfirst
@test length(meths) == 2 + 2 + 1 + 5 + 3
# 3 from findmax
# 3 from findmin
@test length(meths) == 2 + 2 + 1 + 6 + 3 + 3 + 3

# Test what is foreign
BasePkg = Base.PkgId(Base)
CorePkg = Base.PkgId(Core)
ThisPkg = Base.PkgId(PiracyModule)

@test Piracy.is_foreign(Int, BasePkg) # from Core
@test !Piracy.is_foreign(Int, CorePkg) # from Core
@test !Piracy.is_foreign(Set{Int}, BasePkg)
@test !Piracy.is_foreign(Set{Int}, CorePkg)
@test Piracy.is_foreign(Int, BasePkg; treat_as_own = []) # from Core
@test !Piracy.is_foreign(Int, CorePkg; treat_as_own = []) # from Core
@test !Piracy.is_foreign(Set{Int}, BasePkg; treat_as_own = [])
@test !Piracy.is_foreign(Set{Int}, CorePkg; treat_as_own = [])

# Test what is pirate
pirates = filter(Piracy.is_pirate, meths)
pirates = filter(m -> Piracy.is_pirate(m), meths)
@test length(pirates) == 3 + 3 + 3
@test all(pirates) do m
m.name in [:findfirst, :findmax, :findmin]
end

# Test what is pirate (with treat_as_own=[ForeignType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignType]), meths)
@test length(pirates) == 3 + 3
@test all(pirates) do m
m.name in [:findfirst, :findmin]
end

# Test what is pirate (with treat_as_own=[ForeignParameterizedType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignParameterizedType]), meths)
@test length(pirates) == 3 + 3
@test all(pirates) do m
m.name in [:findfirst, :findmax]
end

# Test what is pirate (with treat_as_own=[ForeignType, ForeignParameterizedType])
pirates = filter(
m -> Piracy.is_pirate(m; treat_as_own = [ForeignType, ForeignParameterizedType]),
meths,
)
@test length(pirates) == 3
@test all(pirates) do m
m.name === :findfirst
m.name in [:findfirst]
end

# Test what is pirate (with treat_as_own=[Base.findfirst, Base.findmax])
pirates =
filter(m -> Piracy.is_pirate(m; treat_as_own = [Base.findfirst, Base.findmax]), meths)
@test length(pirates) == 3
@test all(pirates) do m
m.name in [:findmin]
end

# Test what is pirate (excluding a cover of everything)
pirates = filter(
m -> Piracy.is_pirate(
m;
treat_as_own = [ForeignType, ForeignParameterizedType, Base.findfirst],
),
meths,
)
@test length(pirates) == 0

0 comments on commit 175e78f

Please sign in to comment.