Skip to content

Commit

Permalink
Support compilation (Closes #24) (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
ragulpr authored Jan 19, 2025
1 parent 76ffe4e commit 173f350
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 99 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r requirements.txt
pip install -r requirements.txt -r requirements-dev.txt
- name: Run tests
run: pytest -v test.py
run: pytest -vvrP test.py

- name: Run performance tests
if: contains(github.event.head_commit.message, '[perf]')
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torch==2.2.2
# torch==2.2.2 last with macos supported
torch>2
31 changes: 17 additions & 14 deletions taildropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,37 @@ def forward(self, input: Tensor) -> Tensor:
type_out = input.dtype
device = input.device

linspace = torch.arange(1, n_features + 1, 1, device=device,dtype=type_out)
# resized [1,n_features] if input 2d, [1,1,..,n_features] if nd
newshape = replace_w_ones_except(input.shape, self.dropout_dim)
linspace.resize_(newshape)
linspace = torch.arange(1, n_features + 1, 1, device=device, dtype=type_out)
prob_shape = replace_w_ones_except(input.shape, self.dropout_dim) #[1,1,..,n_features]
linspace.resize_(prob_shape)
# self.scale*n_features faster than linspace/n_features
prob = self.cdf(linspace, self.scale * n_features)

# make [n_batch,1] noise if input 2d
newshape = replace_w_ones_except(input.shape, self.batch_dim)
uniform = torch.rand(newshape, device=device, dtype=type_out)
mask = prob < uniform # 43% of cpu cumtime
mask_shape = replace_w_ones_except(input.shape, self.batch_dim)
uniform = torch.rand(mask_shape, device=device, dtype=type_out) # [n_batch,1,1] if input 3d
mask = prob < uniform # 43% of cpu cumtime [n_batch,1,n_features]
mask = mask.type(type_out) # 30% of cpu cumtime
return input * mask # 23% of cpu cumtime # Note works due to broadcasting
# Similar performance / identical with torch.compile:
# inv_mask = prob >= uniform # ~mask
# Similar performance / identical with torch.compile but doesn't propagate NaN:
# inv_mask = prob >= uniform
# return input.masked_fill(inv_mask, 0)

if mode == 'straight-through':
return input

if mode == 'first_k':
# Do mask[:, :, (...), :, k:] = 0 in choice of dropout_dim
mask_shape = replace_w_ones_except(input.shape, self.dropout_dim)
mask = input.new_ones(*mask_shape)
# Do mask[:, :, (...), :, k:] = 0 depending on dropout_dim
slices = [slice(None)] * input.ndim # Start with full slices for all dimensions
slices[self.dropout_dim] = slice(self.k, None) # Modify only the dropout_dim
mask[tuple(slices)] = 0
slices = [slice(None)] * input.ndim

# Avoid recompilation for every k
with torch._dynamo.config.patch({"disable": True}):
slices[self.dropout_dim] = slice(self.k, None)
mask[tuple(slices)] = 0

return input * mask

if mode == 'zero':
return input * 0

Expand Down
128 changes: 70 additions & 58 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,74 @@
import torch
from taildropout import TailDropout, get_scale_param

def test_expected_mask():
def test_routes(dropout, input_shape, requires_grad=False):
x = torch.ones(input_shape, requires_grad=requires_grad)
if torch.cuda.is_available():
x = x.cuda()

# Assert shapes
dropout.train()
assert dropout(x).shape == x.shape
dropout.set_k(2)
assert dropout(x).shape == x.shape
dropout.eval()
assert dropout(x).shape == x.shape
dropout.set_k(2)
assert dropout(x).shape == x.shape


# Test values in train, eval, prune mode
dropout.eval()
y_all_eval = dropout(x)
dropout.set_k(2)
y_k_eval = dropout(x)
dropout.train()
dropout.set_k(f)
y_all_train = dropout(x)
def _check_routes(dropout: TailDropout, input_shape, requires_grad=False):
x = torch.ones(input_shape, requires_grad=requires_grad)
f = input_shape[dropout.dropout_dim]
if torch.cuda.is_available():
x = x.cuda()

# Assert shapes
dropout.train()
assert dropout(x).shape == x.shape
dropout.set_k(2)
assert dropout(x).shape == x.shape
dropout.eval()
assert dropout(x).shape == x.shape
dropout.set_k(2)
assert dropout(x).shape == x.shape


# Test values in train, eval, prune mode
dropout.eval()
y_all_eval = dropout(x)
dropout.set_k(2)
y_k_eval = dropout(x)
dropout.train()
dropout.set_k(f)
y_all_train = dropout(x)
dropout.set_k(2)

y_k_train = dropout(x)
torch.testing.assert_close(y_all_eval, y_all_train)
torch.testing.assert_close(y_k_eval, y_k_train)

# all columns exactly one
assert y_all_eval.mean().allclose(torch.tensor(1.))
assert y_k_eval.mean().allclose(torch.tensor(2/f))

if dropout.dropout_dim==-1 or dropout.dropout_dim == len(input_shape):
# Assumes dropout dimension is the last dimension.
x = torch.randn(input_shape)
# Assert values
dropout.set_k(2)
y = dropout(x)
torch.testing.assert_close(y[..., 2:], torch.zeros_like(y[..., 2:]))
torch.testing.assert_close(y[..., :2], x[..., :2])

y_k_train = dropout(x)
torch.testing.assert_close(y_all_eval, y_all_train)
torch.testing.assert_close(y_k_eval, y_k_train)

# all columns exactly one
assert y_all_eval.mean().allclose(torch.tensor(1.))
assert y_k_eval.mean().allclose(torch.tensor(2/f))

if dropout.dropout_dim==-1 or dropout.dropout_dim == len(input_shape):
# Assumes dropout dimension is the last dimension.
x = torch.randn(input_shape)
# Assert values
dropout.set_k(2)
y = dropout(x)
torch.testing.assert_close(y[..., 2:], torch.zeros_like(y[..., 2:]))
torch.testing.assert_close(y[..., :2], x[..., :2])

def test_expected_mask():
n = 5
f = 7

test_routes(dropout=TailDropout(), input_shape=(n, f)) # noqa
test_routes(dropout=TailDropout(), input_shape=(n, 1, f)) # noqa
test_routes(dropout=TailDropout(), input_shape=(n, n, f)) # noqa
_check_routes(dropout=TailDropout(), input_shape=(n, f)) # noqa
_check_routes(dropout=TailDropout(), input_shape=(n, 1, f)) # noqa
_check_routes(dropout=TailDropout(), input_shape=(n, n, f)) # noqa

test_routes(dropout=TailDropout(dropout_dim=1), input_shape=(n, f))
test_routes(dropout=TailDropout(dropout_dim=2), input_shape=(n, 1, f)) # noqa
test_routes(dropout=TailDropout(dropout_dim=2), input_shape=(n, n, f)) # noqa
_check_routes(dropout=TailDropout(dropout_dim=1), input_shape=(n, f))
_check_routes(dropout=TailDropout(dropout_dim=2), input_shape=(n, 1, f)) # noqa
_check_routes(dropout=TailDropout(dropout_dim=2), input_shape=(n, n, f)) # noqa

test_routes(dropout=TailDropout(batch_dim=0, dropout_dim=-1), input_shape=(n, 1, f)) # noqa
_check_routes(dropout=TailDropout(batch_dim=0, dropout_dim=-1), input_shape=(n, 1, f)) # noqa

test_routes(dropout=TailDropout(batch_dim=1), input_shape=(1, n, 1, f)) # noqa
test_routes(dropout=TailDropout(batch_dim=1), input_shape=(1, n, f)) # noqa
test_routes(dropout=TailDropout(batch_dim=1), input_shape=(n, 1, f)) # noqa
test_routes(dropout=TailDropout(batch_dim=1), input_shape=(n, n, f)) # noqa
test_routes(dropout=TailDropout(batch_dim=1, dropout_dim=-2), input_shape=(1, n, 1, f, 1)) # noqa
test_routes(dropout=TailDropout(batch_dim=1, dropout_dim=3), input_shape=(1, n, 1, f, 1)) # noqa
_check_routes(dropout=TailDropout(batch_dim=1), input_shape=(1, n, 1, f)) # noqa
_check_routes(dropout=TailDropout(batch_dim=1), input_shape=(1, n, f)) # noqa
_check_routes(dropout=TailDropout(batch_dim=1), input_shape=(n, 1, f)) # noqa
_check_routes(dropout=TailDropout(batch_dim=1), input_shape=(n, n, f)) # noqa
_check_routes(dropout=TailDropout(batch_dim=1, dropout_dim=-2), input_shape=(1, n, 1, f, 1)) # noqa
_check_routes(dropout=TailDropout(batch_dim=1, dropout_dim=3), input_shape=(1, n, 1, f, 1)) # noqa


test_routes(dropout=TailDropout(batch_dim=[0, 1]), input_shape=(n, n, f)) # noqa
test_routes(dropout=TailDropout(batch_dim=[1, 0]), input_shape=(n, n, f)) # noqa
_check_routes(dropout=TailDropout(batch_dim=[0, 1]), input_shape=(n, n, f)) # noqa
_check_routes(dropout=TailDropout(batch_dim=[1, 0]), input_shape=(n, n, f)) # noqa


# Test 0/1 probability
Expand All @@ -79,7 +80,7 @@ def test_routes(dropout, input_shape, requires_grad=False):
torch.testing.assert_close(TailDropout(1)(x),torch.zeros_like(x))

# Variable with grad
test_routes(dropout=TailDropout(), input_shape=(n, f), requires_grad=True)
_check_routes(dropout=TailDropout(), input_shape=(n, f), requires_grad=True)


def test_multiple_batch_dim():
Expand Down Expand Up @@ -157,6 +158,17 @@ def test_first_k():
actual = dropout(x)
assert actual.equal(expected)


def test_compilation():
dropout = TailDropout()
dropout.compile()
_check_routes(dropout=dropout, input_shape=(10, 5, 3)) # noqa

x = torch.randn((1, 100))
for k in range(1,101):
dropout.set_k(k)
dropout(x)

print(f'torch version {torch.__version__}')
print(f'torch.cuda.is_available():{torch.cuda.is_available()}')

Expand Down
67 changes: 44 additions & 23 deletions test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from taildropout import TailDropout

import argparse

# python -m cProfile -s cumtime test_performance.py --repeats 10
parser = argparse.ArgumentParser(description='')
parser.add_argument('--repeats', type=int, default=1000000, metavar='N')
Expand All @@ -26,6 +27,13 @@
import time
import math

import os
os.environ['OMP_NUM_THREADS'] = '16'
os.environ['NUMEXPR_MAX_THREADS'] = '16'

# Configure PyTorch threading
torch.set_num_threads(16) # Intra-op threads
torch.set_num_interop_threads(16) # Inter-op threads

def time_since(since):
s = time.time() - since
Expand All @@ -34,12 +42,11 @@ def time_since(since):
return '%dm %ds' % (m, s)


def dropout_runner(Dropout,
def dropout_runner(dropout,
requires_grad = False,
eval_mode = False,
backward = False):
device = 'cuda' if args.cuda else 'cpu'
dropout = Dropout()
y = torch.ones(args.batch_size,
args.n_features,
requires_grad=requires_grad,
Expand All @@ -64,31 +71,45 @@ def dropout_runner(Dropout,

total_start = time.time()

print(f"{'Eval Mode':<12} {'Requires Grad':<15} {'Backward':<10} {'Timing':<20} {'Total (s)':<10} {'Layer Type'}")
print(f"{'Eval Mode':<12} "
f"{'Requires Grad':<15} "
f"{'Backward':<10} "
f"{'Compile':<10} "
f"{'Timing':<20} "
f"{'Total (s)':<10} "
f"{'Layer Type'}")
for eval_mode in [False, True]:
for requires_grad in [True, False]:
for backward in [True, False]:
if backward and not requires_grad:
break

for _ in range(2):
for Dropout in [TailDropout]: # [TailDropout, nn.Dropout]
timing, secs = dropout_runner(
Dropout,
requires_grad=requires_grad,
eval_mode=eval_mode,
backward=backward
)
for compile in [False, True]:
if backward and not requires_grad:
break

print(f"{str(eval_mode):<12} {str(requires_grad):<15} {str(backward):<10} "
f"{timing:<20} {f'({secs:.2f})':<10} {Dropout.__name__}")

if args.time_limit is not None:
secs_elapsed = round(time.time() - total_start)
if secs_elapsed >= args.time_limit:
raise TimeoutError(
f"Time limit exceeded: {secs_elapsed}s > {args.time_limit}s"
)
for _ in range(2):
for dropout in [TailDropout()]: # [TailDropout(), nn.Dropout()]
if compile:
dropout.compile()

timing, secs = dropout_runner(
dropout,
requires_grad=requires_grad,
eval_mode=eval_mode,
backward=backward
)

print(f"{str(eval_mode):<12} "
f"{str(requires_grad):<15} "
f"{str(backward):<10} "
f"{str(compile):<10} "
f"{timing:<20} {f'({secs:.2f})':<10} "
f"{dropout.__ne__}")

if args.time_limit is not None:
secs_elapsed = round(time.time() - total_start)
if secs_elapsed >= args.time_limit:
raise TimeoutError(
f"Time limit exceeded: {secs_elapsed}s > {args.time_limit}s"
)

print("-" * 80)
print(f"Total time: {time_since(total_start)} ({round(time.time() - total_start)}s)")

0 comments on commit 173f350

Please sign in to comment.