Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plots: adjustable num_samples_per_point #407

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions ndscan/plots/image_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import logging
import numpy as np
import pyqtgraph
from oitg.uncertainty_to_string import uncertainty_to_string

from .._qt import QtCore, QtGui
from . import colormaps
from .cursor import CrosshairAxisLabel, CrosshairLabel, LabeledCrosshairCursor
from .model import ScanModel
from .plot_widgets import AlternateMenuPanesWidget, add_source_id_label
from .plot_widgets import (AlternateMenuPanesWidget, add_source_id_label,
build_num_samples_per_point_context_menu)
from .utils import (extract_linked_datasets, extract_scalar_channels,
format_param_identity, get_axis_scaling_info, setup_axis_item,
enum_to_numeric)
Expand Down Expand Up @@ -60,6 +62,7 @@ def __init__(self, *args, **kwargs):
self.x_range = None
self.y_range = None
self.image_data = None
self.error_data = None
self.z_limits = None

def set_crosshair_info(self, unit_suffix: str, data_to_display_scale: float,
Expand All @@ -76,6 +79,7 @@ def set_crosshair_info(self, unit_suffix: str, data_to_display_scale: float,
def set_image_data(
self,
image_data: np.ndarray,
error_data: np.ndarray,
x_range: tuple[float, float, float],
y_range: tuple[float, float, float],
z_limits: tuple[float, float],
Expand All @@ -86,6 +90,7 @@ def set_image_data(
:param z_limits: The current colormap limits.
"""
self.image_data = image_data
self.error_data = error_data
self.x_range = x_range
self.y_range = y_range
self.z_limits = z_limits
Expand All @@ -100,10 +105,14 @@ def update_coords(self, data_coords):
shape = self.image_data.shape
if (0 <= x_idx < shape[0]) and (0 <= y_idx < shape[1]):
z = self.image_data[x_idx, y_idx]
z_err = self.error_data[x_idx, y_idx]
if np.isnan(z):
self.set_text("")
else:
self.set_value(z, self.z_limits)
if np.isnan(z_err):
self.set_value(z, self.z_limits)
else:
self.set_text(uncertainty_to_string(z, z_err))


class _ImagePlot:
Expand All @@ -130,11 +139,14 @@ def __init__(self, image_item: pyqtgraph.ImageItem,
self.x_range = None
self.y_range = None
self.image_data = None
self.error_data = None

#: Whether to average points with the same coordinates.
self.averaging_enabled = False
#: Assumed number of samples per point for calculating the combined uncertainty.
self.num_samples_per_point = 1
#: Keeps track of the running average and the number of samples therein.
self.averages_by_coords = dict[tuple[float, float], tuple[float, int]]()
self.averages_by_coords = dict[tuple[float, float], tuple[float, float, int]]()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a a frozen dataclass/… instead of a tuple? This is pretty opaque to the casual reader.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly! I just pushed my WIP, but it should be using much of the same code as the 1D plot eventually.


self.z_crosshair_label = CrosshairZDataLabel(self.image_item.getViewBox())

Expand All @@ -154,13 +166,13 @@ def activate_channel(self, channel_name: str):
self.z_crosshair_label.set_crosshair_info(*crosshair_info[0])

self._invalidate_current()
self.update(self.averaging_enabled)
self.update(self.averaging_enabled, self.num_samples_per_point)

def data_changed(self, points, invalidate_previous: bool = False):
self.points = points
if invalidate_previous:
self._invalidate_current()
self.update(self.averaging_enabled)
self.update(self.averaging_enabled, self.num_samples_per_point)

def _invalidate_current(self):
self.num_shown = 0
Expand All @@ -175,7 +187,7 @@ def _active_fixed_z_limits(self) -> tuple[float, float] | None:
return None
return channel["min"], channel["max"]

def update(self, averaging_enabled):
def update(self, averaging_enabled: bool, num_samples_per_point: int):
if not self.points:
return

Expand All @@ -189,17 +201,22 @@ def update(self, averaging_enabled):
num_to_show = min(len(x_data), len(y_data), len(z_data))

if (num_to_show == self.num_shown
and averaging_enabled == self.averaging_enabled):
and averaging_enabled == self.averaging_enabled
and num_samples_per_point == self.num_samples_per_point):
return
if num_samples_per_point != self.num_samples_per_point:
self._invalidate_current()

num_skip = self.num_shown

# Update running averages.
for x, y, z in zip(x_data[num_skip:num_to_show], y_data[num_skip:num_to_show],
z_data[num_skip:num_to_show]):
avg, num = self.averages_by_coords.get((x, y), (0., 0))
avg, err, num = self.averages_by_coords.get((x, y), (0., 0., 0))
# TODO: Update error
num += 1
avg += (z - avg) / num
self.averages_by_coords[(x, y)] = (avg, num)
self.averages_by_coords[(x, y)] = (avg, err, num)

# Determine range of x/y values to show and prepare image buffer accordingly if
# it changed.
Expand All @@ -213,6 +230,7 @@ def update(self, averaging_enabled):
# TODO: Splat old data for progressively less blurry look on refining scans?
self.image_data = np.full(
(_num_points_in_range(x_range), _num_points_in_range(y_range)), np.nan)
self.error_data = np.full_like(self.image_data, np.nan)

self.image_rect = QtCore.QRectF(
QtCore.QPointF(x_range[0] - x_range[2] / 2,
Expand All @@ -233,6 +251,8 @@ def update(self, averaging_enabled):
coords, z = (x_data[data_idx], y_data[data_idx]), z_data[data_idx]
self.image_data[x_idx, y_idx] = (self.averages_by_coords[coords][0]
if averaging_enabled else z)
self.error_data[x_idx, y_idx] = (self.averages_by_coords[coords][1]
if averaging_enabled else np.nan)

cmap = colormaps.plasma
channel = self.channels[self.active_channel_name]
Expand All @@ -257,6 +277,7 @@ def update(self, averaging_enabled):

self.num_shown = num_to_show
self.averaging_enabled = averaging_enabled
self.num_samples_per_point = num_samples_per_point


class Image2DPlotWidget(AlternateMenuPanesWidget):
Expand Down Expand Up @@ -377,8 +398,14 @@ def set_both():
action = builder.append_action("Average points with same coordinates")
action.setCheckable(True)
action.setChecked(self.plot.averaging_enabled)
action.triggered.connect(
lambda *a: self.plot.update(not self.plot.averaging_enabled))
action.triggered.connect(lambda *a: self.plot.update(
not self.plot.averaging_enabled, self.plot.num_samples_per_point))

if self.plot.averaging_enabled:
build_num_samples_per_point_context_menu(
builder,
lambda num: self.plot.update(self.plot.averaging_enabled, num),
self.plot.num_samples_per_point)
builder.ensure_separator()

self.channel_menu_group = QtGui.QActionGroup(self)
Expand Down
19 changes: 19 additions & 0 deletions ndscan/plots/plot_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,22 @@ def state_changed(state, name):
checkbox.stateChanged.connect(lambda a, n=name: state_changed(a, n))
layout.addWidget(checkbox)
update_checkboxes_enabled()


def build_num_samples_per_point_context_menu(builder: ContextMenuBuilder,
value_changed_callback: Callable[[], None],
current_num_samples_per_point: int):
num_samples_box = QtWidgets.QSpinBox()
num_samples_box.setMinimum(1)
num_samples_box.setMaximum(2**16)
num_samples_box.setValue(current_num_samples_per_point)
num_samples_box.valueChanged.connect(value_changed_callback)
container = QtWidgets.QWidget()
layout = QtWidgets.QHBoxLayout()
container.setLayout(layout)
label = QtWidgets.QLabel("Samples per point:")
layout.addWidget(label)
layout.addWidget(num_samples_box)
layout.insertStretch(0)
action = builder.append_widget_action()
action.setDefaultWidget(container)
18 changes: 4 additions & 14 deletions ndscan/plots/xy_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .model.select_point import SelectPointFromScanModel
from .model.subscan import create_subscan_roots
from .plot_widgets import (SubplotMenuPanesWidget, build_channel_selection_context_menu,
build_num_samples_per_point_context_menu,
add_source_id_label)
from .utils import (extract_linked_datasets, extract_scalar_channels,
get_default_hidden_channels, format_param_identity,
Expand Down Expand Up @@ -422,20 +423,9 @@ def build_context_menu(self, pane_idx, builder):
lambda *a: self.enable_averaging(not self.averaging_enabled))

if self.averaging_enabled:
num_samples_box = QtWidgets.QSpinBox()
num_samples_box.setMinimum(1)
num_samples_box.setMaximum(2**16)
num_samples_box.setValue(self.num_samples_per_point)
num_samples_box.valueChanged.connect(self.change_num_samples_per_point)
container = QtWidgets.QWidget()
layout = QtWidgets.QHBoxLayout()
container.setLayout(layout)
label = QtWidgets.QLabel("Samples per point:")
layout.addWidget(label)
layout.addWidget(num_samples_box)
layout.insertStretch(0)
action = builder.append_widget_action()
action.setDefaultWidget(container)
build_num_samples_per_point_context_menu(
builder, self.change_num_samples_per_point,
self.num_samples_per_point)

builder.ensure_separator()

Expand Down