Skip to content

Commit

Permalink
add dtype and continuity to ElementFamily (#65)
Browse files Browse the repository at this point in the history
* add dtype and continuity to ElementFamily

* type
  • Loading branch information
mscroggs authored Nov 28, 2024
1 parent 4526a66 commit 560eaa3
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
37 changes: 32 additions & 5 deletions python/ndelement/ciarlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __del__(self):
_lib.ciarlet_element_t_free(self._rs_element)

@property
def dtype(self):
def dtype(self) -> typing.Type[np.floating]:
"""Data type."""
return _dtypes[_lib.ciarlet_element_dtype(self._rs_element)]

Expand Down Expand Up @@ -186,18 +186,31 @@ def tabulate(self, points: npt.NDArray[np.floating], nderivs: int) -> npt.NDArra
class ElementFamily(object):
"""Ciarlet element."""

def __init__(self, family: Family, degree: int, rs_family: _CDataBase, owned: bool = True):
def __init__(
self,
family: Family,
degree: int,
continuity: Continuity,
rs_family: _CDataBase,
owned: bool = True,
):
"""Initialise."""
self._rs_family = rs_family
self._owned = owned
self._family = family
self._degree = degree
self._continuity = continuity

def __del__(self):
"""Delete object."""
if self._owned:
_lib.element_family_t_free(self._rs_family)

@property
def dtype(self) -> typing.Type[np.floating]:
"""Data type."""
return _dtypes[_lib.element_family_dtype(self._rs_family)]

@property
def family(self) -> Family:
"""The family."""
Expand All @@ -208,6 +221,11 @@ def degree(self) -> int:
"""The degree."""
return self._degree

@property
def continuity(self) -> Continuity:
"""Continuity."""
return self._continuity

def element(self, cell: ReferenceCellType) -> CiarletElement:
"""Create an element."""
# TODO: remove these error once https://github.com/linalg-rs/rlst/issues/98 is fixed
Expand Down Expand Up @@ -255,15 +273,24 @@ def create_family(
rust_type = _rtypes[dtype]
if family == Family.Lagrange:
return ElementFamily(
family, degree, _lib.create_lagrange_family(degree, continuity.value, rust_type)
family,
degree,
continuity,
_lib.create_lagrange_family(degree, continuity.value, rust_type),
)
elif family == Family.RaviartThomas:
return ElementFamily(
family, degree, _lib.create_raviart_thomas_family(degree, continuity.value, rust_type)
family,
degree,
continuity,
_lib.create_raviart_thomas_family(degree, continuity.value, rust_type),
)
elif family == Family.NedelecFirstKind:
return ElementFamily(
family, degree, _lib.create_nedelec_family(degree, continuity.value, rust_type)
family,
degree,
continuity,
_lib.create_nedelec_family(degree, continuity.value, rust_type),
)
else:
raise ValueError(f"Unsupported family: {family}")
25 changes: 25 additions & 0 deletions src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ pub mod ciarlet {
DType::F64 => Box::new(ciarlet::RaviartThomasElementFamily::<f64>::new(
degree, continuity,
)),
DType::C32 => Box::new(ciarlet::RaviartThomasElementFamily::<c32>::new(
degree, continuity,
)),
DType::C64 => Box::new(ciarlet::RaviartThomasElementFamily::<c64>::new(
degree, continuity,
)),
_ => panic!("Unsupported dtype"),
};

Expand All @@ -309,6 +315,12 @@ pub mod ciarlet {
DType::F64 => Box::new(ciarlet::NedelecFirstKindElementFamily::<f64>::new(
degree, continuity,
)),
DType::C32 => Box::new(ciarlet::NedelecFirstKindElementFamily::<c32>::new(
degree, continuity,
)),
DType::C64 => Box::new(ciarlet::NedelecFirstKindElementFamily::<c64>::new(
degree, continuity,
)),
_ => panic!("Unsupported dtype"),
};

Expand All @@ -331,6 +343,19 @@ pub mod ciarlet {
ciarlet_element
}

#[concretise_types(
gen_type(name = "dtype", replace_with = ["f32", "f64", "c32", "c64"]),
field(arg = 0, name = "element_family", wrapper = "ElementFamilyT", replace_with = ["crate::ciarlet::LagrangeElementFamily<{{dtype}}>", "ciarlet::RaviartThomasElementFamily<{{dtype}}>", "ciarlet::NedelecFirstKindElementFamily<{{dtype}}>"])
)]
pub fn element_family_dtype<
T: RlstScalar + MatrixInverse + DTypeIdentifier,
F: ElementFamily<CellType = ReferenceCellType, T = T>,
>(
_elem: &F,
) -> DType {
<T as DTypeIdentifier>::dtype()
}

#[concretise_types(
gen_type(name = "dtype", replace_with = ["f32", "f64", "c32", "c64"]),
field(arg = 0, name = "element", wrapper = "CiarletElementT", replace_with = ["CiarletElement<{{dtype}}>"])
Expand Down

0 comments on commit 560eaa3

Please sign in to comment.