Skip to content

Commit

Permalink
Add an EventTable type that supports the Tables interface
Browse files Browse the repository at this point in the history
Both Kaplan-Meier and Nelson-Aalen compute the same set of basic
counts at each unique time prior to computing their respective
quantities of interest. The counts have a notably table-like format,
so much so that they can implement the Tables.jl interface with minimal
effort.

Credit for the idea of integration with Tables.jl goes entirely to Tyler
Beacon, author of PR #25, who has been added as a co-author of this
commit.

Co-Authored-By: Tyler Beason <[email protected]>
  • Loading branch information
ararslan and tbeason committed Jul 25, 2022
1 parent c48d5c9 commit 954a5a8
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CategoricalArrays = "0.9, 0.10"
Expand All @@ -20,6 +21,7 @@ Optim = "1"
StatsAPI = "1"
StatsBase = "0.30, 0.31, 0.32, 0.33"
StatsModels = "0.6"
Tables = "1"
julia = "1.6"

[extras]
Expand Down
2 changes: 2 additions & 0 deletions src/Survival.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ using Optim
using StatsAPI
using StatsBase
using StatsModels
using Tables

export
EventTime,
EventTable,
isevent,
iscensored,

Expand Down
167 changes: 165 additions & 2 deletions src/eventtimes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ struct EventTime{T<:Real}
status::Bool
end

EventTime(time::T) where {T<:Real} = EventTime{T}(time, true)
EventTime(time, status=true) = EventTime{typeof(time)}(time, Bool(status))

## Overloaded Base functions

Base.eltype(::EventTime{T}) where {T} = T
Base.eltype(::Type{EventTime{T}}) where {T} = T
Base.show(io::IO, ev::EventTime) = print(io, ev.time, ifelse(ev.status, "", "+"))

Base.convert(T::Type{<:Real}, ev::EventTime) = convert(T, ev.time)
Expand All @@ -44,4 +44,167 @@ iscensored(ev::EventTime) = !ev.status

StatsModels.concrete_term(t::Term, xs::AbstractVector{<:EventTime}, ::Nothing) =
StatsModels.ContinuousTerm(t.sym, first(xs), first(xs), first(xs), first(xs))

Base.copy(et::EventTime) = et

#####
##### `EventTable`
#####

"""
EventTable{T}
Immutable object summarizing the unique observed event times, including the number of
events, the number of censored observations, and the number remaining at risk for each
unique time.
This type implements the Tables.jl interface for tables, which means that `EventTable`
objects can be seamlessly converted to other tabular types such as `DataFrame`s.
EventTable(eventtimes)
Construct an `EventTable` from an array of [`EventTime`](@ref) values.
EventTable(time, status)
Construct an `EventTable` from an array of time values and an array of event status
indicators.
"""
struct EventTable{T}
time::Vector{T}
nevents::Vector{Int}
ncensored::Vector{Int}
natrisk::Vector{Int}
end

function EventTable(ets)
T = eltype(eltype(ets))
isempty(ets) && return EventTable{T}(T[], Int[], Int[], Int[])
ets = issorted(ets) ? ets : sort(ets) # re-binding, input is unaffected
_droptimezero!(ets)
return _eventtable(ets)
end

function EventTable(time, status)
ntimes = length(time)
nstatus = length(status)
if ntimes != nstatus
throw(DimensionMismatch("number of event statuses does not match number of " *
"event times; got $nstatus and $ntimes, respectively"))
end
T = eltype(time)
ntimes == 0 && return EventTable{T}(T[], Int[], Int[], Int[])
ets = map(EventTime, time, status)
issorted(ets) || sort!(ets)
_droptimezero!(ets)
return _eventtable(ets)
end

function _droptimezero!(ets)
# Assumptions about the input:
# - iterates `EventTime`s
# - sorted ascending by elements' `.time` fields
i = findfirst(et -> !iszero(et.time), ets)
start = firstindex(ets)
if i !== nothing && i > start
deleteat!(ets, start:(start + i - 1))
end
return ets
end

function _eventtable(ets)
# Assumptions about the input:
# - nonempty
# - time 0 is not included
T = typeof(first(ets).time)
outlen = _nuniquetimes(ets)

nobs = length(ets)
dᵢ::Int = 0 # Number of observed events at time t
cᵢ::Int = 0 # Number of censored events at time t
nᵢ::Int = nobs # Number remaining at risk at time t

times = Vector{T}(undef, outlen) # The set of unique event times
nevents = Vector{Int}(undef, outlen) # Total observed events at each time
ncensor = Vector{Int}(undef, outlen) # Total censored events at each time
natrisk = Vector{Int}(undef, outlen) # Number at risk at each time

t_prev = zero(T)
outind = 1

@inbounds begin
for et in ets
t = et.time
s = et.status
# Aggregate over tied times
if t == t_prev
dᵢ += s
cᵢ += !s
continue
elseif !iszero(t_prev)
times[outind] = t_prev
nevents[outind] = dᵢ
ncensor[outind] = cᵢ
natrisk[outind] = nᵢ
outind += 1
end
nᵢ -= dᵢ + cᵢ
dᵢ = s
cᵢ = !s
t_prev = t
end

# We need to do this one more time to capture the last time
# since everything in the loop is lagged
times[outind] = t_prev
nevents[outind] = dᵢ
ncensor[outind] = cᵢ
natrisk[outind] = nᵢ
end

return EventTable{eltype(times)}(times, nevents, ncensor, natrisk)
end

function _nuniquetimes(ets)
# Assumptions about the input:
# - nonempty
# - iterates `EventTime`s
# - sorted ascending by elements' `.time` fields
t_prev = first(ets).time
n = 1
for et in Iterators.drop(ets, 1)
t = et.time
if t != t_prev
n += 1
t_prev = t
end
end
return n
end

Base.copy(et::EventTable{T}) where {T} =
EventTable{T}(copy(et.time), copy(et.nevents), copy(et.ncensored), copy(et.natrisk))

Base.:(==)(a::EventTable, b::EventTable) =
a.time == b.time && a.nevents == b.nevents &&
a.ncensored == b.ncensored && a.natrisk == b.natrisk

# Tables.jl integration

Tables.istable(::Type{<:EventTable}) = true

Tables.columnaccess(::Type{<:EventTable}) = true

_rowtype(T::Type{<:EventTable}) =
NamedTuple{fieldnames(T),Tuple{map(eltype, fieldtypes(T))...}}

_rowtype(et::EventTable) = _rowtype(typeof(et))

Tables.schema(et::EventTable) = Tables.Schema(_rowtype(et))

function Tables.rows(et::EventTable)
NT = _rowtype(et)
nr = length(et.time)
nc = fieldcount(NT)
return (@inbounds(NT(ntuple(i -> getfield(et, i)[j], nc))) for j in 1:nr)
end
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Distributions
using LinearAlgebra
using StatsBase
using StatsModels
using Tables

@testset "Event times" begin
@test isevent(EventTime{Int}(44, true))
Expand Down Expand Up @@ -280,3 +281,25 @@ x7 0.0914971 0.0286485 3.19378 0.0014
@test coeftable(outcome_fincatracecat).rownms == ["fin: 1", "race: 1","fin: 1 & race: 1"]
@test coef(outcome_fincatracecat) coef(outcome_finrace) atol=1e-8
end

@testset "EventTable" begin
et = EventTable([4, 1, 3, 1, 5, 2, 3, 4], [0, 0, 1, 0, 0, 1, 0, 0])
@test Tables.istable(et)
@test Tables.columnaccess(et)
@test Tables.schema(et) isa Tables.Schema{(:time, :nevents, :ncensored, :natrisk),NTuple{4,Int}}
@test collect(Tables.rows(et)) == [(; time=1, nevents=0, ncensored=2, natrisk=8),
(; time=2, nevents=1, ncensored=0, natrisk=6),
(; time=3, nevents=1, ncensored=1, natrisk=5),
(; time=4, nevents=0, ncensored=2, natrisk=3),
(; time=5, nevents=0, ncensored=1, natrisk=1)]
et2 = copy(et)
@test et == et2
@test et !== et2
@test all(1:4) do i
a = getfield(et, i)
b = getfield(et2, i)
return a == b && a !== b
end
@test_throws DimensionMismatch EventTable(1:10, false:true)
@test EventTable(Float32[], Int[]) == EventTable{Float32}(Float32[], Int[], Int[], Int[])
end

0 comments on commit 954a5a8

Please sign in to comment.