Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Oct 18, 2024
1 parent fcec10f commit 0a6057b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
15 changes: 5 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Dataloader for JAX


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

![Python](https://img.shields.io/pypi/pyversions/jax-dataloader.svg)
Expand Down Expand Up @@ -36,6 +37,8 @@ A minimum `jax-dataloader` example:
``` python
import jax_dataloader as jdl

jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility

dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
backend='jax', # Use 'jax' backend for loading data
Expand All @@ -61,9 +64,7 @@ or install directly from the repository:
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
```

<div>

> **Note**
> [!NOTE]
>
> We keep `jax-dataloader`’s dependencies minimum, which only install
> `jax` and `plum-dispatch` (for backend dispatching) when installing.
Expand All @@ -75,8 +76,6 @@ pip install git+https://github.com/BirkhoffG/jax-dataloader.git
> You can also run `pip install jax-dataloader[all]` to install
> everything (not recommended).
</div>

## Usage

[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader)
Expand Down Expand Up @@ -161,16 +160,12 @@ ecosystems (e.g.,
built-in datasets. `jax_dataloader` supports directly passing the
pytorch Dataset.

<div>

> **Note**
> [!NOTE]
>
> Unfortuantely, the [pytorch
> Dataset](https://pytorch.org/docs/stable/data.html) can only work with
> `backend=pytorch`. See the belowing example.
</div>

``` python
from torchvision.datasets import MNIST
import numpy as np
Expand Down
2 changes: 2 additions & 0 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
"```python\n",
"import jax_dataloader as jdl\n",
"\n",
"jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility\n",
"\n",
"dataloader = jdl.DataLoader(\n",
" dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset\n",
" backend='jax', # Use 'jax' backend for loading data\n",
Expand Down

0 comments on commit 0a6057b

Please sign in to comment.