-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
770 additions
and
654 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Implementations.""" | ||
|
||
import importlib | ||
import os | ||
from inspect import isclass | ||
|
||
from defelement.implementations.template import Implementation, VariantNotImplemented, parse_example | ||
|
||
implementations = [] | ||
this_dir = os.path.dirname(os.path.realpath(__file__)) | ||
for file in os.listdir(this_dir): | ||
if file.endswith(".py") and not file.startswith("_") and file != "template.py": | ||
print(file) | ||
mod = importlib.import_module(f"defelement.implementations.{file[:-3]}") | ||
for name in dir(mod): | ||
if not name.startswith("_"): | ||
c = getattr(mod, name) | ||
if isclass(c) and c != Implementation and issubclass(c, Implementation): | ||
implementations.append(c) | ||
|
||
formats = {i.name: i.format for i in implementations} | ||
examples = {i.name: i.example for i in implementations} | ||
verifications = {i.name: i.verify for i in implementations if i.verification} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
"""Basix implementation.""" | ||
|
||
import typing | ||
|
||
from defelement.implementations.template import (Array, Element, Implementation, | ||
VariantNotImplemented, parse_example) | ||
|
||
|
||
class BasixImplementation(Implementation): | ||
"""Basix implementation.""" | ||
|
||
def format(string: typing.Optional[str], params: typing.Dict[str, typing.Any]) -> str: | ||
"""Format implementation string. | ||
Args: | ||
string: Implementation string | ||
params: Parameters | ||
Returns: | ||
Formatted implementation string | ||
""" | ||
out = f"basix.ElementFamily.{string}" | ||
for p, v in params.items(): | ||
out += f", {p}=" | ||
if p == "lagrange_variant": | ||
out += f"basix.LagrangeVariant.{v}" | ||
elif p == "dpc_variant": | ||
out += f"basix.DPCVariant.{v}" | ||
elif p == "discontinuous": | ||
out += v | ||
return out | ||
|
||
def example(element: Element) -> str: | ||
"""Generate Symfem examples. | ||
Args: | ||
element: The element | ||
Returns: | ||
Example code | ||
""" | ||
out = "import basix" | ||
for e in element.examples: | ||
ref, ord, variant, kwargs = parse_example(e) | ||
assert len(kwargs) == 0 | ||
ord = int(ord) | ||
|
||
try: | ||
basix_name, params = element.get_implementation_string("basix", ref, variant) | ||
except VariantNotImplemented: | ||
continue | ||
|
||
if basix_name is not None: | ||
out += "\n\n" | ||
out += f"# Create {element.name_with_variant(variant)} order {ord} on a {ref}\n" | ||
out += "element = basix.create_element(" | ||
out += f"basix.ElementFamily.{basix_name}, basix.CellType.{ref}, {ord}" | ||
if "lagrange_variant" in params: | ||
out += f", lagrange_variant=basix.LagrangeVariant.{params['lagrange_variant']}" | ||
if "dpc_variant" in params: | ||
out += f", dpc_variant=basix.DPCVariant.{params['dpc_variant']}" | ||
if "discontinuous" in params: | ||
assert params["discontinuous"] in ["True", "False"] | ||
out += f", discontinuous={params['discontinuous']}" | ||
out += ")" | ||
return out | ||
|
||
def verify( | ||
element: Element, example: str | ||
) -> typing.Tuple[typing.List[typing.List[typing.List[int]]], typing.Callable[[Array], Array]]: | ||
"""Get verification data. | ||
Args: | ||
element: Element data | ||
example: Example data | ||
Returns: | ||
List of entity dofs, and tabulation function | ||
""" | ||
import basix | ||
|
||
ref, ord, variant, kwargs = parse_example(example) | ||
assert len(kwargs) == 0 | ||
ord = int(ord) | ||
try: | ||
basix_name, params = element.get_implementation_string("basix", ref, variant) | ||
except VariantNotImplemented: | ||
raise NotImplementedError() | ||
if basix_name is None: | ||
raise NotImplementedError() | ||
kwargs = {} | ||
if "lagrange_variant" in params: | ||
kwargs["lagrange_variant"] = getattr(basix.LagrangeVariant, params['lagrange_variant']) | ||
if "dpc_variant" in params: | ||
kwargs["dpc_variant"] = getattr(basix.DPCVariant, params['dpc_variant']) | ||
if "discontinuous" in params: | ||
kwargs["discontinuous"] = params["discontinuous"] == "True" | ||
|
||
e = basix.create_element( | ||
getattr(basix.ElementFamily, basix_name), getattr(basix.CellType, ref), ord, | ||
**kwargs) | ||
return e.entity_dofs, lambda points: e.tabulate(0, points)[0].transpose((0, 2, 1)) | ||
|
||
name = "basix" | ||
verification = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
"""Basix.UFL implementation.""" | ||
|
||
import typing | ||
|
||
from defelement.implementations.basix import BasixImplementation | ||
from defelement.implementations.template import (Array, Element, Implementation, | ||
VariantNotImplemented, parse_example) | ||
|
||
|
||
class BasixUFLImplementation(Implementation): | ||
"""Basix.UFL implementation.""" | ||
|
||
def format(string: typing.Optional[str], params: typing.Dict[str, typing.Any]) -> str: | ||
"""Format implementation string. | ||
Args: | ||
string: Implementation string | ||
params: Parameters | ||
Returns: | ||
Formatted implementation string | ||
""" | ||
out = BasixImplementation.format(string, {i: j for i, j in params.items() if i != "shape"}) | ||
if "shape" in params: | ||
out += f", shape={params['shape']}" | ||
return out | ||
|
||
def example(element: Element) -> str: | ||
"""Generate Symfem examples. | ||
Args: | ||
element: The element | ||
Returns: | ||
Example code | ||
""" | ||
out = "import basix\nimport basix.ufl" | ||
for e in element.examples: | ||
ref, ord, variant, kwargs = parse_example(e) | ||
assert len(kwargs) == 0 | ||
ord = int(ord) | ||
|
||
try: | ||
basix_name, params = element.get_implementation_string("basix.ufl", ref, variant) | ||
except VariantNotImplemented: | ||
continue | ||
|
||
if basix_name is not None: | ||
out += "\n\n" | ||
out += f"# Create {element.name_with_variant(variant)} order {ord} on a {ref}\n" | ||
out += "element = basix.ufl.element(" | ||
out += f"basix.ElementFamily.{basix_name}, basix.CellType.{ref}, {ord}" | ||
if "lagrange_variant" in params: | ||
out += f", lagrange_variant=basix.LagrangeVariant.{params['lagrange_variant']}" | ||
if "dpc_variant" in params: | ||
out += f", dpc_variant=basix.DPCVariant.{params['dpc_variant']}" | ||
if "discontinuous" in params: | ||
assert params["discontinuous"] in ["True", "False"] | ||
out += f", discontinuous={params['discontinuous']}" | ||
if "shape" in params: | ||
if ref == "interval": | ||
dim = 1 | ||
elif ref in ["triangle", "quadrilateral"]: | ||
dim = 2 | ||
else: | ||
dim = 3 | ||
out += ", shape=" + params["shape"].replace("dim", f"{dim}") | ||
out += ")" | ||
return out | ||
|
||
def verify( | ||
element: Element, example: str | ||
) -> typing.Tuple[typing.List[typing.List[typing.List[int]]], typing.Callable[[Array], Array]]: | ||
"""Get verification data. | ||
Args: | ||
element: Element data | ||
example: Example data | ||
Returns: | ||
List of entity dofs, and tabulation function | ||
""" | ||
import basix | ||
import basix.ufl | ||
|
||
kwargs: typing.Dict[str, typing.Any] | ||
|
||
ref, ord, variant, kwargs = parse_example(example) | ||
assert len(kwargs) == 0 | ||
ord = int(ord) | ||
try: | ||
basix_name, params = element.get_implementation_string("basix.ufl", ref, variant) | ||
except VariantNotImplemented: | ||
raise NotImplementedError() | ||
if basix_name is None: | ||
raise NotImplementedError() | ||
kwargs = {} | ||
if "lagrange_variant" in params: | ||
kwargs["lagrange_variant"] = getattr(basix.LagrangeVariant, params['lagrange_variant']) | ||
if "dpc_variant" in params: | ||
kwargs["dpc_variant"] = getattr(basix.DPCVariant, params['dpc_variant']) | ||
if "discontinuous" in params: | ||
kwargs["discontinuous"] = params["discontinuous"] == "True" | ||
if "shape" in params: | ||
if ref == "interval": | ||
dim = 1 | ||
elif ref in ["triangle", "quadrilateral"]: | ||
dim = 2 | ||
else: | ||
dim = 3 | ||
kwargs["shape"] = tuple( | ||
dim if i == "dim" else int(i) for i in params["shape"][1:-1].split(",") if i != "") | ||
|
||
e = basix.ufl.element( | ||
getattr(basix.ElementFamily, basix_name), getattr(basix.CellType, ref), ord, **kwargs) | ||
return e.entity_dofs, lambda points: e.tabulate(0, points)[0].reshape( | ||
points.shape[0], e.reference_value_size, -1) | ||
|
||
name = "basix.ufl" | ||
verification = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
"""Bempp implementation.""" | ||
|
||
import typing | ||
|
||
from defelement.implementations.template import (Element, Implementation, VariantNotImplemented, | ||
parse_example) | ||
|
||
|
||
class BemppImplementation(Implementation): | ||
"""Bempp implementation.""" | ||
|
||
def format(string: typing.Optional[str], params: typing.Dict[str, typing.Any]) -> str: | ||
"""Format implementation string. | ||
Args: | ||
string: Implementation string | ||
params: Parameters | ||
Returns: | ||
Formatted implementation string | ||
""" | ||
return f"\"{string}\"" | ||
|
||
def example(element: Element) -> str: | ||
"""Generate Symfem examples. | ||
Args: | ||
element: The element | ||
Returns: | ||
Example code | ||
""" | ||
out = "import bempp.api" | ||
out += "\n" | ||
out += "grid = bempp.api.shapes.regular_sphere(1)" | ||
for e in element.examples: | ||
ref, ord, variant, kwargs = parse_example(e) | ||
assert len(kwargs) == 0 | ||
ord = int(ord) | ||
|
||
try: | ||
bempp_name, params = element.get_implementation_string("bempp", ref, variant) | ||
except VariantNotImplemented: | ||
continue | ||
|
||
if bempp_name is None: | ||
continue | ||
orders = [int(i) for i in params["orders"].split(",")] | ||
|
||
if ord in orders: | ||
out += "\n\n" | ||
out += f"# Create {element.name} order {ord}\n" | ||
out += "element = bempp.api.function_space(grid, " | ||
out += f"\"{bempp_name}\", {ord})" | ||
return out | ||
|
||
name = "bempp" |
Oops, something went wrong.