Skip to content

Commit

Permalink
Fixing indentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Cydral committed Dec 16, 2024
1 parent caed8ff commit fbaa299
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions dlib/cuda/tensor_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,49 +172,49 @@ namespace dlib { namespace tt
requires
- dest does not alias the memory of lhs or rhs
- The dimensions of lhs and rhs must be compatible for matrix multiplication.
The specific requirements depend on the mode:
For CHANNEL_WISE mode (default):
- Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
- Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
- Let D == mat(dest)
- D.nr() == L.nr() && D.nc() == R.nc()
(i.e. dest must be preallocated and have the correct output dimensions)
- L.nc() == R.nr()
For PLANE_WISE mode:
- lhs.num_samples() == rhs.num_samples() && lhs.k() == rhs.k()
- If !trans_lhs && !trans_rhs:
lhs.nc() == rhs.nr()
dest.nr() == lhs.nr() && dest.nc() == rhs.nc()
- If trans_lhs && !trans_rhs:
lhs.nr() == rhs.nr()
dest.nr() == lhs.nc() && dest.nc() == rhs.nc()
- If !trans_lhs && trans_rhs:
lhs.nc() == rhs.nc()
dest.nr() == lhs.nr() && dest.nc() == rhs.nr()
- If trans_lhs && trans_rhs:
lhs.nr() == rhs.nc()
dest.nr() == lhs.nc() && dest.nc() == rhs.nr()
The specific requirements depend on the mode:
For CHANNEL_WISE mode (default):
- Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
- Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
- Let D == mat(dest)
- D.nr() == L.nr() && D.nc() == R.nc()
(i.e. dest must be preallocated and have the correct output dimensions)
- L.nc() == R.nr()
For PLANE_WISE mode:
- lhs.num_samples() == rhs.num_samples() && lhs.k() == rhs.k()
- If !trans_lhs && !trans_rhs:
lhs.nc() == rhs.nr()
dest.nr() == lhs.nr() && dest.nc() == rhs.nc()
- If trans_lhs && !trans_rhs:
lhs.nr() == rhs.nr()
dest.nr() == lhs.nc() && dest.nc() == rhs.nc()
- If !trans_lhs && trans_rhs:
lhs.nc() == rhs.nc()
dest.nr() == lhs.nr() && dest.nc() == rhs.nr()
- If trans_lhs && trans_rhs:
lhs.nr() == rhs.nc()
dest.nr() == lhs.nc() && dest.nc() == rhs.nr()
ensures
- Performs matrix multiplication based on the specified mode:
For CHANNEL_WISE mode:
- performs: dest = alpha*L*R + beta*mat(dest)
Where L, R, and D are as defined above.
For PLANE_WISE mode:
- Performs matrix multiplication for each corresponding 2D plane (nr x nc)
in lhs and rhs across all samples and channels.
- The operation is equivalent to performing the following for each sample
and channel:
dest[s][k] = alpha * (lhs[s][k] * rhs[s][k]) + beta * dest[s][k]
Where [s][k] represents the 2D plane for sample s and channel k.
For CHANNEL_WISE mode:
- performs: dest = alpha*L*R + beta*mat(dest)
where L, R, and D are as defined above.
For PLANE_WISE mode:
- Performs matrix multiplication for each corresponding 2D plane (nr x nc)
in lhs and rhs across all samples and channels.
- The operation is equivalent to performing the following for each sample
and channel:
dest[s][k] = alpha * (lhs[s][k] * rhs[s][k]) + beta * dest[s][k]
where [s][k] represents the 2D plane for sample s and channel k.
Note that the PLANE_WISE mode is particularly useful for operations like attention
mechanisms in neural networks, where you want to perform matrix multiplications
on 2D planes of 4D tensors while preserving the sample and channel dimensions.
Note that the PLANE_WISE mode is particularly useful for operations like attention
mechanisms in neural networks, where you want to perform matrix multiplications
on 2D planes of 4D tensors while preserving the sample and channel dimensions.
!*/

// ----------------------------------------------------------------------------------------
Expand Down

0 comments on commit fbaa299

Please sign in to comment.