diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index ab4b7634b8..543c3520fa 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -2,143 +2,568 @@ # # See LICENSE for license information. -import os -from contextlib import nullcontext +import random +import contextlib import pytest import torch - +from typing import Optional +from transformer_engine.pytorch.cpu_offload import _CPUOffloadBackend +from transformer_engine.pytorch.cpu_offload import CPUOffload, get_cpu_offload_context +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine_torch as tex +EPSILON = 0.1 -# Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +# Disable garbage collection to tests if there are reference cycles. +# We do not want them, because they can result in CUDA out of memory errors. +import gc +gc.disable() -fp8_recipes = [ - None, # non-fp8 - # recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet - recipe.Float8CurrentScaling(), - recipe.DelayedScaling(), -] +class Utils: + tensor1 = torch.randn((1000, 1000), device="cuda") + _B = 16 + _S = 256 + _H = 4 + _D = 1024 -SIZE = 512 -NUM_HEADS = 8 -NUM_LAYERS = 5 -EPSILON = 0.1 + @staticmethod + def long_job(stream: Optional[torch.cuda.Stream] = None): + NUM_ITERS = 6000 + if stream is None: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + for i in range(NUM_ITERS): + Utils.tensor1.normal_() + + @staticmethod + def measure_time(func): + import time + start = time.time() + func() + end = time.time() + return (end - start) * 1000 -# Flash attention saves some internal tensor for the backward pass -# that cannot be offloaded to CPU. -assert os.getenv("NVTE_FLASH_ATTN") == "0" + @staticmethod + def get_cuda_memory_mb(): + return torch.cuda.memory_allocated() / (1024**2) -# Offloading is supported for attention only for fused and flash attention backends, -# so the use of bfloat16 is required. -# -# For the TransformerLayer, activation offloading with dropout is not supported, -# so we set hidden_dropout to 0.0. -model_types = { - "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), - "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16), - "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), - "multihead_attention": lambda: te.MultiheadAttention( - SIZE, NUM_HEADS, params_dtype=torch.bfloat16 - ), - "transformer_layer": lambda: te.TransformerLayer( - SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 - ), -} - - -def _get_input(): - return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda() - - -def _get_fp8_weight_cache_size(models, fp8_recipe): - """ - Calculate the total FP8 weight cache size (in MB) for a list of models. - """ - if fp8_recipe is None: - return 0 - - params_bytes = 0 - for model in models: - for name, param in model.named_parameters(): - if "weight" in name: - params_bytes += param.numel() - - # One byte for columnwise and one byte for rowwise, - # hence multiply by 2 and convert to MB - # there is 1 byte of scale per 32 elements in mxFP8 - factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1 - return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2) - - -def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload): - tensor = _get_input() - if cpu_offload: - offload_context, sync_function = te.get_cpu_offload_context( + @staticmethod + def get_max_cuda_memory_mb(): + return torch.cuda.max_memory_allocated() / (1024**2) + + @staticmethod + def get_cpu_memory_mb() -> float: + import psutil, os + return psutil.Process(os.getpid()).memory_info().rss / (1024**2) + + + @staticmethod + def get_layer_names(): + return ["linear", "layernorm_linear", "layernorm_mlp", "grouped_linear", "multihead_attention", "transformer_layer"] + + @staticmethod + def create_layer(layer_type: str): + if layer_type == "linear": + return te.Linear(Utils._D, Utils._D) + elif layer_type == "layernorm_linear": + return te.LayerNormLinear(Utils._D, Utils._D) + elif layer_type == "layernorm_mlp": + return te.LayerNormMLP(Utils._D, Utils._D) + elif layer_type == "multihead_attention": + return te.MultiheadAttention(Utils._D, Utils._H, attention_dropout=0.0) + elif layer_type == "grouped_linear": + return te.GroupedLinear(Utils._H, Utils._D, Utils._D) + elif layer_type == "transformer_layer": + return te.TransformerLayer(Utils._D, Utils._D, Utils._H, attention_dropout=0.0, hidden_dropout=0.0) + else: + raise ValueError(f"Unknown layer type: {layer_type}") + + @staticmethod + def get_recipe_names(): + return ["high precision", "fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling", "mxfp8"] + + @staticmethod + def create_tensor(recipe_name: str, requires_grad: bool = False): + shape = (Utils._B, Utils._S, Utils._D) + tensor = torch.randn(shape, device="cuda") + if recipe_name == "high precision": + tensor = tensor.requires_grad_() if requires_grad else tensor + return tensor + elif recipe_name == "fp8_delayed_scaling": + quantizer = te.tensor.float8_tensor.Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1.0], device="cuda"), + amax=torch.tensor([1.], device="cuda") + ) + return quantizer(tensor) + elif recipe_name == "fp8_current_scaling": + quantizer = te.tensor.float8_tensor.Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda") + return quantizer(tensor) + elif recipe_name == "fp8_block_scaling": + quantizer = te.tensor.float8_blockwise_tensor.Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True) + return quantizer(tensor) + elif recipe_name == "mxfp8": + quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3) + return quantizer(tensor) + + @staticmethod + def skip_if_recipe_not_supported(recipe_name: str): + fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() + fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() + ) + if recipe_name == "fp8_delayed_scaling" and not fp8_available: + pytest.skip(reason_for_no_fp8) + elif recipe_name == "fp8_current_scaling" and not fp8_available: + pytest.skip(reason_for_no_fp8) + elif recipe_name == "fp8_block_scaling" and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) + elif recipe_name == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + + @staticmethod + def create_recipe_ctx(recipe_name: str): + if recipe_name == "high precision": + return lambda: contextlib.nullcontext() + elif recipe_name == "fp8_delayed_scaling": + return lambda: te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()) + elif recipe_name == "fp8_current_scaling": + return lambda: te.fp8_autocast(fp8_recipe=recipe.Float8CurrentScaling()) + elif recipe_name == "fp8_block_scaling": + return lambda: te.fp8_autocast(fp8_recipe=recipe.Float8BlockScaling()) + elif recipe_name == "mxfp8": + return lambda: te.fp8_autocast(fp8_recipe=recipe.MXFP8BlockScaling()) + + @staticmethod + def get_tensor_size_mb(tensor): + if type(tensor) == torch.Tensor: + return tensor.numel() * tensor.element_size() / (1024**2) + else: + # 1 byte for rowwise, 1 byte for columnwise + return tensor.numel() * 2 / (1024**2) + + @staticmethod + def memory_leak_check(): + # Should be called before each test. + # Only cublas workspaces and some global tensors are allowed to be allocated. + # All other allocations should be released. + # This is a simple check to catch memory leaks. + assert Utils.get_cuda_memory_mb() < 100, f"Memory leak: {Utils.get_cuda_memory_mb()} MB" + +class AddOneLayer(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + 1 + +class TestsBackend: + @pytest.mark.parametrize("reuse_gpu_buffers", [True, False]) + @pytest.mark.parametrize("random_num_tensors", [True, False]) + @pytest.mark.parametrize("recipe_name", Utils.get_recipe_names()) + def test_general(self, reuse_gpu_buffers, random_num_tensors, recipe_name): + Utils.memory_leak_check() + Utils.skip_if_recipe_not_supported(recipe_name) + if reuse_gpu_buffers and random_num_tensors: + pytest.skip("Cannot have random number of tensors and reuse_gpu_buffers at the same time.") + backend = _CPUOffloadBackend( + reuse_gpu_buffers=reuse_gpu_buffers, + ) + NUM_LAYERS = 10 + NUM_ITERATIONS = 10 + + for _ in range(NUM_ITERATIONS): + original_tensors = [] + tensors_cpu = [] + layer_ids = [] + + for i in range(NUM_LAYERS): + NUM_LAYER_TENSORS = random.randint(1, 10) if random_num_tensors else 1 + layer_tensors = [] + layer_tensors_cpu = [] + backend.start_offloaded_layer_fwd() + for _ in range(NUM_LAYER_TENSORS): + tensor = Utils.create_tensor(recipe_name) + layer_tensors.append(tensor) + if random.randint(0, 1) == 0: + backend.mark_can_start_offload(tensor) + tensor_cpu = backend.offload(tensor) + assert tensor.device.type == "cuda" + assert tensor_cpu.device.type == "cpu" + layer_tensors_cpu.append(tensor_cpu) + layer_id = backend.end_offloaded_layer_fwd() + layer_ids.append(layer_id) + tensors_cpu.append(layer_tensors_cpu) + original_tensors.append(layer_tensors) + backend.finish_fwd() + backend.start_bwd_reloading() + for i in range(NUM_LAYERS - 1, -1, -1): + backend.start_offloaded_layer_bwd(layer_ids[i]) + for j in range(len(tensors_cpu[i])): + tensor_gpu = backend.reload(tensors_cpu[i][j]) + assert tensor_gpu.device.type == "cuda" + assert tensor_gpu.shape == original_tensors[i][j].shape + assert tensor_gpu.dtype == original_tensors[i][j].dtype + torch.testing.assert_close(tensor_gpu, original_tensors[i][j]) + backend.end_offloaded_layer_bwd() + torch.cuda.synchronize() + + @pytest.mark.parametrize("recipe_name", Utils.get_recipe_names()) + def test_memory(self, recipe_name): + Utils.skip_if_recipe_not_supported(recipe_name) + #reset max memory allocated + init_cuda_memory = Utils.get_cuda_memory_mb() + x = Utils.create_tensor(recipe_name) + x_size = Utils.get_tensor_size_mb(x) + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + x_size, 0.1) + + torch.cuda.synchronize() + backend = _CPUOffloadBackend() + backend.start_offloaded_layer_fwd() + x1_cpu = backend.offload(x) + del x + num1 = backend.end_offloaded_layer_fwd() + + torch.cuda.synchronize() + + # Memory is not released yet. + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + x_size, 0.1) + + backend.start_offloaded_layer_fwd() + # Next offloaded layer, memory should be released. + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + + x = Utils.create_tensor(recipe_name) + x2_cpu = backend.offload(x) + del x + num2 = backend.end_offloaded_layer_fwd() + + backend.start_offloaded_layer_fwd() + x = Utils.create_tensor(recipe_name) + x3_cpu = backend.offload(x) + del x + num3 = backend.end_offloaded_layer_fwd() + + backend.finish_fwd() + torch.cuda.reset_max_memory_allocated() + + backend.start_bwd_reloading() + torch.cuda.synchronize() + + backend.start_offloaded_layer_bwd(num3) + backend.reload(x3_cpu) + backend.end_offloaded_layer_bwd() + + torch.cuda.synchronize() + + backend.start_offloaded_layer_bwd(num2) + backend.reload(x2_cpu) + backend.end_offloaded_layer_bwd() + + torch.cuda.synchronize() + + backend.start_offloaded_layer_bwd(num1) + backend.reload(x1_cpu) + backend.end_offloaded_layer_bwd() + + + torch.cuda.synchronize() + # Third copy is released. + assert Utils.get_max_cuda_memory_mb() < init_cuda_memory + 2 * x_size + 0.1 + + def test_mark_can_start_offload(self): + """ + Check that calling `mark_can_start_offload` lets the backend overlap the + D2H copy with computation. The runtime with the mark should therefore + be strictly smaller than without it. + """ + torch.cuda.synchronize() + tensor = torch.randn((128, 512, 512), device="cuda") + + def _timed_run(use_mark: bool) -> float: + """Run a single forward pass and return its wall-clock time (ms).""" + def _run(): + backend = _CPUOffloadBackend() + backend.start_offloaded_layer_fwd() + if use_mark: + backend.mark_can_start_offload(tensor) + + # Simulate compute that should overlap with the offload copy. + Utils.long_job() + + backend.offload(tensor) + backend.end_offloaded_layer_fwd() + backend.finish_fwd() + backend.start_bwd_reloading() + + # Make sure all CUDA work is finished before timing stops. + torch.cuda.current_stream().synchronize() + torch.cuda.synchronize() + + return Utils.measure_time(_run) + + # Warm-up + _timed_run(False) + time_without_mark = _timed_run(False) + time_with_mark = _timed_run(True) + print(f"time_without_mark: {time_without_mark} ms, " + f"time_with_mark: {time_with_mark} ms") + assert time_with_mark < time_without_mark + + @pytest.mark.parametrize("recipe_name", Utils.get_recipe_names()) + def test_multiple_tensor_offload(self, recipe_name): + Utils.memory_leak_check() + Utils.skip_if_recipe_not_supported(recipe_name) + init_cpu_memory = Utils.get_cpu_memory_mb() + init_cuda_memory = Utils.get_cuda_memory_mb() + backend = _CPUOffloadBackend() + backend.start_offloaded_layer_fwd() + x1 = Utils.create_tensor(recipe_name) + x_size = Utils.get_tensor_size_mb(x1) + backend.offload(x1) + backend.offload(x1) + backend.offload(x1) + # Only one copy of tensor on cpu is allocated. + assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1) + del x1 + backend.end_offloaded_layer_fwd() + backend.finish_fwd() + + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + + @pytest.mark.parametrize("job_forward", [True, False]) + @pytest.mark.parametrize("job_backward", [True, False]) + @pytest.mark.parametrize("recipe_name", Utils.get_recipe_names()) + def test_overlap(self, job_forward, job_backward, recipe_name): + Utils.memory_leak_check() + Utils.skip_if_recipe_not_supported(recipe_name) + if not job_forward and not job_backward: + pytest.skip("") + NUM_LAYERS = 10 + def _run(job_forward, job_backward, offloads): + backend = _CPUOffloadBackend() + offloaded_tensors = [] + layer_ids = [] + for _ in range(NUM_LAYERS): + backend.start_offloaded_layer_fwd() + if offloads: + offloaded_tensors.append(backend.offload(Utils.create_tensor(recipe_name))) + else: + offloaded_tensors.append(Utils.create_tensor(recipe_name)) + if job_forward: + Utils.long_job() + layer_id = backend.end_offloaded_layer_fwd() + layer_ids.append(layer_id) + backend.finish_fwd() + backend.start_bwd_reloading() + for i in range(NUM_LAYERS - 1, -1, -1): + backend.start_offloaded_layer_bwd(layer_ids[i]) + if offloads: + backend.reload(offloaded_tensors[i]) + if job_backward: + Utils.long_job() + backend.end_offloaded_layer_bwd() + torch.cuda.synchronize() + + def _measure_time(job_forward, job_backward, offloads): + return Utils.measure_time(lambda: _run(job_forward, job_backward, offloads)) + + _run(True, True, True) # warm-up + + time_offload_only = _measure_time(False, False, True) + time_offload_and_selected_jobs = _measure_time(job_forward, job_backward, True) + time_selected_jobs = _measure_time(job_forward, job_backward, False) + + print(f"time_offload_only: {time_offload_only:.2f} ms, " + f"time_offload_and_selected_jobs: {time_offload_and_selected_jobs:.2f} ms, " + f"time_selected_jobs: {time_selected_jobs:.2f} ms") + + assert time_offload_only + time_selected_jobs > time_offload_and_selected_jobs + EPSILON + + +class TestTEAPI: + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe_name", Utils.get_recipe_names()) + def test_offload_one_layer(self, layer_type, recipe_name): + Utils.memory_leak_check() + Utils.skip_if_recipe_not_supported(recipe_name) + torch.cuda.synchronize() + cpu_offload = CPUOffload() + recipe_ctx = Utils.create_recipe_ctx(recipe_name) + layer = Utils.create_layer(layer_type) + last_layer = AddOneLayer() + inp = Utils.create_tensor("high precision") + + m_splits = {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} \ + if layer_type == "grouped_linear" else {} + + with recipe_ctx(): + out = layer(inp, is_first_microbatch=True, **m_splits) + out = last_layer(out) + out.sum().backward() + + # run with is_first_microbatch=True to cache the fp8 casts + del inp + init_cuda_memory = Utils.get_cuda_memory_mb() + inp = Utils.create_tensor("high precision") + with recipe_ctx(): + out = layer(inp, is_first_microbatch=False, **m_splits) + out = last_layer(out) + del inp + activation_size = Utils.get_cuda_memory_mb() - init_cuda_memory + out.sum().backward() + + init_cuda_memory = Utils.get_cuda_memory_mb() + + # run layer with offload + layer_offload = cpu_offload(layer, offload_activations=True) + last_layer_offload = cpu_offload(last_layer, is_last_layer=True) + inp = Utils.create_tensor("high precision") + with recipe_ctx(): + out = layer_offload(inp, is_first_microbatch=False, **m_splits) + out = last_layer_offload(out) + del inp + offloaded_size = cpu_offload.backend.get_offloaded_total_size_mb() + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + assert offloaded_size == pytest.approx(activation_size, 0.1) + + out.sum().backward() + + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe_name", Utils.get_recipe_names()) + def test_pipeline_parallel(self, layer_type, recipe_name): + Utils.memory_leak_check() + Utils.skip_if_recipe_not_supported(recipe_name) + recipe_ctx = Utils.create_recipe_ctx(recipe_name) + cpu_offload = CPUOffload() + + layer1 = Utils.create_layer(layer_type) + layer2 = Utils.create_layer(layer_type) + layer3 = Utils.create_layer(layer_type) + + def _run(inp): + m_splits = {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} \ + if layer_type == "grouped_linear" else {} + + with recipe_ctx(): + out = layer1(inp, **m_splits) + out = layer2(out, **m_splits) + out = layer3(out, **m_splits) + return out.sum() + + + def _run_offload(inp): + m_splits = {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} \ + if layer_type == "grouped_linear" else {} + layer1_offload = cpu_offload(layer1, offload_activations=True) + layer2_offload = cpu_offload(layer2, offload_activations=True) + layer3_offload = cpu_offload(layer3, is_last_layer=True) + with recipe_ctx(): + out = layer1_offload(inp, **m_splits) + out = layer2_offload(out, **m_splits) + out = layer3_offload(out, **m_splits) + return out.sum() + + inps = [Utils.create_tensor("high precision", requires_grad=True) for _ in range(3)] + + outs = [] + for i in range(3): + outs.append(_run(inps[i])) + for out in outs: + out.backward() + + inps_offload = [inps[i].clone().detach() for i in range(3)] + for i in range(3): + inps_offload[i] = inps_offload[i].requires_grad_() + # run with offload + outs_offload = [] + for i in range(3): + outs_offload.append(_run_offload(inps_offload[i])) + for out in outs_offload: + out.backward() + + # check if inp grads are the same + if recipe_name == "high precision": + for i in range(3): + assert torch.allclose(inps[i].grad, inps_offload[i].grad) + + def test_fake_tensor(self): + Utils.memory_leak_check() + Utils.skip_if_recipe_not_supported("high precision") + layer = Utils.create_layer("linear") + cpu_offload = CPUOffload() + layer_offload = cpu_offload(layer) + layer_offload_2 = cpu_offload(layer) + layer_offload_3 = cpu_offload(layer, is_last_layer=True) + inp = Utils.create_tensor("high precision") + model = torch.nn.Sequential(layer_offload, layer_offload_2, layer_offload_3) + # torch compile model + model = torch.compile(model) + out = model(inp) + out.sum().backward() + + +class TestLegacyAPI: + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe_name", Utils.get_recipe_names()) + def test_legacy_api(self, layer_type, recipe_name): + Utils.memory_leak_check() + Utils.skip_if_recipe_not_supported(recipe_name) + offload_ctx, sync_function = get_cpu_offload_context( enabled=True, - num_layers=len(models) - 1, - model_layers=len(models), + num_layers=1, + model_layers=2, offload_activations=True, - offload_weights=False, + offload_weights=False ) - else: - offload_context = nullcontext() - sync_function = lambda x: x - - for model in models: - with te.fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe - ), offload_context: - tensor = model(tensor) - tensor = sync_function(tensor) - - max_mem_used = torch.cuda.memory_allocated() / (1024**2) - torch.cuda.synchronize() - - return max_mem_used - - -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("model_key", model_types.keys()) -def test_cpu_offload(fp8_recipe, model_key) -> None: - """ - We run three configurations: - (1) No offloading: All activations remain on the GPU between forward and backward passes. - (2) No offloading (one layer): Only the first layer's activations remain on the GPU between - forward and backward passes. - (3) With offloading (all layers): Only the last layer's activations remain on the GPU - between forward and backward passes, while all other layers are offloaded to the CPU. - - We expect the memory consumption of configurations (2) and (3) to be similar, with - the difference being the size of the FP8 cache that is not offloaded to the CPU. - We also expect this memory consumption to be smaller than in scenario (1). - """ - - model_cls = model_types[model_key] - models_list = [model_cls() for _ in range(NUM_LAYERS)] - - if fp8_recipe and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None: - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) + recipe_ctx = Utils.create_recipe_ctx(recipe_name) + layer = Utils.create_layer(layer_type) + inp = Utils.create_tensor("high precision") - without_offloading = _measure_memory_between_forward_and_backward( - models_list, fp8_recipe, False - ) - without_offloading_one_layer = _measure_memory_between_forward_and_backward( - models_list[:1], fp8_recipe, False - ) - with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True) - - assert with_offloading < without_offloading - - # The only difference between the memory consumption of with_offloading - # and without_offloading_one_layer should be the size of the FP8 weights cache, - # which is not offloaded to the CPU. - memory_consumption_diff = abs(with_offloading - without_offloading_one_layer) - assert ( - memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON - ) + m_splits = {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} \ + if layer_type == "grouped_linear" else {} + + with recipe_ctx(): + out = layer(inp, is_first_microbatch=True, **m_splits) + out.sum().backward() + + init_cuda_memory = Utils.get_cuda_memory_mb() + + # run layer with offload + inp = Utils.create_tensor("high precision") + with offload_ctx, recipe_ctx(): + out = layer(inp, is_first_microbatch=False, **m_splits) + out = sync_function(out) + with offload_ctx, recipe_ctx(): + out = out + 1 + out = sync_function(out) + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + out.sum().backward() + + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe_name", Utils.get_recipe_names()) + def test_sanity_legacy_api(self, layer_type, recipe_name): + Utils.memory_leak_check() + Utils.skip_if_recipe_not_supported(recipe_name) + OFFLOAD_LAYERS = 6 + NUM_LAYERS = 10 + offload_ctx, sync_function = get_cpu_offload_context( + enabled=True, + num_layers=OFFLOAD_LAYERS, + model_layers=NUM_LAYERS, + ) + recipe_ctx = Utils.create_recipe_ctx(recipe_name) + layers = [Utils.create_layer(layer_type) for _ in range(NUM_LAYERS)] + inp = Utils.create_tensor("high precision") + m_splits = {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} \ + if layer_type == "grouped_linear" else {} + for i in range(NUM_LAYERS): + with offload_ctx, recipe_ctx(): + out = layers[i](inp, is_first_microbatch=False, **m_splits) + out = sync_function(out) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 2ca133e77b..91ceaa4e0e 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -30,7 +30,7 @@ TransformerLayer, RMSNorm, LayerNorm, - get_cpu_offload_context, + CPUOffload ) from transformer_engine.common import recipe import transformer_engine_torch as tex @@ -289,15 +289,14 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): _disable_wgrads(block) if cpu_offload: - offload_context, sync_function = get_cpu_offload_context(enabled=True) - else: - offload_context = nullcontext() - sync_function = lambda x: x + cpu_offload = CPUOffload() + block = cpu_offload(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context: + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): te_out = block(te_inp_hidden_states) - te_out = sync_function(te_out) + if cpu_offload: + cpu_offload.sync_before_bwd() loss = te_out.sum() loss.backward() torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 1b73c8667c..54f56cc40d 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -105,7 +105,7 @@ def _load_library(): from transformer_engine.pytorch.graph import make_graphed_callables from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker -from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context +from transformer_engine.pytorch.cpu_offload import CPUOffload from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 9feef64210..2dba301df8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -683,12 +683,12 @@ def forward( ) else: from transformer_engine.pytorch.cpu_offload import ( - CPUOffloadEnabled, - mark_activation_offload, + is_cpu_offload_enabled, + offload, ) - if CPUOffloadEnabled: - mark_activation_offload( + if is_cpu_offload_enabled(): + offload( query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv ) @@ -1054,19 +1054,19 @@ def forward( ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) from transformer_engine.pytorch.cpu_offload import ( - CPUOffloadEnabled, - mark_activation_offload, + is_cpu_offload_enabled, + offload, ) - if CPUOffloadEnabled: + if is_cpu_offload_enabled(): if ctx.fp8: tensor_list = fp8_tensors else: tensor_list = [q, k, v, out_save] qkv_layout = "sbhd_sbhd_sbhd" - mark_activation_offload(*tensor_list) - mark_activation_offload(*aux_ctx_tensors) + offload(*tensor_list) + offload(*aux_ctx_tensors) ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 7d50b9fa54..1282a40d99 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1117,15 +1117,15 @@ def forward( cp_stream=self.cp_stream, cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, - fp8_meta=self.fp8_meta, + fp8_meta=self.fp8_meta, quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, ) - from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled + from transformer_engine.pytorch.cpu_offload import is_cpu_offload_enabled - if CPUOffloadEnabled: + if is_cpu_offload_enabled(): warnings.warn( "Attention activation Offloading is only implemented" "with Flash Attention and Fused Attention!" diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 814e699557..b8ca902958 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -3,546 +3,617 @@ # See LICENSE for license information. """Functionality for CPU offloading of tensors saved for backward pass.""" + from __future__ import annotations -from contextlib import nullcontext -from typing import Any, Dict, Optional +import contextlib import torch +from torch.autograd.graph import saved_tensors_hooks -from .tensor.float8_tensor import Float8Tensor - -__all__ = ["get_cpu_offload_context"] +__all__ = [ + "is_cpu_offload_enabled", "mark_is_weight", "mark_can_start_offload", + "CPUOffload", "get_cpu_offload_context" +] -CPUOffloadEnabled = False +MIN_TENSOR_SIZE_TO_OFFLOAD = 1000 +CURRENT_CPU_OFFLOAD_HANDLER = None +def is_cpu_offload_enabled(): + return CURRENT_CPU_OFFLOAD_HANDLER is not None -def mark_activation_offload(*tensors): - """Set the type of the offloading needed for a tensor.""" +def mark_is_weight(*tensors: torch.Tensor): for tensor in tensors: - if tensor is None: - continue - if type(tensor) in [torch.Tensor, torch.nn.Parameter]: - tensor.activation_offloading = True - else: - data_tensors = tensor.get_data_tensors() - for tensor in data_tensors: - if tensor is not None: - tensor.activation_offloading = True - # This is a hack to force clear the tensor after it is offloaded. - # It is needed, because .*TensorBase classes are saved in the ctx, - # and they contain the reference to their data tensors. - tensor.needs_force_clear = True - - -def is_cpu_offload_enabled() -> bool: - """Check if CPU offloading is currently enabled.""" - return CPUOffloadEnabled - - -class CpuOffloadSavedTensorHook: - """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. - - In this context, the ``on_save_for_backward`` method will be called every time - a tensor is saved for backward (this includes intermediary results saved using - :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but - also those recorded by a PyTorch-defined operation). - - The ``on_get_saved_tensors`` method will be called when the backward function - of this op attempts to retrieve the saved tensor from context (this includes - :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the - as input the return value of the ``on_save_for_backward``, and is meant to return - an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of - size, device and element values. - - Example: - - >>> import torch - >>> from typing import Any - >>> - >>> class DummyHook(CpuOffloadSavedTensorHook): - ... - ... def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - ... logging.info("On save", tensor) - ... return (tensor,) - ... - ... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - ... logging.info("On get", saved_state) - ... tensor, = saved_state - ... return tensor - ... - >>> a = torch.ones(5, requires_grad=True) - >>> b = torch.ones(5, requires_grad=True) * 2 - >>> with DummyHook(): - ... y = a * b - ... - On save tensor([1., 1., 1., 1., 1.], requires_grad=True) - On save tensor([2., 2., 2., 2., 2.], grad_fn=) - >>> y.sum().backward() - On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),) - On get (tensor([2., 2., 2., 2., 2.], grad_fn=),) + if tensor is not None: + tensor.is_weight = True +def mark_can_start_offload(*gpu_tensors: torch.Tensor): """ - - def __init__(self) -> None: - self.inside_context = False - - def __enter__(self): - global CPUOffloadEnabled - CPUOffloadEnabled = True - - self.inside_context = True - torch._C._autograd._push_saved_tensors_default_hooks( - self.on_save_for_backward, self.on_get_saved_tensor - ) - - def __exit__(self, *args: Any): - global CPUOffloadEnabled - CPUOffloadEnabled = False - - self.inside_context = False - torch._C._autograd._pop_saved_tensors_default_hooks() - - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - """On save for backward.""" - raise NotImplementedError( - "`on_save_for_backward: Callable[[torch.Tensor], Any]`" - "is not implemented in CpuOffloadHook class. Inherit " - "this class and implement your custom hooks" - ) - - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - """On get saved tensor.""" - raise NotImplementedError( - "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" - "is not implemented in CpuOffloadHook class. Inherit " - "this class and implement your custom hooks" - ) - - -class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): - """Context-manager that offloads/recovers tensors through an offload hander. - - The hook just offloads/recovers the tensor object to the handler through `tensor_push` - and `tensor_pop` interface. How the offload-handler manages the offloading, recovering - or prefetching timing is transparent to this hook. + Marks the moment the tensor can be offloaded. The tensor passed to the offload function, + must be the same object as the one passed to the mark_can_start_offload function, + to not delay the offloading. """ + if CURRENT_CPU_OFFLOAD_HANDLER is not None: + for gpu_tensor in gpu_tensors: + if gpu_tensor is not None: + CURRENT_CPU_OFFLOAD_HANDLER.mark_can_start_offload(gpu_tensor) - def __init__( - self, - offload_handler: OffloadHandler, - handler_extra_kwargs: Optional[Dict[str, Any]] = None, - debug: bool = False, - ) -> None: - if handler_extra_kwargs is None: - handler_extra_kwargs = {} - self.debug: bool = debug - self.offload_handler: OffloadHandler = offload_handler - self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs - super().__init__() - - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) - return retrieve_identifier - - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) - return tensor - - -class OffloadHandler: - """A base class for CPU offload-handler.""" - - def __init__(self) -> None: - pass - - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - """Tensor push.""" - raise NotImplementedError( - "`tensor_push is not implented in OffloadHandler class. " - "Inherit this class and implement your custom tensor_push." - ) - - def tensor_pop(self, tensor_tag: Any, **kwargs): - """Tensor pop.""" - raise NotImplementedError( - "`tensor_pop is not implented in OffloadHandler class. " - "Inherit this class and implement your custom tensor_pop." - ) - -class GroupCommitFunction(torch.autograd.Function): - """this is a dummy op with output identical to input. - However, it is necessary for marking a timepoint for offload handler to - accomplish all synchronizations. Implementing it as a function is necessary - because we need to actions in both forward and backward. +class _StreamedOffloader: """ + _StreamedOffloader represents one stream used to offload the tensors to the CPU. + It provides easy to use interface for offloading and synchronization. + """ + def __init__(self): + self.stream = torch.cuda.Stream() + self.cpu_tensors = {} + self.gpu_tensors = {} + + def mark_can_start_offload(self, gpu_tensor: torch.Tensor): + """ + If gpu_tensor is ready to offload on current stream, record the event. + We may be able to start the offloading earlier than in the save_for_backward. + """ + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + gpu_tensor._offload_event = event + + def offload(self, gpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Offload the tensor from GPU to CPU. Return the CPU copy of the tensor. + """ + if id(gpu_tensor) in self.gpu_tensors.keys(): + # One tensor can be an argument to the offload function multiple times, + # but only one cpu copy of the tensor is created. + return self.cpu_tensors[id(gpu_tensor)] + if hasattr(gpu_tensor, "_offload_event"): + # Used to start early copy of the tensor + # marked with mark_can_start_offload method. + gpu_tensor._offload_event.wait(self.stream) + gpu_tensor._offload_event = None + else: + # Synchronize with the current stream to ensure that + # the gpu tensor is computed before the copy starts. + self.stream.wait_stream(torch.cuda.current_stream()) + + # Since we allocate the cpu tensor on the different stream, + # we keep the copies of the cpu tensors in self._offload_cpu_tensors. + # This tensor is released in start_offloaded_layer_fwd method of the next layer, + # after the synchronization with the main stream. + # This prevents the release of the memory of the cpu tensor before the copy is finished. + with torch.cuda.stream(self.stream): + cpu_tensor = torch.empty_like(gpu_tensor, device="cpu", pin_memory=True) + cpu_tensor.copy_(gpu_tensor, non_blocking=True) + self.cpu_tensors[id(gpu_tensor)] = cpu_tensor + self.gpu_tensors[id(gpu_tensor)] = gpu_tensor + return cpu_tensor + + def wait_for_offloading(self): + """ + Wait for the offloading to finish. + """ + torch.cuda.current_stream().wait_stream(self.stream) + + def get_offloaded(self, gpu_tensor: torch.Tensor): + """ + Return the CPU copy of the tensor. + """ + return self.cpu_tensors[id(gpu_tensor)] + + def get_all_offloaded_tensors(self) -> list[torch.Tensor]: + """ + Return all the CPU copies of the tensors. + """ + return list(self.cpu_tensors.values()) + + def release_memory(self): + """ + Release the memory of the CPU tensors. + """ + self.cpu_tensors = {} + self.gpu_tensors = {} + +class _StreamedReloader: + """ + _StreamedReloader represents one stream used to reload the tensors from the CPU to the GPU. + It provides easy to use interface for reloading and synchronization. + + Parameters + ---------- + reuse_gpu_buffers : bool, default = False + Re-use the same GPU buffers when reloading tensors. All offloaded + layers must therefore produce activations of identical shapes, or an + assertion will be raised. + """ + def __init__(self, reuse_gpu_buffers: bool): + self.stream = torch.cuda.Stream() + self.gpu_tensors = {} + self._reuse_gpu_buffers = reuse_gpu_buffers + self._gpu_buffer_pool = [] # used to re-use the same GPU buffers + + def wait_for_main(self): + """ + Postpone the reloading process until this point in the main stream. + """ + self.stream.wait_stream(torch.cuda.current_stream()) + + def bulk_reload(self, cpu_tensors: list[torch.Tensor]): + """ + Reload all provided tensors from the CPU to the GPU. + The main stream must wait for this call to finish, + but it can start before the main stream reaches this point. + """ + if self._reuse_gpu_buffers: + assert len(cpu_tensors) == len(self._gpu_buffer_pool) or len(self._gpu_buffer_pool) == 0, \ + "All offloaded layers must produce the same number of \ + activation tensors with reuse_gpu_buffers=True" + for tensor in cpu_tensors: + with torch.cuda.stream(self.stream): + if self._reuse_gpu_buffers: + if len(self._gpu_buffer_pool) > 0: + buffer = self._gpu_buffer_pool.pop() + assert buffer.shape == tensor.shape, \ + "All offloaded layers must produce activations of identical \ + shapes to run with reuse_gpu_buffers=True" + else: + # if buffers are not allocated yet - invoked during the first offloaded layer + # with this _StreamedReloader + buffer = torch.empty_like(tensor, device="cuda") + else: + buffer = torch.empty_like(tensor, device="cuda") + buffer.copy_(tensor, non_blocking=True) + self.gpu_tensors[id(tensor)] = buffer + torch.cuda.current_stream().wait_stream(self.stream) + + def get_reloaded(self, cpu_tensor: torch.Tensor): + """ + Return the GPU copy of the tensor. + """ + assert id(cpu_tensor) in self.gpu_tensors.keys(),\ + "The tensor that you are trying to reload was not offloaded. \ + This error should not happen - please report the bug to the Transformer Engine." + return self.gpu_tensors[id(cpu_tensor)] + + def release_memory(self): + """ + Release the memory of the GPU tensors. + """ + if self._reuse_gpu_buffers: + self._gpu_buffer_pool.extend(self.gpu_tensors.values()) # return buffers to the pool + self.gpu_tensors.clear() + + def release_buffers(self): + """ + Release the memory of the GPU buffers. + """ + self._gpu_buffer_pool = [] + + +class _CPUOffloadBackend: + """ + Class providing unified interface for offloading and reloading tensors. + It can be translated into different public APIs. - @staticmethod - def forward(ctx, tensor, cpu_offload_handler): - # pylint: disable=missing-function-docstring - cpu_offload_handler.on_group_commit_forward() - ctx.cpu_offload_handler = cpu_offload_handler - # return the identical tensor - return tensor - - @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_group_commit_backward() - return grad_output, None + The calls should be in following order, represented by the following grammar: + CALLS -> PROGRAM* + PROGRAM -> FWD_LAYER* finish_fwd() start_bwd_reloading() BWD_LAYER* + FWD_LAYER -> start_offloaded_layer_fwd() (mark_can_start_offload()|offload())* end_offloaded_layer_fwd() + BWD_LAYER -> start_offloaded_layer_bwd() reload()* end_offloaded_layer_bwd() -group_prefetch_offload_commit = GroupCommitFunction.apply + The method end_offloaded_layer_bwd() returns the number of the layer, \ + which should be passed to the start_offloaded_layer_bwd(). It enables different + order of forward and backward passes - used for example in the pipeline parallelism. -class SynchronizedGroupOffloadHandler(OffloadHandler): - """Offload Handler that offloads/reloads in a synchronized way. - The device-to-host and host-to-device copying happen in the same stream - as the computation kernels, thus the copying will block computation. + Parameters + ---------- + reuse_gpu_buffers : bool, default = False + Re-use the same GPU buffers when reloading tensors. All offloaded + layers must produce activations of identical shapes, or an + assertion will be raised. """ - - def __init__( - self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False - ) -> None: - super().__init__() - - self.num_offload_group = num_offload_group - self.tensor_need_offloading_checker = tensor_need_offloading_checker - self.debug = debug - - self.groupid_reset() - - def groupid_reset(self): - """Groupid reset.""" - # Data structures to label saved tensors and book-keep their cpu copies. - # Currently, on push, create a new cpu tensor and copies; on pop, copies - # the tensor back to gpu and deletes the cpu tensor. - # These will increment whenever `group_commit()` is invoked - self.current_group, self.tensor_count_current_group = (0, 0) - self.torch_tensor_count = 0 - self.tensor_tag_to_state = {} - - def on_group_commit_forward(self): - """On group commit forward.""" - # finishing up with updating current group and tensor count - self.current_group += 1 # increment - self.tensor_count_current_group = 0 # reset - - def on_group_commit_backward(self): - """On group commit backward.""" - self.current_group -= 1 - assert self.current_group >= 0 - - @staticmethod - def offload(src_tensor, pin_memory=True): - """Offload.""" - - cpu_backup = torch.empty( - src_tensor.size(), - dtype=src_tensor.dtype, - layout=src_tensor.layout, - device="cpu", - pin_memory=pin_memory, - ) - - cpu_backup.copy_(src_tensor, non_blocking=pin_memory) - state = (src_tensor.device, cpu_backup) - return state - - @staticmethod - def reload(state, non_blocking=None): - """Reload.""" - dev, cpu_backup = state - if non_blocking is None: - non_blocking = cpu_backup.is_pinned() - return cpu_backup.to(dev, non_blocking=non_blocking) - - def tensor_push(self, tensor: torch.Tensor, **kwargs): - """Tensor push.""" - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - assert tensor_tag not in self.tensor_tag_to_state - if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( - tensor - ): - state = SynchronizedGroupOffloadHandler.offload(tensor) - self.tensor_tag_to_state[tensor_tag] = state + def __init__(self, reuse_gpu_buffers: bool = False): + # Two streams are used to reload the tensors - + # we want to release the memory at the end of the each layer, + # but we also want to start relaoding the next layer before and finish reloading it after. + # It is hard to achieve this with a single reloading stream. + self.streamed_reloaders = [ + _StreamedReloader(reuse_gpu_buffers) for _ in range(2) + ] + # we switch the streamed reloader after every layer + # this int indicates which reloader to use + self.reloader_parity = 0 + + self.streamed_offloader = _StreamedOffloader() # one stream for offloading + + self.cur_layer_id = 0 + self.total_num_of_layers = 0 + self.total_num_of_reloaded_layers = 0 + + self._total_offloaded_size = 0 + + self.first_layer_fwd_flag = False + self.inside_offloaded_layer_bwd_flag = False + + # layer_num -> id of cpu tensor -> cpu tensor + # This dictionary of dictionaries stores the cpu copies of the tensors of all the layers. + # They are used for reloading. + self._offload_cpu_tensors_for_each_layer: dict[int, dict[int, torch.Tensor]] = {} + + def start_offloaded_layer_fwd(self): + """ + Invoked before the new offloaded layer is started. + """ + if not self.first_layer_fwd_flag: + self._finish_layer_offload() else: - # will be offloaded together after group commit - self.tensor_tag_to_state[tensor_tag] = tensor - - return tensor_tag - - def tensor_pop(self, tensor_tag, **kwargs): - """Tensor pop.""" - assert tensor_tag in self.tensor_tag_to_state - state = self.tensor_tag_to_state.pop(tensor_tag) - if isinstance(state, tuple): - tensor = SynchronizedGroupOffloadHandler.reload(state) - else: - tensor = state - return tensor - - -class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): - """Compared to synchronize, this uses more memory because of the buffer but - achieves better performance due to the overlapping. D2h and h2d copying are - completely hidden behind computation if computation time of a layer is longer - than host-device communication time. Bulk offloading with delay and bulk reloading - with prefetch are implemented.""" - - def __init__( - self, - num_offload_group, # must be <= actual number of groups (number of commits) - num_model_group, - tensor_need_offloading_checker=(lambda t: True), - debug=False, - ) -> None: - super().__init__( - num_offload_group=num_offload_group, - tensor_need_offloading_checker=tensor_need_offloading_checker, - debug=debug, - ) - # Number of layers in the model - self.num_layers = num_model_group - # Data Structure to maintain reference to activation tensors - self.tensor_tag_to_buf = {} - # Data structure to hold the FP8/MXFP8 tensor objects - self.fp8_tensor_object_map = {} - self.float8_transpose_cache_valid = {} - # Tracking the number of layers offloaded - self.offloaded_group_count = 0 - # Core data structure that decides the window for offloading - self.layer_window_map = {} - - # Logic to make offloading load balance across computation - # for optimal CPU/GPU interconnect usage - constant = 0 - for i in range(self.num_offload_group): - self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 - if i < (self.num_layers % self.num_offload_group): - self.layer_window_map[i] += i + 1 - constant = i + 1 + self.first_layer_fwd_flag = False + self.cur_layer_id = self.total_num_of_layers + + def end_offloaded_layer_fwd(self) -> int: + """ + Call right after the forward pass of an offloaded layer. + """ + self.total_num_of_layers += 1 + return self.cur_layer_id + + def finish_fwd(self) -> None: + """ + Synchronization after fwd + """ + self._finish_layer_offload() + self.cur_layer_id = self.total_num_of_layers + + self.first_layer_fwd_flag = True # reset the flag after the forward pass + + def start_bwd_reloading(self) -> None: + """ + Start reloading the first two backward layers. + """ + # Reload should wait for this call. + for reloader in self.streamed_reloaders: + reloader.wait_for_main() + + def start_offloaded_layer_bwd(self, layer_num: int): + """ + Invoked when the backward pass of offloaded layer is started. + """ + if self.inside_offloaded_layer_bwd_flag: + raise RuntimeError( + "Backward of one offloaded layer started before the previous one finished. " + "This is not supported by the Transformer Engine cpu offloading. " + "We support only offloading of subsequence of sequence of consecutive layers - " + "such that the output of one is the input of the next one." + ) + self.inside_offloaded_layer_bwd_flag = True + self.cur_layer_id = layer_num + self.reloader_parity = 1 - self.reloader_parity + cur_reloader = self._get_current_reloader() + cpu_tensors = self._offload_cpu_tensors_for_each_layer.pop(layer_num, {}) + + self._total_offloaded_size -= self._get_size(cpu_tensors) + + cur_reloader.bulk_reload(cpu_tensors) # main waits for cur_reloader to finish here. + + next_reloader = self._get_next_reloader() + next_reloader.wait_for_main() # Reload of next layer should not start before this call. + + def end_offloaded_layer_bwd(self): + """ + Invoked when the backward pass of offloaded layer is finished. + """ + self.inside_offloaded_layer_bwd_flag = False + cur_reloader = self._get_current_reloader() + cur_reloader.wait_for_main() + cur_reloader.release_memory() + + # Reset the number of reloaded layers. + self.total_num_of_reloaded_layers += 1 + if self.total_num_of_reloaded_layers == self.total_num_of_layers: + self.total_num_of_reloaded_layers = self.total_num_of_layers = 0 + + def mark_can_start_offload(self, gpu_tensor: torch.Tensor): + """ + Use this helper when a tensor is produced at the very beginning of + torch.autograd.Function.forward and you want to kick off the GPU → CPU + copy sooner than save_for_backward would allow. We attach a CUDA + event to the tensor so we can later wait on it and begin the transfer + exactly when the tensor is ready. + + Notes: + - The tensor must eventually be passed to offload(); otherwise this + call is a no-op. + - Any tensor that should be transferred early must also be handed to + offload() earlier than the rest. + - For Float8Tensor instances we want to copy as early as + possible, helper calls such as update_usage may discard the + row-wise data. By tagging the event first, then invoking + update_usage, and finally calling offload(), the copy is initiated + at the correct moment and correct data are copied. + """ + self.streamed_offloader.mark_can_start_offload(gpu_tensor) + + def offload(self, gpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Starts asynchronous copy of the tensor from GPU to CPU. + Returns the CPU copy of the tensor. + """ + self.layer_fwd_with_unfinished_offload = True + cpu_tensor = self.streamed_offloader.offload(gpu_tensor) + return cpu_tensor + + def reload(self, cpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Return the GPU copy corresponding to `cpu_tensor`. + """ + cur_reloader = self._get_current_reloader() + return cur_reloader.get_reloaded(cpu_tensor) + + def clear_buffers(self): + """ + Clear the buffers for the activations. Can be used only when reuse_gpu_buffers is True. + """ + assert self.reuse_gpu_buffers, "clear_buffers is only allowed when reuse_gpu_buffers is True" + for reloader in self.streamed_reloaders: + reloader.wait_for_main() + reloader.release_buffers() + + def _get_next_reloader(self): + """ + Get the next reloader. + """ + return self.streamed_reloaders[(self.reloader_parity + 1) % 2] + + def _get_current_reloader(self): + """ + Get the current reloader. + """ + return self.streamed_reloaders[self.reloader_parity] + + def _finish_layer_offload(self): + """ + Finish offloading of the previous layer. + """ + self.layer_fwd_with_unfinished_offload = False + self.streamed_offloader.wait_for_offloading() + self._offload_cpu_tensors_for_each_layer[self.cur_layer_id] = \ + self.streamed_offloader.get_all_offloaded_tensors() + self._total_offloaded_size += \ + self._get_size(self._offload_cpu_tensors_for_each_layer[self.cur_layer_id]) + self.streamed_offloader.release_memory() + + def get_offloaded_total_size_mb(self): + """ + Return the size of tensors that currenlty have CPU copy, + in megabytes. + """ + # For debugging purposes + return self._total_offloaded_size / 1024 / 1024 + + def _get_size(self, cpu_tensors: list[torch.Tensor]): + """ + Get the size of the cpu tensors in bytes. + """ + total_size = 0 + for cpu_tensor in cpu_tensors: + if type(cpu_tensor) == torch.Tensor: + total_size += cpu_tensor.numel() * cpu_tensor.element_size() else: - self.layer_window_map[i] += constant - - # allocate streams and events for synchronization - self.d2h_stream = torch.cuda.Stream() - self.h2d_stream = torch.cuda.Stream() + for tensor in cpu_tensor.get_data_tensors(): + if tensor is not None: + total_size += tensor.numel() * tensor.element_size() + return total_size - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - torch_stray_tensor = isinstance( +class _CPUOffloadPackHooks: + """ + Context manager inseting hooks inside packing and unpacking the tensors. + """ + + def __init__(self, backend: _CPUOffloadBackend): + self.backend = backend + self.context = saved_tensors_hooks(self._pack_hook, self._unpack_hook) + + def __enter__(self): + self.context.__enter__() + + def __exit__(self, *args): + self.context.__exit__() + + def _pack_hook(self, tensor: torch.Tensor) -> tuple[torch.Tensor, _CPUOffloadBackend, bool]: + """ + If tensor needs to be offloaded - if it has activation_offloading attribute - offload it. + + Returns: + - tensor: the offloaded or non-offloaded tensor. + - handler: the offload handler or None - it is needed to restore the tensor in the backward pass. + """ + + if isinstance( tensor, ( torch._subclasses.fake_tensor.FakeTensor, torch._subclasses.functional_tensor.FunctionalTensor, ), - ) - - is_quantized_tensor = callable(getattr(tensor, "prepare_for_saving", None)) - - if not torch_stray_tensor: - - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - - assert tensor_tag not in self.tensor_tag_to_state - - if is_quantized_tensor: - tensor_list, _ = tensor.prepare_for_saving() - - self.tensor_tag_to_state[tensor_tag] = [] - self.tensor_tag_to_buf[tensor_tag] = [] - - self.fp8_tensor_object_map[tensor_tag] = tensor - if isinstance(tensor, Float8Tensor): - self.float8_transpose_cache_valid[tensor_tag] = getattr( - tensor, "_transpose_invalid" - ) - else: - tensor_list = [tensor] + ): + return (tensor, None) - for t in tensor_list: - if is_quantized_tensor: - self.tensor_tag_to_state[tensor_tag].append(t) - else: - self.tensor_tag_to_state[tensor_tag] = t - - if ( - self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(t) - ): - if is_quantized_tensor: - self.tensor_tag_to_buf[tensor_tag].append(t) - # Need to clear the internal data reference for the quantized tensors - tensor.clear() - else: - self.tensor_tag_to_buf[tensor_tag] = t + if self._offload_checker(tensor): + return (self.backend.offload(tensor), self.backend) else: - tensor_tag = (-1, self.torch_tensor_count) - self.torch_tensor_count += 1 - self.tensor_tag_to_state[tensor_tag] = tensor - - return tensor_tag - - def tensor_pop(self, tensor_tag, **kwargs): - """Tensor pop.""" - assert tensor_tag in self.tensor_tag_to_state - tensor = self.tensor_tag_to_state.pop(tensor_tag) - - # Handling the quantized tensor case specially here - if isinstance(tensor, list): - self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) - tensor = self.fp8_tensor_object_map.pop(tensor_tag) - - self.tensor_tag_to_buf.pop(tensor_tag, None) - - # the tensor should have been copied back in on_group_commit_backward() - # which invokes bulk_reload_group. - assert not isinstance(tensor, tuple) + return (tensor, None) + + def _unpack_hook(self, tensor_handler: tuple[torch.Tensor, _CPUOffloadBackend, bool]): + """ + Unpacks the tensor and the offload handler. + """ + tensor, handler = tensor_handler + if handler is not None: + tensor.offload_handler = handler + tensor = self.backend.reload(tensor) return tensor + + def _offload_checker(self, tensor: torch.Tensor): + """ + Check if the tensor should be offloaded. + """ + # We do not offload parameters/weights. + if isinstance(tensor, torch.nn.Parameter): + return False + + # Sometimes weights are processed inside the TransformerEngine layer, + # we do not want to offload them. + if hasattr(tensor, "is_weight"): + return False + + # We do not offload too small tensors. + if tensor.numel() < MIN_TENSOR_SIZE_TO_OFFLOAD: + return False + + return True + +class _SwitchCPUOffloadHandler: + """ + Context manager to switch the CPU offload handler. + """ + def __init__(self, backend: _CPUOffloadBackend): + self.backend = backend - def bulk_offload_group(self, group_to_offload): - """Bulk offload group.""" - with torch.cuda.stream(self.d2h_stream): - for tensor_tag, state in self.tensor_tag_to_state.items(): - group_id, _ = tensor_tag - if group_id == group_to_offload: - assert not isinstance(state, tuple) - - is_quantized_tensor = isinstance(state, list) - - if is_quantized_tensor: - tensor_list = state - self.tensor_tag_to_state[tensor_tag] = [] - else: - tensor_list = [state] - - for tensor_on_device in tensor_list: - # `tensor_offloaded` is a hacky way of dealing with columnwise-only - # quantized tensors for CPU offloading. The complication is due to - # the `rowwise_data` being `None`. The offloading checker incorrectly - # returns `False` and the entire `state` ([None, columnwise_tensor]) - # is added to the tensor tag state dict. A better design would change - # how quantized tensors are kept track of in the offload handler. - # Currently at every stage it is ensured that a quantized tensor is a - # list whereas a non-quantized tensor is standalone object, which is - # not good! TODO(@sanandaraj5597) - tensor_offloaded = False - # if offload, return the reference to cpu copy - if self.tensor_need_offloading_checker(tensor_on_device): - tensor_offloaded = True - state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) - if is_quantized_tensor: - if tensor_offloaded: - self.tensor_tag_to_state[tensor_tag].append(state) - else: - self.tensor_tag_to_state[tensor_tag].append(tensor_on_device) - else: - self.tensor_tag_to_state[tensor_tag] = state - - def synchronize_on_group_commit_forward(self, current_group): - """Synchronize on group commit forward.""" - - # For the first group, kickstart the offload after we have - # the first compute completion - if current_group == 0: - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - self.bulk_offload_group(current_group) - - # Window map data structure helps us synchronize based on number - # of layers offloaded - if self.layer_window_map[self.offloaded_group_count] == current_group: - - # Stream synchronization both ways - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.d2h_stream) - - # Time to free the activation memory after usage - for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): - if tensor_tag[0] == self.offloaded_group_count: - if hasattr(tensor_buf, "needs_force_clear"): - # Need to clear activation tensor - sometimes references persist in the code. - # This is the case for example with the Float8TensorBase class, - # which is saved directly inside the ctx while its internal tensors are - # saved inside save_for_backward. - tensor_buf.data = torch.Tensor() - # Release the pointer to the tensor - self.tensor_tag_to_buf[tensor_tag] = None - - # Time to offload the next group - if self.offloaded_group_count < (self.num_offload_group - 1): - self.bulk_offload_group(self.offloaded_group_count + 1) - - # Increment the offload group count to keep track - self.offloaded_group_count += 1 - - def on_group_commit_forward(self): - """This function will cause host device synchronization""" - # handle synchronization events - self.synchronize_on_group_commit_forward(self.current_group) - - super().on_group_commit_forward() - - def bulk_reload_group(self, group_to_reload): - """Bulk reload group.""" - assert group_to_reload < self.num_offload_group - - with torch.cuda.stream(self.h2d_stream): - # move back tensors - for tensor_label, state in self.tensor_tag_to_state.items(): - group_id, _ = tensor_label - if group_id == group_to_reload: - if isinstance(state, tuple): - recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) - self.tensor_tag_to_state[tensor_label] = recovered_tensor - elif isinstance(state, list): - tensor_list = [] - for state_tuple in state: - if isinstance(state_tuple, tuple): - tensor_list.append( - SynchronizedGroupOffloadHandler.reload(state_tuple) - ) - else: - tensor_list.append(state_tuple) - _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list) - if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): - self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( - self.float8_transpose_cache_valid.pop(tensor_label) - ) - self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( - tensor_label - ) - - def on_group_commit_backward(self): - # first decrement the current group. - # after last commit in forward, the group will +1; in backward it -1. - # Finally it should be decremented to 0. - self.current_group -= 1 - assert self.current_group >= 0 - - # Layer window data structure helps us to reload at right times - if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: - - # Stream synchronization both ways - self.h2d_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.h2d_stream) - - # Time to reload the next group - self.bulk_reload_group(self.offloaded_group_count - 1) - - # Decrease the offloading group counter - self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 - - # Last group computation needs to wait till all the reloads complete - if self.current_group == 0: - torch.cuda.current_stream().wait_stream(self.h2d_stream) - self.offloaded_group_count = 0 - + def __enter__(self): + global CURRENT_CPU_OFFLOAD_HANDLER + self.previous_backend = CURRENT_CPU_OFFLOAD_HANDLER + CURRENT_CPU_OFFLOAD_HANDLER = self.backend + + def __exit__(self, *args): + global CURRENT_CPU_OFFLOAD_HANDLER + CURRENT_CPU_OFFLOAD_HANDLER = self.previous_backend + +class CPUOffload: + """ + The CPUOffload class enables asynchronous offloading of activations. + If we have n consecutive transformer layers, we can choose some of them to be offloaded. + The offloading of the next layer begins after the offloading of the previous layer has finished. + The forward pass of the last layer starts after the offloading of the last layer has finished. + During the backward pass, the reload begins after the gradients of the last layer are computed. + Only one layer is reloaded at a time; the reload of the next layer starts after the backward pass of the previous + offloaded layer has begun. + + This ensures that if k out of n identical layers are offloaded, then at most n - k activations are present in memory simultaneously. + We recommend offloading 1 out of every x layers for a sufficiently large x. This will ensure that computation and offloading + are fully overlapped, reducing memory usage. + + Each layer must be wrapped with an instance of the CPUOffload class - activation offloading + for a particular layer is enabled/disabled by the offload_activations parameter. The last layer needs + to be wrapped with the is_last_layer parameter set to True - activations of this layer cannot be offloaded. + + CPUOffload supports all torch.nn.Module, not only these provided by the Transformer Engine. + + The last layer must have offload_activations=False. CPUOffload supports only sequences of torch.nn.Module; + other graph structures are not supported. CPUOffload supports multiple autograd graphs - it can be used + for pipeline parallelism, for example. + + + Example: + -------- + ```python + # ... + cpu_offload_wrapper = CPUOffload() + + # Wrap all transformer layers, not just these you want to offload. + # It enables optimal synchronization. + layer1 = cpu_offload_wrapper(layer1, offload_activations=True) + layer2 = cpu_offload_wrapper(layer2, offload_activations=False) + layer3 = cpu_offload_wrapper(layer3, is_last_layer=True) + + x2 = layer1(x1) + x3 = layer2(x2) + y = layer3(x3) + y.sum().backward() + + # ... + + Parameters + ---------- + reuse_gpu_buffers : bool, default = False + Re-use the same GPU buffers when reloading tensors. All offloaded + layers must produce activations of identical shapes, or an + assertion will be raised. + ``` + """ + def __init__(self, reuse_gpu_buffers: bool = False): + self.backend = _CPUOffloadBackend(reuse_gpu_buffers) + + self.pack_hooks = _CPUOffloadPackHooks(self.backend) + self.switch_cpu_offload_handler = _SwitchCPUOffloadHandler(self.backend) + + def __call__(self, module, offload_activations: bool = False, is_last_layer: bool = False): + """ + Wraps the function, which activation is offloaded. + + Parameters + ---------- + module : torch.nn.Module + The module, which activation is offloaded. + offload_activations : bool + If True, the activation is offloaded. + + Returns + ------- + torch.nn.Module + The wrapped module. + """ + assert not is_last_layer or not offload_activations, "Last layer activations cannot be offloaded." + + # The module is wrapped into CPUOffloadModule, + # and the hooks are registered on the wrapped module. + class CPUOffloadModule(torch.nn.Module): + def __init__(self, module: torch.nn.Module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + cpu_offload_module = CPUOffloadModule(module) + + def forward_pre_hook(model, input): + self.switch_cpu_offload_handler.__enter__() + self.backend.start_offloaded_layer_fwd() + self.pack_hooks.__enter__() + + def forward_hook(model, input, output): + self.switch_cpu_offload_handler.__exit__() + model.layer_id = self.backend.end_offloaded_layer_fwd() + self.pack_hooks.__exit__() + + def backward_pre_hook(model, input): + self.backend.start_offloaded_layer_bwd(model.layer_id) + + def backward_hook(model, grad_input, grad_output): + if len(grad_input) == 1 and grad_input[0] is None: + # For last layer, when input gradients are not needed, + # we do not call the backward hook, + # because it is called before the backward pass is even computed. + + # We will call the end_offloaded_layer_bwd after the backward pass is finished. + torch.autograd.variable.Variable._execution_engine.queue_callback( + self.backend.end_offloaded_layer_bwd + ) + return + self.backend.end_offloaded_layer_bwd() + + if offload_activations: + cpu_offload_module.register_forward_pre_hook(forward_pre_hook) + cpu_offload_module.register_forward_hook(forward_hook) + cpu_offload_module.register_full_backward_pre_hook(backward_pre_hook) + cpu_offload_module.register_full_backward_hook(backward_hook) + if is_last_layer: + cpu_offload_module.register_forward_pre_hook(lambda *args: self.backend.finish_fwd()) + cpu_offload_module.register_full_backward_hook(lambda *args: self.backend.start_bwd_reloading()) + + return cpu_offload_module + +CURRENT_CPU_OFFLOAD_HANDLER = None def get_cpu_offload_context( enabled: bool = False, @@ -552,35 +623,11 @@ def get_cpu_offload_context( offload_weights: bool = False, ): """ - This function returns the CPU Offload context and the synchronizer function that needs to be - used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. - - Usage: - - .. code-block:: python - - cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) - - with cpu_offload_context: - te_layer.forward(inp_tensor) - cpu_offload_synchronizer() - - Parameters - ---------- - enabled: bool, default = `False` - When set to True, CPU Offloading functionality is enabled. - num_layers: int, default = 1 - Determines the number of transformer layers - you want to offload activations/weights for. - model_layers: int, default = 1 - Number of layers in the model that will be used under this context. - offload_activations: bool, default = `True` - When set to `True`, offloads the activations for the TE layer. - offload_weights: bool, default = `True` - When set to `True`, offloads the weights for the TE layer. - + Legacy offloading API, will be removed in the future. """ + print("[WARNING] get_cpu_offload_context is deprecated. Use CPUOffload instead.") + if not offload_weights and not offload_activations: raise ValueError( "CPU Offloading is enabled while it is not " @@ -598,25 +645,91 @@ def get_cpu_offload_context( # Weights offloading is deprecated but we maintain backward compatibility by doing nothing. if not offload_activations: - return nullcontext(), lambda x: x - - def tensor_need_offloading_checker_activations(tensor): - return hasattr(tensor, "activation_offloading") - - tensor_need_offloading_checker = tensor_need_offloading_checker_activations - - cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( - num_offload_group=num_layers, - num_model_group=model_layers, - tensor_need_offloading_checker=tensor_need_offloading_checker, - ) + return contextlib.nullcontext(), lambda x: x + + + class _CpuOffloadContext(contextlib.ContextDecorator): + def __init__(self, backend: _CPUOffloadBackend): + self.backend = backend + self.previous_backend = None + + self.current_layer = 0 + self.offload_layer = {} # int -> bool + self.pack_hooks = _CPUOffloadPackHooks(self.backend) + self.switch_cpu_offload_handler = _SwitchCPUOffloadHandler(self.backend) + + self.offload_layer = self._get_layers_to_offload(num_layers, model_layers) + + def _get_layers_to_offload(self, num_layers_to_offload: int, model_layers: int): + offload_layer = {} + offload_layer[0] = True + for i in range(1, model_layers): + offload_layer[i] = False + constant = 0 + for i in range(num_layers_to_offload - 1): + layer_to_offload = ((model_layers // num_layers_to_offload) * (i + 1)) - 1 + if i < (model_layers % num_layers_to_offload): + layer_to_offload += i + 1 + constant = i + 1 + else: + layer_to_offload += constant + + offload_layer[layer_to_offload] = True + + return offload_layer + + def __enter__(self): + if self.offload_layer[self.current_layer]: + self.switch_cpu_offload_handler.__enter__() + self.pack_hooks.__enter__() + self.backend.start_offloaded_layer_fwd() + + + def __exit__(self, *args): + if self.offload_layer[self.current_layer]: + self.switch_cpu_offload_handler.__exit__() + self.pack_hooks.__exit__() + self.backend.end_offloaded_layer_fwd() + + if self.current_layer == model_layers - 1: + # finish the forward pass + self.backend.finish_fwd() + + self.current_layer += 1 + + def synchronization_function(self, tensor): + assert tensor.requires_grad == True + + def hook(_): + self.current_layer -= 1 + if self.current_layer < 0: + return + + if self.current_layer == model_layers - 1: + # start reloading after the last layer bwd + self.backend.start_bwd_reloading() + + if self.offload_layer.get(self.current_layer + 1, False): + self.backend.end_offloaded_layer_bwd() + + if self.offload_layer[self.current_layer]: + self.backend.start_offloaded_layer_bwd(self.current_layer) + + if self.current_layer == 0: + torch.autograd.variable.Variable._execution_engine.queue_callback( + self.backend.end_offloaded_layer_bwd + ) - def group_prefetch_offload_commit_async(tensor): - return group_prefetch_offload_commit(tensor, cpu_offload_handler) + tensor.grad_fn.register_prehook(hook) + return tensor + + backend = _CPUOffloadBackend() + cpu_offload_context = _CpuOffloadContext(backend) if enabled: return ( - CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), - group_prefetch_offload_commit_async, + cpu_offload_context, + cpu_offload_context.synchronization_function, ) - return nullcontext(), group_prefetch_offload_commit_async + else: + return contextlib.nullcontext(), lambda x: x diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a9a8e61e2c..0e02473c62 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -41,7 +41,7 @@ from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..cpu_offload import is_cpu_offload_enabled +from ..cpu_offload import is_cpu_offload_enabled, mark_is_weight from ..tensor.quantized_tensor import ( QuantizedTensor, @@ -175,16 +175,19 @@ def forward( input_quantizers[i].calibrate(inputmats[i]) for i in range(num_gemms): weight_quantizers[i].calibrate(weights[i]) + + if cpu_offloading: + mark_is_weight(*weights_fp8, *weights) if is_grad_enabled: ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] # TODO: update after #1638 is merged. # pylint: disable=fixme - if weight_requires_grad: - for inputmat in inputmats: - if isinstance(inputmat, QuantizedTensor): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + #if weight_requires_grad: + # for inputmat in inputmats: + # if isinstance(inputmat, QuantizedTensor): + # inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) if inp.requires_grad: for weight in weights_fp8: if isinstance(weight, QuantizedTensor): @@ -725,6 +728,7 @@ def forward( else: linear_fn = _GroupedLinear.forward args = [None] + print(is_cpu_offload_enabled()) args += ( inp, m_splits, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index eb207e6519..bc982e1eae 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -66,7 +66,7 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ..cpu_offload import is_cpu_offload_enabled, mark_can_start_offload from ..cpp_extensions import ( general_gemm, @@ -151,6 +151,8 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) nvtx_range_pop(f"{nvtx_label}.norm_input_cast") + + mark_can_start_offload(inputmat) tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag_fprop = ( @@ -210,6 +212,7 @@ def forward( fwd_ln_sm_margin, zero_centered_gamma, ) + mark_can_start_offload(ln_out) ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out @@ -361,9 +364,6 @@ def forward( if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) - if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) - # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves @@ -379,15 +379,11 @@ def forward( nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - - if ctx.grad_added_to_main_grad: - # If you are passing torch.nn.Parameter through the Torch hooks, you will - # get back torch.Tensor. Torch rips off the Parameter wrapper. - # You need to preserve the weight object to have all the attributes user - # sets for the weights. Because of this, it is not recommended to offload - # weights if weights are externally touched outside this module - ctx.weight_object = weight + weightmat.is_weight = True + weight.is_weight = True + bias.is_weight = True + ln_weight.is_weight = True + ln_bias.is_weight = True tensors_to_save, tensor_objects = prepare_for_saving( inputmat, @@ -511,6 +507,7 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None @@ -536,14 +533,6 @@ def backward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") - # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, - # we need to connect them into one. - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: - origin_weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - origin_weight.main_grad = main_grad - ctx.ub_obj_gradout = None ub_obj_dgrad = None ub_obj_wgrad = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6f423f5534..397e76af37 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -67,7 +67,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ..cpu_offload import is_cpu_offload_enabled, mark_can_start_offload, mark_is_weight from ..tensor.quantized_tensor import ( QuantizedTensor, Quantizer, @@ -224,6 +224,8 @@ def forward( ln_weight = cast_if_needed(ln_weight, activation_dtype) if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + mark_can_start_offload(inputmat) + # for fp8 DelayedScaling: layernorm output = FP8 # only output of the linear is returned @@ -276,6 +278,7 @@ def forward( fwd_ln_sm_margin, zero_centered_gamma, ) + mark_can_start_offload(ln_out) ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out @@ -431,6 +434,7 @@ def forward( else: act_out = activation_func(fc1_out, fc2_input_quantizer) + mark_can_start_offload(fc1_out, act_out) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -477,10 +481,6 @@ def forward( if not is_grad_enabled: clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) else: - if cpu_offloading: - mark_activation_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out - ) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -508,6 +508,12 @@ def forward( if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None + + if cpu_offloading: + mark_is_weight( + ln_weight, ln_bias, fc1_weight_final, fc1_weight, \ + fc1_bias, fc2_weight_final, fc2_weight, fc2_bias) + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index aaca7d4fe6..4219f2097c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -66,7 +66,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ..cpu_offload import is_cpu_offload_enabled, mark_can_start_offload, mark_is_weight from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.utils import any_feature_enabled @@ -135,6 +135,7 @@ def forward( # Note: Cast to expected dtype and perform tensor-parallel communication nvtx_range_push(f"{nvtx_label}.input_cast_comm") + mark_can_start_offload(inp) inputmat = inp inputmat_total = None with_input_all_gather_nccl = ( @@ -312,9 +313,6 @@ def forward( if isinstance(weightmat, QuantizedTensor): weightmat.update_usage(columnwise_usage=True) - if cpu_offloading and saved_inputmat is not None: - mark_activation_offload(saved_inputmat) - # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights nvtx_range_push(f"{nvtx_label}.fsdp_scatter") @@ -337,6 +335,7 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight + mark_is_weight(weight, weightmat, bias) # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -378,7 +377,8 @@ def forward( ctx.requires_wgrad = weight.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - ctx.owns_input = saved_inputmat is not inp + #ctx.owns_input = saved_inputmat is not inp + ctx.owns_input = False if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -404,6 +404,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + print("bwd ") # pylint: disable=missing-function-docstring # NVTX label for profiling diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 18f7e2031a..3ec24f98de 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -374,7 +374,7 @@ def _initialize_state( if store_param_remainders: data = torch.zeros_like(param, dtype=torch.int16) else: - data = torch.empty_like(param, dtype=dtype) + data = torch.empty_like(param.dequantize(), dtype=dtype) if zero_buffer: data.zero_() diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 0840d57863..472498cf62 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -112,6 +112,11 @@ def get_data_tensors(self): """Get this Tensor's data.""" return self._rowwise_data, self._columnwise_data + def set_data_tensors(self, rowwise_data: torch.Tensor, columnwise_data: torch.Tensor): + """Set this Tensor's data.""" + self._rowwise_data = rowwise_data + self._columnwise_data = columnwise_data + def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: """Takes dequantized columnwise data and permutes to a rowwise shape""" if columnwise_dq.dim() < 2: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index f37055bde4..f97f5dc781 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -127,6 +127,11 @@ def restore_from_saved( def get_data_tensors(self): """Get this Tensor's data.""" return self._data, self._transpose + + def set_data_tensors(self, data: torch.Tensor, transpose: torch.Tensor): + """Set this Tensor's data.""" + self._data = data + self._transpose = transpose def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Dequantize to a higher precision.""" diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index de185844c2..690ec2b51d 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -130,6 +130,11 @@ def restore_from_saved( def get_data_tensors(self): """Get this Tensor's data.""" return self._rowwise_data, self._columnwise_data + + def set_data_tensors(self, rowwise_data: torch.Tensor, columnwise_data: torch.Tensor): + """Set this Tensor's data.""" + self._rowwise_data = rowwise_data + self._columnwise_data = columnwise_data def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Dequantize to a higher precision.""" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 7e101b2612..ab9e99dd59 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -377,6 +377,42 @@ def clone(self) -> Float8BlockwiseQTensor: }, ) + def empty_like(self, *args, **kwargs): + """Create a new empty tensor with the same shape and type as this tensor""" + new_rowwise_data = ( + torch.empty_like(self._rowwise_data, *args, **kwargs) + if self._rowwise_data is not None + else None + ) + new_columnwise_data = ( + torch.empty_like(self._columnwise_data, *args, **kwargs) + if self._columnwise_data is not None + else None + ) + new_rowwise_scale_inv = ( + torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) + if self._rowwise_scale_inv is not None + else None + ) + new_columnwise_scale_inv = ( + torch.empty_like(self._columnwise_scale_inv, *args, **kwargs) + if self._columnwise_scale_inv is not None + else None + ) + + return Float8BlockwiseQTensor( + shape=self.shape, + dtype=self.dtype, + fp8_dtype=self._fp8_dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=new_rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=new_columnwise_scale_inv, + quantizer=self._quantizer, + is_2D_scaled=self._is_2D_scaled, + requires_grad=self.requires_grad, + ) + def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ViewFunc.apply(self, shape) @@ -409,6 +445,28 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8BlockwiseQTensor.make_like(tensor) + if func == torch.ops.aten.copy_.default: + dst, src = args[0], args[1] + if isinstance(src, Float8BlockwiseQTensor) and isinstance(dst, Float8BlockwiseQTensor): + if dst._rowwise_data is not None: + dst._rowwise_data.copy_(src._rowwise_data, *args[2:]) + if dst._rowwise_scale_inv is not None: + dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv, *args[2:]) + if dst._columnwise_data is not None: + dst._columnwise_data.copy_(src._columnwise_data, *args[2:]) + if dst._columnwise_scale_inv is not None: + dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv, *args[2:]) + return dst + elif func == torch.ops.aten.is_pinned.default: + if args[0]._rowwise_data is not None: + return args[0]._rowwise_data.is_pinned() + elif args[0]._columnwise_data is not None: + return args[0]._columnwise_data.is_pinned() + else: + raise RuntimeError( + "Cannot check if pinned for Float8BlockwiseQTensor with no data." + ) + # Default case return super().__torch_dispatch__(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index a37eb4f632..ba2a06e424 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -474,6 +474,27 @@ def clone(self) -> Float8Tensor: }, ) + def empty_like(self, *args, **kwargs): + """Create a new empty tensor with the same shape and type as this tensor""" + new_data = torch.empty_like(self._data, *args, **kwargs) if self._data is not None else None + new_transpose = ( + torch.empty_like(self._transpose, *args, **kwargs) + if self._transpose is not None + else None + ) + new_scale_inv = torch.empty_like(self._scale_inv, *args, **kwargs) + device = new_scale_inv.device + return Float8Tensor( + shape=self.shape, + dtype=self.dtype, + data=new_data, + fp8_scale_inv=new_scale_inv, + fp8_dtype=self._fp8_dtype, + data_transpose=new_transpose, + quantizer=self._quantizer, + device=device, + ) + def view(self, *shape: Tuple[int]) -> Float8Tensor: # pylint: disable=missing-function-docstring return _ViewFunc.apply(self, shape) @@ -598,14 +619,28 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return cls.detach(args[0]) if func == torch.ops.aten.clone.default: return cls.clone(args[0]) + if func == torch.ops.aten.is_pinned.default: + if args[0]._data is not None: + return args[0]._data.is_pinned() + elif args[0]._transpose is not None: + return args[0]._transpose.is_pinned() + else: + raise RuntimeError( + "Cannot check if pinned for Float8Tensor with no data and no transpose." + ) if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): - dst._data.copy_(src._data.detach()) - dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) - if src._transpose is not None or dst._transpose is not None: - dst._create_transpose() + if dst._data is not None: + dst._data.copy_(src._data, *args[2:]) + if dst._scale_inv is not None: + dst._scale_inv.copy_(src._scale_inv, *args[2:]) + if dst._transpose is not None and not dst._transpose_invalid: + if not src._transpose_invalid: + dst._transpose.copy_(src._transpose, *args[2:]) + else: + dst._create_transpose() return dst elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d2124f8e1e..706d5d2f96 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -277,6 +277,39 @@ def clone(self) -> MXFP8Tensor: }, ) + def empty_like(self, *args, **kwargs): + """Create a new empty tensor with the same shape and type as this tensor""" + new_rowwise_data = ( + torch.empty_like(self._rowwise_data, *args, **kwargs) + if self._rowwise_data is not None + else None + ) + new_columnwise_data = ( + torch.empty_like(self._columnwise_data, *args, **kwargs) + if self._columnwise_data is not None + else None + ) + new_rowwise_scale_inv = ( + torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) + if self._rowwise_scale_inv is not None + else None + ) + new_columnwise_scale_inv = ( + torch.empty_like(self._columnwise_scale_inv, *args, **kwargs) + if self._columnwise_scale_inv is not None + else None + ) + return MXFP8Tensor( + shape=self.shape, + dtype=self.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=new_rowwise_scale_inv, + fp8_dtype=self._fp8_dtype, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=new_columnwise_scale_inv, + quantizer=self._quantizer, + ) + def view(self, *shape: Tuple[int]) -> MXFP8Tensor: # pylint: disable=missing-function-docstring return _ViewFunc.apply(self, shape) @@ -330,6 +363,26 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_dtype=tensor._fp8_dtype, ) + if func == torch.ops.aten.copy_.default: + dst, src = args[0], args[1] + # Just copy FP8 attrs if copying between Float8Tensors + if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor): + if dst._rowwise_data is not None: + dst._rowwise_data.copy_(src._rowwise_data, *args[2:]) + if dst._rowwise_scale_inv is not None: + dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv, *args[2:]) + if dst._columnwise_data is not None: + dst._columnwise_data.copy_(src._columnwise_data, *args[2:]) + if dst._columnwise_scale_inv is not None: + dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv, *args[2:]) + return dst + elif func == torch.ops.aten.is_pinned.default: + if args[0]._rowwise_data is not None: + return args[0]._rowwise_data.is_pinned() + elif args[0]._columnwise_data is not None: + return args[0]._columnwise_data.is_pinned() + else: + raise RuntimeError("Cannot check if pinned for MXFP8Tensor with no data.") # Default case return super().__torch_dispatch__(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index aa433e58bc..5fa6a01348 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -31,6 +31,7 @@ def prepare_for_saving( t, t_obj = tensor.prepare_for_saving() tensor_list.extend(t) tensor_objects_list.append(t_obj) + return tensor_list, tensor_objects_list @@ -255,7 +256,7 @@ class QuantizedTensor(torch.Tensor): """ - def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False): + def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False, device: Optional[torch.device] = None): # We are assuming only contiguous tensors stride = _stride_from_shape(shape) instance = torch.Tensor._make_wrapper_subclass( @@ -266,7 +267,7 @@ def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: boo dtype=dtype, layout=torch.strided, requires_grad=requires_grad, - device=torch.cuda.current_device(), + device=torch.cuda.current_device() if device is None else device, ) return instance @@ -311,6 +312,15 @@ def update_usage( def clear(self): """Deallocate this tensor's memory. Typically not needed and must be used carefully""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement clear function" + ) + + def empty_like(self, *args, **kwargs): + """Create a new empty tensor with the same shape and type as this tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement empty_like function" + ) def __repr__(self, *, tensor_contents=None) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" @@ -364,6 +374,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") + # Empty like op + if func == torch.ops.aten.empty_like.default: + tensor = args[0] + return tensor.empty_like(*args[1:], **kwargs) + def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): return arg.dequantize(dtype=arg.dtype)