Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem using different channels #4

Open
megalinier opened this issue Oct 1, 2019 · 4 comments
Open

Problem using different channels #4

megalinier opened this issue Oct 1, 2019 · 4 comments

Comments

@megalinier
Copy link

Hello! =)
I may have missed something, but I guess that if I want to apply your tfw.dwt2d to an image with different channels, I should be able to write something like:

size_image = 64
nb_channels = 4
tf_signal = tf.placeholder(dtype=tf.float32, shape=(size_image,size_image,nb_channels))
output = tfw.dwt2d(tf_signal, wavelet=dwtcoeffs.haar, levels=3)

But then, I get the following error:
Dimensions must be equal, but are 4 and 1 for 'conv1d_34/Conv2D' (op: 'Conv2D') with input shapes: [64,1,64,4], [1,2,1,1].

Actually, the dimensions of the filters are wrong... While there is no problem when nb_channels = 1.
Would you know how to fix this please?
Thank you very much!

Mathilde

@mathialo
Copy link
Contributor

The different channels are just supposed to be the same transform for e.g. each color, right? In that case, tfwavelets does not support that (because of how the transform is implemented, to increase GPU parallelism we implemented a 2D transform as 1D with batching, so the channel dimension is already in use).

To use channels, you sadly have to split the image into separate 1-channeled images and apply the transform separately to each channel, and stack them after.

@megalinier
Copy link
Author

Thank you for your answer. After some modifications of your programs, I was finally able to apply the 2D transform to images with different channels =) I am not sure this implementation is optimal, but it does work. If you are interested in those modifications, please let me know.

@idevede
Copy link

idevede commented May 31, 2020

Thank you for your answer. After some modifications of your programs, I was finally able to apply the 2D transform to images with different channels =) I am not sure this implementation is optimal, but it does work. If you are interested in those modifications, please let me know.

I'm interest in how to support multi-channel, if it is convenient, can you show me the modefication? Thanks~

@megalinier
Copy link
Author

megalinier commented Jun 1, 2020

I didn't find the way to upload directly the scripts... So here are all the (partially modified) functions that make the multi-channel computations possible. You can replace all the functions in the original scripts by the following ones:

import numpy as np
import tensorflow as tf


def adapt_filter(filter):
    """
    Expands dimensions of a 1d vector to match the required tensor dimensions in a TF
    graph.

    Args:
        filter (np.ndarray):     A 1D vector containing filter coefficients

    Returns:
        np.ndarray: A 3D vector with two empty dimensions as dim 2 and 3.

    """
    # Add empty dimensions for batch size and channel num
    return np.expand_dims(np.expand_dims(filter, -1), -1)


def to_tf_mat(matrices,precision_float_type):
    """
    Expands dimensions of 2D matrices to match the required tensor dimensions in a TF
    graph, and wrapping them as TF constants.

    Args:
        matrices (iterable):    A list (or tuple) of 2D numpy arrays.

    Returns:
        iterable: A list of all the matrices converted to 3D TF tensors.
    """
    result = []

    for matrix in matrices:
        result.append(
            tf.constant(np.expand_dims(matrix, 0), dtype=precision_float_type)
        )

    return result


class Filter:
    """
    Class representing a filter.

    Attributes:
        coeffs (tf.constant):      Filter coefficients
        zero (int):                Origin of filter (which index of coeffs array is
                                   actually indexed as 0).
        edge_matrices (iterable):  List of edge matrices, used for circular convolution.
                                   Stored as 3D TF tensors (constants).
    """

    def __init__(self, coeffs, zero,precision_float_type,precision_float_numpy_type):
        """
        Create a filter based on given filter coefficients

        Args:
            coeffs (np.ndarray):       Filter coefficients
            zero (int):                Origin of filter (which index of coeffs array is
                                       actually indexed as 0).
            precision_float_type:      tf.float16 ou tf.float32 ou tf.float64
        """
        self.coeffs = tf.constant(adapt_filter(coeffs), dtype=precision_float_type)

        if not isinstance(coeffs, np.ndarray):
            coeffs = np.array(self.coeffs)
        self._coeffs = coeffs.astype(precision_float_numpy_type)

        self.zero = zero

        self.edge_matrices = to_tf_mat(self._edge_matrices(),precision_float_type)

    def __getitem__(self, item):
        """
        Returns filter coefficients at requested indeces. Indeces are offset by the filter
        origin

        Args:
            item (int or slice):    Item(s) to get

        Returns:
            np.ndarray: Item(s) at specified place(s)
        """
        if isinstance(item, slice):
            return self._coeffs.__getitem__(
                slice(item.start + self.zero, item.stop + self.zero, item.step)
            )
        else:
            return self._coeffs.__getitem__(item + self.zero)


    def num_pos(self):
        """
        Number of positive indexed coefficients in filter, including the origin. Ie,
        strictly speaking it's the number of non-negative indexed coefficients.

        Returns:
            int: Number of positive indexed coefficients in filter.
        """
        return len(self._coeffs) - self.zero


    def num_neg(self):
        """
        Number of negative indexed coefficients, excluding the origin.

        Returns:
            int: Number of negative indexed coefficients
        """
        return self.zero


    def _edge_matrices(self):
        """Computes the submatrices needed at the ends for circular convolution.

        Returns:
            Tuple of 2d-arrays, (top-left, top-right, bottom-left, bottom-right).
        """
        if not isinstance(self._coeffs, np.ndarray):
            self._coeffs = np.array(self._coeffs)

        n, = self._coeffs.shape
        self._coeffs = self._coeffs[::-1]

        # Some padding is necesssary to keep the submatrices
        # from having having columns in common
        padding = max((self.zero, n - self.zero - 1))
        matrix_size = n + padding
        filter_matrix = np.zeros((matrix_size, matrix_size), dtype=np.float32)
        negative = self._coeffs[
                   -(self.zero + 1):]  # negative indexed filter coeffs (and 0)
        positive = self._coeffs[
                   :-(self.zero + 1)]  # filter coeffs with strictly positive indeces

        # Insert first row
        filter_matrix[0, :len(negative)] = negative

        # Because -0 == 0, a length of 0 makes it impossible to broadcast
        # (nor is is necessary)
        if len(positive) > 0:
            filter_matrix[0, -len(positive):] = positive

        # Cycle previous row to compute the entire filter matrix
        for i in range(1, matrix_size):
            filter_matrix[i, :] = np.roll(filter_matrix[i - 1, :], 1)

        # TODO: Indexing not thoroughly tested
        num_pos = len(positive)
        num_neg = len(negative)
        top_left = filter_matrix[:num_pos, :(num_pos + num_neg - 1)]
        top_right = filter_matrix[:num_pos, -num_pos:]
        bottom_left = filter_matrix[-num_neg + 1:, :num_neg - 1]
        bottom_right = filter_matrix[-num_neg + 1:, -(num_pos + num_neg - 1):]

        # Indexing wrong when there are no negative indexed coefficients
        if num_neg == 1:
            bottom_left = np.zeros((0, 0), dtype=np.float32)
            bottom_right = np.zeros((0, 0), dtype=np.float32)

        return top_left, top_right, bottom_left, bottom_right

class Wavelet:
    """
    Class representing a wavelet.

    Attributes:
        decomp_lp (Filter):    Filter coefficients for decomposition low pass filter
        decomp_hp (Filter):    Filter coefficients for decomposition high pass filter
        recon_lp (Filter):     Filter coefficients for reconstruction low pass filter
        recon_hp (Filter):     Filter coefficients for reconstruction high pass filter
    """


    def __init__(self, decomp_lp, decomp_hp, recon_lp, recon_hp):
        """
        Create a new wavelet based on specified filters

        Args:
            decomp_lp (Filter):    Filter coefficients for decomposition low pass filter
            decomp_hp (Filter):    Filter coefficients for decomposition high pass filter
            recon_lp (Filter):     Filter coefficients for reconstruction low pass filter
            recon_hp (Filter):     Filter coefficients for reconstruction high pass filter
        """
        self.decomp_lp = decomp_lp
        self.decomp_hp = decomp_hp
        self.recon_lp = recon_lp
        self.recon_hp = recon_hp           

def cyclic_conv1d(input_node, filter_):
    """
    Cyclic convolution

    Args:
        input_node:  Input signal (3-tensor [batch, width, in_channels])
        filter_:     Filter

    Returns:
        Tensor with the result of a periodic convolution
    """
    # Create shorthands for TF nodes
    kernel_node = filter_.coeffs
    tl_node, tr_node, bl_node, br_node = filter_.edge_matrices

    # Do inner convolution
    rows,columns,nb_channels = input_node.shape
    input_node_resh = tf.transpose(input_node, perm=[2,0,1])
    input_node_resh = tf.reshape(input_node_resh,[rows*nb_channels,columns,1])
    inner_resh = tf.nn.conv1d(input_node_resh, kernel_node[::-1], stride=1, padding='VALID')
    inner_resh = tf.reshape(inner_resh,[nb_channels,rows,columns-1])
    inner = tf.transpose(inner_resh, perm=[1,2,0])

    # Create shorthands for shapes
    input_shape = tf.shape(input_node)
    tl_shape = tf.shape(tl_node)
    tr_shape = tf.shape(tr_node)
    bl_shape = tf.shape(bl_node)
    br_shape = tf.shape(br_node)

    # Slices of the input signal corresponding to the corners
    tl_slice = tf.slice(input_node,
                        [0, 0, 0],
                        [-1, tl_shape[2], -1])
    tr_slice = tf.slice(input_node,
                        [0, input_shape[1] - tr_shape[2], 0],
                        [-1, tr_shape[2], -1])
    bl_slice = tf.slice(input_node,
                        [0, 0, 0],
                        [-1, bl_shape[2], -1])
    br_slice = tf.slice(input_node,
                        [0, input_shape[1] - br_shape[2], 0],
                        [-1, br_shape[2], -1])

    # TODO: It just werks (It's the magic of the algorithm). i.e. Why do we have to transpose?
    tl_node = tf.tile(tl_node, [nb_channels,1,1])
    tr_node = tf.tile(tr_node, [nb_channels,1,1])
    bl_node = tf.tile(bl_node, [nb_channels,1,1])
    br_node = tf.tile(br_node, [nb_channels,1,1])
    tl = tl_node @ tf.transpose(tl_slice, perm=[2, 1, 0])
    tr = tr_node @ tf.transpose(tr_slice, perm=[2, 1, 0])
    bl = bl_node @ tf.transpose(bl_slice, perm=[2, 1, 0])
    br = br_node @ tf.transpose(br_slice, perm=[2, 1, 0])

    head = tf.transpose(tl + tr, perm=[2, 1, 0])
    tail = tf.transpose(bl + br, perm=[2, 1, 0])

    return tf.concat((head, inner, tail), axis=1)


def upsample(input_node, odd=False):
    """Upsamples. Doubles the length of the input, filling with zeros

    Args:
        input_node: 3-tensor [batch, spatial dim, channels] to be upsampled
        odd:        Bool, optional. If True, content of input_node will be
                    placed on the odd indices of the output. Otherwise, the
                    content will be placed on the even indeces. This is the
                    default behaviour.

    Returns:
        The upsampled output Tensor.
    """

    columns = []
    for col in tf.unstack(input_node, axis=1):
        columns.extend([tf.expand_dims(col,1), tf.expand_dims(tf.zeros_like(col),1)])
        
    if odd:
        # https://stackoverflow.com/questions/30097512/how-to-perform-a-pairwise-swap-of-a-list
        # TODO: Understand
        # Rounds down to even number
        l = len(columns) & -2
        columns[1:l:2], columns[:l:2] = columns[:l:2], columns[1:l:2]

    # TODO: Should we actually expand the dimension?
    return tf.concat(columns, 1)


def dwt1d(input_node, wavelet, levels=1):
    """
    Constructs a TF computational graph computing the 1D DWT of an input signal.

    Args:
        input_node:     A 3D tensor containing the signal. The dimensions should be
                        [batch, signal, channels].
        wavelet:        Wavelet object
        levels:         Number of levels.

    Returns:
        The output node of the DWT graph.
    """
    # TODO: Check that level is a reasonable number
    # TODO: Check types

    coeffs = [None] * (levels + 1)

    last_level = input_node

    for level in range(levels):
        lp_res = cyclic_conv1d(last_level, wavelet.decomp_lp)[:, ::2, :]
        hp_res = cyclic_conv1d(last_level, wavelet.decomp_hp)[:, 1::2, :]

        last_level = lp_res
        coeffs[levels - level] = hp_res

    coeffs[0] = last_level
    return tf.concat(coeffs, axis=1)


def dwt2d(input_node, wavelet, levels=1):
    """
    Constructs a TF computational graph computing the 2D DWT of an input signal.

    Args:
        input_node:     A 3D tensor containing the signal. The dimensions should be
                        [rows, cols, channels].
        wavelet:        Wavelet object.
        levels:         Number of levels.

    Returns:
        The output node of the DWT graph.
    """
    # TODO: Check that level is a reasonable number
    # TODO: Check types

    coeffs = [None] * levels
    
    last_level = input_node
    m, n = int(input_node.shape[0]), int(input_node.shape[1])

    for level in range(levels):
        local_m, local_n = m // (2 ** level), n // (2 ** level)

        first_pass = dwt1d(last_level, wavelet, 1)
        second_pass = tf.transpose(
            dwt1d(
                tf.transpose(first_pass, perm=[1, 0, 2]),
                wavelet,
                1
            ),
            perm=[1, 0, 2]
        )

        last_level = tf.slice(second_pass, [0, 0, 0], [local_m // 2, local_n // 2, -1])
        coeffs[level] = [
            tf.slice(second_pass, [local_m // 2, 0, 0], [local_m // 2, local_n // 2, -1]),
            tf.slice(second_pass, [0, local_n // 2, 0], [local_m // 2, local_n // 2, -1]),
            tf.slice(second_pass, [local_m // 2, local_n // 2, 0],
                     [local_m // 2, local_n // 2, -1])
        ]

    for level in range(levels - 1, -1, -1):
        upper_half = tf.concat([last_level, coeffs[level][0]], 0)
        lower_half = tf.concat([coeffs[level][1], coeffs[level][2]], 0)

        last_level = tf.concat([upper_half, lower_half], 1)

    return last_level


def idwt1d(input_node, wavelet, levels=1):
    """
    Constructs a TF graph that computes the 1D inverse DWT for a given wavelet.

    Args:
        input_node (tf.placeholder):             Input signal. A 3D tensor with dimensions
                                                 as [batch, signal, channels]
        wavelet (tfwavelets.dwtcoeffs.Wavelet):  Wavelet object.
        levels (int):                            Number of levels.

    Returns:
        Output node of IDWT graph.
    """
    m, n = int(input_node.shape[0]), int(input_node.shape[1])

    first_n = n // (2 ** levels)
    last_level = tf.slice(input_node, [0, 0, 0], [m, first_n, -1])

    for level in range(levels - 1, -1 , -1):
        local_n = n // (2 ** level)

        detail = tf.slice(input_node, [0, local_n//2, 0], [m, local_n//2, -1])

        lowres_padded = upsample(last_level, odd=False)
        detail_padded = upsample(detail, odd=True)

        lowres_filtered = cyclic_conv1d(lowres_padded, wavelet.recon_lp)
        detail_filtered = cyclic_conv1d(detail_padded, wavelet.recon_hp)

        last_level = lowres_filtered + detail_filtered

    return last_level


def idwt2d(input_node, wavelet, levels=1):
    """
    Constructs a TF graph that computes the 2D inverse DWT for a given wavelet.

    Args:
        input_node (tf.placeholder):             Input signal. A 3D tensor with dimensions
                                                 as [rows, cols, channels]
        wavelet (tfwavelets.dwtcoeffs.Wavelet):  Wavelet object.
        levels (int):                            Number of levels.

    Returns:
        Output node of IDWT graph.
    """
    m, n = int(input_node.shape[0]), int(input_node.shape[1])
    first_m, first_n = m // (2 ** levels), n // (2 ** levels)

    last_level = tf.slice(input_node, [0, 0, 0], [first_m, first_n, -1])

    for level in range(levels - 1, -1, -1):
        local_m, local_n = m // (2 ** level), n // (2 ** level)

        # Extract detail spaces
        detail_tr = tf.slice(input_node, [local_m // 2, 0, 0],
                             [local_n // 2, local_m // 2, -1])
        detail_bl = tf.slice(input_node, [0, local_n // 2, 0],
                             [local_n // 2, local_m // 2, -1])
        detail_br = tf.slice(input_node, [local_n // 2, local_m // 2, 0],
                             [local_n // 2, local_m // 2, -1])

        # Construct image of this DWT level
        upper_half = tf.concat([last_level, detail_tr], 0)
        lower_half = tf.concat([detail_bl, detail_br], 0)

        this_level = tf.concat([upper_half, lower_half], 1)

        # First pass, corresponding to second pass in dwt2d
        first_pass = tf.transpose(
            idwt1d(
                tf.transpose(this_level, perm=[1, 0, 2]),
                wavelet,
                1
            ),
            perm=[1, 0, 2]
        )
        # Second pass, corresponding to first pass in dwt2d
        second_pass = idwt1d(first_pass, wavelet, 1)

        last_level = second_pass

    return last_level
    
def create_wavelet_transform_filters(wavelet_type, precision_float_type,precision_float_numpy_type):
    """    
    Arguments:
    wavelet -- 'haar', 'db2'               
    
    Returns:
    filter - filter useful for the wavelet transform
    """
    if wavelet_type=='haar':
        print('creation of wavelet transform filters : haar by tfw')
        # Haar wavelet
        c0 = np.sqrt(2)/2
        return Wavelet(
            Filter(np.array([c0, c0]), 1, precision_float_type,precision_float_numpy_type),
            Filter(np.array([-c0, c0]), 0, precision_float_type,precision_float_numpy_type),
            Filter(np.array([c0, c0]), 0, precision_float_type,precision_float_numpy_type),
            Filter(np.array([c0, -c0]), 1, precision_float_type,precision_float_numpy_type),
         )
    elif wavelet_type=='db2':
        print('creation of wavelet transform filters : db2 by tfw')
        # Daubechies wavelets    
        c1 = (1+np.sqrt(3))/(4*np.sqrt(2))
        c2 = (3+np.sqrt(3))/(4*np.sqrt(2))
        c3 = (3-np.sqrt(3))/(4*np.sqrt(2))
        c4 = (1-np.sqrt(3))/(4*np.sqrt(2))
        
        return Wavelet(
            Filter(np.array([c4, c3, c2, c1]), 3, precision_float_type,precision_float_numpy_type),
            Filter(np.array([-c1, c2, -c3, c4]), 0, precision_float_type,precision_float_numpy_type),
            Filter(np.array([c1, c2, c3, c4]), 0, precision_float_type,precision_float_numpy_type),
            Filter(np.array([c4, -c3, c2, -c1]), 3, precision_float_type,precision_float_numpy_type)
        )
        

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants