diff --git a/book/README.md b/book/README.md index cbb41b5..6c26450 100644 --- a/book/README.md +++ b/book/README.md @@ -1,31 +1,29 @@ -# Jupyter Book - -Everything is here. - -`_config.yml` is set to not render the notebooks. So make sure to save in rendered format. - -### GitHub Action - -There is a GitHub Action that should build the book whenever there is a push to the `book` directory. - -Set Pages to use a GitHub Action. If the Action does not run, then you will need to debug. Click on the Action that did not build and click on the part that had a problem. - -### Build locally and push to GitHub - -Do `pip install ghp-import` if needed. Then build book and push to GitHub. Set Pages to use `gh-pages` branch (which is going to disable deploying from the GitHub Action). These commands are run within the `book` directory. - -``` -cd /book -jupyter-book build . --keep-going -ghp-import -n -p -f _build/html -``` - -### Building Locally - -1. Open a terminal. -2. Run `jupyter-book clean book/` to remove any existing builds -3. Run `jupyter-book build book/` - -A fully-rendered HTML version of the book will be built in `book/_build/html/`. - - +# Jupyter Book + +Everything is here. + +`_config.yml` is set to not render the notebooks. So make sure to save in rendered format. + +### GitHub Action + +There is a GitHub Action that should build the book whenever there is a push to the `book` directory. If the Action does not run, then you will need to debug. Click on the Action that did not build and click on the part that had a problem. + +### Build locally and push to GitHub + +Do `pip install ghp-import` if needed. Then build book and push to GitHub. Set Pages to use `gh-pages` branch. These commands are run within the `book` directory. + +``` +cd /book +jupyter-book build . --keep-going +ghp-import -n -p -f _build/html +``` + +### Building Locally + +1. Open a terminal. +2. Run `jupyter-book clean book/` to remove any existing builds +3. Run `jupyter-book build book/` + +A fully-rendered HTML version of the book will be built in `book/_build/html/`. + + diff --git a/book/_toc.yml b/book/_toc.yml index f01d8cf..79448d5 100644 --- a/book/_toc.yml +++ b/book/_toc.yml @@ -1,20 +1,9 @@ -# Table of contents -# Learn more at https://jupyterbook.org/customize/toc.html - -format: jb-book -root: intro -parts: - - caption: Data - chapters: - - file: notebooks/IO_Zarr.md - title: Indian Ocean dataset - - file: notebooks/background.md - title: Background - - file: notebooks/IO_Zarr_visualizations.ipynb - title: Data visualizations - - caption: Models - chapters: - - file: notebooks/CHL_prediction_CNN.ipynb - title: CNNs - - file: notebooks/CHL_prediction_ConvLSTM_.ipynb - title: ConvLSTM +# Table of contents +# Learn more at https://jupyterbook.org/customize/toc.html + +format: jb-book +root: intro +chapters: +- file: myst-markdown +- file: notebooks/ipynb-notebook +- file: notebooks/myst-notebook \ No newline at end of file diff --git a/book/intro.md b/book/intro.md index 9601099..9d8dc28 100644 --- a/book/intro.md +++ b/book/intro.md @@ -1,27 +1,19 @@ -# Home - -## Neural network models for Chloraphyll-a gap-filling for remote-sensing products - -### Authors: See individual notebooks - -2024 GeoSMART Hackweek: - -[Pitch slide](https://docs.google.com/presentation/d/1YfBLkspba2hRz5pTHG9OF3o9WHv-yNemZDq2QKFCme0/edit?usp=sharing) -[Zotero library](https://www.zotero.org/groups/5595561/safs-interns-/library) -[Google doc](https://docs.google.com/document/d/1ADjtPFMy5mDxWJ_jhFhUWaBvjSd54YAfcc3d6araPCs/edit?usp=sharing) - - -### Collaborators - -| Name | Affiliation | Role | email | -| ------------- | ------------- | ------------- | ------------- | -| [Elizabeth Eli Holmes](https://eeholmes.github.io/) | NOAA Fisheries, University of Washington SAFS| SAFS Varanasi mentor | eli.holmes@noaa.gov | -[Shridhar Sinha](https://www.linkedin.com/in/shridhar-sinha-5b7125184/) | University of Washington, Paul G. Allen School of Computer Science & Engineering | 2024 SAFS Varanasi Intern | ssinha19@uw.edu | -| Yifei Hang | University of Washington, Applied & Computational Mathematical Sciences | 2024 SAFS Varanasi Intern | yhang2@uw.edu | -| [Jiarui Yu](https://www.linkedin.com/in/jiarui-yu-0b0ab522b/) | University of Washington, Applied & Computational Mathematical Sciences | 2023 SAFS Varanasi Intern | | -| [Minh Phan](https://www.linkedin.com/in/minhphan03/) | University of Washington, Applied & Computational Mathematical Sciences | 2023 SAFS Varanasi Intern | | -| Ares | | geo-smart HackWeek 2024 | | -| Gabe | | geo-smart HackWeek 2024 | | -| Qi Ge | | geo-smart HackWeek 2024 | | -| Andy Barrett | | geo-smart HackWeek 2024 | | -| Robin Clancy | | geo-smart HackWeek 2024 | | \ No newline at end of file +# Project Title and Introduction + +Provide a brief introduction. + +* Edit `_config.yml` with your title, authors, repo name etc. +* Add new notebooks in the `notebooks` folder +* Add those notebooks into `_toc.yml` + +### Collaborators + + +| Name | Personal goals | Can help with | Role | +| ------------- | ------------- | ------------- | ------------- | +| Katherine J. | I want to learn specific python libraries for working with these data | I can help with understanding our dataset, programming in R | Project Lead | +| Rosalind F. | Practice leading a software project | machine learning and python (scipy, scikit-learn) | Project Lead | +| Alan T. | learning about your dataset | GitHub, Jupyter, cloud computing | Project Helper | +| Rachel C. | learn to use github, resolve merge conflicts | I am familiar with our dataset | Team Member | +| ... | ... | ... | ... | +| ... | ... | ... | ... | \ No newline at end of file diff --git a/notebooks/PINN_refactor_test.ipynb b/notebooks/PINN_refactor_test.ipynb new file mode 100644 index 0000000..6e03ace --- /dev/null +++ b/notebooks/PINN_refactor_test.ipynb @@ -0,0 +1,178 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import os\n", + "os.environ[\"DDEBACKEND\"] = \"pytorch\"\n", + "import torch\n", + "from torch import nn, optim\n", + "import deepxde as dde\n", + "from shapely.geometry import Point\n", + "import cartopy.feature as cfeature\n", + "\n", + "import sys\n", + "sys.path.append(\"../\")\n", + "\n", + "from src.models import ChlorophyllDeepONet\n", + "from src.boundary_conds import get_xt_geom, is_in_ocean\n", + "from src.pdes import pde\n", + "from src.data_utils import *\n", + "\n", + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting data load and preprocessing...\n", + "Starting data preparation for PINN...\n" + ] + } + ], + "source": [ + "zarr_ds = load_and_preprocess_data()\n", + "data, time, lat, lon, water_mask = prepare_data_for_pinn(zarr_ds)\n", + "chl_data = data[\"CHL_cmes-level3\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "geomtime, coastline = get_xt_geom(lat, lon, time)\n", + "\n", + "# Convert numpy arrays to PyTorch tensors\n", + "air_temp = torch.tensor(data[\"air_temp\"], dtype=torch.float32)\n", + "sst = torch.tensor(data[\"sst\"], dtype=torch.float32)\n", + "curr_dir = torch.tensor(data[\"curr_dir\"], dtype=torch.float32)\n", + "ug_curr = torch.tensor(data[\"ug_curr\"], dtype=torch.float32)\n", + "u_wind = torch.tensor(data[\"u_wind\"], dtype=torch.float32)\n", + "v_wind = torch.tensor(data[\"v_wind\"], dtype=torch.float32)\n", + "v_curr = torch.tensor(data[\"v_curr\"], dtype=torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 9\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m on_boundary \u001b[38;5;129;01mand\u001b[39;00m ocean_boundary\n\u001b[1;32m 7\u001b[0m bc_robin \u001b[38;5;241m=\u001b[39m dde\u001b[38;5;241m.\u001b[39micbc\u001b[38;5;241m.\u001b[39mRobinBC(geomtime, \u001b[38;5;28;01mlambda\u001b[39;00m X, y: y, boundary_condition)\n\u001b[0;32m----> 9\u001b[0m data_pinn \u001b[38;5;241m=\u001b[39m \u001b[43mdde\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTimePDE\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mgeomtime\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mpde\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mbc_robin\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_domain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10000\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_boundary\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2000\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_initial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2000\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/data/pde.py:338\u001b[0m, in \u001b[0;36mTimePDE.__init__\u001b[0;34m(self, geometryxtime, pde, ic_bcs, num_domain, num_boundary, num_initial, train_distribution, anchors, exclusions, solution, num_test, auxiliary_var_function)\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 323\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 324\u001b[0m geometryxtime,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 335\u001b[0m auxiliary_var_function\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 336\u001b[0m ):\n\u001b[1;32m 337\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_initial \u001b[38;5;241m=\u001b[39m num_initial\n\u001b[0;32m--> 338\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 339\u001b[0m \u001b[43m \u001b[49m\u001b[43mgeometryxtime\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 340\u001b[0m \u001b[43m \u001b[49m\u001b[43mpde\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 341\u001b[0m \u001b[43m \u001b[49m\u001b[43mic_bcs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 342\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_domain\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 343\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_boundary\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 344\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_distribution\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_distribution\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 345\u001b[0m \u001b[43m \u001b[49m\u001b[43manchors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43manchors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 346\u001b[0m \u001b[43m \u001b[49m\u001b[43mexclusions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexclusions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[43m \u001b[49m\u001b[43msolution\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msolution\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_test\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_test\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 349\u001b[0m \u001b[43m \u001b[49m\u001b[43mauxiliary_var_function\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mauxiliary_var_function\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 350\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/data/pde.py:130\u001b[0m, in \u001b[0;36mPDE.__init__\u001b[0;34m(self, geometry, pde, bcs, num_domain, num_boundary, train_distribution, anchors, exclusions, solution, num_test, auxiliary_var_function)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_x, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_aux_vars, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest_aux_vars \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_next_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtest()\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/utils/internal.py:38\u001b[0m, in \u001b[0;36mrun_if_all_none..decorator..wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m x \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, a) \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m attr]\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(i \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m x):\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(x) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m x[\u001b[38;5;241m0\u001b[39m]\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/data/pde.py:186\u001b[0m, in \u001b[0;36mPDE.train_next_batch\u001b[0;34m(self, batch_size)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[38;5;129m@run_if_all_none\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain_x\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain_y\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain_aux_vars\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain_next_batch\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_x_all \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_points()\n\u001b[0;32m--> 186\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbc_points\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Generate self.num_bcs and self.train_x_bc\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbcs \u001b[38;5;129;01mand\u001b[39;00m config\u001b[38;5;241m.\u001b[39mhvd \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 188\u001b[0m num_bcs \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_bcs)\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/utils/internal.py:38\u001b[0m, in \u001b[0;36mrun_if_all_none..decorator..wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m x \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, a) \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m attr]\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(i \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m x):\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(x) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m x[\u001b[38;5;241m0\u001b[39m]\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/data/pde.py:298\u001b[0m, in \u001b[0;36mPDE.bc_points\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;129m@run_if_all_none\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain_x_bc\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbc_points\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 298\u001b[0m x_bcs \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[43mbc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollocation_points\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_x_all\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbc\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbcs\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 299\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_bcs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28mlen\u001b[39m, x_bcs))\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_x_bc \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 301\u001b[0m np\u001b[38;5;241m.\u001b[39mvstack(x_bcs)\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x_bcs\n\u001b[1;32m 303\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m np\u001b[38;5;241m.\u001b[39mempty([\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_x_all\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]], dtype\u001b[38;5;241m=\u001b[39mconfig\u001b[38;5;241m.\u001b[39mreal(np))\n\u001b[1;32m 304\u001b[0m )\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/data/pde.py:298\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;129m@run_if_all_none\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain_x_bc\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbc_points\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 298\u001b[0m x_bcs \u001b[38;5;241m=\u001b[39m [\u001b[43mbc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollocation_points\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_x_all\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m bc \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbcs]\n\u001b[1;32m 299\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_bcs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28mlen\u001b[39m, x_bcs))\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_x_bc \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 301\u001b[0m np\u001b[38;5;241m.\u001b[39mvstack(x_bcs)\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x_bcs\n\u001b[1;32m 303\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m np\u001b[38;5;241m.\u001b[39mempty([\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_x_all\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]], dtype\u001b[38;5;241m=\u001b[39mconfig\u001b[38;5;241m.\u001b[39mreal(np))\n\u001b[1;32m 304\u001b[0m )\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/icbc/boundary_conditions.py:53\u001b[0m, in \u001b[0;36mBC.collocation_points\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcollocation_points\u001b[39m(\u001b[38;5;28mself\u001b[39m, X):\n\u001b[0;32m---> 53\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/icbc/boundary_conditions.py:50\u001b[0m, in \u001b[0;36mBC.filter\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfilter\u001b[39m(\u001b[38;5;28mself\u001b[39m, X):\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X[\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_boundary\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgeom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_boundary\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m]\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/icbc/boundary_conditions.py:41\u001b[0m, in \u001b[0;36mBC.__init__..\u001b[0;34m(x, on)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, geom, on_boundary, component):\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgeom \u001b[38;5;241m=\u001b[39m geom\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_boundary \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x, on: np\u001b[38;5;241m.\u001b[39marray(\n\u001b[0;32m---> 41\u001b[0m \u001b[43m[\u001b[49m\u001b[43mon_boundary\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mon\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 42\u001b[0m )\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcomponent \u001b[38;5;241m=\u001b[39m component\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mboundary_normal \u001b[38;5;241m=\u001b[39m npfunc_range_autocache(\n\u001b[1;32m 46\u001b[0m utils\u001b[38;5;241m.\u001b[39mreturn_tensor(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgeom\u001b[38;5;241m.\u001b[39mboundary_normal)\n\u001b[1;32m 47\u001b[0m )\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/deepxde/icbc/boundary_conditions.py:41\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, geom, on_boundary, component):\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgeom \u001b[38;5;241m=\u001b[39m geom\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_boundary \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x, on: np\u001b[38;5;241m.\u001b[39marray(\n\u001b[0;32m---> 41\u001b[0m [\u001b[43mon_boundary\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mon\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(x))]\n\u001b[1;32m 42\u001b[0m )\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcomponent \u001b[38;5;241m=\u001b[39m component\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mboundary_normal \u001b[38;5;241m=\u001b[39m npfunc_range_autocache(\n\u001b[1;32m 46\u001b[0m utils\u001b[38;5;241m.\u001b[39mreturn_tensor(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgeom\u001b[38;5;241m.\u001b[39mboundary_normal)\n\u001b[1;32m 47\u001b[0m )\n", + "Cell \u001b[0;32mIn[9], line 4\u001b[0m, in \u001b[0;36mboundary_condition\u001b[0;34m(x, on_boundary)\u001b[0m\n\u001b[1;32m 2\u001b[0m lat \u001b[38;5;241m=\u001b[39m x[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 3\u001b[0m lon \u001b[38;5;241m=\u001b[39m x[\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m----> 4\u001b[0m ocean_boundary \u001b[38;5;241m=\u001b[39m \u001b[43mis_in_ocean\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlon\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcoastline\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m on_boundary \u001b[38;5;129;01mand\u001b[39;00m ocean_boundary\n", + "File \u001b[0;32m~/mind-the-chl-gap/notebooks/../boundary_conds.py:9\u001b[0m, in \u001b[0;36mis_in_ocean\u001b[0;34m(lat, lon, coastline)\u001b[0m\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/shapely/geometry/base.py:658\u001b[0m, in \u001b[0;36mBaseGeometry.contains\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 656\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcontains\u001b[39m(\u001b[38;5;28mself\u001b[39m, other):\n\u001b[1;32m 657\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns True if the geometry contains the other, else False\"\"\"\u001b[39;00m\n\u001b[0;32m--> 658\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _maybe_unpack(\u001b[43mshapely\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontains\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mother\u001b[49m\u001b[43m)\u001b[49m)\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/shapely/decorators.py:77\u001b[0m, in \u001b[0;36mmultithreading_enabled..wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m arr \u001b[38;5;129;01min\u001b[39;00m array_args:\n\u001b[1;32m 76\u001b[0m arr\u001b[38;5;241m.\u001b[39mflags\u001b[38;5;241m.\u001b[39mwriteable \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m---> 77\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m arr, old_flag \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(array_args, old_flags):\n", + "File \u001b[0;32m/srv/conda/envs/notebook/lib/python3.11/site-packages/shapely/predicates.py:540\u001b[0m, in \u001b[0;36mcontains\u001b[0;34m(a, b, **kwargs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;129m@multithreading_enabled\u001b[39m\n\u001b[1;32m 486\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcontains\u001b[39m(a, b, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 487\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns True if geometry B is completely inside geometry A.\u001b[39;00m\n\u001b[1;32m 488\u001b[0m \n\u001b[1;32m 489\u001b[0m \u001b[38;5;124;03m A contains B if no points of B lie in the exterior of A and at least one\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[38;5;124;03m False\u001b[39;00m\n\u001b[1;32m 539\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 540\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontains\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "def boundary_condition(x, on_boundary):\n", + " lat = x[0]\n", + " lon = x[1]\n", + " ocean_boundary = is_in_ocean(lat, lon, coastline)\n", + " return on_boundary and ocean_boundary\n", + "\n", + "bc_robin = dde.icbc.RobinBC(geomtime, lambda X, y: y, boundary_condition)\n", + "\n", + "data_pinn = dde.data.TimePDE(\n", + " geomtime,\n", + " pde,\n", + " [bc_robin],\n", + " num_domain=10000,\n", + " num_boundary=2000,\n", + " num_initial=2000,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/plots/chl_animation.mp4 b/plots/chl_animation.mp4 new file mode 100644 index 0000000..2efec1e Binary files /dev/null and b/plots/chl_animation.mp4 differ diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/boundary_conds.py b/src/boundary_conds.py new file mode 100644 index 0000000..3172724 --- /dev/null +++ b/src/boundary_conds.py @@ -0,0 +1,35 @@ +from shapely.geometry import Point +import cartopy.feature as cfeature +import deepxde as dde + + +def is_in_ocean(lat, lon, coastline): + point = Point(lon, lat) + for geometry in coastline.geometries(): + if geometry.contains(point): + return False + return True + + +# Currently not working because `coastline` i a global variable +# and can't be an argument for the `boundary_condition` function +# and python is weird +# def boundary_condition(x, on_boundary): +# lat = x[0] +# lon = x[1] +# ocean_boundary = is_in_ocean(lat, lon, coastline) +# return on_boundary and ocean_boundary + + +def get_xt_geom(lat, lon, time): + lat_min, lat_max = lat.min(), lat.max() + lon_min, lon_max = lon.min(), lon.max() + time_min, time_max = time.min(), time.max() + spatial_domain = dde.geometry.Rectangle( + xmin=[lat_min, lon_min], xmax=[lat_max, lon_max] + ) + temporal_domain = dde.geometry.TimeDomain(t0=time_min, t1=time_max) + geomtime = dde.geometry.GeometryXTime(spatial_domain, temporal_domain) + coastline = cfeature.NaturalEarthFeature("physical", "coastline", "50m") + + return geomtime, coastline diff --git a/src/data_utils.py b/src/data_utils.py new file mode 100644 index 0000000..ee4b439 --- /dev/null +++ b/src/data_utils.py @@ -0,0 +1,62 @@ +import numpy as np +import xarray as xr + + +# Load and preprocess data +def load_and_preprocess_data(): + "TODO: Time slice variable?" + print("Starting data load and preprocessing...") + zarr_ds = xr.open_zarr(store="~/shared-public/mind_the_chl_gap/IO.zarr", consolidated=True) + zarr_ds = zarr_ds.sel(lat=slice(32, -11.75), lon=slice(42, 101.75)) + + all_nan_dates = ( + np.isnan(zarr_ds["CHL_cmes-level3"]).all(dim=["lon", "lat"]).compute() + ) + zarr_ds = zarr_ds.sel(time=~all_nan_dates) + zarr_ds = zarr_ds.sortby("time") + zarr_ds = zarr_ds.sel(time=slice("2019-01-01", "2022-12-31")) + return zarr_ds + + +# Prepare data for PINN +def prepare_data_for_pinn(zarr_ds): + print("Starting data preparation for PINN...") + variables = [ + "CHL_cmes-level3", + "air_temp", + "sst", + "curr_dir", + "ug_curr", + "u_wind", + "v_wind", + "v_curr", + ] + data = {var: zarr_ds[var].values for var in variables} + + water_mask = ~np.isnan(data["sst"][0]) + + for var in variables: + data[var] = data[var][:, water_mask] + data[var] = np.nan_to_num( + data[var], + nan=np.nanmean(data[var]), + posinf=np.nanmax(data[var]), + neginf=np.nanmin(data[var]), + ) + if var == "CHL_cmes-level3": + data[var] = np.log(data[var]) # Use log CHL + mean = np.mean(data[var]) + std = np.std(data[var]) + data[var] = (data[var] - mean) / std + data[f"{var}_mean"] = mean + data[f"{var}_std"] = std + + time = zarr_ds.time.values + lat = zarr_ds.lat.values + lon = zarr_ds.lon.values + time_numeric = (time - time[0]).astype("timedelta64[D]").astype(float) + lon_grid, lat_grid = np.meshgrid(lon, lat) + lat_flat = lat_grid.flatten()[water_mask.flatten()] + lon_flat = lon_grid.flatten()[water_mask.flatten()] + + return data, time_numeric, lat_flat, lon_flat, water_mask diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..2aca8bd --- /dev/null +++ b/src/models.py @@ -0,0 +1,59 @@ +import torch +from torch import nn + +import deepxde as dde + + +class ChlorophyllDeepONet(dde.nn.pytorch.deeponet.DeepONet): + def __init__(self, layer_sizes_branch, layer_sizes_trunk, activation): + super().__init__( + layer_sizes_branch, layer_sizes_trunk, activation, "Glorot normal" + ) + + self.branch_net = dde.nn.pytorch.fnn.FNN( + layer_sizes_branch, activation, "Glorot normal" + ) + self.trunk_net = dde.nn.pytorch.fnn.FNN( + layer_sizes_trunk, activation, "Glorot normal" + ) + + def forward(self, inputs): + x_func = self.branch_net(inputs) + x_loc = self.trunk_net(inputs) + if self._output_transform is not None: + return self._output_transform(self.merge_branch_trunk(x_func, x_loc, -1)) + return self.merge_branch_trunk(x_func, x_loc, -1) + + +class UNet(nn.Module): + def __init__(self): + super(UNet, self).__init__() + self.encoder = nn.Sequential( + nn.Conv2d(4, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=0), + ) + self.middle = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=0), + ) + self.decoder = nn.Sequential( + nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(64, 1, kernel_size=1), + ) + + def forward(self, x): + x1 = self.encoder(x) + x2 = self.middle(x1) + x3 = self.decoder(x2) + return x3 diff --git a/src/pdes.py b/src/pdes.py new file mode 100644 index 0000000..13eb40c --- /dev/null +++ b/src/pdes.py @@ -0,0 +1,26 @@ +import torch +import deepxde as dde + + +def pde(x, y): + lat, lon, t = x[:, 0:1], x[:, 1:2], x[:, 2:3] + d2U_dlat2 = dde.grad.hessian(y, x, component=0, i=0, j=0) + d2U_dlon2 = dde.grad.hessian(y, x, component=0, i=1, j=1) + d2U_dt2 = dde.grad.hessian(y, x, component=0, i=2, j=2) + + rho = ( + 0.1 * torch.sin(lat) * torch.cos(lon) * torch.exp(-0.1 * t) + + 0.05 * torch.sin(2 * torch.pi * t / 365) + + ( + 0.5 * air_temp_mean + + -1.0 * sst_mean + + 0.05 * curr_dir_mean + + 0.15 * ug_curr_mean + + 0.4 * u_wind_mean + + -0.2 * v_wind_mean + + 0.3 * v_curr_mean + ) + ) + + residual = d2U_dlat2 + d2U_dlon2 + d2U_dt2 - rho + return residual diff --git a/src/training.py b/src/training.py new file mode 100644 index 0000000..ed6c77b --- /dev/null +++ b/src/training.py @@ -0,0 +1,63 @@ +import numpy as np +import torch +from models import UNet + + +# Train the UNet model +def train_unet(data, epochs=100, batch_size=32): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = UNet().to(device) + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + + features = ["u_wind", "v_wind", "sst", "air_temp"] + X = np.stack([data[feature] for feature in features], axis=1) + y = data["CHL_cmes-level3"] + + # Debug: Print the shapes of X and y before reshaping + print(f"Original X shape: {X.shape}") + print(f"Original y shape: {y.shape}") + + num_elements = X.shape[2] + nearest_square = int(np.floor(np.sqrt(num_elements)) ** 2) + height = int(np.sqrt(nearest_square)) + width = height + + # Trim X and y to the nearest perfect square + X = X[:, :, :nearest_square] + y = y[:, :nearest_square] + + # Reshape X and y to match the expected input shape for UNet + num_samples = X.shape[0] + num_features = len(features) + + X = X.reshape(num_samples, num_features, height, width) + y = y.reshape(num_samples, 1, height, width) + + # Debug: Print the shapes of X and y after reshaping + print(f"Reshaped X shape: {X.shape}") + print(f"Reshaped y shape: {y.shape}") + + X = torch.tensor(X, dtype=torch.float32).to(device) + y = torch.tensor(y, dtype=torch.float32).to(device) + + generator = torch.Generator(device=device) + dataset = torch.utils.data.TensorDataset(X, y) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, generator=generator + ) + + for epoch in range(epochs): + for inputs, targets in dataloader: + optimizer.zero_grad() + outputs = model(inputs) + # Resize the outputs to match the target size + outputs = nn.functional.interpolate( + outputs, size=(height, width), mode="bilinear", align_corners=False + ) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}") + save_model(model, "unet_model.pth") + return model diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..e8fd05c --- /dev/null +++ b/src/utils.py @@ -0,0 +1,32 @@ +import torch +import numpy as np + +def save_model(model, path): + torch.save(model.state_dict(), path) + + +def load_model(model, path): + model.load_state_dict(torch.load(path)) + model.eval() + return model + +def pad_img(img, water_mask): + pad_length = np.sum(water_mask) - img.shape[0] + if pad_length > 0: + img = np.pad(img, (0, pad_length), mode="constant", constant_values=np.nan) + return img + + +def data_to_img(data, water_mask, pad=False): + "Transform data slice or PINN prediction into image" + if len(data.shape) > 0: + data = data.flatten() + if pad: + data = pad_img(data, water_mask) + # Create full NaN arrays matching the water_mask shape + img_grid = np.full(water_mask.shape, np.nan) + # Assign the chlorophyll values to the water pixels + img_grid[water_mask] = data[: np.sum(water_mask)] + # Reshape grids to match the lat/lon dimensions + img_grid = img_grid.reshape(water_mask.shape) + return img_grid \ No newline at end of file