Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Refactor CosineSimiliarity
Browse files Browse the repository at this point in the history
  • Loading branch information
mejai1206 committed Sep 8, 2023
1 parent 5941c4e commit 70944c5
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 409 deletions.
34 changes: 13 additions & 21 deletions benchmarks/benchmark_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,33 @@
"cosine similarity forward",
["x_size"],
[256 * i for i in range(1, 21)],
{"num_batches": 16, "y_size": 16},
{"z_size": 16, "y_size": 16},
)
def bench_cosine_similarity_forward(num_batches, y_size, x_size, backend):
factory_kwargs = {"device": "cuda"}

input = torch.randn(num_batches, y_size, x_size, **factory_kwargs)
other = torch.randn(num_batches, y_size, x_size, **factory_kwargs)
def bench_cosine_similarity_forward(z_size, y_size, x_size, backend):
x1 = torch.randn(z_size, y_size, x_size, device="cuda")
x2 = torch.randn(z_size, y_size, x_size, device="cuda")

if backend == "torch":
return triton.testing.do_bench_cudagraph(lambda: torch.nn.functional.cosine_similarity(input, other, 2))
return triton.testing.do_bench_cudagraph(lambda: torch.nn.functional.cosine_similarity(x1, x2, 2))
else:
return triton.testing.do_bench_cudagraph(lambda: trident.function.cosine_similarity(input, other, 2))
return triton.testing.do_bench_cudagraph(lambda: trident.function.cosine_similarity(x1, x2, 2))


@util.report(
"cosine similarity backward",
["x_size"],
[256 * i for i in range(1, 21)],
{"num_batches": 16, "y_size": 16},
{"z_size": 16, "y_size": 16},
)
def bench_cosine_similarity_backward(num_batches, y_size, x_size, backend):
factory_kwargs = {"device": "cuda"}

input = torch.randn(num_batches, y_size, x_size, **factory_kwargs)
other = torch.randn(num_batches, y_size, x_size, **factory_kwargs)

input.requires_grad = other.requires_grad = True
def bench_cosine_similarity_backward(z_size, y_size, x_size, backend):
x1 = torch.randn(z_size, y_size, x_size, device="cuda", requires_grad=True)
x2 = torch.randn(z_size, y_size, x_size, device="cuda", requires_grad=True)
grad_output = torch.randn(z_size, y_size, device="cuda")

if backend == "torch":
operation = torch.nn.CosineSimilarity(2)
output = torch.nn.functional.cosine_similarity(x1, x2, 2)
else:
operation = trident.CosineSimilarity(2)

output = operation.forward(input, other)
grad_output = torch.ones_like(output)
output = trident.function.cosine_similarity(x1, x2, 2)

return triton.testing.do_bench_cudagraph(lambda: output.backward(grad_output, retain_graph=True))

Expand Down
44 changes: 22 additions & 22 deletions tests/test_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,49 @@


@pytest.mark.parametrize(
"num_batches, y_size, x_size, dim",
"z_size, y_size, x_size, dim",
[(1431, 500, 200, 0), (221, 1250, 200, 1), (21, 6400, 86, 2)],
)
def test_forward(num_batches, y_size, x_size, dim, device):
def test_forward(z_size, y_size, x_size, dim, device):
factory_kwargs = {"device": device}

input = torch.randn(num_batches, y_size, x_size, **factory_kwargs)
other = torch.randn(num_batches, y_size, x_size, **factory_kwargs)
x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs)

assert util.equal(
torch.nn.functional.cosine_similarity(input, other, dim=dim),
trident.function.cosine_similarity(input, other, dim=dim),
torch.nn.functional.cosine_similarity(x1, x2, dim=dim),
trident.function.cosine_similarity(x1, x2, dim=dim),
)


@pytest.mark.parametrize(
"num_batches, y_size, x_size, dim",
"z_size, y_size, x_size, dim",
[(1280, 1000, 200, 0), (200, 1280, 200, 1), (640, 21, 86, 2)],
)
def test_backward(num_batches, y_size, x_size, dim, device):
def test_backward(z_size, y_size, x_size, dim, device):
factory_kwargs = {"device": device}

input = torch.randn(num_batches, y_size, x_size, **factory_kwargs)
other = torch.randn(num_batches, y_size, x_size, **factory_kwargs)
x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs)
x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs)

if dim == 0:
target_dim = (y_size, x_size)
elif dim == 1:
target_dim = (num_batches, x_size)
target_dim = (z_size, x_size)
else:
target_dim = (num_batches, y_size)
target_dim = (z_size, y_size)

target = torch.randn(target_dim, **factory_kwargs)
grad_ouput = torch.randn(target_dim, **factory_kwargs)

def train(func):
x1 = input.clone()
x2 = other.clone()
x1.requires_grad = x2.requires_grad = True
func(x1, x2).backward(target, retain_graph=True)
return [x1.grad, x2.grad]
i = x1.clone()
j = x2.clone()
i.requires_grad = j.requires_grad = True
func(i, j).backward(grad_ouput, retain_graph=True)
return i.grad, j.grad

grad_a = train(torch.nn.CosineSimilarity(dim=dim))
grad_b = train(trident.CosineSimilarity(dim=dim))
(x, y) = train(torch.nn.CosineSimilarity(dim))
(a, b) = train(trident.CosineSimilarity(dim))

assert util.equal(grad_a[0], grad_b[0])
assert util.equal(grad_a[1], grad_b[1])
assert util.equal(x, a)
assert util.equal(y, b)
Loading

0 comments on commit 70944c5

Please sign in to comment.