diff --git a/README.md b/README.md
index c6c8eec..6e7be45 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,6 @@
# Dataloader for JAX
+
![Python](https://img.shields.io/pypi/pyversions/jax-dataloader.svg)
@@ -36,6 +37,8 @@ A minimum `jax-dataloader` example:
``` python
import jax_dataloader as jdl
+jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility
+
dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
backend='jax', # Use 'jax' backend for loading data
@@ -61,9 +64,7 @@ or install directly from the repository:
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
```
-
-
-> **Note**
+> [!NOTE]
>
> We keep `jax-dataloader`’s dependencies minimum, which only install
> `jax` and `plum-dispatch` (for backend dispatching) when installing.
@@ -75,8 +76,6 @@ pip install git+https://github.com/BirkhoffG/jax-dataloader.git
> You can also run `pip install jax-dataloader[all]` to install
> everything (not recommended).
-
-
## Usage
[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader)
@@ -161,16 +160,12 @@ ecosystems (e.g.,
built-in datasets. `jax_dataloader` supports directly passing the
pytorch Dataset.
-
-
-> **Note**
+> [!NOTE]
>
> Unfortuantely, the [pytorch
> Dataset](https://pytorch.org/docs/stable/data.html) can only work with
> `backend=pytorch`. See the belowing example.
-
-
``` python
from torchvision.datasets import MNIST
import numpy as np
diff --git a/nbs/index.ipynb b/nbs/index.ipynb
index 9a2feed..f66d250 100644
--- a/nbs/index.ipynb
+++ b/nbs/index.ipynb
@@ -50,6 +50,8 @@
"```python\n",
"import jax_dataloader as jdl\n",
"\n",
+ "jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility\n",
+ "\n",
"dataloader = jdl.DataLoader(\n",
" dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset\n",
" backend='jax', # Use 'jax' backend for loading data\n",