Skip to content

Commit

Permalink
Fix device placement linspace (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
ragulpr authored Jan 13, 2025
1 parent 214d74a commit af83daa
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions taildropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ def forward(self, input: Tensor) -> Tensor:
mode = 'first_k'

if mode == 'random':
type_out = input.type()
type_out = input.dtype
device = input.device

# No cuda torch.linspace for old versions of pytorch.
linspace = torch.arange(1, n_features + 1, 1).type(type_out)
linspace = torch.arange(1, n_features + 1, 1, device=device,dtype=type_out)
# resized [1,n_features] if input 2d, [1,1,..,n_features] if nd
newshape = replace_w_ones_except(input.shape, self.dropout_dim)
linspace.resize_(newshape)
Expand All @@ -147,7 +147,7 @@ def forward(self, input: Tensor) -> Tensor:

# make [n_batch,1] noise if input 2d
newshape = replace_w_ones_except(input.shape, self.batch_dim)
uniform = input.new(*newshape).uniform_()
uniform = torch.rand(newshape, device=device, dtype=type_out)
mask = prob < uniform # 43% of cpu cumtime
mask = mask.type(type_out) # 30% of cpu cumtime
return input * mask # 23% of cpu cumtime # Note works due to broadcasting
Expand Down

0 comments on commit af83daa

Please sign in to comment.