Skip to content

Commit

Permalink
Support assembly of linear forms with cell integrals on mixed-topolog…
Browse files Browse the repository at this point in the history
…y meshes (#3606)

* Start on RHS

* Expose to Python

* Call correct function

* Add comment

* Update assembler

* Get correct cells

* Add more interesting RHS

* Remove cout

* Remove prints

* Ruff

* Ruff

* Ruff

* Docs
  • Loading branch information
jpdean authored Jan 30, 2025
1 parent 372f8b7 commit d0844c9
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 113 deletions.
227 changes: 116 additions & 111 deletions cpp/dolfinx/fem/assemble_vector_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1140,13 +1140,12 @@ void apply_lifting(
/// @param[in,out] b The vector to be assembled. It will not be zeroed
/// before assembly.
/// @param[in] L Linear forms to assemble into b.
/// @param[in] x_dofmap Mesh geometry dofmap.
/// @param[in] x Mesh coordinates.
/// @param[in] constants Packed constants that appear in `L`.
/// @param[in] coefficients Packed coefficients that appear in `L`.
template <dolfinx::scalar T, std::floating_point U>
void assemble_vector(
std::span<T> b, const Form<T, U>& L, mdspan2_t x_dofmap,
std::span<T> b, const Form<T, U>& L,
std::span<const scalar_value_type_t<T>> x, std::span<const T> constants,
const std::map<std::pair<IntegralType, int>,
std::pair<std::span<const T>, int>>& coefficients)
Expand All @@ -1159,123 +1158,131 @@ void assemble_vector(
auto mesh0 = L.function_spaces().at(0)->mesh();
assert(mesh0);

// Get dofmap data
assert(L.function_spaces().at(0));
auto element = L.function_spaces().at(0)->element();
assert(element);
std::shared_ptr<const fem::DofMap> dofmap
= L.function_spaces().at(0)->dofmap();
assert(dofmap);
auto dofs = dofmap->map();
const int bs = dofmap->bs();

fem::DofTransformKernel<T> auto P0
= element->template dof_transformation_fn<T>(doftransform::standard);

std::span<const std::uint32_t> cell_info0;
if (element->needs_dof_transformations() or L.needs_facet_permutations())
{
mesh0->topology_mutable()->create_entity_permutations();
cell_info0 = std::span(mesh0->topology()->get_cell_permutation_info());
}

for (int i : L.integral_ids(IntegralType::cell))
const int num_cell_types = mesh->topology()->cell_types().size();
for (int cell_type_idx = 0; cell_type_idx < num_cell_types; ++cell_type_idx)
{
auto fn = L.kernel(IntegralType::cell, i);
assert(fn);
auto& [coeffs, cstride] = coefficients.at({IntegralType::cell, i});
std::span<const std::int32_t> cells = L.domain(IntegralType::cell, i);
if (bs == 1)
{
impl::assemble_cells<T, 1>(
P0, b, x_dofmap, x, cells,
{dofs, bs, L.domain(IntegralType::cell, i, *mesh0)}, fn, constants,
coeffs, cstride, cell_info0);
}
else if (bs == 3)
// Geometry dofmap and data
mdspan2_t x_dofmap = mesh->geometry().dofmap(cell_type_idx);

// Get dofmap data
assert(L.function_spaces().at(0));
auto element = L.function_spaces().at(0)->elements(cell_type_idx);
assert(element);
std::shared_ptr<const fem::DofMap> dofmap
= L.function_spaces().at(0)->dofmaps(cell_type_idx);
assert(dofmap);
auto dofs = dofmap->map();
const int bs = dofmap->bs();

fem::DofTransformKernel<T> auto P0
= element->template dof_transformation_fn<T>(doftransform::standard);

std::span<const std::uint32_t> cell_info0;
if (element->needs_dof_transformations() or L.needs_facet_permutations())
{
impl::assemble_cells<T, 3>(
P0, b, x_dofmap, x, cells,
{dofs, bs, L.domain(IntegralType::cell, i, *mesh0)}, fn, constants,
coeffs, cstride, cell_info0);
mesh0->topology_mutable()->create_entity_permutations();
cell_info0 = std::span(mesh0->topology()->get_cell_permutation_info());
}
else
{
impl::assemble_cells(P0, b, x_dofmap, x, cells,
{dofs, bs, L.domain(IntegralType::cell, i, *mesh0)},
fn, constants, coeffs, cstride, cell_info0);
}
}

std::span<const std::uint8_t> perms;
if (L.needs_facet_permutations())
{
mesh->topology_mutable()->create_entity_permutations();
perms = std::span(mesh->topology()->get_facet_permutations());
}

mesh::CellType cell_type = mesh->topology()->cell_type();
int num_facets_per_cell
= mesh::cell_num_entities(cell_type, mesh->topology()->dim() - 1);
for (int i : L.integral_ids(IntegralType::exterior_facet))
{
auto fn = L.kernel(IntegralType::exterior_facet, i);
assert(fn);
auto& [coeffs, cstride]
= coefficients.at({IntegralType::exterior_facet, i});
std::span<const std::int32_t> facets
= L.domain(IntegralType::exterior_facet, i);
if (bs == 1)
for (int i : L.integral_ids(IntegralType::cell))
{
impl::assemble_exterior_facets<T, 1>(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{dofs, bs, L.domain(IntegralType::exterior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
}
else if (bs == 3)
{
impl::assemble_exterior_facets<T, 3>(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{dofs, bs, L.domain(IntegralType::exterior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
}
else
{
impl::assemble_exterior_facets(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{dofs, bs, L.domain(IntegralType::exterior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
auto fn = L.kernel(IntegralType::cell, i, cell_type_idx);
assert(fn);
auto& [coeffs, cstride] = coefficients.at({IntegralType::cell, i});
std::vector<std::int32_t> cells = L.domain(IntegralType::cell, i, cell_type_idx);
if (bs == 1)
{
impl::assemble_cells<T, 1>(
P0, b, x_dofmap, x, cells,
{dofs, bs, L.domain(IntegralType::cell, i, cell_type_idx, *mesh0)}, fn, constants,
coeffs, cstride, cell_info0);
}
else if (bs == 3)
{
impl::assemble_cells<T, 3>(
P0, b, x_dofmap, x, cells,
{dofs, bs, L.domain(IntegralType::cell, i, cell_type_idx, *mesh0)}, fn, constants,
coeffs, cstride, cell_info0);
}
else
{
impl::assemble_cells(
P0, b, x_dofmap, x, cells,
{dofs, bs, L.domain(IntegralType::cell, i, cell_type_idx, *mesh0)}, fn, constants,
coeffs, cstride, cell_info0);
}
}
}

for (int i : L.integral_ids(IntegralType::interior_facet))
{
auto fn = L.kernel(IntegralType::interior_facet, i);
assert(fn);
auto& [coeffs, cstride]
= coefficients.at({IntegralType::interior_facet, i});
std::span<const std::int32_t> facets
= L.domain(IntegralType::interior_facet, i);
if (bs == 1)
std::span<const std::uint8_t> perms;
if (L.needs_facet_permutations())
{
impl::assemble_interior_facets<T, 1>(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{*dofmap, bs, L.domain(IntegralType::interior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
mesh->topology_mutable()->create_entity_permutations();
perms = std::span(mesh->topology()->get_facet_permutations());
}
else if (bs == 3)

mesh::CellType cell_type = mesh->topology()->cell_types()[cell_type_idx];
int num_facets_per_cell
= mesh::cell_num_entities(cell_type, mesh->topology()->dim() - 1);
for (int i : L.integral_ids(IntegralType::exterior_facet))
{
impl::assemble_interior_facets<T, 3>(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{*dofmap, bs, L.domain(IntegralType::interior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
auto fn = L.kernel(IntegralType::exterior_facet, i);
assert(fn);
auto& [coeffs, cstride]
= coefficients.at({IntegralType::exterior_facet, i});
std::span<const std::int32_t> facets
= L.domain(IntegralType::exterior_facet, i);
if (bs == 1)
{
impl::assemble_exterior_facets<T, 1>(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{dofs, bs, L.domain(IntegralType::exterior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
}
else if (bs == 3)
{
impl::assemble_exterior_facets<T, 3>(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{dofs, bs, L.domain(IntegralType::exterior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
}
else
{
impl::assemble_exterior_facets(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{dofs, bs, L.domain(IntegralType::exterior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
}
}
else

for (int i : L.integral_ids(IntegralType::interior_facet))
{
impl::assemble_interior_facets(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{*dofmap, bs, L.domain(IntegralType::interior_facet, i, *mesh0)}, fn,
constants, coeffs, cstride, cell_info0, perms);
auto fn = L.kernel(IntegralType::interior_facet, i);
assert(fn);
auto& [coeffs, cstride]
= coefficients.at({IntegralType::interior_facet, i});
std::span<const std::int32_t> facets
= L.domain(IntegralType::interior_facet, i);
if (bs == 1)
{
impl::assemble_interior_facets<T, 1>(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{*dofmap, bs, L.domain(IntegralType::interior_facet, i, *mesh0)},
fn, constants, coeffs, cstride, cell_info0, perms);
}
else if (bs == 3)
{
impl::assemble_interior_facets<T, 3>(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{*dofmap, bs, L.domain(IntegralType::interior_facet, i, *mesh0)},
fn, constants, coeffs, cstride, cell_info0, perms);
}
else
{
impl::assemble_interior_facets(
P0, b, x_dofmap, x, num_facets_per_cell, facets,
{*dofmap, bs, L.domain(IntegralType::interior_facet, i, *mesh0)},
fn, constants, coeffs, cstride, cell_info0, perms);
}
}
}
}
Expand All @@ -1296,15 +1303,13 @@ void assemble_vector(
assert(mesh);
if constexpr (std::is_same_v<U, scalar_value_type_t<T>>)
{
assemble_vector(b, L, mesh->geometry().dofmap(), mesh->geometry().x(),
constants, coefficients);
assemble_vector(b, L, mesh->geometry().x(), constants, coefficients);
}
else
{
auto x = mesh->geometry().x();
std::vector<scalar_value_type_t<T>> _x(x.begin(), x.end());
assemble_vector(b, L, mesh->geometry().dofmap(), _x, constants,
coefficients);
assemble_vector(b, L, _x, constants, coefficients);
}
}
} // namespace dolfinx::fem::impl
10 changes: 9 additions & 1 deletion python/demo/demo_mixed-topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dolfinx.fem import (
FunctionSpace,
assemble_matrix,
assemble_vector,
coordinate_element,
mixed_topology_form,
)
Expand Down Expand Up @@ -112,28 +113,35 @@
# FIXME This hack is required at the moment because UFL does not yet know about
# mixed topology meshes.
a = []
L = []
for i, cell_name in enumerate(["hexahedron", "prism"]):
print(f"Creating form for {cell_name}")
element = basix.ufl.wrap_element(elements[i])
domain = ufl.Mesh(basix.ufl.element("Lagrange", cell_name, 1, shape=(3,)))
V = FunctionSpace(Mesh(mesh, domain), element, V_cpp)
u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
k = 12.0
x = ufl.SpatialCoordinate(domain)
a += [(ufl.inner(ufl.grad(u), ufl.grad(v)) - k**2 * u * v) * ufl.dx]
f = ufl.sin(ufl.pi * x[0]) * ufl.sin(ufl.pi * x[1])
L += [f * v * ufl.dx]

# Compile the form
# FIXME: For the time being, since UFL doesn't understand mixed topology meshes,
# we have to call mixed_topology_form instead of form.
a_form = mixed_topology_form(a, dtype=np.float64)
L_form = mixed_topology_form(L, dtype=np.float64)

# Assemble the matrix
A = assemble_matrix(a_form)
b = assemble_vector(L_form)

# Solve
A_scipy = A.to_scipy()
b_scipy = np.ones(A_scipy.shape[1])
b_scipy = b.array

x = spsolve(A_scipy, b_scipy)

print(f"Solution vector norm {np.linalg.norm(x)}")

# I/O
Expand Down
4 changes: 3 additions & 1 deletion python/dolfinx/fem/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def _pack(form):

def create_vector(L: Form) -> la.Vector:
"""Create a Vector that is compatible with a given linear form"""
dofmap = L.function_spaces[0].dofmap
# Can just take the first dofmap here, since all dof maps have the same
# index map in mixed-topology meshes
dofmap = L.function_spaces[0].dofmaps(0)
return la.vector(dofmap.index_map, dofmap.index_map_bs, dtype=L.dtype)


Expand Down
1 change: 1 addition & 0 deletions python/dolfinx/wrappers/fem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ void declare_function_space(nb::module_& m, std::string type)
.def_prop_ro("element", &dolfinx::fem::FunctionSpace<T>::element)
.def_prop_ro("mesh", &dolfinx::fem::FunctionSpace<T>::mesh)
.def_prop_ro("dofmap", &dolfinx::fem::FunctionSpace<T>::dofmap)
.def("dofmaps", &dolfinx::fem::FunctionSpace<T>::dofmaps, nb::arg("cell_type_index"))
.def("sub", &dolfinx::fem::FunctionSpace<T>::sub, nb::arg("component"))
.def("tabulate_dof_coordinates",
[](const dolfinx::fem::FunctionSpace<T>& self)
Expand Down

0 comments on commit d0844c9

Please sign in to comment.