Skip to content

Commit

Permalink
Templated strided convolution wrapper with test
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Sep 27, 2024
1 parent 6a2e066 commit 112e0ee
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 4 deletions.
81 changes: 81 additions & 0 deletions RTNeural/conv1d/conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,4 +440,85 @@ class StrideConv1D final : public Layer<T>
int strides_counter = 0;
std::vector<T> skip_output {};
};

template <typename T, int in_sizet, int out_sizet, int kernel_size, int dilation_rate, int stride, int groups = 1, bool dynamic_state = false>
class StrideConv1DT
{
Conv1DT<T, in_sizet, out_sizet, kernel_size, dilation_rate, groups, dynamic_state> internal;

int strides_counter = 0;

public:
static constexpr auto in_size = in_sizet;
static constexpr auto out_size = out_sizet;
static constexpr auto filters_per_group = in_size / groups;
static constexpr auto channels_per_group = out_size / groups;

StrideConv1DT()
: outs(internal.outs)
{
}

/** Returns the name of this layer. */
std::string getName() const noexcept { return "strided_conv1d"; }

/** Returns false since convolution is not an activation layer. */
constexpr bool isActivation() const noexcept { return false; }

/** Resets the layer state. */
RTNEURAL_REALTIME void reset()
{
internal.reset();
}

/** Performs a stride step for this layer. */
template <typename Inputs>
RTNEURAL_REALTIME inline void skip(const Inputs& ins) noexcept
{
internal.skip(ins);
}

/** Performs forward propagation for this layer. */
template <typename Inputs>
RTNEURAL_REALTIME inline void forward(const Inputs& ins) noexcept
{
if(strides_counter == 0)
internal.forward(ins);
else
internal.skip(ins);

strides_counter = (strides_counter == stride - 1) ? 0 : strides_counter + 1;
}

/**
* Sets the layer weights.
*
* The weights vector must have size weights[out_size][group_count][kernel_size * dilation]
*/
RTNEURAL_REALTIME void setWeights(const std::vector<std::vector<std::vector<T>>>& weights)
{
internal.setWeights(weights);
}

/**
* Sets the layer biases.
*
* The bias vector must have size bias[out_size]
*/
RTNEURAL_REALTIME void setBias(const std::vector<T>& biasVals)
{
internal.setBias(biasVals);
}

/** Returns the size of the convolution kernel. */
RTNEURAL_REALTIME int getKernelSize() const noexcept { return kernel_size; }

/** Returns the convolution dilation rate. */
RTNEURAL_REALTIME int getDilationRate() const noexcept { return dilation_rate; }

/** Returns the number of "groups" in the convolution. */
int getGroups() const noexcept { return groups; }

decltype(internal.outs)& outs;
};
}
81 changes: 77 additions & 4 deletions tests/functional/torch_conv1d_stride_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ void testTorchConv1DModel()
jsonStream >> modelJson;
const size_t STRIDE = 3, KS = 5, OUT_CH = 12;

// Use dynamic model. Call model.skip() for striding.
// Use dynamic model.
RTNeural::StrideConv1D<T> model(1, OUT_CH, KS, 1, STRIDE, 1);
// RTNeural::Conv1D<T> model(1, OUT_CH, KS, 1, 1);
RTNeural::torch_helpers::loadConv1D<T>(modelJson, "", model);
model.reset();

Expand Down Expand Up @@ -55,14 +54,88 @@ void testTorchConv1DModel()
}
}
}

template <typename T>
void testTorchConv1DModelComptime()
{
const auto model_file = std::string { RTNEURAL_ROOT_DIR } + "models/conv1d_torch_stride_3.json";
std::ifstream jsonStream(model_file, std::ifstream::binary);
nlohmann::json modelJson;
jsonStream >> modelJson;
static constexpr size_t STRIDE = 3, KS = 5, OUT_CH = 12;

RTNeural::StrideConv1DT<T, 1, OUT_CH, KS, 1, STRIDE> model;
RTNeural::torch_helpers::loadConv1D<T>(modelJson, "", model);
model.reset();

std::ifstream modelInputsFile { std::string { RTNEURAL_ROOT_DIR } + "test_data/conv1d_torch_x_python_stride_3.csv" };
const auto inputs = load_csv::loadFile<T>(modelInputsFile);
std::vector<std::array<T, OUT_CH>> outputs {};
const size_t start_point = KS-1;
outputs.resize((inputs.size() - start_point)/ STRIDE, {});
//std::cout << "Out size " << outputs.size() << "\n";

#if RTNEURAL_USE_EIGEN
alignas(RTNEURAL_DEFAULT_ALIGNMENT) Eigen::Matrix<T, 1, 1> input_data {};
input_data.setZero();
#elif RTNEURAL_USE_XSIMD
alignas(RTNEURAL_DEFAULT_ALIGNMENT) xsimd::batch<T> input_data[RTNeural::ceil_div(1, (int) xsimd::batch<T>::size)] {};
#else
alignas(RTNEURAL_DEFAULT_ALIGNMENT) T input_data[1] {};
#endif

for(size_t i = 0; i < start_point; ++i)
{
input_data[0] = inputs[i];
model.skip(input_data);
}

for(size_t i = start_point; i < inputs.size(); ++i)
{
input_data[0] = inputs[i];
model.forward(input_data);

const auto out_idx = (i-start_point)/STRIDE;
#if RTNEURAL_USE_XSIMD
std::copy(reinterpret_cast<T*>(std::begin(model.outs)),
reinterpret_cast<T*>(std::end(model.outs)),
std::begin(outputs[out_idx]));
#else
std::copy(std::begin(model.outs),
std::end(model.outs),
std::begin(outputs[out_idx]));
#endif
}

std::ifstream modelOutputsFile { std::string { RTNEURAL_ROOT_DIR } + "test_data/conv1d_torch_y_python_stride_3.csv" };
const auto expected_y = RTNeural::torch_helpers::detail::transpose(load_csv::loadFile2d<T>(modelOutputsFile));

for(size_t n = 0; n < expected_y.size(); ++n)
{
for(size_t j = 0; j < outputs[n].size(); ++j)
{
expectNear(outputs[n][j], expected_y[n][j]);
}
}
}
}

TEST(TestTorchConv1DStride, modelOutputMatchesPythonImplementationForFloats)
TEST(TestTorchConv1DStride, modelOutputMatchesPythonImplementationForFloatsRuntime)
{
testTorchConv1DModel<float>();
}

TEST(TestTorchConv1DStride, modelOutputMatchesPythonImplementationForDoubles)
TEST(TestTorchConv1DStride, modelOutputMatchesPythonImplementationForFloatsComptime)
{
testTorchConv1DModelComptime<float>();
}

TEST(TestTorchConv1DStride, modelOutputMatchesPythonImplementationForDoublesRuntime)
{
testTorchConv1DModel<double>();
}

TEST(TestTorchConv1DStride, modelOutputMatchesPythonImplementationForDoublesComptime)
{
testTorchConv1DModelComptime<double>();
}

0 comments on commit 112e0ee

Please sign in to comment.