Linear/Dense performance for PyTorch vs JAX (flax/stax) #8497
-
I have quite some training code where JAX (in combination with flax) is faster than PyTorch. However, the code below is an example where JAX is much slower: import functools
import time
from typing import Any
import flax.linen as jnn
import jax
import jax.numpy as jnp
import jax.experimental.stax as stax
import numpy as np
import torch
import torch.nn as tnn
Array = Any
class TorchLinear(tnn.Module):
def __init__(self, layer_size):
super().__init__()
self.ff = tnn.Linear(layer_size, layer_size, bias=False)
def forward(self, input_):
output = self.ff(input_)
return output
class JaxLinear(jnn.Module):
layer_size: int
@jnn.compact
def __call__(self, input_: Array) -> Array:
output = jnn.Dense(self.layer_size, use_bias=False)(input_)
return output
def benchmark_torch_linear(inputs, layer_size, device, label):
with torch.no_grad():
# create model
model = TorchLinear(layer_size)
model.to(device)
# run
start = time.time()
outputs = []
for input_ in inputs:
output = model(input_)
outputs.append(output)
outputs = torch.stack(outputs)
end = time.time()
duration = end - start
print(f"{label}: {duration}")
return duration
def benchmark_jax_linear(inputs, layer_size, device, label):
# create model
rng = jax.random.PRNGKey(0)
model = JaxLinear(layer_size)
@jax.jit
def init(*args):
return model.init(*args)
input_shape = (batch_size, layer_size)
params = init(rng, jnp.ones(input_shape))
@functools.partial(jax.jit, static_argnums=(0,))
def step(model, params, input_):
output = model.apply(params, input_)
return output
# run
start = time.time()
outputs = []
for input_ in inputs:
output = step(model, params, input_)
outputs.append(output)
outputs = jnp.stack(outputs)
outputs[0].block_until_ready()
end = time.time()
duration = end - start
print(f"{label}: {duration}")
return duration
def benchmark_stax_linear(inputs, layer_size, device, label):
# create model
rng = jax.random.PRNGKey(0)
init_fn, apply_fn = stax.Dense(layer_size) # TODO: has to have bias
input_shape = (batch_size, layer_size)
_, params = init_fn(rng, input_shape)
@functools.partial(jax.jit, static_argnums=(0,))
def step(apply_fn, params, input_):
output = apply_fn(params, input_)
return output
# run
start = time.time()
outputs = []
for input_ in inputs:
output = step(apply_fn, params, input_)
outputs.append(output)
outputs = jnp.stack(outputs)
outputs[0].block_until_ready()
end = time.time()
duration = end - start
print(f"{label}: {duration}")
return duration
if __name__ == "__main__":
# benchmark
runs = 10
batch_size = 32
seq_len = 100
layer_size = 250
device = torch.device("cuda")
# input data
input_torch = (torch.rand(seq_len, batch_size, layer_size, device=device) < 0.2).float().contiguous()
input_np = np.asarray(input_torch.cpu(), dtype=jnp.float32)
input_jnp = jnp.asarray(input_torch.cpu())
# pytorch / loop outside module
for _ in range(runs):
benchmark_torch_linear(input_torch, layer_size, device, "torch, loop outside")
# jax / np input / loop outside module
for _ in range(runs):
benchmark_jax_linear(input_np, layer_size, device, "jax, np input, loop outside")
# stax / np input / loop outside module
for _ in range(runs):
benchmark_stax_linear(input_np, layer_size, device, "stax, np input, loop outside") The output of the above, when run on a laptop GPU, is: torch, loop outside: 0.004362821578979492
torch, loop outside: 0.002480745315551758
torch, loop outside: 0.002397775650024414
torch, loop outside: 0.0025179386138916016
torch, loop outside: 0.002394437789916992
torch, loop outside: 0.002406597137451172
torch, loop outside: 0.0026977062225341797
torch, loop outside: 0.002962827682495117
torch, loop outside: 0.0031156539916992188
torch, loop outside: 0.0030014514923095703
jax, np input, loop outside: 0.8470847606658936
jax, np input, loop outside: 0.0435943603515625
jax, np input, loop outside: 0.043665170669555664
jax, np input, loop outside: 0.042426347732543945
jax, np input, loop outside: 0.04245901107788086
jax, np input, loop outside: 0.045491933822631836
jax, np input, loop outside: 0.04279947280883789
jax, np input, loop outside: 0.042189836502075195
jax, np input, loop outside: 0.042439937591552734
jax, np input, loop outside: 0.04413032531738281
stax, np input, loop outside: 0.09287095069885254
stax, np input, loop outside: 0.03866934776306152
stax, np input, loop outside: 0.03732895851135254
stax, np input, loop outside: 0.03871488571166992
stax, np input, loop outside: 0.03256487846374512
stax, np input, loop outside: 0.03566265106201172
stax, np input, loop outside: 0.03671145439147949
stax, np input, loop outside: 0.03417348861694336
stax, np input, loop outside: 0.03485298156738281
stax, np input, loop outside: 0.03808856010437012 As you can see, the difference for feeding a sequence through a simple |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 10 replies
-
for a proper comparison you should pre-heat the jit. For a fair comparison you should also feed |
Beta Was this translation helpful? Give feedback.
for a proper comparison you should pre-heat the jit.
So you should call the jitted function once before starting the timer, so that you don't profile jit time.
For a fair comparison you should also feed
jnp.array
s and not numpy arrays to jax.