Skip to content

Commit

Permalink
Extend functions from StatsAPI instead of StatsBase (#42)
Browse files Browse the repository at this point in the history
Also define methods for `modelmatrix` and `dof_residual`.

The direct dependency on StatsBase may actually be unnecessary now but I
didn't bother to verify.
  • Loading branch information
ararslan authored Jul 19, 2022
1 parent c38b803 commit 2146c2c
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 25 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"

Expand All @@ -16,6 +17,7 @@ Compat = "3.43, 4"
DataFrames = "1"
Distributions = "0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
Optim = "1"
StatsAPI = "1"
StatsBase = "0.30, 0.31, 0.32, 0.33"
StatsModels = "0.6"
julia = "1.6"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Survival = "8a913413-2070-5976-9d4c-2b364fdc2f7f"

Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Survival, Documenter, StatsBase
using Survival, Documenter, StatsBase, StatsAPI

makedocs(
modules = [Survival],
Expand Down
2 changes: 1 addition & 1 deletion docs/src/cox.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ the vector of coefficients in the model.
## API

```@docs
StatsBase.fit(::Type{CoxModel}, M::AbstractMatrix, y::AbstractVector; kwargs...)
StatsAPI.fit(::Type{CoxModel}, M::AbstractMatrix, y::AbstractVector; kwargs...)
```

## References
Expand Down
4 changes: 2 additions & 2 deletions docs/src/km.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ using Greenwood's formula:

```@docs
Survival.KaplanMeier
StatsBase.fit(::Type{KaplanMeier}, ::Any, ::Any)
StatsBase.confint(::KaplanMeier, ::Float64)
StatsAPI.fit(::Type{KaplanMeier}, ::Any, ::Any)
StatsAPI.confint(::KaplanMeier, ::Float64)
```

## References
Expand Down
4 changes: 2 additions & 2 deletions docs/src/na.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ from ``n_i`` samples:

```@docs
Survival.NelsonAalen
StatsBase.fit(::Type{NelsonAalen}, ::Any, ::Any)
StatsBase.confint(::NelsonAalen, ::Float64)
StatsAPI.fit(::Type{NelsonAalen}, ::Any, ::Any)
StatsAPI.confint(::NelsonAalen, ::Float64)
```

## References
Expand Down
3 changes: 3 additions & 0 deletions src/Survival.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Compat
using Distributions
using LinearAlgebra
using Optim
using StatsAPI
using StatsBase
using StatsModels

Expand All @@ -21,10 +22,12 @@ export
CoxModel,
coxph,
coef,
modelmatrix,
loglikelihood,
nullloglikelihood,
nobs,
dof,
dof_residual,
vcov,
stderror

Expand Down
23 changes: 14 additions & 9 deletions src/cox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct CoxModel{T<:Real} <: RegressionModel
vcov::Matrix{T}
end

function StatsBase.coeftable(obj::CoxModel)
function StatsAPI.coeftable(obj::CoxModel)
β = coef(obj)
se = stderror(obj)
z_score = β ./ se
Expand All @@ -105,18 +105,23 @@ function Base.show(io::IO, model::CoxModel)
show(io, ct)
end

StatsBase.coef(obj::CoxModel) = obj.β
StatsAPI.coef(obj::CoxModel) = obj.β

StatsBase.loglikelihood(obj::CoxModel) = obj.loglik
StatsBase.nullloglikelihood(obj::CoxModel{T}) where {T} = -_cox_f(obj.β * zero(T), obj.aux)
StatsAPI.modelmatrix(obj::CoxModel) = obj.aux.X

StatsBase.nobs(obj::CoxModel) = size(obj.aux.X, 1)
StatsAPI.loglikelihood(obj::CoxModel) = obj.loglik

StatsBase.dof(obj::CoxModel) = length(obj.β)
StatsAPI.nullloglikelihood(obj::CoxModel) = -_cox_f(zero(coef(obj)), obj.aux)

StatsBase.vcov(obj::CoxModel) = obj.vcov
StatsAPI.nobs(obj::CoxModel) = size(modelmatrix(obj), 1)

StatsBase.stderror(obj::CoxModel) = sqrt.(diag(vcov(obj)))
StatsAPI.dof(obj::CoxModel) = length(coef(obj))

StatsAPI.dof_residual(obj::CoxModel) = nobs(obj) - dof(obj)

StatsAPI.vcov(obj::CoxModel) = obj.vcov

StatsAPI.stderror(obj::CoxModel) = sqrt.(diag(vcov(obj)))

#compute negative loglikelihood

Expand Down Expand Up @@ -209,7 +214,7 @@ Given a matrix `M` of predictors and a corresponding vector of events, compute t
Cox proportional hazard model estimate of coefficients. Returns a `CoxModel`
object.
"""
function StatsBase.fit(::Type{CoxModel}, M::AbstractMatrix, y::AbstractVector; tol=1e-4, l2_cost=0)
function StatsAPI.fit(::Type{CoxModel}, M::AbstractMatrix, y::AbstractVector; tol=1e-4, l2_cost=0)
index_perm = sortperm(y)
X = M[index_perm,:]
s = y[index_perm]
Expand Down
8 changes: 4 additions & 4 deletions src/estimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ function _estimator(::Type{S}, ets::AbstractVector{EventTime{T}}) where {S,T}
return S{T}(times, nevents, ncensor, natrisk, estimator, stderr)
end

function StatsBase.fit(::Type{S},
times::AbstractVector{T},
status::AbstractVector{<:Integer}) where {S<:NonparametricEstimator,T}
function StatsAPI.fit(::Type{S},
times::AbstractVector{T},
status::AbstractVector{<:Integer}) where {S<:NonparametricEstimator,T}
ntimes = length(times)
nstatus = length(status)
if ntimes != nstatus
Expand All @@ -87,7 +87,7 @@ function StatsBase.fit(::Type{S},
return fit(S, map(EventTime{T}, times, status))
end

function StatsBase.fit(::Type{S}, ets::AbstractVector{<:EventTime}) where S<:NonparametricEstimator
function StatsAPI.fit(::Type{S}, ets::AbstractVector{<:EventTime}) where S<:NonparametricEstimator
isempty(ets) && throw(ArgumentError("can't compute $(nameof(S)) from 0 observations"))
return _estimator(S, issorted(ets) ? ets : sort(ets))
end
6 changes: 3 additions & 3 deletions src/kaplanmeier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ stderr_update(::Type{KaplanMeier}, gw, dᵢ, nᵢ) = gw + dᵢ / (nᵢ * (nᵢ -
Compute the pointwise log-log transformed confidence intervals for the survivor
function as a vector of tuples.
"""
function StatsBase.confint(km::KaplanMeier, α::Float64=0.05)
function StatsAPI.confint(km::KaplanMeier, α::Float64=0.05)
q = quantile(Normal(), 1 - α/2)
return map(km.survival, km.stderr) do srv, se
l = log(-log(srv))
Expand All @@ -52,12 +52,12 @@ Given a vector of times to events and a corresponding vector of indicators that
denote whether each time is an observed event or is right censored, compute the
Kaplan-Meier estimate of the survivor function.
"""
StatsBase.fit(::Type{KaplanMeier}, times, status)
StatsAPI.fit(::Type{KaplanMeier}, times, status)

"""
fit(KaplanMeier, ets) -> KaplanMeier
Compute the Kaplan-Meier estimate of the survivor function from a vector of
[`EventTime`](@ref) values.
"""
StatsBase.fit(::Type{KaplanMeier}, ets)
StatsAPI.fit(::Type{KaplanMeier}, ets)
6 changes: 3 additions & 3 deletions src/nelsonaalen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ stderr_update(::Type{NelsonAalen}, gw, dᵢ, nᵢ) = gw + dᵢ * (nᵢ - dᵢ) /
Compute the pointwise confidence intervals for the cumulative hazard
function as a vector of tuples.
"""
function StatsBase.confint(na::NelsonAalen, α::Float64=0.05)
function StatsAPI.confint(na::NelsonAalen, α::Float64=0.05)
q = quantile(Normal(), 1 - α/2)
return map(na.chaz, na.stderr) do srv, se
srv - q * se, srv + q * se
Expand All @@ -49,12 +49,12 @@ Given a vector of times to events and a corresponding vector of indicators that
denote whether each time is an observed event or is right censored, compute the
Nelson-Aalen estimate of the cumulative hazard rate function.
"""
StatsBase.fit(::Type{NelsonAalen}, times, status)
StatsAPI.fit(::Type{NelsonAalen}, times, status)

"""
fit(NelsonAalen, ets) -> NelsonAalen
Compute the Nelson-Aalen estimate of the cumulative hazard rate function from a
vector of [`EventTime`](@ref) values.
"""
StatsBase.fit(::Type{NelsonAalen}, ets)
StatsAPI.fit(::Type{NelsonAalen}, ets)
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ end

outcome_without_formula = coxph(regressor_matrix, event_vector)

@test modelmatrix(outcome) == modelmatrix(outcome_without_formula)

@test sprint(show, outcome_without_formula) == """
CoxModel{Float64}
Expand All @@ -215,6 +217,9 @@ x7 0.0914971 0.0286485 3.19378 0.0014
──────────────────────────────────────────────"""

coef_matrix = ModelMatrix(ModelFrame(@formula(event ~ 0 + fin + age + race + wexp + mar + paro + prio), rossi)).m

@test modelmatrix(outcome) == coef_matrix[sortperm(event_vector), :]

outcome_from_matrix = coxph(coef_matrix, rossi.event; tol=1e-8, l2_cost=0)
outcome_from_matrix32 = coxph(Float32.(coef_matrix), rossi.event; tol=1e-5)
outcome_from_matrix_int = coxph(Int64.(coef_matrix), rossi.event; tol=1e-6, l2_cost=0.0)
Expand All @@ -234,6 +239,7 @@ x7 0.0914971 0.0286485 3.19378 0.0014
@test coef(outcome_from_matrix) coef(outcome_from_matrix_int) atol=1e-5
@test nobs(outcome) == size(rossi, 1)
@test dof(outcome) == 7
@test dof_residual(outcome) == 425
@test loglikelihood(outcome) > nullloglikelihood(outcome)
@test all(x->x > 0, eigen(outcome.model.fischer_info).values)
@test outcome.model.fischer_info * vcov(outcome) I atol=1e-10
Expand Down

0 comments on commit 2146c2c

Please sign in to comment.