Skip to content
/ DOFA Public

Code for Neural Plasticity-Inspired Foundation Model for Observing the Earth Crossing Modalities

License

Notifications You must be signed in to change notification settings

zhu-xlab/DOFA

Repository files navigation

🚨 !!! A new version of DOFA weights has been released !!!

🚨 Please use the new version of codes and weights. The performance is much better!

🚨 If you are using the new version of weights, please make sure that the new version of wave_dynamic_layer.py is used.


Dynamic One-For-All foundation model for Remote sensing and Earth observation

What is DOFA: DOFA is a unified multimodal foundation model for different data modalities in remote sensing and Earth observation.

Differences with existing foundation models: DOFA is pre-trained using five different data modalities in remote sensing and Earth observation. It can handle images with any number of input channels.

DOFA is inspired by neuroplasticity Neuroplasticity is an important brain mechanism for adjusting to new experiences or environmental shifts. Inspired by this concept, we design DOFA to emulate this mechanism for processing multimodal EO data.

For more details, please take a look at the paper Neural Plasticity-Inspired Foundation Model for Observing the Earth Crossing Modalities.

Why develop DOFA

  • The learned multimodal representation may not effectively capture such an intersensor relationship.

  • The performance of foundation models will degrade when downstream tasks require the utilization of data from unseen sensors with varying numbers of spectral bands and spatial resolutions or different wavelength regimes.

  • The development of individual, customized foundation models requires considerably more computing resources and human efforts.

  • The increasing number of specialized foundation models makes it difficult to select the most appropriate one for a specific downstream task.

Installation

Requirements

The requirements of DOFA can be installed as follows:

> pip install -r requirements.txt

Weights

Pre-trained model weights can be downloaded from HuggingFace.

Usage

Please refer to demo.ipynb for more details.

DOFA supports input images with any number of channels using our pre-trained foundation models. The following examples show how to use DOFA for Sentinel-1 (SAR), Sentinel-2, NAIP RGB. We will add example usage for Gaofen Multispectral, and Hyperspectral data soon.


The following examples show that how to use a single DOFA model to process image data from different modalities with any number of channels!


Download the pre-trained weights for DOFA from huggingface

python download_weights.py

Load the pre-trained weights of DOFA base model

from models_dwv import vit_base_patch16

check_point = torch.load('./checkpoints/DOFA_ViT_base_e100.pth')
vit_model = vit_base_patch16()
msg = vit_model.load_state_dict(check_point, strict=False)
# missing_keys=['fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias'], unexpected_keys=['mask_token', 'norm.weight', 'norm.bias', 'projector.weight', 'projector.bias']
vit_model = vit_model.cuda()

Now you can use the loaded single DOFA model to process image data from different modalities with any number of channels!

Preprare for the data loading and preprocessing

# Step 1: Data preprocessing (normalization and resize)

import torch
import rasterio
import kornia as K
import numpy as np
# vh,vv
S1_MEAN = [166.36275909, 88.45542715]# / 255.0
S1_STD = [64.83126309, 43.07350145]# /255.0

S2_MEAN = [114.1099739 , 114.81779093, 126.63977424,  84.33539309,
        97.84789168, 103.94461911, 101.435633  ,  72.32804172,
        56.66528851]
S2_STD = [77.84352553, 69.96844919, 67.42465279, 64.57022983, 61.72545487,
       61.34187099, 60.29744676, 47.88519516, 42.55886798]

NAIP_MEAN = [123.675, 116.28, 103.53] # ImageNet stats for now
NAIP_STD = [58.395, 57.12, 57.375] # ImageNet stats for now

Gaufen_MEAN = [123.94924583,  92.58088583,  97.28130189,  90.31526596]
Gaufen_STD = [67.34487297, 62.8271046 , 60.5856767 , 60.3946299]
class DataAugmentation(torch.nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.transform = torch.nn.Sequential(
            K.augmentation.RandomResizedCrop(size=(224,224), scale=(0.2,1.0)),
            K.augmentation.Normalize(mean=mean,std=std)
        )
    @torch.no_grad()
    def forward(self,x):
        x_out = self.transform(x)
        return x_out

Load Sentinel-1 data with 2 channels

transform = DataAugmentation(mean=S1_MEAN,std=S1_STD)

def preprocess_s1(vh_path, vv_path):
    with rasterio.open(vh_path) as f1:
        vh = f1.read()
    with rasterio.open(vv_path) as f2:
        vv = f2.read()
    s1_img = np.concatenate((vh,vv),0).astype('float32')
    s1_img = torch.from_numpy(s1_img)
    s1_img = transform(s1_img).squeeze(0)
    return s1_img
import matplotlib.pyplot as plt
# Load Sentinel-1 images from the given example data
C = 2  # can be 2,3,4,6,9,12,13,202 or any number if you can provide the wavelengths of them

image1 = './data/s1/vv/1869_3575.png'
image2 = './data/s1/vh/1869_3575.png'
s1_img = preprocess_s1(image1,image2)

fig, ax = plt.subplots(nrows=1, ncols=C, figsize=(10, 10))

for i,row in enumerate(ax):
    row.imshow(s1_img[i])

s1_img = s1_img.view([1,2,224,224]).cuda()

png

wavelengths = [3.75, 3.75]
out_feat = vit_model.forward_features(s1_img, wave_list=wavelengths)
out_logits = vit_model.forward(s1_img, wave_list=wavelengths)
print(out_feat.shape)
print(out_logits.shape)

Load Sentinel-2 data with 9 channels

import glob

C = 9

transform = DataAugmentation(mean=S2_MEAN,std=S2_STD)

def preprocess_s2(img_path):
    chs = []
    s2_files = glob.glob(f"{img_path}/*/*.png")
    for path in s2_files:
        with rasterio.open(path) as f:
            ch = f.read()
        chs.append(ch)
    s2_img = np.concatenate(chs, 0).astype("float32")
    s2_img = torch.from_numpy(s2_img)
    s2_img = transform(s2_img).squeeze(0)
    return s2_img

Visualize the Sentinel-2 imagery

s2_img = preprocess_s2("data/s2/S2A_MSIL1C_20170528T050611_N0205_R076_T44NMM_20170528T050606")
fig, ax = plt.subplots(nrows=1, ncols=C, figsize=(10, 10))

for i,row in enumerate(ax):
    row.imshow(s2_img[i])
    row.axis("off")
s2_img = s2_img.view([1,C,224,224]).cuda()
/opt/miniconda3/lib/python3.11/site-packages/rasterio/__init__.py:304: NotGeoreferencedWarning: Dataset has no geotransform, gcps, or rpcs. The identity matrix will be returned.
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)

png

We use the same DOFA model to inference the Sentinel-2 image

wavelengths = [0.665, 0.56, 0.49, 0.705, 0.74, 0.783, 0.842, 1.61, 2.19]

out_feat = vit_model.forward_features(s2_img, wave_list=wavelengths)
out_logits = vit_model.forward(s2_img, wave_list=wavelengths)
print(out_feat.shape)
print(out_logits.shape)

What if I only want to use a subset of Sentinel-2 data?

# Let's only keep the first 5 channels
wavelengths = [0.665, 0.56, 0.49, 0.705, 0.74]

out_feat = vit_model.forward_features(s2_img[:,:5,...], wave_list=wavelengths)
out_logits = vit_model.forward(s2_img[:,:5,...], wave_list=wavelengths)
print(out_feat.shape)
print(out_logits.shape)

Usage for RGB optical data

C = 3

transform = DataAugmentation(mean=NAIP_MEAN, std=NAIP_STD)


def preprocess_rgb(img_path):
    with rasterio.open(img_path) as f:
        rgb_img = f.read().astype("float32")
    rgb_img = torch.from_numpy(rgb_img)
    rgb_img = transform(rgb_img).squeeze(0)
    return rgb_img
import cv2

rgb_path = 'data/naip/36861_49963.png'
naip_img = preprocess_rgb(rgb_path)

plt.imshow(cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB))
plt.axis("off")

naip_img = naip_img.view([1,C,224,224]).cuda()

png

wavelengths = [0.665, 0.56, 0.49]

out_feat = vit_model.forward_features(naip_img, wave_list=wavelengths)
out_logits = vit_model.forward(naip_img, wave_list=wavelengths)
print(out_feat.shape)
print(out_logits.shape)

Usage for Hyperspectral images is similar to other images.

TorchGeo

Alternatively, DOFA can be used via the TorchGeo library:

import torch
from torchgeo.models import DOFABase16_Weights, dofa_base_patch16_224

# Example NAIP image (wavelengths in $\mu$m)
x = torch.rand(2, 4, 224, 224)
wavelengths = [0.48, 0.56, 0.64, 0.81]

# Use pre-trained model weights
model = dofa_base_patch16_224(weights=DOFABase16_Weights.DOFA_MAE)

# Make a prediction (model may need to be fine-tuned first)
y = model(x, wavelengths)

If you find the DOFA useful in your research, please kindly cite our paper:

@article{xiong2024neural,
  title={Neural Plasticity-Inspired Foundation Model for Observing the {Earth} Crossing Modalities},
  author={Xiong, Zhitong and Wang, Yi and Zhang, Fahong and Stewart, Adam J and Hanna, Jo{\"e}lle and Borth, Damian and Papoutsis, Ioannis and Saux, Bertrand Le and Camps-Valls, Gustau and Zhu, Xiao Xiang},
  journal={arXiv preprint arXiv:2403.15356},
  year={2024}
}