Skip to content

Commit

Permalink
Add second treat_as_own kwarg to piracy tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lgoettgens committed Jun 24, 2023
1 parent 4e2756d commit f5d3ad4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 39 deletions.
85 changes: 54 additions & 31 deletions src/piracy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ end
##################################
# Generic fallback for type parameters that are instances, like the 1 in
# Array{T, 1}
is_foreign(@nospecialize(x), pkg::Base.PkgId; treat_as_own::Vector{<:Type} = Type[]) =
is_foreign(typeof(x), pkg; treat_as_own = treat_as_own)
is_foreign(@nospecialize(x), pkg::Base.PkgId; treat_as_own_type::Vector{<:Type} = Type[]) =
is_foreign(typeof(x), pkg; treat_as_own_type = treat_as_own_type)

# Symbols can be used as type params - we assume these are unique and not
# piracy. This implies that we have
Expand Down Expand Up @@ -78,113 +78,133 @@ is_foreign(@nospecialize(x), pkg::Base.PkgId; treat_as_own::Vector{<:Type} = Typ
# a crazy API). The symbol name may also come from `gensym`. Since the aim of
# `Aqua.test_piracy` is to detect only "obvious" piracy, let us play on the
# safe side.
is_foreign(x::Symbol, pkg::Base.PkgId; treat_as_own::Vector{<:Type} = Type[]) = false
is_foreign(x::Symbol, pkg::Base.PkgId; treat_as_own_type::Vector{<:Type} = Type[]) = false

is_foreign_module(mod::Module, pkg::Base.PkgId) = Base.PkgId(mod) != pkg

function is_foreign(
@nospecialize(T::DataType),
pkg::Base.PkgId;
treat_as_own::Vector{<:Type} = Type[],
treat_as_own_type::Vector{<:Type} = Type[],
)
params = T.parameters
# For Type{Foo}, we consider it to originate from the same as Foo
C = getfield(parentmodule(T), nameof(T))
if C === Type
@assert length(params) == 1
return is_foreign(first(params), pkg; treat_as_own = treat_as_own)
return is_foreign(first(params), pkg; treat_as_own_type = treat_as_own_type)
else
# Both the type itself and all of its parameters must be foreign
return (!(C in treat_as_own) && is_foreign_module(parentmodule(T), pkg)) &&
all(param -> is_foreign(param, pkg; treat_as_own = treat_as_own), params)
return (!(C in treat_as_own_type) && is_foreign_module(parentmodule(T), pkg)) &&
all(
param -> is_foreign(param, pkg; treat_as_own_type = treat_as_own_type),
params,
)
end
end

function is_foreign(
@nospecialize(U::UnionAll),
pkg::Base.PkgId;
treat_as_own::Vector{<:Type} = Type[],
treat_as_own_type::Vector{<:Type} = Type[],
)
# We do not consider extending Set{T} to be piracy, if T is not foreign.
# Extending it goes against Julia style, but it's not piracy IIUC.
is_foreign(U.body, pkg; treat_as_own = treat_as_own) &&
is_foreign(U.var, pkg; treat_as_own = treat_as_own)
is_foreign(U.body, pkg; treat_as_own_type = treat_as_own_type) &&
is_foreign(U.var, pkg; treat_as_own_type = treat_as_own_type)
end

is_foreign(
@nospecialize(T::TypeVar),
pkg::Base.PkgId;
treat_as_own::Vector{<:Type} = Type[],
) = is_foreign(T.ub, pkg; treat_as_own = treat_as_own)
treat_as_own_type::Vector{<:Type} = Type[],
) = is_foreign(T.ub, pkg; treat_as_own_type = treat_as_own_type)

# Before 1.7, Vararg was a UnionAll, so the UnionAll method will work
@static if VERSION >= v"1.7"
is_foreign(
@nospecialize(T::Core.TypeofVararg),
pkg::Base.PkgId;
treat_as_own::Vector{<:Type} = Type[],
) = is_foreign(T.T, pkg; treat_as_own = treat_as_own)
treat_as_own_type::Vector{<:Type} = Type[],
) = is_foreign(T.T, pkg; treat_as_own_type = treat_as_own_type)
end

function is_foreign(
@nospecialize(U::Union),
pkg::Base.PkgId;
treat_as_own::Vector{<:Type} = Type[],
treat_as_own_type::Vector{<:Type} = Type[],
)
# Even if Foo is local, overloading f(::Union{Foo, Int}) with foreign f
# is piracy.
any(T -> is_foreign(T, pkg; treat_as_own = treat_as_own), Base.uniontypes(U))
any(T -> is_foreign(T, pkg; treat_as_own_type = treat_as_own_type), Base.uniontypes(U))
end

function is_foreign_method(
@nospecialize(U::Union),
pkg::Base.PkgId;
treat_as_own::Vector{<:Type} = Type[],
treat_as_own_type::Vector{<:Type} = Type[],
treat_as_own_func::Vector{<:Function} = Function[],
)
# When installing a method for a union type, then we only consider it as
# foreign if *all* parameters of the union are foreign, i.e. overloading
# Union{Foo, Int}() is not piracy.
all(T -> is_foreign(T, pkg; treat_as_own = treat_as_own), Base.uniontypes(U))
all(T -> is_foreign(T, pkg; treat_as_own_type = treat_as_own_type), Base.uniontypes(U))
end

function is_foreign_method(
@nospecialize(x::Any),
pkg::Base.PkgId;
treat_as_own::Vector{<:Type} = Type[],
treat_as_own_type::Vector{<:Type} = Type[],
treat_as_own_func::Vector{<:Function} = Function[],
)
is_foreign(x, pkg; treat_as_own = treat_as_own)
is_foreign(x, pkg; treat_as_own_type = treat_as_own_type)
end

function is_foreign_method(
@nospecialize(T::DataType),
pkg::Base.PkgId;
treat_as_own::Vector{<:Type} = Type[],
treat_as_own_type::Vector{<:Type} = Type[],
treat_as_own_func::Vector{<:Function} = Function[],
)
params = T.parameters
# For Type{Foo}, we consider it to originate from the same as Foo
C = getfield(parentmodule(T), nameof(T))
if C === Type
@assert length(params) == 1
return is_foreign_method(first(params), pkg; treat_as_own = treat_as_own)
return is_foreign_method(
first(params),
pkg;
treat_as_own_type = treat_as_own_type,
treat_as_own_func = treat_as_own_func,
)
end

# fallback to general code
return is_foreign(T, pkg; treat_as_own = treat_as_own)
return !(T <: Function && T.instance in treat_as_own_func) &&
is_foreign(T, pkg; treat_as_own_type = treat_as_own_type)
end


function is_pirate(meth::Method; treat_as_own::Vector{<:Type} = Type[])
function is_pirate(
meth::Method;
treat_as_own_type::Vector{<:Type} = Type[],
treat_as_own_func::Vector{<:Function} = Function[],
)
method_pkg = Base.PkgId(meth.module)

signature = Base.unwrap_unionall(meth.sig)

# the first parameter in the signature is the function type, and it
# follows slightly other rules if it happens to be a Union type
is_foreign_method(signature.parameters[1], method_pkg; treat_as_own = treat_as_own) ||
return false
is_foreign_method(
signature.parameters[1],
method_pkg;
treat_as_own_type = treat_as_own_type,
treat_as_own_func = treat_as_own_func,
) || return false

all(
param -> is_foreign(param, method_pkg; treat_as_own = treat_as_own),
param -> is_foreign(param, method_pkg; treat_as_own_type = treat_as_own_type),
signature.parameters[2:end],
)
end
Expand All @@ -207,10 +227,13 @@ See [Julia documentation](https://docs.julialang.org/en/v1/manual/style-guide/#A
# Keyword Arguments
- `broken::Bool = false`: If true, it uses `@test_broken` instead of
`@test`.
- `treat_as_own::Vector{<:Type} = Type[]`: The types in this vector are considered
to be "owned" by the module `m`. This is useful for testing packages that
deliberately commit some type piracy, e.g. modules adding higher-level
functionality to a lightweight C-wrapper.
- `treat_as_own_type::Vector{<:Type} = Type[]`: The types in this vector
are considered to be "owned" by the module `m`. This is useful for
testing packages that deliberately commit some type piracy, e.g. modules
adding higher-level functionality to a lightweight C-wrapper.
- `treat_as_own_func::Vector{<:Function} = Function[]`: The functions in
this vector are considered to be "owned" by the module `m`. This is useful,
e. g. for testing packages that are extending StatsAPI.jl.
"""
function test_piracy(m::Module; broken::Bool = false, kwargs...)
v = hunt(m; kwargs...)
Expand Down
38 changes: 30 additions & 8 deletions test/test_piracy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ Base.findfirst(::Set{Vector{Char}}, ::Int) = 1
Base.findfirst(::Union{Foo,Bar{Set{Unsigned}},UInt}, ::Tuple{Vararg{String}}) = 1
Base.findfirst(::AbstractChar, ::Set{T}) where {Int <: T <: Integer} = 1

# Piracy, but not for `ForeignType in treat_as_own`
# Piracy, but not for `ForeignType in treat_as_own_type`
Base.findmax(::ForeignType, x::Int) = x + 1
Base.findmax(::Set{Vector{ForeignType}}, x::Int) = x + 1
Base.findmax(::Union{Foo,ForeignType}, x::Int) = x + 1

# Piracy, but not for `ForeignParameterizedType in treat_as_own`
# Piracy, but not for `ForeignParameterizedType in treat_as_own_type`
Base.findmin(::ForeignParameterizedType{Int}, x::Int) = x + 1
Base.findmin(::Set{Vector{ForeignParameterizedType{Int}}}, x::Int) = x + 1
Base.findmin(::Union{Foo,ForeignParameterizedType{Int}}, x::Int) = x + 1
Expand Down Expand Up @@ -96,26 +96,48 @@ pirates = filter(m -> Piracy.is_pirate(m), meths)
m.name in [:findfirst, :findmax, :findmin]
end

# Test what is pirate (with treat_as_own=[ForeignType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignType]), meths)
# Test what is pirate (with treat_as_own_type=[ForeignType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own_type = [ForeignType]), meths)
@test length(pirates) == 3 + 3
@test all(pirates) do m
m.name in [:findfirst, :findmin]
end

# Test what is pirate (with treat_as_own=[ForeignParameterizedType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignParameterizedType]), meths)
# Test what is pirate (with treat_as_own_type=[ForeignParameterizedType])
pirates =
filter(m -> Piracy.is_pirate(m; treat_as_own_type = [ForeignParameterizedType]), meths)
@test length(pirates) == 3 + 3
@test all(pirates) do m
m.name in [:findfirst, :findmax]
end

# Test what is pirate (with treat_as_own=[ForeignType, ForeignParameterizedType])
# Test what is pirate (with treat_as_own_type=[ForeignType, ForeignParameterizedType])
pirates = filter(
m -> Piracy.is_pirate(m; treat_as_own = [ForeignType, ForeignParameterizedType]),
m -> Piracy.is_pirate(m; treat_as_own_type = [ForeignType, ForeignParameterizedType]),
meths,
)
@test length(pirates) == 3
@test all(pirates) do m
m.name in [:findfirst]
end

# Test what is pirate (with treat_as_own_func=[Base.findfirst, Base.findmax])
pirates = filter(
m -> Piracy.is_pirate(m; treat_as_own_func = [Base.findfirst, Base.findmax]),
meths,
)
@test length(pirates) == 3
@test all(pirates) do m
m.name in [:findmin]
end

# Test what is pirate (excluding a cover of everything)
pirates = filter(
m -> Piracy.is_pirate(
m;
treat_as_own_type = [ForeignType, ForeignParameterizedType],
treat_as_own_func = [Base.findfirst],
),
meths,
)
@test length(pirates) == 0

0 comments on commit f5d3ad4

Please sign in to comment.