Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

equality and hash for terms and schemas #241

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
13 changes: 13 additions & 0 deletions src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ Base.merge!(a::Schema, b::Schema) = (merge!(a.schema, b.schema); a)
Base.keys(schema::Schema) = keys(schema.schema)
Base.haskey(schema::Schema, key) = haskey(schema.schema, key)

function Base.:(==)(first::Schema, second::Schema)
first === second && return true
first.schema === second.schema && return true
Comment on lines +57 to +58
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this will be wrong if the dict contains missing (recursively). Can this happen?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a key or a value? Not possible either way at the moment (unless it's a pathological manuallly constructed instance)

Copy link
Member

@nalimilan nalimilan Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a value I think, as for keys dicts use isequal. Note that this also applies if the value contains an object which contains a missing value (whatever the number of recursions).

length(first.schema) != length(second.schema) && return false
for key in keys(first)
!haskey(second, key) && return false
second[key] != first[key] && return false
end
true
Comment on lines +59 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this identical to first == second? If not, maybe worth a comment. (Note that this throws if the dict contains missing, maybe that's OK though.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I think you're right. Probably fine to just check first.schema == second.schema since that's really what this check is about

end

Base.hash(schema::Schema, h::UInt) = hash(schema.schema, h)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes (e.g. for arrays and tuples) we add an arbitrary constant to h (type-specific) to ensure that hash(schema) != hash(schema.schema). Not sure whether it's worth it here though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wondered about that too. Easy to do here, a bit tricker in the case of the terms (where the types might not be the asme for things that we want to be ==, e.g. different instances of a function term whre the anonymous function is different but the underlying expression is the same).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And that can "corrupt" the type of containers because of the ahem zealous use of type parameters in this codebase

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't necessarily need to hash the type: you can just define a constant and use it for all types which compare ==.


"""
schema([terms::AbstractVector{<:AbstractTerm}, ]data, hints::Dict{Symbol})
schema(term::AbstractTerm, data, hints::Dict{Symbol})
Expand Down
15 changes: 14 additions & 1 deletion src/terms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ abstract type AbstractTerm end
const TermOrTerms = Union{AbstractTerm, Tuple{AbstractTerm, Vararg{AbstractTerm}}}
const TupleTerm = Tuple{TermOrTerms, Vararg{TermOrTerms}}

Base.hash(term::T, h::UInt) where {T<:AbstractTerm} =
foldl((h, x) -> hash(x, h), getfield(term, field) for field in fieldnames(T); init=h)

function Base.:(==)(a::A, b::B) where {A<:AbstractTerm, B<:AbstractTerm}
fieldnames(A) == fieldnames(B) || return false
return all(getfield(a, field) == getfield(b, field) for field in fieldnames(A))
end

width(::T) where {T<:AbstractTerm} =
throw(ArgumentError("terms of type $T have undefined width"))

Expand Down Expand Up @@ -127,7 +135,10 @@ FunctionTerm(forig::Fo, fanon::Fa, names::NTuple{N,Symbol},
FunctionTerm{Fo, Fa, names}(forig, fanon, exorig, args_parsed)
width(::FunctionTerm) = 1

Base.:(==)(a::FunctionTerm, b::FunctionTerm) = a.forig == b.forig && a.exorig == b.exorig
Base.:(==)(first::FunctionTerm, second::FunctionTerm) =
first.forig == second.forig &&
first.exorig == second.exorig
Base.hash(term::FunctionTerm, h::UInt) = hash(term.forig, hash(term.exorig, h))

"""
InteractionTerm{Ts} <: AbstractTerm
Expand Down Expand Up @@ -191,6 +202,8 @@ via the [`implicit_intercept`](@ref) trait).
struct InterceptTerm{HasIntercept} <: AbstractTerm end
width(::InterceptTerm{H}) where {H} = H ? 1 : 0

Base.:(==)(first::InterceptTerm{T}, second::InterceptTerm{S}) where {T,S} = T == S

# Typed terms

"""
Expand Down
53 changes: 52 additions & 1 deletion test/schema.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
@testset "schemas" begin

using StatsModels: schema, apply_schema, FullRank

@testset "no-op apply_schema" begin
Expand Down Expand Up @@ -70,4 +69,56 @@

end

@testset "basic hash and equality" begin
f = @formula(y ~ 1 + a + log(b) + c + b & c)
y = rand(9)
b = rand(9)

df = (y = y, a = 1:9, b = b, c = repeat(["d", "e", "f"], 3))
f = apply_schema(f, schema(f, df))
@test f == apply_schema(f, schema(f, df))

sch1 = schema(f, df)
sch2 = schema(f, df)
@test sch1 == sch2
@test sch1 !== sch2
@test hash(sch1) == hash(sch2)

# double categorical column c to test for invariance based on levels
df2 = (y = y, a = 1:9, b = b, c = [df.c; df.c])
@test schema(df) == schema(df2)
@test hash(schema(df)) == hash(schema(df2))
@test apply_schema(f, schema(df)) == apply_schema(f, schema(df2))

# different levels
df3 = (y = y, a = 1:9, b = b, c = repeat(["a", "b", "c"], 3))
@test schema(df) != schema(df3)

# different length, so different summary stats for continuous
df4 = (y = [df.y; df.y], a = [1:9; 1:9], b = [b; b], c = [df.c; df.c])
@test schema(df) != schema(df4)

# different names for some columns
df5 = (z = y, a = 1:9, b = b, c = repeat(["d", "e", "f"], 3))
@test schema(df) != schema(df5)

# different values in continuous column so different stats
df6 = (y = y, a = 2:10, b = b, c = repeat(["a", "b", "c"], 3))
@test schema(df) != schema(df6)

# different names?
df7 = (w = y, d = 1:9, x = b, z = repeat(["d", "e", "f"], 3))
@test schema(df) != schema(df7)

# missing column
df8 = (y = y, a = 1:9, c = repeat(["d", "e", "f"], 3))
@test schema(df) != schema(df8)

# different coding/hints
sch = schema(df, Dict(:c => DummyCoding(base="e")))
sch2 = schema(df, Dict(:c => EffectsCoding(base="e")))
sch3 = schema(df, Dict(:y => DummyCoding()))
@test sch != sch2
@test sch != sch3
end
end
60 changes: 48 additions & 12 deletions test/terms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,36 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) =
@test t0.var == var([1,2,3])
@test t0.min == 1.0
@test t0.max == 3.0
@test t0 == concrete_term(t, [3, 2, 1])
@test hash(t0) == hash(concrete_term(t, [3, 2, 1]))

t1 = concrete_term(t, [:a, :b, :c])
@test t1.contrasts isa StatsModels.ContrastsMatrix{DummyCoding}
@test string(t1) == "aaa"
@test mimestring(t1) == "aaa(DummyCoding:3→2)"
@test t1 == concrete_term(t, [:a, :b, :c])
@test t1 !== concrete_term(t, [:a, :b, :c])
@test hash(t1) == hash(concrete_term(t, [:a, :b, :c]))

t3 = concrete_term(t, [:a, :b, :c], DummyCoding())
@test t3.contrasts isa StatsModels.ContrastsMatrix{DummyCoding}
@test string(t3) == "aaa"
@test mimestring(t3) == "aaa(DummyCoding:3→2)"
@test t1 == t3
@test hash(t1) == hash(t3)

t2 = concrete_term(t, [:a, :a, :b], EffectsCoding())
@test t2.contrasts isa StatsModels.ContrastsMatrix{EffectsCoding}
@test mimestring(t2) == "aaa(EffectsCoding:2→1)"
@test string(t2) == "aaa"
@test t2 == concrete_term(t, [:a, :a, :b], EffectsCoding())
@test t1 != t2

t2full = concrete_term(t, [:a, :a, :b], StatsModels.FullDummyCoding())
@test t2full.contrasts isa StatsModels.ContrastsMatrix{StatsModels.FullDummyCoding}
@test mimestring(t2full) == "aaa(StatsModels.FullDummyCoding:2→2)"
@test string(t2full) == "aaa"
@test t1 != t2full
end

@testset "term operators" begin
Expand Down Expand Up @@ -89,18 +99,6 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) =
@test +a == a
end

@testset "uniqueness of FunctionTerms" begin
f1 = @formula(y ~ lag(x,1) + lag(x,1))
f2 = @formula(y ~ lag(x,1))
f3 = @formula(y ~ lag(x,1) + lag(x,2))

@test f1.rhs == f2.rhs
@test f1.rhs != f3.rhs

## addition of two identical function terms
@test f2.rhs + f2.rhs == f2.rhs
end

@testset "expand nested tuples of terms during apply_schema" begin
sch = schema((a=rand(10), b=rand(10), c=rand(10)))

Expand Down Expand Up @@ -173,6 +171,44 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) =

end

@testset "equality of function terms" begin
# for now, we use `@formula` to construct the function terms
f1 = @formula(0 ~ (1 | x)).rhs
f2 = @formula(0 ~ (1 | x)).rhs
@test f1 !== f2
@test f1 == f2
@test hash(f1) == hash(f2)

f3 = @formula(0 ~ (1 % x)).rhs
@test f1 != f3
@test hash(f1) != hash(f3)

f4 = @formula(0 ~ (x | 1)).rhs
@test f1 != f4
@test hash(f1) != hash(f4)

f5 = @formula(0 ~ (1 & y | x)).rhs
@test f1 != f5
@test hash(f1) != hash(f5)

ff1 = @formula(y ~ 1 + x + x & y + (1 + x | g))
ff2 = @formula(y ~ 1 + x + x & y + (1 + x | g))
@test ff1 == ff2
@test hash(ff1) == hash(ff2)
end

@testset "uniqueness of FunctionTerms" begin
f1 = @formula(y ~ lag(x,1) + lag(x,1))
f2 = @formula(y ~ lag(x,1))
f3 = @formula(y ~ lag(x,1) + lag(x,2))

@test f1.rhs == f2.rhs
@test f1.rhs != f3.rhs

## addition of two identical function terms
@test f2.rhs + f2.rhs == f2.rhs
end

@testset "Tuple terms" begin
using StatsModels: TermOrTerms, TupleTerm, Term
a, b, c = Term.((:a, :b, :c))
Expand Down