Skip to content

Commit

Permalink
Added MatMul and parallel versions of the conv and bricked rnn layers
Browse files Browse the repository at this point in the history
  • Loading branch information
Anirudh0707 committed Jul 2, 2021
1 parent 6411d15 commit 7efca51
Show file tree
Hide file tree
Showing 14 changed files with 1,187 additions and 564 deletions.
1 change: 0 additions & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ c_reference/tests/kws/keyword_spotting_io_2.h filter=lfs diff=lfs merge=lfs -tex
c_reference/tests/kws/keyword_spotting_io_3.h filter=lfs diff=lfs merge=lfs -text
c_reference/tests/conv1d/conv1d_regular/conv_param.h filter=lfs diff=lfs merge=lfs -text
c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h filter=lfs diff=lfs merge=lfs -text
c_reference/tests/conv1d/conv1d_lr_depthwise/conv_param_lr_depth.h filter=lfs diff=lfs merge=lfs -text
c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h filter=lfs diff=lfs merge=lfs -text
c_reference/tests/kws/precnn_params.h filter=lfs diff=lfs merge=lfs -text
c_reference/tests/kws/postcnn_params.h filter=lfs diff=lfs merge=lfs -text
Expand Down
140 changes: 97 additions & 43 deletions c_reference/include/conv1d.h

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion c_reference/include/dscnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ typedef int (*conv_layer)(float*, unsigned, unsigned, const float*,
* @brief sub-layers : batchnorm1d -> conv1d_lr
* @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers
* @param[in] input_signal pointer to the input signal. size = in_time * in_channels
* @param[in] cnn function pointer for the CNN layer. (any of the conv layers can be passed with appropriate params)
* @param[in] in_time number of time steps in the input_signal
* @param[in] in_channels number of input channels
* @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2
Expand All @@ -38,7 +39,7 @@ typedef int (*conv_layer)(float*, unsigned, unsigned, const float*,
* 3: relu
*/
int phon_pred_lr_cnn(float* output_signal, float* input_signal,
unsigned in_time, unsigned in_channels,
conv_layer cnn, unsigned in_time, unsigned in_channels,
const float* const mean, const float* const var,
unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place,
unsigned cnn_hidden, unsigned cnn_padding, unsigned cnn_kernel_size,
Expand All @@ -49,6 +50,7 @@ int phon_pred_lr_cnn(float* output_signal, float* input_signal,
* @brief sub-layers : custom nonlinearity(semi_sigmoid_tanh) -> batchnorm1d -> conv1d_depth -> conv1d_lr -> avgpool1d
* @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers
* @param[in] input_signal pointer to the input signal. size = in_time * in_channels
* @param[in] point_cnn function pointer for the point-wise CNN. (any of the conv layers can be passed with appropriate params)
* @param[in] in_time number of time steps in the input
* @param[in] in_channels number of input channels
* @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2
Expand Down
105 changes: 65 additions & 40 deletions c_reference/include/rnn_bricked.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,65 @@
#ifndef __RNN_BRICKED_H__
#define __RNN_BRICKED_H__

// Function pointer for the RNN to be passed as a parameter
typedef int (*rnn_layer)(float* const, unsigned, const float* const, unsigned,
unsigned, const void*, void*, int, int);
/* All the matrices are stored in the row major format
NOTES for using the layers
-> Single-directional Computation
While using the bricked fastgrnn layers, the user needs to adhered to the two following constraints
1) in_time % hop = 0
2) fwd_window % hop = 0 and bwd_window % hop = 0
// NOTES for bi-direction
// If bi_direction = 1, then actual rnn_output_dims is twice the rnn_hidden(rnn_hidden is output dims for each cell).
// Each function will only process its given context(forward/backward).
// The other context will need to be called separately with an appropriate offset.
// E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...)
// 2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...)
//
// Each cell will only calculate half the hidden state i.e. rnn_hidden slots of memory from the start of the output pointer
// Hence rnn_hidden is used as an offset for the backward pass. The offset for the forward pass is 0
// This use of an offset is a way to exploit the nature of bi-direction to bypass the concatenation step typically associated with bi-directional passes
//
// Constraints
// For Bi-Directional use, there are 3 constraints
// 1) (in_time - fwd_window) % hop == 0 and (in_time - bwd_window) % hop == 0
// 2) fwd_window % hop == 0 and bwd_window % hop == 0
// 3) sample_first_brick and sample_last_brick = 1
//
// Violation of these constraints can lead to one of the following issues
// 1) segmentation faults
// 2) forward out_time != backward out_time
// 3) mismatch between forward index and backward index during sampling i.e forward index 8 would correspond to backward index 6. This index error continues for all consecutive bricks
// Hence, padding of the input and appropriate window choice is necessary
//
// These constraints can be ignored while performing uni-directional passes. However, it is favorable to follow constraints 1 and 2
Violation of the above two constraints (1 & 2), will cause segmentation faults
The layers first compute all the Wx steps and then compute Uh for all the windows parallelly
Hence, the user needs to adhered to the constraints 1 & 2
-> Bi-directional Computation
For bi-directional cases, there are 2 additionally constraints that would need to be followed
A) sample_first_brick and sample_last_brick = 1
B) An offset of rnn_hidden would need to be given to the output_signal pointer during the backward function call
Each function will only process its given context(forward/backward). The other context will need to be called separately.
E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...)
2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...)
The two extra constraints (A & B) are only for bi-directional cases and can be ignored if only forward (or only backward) is used
Violating the conditions would cause index mis-matches or data corruption
If the first (last) brick is not sampled, the first few (last few) time steps would be missing in the forward (backward) result
If the offset is not passed during the backward function call, the backward pass will overwrite the forward result (bi-directional case only)
*/

/**
* @brief Model parameters for the 1D Convolution Layer
* @var W1 pointer to first low-rank component of W. shape = [rank * in_dims]
* @var W2 pointer to second low-rank component of W. shape = [rnn_hidden * rank]
* @var wRank rank of W matrix
* @var U1 pointer to first low-rank component of U. shape = [rank * rnn_hidden]
* @var U2 pointer to second low-rank component of U. shape = [rnn_hidden * rank]
* @var uRank rank of U matrix
* @var Bg pointer to bias for sigmoid
* @var Bh pointer to bias for tanh
* @var sigmoid_zeta first weight parameter for update from input from next step
* @var sigmoid_nu second weight parameter for update from input from next step
* @var block_size_w_to_lr block/tile size for the cache. Used for tiled MatMul. For W1 * x
* @var block_size_w_from_lr block/tile size for the cache. Used for tiled MatMul. For W2 * result(W1 * x)
* @var block_size_u_to_lr block/tile size for the cache. Used for tiled MatMul. For U1 * h
* @var block_size_u_from_lr block/tile size for the cache. Used for tiled MatMul. For U2 * result(U1 * h)
*/
typedef struct BrickedFastGRNN_LR_Params {
float* W1;
float* W2;
unsigned wRank;
float* U1;
float* U2;
unsigned uRank;
float* Bg;
float* Bh;
float sigmoid_zeta;
float sigmoid_nu;
unsigned block_size_w_to_lr;
unsigned block_size_w_from_lr;
unsigned block_size_u_to_lr;
unsigned block_size_u_from_lr;
} BrickedFastGRNN_LR_Params;

/** Forward Bricking and application of the forward RNN for an input signal
* @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden
Expand All @@ -42,18 +72,16 @@ typedef int (*rnn_layer)(float* const, unsigned, const float* const, unsigned,
* @param[in] in_dims input dimensions
* @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick)
* @param[in] hop hop distance for between bricks
* @param[in] rnn function pointer to the RNN
* @param[in] params pointer to the parameters for the RNN
* @param[in,out] buffers pointer to buffer for the RNN
* @param[in] bi_direction determine if the ouput if for a bi-directional RNN.
* @param[in] sample_first_brick determine if the 1st brick should also be sampled
* -> if = 0, only the last hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1
* -> if = 1, for the 1st brick, we sample every hop index(similar to ::hop). For all the bricks(including the 1st) we sample the final hiddens state. out_time = in_time/hop + 1
*/
int forward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_signal,
unsigned in_time, unsigned in_dims, unsigned window, unsigned hop,
rnn_layer rnn, const void* params, void* buffers,
unsigned bi_direction, unsigned sample_first_brick, int normalize);
int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden,
float* input_signal, unsigned in_time, unsigned in_dims,
unsigned window, unsigned hop, const void* params,
unsigned bi_direction, unsigned sample_first_brick);

/** Backward Bricking and application of the backward RNN for an input signal
* @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden
Expand All @@ -63,18 +91,15 @@ int forward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_
* @param[in] in_dims input dimensions
* @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick)
* @param[in] hop hop distance for between bricks
* @param[in] rnn function pointer to the RNN
* @param[in] params pointer to the parameters for the RNN
* @param[in,out] buffers pointer to buffer for the RNN
* @param[in] bi_direction determine if the ouput if for a bi-directional RNN.
* @param[in] sample_last_brick determine if the last brick should also be sampled
* -> if = 0, only the first(last in reverse) hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1
* -> if = 1, for the last brick, we sample every hop index in reverse(similar to ::hop in reverse). For all the bricks(including the last) we sample the first hiddens state(last in reverse). out_time = in_time/hop + 1
*/
int backward_bricked_rnn(float* output_signal, unsigned rnn_hidden, float* input_signal,
unsigned in_time, unsigned in_dims, unsigned window, unsigned hop,
rnn_layer rnn, const void* params, void* buffers,
unsigned bi_direction, unsigned sample_last_brick, int normalize);

int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden,
float* input_signal, unsigned in_time, unsigned in_dims,
unsigned window, unsigned hop, const void* params,
unsigned bi_direction, unsigned sample_last_brick);

#endif
96 changes: 71 additions & 25 deletions c_reference/include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,83 @@ void matVec(const float* const mat, const float* const vec,
float alpha, float beta,
float* const ret);

/* Matrix-vector multiplication with a row offset
This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis
ret is of size nrows, vec is of size ncols
mat is of size nrows * ncols, stored in row major
depthwise is to change the matVec to depthwise specific convolutions
row_stride is the offset factor between two adjacent rows
Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped
For a normal matVec case, this value will be ncols
Eg : for a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. For this eg ncols will be 100 and row_stride will be 400
vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals
For a normal matVec case, this value will be 1
Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. So it's possible to enter a 400 length vector and consider every 4th element. For this ncols will be 100 and vec_stride will be 4*/
/*
Matrix-vector multiplication with a row offset
This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis
ret is of size nrows, vec is of size ncols
mat is of size nrows * ncols, stored in row major
depthwise is to change the matVec to depthwise specific convolutions
row_stride is the offset factor between two adjacent rows
Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped
For a normal matVec case, this value will be ncols
Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication.
Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication.
Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication.
For this eg ncols will be 100 and row_stride will be 400
vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals
For a normal matVec case, this value will be 1
Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed.
Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed.
Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed.
So it's possible to enter a 400 length vector and consider every 4th element.
So it's possible to enter a 400 length vector and consider every 4th element.
So it's possible to enter a 400 length vector and consider every 4th element.
For this ncols will be 100 and vec_stride will be 4
*/
void offset_matVec_conv1d(const float* mat, const float* vec,
unsigned nrows, unsigned ncols,
unsigned row_stride, unsigned vec_stride,
unsigned depthwise, float* ret);

/* Scaled matrix-matrix multiplication: ret = alpha * ret + beta * matA * matB
matA first matrix; size = nrows * ncommon
matB second matrix; size = ncommon * ncols
nrows number of rows in the first matrix
ncommon number of columns in the first matrix/number of rows in the second matrix
ncols number of columns in the second matrix
alpha scaling factor for the previously-stored output matrix
beta scaling factor for the result of the multiplication (matA * matB)
ret matrix multiplication output
*/
void matMul(const float* const matA, const float* const matB,
/*
Tiled (cache-blocked) implementation of the Matrix Multiplication
Note: If only the MatMul output is needed, then please use calloc to initialize the output
An alternative is to use malloc, followed by memset 0
There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix
If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly
This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed
matA first matrix; shape = [nrows, ncommon]
matB second matrix; shape = [ncommon, ncols]
nrows number of rows in the first matrix
ncommon number of columns in the first matrix/number of rows in the second matrix
ncols number of columns in the second matrix
total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored
total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
ret matrix multiplication output. shape = [nrows, ncols]
block_size tile/block size for optimal cache performance. A hardware specific parameter
*/
void tiledMatMul_float(const float* const matA, const float* const matB,
unsigned nrows, unsigned ncommon, unsigned ncols,
float alpha, float beta,
float* const ret);
unsigned total_comm_A, unsigned total_cols_B,
float* const ret, unsigned block_size);

/*
Tiled (cache-blocked) implementation of the Matrix Multiplication, but with matB stored in the transposed format
The result will the same as the regular MatMul but the matrix B provided will be pre-transposed (before the storage or usage)
Note: If only the MatMul output is needed, then please use calloc to initialize the output
An alternative is to use malloc, followed by memset 0
There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix
If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly
This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed
matA first matrix; shape = [nrows, ncommon]
matB second matrix; shape = [ncols, ncommon]
nrows number of rows in the first matrix
ncommon number of columns in the first matrix/number of rows in the second matrix
ncols number of columns in the second matrix
total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored
total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored.
Since matB is transposed the columns are now the ncomm axis
ret matrix multiplication output. shape = [nrows, ncols]
block_size tile/block size for optimal cache performance. A hardware specific parameter
*/
void transposed_tiledMatMul(const float* const matA, const float* const matB,
unsigned nrows, unsigned ncommon, unsigned ncols,
unsigned total_comm_A, unsigned total_comm_B,
float* const ret, unsigned block_size);

// scaled vector addition: ret = scalar1 * vec1 + scalar2 * vector2
void v_add(float scalar1, const float* const vec1,
Expand Down
Loading

0 comments on commit 7efca51

Please sign in to comment.