We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chunk.py
from typing import Optional import oneflow as flow from oneflow.python.framework.tensor import Tensor from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.framework.tensor import register_tensor_op from oneflow.python.nn.module import Module class Chunk(Module): def __init__(self) -> None: super().__init__() def forward(self, input, chunks, dim): if dim is not None: assert input.shape[dim] > 0, "chunk expects at least a 1-dimensional tensor" assert chunks > 0, "chunk expects `chunks` to be greater than 0" channel = input.dim() dim_size = input.shape[dim] chunk_size = dim_size / chunks if dim_size % chunks == 0 else (int)(dim_size / chunks) last_chunk_size = dim_size / chunks if dim_size % chunks == 0 else dim_size - (chunk_size * (chunks - 1)) chunk_dim_dict = {} tup_ndim = [] splits = [] for chunk in range(0, chunks): if dim_size % chunks == 0: start = chunk * chunk_size stop = (chunk + 1) * chunk_size else: start = chunk * chunk_size if chunk < chunks - 1 else chunk_size * (chunks - 1) stop = (chunk + 1) * chunk_size if chunk < chunks - 1 else dim_size step = 1 chunk_dim_dict.setdefault(dim, []).append([int(start), int(stop), step]) for k, v in chunk_dim_dict.items(): for v_chunk in v: tup_list = [] for i in range(0, channel): if i != dim: tup_list.append([None, None, None]) else: tup_list.append(v_chunk) splits.append( flow.experimental.slice(input, slice_tup_list=tup_list) ) return splits @oneflow_export("chunk") @register_tensor_op("chunk") @experimental_api def chunk_op(input, chunks, dim): r"""Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks. Args: input (oneflow.experimental.Tensor): The tensor to split. chunks (int): Number of chunks to return. dim (int): Dimension along which to split the tensor. Returns: List of Tensors. For example: .. code-block:: python >>> import oneflow.experimental as flow >>> import numpy as np >>> flow.enable_eager_execution() >>> np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32) >>> input = flow.Tensor(np_arr) >>> of_out = flow.chunk(input, chunks=3, dim=2) >>> chunks = 3 >>> of_out_shape = [] >>> for i in range(0, chunks): ... of_out_shape.append(of_out[i].numpy().shape) >>> of_out_shape [(5, 3, 2, 9), (5, 3, 2, 9), (5, 3, 2, 9)] >>> np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32) >>> input = flow.Tensor(np_arr) >>> of_out = flow.chunk(input, chunks=4, dim=3) >>> chunks = 4 >>> of_out_shape = [] >>> for i in range(0, chunks): ... of_out_shape.append(of_out[i].numpy().shape) >>> of_out_shape [(5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 3)] """ return Chunk()(input, chunks, dim) if __name__ == "__main__": import doctest doctest.testmod(raise_on_error=False)
test_chunk.py
import unittest from collections import OrderedDict import numpy as np import oneflow.experimental as flow from test_util import GenArgList def _test_2_dim_forward(test_case, device): np_arr = np.random.randn(2, 3).astype(np.float32) input = flow.Tensor(np_arr, device=flow.device(device)) dim = 0 chunks = 2 of_out = flow.chunk(input, chunks, dim) np_out_shape = [(1, 3), (1, 3)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 1 chunks = 2 of_out = flow.chunk(input, chunks, dim) np_out_shape = [(2, 1), (2, 2)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 1 chunks = 3 of_out = flow.chunk(input, chunks, dim) np_out_shape = [(2, 1), (2, 1), (2, 1)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) def _test_2_dim_tensor_function_forward(test_case, device): np_arr = np.random.randn(2, 3).astype(np.float32) input = flow.Tensor(np_arr, device=flow.device(device)) dim = 0 chunks = 2 of_out = input.chunk(chunks, dim) np_out_shape = [(1, 3), (1, 3)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 1 chunks = 2 of_out = input.chunk(chunks, dim) np_out_shape = [(2, 1), (2, 2)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 1 chunks = 3 of_out = input.chunk(chunks, dim) np_out_shape = [(2, 1), (2, 1), (2, 1)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) def _test_4_dim_forward(test_case, device): np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32) input = flow.Tensor(np_arr, device=flow.device(device)) dim = 2 chunks = 3 of_out = flow.chunk(input, chunks, dim) np_out_shape = [(5, 3, 2, 9), (5, 3, 2, 9), (5, 3, 2, 9)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 2 chunks = 4 of_out = flow.chunk(input, chunks, dim) np_out_shape = [(5, 3, 1, 9), (5, 3, 1, 9), (5, 3, 1, 9), (5, 3, 3, 9)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 3 chunks = 3 of_out = flow.chunk(input, chunks, dim) np_out_shape = [(5, 3, 6, 3), (5, 3, 6, 3), (5, 3, 6, 3)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 3 chunks = 2 of_out = flow.chunk(input, chunks, dim) np_out_shape = [(5, 3, 6, 4), (5, 3, 6, 5)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 3 chunks = 4 of_out = flow.chunk(input, chunks, dim) np_out_shape = [(5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 3)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) def _test_4_dim_tensor_function_forward(test_case, device): np_arr = np.random.randn(5, 3, 6, 9).astype(np.float32) input = flow.Tensor(np_arr, device=flow.device(device)) dim = 2 chunks = 3 of_out = input.chunk(chunks, dim) np_out_shape = [(5, 3, 2, 9), (5, 3, 2, 9), (5, 3, 2, 9)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 2 chunks = 4 of_out = input.chunk(chunks, dim) np_out_shape = [(5, 3, 1, 9), (5, 3, 1, 9), (5, 3, 1, 9), (5, 3, 3, 9)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 3 chunks = 3 of_out = input.chunk(chunks, dim) np_out_shape = [(5, 3, 6, 3), (5, 3, 6, 3), (5, 3, 6, 3)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 3 chunks = 2 of_out = input.chunk(chunks, dim) np_out_shape = [(5, 3, 6, 4), (5, 3, 6, 5)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) dim = 3 chunks = 4 of_out = input.chunk(chunks, dim) np_out_shape = [(5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 3)] for i in range(0, chunks): of_out_shape = of_out[i].numpy().shape test_case.assertTrue(np.allclose(of_out_shape, np_out_shape[i], 1e-5, 1e-5)) def _test_chunk_backward(test_case, device): np_arr = np.random.randn(2, 3).astype(np.float32) input = flow.Tensor(np_arr, device=flow.device(device)) input.requires_grad = True y = flow.chunk(input, chunks=2, dim=0) z1, z2 = y[0].sum(), y[1].sum() z1.backward() z2.backward() np_grad = np.ones((2, 3)) test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad)) @unittest.skipIf( not flow.unittest.env.eager_execution_enabled(), ".numpy() doesn't work in lazy mode", ) class TestChunk(flow.unittest.TestCase): def test_chunk(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_2_dim_forward, _test_4_dim_forward, _test_2_dim_tensor_function_forward, _test_4_dim_tensor_function_forward, _test_chunk_backward ] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) if __name__ == "__main__": unittest.main()
The text was updated successfully, but these errors were encountered:
flow.chunk 思路
channel = input.dim() dim_size = input.shape[dim] chunk_size = dim_size / chunks if dim_size % chunks == 0 else (int)(dim_size / chunks) last_chunk_size = dim_size / chunks if dim_size % chunks == 0 else dim_size - (chunk_size * (chunks - 1))
chunk_dim_dict = {} tup_ndim = [] splits = [] for chunk in range(0, chunks): if dim_size % chunks == 0: start = chunk * chunk_size stop = (chunk + 1) * chunk_size else: start = chunk * chunk_size if chunk < chunks - 1 else chunk_size * (chunks - 1) stop = (chunk + 1) * chunk_size if chunk < chunks - 1 else dim_size step = 1 chunk_dim_dict.setdefault(dim, []).append([int(start), int(stop), step])
for k, v in chunk_dim_dict.items(): for v_chunk in v: tup_list = [] for i in range(0, channel): if i != dim: tup_list.append([None, None, None]) else: tup_list.append(v_chunk) splits.append( flow.experimental.slice(input, slice_tup_list=tup_list) )
Sorry, something went wrong.
No branches or pull requests
chunk.py
test_chunk.py
The text was updated successfully, but these errors were encountered: