Skip to content

Commit

Permalink
New bindings (#60)
Browse files Browse the repository at this point in the history
* WIP: Bindings

* WIP: Bindings

* run maturin with nightly

* install nightly

* copy change to weekly workflow

* Update bindings

* remove defunct tests

* equivalence

* nightly in build

* cargo install?

* run nightly?

* maturin with pip?

* no need for rustfmt in most workflows

* python -m maturin

* remove typo

---------

Co-authored-by: Timo Betcke <[email protected]>
  • Loading branch information
mscroggs and tbetcke authored Nov 25, 2024
1 parent ba256a6 commit c39459e
Show file tree
Hide file tree
Showing 13 changed files with 488 additions and 922 deletions.
1 change: 0 additions & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ jobs:
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: "nightly"
components: rustfmt
- name: Set up MPI
uses: mpi4py/setup-mpi@v1
with:
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ jobs:
- name: Set up Rust
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: ${{ matrix.rust-version }}
components: rustfmt

toolchain: nightly
- name: Install OpenBLAS, LAPACK, OpenSSL (yum)
run: yum -y install openblas-devel lapack-devel openssl-devel
if: matrix.platform.package-manager == 'yum'
Expand All @@ -42,7 +40,7 @@ jobs:

- name: Install maturin, CFFI
run: |
${{ matrix.platform.python }} -m pip install maturin>=1.7.2
${{ matrix.platform.python }} -m pip install "maturin>=1.7"
${{ matrix.platform.python }} -m pip install cffi
- name: Build wheel
Expand Down Expand Up @@ -70,6 +68,7 @@ jobs:
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
rust-toolchain: nightly
target: ${{ matrix.platform.target }}
args: --release --out dist --find-interpreter
sccache: 'true'
Expand All @@ -87,6 +86,7 @@ jobs:
- name: Build sdist
uses: PyO3/maturin-action@v1
with:
rust-toolchain: nightly
command: sdist
args: --out dist
- name: Upload sdist
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ jobs:
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: ${{ matrix.rust-version }}
components: rustfmt
- name: Set up MPI
uses: mpi4py/setup-mpi@v1
with:
Expand Down Expand Up @@ -56,6 +55,10 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- name: Set up Rust
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: nightly
- name: Set up Python
uses: actions/setup-python@v4
with:
Expand All @@ -70,7 +73,7 @@ jobs:
- name: Install python package
run: |
source .venv/bin/activate
maturin develop --release
rustup run nightly maturin develop --release
- name: Run Python tests
run: |
source .venv/bin/activate
Expand All @@ -87,7 +90,6 @@ jobs:
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: ${{ matrix.rust-version }}
components: rustfmt
- name: Install cargo-upgrades
run: cargo install cargo-upgrades
- uses: actions/checkout@v4
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/run-weekly-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ jobs:
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: ${{ matrix.rust-version }}
components: rustfmt
- name: Set up MPI
uses: mpi4py/setup-mpi@v1
with:
Expand Down Expand Up @@ -65,6 +64,10 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- name: Set up Rust
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: nightly
- name: Set up Python
uses: actions/setup-python@v4
with:
Expand All @@ -79,7 +82,7 @@ jobs:
- name: Install python package
run: |
source .venv/bin/activate
maturin develop --release
rustup run nightly maturin develop --release
- name: Run Python tests
run: |
source .venv/bin/activate
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ rlst = { version = "0.2.0", default-features = false }
serde = { version = "1", features = ["derive"], optional = true }
strum = "0.26"
strum_macros = "0.26"
c-api-tools = { git = "https://github.com/bempp/c-api-tools.git" }

[dev-dependencies]
paste = "1.*"
Expand Down
33 changes: 0 additions & 33 deletions build.rs

This file was deleted.

7 changes: 7 additions & 0 deletions cbindgen.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@ exclude = []

[enum]
prefix_with_name = true

[parse]
parse_deps = true
include = ["c-api-tools", "ndelement"]

[parse.expand]
crates = ["ndelement"]
111 changes: 47 additions & 64 deletions python/ndelement/ciarlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
class Continuity(Enum):
"""Continuity."""

Standard = 0
Discontinuous = 1
Standard = _lib.Continuity_Standard
Discontinuous = _lib.Continuity_Discontinuous


class Family(Enum):
Expand All @@ -27,21 +27,21 @@ class Family(Enum):
class MapType(Enum):
"""Map type."""

Identity = 0
CovariantPiola = 1
ContravariantPiola = 2
L2Piola = 3
Identity = _lib.MapType_Identity
CovariantPiola = _lib.MapType_CovariantPiola
ContravariantPiola = _lib.MapType_ContravariantPiola
L2Piola = _lib.MapType_L2Piola


_dtypes = {
0: np.float32,
1: np.float64,
}
_ctypes = {
np.float32: "float",
np.float64: "double",
_rtypes = {
np.float32: _lib.DType_F32,
np.float64: _lib.DType_F64,
np.complex64: _lib.DType_C32,
np.complex128: _lib.DType_C64,
}

_dtypes = {j: i for i, j in _rtypes.items()}


class CiarletElement(object):
"""Ciarlet element."""
Expand All @@ -54,78 +54,77 @@ def __init__(self, rs_element: _CDataBase, owned: bool = True):
def __del__(self):
"""Delete object."""
if self._owned:
_lib.ciarlet_free_element(self._rs_element)
_lib.ciarlet_element_t_free(self._rs_element)

@property
def dtype(self):
"""Data type."""
return _dtypes[_lib.ciarlet_element_dtype(self._rs_element)]

@property
def _ctype(self):
"""C data type."""
return _ctypes[self.dtype]

@property
def value_size(self) -> int:
"""Value size of the element."""
return _lib.ciarlet_value_size(self._rs_element)
return _lib.element_value_size(self._rs_element)

@property
def value_shape(self) -> typing.Tuple[int, ...]:
"""Value size of the element."""
shape = np.empty(_lib.ciarlet_value_rank(self._rs_element), dtype=np.uintp)
_lib.ciarlet_value_shape(self._rs_element, _ffi.cast("uintptr_t*", shape.ctypes.data))
shape = np.empty(_lib.ciarlet_element_value_rank(self._rs_element), dtype=np.uintp)
_lib.ciarlet_element_value_shape(
self._rs_element, _ffi.cast("uintptr_t*", shape.ctypes.data)
)
return tuple(int(i) for i in shape)

@property
def degree(self) -> int:
"""Degree of the element."""
return _lib.ciarlet_degree(self._rs_element)
return _lib.ciarlet_element_degree(self._rs_element)

@property
def embedded_superdegree(self) -> int:
"""Embedded superdegree of the element."""
return _lib.ciarlet_embedded_superdegree(self._rs_element)
return _lib.ciarlet_element_embedded_superdegree(self._rs_element)

@property
def dim(self) -> int:
"""Dimension (number of basis functions) of the element."""
return _lib.ciarlet_dim(self._rs_element)
return _lib.ciarlet_element_dim(self._rs_element)

@property
def continuity(self) -> Continuity:
"""Continuity of the element."""
return Continuity(_lib.ciarlet_continuity(self._rs_element))
return Continuity(_lib.ciarlet_element_continuity(self._rs_element))

@property
def map_type(self) -> MapType:
"""Pullback map type of the element."""
return MapType(_lib.ciarlet_map_type(self._rs_element))
return MapType(_lib.ciarlet_element_map_type(self._rs_element))

@property
def cell_type(self) -> ReferenceCellType:
"""Cell type of the element."""
return ReferenceCellType(_lib.ciarlet_cell_type(self._rs_element))
return ReferenceCellType(_lib.ciarlet_element_cell_type(self._rs_element))

def entity_dofs(self, entity_dim: int, entity_index: int) -> typing.List[int]:
"""Get the DOFs associated with an entity."""
dofs = np.empty(
_lib.ciarlet_entity_dofs_size(self._rs_element, entity_dim, entity_index),
_lib.ciarlet_element_entity_dofs_size(self._rs_element, entity_dim, entity_index),
dtype=np.uintp,
)
_lib.ciarlet_entity_dofs(
_lib.ciarlet_element_entity_dofs(
self._rs_element, entity_dim, entity_index, _ffi.cast("uintptr_t*", dofs.ctypes.data)
)
return [int(i) for i in dofs]

def entity_closure_dofs(self, entity_dim: int, entity_index: int) -> typing.List[int]:
"""Get the DOFs associated with the closure of an entity."""
dofs = np.empty(
_lib.ciarlet_entity_closure_dofs_size(self._rs_element, entity_dim, entity_index),
_lib.ciarlet_element_entity_closure_dofs_size(
self._rs_element, entity_dim, entity_index
),
dtype=np.uintp,
)
_lib.ciarlet_entity_closure_dofs(
_lib.ciarlet_element_entity_closure_dofs(
self._rs_element, entity_dim, entity_index, _ffi.cast("uintptr_t*", dofs.ctypes.data)
)
return [int(i) for i in dofs]
Expand All @@ -137,9 +136,9 @@ def interpolation_points(self) -> typing.List[typing.List[npt.NDArray]]:
for d, n in enumerate(entity_counts(self.cell_type)):
points_d = []
for i in range(n):
shape = (_lib.ciarlet_interpolation_npoints(self._rs_element, d, i), tdim)
points_di = np.empty(shape, dtype=self.dtype)
_lib.ciarlet_interpolation_points(
shape = (_lib.ciarlet_element_interpolation_npoints(self._rs_element, d, i), tdim)
points_di = np.empty(shape, dtype=self.dtype(0).real.dtype)
_lib.ciarlet_element_interpolation_points(
self._rs_element, d, i, _ffi.cast("void*", points_di.ctypes.data)
)
points_d.append(points_di)
Expand All @@ -153,12 +152,12 @@ def interpolation_weights(self) -> typing.List[typing.List[npt.NDArray]]:
weights_d = []
for i in range(n):
shape = (
_lib.ciarlet_interpolation_ndofs(self._rs_element, d, i),
_lib.ciarlet_element_interpolation_ndofs(self._rs_element, d, i),
self.value_size,
_lib.ciarlet_interpolation_npoints(self._rs_element, d, i),
_lib.ciarlet_element_interpolation_npoints(self._rs_element, d, i),
)
weights_di = np.empty(shape, dtype=self.dtype)
_lib.ciarlet_interpolation_weights(
_lib.ciarlet_element_interpolation_weights(
self._rs_element, d, i, _ffi.cast("void*", weights_di.ctypes.data)
)
weights_d.append(weights_di)
Expand All @@ -167,12 +166,14 @@ def interpolation_weights(self) -> typing.List[typing.List[npt.NDArray]]:

def tabulate(self, points: npt.NDArray[np.floating], nderivs: int) -> npt.NDArray:
"""Tabulate the basis functions at a set of points."""
if points.dtype != self.dtype(0).real.dtype:
raise TypeError("points has incorrect type")
shape = np.empty(4, dtype=np.uintp)
_lib.ciarlet_tabulate_array_shape(
_lib.ciarlet_element_tabulate_array_shape(
self._rs_element, nderivs, points.shape[0], _ffi.cast("uintptr_t*", shape.ctypes.data)
)
data = np.empty(shape[::-1], dtype=self.dtype)
_lib.ciarlet_tabulate(
_lib.ciarlet_element_tabulate(
self._rs_element,
_ffi.cast("void*", points.ctypes.data),
points.shape[0],
Expand All @@ -193,10 +194,10 @@ def __init__(self, rs_family: _CDataBase, owned: bool = True):
def __del__(self):
"""Delete object."""
if self._owned:
_lib.ciarlet_free_family(self._rs_family)
_lib.element_family_t_free(self._rs_family)

def element(self, cell: ReferenceCellType) -> CiarletElement:
return CiarletElement(_lib.element_family_element(self._rs_family, cell.value))
return CiarletElement(_lib.element_family_create_element(self._rs_family, cell.value))


def create_family(
Expand All @@ -206,30 +207,12 @@ def create_family(
dtype: typing.Type[np.floating] = np.float64,
) -> ElementFamily:
"""Create a new element family."""
rust_type = _rtypes[dtype]
if family == Family.Lagrange:
if dtype == np.float64:
return ElementFamily(_lib.lagrange_element_family_new_f64(degree, continuity.value))
elif dtype == np.float32:
return ElementFamily(_lib.lagrange_element_family_new_f64(degree, continuity.value))
else:
raise TypeError(f"Unsupported dtype: {dtype}")
return ElementFamily(_lib.create_lagrange_family(degree, continuity.value, rust_type))
elif family == Family.RaviartThomas:
if dtype == np.float64:
return ElementFamily(
_lib.raviart_thomas_element_family_new_f64(degree, continuity.value)
)
elif dtype == np.float32:
return ElementFamily(
_lib.raviart_thomas_element_family_new_f64(degree, continuity.value)
)
else:
raise TypeError(f"Unsupported dtype: {dtype}")
return ElementFamily(_lib.create_raviart_thomas_family(degree, continuity.value, rust_type))
elif family == Family.NedelecFirstKind:
if dtype == np.float64:
return ElementFamily(_lib.nedelec_element_family_new_f64(degree, continuity.value))
elif dtype == np.float32:
return ElementFamily(_lib.nedelec_element_family_new_f64(degree, continuity.value))
else:
raise TypeError(f"Unsupported dtype: {dtype}")
return ElementFamily(_lib.create_nedelec_family(degree, continuity.value, rust_type))
else:
raise ValueError(f"Unsupported family: {family}")
Loading

0 comments on commit c39459e

Please sign in to comment.