Skip to content

Commit

Permalink
WIP: First pass at new MaskManager class
Browse files Browse the repository at this point in the history
Signed-off-by: Brianna Major <[email protected]>
  • Loading branch information
bnmajor committed Nov 16, 2023
1 parent 8fe71a5 commit 743c848
Show file tree
Hide file tree
Showing 2 changed files with 312 additions and 0 deletions.
31 changes: 31 additions & 0 deletions hexrdgui/mask_compatability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from hexrd.utils.compatibility import h5py_read_string


def load_old_mask_file(self, h5py_group):
items = {}
visible = list(h5py_read_string(h5py_group['_visible']))
for key, data in h5py_group.items():
if key == '_visible':
continue

if key == 'threshold':
values = data['values'][()].tolist()
items['threshold'] = {
'min_val': values[0],
'max_val': values[1],
'name': 'threshold',
'mtype': 'threshold',
'visible': 'threshold' in visible,
'border': False,
}
else:
for name, masks in data.items():
items.setdefault(name, {
'name': name,
'mtype': 'unknown',
'visible': name in visible,
'border': False,
})
for mask in masks.values():
# Load the numpy array from the hdf5 file
items[name].setdefault(key, []).append(mask[()])
281 changes: 281 additions & 0 deletions hexrdgui/mask_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
import math
from hexrdgui.constants import ViewType
from hexrdgui.create_polar_mask import (
convert_raw_to_polar, create_polar_mask_from_raw
)
from hexrdgui.create_raw_mask import (
apply_threshold_mask, convert_polar_to_raw, create_raw_mask
)
from hexrdgui.hexrd_config import HexrdConfig
from hexrdgui.mask_compatability import load_old_mask_file
from hexrdgui.singletons import Singleton
from hexrdgui.utils import unique_name

from abc import ABC, abstractmethod


class Mask(ABC):
def __init__(self, mtype='', name='', mask_image=True, show_border=False):
self._mask_type = mtype
self._name = name
self._mask_image = mask_image
self._show_border = show_border
self._masked_arrays = None

@property
def masked_arrays(self):
return self._masked_arrays

@masked_arrays.setter
def masked_arrays(self, arrs):
self._masked_arrays = arrs

@property
def mask_type(self):
return self._mask_type

@mask_type.setter
def mask_type(self, mtype):
self._mask_type = mtype

@property
def name(self):
return self._name

@name.setter
def name(self, new_name):
self._name = new_name

@property
def mask_image(self):
return self._mask_image

@mask_image.setter
def mask_image(self, status):
self._mask_image = status

@property
def show_border(self):
return self._show_border

@show_border.setter
def show_border(self, status):
self._show_border = status

# Abstract methods
@abstractmethod
def get_data(self):
pass

@abstractmethod
def set_data(self, data):
pass

@abstractmethod
def get_mask_arrays(self):
pass

@abstractmethod
def update_mask_array(self):
pass

@abstractmethod
def serialize(self):
pass

@abstractmethod
def deserialize(self, data):
pass


class RegionMask(Mask):
def __init__(self):
self._polar = None
self._raw = None

def get_data(self, view=ViewType.raw):
if view == ViewType.raw:
return self._raw
else:
return self._polar

def set_data(self, data, view=ViewType.raw):
if view == ViewType.raw:
raw_data = data
polar_data = []
for det, value in data:
polar_data.extend(convert_raw_to_polar(det, value))
else:
polar_data = data
raw_data = convert_polar_to_raw(data)
self._raw = raw_data
self._polar = polar_data
self.update_mask_array()

def get_mask_arrays(self, view=ViewType.raw):
if view == ViewType.raw:
return self._masked_arrays
else:
# FIXME: Function parameters changed
return create_polar_mask_from_raw(self._raw)

def update_mask_array(self):
# FIXME: Function parameters changed
self._masked_arrays = create_raw_mask(self._raw)

def serialize(self):
data = {
'name': self.name,
'mtype': self.mask_type,
'visible': self.mask_image,
'border': self.show_border,
}
for det, values in self._raw:
data.setdefault(det, []).append(values)
return data

def deserialize(self, data):
self.name = data['name']
self.mask_type = data['mtype']
self.mask_image = data['visible']
self.show_border = data['border']
raw_data = []
for det in HexrdConfig().detector_names:
raw_data.append([(det, v) for v in data[det]])
self.set_data(raw_data)


class ThresholdMask(Mask):
def __init__(self):
self._min = -math.inf
self._max = math.inf

@property
def min_val(self):
return self._min

@min_val.setter
def min_val(self, val):
self._min = val

@property
def max_val(self):
return self._max

@max_val.setter
def max_val(self, val):
self._max = val

def get_data(self):
return [self.min_val, self.max_val]

def set_data(self, data):
self.min_val = data[0]
self.max_val = data[1]
self.update_mask_array()

def get_mask_arrays(self):
return self._masked_arrays

def update_mask_array(self):
# TODO: rename apply_threshold_mask since its purpose has changed now?
# FIXME: Function parameters changed
self._masked_arrays = apply_threshold_mask(self.values)

def serialize(self):
return {
'min_val': self.min_val,
'max_val': self.max_val,
'name': self.name,
'mtype': self.mask_type,
'visible': self.mask_image,
'border': self.show_border,
}

def deserialize(self, data):
self.name = data['name']
self.mask_type = data['mtype']
self.mask_image = data['visible']
self.show_border = data['border']
self.set_data([data['min_val'], data['max_val']])


class MaskManager(metaclass=Singleton):
def __init__(self, view_mode):
self.masks = {}
self.view_mode = view_mode

@property
def visible_masks(self):
return [k for k, v in self.masks if v.mask_image]

@property
def visible_boundaries(self):
return [k for k, v in self.masks if v.show_border]

@property
def threshold_mask(self):
for mask in self.masks.values():
if mask.mask_type == 'threshold':
return mask
return None

def add_mask(self, name, data, mtype, mask_image=True, show_border=False):
# Enforce name uniqueness
name = unique_name(self.masks.keys(), name)
if mtype == 'threshold':
new_mask = ThresholdMask(name, mtype, mask_image)
else:
new_mask = RegionMask(name, mtype, mask_image, show_border)
new_mask.set_data(self.view_mode, data)
self.masks[name] = new_mask

def remove_mask(self, name):
self.masks.pop(name)

def write_all_masks(self, h5py_group=None):
d = {}
for name, mask_info in self.masks:
d[name] = mask_info.serialize()
if h5py_group:
self.write_masks_to_group(d, h5py_group)
else:
self.export_masks_to_file(d)

def save_state(self, h5py_group):
if 'masks' not in h5py_group:
h5py_group.create_group('masks')

self.write_all_masks(h5py_group['masks'])

def load_masks(self, h5py_group):
# TODO: Handle case of detector name mismatch (loading wrong mask file)
items = h5py_group.items()
if '_visible' in h5py_group.values():
# This is a file using the old format
items = load_old_mask_file(h5py_group)
actual_view_mode = self.view_mode
self.view_mode = ViewType.raw
for key, data in items:
if data['mtype'] == 'threshold':
new_mask = ThresholdMask(None, None)
new_mask.deserialize(data)
else:
new_mask = RegionMask(None, None)
new_mask.deserialize(data)
self.masks[key] = new_mask

if not HexrdConfig().loading_state:
# We're importing masks directly,
# don't wait for the state loaded signal
# FIXME: This is not connected to anything atm
self.rebuild_masks()
self.view_mode = actual_view_mode

def load_state(self, h5py_group):
self.masks = {}
if 'masks' in h5py_group:
self.load_masks(h5py_group['masks'])

def update_view_mode(self, mode):
self.view_mode = mode

0 comments on commit 743c848

Please sign in to comment.