diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 7b49525..ab7ad58 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -49,6 +49,11 @@ jobs: chmod +x examples.sh ./examples.sh + run-tests-python: + name: Run Python tests + runs-on: ubuntu-latest + strategy: + check-dependencies: name: Check dependencies runs-on: ubuntu-latest diff --git a/python/ndelement/reference_cell.py b/python/ndelement/reference_cell.py index 7514db4..576e7bf 100644 --- a/python/ndelement/reference_cell.py +++ b/python/ndelement/reference_cell.py @@ -1,5 +1,9 @@ -"""Types.""" -from ._ndelementrs import lib as _lib +"""Reference cell information.""" +import typing +import ctypes +import numpy as np +import numpy.typing as npt +from ._ndelementrs import lib as _lib, ffi as _ffi from enum import Enum @@ -16,6 +20,99 @@ class ReferenceCellType(Enum): def dim(cell: ReferenceCellType) -> int: - """The topological dimension of a reference cell.""" + """Get the topological dimension of a reference cell.""" return _lib.dim(cell.value) + +def is_simplex(cell: ReferenceCellType) -> bool: + """Check if a reference cell is a simplex.""" + return _lib.is_simplex(cell.value) + + +def vertices(cell: ReferenceCellType, dtype: str = np.float64) -> npt.NDArray: + """Get the vertices of a reference cell.""" + vertices = np.empty((entity_counts(cell)[0], dim(cell)), dtype=dtype) + if dtype == np.float64: + _lib.vertices_f64(cell.value, _ffi.cast("double*", vertices.ctypes.data)) + elif dtype == np.float32: + _lib.vertices_f32(cell.value, _ffi.cast("float*", vertices.ctypes.data)) + else: + raise TypeError(f"Unsupported dtype: {dtype}") + return vertices + + +def midpoint(cell: ReferenceCellType, dtype: str = np.float64) -> npt.NDArray: + """Get the midpoint of a reference cell.""" + point = np.empty(dim(cell), dtype=dtype) + if dtype == np.float64: + _lib.midpoint_f64(cell.value, _ffi.cast("double*", point.ctypes.data)) + elif dtype == np.float32: + _lib.midpoint_f32(cell.value, _ffi.cast("float*", point.ctypes.data)) + else: + raise TypeError(f"Unsupported dtype: {dtype}") + return point + + +def edges(cell: ReferenceCellType) -> typing.List[npt.NDArray[int]]: + """Get the edges of a reference cell.""" + edges = [] + e = np.empty(2 * entity_counts(cell)[1], dtype=int) + _lib.faces(cell.value, _ffi.cast("uintptr_t* ", e.ctypes.data)) + for i in range(entity_counts(cell)[1]): + edges.append(e[2*i:2*i+2]) + return edges + + +def faces(cell: ReferenceCellType) -> typing.List[npt.NDArray[int]]: + """Get the faces of a reference cell.""" + faces = [] + flen = 0 + for t in entity_types(cell)[2]: + flen += entity_counts(t)[0] + f = np.empty(flen, dtype=int) + _lib.faces(cell.value, _ffi.cast("uintptr_t* ", f.ctypes.data)) + start = 0 + for t in entity_types(cell)[2]: + n = entity_counts(t)[0] + faces.append(f[start:start+n]) + start += n + return faces + + +def volumes(cell: ReferenceCellType) -> typing.List[npt.NDArray[int]]: + """Get the volumes of a reference cell.""" + volumes = [] + vlen = 0 + for t in entity_types(cell)[3]: + vlen += entity_counts(t)[0] + v = np.empty(vlen, dtype=int) + _lib.volumes(cell.value, _ffi.cast("uintptr_t* ", v.ctypes.data)) + start = 0 + for t in entity_types(cell)[3]: + n = entity_counts(t)[0] + volumes.append(v[start:start+n]) + start += n + return volumes + + +def entity_types(cell: ReferenceCellType) -> typing.List[typing.List[ReferenceCellType]]: + """Get the types of the sub-entities of a reference cell.""" + # TODO: should int be uintptr_t? + t = np.empty(sum(entity_counts(cell)), dtype=int) + _lib.entity_types(cell.value, _ffi.cast("uintptr_t* ", t.ctypes.data)) + types = [] + start = 0 + for n in entity_counts(cell): + types.append([ReferenceCellType(i) for i in t[start:start+n]]) + start += n + return types + + +def entity_counts(cell: ReferenceCellType) -> npt.NDArray[int]: + """Get the number of the sub-entities of each dimension for a reference cell.""" + counts = np.empty(4, dtype=int) + _lib.entity_counts(cell.value, _ffi.cast("uintptr_t* ", counts.ctypes.data)) + return counts + + +# TODO: connectivity diff --git a/python/test/test_reference_cell.py b/python/test/test_reference_cell.py index a90e83e..66bc03d 100644 --- a/python/test/test_reference_cell.py +++ b/python/test/test_reference_cell.py @@ -1,4 +1,6 @@ -from ndelement.reference_cell import dim, ReferenceCellType +import pytest +import numpy as np +from ndelement.reference_cell import dim, ReferenceCellType, midpoint, vertices def test_dim(): @@ -8,3 +10,25 @@ def test_dim(): assert dim(ReferenceCellType.Quadrilateral) == 2 assert dim(ReferenceCellType.Tetrahedron) == 3 assert dim(ReferenceCellType.Hexahedron) == 3 + + +def test_midpoint(): + assert np.allclose(midpoint(ReferenceCellType.Interval), [0.5]) + assert np.allclose(midpoint(ReferenceCellType.Triangle), [1 / 3, 1 / 3]) + assert np.allclose(midpoint(ReferenceCellType.Quadrilateral), [1 / 2, 1 / 2]) + assert np.allclose(midpoint(ReferenceCellType.Tetrahedron), [1 / 4, 1 / 4, 1 / 4]) + assert np.allclose(midpoint(ReferenceCellType.Hexahedron), [1 / 2, 1 / 2, 1 / 2]) + + +@pytest.mark.parametrize("cell", [ + ReferenceCellType.Interval, + ReferenceCellType.Triangle, + ReferenceCellType.Quadrilateral, + ReferenceCellType.Tetrahedron, + ReferenceCellType.Hexahedron, +]) +def test_vertices_and_midpoint(cell): + v = vertices(cell) + m = midpoint(cell) + + assert np.allclose(sum(i for i in v) / v.shape[0], m) diff --git a/src/bindings.rs b/src/bindings.rs index 9ca9db8..2805221 100644 --- a/src/bindings.rs +++ b/src/bindings.rs @@ -3,13 +3,101 @@ pub mod reference_cell { use crate::types::ReferenceCellType; + use crate::reference_cell; + use rlst::RlstScalar; #[no_mangle] pub unsafe extern "C" fn dim(cell: u8) -> usize { - crate::reference_cell::dim(ReferenceCellType::from(cell).expect("Invalid cell type")) + reference_cell::dim(ReferenceCellType::from(cell).expect("Invalid cell type")) } #[no_mangle] pub unsafe extern "C" fn is_simplex(cell: u8) -> bool { - crate::reference_cell::is_simplex(ReferenceCellType::from(cell).expect("Invalid cell type")) + reference_cell::is_simplex(ReferenceCellType::from(cell).expect("Invalid cell type")) + } + unsafe fn vertices>(cell: u8, vs: *mut T) { + let mut i = 0; + for v in reference_cell::vertices::(ReferenceCellType::from(cell).expect("Invalid cell type")) { + for c in v { + *vs.add(i) = c; + i += 1; + } + } + } + #[no_mangle] + pub unsafe extern "C" fn vertices_f32(cell: u8, vs: *mut f32) { + vertices(cell, vs); + } + #[no_mangle] + pub unsafe extern "C" fn vertices_f64(cell: u8, vs: *mut f64) { + vertices(cell, vs); + } + unsafe fn midpoint>(cell: u8, pt: *mut T) { + let pt = pt as *mut T; + for (i, c) in reference_cell::midpoint(ReferenceCellType::from(cell).expect("Invalid cell type")).iter().enumerate() { + *pt.add(i) = *c; + } + } + #[no_mangle] + pub unsafe extern "C" fn midpoint_f32(cell: u8, pt: *mut f32) { + midpoint(cell, pt); + } + #[no_mangle] + pub unsafe extern "C" fn midpoint_f64(cell: u8, pt: *mut f64) { + midpoint(cell, pt); + } + #[no_mangle] + pub unsafe extern "C" fn edges(cell: u8, es: *mut usize) { + let mut i = 0; + for e in reference_cell::edges(ReferenceCellType::from(cell).expect("Invalid cell type")) { + for v in e { + *es.add(i) = v; + i += 1 + } + } + } + #[no_mangle] + pub unsafe extern "C" fn faces(cell: u8, es: *mut usize) { + let mut i = 0; + for e in reference_cell::faces(ReferenceCellType::from(cell).expect("Invalid cell type")) { + for v in e { + *es.add(i) = v; + i += 1 + } + } + } + #[no_mangle] + pub unsafe extern "C" fn volumes(cell: u8, es: *mut usize) { + let mut i = 0; + for e in reference_cell::volumes(ReferenceCellType::from(cell).expect("Invalid cell type")) { + for v in e { + *es.add(i) = v; + i += 1 + } + } + } + #[no_mangle] + pub unsafe extern "C" fn entity_types(cell: u8, et: *mut u8) { + let mut i = 0; + for es in reference_cell::entity_types(ReferenceCellType::from(cell).expect("Invalid cell type")) { + for e in es { + *et.add(i) = e as u8; + i += 1 + } + } + } + #[no_mangle] + pub unsafe extern "C" fn entity_counts(cell: u8, ec: *mut usize) { + for (i, e) in reference_cell::entity_counts(ReferenceCellType::from(cell).expect("Invalid cell type")).iter().enumerate() { + *ec.add(i) = *e; + } + } + pub unsafe extern "C" fn connetivity_size(cell: u8, dim0: usize, index0: usize, dim1: usize) -> usize { + reference_cell::connectivity(ReferenceCellType::from(cell).expect("Invalid cell type"))[dim0][index0][dim1].len() + } + #[no_mangle] + pub unsafe extern "C" fn connetivity(cell: u8, dim0: usize, index0: usize, dim1: usize, c: *mut usize) { + for (i, j) in reference_cell::connectivity(ReferenceCellType::from(cell).expect("Invalid cell type"))[dim0][index0][dim1].iter().enumerate() { + *c.add(i) = *j; + } } } diff --git a/src/reference_cell.rs b/src/reference_cell.rs index 9e15d2a..5884ece 100644 --- a/src/reference_cell.rs +++ b/src/reference_cell.rs @@ -87,7 +87,7 @@ pub fn midpoint>(cell: ReferenceCellType) -> Vec { ReferenceCellType::Interval => vec![half], ReferenceCellType::Triangle => vec![third; 2], ReferenceCellType::Quadrilateral => vec![half; 2], - ReferenceCellType::Tetrahedron => vec![T::from(1.0).unwrap() / T::from(6.0).unwrap(); 3], + ReferenceCellType::Tetrahedron => vec![T::from(1.0).unwrap() / T::from(4.0).unwrap(); 3], ReferenceCellType::Hexahedron => vec![half; 3], ReferenceCellType::Prism => vec![third, third, half], ReferenceCellType::Pyramid => vec![