Skip to content

Commit

Permalink
Merge pull request #1 from geo-smart/temp-branch
Browse files Browse the repository at this point in the history
Refactoring attempt
  • Loading branch information
eeholmes authored Aug 21, 2024
2 parents f835e43 + 40c7c20 commit 86b230b
Show file tree
Hide file tree
Showing 12 changed files with 512 additions and 78 deletions.
60 changes: 29 additions & 31 deletions book/README.md
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
# Jupyter Book

Everything is here.

`_config.yml` is set to not render the notebooks. So make sure to save in rendered format.

### GitHub Action

There is a GitHub Action that should build the book whenever there is a push to the `book` directory.

Set Pages to use a GitHub Action. If the Action does not run, then you will need to debug. Click on the Action that did not build and click on the part that had a problem.

### Build locally and push to GitHub

Do `pip install ghp-import` if needed. Then build book and push to GitHub. Set Pages to use `gh-pages` branch (which is going to disable deploying from the GitHub Action). These commands are run within the `book` directory.

```
cd /book
jupyter-book build . --keep-going
ghp-import -n -p -f _build/html
```

### Building Locally

1. Open a terminal.
2. Run `jupyter-book clean book/` to remove any existing builds
3. Run `jupyter-book build book/`

A fully-rendered HTML version of the book will be built in `book/_build/html/`.


# Jupyter Book

Everything is here.

`_config.yml` is set to not render the notebooks. So make sure to save in rendered format.

### GitHub Action

There is a GitHub Action that should build the book whenever there is a push to the `book` directory. If the Action does not run, then you will need to debug. Click on the Action that did not build and click on the part that had a problem.

### Build locally and push to GitHub

Do `pip install ghp-import` if needed. Then build book and push to GitHub. Set Pages to use `gh-pages` branch. These commands are run within the `book` directory.

```
cd /book
jupyter-book build . --keep-going
ghp-import -n -p -f _build/html
```

### Building Locally

1. Open a terminal.
2. Run `jupyter-book clean book/` to remove any existing builds
3. Run `jupyter-book build book/`

A fully-rendered HTML version of the book will be built in `book/_build/html/`.


29 changes: 9 additions & 20 deletions book/_toc.yml
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
# Table of contents
# Learn more at https://jupyterbook.org/customize/toc.html

format: jb-book
root: intro
parts:
- caption: Data
chapters:
- file: notebooks/IO_Zarr.md
title: Indian Ocean dataset
- file: notebooks/background.md
title: Background
- file: notebooks/IO_Zarr_visualizations.ipynb
title: Data visualizations
- caption: Models
chapters:
- file: notebooks/CHL_prediction_CNN.ipynb
title: CNNs
- file: notebooks/CHL_prediction_ConvLSTM_.ipynb
title: ConvLSTM
# Table of contents
# Learn more at https://jupyterbook.org/customize/toc.html

format: jb-book
root: intro
chapters:
- file: myst-markdown
- file: notebooks/ipynb-notebook
- file: notebooks/myst-notebook
46 changes: 19 additions & 27 deletions book/intro.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
# Home

## Neural network models for Chloraphyll-a gap-filling for remote-sensing products

### Authors: See individual notebooks

2024 GeoSMART Hackweek:

[Pitch slide](https://docs.google.com/presentation/d/1YfBLkspba2hRz5pTHG9OF3o9WHv-yNemZDq2QKFCme0/edit?usp=sharing)
[Zotero library](https://www.zotero.org/groups/5595561/safs-interns-/library)
[Google doc](https://docs.google.com/document/d/1ADjtPFMy5mDxWJ_jhFhUWaBvjSd54YAfcc3d6araPCs/edit?usp=sharing)


### Collaborators

| Name | Affiliation | Role | email |
| ------------- | ------------- | ------------- | ------------- |
| [Elizabeth Eli Holmes](https://eeholmes.github.io/) | NOAA Fisheries, University of Washington SAFS| SAFS Varanasi mentor | [email protected] |
[Shridhar Sinha](https://www.linkedin.com/in/shridhar-sinha-5b7125184/) | University of Washington, Paul G. Allen School of Computer Science & Engineering | 2024 SAFS Varanasi Intern | [email protected] |
| Yifei Hang | University of Washington, Applied & Computational Mathematical Sciences | 2024 SAFS Varanasi Intern | [email protected] |
| [Jiarui Yu](https://www.linkedin.com/in/jiarui-yu-0b0ab522b/) | University of Washington, Applied & Computational Mathematical Sciences | 2023 SAFS Varanasi Intern | |
| [Minh Phan](https://www.linkedin.com/in/minhphan03/) | University of Washington, Applied & Computational Mathematical Sciences | 2023 SAFS Varanasi Intern | |
| Ares | | geo-smart HackWeek 2024 | |
| Gabe | | geo-smart HackWeek 2024 | |
| Qi Ge | | geo-smart HackWeek 2024 | |
| Andy Barrett | | geo-smart HackWeek 2024 | |
| Robin Clancy | | geo-smart HackWeek 2024 | |
# Project Title and Introduction

Provide a brief introduction.

* Edit `_config.yml` with your title, authors, repo name etc.
* Add new notebooks in the `notebooks` folder
* Add those notebooks into `_toc.yml`

### Collaborators


| Name | Personal goals | Can help with | Role |
| ------------- | ------------- | ------------- | ------------- |
| Katherine J. | I want to learn specific python libraries for working with these data | I can help with understanding our dataset, programming in R | Project Lead |
| Rosalind F. | Practice leading a software project | machine learning and python (scipy, scikit-learn) | Project Lead |
| Alan T. | learning about your dataset | GitHub, Jupyter, cloud computing | Project Helper |
| Rachel C. | learn to use github, resolve merge conflicts | I am familiar with our dataset | Team Member |
| ... | ... | ... | ... |
| ... | ... | ... | ... |
178 changes: 178 additions & 0 deletions notebooks/PINN_refactor_test.ipynb

Large diffs are not rendered by default.

Binary file added plots/chl_animation.mp4
Binary file not shown.
Empty file added src/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions src/boundary_conds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from shapely.geometry import Point
import cartopy.feature as cfeature
import deepxde as dde


def is_in_ocean(lat, lon, coastline):
point = Point(lon, lat)
for geometry in coastline.geometries():
if geometry.contains(point):
return False
return True


# Currently not working because `coastline` i a global variable
# and can't be an argument for the `boundary_condition` function
# and python is weird
# def boundary_condition(x, on_boundary):
# lat = x[0]
# lon = x[1]
# ocean_boundary = is_in_ocean(lat, lon, coastline)
# return on_boundary and ocean_boundary


def get_xt_geom(lat, lon, time):
lat_min, lat_max = lat.min(), lat.max()
lon_min, lon_max = lon.min(), lon.max()
time_min, time_max = time.min(), time.max()
spatial_domain = dde.geometry.Rectangle(
xmin=[lat_min, lon_min], xmax=[lat_max, lon_max]
)
temporal_domain = dde.geometry.TimeDomain(t0=time_min, t1=time_max)
geomtime = dde.geometry.GeometryXTime(spatial_domain, temporal_domain)
coastline = cfeature.NaturalEarthFeature("physical", "coastline", "50m")

return geomtime, coastline
62 changes: 62 additions & 0 deletions src/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np
import xarray as xr


# Load and preprocess data
def load_and_preprocess_data():
"TODO: Time slice variable?"
print("Starting data load and preprocessing...")
zarr_ds = xr.open_zarr(store="~/shared-public/mind_the_chl_gap/IO.zarr", consolidated=True)
zarr_ds = zarr_ds.sel(lat=slice(32, -11.75), lon=slice(42, 101.75))

all_nan_dates = (
np.isnan(zarr_ds["CHL_cmes-level3"]).all(dim=["lon", "lat"]).compute()
)
zarr_ds = zarr_ds.sel(time=~all_nan_dates)
zarr_ds = zarr_ds.sortby("time")
zarr_ds = zarr_ds.sel(time=slice("2019-01-01", "2022-12-31"))
return zarr_ds


# Prepare data for PINN
def prepare_data_for_pinn(zarr_ds):
print("Starting data preparation for PINN...")
variables = [
"CHL_cmes-level3",
"air_temp",
"sst",
"curr_dir",
"ug_curr",
"u_wind",
"v_wind",
"v_curr",
]
data = {var: zarr_ds[var].values for var in variables}

water_mask = ~np.isnan(data["sst"][0])

for var in variables:
data[var] = data[var][:, water_mask]
data[var] = np.nan_to_num(
data[var],
nan=np.nanmean(data[var]),
posinf=np.nanmax(data[var]),
neginf=np.nanmin(data[var]),
)
if var == "CHL_cmes-level3":
data[var] = np.log(data[var]) # Use log CHL
mean = np.mean(data[var])
std = np.std(data[var])
data[var] = (data[var] - mean) / std
data[f"{var}_mean"] = mean
data[f"{var}_std"] = std

time = zarr_ds.time.values
lat = zarr_ds.lat.values
lon = zarr_ds.lon.values
time_numeric = (time - time[0]).astype("timedelta64[D]").astype(float)
lon_grid, lat_grid = np.meshgrid(lon, lat)
lat_flat = lat_grid.flatten()[water_mask.flatten()]
lon_flat = lon_grid.flatten()[water_mask.flatten()]

return data, time_numeric, lat_flat, lon_flat, water_mask
59 changes: 59 additions & 0 deletions src/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from torch import nn

import deepxde as dde


class ChlorophyllDeepONet(dde.nn.pytorch.deeponet.DeepONet):
def __init__(self, layer_sizes_branch, layer_sizes_trunk, activation):
super().__init__(
layer_sizes_branch, layer_sizes_trunk, activation, "Glorot normal"
)

self.branch_net = dde.nn.pytorch.fnn.FNN(
layer_sizes_branch, activation, "Glorot normal"
)
self.trunk_net = dde.nn.pytorch.fnn.FNN(
layer_sizes_trunk, activation, "Glorot normal"
)

def forward(self, inputs):
x_func = self.branch_net(inputs)
x_loc = self.trunk_net(inputs)
if self._output_transform is not None:
return self._output_transform(self.merge_branch_trunk(x_func, x_loc, -1))
return self.merge_branch_trunk(x_func, x_loc, -1)


class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(4, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
)
self.middle = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(64, 1, kernel_size=1),
)

def forward(self, x):
x1 = self.encoder(x)
x2 = self.middle(x1)
x3 = self.decoder(x2)
return x3
26 changes: 26 additions & 0 deletions src/pdes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import deepxde as dde


def pde(x, y):
lat, lon, t = x[:, 0:1], x[:, 1:2], x[:, 2:3]
d2U_dlat2 = dde.grad.hessian(y, x, component=0, i=0, j=0)
d2U_dlon2 = dde.grad.hessian(y, x, component=0, i=1, j=1)
d2U_dt2 = dde.grad.hessian(y, x, component=0, i=2, j=2)

rho = (
0.1 * torch.sin(lat) * torch.cos(lon) * torch.exp(-0.1 * t)
+ 0.05 * torch.sin(2 * torch.pi * t / 365)
+ (
0.5 * air_temp_mean
+ -1.0 * sst_mean
+ 0.05 * curr_dir_mean
+ 0.15 * ug_curr_mean
+ 0.4 * u_wind_mean
+ -0.2 * v_wind_mean
+ 0.3 * v_curr_mean
)
)

residual = d2U_dlat2 + d2U_dlon2 + d2U_dt2 - rho
return residual
63 changes: 63 additions & 0 deletions src/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import torch
from models import UNet


# Train the UNet model
def train_unet(data, epochs=100, batch_size=32):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

features = ["u_wind", "v_wind", "sst", "air_temp"]
X = np.stack([data[feature] for feature in features], axis=1)
y = data["CHL_cmes-level3"]

# Debug: Print the shapes of X and y before reshaping
print(f"Original X shape: {X.shape}")
print(f"Original y shape: {y.shape}")

num_elements = X.shape[2]
nearest_square = int(np.floor(np.sqrt(num_elements)) ** 2)
height = int(np.sqrt(nearest_square))
width = height

# Trim X and y to the nearest perfect square
X = X[:, :, :nearest_square]
y = y[:, :nearest_square]

# Reshape X and y to match the expected input shape for UNet
num_samples = X.shape[0]
num_features = len(features)

X = X.reshape(num_samples, num_features, height, width)
y = y.reshape(num_samples, 1, height, width)

# Debug: Print the shapes of X and y after reshaping
print(f"Reshaped X shape: {X.shape}")
print(f"Reshaped y shape: {y.shape}")

X = torch.tensor(X, dtype=torch.float32).to(device)
y = torch.tensor(y, dtype=torch.float32).to(device)

generator = torch.Generator(device=device)
dataset = torch.utils.data.TensorDataset(X, y)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True, generator=generator
)

for epoch in range(epochs):
for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
# Resize the outputs to match the target size
outputs = nn.functional.interpolate(
outputs, size=(height, width), mode="bilinear", align_corners=False
)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")
save_model(model, "unet_model.pth")
return model
Loading

0 comments on commit 86b230b

Please sign in to comment.