Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Piracy: Add kwarg treat_as_own #140

Merged
merged 6 commits into from
Jun 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions src/piracy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ end
##################################
# Generic fallback for type parameters that are instances, like the 1 in
# Array{T, 1}
is_foreign(@nospecialize(x), pkg::Base.PkgId) = is_foreign(typeof(x), pkg)
is_foreign(@nospecialize(x), pkg::Base.PkgId; treat_as_own) =
is_foreign(typeof(x), pkg; treat_as_own = treat_as_own)

# Symbols can be used as type params - we assume these are unique and not
# piracy. This implies that we have
Expand Down Expand Up @@ -77,87 +78,96 @@ is_foreign(@nospecialize(x), pkg::Base.PkgId) = is_foreign(typeof(x), pkg)
# a crazy API). The symbol name may also come from `gensym`. Since the aim of
# `Aqua.test_piracy` is to detect only "obvious" piracy, let us play on the
# safe side.
is_foreign(x::Symbol, pkg::Base.PkgId) = false
is_foreign(x::Symbol, pkg::Base.PkgId; treat_as_own) = false

is_foreign_module(mod::Module, pkg::Base.PkgId) = Base.PkgId(mod) != pkg

function is_foreign(@nospecialize(T::DataType), pkg::Base.PkgId)
function is_foreign(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as_own)
params = T.parameters
# For Type{Foo}, we consider it to originate from the same as Foo
C = getfield(parentmodule(T), nameof(T))
if C === Type
@assert length(params) == 1
return is_foreign(first(params), pkg)
return is_foreign(first(params), pkg; treat_as_own = treat_as_own)
else
# Both the type itself and all of its parameters must be foreign
return is_foreign_module(parentmodule(T), pkg) && all(params) do param
is_foreign(param, pkg)
end
return !(C in treat_as_own) &&
is_foreign_module(parentmodule(T), pkg) &&
all(param -> is_foreign(param, pkg; treat_as_own = treat_as_own), params)
end
end

function is_foreign(@nospecialize(U::UnionAll), pkg::Base.PkgId)
function is_foreign(@nospecialize(U::UnionAll), pkg::Base.PkgId; treat_as_own)
# We do not consider extending Set{T} to be piracy, if T is not foreign.
# Extending it goes against Julia style, but it's not piracy IIUC.
is_foreign(U.body, pkg) && is_foreign(U.var, pkg)
is_foreign(U.body, pkg; treat_as_own = treat_as_own) &&
is_foreign(U.var, pkg; treat_as_own = treat_as_own)
end

is_foreign(@nospecialize(T::TypeVar), pkg::Base.PkgId) = is_foreign(T.ub, pkg)
is_foreign(@nospecialize(T::TypeVar), pkg::Base.PkgId; treat_as_own) =
is_foreign(T.ub, pkg; treat_as_own = treat_as_own)

# Before 1.7, Vararg was a UnionAll, so the UnionAll method will work
@static if VERSION >= v"1.7"
is_foreign(@nospecialize(T::Core.TypeofVararg), pkg::Base.PkgId) = is_foreign(T.T, pkg)
is_foreign(@nospecialize(T::Core.TypeofVararg), pkg::Base.PkgId; treat_as_own) =
is_foreign(T.T, pkg; treat_as_own = treat_as_own)
end

function is_foreign(@nospecialize(U::Union), pkg::Base.PkgId)
function is_foreign(@nospecialize(U::Union), pkg::Base.PkgId; treat_as_own)
# Even if Foo is local, overloading f(::Union{Foo, Int}) with foreign f
# is piracy.
any(T -> is_foreign(T, pkg), Base.uniontypes(U))
any(T -> is_foreign(T, pkg; treat_as_own = treat_as_own), Base.uniontypes(U))
end

function is_foreign_method(@nospecialize(U::Union), pkg::Base.PkgId)
function is_foreign_method(@nospecialize(U::Union), pkg::Base.PkgId; treat_as_own)
# When installing a method for a union type, then we only consider it as
# foreign if *all* parameters of the union are foreign, i.e. overloading
# Union{Foo, Int}() is not piracy.
all(T -> is_foreign(T, pkg), Base.uniontypes(U))
all(T -> is_foreign(T, pkg; treat_as_own = treat_as_own), Base.uniontypes(U))
end

function is_foreign_method(@nospecialize(x::Any), pkg::Base.PkgId)
is_foreign(x, pkg)
function is_foreign_method(@nospecialize(x::Any), pkg::Base.PkgId; treat_as_own)
is_foreign(x, pkg; treat_as_own = treat_as_own)
end

function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId)
function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as_own)
params = T.parameters
# For Type{Foo}, we consider it to originate from the same as Foo
C = getfield(parentmodule(T), nameof(T))
if C === Type
@assert length(params) == 1
U = first(params)
return is_foreign_method(first(params), pkg)
return is_foreign_method(first(params), pkg; treat_as_own = treat_as_own)
end

# fallback to general code
return is_foreign(T, pkg)
return !(T in treat_as_own) &&
!(T <: Function && T.instance in treat_as_own) &&
is_foreign(T, pkg; treat_as_own = treat_as_own)
end


function is_pirate(meth::Method)
function is_pirate(meth::Method; treat_as_own = Union{Function,Type}[])
method_pkg = Base.PkgId(meth.module)

signature = Base.unwrap_unionall(meth.sig)

# the first parameter in the signature is the function type, and it
# follows slightly other rules if it happens to be a Union type
is_foreign_method(signature.parameters[1], method_pkg) || return false
is_foreign_method(signature.parameters[1], method_pkg; treat_as_own = treat_as_own) ||
return false

all(param -> is_foreign(param, method_pkg), signature.parameters[2:end])
all(
param -> is_foreign(param, method_pkg; treat_as_own = treat_as_own),
signature.parameters[2:end],
)
end

hunt(mod::Module; from::Module = mod) = hunt(Base.PkgId(mod); from = from)
hunt(mod::Module; from::Module = mod, kwargs...) =
hunt(Base.PkgId(mod); from = from, kwargs...)

function hunt(pkg::Base.PkgId; from::Module)
function hunt(pkg::Base.PkgId; from::Module, kwargs...)
filter(all_methods(from)) do method
is_pirate(method) && Base.PkgId(method.module) === pkg
Base.PkgId(method.module) === pkg && is_pirate(method, kwargs...)
end
end

Expand All @@ -172,9 +182,14 @@ See [Julia documentation](https://docs.julialang.org/en/v1/manual/style-guide/#A
# Keyword Arguments
- `broken::Bool = false`: If true, it uses `@test_broken` instead of
`@test`.
- `treat_as_own = Union{Function, Type}[]`: The types in this container
are considered to be "owned" by the module `m`. This is useful for
testing packages that deliberately commit some type piracy, e.g. modules
adding higher-level functionality to a lightweight C-wrapper, or packages
that are extending `StatsAPI.jl`.
"""
function test_piracy(m::Module; broken::Bool = false)
v = Piracy.hunt(m)
function test_piracy(m::Module; broken::Bool = false, kwargs...)
v = Piracy.hunt(m; kwargs...)
if !isempty(v)
printstyled(
stderr,
Expand Down
2 changes: 2 additions & 0 deletions test/pkgs/PiracyForeignProject/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
name = "PiracyForeignProject"
uuid = "f592ac8b-a2e8-4dd0-be7a-e4053dab5b76"
6 changes: 6 additions & 0 deletions test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module PiracyForeignProject

struct ForeignType end
struct ForeignParameterizedType{T} end

end
82 changes: 72 additions & 10 deletions test/test_piracy.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
push!(LOAD_PATH, joinpath(@__DIR__, "pkgs", "PiracyForeignProject"))

baremodule PiracyModule

using PiracyForeignProject: ForeignType, ForeignParameterizedType

using Base:
Base,
Set,
Expand Down Expand Up @@ -28,6 +32,7 @@ Base.findlast(::Set{Foo}, x::Int) = x + 1
Base.findlast(::Type{Val{Foo}}, x::Int) = x + 1
Base.findlast(::Tuple{Vararg{Bar{Set{Int}}}}, x::Int) = x + 1
Base.findlast(::Val{:foo}, x::Int) = x + 1
Base.findlast(::ForeignParameterizedType{Foo}, x::Int) = x + 1

# Not piracy
const MyUnion = Union{Int,Foo}
Expand All @@ -40,12 +45,25 @@ Base.findfirst(::Set{Vector{Char}}, ::Int) = 1
Base.findfirst(::Union{Foo,Bar{Set{Unsigned}},UInt}, ::Tuple{Vararg{String}}) = 1
Base.findfirst(::AbstractChar, ::Set{T}) where {Int <: T <: Integer} = 1

# Piracy, but not for `ForeignType in treat_as_own`
Base.findmax(::ForeignType, x::Int) = x + 1
Base.findmax(::Set{Vector{ForeignType}}, x::Int) = x + 1
Base.findmax(::Union{Foo,ForeignType}, x::Int) = x + 1

# Piracy, but not for `ForeignParameterizedType in treat_as_own`
Base.findmin(::ForeignParameterizedType{Int}, x::Int) = x + 1
Base.findmin(::Set{Vector{ForeignParameterizedType{Int}}}, x::Int) = x + 1
Base.findmin(::Union{Foo,ForeignParameterizedType{Int}}, x::Int) = x + 1

# Assign them names in this module so they can be found by all_methods
x = Base.findfirst
y = Base.findlast
a = Base.findfirst
b = Base.findlast
c = Base.findmax
d = Base.findmin
end # PiracyModule

using Aqua: Piracy
using PiracyForeignProject: ForeignType, ForeignParameterizedType

# Get all methods - test length
meths = filter(Piracy.all_methods(PiracyModule)) do m
Expand All @@ -55,23 +73,67 @@ end
# 2 Foo constructors
# 2 from f
# 1 from MyUnion
# 5 from findlast
# 6 from findlast
# 3 from findfirst
@test length(meths) == 2 + 2 + 1 + 5 + 3
# 3 from findmax
# 3 from findmin
@test length(meths) == 2 + 2 + 1 + 6 + 3 + 3 + 3

# Test what is foreign
BasePkg = Base.PkgId(Base)
CorePkg = Base.PkgId(Core)
ThisPkg = Base.PkgId(PiracyModule)

@test Piracy.is_foreign(Int, BasePkg) # from Core
@test !Piracy.is_foreign(Int, CorePkg) # from Core
@test !Piracy.is_foreign(Set{Int}, BasePkg)
@test !Piracy.is_foreign(Set{Int}, CorePkg)
@test Piracy.is_foreign(Int, BasePkg; treat_as_own = []) # from Core
@test !Piracy.is_foreign(Int, CorePkg; treat_as_own = []) # from Core
@test !Piracy.is_foreign(Set{Int}, BasePkg; treat_as_own = [])
@test !Piracy.is_foreign(Set{Int}, CorePkg; treat_as_own = [])

# Test what is pirate
pirates = filter(Piracy.is_pirate, meths)
pirates = filter(m -> Piracy.is_pirate(m), meths)
@test length(pirates) == 3 + 3 + 3
@test all(pirates) do m
m.name in [:findfirst, :findmax, :findmin]
end

# Test what is pirate (with treat_as_own=[ForeignType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignType]), meths)
@test length(pirates) == 3 + 3
@test all(pirates) do m
m.name in [:findfirst, :findmin]
end

# Test what is pirate (with treat_as_own=[ForeignParameterizedType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignParameterizedType]), meths)
@test length(pirates) == 3 + 3
@test all(pirates) do m
m.name in [:findfirst, :findmax]
end

# Test what is pirate (with treat_as_own=[ForeignType, ForeignParameterizedType])
pirates = filter(
m -> Piracy.is_pirate(m; treat_as_own = [ForeignType, ForeignParameterizedType]),
meths,
)
@test length(pirates) == 3
@test all(pirates) do m
m.name === :findfirst
m.name in [:findfirst]
end

# Test what is pirate (with treat_as_own=[Base.findfirst, Base.findmax])
pirates =
filter(m -> Piracy.is_pirate(m; treat_as_own = [Base.findfirst, Base.findmax]), meths)
@test length(pirates) == 3
@test all(pirates) do m
m.name in [:findmin]
end

# Test what is pirate (excluding a cover of everything)
pirates = filter(
m -> Piracy.is_pirate(
m;
treat_as_own = [ForeignType, ForeignParameterizedType, Base.findfirst],
),
meths,
)
@test length(pirates) == 0