Skip to content

Commit

Permalink
test Euclidean weights and weight standardisation for control_for_euc()
Browse files Browse the repository at this point in the history
  • Loading branch information
JackEdTaylor committed Apr 28, 2021
1 parent 4b9b9ee commit 82cd7ef
Showing 1 changed file with 96 additions and 2 deletions.
98 changes: 96 additions & 2 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ testthat::test_that("no id_col given", {
eg_df %>%
split_by(a, -5:0 ~ 0:5) %>%
control_for(b, -2.5:2.5) %>%
generate(17) %>%
generate(17, silent=TRUE) %>%
nrow()
),
17
Expand All @@ -198,7 +198,7 @@ testthat::test_that("no id_col given", {
eg_df %>%
split_by(a, -5:0 ~ 0:5) %>%
control_for(b, -2.5:2.5) %>%
generate(17),
generate(17, silent=TRUE),
"No id_col detected; will use row numbers."
)
# since eg_df's id column is just the row numbers anyway, these should be identical
Expand Down Expand Up @@ -777,4 +777,98 @@ testthat::test_that("control_for_euc", {
dplyr::filter(gen_euc_dist == man_euc_dist) %>%
nrow()
}, 20)
# test weighted Euclidean distance is calculated as expected
testthat::expect_equal({
weights <- runif(2, 0.1, 100)
weights_std <- weights / mean(weights)

wide_res <- eg_df %>%
set_options(id_col = "id") %>%
split_by(a, -5:0 ~ 0:5) %>%
control_for_euc(
c(b, e),
0:1.5,
name = "gen_euc_dist",
weights = weights
) %>%
generate(20, silent=TRUE)

manual_euc_dist <- wide_res %>%
dplyr::left_join(
eg_df %>%
dplyr::select(id, b, e) %>%
dplyr::mutate(b = weights_std[1]*scale(b), e = weights_std[2]*scale(e)) %>%
dplyr::rename(A1_b = b, A1_e = e),
by = c("A1" = "id")
) %>%
dplyr::left_join(
eg_df %>%
dplyr::select(id, b, e) %>%
dplyr::mutate(b = weights_std[1]*scale(b), e = weights_std[2]*scale(e)) %>%
dplyr::rename(A2_b = b, A2_e = e),
by = c("A2" = "id")
) %>%
dplyr::mutate(
dist_b = A1_b - A2_b,
dist_e = A1_e - A2_e,
man_euc_dist = sqrt(dist_b**2 + dist_e**2)
)

wide_res %>%
long_format() %>%
dplyr::filter(condition != match_null) %>%
dplyr::left_join(
dplyr::select(manual_euc_dist, item_nr, man_euc_dist),
by = "item_nr"
) %>%
dplyr::filter(gen_euc_dist == man_euc_dist) %>%
nrow()
}, 20)
# test that weight standardisation can be disabled
testthat::expect_equal({
weights <- runif(2, 0.1, 100)

wide_res <- eg_df %>%
set_options(id_col = "id") %>%
split_by(a, -5:0 ~ 0:5) %>%
control_for_euc(
c(b, e),
0:10,
name = "gen_euc_dist",
weights = weights,
standardise_weights = FALSE
) %>%
generate(20, silent=TRUE)

manual_euc_dist <- wide_res %>%
dplyr::left_join(
eg_df %>%
dplyr::select(id, b, e) %>%
dplyr::mutate(b = weights[1]*scale(b), e = weights[2]*scale(e)) %>%
dplyr::rename(A1_b = b, A1_e = e),
by = c("A1" = "id")
) %>%
dplyr::left_join(
eg_df %>%
dplyr::select(id, b, e) %>%
dplyr::mutate(b = weights[1]*scale(b), e = weights[2]*scale(e)) %>%
dplyr::rename(A2_b = b, A2_e = e),
by = c("A2" = "id")
) %>%
dplyr::mutate(
dist_b = A1_b - A2_b,
dist_e = A1_e - A2_e,
man_euc_dist = sqrt(dist_b**2 + dist_e**2)
)

wide_res %>%
long_format() %>%
dplyr::filter(condition != match_null) %>%
dplyr::left_join(
dplyr::select(manual_euc_dist, item_nr, man_euc_dist),
by = "item_nr"
) %>%
dplyr::filter(gen_euc_dist == man_euc_dist) %>%
nrow()
}, 20)
})

0 comments on commit 82cd7ef

Please sign in to comment.