Skip to content

Commit

Permalink
Add element dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgensd committed May 5, 2024
1 parent 504de00 commit b18f4a2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 2 additions & 0 deletions tests/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_read_write_P_3D(
degree,
basix.LagrangeVariant.gll_warped,
shape=(mesh.geometry.dim,),
dtype=mesh.geometry.x.dtype,
)

def f(x):
Expand Down Expand Up @@ -178,6 +179,7 @@ def test_read_write_P_3D_time(
degree,
basix.LagrangeVariant.gll_warped,
shape=(mesh.geometry.dim,),
dtype=mesh.geometry.x.dtype,
)

def f(x):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_checkpointing_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_read_write_2D(
):
mesh = simplex_mesh_2D
f_dtype = get_dtype(mesh.geometry.x.dtype, is_complex)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree, dtype=mesh.geometry.x.dtype)

def f(x):
values = np.empty((2, x.shape[1]), dtype=f_dtype)
Expand All @@ -81,7 +81,7 @@ def test_read_write_3D(
):
mesh = simplex_mesh_3D
f_dtype = get_dtype(mesh.geometry.x.dtype, is_complex)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree, dtype=mesh.geometry.x.dtype)

def f(x):
values = np.empty((3, x.shape[1]), dtype=f_dtype)
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_read_write_2D_quad(
):
mesh = non_simplex_mesh_2D
f_dtype = get_dtype(mesh.geometry.x.dtype, is_complex)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree, dtype=mesh.geometry.x.dtype)

def f(x):
values = np.empty((2, x.shape[1]), dtype=f_dtype)
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_read_write_hex(
):
mesh = non_simplex_mesh_3D
f_dtype = get_dtype(mesh.geometry.x.dtype, is_complex)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree, dtype=mesh.geometry.x.dtype)

def f(x):
values = np.empty((3, x.shape[1]), dtype=f_dtype)
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_read_write_multiple(
):
mesh = non_simplex_mesh_2D
f_dtype = get_dtype(mesh.geometry.x.dtype, is_complex)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree)
el = basix.ufl.element(family, mesh.ufl_cell().cellname(), degree, dtype=mesh.geometry.x.dtype)

def f(x):
values = np.empty((2, x.shape[1]), dtype=f_dtype)
Expand Down

0 comments on commit b18f4a2

Please sign in to comment.