diff --git a/src/internal/strided_impl.h b/src/internal/strided_impl.h index ea017f0..3dc4926 100644 --- a/src/internal/strided_impl.h +++ b/src/internal/strided_impl.h @@ -140,6 +140,8 @@ OSHMPI_STATIC_INLINE_PREFIX void OSHMPI_create_strided_dtype(size_t nelems, ptrd /* Slow path: create a new datatype and cache it */ MPI_Datatype vtype = MPI_DATATYPE_NULL; size_t elem_bytes = 0; + MPI_Aint lb, extent; + int typesize; OSHMPI_CALLMPI(MPI_Type_vector((int) nelems, 1, (int) stride, mpi_type, &vtype)); @@ -148,10 +150,19 @@ OSHMPI_STATIC_INLINE_PREFIX void OSHMPI_create_strided_dtype(size_t nelems, ptrd * Extent can be negative in MPI, however, we do not expect such case in OSHMPI. * Thus skip any negative one */ if (required_ext_nelems > 0) { - if (mpi_type == OSHMPI_MPI_COLL32_T) + if (mpi_type == OSHMPI_MPI_COLL_BYTE_T) + elem_bytes = 1; + else if (mpi_type == OSHMPI_MPI_COLL32_T) elem_bytes = 4; - else + else if (mpi_type == OSHMPI_MPI_COLL64_T) elem_bytes = 8; + else { + OSHMPI_CALLMPI(MPI_Type_get_extent(mpi_type, &lb, &extent)); + OSHMPI_ASSERT(lb == 0); + OSHMPI_CALLMPI(MPI_Type_size(mpi_type, &typesize)); + OSHMPI_ASSERT(extent == typesize); + elem_bytes = (size_t) extent; + } OSHMPI_CALLMPI(MPI_Type_create_resized (vtype, 0, required_ext_nelems * elem_bytes, strided_type)); } else