Skip to content

Commit

Permalink
Nested Tensor Support (#76)
Browse files Browse the repository at this point in the history
* Adds basic support for single-level nested tensor inference
  * Partially addresses #72 ;
* Backward pass / Forward mode AD are not supported; torch autograd api
does not allow them yet. Will open an issue for this.
* C++ API now expects output tensor refs instead of creating its own.
  * This will finally fix #13 .
  • Loading branch information
alihassanijr authored Dec 30, 2023
1 parent 63f69d6 commit 7424939
Show file tree
Hide file tree
Showing 25 changed files with 910 additions and 441 deletions.
2 changes: 1 addition & 1 deletion csrc/include/natten/pytorch/cpu/na1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void na1d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int length,
Expand Down
2 changes: 1 addition & 1 deletion csrc/include/natten/pytorch/cpu/na2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void na2d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int height,
Expand Down
2 changes: 1 addition & 1 deletion csrc/include/natten/pytorch/cpu/na3d.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void na3d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int depth,
Expand Down
2 changes: 1 addition & 1 deletion csrc/include/natten/pytorch/cuda/na1d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void na1d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int length,
Expand Down
2 changes: 1 addition & 1 deletion csrc/include/natten/pytorch/cuda/na2d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void na2d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int height,
Expand Down
2 changes: 1 addition & 1 deletion csrc/include/natten/pytorch/cuda/na3d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void na3d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int depth,
Expand Down
98 changes: 98 additions & 0 deletions csrc/include/natten/pytorch/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,101 @@
} \
}()
#endif

namespace natten {
namespace pytorch {

inline void CheckArgs(int kernel_size, int dilation) {
TORCH_CHECK(kernel_size > 1 && kernel_size % 2 == 1, "Kernel size must be an odd number greater than 1, got ", kernel_size, ".");
TORCH_CHECK(dilation >= 1, "Dilation must be a nonnegative integer, got ", dilation, ".");
}

inline void CheckArgsAgainstDim(int dim, int kernel_size, int dilation) {
TORCH_CHECK(kernel_size * dilation <= dim, "Input axes must be less than or equal to the product of kernel size and dilation. "
"Got kernel size ", kernel_size, ", dilation ", dilation, ", but dimension size was ", dim, ".");
}

inline void CheckIfPropertiesMatch(const at::Tensor& a, const at::Tensor& b) {
CHECK_CONTIGUOUS(a);
CHECK_CONTIGUOUS(b);
TORCH_CHECK(a.device().is_cuda() == b.device().is_cuda(), "Expected all tensors to be on the same device.");
TORCH_CHECK(a.scalar_type() == b.scalar_type(), "Input tensors must match in dtype!");
}

inline void CheckIfPropertiesMatch(const at::Tensor& a, const at::Tensor& b, const at::Tensor& c) {
CHECK_CONTIGUOUS(a);
CHECK_CONTIGUOUS(b);
CHECK_CONTIGUOUS(c);
TORCH_CHECK(a.device().is_cuda() == b.device().is_cuda() && b.device().is_cuda() == c.device().is_cuda(), "Expected all tensors to be on the same device.");
TORCH_CHECK(a.scalar_type() == b.scalar_type() && b.scalar_type() == c.scalar_type(), "Input tensors must match in dtype!");
}

template <size_t NaDim>
void CheckIfTensorShapesMatch(const at::Tensor& a, const at::Tensor& b) {
static_assert(NaDim >= 1 && NaDim < 4);
static constexpr size_t Rank = NaDim + 3;
TORCH_CHECK(a.dim() == b.dim() && a.dim() == Rank, "Expected ", Rank, "-D tensors.");
for (size_t i=0; i < Rank; ++i) {
TORCH_CHECK(a.size(i) == b.size(i), "Tensor shape mismatch at dimension ", i, ": ", a.size(i), " != ", b.size(i));
}
}

template <size_t NaDim>
void CheckAttnShape(const at::Tensor& input, const at::Tensor& attn, int kernel_size) {
static_assert(NaDim >= 1 && NaDim < 4);
TORCH_CHECK(attn.dim() == NaDim + 3, "Expected ", NaDim + 3, "-D tensors.");
for (size_t i=0; i < NaDim + 2; ++i) {
TORCH_CHECK(input.size(i) == attn.size(i), "Tensor shape mismatch at dimension ", i, ": ", input.size(i), " != ", input.size(i));
}
auto expected_kernel_size = std::pow(kernel_size, NaDim);
TORCH_CHECK(attn.size(NaDim + 2) == expected_kernel_size, "Expected attention dim was ", expected_kernel_size, ", got ", attn.size(NaDim + 2));
}

template <size_t NaDim>
void CheckBias(const at::Tensor& input, const at::Tensor& bias, int kernel_size) {
static_assert(NaDim >= 1 && NaDim < 4);
TORCH_CHECK(input.scalar_type() == bias.scalar_type(), "Inputs and bias must match in dtype.");
TORCH_CHECK(bias.device().is_cuda() == input.device().is_cuda(),
"Expected positional bias to be on the same device as the inputs.");
CHECK_CONTIGUOUS(bias);
TORCH_CHECK(bias.size(0) == input.size(1), "Expected bias.shape[0] == input.shape[1] == heads.");
for (size_t i=0; i < NaDim; ++i) {
auto expected_bias_dim = kernel_size * 2 - 1;
TORCH_CHECK(bias.size(i + 1) == expected_bias_dim, "Invalid bias shape at dim ", i + 1, "; "
"expected ", expected_bias_dim, ", got ", bias.size(i + 1), ".");
}
}

// TODO: I resent this; please do it the right way.
template <size_t NaDim>
void CheckAttnShape(const at::Tensor& input, const at::Tensor& attn, int kernel_size, int kernel_size_d) {
static_assert(NaDim == 3);
TORCH_CHECK(attn.dim() == NaDim + 3, "Expected ", NaDim + 3, "-D tensors.");
for (size_t i=0; i < NaDim + 2; ++i) {
TORCH_CHECK(input.size(i) == attn.size(i), "Tensor shape mismatch at dimension ", i, ": ", input.size(i), " != ", input.size(i));
}
auto expected_kernel_size = kernel_size * kernel_size * kernel_size_d;
TORCH_CHECK(attn.size(NaDim + 2) == expected_kernel_size, "Expected attention dim was ", expected_kernel_size, ", got ", attn.size(NaDim + 2));
}

template <size_t NaDim>
void CheckBias(const at::Tensor& input, const at::Tensor& bias, int kernel_size, int kernel_size_d) {
static_assert(NaDim ==3);
TORCH_CHECK(input.scalar_type() == bias.scalar_type(), "Inputs and bias must match in dtype.");
TORCH_CHECK(bias.device().is_cuda() == input.device().is_cuda(),
"Expected positional bias to be on the same device as the inputs.");
CHECK_CONTIGUOUS(bias);
TORCH_CHECK(bias.size(0) == input.size(1), "Expected bias.shape[0] == input.shape[1] == heads.");

auto expected_bias_dim_0 = kernel_size_d * 2 - 1;
TORCH_CHECK(bias.size(1) == expected_bias_dim_0, "Invalid bias shape at dim 1; expected ",
expected_bias_dim_0, ", got ", bias.size(1), ".");
for (size_t i=1; i < NaDim; ++i) {
auto expected_bias_dim = kernel_size * 2 - 1;
TORCH_CHECK(bias.size(i + 1) == expected_bias_dim, "Invalid bias shape at dim ", i + 1, "; "
"expected ", expected_bias_dim, ", got ", bias.size(i + 1), ".");
}
}

} // namespace pytorch
} // namespace natten
17 changes: 11 additions & 6 deletions csrc/include/natten/pytorch/na1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,38 @@

#pragma once
#include <ATen/ATen.h>
#include <vector>

namespace natten {
namespace pytorch {

at::Tensor na1d_qk_forward(
void na1d_qk_forward(
at::Tensor &attn,
const at::Tensor &query,
const at::Tensor &key,
const at::optional<at::Tensor> &bias,
const int kernel_size,
const int dilation);

std::vector<at::Tensor> na1d_qk_backward(
void na1d_qk_backward(
at::Tensor &d_query,
at::Tensor &d_key,
at::optional<at::Tensor> &d_bias,
const at::Tensor &d_attn,
const at::Tensor &query,
const at::Tensor &key,
const bool has_bias,
const int kernel_size,
const int dilation);

at::Tensor na1d_av_forward(
void na1d_av_forward(
at::Tensor &out,
const at::Tensor &attn,
const at::Tensor &value,
const int kernel_size,
const int dilation);

std::vector<at::Tensor> na1d_av_backward(
void na1d_av_backward(
at::Tensor &d_attn,
at::Tensor &d_value,
const at::Tensor &d_out,
const at::Tensor &attn,
const at::Tensor &value,
Expand Down
17 changes: 11 additions & 6 deletions csrc/include/natten/pytorch/na2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,38 @@

#pragma once
#include <ATen/ATen.h>
#include <vector>

namespace natten {
namespace pytorch {

at::Tensor na2d_qk_forward(
void na2d_qk_forward(
at::Tensor &attn,
const at::Tensor &query,
const at::Tensor &key,
const at::optional<at::Tensor> &bias,
const int kernel_size,
const int dilation);

std::vector<at::Tensor> na2d_qk_backward(
void na2d_qk_backward(
at::Tensor &d_query,
at::Tensor &d_key,
at::optional<at::Tensor> &d_bias,
const at::Tensor &d_attn,
const at::Tensor &query,
const at::Tensor &key,
const bool has_bias,
const int kernel_size,
const int dilation);

at::Tensor na2d_av_forward(
void na2d_av_forward(
at::Tensor &out,
const at::Tensor &attn,
const at::Tensor &value,
const int kernel_size,
const int dilation);

std::vector<at::Tensor> na2d_av_backward(
void na2d_av_backward(
at::Tensor &d_attn,
at::Tensor &d_value,
const at::Tensor &d_out,
const at::Tensor &attn,
const at::Tensor &value,
Expand Down
18 changes: 11 additions & 7 deletions csrc/include/natten/pytorch/na3d.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@

#pragma once
#include <ATen/ATen.h>
#include <vector>

namespace natten {
namespace pytorch {

at::Tensor na3d_qk_forward(
void na3d_qk_forward(
at::Tensor &attn,
const at::Tensor &query,
const at::Tensor &key,
const at::optional<at::Tensor> &bias,
Expand All @@ -40,25 +40,30 @@ at::Tensor na3d_qk_forward(
const int depth_kernel_size,
const int depth_dilation);

std::vector<at::Tensor> na3d_qk_backward(
void na3d_qk_backward(
at::Tensor &d_query,
at::Tensor &d_key,
at::optional<at::Tensor> &d_bias,
const at::Tensor &d_attn,
const at::Tensor &query,
const at::Tensor &key,
const bool has_bias,
const int kernel_size,
const int dilation,
const int depth_kernel_size,
const int depth_dilation);

at::Tensor na3d_av_forward(
void na3d_av_forward(
at::Tensor &out,
const at::Tensor &attn,
const at::Tensor &value,
const int kernel_size,
const int dilation,
const int depth_kernel_size,
const int depth_dilation);

std::vector<at::Tensor> na3d_av_backward(
void na3d_av_backward(
at::Tensor &d_attn,
at::Tensor &d_value,
const at::Tensor &d_out,
const at::Tensor &attn,
const at::Tensor &value,
Expand All @@ -69,4 +74,3 @@ std::vector<at::Tensor> na3d_av_backward(

} // namespace pytorch
} // namespace natten

4 changes: 2 additions & 2 deletions csrc/src/pytorch/cpu/na1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void na1d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int length,
Expand All @@ -74,7 +74,7 @@ void na1d_qk_backward(
static_cast<void *>(d_attn.data_ptr()),
static_cast<void *>(d_query.data_ptr()),
static_cast<void *>(d_key.data_ptr()),
d_bias.has_storage() ? static_cast<void *>(d_bias.data_ptr()) : nullptr,
d_bias.has_value() ? static_cast<void *>(d_bias.value().data_ptr()) : nullptr,
batch_size, heads, length, dim,
kernel_size, dilation);
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/src/pytorch/cpu/na2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void na2d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int height,
Expand All @@ -76,7 +76,7 @@ void na2d_qk_backward(
static_cast<void *>(d_attn.data_ptr()),
static_cast<void *>(d_query.data_ptr()),
static_cast<void *>(d_key.data_ptr()),
d_bias.has_storage() ? static_cast<void *>(d_bias.data_ptr()) : nullptr,
d_bias.has_value() ? static_cast<void *>(d_bias.value().data_ptr()) : nullptr,
batch_size, heads, height, width, dim,
kernel_size, dilation);
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/src/pytorch/cpu/na3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void na3d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int depth,
Expand All @@ -82,7 +82,7 @@ void na3d_qk_backward(
static_cast<void *>(d_attn.data_ptr()),
static_cast<void *>(d_query.data_ptr()),
static_cast<void *>(d_key.data_ptr()),
d_bias.has_storage() ? static_cast<void *>(d_bias.data_ptr()) : nullptr,
d_bias.has_value() ? static_cast<void *>(d_bias.value().data_ptr()) : nullptr,
batch_size, heads, depth, height, width, dim,
kernel_size, dilation, depth_kernel_size, depth_dilation);
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/src/pytorch/cuda/na1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void na1d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int length,
Expand All @@ -75,7 +75,7 @@ void na1d_qk_backward(
static_cast<void *>(d_attn.data_ptr()),
static_cast<void *>(d_query.data_ptr()),
static_cast<void *>(d_key.data_ptr()),
d_bias.has_storage() ? static_cast<void *>(d_bias.data_ptr()) : nullptr,
d_bias.has_value() ? static_cast<void *>(d_bias.value().data_ptr()) : nullptr,
batch_size, heads, length, dim,
kernel_size, dilation);
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/src/pytorch/cuda/na2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void na2d_qk_backward(
const at::Tensor &key,
at::Tensor &d_query,
at::Tensor &d_key,
at::Tensor &d_bias,
at::optional<at::Tensor> &d_bias,
const int batch_size,
const int heads,
const int height,
Expand All @@ -77,7 +77,7 @@ void na2d_qk_backward(
static_cast<void *>(d_attn.data_ptr()),
static_cast<void *>(d_query.data_ptr()),
static_cast<void *>(d_key.data_ptr()),
d_bias.has_storage() ? static_cast<void *>(d_bias.data_ptr()) : nullptr,
d_bias.has_value() ? static_cast<void *>(d_bias.value().data_ptr()) : nullptr,
batch_size, heads, height, width, dim,
kernel_size, dilation);
}
Expand Down
Loading

0 comments on commit 7424939

Please sign in to comment.