Skip to content

Commit

Permalink
plots.image_2d: WIP adding uncertainty to z cursor label
Browse files Browse the repository at this point in the history
  • Loading branch information
pmldrmota committed May 8, 2024
1 parent 8f5ebb3 commit db3089b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 25 deletions.
49 changes: 38 additions & 11 deletions ndscan/plots/image_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import logging
import numpy as np
import pyqtgraph
from oitg.uncertainty_to_string import uncertainty_to_string

from .._qt import QtCore, QtGui
from . import colormaps
from .cursor import CrosshairAxisLabel, CrosshairLabel, LabeledCrosshairCursor
from .model import ScanModel
from .plot_widgets import AlternateMenuPanesWidget, add_source_id_label
from .plot_widgets import (AlternateMenuPanesWidget, add_source_id_label,
build_num_samples_per_point_context_menu)
from .utils import (extract_linked_datasets, extract_scalar_channels,
format_param_identity, get_axis_scaling_info, setup_axis_item,
enum_to_numeric)
Expand Down Expand Up @@ -60,6 +62,7 @@ def __init__(self, *args, **kwargs):
self.x_range = None
self.y_range = None
self.image_data = None
self.error_data = None
self.z_limits = None

def set_crosshair_info(self, unit_suffix: str, data_to_display_scale: float,
Expand All @@ -76,6 +79,7 @@ def set_crosshair_info(self, unit_suffix: str, data_to_display_scale: float,
def set_image_data(
self,
image_data: np.ndarray,
error_data: np.ndarray,
x_range: tuple[float, float, float],
y_range: tuple[float, float, float],
z_limits: tuple[float, float],
Expand All @@ -86,6 +90,7 @@ def set_image_data(
:param z_limits: The current colormap limits.
"""
self.image_data = image_data
self.error_data = error_data
self.x_range = x_range
self.y_range = y_range
self.z_limits = z_limits
Expand All @@ -100,10 +105,14 @@ def update_coords(self, data_coords):
shape = self.image_data.shape
if (0 <= x_idx < shape[0]) and (0 <= y_idx < shape[1]):
z = self.image_data[x_idx, y_idx]
z_err = self.error_data[x_idx, y_idx]
if np.isnan(z):
self.set_text("")
else:
self.set_value(z, self.z_limits)
if np.isnan(z_err):
self.set_value(z, self.z_limits)
else:
self.set_text(uncertainty_to_string(z, z_err))


class _ImagePlot:
Expand All @@ -130,11 +139,14 @@ def __init__(self, image_item: pyqtgraph.ImageItem,
self.x_range = None
self.y_range = None
self.image_data = None
self.error_data = None

#: Whether to average points with the same coordinates.
self.averaging_enabled = False
#: Assumed number of samples per point for calculating the combined uncertainty.
self.num_samples_per_point = 1
#: Keeps track of the running average and the number of samples therein.
self.averages_by_coords = dict[tuple[float, float], tuple[float, int]]()
self.averages_by_coords = dict[tuple[float, float], tuple[float, float, int]]()

self.z_crosshair_label = CrosshairZDataLabel(self.image_item.getViewBox())

Expand All @@ -154,13 +166,13 @@ def activate_channel(self, channel_name: str):
self.z_crosshair_label.set_crosshair_info(*crosshair_info[0])

self._invalidate_current()
self.update(self.averaging_enabled)
self.update(self.averaging_enabled, self.num_samples_per_point)

def data_changed(self, points, invalidate_previous: bool = False):
self.points = points
if invalidate_previous:
self._invalidate_current()
self.update(self.averaging_enabled)
self.update(self.averaging_enabled, self.num_samples_per_point)

def _invalidate_current(self):
self.num_shown = 0
Expand All @@ -175,7 +187,7 @@ def _active_fixed_z_limits(self) -> tuple[float, float] | None:
return None
return channel["min"], channel["max"]

def update(self, averaging_enabled):
def update(self, averaging_enabled: bool, num_samples_per_point: int):
if not self.points:
return

Expand All @@ -189,17 +201,22 @@ def update(self, averaging_enabled):
num_to_show = min(len(x_data), len(y_data), len(z_data))

if (num_to_show == self.num_shown
and averaging_enabled == self.averaging_enabled):
and averaging_enabled == self.averaging_enabled
and num_samples_per_point == self.num_samples_per_point):
return
if num_samples_per_point != self.num_samples_per_point:
self._invalidate_current()

num_skip = self.num_shown

# Update running averages.
for x, y, z in zip(x_data[num_skip:num_to_show], y_data[num_skip:num_to_show],
z_data[num_skip:num_to_show]):
avg, num = self.averages_by_coords.get((x, y), (0., 0))
avg, err, num = self.averages_by_coords.get((x, y), (0., 0., 0))
# TODO: Update error
num += 1
avg += (z - avg) / num
self.averages_by_coords[(x, y)] = (avg, num)
self.averages_by_coords[(x, y)] = (avg, err, num)

# Determine range of x/y values to show and prepare image buffer accordingly if
# it changed.
Expand All @@ -213,6 +230,7 @@ def update(self, averaging_enabled):
# TODO: Splat old data for progressively less blurry look on refining scans?
self.image_data = np.full(
(_num_points_in_range(x_range), _num_points_in_range(y_range)), np.nan)
self.error_data = np.full_like(self.image_data, np.nan)

self.image_rect = QtCore.QRectF(
QtCore.QPointF(x_range[0] - x_range[2] / 2,
Expand All @@ -233,6 +251,8 @@ def update(self, averaging_enabled):
coords, z = (x_data[data_idx], y_data[data_idx]), z_data[data_idx]
self.image_data[x_idx, y_idx] = (self.averages_by_coords[coords][0]
if averaging_enabled else z)
self.error_data[x_idx, y_idx] = (self.averages_by_coords[coords][1]
if averaging_enabled else np.nan)

cmap = colormaps.plasma
channel = self.channels[self.active_channel_name]
Expand All @@ -257,6 +277,7 @@ def update(self, averaging_enabled):

self.num_shown = num_to_show
self.averaging_enabled = averaging_enabled
self.num_samples_per_point = num_samples_per_point


class Image2DPlotWidget(AlternateMenuPanesWidget):
Expand Down Expand Up @@ -377,8 +398,14 @@ def set_both():
action = builder.append_action("Average points with same coordinates")
action.setCheckable(True)
action.setChecked(self.plot.averaging_enabled)
action.triggered.connect(
lambda *a: self.plot.update(not self.plot.averaging_enabled))
action.triggered.connect(lambda *a: self.plot.update(
not self.plot.averaging_enabled, self.plot.num_samples_per_point))

if self.plot.averaging_enabled:
build_num_samples_per_point_context_menu(
builder,
lambda num: self.plot.update(self.plot.averaging_enabled, num),
self.plot.num_samples_per_point)
builder.ensure_separator()

self.channel_menu_group = QtGui.QActionGroup(self)
Expand Down
19 changes: 19 additions & 0 deletions ndscan/plots/plot_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,22 @@ def state_changed(state, name):
checkbox.stateChanged.connect(lambda a, n=name: state_changed(a, n))
layout.addWidget(checkbox)
update_checkboxes_enabled()


def build_num_samples_per_point_context_menu(builder: ContextMenuBuilder,
value_changed_callback: Callable[[], None],
current_num_samples_per_point: int):
num_samples_box = QtWidgets.QSpinBox()
num_samples_box.setMinimum(1)
num_samples_box.setMaximum(2**16)
num_samples_box.setValue(current_num_samples_per_point)
num_samples_box.valueChanged.connect(value_changed_callback)
container = QtWidgets.QWidget()
layout = QtWidgets.QHBoxLayout()
container.setLayout(layout)
label = QtWidgets.QLabel("Samples per point:")
layout.addWidget(label)
layout.addWidget(num_samples_box)
layout.insertStretch(0)
action = builder.append_widget_action()
action.setDefaultWidget(container)
18 changes: 4 additions & 14 deletions ndscan/plots/xy_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .model.select_point import SelectPointFromScanModel
from .model.subscan import create_subscan_roots
from .plot_widgets import (SubplotMenuPanesWidget, build_channel_selection_context_menu,
build_num_samples_per_point_context_menu,
add_source_id_label)
from .utils import (extract_linked_datasets, extract_scalar_channels,
get_default_hidden_channels, format_param_identity,
Expand Down Expand Up @@ -422,20 +423,9 @@ def build_context_menu(self, pane_idx, builder):
lambda *a: self.enable_averaging(not self.averaging_enabled))

if self.averaging_enabled:
num_samples_box = QtWidgets.QSpinBox()
num_samples_box.setMinimum(1)
num_samples_box.setMaximum(2**16)
num_samples_box.setValue(self.num_samples_per_point)
num_samples_box.valueChanged.connect(self.change_num_samples_per_point)
container = QtWidgets.QWidget()
layout = QtWidgets.QHBoxLayout()
container.setLayout(layout)
label = QtWidgets.QLabel("Samples per point:")
layout.addWidget(label)
layout.addWidget(num_samples_box)
layout.insertStretch(0)
action = builder.append_widget_action()
action.setDefaultWidget(container)
build_num_samples_per_point_context_menu(
builder, self.change_num_samples_per_point,
self.num_samples_per_point)

builder.ensure_separator()

Expand Down

0 comments on commit db3089b

Please sign in to comment.