Skip to content

Commit

Permalink
add tf + censoring example
Browse files Browse the repository at this point in the history
  • Loading branch information
AshesITR committed Oct 15, 2023
1 parent 9f2783c commit 81e0d23
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 1 deletion.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# reservr (development version)

* fixed tensorflow log-density implementation for ErlangMixtureDistribution and ExponentialDistribution to work with censored data.

# reservr 0.0.1

* Initial CRAN release
119 changes: 118 additions & 1 deletion jss-paper/reservr.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,124 @@ str(coef_nnet)
str(coef_lm)
```

\textcolor{red}{Evtl. zusätzlich ein Beispiel mit einem komplizierteren Network-Modell und höher dimensoinalen Loss ergänzen? Auf einem Datensatz, den es in R gibt?}
We now discuss a more complex example involving censoring, using the right-censored `ovarian` dataset bundled with the \pkg{survival} package \citep{baseR}.
Our goal is to predict the rate parameter of an exponential survival time distribution in cancer patients given four features $X = (\mathtt{age}, \mathtt{resid.ds}, \mathtt{rx}, \mathtt{ecog.ps})$ collected in the study.
The variables $\mathtt{resid.ds}, \mathtt{rx}$ and $\mathtt{ecog.ps}$ are indicator variables coded in $\{1, 2\}$.
$\mathtt{age}$ is a continuous variable with values in $(38, 75)$.
Due to the different scale of the $\mathtt{age}$ variable, it is useful to separate it from the other variables in order to perform normalization.
Normalization using `keras::layer_normalization()` transforms its input variables to zero mean and unit variance.
This step is not necessary for the categorical features.

```{r}
set.seed(1219L)
tensorflow::set_random_seed(1219L)
keras::k_set_floatx("float32")
dist <- dist_exponential()
ovarian <- survival::ovarian
dat <- list(
y = trunc_obs(
xmin = ovarian$futime,
xmax = ifelse(ovarian$fustat == 1, ovarian$futime, Inf)
),
x = list(
age = keras::k_constant(ovarian$age, shape = nrow(ovarian)),
flags = k_matrix(ovarian[, c("resid.ds", "rx", "ecog.ps")] - 1.0)
)
)
```

Next, we define the input layers and shapes, conforming to our input predictor list `dat$x`.

```{r}
nnet_inputs <- list(
keras::layer_input(shape = 1L, name = "age"),
keras::layer_input(shape = 3L, name = "flags")
)
```

`age` will be normalized and then concatenated to the other features, stored in `flags`, resulting in a 4-dimensional representation.
We then add two hidden ReLU-layers each with $5$ neurons to the network and compile the result, adapting the 5-dimensional hidden output to the parameter space $\Theta = (0, \infty)$ for the rate parameter of an exponential distribution.
This is accomplished using a dense layer with $1$ neuron and the $\mathrm{softplus}$ activation function.

```{r}
hidden1 <- keras::layer_concatenate(
keras::layer_normalization(nnet_inputs[[1L]]),
nnet_inputs[[2L]]
)
hidden2 <- keras::layer_dense(
hidden1,
units = 5L,
activation = keras::activation_relu
)
nnet_output <- keras::layer_dense(
hidden2,
units = 5L,
activation = keras::activation_relu
)
nnet <- tf_compile_model(
inputs = nnet_inputs,
intermediate_output = nnet_output,
dist = dist,
optimizer = keras::optimizer_adam(learning_rate = 0.01),
censoring = TRUE,
truncation = FALSE
)
nnet$model
```

For stability reasons, the default weight initialization is not optimal.
To circumvent this, we estimate a global exponential distribution fit on the observations and initialize the final layer weights such that the global fit is the initial prediction of the network.

```{r}
str(predict(nnet, dat$x))
global_fit <- fit(dist, dat$y)
tf_initialise_model(nnet, params = global_fit$params, mode = "zero")
str(predict(nnet, dat$x))
```

Finally, we can train the network and visualize the predictions.

```{r}
nnet_fit <- fit(
nnet,
x = dat$x,
y = dat$y,
epochs = 100L,
batch_size = nrow(dat$y),
shuffle = FALSE,
verbose = FALSE
)
plot(nnet_fit)
ovarian$expected_lifetime <- 1.0 / predict(nnet, dat$x)$rate
```

A plot of expected lifetime by $(\mathtt{age}, \mathtt{rx})$ shows that the network learned longer expected lifetimes for lower $\mathtt{age}$ and for treatment group ($\mathtt{rx}$) 2.
The global fit is included as a dashed blue line.

```{r echo = FALSE}
ggplot(ovarian, aes(x = age, y = expected_lifetime, color = factor(rx))) +
geom_point() +
geom_hline(yintercept = 1.0 / global_fit$params$rate, color = "blue", linetype = "dotted") +
scale_color_discrete(name = "treatment group")
```

Individual predictions and observations can also be plotted on a subject level.

```{r, echo = FALSE}
ggplot(ovarian[order(ovarian$futime), ], aes(x = seq_len(nrow(ovarian)))) +
geom_linerange(aes(ymin = futime, ymax = ifelse(fustat == 1, futime, Inf)), show.legend = FALSE) +
geom_point(aes(y = futime, shape = ifelse(fustat == 1, "observed", "censored"))) +
geom_point(aes(y = expected_lifetime, shape = "predicted"), color = "blue") +
geom_hline(yintercept = 1.0 / global_fit$params$rate, color = "blue", linetype = "dotted") +
coord_flip() +
scale_x_continuous(name = "subject", breaks = NULL) +
scale_y_continuous(name = "lifetime") +
scale_shape_manual(name = NULL, values = c(observed = "circle", censored = "circle open", predicted = "cross")) +
guides(shape = guide_legend(override.aes = list(color = c("black", "black", "blue"))))
```

# Conclusion {#conclusion}

Expand Down

0 comments on commit 81e0d23

Please sign in to comment.