Skip to content

Commit b5b7bdf

Browse files
authored
Generalize recipe for regularizations for n dims (#69)
Generalize the recipe to build an l2 regularization to mesh dimensions other than 3.
1 parent 1d2259c commit b5b7bdf

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

src/inversion_ideas/recipes.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,14 @@ def create_tikhonov_regularization(
307307
-----
308308
TODO
309309
"""
310-
# TODO: raise errors:
311-
# if dims == 2 and alpha_z is passed
312-
# if dims == 1 and alpha_y or alpha_z are passed
310+
ndims = mesh.dim
311+
if ndims == 2 and alpha_z is not None:
312+
msg = f"Cannot pass 'alpha_z' when mesh has {ndims} dimensions."
313+
raise TypeError(msg)
314+
if ndims == 1 and (alpha_y is not None or alpha_z is not None):
315+
msg = "Cannot pass 'alpha_y' nor 'alpha_z' when mesh has 1 dimension."
316+
raise TypeError(msg)
317+
313318
smallness = Smallness(
314319
mesh,
315320
active_cells=active_cells,
@@ -326,16 +331,24 @@ def create_tikhonov_regularization(
326331
if reference_model_in_flatness:
327332
kwargs["reference_model"] = reference_model
328333

329-
flatness_x = Flatness(mesh, **kwargs, direction="x")
330-
if alpha_x is not None:
331-
flatness_x = alpha_x * flatness_x
332-
333-
flatness_y = Flatness(mesh, **kwargs, direction="y")
334-
if alpha_y is not None:
335-
flatness_y = alpha_y * flatness_y
336-
337-
flatness_z = Flatness(mesh, **kwargs, direction="z")
338-
if alpha_z is not None:
339-
flatness_z = alpha_z * flatness_z
340-
341-
return (smallness + flatness_x + flatness_y + flatness_z).flatten()
334+
match ndims:
335+
case 3:
336+
directions = ("x", "y", "z")
337+
alphas = (alpha_x, alpha_y, alpha_z)
338+
case 2:
339+
directions = ("x", "y")
340+
alphas = (alpha_x, alpha_y)
341+
case 1:
342+
directions = ("x",)
343+
alphas = (alpha_x,)
344+
case _:
345+
raise ValueError()
346+
347+
regularization = smallness
348+
for direction, alpha in zip(directions, alphas, strict=True):
349+
phi = Flatness(mesh, **kwargs, direction=direction)
350+
if alpha is not None:
351+
phi = alpha * phi
352+
regularization = regularization + phi
353+
354+
return regularization.flatten()

0 commit comments

Comments
 (0)