Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningDataModule to load GeoTIFF files #52

Merged
merged 7 commits into from
Nov 28, 2023
Merged

LightningDataModule to load GeoTIFF files #52

merged 7 commits into from
Nov 28, 2023

Conversation

weiji14
Copy link
Contributor

@weiji14 weiji14 commented Nov 24, 2023

What I am changing

How I did it

  • Using torchdata to construct the DataPipe
  • GeoTIFF files are read using rasterio
  • Train/validation split is 80%/20%

TODO:

  • Install torchdata dependency
  • Initial implementation of GeoTIFFDataPipeModule
  • Add extra parameters to control DataLoader (e.g. num_workers)
  • Add unit tests
  • Refactor to load GeoTIFF data from s3 bucket instead of local drive
  • etc

Notes:

  • Have tried using rioxarray to read the GeoTIFFs, but seems a little slower than rasterio
  • Also experimented with loading from NetCDF files using xarray's h5netcdf engine (about same speed as rioxarray loading GeoTIFF)
  • Fastest seems to be loading from Zarr, but would require re-formatting of data, so leaving that to a future PR.

How you can test it

  • Download the GeoTIFF files from the s3 bucket (TODO add instructions)
  • Run python trainer.py fit --trainer.max_epochs=20 --trainer.precision=16-mixed --data.data_path=data --data.batch_size=32 --data.num_workers=8 locally

Related Issues

References:

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries!
Decoupling the neural network model's unit test from the LightningDataModule by implementing a standalone datapipe fixture instead.
Create a LightningDataModule to load GeoTIFF files. Uses torchdata to create the data pipeline. Using the FileLister DataPipe to iterate over *.tif files in the data/ folder, and do a random 80/20 split for the training and validation set. The GeoTIFF files are read into numpy.ndarrrays using rasterio, and converted to torch.Tensors with the default collate function. Using rasterio instead of rioxarray to reduce an extra layer of overhead in the data loading.
@weiji14 weiji14 added the data-pipeline Pull Requests about the data pipeline label Nov 24, 2023
@weiji14 weiji14 self-assigned this Nov 24, 2023
@weiji14 weiji14 changed the title Implement GeoTIFFDataPipeModule LightningDataModule to load GeoTIFF files Nov 24, 2023
# GeoTIFF - Rasterio
with rasterio.open(fp=filepath) as dataset:
array: np.ndarray = dataset.read()
tensor: torch.Tensor = torch.as_tensor(data=array.astype(dtype="float16"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is float32 to float16 a save tansformation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be some loss of floating point precision, but we'll likely be using 16-bit precision training (see https://lightning.ai/docs/pytorch/2.1.0/common/precision_intermediate.html) to speed up the model training, so best to pre-emptively convert the data to float16 dtype here.

"""
# GeoTIFF - Rasterio
with rasterio.open(fp=filepath) as dataset:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you experiment with other file openers? Maybe the loader gets more stable if we use

from skimage import io
im = io.imread(filepath)

or other tif specific loaders like

https://pypi.org/project/tifffile/

Could be worth a shot to see if that helps stabilizing the loader when compared to zarr.

Copy link
Contributor Author

@weiji14 weiji14 Nov 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So skimage's imread actually uses tifffile behind the scenes for reading TIFF files, see https://github.com/scikit-image/scikit-image/blob/441fe68b95a86d4ae2a351311a0c39a4232b6521/skimage/io/_io.py#L16-L68, but I know that tiffile has some issues with multiprocessing/threading, see cgohlke/tifffile#215. Will also need to see if skimage/tifffile supports reading from s3 buckets directly like rasterio/GDAL. Found this thread cgohlke/tifffile#125 which looks interesting.

Enable setting the number of subprocesses used for data loading. Default to 8 for now, but can be configured on LightningCLI using `python trainer.py fit --data.num_workers=8`.
Contains a build of torchdata that is pre-compiled with the correct AWSSDK extension, and won't result in errors like `ValueError: curlCode: 77, Problem with the SSL CA cert (path? access rights?)`.
Enable setting the path to the folder containing the GeoTIFF data files. Defaults to data/ for now, but can be configured on LightningCLI using `python trainer.py fit --data.data_path=data/56HKH`. Also setting the recursive=True flag to allow for files in nested directories.
Ensure that loading one mini-batch of data from a data folder works. Created two temporary random GeoTIFF files containing arrays of shape (3, 256, 256) in a fixture for the test.
@weiji14 weiji14 marked this pull request as ready for review November 28, 2023 07:17
@weiji14
Copy link
Contributor Author

weiji14 commented Nov 28, 2023

  • Refactor to load GeoTIFF data from s3 bucket instead of local drive

Decided to handle reading from s3 in a separate PR, because it was about 10x slower than reading from a local disk, even from the same us-east-1 region. Specifically, a mini-batch took about 0.2it/s when reading from s3, compared to ~2it/s from a local data folder. Might need to play with some I/O or networking related settings.

@weiji14
Copy link
Contributor Author

weiji14 commented Nov 28, 2023

Again, merging directly in the interest of speed. Will refactor to try other drivers (following discussion at #52 (comment)) later.

@weiji14 weiji14 merged commit be426c1 into main Nov 28, 2023
1 check passed
@weiji14 weiji14 deleted the geotiff-datapipe branch November 28, 2023 07:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data-pipeline Pull Requests about the data pipeline
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants