@@ -307,9 +307,14 @@ def create_tikhonov_regularization(
307307 -----
308308 TODO
309309 """
310- # TODO: raise errors:
311- # if dims == 2 and alpha_z is passed
312- # if dims == 1 and alpha_y or alpha_z are passed
310+ ndims = mesh .dim
311+ if ndims == 2 and alpha_z is not None :
312+ msg = f"Cannot pass 'alpha_z' when mesh has { ndims } dimensions."
313+ raise TypeError (msg )
314+ if ndims == 1 and (alpha_y is not None or alpha_z is not None ):
315+ msg = "Cannot pass 'alpha_y' nor 'alpha_z' when mesh has 1 dimension."
316+ raise TypeError (msg )
317+
313318 smallness = Smallness (
314319 mesh ,
315320 active_cells = active_cells ,
@@ -326,16 +331,24 @@ def create_tikhonov_regularization(
326331 if reference_model_in_flatness :
327332 kwargs ["reference_model" ] = reference_model
328333
329- flatness_x = Flatness (mesh , ** kwargs , direction = "x" )
330- if alpha_x is not None :
331- flatness_x = alpha_x * flatness_x
332-
333- flatness_y = Flatness (mesh , ** kwargs , direction = "y" )
334- if alpha_y is not None :
335- flatness_y = alpha_y * flatness_y
336-
337- flatness_z = Flatness (mesh , ** kwargs , direction = "z" )
338- if alpha_z is not None :
339- flatness_z = alpha_z * flatness_z
340-
341- return (smallness + flatness_x + flatness_y + flatness_z ).flatten ()
334+ match ndims :
335+ case 3 :
336+ directions = ("x" , "y" , "z" )
337+ alphas = (alpha_x , alpha_y , alpha_z )
338+ case 2 :
339+ directions = ("x" , "y" )
340+ alphas = (alpha_x , alpha_y )
341+ case 1 :
342+ directions = ("x" ,)
343+ alphas = (alpha_x ,)
344+ case _:
345+ raise ValueError ()
346+
347+ regularization = smallness
348+ for direction , alpha in zip (directions , alphas , strict = True ):
349+ phi = Flatness (mesh , ** kwargs , direction = direction )
350+ if alpha is not None :
351+ phi = alpha * phi
352+ regularization = regularization + phi
353+
354+ return regularization .flatten ()
0 commit comments