Add an EventTable type that supports the Tables interface
Both Kaplan-Meier and Nelson-Aalen compute the same set of basic
counts at each unique time prior to computing their respective
quantities of interest. The counts have a notably table-like format,
so much so that they can implement the Tables.jl interface with minimal

Credit for the idea of integration with Tables.jl goes entirely to Tyler
Beacon, author of PR #25, who has been added as a co-author of this

Co-Authored-By: Tyler Beason <[email protected]>
ararslan and tbeason committed Jul 25, 2022
1 parent c48d5c9 commit 954a5a8
Expand Up @@ -10,6 +10,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

CategoricalArrays = "0.9, 0.10"
Expand All @@ -20,6 +21,7 @@ Optim = "1"
StatsAPI = "1"
StatsBase = "0.30, 0.31, 0.32, 0.33"
StatsModels = "0.6"
Tables = "1"
julia = "1.6"

Expand Up @@ -7,9 +7,11 @@ using Optim
using StatsAPI
using StatsBase
using StatsModels
using Tables


Expand Up @@ -13,11 +13,11 @@ struct EventTime{T<:Real}

EventTime(time::T) where {T<:Real} = EventTime{T}(time, true)
EventTime(time, status=true) = EventTime{typeof(time)}(time, Bool(status))

## Overloaded Base functions

Base.eltype(::EventTime{T}) where {T} = T
Base.eltype(::Type{EventTime{T}}) where {T} = T, ev::EventTime) = print(io, ev.time, ifelse(ev.status, "", "+"))

Base.convert(T::Type{<:Real}, ev::EventTime) = convert(T, ev.time)
Expand All @@ -44,4 +44,167 @@ iscensored(ev::EventTime) = !ev.status

StatsModels.concrete_term(t::Term, xs::AbstractVector{<:EventTime}, ::Nothing) =
StatsModels.ContinuousTerm(t.sym, first(xs), first(xs), first(xs), first(xs))

Base.copy(et::EventTime) = et

##### `EventTable`

Immutable object summarizing the unique observed event times, including the number of
events, the number of censored observations, and the number remaining at risk for each
unique time.
This type implements the Tables.jl interface for tables, which means that `EventTable`
objects can be seamlessly converted to other tabular types such as `DataFrame`s.
Construct an `EventTable` from an array of [`EventTime`](@ref) values.
EventTable(time, status)
Construct an `EventTable` from an array of time values and an array of event status
struct EventTable{T}

function EventTable(ets)
T = eltype(eltype(ets))
isempty(ets) && return EventTable{T}(T[], Int[], Int[], Int[])
ets = issorted(ets) ? ets : sort(ets) # re-binding, input is unaffected
return _eventtable(ets)

function EventTable(time, status)
ntimes = length(time)
nstatus = length(status)
if ntimes != nstatus
throw(DimensionMismatch("number of event statuses does not match number of " *
"event times; got $nstatus and $ntimes, respectively"))
T = eltype(time)
ntimes == 0 && return EventTable{T}(T[], Int[], Int[], Int[])
ets = map(EventTime, time, status)
issorted(ets) || sort!(ets)
return _eventtable(ets)

function _droptimezero!(ets)
# Assumptions about the input:
# - iterates `EventTime`s
# - sorted ascending by elements' `.time` fields
i = findfirst(et -> !iszero(et.time), ets)
start = firstindex(ets)
if i !== nothing && i > start
deleteat!(ets, start:(start + i - 1))
return ets

function _eventtable(ets)
# Assumptions about the input:
# - nonempty
# - time 0 is not included
T = typeof(first(ets).time)
outlen = _nuniquetimes(ets)

nobs = length(ets)
dᵢ::Int = 0 # Number of observed events at time t
cᵢ::Int = 0 # Number of censored events at time t
nᵢ::Int = nobs # Number remaining at risk at time t

times = Vector{T}(undef, outlen) # The set of unique event times
nevents = Vector{Int}(undef, outlen) # Total observed events at each time
ncensor = Vector{Int}(undef, outlen) # Total censored events at each time
natrisk = Vector{Int}(undef, outlen) # Number at risk at each time

t_prev = zero(T)
outind = 1

@inbounds begin
for et in ets
t = et.time
s = et.status
# Aggregate over tied times
if t == t_prev
dᵢ += s
cᵢ += !s
elseif !iszero(t_prev)
times[outind] = t_prev
nevents[outind] = dᵢ
ncensor[outind] = cᵢ
natrisk[outind] = nᵢ
outind += 1
nᵢ -= dᵢ + cᵢ
dᵢ = s
cᵢ = !s
t_prev = t

# We need to do this one more time to capture the last time
# since everything in the loop is lagged
times[outind] = t_prev
nevents[outind] = dᵢ
ncensor[outind] = cᵢ
natrisk[outind] = nᵢ

return EventTable{eltype(times)}(times, nevents, ncensor, natrisk)

function _nuniquetimes(ets)
# Assumptions about the input:
# - nonempty
# - iterates `EventTime`s
# - sorted ascending by elements' `.time` fields
t_prev = first(ets).time
n = 1
for et in Iterators.drop(ets, 1)
t = et.time
if t != t_prev
n += 1
t_prev = t
return n

Base.copy(et::EventTable{T}) where {T} =
EventTable{T}(copy(et.time), copy(et.nevents), copy(et.ncensored), copy(et.natrisk))

Base.:(==)(a::EventTable, b::EventTable) =
a.time == b.time && a.nevents == b.nevents &&
a.ncensored == b.ncensored && a.natrisk == b.natrisk

# Tables.jl integration

Tables.istable(::Type{<:EventTable}) = true

Tables.columnaccess(::Type{<:EventTable}) = true

_rowtype(T::Type{<:EventTable}) =
NamedTuple{fieldnames(T),Tuple{map(eltype, fieldtypes(T))...}}

_rowtype(et::EventTable) = _rowtype(typeof(et))

Tables.schema(et::EventTable) = Tables.Schema(_rowtype(et))

function Tables.rows(et::EventTable)
NT = _rowtype(et)
nr = length(et.time)
nc = fieldcount(NT)
return (@inbounds(NT(ntuple(i -> getfield(et, i)[j], nc))) for j in 1:nr)
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Distributions
using LinearAlgebra
using StatsBase
using StatsModels
using Tables

@testset "Event times" begin
@test isevent(EventTime{Int}(44, true))
Expand Down Expand Up @@ -280,3 +281,25 @@ x7 0.0914971 0.0286485 3.19378 0.0014
@test coeftable(outcome_fincatracecat).rownms == ["fin: 1", "race: 1","fin: 1 & race: 1"]
@test coef(outcome_fincatracecat) coef(outcome_finrace) atol=1e-8

@testset "EventTable" begin
et = EventTable([4, 1, 3, 1, 5, 2, 3, 4], [0, 0, 1, 0, 0, 1, 0, 0])
@test Tables.istable(et)
@test Tables.columnaccess(et)
@test Tables.schema(et) isa Tables.Schema{(:time, :nevents, :ncensored, :natrisk),NTuple{4,Int}}
@test collect(Tables.rows(et)) == [(; time=1, nevents=0, ncensored=2, natrisk=8),
(; time=2, nevents=1, ncensored=0, natrisk=6),
(; time=3, nevents=1, ncensored=1, natrisk=5),
(; time=4, nevents=0, ncensored=2, natrisk=3),
(; time=5, nevents=0, ncensored=1, natrisk=1)]
et2 = copy(et)
@test et == et2
@test et !== et2
@test all(1:4) do i
a = getfield(et, i)
b = getfield(et2, i)
return a == b && a !== b
@test_throws DimensionMismatch EventTable(1:10, false:true)
@test EventTable(Float32[], Int[]) == EventTable{Float32}(Float32[], Int[], Int[], Int[])

