Skip to content

Commit

Permalink
More tests for ConvTranspose1d
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Sep 27, 2024
1 parent 7207500 commit 0b6364f
Show file tree
Hide file tree
Showing 4 changed files with 340 additions and 49 deletions.
13 changes: 12 additions & 1 deletion RTNeural/conv1d/conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class Conv1D final : public Layer<T>
/** Returns the name of this layer. */
std::string getName() const noexcept override { return "conv1d"; }


/** Performs a stride step for this layer. */
RTNEURAL_REALTIME inline void skip(const T* input)
{
Expand Down Expand Up @@ -211,6 +210,18 @@ class Conv1DT
/** Resets the layer state. */
RTNEURAL_REALTIME void reset();

/** Performs a stride step for this layer. */
RTNEURAL_REALTIME inline void skip(const T (&ins)[in_size])
{
// insert input into a circular buffer
std::copy(std::begin(ins), std::end(ins), state[state_ptr].begin());

// set state pointers to particular columns of the buffer
setStatePointers();

state_ptr = (state_ptr == state_size - 1 ? 0 : state_ptr + 1); // iterate state pointer forwards
}

template <int _groups = groups, std::enable_if_t<_groups == 1, bool> = true>
/** Performs forward propagation for this layer. */
RTNEURAL_REALTIME inline void forward(const T (&ins)[in_size]) noexcept
Expand Down
12 changes: 12 additions & 0 deletions RTNeural/conv1d/conv1d_eigen.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ class Conv1DT
/** Resets the layer state. */
RTNEURAL_REALTIME void reset();

/** Performs a stride step for this layer. */
RTNEURAL_REALTIME inline void skip(const Eigen::Matrix<T, in_size, 1>& ins)
{
// insert input into a circular buffer
state.col(state_ptr) = ins;

// set state pointers to the particular columns of the buffer
setStatePointers();

state_ptr = (state_ptr == state_size - 1 ? 0 : state_ptr + 1); // iterate state pointer forwards
}

/** Performs forward propagation for this layer. */
template <int _groups = groups, std::enable_if_t<_groups == 1, bool> = true>
RTNEURAL_REALTIME inline void forward(const Eigen::Matrix<T, in_size, 1>& ins) noexcept
Expand Down
12 changes: 12 additions & 0 deletions RTNeural/conv1d/conv1d_xsimd.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ class Conv1DT
/** Resets the layer state. */
RTNEURAL_REALTIME void reset();

/** Performs a stride step for this layer. */
RTNEURAL_REALTIME inline void skip(const v_type (&ins)[v_in_size])
{
// insert input into a circular buffer
std::copy(std::begin(ins), std::end(ins), state[state_ptr].begin());

// set state pointers to particular columns of the buffer
setStatePointers();

state_ptr = (state_ptr == state_size - 1 ? 0 : state_ptr + 1); // iterate state pointer forwards
}

/** Performs forward propagation for this layer. */
template <int G = groups>
RTNEURAL_REALTIME inline typename std::enable_if<(G > 1), void>::type
Expand Down
Loading

0 comments on commit 0b6364f

Please sign in to comment.