-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LightningDataModule to load GeoTIFF files #52
Conversation
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries!
Decoupling the neural network model's unit test from the LightningDataModule by implementing a standalone datapipe fixture instead.
Create a LightningDataModule to load GeoTIFF files. Uses torchdata to create the data pipeline. Using the FileLister DataPipe to iterate over *.tif files in the data/ folder, and do a random 80/20 split for the training and validation set. The GeoTIFF files are read into numpy.ndarrrays using rasterio, and converted to torch.Tensors with the default collate function. Using rasterio instead of rioxarray to reduce an extra layer of overhead in the data loading.
# GeoTIFF - Rasterio | ||
with rasterio.open(fp=filepath) as dataset: | ||
array: np.ndarray = dataset.read() | ||
tensor: torch.Tensor = torch.as_tensor(data=array.astype(dtype="float16")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is float32 to float16 a save tansformation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There will be some loss of floating point precision, but we'll likely be using 16-bit precision training (see https://lightning.ai/docs/pytorch/2.1.0/common/precision_intermediate.html) to speed up the model training, so best to pre-emptively convert the data to float16 dtype here.
""" | ||
# GeoTIFF - Rasterio | ||
with rasterio.open(fp=filepath) as dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you experiment with other file openers? Maybe the loader gets more stable if we use
from skimage import io
im = io.imread(filepath)
or other tif specific loaders like
https://pypi.org/project/tifffile/
Could be worth a shot to see if that helps stabilizing the loader when compared to zarr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So skimage
's imread actually uses tifffile
behind the scenes for reading TIFF files, see https://github.com/scikit-image/scikit-image/blob/441fe68b95a86d4ae2a351311a0c39a4232b6521/skimage/io/_io.py#L16-L68, but I know that tiffile
has some issues with multiprocessing/threading, see cgohlke/tifffile#215. Will also need to see if skimage/tifffile supports reading from s3 buckets directly like rasterio/GDAL. Found this thread cgohlke/tifffile#125 which looks interesting.
Enable setting the number of subprocesses used for data loading. Default to 8 for now, but can be configured on LightningCLI using `python trainer.py fit --data.num_workers=8`.
Contains a build of torchdata that is pre-compiled with the correct AWSSDK extension, and won't result in errors like `ValueError: curlCode: 77, Problem with the SSL CA cert (path? access rights?)`.
Enable setting the path to the folder containing the GeoTIFF data files. Defaults to data/ for now, but can be configured on LightningCLI using `python trainer.py fit --data.data_path=data/56HKH`. Also setting the recursive=True flag to allow for files in nested directories.
Ensure that loading one mini-batch of data from a data folder works. Created two temporary random GeoTIFF files containing arrays of shape (3, 256, 256) in a fixture for the test.
Decided to handle reading from s3 in a separate PR, because it was about 10x slower than reading from a local disk, even from the same us-east-1 region. Specifically, a mini-batch took about 0.2it/s when reading from s3, compared to ~2it/s from a local data folder. Might need to play with some I/O or networking related settings. |
Again, merging directly in the interest of speed. Will refactor to try other drivers (following discussion at #52 (comment)) later. |
What I am changing
How I did it
torchdata
to construct the DataPiperasterio
TODO:
torchdata
dependencyNotes:
rioxarray
to read the GeoTIFFs, but seems a little slower thanrasterio
How you can test it
python trainer.py fit --trainer.max_epochs=20 --trainer.precision=16-mixed --data.data_path=data --data.batch_size=32 --data.num_workers=8
locallyRelated Issues
References: