From be426c1de2844401202132c0ebbed7092424fc2d Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Tue, 28 Nov 2023 20:25:09 +1300 Subject: [PATCH] LightningDataModule to load GeoTIFF files (#52) * :heavy_plus_sign: Add torchdata A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries! * :recycle: Refactor test_model_vit to use datapipe fixture Decoupling the neural network model's unit test from the LightningDataModule by implementing a standalone datapipe fixture instead. * :sparkles: Implement GeoTIFFDataPipeModule Create a LightningDataModule to load GeoTIFF files. Uses torchdata to create the data pipeline. Using the FileLister DataPipe to iterate over *.tif files in the data/ folder, and do a random 80/20 split for the training and validation set. The GeoTIFF files are read into numpy.ndarrrays using rasterio, and converted to torch.Tensors with the default collate function. Using rasterio instead of rioxarray to reduce an extra layer of overhead in the data loading. * :thread: Allow configuring num_workers in DataLoader Enable setting the number of subprocesses used for data loading. Default to 8 for now, but can be configured on LightningCLI using `python trainer.py fit --data.num_workers=8`. * :pushpin: Install torchdata=0.7.1 from conda-forge instead of PyPI Contains a build of torchdata that is pre-compiled with the correct AWSSDK extension, and won't result in errors like `ValueError: curlCode: 77, Problem with the SSL CA cert (path? access rights?)`. * :wrench: Allow configuring data path containing the GeoTIFF files Enable setting the path to the folder containing the GeoTIFF data files. Defaults to data/ for now, but can be configured on LightningCLI using `python trainer.py fit --data.data_path=data/56HKH`. Also setting the recursive=True flag to allow for files in nested directories. * :white_check_mark: Add unit test for GeoTIFFDataModule Ensure that loading one mini-batch of data from a data folder works. Created two temporary random GeoTIFF files containing arrays of shape (3, 256, 256) in a fixture for the test. --- conda-lock.yml | 249 ++++++++++++++++++----------------- environment.yml | 1 + src/README.md | 2 +- src/datamodule.py | 74 ++++++++--- src/tests/test_datamodule.py | 56 ++++++++ src/tests/test_model.py | 28 +++- trainer.py | 4 +- 7 files changed, 266 insertions(+), 148 deletions(-) create mode 100644 src/tests/test_datamodule.py diff --git a/conda-lock.yml b/conda-lock.yml index 79a06285..2c9c3372 100644 --- a/conda-lock.yml +++ b/conda-lock.yml @@ -13,7 +13,7 @@ version: 1 metadata: content_hash: - linux-64: 236c8d893324750c2fa101feb8777197e44eaebbdf3ca40815a0776206ae5163 + linux-64: f3e19d072b1b3c728c23e5d3907d2cb4145ab3042eacb616e10e3a8109f36243 channels: - url: conda-forge used_env_vars: [] @@ -59,7 +59,7 @@ package: category: main optional: false - name: aiohttp - version: 3.9.0 + version: 3.9.1 manager: conda platform: linux-64 dependencies: @@ -71,10 +71,10 @@ package: python: '>=3.11,<3.12.0a0' python_abi: 3.11.* yarl: '>=1.0,<2.0' - url: https://conda.anaconda.org/conda-forge/linux-64/aiohttp-3.9.0-py311h459d7ec_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/aiohttp-3.9.1-py311h459d7ec_0.conda hash: - md5: f8622107430b609e3956250ed601ad30 - sha256: f1e0233e814200c3bcfa081f1d0adc3200d334475aeac91917187026434a4b3f + md5: a51ceb9a9219e3c11af56b2b77794839 + sha256: 3f16a6ff0ce0c137de2bc63ac3758616f7232f75ff630a3955f02af2452b0490 category: main optional: false - name: aiosignal @@ -516,22 +516,22 @@ package: category: main optional: false - name: boto3 - version: 1.29.6 + version: 1.29.7 manager: conda platform: linux-64 dependencies: - botocore: '>=1.32.6,<1.33.0' + botocore: '>=1.32.7,<1.33.0' jmespath: '>=0.7.1,<2.0.0' python: '>=3.7' - s3transfer: '>=0.7.0,<0.8.0' - url: https://conda.anaconda.org/conda-forge/noarch/boto3-1.29.6-pyhd8ed1ab_0.conda + s3transfer: '>=0.8.0,<0.9.0' + url: https://conda.anaconda.org/conda-forge/noarch/boto3-1.29.7-pyhd8ed1ab_0.conda hash: - md5: 0cbc42e6f9557edfea7f552c644027d7 - sha256: 7e3c31d99afff810f0d68b4d7c957be34917d1d4bfc76a34620dee0bc35eec1d + md5: f20e114fa86dbdca7534aa0af7664c0e + sha256: 44fe95ed89d0db0c421968d1e79c8372a15136a9146b9a9bd0c66795194eb81d category: main optional: false - name: botocore - version: 1.32.6 + version: 1.32.7 manager: conda platform: linux-64 dependencies: @@ -539,10 +539,10 @@ package: python: '>=3.7' python-dateutil: '>=2.1,<3.0.0' urllib3: '>=1.25.4,<1.27' - url: https://conda.anaconda.org/conda-forge/noarch/botocore-1.32.6-pyhd8ed1ab_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/botocore-1.32.7-pyhd8ed1ab_0.conda hash: - md5: a6747e9f4cb2ca858735017cf783fe08 - sha256: 534d61c7d2c2184d59b828dc582600482ed12c08922125f07f454f5d91d85573 + md5: 9e5a1d24c1fcd8017ac713c28dffc871 + sha256: cf405020da251ff2007d5fbc5f1ee61966e925e8d7e9a12525a7ac042afb038d category: main optional: false - name: brotli-python @@ -716,20 +716,20 @@ package: category: main optional: false - name: cfitsio - version: 4.3.0 + version: 4.3.1 manager: conda platform: linux-64 dependencies: bzip2: '>=1.0.8,<2.0a0' - libcurl: '>=8.2.0,<9.0a0' + libcurl: '>=8.4.0,<9.0a0' libgcc-ng: '>=12' libgfortran-ng: '' libgfortran5: '>=12.3.0' libzlib: '>=1.2.13,<1.3.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/cfitsio-4.3.0-hbdc6101_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/cfitsio-4.3.1-hbdc6101_0.conda hash: - md5: 797554b8b7603011e8677884381fbcc5 - sha256: c74938f1ade9b8f37b9fa8cc98a5b9262b325506f41d7492ad1d00146e0f1d08 + md5: dcea02841b33a9c49f74ca9328de919a + sha256: b91003bff71351a0132c84d69fbb5afcfa90e57d83f76a180c6a5a0289099fb1 category: main optional: false - name: charset-normalizer @@ -1384,10 +1384,10 @@ package: manager: conda platform: linux-64 dependencies: {} - url: https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-hab24e00_0.tar.bz2 + url: https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-h77eed37_1.conda hash: - md5: 19410c3df09dfb12d1206132a1d357c5 - sha256: 470d5db54102bd51dbb0c5990324a2f4a0bc976faa493b22193338adb9882e2e + md5: 6185f640c43843e5ad6fd1c5372c3f80 + sha256: 056c85b482d58faab5fd4670b6c1f5df0986314cca3bc831d458b22e4ef2c792 category: main optional: false - name: fontconfig @@ -1512,13 +1512,13 @@ package: libstdcxx-ng: '>=12' libxml2: '>=2.11.6,<2.12.0a0' numpy: '>=1.23.5,<2.0a0' - openssl: '>=3.1.4,<4.0a0' + openssl: '>=3.2.0,<4.0a0' python: '>=3.11,<3.12.0a0' python_abi: 3.11.* - url: https://conda.anaconda.org/conda-forge/linux-64/gdal-3.7.3-py311h815a124_6.conda + url: https://conda.anaconda.org/conda-forge/linux-64/gdal-3.7.3-py311h815a124_8.conda hash: - md5: f92a59621633603d7617c3c5305dc28d - sha256: a248a8b4439bb5a0bd22842ad7dc1177b3acda538d5c011277c4c625909ab8bb + md5: e46623f8e642a7995e5dae80e399e4be + sha256: c7ca6100fbefb72b87feb4ebc7461f498a919e00c7afbf46f4316299a182e8d4 category: main optional: false - name: geopandas-base @@ -1538,16 +1538,16 @@ package: category: main optional: false - name: geos - version: 3.12.0 + version: 3.12.1 manager: conda platform: linux-64 dependencies: libgcc-ng: '>=12' libstdcxx-ng: '>=12' - url: https://conda.anaconda.org/conda-forge/linux-64/geos-3.12.0-h59595ed_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/geos-3.12.1-h59595ed_0.conda hash: - md5: 3fdf79ef322c8379ae83be491d805369 - sha256: c80ff0ed71db0d56567ee87df28bc442b596330ac241ab86f488e3139f0e2cae + md5: 8c0f4f71f5a59ceb0c6fa9f51501066d + sha256: 2593b255cb9c4639d6ea261c47aaed1380216a366546f0468e95c36c2afd1c1a category: main optional: false - name: geotiff @@ -1801,15 +1801,15 @@ package: category: main optional: false - name: idna - version: '3.4' + version: '3.6' manager: conda platform: linux-64 dependencies: python: '>=3.6' - url: https://conda.anaconda.org/conda-forge/noarch/idna-3.4-pyhd8ed1ab_0.tar.bz2 + url: https://conda.anaconda.org/conda-forge/noarch/idna-3.6-pyhd8ed1ab_0.conda hash: - md5: 34272b248891bddccc64479f9a7fffed - sha256: 9887c35c374ec1847f167292d3fde023cb4c994a4ceeec283072b95440131f09 + md5: 1a76f09108576397c41c0b0c5bd84134 + sha256: 6ee4c986d69ce61e60a20b2459b6f2027baeba153f0a64995fd3cb47c2cc7e07 category: main optional: false - name: importlib-metadata @@ -1905,27 +1905,27 @@ package: category: main optional: false - name: ipython - version: 8.17.2 + version: 8.18.1 manager: conda platform: linux-64 dependencies: - __linux: '' + __unix: '' decorator: '' exceptiongroup: '' jedi: '>=0.16' matplotlib-inline: '' pexpect: '>4.3' pickleshare: '' - prompt_toolkit: '>=3.0.30,<3.1.0,!=3.0.37' + prompt-toolkit: '>=3.0.30,<3.1.0,!=3.0.37' pygments: '>=2.4.0' python: '>=3.9' stack_data: '' traitlets: '>=5' typing_extensions: '' - url: https://conda.anaconda.org/conda-forge/noarch/ipython-8.17.2-pyh41d4057_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/ipython-8.18.1-pyh31011fe_1.conda hash: - md5: f39d0b60e268fe547f1367edbab457d4 - sha256: 31322d58f412787f5beeb01db4d16f10f8ae4e0cc2ec99fafef1e690374fe298 + md5: ac2f9c2e10c2e90e8d135cef51f9753a + sha256: 67490e640faa372d663a5c5cd2d61f417cce22a019a4de82a9e5ddb1cf2ee181 category: main optional: false - name: isoduration @@ -2156,17 +2156,17 @@ package: category: main optional: false - name: jupyter-lsp - version: 2.2.0 + version: 2.2.1 manager: conda platform: linux-64 dependencies: importlib-metadata: '>=4.8.3' jupyter_server: '>=1.1.2' python: '>=3.8' - url: https://conda.anaconda.org/conda-forge/noarch/jupyter-lsp-2.2.0-pyhd8ed1ab_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/jupyter-lsp-2.2.1-pyhd8ed1ab_0.conda hash: - md5: 38589f4104d11f2a59ff01a9f4e3bfb3 - sha256: 16fc7b40024adece716ba7227e5c123a2deccc13f946a10d9a3270493908d11c + md5: d1a5efc65bfabc3bfebf4d3a204da897 + sha256: 0f995f60609fb50db74bed3637165ad202cf091ec0804519c11b6cffce901e88 category: main optional: false - name: jupyter_client @@ -2222,7 +2222,7 @@ package: category: main optional: false - name: jupyter_server - version: 2.10.1 + version: 2.11.1 manager: conda platform: linux-64 dependencies: @@ -2245,10 +2245,10 @@ package: tornado: '>=6.2.0' traitlets: '>=5.6.0' websocket-client: '' - url: https://conda.anaconda.org/conda-forge/noarch/jupyter_server-2.10.1-pyhd8ed1ab_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/jupyter_server-2.11.1-pyhd8ed1ab_0.conda hash: - md5: 7d15498584d83de3b357425e37086397 - sha256: b8b55ee57785b39a9096884bfd1da3858da8f27764572321d51a3dd0a990de86 + md5: 0699b715659c026f7f81c27d0e744205 + sha256: 605825c0e2d5af7935b37319b9a46ff39e081e7a0f4dc973f0dd583f41c69ce5 category: main optional: false - name: jupyter_server_terminals @@ -2608,14 +2608,14 @@ package: category: main optional: false - name: libboost-headers - version: 1.82.0 + version: 1.83.0 manager: conda platform: linux-64 dependencies: {} - url: https://conda.anaconda.org/conda-forge/linux-64/libboost-headers-1.82.0-ha770c72_6.conda + url: https://conda.anaconda.org/conda-forge/linux-64/libboost-headers-1.83.0-ha770c72_0.conda hash: - md5: a943dcb8fd22cf23ce901ac84f6538c2 - sha256: c996950b85808115ea833e577a0af2969dbb0378c299560c2b945401a7770823 + md5: 1fc57b3ba24d18cc75f431d7feb2c785 + sha256: aaa194e8b7ba401e6507a2f6dc0714d2f8f5a9951f8be18b96c250b0a1175982 category: main optional: false - name: libbrotlicommon @@ -2872,9 +2872,9 @@ package: dependencies: __glibc: '>=2.17,<3.0.a0' blosc: '>=1.21.5,<2.0a0' - cfitsio: '>=4.3.0,<4.3.1.0a0' + cfitsio: '>=4.3.1,<4.3.2.0a0' freexl: '>=2.0.0,<3.0a0' - geos: '>=3.12.0,<3.12.1.0a0' + geos: '>=3.12.1,<3.12.2.0a0' geotiff: '>=1.7.1,<1.8.0a0' giflib: '>=5.2.1,<5.3.0a0' hdf4: '>=4.2.15,<4.2.16.0a0' @@ -2894,7 +2894,7 @@ package: libpng: '>=1.6.39,<1.7.0a0' libpq: '>=16.1,<17.0a0' libspatialite: '>=5.1.0,<5.2.0a0' - libsqlite: '>=3.44.0,<4.0a0' + libsqlite: '>=3.44.2,<4.0a0' libstdcxx-ng: '>=12' libtiff: '>=4.6.0,<4.7.0a0' libuuid: '>=2.38.1,<3.0a0' @@ -2903,7 +2903,7 @@ package: libzlib: '>=1.2.13,<1.3.0a0' lz4-c: '>=1.9.3,<1.10.0a0' openjpeg: '>=2.5.0,<3.0a0' - openssl: '>=3.1.4,<4.0a0' + openssl: '>=3.2.0,<4.0a0' pcre2: '>=10.42,<10.43.0a0' poppler: '>=23.11.0,<23.12.0a0' postgresql: '' @@ -2912,10 +2912,10 @@ package: xerces-c: '>=3.2.4,<3.3.0a0' xz: '>=5.2.6,<6.0a0' zstd: '>=1.5.5,<1.6.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/libgdal-3.7.3-h5cd9125_6.conda + url: https://conda.anaconda.org/conda-forge/linux-64/libgdal-3.7.3-h11296eb_8.conda hash: - md5: b46b5dbb938860bf6a88f658f89cac42 - sha256: c248efda55029a93cefb8d1b88eb6c7ccc97dc2b205bf5c5ffcc57c3f5022fce + md5: af74cfb737d4b1f8baba662976769111 + sha256: 3574fa8737fcb077eed391952a7275b1c8c0fdfab1812e38b0bf8161bc036717 category: main optional: false - name: libgfortran-ng @@ -3296,13 +3296,13 @@ package: manager: conda platform: linux-64 dependencies: - geos: '>=3.12.0,<3.12.1.0a0' + geos: '>=3.12.1,<3.12.2.0a0' libgcc-ng: '>=12' libstdcxx-ng: '>=12' - url: https://conda.anaconda.org/conda-forge/linux-64/librttopo-1.1.0-hb58d41b_14.conda + url: https://conda.anaconda.org/conda-forge/linux-64/librttopo-1.1.0-h8917695_15.conda hash: - md5: 264f9a3a4ea52c8f4d3e8ae1213a3335 - sha256: a87307e9c8fb446eb7a1698d9ab40e590ba7e55de669b59f5751c48c2b320827 + md5: 20c3c14bc491f30daecaa6f73e2223ae + sha256: 03e248787162a1804683c614c0681c2488fa6d9f353cb32e2f8c1158157165ea category: main optional: false - name: libsodium @@ -3323,33 +3323,33 @@ package: platform: linux-64 dependencies: freexl: '>=2.0.0,<3.0a0' - geos: '>=3.12.0,<3.12.1.0a0' + geos: '>=3.12.1,<3.12.2.0a0' libgcc-ng: '>=12' librttopo: '>=1.1.0,<1.2.0a0' - libsqlite: '>=3.44.0,<4.0a0' + libsqlite: '>=3.44.1,<4.0a0' libstdcxx-ng: '>=12' - libxml2: '>=2.11.5,<2.12.0a0' + libxml2: '>=2.11.6,<2.12.0a0' libzlib: '>=1.2.13,<1.3.0a0' proj: '>=9.3.0,<9.3.1.0a0' sqlite: '' zlib: '' - url: https://conda.anaconda.org/conda-forge/linux-64/libspatialite-5.1.0-h090f1da_1.conda + url: https://conda.anaconda.org/conda-forge/linux-64/libspatialite-5.1.0-h7385560_2.conda hash: - md5: 9a2d6acaa8ce6d53a150248e7b11165e - sha256: c00eb70e8cf3778bffd04a9551e205e399d16e83a04f55ec392c3163b93d4feb + md5: 4260750b280f6f7c38a4459bf0d919ff + sha256: 5efbd5bc05ebaa6bb66b7e408d2998a4116029c2e3c2b0336e29267488d93a45 category: main optional: false - name: libsqlite - version: 3.44.1 + version: 3.44.2 manager: conda platform: linux-64 dependencies: libgcc-ng: '>=12' libzlib: '>=1.2.13,<1.3.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.1-h2797004_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.44.2-h2797004_0.conda hash: - md5: b4ad86d2527b890e43ff2efc68b239f4 - sha256: c37bb6ec8b09f690d84e8f14fabb75e00c221d11a256137d5b206e26f37e9483 + md5: 3b6a9f225c3dbe0d24f4fedd4625c5bf + sha256: ee2c4d724a3ed60d5b458864d66122fb84c6ce1df62f735f90d8db17b66cd88a category: main optional: false - name: libssh2 @@ -3814,7 +3814,7 @@ package: category: main optional: false - name: msgpack-python - version: 1.0.6 + version: 1.0.7 manager: conda platform: linux-64 dependencies: @@ -3822,10 +3822,10 @@ package: libstdcxx-ng: '>=12' python: '>=3.11,<3.12.0a0' python_abi: 3.11.* - url: https://conda.anaconda.org/conda-forge/linux-64/msgpack-python-1.0.6-py311h9547e67_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/msgpack-python-1.0.7-py311h9547e67_0.conda hash: - md5: e826b71bf3dc8c91ee097663e2bcface - sha256: da765eabe27d8adec5bcce30ea1a0b9308d01640089d039f06bef2cc5ef63f46 + md5: 3ac85c6c226e2a2e4b17864fc2ca88ff + sha256: b12070ce86f108d3dcf2f447dfa76906c4bc15f2d2bf6cef19703ee42768b74a category: main optional: false - name: multidict @@ -4006,20 +4006,20 @@ package: category: main optional: false - name: nss - version: '3.94' + version: '3.95' manager: conda platform: linux-64 dependencies: __glibc: '>=2.17,<3.0.a0' libgcc-ng: '>=12' - libsqlite: '>=3.43.0,<4.0a0' + libsqlite: '>=3.44.2,<4.0a0' libstdcxx-ng: '>=12' libzlib: '>=1.2.13,<1.3.0a0' nspr: '>=4.35,<5.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/nss-3.94-h1d7d5a4_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/nss-3.95-h1d7d5a4_0.conda hash: - md5: 7caef74bbfa730e014b20f0852068509 - sha256: c9b7910fc554c6550905b9150f4c8230e973ca63f41b42f2c18a49e8aa458e78 + md5: d3a8067adcc45a923f4b1987c91d69da + sha256: 02d8e38b4708ce707e51084d0dff7286e6e6d24d1bf32ebbda7710fac4a0581e category: main optional: false - name: numcodecs @@ -4040,7 +4040,7 @@ package: category: main optional: false - name: numpy - version: 1.26.0 + version: 1.26.2 manager: conda platform: linux-64 dependencies: @@ -4051,10 +4051,10 @@ package: libstdcxx-ng: '>=12' python: '>=3.11,<3.12.0a0' python_abi: 3.11.* - url: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.0-py311h64a7726_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.2-py311h64a7726_0.conda hash: - md5: bf16a9f625126e378302f08e7ed67517 - sha256: 0aab5cef67cc2a1cd584f6e9cc6f2065c7a28c142d7defcb8096e8f719d9b3bf + md5: fd2f142dcd680413b5ede5d0fb799205 + sha256: c68b2c0ce95b79913134ec6ba2a2f1c10adcd60133afd48e4a57fdd128b694b7 category: main optional: false - name: openjpeg @@ -4074,16 +4074,16 @@ package: category: main optional: false - name: openssl - version: 3.1.4 + version: 3.2.0 manager: conda platform: linux-64 dependencies: ca-certificates: '' libgcc-ng: '>=12' - url: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.1.4-hd590300_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.2.0-hd590300_0.conda hash: - md5: 412ba6938c3e2abaca8b1129ea82e238 - sha256: d15b3e83ce66c6f6fbb4707f2f5c53337124c01fb03bfda1cf25c5b41123efc7 + md5: 68223671a2b68cdf7241eb4679ab2dd4 + sha256: a8ca7c31be33894bd70bb34786d1a8c26ae650382411250b61f6b5249b69a23e category: main optional: false - name: orc @@ -4497,18 +4497,6 @@ package: sha256: e26a5554883a0eada3641b6d861d8cb4895e2c7fcc17a587de07b8b1ecbfff0f category: main optional: false -- name: prompt_toolkit - version: 3.0.41 - manager: conda - platform: linux-64 - dependencies: - prompt-toolkit: '>=3.0.41,<3.0.42.0a0' - url: https://conda.anaconda.org/conda-forge/noarch/prompt_toolkit-3.0.41-hd8ed1ab_0.conda - hash: - md5: b1387bd091fa0420557f801a78587678 - sha256: dd2fea25930d258159441ad4a45e5d3274f0d2f1dea92fe25b44b48c486aa969 - category: main - optional: false - name: psutil version: 5.9.5 manager: conda @@ -5266,16 +5254,16 @@ package: category: main optional: false - name: s3transfer - version: 0.7.0 + version: 0.8.0 manager: conda platform: linux-64 dependencies: - botocore: '>=1.12.36,<2.0a.0' + botocore: '>=1.32.7,<2.0a.0' python: '>=3.7' - url: https://conda.anaconda.org/conda-forge/noarch/s3transfer-0.7.0-pyhd8ed1ab_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/s3transfer-0.8.0-pyhd8ed1ab_0.conda hash: - md5: 5fe335cb1420d13a818fe01310af2b80 - sha256: 5ed09d013ad7f2c2f65d1637c04ee19da242ef9bed0d86aa9faae2c48aaa255d + md5: 9d4e095f2a2e84d0a3f54e3d9f13f9b2 + sha256: cedf5d2e5da3dcd14d7da767a0cee8ef18938af724fdcf2fec682d44024cc2e8 category: main optional: false - name: sacremoses @@ -5355,15 +5343,15 @@ package: manager: conda platform: linux-64 dependencies: - geos: '>=3.12.0,<3.12.1.0a0' + geos: '>=3.12.1,<3.12.2.0a0' libgcc-ng: '>=12' numpy: '>=1.23.5,<2.0a0' python: '>=3.11,<3.12.0a0' python_abi: 3.11.* - url: https://conda.anaconda.org/conda-forge/linux-64/shapely-2.0.2-py311he06c224_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/shapely-2.0.2-py311h2032efe_1.conda hash: - md5: c90e2469d7512f3bba893533a82d7a02 - sha256: 2a02e516c57a2122cf9acaec54b75a821ad5f959a7702b17cb8df2c3fe31ef20 + md5: 4ba860ff851768615b1a25b788022750 + sha256: 5406be99410c471db7ce7bb59f238371525425acd7a7f5180387a7a16ae78b96 category: main optional: false - name: shellingham @@ -5467,19 +5455,19 @@ package: category: main optional: false - name: sqlite - version: 3.44.1 + version: 3.44.2 manager: conda platform: linux-64 dependencies: libgcc-ng: '>=12' - libsqlite: 3.44.1 + libsqlite: 3.44.2 libzlib: '>=1.2.13,<1.3.0a0' ncurses: '>=6.4,<7.0a0' readline: '>=8.2,<9.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.44.1-h2c6b66d_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.44.2-h2c6b66d_0.conda hash: - md5: cf535736bb0de7bf388dbfd2d6a50f53 - sha256: a0a2fc6c9d7e170c5738ad134f8b71f51a1c982c4496c47f8caa73ef4e5b17c8 + md5: 4f2892c672829693fd978d065db4e8be + sha256: bae479520fe770fe11996b4c240923ed097f851fbd2401d55540e551c9dbbef7 category: main optional: false - name: stack_data @@ -5685,6 +5673,25 @@ package: sha256: 90229da7665175b0185183ab7b53f50af487c7f9b0f47cf09c184cbc139fd24b category: main optional: false +- name: torchdata + version: 0.7.1 + manager: conda + platform: linux-64 + dependencies: + aws-sdk-cpp: '>=1.11.182,<1.11.183.0a0' + libgcc-ng: '>=12' + libstdcxx-ng: '>=12' + python: '>=3.11,<3.12.0a0' + python_abi: 3.11.* + pytorch: '>=2.1.0,<2.2.0a0' + requests: '' + urllib3: '>=1.25' + url: https://conda.anaconda.org/conda-forge/linux-64/torchdata-0.7.1-py311ha8bf654_0.conda + hash: + md5: fa7532ed041ab78a5eee52d2e6b80b12 + sha256: 2c64eace1fddfc07d2307d34711b7c1af167a44b73a0b2649e30704035348356 + category: main + optional: false - name: torchmetrics version: 1.2.0 manager: conda @@ -5730,15 +5737,15 @@ package: category: main optional: false - name: traitlets - version: 5.13.0 + version: 5.14.0 manager: conda platform: linux-64 dependencies: python: '>=3.8' - url: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.13.0-pyhd8ed1ab_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.0-pyhd8ed1ab_0.conda hash: - md5: 8a9953c15e1e5a7c1baddbbf4511a567 - sha256: 7ac67960ba2e8c16818043cc65ac6190fa4fd95f5b24357df58e4f73d5e60a10 + md5: 886f4a84ddb49b943b1697ac314e85b3 + sha256: c32412029033264140926be474d327d7fd57c0d11db9b1745396b3d4db78a799 category: main optional: false - name: transformers diff --git a/environment.yml b/environment.yml index 9f203776..e96b61bd 100644 --- a/environment.yml +++ b/environment.yml @@ -14,6 +14,7 @@ dependencies: - pytorch~=2.1.0=*cuda120* - python=3.11 - stackstac~=0.5.0 + - torchdata~=0.7.1 - transformers~=4.35.2 - typeshed-client~=2.4.0 - zarr~=2.16.1 diff --git a/src/README.md b/src/README.md index d71077e4..cd719680 100644 --- a/src/README.md +++ b/src/README.md @@ -4,7 +4,7 @@ This folder contains several LightningDataModule and LightningModule classes. ## DataModules (data pipeline) -- datamodule.py - Base data pipeline to read in Earth Observation chips +- datamodule.py - Data pipeline to read in Earth Observation chips from GeoTIFF files ## LightningModule (model architecture) diff --git a/src/datamodule.py b/src/datamodule.py index 71bb979e..582f28a6 100644 --- a/src/datamodule.py +++ b/src/datamodule.py @@ -1,43 +1,54 @@ """ -LightningDataModule to loads Earth Observation data from using -. +LightningDataModule to load Earth Observation data from GeoTIFF files using +rasterio. """ import lightning as L +import numpy as np +import rasterio import torch +import torchdata # %% -class RandomDataset(torch.utils.data.Dataset): +def _array_to_torch(filepath: str) -> torch.Tensor: """ - Torch Dataset that returns tensors of size (13, 256, 256) with random - values. + Read a GeoTIFF file using rasterio into a numpy.ndarray, and convert it + to a torch.Tensor (float16 dtype). """ + # GeoTIFF - Rasterio + with rasterio.open(fp=filepath) as dataset: + array: np.ndarray = dataset.read() + tensor: torch.Tensor = torch.as_tensor(data=array.astype(dtype="float16")) - def __init__(self): - super().__init__() - - def __len__(self): - return 2048 - - def __getitem__(self, idx: int): - return torch.randn(13, 256, 256) + return tensor -class BaseDataModule(L.LightningDataModule): +class GeoTIFFDataPipeModule(L.LightningDataModule): """ - LightningDataModule for loading files. + LightningDataModule for loading GeoTIFF files. - Uses + Uses torchdata. """ - def __init__(self, batch_size: int = 32): + def __init__( + self, + data_path: str = "data/", + batch_size: int = 32, + num_workers: int = 8, + ): """ Go from datacubes to 256x256 chips! Parameters ---------- + data_path : str + Path to the data folder where the GeoTIFF files are stored. Default + is 'data/'. batch_size : int Size of each mini-batch. Default is 32. + num_workers : int + How many subprocesses to use for data loading. 0 means that the + data will be loaded in the main process. Default is 8. Returns ------- @@ -45,21 +56,42 @@ def __init__(self, batch_size: int = 32): A torch DataPipe that can be passed into a torch DataLoader. """ super().__init__() + self.data_path: str = data_path self.batch_size: int = batch_size + self.num_workers: int = num_workers def setup(self, stage: str | None = None): """ Data operations to perform on every GPU. Split data into training and test sets, etc. """ - self.dataset = RandomDataset() + # Step 1 - Get list of GeoTIFF filepaths from data/ folder + dp_paths = torchdata.datapipes.iter.FileLister( + root=self.data_path, masks="*.tif", recursive=True, length=423 + ) + + # Step 2 - Split GeoTIFF chips into train/val sets (80%/20%) + # https://pytorch.org/data/0.7/generated/torchdata.datapipes.iter.RandomSplitter.html + dp_train, dp_val = dp_paths.random_split( + weights={"train": 0.8, "validation": 0.2}, total_length=423, seed=42 + ) + + # Step 3 - Read GeoTIFF into numpy.ndarray, batch and convert to torch.Tensor + self.datapipe_train = ( + dp_train.map(fn=_array_to_torch).batch(batch_size=self.batch_size).collate() + ) + self.datapipe_val = ( + dp_val.map(fn=_array_to_torch).batch(batch_size=self.batch_size).collate() + ) def train_dataloader(self) -> torch.utils.data.DataLoader: """ Loads the data used in the training loop. """ return torch.utils.data.DataLoader( - dataset=self.dataset, batch_size=self.batch_size + dataset=self.datapipe_train, + batch_size=None, # handled in datapipe already + num_workers=self.num_workers, ) def val_dataloader(self) -> torch.utils.data.DataLoader: @@ -67,5 +99,7 @@ def val_dataloader(self) -> torch.utils.data.DataLoader: Loads the data used in the validation loop. """ return torch.utils.data.DataLoader( - dataset=self.dataset, batch_size=self.batch_size + dataset=self.datapipe_val, + batch_size=None, # handled in datapipe already + num_workers=self.num_workers, ) diff --git a/src/tests/test_datamodule.py b/src/tests/test_datamodule.py new file mode 100644 index 00000000..49ed0db2 --- /dev/null +++ b/src/tests/test_datamodule.py @@ -0,0 +1,56 @@ +""" +Tests for GeoTIFFDataPipeModule. + +Integration test for the entire data pipeline from loading the data and +pre-processing steps, up to the DataLoader producing mini-batches. +""" +import tempfile + +import lightning as L +import numpy as np +import pytest +import rasterio +import torch + +from src.datamodule import GeoTIFFDataPipeModule + + +# %% +@pytest.fixture(scope="function", name="geotiff_folder") +def fixture_geotiff_folder(): + """ + Create a temporary folder containing two GeoTIFF files with random data to + use in the tests. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + for filename in ["one", "two"]: + array: np.ndarray = np.ones(shape=(3, 256, 256)) + with rasterio.open( + fp=f"{tmpdirname}/{filename}.tif", + mode="w", + width=256, + height=256, + count=3, + dtype=rasterio.uint16, + ) as dst: + dst.write(array) + + yield tmpdirname + + +# %% +def test_geotiffdatapipemodule(geotiff_folder): + """ + Ensure that GeoTIFFDataPipeModule works to load data from a GeoTIFF file + into torch.Tensor objects. + """ + datamodule: L.LightningDataModule = GeoTIFFDataPipeModule( + data_path=geotiff_folder, batch_size=2 + ) + datamodule.setup() + + it = iter(datamodule.train_dataloader()) + image = next(it) + + assert image.shape == torch.Size([2, 3, 256, 256]) + assert image.dtype == torch.float16 diff --git a/src/tests/test_model.py b/src/tests/test_model.py index 06bf59da..ba4813c8 100644 --- a/src/tests/test_model.py +++ b/src/tests/test_model.py @@ -5,24 +5,44 @@ https://github.com/Lightning-AI/lightning/blob/2.1.0/.github/CONTRIBUTING.md#how-to-add-new-tests """ import lightning as L +import pytest +import torch +import torchdata +import torchdata.dataloader2 -from src.datamodule import BaseDataModule from src.model_vit import ViTLitModule # %% -def test_model_vit(): +@pytest.fixture(scope="function", name="datapipe") +def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe: + """ + A torchdata DataPipe with random data to use in the tests. + """ + datapipe = torchdata.datapipes.iter.IterableWrapper( + iterable=[ + torch.randn(2, 13, 256, 256).to(dtype=torch.float16), + torch.randn(2, 13, 256, 256).to(dtype=torch.float16), + ] + ) + return datapipe + + +# %% +def test_model_vit(datapipe): """ Run a full train, val, test and prediction loop using 1 batch. """ # Get some random data - dataloader: L.LightningDataModule = BaseDataModule() + dataloader = torchdata.dataloader2.DataLoader2(datapipe=datapipe) # Initialize model model: L.LightningModule = ViTLitModule() # Training - trainer: L.Trainer = L.Trainer(accelerator="auto", devices=1, fast_dev_run=True) + trainer: L.Trainer = L.Trainer( + accelerator="auto", devices=1, precision="16-mixed", fast_dev_run=True + ) trainer.fit(model=model, train_dataloaders=dataloader) # Test/Evaluation diff --git a/trainer.py b/trainer.py index c9051b81..9fa4e681 100644 --- a/trainer.py +++ b/trainer.py @@ -11,7 +11,7 @@ """ from lightning.pytorch.cli import ArgsType, LightningCLI -from src.datamodule import BaseDataModule +from src.datamodule import GeoTIFFDataPipeModule from src.model_vit import ViTLitModule @@ -27,7 +27,7 @@ def cli_main( """ cli = LightningCLI( model_class=ViTLitModule, - datamodule_class=BaseDataModule, + datamodule_class=GeoTIFFDataPipeModule, save_config_callback=save_config_callback, seed_everything_default=seed_everything_default, trainer_defaults=trainer_defaults,