Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/run inherited objects #198

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 67 additions & 48 deletions cellpack/autopack/DBRecipeHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class CompositionDoc(DataDoc):

SHALLOW_MATCH = ["object", "count", "molarity"]
DEFAULT_VALUES = {"object": None, "count": None, "regions": {}, "molarity": None}
KEY_TO_DICT_MAPPING = {"gradient": "gradients", "inherit": "objects"}

def __init__(
self,
Expand Down Expand Up @@ -79,12 +80,10 @@ def as_dict(self):
return data

@staticmethod
def get_gradient_reference(downloaded_data, db):
if "gradient" in downloaded_data and db.is_reference(
downloaded_data["gradient"]
):
gradient_key = downloaded_data["gradient"]
downloaded_data["gradient"], _ = db.get_doc_by_ref(gradient_key)
def get_reference_in_obj(downloaded_data, db):
for key in CompositionDoc.KEY_TO_DICT_MAPPING:
if key in downloaded_data and db.is_reference(downloaded_data[key]):
downloaded_data[key], _ = db.get_doc_by_ref(downloaded_data[key])

@staticmethod
def get_reference_data(key_or_dict, db):
Expand All @@ -96,14 +95,14 @@ def get_reference_data(key_or_dict, db):
if DataDoc.is_key(key_or_dict) and db.is_reference(key_or_dict):
key = key_or_dict
downloaded_data, _ = db.get_doc_by_ref(key)
CompositionDoc.get_gradient_reference(downloaded_data, db)
CompositionDoc.get_reference_in_obj(downloaded_data, db)
return downloaded_data, None
elif key_or_dict and isinstance(key_or_dict, dict):
object_dict = key_or_dict
if "object" in object_dict and db.is_reference(object_dict["object"]):
key = object_dict["object"]
downloaded_data, _ = db.get_doc_by_ref(key)
CompositionDoc.get_gradient_reference(downloaded_data, db)
CompositionDoc.get_reference_in_obj(downloaded_data, db)
return downloaded_data, key
return {}, None

Expand Down Expand Up @@ -141,6 +140,15 @@ def gradient_list_to_dict(prep_recipe_data):
gradient_dict[gradient["name"]] = gradient
prep_recipe_data["gradients"] = gradient_dict

def resolve_object_data(self, object_data, prep_recipe_data):
"""
Resolve the object data from the local data.
"""
for key in CompositionDoc.KEY_TO_DICT_MAPPING:
if key in object_data and isinstance(object_data[key], str):
target_dict = CompositionDoc.KEY_TO_DICT_MAPPING[key]
object_data[key] = prep_recipe_data[target_dict][object_data[key]]

def resolve_local_regions(self, local_data, recipe_data, db):
"""
Recursively resolves the regions of a composition from local data.
Expand All @@ -156,12 +164,7 @@ def resolve_local_regions(self, local_data, recipe_data, db):
else:
key_name = local_data["object"]["name"]
local_data["object"] = prep_recipe_data["objects"][key_name]
if "gradient" in local_data["object"] and isinstance(
local_data["object"]["gradient"], str
):
local_data["object"]["gradient"] = prep_recipe_data["gradients"][
local_data["object"]["gradient"]
]
self.resolve_object_data(local_data["object"], prep_recipe_data)
for region_name in local_data["regions"]:
for index, key_or_dict in enumerate(local_data["regions"][region_name]):
if not DataDoc.is_key(key_or_dict):
Expand All @@ -174,12 +177,9 @@ def resolve_local_regions(self, local_data, recipe_data, db):
local_data["regions"][region_name][index][
"object"
] = prep_recipe_data["objects"][obj_item["name"]]
# replace gradient reference with gradient data
# replace reference in obj with actual data
obj_data = local_data["regions"][region_name][index]["object"]
if "gradient" in obj_data and isinstance(obj_data["gradient"], str):
local_data["regions"][region_name][index]["object"][
"gradient"
] = prep_recipe_data["gradients"][obj_data["gradient"]]
self.resolve_object_data(obj_data, prep_recipe_data)
else:
comp_name = local_data["regions"][region_name][index]
prep_comp_data = prep_recipe_data["composition"][comp_name]
Expand Down Expand Up @@ -336,6 +336,12 @@ def convert_representation(doc, db):
] = ObjectDoc.convert_positions_in_representation(position_value)
return convert_doc

@staticmethod
def _object_contains_grad_or_inherit(obj_data):
return (
"gradient" in obj_data and isinstance(obj_data["gradient"], dict)
) or "inherit" in obj_data

def should_write(self, db):
docs = db.get_doc_by_name("objects", self.name)
if docs and len(docs) >= 1:
Expand Down Expand Up @@ -377,6 +383,7 @@ def __init__(self, db_handler):
self.objects_to_path_map = {}
self.comp_to_path_map = {}
self.grad_to_path_map = {}
self.objects_with_inherit_key = []

@staticmethod
def prep_data_for_db(data):
Expand Down Expand Up @@ -435,24 +442,38 @@ def upload_gradients(self, gradients):
_, grad_path = self.upload_data("gradients", gradient_doc.settings)
self.grad_to_path_map[gradient_name] = grad_path

def upload_single_object(self, obj_name, obj_data):
# replace gradient name with path to check if gradient exists in db
if "gradient" in obj_data[obj_name]:
grad_name = obj_data[obj_name]["gradient"]
obj_data[obj_name]["gradient"] = self.grad_to_path_map[grad_name]
object_doc = ObjectDoc(name=obj_name, settings=obj_data[obj_name])
_, doc_id = object_doc.should_write(self.db)
if doc_id:
print(f"objects/{object_doc.name} is already in firestore")
obj_path = self.db.create_path("objects", doc_id)
self.objects_to_path_map[obj_name] = obj_path
else:
_, obj_path = self.upload_data("objects", object_doc.as_dict())
self.objects_to_path_map[obj_name] = obj_path

def upload_objects(self, objects):
# modify a copy of objects to avoid key error when resolving local regions
modify_objects = copy.deepcopy(objects)
for obj_name in objects:
objects[obj_name]["name"] = obj_name
# modify a copy of objects to avoid key error when resolving local regions
modify_objects = copy.deepcopy(objects)
# replace gradient name with path to check if gradient exists in db
if "gradient" in modify_objects[obj_name]:
grad_name = modify_objects[obj_name]["gradient"]
modify_objects[obj_name]["gradient"] = self.grad_to_path_map[grad_name]
object_doc = ObjectDoc(name=obj_name, settings=modify_objects[obj_name])
_, doc_id = object_doc.should_write(self.db)
if doc_id:
print(f"objects/{object_doc.name} is already in firestore")
obj_path = self.db.create_path("objects", doc_id)
self.objects_to_path_map[obj_name] = obj_path
if "inherit" not in objects[obj_name]:
self.upload_single_object(obj_name, modify_objects)
else:
_, obj_path = self.upload_data("objects", object_doc.as_dict())
self.objects_to_path_map[obj_name] = obj_path
self.objects_with_inherit_key.append(obj_name)

# upload objs having `inherit` key only after all their base objs are uploaded
for obj_name in self.objects_with_inherit_key:
inherited_from = objects[obj_name]["inherit"]
modify_objects[obj_name]["inherit"] = self.objects_to_path_map[
inherited_from
]
self.upload_single_object(obj_name, modify_objects)

def upload_compositions(self, compositions, recipe_to_save, recipe_data):
references_to_update = {}
Expand Down Expand Up @@ -597,15 +618,17 @@ def collect_docs_by_id(self, collection, id):

@staticmethod
def _get_grad_and_obj(obj_data, obj_dict, grad_dict):
try:
grad_name = obj_data["gradient"]["name"]
obj_name = obj_data["name"]
except KeyError as e:
print(f"Missing keys in object: {e}")
return obj_dict, grad_dict

grad_dict[grad_name] = obj_data["gradient"]
obj_dict[obj_name]["gradient"] = grad_name
"""
Collect gradient and inherited object data from the downloaded object data
return object data dict and gradient data dict with name as key
"""
obj_name = obj_data["name"]
for key, target_dict in CompositionDoc.KEY_TO_DICT_MAPPING.items():
if key in obj_data:
item_name = obj_data[key]["name"]
target_dict = grad_dict if key == "gradient" else obj_dict
target_dict[item_name] = obj_data[key]
obj_dict[obj_name][key] = item_name
return obj_dict, grad_dict

@staticmethod
Expand All @@ -626,9 +649,7 @@ def collect_and_sort_data(comp_data):
composition[comp_name]["object"] = comp_value["object"]["name"]
object_copy = copy.deepcopy(comp_value["object"])
objects[object_copy["name"]] = object_copy
if "gradient" in object_copy and isinstance(
object_copy["gradient"], dict
):
if ObjectDoc._object_contains_grad_or_inherit(object_copy):
objects, gradients = DBRecipeLoader._get_grad_and_obj(
object_copy, objects, gradients
)
Expand All @@ -645,9 +666,7 @@ def collect_and_sort_data(comp_data):
)
object_copy = copy.deepcopy(region_item["object"])
objects[object_copy["name"]] = object_copy
if "gradient" in object_copy and isinstance(
object_copy["gradient"], dict
):
if ObjectDoc._object_contains_grad_or_inherit(object_copy):
objects, gradients = DBRecipeLoader._get_grad_and_obj(
object_copy, objects, gradients
)
Expand Down
9 changes: 5 additions & 4 deletions cellpack/autopack/loaders/recipe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _migrate_version(self, old_recipe):
f"{old_recipe['format_version']} is not a format version we support"
)

def _read(self):
def _read(self, resolve_inheritance=True):
new_values, database_name = autopack.load_file(self.file_path, cache="recipes")
if database_name == "firebase":
objects, gradients, composition = DBRecipeLoader.collect_and_sort_data(
Expand All @@ -185,9 +185,10 @@ def _read(self):

# TODO: request any external data before returning
if "objects" in recipe_data:
recipe_data["objects"] = RecipeLoader.resolve_inheritance(
recipe_data["objects"]
)
if resolve_inheritance:
recipe_data["objects"] = RecipeLoader.resolve_inheritance(
recipe_data["objects"]
)
for _, obj in recipe_data["objects"].items():
reps = obj["representations"] if "representations" in obj else {}
obj["representations"] = Representations(
Expand Down
1 change: 1 addition & 0 deletions cellpack/autopack/writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def save_as_simularium(self, env, all_ingr_as_array, compartments):
):
autopack.helper.post_and_open_file(file_name)


def save_Mixed_asJson(
self,
env,
Expand Down
2 changes: 1 addition & 1 deletion cellpack/bin/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def upload(
# fetch the service key json file
db_handler = FirebaseHandler()
recipe_loader = RecipeLoader(recipe_path)
recipe_full_data = recipe_loader.recipe_data
recipe_full_data = recipe_loader._read(resolve_inheritance=False)
recipe_meta_data = recipe_loader.get_only_recipe_metadata()
recipe_db_handler = DBUploader(db_handler)
recipe_db_handler.upload_recipe(recipe_meta_data, recipe_full_data)
Expand Down