diff --git a/R/binary_classification.R b/R/binary_classification.R index b8a2aa7..6b52fff 100644 --- a/R/binary_classification.R +++ b/R/binary_classification.R @@ -49,17 +49,17 @@ auc <- function(actual, predicted) { #' @param predicted A numeric vector of predicted values, where the values correspond #' to the probabilities that each observation in \code{actual} #' belongs to the positive class +#' @param eps Log loss is undefined for p=0 or p=1, so probabilities are clipped to +#' \code{pmax(eps, pmin(1 - eps, p))}. #' @export #' @seealso \code{\link{logLoss}} #' @examples #' actual <- c(1, 1, 1, 0, 0, 0) #' predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2) #' ll(actual, predicted) -ll <- function(actual, predicted) { - score <- -(actual * log(predicted) + (1 - actual) * log(1 - predicted)) - score[actual == predicted] <- 0 - score[is.nan(score)] <- Inf - return(score) +ll <- function(actual, predicted, eps = 1e-12) { + predicted <- pmax(eps, pmin(1 - eps, predicted)) + return(-ifelse(actual, log(predicted), log(1 - predicted))) } #' Mean Log Loss diff --git a/man/ll.Rd b/man/ll.Rd index 0712a33..a9ade50 100644 --- a/man/ll.Rd +++ b/man/ll.Rd @@ -4,7 +4,7 @@ \alias{ll} \title{Log Loss} \usage{ -ll(actual, predicted) +ll(actual, predicted, eps = 1e-15) } \arguments{ \item{actual}{The ground truth binary numeric vector containing 1 for the positive @@ -13,6 +13,9 @@ class and 0 for the negative class.} \item{predicted}{A numeric vector of predicted values, where the values correspond to the probabilities that each observation in \code{actual} belongs to the positive class} + +\item{eps}{Log loss is undefined for p=0 or p=1, so probabilities are clipped to +\code{pmax(eps, pmin(1 - eps, p))}.} } \description{ \code{ll} computes the elementwise log loss between two numeric vectors. diff --git a/tests/testthat/test-binary_classification.R b/tests/testthat/test-binary_classification.R index b94d98f..109c189 100644 --- a/tests/testthat/test-binary_classification.R +++ b/tests/testthat/test-binary_classification.R @@ -10,14 +10,14 @@ test_that('area under ROC curve is calculated correctly', { test_that('log loss is calculated correctly', { expect_equal(ll(1,1), 0) - expect_equal(ll(1,0), Inf) - expect_equal(ll(0,1), Inf) + expect_equal(ll(1,0), -log(1e-12)) + expect_equal(ll(0,1), -log(1 - (1 - 1e-12))) expect_equal(ll(1,0.5), -log(0.5)) }) -test_that('mean los loss is calculated correctly', { +test_that('mean log loss is calculated correctly', { expect_equal(logLoss(c(1,1,0,0),c(1,1,0,0)), 0) - expect_equal(logLoss(c(1,1,0,0),c(1,1,1,0)), Inf) + expect_true(is.finite(logLoss(c(1,1,0,0),c(1,1,1,0)))) expect_equal(logLoss(c(1,1,1,0,0,0),c(.5,.1,.01,.9,.75,.001)), 1.881797068998267) })