From 060468967ee714cbad73eefc821727c044bd6858 Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Fri, 10 Dec 2021 13:10:09 -0600 Subject: [PATCH] add extract_plot_data function and slight improvement to autoplot cm mosiac --- NAMESPACE | 2 + R/conf_mat.R | 162 ++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 116 insertions(+), 48 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index c049eebd..e29ad883 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -17,6 +17,7 @@ S3method(conf_mat,table) S3method(detection_prevalence,data.frame) S3method(detection_prevalence,matrix) S3method(detection_prevalence,table) +S3method(extract_plot_data,conf_mat) S3method(f_meas,data.frame) S3method(f_meas,matrix) S3method(f_meas,table) @@ -97,6 +98,7 @@ export(conf_mat) export(detection_prevalence) export(detection_prevalence_vec) export(dots_to_estimate) +export(extract_plot_data) export(f_meas) export(f_meas_vec) export(finalize_estimator) diff --git a/R/conf_mat.R b/R/conf_mat.R index 0932d3b0..7d6f8f1f 100644 --- a/R/conf_mat.R +++ b/R/conf_mat.R @@ -284,6 +284,9 @@ summary.conf_mat <- function(object, stats } + +conf_mat_plot_types <- c("mosaic", "heatmap") + # Dynamically exported autoplot.conf_mat <- function(object, type = "mosaic", ...) { type <- rlang::arg_match(type, conf_mat_plot_types) @@ -294,46 +297,29 @@ autoplot.conf_mat <- function(object, type = "mosaic", ...) { ) } -conf_mat_plot_types <- c("mosaic", "heatmap") -cm_heat <- function(x) { - `%+%` <- ggplot2::`%+%` +#' @export +extract_plot_data <- function(x, ...){ + UseMethod("extract_plot_data") +} - df <- as.data.frame.table(x$table) - # Force specific column names for referencing in ggplot2 code - names(df) <- c("Prediction", "Truth", "Freq") +#' @export +#' @rdname extract_plot_data +#' +#' @param object a yardstick conf_mat +#' +#' @param type type of conf_mat plot +#' +#' @return a list of plot data elements +extract_plot_data.conf_mat <- function(object, type = "mosaic", ...) { - # Have prediction levels going from high to low so they plot in an - # order that matches the LHS of the confusion matrix - lvls <- levels(df$Prediction) - df$Prediction <- factor(df$Prediction, levels = rev(lvls)) + type <- rlang::arg_match(type, conf_mat_plot_types) - axis_labels <- get_axis_labels(x) + switch(type, + mosaic = cm_mosaic_data(object), + heatmap = cm_heat_data(object) + ) - df %>% - ggplot2::ggplot( - ggplot2::aes( - x = Truth, - y = Prediction, - fill = Freq - ) - ) %+% - ggplot2::geom_tile() %+% - ggplot2::scale_fill_gradient( - low = "grey90", - high = "grey40" - ) %+% - ggplot2::theme( - panel.background = ggplot2::element_blank(), - legend.position = "none" - ) %+% - ggplot2::geom_text( - mapping = ggplot2::aes(label = Freq) - ) %+% - ggplot2::labs( - x = axis_labels$x, - y = axis_labels$y - ) } space_fun <- function(x, adjustment, rescale = FALSE) { @@ -360,8 +346,10 @@ space_y_fun <- function(data, id, x_data) { out } -cm_mosaic <- function(x) { - `%+%` <- ggplot2::`%+%` + +cm_mosaic_data <- function(x){ + + cols <- dim(x$table)[[1]] cm_zero <- (as.numeric(x$table == 0) / 2) + x$table @@ -372,34 +360,112 @@ cm_mosaic <- function(x) { ~ space_y_fun(cm_zero, .x, x_data) ) - full_data <- dplyr::bind_rows(full_data_list) + i <- seq(1, cols ^ 2, cols) + seq(0, cols - 1 , 1) + + pred_type <- rep("incorrect", cols * cols) + + pred_type[i] <- "correct" + + full_data <- dplyr::bind_rows(full_data_list) %>% + dplyr::bind_cols("pred_type" = pred_type, .) y1_data <- full_data_list[[1]] tick_labels <- colnames(cm_zero) axis_labels <- get_axis_labels(x) - ggplot2::ggplot(full_data) %+% + final_data_list <- list( + data = full_data, + x_breaks = (x_data$xmin + x_data$xmax) / 2, + y_breaks = (y1_data$ymin + y1_data$ymax) / 2, + tick_labels = tick_labels, + axis_labels = axis_labels + ) + +} + +cm_heat_data <- function(x){ + df <- as.data.frame.table(x$table) + # Force specific column names for referencing in ggplot2 code + names(df) <- c("Prediction", "Truth", "Freq") + + # Have prediction levels going from high to low so they plot in an + # order that matches the LHS of the confusion matrix + lvls <- levels(df$Prediction) + df$Prediction <- factor(df$Prediction, levels = rev(lvls)) + + axis_labels <- get_axis_labels(x) + + full_data_list <- list( + data = df, + axis_labels = axis_labels + ) +} + + +cm_heat <- function(x) { + `%+%` <- ggplot2::`%+%` + + full_data_list <- cm_heat_data(x) + + full_data_list$data %>% + ggplot2::ggplot( + ggplot2::aes( + x = Truth, + y = Prediction, + fill = Freq + ) + ) %+% + ggplot2::geom_tile() %+% + ggplot2::scale_fill_gradient( + low = "grey90", + high = "grey40" + ) %+% + ggplot2::theme( + panel.background = ggplot2::element_blank(), + legend.position = "none" + ) %+% + ggplot2::geom_text( + mapping = ggplot2::aes(label = Freq) + ) %+% + ggplot2::labs( + x = full_data_list$axis_labels$x, + y = full_data_list$axis_labels$y + ) +} + + +cm_mosaic <- function(x) { + `%+%` <- ggplot2::`%+%` + + full_data_list <- cm_mosaic_data(x) + + ggplot2::ggplot(full_data_list$data) %+% ggplot2::geom_rect( ggplot2::aes( xmin = xmin, xmax = xmax, ymin = ymin, - ymax = ymax - ) - ) %+% + ymax = ymax, + fill = pred_type + ), + alpha = 0.9, + show.legend = F + )%+% ggplot2::scale_x_continuous( - breaks = (x_data$xmin + x_data$xmax) / 2, - labels = tick_labels + breaks = full_data_list$x_breaks, + labels = full_data_list$tick_labels ) %+% ggplot2::scale_y_continuous( - breaks = (y1_data$ymin + y1_data$ymax) / 2, - labels = tick_labels + breaks = full_data_list$y_breaks, + labels = full_data_list$tick_labels ) %+% ggplot2::labs( - y = axis_labels$y, - x = axis_labels$x + y = full_data_list$axis_labels$y, + x = full_data_list$axis_labels$x ) %+% + ggplot2::scale_fill_manual(breaks = c("correct", "incorrect"), + values = c("#4f58bd", "grey70")) %+% ggplot2::theme(panel.background = ggplot2::element_blank()) }