From c8c79102f02f66f7fb9576d51653843a8bf8ce6f Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Mon, 2 Jan 2023 15:51:03 +0800 Subject: [PATCH] [autoparallel] patch torch.flatten metainfo for autoparallel (#2247) * [autoparallel] patch torch.flatten --- .../auto_parallel/meta_profiler/meta_registry/activation.py | 4 ++-- .../auto_parallel/meta_profiler/meta_registry/pooling.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 909232e61cb6..774457f7d3b6 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -30,7 +30,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis input_tensor = args[0].data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - inplace = kwargs.get("inplace", False) + is_inplace = kwargs.get("inplace", False) # construct input args for forward fwd_in_args = [input_tensor] @@ -51,7 +51,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis # NOTE: the inplace ReLU don't have forward memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward fwd_memory_cost = MemoryCost( - activation=activation_size(input_tensor) if inplace else activation_size([output_tensor, input_tensor]), + activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]), parameter=0, temp=0, buffer=0) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index 3ecabb6dcf0d..79780c92eed4 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -14,6 +14,7 @@ @meta_register.register(torch.nn.AdaptiveAvgPool1d) @meta_register.register(torch.nn.AdaptiveAvgPool2d) @meta_register.register(torch.nn.AdaptiveAvgPool3d) +@meta_register.register(torch.flatten) def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: """Meta info for AdaptiveAvgPool The aten graph of AdaptiveAvgPool is @@ -32,6 +33,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, input_tensor = args[0].data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + is_inplace = kwargs.get("inplace", False) # construct forward args for flop mapping fwd_in_args = [input_tensor] @@ -51,8 +53,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensor)) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor)) + fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor)) + bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor)) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)