Skip to content

Commit

Permalink
sanitise
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Apr 10, 2024
1 parent 0b2957f commit bc1a5e0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
19 changes: 13 additions & 6 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
TSFCIntegralDataInfo = collections.namedtuple("TSFCIntegralDataInfo",
["domain", "integral_type", "subdomain_id", "domain_number", "domain_integral_type_map",
"arguments",
"coefficients", "coefficient_numbers"])
"coefficients", "coefficient_split", "coefficient_numbers"])
TSFCIntegralDataInfo.__doc__ = """
Minimal set of objects for kernel builders.
Expand Down Expand Up @@ -116,14 +116,20 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F
raise NotImplementedError("Sorry, we can't assemble the diagonal of a form for interior facet integrals")
arguments = form_data.preprocessed_form.arguments()
kernel_name = f"{prefix}_{integral_type}_integral"
coefficients = [form_data.function_replace_map[c] for c in integral_data.integral_coefficients]
# This is which coefficient in the original form the
# current coefficient is.
# Consider f*v*dx + g*v*ds, the full form contains two
# coefficients, but each integral only requires one.
coefficient_numbers = tuple(form_data.original_coefficient_positions[i]
for i, (_, enabled) in enumerate(zip(form_data.reduced_coefficients, integral_data.enabled_coefficients))
if enabled)
coefficients = []
coefficient_split = {}
coefficient_numbers = []
for i, (coeff_orig, enabled) in enumerate(zip(form_data.reduced_coefficients, integral_data.enabled_coefficients)):
if enabled:
coeff = form_data.function_replace_map[coeff_orig]
coefficients.append(coeff)
if coeff in form_data.coefficient_split:
coefficient_split[coeff] = form_data.coefficient_split[coeff]
coefficient_numbers.append(form_data.original_coefficient_positions[i])
mesh = integral_data.domain
all_meshes = extract_domains(form_data.original_form)
domain_number = all_meshes.index(mesh)
Expand All @@ -134,6 +140,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F
domain_integral_type_map={mesh: integral_data.domain_integral_type_map[mesh] if mesh in integral_data.domain_integral_type_map else None for mesh in all_meshes},
arguments=arguments,
coefficients=coefficients,
coefficient_split=coefficient_split,
coefficient_numbers=coefficient_numbers)
builder = firedrake_interface_loopy.KernelBuilder(integral_data_info,
scalar_type,
Expand All @@ -142,7 +149,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F
builder.set_coordinates(all_meshes)
builder.set_cell_orientations(all_meshes)
builder.set_cell_sizes(all_meshes)
builder.set_coefficients(integral_data, form_data)
builder.set_coefficients()
# TODO: We do not want pass constants to kernels that do not need them
# so we should attach the constants to integral data instead
builder.set_constants(form_data.constants)
Expand Down
30 changes: 12 additions & 18 deletions tsfc/kernel_interface/firedrake_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,29 +316,23 @@ def set_coordinates(self, domains):
self.domain_coordinate[domain] = f
self._coefficient(f, f"coords_{i}")

def set_coefficients(self, integral_data, form_data):
def set_coefficients(self):
"""Prepare the coefficients of the form.
:arg integral_data: UFL integral data
:arg form_data: UFL form data
"""
# enabled_coefficients is a boolean array that indicates which
# of reduced_coefficients the integral requires.
coefficient_split = form_data.coefficient_split
info = self.integral_data_info
n, k = 0, 0
for original_coeff, enabled in zip(form_data.reduced_coefficients, integral_data.enabled_coefficients):
if enabled:
coeff = form_data.function_replace_map[original_coeff]
if coeff in coefficient_split:
for i, c in enumerate(coefficient_split[coeff]):
self.coefficient_number_index_map[c] = (n, i)
self._coefficient(c, f"w_{k}")
k += 1
else:
self.coefficient_number_index_map[coeff] = (n, 0)
self._coefficient(coeff, f"w_{k}")
for coeff in info.coefficients:
if coeff in info.coefficient_split:
for i, c in enumerate(info.coefficient_split[coeff]):
self.coefficient_number_index_map[c] = (n, i)
self._coefficient(c, f"w_{k}")
k += 1
n += 1
else:
self.coefficient_number_index_map[coeff] = (n, 0)
self._coefficient(coeff, f"w_{k}")
k += 1
n += 1

def set_constants(self, constants):
for i, const in enumerate(constants):
Expand Down

0 comments on commit bc1a5e0

Please sign in to comment.