Skip to content

Commit

Permalink
suggestion from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 14, 2024
1 parent 0b0296d commit 40cf6d7
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,18 @@ def get_assembler(form, *args, **kwargs):
"""
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
bcs = kwargs.get('bcs', None)
fc_params = kwargs.get('form_compiler_parameters', None)
if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed:
mat_type = kwargs.get('mat_type', None)
fc_params = kwargs.get('form_compiler_parameters', None)
# Preprocess the DAG and restructure the DAG
# Only pre-process `form` once beforehand to avoid pre-processing for each assembly call
form = BaseFormAssembler.preprocess_base_form(form, mat_type=mat_type, form_compiler_parameters=fc_params)
if isinstance(form, (ufl.form.Form, slate.TensorBase)) and not BaseFormAssembler.base_form_operands(form):
diagonal = kwargs.pop('diagonal', False)
if len(form.arguments()) == 0:
return ZeroFormAssembler(form, form_compiler_parameters=fc_params)
return ZeroFormAssembler(form, **kwargs)
elif len(form.arguments()) == 1 or diagonal:
return OneFormAssembler(form, *args, bcs=bcs, form_compiler_parameters=fc_params, needs_zeroing=kwargs.get('needs_zeroing', True),
zero_bc_nodes=kwargs.get('zero_bc_nodes', True), diagonal=diagonal)
return OneFormAssembler(form, *args, diagonal=diagonal, **kwargs)
elif len(form.arguments()) == 2:
return TwoFormAssembler(form, *args, **kwargs)
else:
Expand Down Expand Up @@ -1149,13 +1147,13 @@ class OneFormAssembler(ParloopFormAssembler):

@classmethod
def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False, weight=1.0):
zero_bc_nodes=True, diagonal=False, weight=1.0):
bcs = solving._extract_bcs(bcs)
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal, weight

@FormAssembler._skip_if_initialised
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False, weight=1.0):
zero_bc_nodes=True, diagonal=False, weight=1.0):
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
self._weight = weight
self._diagonal = diagonal
Expand Down

0 comments on commit 40cf6d7

Please sign in to comment.