Skip to content

Commit

Permalink
Add a parametrized benchmark for swiglu (#1105)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk authored Sep 5, 2024
1 parent f915c3c commit e64d347
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 1 deletion.
82 changes: 82 additions & 0 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,88 @@ def foo(a):
return foo


class LitGPTSwigluBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
_args = (
BenchmarkArg(
name="config",
description="The LitGPT config (str, LitGPTConfig) to use. See the litgpt_model.py for details.",
),
BenchmarkArg(
name="batchdims",
description="The shape (Sequence[int]) of input batch dimensions. The input will have innermost dimensions of (config.seq_len,). Default is (16,).",
),
BenchmarkArg(
name="device",
description="A string representing the device to run on. Default is 'cuda'.",
),
BenchmarkArg(
name="dtype",
description="The dtype of the tensors. Default is thunder.float32.",
),
BenchmarkArg(
name="requires_grad",
description="Whether the input tensors require grad. Default is False.",
),
)

@classmethod
@property
def name(cls) -> str:
return "litgpt-swiglu"

@classmethod
@property
def description(cls) -> str:
return "LitGPT's 'swiglu' elementwise unary operation."

@classmethod
@property
def args(cls) -> tuple[BenchmarkArg, ...]:
return cls._args

def __init__(
self,
config: str | LitGPTConfig,
batchdims: Sequence[int],
device: str,
dtype: dtypes.dtype,
requires_grad: bool,
use_liger: bool = False,
) -> None:
super().__init__()

self.config = LitGPTConfig.from_name(config) if not isinstance(config, LitGPTConfig) else config
self.batchdims = batchdims
self.shape: Sequence[int] = batchdims + (self.config.block_size, self.config.intermediate_size)
self.device: str = device
self.dtype: dtypes.dtype = dtype
self.tdtype: torch.dtype = ltorch.to_torch_dtype(dtype)
self.requires_grad: bool = requires_grad
self.devices: list[str] = [device]
self.use_liger: bool = use_liger

def make_batch(self) -> tuple[list, dict]:
return (
make_tensor(self.shape, device=self.device, dtype=self.tdtype, requires_grad=self.requires_grad),
make_tensor(self.shape, device=self.device, dtype=self.tdtype, requires_grad=self.requires_grad),
), {}

def fn(self) -> Callable:
# https://github.com/Lightning-AI/litgpt/blob/fdf6a120056d1363287285599eb84907f6c589b9/litgpt/model.py#L372
def fn(x_fc_1, x_fc_2):
return torch.nn.functional.silu(x_fc_1) * x_fc_2

if self.use_liger:
try:
from liger_kernel.ops.swiglu import LigerSiLUMulFunction

return LigerSiLUMulFunction.apply
except ImportError:
raise ImportError("Requested to use the Liger SiLU Mul function, but the Liger kernel is not available")

return fn


class NanoGPTBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
_args = (
BenchmarkArg(
Expand Down
65 changes: 64 additions & 1 deletion thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LitGPTBenchmark,
LitGPTCausalSelfAttentionBenchmark,
LitGPTSDPABenchmark,
LitGPTSwigluBenchmark,
LlamaMLPBenchmark,
NanoGPTBenchmark,
NanoGPTCrossEntropyBenchmark,
Expand All @@ -39,7 +40,7 @@
from thunder.tests.litgpt_model import Config as LitGPTConfig
from thunder.tests.make_tensor import make_tensor


LIGER_FUSED_SWIGLU_AVAILABLE: bool = package_available("liger_kernel.ops.swiglu")
APEX_FUSED_ROPE_AVAILABLE: bool = package_available("fused_rotary_positional_embedding")
IMPORTANT_CONFIGS = [
"Llama-2-13b-hf",
Expand Down Expand Up @@ -233,6 +234,68 @@ def test_litgpt_gelu(benchmark, executor: Callable, bs: int, compute_type: Compu
benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)


# There are many configurations but only the following parameters affect the swiglu benchmark:
# - intermediate_size
# - block_size
# Let's select only the configurations that differ in these parameters
def get_configs_for_swiglu():
return get_unique_configs(("intermediate_size", "block_size"))


swiglu_executors = (
(torch_executor, False),
(torch_compile_executor, False),
(thunder_executor, False),
(torch_executor, True),
(torch_compile_executor, True),
)
swiglu_executors_ids = (
"torch",
"torch.compile",
"thunder",
"torch+liger",
"torch.compile+liger",
)


# Sample command to run this benchmark:
# pytest thunder/benchmarks/targets.py -k "test_litgpt_swiglu" --benchmark-group-by='param:config,param:bs,param:compute_type'
@pytest.mark.parametrize(
"executor,use_liger,",
swiglu_executors,
ids=swiglu_executors_ids,
)
# bs = batch size
# It's typically small for LLMs
@pytest.mark.parametrize(
"bs,",
(2**i for i in range(0, 2)),
ids=(f"bs{2**i}" for i in range(0, 2)),
)
@parametrize_compute_type
@pytest.mark.parametrize(
"config,",
get_configs_for_swiglu(),
)
def test_litgpt_swiglu(benchmark, executor: Callable, use_liger: bool, bs: int, compute_type: ComputeType, config: str):
if use_liger and not LIGER_FUSED_SWIGLU_AVAILABLE:
pytest.skip("Liger fused swiglu is unavailable")

bench: Benchmark = LitGPTSwigluBenchmark(
config=config,
batchdims=(bs,),
device="cuda:0",
dtype=thunder.bfloat16,
requires_grad=is_requires_grad(compute_type),
use_liger=use_liger,
)

args, kwargs = bench.make_batch()
fn = executor(bench.fn())

benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)


@pytest.mark.parametrize(
"executor,",
executors,
Expand Down

0 comments on commit e64d347

Please sign in to comment.