diff --git a/hexrdgui/mask_compatability.py b/hexrdgui/mask_compatability.py new file mode 100644 index 000000000..b7e4f3fe6 --- /dev/null +++ b/hexrdgui/mask_compatability.py @@ -0,0 +1,31 @@ +from hexrd.utils.compatibility import h5py_read_string + + +def load_old_mask_file(self, h5py_group): + items = {} + visible = list(h5py_read_string(h5py_group['_visible'])) + for key, data in h5py_group.items(): + if key == '_visible': + continue + + if key == 'threshold': + values = data['values'][()].tolist() + items['threshold'] = { + 'min_val': values[0], + 'max_val': values[1], + 'name': 'threshold', + 'mtype': 'threshold', + 'visible': 'threshold' in visible, + 'border': False, + } + else: + for name, masks in data.items(): + items.setdefault(name, { + 'name': name, + 'mtype': 'unknown', + 'visible': name in visible, + 'border': False, + }) + for mask in masks.values(): + # Load the numpy array from the hdf5 file + items[name].setdefault(key, []).append(mask[()]) diff --git a/hexrdgui/mask_manager.py b/hexrdgui/mask_manager.py new file mode 100644 index 000000000..6247b9b67 --- /dev/null +++ b/hexrdgui/mask_manager.py @@ -0,0 +1,281 @@ +import math +from hexrdgui.constants import ViewType +from hexrdgui.create_polar_mask import ( + convert_raw_to_polar, create_polar_mask_from_raw +) +from hexrdgui.create_raw_mask import ( + apply_threshold_mask, convert_polar_to_raw, create_raw_mask +) +from hexrdgui.hexrd_config import HexrdConfig +from hexrdgui.mask_compatability import load_old_mask_file +from hexrdgui.singletons import Singleton +from hexrdgui.utils import unique_name + +from abc import ABC, abstractmethod + + +class Mask(ABC): + def __init__(self, mtype='', name='', mask_image=True, show_border=False): + self._mask_type = mtype + self._name = name + self._mask_image = mask_image + self._show_border = show_border + self._masked_arrays = None + + @property + def masked_arrays(self): + return self._masked_arrays + + @masked_arrays.setter + def masked_arrays(self, arrs): + self._masked_arrays = arrs + + @property + def mask_type(self): + return self._mask_type + + @mask_type.setter + def mask_type(self, mtype): + self._mask_type = mtype + + @property + def name(self): + return self._name + + @name.setter + def name(self, new_name): + self._name = new_name + + @property + def mask_image(self): + return self._mask_image + + @mask_image.setter + def mask_image(self, status): + self._mask_image = status + + @property + def show_border(self): + return self._show_border + + @show_border.setter + def show_border(self, status): + self._show_border = status + + # Abstract methods + @abstractmethod + def get_data(self): + pass + + @abstractmethod + def set_data(self, data): + pass + + @abstractmethod + def get_mask_arrays(self): + pass + + @abstractmethod + def update_mask_array(self): + pass + + @abstractmethod + def serialize(self): + pass + + @abstractmethod + def deserialize(self, data): + pass + + +class RegionMask(Mask): + def __init__(self): + self._polar = None + self._raw = None + + def get_data(self, view=ViewType.raw): + if view == ViewType.raw: + return self._raw + else: + return self._polar + + def set_data(self, data, view=ViewType.raw): + if view == ViewType.raw: + raw_data = data + polar_data = [] + for det, value in data: + polar_data.extend(convert_raw_to_polar(det, value)) + else: + polar_data = data + raw_data = convert_polar_to_raw(data) + self._raw = raw_data + self._polar = polar_data + self.update_mask_array() + + def get_mask_arrays(self, view=ViewType.raw): + if view == ViewType.raw: + return self._masked_arrays + else: + # FIXME: Function parameters changed + return create_polar_mask_from_raw(self._raw) + + def update_mask_array(self): + # FIXME: Function parameters changed + self._masked_arrays = create_raw_mask(self._raw) + + def serialize(self): + data = { + 'name': self.name, + 'mtype': self.mask_type, + 'visible': self.mask_image, + 'border': self.show_border, + } + for det, values in self._raw: + data.setdefault(det, []).append(values) + return data + + def deserialize(self, data): + self.name = data['name'] + self.mask_type = data['mtype'] + self.mask_image = data['visible'] + self.show_border = data['border'] + raw_data = [] + for det in HexrdConfig().detector_names: + raw_data.append([(det, v) for v in data[det]]) + self.set_data(raw_data) + + +class ThresholdMask(Mask): + def __init__(self): + self._min = -math.inf + self._max = math.inf + + @property + def min_val(self): + return self._min + + @min_val.setter + def min_val(self, val): + self._min = val + + @property + def max_val(self): + return self._max + + @max_val.setter + def max_val(self, val): + self._max = val + + def get_data(self): + return [self.min_val, self.max_val] + + def set_data(self, data): + self.min_val = data[0] + self.max_val = data[1] + self.update_mask_array() + + def get_mask_arrays(self): + return self._masked_arrays + + def update_mask_array(self): + # TODO: rename apply_threshold_mask since its purpose has changed now? + # FIXME: Function parameters changed + self._masked_arrays = apply_threshold_mask(self.values) + + def serialize(self): + return { + 'min_val': self.min_val, + 'max_val': self.max_val, + 'name': self.name, + 'mtype': self.mask_type, + 'visible': self.mask_image, + 'border': self.show_border, + } + + def deserialize(self, data): + self.name = data['name'] + self.mask_type = data['mtype'] + self.mask_image = data['visible'] + self.show_border = data['border'] + self.set_data([data['min_val'], data['max_val']]) + + +class MaskManager(metaclass=Singleton): + def __init__(self, view_mode): + self.masks = {} + self.view_mode = view_mode + + @property + def visible_masks(self): + return [k for k, v in self.masks if v.mask_image] + + @property + def visible_boundaries(self): + return [k for k, v in self.masks if v.show_border] + + @property + def threshold_mask(self): + for mask in self.masks.values(): + if mask.mask_type == 'threshold': + return mask + return None + + def add_mask(self, name, data, mtype, mask_image=True, show_border=False): + # Enforce name uniqueness + name = unique_name(self.masks.keys(), name) + if mtype == 'threshold': + new_mask = ThresholdMask(name, mtype, mask_image) + else: + new_mask = RegionMask(name, mtype, mask_image, show_border) + new_mask.set_data(self.view_mode, data) + self.masks[name] = new_mask + + def remove_mask(self, name): + self.masks.pop(name) + + def write_all_masks(self, h5py_group=None): + d = {} + for name, mask_info in self.masks: + d[name] = mask_info.serialize() + if h5py_group: + self.write_masks_to_group(d, h5py_group) + else: + self.export_masks_to_file(d) + + def save_state(self, h5py_group): + if 'masks' not in h5py_group: + h5py_group.create_group('masks') + + self.write_all_masks(h5py_group['masks']) + + def load_masks(self, h5py_group): + # TODO: Handle case of detector name mismatch (loading wrong mask file) + items = h5py_group.items() + if '_visible' in h5py_group.values(): + # This is a file using the old format + items = load_old_mask_file(h5py_group) + actual_view_mode = self.view_mode + self.view_mode = ViewType.raw + for key, data in items: + if data['mtype'] == 'threshold': + new_mask = ThresholdMask(None, None) + new_mask.deserialize(data) + else: + new_mask = RegionMask(None, None) + new_mask.deserialize(data) + self.masks[key] = new_mask + + if not HexrdConfig().loading_state: + # We're importing masks directly, + # don't wait for the state loaded signal + # FIXME: This is not connected to anything atm + self.rebuild_masks() + self.view_mode = actual_view_mode + + def load_state(self, h5py_group): + self.masks = {} + if 'masks' in h5py_group: + self.load_masks(h5py_group['masks']) + + def update_view_mode(self, mode): + self.view_mode = mode