Skip to content

Commit

Permalink
move models to own files/simplify view (#123)
Browse files Browse the repository at this point in the history
* move models to own files/simplify view

in view of moving to a separate version managing widget

* make brainrender widget docs/naming consistent with changes

* adapt tests to view refactor

* Apply suggestions from code review

Co-authored-by: Igor Tatarnikov <[email protected]>

* refactor to avoid hardcoding columns

* improve/tidy tests for model header

---------

Co-authored-by: Igor Tatarnikov <[email protected]>
  • Loading branch information
alessandrofelder and IgorTatarnikov authored Dec 20, 2023
1 parent fa68926 commit ebf43dc
Show file tree
Hide file tree
Showing 9 changed files with 553 additions and 435 deletions.
44 changes: 17 additions & 27 deletions brainrender_napari/brainrender_widget.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
A napari widget to view atlases.
Atlases that are exposed by the Brainglobe atlas API are
Locally available atlases are
shown in a table view using the Qt model/view framework
[Qt Model/View framework](https://doc.qt.io/qt-6/model-view-programming.html)
Users can download and add the atlas images/structures as layers to the viewer.
Users can add the atlas images/structures as layers to the viewer.
"""
from bg_atlasapi import BrainGlobeAtlas
from bg_atlasapi.list_atlases import get_downloaded_atlases
Expand All @@ -21,16 +21,14 @@
NapariAtlasRepresentation,
)
from brainrender_napari.utils.brainglobe_logo import header_widget
from brainrender_napari.widgets.atlas_table_view import AtlasTableView
from brainrender_napari.widgets.atlas_viewer_view import AtlasViewerView
from brainrender_napari.widgets.structure_view import StructureView


class BrainrenderWidget(QWidget):
"""The purpose of this class is
* to hold atlas visualisation widgets for napari
* coordinate between these widgets and napari by
* creating appropriate signal-slot connections
* creating napari representations as requested
* coordinate between these widgets and napari
"""

def __init__(self, napari_viewer: Viewer):
Expand All @@ -43,7 +41,7 @@ def __init__(self, napari_viewer: Viewer):
self.layout().addWidget(header_widget())

# create widgets
self.atlas_table_view = AtlasTableView(parent=self)
self.atlas_viewer_view = AtlasViewerView(parent=self)

self.show_structure_names = QCheckBox()
self.show_structure_names.setChecked(False)
Expand All @@ -56,14 +54,14 @@ def __init__(self, napari_viewer: Viewer):
self.structure_view = StructureView(parent=self)

# add widgets to the layout as group boxes
self.atlas_table_group = QGroupBox("Atlas table view")
self.atlas_table_group.setToolTip(
"Double-click on row to download/add annotations and reference\n"
self.atlas_viewer_group = QGroupBox("Atlas Viewer")
self.atlas_viewer_group.setToolTip(
"Double-click on row to add annotations and reference\n"
"Right-click to add additional reference images (if any exist)"
)
self.atlas_table_group.setLayout(QVBoxLayout())
self.atlas_table_group.layout().addWidget(self.atlas_table_view)
self.layout().addWidget(self.atlas_table_group)
self.atlas_viewer_group.setLayout(QVBoxLayout())
self.atlas_viewer_group.layout().addWidget(self.atlas_viewer_view)
self.layout().addWidget(self.atlas_viewer_group)

self.structure_tree_group = QGroupBox("3D Atlas region meshes")
self.structure_tree_group.setToolTip(
Expand All @@ -79,16 +77,13 @@ def __init__(self, napari_viewer: Viewer):
self.layout().addWidget(self.structure_tree_group)

# connect atlas view widget signals
self.atlas_table_view.download_atlas_confirmed.connect(
self._on_download_atlas_confirmed
)
self.atlas_table_view.add_atlas_requested.connect(
self.atlas_viewer_view.add_atlas_requested.connect(
self._on_add_atlas_requested
)
self.atlas_table_view.additional_reference_requested.connect(
self.atlas_viewer_view.additional_reference_requested.connect(
self._on_additional_reference_requested
)
self.atlas_table_view.selected_atlas_changed.connect(
self.atlas_viewer_view.selected_atlas_changed.connect(
self._on_atlas_selection_changed
)

Expand All @@ -102,15 +97,10 @@ def __init__(self, napari_viewer: Viewer):
self._on_add_structure_requested
)

def _on_download_atlas_confirmed(self, atlas_name):
"""Ensure structure view is displayed if new atlas downloaded."""
show_structure_names = self.show_structure_names.isChecked()
self.structure_view.refresh(atlas_name, show_structure_names)

def _on_add_structure_requested(self, structure_name: str):
"""Add given structure as napari atlas representation"""
selected_atlas = BrainGlobeAtlas(
self.atlas_table_view.selected_atlas_name()
self.atlas_viewer_view.selected_atlas_name()
)
selected_atlas_representation = NapariAtlasRepresentation(
selected_atlas, self._viewer
Expand All @@ -121,7 +111,7 @@ def _on_additional_reference_requested(
self, additional_reference_name: str
):
"""Add additional reference as napari atlas representation"""
atlas = BrainGlobeAtlas(self.atlas_table_view.selected_atlas_name())
atlas = BrainGlobeAtlas(self.atlas_viewer_view.selected_atlas_name())
atlas_representation = NapariAtlasRepresentation(atlas, self._viewer)
atlas_representation.add_additional_reference(
additional_reference_name
Expand All @@ -144,6 +134,6 @@ def _on_add_atlas_requested(self, atlas_name: str):
selected_atlas_representation.add_to_viewer()

def _on_show_structure_names_clicked(self):
atlas_name = self.atlas_table_view.selected_atlas_name()
atlas_name = self.atlas_viewer_view.selected_atlas_name()
show_structure_names = self.show_structure_names.isChecked()
self.structure_view.refresh(atlas_name, show_structure_names)
96 changes: 96 additions & 0 deletions brainrender_napari/data_models/atlas_table_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from bg_atlasapi.list_atlases import (
get_all_atlases_lastversions,
get_atlases_lastversions,
get_downloaded_atlases,
get_local_atlas_version,
)
from qtpy.QtCore import QAbstractTableModel, QModelIndex, Qt

from brainrender_napari.utils.load_user_data import (
read_atlas_metadata_from_file,
)


class AtlasTableModel(QAbstractTableModel):
"""A table data model for atlases."""

def __init__(self):
super().__init__()
self.column_headers = [
"Raw name",
"Atlas",
"Local version",
"Latest version",
]
self.refresh_data()

def refresh_data(self) -> None:
"""Refresh model data by calling atlas API"""
all_atlases = get_all_atlases_lastversions()
data = []
for name, latest_version in all_atlases.items():
if name in get_atlases_lastversions().keys():
data.append(
[
name,
self._format_name(name),
get_local_atlas_version(name),
latest_version,
]
)
else:
data.append(
[name, self._format_name(name), "n/a", latest_version]
)

self._data = data

def _format_name(self, name: str) -> str:
formatted_name = name.split("_")
formatted_name[0] = formatted_name[0].capitalize()
formatted_name[-1] = f"({formatted_name[-1].split('um')[0]} \u03BCm)"
return " ".join([formatted for formatted in formatted_name])

def data(self, index: QModelIndex, role=Qt.DisplayRole):
if role == Qt.DisplayRole:
return self._data[index.row()][index.column()]
if role == Qt.ToolTipRole:
hovered_atlas_name = self._data[index.row()][0]
return AtlasTableModel._get_tooltip_text(hovered_atlas_name)

def rowCount(self, index: QModelIndex = QModelIndex()):
return len(self._data)

def columnCount(self, index: QModelIndex = QModelIndex()):
return len(self._data[0])

def headerData(
self, section: int, orientation: Qt.Orientation, role: Qt.ItemDataRole
):
"""Customises the horizontal header data of model,
and raises an error if an unexpected column is found."""
if role == Qt.DisplayRole and orientation == Qt.Orientation.Horizontal:
if section >= 0 and section < len(self.column_headers):
return self.column_headers[section]
else:
raise ValueError("Unexpected horizontal header value.")
else:
return super().headerData(section, orientation, role)

@classmethod
def _get_tooltip_text(cls, atlas_name: str):
"""Returns the atlas metadata as a formatted string,
as well as instructions on how to interact with the atlas."""
if atlas_name in get_downloaded_atlases():
metadata = read_atlas_metadata_from_file(atlas_name)
metadata_as_string = ""
for key, value in metadata.items():
metadata_as_string += f"{key}:\t{value}\n"

tooltip_text = f"{atlas_name} (double-click to add to viewer)\
\n{metadata_as_string}"
elif atlas_name in get_all_atlases_lastversions().keys():
tooltip_text = f"{atlas_name} (double-click to download)"
else:
raise ValueError("Tooltip text called with invalid atlas name.")
return tooltip_text
144 changes: 144 additions & 0 deletions brainrender_napari/data_models/structure_tree_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Dict, List

from bg_atlasapi.structure_tree_util import get_structures_tree
from qtpy.QtCore import QAbstractItemModel, QModelIndex, Qt
from qtpy.QtGui import QStandardItem


class StructureTreeItem(QStandardItem):
"""A class to hold items in a tree model."""

def __init__(self, data, parent=None):
self.parent_item = parent
self.item_data = data
self.child_items = []

def appendChild(self, item):
self.child_items.append(item)

def child(self, row):
return self.child_items[row]

def childCount(self):
return len(self.child_items)

def columnCount(self):
return len(self.item_data)

def data(self, column):
try:
return self.item_data[column]
except IndexError:
return None

def parent(self):
return self.parent_item

def row(self):
if self.parent_item:
return self.parent_item.child_items.index(self)
return 0


class StructureTreeModel(QAbstractItemModel):
"""Implementation of a read-only QAbstractItemModel to hold
the structure tree information provided by the Atlas API in a Qt Model"""

def __init__(self, data: List, parent=None):
super().__init__()
self.root_item = StructureTreeItem(data=("acronym", "name", "id"))
self.build_structure_tree(data, self.root_item)

def build_structure_tree(self, structures: List, root: StructureTreeItem):
"""Build the structure tree given a list of structures."""
tree = get_structures_tree(structures)
structure_id_dict = {}
for structure in structures:
structure_id_dict[structure["id"]] = structure

inserted_items: Dict[int, StructureTreeItem] = {}
for n_id in tree.expand_tree(): # sorts nodes by default,
# so parents will always be already in the QAbstractItemModel
# before their children
node = tree.get_node(n_id)
acronym = structure_id_dict[node.identifier]["acronym"]
name = structure_id_dict[node.identifier]["name"]
if (
len(structure_id_dict[node.identifier]["structure_id_path"])
== 1
):
parent_item = root
else:
parent_id = tree.parent(node.identifier).identifier
parent_item = inserted_items[parent_id]

item = StructureTreeItem(
data=(acronym, name, node.identifier), parent=parent_item
)
parent_item.appendChild(item)
inserted_items[node.identifier] = item

def data(self, index: QModelIndex, role=Qt.DisplayRole):
"""Provides read-only data for a given index if
intended for display, otherwise None."""
if not index.isValid():
return None

if role != Qt.DisplayRole:
return None

item = index.internalPointer()

return item.data(index.column())

def rowCount(self, parent: StructureTreeItem):
"""Returns the number of rows(i.e. children) of an item"""
if parent.column() > 0:
return 0

if not parent.isValid():
parent_item = self.root_item
else:
parent_item = parent.internalPointer()

return parent_item.childCount()

def columnCount(self, parent: StructureTreeItem):
"""The number of columns of an item."""
if parent.isValid():
return parent.internalPointer().columnCount()
else:
return self.root_item.columnCount()

def parent(self, index: QModelIndex):
"""The first-column index of parent of the item
at a given index. Returns an empty index if the root,
or an invalid index, is passed.
"""
if not index.isValid():
return QModelIndex()

child_item = index.internalPointer()
parent_item = child_item.parent()

if parent_item == self.root_item:
return QModelIndex()

return self.createIndex(parent_item.row(), 0, parent_item)

def index(self, row, column, parent=QModelIndex()):
"""The index of the item at (row, column) with a given parent.
By default, the given parent is assumed to be the root."""
if not self.hasIndex(row, column, parent):
return QModelIndex()

if not parent.isValid():
parent_item = self.root_item
else:
parent_item = parent.internalPointer()

child_item = parent_item.child(row)
if child_item:
return self.createIndex(row, column, child_item)
else:
return QModelIndex()
Loading

0 comments on commit ebf43dc

Please sign in to comment.