From af83daa54010fe015b019e3f12cdd01637c6481f Mon Sep 17 00:00:00 2001 From: Egil Martinsson Date: Mon, 13 Jan 2025 14:12:58 -0800 Subject: [PATCH] Fix device placement linspace (#20) --- taildropout.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/taildropout.py b/taildropout.py index 3072b81..05326d2 100644 --- a/taildropout.py +++ b/taildropout.py @@ -135,10 +135,10 @@ def forward(self, input: Tensor) -> Tensor: mode = 'first_k' if mode == 'random': - type_out = input.type() + type_out = input.dtype + device = input.device - # No cuda torch.linspace for old versions of pytorch. - linspace = torch.arange(1, n_features + 1, 1).type(type_out) + linspace = torch.arange(1, n_features + 1, 1, device=device,dtype=type_out) # resized [1,n_features] if input 2d, [1,1,..,n_features] if nd newshape = replace_w_ones_except(input.shape, self.dropout_dim) linspace.resize_(newshape) @@ -147,7 +147,7 @@ def forward(self, input: Tensor) -> Tensor: # make [n_batch,1] noise if input 2d newshape = replace_w_ones_except(input.shape, self.batch_dim) - uniform = input.new(*newshape).uniform_() + uniform = torch.rand(newshape, device=device, dtype=type_out) mask = prob < uniform # 43% of cpu cumtime mask = mask.type(type_out) # 30% of cpu cumtime return input * mask # 23% of cpu cumtime # Note works due to broadcasting