Skip to content

Commit

Permalink
added option to plot skip layer network in nnet models
Browse files Browse the repository at this point in the history
  • Loading branch information
fawda123 committed Sep 7, 2015
1 parent 758dd02 commit ef3b8d5
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 34 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: NeuralNetTools
Type: Package
Title: Visualization and Analysis Tools for Neural Networks
Version: 1.3.10.9000
Date: 2015-08-26
Version: 1.3.11.9000
Date: 2015-09-07
Author: Marcus W. Beck [aut, cre]
Maintainer: Marcus W. Beck <[email protected]>
Description: Visualization and analysis tools to aid in the interpretation of
Expand Down
79 changes: 62 additions & 17 deletions R/NeuralNetTools_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
#' @param bord_col chr string indicating border color around nodes, default \code{'lightblue'}
#' @param prune_col chr string indicating color of pruned connections, otherwise not shown
#' @param prune_lty line type for pruned connections, passed to \code{\link[graphics]{segments}}
#' @param max_sp logical value indicating if space between nodes in each layer is maximized, default \code{FALSE}
#' @param max_sp logical value indicating if space between nodes in each layer is maximized, default \code{FALSE}
#' @param skip logical if skip layer connections are plotted instead of the primary network
#' @param ... additional arguments passed to plot
#'
#' @export
Expand All @@ -37,6 +38,8 @@
#' @details
#' This function plots a neural network as a neural interpretation diagram as in Ozesmi and Ozesmi (1999). Options to plot without color-coding or shading of weights are also provided. The default settings plot positive weights between layers as black lines and negative weights as grey lines. Line thickness is in proportion to relative magnitude of each weight. The first layer includes only input variables with nodes labelled arbitrarily as I1 through In for n input variables. One through many hidden layers are plotted with each node in each layer labelled as H1 through Hn. The output layer is plotted last with nodes labeled as O1 through On. Bias nodes connected to the hidden and output layers are also shown. Neural networks created using \code{\link[RSNNS]{mlp}} do not show bias layers.
#'
#' A primary network and a skip layer network can be plotted for \code{\link[nnet]{nnet}} models with a skip layer connection. The default is to plot the primary network, whereas the skip layer network can be viewed with \code{skip = TRUE}. If \code{nid = TRUE}, the line widths for both the primary and skip layer plots are relative to all weights. Viewing both plots is recommended to see which network has larger relative weights.
#'
#' @examples
#' ## using numeric input
#'
Expand All @@ -56,6 +59,12 @@
#'
#' plotnet(mod)
#'
#' ## plot the skip layer from nnet model
#'
#' mod <- nnet(Y1 ~ X1 + X2 + X3, data = neuraldat, size = 5, skip = TRUE)
#'
#' plotnet(mod, skip = TRUE)
#'
#' ## using RSNNS, no bias layers
#'
#' library(RSNNS)
Expand Down Expand Up @@ -131,7 +140,7 @@ plotnet <- function(mod_in, ...) UseMethod('plotnet')
#' @export
#'
#' @method plotnet default
plotnet.default <- function(mod_in, x_names, y_names, struct = NULL, nid = TRUE, all_out = TRUE, all_in = TRUE, bias = TRUE, rel_rsc = 5, circle_cex = 5, node_labs = TRUE, var_labs = TRUE, line_stag = NULL, cex_val = 1, alpha_val = 1, circle_col = 'lightblue', pos_col = 'black', neg_col = 'grey', bord_col = 'lightblue', max_sp = FALSE, prune_col = NULL, prune_lty = 'dashed', ...){
plotnet.default <- function(mod_in, x_names, y_names, struct = NULL, nid = TRUE, all_out = TRUE, all_in = TRUE, bias = TRUE, rel_rsc = 5, circle_cex = 5, node_labs = TRUE, var_labs = TRUE, line_stag = NULL, cex_val = 1, alpha_val = 1, circle_col = 'lightblue', pos_col = 'black', neg_col = 'grey', bord_col = 'lightblue', max_sp = FALSE, prune_col = NULL, prune_lty = 'dashed', skip = NULL, ...){

wts <- neuralweights(mod_in, struct = struct)
struct <- wts$struct
Expand All @@ -157,6 +166,36 @@ plotnet.default <- function(mod_in, x_names, y_names, struct = NULL, nid = TRUE,
#initiate plot
plot(x_range, y_range, type = 'n', axes = FALSE, ylab = '', xlab = '')

# warning if nnet hidden is zero
if(struct[2] == 0) warning('No hidden layer, plotting skip layer only')

# subroutine for skip layer connections in nnet
if(any(skip)){

return({ # use this to exit

# plot connections usign layer lines with skip TRUE
mapply(
function(x) layer_lines(mod_in, x, layer1 = 1, layer2 = length(struct), out_layer = TRUE, nid = nid, rel_rsc = rel_rsc, all_in = all_in, pos_col = scales::alpha(pos_col, alpha_val), neg_col = scales::alpha(neg_col, alpha_val), x_range = x_range, y_range = y_range, line_stag = line_stag, x_names = x_names, layer_x = layer_x, max_sp = max_sp, struct = struct, prune_col = prune_col, prune_lty = prune_lty, skip = skip),
1:struct[length(struct)]
)

# plot only input, output nodes
for(i in c(1, length(struct))){
in_col <- circle_col
if(i == 1) { layer_name <- 'I'; in_col <- circle_col_inp}
if(i == length(struct)) layer_name <- 'O'
layer_points(struct[i], layer_x[i], x_range, layer_name, cex_val, circle_cex, bord_col, in_col,
node_labs, line_stag, var_labs, x_names, y_names, max_sp = max_sp, struct = struct,
y_range = y_range
)

}

})

}

#use functions to plot connections between layers
#bias lines
if(bias) bias_lines(bias_x, bias_y, mod_in, nid = nid, rel_rsc = rel_rsc, all_out = all_out, pos_col = scales::alpha(pos_col, alpha_val), neg_col = scales::alpha(neg_col, alpha_val), y_names = y_names, x_range = x_range, max_sp = max_sp, struct = struct, y_range = y_range, layer_x = layer_x, line_stag = line_stag)
Expand Down Expand Up @@ -191,19 +230,19 @@ plotnet.default <- function(mod_in, x_names, y_names, struct = NULL, nid = TRUE,
layer_lines(mod_in, node, layer1 = lay[1], layer2 = lay[2], nid = nid, rel_rsc = rel_rsc, all_in = TRUE,
pos_col = scales::alpha(pos_col, alpha_val), neg_col = scales::alpha(neg_col, alpha_val),
x_range = x_range, y_range = y_range, line_stag = line_stag, x_names = x_names, layer_x = layer_x,
max_sp = max_sp, struct = struct, prune_col = prune_col, prune_lty = prune_lty)
max_sp = max_sp, struct = struct, prune_col = prune_col, prune_lty = prune_lty, skip = skip)
}
}
#lines for hidden - output
#uses 'all_out' argument to plot connection lines for all output nodes or a single node
if(is.logical(all_out))
mapply(
function(x) layer_lines(mod_in, x, layer1 = length(struct) - 1, layer2 = length(struct), out_layer = TRUE, nid = nid, rel_rsc = rel_rsc, all_in = all_in, pos_col = scales::alpha(pos_col, alpha_val), neg_col = scales::alpha(neg_col, alpha_val), x_range = x_range, y_range = y_range, line_stag = line_stag, x_names = x_names, layer_x = layer_x, max_sp = max_sp, struct = struct, prune_col = prune_col, prune_lty = prune_lty),
function(x) layer_lines(mod_in, x, layer1 = length(struct) - 1, layer2 = length(struct), out_layer = TRUE, nid = nid, rel_rsc = rel_rsc, all_in = all_in, pos_col = scales::alpha(pos_col, alpha_val), neg_col = scales::alpha(neg_col, alpha_val), x_range = x_range, y_range = y_range, line_stag = line_stag, x_names = x_names, layer_x = layer_x, max_sp = max_sp, struct = struct, prune_col = prune_col, prune_lty = prune_lty, skip = skip),
1:struct[length(struct)]
)
else{
node_in <- which(y_names == all_out)
layer_lines(mod_in, node_in, layer1 = length(struct) - 1, layer2 = length(struct), out_layer = TRUE, nid = nid, rel_rsc = rel_rsc, pos_col = pos_col, neg_col = neg_col, x_range = x_range, y_range = y_range, line_stag = line_stag, x_names = x_names, layer_x = layer_x, max_sp = max_sp, struct = struct, prune_col = prune_col, prune_lty = prune_lty)
layer_lines(mod_in, node_in, layer1 = length(struct) - 1, layer2 = length(struct), out_layer = TRUE, nid = nid, rel_rsc = rel_rsc, pos_col = pos_col, neg_col = neg_col, x_range = x_range, y_range = y_range, line_stag = line_stag, x_names = x_names, layer_x = layer_x, max_sp = max_sp, struct = struct, prune_col = prune_col, prune_lty = prune_lty, skip = skip)
}

#use functions to plot nodes
Expand All @@ -228,12 +267,15 @@ plotnet.default <- function(mod_in, x_names, y_names, struct = NULL, nid = TRUE,
#' @export
#'
#' @method plotnet nnet
plotnet.nnet <- function(mod_in, x_names = NULL, y_names = NULL, ...){
plotnet.nnet <- function(mod_in, x_names = NULL, y_names = NULL, skip = FALSE, ...){

# check for skip layers
chk <- grepl('skip-layer', capture.output(mod_in))
if(any(chk))
warning('Skip layer used, line scaling is proportional to weights in current plot')
if(any(chk)){
warning('Skip layer used, line scaling is proportional to all weights including skip layer.')
} else {
skip <- FALSE
}

#get variable names from mod_in object
#change to user input if supplied
Expand All @@ -252,7 +294,7 @@ plotnet.nnet <- function(mod_in, x_names = NULL, y_names = NULL, ...){
if(is.null(x_names)) x_names <- xlabs
if(is.null(y_names)) y_names <- ylabs

plotnet.default(mod_in, x_names = x_names, y_names = y_names, ...)
plotnet.default(mod_in, x_names = x_names, y_names = y_names, skip = skip, ...)

}

Expand All @@ -272,7 +314,7 @@ plotnet.numeric <- function(mod_in, struct, x_names = NULL, y_names = NULL, ...)
if(is.null(y_names))
y_names <- paste0(rep('Y', struct[length(struct)]), seq(1:struct[length(struct)]))

plotnet.default(mod_in, struct = struct, x_names = x_names, y_names = y_names, ...)
plotnet.default(mod_in, struct = struct, x_names = x_names, y_names = y_names, skip = FALSE, ...)

}

Expand All @@ -293,7 +335,7 @@ plotnet.mlp <- function(mod_in, x_names = NULL, y_names = NULL, prune_col = NULL
bias <- FALSE

plotnet.default(mod_in, x_names = x_names, y_names = y_names, bias = bias, prune_col = prune_col,
prune_lty = prune_lty, ...)
prune_lty = prune_lty, skip = FALSE, ...)

}

Expand All @@ -310,7 +352,7 @@ plotnet.nn <- function(mod_in, x_names = NULL, y_names = NULL, ...){
if(is.null(y_names))
y_names <- mod_in$model.list$respons

plotnet.default(mod_in, x_names = x_names, y_names = y_names, ...)
plotnet.default(mod_in, x_names = x_names, y_names = y_names, skip = FALSE, ...)

}

Expand All @@ -319,7 +361,7 @@ plotnet.nn <- function(mod_in, x_names = NULL, y_names = NULL, ...){
#' @export
#'
#' @method plotnet train
plotnet.train <- function(mod_in, x_names = NULL, y_names = NULL, ...){
plotnet.train <- function(mod_in, x_names = NULL, y_names = NULL, skip = FALSE, ...){

if(is.null(y_names))
y_names <- strsplit(as.character(mod_in$terms[[2]]), ' + ', fixed = TRUE)[[1]]
Expand All @@ -329,9 +371,12 @@ plotnet.train <- function(mod_in, x_names = NULL, y_names = NULL, ...){

# check for skip layers
chk <- grepl('skip-layer', capture.output(mod_in))
if(any(chk))
warning('Skip layer used, line scaling is proportional to weights in current plot')

plotnet.default(mod_in, x_names = x_names, y_names = y_names, ...)
if(any(chk)){
warning('Skip layer used, line scaling is proportional to all weights including skip layer.')
} else {
skip <- FALSE
}

plotnet.default(mod_in, x_names = x_names, y_names = y_names, skip = skip, ...)

}
40 changes: 33 additions & 7 deletions R/NeuralNetTools_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,18 @@ neuralskips <- function(mod_in, ...) UseMethod('neuralskips')
#' @rdname neuralskips
#'
#' @import scales
#'
#' @param rel_rsc numeric value indicating maximum to rescale weights for plotting in a neural interpretation diagram. Default is \code{NULL} for no rescaling. Scaling is relative to all weights, not just those in the primary network.
#'
#' @export
#'
#' @method neuralskips nnet
neuralskips.nnet <- function(mod_in, ...){
neuralskips.nnet <- function(mod_in, rel_rsc = NULL, ...){

wts <- mod_in$wts

if(!is.null(rel_rsc)) wts <- scales::rescale(abs(wts), c(1, rel_rsc))

# get skip layer weights if present, otherwise exit
chk <- grepl('skip-layer', capture.output(mod_in))
if(any(chk)){
Expand Down Expand Up @@ -481,28 +485,50 @@ bias_points <- function(bias_x, bias_y, layer_name, node_labs, x_range, y_range,
#' @param max_sp logical indicating if space is maximized in plot
#' @param prune_col chr string indicating color of pruned connections, otherwise not shown
#' @param prune_lty line type for pruned connections, passed to \code{\link[graphics]{segments}}
#' @param skip logical to plot connections for skip layer
#'
layer_lines <- function(mod_in, h_layer, layer1 = 1, layer2 = 2, out_layer = FALSE, nid, rel_rsc, all_in, pos_col, neg_col, x_range, y_range, line_stag, x_names, layer_x, struct, max_sp, prune_col = NULL, prune_lty = 'dashed'){
layer_lines <- function(mod_in, h_layer, layer1 = 1, layer2 = 2, out_layer = FALSE, nid, rel_rsc, all_in, pos_col, neg_col, x_range, y_range, line_stag, x_names, layer_x, struct, max_sp, prune_col = NULL, prune_lty = 'dashed', skip){

x0 <- rep(layer_x[layer1] * diff(x_range) + line_stag * diff(x_range), struct[layer1])
x1 <- rep(layer_x[layer2] * diff(x_range) - line_stag * diff(x_range), struct[layer1])

# see if skip layer is presnet in nnet
chk <- grepl('skip-layer', capture.output(mod_in))

if(out_layer == TRUE){

y0 <- get_ys(struct[layer1], max_sp, struct, y_range)
y1 <- rep(get_ys(struct[layer2], max_sp, struct, y_range)[h_layer], struct[layer1])
src_str <- paste('out', h_layer)

# get weights for numeric, otherwise use different method for neuralweights
if(inherits(mod_in, c('numeric', 'integer'))){

wts <- neuralweights(mod_in, struct = struct)$wts
wts_rs <- neuralweights(mod_in, rel_rsc, struct = struct)$wts
wts <- wts[grep(src_str, names(wts))][[1]][-1]
wts_rs <- wts_rs[grep(src_str, names(wts_rs))][[1]][-1]

} else {
wts <- neuralweights(mod_in)$wts
wts_rs <- neuralweights(mod_in, rel_rsc)$wts

# get skip weights if both TRUE
if(skip & any(chk)){

wts <- neuralskips(mod_in)
wts_rs <- neuralskips(mod_in, rel_rsc)

# otherwise get normal connects
} else {

wts <- neuralweights(mod_in)$wts
wts_rs <- neuralweights(mod_in, rel_rsc)$wts
wts <- wts[grep(src_str, names(wts))][[1]][-1]
wts_rs <- wts_rs[grep(src_str, names(wts_rs))][[1]][-1]

}

}
wts <- wts[grep(src_str, names(wts))][[1]][-1]
wts_rs <- wts_rs[grep(src_str, names(wts_rs))][[1]][-1]


cols <- rep(pos_col, struct[layer1])
cols[wts<0] <- neg_col

Expand Down
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ The `olden` function is an alternative and more flexible approach to evaluate va
olden(mod, 'Y1')
```

The `lekprofile` function performs a simple sensitivity analysis for neural networks. The Lek profile method is fairly generic and can be extended to any statistical model in R with a predict method. However, it is one of few methods to evaluate sensitivity in neural networks. The function begins by predicting the response variable across the range of values for a given explanatory variable. All other explanatory variables are held constant at set values (e.g., minimum, 20th percentile, maximum) that are indicated in the plot legend. The final result is a set of predictions for the response that are evalutaed across the range of values for one explanatory variable, while holding all other explanatory variables constant. This is repeated for each explanatory variable to describe the fitted response values returned by the model.
The `lekprofile` function performs a simple sensitivity analysis for neural networks. The Lek profile method is fairly generic and can be extended to any statistical model in R with a predict method. However, it is one of few methods to evaluate sensitivity in neural networks. The function begins by predicting the response variable across the range of values for a given explanatory variable. All other explanatory variables are held constant at set values (e.g., minimum, 20th percentile, maximum) that are indicated in the plot legend. The final result is a set of predictions for the response that are evaluated across the range of values for one explanatory variable, while holding all other explanatory variables constant. This is repeated for each explanatory variable to describe the fitted response values returned by the model.

```{r, results = 'hide', fig.height = 3, warning = FALSE, fig.width = 9}
# sensitivity analysis
Expand Down
2 changes: 1 addition & 1 deletion README.html
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ <h3>Functions</h3>
<pre class="r"><code># importance of each variable
olden(mod, &#39;Y1&#39;)</code></pre>
<p><img src="README_files/figure-html/unnamed-chunk-8-1.png" /></p>
<p>The <code>lekprofile</code> function performs a simple sensitivity analysis for neural networks. The Lek profile method is fairly generic and can be extended to any statistical model in R with a predict method. However, it is one of few methods to evaluate sensitivity in neural networks. The function begins by predicting the response variable across the range of values for a given explanatory variable. All other explanatory variables are held constant at set values (e.g., minimum, 20th percentile, maximum) that are indicated in the plot legend. The final result is a set of predictions for the response that are evalutaed across the range of values for one explanatory variable, while holding all other explanatory variables constant. This is repeated for each explanatory variable to describe the fitted response values returned by the model.</p>
<p>The <code>lekprofile</code> function performs a simple sensitivity analysis for neural networks. The Lek profile method is fairly generic and can be extended to any statistical model in R with a predict method. However, it is one of few methods to evaluate sensitivity in neural networks. The function begins by predicting the response variable across the range of values for a given explanatory variable. All other explanatory variables are held constant at set values (e.g., minimum, 20th percentile, maximum) that are indicated in the plot legend. The final result is a set of predictions for the response that are evaluated across the range of values for one explanatory variable, while holding all other explanatory variables constant. This is repeated for each explanatory variable to describe the fitted response values returned by the model.</p>
<pre class="r"><code># sensitivity analysis
lekprofile(mod)</code></pre>
<p><img src="README_files/figure-html/unnamed-chunk-9-1.png" /></p>
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ olden(mod, 'Y1')

![](README_files/figure-html/unnamed-chunk-8-1.png)

The `lekprofile` function performs a simple sensitivity analysis for neural networks. The Lek profile method is fairly generic and can be extended to any statistical model in R with a predict method. However, it is one of few methods to evaluate sensitivity in neural networks. The function begins by predicting the response variable across the range of values for a given explanatory variable. All other explanatory variables are held constant at set values (e.g., minimum, 20th percentile, maximum) that are indicated in the plot legend. The final result is a set of predictions for the response that are evalutaed across the range of values for one explanatory variable, while holding all other explanatory variables constant. This is repeated for each explanatory variable to describe the fitted response values returned by the model.
The `lekprofile` function performs a simple sensitivity analysis for neural networks. The Lek profile method is fairly generic and can be extended to any statistical model in R with a predict method. However, it is one of few methods to evaluate sensitivity in neural networks. The function begins by predicting the response variable across the range of values for a given explanatory variable. All other explanatory variables are held constant at set values (e.g., minimum, 20th percentile, maximum) that are indicated in the plot legend. The final result is a set of predictions for the response that are evaluated across the range of values for one explanatory variable, while holding all other explanatory variables constant. This is repeated for each explanatory variable to describe the fitted response values returned by the model.


```r
Expand Down
Binary file modified README_files/figure-html/unnamed-chunk-6-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified README_files/figure-html/unnamed-chunk-7-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified README_files/figure-html/unnamed-chunk-8-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified README_files/figure-html/unnamed-chunk-9-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit ef3b8d5

Please sign in to comment.