Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rolling windows and cv #183

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Add rolling windows and cv #183

wants to merge 3 commits into from

Conversation

jeffzi
Copy link

@jeffzi jeffzi commented Mar 25, 2020

Hi.

After reading #180 and #177, I thought I could give a shot at implementing cross-validation.

I defined common cv procedures ExpandingWindow(), SlidingWindow(), Holdout() as rolling_window objects + 2 S3 methods:

  • roll() is based on the slider package and apply an arbitrary function (identity by default) on sub-windows iteratively. It also cuts off the specified horizon. The output is a tibble with the tsibble keys + an extra list-column containing untransformed results.
  • cv() fits models on folds and return forecasts as a fable. Intermediary folds are not kept because in most cases we are only interested in forecasts to evaluate accuracy. It's also faster and more memory-efficient.

roll() can be used to compute features() on folds if we want to do timeseries classifications for example.

The window parameters .init, .size, .step, and the cutoff h can be specified in terms of calendar periods, or in terms of the number of observations if .period is NULL. The implementation relies on the warp package.

I also implemented an optional parallelism. I confirmed with microbenchmarks that it is more efficient to parallelize on the folds rather than models.

An example of usage:

library(tsibbledata)
library(fable)

ExpandingWindow(.init = 10) %>%
  roll(aus_retail, h = 5)#> # A tibble: 152 x 3
#>    State                 Industry                                      .fold    
#>    <chr>                 <chr>                                         <list>   
#>  1 Australian Capital T… Cafes, restaurants and catering services      <list [4…
#>  2 Australian Capital T… Cafes, restaurants and takeaway food services <list [4…
#>  3 Australian Capital T… Clothing retailing                            <list [4…
#>  4 Australian Capital T… Clothing, footwear and personal accessory re… <list [4…
#>  5 Australian Capital T… Department stores                             <list [4…
#>  6 Australian Capital T… Electrical and electronic goods retailing     <list [4…
#>  7 Australian Capital T… Food retailing                                <list [4…
#>  8 Australian Capital T… Footwear and other personal accessory retail… <list [4…
#>  9 Australian Capital T… Furniture, floor coverings, houseware and te… <list [4…
#> 10 Australian Capital T… Hardware, building and garden supplies retai… <list [4…
#> # … with 142 more rows
ts <- aus_retail %>%
 filter(State %in% c("Queensland", "Victoria"), Industry == "Food retailing")

models <- list(
  snaive = SNAIVE(Turnover),
  ets = TSLM(log(Turnover) ~ trend() + season())
)

suppressWarnings({
ExpandingWindow(.init = 25, .step = 1, .period = "year") %>%
  cv(ts, h = 3, !!!models)
})#> # A fable: 4,896 x 7 [1M]
#> # Key:     .fold, State, Industry, .model [136]
#>    .fold State      Industry       .model    Month      Turnover .mean
#>    <int> <chr>      <chr>          <chr>     <mth>        <dist> <dbl>
#>  1     1 Queensland Food retailing snaive 2007 Jan N(1143, 2777) 1143.
#>  2     1 Queensland Food retailing snaive 2007 Feb N(1057, 2777) 1057.
#>  3     1 Queensland Food retailing snaive 2007 Mar N(1176, 2777) 1176.
#>  4     1 Queensland Food retailing snaive 2007 Apr N(1156, 2777) 1156.
#>  5     1 Queensland Food retailing snaive 2007 May N(1163, 2777) 1163.
#>  6     1 Queensland Food retailing snaive 2007 Jun N(1158, 2777) 1158.
#>  7     1 Queensland Food retailing snaive 2007 Jul N(1220, 2777) 1220.
#>  8     1 Queensland Food retailing snaive 2007 Aug N(1251, 2777) 1251.
#>  9     1 Queensland Food retailing snaive 2007 Sep N(1224, 2777) 1224.
#> 10     1 Queensland Food retailing snaive 2007 Oct N(1261, 2777) 1261.
#> # … with 4,886 more rows

Dev version of tibble breaks forecast() and therefore cv(). It's caused by [[<-.tbl_df:

library(tibble)
df <- tibble(x = 1:3, y = 3:1)
df[["z"]] <- c("a", "b", "c")
df
#> # A tibble: 3 x 3
#>       x     y ...3 
#>   <int> <int> <chr>
#> 1     1     3 a    
#> 2     2     2 b    
#> 3     3     1 c

# works
add_column(df, z = c("a", "b", "c"))
#> # A tibble: 3 x 4
#>       x     y ...3  z    
#>   <int> <int> <chr> <chr>
#> 1     1     3 a     a    
#> 2     2     2 b     b    
#> 3     3     1 c     c

Created on 2020-03-25 by the reprex package (v0.3.0)

I did not write tests but I can work on them if you think my implementation is useful.

@edgBR
Copy link

edgBR commented Jun 5, 2020

Is there any progress in the conflict solving?

@jeffzi
Copy link
Author

jeffzi commented Jun 6, 2020

There were a lot of changes to vctrs and dplyr, and consequently to fabletools. I'm waiting for fabletools to be stable to push fixes.

@mitchelloharawild are you interested in my proposal of cv api?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants