diff --git a/src/linkml_map/datamodel/transformer_model.py b/src/linkml_map/datamodel/transformer_model.py index 159a391..ac0d032 100644 --- a/src/linkml_map/datamodel/transformer_model.py +++ b/src/linkml_map/datamodel/transformer_model.py @@ -233,6 +233,7 @@ class TransformationSpecification(SpecificationComponent): } }, ) + copy_directives: Optional[Dict[str, CopyDirective]] = Field(default_factory=dict) description: Optional[str] = Field( None, description="""description of the specification component""", diff --git a/src/linkml_map/datamodel/transformer_model.yaml b/src/linkml_map/datamodel/transformer_model.yaml index d573cb6..eaf6eeb 100644 --- a/src/linkml_map/datamodel/transformer_model.yaml +++ b/src/linkml_map/datamodel/transformer_model.yaml @@ -81,6 +81,10 @@ classes: multivalued: true inlined: true slot_uri: sh:declare + copy_directives: + range: CopyDirective + multivalued: true + inlined: true source_schema: description: name of the schema that describes the source (input) objects target_schema: diff --git a/src/linkml_map/inference/schema_mapper.py b/src/linkml_map/inference/schema_mapper.py index 966baab..bc617f6 100644 --- a/src/linkml_map/inference/schema_mapper.py +++ b/src/linkml_map/inference/schema_mapper.py @@ -51,6 +51,93 @@ class SchemaMapper: slot_info: Dict[Tuple[str, str], Any] = field(default_factory=lambda: {}) + def _copy_dict( + self, + copy_directive: CopyDirective, + src_elements, + tgt_elements, + ): + if copy_directive.copy_all: + for element in src_elements.keys(): + tgt_elements[element] = src_elements[element] + if copy_directive.exclude: + for element in src_elements.keys(): + if element in copy_directive.exclude: + del tgt_elements[element] + if copy_directive.exclude_all: + elements_to_delete = [key for key in tgt_elements] + for element in elements_to_delete: + del tgt_elements[element] + if copy_directive.include: + for element in copy_directive.include: + if element in src_elements.keys(): + tgt_elements[element] = src_elements[element] + + def _copy_list( + self, + copy_directive: CopyDirective, + src_elements, + tgt_elements, + ): + if copy_directive.copy_all: + for element in src_elements: + tgt_elements.append(element) + if copy_directive.exclude: + for element in src_elements: + if copy_directive.exclude: + tgt_elements.remove(element) + if copy_directive.exclude_all: + for element in tgt_elements: + tgt_elements.remove(element) + if copy_directive.include: + for element in copy_directive.include: + if element in src_elements: + tgt_elements.append(element) + + def _copy_schema( + self, + copy_directives: list[CopyDirective], + source: SchemaDefinition, + target: SchemaDefinition, + ) -> SchemaDefinition: + if type(copy_directives) is dict: + copy_directives_list = copy_directives.values() + else: + copy_directives_list = copy_directives + + for copy_directive in copy_directives_list: + for element_type in ["classes", "slots", "enums"]: + if not hasattr(source, element_type): + continue + src_elements = getattr(source, element_type) + tgt_elements = getattr(target, element_type) + self._copy_dict(copy_directive, src_elements, tgt_elements) + return target + + def _copy_class( + self, + copy_directives: list[CopyDirective], + source: ClassDefinition, + target: ClassDefinition, + ) -> ClassDefinition: + if type(copy_directives) is dict: + copy_directives_list = copy_directives.values() + else: + copy_directives_list = copy_directives + + for copy_directive in copy_directives_list: + if hasattr(source, "attributes"): + # copy attributes (which is a dict) + src_elements = source.attributes + tgt_elements = target.attributes + self._copy_dict(copy_directive, src_elements, tgt_elements) + if hasattr(source, "slots"): + # copy slots (which is a list) + src_elements = source.slots + tgt_elements = target.slots + self._copy_list(copy_directive, src_elements, tgt_elements) + return target + def derive_schema( self, specification: Optional[TransformationSpecification] = None, @@ -73,6 +160,12 @@ def derive_schema( if target_schema_name is None: target_schema_name = source_schema.name + suffix target_schema = SchemaDefinition(id=target_schema_id, name=target_schema_name) + if hasattr(specification, "copy_directives"): + target_schema = self._copy_schema( + specification.copy_directives, + source_schema, + target_schema, + ) for im in source_schema.imports: target_schema.imports.append(im) for prefix in source_schema.prefixes.values(): @@ -112,6 +205,12 @@ def _derive_class(self, class_derivation: ClassDerivation) -> ClassDefinition: target_class.slots = [] target_class.attributes = {} target_class.slot_usage = {} + if hasattr(class_derivation, "copy_directives"): + target_class = self._copy_class( + class_derivation.copy_directives, + source_class, + target_class, + ) for slot_derivation in class_derivation.slot_derivations.values(): slot_definition = self._derive_slot(slot_derivation) target_class.attributes[slot_definition.name] = slot_definition diff --git a/tests/test_schema_mapper/test_schema_mapper.py b/tests/test_schema_mapper/test_schema_mapper.py index 000314d..86e22e3 100644 --- a/tests/test_schema_mapper/test_schema_mapper.py +++ b/tests/test_schema_mapper/test_schema_mapper.py @@ -6,6 +6,7 @@ from linkml_map.datamodel.transformer_model import ( ClassDerivation, + CopyDirective, SlotDerivation, TransformationSpecification, ) @@ -186,6 +187,200 @@ def test_rewire(self): self.assertEqual(["tr_salary"], list(emp.attributes.keys())) # self.assertEqual("Person", emp.is_a) + def test_full_copy_specification(self): + """tests copy isomorphism""" + tr = self.mapper + copy_all_directive = {"*": CopyDirective(element_name="*", copy_all=True)} + specification = TransformationSpecification(id="test", copy_directives=copy_all_directive) + source_schema = tr.source_schemaview.schema + + target_schema = tr.derive_schema(specification) + # classes, slots and enums must be exactly the same + self.assertEqual( + yaml_dumper.dumps(source_schema.classes), yaml_dumper.dumps(target_schema.classes) + ) + self.assertEqual( + yaml_dumper.dumps(source_schema.slots), yaml_dumper.dumps(target_schema.slots) + ) + self.assertEqual( + yaml_dumper.dumps(source_schema.enums), yaml_dumper.dumps(target_schema.enums) + ) + + def test_partial_copy_specification(self): + """tests copy isomorphism excluding derivations""" + tr = self.mapper + copy_all_directive = {"*": CopyDirective(element_name="*", copy_all=True)} + specification = TransformationSpecification(id="test", copy_directives=copy_all_directive) + source_schema = tr.source_schemaview.schema + + derivations = [ + ClassDerivation(name="Agent", populated_from="Person"), + ] + for derivation in derivations: + specification.class_derivations[derivation.name] = derivation + target_schema = tr.derive_schema(specification) + # classes must be the same with addition + for schema_class in source_schema.classes.keys(): + self.assertIn( + schema_class, + target_schema.classes.keys(), + f"Class '{schema_class}' is missing in target", + ) + self.assertIn( + "Agent", target_schema.classes.keys(), "Derived class 'Agent' is missing in target" + ) + # slots and enums must be exactly the same + self.assertEqual( + yaml_dumper.dumps(source_schema.slots), yaml_dumper.dumps(target_schema.slots) + ) + self.assertEqual( + yaml_dumper.dumps(source_schema.enums), yaml_dumper.dumps(target_schema.enums) + ) + + def test_full_copy_class(self): + """tests copy isomorphism with class derivation""" + tr = self.mapper + copy_all_directive = {"*": CopyDirective(element_name="*", copy_all=True)} + specification = TransformationSpecification(id="test", copy_directives=copy_all_directive) + source_schema = tr.source_schemaview.schema + + derivations = [ + ClassDerivation( + name="Agent", populated_from="Person", copy_directives=copy_all_directive + ), + ] + for derivation in derivations: + specification.class_derivations[derivation.name] = derivation + target_schema = tr.derive_schema(specification) + # classes must be the same with addition + for schema_class in source_schema.classes.keys(): + self.assertIn( + schema_class, + target_schema.classes.keys(), + f"Class '{schema_class}' is missing in target", + ) + self.assertIn( + "Agent", target_schema.classes.keys(), "Derived class 'Agent' is missing in target" + ) + self.assertEqual( + yaml_dumper.dumps(source_schema.classes["Person"].slots), + yaml_dumper.dumps(target_schema.classes["Agent"].slots), + ) + self.assertEqual( + yaml_dumper.dumps(source_schema.classes["Person"].attributes), + yaml_dumper.dumps(target_schema.classes["Agent"].attributes), + ) + # slots and enums must be exactly the same + self.assertEqual( + yaml_dumper.dumps(source_schema.slots), yaml_dumper.dumps(target_schema.slots) + ) + self.assertEqual( + yaml_dumper.dumps(source_schema.enums), yaml_dumper.dumps(target_schema.enums) + ) + + def test_copy_blacklisting(self): + """tests copy on a blacklist approach""" + tr = self.mapper + blacklist = ["Person"] + copy_all_directive = { + "*": CopyDirective(element_name="*", copy_all=True, exclude=blacklist) + } + specification = TransformationSpecification(id="test", copy_directives=copy_all_directive) + source_schema = tr.source_schemaview.schema + + derivations = [ + ClassDerivation(name="Agent", populated_from="Person"), + ] + for derivation in derivations: + specification.class_derivations[derivation.name] = derivation + target_schema = tr.derive_schema(specification) + # classes must be the same with addition + for schema_class in source_schema.classes.keys(): + if schema_class in blacklist: + self.assertNotIn( + schema_class, + target_schema.classes.keys(), + f"Class '{schema_class}' is missing in target", + ) + else: + self.assertIn( + schema_class, + target_schema.classes.keys(), + f"Class '{schema_class}' is missing in target", + ) + self.assertIn( + "Agent", target_schema.classes.keys(), "Derived class 'Agent' is missing in target" + ) + # slots and enums must be exactly the same + self.assertEqual( + yaml_dumper.dumps(source_schema.slots), yaml_dumper.dumps(target_schema.slots) + ) + self.assertEqual( + yaml_dumper.dumps(source_schema.enums), yaml_dumper.dumps(target_schema.enums) + ) + + def test_copy_whitelisting(self): + """tests copy on a whitelist approach""" + tr = self.mapper + whitelist = ["NamedThing"] + whitelist_directive = { + "Whitelist": CopyDirective( + element_name="*", copy_all=True, exclude_all=True, include=whitelist + ) + } + specification = TransformationSpecification(id="test", copy_directives=whitelist_directive) + source_schema = tr.source_schemaview.schema + + derivations = [ + ClassDerivation(name="Agent", populated_from="Person"), + ] + for derivation in derivations: + specification.class_derivations[derivation.name] = derivation + target_schema = tr.derive_schema(specification) + # classes, slots and enums must have only what explicitly included + for schema_class in source_schema.classes.keys(): + if schema_class in whitelist: + self.assertIn( + schema_class, + target_schema.classes.keys(), + f"Class '{schema_class}' is missing in target", + ) + else: + self.assertNotIn( + schema_class, + target_schema.classes.keys(), + f"Class '{schema_class}' is missing in target", + ) + self.assertIn( + "Agent", target_schema.classes.keys(), "Derived class 'Agent' is missing in target" + ) + for schema_slot in source_schema.slots.keys(): + if schema_slot in whitelist: + self.assertIn( + schema_slot, + target_schema.slots.keys(), + f"Slot '{schema_slot}' is missing in target", + ) + else: + self.assertNotIn( + schema_slot, + target_schema.slots.keys(), + f"Slot '{schema_slot}' is missing in target", + ) + for schema_enum in source_schema.enums.keys(): + if schema_enum in whitelist: + self.assertIn( + schema_enum, + target_schema.enums.keys(), + f"Enum '{schema_enum}' is missing in target", + ) + else: + self.assertNotIn( + schema_enum, + target_schema.enums.keys(), + f"Enum '{schema_enum}' is missing in target", + ) + if __name__ == "__main__": unittest.main()