diff --git a/libs/libvtrutil/src/vtr_ndmatrix.h b/libs/libvtrutil/src/vtr_ndmatrix.h index 57571cc865..fb6a4ad39e 100644 --- a/libs/libvtrutil/src/vtr_ndmatrix.h +++ b/libs/libvtrutil/src/vtr_ndmatrix.h @@ -34,9 +34,10 @@ class NdMatrixProxy { * @param dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension) * @param start: Pointer to the start of the sub-matrix this proxy represents */ - NdMatrixProxy(const size_t* dim_sizes, const size_t* dim_strides, T* start) + NdMatrixProxy(const size_t* dim_sizes, const size_t* dim_strides, size_t offset, const std::unique_ptr& start) : dim_sizes_(dim_sizes) , dim_strides_(dim_strides) + , offset_(offset) , start_(start) {} NdMatrixProxy& operator=(const NdMatrixProxy& other) = delete; @@ -50,7 +51,8 @@ class NdMatrixProxy { return NdMatrixProxy( dim_sizes_ + 1, // Pass the dimension information dim_strides_ + 1, // Pass the stride for the next dimension - start_ + dim_strides_[0] * index); // Advance to index in this dimension + offset_ + dim_strides_[0] * index, // Advance to index in this dimension + start_); // Pass the base pointer. } ///@brief [] operator @@ -62,7 +64,8 @@ class NdMatrixProxy { private: const size_t* dim_sizes_; const size_t* dim_strides_; - T* start_; + size_t offset_; + const std::unique_ptr& start_; }; ///@brief Base case: 1-dimensional array @@ -76,9 +79,10 @@ class NdMatrixProxy { * @param dim_stride: The stride of this dimension (i.e. how many element in memory between indicies of this dimension) * @param start: Pointer to the start of the sub-matrix this proxy represents */ - NdMatrixProxy(const size_t* dim_sizes, const size_t* dim_stride, T* start) + NdMatrixProxy(const size_t* dim_sizes, const size_t* dim_stride, size_t offset, const std::unique_ptr& start) : dim_sizes_(dim_sizes) , dim_strides_(dim_stride) + , offset_(offset) , start_(start) {} NdMatrixProxy& operator=(const NdMatrixProxy& other) = delete; @@ -89,7 +93,7 @@ class NdMatrixProxy { VTR_ASSERT_SAFE_MSG(index < dim_sizes_[0], "Index out of range (above dimension maximum)"); //Base case - return start_[index]; + return start_[offset_ + index]; } ///@brief [] operator @@ -108,7 +112,7 @@ class NdMatrixProxy { * not to clobber elements in other dimensions */ const T* data() const { - return start_; + return start_.get() + offset_; } ///@brief same as above but allow update the value @@ -120,7 +124,8 @@ class NdMatrixProxy { private: const size_t* dim_sizes_; const size_t* dim_strides_; - T* start_; + size_t offset_; + const std::unique_ptr& start_; }; /** @@ -359,7 +364,8 @@ class NdMatrix : public NdMatrixBase { return NdMatrixProxy( this->dim_sizes_.data() + 1, //Pass the dimension information this->dim_strides_.data() + 1, //Pass the stride for the next dimension - this->data_.get() + this->dim_strides_[0] * index); //Advance to index in this dimension + this->dim_strides_[0] * index, //Advance to index in this dimension + this->data_); //Pass the base pointer } /**