diff --git a/gates.csv b/gates.csv new file mode 100644 index 0000000..1a89525 --- /dev/null +++ b/gates.csv @@ -0,0 +1,121 @@ +sample_id,marker_id,gate_value +68,Rabbit IgG,3478.531304347826 +68,Goat IgG,0.0 +68,Mouse IgG,0.0 +68,CD73,0.0 +68,CD107B,0.0 +68,MART1,0.0 +68,KI67,0.0 +68,pan-CK,0.0 +68,CD45,0.0 +68,ECAD,0.0 +68,aSMA,0.0 +68,CD56,0.0 +68,CD13,0.0 +68,CD63,0.0 +68,CD32,0.0 +68,CDKN1A,0.0 +68,CCNA2,0.0 +68,CDKN1C,0.0 +68,PCNA_1,0.0 +68,pAUR,0.0 +68,CDKN1B_1,0.0 +68,CCND1,0.0 +68,cPARP,0.0 +68,CDKN1B_2,0.0 +68,pCREB,0.0 +68,CCNB1,0.0 +68,CCNE,0.0 +68,PCNA_2,3487.7142857142853 +68,CDK2,0.0 +68,CDKN2A,0.0 +1,Rabbit IgG,0.0 +1,Goat IgG,0.0 +1,Mouse IgG,0.0 +1,CD73,0.0 +1,CD107B,0.0 +1,MART1,0.0 +1,KI67,0.0 +1,pan-CK,0.0 +1,CD45,0.0 +1,ECAD,0.0 +1,aSMA,0.0 +1,CD56,0.0 +1,CD13,0.0 +1,CD63,0.0 +1,CD32,0.0 +1,CDKN1A,0.0 +1,CCNA2,0.0 +1,CDKN1C,0.0 +1,PCNA_1,0.0 +1,pAUR,0.0 +1,CDKN1B_1,0.0 +1,CCND1,0.0 +1,cPARP,0.0 +1,CDKN1B_2,0.0 +1,pCREB,0.0 +1,CCNB1,0.0 +1,CCNE,0.0 +1,PCNA_2,0.0 +1,CDK2,0.0 +1,CDKN2A,0.0 +18,Rabbit IgG,2666.8921052631576 +18,Goat IgG,0.0 +18,Mouse IgG,0.0 +18,CD73,0.0 +18,CD107B,0.0 +18,MART1,0.0 +18,KI67,0.0 +18,pan-CK,0.0 +18,CD45,0.0 +18,ECAD,0.0 +18,aSMA,0.0 +18,CD56,0.0 +18,CD13,578.2247191011238 +18,CD63,0.0 +18,CD32,0.0 +18,CDKN1A,0.0 +18,CCNA2,0.0 +18,CDKN1C,0.0 +18,PCNA_1,0.0 +18,pAUR,0.0 +18,CDKN1B_1,0.0 +18,CCND1,0.0 +18,cPARP,0.0 +18,CDKN1B_2,0.0 +18,pCREB,0.0 +18,CCNB1,0.0 +18,CCNE,0.0 +18,PCNA_2,0.0 +18,CDK2,0.0 +18,CDKN2A,0.0 +15,Rabbit IgG,0.0 +15,Goat IgG,3723.361515151514 +15,Mouse IgG,3703.99125 +15,CD73,1577.505063291139 +15,CD107B,0.0 +15,MART1,0.0 +15,KI67,0.0 +15,pan-CK,0.0 +15,CD45,0.0 +15,ECAD,0.0 +15,aSMA,0.0 +15,CD56,0.0 +15,CD13,0.0 +15,CD63,0.0 +15,CD32,0.0 +15,CDKN1A,821.3441860465116 +15,CCNA2,0.0 +15,CDKN1C,0.0 +15,PCNA_1,0.0 +15,pAUR,0.0 +15,CDKN1B_1,0.0 +15,CCND1,0.0 +15,cPARP,0.0 +15,CDKN1B_2,0.0 +15,pCREB,0.0 +15,CCNB1,0.0 +15,CCNE,0.0 +15,PCNA_2,0.0 +15,CDK2,0.0 +15,CDKN2A,0.0 diff --git a/scratch b/scratch new file mode 100644 index 0000000..cf09230 --- /dev/null +++ b/scratch @@ -0,0 +1,6 @@ +# gates should be a dataframe where +# Column 1 is sample_id +# Column 2 is marker_id +# Column 3 is gate #default just gate +# if extra gates desired, Columns 4... are gate_1, gate_2, etc + diff --git a/src/cell_gater/model/data_model.py b/src/cell_gater/model/data_model.py index 0858845..67144de 100644 --- a/src/cell_gater/model/data_model.py +++ b/src/cell_gater/model/data_model.py @@ -7,7 +7,6 @@ import pandas as pd from napari.utils.events import EmitterGroup, Event - @dataclass class DataModel: """Model containing all necessary fields for gating.""" @@ -33,6 +32,24 @@ class DataModel: _gates: pd.DataFrame = field(default_factory=pd.DataFrame, init=False) _current_gate: float = field(default_factory=float, init=False) + @property + def gates(self): + """The gates dataframe.""" + return self._gates + + @gates.setter + def gates(self, gates: pd.DataFrame) -> None: + self._gates = gates + + @property + def current_gate(self) -> float: + """The current gate value.""" + return self._current_gate + + @current_gate.setter + def current_gate(self, value: float) -> None: + self._current_gate = value + def __post_init__(self) -> None: """Allow fields in the dataclass to emit events when changed.""" self.events = EmitterGroup(source=self, samples=Event, regionprops_df=Event, validated=Event) diff --git a/src/cell_gater/utils/misc.py b/src/cell_gater/utils/misc.py index c79387f..f7bac20 100644 --- a/src/cell_gater/utils/misc.py +++ b/src/cell_gater/utils/misc.py @@ -7,4 +7,4 @@ def napari_notification(msg, severity=NotificationSeverity.INFO): notification_ = Notification(msg, severity=severity) - notification_manager.dispatch(notification_) + notification_manager.dispatch(notification_) \ No newline at end of file diff --git a/src/cell_gater/widgets/sample_widget.py b/src/cell_gater/widgets/sample_widget.py index 291dd54..844750a 100644 --- a/src/cell_gater/widgets/sample_widget.py +++ b/src/cell_gater/widgets/sample_widget.py @@ -21,7 +21,6 @@ from cell_gater.utils.misc import napari_notification from cell_gater.widgets.scatter_widget import ScatterInputWidget - class SampleWidget(QWidget): """Sample widget for loading required data.""" @@ -53,17 +52,17 @@ def __init__(self, viewer: Viewer, model: DataModel | None = None) -> None: # Open sample directory dialog self.load_samples_button = QPushButton("Load regionprops dir") self.load_samples_button.clicked.connect(self._open_sample_dialog) - self.layout().addWidget(self.load_samples_button, 1, 0) + self.layout().addWidget(self.load_samples_button, 0, 1) # Open image directory dialog self.load_image_dir_button = QPushButton("Load image dir") self.load_image_dir_button.clicked.connect(self._open_image_dir_dialog) - self.layout().addWidget(self.load_image_dir_button, 1, 1) + self.layout().addWidget(self.load_image_dir_button, 0, 2) # Open mask directory dialog self.load_mask_dir_button = QPushButton("Load mask dir") self.load_mask_dir_button.clicked.connect(self._open_mask_dir_dialog) - self.layout().addWidget(self.load_mask_dir_button, 1, 2) + self.layout().addWidget(self.load_mask_dir_button, 0, 3) # The lower bound marker column dropdown lower_col = QLabel("Select lowerbound marker column:") @@ -71,8 +70,8 @@ def __init__(self, viewer: Viewer, model: DataModel | None = None) -> None: if len(self.model.regionprops_df) > 0: self.lower_bound_marker_col.addItems([None] + self.model.regionprops_df.columns) self.lower_bound_marker_col.currentTextChanged.connect(self._update_model_lowerbound) - self.layout().addWidget(lower_col, 2, 0) - self.layout().addWidget(self.lower_bound_marker_col, 3, 0) + self.layout().addWidget(lower_col, 1, 0, 1, 2) + self.layout().addWidget(self.lower_bound_marker_col, 1, 2, 1, 2) # The upper bound marker column dropdown upper_col = QLabel("Select upperbound marker column:") @@ -80,8 +79,8 @@ def __init__(self, viewer: Viewer, model: DataModel | None = None) -> None: if len(self.model.regionprops_df) > 0: self.upper_bound_marker_col.addItems([None] + self.model.regionprops_df.columns) self.upper_bound_marker_col.currentTextChanged.connect(self._update_model_upperbound) - self.layout().addWidget(upper_col, 2, 1) - self.layout().addWidget(self.upper_bound_marker_col, 3, 1) + self.layout().addWidget(upper_col, 2, 0, 1, 2) + self.layout().addWidget(self.upper_bound_marker_col, 2, 2, 1, 2) # Filter field for user to pass on strings to filter markers out. filter_label = QLabel("Remove markers with prefix (default: DNA,DAPI)") @@ -90,16 +89,21 @@ def __init__(self, viewer: Viewer, model: DataModel | None = None) -> None: placeholderText="Prefixes separated by commas.", ) self.filter_field.editingFinished.connect(self._update_filter) - self.layout().addWidget(filter_label, 4, 0) - self.layout().addWidget(self.filter_field, 5, 0) + self.layout().addWidget(filter_label, 3, 0, 1 ,2) + self.layout().addWidget(self.filter_field, 3, 3) # Button to start validating all the input self.validate_button = QPushButton("Validate input") self.validate_button.clicked.connect(self._validate) - self.layout().addWidget(self.validate_button, 6, 0) + self.layout().addWidget(self.validate_button, 4, 0, 1, 4) self.model.events.regionprops_df.connect(self._set_dropdown_marker_lowerbound) self.model.events.regionprops_df.connect(self._set_dropdown_marker_upperbound) + + #set default bounds + + + @property def viewer(self) -> Viewer: @@ -176,6 +180,7 @@ def _set_dropdown_marker_lowerbound(self): region_props = self.model.regionprops_df if region_props is not None and len(region_props) > 0: self.lower_bound_marker_col.addItems(region_props.columns) + self.lower_bound_marker_col.setCurrentIndex(1) # Skip the cell id column def _set_dropdown_marker_upperbound(self): """Add items to dropdown menu for upperbound marker. @@ -187,6 +192,13 @@ def _set_dropdown_marker_upperbound(self): region_props = self.model.regionprops_df if region_props is not None and len(region_props) > 0: self.upper_bound_marker_col.addItems(region_props.columns) + + #TODO set default to column before "X_centroid" + # This does not work + # if "X_centroid" in list(self.model.regionprops_df.columns): + # self.upper_bound_marker_col.setCurrentIndex( + # self.model.regionprops_df.columns.index("X_centroid")-1 ) + def _update_model_lowerbound(self): """Update the lowerbound marker in the data model upon change of text in the lowerbound marker column widget.""" diff --git a/src/cell_gater/widgets/scatter_widget.py b/src/cell_gater/widgets/scatter_widget.py index 4d53ad8..57c96ab 100644 --- a/src/cell_gater/widgets/scatter_widget.py +++ b/src/cell_gater/widgets/scatter_widget.py @@ -9,20 +9,50 @@ from matplotlib.backends.backend_qt5agg import ( NavigationToolbar2QT as NavigationToolbar, ) -from matplotlib.figure import Figure + +from napari.utils.history import ( + get_open_history, + update_open_history, +) +import napari from napari import Viewer from napari.layers import Image +from PyQt5.QtCore import Qt from qtpy.QtWidgets import ( QComboBox, QLabel, QSizePolicy, QVBoxLayout, + QPushButton, QWidget, QGridLayout, + QSlider, + QFileDialog ) +from matplotlib.widgets import Slider +from matplotlib.figure import Figure +import matplotlib.pyplot as plt + from cell_gater.model.data_model import DataModel +from cell_gater.utils.misc import napari_notification +import numpy as np +import pandas as pd +from itertools import product +import sys +import os +from loguru import logger +logger.remove() +logger.add(sys.stdout, format="{time:HH:mm:ss.SS} | {level} | {message}") + +#Good to have features +#TODO Dynamic loading of markers, without reloading masks or DNA channel, so deprecate Load Sample and Marker button + +#Ideas to maybe implement +#TODO log axis options for scatter plot +#TODO dynamic plotting of points on top of created polygons +#TODO save plots as images for QC, perhaps when saving gates run plotting function to go through all samples and markers and save plots class ScatterInputWidget(QWidget): """Widget for a scatter plot with markers on the x axis and any dtype column on the y axis.""" @@ -36,6 +66,10 @@ def __init__(self, model: DataModel, viewer: Viewer) -> None: self._model = model self._viewer = viewer + logger.debug("ScatterInputWidget initialized") + logger.debug(f"Model regionprops_df shape: {self.model.regionprops_df.shape}") + logger.debug(f"Model regionprops_df columns: {self.model.regionprops_df.columns}") + # Reason for setting current sample here as well is, so we can check whether we have to load a new mask. self._current_sample = None self._image = None @@ -52,19 +86,23 @@ def __init__(self, model: DataModel, viewer: Viewer) -> None: self.marker_selection_dropdown.addItems(self.model.markers) self.marker_selection_dropdown.currentTextChanged.connect(self._on_marker_changed) + apply_button = QPushButton("Load Sample and Marker") + apply_button.clicked.connect(self._load_images_and_scatter_plot) + choose_y_axis_label = QLabel("Choose Y-axis") self.choose_y_axis_dropdown = QComboBox() - self.choose_y_axis_dropdown.addItems([None] + self.model.regionprops_df.columns) + self.choose_y_axis_dropdown.addItems(self.model.regionprops_df.columns) self.choose_y_axis_dropdown.setCurrentText("Area") self.choose_y_axis_dropdown.currentTextChanged.connect(self._on_y_axis_changed) self.layout().addWidget(selection_label, 0, 0) - self.layout().addWidget(self.sample_selection_dropdown, 1, 0) - self.layout().addWidget(marker_label, 2, 0) - self.layout().addWidget(self.marker_selection_dropdown, 3, 0) - self.layout().addWidget(choose_y_axis_label, 4, 0) - self.layout().addWidget(self.choose_y_axis_dropdown, 5, 0) - + self.layout().addWidget(self.sample_selection_dropdown, 0, 1) + self.layout().addWidget(marker_label, 0, 2) + self.layout().addWidget(self.marker_selection_dropdown, 0, 3) + self.layout().addWidget(apply_button, 1, 0, 1, 4) + self.layout().addWidget(choose_y_axis_label, 2, 0, 1, 2) + self.layout().addWidget(self.choose_y_axis_dropdown, 2, 2, 1, 2) + # we have to do this because initially the dropdowns did not change texts yet so these variables are still None. self.model.active_sample = self.sample_selection_dropdown.currentText() self.model.active_marker = self.marker_selection_dropdown.currentText() @@ -73,36 +111,169 @@ def __init__(self, model: DataModel, viewer: Viewer) -> None: self._read_data(self.model.active_sample) self._load_layers(self.model.markers[self.model.active_marker]) - #this has to go after active sample and marker are set + # scatter plot self.scatter_canvas = PlotCanvas(self.model) - # self.layout().addWidget(NavigationToolbar(self.gate_canvas, self)) - self.layout().addWidget(self.scatter_canvas.fig, 6, 0) - # Update the plot initially - self.update_plot() - # the scatter plot is not updating when the gate is changed - # unsure what is happening here + self.layout().addWidget(self.scatter_canvas.fig, 3, 0, 1, 4) + + # slider + self.slider_figure = Figure(figsize=(5, 1)) + self.slider_canvas = FigureCanvas(self.slider_figure) + self.slider_ax = self.slider_figure.add_subplot(111) + self.update_slider() + self.layout().addWidget(self.slider_canvas, 4, 0, 1, 4) + + # plot points button + plot_points_button = QPushButton("Plot Points") + plot_points_button.clicked.connect(self.plot_points) + self.layout().addWidget(plot_points_button, 5,0,1,1) + + # Initialize gates dataframe + sample_marker_combinations = list(product( + self.model.regionprops_df['sample_id'].unique(), + self.model.markers + )) + self.model.gates = pd.DataFrame(sample_marker_combinations, columns=['sample_id', 'marker_id']) + self.model.gates['gate_value'] = float(0) + + # gate buttons + save_gate_button = QPushButton("Save Gate") + save_gate_button.clicked.connect(self.save_gate) + self.layout().addWidget(save_gate_button, 5, 1, 1, 1) + + load_gates_button = QPushButton("Load Gates Dataframe") + load_gates_button.clicked.connect(self.load_gates_dataframe) + self.layout().addWidget(load_gates_button, 5, 2, 1, 1) + + save_gates_dataframe_button = QPushButton("Save Gates Dataframe") + save_gates_dataframe_button.clicked.connect(self.save_gates_dataframe) + self.layout().addWidget(save_gates_dataframe_button, 5, 3, 1, 1) + + + ########################### FUNCTIONS ########################### + + ################### + ### PLOT POINTS ### + ################### - @property - def model(self) -> DataModel: - """The dataclass model that stores information required for cell_gating.""" - return self._model + #TODO keep adding point layers to the viewer with simple names, and hide old ones + #TODO how to list layers, filter to points layers, and hide them + #TODO dynamic plotting of points on top of created polygons - @model.setter - def model(self, model: DataModel) -> None: - self._model = model + def plot_points(self): + """Plot positive cells in Napari.""" + assert self.model.active_sample is not None + assert self.model.active_marker is not None - @property - def viewer(self) -> Viewer: - """The napari Viewer.""" - return self._viewer + df = self.model.regionprops_df + df = df[df["sample_id"] == self.model.active_sample] + + self.viewer.add_points( + df[df[self.model.active_marker] > self.model.current_gate][["Y_centroid", "X_centroid"]], + name=f"Gate: {round(self.model.current_gate)} {self.model.active_sample} {self.model.active_marker}", + face_color="#ff00ff", + edge_color="yellow", + size=8, + opacity=0.5, + ) - @viewer.setter - def viewer(self, viewer: Viewer) -> None: - self._viewer = viewer + #################################### + ### GATES DATAFRAME INPUT OUTPUT ### + #################################### + + def load_gates_dataframe(self): + file_path, _ = self._file_dialog() + if file_path: + self.model.gates = pd.read_csv(file_path) + self.model.gates['sample_id'] = self.model.gates['sample_id'].astype(str) + # check if dataframe has samples and markers + assert 'sample_id' in self.model.gates.columns + assert 'marker_id' in self.model.gates.columns + assert 'gate_value' in self.model.gates.columns + # check if dataframe has the same samples and markers as the regionprops_df + assert set(self.model.gates['sample_id'].unique()) == set(self.model.regionprops_df['sample_id'].unique()) + assert set(self.model.gates['marker_id'].unique()) == set(self.model.markers) + + def save_gates_dataframe(self): + options = QFileDialog.Options() + fileName, _ = QFileDialog.getSaveFileName(self, "Save Gates Dataframe", "", "CSV Files (*.csv);;All Files (*)", options=options) + if fileName: + self.model.gates.to_csv(fileName, index=False) + print("File saved to:", fileName) + + def save_gate(self): + if self.model.current_gate == 0: + napari_notification("Gate not saved, please select a gate value.") + if self.access_gate() == self.model.current_gate: + napari_notification("No changes detected.") + if self.access_gate() != self.model.current_gate: + napari_notification(f"Old gate {self.access_gate().round(2)} overwritten to {self.model.current_gate.round(2)}") + self.model.gates.loc[ + (self.model.gates['sample_id'] == self.model.active_sample) & + (self.model.gates['marker_id'] == self.model.active_marker), + 'gate_value'] = self.model.current_gate + assert self.access_gate() == self.model.current_gate + logger.debug(f"Gate saved: {self.model.current_gate}") + + def access_gate(self): + assert self.model.active_sample is not None + assert self.model.active_marker is not None + gate_value = self.model.gates.loc[ + (self.model.gates['sample_id'] == self.model.active_sample) & + (self.model.gates['marker_id'] == self.model.active_marker), + 'gate_value'].values[0] + assert isinstance(gate_value, float) + return gate_value + + ########################## + #### SLIDER FUNCTIONS #### + ########################## + + def get_min_max_median_step(self) -> tuple: + df = self.model.regionprops_df + df = df[df["sample_id"] == self.model.active_sample] + min = df[self.model.active_marker].min() + max = df[self.model.active_marker].max() + init = df[self.model.active_marker].median() + step = min / 100 + return min, max, init, step + + def slider_changed(self, val): + self.model._current_gate = val + self.scatter_canvas.update_vertical_line(val) + self.scatter_canvas.fig.draw() + + def update_slider(self): + min, max, init, step = self.get_min_max_median_step() + self.slider_ax.clear() + self.slider = Slider(self.slider_ax, "Gate", min, max, valinit=init, valstep=step, color="black") + self.slider.on_changed(self.slider_changed) + self.slider_canvas.draw() + + ########################## + ###### LOADING DATA ###### + ########################## + + def update_plot(self): + self.scatter_canvas.ax.clear() + self.scatter_canvas.plot_scatter_plot(self.model) + self.scatter_canvas.fig.draw() + + def _load_images_and_scatter_plot(self): + self._clear_layers(clear_all=True) + self._read_data(self.model.active_sample) + self._load_layers(self.model.markers[self.model.active_marker]) + logger.debug(f"loading index {self.model.markers[self.model.active_marker]}") + self.update_plot() + self.update_slider() + + def _read_data(self, sample: str | None) -> None: + logger.info(f"Reading data for sample {sample}.") if sample is not None: + logger.debug(f"Reading image from {self.model.sample_image_mapping[sample]}.") image_path = self.model.sample_image_mapping[sample] + logger.debug(f"Reading mask from {self.model.sample_mask_mapping[sample]}.") mask_path = self.model.sample_mask_mapping[sample] self._image = imread(image_path) @@ -110,14 +281,21 @@ def _read_data(self, sample: str | None) -> None: def _load_layers(self, marker_index): - if self.model.active_sample != self._current_sample: - self._current_sample = copy(self.model.active_sample) - self.viewer.add_labels( - self._mask, - name="mask_" + self.model.active_sample, - visible=False, opacity=0.4 - ) + # if self.model.active_sample != self._current_sample: + # self._current_sample = copy(self.model.active_sample) + #TODO let user decide which is their DNA channel + self.viewer.add_image( + self._image[0], + name="DNA_" + self.model.active_sample, + blending="additive", + visible=False + ) + self.viewer.add_labels( + self._mask, + name="mask_" + self.model.active_sample, + visible=False, opacity=0.4 + ) self.viewer.add_image( self._image[marker_index], name=self.model.active_marker + "_" + self.model.active_sample, @@ -125,19 +303,9 @@ def _load_layers(self, marker_index): ) def _on_sample_changed(self): - """Set the active sample. - - This changes the active sample and clears the layers and the marker selection dropdown. - Subsequently, the new layers are loaded. - """ self.model.active_sample = self.sample_selection_dropdown.currentText() - - self._clear_layers(clear_all=True) - self._reinitiate_marker_selection_dropdown() - - self._read_data(self.model.active_sample) - self._load_layers(self.model.markers[self.model.active_marker]) - self.update_plot() + def _on_marker_changed(self): + self.model.active_marker = self.marker_selection_dropdown.currentText() def _clear_layers(self, clear_all: bool) -> None: """Remove all layers upon changing sample.""" @@ -149,21 +317,14 @@ def _clear_layers(self, clear_all: bool) -> None: if isinstance(layer, Image): self.viewer.layers.remove(layer) - def _reinitiate_marker_selection_dropdown(self) -> None: - """Reiniatiate the marker selection dropdown after sample has changed.""" - # This is preemptively added for clearing visual completed feedback once implemented. - # We also block the outgoing signal in order not to update the layer when there is no current active marker. - self.marker_selection_dropdown.blockSignals(True) - self.marker_selection_dropdown.clear() - self.marker_selection_dropdown.addItems(self.model.markers) - self.marker_selection_dropdown.blockSignals(False) - - def _on_marker_changed(self): - """Set active marker and update the marker image layer.""" - self.model.active_marker = self.marker_selection_dropdown.currentText() - self._clear_layers(clear_all=False) - self._load_layers(self.model.markers[self.model.active_marker]) - self.update_plot() + # def _reinitiate_marker_selection_dropdown(self) -> None: + # """Reiniatiate the marker selection dropdown after sample has changed.""" + # # This is preemptively added for clearing visual completed feedback once implemented. + # # We also block the outgoing signal in order not to update the layer when there is no current active marker. + # self.marker_selection_dropdown.blockSignals(True) + # self.marker_selection_dropdown.clear() + # self.marker_selection_dropdown.addItems(self.model.markers) + # self.marker_selection_dropdown.blockSignals(False) def _set_samples_dropdown(self) -> None: """Set the items for the samples dropdown QComboBox.""" @@ -181,9 +342,37 @@ def _on_y_axis_changed(self): self.model.active_y_axis = self.choose_y_axis_dropdown.currentText() self.update_plot() - def update_plot(self): - self.scatter_canvas.plot_scatter_plot() + def _file_dialog(self): + """Open dialog for a user to select a file.""" + dlg = QFileDialog() + hist = get_open_history() + dlg.setHistory(hist) + options = QFileDialog.Options() + return dlg.getOpenFileName( + self, + "Select file", + hist[0], + "CSV Files (*.csv)", + options=options, + ) + + @property + def model(self) -> DataModel: + """The dataclass model that stores information required for cell_gating.""" + return self._model + + @model.setter + def model(self, model: DataModel) -> None: + self._model = model + + @property + def viewer(self) -> Viewer: + """The napari Viewer.""" + return self._viewer + @viewer.setter + def viewer(self, viewer: Viewer) -> None: + self._viewer = viewer class PlotCanvas(): """The canvas class for the gating scatter plot.""" @@ -191,7 +380,7 @@ class PlotCanvas(): def __init__(self, model: DataModel): self.model = DataModel() if model is None else model - self.fig = FigureCanvas(Figure()) + self.fig = FigureCanvas(Figure()) self.fig.figure.subplots_adjust(left=0.1, bottom=0.1) self.ax = self.fig.figure.subplots() self.ax.set_title("Scatter plot") @@ -208,73 +397,34 @@ def model(self, model: DataModel) -> None: self._model = model def plot_scatter_plot(self, model: DataModel) -> None: - """Plot the scatter plot.""" - - # check if sample and marker are selected assert self.model.active_marker is not None assert self.model.active_sample is not None - # get the data for the scatter plot df = self.model.regionprops_df df = df[df["sample_id"] == self.model.active_sample] + + logger.debug(f"Plotting scatter plot for {self.model.active_sample} and {self.model.active_marker}.") self.ax.scatter( x=df[self.model.active_marker], - y=df[self.model.active_y_axis], # later change to desired_y_axis + y=df[self.model.active_y_axis], color="steelblue", ec="white", lw=0.1, alpha=1.0, s=80000 / int(df.shape[0]), ) - # - self.ax.set_ylabel("Area") # later change to desired_y_axis + # Set x-axis limits + self.ax.set_xlim(df[self.model.active_marker].min(), df[self.model.active_marker].max()) + self.ax.set_ylabel(self.model.active_y_axis) self.ax.set_xlabel(f"{self.model.active_marker} intensity") - # - # # add vertical line at current gate if it exists - # if self.model.current_gate is not None: - # ax.axvline(x=self.model.current_gate, color="red", linewidth=1.0, linestyle="--") - # - # minimum = df[self.model.active_marker].min() - # maximum = df[self.model.active_marker].max() - # value_initial = df[self.model.active_marker].median() - # value_step = minimum / 100 - # - # # add slider as an axis, underneath the scatter plot - # axSlider = fig.add_axes([0.1, 0.01, 0.8, 0.03], facecolor="yellow") - # slider = Slider(axSlider, "Gate", minimum, maximum, valinit=value_initial, valstep=value_step, color="black") - # - # def update_gate(val): - # self.model.current_gate = val - # ax.axvline(x=self.model.current_gate, color="red", linewidth=1.0, linestyle="--") - # napari_notification(f"Gate set to {val}") - # - # slider.on_changed(update_gate) - # - # # TODO add the plot to a widget and display it - # - # def plot_points(self, model: DataModel) -> None: - # """Plot positive cells in Napari.""" - # if self.model.active_marker is not None: - # marker = self.model.active_marker - # if self.model.active_sample is not None: - # sample = self.model.active_sample - # - # viewer = self.model.viewer - # df = self.model.regionprops_df - # df = df[df["sample_id"] == self.model.active_sample] - # gate = self.model.gates.loc[marker, sample] - # - # viewer.add_points( - # df[df[marker] > gate][["X_centroid", "Y_centroid"]], - # name=f"{gate} and its positive cells", - # face_color="red", - # edge_color="black", - # size=15, - # ) - # - # def plot_points_button(self): - # """Plot points button.""" - # self.plot_points_button = QPushButton("Plot Points") - # self.plot_points_button.clicked.connect(self.plot_points) - # self.layout().addWidget(self.plot_points_button, 1, 2) # not sure where to put this button + + logger.debug(f"The current gate is {self.model.current_gate}.") + if self.model.current_gate > 0.0: + self.ax.axvline(x=self.model.current_gate, color="red", linewidth=1.0, linestyle="--") + else: + self.ax.axvline(x=1, color="red", linewidth=1.0, linestyle="--") + + def update_vertical_line(self, x_position): + """Update the position of the vertical line.""" + self.ax.lines[0].set_xdata(x_position) \ No newline at end of file