diff --git a/.gitattributes b/.gitattributes index 96d32bd35..dda5bfc74 100644 --- a/.gitattributes +++ b/.gitattributes @@ -65,7 +65,6 @@ c_reference/tests/kws/keyword_spotting_io_2.h filter=lfs diff=lfs merge=lfs -tex c_reference/tests/kws/keyword_spotting_io_3.h filter=lfs diff=lfs merge=lfs -text c_reference/tests/conv1d/conv1d_regular/conv_param.h filter=lfs diff=lfs merge=lfs -text c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h filter=lfs diff=lfs merge=lfs -text -c_reference/tests/conv1d/conv1d_lr_depthwise/conv_param_lr_depth.h filter=lfs diff=lfs merge=lfs -text c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h filter=lfs diff=lfs merge=lfs -text c_reference/tests/kws/precnn_params.h filter=lfs diff=lfs merge=lfs -text c_reference/tests/kws/postcnn_params.h filter=lfs diff=lfs merge=lfs -text diff --git a/c_reference/include/conv1d.h b/c_reference/include/conv1d.h index 7fda9e4a4..a7ed49315 100644 --- a/c_reference/include/conv1d.h +++ b/c_reference/include/conv1d.h @@ -4,18 +4,45 @@ #ifndef __CONV1D_H__ #define __CONV1D_H__ -// Currently dilation is not supported. We have coded separate functions for regular and depthwise conv1d(and low-rank versions). They currently do not support the use of groups -// We use a custom matVec with offset (in utils) for our conv codes. This allows us to decompose the operation using the low-rank property and decrease the run-time -// The unoptimized version would be to first compute the weights and then perform the convolution +/* All the matrices/tensors are stored in the row major format + + NOTES for the conv layers +-> The conv1d & conv1d_lr layers work for all cases and can be used unconstrained. + There are no hard constraints for the parallel version, but a few points regarding its optimal usage are given below +-> Dilation = 1 (no dilation) for all cases +-> For the non-depthwise cases, store the matrices as described below. Permutation might be necessary +-> The low-rank decomposition cannot be applied to the depthwise weight matrices. This is due to the out_channels/in_channels = 0 constarint imposed by the depthwise convolution. + For full-rank this is satisfied since out_channels = in_channels + But, when the matrix is decomposed, the constarint is violated (since rank < out_channels ; rank is not divisible by in_channels) + Hence due to the decomposition being theoretically impossible, we have not provided the support + However we suggest a less-efficient alternative => First pre-compute the weights W = W2 * W1 and then use a regular conv +-> For the parallel cases, the non-overlapping cases of the convolution are computed parallelly using MatMul (since the blocked MatMul is faster) + This howver is only valid for when the filter is fully in the input. There would be no-overlapping for the edge cases + Hence the MatVec code(regular code) is used to calculate these cases + + Important points regarding parallel versions +-> Due to the above reason, the parallel layers is only recommended for large in_time inputs + This should typically be for in_time (without the padding) > 2 * num_steps_one_row + stride. Else there would not be enough time-steps to efficiently parallelise + We need at least 2 rows for a good a MatMul performace. In the worst case the starting time step would be (stride - 1). Hence we choose 2 * num_steps_one_row + stride as the threshold + For the short input cases, the code will skip the MatMul computation and use MatVec instead (but the MatMul-variable computation overhead would remain) + For such cases, the MatVec code (conv1d and conv1d_lr) would work more efficiently due to the lower RAM usage and lack of any major overheads +-> There is no support for depthwise for conv1d_parallel + The regular convolution acts on all the channels while the depthwise acts only on one channel at a time + This results in a non-contiguos memory access. MatMul would need to process multiple such time-steps, while the MatVec would only need to process one + Hence, the MatVec would be able to enter the next channel earlier and would work much faster + While the MatMul would have cache misses (when dealing with the small chache size of edge devices) +*/ /** * @brief Model parameters for the 1D Convolution Layer - * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1] - * @var B pointer to the bias vector, original shape = [out_channels] + * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1] + * @var B pointer to the bias vector, original shape = [out_channels] + * @var depthwise flag for deciding between regular(=0) and depthwise(=1) conv */ typedef struct ConvLayers_Params { const float* const W; const float* const B; + unsigned depthwise; } ConvLayers_Params; /** @@ -23,47 +50,62 @@ typedef struct ConvLayers_Params { * @param[out] output_signal pointer to the output signal, size = out_time * out_channels * @param[in] out_time number of time steps in the output * @param[in] out_channels number of output channels for the output of the conv layer + * NOTE: out_channels = in_channels for depthwise. This is set manually in the function * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input * @param[in] in_channels number of input channels * @param[in] padding padding applied to the input before the conv is performed. * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time). + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the conv filter * @param[in] params weights, bias and other essential parameters used to describe the layer * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity * 0: none * 1: sigmoid * 2: tanh * 3: relu */ -int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, const float* input_signal, - unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, +int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation); /** - * @brief Model definition for the 1D Depthwise Convolution Layer. Currently only for dilation = 1 - * @param[out] output_signal pointer to the output signal, size = out_time * in_channels - * NOTE: out_channels == in_channels for depthwise + * @brief Model parameters for the 1D Parallel Convolution Layer + * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1] + * @var B pointer to the bias vector, original shape = [out_channels] + * @var block_size block/tile size for the cache. Used for tiled MatMul + */ +typedef struct ConvLayers_Parallel_Params { + const float* const W; + const float* const B; + unsigned block_size; +} ConvLayers_Parallel_Params; + +/** + * @brief Model definition for the 1D Parallel Convolution Layer. Currently only for dilation = 1. No depthwise. + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels * @param[in] out_time number of time steps in the output + * @param[in] out_channels number of output channels for the output of the conv layer * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels. The output will have the same number of channels + * @param[in] in_channels number of input channels * @param[in] padding padding applied to the input before the conv is performed. * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time). + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the conv filter * @param[in] params weights, bias and other essential parameters used to describe the layer * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity * 0: none * 1: sigmoid * 2: tanh * 3: relu */ -int conv1d_depth(float* output_signal, unsigned out_time, const float* input_signal, - unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, +int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation); /** @@ -81,61 +123,73 @@ typedef struct ConvLayers_LR_Params { } ConvLayers_LR_Params; /** - * @brief Model definition for the 1D Low-Rank Convolution Layer. Currently only for dilation = 1 + * @brief Model definition for the 1D Low-Rank Convolution Layer. Currently only for dilation = 1. + * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels) * @param[out] output_signal pointer to the output signal, size = out_time * out_channels * @param[in] out_time number of time steps in the output - * @param[in] out_channels number of output channels for the ouput of the conv layer + * @param[in] out_channels number of output channels for the output of the conv layer * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input * @param[in] in_channels number of input channels * @param[in] padding padding applied to the input before the conv is performed. * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time). + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the conv filter * @param[in] params weights, bias and other essential parameters used to describe the layer * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity * 0: none * 1: sigmoid * 2: tanh * 3: relu */ -int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, const float* input_signal, - unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, +int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation); /** - * @brief Model definition for the 1D Low-Rank Depthwise Convolution Layer. Currently only for dilation = 1 - * @param[out] output_signal pointer to the output signal, size = out_time * in_channels - * NOTE: out_channels == in_channels for depthwise conv + * @brief Model parameters for the 1D Low Rank Parallel Convolution Layer. + * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels + * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1] + * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels] + * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage + * @var block_size_to_lr block/tile size for the cache. Used for tiled MatMul. Used for the input -> low-rank computation + * @var block_size_from_lr block/tile size for the cache. Used for tiled MatMul. Used for the low-rank -> output computation + */ +typedef struct ConvLayers_LR_Parallel_Params { + const float* const W1; + const float* const W2; + const float* const B; + unsigned rank; + unsigned block_size_to_lr; + unsigned block_size_from_lr; +} ConvLayers_LR_Parallel_Params; + +/** + * @brief Model definition for the 1D Low-Rank Parallel Convolution Layer. Currently only for dilation = 1. + * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels) + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels * @param[in] out_time number of time steps in the output + * @param[in] out_channels number of output channels for the output of the conv layer * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels. The output will have the same number of channels + * @param[in] in_channels number of input channels * @param[in] padding padding applied to the input before the conv is performed. * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time). + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the conv filter * @param[in] params weights, bias and other essential parameters used to describe the layer * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity * 0: none * 1: sigmoid * 2: tanh * 3: relu - Note for the usage of conv1d_depth_lr: - The depthwise with low-rank conv1d code currently uses an unoptimized implementation based on the computation of the conv weights, followed by the depthwise convolution - The reason for using the unoptimized implementation for the depthwise with low-rank conv1d is due to the violation of the depthwise constraints when the low-rank decomposition is applied - The use of depthwise convolution imposes a constraint on the out_channels of the weight matrix. When the low-rank decomposition is applied on top of this matrix, these constraints will be violated - The decomposition converts the depthwise conv into a fully-connected layer and a convolution layer with weight [rank, kernel_size, 1] - The new smaller weight matrix resembles a depthwise conv. But here, typically, in_channels > rank. This causes a violation in the matrix constraints for depthwise convolution - Hence, due to the violation, we cannot split the opeartion and would need to use the unoptimized solution with full-rank weight computation followed by convolution - - The depthwise with low-rank code is recommended for extreme storage constraints with no major constraints on the computation cost - For all other puposes, we recommend the use of a combinantion of depthwise conv, low-rank conv and regular conv */ -int conv1d_depth_lr(float* output_signal, unsigned out_time, const float* input_signal, - unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, +int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation); // Auxiliary Layers @@ -149,10 +203,10 @@ int conv1d_depth_lr(float* output_signal, unsigned out_time, const float* input_ * @param[in] in_channels number of input channels. The output will have the same number of channels * @param[in] padding padding applied to the input before the conv is performed. * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time). + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the pool filter * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity * 0: none * 1: sigmoid * 2: tanh diff --git a/c_reference/include/dscnn.h b/c_reference/include/dscnn.h index ce3a9e12d..541923056 100644 --- a/c_reference/include/dscnn.h +++ b/c_reference/include/dscnn.h @@ -14,6 +14,7 @@ typedef int (*conv_layer)(float*, unsigned, unsigned, const float*, * @brief sub-layers : batchnorm1d -> conv1d_lr * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] cnn function pointer for the CNN layer. (any of the conv layers can be passed with appropriate params) * @param[in] in_time number of time steps in the input_signal * @param[in] in_channels number of input channels * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 @@ -38,7 +39,7 @@ typedef int (*conv_layer)(float*, unsigned, unsigned, const float*, * 3: relu */ int phon_pred_lr_cnn(float* output_signal, float* input_signal, - unsigned in_time, unsigned in_channels, + conv_layer cnn, unsigned in_time, unsigned in_channels, const float* const mean, const float* const var, unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, unsigned cnn_hidden, unsigned cnn_padding, unsigned cnn_kernel_size, @@ -49,6 +50,7 @@ int phon_pred_lr_cnn(float* output_signal, float* input_signal, * @brief sub-layers : custom nonlinearity(semi_sigmoid_tanh) -> batchnorm1d -> conv1d_depth -> conv1d_lr -> avgpool1d * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] point_cnn function pointer for the point-wise CNN. (any of the conv layers can be passed with appropriate params) * @param[in] in_time number of time steps in the input * @param[in] in_channels number of input channels * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 diff --git a/c_reference/include/rnn_bricked.h b/c_reference/include/rnn_bricked.h index 39fb3221c..adc910d42 100644 --- a/c_reference/include/rnn_bricked.h +++ b/c_reference/include/rnn_bricked.h @@ -4,35 +4,65 @@ #ifndef __RNN_BRICKED_H__ #define __RNN_BRICKED_H__ -// Function pointer for the RNN to be passed as a parameter -typedef int (*rnn_layer)(float* const, unsigned, const float* const, unsigned, - unsigned, const void*, void*, int, int); +/* All the matrices are stored in the row major format + + NOTES for using the layers +-> Single-directional Computation + While using the bricked fastgrnn layers, the user needs to adhered to the two following constraints + 1) in_time % hop = 0 + 2) fwd_window % hop = 0 and bwd_window % hop = 0 -// NOTES for bi-direction -// If bi_direction = 1, then actual rnn_output_dims is twice the rnn_hidden(rnn_hidden is output dims for each cell). -// Each function will only process its given context(forward/backward). -// The other context will need to be called separately with an appropriate offset. -// E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...) -// 2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...) -// -// Each cell will only calculate half the hidden state i.e. rnn_hidden slots of memory from the start of the output pointer -// Hence rnn_hidden is used as an offset for the backward pass. The offset for the forward pass is 0 -// This use of an offset is a way to exploit the nature of bi-direction to bypass the concatenation step typically associated with bi-directional passes -// -// Constraints -// For Bi-Directional use, there are 3 constraints -// 1) (in_time - fwd_window) % hop == 0 and (in_time - bwd_window) % hop == 0 -// 2) fwd_window % hop == 0 and bwd_window % hop == 0 -// 3) sample_first_brick and sample_last_brick = 1 -// -// Violation of these constraints can lead to one of the following issues -// 1) segmentation faults -// 2) forward out_time != backward out_time -// 3) mismatch between forward index and backward index during sampling i.e forward index 8 would correspond to backward index 6. This index error continues for all consecutive bricks -// Hence, padding of the input and appropriate window choice is necessary -// -// These constraints can be ignored while performing uni-directional passes. However, it is favorable to follow constraints 1 and 2 + Violation of the above two constraints (1 & 2), will cause segmentation faults + The layers first compute all the Wx steps and then compute Uh for all the windows parallelly + Hence, the user needs to adhered to the constraints 1 & 2 +-> Bi-directional Computation + For bi-directional cases, there are 2 additionally constraints that would need to be followed + A) sample_first_brick and sample_last_brick = 1 + B) An offset of rnn_hidden would need to be given to the output_signal pointer during the backward function call + Each function will only process its given context(forward/backward). The other context will need to be called separately. + E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...) + 2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...) + + The two extra constraints (A & B) are only for bi-directional cases and can be ignored if only forward (or only backward) is used + Violating the conditions would cause index mis-matches or data corruption + If the first (last) brick is not sampled, the first few (last few) time steps would be missing in the forward (backward) result + If the offset is not passed during the backward function call, the backward pass will overwrite the forward result (bi-directional case only) +*/ + +/** + * @brief Model parameters for the 1D Convolution Layer + * @var W1 pointer to first low-rank component of W. shape = [rank * in_dims] + * @var W2 pointer to second low-rank component of W. shape = [rnn_hidden * rank] + * @var wRank rank of W matrix + * @var U1 pointer to first low-rank component of U. shape = [rank * rnn_hidden] + * @var U2 pointer to second low-rank component of U. shape = [rnn_hidden * rank] + * @var uRank rank of U matrix + * @var Bg pointer to bias for sigmoid + * @var Bh pointer to bias for tanh + * @var sigmoid_zeta first weight parameter for update from input from next step + * @var sigmoid_nu second weight parameter for update from input from next step + * @var block_size_w_to_lr block/tile size for the cache. Used for tiled MatMul. For W1 * x + * @var block_size_w_from_lr block/tile size for the cache. Used for tiled MatMul. For W2 * result(W1 * x) + * @var block_size_u_to_lr block/tile size for the cache. Used for tiled MatMul. For U1 * h + * @var block_size_u_from_lr block/tile size for the cache. Used for tiled MatMul. For U2 * result(U1 * h) + */ +typedef struct BrickedFastGRNN_LR_Params { + float* W1; + float* W2; + unsigned wRank; + float* U1; + float* U2; + unsigned uRank; + float* Bg; + float* Bh; + float sigmoid_zeta; + float sigmoid_nu; + unsigned block_size_w_to_lr; + unsigned block_size_w_from_lr; + unsigned block_size_u_to_lr; + unsigned block_size_u_from_lr; +} BrickedFastGRNN_LR_Params; /** Forward Bricking and application of the forward RNN for an input signal * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden @@ -42,18 +72,16 @@ typedef int (*rnn_layer)(float* const, unsigned, const float* const, unsigned, * @param[in] in_dims input dimensions * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick) * @param[in] hop hop distance for between bricks - * @param[in] rnn function pointer to the RNN * @param[in] params pointer to the parameters for the RNN - * @param[in,out] buffers pointer to buffer for the RNN * @param[in] bi_direction determine if the ouput if for a bi-directional RNN. * @param[in] sample_first_brick determine if the 1st brick should also be sampled * -> if = 0, only the last hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1 * -> if = 1, for the 1st brick, we sample every hop index(similar to ::hop). For all the bricks(including the 1st) we sample the final hiddens state. out_time = in_time/hop + 1 */ -int forward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_signal, - unsigned in_time, unsigned in_dims, unsigned window, unsigned hop, - rnn_layer rnn, const void* params, void* buffers, - unsigned bi_direction, unsigned sample_first_brick, int normalize); +int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_first_brick); /** Backward Bricking and application of the backward RNN for an input signal * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden @@ -63,18 +91,15 @@ int forward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_ * @param[in] in_dims input dimensions * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick) * @param[in] hop hop distance for between bricks - * @param[in] rnn function pointer to the RNN * @param[in] params pointer to the parameters for the RNN - * @param[in,out] buffers pointer to buffer for the RNN * @param[in] bi_direction determine if the ouput if for a bi-directional RNN. * @param[in] sample_last_brick determine if the last brick should also be sampled * -> if = 0, only the first(last in reverse) hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1 * -> if = 1, for the last brick, we sample every hop index in reverse(similar to ::hop in reverse). For all the bricks(including the last) we sample the first hiddens state(last in reverse). out_time = in_time/hop + 1 */ -int backward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_signal, - unsigned in_time, unsigned in_dims, unsigned window, unsigned hop, - rnn_layer rnn, const void* params, void* buffers, - unsigned bi_direction, unsigned sample_last_brick, int normalize); - +int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_last_brick); #endif diff --git a/c_reference/include/utils.h b/c_reference/include/utils.h index 37b89ae3b..07438d134 100644 --- a/c_reference/include/utils.h +++ b/c_reference/include/utils.h @@ -31,37 +31,83 @@ void matVec(const float* const mat, const float* const vec, float alpha, float beta, float* const ret); -/* Matrix-vector multiplication with a row offset - This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis - ret is of size nrows, vec is of size ncols - mat is of size nrows * ncols, stored in row major - depthwise is to change the matVec to depthwise specific convolutions - row_stride is the offset factor between two adjacent rows - Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped - For a normal matVec case, this value will be ncols - Eg : for a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. For this eg ncols will be 100 and row_stride will be 400 - vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals - For a normal matVec case, this value will be 1 - Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. So it's possible to enter a 400 length vector and consider every 4th element. For this ncols will be 100 and vec_stride will be 4*/ +/* + Matrix-vector multiplication with a row offset + This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis + ret is of size nrows, vec is of size ncols + mat is of size nrows * ncols, stored in row major + depthwise is to change the matVec to depthwise specific convolutions + row_stride is the offset factor between two adjacent rows + Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped + For a normal matVec case, this value will be ncols + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + For this eg ncols will be 100 and row_stride will be 400 + vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals + For a normal matVec case, this value will be 1 + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + So it's possible to enter a 400 length vector and consider every 4th element. + So it's possible to enter a 400 length vector and consider every 4th element. + So it's possible to enter a 400 length vector and consider every 4th element. + For this ncols will be 100 and vec_stride will be 4 +*/ void offset_matVec_conv1d(const float* mat, const float* vec, unsigned nrows, unsigned ncols, unsigned row_stride, unsigned vec_stride, unsigned depthwise, float* ret); -/* Scaled matrix-matrix multiplication: ret = alpha * ret + beta * matA * matB - matA first matrix; size = nrows * ncommon - matB second matrix; size = ncommon * ncols - nrows number of rows in the first matrix - ncommon number of columns in the first matrix/number of rows in the second matrix - ncols number of columns in the second matrix - alpha scaling factor for the previously-stored output matrix - beta scaling factor for the result of the multiplication (matA * matB) - ret matrix multiplication output - */ -void matMul(const float* const matA, const float* const matB, +/* + Tiled (cache-blocked) implementation of the Matrix Multiplication + Note: If only the MatMul output is needed, then please use calloc to initialize the output + An alternative is to use malloc, followed by memset 0 + There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix + If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly + This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed + matA first matrix; shape = [nrows, ncommon] + matB second matrix; shape = [ncommon, ncols] + nrows number of rows in the first matrix + ncommon number of columns in the first matrix/number of rows in the second matrix + ncols number of columns in the second matrix + total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored + total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + ret matrix multiplication output. shape = [nrows, ncols] + block_size tile/block size for optimal cache performance. A hardware specific parameter +*/ +void tiledMatMul_float(const float* const matA, const float* const matB, unsigned nrows, unsigned ncommon, unsigned ncols, - float alpha, float beta, - float* const ret); + unsigned total_comm_A, unsigned total_cols_B, + float* const ret, unsigned block_size); + +/* + Tiled (cache-blocked) implementation of the Matrix Multiplication, but with matB stored in the transposed format + The result will the same as the regular MatMul but the matrix B provided will be pre-transposed (before the storage or usage) + Note: If only the MatMul output is needed, then please use calloc to initialize the output + An alternative is to use malloc, followed by memset 0 + There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix + If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly + This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed + matA first matrix; shape = [nrows, ncommon] + matB second matrix; shape = [ncols, ncommon] + nrows number of rows in the first matrix + ncommon number of columns in the first matrix/number of rows in the second matrix + ncols number of columns in the second matrix + total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored + total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + Since matB is transposed the columns are now the ncomm axis + ret matrix multiplication output. shape = [nrows, ncols] + block_size tile/block size for optimal cache performance. A hardware specific parameter +*/ +void transposed_tiledMatMul(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_comm_B, + float* const ret, unsigned block_size); // scaled vector addition: ret = scalar1 * vec1 + scalar2 * vector2 void v_add(float scalar1, const float* const vec1, diff --git a/c_reference/src/conv1d.c b/c_reference/src/conv1d.c index a22d0c682..2ab5b7f30 100644 --- a/c_reference/src/conv1d.c +++ b/c_reference/src/conv1d.c @@ -7,89 +7,88 @@ #include "conv1d.h" #include "utils.h" -int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, const float* input_signal, - unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, +int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation) { const ConvLayers_LR_Params* tparams= (ConvLayers_LR_Params*) params; // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding + unsigned rank = tparams->rank; // Buffer for W2 out - float* temp_rank_out = (float*)malloc(tparams->rank * sizeof(float)); + float* temp_rank_out = (float*)malloc(rank * sizeof(float)); // Buffer for W1 out float* temp_out = (float*)malloc(out_channels * sizeof(float)); for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { unsigned t_index = t_out * out_channels; - // There are typically 5 cases - // 1) Filter not yet inside the input - // 2) Filter partially inside the input - // 3) Filter fully inside the input - // 4) Filter partly outside the input - // 5) Filter fully outside the input if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { // Filter fully inside the input. Kept as the initial condition, since this is the most common one - offset_matVec_conv1d(tparams->W2, input_signal + (t_in_start - padding) * in_channels, - tparams->rank, kernel_size * in_channels, - kernel_size * in_channels, 1, 0, temp_rank_out); - // The row_stride and ncols are provided with the same value in the function call below and vec_stride = 1, depthwise = 0 - // Hence, this call will be the same as a regular MatVec function call (without any scaling) + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, - tparams->rank, tparams->rank, 1, 0, temp_out); + rank, rank, 1, 0, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else if ((t_in_start < padding) && (t_in_end >= padding)) { // Filter partially entered the input - // In this case we using only a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps - // Hence the number of columns needed reduces. But the whole matrix is a continuous piece of memory. So we need to discard/skip certain columns - // Since we discard the last(or first) few column elements we can't iterate regularly(else we access the unnecessary values) - // Hence we provide a separate row_stride to hop from one row to another + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, - input_signal, tparams->rank, (t_in_end - padding + 1) * in_channels, - kernel_size * in_channels, 1, 0, temp_rank_out); - // The row_stride and ncols are provided with the same value in the function call below and vec_stride = 1, depthwise = 0 - // Hence, this call will be the same as a regular MatVec function call (without any scaling) + input_signal, rank, + (t_in_end - padding + 1) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, - tparams->rank, tparams->rank, 1, 0, temp_out); + rank, rank, 1, 0, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { // Filter partially exited the input - // In this case we using only a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps - // Hence the number of columns needed reduces. But the whole matrix is a continuous piece of memory. So we need to discard/skip certain columns - // Since we discard the last(or first) few column elements we can't iterate regularly(else we access the unnecessary values) - // Hence we provide a separate row_stride to hop from one row to another - offset_matVec_conv1d(tparams->W2, input_signal + (t_in_start - padding) * in_channels, - tparams->rank, (in_time + padding - t_in_start) * in_channels, - kernel_size * in_channels, 1, 0, temp_rank_out); - // The row_stride and ncols are provided with the same value in the function call below and vec_stride = 1, depthwise = 0 - // Hence, this call will be the same as a regular MatVec function call (without any scaling) + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, (in_time + padding - t_in_start) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, - tparams->rank, tparams->rank, 1, 0, temp_out); + rank, rank, 1, 0, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else { // Filter completely in the padding region // The filter is either fully outside the input or has not yet entered the input - // Hence we can skip the multiplication/addition operations and just set those output values to zero memset(output_signal + t_index, 0, out_channels * sizeof(float)); } for (unsigned co = 0; co < out_channels; co++) { // Post-Conv activation. More activation functions can be added should the necessity arise - if (activation == 1) { - output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + tparams->B[co]); - } - else if (activation == 2) { - output_signal[t_index + co] = tanh(output_signal[t_index + co] + tparams->B[co]); - } - else if (activation == 3) { - output_signal[t_index + co] = relu(output_signal[t_index + co] + tparams->B[co]); - } - else { - output_signal[t_index + co] += tparams->B[co]; + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; } } } @@ -98,90 +97,177 @@ int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, co return 0; } -int conv1d_depth_lr(float* output_signal, unsigned out_time, const float* input_signal, - unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, +int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation) { - const ConvLayers_LR_Params* tparams= (ConvLayers_LR_Params*) params; - // Pre-computation of the weights for conv - float* tempW = (float*)malloc(in_channels * kernel_size * sizeof(float)); - matMul(tparams->W1, tparams->W2, in_channels, tparams->rank, - kernel_size, 0, 1.0, tempW); - // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding - float* temp_out = (float*)malloc(in_channels * sizeof(float)); - for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; - t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { - unsigned t_index = t_out * in_channels; - - // There are typically 5 cases - // 1) Filter not yet inside the input - // 2) Filter partially inside the input - // 3) Filter fully inside the input - // 4) Filter partly outside the input - // 5) Filter fully outside the input - if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { - // Filter fully inside the input. Kept as the initial condition, since this is the most common one - offset_matVec_conv1d(tempW, input_signal + (t_in_start - padding) * in_channels, - in_channels, kernel_size, - kernel_size, in_channels, 1, temp_out); - memcpy(output_signal + t_index, temp_out, in_channels * sizeof(float)); + unsigned ncols = kernel_size * in_channels, num_iter = 0, num_steps_one_row = 0; + // Calculate the number of time steps in one row for the first non-overlapping instance + while (num_steps_one_row < kernel_size) { + num_steps_one_row += stride; + num_iter++; + } + unsigned total_in_cols = num_steps_one_row * in_channels; + + const ConvLayers_LR_Parallel_Params* tparams = (ConvLayers_LR_Parallel_Params*) params; + // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding + // Buffer to hold the output. For corner cases, this will be realtively big. + // But will be needed for the central condition (filter inside input). + // If there are not enough time steps to linearise into one row, then allocate only 1 time step + unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? + in_time / num_steps_one_row : 1; + unsigned rank = tparams->rank; + // Buffer for W2 out + float* temp_rank_out = (float*)malloc(buffer_steps * rank * sizeof(float)); + // Buffer for W1 out + float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); + + unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here + for (t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_in_start < padding && t_out < out_time; + t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_end < padding) { + // Filter outside the input region and in the padded region + memset(output_signal + t_out * out_channels, 0, + out_channels * sizeof(float)); } - else if ((t_in_start < padding) && (t_in_end >= padding)) { + else { //(t_in_end >= padding) // Filter partially entered the input - // In this case we using only a part of the weight matrix(assuming shape = in_channels, kernel_size) - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps - // Hence the number of columns needed reduces. But the whole matrix is a continuous piece of memory. So we need to discard/skip certain columns - // Since we discard the last(or first) few column elements we can't iterate regularly(else we access the unnecessary values) - // Hence we provide a separate row_stride to hop from one row to another - offset_matVec_conv1d(tempW + (padding - t_in_start), - input_signal, in_channels, (t_in_end - padding + 1), - kernel_size, in_channels, 1, temp_out); - memcpy(output_signal + t_index, temp_out, in_channels * sizeof(float)); + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, + input_signal, rank, (t_in_end - padding + 1) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + } + // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases + // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region + // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row + // Using the above logic, we can convert the MatVec opeartion into a MatMul operation + // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied) + // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria + // The MatVec will be used for the computation in-case the following block is skipped + if (in_time > ((num_steps_one_row << 1) + stride)) { + t_in_start -= padding; // remove the padding offset temporarily + t_in_end -= padding; // Used to keep track of the final processed index + for (unsigned iter = 0; (iter < num_iter) && (t_out < out_channels); + iter++, t_in_start += stride, t_out++) { + unsigned in_rows = (in_time - t_in_start) / num_steps_one_row; + memset(temp_rank_out, 0, buffer_steps * rank * sizeof(float)); + memset(temp_out, 0, buffer_steps * out_channels * sizeof(float)); + if (t_in_end < (t_in_start + ((in_rows - 1) * num_steps_one_row))) { + // t_in_end is used to find the furthest time step was used in the MatMul calculation + // This value will be used for calculating the index for the final section of the processing + t_in_end = ((in_rows - 1) * num_steps_one_row) + t_in_start + stride; + } + transposed_tiledMatMul(input_signal + t_in_start * in_channels , tparams->W2, + in_rows, ncols, rank, total_in_cols, ncols, + temp_rank_out, tparams->block_size_to_lr); + transposed_tiledMatMul(temp_rank_out , tparams->W1, + in_rows, rank, out_channels, rank, rank, + temp_out, tparams->block_size_from_lr); + // Copy all the data into the output + float* output_offset = (float*)output_signal + t_out * out_channels; + float* temp_offset = (float*)temp_out; + unsigned t_iter = in_rows, offset_factor_for_out = num_iter * out_channels; + while (t_iter--) { + memcpy(output_offset, temp_offset, out_channels * sizeof(float)); + output_offset += offset_factor_for_out; + temp_offset += out_channels; + } + } + // Initialize the time iterators + // Use the stored value in t_in_end to calculate the iterators + t_in_start = t_in_end + padding; // Add the padding and stride offsets again + t_in_end = t_in_start + kernel_size - 1; + t_out = t_in_start / stride; + } + for (; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_start < (in_time + padding) && (t_in_end < (in_time + padding))) { + // Filter fully in the input but very close to the edges. + // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped + // Incase the MatMul is skipped, this block will be used to compute the results + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); } else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { // Filter partially exited the input - // In this case we using only a part of the weight matrix(assuming shape = in_channels, kernel_size) - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps - // Hence the number of columns needed reduces. But the whole matrix is a continuous piece of memory. So we need to discard/skip certain columns - // Since we discard the last(or first) few column elements we can't iterate regularly(else we access the unnecessary values) - // Hence we provide a separate row_stride to hop from one row to another - offset_matVec_conv1d(tempW, input_signal + (t_in_start - padding) * in_channels, - in_channels, (in_time + padding - t_in_start), - kernel_size, in_channels, 1, temp_out); - memcpy(output_signal + t_index, temp_out, in_channels * sizeof(float)); + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, (in_time + padding - t_in_start) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); } else { - // Filter completely in the padding region - // The filter is either fully outside the input or has not yet entered the input - // Hence we can skip the multiplication/addition operations and just set those output values to zero - memset(output_signal + t_index, 0, in_channels * sizeof(float)); + // Filter completely outside the input and in the padding region + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); } - for (unsigned co = 0; co < in_channels; co++) { + } + // Bias and activation + for (t_out = 0; t_out < out_time; t_out++) { + unsigned t_index = t_out * out_channels; + for (unsigned co = 0; co < out_channels; co++) { // Post-Conv activation. More activation functions can be added should the necessity arise - if (activation == 1) { - output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + tparams->B[co]); - } - else if (activation == 2) { - output_signal[t_index + co] = tanh(output_signal[t_index + co] + tparams->B[co]); - } - else if (activation == 3) { - output_signal[t_index + co] = relu(output_signal[t_index + co] + tparams->B[co]); - } - else { - output_signal[t_index + co] += tparams->B[co]; + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; } } } free(temp_out); - free(tempW); return 0; } -int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, const float* input_signal, - unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, +int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation) { const ConvLayers_Params* tparams= (ConvLayers_Params*) params; + unsigned vec_stride = 1, cols_scale = in_channels; + if (tparams->depthwise) { + vec_stride = in_channels; + out_channels = in_channels; + cols_scale = 1; + } // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding float* temp_out = (float*)malloc(out_channels * sizeof(float)); @@ -189,62 +275,61 @@ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, const t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { unsigned t_index = t_out * out_channels; - // There are typically 5 cases - // 1) Filter not yet inside the input - // 2) Filter partially inside the input - // 3) Filter fully inside the input - // 4) Filter partly outside the input - // 5) Filter fully outside the input if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { // Filter fully inside the input. Kept as the initial condition, since this is the most common one - offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, - out_channels, kernel_size * in_channels, - kernel_size * in_channels, 1, 0, temp_out); + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, kernel_size * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else if ((t_in_start < padding) && (t_in_end >= padding)) { // Filter partially entered the input - // In this case we using only a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps - // Hence the number of columns needed reduces. But the whole matrix is a continuous piece of memory. So we need to discard/skip certain columns - // Since we discard the last(or first) few column elements we can't iterate regularly(else we access the unnecessary values) - // Hence we provide a separate row_stride to hop from one row to another - offset_matVec_conv1d(tparams->W + (padding - t_in_start) * in_channels, - input_signal, out_channels, (t_in_end - padding + 1) * in_channels, - kernel_size * in_channels, 1, 0, temp_out); + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W + (padding - t_in_start) * cols_scale, + input_signal, out_channels, (t_in_end - padding + 1) * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { // Filter partially exited the input - // In this case we using only a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps - // Hence the number of columns needed reduces. But the whole matrix is a continuous piece of memory. So we need to discard/skip certain columns - // Since we discard the last(or first) few column elements we can't iterate regularly(else we access the unnecessary values) - // Hence we provide a separate row_stride to hop from one row to another - offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, - out_channels, (in_time + padding - t_in_start) * in_channels, - kernel_size * in_channels, 1, 0, temp_out); + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, (in_time + padding - t_in_start) * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else { // Filter completely in the padding region // The filter is either fully outside the input or has not yet entered the input - // Hence we can skip the multiplication/addition operations and just set those output values to zero memset(output_signal + t_index, 0, out_channels * sizeof(float)); } for (unsigned co = 0; co < out_channels; co++) { // Post-Conv activation. More activation functions can be added should the necessity arise - if (activation == 1) { - output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + tparams->B[co]); - } - else if (activation == 2) { - output_signal[t_index + co] = tanh(output_signal[t_index + co] + tparams->B[co]); - } - else if (activation == 3) { - output_signal[t_index + co] = relu(output_signal[t_index + co] + tparams->B[co]); - } - else { - output_signal[t_index + co] += tparams->B[co]; + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; } } } @@ -252,74 +337,140 @@ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, const return 0; } -int conv1d_depth(float* output_signal, unsigned out_time, const float* input_signal, - unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, +int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation) { + + unsigned ncols = kernel_size * in_channels, num_iter = 0, num_steps_one_row = 0; + // Calculate the number of time steps in one row for the first non-overlapping instance + while (num_steps_one_row < kernel_size) { + num_steps_one_row += stride; + num_iter++; + } + unsigned total_in_cols = num_steps_one_row * in_channels; - const ConvLayers_Params* tparams= (ConvLayers_Params*) params; - + const ConvLayers_Parallel_Params* tparams = (ConvLayers_Parallel_Params*) params; // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding - float* temp_out = (float*)malloc(in_channels * sizeof(float)); - for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; - t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { - unsigned t_index = t_out * in_channels; - - // There are typically 5 cases - // 1) Filter not yet inside the input - // 2) Filter partially inside the input - // 3) Filter fully inside the input - // 4) Filter partly outside the input - // 5) Filter fully outside the input - if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { - // Filter fully inside the input. Kept as the initial condition, since this is the most common one - offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, - in_channels, kernel_size, - kernel_size, in_channels, 1, temp_out); - memcpy(output_signal + t_index, temp_out, in_channels * sizeof(float)); + // Buffer to hold the output. For corner cases, this will be realtively big. + // But will be needed for the central condition (filter inside input). + // If there are not enough time steps to linearise into one row, then allocate only 1 time step + unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? + in_time / num_steps_one_row : 1; + float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); + unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here + for (t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_in_start < padding && t_out < out_time; + t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_end < padding) { + // Filter outside the input region and in the padded region + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); } - else if ((t_in_start < padding) && (t_in_end >= padding)) { + else { //(t_in_end >= padding) // Filter partially entered the input - // In this case we using only a part of the weight matrix(assuming shape = in_channels, kernel_size) - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps - // Hence the number of columns needed reduces. But the whole matrix is a continuous piece of memory. So we need to discard/skip certain columns - // Since we discard the last(or first) few column elements we can't iterate regularly(else we access the unnecessary values) - // Hence we provide a separate row_stride to hop from one row to another - offset_matVec_conv1d(tparams->W + (padding - t_in_start), - input_signal, in_channels, (t_in_end - padding + 1), - kernel_size, in_channels, 1, temp_out); - memcpy(output_signal + t_index, temp_out, in_channels * sizeof(float)); + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W + (padding - t_in_start) * in_channels, + input_signal, out_channels, (t_in_end - padding + 1) * in_channels, + ncols, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + } + // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases + // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region + // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row + // Using the above logic, we can convert the MatVec opeartion into a MatMul operation + // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied) + // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria + // The MatVec will be used for the computation in-case the following block is skipped + if (in_time > ((num_steps_one_row << 1) + stride)) { + t_in_start -= padding; // remove the padding offset temporarily + t_in_end -= padding; // Used to keep track of the final processed index + for (unsigned iter = 0; (iter < num_iter) && (t_out < out_channels); + iter++, t_in_start += stride, t_out++) { + unsigned in_rows = (in_time - t_in_start) / num_steps_one_row; + memset(temp_out, 0, buffer_steps * out_channels * sizeof(float)); + if (t_in_end < (t_in_start + ((in_rows - 1) * num_steps_one_row))) { + // t_in_end is used to find the furthest time step was used in the MatMul calculation + // This value will be used for calculating the index for the final section of the processing + t_in_end = ((in_rows - 1) * num_steps_one_row) + t_in_start + stride; + } + transposed_tiledMatMul(input_signal + t_in_start * in_channels , tparams->W, + in_rows, ncols, out_channels, total_in_cols, ncols, + temp_out, tparams->block_size); + // Copy all the data into the output + float* output_offset = (float*)output_signal + t_out * out_channels; + float* temp_offset = (float*)temp_out; + unsigned t_iter = in_rows, offset_factor_for_out = num_iter * out_channels; + while (t_iter--) { + memcpy(output_offset, temp_offset, out_channels * sizeof(float)); + output_offset += offset_factor_for_out; + temp_offset += out_channels; + } + } + // Initialize the time iterators + // Use the stored value in t_in_end to calculate the iterators + t_in_start = t_in_end + padding; // Add the padding and stride offsets again + t_in_end = t_in_start + kernel_size - 1; + t_out = t_in_start / stride; + } + for (; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_start < (in_time + padding) && (t_in_end < (in_time + padding))) { + // Filter fully in the input but very close to the edges. + // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped + // Incase the MatMul is skipped, this block will be used to compute the results + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); } else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { // Filter partially exited the input - // In this case we using only a part of the weight matrix(assuming shape = in_channels, kernel_size) - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps - // Hence the number of columns needed reduces. But the whole matrix is a continuous piece of memory. So we need to discard/skip certain columns - // Since we discard the last(or first) few column elements we can't iterate regularly(else we access the unnecessary values) - // Hence we provide a separate row_stride to hop from one row to another - offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, - in_channels, (in_time + padding - t_in_start), - kernel_size, in_channels, 1, temp_out); - memcpy(output_signal + t_index, temp_out, in_channels * sizeof(float)); + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, (in_time + padding - t_in_start) * in_channels, + ncols, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); } else { - // Filter completely in the padding region - // The filter is either fully outside the input or has not yet entered the input - // Hence we can skip the multiplication/addition operations and just set those output values to zero - memset(output_signal + t_index, 0, in_channels * sizeof(float)); + // Filter completely outside the input and in the padding region + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); } - for (unsigned co = 0; co < in_channels; co++) { + } + // Bias and activation + for (t_out = 0; t_out < out_time; t_out++) { + unsigned t_index = t_out * out_channels; + for (unsigned co = 0; co < out_channels; co++) { // Post-Conv activation. More activation functions can be added should the necessity arise - if (activation == 1) { - output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + tparams->B[co]); - } - else if (activation == 2) { - output_signal[t_index + co] = tanh(output_signal[t_index + co] + tparams->B[co]); - } - else if (activation == 3) { - output_signal[t_index + co] = relu(output_signal[t_index + co] + tparams->B[co]); - } - else { - output_signal[t_index + co] += tparams->B[co]; + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; } } } @@ -331,8 +482,8 @@ int avgpool1d(float* output_signal, unsigned out_time, const float* input_signal unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, unsigned stride, unsigned activation) { - // Iterate over the time steps and average them. Similar to Conv1D_Dept with a filter kernel of ones - float scale = 1.0/(float)kernel_size; + // Iterate over the time steps and average them + float scale = 1.0/(float)kernel_size; // To avoid divisions for (unsigned t_in = 0, t_out = 0; t_out < out_time; t_out++, t_in += stride) { for (unsigned ci = 0; ci < in_channels; ci++) { float sum = 0; @@ -344,17 +495,22 @@ int avgpool1d(float* output_signal, unsigned out_time, const float* input_signal sum += (input_signal[((tf + t_in) - padding) * in_channels + ci]); } } - if (activation == 1) { - output_signal[t_out * in_channels + ci] = sigmoid(sum * scale); - } - else if (activation == 2) { - output_signal[t_out * in_channels + ci] = tanh(sum * scale); - } - else if (activation == 3) { - output_signal[t_out * in_channels + ci] = relu(sum * scale); - } - else { - output_signal[t_out * in_channels + ci] = sum * scale; + switch (activation) { + case 1 : + output_signal[t_out * in_channels + ci] = sigmoid(sum * scale); + break; + + case 2 : + output_signal[t_out * in_channels + ci] = tanh(sum * scale); + break; + + case 3 : + output_signal[t_out * in_channels + ci] = relu(sum * scale); + break; + + default : + output_signal[t_out * in_channels + ci] = sum * scale; + break; } } } @@ -366,67 +522,87 @@ int batchnorm1d(float* output_signal, float* input_signal, const float* const mean, const float* const var, unsigned affine_config, const float* const gamma , const float* const beta, unsigned in_place, float eps) { - // Check if affine values was learnt + float* ret = in_place ? (float*)input_signal : (float*)output_signal; + + // Check for affine_config + // = 1 ; Use gamma, beta, mean and var + // = 2 ; Use only gamma and beta + // = 3 ; Use only mean and var if (affine_config == 1) { - // Check for in-place computation - if (in_place) { - for (unsigned t = 0; t < in_time; t++) { - for (unsigned d = 0; d < in_channels; d++) { - input_signal[t * in_channels + d] = gamma[d] - * ((input_signal[t * in_channels + d] - - mean[d]) / sqrt(var[d] + eps)) - + beta[d]; - } - } - } - else { - for (unsigned t = 0; t < in_time; t++) { - for (unsigned d = 0; d < in_channels; d++) { - output_signal[t * in_channels + d] = gamma[d] - * ((input_signal[t * in_channels + d] - - mean[d]) / sqrt(var[d] + eps)) - + beta[d]; + while (in_time--) { + float* gamma_offset = (float*)gamma; + float* beta_offset = (float*)beta; + float* mean_offset = (float*)mean; + float* var_offset = (float*)var; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); } + #endif + + while (channels--) { + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); } } } else if (affine_config == 2) { - // Check for in-place computation - if (in_place) { - for (unsigned t = 0; t < in_time; t++) { - for (unsigned d = 0; d < in_channels; d++) { - input_signal[t * in_channels + d] = (gamma[d] - * input_signal[t * in_channels + d]) - + beta[d]; - } - } - } - else { - for (unsigned t = 0; t < in_time; t++) { - for (unsigned d = 0; d < in_channels; d++) { - output_signal[t * in_channels + d] = (gamma[d] - * input_signal[t * in_channels + d]) - + beta[d]; + while (in_time--) { + float* gamma_offset = (float*)gamma; + float* beta_offset = (float*)beta; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); } + #endif + + while (channels--) { + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); } } } else { - // Check for in-place computation - if (in_place) { - for (unsigned t = 0; t < in_time; t++) { - for (unsigned d = 0; d < in_channels; d++) { - input_signal[t * in_channels + d] = (input_signal[t * in_channels + d] - - mean[d]) / sqrt(var[d] + eps); - } - } - } - else { - for (unsigned t = 0; t < in_time; t++) { - for (unsigned d = 0; d < in_channels; d++) { - output_signal[t * in_channels + d] = (input_signal[t * in_channels + d] - - mean[d]) / sqrt(var[d] + eps); + while (in_time--) { + float* mean_offset = (float*)mean; + float* var_offset = (float*)var; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); } + #endif + + while (channels--) { + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); } } } diff --git a/c_reference/src/dscnn.c b/c_reference/src/dscnn.c index 09f08e728..a304ff54f 100644 --- a/c_reference/src/dscnn.c +++ b/c_reference/src/dscnn.c @@ -9,7 +9,7 @@ #include "utils.h" int phon_pred_lr_cnn(float* output_signal, float* input_signal, - unsigned in_time, unsigned in_channels, + conv_layer cnn, unsigned in_time, unsigned in_channels, const float* const mean, const float* const var, unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, unsigned cnn_hidden, unsigned cnn_padding, unsigned cnn_kernel_size, @@ -23,7 +23,7 @@ int phon_pred_lr_cnn(float* output_signal, float* input_signal, mean, var, affine_config, gamma, beta, in_place, 0.00001); // CNN - conv1d_lr(output_signal, out_time, cnn_hidden, input_signal, + cnn(output_signal, out_time, cnn_hidden, input_signal, in_time, in_channels, cnn_padding, cnn_kernel_size, cnn_params, cnn_stride, cnn_activation); } @@ -35,7 +35,7 @@ int phon_pred_lr_cnn(float* output_signal, float* input_signal, mean, var, affine_config, gamma, beta, in_place, 0.00001); // CNN - conv1d_lr(output_signal, out_time, cnn_hidden, norm_out, + cnn(output_signal, out_time, cnn_hidden, norm_out, in_time, in_channels, cnn_padding, cnn_kernel_size, cnn_params, cnn_stride, cnn_activation); free(norm_out); @@ -70,7 +70,7 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, in_place, 0.00001); // Depth CNN depth_out = (float*)malloc(out_time * in_channels * sizeof(float)); - conv1d_depth(depth_out, out_time, act_out, + conv1d(depth_out, out_time, 0, act_out, in_time, in_channels, depth_cnn_padding, depth_cnn_kernel_size, depth_cnn_params, depth_cnn_stride, depth_cnn_activation); free(act_out); @@ -86,7 +86,7 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, free(act_out); // Depth CNN depth_out = (float*)malloc(out_time * in_channels * sizeof(float)); - conv1d_depth(depth_out, out_time, norm_out, + conv1d(depth_out, out_time, 0, norm_out, in_time, in_channels, depth_cnn_padding, depth_cnn_kernel_size, depth_cnn_params, depth_cnn_stride, depth_cnn_activation); free(norm_out); diff --git a/c_reference/src/rnn_bricked.c b/c_reference/src/rnn_bricked.c index 8917c8eb9..041ae8f05 100644 --- a/c_reference/src/rnn_bricked.c +++ b/c_reference/src/rnn_bricked.c @@ -8,102 +8,296 @@ #include "utils.h" // Forward Pass -int forward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_signal, - unsigned in_time, unsigned in_dims, unsigned window, unsigned hop, - rnn_layer rnn, const void* params, void* buffers, - unsigned bi_direction, unsigned sample_first_brick, int normalize) { - unsigned out_index = 0, t; // t is an index, but we want to remember the value after the loop. Hence we define it outside +int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_first_brick) { + + // Buffers and params + const BrickedFastGRNN_LR_Params* tparams = (const BrickedFastGRNN_LR_Params*)params; - unsigned rnn_assign_offset = rnn_hidden; - float* temp_hiddenstate = (float*)calloc(rnn_hidden, sizeof(float)); + unsigned rnn_assign_offset = rnn_hidden, out_index = 0; + unsigned num_bricks = (in_time - window) / hop + 1; // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden // This function only processes the forward context if (bi_direction) { rnn_assign_offset <<= 1; } - // for the first window, sample every hop index only if sample_first_block = 1. else only the final hidden state is calculated - for (t = 0; t < window; t++) { - rnn(temp_hiddenstate, rnn_hidden, - input_signal + (t * in_dims) , in_dims, 1, - params, buffers, 0, normalize); + + // Compute W1 * W2 * X + float* inputMulW = (float*)calloc(in_time * rnn_hidden, sizeof(float)); + float* tempLR = (float*)calloc(in_time * tparams->wRank, sizeof(float)); + float* hiddenState = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + float* preComp = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + transposed_tiledMatMul(input_signal, tparams->W1, in_time, in_dims, + tparams->wRank, in_dims, in_dims, + tempLR, tparams->block_size_w_to_lr); + transposed_tiledMatMul(tempLR, tparams->W2, in_time, tparams->wRank, + rnn_hidden, tparams->wRank, tparams->wRank, + inputMulW, tparams->block_size_w_from_lr); + free(tempLR); + // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch + // memset is used. Hence, malloc can be used here for matMul result initialization + tempLR = (float*)malloc(num_bricks * tparams->uRank * sizeof(float)); + for (unsigned t = 0; t < window; t++) { + // From higher dims to lower dims + memset(tempLR, 0, num_bricks * tparams->uRank * sizeof(float)); + transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, + tparams->uRank, rnn_hidden, rnn_hidden, + tempLR, tparams->block_size_u_to_lr); + // From lower dims to higher dims + // Add Wx with Uh + // The tiled MatMuls are codes such that they yield result += matA * matB + // Hence we use calloc and memset to equate the result to 0 + // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input + float* preComp_offset = (float*)preComp; + for (unsigned n = 0; n < num_bricks; n++) { + float* inputMulW_offset = (float*)inputMulW + (n * hop + t) * rnn_hidden; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden %= 4; + while (len_unroll--) { + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + } + #endif + + while (hidden--) { + *preComp_offset++ = *inputMulW_offset++; + } + } + transposed_tiledMatMul(tempLR, tparams->U2, num_bricks, tparams->uRank, + rnn_hidden, tparams->uRank, tparams->uRank, + preComp, tparams->block_size_u_from_lr); + + // Apply the gating + float* hiddenState_offset = (float*)hiddenState; + preComp_offset = (float*)preComp; + unsigned bricks = num_bricks; + while (bricks--) { + float* gateBias = (float*)tparams->Bg; + float* hiddenBias = (float*)tparams->Bh; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden = rnn_hidden % 4; + float gate, update; + while (len_unroll--) { + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + #endif + + while (hidden--) { + float gate = sigmoid((*preComp_offset) + (*gateBias++)); + float update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + } + // Sample first block if necessary if (sample_first_brick) { - if (t % hop==0) { - memcpy(output_signal + ((out_index++) * rnn_assign_offset), - temp_hiddenstate, rnn_hidden * sizeof(float)); + if (t % hop == 0) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState, rnn_hidden * sizeof(float)); } } } - // sample the last hidden state of the first brick - memcpy(output_signal + ((out_index++) * rnn_assign_offset), - temp_hiddenstate, rnn_hidden * sizeof(float)); - // sample the last hidden state of all subsequent bricks, except the last - for (t = hop; t < in_time - window; t += hop ) { - memset(temp_hiddenstate, 0, rnn_hidden * sizeof(float)); - rnn(temp_hiddenstate, rnn_hidden, - input_signal + (t * in_dims) , in_dims, window, - params, buffers, 0, normalize); - memcpy(output_signal + ((out_index++) * rnn_assign_offset), - temp_hiddenstate, rnn_hidden * sizeof(float)); + if (bi_direction) { + // If bi-directional then a gap would need to be left for the backward outputs + float* hiddenState_offset = hiddenState; + for (unsigned n = 0; n < num_bricks; n++) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState_offset, rnn_hidden * sizeof(float)); + hiddenState_offset += rnn_hidden; + } + } + else { + // If only forward is needed, the the whole block of memory can be copied without the loop + memcpy(output_signal + out_index * rnn_assign_offset, + hiddenState, num_bricks * rnn_hidden * sizeof(float)); } - // Calculated seperately since, the time steps left need not be equal to "window" - // Hence if the last brick has less than "window" time steps - // We only pass those values by reducing the forward-window length - memset(temp_hiddenstate, 0, rnn_hidden * sizeof(float)); - rnn(temp_hiddenstate, rnn_hidden, - input_signal + (t * in_dims) , in_dims, in_time - t, - params, buffers, 0, normalize); - memcpy(output_signal + out_index * rnn_assign_offset, - temp_hiddenstate, rnn_hidden * sizeof(float)); - free(temp_hiddenstate); + free(hiddenState); + free(inputMulW); + free(preComp); + free(tempLR); return 0; } // Backward Pass -int backward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_signal, - unsigned in_time, unsigned in_dims, unsigned window, unsigned hop, - rnn_layer rnn, const void* params, void* buffers, - unsigned bi_direction, unsigned sample_last_brick, int normalize) { - unsigned out_index = 0, t; +int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_last_brick) { + + // Buffers and params + const BrickedFastGRNN_LR_Params* tparams = (const BrickedFastGRNN_LR_Params*)params; - // When bi-direction = 1, an offset of "rnn_hidden" will need to be provided during the function call(to the output_signal). - // This is to allocate the results of the backward pass correctly(each of size rnn_hidden, for each time step) unsigned rnn_assign_offset = rnn_hidden; - float* temp_hiddenstate = (float*)calloc(rnn_hidden, sizeof(float)); + unsigned num_bricks = (in_time - window) / hop + 1; + unsigned out_index = in_time / hop; // = out_time - 1; // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden - // This function only processes the backward context. for this context ouput size = rnn_hidden + // This function only processes the forward context if (bi_direction) { rnn_assign_offset <<= 1; } - // sample the last hidden state(in reverse) of all bricks, except the last one - for (t = 0; t < in_time - window; t += hop ) { - memset(temp_hiddenstate, 0, rnn_hidden * sizeof(float)); - rnn(temp_hiddenstate, rnn_hidden, - input_signal + (t * in_dims) , in_dims, window, - params, buffers, 1, normalize); - memcpy(output_signal + ((out_index++) * rnn_assign_offset), - temp_hiddenstate, rnn_hidden * sizeof(float)); - } - // Necessary offset for output allocation, for the sample_last_brick = 1 case - if (sample_last_brick) - out_index += window / hop; - // If sample_last_block = 1, sample every hop index only for the last window - // Else the final hidden state(in reverse) is calculated - unsigned stop_time = t; - memset(temp_hiddenstate, 0, rnn_hidden * sizeof(float)); - for (t = in_time - 1; t >= stop_time; t--) { - rnn(temp_hiddenstate, rnn_hidden, - input_signal + (t * in_dims) , in_dims, 1, - params, buffers, 0, normalize); // Since only one time step is passed at a time, the backward flag can be set either way + + // Compute W1 * W2 * X + float* inputMulW = (float*)calloc(in_time * rnn_hidden, sizeof(float)); + float* tempLR = (float*)calloc(in_time * tparams->wRank, sizeof(float)); + float* hiddenState = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + float* preComp = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + transposed_tiledMatMul(input_signal, tparams->W1, in_time, in_dims, + tparams->wRank, in_dims, in_dims, + tempLR, tparams->block_size_w_to_lr); + transposed_tiledMatMul(tempLR, tparams->W2, in_time, tparams->wRank, + rnn_hidden, tparams->wRank, tparams->wRank, + inputMulW, tparams->block_size_w_from_lr); + free(tempLR); + // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch + tempLR = (float*)calloc(num_bricks * tparams->uRank, sizeof(float)); + for (int t = window - 1; t >= 0; t--) { + // From higher dims to lower dims + memset(tempLR, 0, num_bricks * tparams->uRank * sizeof(float)); + transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, + tparams->uRank, rnn_hidden, rnn_hidden, + tempLR, tparams->block_size_u_to_lr); + // From lower dims to higher dims + // Add Wx with Uh + // The tiled MatMuls are codes such that they yield result += matA * matB + // Hence we use calloc and memset to equate the result to 0 + // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input + float* preComp_offset = (float*)preComp; + for (unsigned n = 0; n < num_bricks; n++) { + float* inputMulW_offset = (float*)inputMulW + (n * hop + t) * rnn_hidden; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden %= 4; + while (len_unroll--) { + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + } + #endif + + while (hidden--) { + *preComp_offset++ = *inputMulW_offset++; + } + } + transposed_tiledMatMul(tempLR, tparams->U2, num_bricks, tparams->uRank, + rnn_hidden, tparams->uRank, tparams->uRank, + preComp, tparams->block_size_u_from_lr); + + // Apply the gating + float* hiddenState_offset = (float*)hiddenState; + preComp_offset = (float*)preComp; + unsigned bricks = num_bricks; + while (bricks--) { + float* gateBias = (float*)tparams->Bg; + float* hiddenBias = (float*)tparams->Bh; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden = rnn_hidden % 4; + float gate, update; + while (len_unroll--) { + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + #endif + + while (hidden--) { + float gate = sigmoid((*preComp_offset) + (*gateBias++)); + float update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + } + // Sample first block if necessary if (sample_last_brick) { - if ((in_time - 1 - t) % hop == 0) { - memcpy(output_signal + ((out_index--) * rnn_assign_offset), - temp_hiddenstate, rnn_hidden * sizeof(float)); + if ((window - 1 - t) % hop == 0) { + // Iterate over the output in reverse + memcpy(output_signal + (out_index--) * rnn_assign_offset, + hiddenState + (num_bricks - 1) * rnn_hidden, rnn_hidden * sizeof(float)); } } } - // sample the last hidden state(in reverse) of the last brick - memcpy(output_signal + out_index * rnn_assign_offset, - temp_hiddenstate, rnn_hidden * sizeof(float)); - free(temp_hiddenstate); + // Since the all first (final in reverse) hiddenstates are calculated, we assign the whole block + out_index = 0; + if (bi_direction) { + // If bi-directional then a gap would need to be left for the backward outputs + float* hiddenState_offset = hiddenState; + for (unsigned n = 0; n < num_bricks; n++) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState_offset, rnn_hidden * sizeof(float)); + hiddenState_offset += rnn_hidden; + } + } + else { + // If only forward is needed, the the whole block of memory can be copied without the loop + memcpy(output_signal + out_index * rnn_assign_offset, + hiddenState, num_bricks * rnn_hidden * sizeof(float)); + } + free(hiddenState); + free(inputMulW); + free(preComp); + free(tempLR); return 0; } diff --git a/c_reference/src/utils.c b/c_reference/src/utils.c index 54c0e6ccd..0373d0c0b 100644 --- a/c_reference/src/utils.c +++ b/c_reference/src/utils.c @@ -76,35 +76,119 @@ void offset_matVec_conv1d(const float* mat, const float* vec, unsigned row_stride, unsigned vec_stride, unsigned depthwise, float* ret) { - for (unsigned row = 0; row < nrows; row++) { - float sum = 0.0f; + while (nrows--) { + // For depthwise, the vec(input) pointer is updated + // Since each row of the mat corresponds to a separate channel index + float* vec_offset = depthwise ? (float*)vec++ : (float*)vec; float* mat_offset = (float*)mat; - float* vec_offset = (float*)vec; + float sum = 0.0f; unsigned cols = ncols; + + #ifdef LOOP_UNROLL + unsigned len_unroll = cols >> 2; + cols %= 4; // ncols % 4 + while (len_unroll--) { + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + } + #endif + while (cols--) { sum += (*mat_offset++) * (*vec_offset); vec_offset += vec_stride; } *ret++ = sum; mat += row_stride; - // For depthwise, the vec(input) pointer is updated since each row of the mat corresponds to a separate index in the channels - if (depthwise) { - vec++; + } +} + +void tiledMatMul_float(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_cols_B, + float* const ret, unsigned block_size) { + for (unsigned row = 0; row < nrows; row += block_size) { + unsigned row_block_size = (row + block_size < nrows) ? block_size : nrows - row; + for (unsigned col = 0; col < ncols; col += block_size) { + unsigned col_block_size = (col + block_size < ncols) ? block_size : ncols - col; + for (unsigned comm = 0; comm < ncommon; comm += block_size) { + unsigned comm_block_size = (comm + block_size < ncommon) ? block_size : ncommon - comm; + for (unsigned block_row = row; block_row < row + row_block_size; block_row++) { + float *ret_offset = (float *)ret + block_row * ncols + col; + for (unsigned block_col = col; block_col < col + col_block_size; block_col++) { + float sum = 0; + unsigned temp_block_size = comm_block_size; + const float *matA_offset = (const float*)matA + block_row * total_comm_A + comm; + const float *matB_offset = (const float*)matB + comm * total_cols_B + block_col; + + #ifdef LOOP_UNROLL + unsigned len_unroll = temp_block_size >> 2; + temp_block_size %= 4; // comm_block_size % 4 + while (len_unroll--) { + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + } + #endif + + while (temp_block_size--) { + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + } + *ret_offset++ += sum; + } + } + } } } } -void matMul(const float* const matA, const float* const matB, +void transposed_tiledMatMul(const float* const matA, const float* const matB, unsigned nrows, unsigned ncommon, unsigned ncols, - float alpha, float beta, - float* const ret) { - for (unsigned row = 0; row < nrows; row++) { - for (unsigned col = 0; col < ncols; col++) { - float sum = 0; - for (unsigned k = 0; k < ncommon; k++) { - sum += (matA[row * ncommon + k] * matB[k * ncols + col]); + unsigned total_comm_A, unsigned total_comm_B, + float* const ret, unsigned block_size) { + for (unsigned row = 0; row < nrows; row += block_size) { + unsigned row_block_size = (row + block_size < nrows) ? block_size : nrows - row; + for (unsigned col = 0; col < ncols; col += block_size) { + unsigned col_block_size = (col + block_size < ncols) ? block_size : ncols - col; + for (unsigned comm = 0; comm < ncommon; comm += block_size) { + unsigned comm_block_size = (comm + block_size < ncommon) ? block_size : ncommon - comm; + for (unsigned block_row = row; block_row < row + row_block_size; block_row++) { + float *ret_offset = (float *)ret + block_row * ncols + col; + for (unsigned block_col = col; block_col < col + col_block_size; block_col++) { + float sum = 0; + unsigned temp_block_size = comm_block_size; + const float *matA_offset = (const float*)matA + block_row * total_comm_A + comm; + const float *matB_offset = (const float*)matB + block_col * total_comm_B + comm; + + #ifdef LOOP_UNROLL + unsigned len_unroll = temp_block_size >> 2; + temp_block_size %= 4; // comm_block_size % 4 + while (len_unroll--) { + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + } + #endif + + while (temp_block_size--) { + sum += (*matA_offset++) * (*matB_offset++); + } + *ret_offset++ += sum; + } + } } - ret[row * ncols + col] = alpha * ret[row * ncols + col] + beta * sum; } } } @@ -161,11 +245,31 @@ void softmax(const float* const input, unsigned len, float* const ret) { void semi_sigmoid_tanh(float* output_signal, const float* const input_signal, unsigned in_time, unsigned in_channels) { - unsigned piv = in_channels >> 1; - for (unsigned t = 0; t < in_time; t++) { - for (unsigned d = 0; d < piv; d++) { - output_signal[t * piv + d] = sigmoid(input_signal[t * in_channels + d]) - * tanh(input_signal[t * in_channels + d + piv]); + unsigned time_step = 0; // used to avoid index multiplication + while (in_time--) { + unsigned pivot = in_channels >> 1; + float* input_sigmoid_offset = (float*)input_signal + time_step; + float* input_tanh_offset = (float*)input_signal + time_step + pivot; + + #ifdef LOOP_UNROLL + unsigned len_unroll = pivot >> 2; + pivot %= 4; + while (len_unroll--) { + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + } + #endif + + while (pivot--) { + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); } + time_step += in_channels; } } diff --git a/c_reference/tests/Makefile b/c_reference/tests/Makefile index 84512dc7b..4eb8c7d70 100644 --- a/c_reference/tests/Makefile +++ b/c_reference/tests/Makefile @@ -43,11 +43,11 @@ test_quantized_face_detection_sparse: $(FACE_DETECTION_DIR)/test_quantized_face_ $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -Wno-unused-result -lm RNNBRICKED_DIR=rnn_bricked -test_rnn_bricked: $(RNNBRICKED_DIR)/test_rnn_bricked.c $(SRC_DIR)/utils.o $(SRC_DIR)/fastgrnn.o $(SRC_DIR)/rnn_bricked.o +test_rnn_bricked: $(RNNBRICKED_DIR)/test_rnn_bricked.c $(SRC_DIR)/utils.o $(SRC_DIR)/rnn_bricked.o $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm KWS_DIR=kws -test_phoneme_det_cnn_rnn: $(KWS_DIR)/test_phoneme_det_cnn_rnn.c $(SRC_DIR)/utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o $(SRC_DIR)/fastgrnn.o $(SRC_DIR)/rnn_bricked.o +test_phoneme_det_cnn_rnn: $(KWS_DIR)/test_phoneme_det_cnn_rnn.c $(SRC_DIR)/utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o $(SRC_DIR)/rnn_bricked.o $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm .PHONY: clean cleanest diff --git a/c_reference/tests/conv1d/conv1d_lr_depthwise/conv_param_lr_depth.h b/c_reference/tests/conv1d/conv1d_lr_depthwise/conv_param_lr_depth.h deleted file mode 100644 index 373dbf1fe..000000000 --- a/c_reference/tests/conv1d/conv1d_lr_depthwise/conv_param_lr_depth.h +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:df49d58db43d0ea21492325425e3ddcf18a16ff8c712aca42f1576dfc3db7d31 -size 36305 diff --git a/c_reference/tests/conv1d/test_conv1d.c b/c_reference/tests/conv1d/test_conv1d.c index d276bd2bb..189b4f257 100644 --- a/c_reference/tests/conv1d/test_conv1d.c +++ b/c_reference/tests/conv1d/test_conv1d.c @@ -9,7 +9,6 @@ #include "./conv1d_regular/conv_param.h" #include "./conv1d_depthwise/conv_param_depth.h" #include "./conv1d_lr/conv_param_lr.h" -#include "./conv1d_lr_depthwise/conv_param_lr_depth.h" // Error Check void errorCheck(float* pred, float* label, unsigned out_time, int out_features) { @@ -29,6 +28,7 @@ void conv1d_check() { ConvLayers_Params conv_params = { .W = CONV1D_CONV_WEIGHT, .B = CONV1D_CONV_BIAS, + .depthwise = 0, }; float* pred = (float*)malloc(CONV1D_OUT_TIME * CONV1D_OUT_FEATURES * sizeof(float)); @@ -41,15 +41,33 @@ void conv1d_check() { free(pred); } +void conv1d_parallel_check() { + ConvLayers_Parallel_Params conv_params = { + .W = CONV1D_CONV_WEIGHT, + .B = CONV1D_CONV_BIAS, + .block_size = 100, + }; + + float* pred = (float*)malloc(CONV1D_OUT_TIME * CONV1D_OUT_FEATURES * sizeof(float)); + conv1d_parallel(pred, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES, CONV1D_INPUT, + CONV1D_IN_TIME, CONV1D_IN_FEATURES, CONV1D_PAD, CONV1D_FILT, + &conv_params, CONV1D_STRIDE, CONV1D_ACT); + + printf("Testing Parallel Convolution\n"); + errorCheck(pred, CONV1D_OUTPUT, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES); + free(pred); +} + void conv1d_depth_check() { ConvLayers_Params conv_params = { .W = CONV1D_DEPTH_CONV_WEIGHT, .B = CONV1D_DEPTH_CONV_BIAS, + .depthwise = 1, }; float* pred = (float*)malloc(CONV1D_DEPTH_OUT_TIME * CONV1D_DEPTH_OUT_FEATURES * sizeof(float)); - conv1d_depth(pred, CONV1D_DEPTH_OUT_TIME, CONV1D_DEPTH_INPUT, + conv1d(pred, CONV1D_DEPTH_OUT_TIME, 0, CONV1D_DEPTH_INPUT, CONV1D_DEPTH_IN_TIME, CONV1D_DEPTH_IN_FEATURES, CONV1D_DEPTH_PAD, CONV1D_DEPTH_FILT, &conv_params, CONV1D_DEPTH_STRIDE, CONV1D_DEPTH_ACT); @@ -78,31 +96,35 @@ void conv1d_lr_check() { free(pred); } -void conv1d_lr_depth_check() { - ConvLayers_LR_Params conv_params = { - .W1 = CONV1D_LR_DEPTHWISE_CONV_W1, - .W2 = CONV1D_LR_DEPTHWISE_CONV_W2, - .B = CONV1D_LR_DEPTHWISE_CONV_BIAS, - .rank = CONV1D_LR_DEPTHWISE_LOW_RANK +void conv1d_lr_parallel_check() { + ConvLayers_LR_Parallel_Params conv_params = { + .W1 = CONV1D_LR_CONV_W1, + .W2 = CONV1D_LR_CONV_W2, + .B = CONV1D_LR_CONV_BIAS, + .rank = CONV1D_LR_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, }; - float* pred = (float*)malloc(CONV1D_LR_DEPTHWISE_OUT_TIME - * CONV1D_LR_DEPTHWISE_OUT_FEATURES * sizeof(float)); - conv1d_depth_lr(pred, CONV1D_LR_DEPTHWISE_OUT_TIME, CONV1D_LR_DEPTHWISE_INPUT, - CONV1D_LR_DEPTHWISE_IN_TIME, CONV1D_LR_DEPTHWISE_IN_FEATURES, - CONV1D_LR_DEPTHWISE_PAD, CONV1D_LR_DEPTHWISE_FILT, - &conv_params, CONV1D_LR_DEPTHWISE_STRIDE, CONV1D_LR_DEPTHWISE_ACT); - - printf("Testing Low-Rank Depthwise Convolution\n"); - errorCheck(pred, CONV1D_LR_DEPTHWISE_OUTPUT, - CONV1D_LR_DEPTHWISE_OUT_TIME, CONV1D_LR_DEPTHWISE_OUT_FEATURES); + float* pred = (float*)malloc(CONV1D_LR_OUT_TIME + * CONV1D_LR_OUT_FEATURES * sizeof(float)); + conv1d_lr_parallel(pred, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES, CONV1D_LR_INPUT, + CONV1D_LR_IN_TIME, CONV1D_LR_IN_FEATURES, CONV1D_LR_PAD, CONV1D_LR_FILT, + &conv_params, CONV1D_LR_STRIDE, CONV1D_LR_ACT); + + printf("Testing Low-Rank Parallel Convolution\n"); + errorCheck(pred, CONV1D_LR_OUTPUT, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES); free(pred); } int main() { + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif conv1d_check(); + conv1d_parallel_check(); conv1d_lr_check(); conv1d_depth_check(); - conv1d_lr_depth_check(); + conv1d_lr_parallel_check(); return 0; } diff --git a/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c index 2ba0f6fe4..ff204c8ef 100644 --- a/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c +++ b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c @@ -7,7 +7,6 @@ #include #include "conv1d.h" #include "dscnn.h" -#include "fastgrnn.h" #include "utils.h" #include "rnn_bricked.h" @@ -60,64 +59,76 @@ void checkError(float* pred, float* label) { These changes are subject to the input specs i.e fixing input buffer time steps, number of features from the deployed featurizer, method of reading the input into a buffer */ void phoneme_prediction(float* mem_buf) { - ConvLayers_LR_Params conv_params = { + ConvLayers_LR_Parallel_Params conv_params = { .W1 = CNN1_W1, .W2 = CNN1_W2, .B = CNN1_BIAS, - .rank = PRE_CNN_LOW_RANK + .rank = PRE_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, }; ConvLayers_Params depth_param_2 = { .W = CNN2_DEPTH_W, .B = CNN2_DEPTH_BIAS, + .depthwise = 1, }; - ConvLayers_LR_Params point_param_2 = { + ConvLayers_LR_Parallel_Params point_param_2 = { .W1 = CNN2_POINT_W1, .W2 = CNN2_POINT_W2, .B = CNN2_POINT_BIAS, - .rank = POST_CNN_LOW_RANK + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, }; ConvLayers_Params depth_param_3 = { .W = CNN3_DEPTH_W, .B = CNN3_DEPTH_BIAS, + .depthwise = 1, }; - ConvLayers_LR_Params point_param_3 = { + ConvLayers_LR_Parallel_Params point_param_3 = { .W1 = CNN3_POINT_W1, .W2 = CNN3_POINT_W2, .B = CNN3_POINT_BIAS, - .rank = POST_CNN_LOW_RANK + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, }; ConvLayers_Params depth_param_4 = { .W = CNN4_DEPTH_W, .B = CNN4_DEPTH_BIAS, + .depthwise = 1, }; - ConvLayers_LR_Params point_param_4 = { + ConvLayers_LR_Parallel_Params point_param_4 = { .W1 = CNN4_POINT_W1, .W2 = CNN4_POINT_W2, .B = CNN4_POINT_BIAS, - .rank = POST_CNN_LOW_RANK + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, }; ConvLayers_Params depth_param_5 = { .W = CNN5_DEPTH_W, .B = CNN5_DEPTH_BIAS, + .depthwise = 1, }; - ConvLayers_LR_Params point_param_5 = { + ConvLayers_LR_Parallel_Params point_param_5 = { .W1 = CNN5_POINT_W1, .W2 = CNN5_POINT_W2, .B = CNN5_POINT_BIAS, - .rank = POST_CNN_LOW_RANK + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, }; - FastGRNN_LR_Params bwd_RNN_params = { - .mean = 0, - .stdDev = 0, + BrickedFastGRNN_LR_Params bwd_RNN_params = { .W1 = B_W1, .W2 = B_W2, .wRank = RNN_LOW_RANK, @@ -127,12 +138,14 @@ void phoneme_prediction(float* mem_buf) { .Bg = B_BIAS_GATE, .Bh = B_BIAS_UPDATE, .sigmoid_zeta = sigmoid(B_ZETA), - .sigmoid_nu = sigmoid(B_NU) + .sigmoid_nu = sigmoid(B_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, }; - FastGRNN_LR_Params fwd_RNN_params = { - .mean = 0, - .stdDev = 0, + BrickedFastGRNN_LR_Params fwd_RNN_params = { .W1 = F_W1, .W2 = F_W2, .wRank = RNN_LOW_RANK, @@ -142,19 +155,13 @@ void phoneme_prediction(float* mem_buf) { .Bg = F_BIAS_GATE, .Bh = F_BIAS_UPDATE, .sigmoid_zeta = sigmoid(F_ZETA), - .sigmoid_nu = sigmoid(F_NU) + .sigmoid_nu = sigmoid(F_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, }; - float preComp[RNN_IN_FEATURES] = { 0.0 }; - float tempLRW[RNN_LOW_RANK] = { 0.0 }; - float tempLRU[RNN_LOW_RANK] = { 0.0 }; - float normFeatures[RNN_IN_FEATURES] = { 0.0 }; - FastGRNN_LR_Buffers buffers = { - .preComp = preComp, - .tempLRW = tempLRW, - .tempLRU = tempLRU, - .normFeatures = normFeatures - }; unsigned in_time, out_time; /* Pre-CNN */ @@ -164,7 +171,7 @@ void phoneme_prediction(float* mem_buf) { // Since batchnorm1d is the first layer and in-place will alter the input. // Use the in-place computation only if the input can be discarded/altered. Else avoid in-place computation for this layer phon_pred_lr_cnn(cnn1_out, mem_buf, - in_time, PRE_CNN_IN_FEATURES, + conv1d_lr_parallel, in_time, PRE_CNN_IN_FEATURES, 0, 0, PRE_CNN_BNORM_AFFINE, CNN1_SCALE, CNN1_OFFSET, PRE_CNN_BNORM_INPLACE, PRE_CNN_OUT_FEATURES, PRE_CNN_FILT_PAD, PRE_CNN_FILT, &conv_params, PRE_CNN_STRIDE, PRE_CNN_FILT_ACT); // regular tanh activation @@ -175,15 +182,14 @@ void phoneme_prediction(float* mem_buf) { /* Bricked Bi-FastGRNN Block */ out_time = in_time/RNN_HOP + 1; float* rnn_out = (float*)malloc(out_time * RNN_OUT_FEATURES * sizeof(float)); - forward_bricked_rnn(rnn_out, RNN_OUT_FEATURES >> 1, cnn1_out, + forward_bricked_fastgrnn_lr(rnn_out, RNN_OUT_FEATURES >> 1, cnn1_out, in_time, RNN_IN_FEATURES, RNN_FWD_WINDOW, RNN_HOP, - fastgrnn_lr, &fwd_RNN_params, &buffers, - RNN_BI_DIR, RNN_SAMPLE_FIRST_BRICK, 0); + &fwd_RNN_params, RNN_BI_DIR, RNN_SAMPLE_FIRST_BRICK); - backward_bricked_rnn(rnn_out + (RNN_OUT_FEATURES >> 1), RNN_OUT_FEATURES >> 1, cnn1_out, + backward_bricked_fastgrnn_lr(rnn_out + (RNN_OUT_FEATURES >> 1), + RNN_OUT_FEATURES >> 1, cnn1_out, in_time, RNN_IN_FEATURES, RNN_BWD_WINDOW, RNN_HOP, - fastgrnn_lr, &bwd_RNN_params, &buffers, - RNN_BI_DIR, RNN_SAMPLE_LAST_BRICK, 0); + &bwd_RNN_params, RNN_BI_DIR, RNN_SAMPLE_LAST_BRICK); free(cnn1_out); /* Post-CNN */ @@ -194,7 +200,7 @@ void phoneme_prediction(float* mem_buf) { out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; float* cnn2_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); phon_pred_depth_point_lr_cnn(cnn2_out, rnn_out, - conv1d_lr, in_time, POST_CNN_INTER_FEATURES, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, 0, 0, POST_CNN_BNORM_AFFINE, CNN2_SCALE, CNN2_OFFSET, POST_CNN_BNORM_INPLACE, POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, &depth_param_2, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, @@ -209,7 +215,7 @@ void phoneme_prediction(float* mem_buf) { out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; float* cnn3_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); phon_pred_depth_point_lr_cnn(cnn3_out, cnn2_out, - conv1d_lr, in_time, POST_CNN_INTER_FEATURES, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, 0, 0, POST_CNN_BNORM_AFFINE, CNN3_SCALE, CNN3_OFFSET, POST_CNN_BNORM_INPLACE, POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, &depth_param_3, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, @@ -224,7 +230,7 @@ void phoneme_prediction(float* mem_buf) { out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; float* cnn4_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); phon_pred_depth_point_lr_cnn(cnn4_out, cnn3_out, - conv1d_lr, in_time, POST_CNN_INTER_FEATURES, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, 0, 0, POST_CNN_BNORM_AFFINE, CNN4_SCALE, CNN4_OFFSET, POST_CNN_BNORM_INPLACE, POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, &depth_param_4, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, @@ -239,7 +245,7 @@ void phoneme_prediction(float* mem_buf) { out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; float* pred = (float*)malloc(out_time * POST_CNN_OUT_FEATURES * sizeof(float)); phon_pred_depth_point_lr_cnn(pred, cnn4_out, - conv1d_lr, in_time, POST_CNN_INTER_FEATURES, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, 0, 0, POST_CNN_BNORM_AFFINE, CNN5_SCALE, CNN5_OFFSET, POST_CNN_BNORM_INPLACE, POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, &depth_param_5, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, @@ -257,6 +263,9 @@ void phoneme_prediction(float* mem_buf) { } int main() { + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif clock_t begin = clock(); phoneme_prediction(INPUT); clock_t end = clock(); diff --git a/c_reference/tests/rnn_bricked/test_rnn_bricked.c b/c_reference/tests/rnn_bricked/test_rnn_bricked.c index 8fe846f00..701d73af4 100644 --- a/c_reference/tests/rnn_bricked/test_rnn_bricked.c +++ b/c_reference/tests/rnn_bricked/test_rnn_bricked.c @@ -4,7 +4,6 @@ #include #include #include "rnn_bricked.h" -#include "fastgrnn.h" #include "utils.h" #include "rnn_params.h" @@ -12,9 +11,7 @@ int main() { - FastGRNN_LR_Params bwd_RNN_params = { - .mean = 0, - .stdDev = 0, + BrickedFastGRNN_LR_Params bwd_RNN_params = { .W1 = B_W1, .W2 = B_W2, .wRank = RNN_LOW_RANK, @@ -24,12 +21,14 @@ int main() { .Bg = B_BIAS_GATE, .Bh = B_BIAS_UPDATE, .sigmoid_zeta = sigmoid(B_ZETA), - .sigmoid_nu = sigmoid(B_NU) + .sigmoid_nu = sigmoid(B_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, }; - FastGRNN_LR_Params fwd_RNN_params = { - .mean = 0, - .stdDev = 0, + BrickedFastGRNN_LR_Params fwd_RNN_params = { .W1 = F_W1, .W2 = F_W2, .wRank = RNN_LOW_RANK, @@ -39,31 +38,22 @@ int main() { .Bg = F_BIAS_GATE, .Bh = F_BIAS_UPDATE, .sigmoid_zeta = sigmoid(F_ZETA), - .sigmoid_nu = sigmoid(F_NU) + .sigmoid_nu = sigmoid(F_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, }; - float preComp[RNN_IN_FEATURES] = { 0.0 }; - float tempLRW[RNN_LOW_RANK] = { 0.0 }; - float tempLRU[RNN_LOW_RANK] = { 0.0 }; - float normFeatures[RNN_IN_FEATURES] = { 0.0 }; - FastGRNN_LR_Buffers buffers = { - .preComp = preComp, - .tempLRW = tempLRW, - .tempLRU = tempLRU, - .normFeatures = normFeatures - }; - - float pred[RNN_OUT_TIME * RNN_OUT_FEATURES] = {}; + float* pred = (float*)malloc(RNN_OUT_TIME * RNN_OUT_FEATURES * sizeof(float)); - forward_bricked_rnn(pred, RNN_OUT_FEATURES >> 1, INPUT, + forward_bricked_fastgrnn_lr(pred, RNN_OUT_FEATURES >> 1, INPUT, RNN_IN_TIME, RNN_IN_FEATURES, FWD_WINDOW, HOP, - fastgrnn_lr, &fwd_RNN_params, &buffers, - 1, 1, 0); + &fwd_RNN_params, 1, 1); - backward_bricked_rnn(pred + (RNN_OUT_FEATURES >> 1), RNN_OUT_FEATURES >> 1, INPUT, + backward_bricked_fastgrnn_lr(pred + (RNN_OUT_FEATURES >> 1), RNN_OUT_FEATURES >> 1, INPUT, RNN_IN_TIME, RNN_IN_FEATURES, BWD_WINDOW, HOP, - fastgrnn_lr, &bwd_RNN_params, &buffers, - 1, 1, 0); + &bwd_RNN_params, 1, 1); float error = 0; float denom = 0; @@ -76,7 +66,12 @@ int main() { } float avg_error = error / (RNN_OUT_TIME * RNN_OUT_FEATURES); float rmse = error / denom; + + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif printf("Testing Bricked RNNs Bi-Directional\n"); printf("Agg Squared Error: %f ; MSE: %f ; RMSE: %f\n", error, avg_error, rmse); + free(pred); return 0; }