Skip to content

Commit

Permalink
tidy implementation code
Browse files Browse the repository at this point in the history
  • Loading branch information
mscroggs committed Sep 11, 2024
1 parent 0260ebc commit 2dcdd22
Show file tree
Hide file tree
Showing 10 changed files with 770 additions and 654 deletions.
654 changes: 0 additions & 654 deletions defelement/implementations.py

This file was deleted.

23 changes: 23 additions & 0 deletions defelement/implementations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Implementations."""

import importlib
import os
from inspect import isclass

from defelement.implementations.template import Implementation, VariantNotImplemented, parse_example

implementations = []
this_dir = os.path.dirname(os.path.realpath(__file__))
for file in os.listdir(this_dir):
if file.endswith(".py") and not file.startswith("_") and file != "template.py":
print(file)
mod = importlib.import_module(f"defelement.implementations.{file[:-3]}")
for name in dir(mod):
if not name.startswith("_"):
c = getattr(mod, name)
if isclass(c) and c != Implementation and issubclass(c, Implementation):
implementations.append(c)

formats = {i.name: i.format for i in implementations}
examples = {i.name: i.example for i in implementations}
verifications = {i.name: i.verify for i in implementations if i.verification}
105 changes: 105 additions & 0 deletions defelement/implementations/basix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Basix implementation."""

import typing

from defelement.implementations.template import (Array, Element, Implementation,
VariantNotImplemented, parse_example)


class BasixImplementation(Implementation):
"""Basix implementation."""

def format(string: typing.Optional[str], params: typing.Dict[str, typing.Any]) -> str:
"""Format implementation string.
Args:
string: Implementation string
params: Parameters
Returns:
Formatted implementation string
"""
out = f"basix.ElementFamily.{string}"
for p, v in params.items():
out += f", {p}="
if p == "lagrange_variant":
out += f"basix.LagrangeVariant.{v}"
elif p == "dpc_variant":
out += f"basix.DPCVariant.{v}"
elif p == "discontinuous":
out += v
return out

def example(element: Element) -> str:
"""Generate Symfem examples.
Args:
element: The element
Returns:
Example code
"""
out = "import basix"
for e in element.examples:
ref, ord, variant, kwargs = parse_example(e)
assert len(kwargs) == 0
ord = int(ord)

try:
basix_name, params = element.get_implementation_string("basix", ref, variant)
except VariantNotImplemented:
continue

if basix_name is not None:
out += "\n\n"
out += f"# Create {element.name_with_variant(variant)} order {ord} on a {ref}\n"
out += "element = basix.create_element("
out += f"basix.ElementFamily.{basix_name}, basix.CellType.{ref}, {ord}"
if "lagrange_variant" in params:
out += f", lagrange_variant=basix.LagrangeVariant.{params['lagrange_variant']}"
if "dpc_variant" in params:
out += f", dpc_variant=basix.DPCVariant.{params['dpc_variant']}"
if "discontinuous" in params:
assert params["discontinuous"] in ["True", "False"]
out += f", discontinuous={params['discontinuous']}"
out += ")"
return out

def verify(
element: Element, example: str
) -> typing.Tuple[typing.List[typing.List[typing.List[int]]], typing.Callable[[Array], Array]]:
"""Get verification data.
Args:
element: Element data
example: Example data
Returns:
List of entity dofs, and tabulation function
"""
import basix

ref, ord, variant, kwargs = parse_example(example)
assert len(kwargs) == 0
ord = int(ord)
try:
basix_name, params = element.get_implementation_string("basix", ref, variant)
except VariantNotImplemented:
raise NotImplementedError()
if basix_name is None:
raise NotImplementedError()
kwargs = {}
if "lagrange_variant" in params:
kwargs["lagrange_variant"] = getattr(basix.LagrangeVariant, params['lagrange_variant'])
if "dpc_variant" in params:
kwargs["dpc_variant"] = getattr(basix.DPCVariant, params['dpc_variant'])
if "discontinuous" in params:
kwargs["discontinuous"] = params["discontinuous"] == "True"

e = basix.create_element(
getattr(basix.ElementFamily, basix_name), getattr(basix.CellType, ref), ord,
**kwargs)
return e.entity_dofs, lambda points: e.tabulate(0, points)[0].transpose((0, 2, 1))

name = "basix"
verification = True
120 changes: 120 additions & 0 deletions defelement/implementations/basix_ufl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Basix.UFL implementation."""

import typing

from defelement.implementations.basix import BasixImplementation
from defelement.implementations.template import (Array, Element, Implementation,
VariantNotImplemented, parse_example)


class BasixUFLImplementation(Implementation):
"""Basix.UFL implementation."""

def format(string: typing.Optional[str], params: typing.Dict[str, typing.Any]) -> str:
"""Format implementation string.
Args:
string: Implementation string
params: Parameters
Returns:
Formatted implementation string
"""
out = BasixImplementation.format(string, {i: j for i, j in params.items() if i != "shape"})
if "shape" in params:
out += f", shape={params['shape']}"
return out

def example(element: Element) -> str:
"""Generate Symfem examples.
Args:
element: The element
Returns:
Example code
"""
out = "import basix\nimport basix.ufl"
for e in element.examples:
ref, ord, variant, kwargs = parse_example(e)
assert len(kwargs) == 0
ord = int(ord)

try:
basix_name, params = element.get_implementation_string("basix.ufl", ref, variant)
except VariantNotImplemented:
continue

if basix_name is not None:
out += "\n\n"
out += f"# Create {element.name_with_variant(variant)} order {ord} on a {ref}\n"
out += "element = basix.ufl.element("
out += f"basix.ElementFamily.{basix_name}, basix.CellType.{ref}, {ord}"
if "lagrange_variant" in params:
out += f", lagrange_variant=basix.LagrangeVariant.{params['lagrange_variant']}"
if "dpc_variant" in params:
out += f", dpc_variant=basix.DPCVariant.{params['dpc_variant']}"
if "discontinuous" in params:
assert params["discontinuous"] in ["True", "False"]
out += f", discontinuous={params['discontinuous']}"
if "shape" in params:
if ref == "interval":
dim = 1
elif ref in ["triangle", "quadrilateral"]:
dim = 2
else:
dim = 3
out += ", shape=" + params["shape"].replace("dim", f"{dim}")
out += ")"
return out

def verify(
element: Element, example: str
) -> typing.Tuple[typing.List[typing.List[typing.List[int]]], typing.Callable[[Array], Array]]:
"""Get verification data.
Args:
element: Element data
example: Example data
Returns:
List of entity dofs, and tabulation function
"""
import basix
import basix.ufl

kwargs: typing.Dict[str, typing.Any]

ref, ord, variant, kwargs = parse_example(example)
assert len(kwargs) == 0
ord = int(ord)
try:
basix_name, params = element.get_implementation_string("basix.ufl", ref, variant)
except VariantNotImplemented:
raise NotImplementedError()
if basix_name is None:
raise NotImplementedError()
kwargs = {}
if "lagrange_variant" in params:
kwargs["lagrange_variant"] = getattr(basix.LagrangeVariant, params['lagrange_variant'])
if "dpc_variant" in params:
kwargs["dpc_variant"] = getattr(basix.DPCVariant, params['dpc_variant'])
if "discontinuous" in params:
kwargs["discontinuous"] = params["discontinuous"] == "True"
if "shape" in params:
if ref == "interval":
dim = 1
elif ref in ["triangle", "quadrilateral"]:
dim = 2
else:
dim = 3
kwargs["shape"] = tuple(
dim if i == "dim" else int(i) for i in params["shape"][1:-1].split(",") if i != "")

e = basix.ufl.element(
getattr(basix.ElementFamily, basix_name), getattr(basix.CellType, ref), ord, **kwargs)
return e.entity_dofs, lambda points: e.tabulate(0, points)[0].reshape(
points.shape[0], e.reference_value_size, -1)

name = "basix.ufl"
verification = True
57 changes: 57 additions & 0 deletions defelement/implementations/bempp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Bempp implementation."""

import typing

from defelement.implementations.template import (Element, Implementation, VariantNotImplemented,
parse_example)


class BemppImplementation(Implementation):
"""Bempp implementation."""

def format(string: typing.Optional[str], params: typing.Dict[str, typing.Any]) -> str:
"""Format implementation string.
Args:
string: Implementation string
params: Parameters
Returns:
Formatted implementation string
"""
return f"\"{string}\""

def example(element: Element) -> str:
"""Generate Symfem examples.
Args:
element: The element
Returns:
Example code
"""
out = "import bempp.api"
out += "\n"
out += "grid = bempp.api.shapes.regular_sphere(1)"
for e in element.examples:
ref, ord, variant, kwargs = parse_example(e)
assert len(kwargs) == 0
ord = int(ord)

try:
bempp_name, params = element.get_implementation_string("bempp", ref, variant)
except VariantNotImplemented:
continue

if bempp_name is None:
continue
orders = [int(i) for i in params["orders"].split(",")]

if ord in orders:
out += "\n\n"
out += f"# Create {element.name} order {ord}\n"
out += "element = bempp.api.function_space(grid, "
out += f"\"{bempp_name}\", {ord})"
return out

name = "bempp"
Loading

0 comments on commit 2dcdd22

Please sign in to comment.