Skip to content

Commit

Permalink
Update alpha broadcasting and put beta on device (#18927)
Browse files Browse the repository at this point in the history
  • Loading branch information
KhawajaAbaid authored Dec 12, 2023
1 parent 7d431ce commit a41eb5e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions keras/backend/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def shuffle(x, axis=0, seed=None):
def gamma(shape, alpha, dtype=None, seed=None):
dtype = dtype or floatx()
dtype = to_torch_dtype(dtype)
alpha = torch.ones(shape) * torch.tensor(alpha)
beta = torch.ones(shape)
alpha = torch.broadcast_to(convert_to_tensor(alpha), shape)
beta = torch.ones(shape, device=get_device())
prev_rng_state = torch.random.get_rng_state()
first_seed, second_seed = draw_seed(seed)
torch.manual_seed(first_seed + second_seed)
Expand Down

0 comments on commit a41eb5e

Please sign in to comment.