Skip to content

Commit

Permalink
workingon python interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mscroggs committed Sep 5, 2024
1 parent 842b38a commit 4561dce
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ jobs:
chmod +x examples.sh
./examples.sh
run-tests-python:
name: Run Python tests
runs-on: ubuntu-latest
strategy:

check-dependencies:
name: Check dependencies
runs-on: ubuntu-latest
Expand Down
103 changes: 100 additions & 3 deletions python/ndelement/reference_cell.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Types."""
from ._ndelementrs import lib as _lib
"""Reference cell information."""
import typing
import ctypes
import numpy as np
import numpy.typing as npt
from ._ndelementrs import lib as _lib, ffi as _ffi
from enum import Enum


Expand All @@ -16,6 +20,99 @@ class ReferenceCellType(Enum):


def dim(cell: ReferenceCellType) -> int:
"""The topological dimension of a reference cell."""
"""Get the topological dimension of a reference cell."""
return _lib.dim(cell.value)


def is_simplex(cell: ReferenceCellType) -> bool:
"""Check if a reference cell is a simplex."""
return _lib.is_simplex(cell.value)


def vertices(cell: ReferenceCellType, dtype: str = np.float64) -> npt.NDArray:
"""Get the vertices of a reference cell."""
vertices = np.empty((entity_counts(cell)[0], dim(cell)), dtype=dtype)
if dtype == np.float64:
_lib.vertices_f64(cell.value, _ffi.cast("double*", vertices.ctypes.data))
elif dtype == np.float32:
_lib.vertices_f32(cell.value, _ffi.cast("float*", vertices.ctypes.data))
else:
raise TypeError(f"Unsupported dtype: {dtype}")
return vertices


def midpoint(cell: ReferenceCellType, dtype: str = np.float64) -> npt.NDArray:
"""Get the midpoint of a reference cell."""
point = np.empty(dim(cell), dtype=dtype)
if dtype == np.float64:
_lib.midpoint_f64(cell.value, _ffi.cast("double*", point.ctypes.data))
elif dtype == np.float32:
_lib.midpoint_f32(cell.value, _ffi.cast("float*", point.ctypes.data))
else:
raise TypeError(f"Unsupported dtype: {dtype}")
return point


def edges(cell: ReferenceCellType) -> typing.List[npt.NDArray[int]]:
"""Get the edges of a reference cell."""
edges = []
e = np.empty(2 * entity_counts(cell)[1], dtype=int)
_lib.faces(cell.value, _ffi.cast("uintptr_t* ", e.ctypes.data))
for i in range(entity_counts(cell)[1]):
edges.append(e[2*i:2*i+2])
return edges


def faces(cell: ReferenceCellType) -> typing.List[npt.NDArray[int]]:
"""Get the faces of a reference cell."""
faces = []
flen = 0
for t in entity_types(cell)[2]:
flen += entity_counts(t)[0]
f = np.empty(flen, dtype=int)
_lib.faces(cell.value, _ffi.cast("uintptr_t* ", f.ctypes.data))
start = 0
for t in entity_types(cell)[2]:
n = entity_counts(t)[0]
faces.append(f[start:start+n])
start += n
return faces


def volumes(cell: ReferenceCellType) -> typing.List[npt.NDArray[int]]:
"""Get the volumes of a reference cell."""
volumes = []
vlen = 0
for t in entity_types(cell)[3]:
vlen += entity_counts(t)[0]
v = np.empty(vlen, dtype=int)
_lib.volumes(cell.value, _ffi.cast("uintptr_t* ", v.ctypes.data))
start = 0
for t in entity_types(cell)[3]:
n = entity_counts(t)[0]
volumes.append(v[start:start+n])
start += n
return volumes


def entity_types(cell: ReferenceCellType) -> typing.List[typing.List[ReferenceCellType]]:
"""Get the types of the sub-entities of a reference cell."""
# TODO: should int be uintptr_t?
t = np.empty(sum(entity_counts(cell)), dtype=int)
_lib.entity_types(cell.value, _ffi.cast("uintptr_t* ", t.ctypes.data))
types = []
start = 0
for n in entity_counts(cell):
types.append([ReferenceCellType(i) for i in t[start:start+n]])
start += n
return types


def entity_counts(cell: ReferenceCellType) -> npt.NDArray[int]:
"""Get the number of the sub-entities of each dimension for a reference cell."""
counts = np.empty(4, dtype=int)
_lib.entity_counts(cell.value, _ffi.cast("uintptr_t* ", counts.ctypes.data))
return counts


# TODO: connectivity
26 changes: 25 additions & 1 deletion python/test/test_reference_cell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ndelement.reference_cell import dim, ReferenceCellType
import pytest
import numpy as np
from ndelement.reference_cell import dim, ReferenceCellType, midpoint, vertices


def test_dim():
Expand All @@ -8,3 +10,25 @@ def test_dim():
assert dim(ReferenceCellType.Quadrilateral) == 2
assert dim(ReferenceCellType.Tetrahedron) == 3
assert dim(ReferenceCellType.Hexahedron) == 3


def test_midpoint():
assert np.allclose(midpoint(ReferenceCellType.Interval), [0.5])
assert np.allclose(midpoint(ReferenceCellType.Triangle), [1 / 3, 1 / 3])
assert np.allclose(midpoint(ReferenceCellType.Quadrilateral), [1 / 2, 1 / 2])
assert np.allclose(midpoint(ReferenceCellType.Tetrahedron), [1 / 4, 1 / 4, 1 / 4])
assert np.allclose(midpoint(ReferenceCellType.Hexahedron), [1 / 2, 1 / 2, 1 / 2])


@pytest.mark.parametrize("cell", [
ReferenceCellType.Interval,
ReferenceCellType.Triangle,
ReferenceCellType.Quadrilateral,
ReferenceCellType.Tetrahedron,
ReferenceCellType.Hexahedron,
])
def test_vertices_and_midpoint(cell):
v = vertices(cell)
m = midpoint(cell)

assert np.allclose(sum(i for i in v) / v.shape[0], m)
92 changes: 90 additions & 2 deletions src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,101 @@

pub mod reference_cell {
use crate::types::ReferenceCellType;
use crate::reference_cell;
use rlst::RlstScalar;

#[no_mangle]
pub unsafe extern "C" fn dim(cell: u8) -> usize {
crate::reference_cell::dim(ReferenceCellType::from(cell).expect("Invalid cell type"))
reference_cell::dim(ReferenceCellType::from(cell).expect("Invalid cell type"))
}
#[no_mangle]
pub unsafe extern "C" fn is_simplex(cell: u8) -> bool {

Check warning on line 14 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "mpi,serde,strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs

Check warning on line 14 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs
crate::reference_cell::is_simplex(ReferenceCellType::from(cell).expect("Invalid cell type"))
reference_cell::is_simplex(ReferenceCellType::from(cell).expect("Invalid cell type"))
}
unsafe fn vertices<T: RlstScalar<Real=T>>(cell: u8, vs: *mut T) {
let mut i = 0;
for v in reference_cell::vertices::<T>(ReferenceCellType::from(cell).expect("Invalid cell type")) {
for c in v {
*vs.add(i) = c;
i += 1;
}
}
}
#[no_mangle]
pub unsafe extern "C" fn vertices_f32(cell: u8, vs: *mut f32) {
vertices(cell, vs);
}
#[no_mangle]
pub unsafe extern "C" fn vertices_f64(cell: u8, vs: *mut f64) {

Check warning on line 31 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "mpi,serde,strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs

Check warning on line 31 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs
vertices(cell, vs);
}
unsafe fn midpoint<T: RlstScalar<Real=T>>(cell: u8, pt: *mut T) {
let pt = pt as *mut T;
for (i, c) in reference_cell::midpoint(ReferenceCellType::from(cell).expect("Invalid cell type")).iter().enumerate() {
*pt.add(i) = *c;
}
}
#[no_mangle]
pub unsafe extern "C" fn midpoint_f32(cell: u8, pt: *mut f32) {
midpoint(cell, pt);
}
#[no_mangle]
pub unsafe extern "C" fn midpoint_f64(cell: u8, pt: *mut f64) {
midpoint(cell, pt);
}
#[no_mangle]
pub unsafe extern "C" fn edges(cell: u8, es: *mut usize) {
let mut i = 0;
for e in reference_cell::edges(ReferenceCellType::from(cell).expect("Invalid cell type")) {
for v in e {
*es.add(i) = v;
i += 1
}
}
}
#[no_mangle]
pub unsafe extern "C" fn faces(cell: u8, es: *mut usize) {
let mut i = 0;
for e in reference_cell::faces(ReferenceCellType::from(cell).expect("Invalid cell type")) {
for v in e {
*es.add(i) = v;
i += 1
}
}
}
#[no_mangle]

Check warning on line 68 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "mpi,serde,strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs

Check warning on line 68 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs
pub unsafe extern "C" fn volumes(cell: u8, es: *mut usize) {
let mut i = 0;
for e in reference_cell::volumes(ReferenceCellType::from(cell).expect("Invalid cell type")) {
for v in e {
*es.add(i) = v;
i += 1
}
}
}
#[no_mangle]

Check warning on line 78 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "mpi,serde,strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs

Check warning on line 78 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs
pub unsafe extern "C" fn entity_types(cell: u8, et: *mut u8) {
let mut i = 0;
for es in reference_cell::entity_types(ReferenceCellType::from(cell).expect("Invalid cell type")) {
for e in es {
*et.add(i) = e as u8;
i += 1
}
}
}

Check warning on line 87 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "mpi,serde,strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs

Check warning on line 87 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs
#[no_mangle]
pub unsafe extern "C" fn entity_counts(cell: u8, ec: *mut usize) {
for (i, e) in reference_cell::entity_counts(ReferenceCellType::from(cell).expect("Invalid cell type")).iter().enumerate() {
*ec.add(i) = *e;
}
}
pub unsafe extern "C" fn connetivity_size(cell: u8, dim0: usize, index0: usize, dim1: usize) -> usize {

Check warning on line 94 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "mpi,serde,strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs

Check warning on line 94 in src/bindings.rs

View workflow job for this annotation

GitHub Actions / Rust style checks (--features "strict")

Diff in /home/runner/work/ndelement/ndelement/src/bindings.rs
reference_cell::connectivity(ReferenceCellType::from(cell).expect("Invalid cell type"))[dim0][index0][dim1].len()
}
#[no_mangle]
pub unsafe extern "C" fn connetivity(cell: u8, dim0: usize, index0: usize, dim1: usize, c: *mut usize) {
for (i, j) in reference_cell::connectivity(ReferenceCellType::from(cell).expect("Invalid cell type"))[dim0][index0][dim1].iter().enumerate() {
*c.add(i) = *j;
}
}
}
2 changes: 1 addition & 1 deletion src/reference_cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub fn midpoint<T: RlstScalar<Real = T>>(cell: ReferenceCellType) -> Vec<T> {
ReferenceCellType::Interval => vec![half],
ReferenceCellType::Triangle => vec![third; 2],
ReferenceCellType::Quadrilateral => vec![half; 2],
ReferenceCellType::Tetrahedron => vec![T::from(1.0).unwrap() / T::from(6.0).unwrap(); 3],
ReferenceCellType::Tetrahedron => vec![T::from(1.0).unwrap() / T::from(4.0).unwrap(); 3],
ReferenceCellType::Hexahedron => vec![half; 3],
ReferenceCellType::Prism => vec![third, third, half],
ReferenceCellType::Pyramid => vec![
Expand Down

0 comments on commit 4561dce

Please sign in to comment.