Skip to content

Commit

Permalink
Merge branches 'feature/weighted_gradient_combination' and 'main' of …
Browse files Browse the repository at this point in the history
…github.com:mesoscope/cellpack into feature/weighted_gradient_combination
  • Loading branch information
mogres committed Jul 8, 2024
2 parents fa36f88 + 12276f5 commit c386236
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
run: |
pytest --cov cellpack/tests/
- name: Upload codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4

lint:
runs-on: ubuntu-latest
Expand Down
11 changes: 8 additions & 3 deletions cellpack/autopack/Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,12 +828,17 @@ def update_distance_distribution_dictionaries(

get_angles = False
if ingr.packing_mode == "gradient" and self.env.use_gradient:
if not isinstance(ingr.gradient, list):
if isinstance(ingr.gradient, list):
if len(ingr.gradient) > 1 or len(ingr.gradient) == 0:
self.center = center
else:
self.center = center = self.env.gradients[
ingr.gradient[0]
].mode_settings.get("center", center)
else:
self.center = center = self.env.gradients[
ingr.gradient
].mode_settings.get("center", center)
else:
self.center = center
get_angles = True

# get angles wrt gradient
Expand Down
25 changes: 3 additions & 22 deletions cellpack/autopack/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,28 +1700,9 @@ def getPointToDrop(
# get the most probable point using the gradient
# use the gradient weighted map and get mot probabl point
self.log.info("pick point from gradients %d", (len(allIngrPts)))
if isinstance(ingr.gradient, list) and len(ingr.gradient) > 1:
if not hasattr(ingr, "combined_weight"):
gradient_list = []
gradient_weights = []
for (gradient_name, gradient), gradient_weight in zip(
self.gradients.items(), ingr.gradient_weights
):
if gradient_name in ingr.gradient:
gradient_list.append(gradient)
gradient_weights.append(gradient_weight)

combined_weight = Gradient.get_combined_gradient_weight(
gradient_list, gradient_weights
)
ingr.combined_weight = combined_weight

ptInd = Gradient.pick_point_from_weight(
ingr.combined_weight, allIngrPts
)

else:
ptInd = self.gradients[ingr.gradient].pickPoint(allIngrPts)
ptInd = Gradient.pick_point_for_ingredient(
ingr, allIngrPts, self.gradients
)
else:
# pick a point randomly among free points
# random or uniform?
Expand Down
50 changes: 47 additions & 3 deletions cellpack/autopack/Gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def scale_between_0_and_1(values):
return (values - min_value) / (max_value - min_value)

@staticmethod
def get_combined_gradient_weight(gradient_list, gradient_weights=None):
def get_combined_gradient_weight(gradient_list):
"""
Combine the gradient weights
Expand All @@ -119,7 +119,7 @@ def get_combined_gradient_weight(gradient_list, gradient_weights=None):
for i in range(len(gradient_list)):
weight_list[i] = Gradient.scale_between_0_and_1(gradient_list[i].weight)

combined_weight = numpy.average(weight_list, axis=0, weights=gradient_weights)
combined_weight = numpy.mean(weight_list, axis=0)
combined_weight = Gradient.scale_between_0_and_1(combined_weight)

return combined_weight
Expand All @@ -143,15 +143,59 @@ def pick_point_from_weight(weight, points):
the index of the picked point
"""
weights_to_use = numpy.take(weight, points)
weights_to_use[numpy.isnan(weights_to_use)] = 0
weights_to_use = Gradient.scale_between_0_and_1(weights_to_use)
weights_to_use[numpy.isnan(weights_to_use)] = 0

point_probabilities = weights_to_use / numpy.sum(weights_to_use)

point = numpy.random.choice(points, p=point_probabilities)

return point

@staticmethod
def pick_point_for_ingredient(ingr, allIngrPts, all_gradients):
"""
Picks a point for an ingredient according to the gradient
Parameters
----------
ingr: Ingredient
the ingredient object
allIngrPts: numpy.ndarray
list of grid point indices
all_gradients: dict
dictionary of all gradient objects
Returns
----------
int
the index of the picked point
"""
if isinstance(ingr.gradient, list):
if len(ingr.gradient) > 1:
if not hasattr(ingr, "combined_weight"):
gradient_list = [
gradient
for gradient_name, gradient in all_gradients.items()
if gradient_name in ingr.gradient
]
combined_weight = Gradient.get_combined_gradient_weight(
gradient_list
)
ingr.combined_weight = combined_weight

ptInd = Gradient.pick_point_from_weight(
ingr.combined_weight, allIngrPts
)
else:
ptInd = all_gradients[ingr.gradient[0]].pickPoint(allIngrPts)
else:
ptInd = all_gradients[ingr.gradient].pickPoint(allIngrPts)

return ptInd

def get_center(self):
"""get the center of the gradient grid"""
center = [0.0, 0.0, 0.0]
Expand Down
10 changes: 0 additions & 10 deletions cellpack/autopack/loaders/recipe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,6 @@ def _sanitize_format_version(recipe_data):
format_version = recipe_data["format_version"]
return format_version

def get_only_recipe_metadata(self):
recipe_meta_data = {
"format_version": self.recipe_data["format_version"],
"version": self.recipe_data["version"],
"name": self.recipe_data["name"],
"bounding_box": self.recipe_data["bounding_box"],
"composition": {},
}
return recipe_meta_data

def _migrate_version(self, old_recipe):
converted = False
if old_recipe["format_version"] == "1.0":
Expand Down
5 changes: 4 additions & 1 deletion cellpack/autopack/upy/simularium/simularium_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,10 @@ def add_grid_data_to_scene(self, incoming_name, positions, values, radius=0.5):

positions, values = self.sort_values(positions, values)

colormap = matplotlib.cm.Reds(values)
normalized_values = (values - np.min(values)) / (
np.max(values) - np.min(values)
)
colormap = matplotlib.cm.Reds(normalized_values)

for index, value in enumerate(values):
name = f"{incoming_name}#{value:.3f}"
Expand Down
21 changes: 20 additions & 1 deletion cellpack/bin/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@
from cellpack.autopack.loaders.recipe_loader import RecipeLoader


def get_recipe_metadata(loader):
"""
Extracts and returns essential metadata from a recipe for uploading
"""
try:
recipe_meta_data = {
"format_version": loader.recipe_data["format_version"],
"version": loader.recipe_data["version"],
"name": loader.recipe_data["name"],
"bounding_box": loader.recipe_data["bounding_box"],
"composition": {},
}
if "grid_file_path" in loader.recipe_data:
recipe_meta_data["grid_file_path"] = loader.recipe_data["grid_file_path"]
return recipe_meta_data
except KeyError as e:
sys.exit(f"Recipe metadata is missing. {e}")


def upload(
recipe_path,
db_id=DATABASE_IDS.FIREBASE,
Expand All @@ -23,7 +42,7 @@ def upload(
if FirebaseHandler._initialized:
recipe_loader = RecipeLoader(recipe_path)
recipe_full_data = recipe_loader._read(resolve_inheritance=False)
recipe_meta_data = recipe_loader.get_only_recipe_metadata()
recipe_meta_data = get_recipe_metadata(recipe_loader)
recipe_db_handler = DBUploader(db_handler)
recipe_db_handler.upload_recipe(recipe_meta_data, recipe_full_data)
else:
Expand Down

0 comments on commit c386236

Please sign in to comment.