diff --git a/README.md b/README.md index c6c8eec..6e7be45 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Dataloader for JAX + ![Python](https://img.shields.io/pypi/pyversions/jax-dataloader.svg) @@ -36,6 +37,8 @@ A minimum `jax-dataloader` example: ``` python import jax_dataloader as jdl +jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility + dataloader = jdl.DataLoader( dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset backend='jax', # Use 'jax' backend for loading data @@ -61,9 +64,7 @@ or install directly from the repository: pip install git+https://github.com/BirkhoffG/jax-dataloader.git ``` -
- -> **Note** +> [!NOTE] > > We keep `jax-dataloader`’s dependencies minimum, which only install > `jax` and `plum-dispatch` (for backend dispatching) when installing. @@ -75,8 +76,6 @@ pip install git+https://github.com/BirkhoffG/jax-dataloader.git > You can also run `pip install jax-dataloader[all]` to install > everything (not recommended). -
- ## Usage [`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader) @@ -161,16 +160,12 @@ ecosystems (e.g., built-in datasets. `jax_dataloader` supports directly passing the pytorch Dataset. -
- -> **Note** +> [!NOTE] > > Unfortuantely, the [pytorch > Dataset](https://pytorch.org/docs/stable/data.html) can only work with > `backend=pytorch`. See the belowing example. -
- ``` python from torchvision.datasets import MNIST import numpy as np diff --git a/nbs/index.ipynb b/nbs/index.ipynb index 9a2feed..f66d250 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -50,6 +50,8 @@ "```python\n", "import jax_dataloader as jdl\n", "\n", + "jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility\n", + "\n", "dataloader = jdl.DataLoader(\n", " dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset\n", " backend='jax', # Use 'jax' backend for loading data\n",