From 83cca5a75f8dbe3a0a03d467a87728e6b486044a Mon Sep 17 00:00:00 2001 From: Saurabh Mogre Date: Tue, 16 Jul 2024 10:49:35 -0700 Subject: [PATCH] Specify weights while combining gradients (#271) * Add methods and recipe to use multiple gradients on a single ingredient * Add gradient weights to Environment.py and Gradient.py * Add gradient_weights parameter to Ingredient and Agent classes * Update decay length and gradient weights in test_combined_gradient.json * Increase strength of gradient in example recipe * Remove unused import * Refactor gradient weighting logic * Add validation for gradient names in Gradient.py * Add assertion to check for at least two gradients before combining * Check if gradient key exists in recipe before updating ingredient info * add validation for gradient information while creating ingredient * add tests for gradient information validation * Simplify ingredient gradient update by moving validation to previous step --- cellpack/autopack/Environment.py | 9 +---- cellpack/autopack/Gradient.py | 35 ++++++++++++----- cellpack/autopack/ingredient/Ingredient.py | 35 +++++++++++++++++ cellpack/autopack/ingredient/agent.py | 4 +- cellpack/autopack/ingredient/grow.py | 1 + .../autopack/ingredient/multi_cylinder.py | 1 + cellpack/autopack/ingredient/multi_sphere.py | 1 + cellpack/autopack/ingredient/single_cube.py | 1 + .../autopack/ingredient/single_cylinder.py | 3 +- cellpack/autopack/ingredient/single_sphere.py | 1 + .../recipes/v2/test_combined_gradient.json | 4 ++ cellpack/tests/test_ingredient.py | 38 +++++++++++++++++++ 12 files changed, 115 insertions(+), 18 deletions(-) diff --git a/cellpack/autopack/Environment.py b/cellpack/autopack/Environment.py index 7e2b5799d..76ce34924 100644 --- a/cellpack/autopack/Environment.py +++ b/cellpack/autopack/Environment.py @@ -974,13 +974,8 @@ def create_ingredient(self, recipe, arguments): ingredient_type = arguments["type"] ingredient_class = ingredient.get_ingredient_class(ingredient_type) ingr = ingredient_class(**arguments) - if ( - "gradient" in arguments - and arguments["gradient"] != "" - and arguments["gradient"] != "None" - ): - ingr.gradient = arguments["gradient"] - # TODO: allow ingrdients to have multiple gradients + if "gradient" in arguments: + ingr = Gradient.update_ingredient_gradient(ingr, arguments) if "results" in arguments: ingr.results = arguments["results"] ingr.initialize_mesh(self.mesh_store) diff --git a/cellpack/autopack/Gradient.py b/cellpack/autopack/Gradient.py index 9ee971268..4d6478ccf 100644 --- a/cellpack/autopack/Gradient.py +++ b/cellpack/autopack/Gradient.py @@ -91,6 +91,18 @@ def __init__(self, gradient_data): self.function = self.defaultFunction # lambda ? + @staticmethod + def update_ingredient_gradient(ingr, arguments): + """ + Update the ingredient gradient + """ + ingr.gradient = arguments["gradient"] + ingr.gradient_weights = None + if "gradient_weights" in arguments: + ingr.gradient_weights = arguments["gradient_weights"] + + return ingr + @staticmethod def scale_between_0_and_1(values): """ @@ -101,7 +113,7 @@ def scale_between_0_and_1(values): return (values - min_value) / (max_value - min_value) @staticmethod - def get_combined_gradient_weight(gradient_list): + def get_combined_gradient_weight(gradient_list, gradient_weights=None): """ Combine the gradient weights @@ -115,11 +127,13 @@ def get_combined_gradient_weight(gradient_list): numpy.ndarray the combined gradient weight """ + assert len(gradient_list) > 1, "Need at least two gradients to combine" + weight_list = numpy.zeros((len(gradient_list), len(gradient_list[0].weight))) for i in range(len(gradient_list)): weight_list[i] = Gradient.scale_between_0_and_1(gradient_list[i].weight) - combined_weight = numpy.mean(weight_list, axis=0) + combined_weight = numpy.average(weight_list, axis=0, weights=gradient_weights) combined_weight = Gradient.scale_between_0_and_1(combined_weight) return combined_weight @@ -143,8 +157,8 @@ def pick_point_from_weight(weight, points): the index of the picked point """ weights_to_use = numpy.take(weight, points) - weights_to_use = Gradient.scale_between_0_and_1(weights_to_use) weights_to_use[numpy.isnan(weights_to_use)] = 0 + weights_to_use = Gradient.scale_between_0_and_1(weights_to_use) point_probabilities = weights_to_use / numpy.sum(weights_to_use) @@ -176,13 +190,16 @@ def pick_point_for_ingredient(ingr, allIngrPts, all_gradients): if isinstance(ingr.gradient, list): if len(ingr.gradient) > 1: if not hasattr(ingr, "combined_weight"): - gradient_list = [ - gradient - for gradient_name, gradient in all_gradients.items() - if gradient_name in ingr.gradient - ] + gradient_list = [] + for gradient_name in ingr.gradient: + if gradient_name not in all_gradients: + raise ValueError( + f"Gradient {gradient_name} not found in gradient list" + ) + gradient_list.append(all_gradients[gradient_name]) + combined_weight = Gradient.get_combined_gradient_weight( - gradient_list + gradient_list, ingr.gradient_weights ) ingr.combined_weight = combined_weight diff --git a/cellpack/autopack/ingredient/Ingredient.py b/cellpack/autopack/ingredient/Ingredient.py index 1194ea5e2..78e6d677e 100644 --- a/cellpack/autopack/ingredient/Ingredient.py +++ b/cellpack/autopack/ingredient/Ingredient.py @@ -161,6 +161,7 @@ class Ingredient(Agent): "distance_function", "force_random", "gradient", + "gradient_weights", "is_attractor", "max_jitter", "molarity", @@ -199,6 +200,7 @@ def __init__( distance_function=None, force_random=False, # avoid any binding gradient=None, + gradient_weights=None, is_attractor=False, max_jitter=(1, 1, 1), molarity=0.0, @@ -231,6 +233,7 @@ def __init__( distance_function=distance_function, force_random=force_random, gradient=gradient, + gradient_weights=gradient_weights, is_attractor=is_attractor, overwrite_distance_function=overwrite_distance_function, packing_mode=packing_mode, @@ -407,6 +410,38 @@ def validate_ingredient_info(ingredient_info): ingredient_info["size_options"] ) + # check if gradient information is entered correctly + if "gradient" in ingredient_info: + if not isinstance(ingredient_info["gradient"], (list, str)): + raise Exception( + ( + f"Invalid gradient: {ingredient_info['gradient']} " + f"for ingredient {ingredient_info['name']}" + ) + ) + if ( + ingredient_info["gradient"] == "" + or ingredient_info["gradient"] == "None" + ): + raise Exception( + f"Missing gradient for ingredient {ingredient_info['name']}" + ) + + # if multiple gradients are provided with weights, check if weights are correct + if isinstance(ingredient_info["gradient"], list): + if "gradient_weights" in ingredient_info: + # check if gradient_weights are missing + if not isinstance(ingredient_info["gradient_weights"], list): + raise Exception( + f"Invalid gradient weights for ingredient {ingredient_info['name']}" + ) + if len(ingredient_info["gradient"]) != len( + ingredient_info["gradient_weights"] + ): + raise Exception( + f"Missing gradient weights for ingredient {ingredient_info['name']}" + ) + return ingredient_info def reset(self): diff --git a/cellpack/autopack/ingredient/agent.py b/cellpack/autopack/ingredient/agent.py index 8ae82dd2c..9cf497e2b 100644 --- a/cellpack/autopack/ingredient/agent.py +++ b/cellpack/autopack/ingredient/agent.py @@ -11,7 +11,8 @@ def __init__( distance_expression=None, distance_function=None, force_random=False, # avoid any binding - gradient="", + gradient=None, + gradient_weights=None, is_attractor=False, overwrite_distance_function=True, # overWrite packing_mode="random", @@ -42,6 +43,7 @@ def __init__( self.distance_expression = distance_expression self.overwrite_distance_function = overwrite_distance_function self.gradient = gradient + self.gradient_weights = gradient_weights self.cb = None self.radii = None self.recipe = None # weak ref to recipe diff --git a/cellpack/autopack/ingredient/grow.py b/cellpack/autopack/ingredient/grow.py index 07c6fabdb..a0b7a4c9b 100644 --- a/cellpack/autopack/ingredient/grow.py +++ b/cellpack/autopack/ingredient/grow.py @@ -50,6 +50,7 @@ def __init__( cutoff_boundary=1.0, cutoff_surface=0.5, gradient=None, + gradient_weights=None, is_attractor=False, max_jitter=(1, 1, 1), length=10.0, diff --git a/cellpack/autopack/ingredient/multi_cylinder.py b/cellpack/autopack/ingredient/multi_cylinder.py index cdca69364..58012ec17 100644 --- a/cellpack/autopack/ingredient/multi_cylinder.py +++ b/cellpack/autopack/ingredient/multi_cylinder.py @@ -30,6 +30,7 @@ def __init__( distance_function=None, force_random=False, # avoid any binding gradient=None, + gradient_weights=None, is_attractor=False, max_jitter=(1, 1, 1), molarity=0.0, diff --git a/cellpack/autopack/ingredient/multi_sphere.py b/cellpack/autopack/ingredient/multi_sphere.py index 132569341..6ba1c5447 100644 --- a/cellpack/autopack/ingredient/multi_sphere.py +++ b/cellpack/autopack/ingredient/multi_sphere.py @@ -23,6 +23,7 @@ def __init__( cutoff_boundary=None, cutoff_surface=None, gradient=None, + gradient_weights=None, is_attractor=False, max_jitter=(1, 1, 1), molarity=0.0, diff --git a/cellpack/autopack/ingredient/single_cube.py b/cellpack/autopack/ingredient/single_cube.py index 2d0f65be9..c8612fd60 100644 --- a/cellpack/autopack/ingredient/single_cube.py +++ b/cellpack/autopack/ingredient/single_cube.py @@ -27,6 +27,7 @@ def __init__( distance_function=None, force_random=False, # avoid any binding gradient=None, + gradient_weights=None, is_attractor=False, max_jitter=(1, 1, 1), molarity=0.0, diff --git a/cellpack/autopack/ingredient/single_cylinder.py b/cellpack/autopack/ingredient/single_cylinder.py index 01c003a2e..eb87e73a6 100644 --- a/cellpack/autopack/ingredient/single_cylinder.py +++ b/cellpack/autopack/ingredient/single_cylinder.py @@ -30,7 +30,8 @@ def __init__( distance_expression=None, distance_function=None, force_random=False, # avoid any binding - gradient="", + gradient=None, + gradient_weights=None, is_attractor=False, max_jitter=(1, 1, 1), molarity=0.0, diff --git a/cellpack/autopack/ingredient/single_sphere.py b/cellpack/autopack/ingredient/single_sphere.py index bcc35999c..d0e2b356c 100644 --- a/cellpack/autopack/ingredient/single_sphere.py +++ b/cellpack/autopack/ingredient/single_sphere.py @@ -28,6 +28,7 @@ def __init__( distance_function=None, force_random=False, # avoid any binding gradient=None, + gradient_weights=None, is_attractor=False, max_jitter=(1, 1, 1), molarity=0.0, diff --git a/cellpack/tests/recipes/v2/test_combined_gradient.json b/cellpack/tests/recipes/v2/test_combined_gradient.json index aebccbf71..e8776872f 100644 --- a/cellpack/tests/recipes/v2/test_combined_gradient.json +++ b/cellpack/tests/recipes/v2/test_combined_gradient.json @@ -88,6 +88,10 @@ "gradient": [ "X_gradient", "Y_gradient" + ], + "gradient_weights": [ + 70, + 30 ] } }, diff --git a/cellpack/tests/test_ingredient.py b/cellpack/tests/test_ingredient.py index 2e5b6344a..eea5d204e 100644 --- a/cellpack/tests/test_ingredient.py +++ b/cellpack/tests/test_ingredient.py @@ -65,6 +65,44 @@ }, "Missing option 'min' for uniform distribution", ), + ( + { + "name": "test", + "type": "single_sphere", + "count": 1, + "gradient": 3, + }, + "Invalid gradient: 3 for ingredient test", + ), + ( + { + "name": "test", + "type": "single_sphere", + "count": 1, + "gradient": "", + }, + "Missing gradient for ingredient test", + ), + ( + { + "name": "test", + "type": "single_sphere", + "count": 1, + "gradient": ["gradient_1", "gradient_2"], + "gradient_weights": 0.5, + }, + "Invalid gradient weights for ingredient test", + ), + ( + { + "name": "test", + "type": "single_sphere", + "count": 1, + "gradient": ["gradient_1", "gradient_2", "gradient_3"], + "gradient_weights": [0.5, 0.5], + }, + "Missing gradient weights for ingredient test", + ), ], ) def test_validate_ingredient_info(ingredient_info, output):