From 35a2375146731bd0755348e4a5d79297f60e0787 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 5 Feb 2025 12:05:33 +0100 Subject: [PATCH] feat: update gallery and fix examples. --- .gitignore | 3 +- docs/_ext/colab_extension.py | 175 ++++++++++++++++++ docs/conf.py | 17 +- examples/anatomical/example_EPI_2D.py | 2 +- examples/anatomical/example_T2s_EPI.py | 2 +- examples/anatomical/example_anat_EPI.py | 2 +- .../anatomical/example_generate_phantom.py | 2 +- .../anatomical/example_gpu_anat_spirals.py | 26 ++- .../example_gpu_anat_spirals_slice.py | 12 +- src/cli-conf/scenario3-tpz.yaml | 2 - src/snake/mrd_utils/loader.py | 2 +- src/snake/toolkit/reconstructors/pysap.py | 2 +- 12 files changed, 225 insertions(+), 22 deletions(-) create mode 100644 docs/_ext/colab_extension.py diff --git a/.gitignore b/.gitignore index 38dca291..4647cf06 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ __pycache__/ _version.py *.mrd +**/*.mrd results/ multirun/ outputs/ @@ -33,7 +34,7 @@ docs/auto_*/ docs/sg_execution_times.rst - +jupyter_execute/ htmlcov/ .tox/ .nox/ diff --git a/docs/_ext/colab_extension.py b/docs/_ext/colab_extension.py new file mode 100644 index 00000000..0955ed5f --- /dev/null +++ b/docs/_ext/colab_extension.py @@ -0,0 +1,175 @@ +"""A Sphinx extension to add a button to open a notebook in Google Colab.""" + +from docutils import nodes +from sphinx.util.docutils import SphinxDirective +from sphinx_gallery.notebook import add_code_cell, add_markdown_cell + +import os +import json + + +class ColabLinkNode(nodes.General, nodes.Element): + """A custom docutils node to represent the Colab link.""" + + +def visit_colab_link_node_html(self, node): + self.body.append(node["html"]) + + +def depart_colab_link_node_html(self, node): + pass + + +class ColabLinkDirective(SphinxDirective): + """Directive to insert a link to open a notebook in Google Colab.""" + + has_content = True + option_spec = { + "needs_gpu": int, + } + + def run(self): + """Run the directive.""" + # Determine the path of the current .rst file + rst_file_path = self.env.doc2path(self.env.docname) + rst_file_dir = os.path.dirname(rst_file_path) + + # Determine the notebook file path assuming it is in the same directory as the .rst file + notebook_filename = os.path.basename(rst_file_path).replace(".rst", ".ipynb") + + # Full path to the notebook + notebook_full_path = os.path.join(rst_file_dir, notebook_filename) + + # Convert the full path back to a relative path from the repo root + # repo_root = self.config.project_root_dir + notebook_repo_relative_path = os.path.relpath( + notebook_full_path, os.path.join(os.getcwd(), "docs") + ) + + config_ext = self.env.config["colab_notebook"] + base_colab_url = config_ext.get( + "base_colab_url", "https://colab.research.google.com" + ) + base_repo = config_ext["repo"] + branch = config_ext["branch"] + path = config_ext["path"] + # Generate the Colab URL based on GitHub repo information + self.colab_url = f"{base_colab_url}/{base_repo}/blob/{branch}/{path}/{notebook_repo_relative_path}" + + # Create the HTML button or link + self.html = f"""
+ + Open In Colab + +
+ """ + self.notebook_modifier(notebook_full_path, "\n".join(self.content)) + + # Create the node to insert the HTML + node = ColabLinkNode(html=self.html) + return [node] + + def notebook_modifier(self, notebook_path, commands): + """Modify the notebook to add a warning about GPU requirement.""" + with open(notebook_path) as f: + notebook = json.load(f) + if "cells" not in notebook: + notebook["cells"] = [] + + # Add a cell to install the required libraries at the position where we have + # colab link + idx = self.find_index_of_colab_link(notebook) + + code_lines = ["# Install libraries"] + code_lines.append(commands) + code_lines.append( + "!pip install" + " ".join(self.env.config["colab_notebook"]["dependencies"]) + ) + dummy_notebook_content = {"cells": []} + add_code_cell( + dummy_notebook_content, + "\n".join(code_lines), + ) + notebook["cells"][idx] = dummy_notebook_content["cells"][0] + + needs_GPU = self.options.get("needs_gpu", False) + if needs_GPU: + # Add a warning cell at the top of the notebook + warning_template = "\n".join( + [ + "
", + "", + "# Need GPU warning", + "", + "{message}", + "
", + self.html, + ] + ) + message_class = "warning" + message = ( + "Running this example requires a GPU, and hence is NOT " + "possible on binder currently We request you to kindly run this notebook " + "on Google Colab by clicking the link below. Additionally, please make " + "sure to set the runtime on Colab to use a GPU and install the below " + "libraries before running." + ) + idx = 0 + else: + # Add a warning cell at the top of the notebook + warning_template = "\n".join( + [ + "
", + "", + "# Install libraries needed for Colab", + "", + "{message}", + "
", + self.html, + ] + ) + message_class = "info" + message = ( + "The below installation commands are needed to be run only on " + "Google Colab." + ) + + dummy_notebook_content = {"cells": []} + add_markdown_cell( + dummy_notebook_content, + warning_template.format(message_class=message_class, message=message), + ) + notebook["cells"] = ( + notebook["cells"][:idx] + + dummy_notebook_content["cells"] + + notebook["cells"][idx:] + ) + + # Write back updated notebook + with open(notebook_path, "w", encoding="utf-8") as f: + json.dump(notebook, f, ensure_ascii=False, indent=2) + + def find_index_of_colab_link(self, notebook): + """Find the index of the cell containing the Colab link.""" + for idx, cell in enumerate(notebook["cells"]): + if cell["cell_type"] == "markdown" and ".. colab-link::" in "".join( + cell.get("source", "") + ): + return idx + return 0 + + +def setup(app): + """Set up the Sphinx extension.""" + app.add_config_value("colab_notebook", dict(), "html") + app.add_node( + ColabLinkNode, html=(visit_colab_link_node_html, depart_colab_link_node_html) + ) + app.add_directive("colab-link", ColabLinkDirective) + + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/conf.py b/docs/conf.py index 430c121e..d22e770a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -43,6 +43,7 @@ "myst_sphinx_gallery", "myst_nb", "scenario", + "colab_extension", "sphinx_gallery.gen_gallery", ] @@ -52,8 +53,12 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", + "auto_examples/anatomical/*.ipynb", +] autodoc2_packages = ["../src/snake/"] autodoc2_output_dir = "auto_api" @@ -74,6 +79,13 @@ "matplotlib": ("https://matplotlib.org/stable/", None), } +colab_notebook = { + "dependencies": ["snake-fmri", "mri-nufft[finufft,cufinufft]"], + "branch": "gh-pages", + "path": "examples", + "base_colab_url": "https://colab.research.google.com", + "repo": "https://github.com/paquiteau/snake-fmri", +} # -- MyST configuration --------------------------------------------------- # @@ -144,4 +156,5 @@ sphinx_gallery_conf = { "examples_dirs": "../examples", # path to your example scripts "gallery_dirs": "auto_examples", # path to where to save gallery generated output + "filename_pattern": "/example", } diff --git a/examples/anatomical/example_EPI_2D.py b/examples/anatomical/example_EPI_2D.py index 42d5c8cb..ded4537e 100644 --- a/examples/anatomical/example_EPI_2D.py +++ b/examples/anatomical/example_EPI_2D.py @@ -137,7 +137,7 @@ fig, ax = plt.subplots() -axis3dcut(fig, ax, image_data.squeeze().T, None, None, cbar=False, cuts=(40, 60, 40)) +axis3dcut(image_data.squeeze().T, None, None, cbar=False, cuts=(40, 60, 40), ax=ax) plt.show() # %% diff --git a/examples/anatomical/example_T2s_EPI.py b/examples/anatomical/example_T2s_EPI.py index 98bcf992..f03eeabe 100644 --- a/examples/anatomical/example_T2s_EPI.py +++ b/examples/anatomical/example_T2s_EPI.py @@ -152,7 +152,7 @@ def reconstruct_frame(filename): (image_simple, image_T2s, abs(image_simple - image_T2s)), ("simple", "T2s", "diff"), ): - axis3dcut(fig, ax, img, None, None, cbar=True, cuts=(40, 40, 40), width_inches=4) + axis3dcut(img, None, None, cbar=True, cuts=(40, 40, 40), width_inches=4, ax=ax) ax.set_title(title) plt.show() diff --git a/examples/anatomical/example_anat_EPI.py b/examples/anatomical/example_anat_EPI.py index 3a22342f..8ee18b1f 100644 --- a/examples/anatomical/example_anat_EPI.py +++ b/examples/anatomical/example_anat_EPI.py @@ -154,5 +154,5 @@ fig, ax = plt.subplots() -axis3dcut(fig, ax, image_data.squeeze().T, None, None, cbar=False, cuts=(40, 60, 40)) +axis3dcut(image_data.squeeze().T, None, None, cbar=False, cuts=(40, 60, 40), ax=ax) plt.show() diff --git a/examples/anatomical/example_generate_phantom.py b/examples/anatomical/example_generate_phantom.py index fc7f3dc5..9dab174c 100644 --- a/examples/anatomical/example_generate_phantom.py +++ b/examples/anatomical/example_generate_phantom.py @@ -33,7 +33,7 @@ import matplotlib.pyplot as plt fig, ax = plt.subplots() -axis3dcut(fig, ax, contrast_at_TE.T, None, None, cuts=(60, 60, 60), width_inches=5) +axis3dcut(contrast_at_TE.T, None, None, cuts=(60, 60, 60), ax=ax, width_inches=5) fig from ipywidgets import interact diff --git a/examples/anatomical/example_gpu_anat_spirals.py b/examples/anatomical/example_gpu_anat_spirals.py index 5526f2d8..034a9165 100644 --- a/examples/anatomical/example_gpu_anat_spirals.py +++ b/examples/anatomical/example_gpu_anat_spirals.py @@ -1,7 +1,7 @@ # %% """ Compare Fourier Model and T2* Model for Stack of Spirals trajectory -=========================================== +=================================================================== This examples walks through the elementary components of SNAKE. @@ -88,6 +88,9 @@ display_3D_trajectory(traj) +# %% +traj.shape + # %% # Adding noise in Image # ---------------------- @@ -183,12 +186,21 @@ compute_backend=COMPUTE_BACKEND, ) with NonCartesianFrameDataLoader("example_spiral.mrd") as data_loader: - adjoint_spiral = abs(zer_rec.reconstruct(data_loader, sim_conf)[0]) - cs_spiral = abs(seq_rec.reconstruct(data_loader, sim_conf)[0]) + adjoint_spiral = abs(zer_rec.reconstruct(data_loader)[0]) + cs_spiral = abs(seq_rec.reconstruct(data_loader)[0]) with NonCartesianFrameDataLoader("example_spiral_t2s.mrd") as data_loader: - adjoint_spiral_T2s = abs(zer_rec.reconstruct(data_loader, sim_conf)[0]) - cs_spiral_T2s = abs(seq_rec.reconstruct(data_loader, sim_conf)[0]) + adjoint_spiral_T2s = abs(zer_rec.reconstruct(data_loader,sim_conf)[0]) + cs_spiral_T2s = abs(seq_rec.reconstruct(data_loader)[0]) + +# %% +with NonCartesianFrameDataLoader("example_spiral.mrd") as data_loader: + traj,data = data_loader.get_kspace_frame(0) + +# %% +data.shape + +# %% # %% # Plotting the result @@ -210,7 +222,7 @@ (adjoint_spiral, adjoint_spiral_T2s, abs(adjoint_spiral - adjoint_spiral_T2s)), ("simple", "T2s", "diff"), ): - axis3dcut(fig, ax, img.T, None, None, cbar=True, cuts=(40, 40, 40), width_inches=4) + axis3dcut(img.T, None, None, cbar=True, cuts=(40, 40, 40), ax=ax,width_inches=4) ax.set_title(title) @@ -219,7 +231,7 @@ (cs_spiral, cs_spiral_T2s, abs(cs_spiral - cs_spiral_T2s)), ("simple", "T2s", "diff"), ): - axis3dcut(fig, ax, img.T, None, None, cbar=True, cuts=(40, 40, 40), width_inches=4) + axis3dcut(img.T, None, None, cbar=True, cuts=(40, 40, 40), ax=ax,width_inches=4) ax.set_title(title + " CS") diff --git a/examples/anatomical/example_gpu_anat_spirals_slice.py b/examples/anatomical/example_gpu_anat_spirals_slice.py index 264c220a..af5afaa6 100644 --- a/examples/anatomical/example_gpu_anat_spirals_slice.py +++ b/examples/anatomical/example_gpu_anat_spirals_slice.py @@ -1,16 +1,21 @@ # %% """ -Compare Fourier Model and T2* Model for Stack of Spirals trajectory -=========================================== +Compare Fourier Model and T2* Model for 2D Stack of Spirals trajectory +====================================================================== This examples walks through the elementary components of SNAKE. Here we proceed step by step and use the Python interface. A more integrated alternative is to use the CLI ``snake-main`` + """ # %% +# .. colab-link:: +# :needs_gpu: 1 +# +# !pip install mri-nufft[gpunufft] scikit-image # Imports import matplotlib.pyplot as plt @@ -169,7 +174,6 @@ # %% shot = traj[18].copy() -print(shot) nufft = get_operator(NUFFT_BACKEND)( samples=shot[:, :2], shape=data_loader.shape[:-1], @@ -180,4 +184,4 @@ image = nufft.adj_op(kspace_data) fig, ax = plt.subplots() -axis3dcut(fig, ax, image, None, cuts=(40, 40, 40)) +axis3dcut(abs(image), None, cuts=(40, 40, 40), ax=ax) diff --git a/src/cli-conf/scenario3-tpz.yaml b/src/cli-conf/scenario3-tpz.yaml index 7d9368d7..4294abb2 100644 --- a/src/cli-conf/scenario3-tpz.yaml +++ b/src/cli-conf/scenario3-tpz.yaml @@ -63,8 +63,6 @@ reconstructors: - - hydra: job: chdir: true diff --git a/src/snake/mrd_utils/loader.py b/src/snake/mrd_utils/loader.py index 5cbf7ac5..16c3c2f5 100644 --- a/src/snake/mrd_utils/loader.py +++ b/src/snake/mrd_utils/loader.py @@ -219,7 +219,7 @@ def engine_model(self) -> str: @property def slice_2d(self) -> bool: """Is the acquisition run on 2D slices.""" - return bool(self.header.userParameters.userParameterString[1].value) + return self.header.userParameters.userParameterString[1].value == "True" ############# # Get data # diff --git a/src/snake/toolkit/reconstructors/pysap.py b/src/snake/toolkit/reconstructors/pysap.py index d2e3049a..431a2318 100644 --- a/src/snake/toolkit/reconstructors/pysap.py +++ b/src/snake/toolkit/reconstructors/pysap.py @@ -260,7 +260,7 @@ def setup(self, sim_conf: SimConfig = None, shape: tuple[int] = None) -> None: def reconstruct(self, data_loader: MRDLoader) -> np.ndarray: """Reconstruct with Sequential.""" shape = data_loader.shape - self.setup(shape) + self.setup(shape=shape) from fmri.operators.gradient import ( GradAnalysis, GradSynthesis,