Skip to content

Commit

Permalink
Merge pull request #157 from VirtualFlyBrain/dev
Browse files Browse the repository at this point in the history
making VFBTerms/VFBTerm hashable and comparable
  • Loading branch information
Robbie1977 authored Aug 27, 2024
2 parents 60bc7c7 + 8460fff commit b48271e
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/vfb_connect/schema/test/vfb_term_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ def test_vfbterms_get_colours_for_terms(self):
tp = terms.get_colours_for('types', take_first=True)
print(tp)
self.assertEqual(len(tp), 4)
tp = terms.get_colours_for('parents', take_first=True, verbose=True)
print(tp)
self.assertEqual(len(tp), 4)

if __name__ == "__main__":
unittest.main()
213 changes: 213 additions & 0 deletions src/vfb_connect/schema/vfb_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,8 +2246,92 @@ def get(self, key, default=None):
return getattr(self, key, default)

def __len__(self):
"""
Get the length of the term.
"""
return 1

def __eq__(self, other):
"""
Check if two terms are equal.
"""
if isinstance(other, VFBTerm):
return self.id == other.id
return False

def __hash__(self):
"""
Get the hash of the term.
"""
return hash(self.id)

def __str__(self):
"""
Get the string representation of the term.
"""
return self.name

def __lt__(self, other):
"""
Check if one term is less than another.
"""
if isinstance(other, VFBTerm):
return self.name < other.name
return False

def __gt__(self, other):
"""
Check if one term is greater than another.
"""
if isinstance(other, VFBTerm):
return self.name > other.name
return False

def __le__(self, other):
"""
Check if one term is less than or equal to another.
"""
if isinstance(other, VFBTerm):
return self.name <= other.name
return False

def __ge__(self, other):
"""
Check if one term is greater than or equal to another.
"""
if isinstance(other, VFBTerm):
return self.name >= other.name
return False

def __ne__(self, other):
"""
Check if two terms are not equal.
"""
return not self.__eq__(other)

def __contains__(self, item):
"""
Check if an item is in the term.
"""
return item in self.__dict__.keys()

def __eq__(self, value: object) -> bool:
"""
Check if two terms are equal.
"""
if isinstance(value, VFBTerm):
return self.id == value.id
if isinstance(value, str):
if self.id == value:
return True
if self.name == value:
return True
if self.core.label == value:
return True
if self.core.symbol == value:
return True
return False

def __add__(self, other):
if isinstance(other, VFBTerms):
combined_terms = [self.term] + other.terms
Expand Down Expand Up @@ -2889,6 +2973,62 @@ def append(self, vfb_term, verbose=False):
def __len__(self):
return len(self.terms)

def __eq__(self, other):
"""
Compare two VFBTerms objects for equality.
Two VFBTerms objects are considered equal if they contain the same set of term IDs.
:param other: The other VFBTerms object to compare.
:return: True if the two VFBTerms objects are equal, False otherwise.
"""
if not isinstance(other, VFBTerms):
if isinstance(other, list) and all(isinstance(term, VFBTerm) for term in other):
return set(self.get_ids()) == set([term.id for term in other])
if isinstance(other, list) and all(isinstance(term, str) for term in other):
if set(self.get_ids()) == set(other):
return True
if set(self.get_names()) == set(other):
return True
return False

# Compare the sets of IDs for equality
return set(self.get_ids()) == set(other.get_ids())

def __contains__(self, item):
"""
Check if a term is in the VFBTerms object.
"""
if isinstance(item, VFBTerm):
return item.id in self.get_ids()
if isinstance(item, str):
if item in self.get_ids():
return True
if item in self.get_names():
return True
return False

def __hash__(self):
"""
Return a hash value based on the set of term IDs.
This makes the VFBTerms object hashable and suitable for use in sets and as dictionary keys.
:return: Hash value.
"""
# Use a frozenset of IDs for hashing since frozenset is hashable and immutable
return hash(frozenset(self.get_ids()))

def __str__(self) -> str:
"""
Return a string representation of the VFBTerms object.
"""
return f"VFBTerms({self.get_names()})"

def __ne__(self, value: object) -> bool:
"""
Compare two VFBTerms objects for inequality.
"""
return not self.__eq__(value)

def __add__(self, other):
"""
Add two VFBTerms objects or a VFBTerm object and a VFBTerms object.
Expand Down Expand Up @@ -2970,6 +3110,78 @@ def __eq__(self, other):

# Compare the sets of IDs for equality
return set(self.get_ids()) == set(other.get_ids())

def __lt__(self, other):
"""
Compare two VFBTerms objects based on their term IDs.
:param other: The other VFBTerms object to compare.
:return: True if the current VFBTerms object is less than the other, False otherwise.
"""
if not isinstance(other, VFBTerms):
if isinstance(other, list) and all(isinstance(term, VFBTerm) for term in other):
return sorted(self.get_ids()) < sorted([term.id for term in other])
if isinstance(other, list) and all(isinstance(term, str) for term in other):
return sorted(self.get_ids()) < sorted(other)
return NotImplemented
# Compare based on sorted IDs
return sorted(self.get_ids()) < sorted(other.get_ids())

def __le__(self, other):
"""
Compare two VFBTerms objects based on their term IDs.
:param other: The other VFBTerms object to compare.
:return: True if the current VFBTerms object is less than or equal to the other, False otherwise.
"""
if not isinstance(other, VFBTerms):
if isinstance(other, list) and all(isinstance(term, VFBTerm) for term in other):
return sorted(self.get_ids()) <= sorted([term.id for term in other])
if isinstance(other, list) and all(isinstance(term, str) for term in other):
return sorted(self.get_ids()) <= sorted(other)
return NotImplemented
# Compare based on sorted IDs
return sorted(self.get_ids()) <= sorted(other.get_ids())

def __gt__(self, other):
"""
Compare two VFBTerms objects based on their term IDs.
:param other: The other VFBTerms object to compare.
:return: True if the current VFBTerms object is greater than the other, False otherwise.
"""
if not isinstance(other, VFBTerms):
if isinstance(other, list) and all(isinstance(term, VFBTerm) for term in other):
return sorted(self.get_ids()) > sorted([term.id for term in other])
if isinstance(other, list) and all(isinstance(term, str) for term in other):
return sorted(self.get_ids()) > sorted(other)
return NotImplemented
# Compare based on sorted IDs
return sorted(self.get_ids()) > sorted(other.get_ids())

def __ge__(self, other):
"""
Compare two VFBTerms objects based on their term IDs.
:param other: The other VFBTerms object to compare.
:return: True if the current VFBTerms object is greater than or equal to the other, False otherwise.
"""
if not isinstance(other, VFBTerms):
if isinstance(other, list) and all(isinstance(term, VFBTerm) for term in other):
return sorted(self.get_ids()) >= sorted([term.id for term in other])
if isinstance(other, list) and all(isinstance(term, str) for term in other):
return sorted(self.get_ids()) >= sorted(other)
return NotImplemented
# Compare based on sorted IDs
return sorted(self.get_ids()) >= sorted(other.get_ids())

def __iter__(self):
"""
Make VFBTerms iterable by returning an iterator over the 'terms' list.
:return: Iterator over the list of terms.
"""
return iter(self.terms)

def get_all(self, property_name='name', verbose=False, return_dict=False):
"""
Expand Down Expand Up @@ -3042,6 +3254,7 @@ def get_colours_for(self, property_name='name', verbose=False, take_first=False)

# If the property value is iterable, handle based on take_first flag
if isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
print(f"Property '{property_name}' is iterable. Processing items: {value}") if verbose else None
if take_first:
value = next(iter(value), None) # Take the first value
if verbose and value is not None:
Expand Down

0 comments on commit b48271e

Please sign in to comment.