Skip to content

Commit

Permalink
Fix replace_w_ones_except (#22) [perf]
Browse files Browse the repository at this point in the history
ragulpr authored Jan 13, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent af83daa commit 3f340b4
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions taildropout.py
Original file line number Diff line number Diff line change
@@ -27,14 +27,10 @@ def get_scale_param(p, tol=1e-9) -> float:
return a

def replace_w_ones_except(shape, dims):
# List like `shape` with ones everywhere except at `dims`.
newshape = [1 for _ in range(len(shape))]
try:
newshape[dims] = shape[dims]
except:
# dims iterable
for j in dims:
newshape[j] = shape[j]
newshape = [1]*len(shape)
dims = [dims] if isinstance(dims, int) else dims
for j in dims:
newshape[j] = shape[j]
return newshape


0 comments on commit 3f340b4

Please sign in to comment.