diff --git a/R/sgd_lm.R b/R/sgd_lm.R index db295c9..b2c77a3 100644 --- a/R/sgd_lm.R +++ b/R/sgd_lm.R @@ -51,7 +51,7 @@ sgd_lm = function(formula, mf <- eval(mf, parent.frame()) mt <- attr(mf, "terms") y <- model.response(mf, "numeric") - x <- model.matrix(mt, mf)[,-1] + x <- as.matrix(model.matrix(mt, mf, drop=F)[,-1]) if (studentize){ # Compute column means and standard errors and save them for later reconversion diff --git a/R/sgd_qr.R b/R/sgd_qr.R index ac517e4..35bc8fa 100644 --- a/R/sgd_qr.R +++ b/R/sgd_qr.R @@ -55,7 +55,7 @@ sgd_qr = function(formula, mf <- eval(mf, parent.frame()) mt <- attr(mf, "terms") y <- model.response(mf, "numeric") - x <- model.matrix(mt, mf)[,-1] + x <- as.matrix(model.matrix(mt, mf, drop=F)[,-1]) if (studentize){ # Compute column means and standard errors and save them for later reconversion diff --git a/R/sgdi_qr.R b/R/sgdi_qr.R index f1e201c..3bcfcf5 100644 --- a/R/sgdi_qr.R +++ b/R/sgdi_qr.R @@ -74,7 +74,7 @@ sgdi_qr = function(formula, mf <- eval(mf, parent.frame()) mt <- attr(mf, "terms") y <- model.response(mf, "numeric") - x <- model.matrix(mt, mf)[,-1] + x <- as.matrix(model.matrix(mt, mf, drop=F)[,-1]) if (inference == "rss"){ if (0 %in% rss_idx ){ stop("rss_idx includes 0 (the intercept term), where it should be bigger than 1.")