Skip to content

Commit

Permalink
feat: update gallery and fix examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Feb 5, 2025
1 parent bf18a00 commit 35a2375
Show file tree
Hide file tree
Showing 12 changed files with 225 additions and 22 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ __pycache__/
_version.py

*.mrd
**/*.mrd
results/
multirun/
outputs/
Expand All @@ -33,7 +34,7 @@ docs/auto_*/
docs/sg_execution_times.rst



jupyter_execute/
htmlcov/
.tox/
.nox/
Expand Down
175 changes: 175 additions & 0 deletions docs/_ext/colab_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""A Sphinx extension to add a button to open a notebook in Google Colab."""

from docutils import nodes
from sphinx.util.docutils import SphinxDirective
from sphinx_gallery.notebook import add_code_cell, add_markdown_cell

import os
import json


class ColabLinkNode(nodes.General, nodes.Element):
"""A custom docutils node to represent the Colab link."""


def visit_colab_link_node_html(self, node):
self.body.append(node["html"])


def depart_colab_link_node_html(self, node):
pass


class ColabLinkDirective(SphinxDirective):
"""Directive to insert a link to open a notebook in Google Colab."""

has_content = True
option_spec = {
"needs_gpu": int,
}

def run(self):
"""Run the directive."""
# Determine the path of the current .rst file
rst_file_path = self.env.doc2path(self.env.docname)
rst_file_dir = os.path.dirname(rst_file_path)

# Determine the notebook file path assuming it is in the same directory as the .rst file
notebook_filename = os.path.basename(rst_file_path).replace(".rst", ".ipynb")

# Full path to the notebook
notebook_full_path = os.path.join(rst_file_dir, notebook_filename)

# Convert the full path back to a relative path from the repo root
# repo_root = self.config.project_root_dir
notebook_repo_relative_path = os.path.relpath(
notebook_full_path, os.path.join(os.getcwd(), "docs")
)

config_ext = self.env.config["colab_notebook"]
base_colab_url = config_ext.get(
"base_colab_url", "https://colab.research.google.com"
)
base_repo = config_ext["repo"]
branch = config_ext["branch"]
path = config_ext["path"]
# Generate the Colab URL based on GitHub repo information
self.colab_url = f"{base_colab_url}/{base_repo}/blob/{branch}/{path}/{notebook_repo_relative_path}"

# Create the HTML button or link
self.html = f"""<div class="colab-button">
<a href="{self.colab_url}" target="_blank">
<img src="https://colab.research.google.com/assets/colab-badge.svg"
alt="Open In Colab"/>
</a>
</div>
"""
self.notebook_modifier(notebook_full_path, "\n".join(self.content))

# Create the node to insert the HTML
node = ColabLinkNode(html=self.html)
return [node]

def notebook_modifier(self, notebook_path, commands):
"""Modify the notebook to add a warning about GPU requirement."""
with open(notebook_path) as f:
notebook = json.load(f)
if "cells" not in notebook:
notebook["cells"] = []

# Add a cell to install the required libraries at the position where we have
# colab link
idx = self.find_index_of_colab_link(notebook)

code_lines = ["# Install libraries"]
code_lines.append(commands)
code_lines.append(
"!pip install" + " ".join(self.env.config["colab_notebook"]["dependencies"])
)
dummy_notebook_content = {"cells": []}
add_code_cell(
dummy_notebook_content,
"\n".join(code_lines),
)
notebook["cells"][idx] = dummy_notebook_content["cells"][0]

needs_GPU = self.options.get("needs_gpu", False)
if needs_GPU:
# Add a warning cell at the top of the notebook
warning_template = "\n".join(
[
"<div class='alert alert-{message_class}'>",
"",
"# Need GPU warning",
"",
"{message}",
"</div>",
self.html,
]
)
message_class = "warning"
message = (
"Running this example requires a GPU, and hence is NOT "
"possible on binder currently We request you to kindly run this notebook "
"on Google Colab by clicking the link below. Additionally, please make "
"sure to set the runtime on Colab to use a GPU and install the below "
"libraries before running."
)
idx = 0
else:
# Add a warning cell at the top of the notebook
warning_template = "\n".join(
[
"<div class='alert alert-{message_class}'>",
"",
"# Install libraries needed for Colab",
"",
"{message}",
"</div>",
self.html,
]
)
message_class = "info"
message = (
"The below installation commands are needed to be run only on "
"Google Colab."
)

dummy_notebook_content = {"cells": []}
add_markdown_cell(
dummy_notebook_content,
warning_template.format(message_class=message_class, message=message),
)
notebook["cells"] = (
notebook["cells"][:idx]
+ dummy_notebook_content["cells"]
+ notebook["cells"][idx:]
)

# Write back updated notebook
with open(notebook_path, "w", encoding="utf-8") as f:
json.dump(notebook, f, ensure_ascii=False, indent=2)

def find_index_of_colab_link(self, notebook):
"""Find the index of the cell containing the Colab link."""
for idx, cell in enumerate(notebook["cells"]):
if cell["cell_type"] == "markdown" and ".. colab-link::" in "".join(
cell.get("source", "")
):
return idx
return 0


def setup(app):
"""Set up the Sphinx extension."""
app.add_config_value("colab_notebook", dict(), "html")
app.add_node(
ColabLinkNode, html=(visit_colab_link_node_html, depart_colab_link_node_html)
)
app.add_directive("colab-link", ColabLinkDirective)

return {
"version": "0.1",
"parallel_read_safe": True,
"parallel_write_safe": True,
}
17 changes: 15 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"myst_sphinx_gallery",
"myst_nb",
"scenario",
"colab_extension",
"sphinx_gallery.gen_gallery",
]

Expand All @@ -52,8 +53,12 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

exclude_patterns = [
"_build",
"Thumbs.db",
".DS_Store",
"auto_examples/anatomical/*.ipynb",
]

autodoc2_packages = ["../src/snake/"]
autodoc2_output_dir = "auto_api"
Expand All @@ -74,6 +79,13 @@
"matplotlib": ("https://matplotlib.org/stable/", None),
}

colab_notebook = {
"dependencies": ["snake-fmri", "mri-nufft[finufft,cufinufft]"],
"branch": "gh-pages",
"path": "examples",
"base_colab_url": "https://colab.research.google.com",
"repo": "https://github.com/paquiteau/snake-fmri",
}

# -- MyST configuration ---------------------------------------------------
#
Expand Down Expand Up @@ -144,4 +156,5 @@
sphinx_gallery_conf = {
"examples_dirs": "../examples", # path to your example scripts
"gallery_dirs": "auto_examples", # path to where to save gallery generated output
"filename_pattern": "/example",
}
2 changes: 1 addition & 1 deletion examples/anatomical/example_EPI_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@

fig, ax = plt.subplots()

axis3dcut(fig, ax, image_data.squeeze().T, None, None, cbar=False, cuts=(40, 60, 40))
axis3dcut(image_data.squeeze().T, None, None, cbar=False, cuts=(40, 60, 40), ax=ax)
plt.show()

# %%
Expand Down
2 changes: 1 addition & 1 deletion examples/anatomical/example_T2s_EPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def reconstruct_frame(filename):
(image_simple, image_T2s, abs(image_simple - image_T2s)),
("simple", "T2s", "diff"),
):
axis3dcut(fig, ax, img, None, None, cbar=True, cuts=(40, 40, 40), width_inches=4)
axis3dcut(img, None, None, cbar=True, cuts=(40, 40, 40), width_inches=4, ax=ax)
ax.set_title(title)

plt.show()
Expand Down
2 changes: 1 addition & 1 deletion examples/anatomical/example_anat_EPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,5 @@

fig, ax = plt.subplots()

axis3dcut(fig, ax, image_data.squeeze().T, None, None, cbar=False, cuts=(40, 60, 40))
axis3dcut(image_data.squeeze().T, None, None, cbar=False, cuts=(40, 60, 40), ax=ax)
plt.show()
2 changes: 1 addition & 1 deletion examples/anatomical/example_generate_phantom.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
axis3dcut(fig, ax, contrast_at_TE.T, None, None, cuts=(60, 60, 60), width_inches=5)
axis3dcut(contrast_at_TE.T, None, None, cuts=(60, 60, 60), ax=ax, width_inches=5)
fig

from ipywidgets import interact
Expand Down
26 changes: 19 additions & 7 deletions examples/anatomical/example_gpu_anat_spirals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# %%
"""
Compare Fourier Model and T2* Model for Stack of Spirals trajectory
===========================================
===================================================================
This examples walks through the elementary components of SNAKE.
Expand Down Expand Up @@ -88,6 +88,9 @@

display_3D_trajectory(traj)

# %%
traj.shape

# %%
# Adding noise in Image
# ----------------------
Expand Down Expand Up @@ -183,12 +186,21 @@
compute_backend=COMPUTE_BACKEND,
)
with NonCartesianFrameDataLoader("example_spiral.mrd") as data_loader:
adjoint_spiral = abs(zer_rec.reconstruct(data_loader, sim_conf)[0])
cs_spiral = abs(seq_rec.reconstruct(data_loader, sim_conf)[0])
adjoint_spiral = abs(zer_rec.reconstruct(data_loader)[0])
cs_spiral = abs(seq_rec.reconstruct(data_loader)[0])
with NonCartesianFrameDataLoader("example_spiral_t2s.mrd") as data_loader:
adjoint_spiral_T2s = abs(zer_rec.reconstruct(data_loader, sim_conf)[0])
cs_spiral_T2s = abs(seq_rec.reconstruct(data_loader, sim_conf)[0])
adjoint_spiral_T2s = abs(zer_rec.reconstruct(data_loader,sim_conf)[0])
cs_spiral_T2s = abs(seq_rec.reconstruct(data_loader)[0])


# %%
with NonCartesianFrameDataLoader("example_spiral.mrd") as data_loader:
traj,data = data_loader.get_kspace_frame(0)

# %%
data.shape

# %%

# %%
# Plotting the result
Expand All @@ -210,7 +222,7 @@
(adjoint_spiral, adjoint_spiral_T2s, abs(adjoint_spiral - adjoint_spiral_T2s)),
("simple", "T2s", "diff"),
):
axis3dcut(fig, ax, img.T, None, None, cbar=True, cuts=(40, 40, 40), width_inches=4)
axis3dcut(img.T, None, None, cbar=True, cuts=(40, 40, 40), ax=ax,width_inches=4)
ax.set_title(title)


Expand All @@ -219,7 +231,7 @@
(cs_spiral, cs_spiral_T2s, abs(cs_spiral - cs_spiral_T2s)),
("simple", "T2s", "diff"),
):
axis3dcut(fig, ax, img.T, None, None, cbar=True, cuts=(40, 40, 40), width_inches=4)
axis3dcut(img.T, None, None, cbar=True, cuts=(40, 40, 40), ax=ax,width_inches=4)
ax.set_title(title + " CS")


Expand Down
12 changes: 8 additions & 4 deletions examples/anatomical/example_gpu_anat_spirals_slice.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# %%
"""
Compare Fourier Model and T2* Model for Stack of Spirals trajectory
===========================================
Compare Fourier Model and T2* Model for 2D Stack of Spirals trajectory
======================================================================
This examples walks through the elementary components of SNAKE.
Here we proceed step by step and use the Python interface. A more integrated
alternative is to use the CLI ``snake-main``
"""

# %%
# .. colab-link::
# :needs_gpu: 1
#
# !pip install mri-nufft[gpunufft] scikit-image

# Imports
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -169,7 +174,6 @@

# %%
shot = traj[18].copy()
print(shot)
nufft = get_operator(NUFFT_BACKEND)(
samples=shot[:, :2],
shape=data_loader.shape[:-1],
Expand All @@ -180,4 +184,4 @@
image = nufft.adj_op(kspace_data)

fig, ax = plt.subplots()
axis3dcut(fig, ax, image, None, cuts=(40, 40, 40))
axis3dcut(abs(image), None, cuts=(40, 40, 40), ax=ax)
2 changes: 0 additions & 2 deletions src/cli-conf/scenario3-tpz.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ reconstructors:





hydra:
job:
chdir: true
Expand Down
2 changes: 1 addition & 1 deletion src/snake/mrd_utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def engine_model(self) -> str:
@property
def slice_2d(self) -> bool:
"""Is the acquisition run on 2D slices."""
return bool(self.header.userParameters.userParameterString[1].value)
return self.header.userParameters.userParameterString[1].value == "True"

#############
# Get data #
Expand Down
2 changes: 1 addition & 1 deletion src/snake/toolkit/reconstructors/pysap.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def setup(self, sim_conf: SimConfig = None, shape: tuple[int] = None) -> None:
def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
"""Reconstruct with Sequential."""
shape = data_loader.shape
self.setup(shape)
self.setup(shape=shape)
from fmri.operators.gradient import (
GradAnalysis,
GradSynthesis,
Expand Down

0 comments on commit 35a2375

Please sign in to comment.