Skip to content

Commit

Permalink
variants
Browse files Browse the repository at this point in the history
  • Loading branch information
mscroggs committed Sep 15, 2023
1 parent c2860eb commit 158963c
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 140 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ The entries in this yaml file are:
<tr><td>`reference&#8209;elements`</td><td>{{tick}}</td><td>The reference element(s) that this finite element can be defined on.</td></tr>
<tr><td>`alt&#8209;names`</td><td></td><td>Alternative (HTML) names of the element.</td></tr>
<tr><td>`short&#8209;names`</td><td></td><td>Abbreviated names of the element.</td></tr>
<tr><td>`variants`</td><td></td><td>Variants of this element.</td></tr>
<tr><td>`complexes`</td><td></td><td>Any discretiations of complexes that this element is part of.</td></tr>
<tr><td>`dofs`</td><td></td><td>Description of the DOFs of this element.</td></tr>
<tr><td>`ndofs`</td><td></td><td>The number of DOFs the element has and the A-numbers of the [OEIS](http://oeis.org) sequence(s) giving the number of DOFs.</td></tr>
Expand Down
11 changes: 8 additions & 3 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,19 @@ def write_html_page(path, title, content):
assert e.implemented("symfem")

for eg in e.examples:
cell, order, kwargs = parse_example(eg)
symfem_name, params = e.get_implementation_string("symfem", cell)
cell, order, variant, kwargs = parse_example(eg)
symfem_name, params = e.get_implementation_string("symfem", cell, variant)

fname = f"{cell}-{e.filename}-{order}.html"
fname = f"{cell}-{e.filename}"
if variant is not None:
fname += f"-{variant}"
fname += f"-{order}.html"
for s in " ()":
fname = fname.replace(s, "-")

name = f"{cell}<br />order {order}"
if variant is not None:
name += f"<br />{e.variant_name(variant)} variant"
for i, j in kwargs.items():
name += f"<br />{i}={str(j).replace(' ', '&nbsp;')}"

Expand Down
125 changes: 96 additions & 29 deletions builder/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from . import implementations
from . import settings
from .families import keys_and_names, arnold_logg_reference, cockburn_fu_reference
from .implementations import VariantNotImplemented
from .markup import insert_links
from .polyset import make_poly_set, make_extra_info

Expand Down Expand Up @@ -183,6 +184,14 @@ def __init__(self, data, fname):
self.created = None
self.modified = None

def name_with_variant(self, variant):
if variant is None:
return self.name
return f"{self.name} ({self.variant_name(variant)} variant)"

def variant_name(self, variant):
return self.data["variants"][variant]["variant-name"]

def min_order(self, ref):
if "min-order" not in self.data:
return 0
Expand Down Expand Up @@ -441,14 +450,20 @@ def html_link(self):
def implemented(self, lib):
return lib in self.data

def get_implementation_string(self, lib, reference):
def get_implementation_string(self, lib, reference, variant=None):
assert self.implemented(lib)
if isinstance(self.data[lib], dict):
if reference not in self.data[lib]:
if variant is None:
data = self.data[lib]
else:
if variant not in self.data[lib]:
raise VariantNotImplemented()
data = self.data[lib][variant]
if isinstance(data, dict):
if reference not in data:
return None, {}
out = self.data[lib][reference]
out = data[reference]
else:
out = self.data[lib]
out = data
params = {}
if "=" in out:
sp = out.split("=")
Expand All @@ -467,33 +482,85 @@ def get_implementation_string(self, lib, reference):

def list_of_implementation_strings(self, lib, joiner="<br />"):
assert self.implemented(lib)
if isinstance(self.data[lib], str):
s, params = self.get_implementation_string(lib, None)
if lib == "basix":
s = f"basix.ElementFamily.{s}"
if "lagrange_variant" in params:
s += f", ..., basix.LagrangeVariant.{params['lagrange_variant']}"
else:
s = f"\"{s}\""
if "variant" in params:
s += f", variant=\"{params['variant']}\""
return f"<code>{s}</code>"

if "variants" in self.data:
variants = self.data["variants"]
else:
variants = {None: {}}

i_dict = {}
for i, j in self.data[lib].items():
s, params = self.get_implementation_string(lib, i)
if lib == "basix":
s = f"basix.ElementFamily.{s}"
if "lagrange_variant" in params:
s += f", ..., basix.LagrangeVariant.{params['lagrange_variant']}"
for v, vinfo in variants.items():
if v is None:
data = self.data[lib]
else:
s = f"\"{s}\""
if "variant" in params:
s += f", variant=\"{params['variant']}\""
if s not in i_dict:
i_dict[s] = []
i_dict[s].append(i)
imp_list = [f"<code>{i}</code> ({', '.join(j)})" for i, j in i_dict.items()]
if v not in self.data[lib]:
continue
data = self.data[lib][v]
if isinstance(data, str):
s, params = self.get_implementation_string(lib, None, v)
# TODO: move this to implementations.py
if lib == "basix":
s = f"basix.ElementFamily.{s}"
if "lagrange_variant" in params:
s += f", lagrange_variant=basix.LagrangeVariant.{params['lagrange_variant']}"
if "dpc_variant" in params:
s += f", dpc_variant=basix.DPCVariant.{params['dpc_variant']}"
if "discontinuous" in params:
s += f", discontinuous={params['discontinuous']}"
elif lib == "basix.ufl":
s = f"basix.ElementFamily.{s}"
if "lagrange_variant" in params:
s += f", lagrange_variant=basix.LagrangeVariant.{params['lagrange_variant']}"
if "dpc_variant" in params:
s += f", dpc_variant=basix.DPCVariant.{params['dpc_variant']}"
if "rank" in params:
s += f", rank={params['rank']}"
if "discontinuous" in params:
s += f", discontinuous={params['discontinuous']}"
else:
s = f"\"{s}\""
if "variant" in params:
s += f", variant=\"{params['variant']}\""
if s not in i_dict:
i_dict[s] = []
if v is None:
i_dict[s].append("")
else:
i_dict[s].append(vinfo["variant-name"])
else:
for i, j in data.items():
s, params = self.get_implementation_string(lib, i, v)
if lib == "basix":
s = f"basix.ElementFamily.{s}"
if "lagrange_variant" in params:
s += (", lagrange_variant=basix.LagrangeVariant."
f"{params['lagrange_variant']}")
if "dpc_variant" in params:
s += f", dpc_variant=basix.DPCVariant.{params['dpc_variant']}"
elif lib == "basix.ufl":
s = f"basix.ElementFamily.{s}"
if "lagrange_variant" in params:
s += f", lagrange_variant=basix.LagrangeVariant.{params['lagrange_variant']}"
if "dpc_variant" in params:
s += f", dpc_variant=basix.DPCVariant.{params['dpc_variant']}"
if "rank" in params:
s += f", rank={params['rank']}"
if "discontinuous" in params:
s += f", discontinuous={params['discontinuous']}"
else:
s = f"\"{s}\""
if "variant" in params:
s += f", variant=\"{params['variant']}\""
if s not in i_dict:
i_dict[s] = []
if v is None:
i_dict[s].append(i)
else:
i_dict[s].append(f"{i}, {vinfo['variant-name']}")
if len(i_dict) == 1:
return f"<code>{list(i_dict.keys())[0]}</code>"
imp_list = [f"<code>{i}</code> <span style='font-size:60%'>({'; '.join(j)})</span>"
for i, j in i_dict.items()]
if joiner is None:
return imp_list
else:
Expand Down
72 changes: 50 additions & 22 deletions builder/implementations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import re


class VariantNotImplemented(BaseException):
pass


def _parse_value(v):
v = v.strip()
if v[0] == "[" and v[-1] == "]":
Expand All @@ -22,21 +26,26 @@ def parse_example(e):
kwargs[key] = _parse_value(value)
else:
kwargs = {}
ref, order = e.split(",")
return ref, int(order), kwargs
s = e.split(",")
if len(s) == 3:
ref, order, variant = s
else:
ref, order = e.split(",")
variant = None
return ref, int(order), variant, kwargs


def symfem_example(element):
out = "import symfem"
for e in element.examples:
ref, ord, kwargs = parse_example(e)
ref, ord, variant, kwargs = parse_example(e)
ord = int(ord)

symfem_name, params = element.get_implementation_string("symfem", ref)
symfem_name, params = element.get_implementation_string("symfem", ref, variant)

if symfem_name is not None:
out += "\n\n"
out += f"# Create {element.name} order {ord} on a {ref}\n"
out += f"# Create {element.name_with_variant(variant)} order {ord} on a {ref}\n"
if ref == "dual polygon":
out += f"element = symfem.create_element(\"{ref}(4)\","
else:
Expand All @@ -57,15 +66,18 @@ def symfem_example(element):
def basix_example(element):
out = "import basix"
for e in element.examples:
ref, ord, kwargs = parse_example(e)
ref, ord, variant, kwargs = parse_example(e)
assert len(kwargs) == 0
ord = int(ord)

basix_name, params = element.get_implementation_string("basix", ref)
try:
basix_name, params = element.get_implementation_string("basix", ref, variant)
except VariantNotImplemented:
continue

if basix_name is not None:
out += "\n\n"
out += f"# Create {element.name} order {ord} on a {ref}\n"
out += f"# Create {element.name_with_variant(variant)} order {ord} on a {ref}\n"
out += "element = basix.create_element("
out += f"basix.ElementFamily.{basix_name}, basix.CellType.{ref}, {ord}"
if "lagrange_variant" in params:
Expand All @@ -82,15 +94,18 @@ def basix_example(element):
def basix_ufl_example(element):
out = "import basix\nimport basix.ufl"
for e in element.examples:
ref, ord, kwargs = parse_example(e)
ref, ord, variant, kwargs = parse_example(e)
assert len(kwargs) == 0
ord = int(ord)

basix_name, params = element.get_implementation_string("basix.ufl", ref)
try:
basix_name, params = element.get_implementation_string("basix.ufl", ref, variant)
except VariantNotImplemented:
continue

if basix_name is not None:
out += "\n\n"
out += f"# Create {element.name} order {ord} on a {ref}\n"
out += f"# Create {element.name_with_variant(variant)} order {ord} on a {ref}\n"
out += "element = basix.ufl.element("
out += f"basix.ElementFamily.{basix_name}, basix.CellType.{ref}, {ord}"
if "lagrange_variant" in params:
Expand All @@ -109,15 +124,18 @@ def basix_ufl_example(element):
def ufl_example(element):
out = "import ufl"
for e in element.examples:
ref, ord, kwargs = parse_example(e)
ref, ord, variant, kwargs = parse_example(e)
assert len(kwargs) == 0
ord = int(ord)

ufl_name, params = element.get_implementation_string("ufl", ref)
try:
ufl_name, params = element.get_implementation_string("ufl", ref, variant)
except VariantNotImplemented:
continue

if ufl_name is not None:
out += "\n\n"
out += f"# Create {element.name} order {ord} on a {ref}\n"
out += f"# Create {element.name_with_variant(variant)} order {ord} on a {ref}\n"
if "type" in params:
out += f"element = ufl.{params['type']}("
else:
Expand All @@ -131,11 +149,15 @@ def bempp_example(element):
out += "\n"
out += "grid = bempp.api.shapes.regular_sphere(1)"
for e in element.examples:
ref, ord, kwargs = parse_example(e)
ref, ord, variant, kwargs = parse_example(e)
assert len(kwargs) == 0
ord = int(ord)

bempp_name, params = element.get_implementation_string("bempp", ref)
try:
bempp_name, params = element.get_implementation_string("bempp", ref, variant)
except VariantNotImplemented:
continue

if bempp_name is None:
continue
orders = [int(i) for i in params["orders"].split(",")]
Expand Down Expand Up @@ -177,9 +199,9 @@ def symfem_tabulate(element, example):
import numpy as np
import symfem

ref, ord, kwargs = parse_example(example)
ref, ord, variant, kwargs = parse_example(example)
ord = int(ord)
symfem_name, params = element.get_implementation_string("symfem", ref)
symfem_name, params = element.get_implementation_string("symfem", ref, variant)
assert symfem_name is not None
if ref == "dual polygon":
ref += "(4)"
Expand All @@ -190,10 +212,13 @@ def symfem_tabulate(element, example):
def basix_tabulate(element, example):
import basix

ref, ord, kwargs = parse_example(example)
ref, ord, variant, kwargs = parse_example(example)
assert len(kwargs) == 0
ord = int(ord)
basix_name, params = element.get_implementation_string("basix", ref)
try:
basix_name, params = element.get_implementation_string("basix", ref, variant)
except VariantNotImplemented:
raise NotImplementedError()
if basix_name is None:
raise NotImplementedError()
kwargs = {}
Expand All @@ -215,10 +240,13 @@ def basix_ufl_tabulate(element, example):
import basix
import basix.ufl

ref, ord, kwargs = parse_example(example)
ref, ord, variant, kwargs = parse_example(example)
assert len(kwargs) == 0
ord = int(ord)
basix_name, params = element.get_implementation_string("basix.ufl", ref)
try:
basix_name, params = element.get_implementation_string("basix.ufl", ref, variant)
except VariantNotImplemented:
raise NotImplementedError()
if basix_name is None:
raise NotImplementedError()
kwargs = {}
Expand Down
Loading

0 comments on commit 158963c

Please sign in to comment.