diff --git a/Project.toml b/Project.toml index 934359f..a42e3e2 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] CategoricalArrays = "0.9, 0.10" @@ -20,6 +21,7 @@ Optim = "1" StatsAPI = "1" StatsBase = "0.30, 0.31, 0.32, 0.33" StatsModels = "0.6" +Tables = "1" julia = "1.6" [extras] diff --git a/src/Survival.jl b/src/Survival.jl index 9a2fd11..44955e4 100644 --- a/src/Survival.jl +++ b/src/Survival.jl @@ -7,9 +7,11 @@ using Optim using StatsAPI using StatsBase using StatsModels +using Tables export EventTime, + EventTable, isevent, iscensored, diff --git a/src/eventtimes.jl b/src/eventtimes.jl index b66f87e..2887c1b 100644 --- a/src/eventtimes.jl +++ b/src/eventtimes.jl @@ -13,11 +13,11 @@ struct EventTime{T<:Real} status::Bool end -EventTime(time::T) where {T<:Real} = EventTime{T}(time, true) +EventTime(time, status=true) = EventTime{typeof(time)}(time, Bool(status)) ## Overloaded Base functions -Base.eltype(::EventTime{T}) where {T} = T +Base.eltype(::Type{EventTime{T}}) where {T} = T Base.show(io::IO, ev::EventTime) = print(io, ev.time, ifelse(ev.status, "", "+")) Base.convert(T::Type{<:Real}, ev::EventTime) = convert(T, ev.time) @@ -44,4 +44,167 @@ iscensored(ev::EventTime) = !ev.status StatsModels.concrete_term(t::Term, xs::AbstractVector{<:EventTime}, ::Nothing) = StatsModels.ContinuousTerm(t.sym, first(xs), first(xs), first(xs), first(xs)) + Base.copy(et::EventTime) = et + +##### +##### `EventTable` +##### + +""" + EventTable{T} + +Immutable object summarizing the unique observed event times, including the number of +events, the number of censored observations, and the number remaining at risk for each +unique time. + +This type implements the Tables.jl interface for tables, which means that `EventTable` +objects can be seamlessly converted to other tabular types such as `DataFrame`s. + + EventTable(eventtimes) + +Construct an `EventTable` from an array of [`EventTime`](@ref) values. + + EventTable(time, status) + +Construct an `EventTable` from an array of time values and an array of event status +indicators. +""" +struct EventTable{T} + time::Vector{T} + nevents::Vector{Int} + ncensored::Vector{Int} + natrisk::Vector{Int} +end + +function EventTable(ets) + T = eltype(eltype(ets)) + isempty(ets) && return EventTable{T}(T[], Int[], Int[], Int[]) + ets = issorted(ets) ? ets : sort(ets) # re-binding, input is unaffected + _droptimezero!(ets) + return _eventtable(ets) +end + +function EventTable(time, status) + ntimes = length(time) + nstatus = length(status) + if ntimes != nstatus + throw(DimensionMismatch("number of event statuses does not match number of " * + "event times; got $nstatus and $ntimes, respectively")) + end + T = eltype(time) + ntimes == 0 && return EventTable{T}(T[], Int[], Int[], Int[]) + ets = map(EventTime, time, status) + issorted(ets) || sort!(ets) + _droptimezero!(ets) + return _eventtable(ets) +end + +function _droptimezero!(ets) + # Assumptions about the input: + # - iterates `EventTime`s + # - sorted ascending by elements' `.time` fields + i = findfirst(et -> !iszero(et.time), ets) + start = firstindex(ets) + if i !== nothing && i > start + deleteat!(ets, start:(start + i - 1)) + end + return ets +end + +function _eventtable(ets) + # Assumptions about the input: + # - nonempty + # - time 0 is not included + T = typeof(first(ets).time) + outlen = _nuniquetimes(ets) + + nobs = length(ets) + dᵢ::Int = 0 # Number of observed events at time t + cᵢ::Int = 0 # Number of censored events at time t + nᵢ::Int = nobs # Number remaining at risk at time t + + times = Vector{T}(undef, outlen) # The set of unique event times + nevents = Vector{Int}(undef, outlen) # Total observed events at each time + ncensor = Vector{Int}(undef, outlen) # Total censored events at each time + natrisk = Vector{Int}(undef, outlen) # Number at risk at each time + + t_prev = zero(T) + outind = 1 + + @inbounds begin + for et in ets + t = et.time + s = et.status + # Aggregate over tied times + if t == t_prev + dᵢ += s + cᵢ += !s + continue + elseif !iszero(t_prev) + times[outind] = t_prev + nevents[outind] = dᵢ + ncensor[outind] = cᵢ + natrisk[outind] = nᵢ + outind += 1 + end + nᵢ -= dᵢ + cᵢ + dᵢ = s + cᵢ = !s + t_prev = t + end + + # We need to do this one more time to capture the last time + # since everything in the loop is lagged + times[outind] = t_prev + nevents[outind] = dᵢ + ncensor[outind] = cᵢ + natrisk[outind] = nᵢ + end + + return EventTable{eltype(times)}(times, nevents, ncensor, natrisk) +end + +function _nuniquetimes(ets) + # Assumptions about the input: + # - nonempty + # - iterates `EventTime`s + # - sorted ascending by elements' `.time` fields + t_prev = first(ets).time + n = 1 + for et in Iterators.drop(ets, 1) + t = et.time + if t != t_prev + n += 1 + t_prev = t + end + end + return n +end + +Base.copy(et::EventTable{T}) where {T} = + EventTable{T}(copy(et.time), copy(et.nevents), copy(et.ncensored), copy(et.natrisk)) + +Base.:(==)(a::EventTable, b::EventTable) = + a.time == b.time && a.nevents == b.nevents && + a.ncensored == b.ncensored && a.natrisk == b.natrisk + +# Tables.jl integration + +Tables.istable(::Type{<:EventTable}) = true + +Tables.columnaccess(::Type{<:EventTable}) = true + +_rowtype(T::Type{<:EventTable}) = + NamedTuple{fieldnames(T),Tuple{map(eltype, fieldtypes(T))...}} + +_rowtype(et::EventTable) = _rowtype(typeof(et)) + +Tables.schema(et::EventTable) = Tables.Schema(_rowtype(et)) + +function Tables.rows(et::EventTable) + NT = _rowtype(et) + nr = length(et.time) + nc = fieldcount(NT) + return (@inbounds(NT(ntuple(i -> getfield(et, i)[j], nc))) for j in 1:nr) +end diff --git a/test/runtests.jl b/test/runtests.jl index be432a9..92c8ab4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Distributions using LinearAlgebra using StatsBase using StatsModels +using Tables @testset "Event times" begin @test isevent(EventTime{Int}(44, true)) @@ -280,3 +281,25 @@ x7 0.0914971 0.0286485 3.19378 0.0014 @test coeftable(outcome_fincatracecat).rownms == ["fin: 1", "race: 1","fin: 1 & race: 1"] @test coef(outcome_fincatracecat) ≈ coef(outcome_finrace) atol=1e-8 end + +@testset "EventTable" begin + et = EventTable([4, 1, 3, 1, 5, 2, 3, 4], [0, 0, 1, 0, 0, 1, 0, 0]) + @test Tables.istable(et) + @test Tables.columnaccess(et) + @test Tables.schema(et) isa Tables.Schema{(:time, :nevents, :ncensored, :natrisk),NTuple{4,Int}} + @test collect(Tables.rows(et)) == [(; time=1, nevents=0, ncensored=2, natrisk=8), + (; time=2, nevents=1, ncensored=0, natrisk=6), + (; time=3, nevents=1, ncensored=1, natrisk=5), + (; time=4, nevents=0, ncensored=2, natrisk=3), + (; time=5, nevents=0, ncensored=1, natrisk=1)] + et2 = copy(et) + @test et == et2 + @test et !== et2 + @test all(1:4) do i + a = getfield(et, i) + b = getfield(et2, i) + return a == b && a !== b + end + @test_throws DimensionMismatch EventTable(1:10, false:true) + @test EventTable(Float32[], Int[]) == EventTable{Float32}(Float32[], Int[], Int[], Int[]) +end