diff --git a/projects/super_res/config.py b/projects/super_res/config.py index 80ebd24c08..4d8aadc37c 100644 --- a/projects/super_res/config.py +++ b/projects/super_res/config.py @@ -1,6 +1,5 @@ from ml_collections import config_dict -#batch_size = 4 config = config_dict.ConfigDict() config.dim = 64 @@ -9,34 +8,36 @@ config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 config.diffusion_steps = 1500 -config.sampling_steps = 6 -config.loss = "l1" +config.sampling_steps = 20 +config.loss = "l2" config.objective = "pred_v" -config.lr = 8e-5 -config.steps = 5000000 +config.lr = 1e-4 +config.steps = 700000 config.grad_acc = 1 -config.val_num_of_batch = 1 -config.save_and_sample_every = 5000 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "" +config.additional_note = "2d-nomulti-nols-ensemble" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" config.milestone = 1 +config.rollout = None +config.rollout_batch = None config.batch_size = 1 config.data_config = config_dict.ConfigDict({ "dataset_name": "c384", "length": 7, - #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], "channels": ["PRATEsfc_coarse"], - #"img_channel": 2, "img_channel": 1, "img_size": 384, - "logscale": True, - "quick": True + "logscale": False, + "multi": False, + "flow": "2d", + "minipatch": False }) config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" diff --git a/projects/super_res/config_focal.py b/projects/super_res/config_focal.py new file mode 100644 index 0000000000..fb4a988761 --- /dev/null +++ b/projects/super_res/config_focal.py @@ -0,0 +1,44 @@ +from ml_collections import config_dict + +config = config_dict.ConfigDict() + +config.dim = 64 +config.dim_mults = (1, 1, 2, 2, 3, 4) +config.learned_sinusoidal_cond = True, +config.random_fourier_features = True, +config.learned_sinusoidal_dim = 32 +config.diffusion_steps = 1500 +config.sampling_steps = 20 +config.loss = "focal" +config.objective = "pred_v" +config.lr = 1e-4 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.995 +config.amp = False +config.split_batches = True +config.additional_note = "2d-multi-ls-focal-ensemble" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 1 +config.rollout = None +config.rollout_batch = None + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 7, + "channels": ["PRATEsfc_coarse"], + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "flow": "2d", + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_infer.py b/projects/super_res/config_infer.py index 109103ceec..2520b3a1ca 100644 --- a/projects/super_res/config_infer.py +++ b/projects/super_res/config_infer.py @@ -2,41 +2,43 @@ config = config_dict.ConfigDict() - config.dim = 64 -config.dim_mults = (1, 1, 2, 2, 4, 4) +config.dim_mults = (1, 1, 2, 2, 3, 4) config.learned_sinusoidal_cond = True, config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 config.diffusion_steps = 1500 -config.sampling_steps = 6 -config.loss = "l1" +config.sampling_steps = 20 +config.loss = "l2" config.objective = "pred_v" -config.lr = 8e-5 -config.steps = 5000000 -config.grad_acc = 2 +config.lr = 1e-4 +config.steps = 700000 +config.grad_acc = 1 config.val_num_of_batch = 5 -config.save_and_sample_every = 5000 +config.save_and_sample_every = 20000 config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "" +config.additional_note = "2d-nomulti-nols-ensemble" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" -config.milestone = 1 +config.milestone = 2 +config.rollout = "partial" +config.rollout_batch = 25 -config.batch_size = 4 +config.batch_size = 2 config.data_config = config_dict.ConfigDict({ "dataset_name": "c384", "length": 7, - #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], "channels": ["PRATEsfc_coarse"], - #"img_channel": 2, "img_channel": 1, "img_size": 384, - "logscale": True + "logscale": False, + "multi": False, + "flow": "2d", + "minipatch": False }) -data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" -model_name = f"c384-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" \ No newline at end of file +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_isr.py b/projects/super_res/config_isr.py new file mode 100644 index 0000000000..9100602f51 --- /dev/null +++ b/projects/super_res/config_isr.py @@ -0,0 +1,37 @@ +from ml_collections import config_dict + +config = config_dict.ConfigDict() + +config.lr = 1e-4 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.995 +config.amp = False +config.split_batches = True +config.additional_note = "isr" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 1 +config.rollout = None +config.rollout_batch = None + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 7, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "flow": "2d", + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_isr_infer.py b/projects/super_res/config_isr_infer.py new file mode 100644 index 0000000000..54285da702 --- /dev/null +++ b/projects/super_res/config_isr_infer.py @@ -0,0 +1,37 @@ +from ml_collections import config_dict + +config = config_dict.ConfigDict() + +config.lr = 1e-4 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.995 +config.amp = False +config.split_batches = True +config.additional_note = "isr" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 2 +config.rollout = 'partial' +config.rollout_batch = 25 + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 7, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "flow": "2d", + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_rvrt_full.py b/projects/super_res/config_rvrt_full.py new file mode 100644 index 0000000000..f68fff7776 --- /dev/null +++ b/projects/super_res/config_rvrt_full.py @@ -0,0 +1,50 @@ +from ml_collections import config_dict + +#batch_size = 4 +config = config_dict.ConfigDict() + +config.dim = 120 +config.num_blocks = 6 +config.num_heads = 8 +config.depth = 8 +config.time_emb_dim = 32 +config.learned_sinusoidal_cond = True +config.diffusion_steps = 1500 +config.sampling_steps = 20 +# config.loss = "l2" +config.loss = "charbonnier" +config.objective = "pred_x0" +# config.lr = 8e-5 +config.lr = 1e-4 +# config.steps = 500000 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.999 +config.amp = False +config.split_batches = True +config.additional_note = "rvrt_full" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 1 +config.rollout = None +config.rollout_batch = None + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 6, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_rvrt_full_infer.py b/projects/super_res/config_rvrt_full_infer.py new file mode 100644 index 0000000000..12ab1a68f2 --- /dev/null +++ b/projects/super_res/config_rvrt_full_infer.py @@ -0,0 +1,50 @@ +from ml_collections import config_dict + +#batch_size = 4 +config = config_dict.ConfigDict() + +config.dim = 120 +config.num_blocks = 6 +config.num_heads = 8 +config.depth = 8 +config.time_emb_dim = 32 +config.learned_sinusoidal_cond = True +config.diffusion_steps = 1500 +config.sampling_steps = 20 +# config.loss = "l2" +config.loss = "charbonnier" +config.objective = "pred_x0" +# config.lr = 8e-5 +config.lr = 1e-4 +# config.steps = 500000 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.999 +config.amp = False +config.split_batches = True +config.additional_note = "rvrt_full" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 2 +config.rollout = 'partial' +config.rollout_batch = 22 + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 6, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/data/channel_data_gen.py b/projects/super_res/data/channel_data_gen.py new file mode 100644 index 0000000000..145fc634eb --- /dev/null +++ b/projects/super_res/data/channel_data_gen.py @@ -0,0 +1,29 @@ +import xarray as xr +import numpy as np +from pathlib import Path + +channel_folder = Path('./more_channels') +channel_folder.mkdir(exist_ok = True, parents = True) + +c384 = xr.open_zarr("gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt_coarse": "x", "grid_yt_coarse": "y"}) +c48 = xr.open_zarr("gs://vcm-ml-intermediate/2021-10-12-PIRE-c48-post-spinup-verification/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt": "x", "grid_yt": "y"}) + +channels = ["UGRD10m_coarse", "VGRD10m_coarse", "tsfc_coarse", "CPRATEsfc_coarse"] +c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +split = int(c384_np.shape[1] * 0.8) + +# compute statistics on training set +c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(axis=(0,1,3,4)).reshape(1,1,4,1,1), c384_np[:, :split, :, :, :].max(axis=(0,1,3,4)).reshape(1,1,4,1,1), c48_np[:, :split, :, :, :].min(axis=(0,1,3,4)).reshape(1,1,4,1,1), c48_np[:, :split, :, :, :].max(axis=(0,1,3,4)).reshape(1,1,4,1,1) + +# normalize +c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +np.save('more_channels/c384_min.npy', c384_min) +np.save('more_channels/c384_max.npy', c384_max) +np.save('more_channels/c48_min.npy', c48_min) +np.save('more_channels/c48_max.npy', c48_max) +np.save('more_channels/c48_norm.npy', c48_norm) +np.save('more_channels/c384_norm.npy', c384_norm) \ No newline at end of file diff --git a/projects/super_res/data/dataload.sh b/projects/super_res/data/dataload.sh new file mode 100755 index 0000000000..55f69df969 --- /dev/null +++ b/projects/super_res/data/dataload.sh @@ -0,0 +1,13 @@ +#! /bin/sh +channel='c48_atmos_ave' +file='atmos_8xdaily_ave_coarse.zarr' +for member in $(seq -f "%04g" 1 11) +do + mkdir -p /data/prakhars/ensemble/$channel/$member + gsutil -m cp -r gs://vcm-ml-raw-flexible-retention/2023-08-14-C384-reference-ensemble/ic_$member/diagnostics/$file /data/prakhars/ensemble/$channel/$member +done +# channel --> file +# c384_precip_ave --> sfc_8xdaily_ave.zarr +# c48_precip_plus_more_ave --> sfc_8xdaily_ave_coarse.zarr +# c384_topo --> atmos_static.zarr +# c48_atmos_ave --> atmos_8xdaily_ave_coarse.zarr \ No newline at end of file diff --git a/projects/super_res/data/ensemble_c384_trainstats/chl.pkl b/projects/super_res/data/ensemble_c384_trainstats/chl.pkl new file mode 100644 index 0000000000..fa1744dbe2 Binary files /dev/null and b/projects/super_res/data/ensemble_c384_trainstats/chl.pkl differ diff --git a/projects/super_res/data/ensemble_c384_trainstats/log_chl.pkl b/projects/super_res/data/ensemble_c384_trainstats/log_chl.pkl new file mode 100644 index 0000000000..de65b8ebb4 Binary files /dev/null and b/projects/super_res/data/ensemble_c384_trainstats/log_chl.pkl differ diff --git a/projects/super_res/data/ensemble_c384_trainstats/topo.pkl b/projects/super_res/data/ensemble_c384_trainstats/topo.pkl new file mode 100644 index 0000000000..7d0af0072a Binary files /dev/null and b/projects/super_res/data/ensemble_c384_trainstats/topo.pkl differ diff --git a/projects/super_res/data/ensemble_c48_trainstats/atm_chl.pkl b/projects/super_res/data/ensemble_c48_trainstats/atm_chl.pkl new file mode 100644 index 0000000000..e570cc3169 Binary files /dev/null and b/projects/super_res/data/ensemble_c48_trainstats/atm_chl.pkl differ diff --git a/projects/super_res/data/ensemble_c48_trainstats/chl.pkl b/projects/super_res/data/ensemble_c48_trainstats/chl.pkl new file mode 100644 index 0000000000..bb1926631a Binary files /dev/null and b/projects/super_res/data/ensemble_c48_trainstats/chl.pkl differ diff --git a/projects/super_res/data/ensemble_c48_trainstats/log_chl.pkl b/projects/super_res/data/ensemble_c48_trainstats/log_chl.pkl new file mode 100644 index 0000000000..aded7a514b Binary files /dev/null and b/projects/super_res/data/ensemble_c48_trainstats/log_chl.pkl differ diff --git a/projects/super_res/data/ensemblec384logtrainstats.py b/projects/super_res/data/ensemblec384logtrainstats.py new file mode 100644 index 0000000000..eaf206573e --- /dev/null +++ b/projects/super_res/data/ensemblec384logtrainstats.py @@ -0,0 +1,20 @@ +import pickle +import numpy as np +from pathlib import Path + +precip_folder = Path('./ensemble_c384_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +# load the data +with open('ensemble_c384_trainstats/chl.pkl', 'rb') as f: + chl = pickle.load(f) + +precip = chl['PRATEsfc'] +log_chl = {} +log_chl['PRATEsfc'] = {} +log_chl['PRATEsfc']['min'] = np.log(precip['min'] - precip['min'] + 1e-14) +log_chl['PRATEsfc']['max'] = np.log(precip['max'] - precip['min'] + 1e-14) + +# save the chl dictionary as pickle +with open(precip_folder / 'log_chl.pkl', 'wb') as f: + pickle.dump(log_chl, f) \ No newline at end of file diff --git a/projects/super_res/data/ensemblec384topotrainstats.py b/projects/super_res/data/ensemblec384topotrainstats.py new file mode 100644 index 0000000000..11b11aba35 --- /dev/null +++ b/projects/super_res/data/ensemblec384topotrainstats.py @@ -0,0 +1,36 @@ +import pickle +import numpy as np +import xarray as xr +from tqdm import tqdm +from pathlib import Path + +precip_folder = Path('./ensemble_c384_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +ENSEMBLE = 10 + +channels = ["zsurf"] +chl = {} + +for channel in channels: + + chl[channel] = {} + chl[channel]['min'] = np.PINF + chl[channel]['max'] = np.NINF + +for member in tqdm(range(1, ENSEMBLE + 1)): + + topo = xr.open_zarr(f"/extra/ucibdl0/shared/data/fv3gfs/c384_topo/{member:04d}/atmos_static.zarr") + + for channel in tqdm(channels): + channel_384 = topo[channel] + channel_384_min = channel_384.min().values + channel_384_max = channel_384.max().values + if channel_384_min < chl[channel]['min']: + chl[channel]['min'] = channel_384_min + if channel_384_max > chl[channel]['max']: + chl[channel]['max'] = channel_384_max + +# save the chl dictionary as pickle +with open(precip_folder / 'topo.pkl', 'wb') as f: + pickle.dump(chl, f) \ No newline at end of file diff --git a/projects/super_res/data/ensemblec384trainstats.py b/projects/super_res/data/ensemblec384trainstats.py new file mode 100644 index 0000000000..8039c87ec8 --- /dev/null +++ b/projects/super_res/data/ensemblec384trainstats.py @@ -0,0 +1,45 @@ +import pickle +import numpy as np +import xarray as xr +from tqdm import tqdm +from pathlib import Path + +precip_folder = Path('./ensemble_c384_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +ENSEMBLE = 10 + +channels = ["PRATEsfc"] +chl = {} + +for channel in channels: + + chl[channel] = {} + chl[channel]['min'] = np.PINF + chl[channel]['max'] = np.NINF + +for member in tqdm(range(1, ENSEMBLE + 1)): + + c384 = xr.open_zarr(f"/extra/ucibdl0/shared/data/fv3gfs/c384_precip_ave/{member:04d}/sfc_8xdaily_ave.zarr") + + for channel in tqdm(channels): + + channel_384 = c384[channel] + + for idx in tqdm(range(397)): + + channel_384_slice = channel_384.isel(time = slice(idx*8, (idx+1)*8)) + channel_384_max = channel_384_slice.max().values + channel_384_min = channel_384_slice.min().values + + if channel_384_min < chl[channel]['min']: + + chl[channel]['min'] = channel_384_min + + if channel_384_max > chl[channel]['max']: + + chl[channel]['max'] = channel_384_max + +# save the chl dictionary as pickle +with open(precip_folder / 'chl.pkl', 'wb') as f: + pickle.dump(chl, f) \ No newline at end of file diff --git a/projects/super_res/data/ensemblec48logtrainstats.py b/projects/super_res/data/ensemblec48logtrainstats.py new file mode 100644 index 0000000000..f5a66ffd61 --- /dev/null +++ b/projects/super_res/data/ensemblec48logtrainstats.py @@ -0,0 +1,20 @@ +import pickle +import numpy as np +from pathlib import Path + +precip_folder = Path('./ensemble_c48_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +# load the data +with open('ensemble_c48_trainstats/chl.pkl', 'rb') as f: + chl = pickle.load(f) + +precip = chl['PRATEsfc_coarse'] +log_chl = {} +log_chl['PRATEsfc_coarse'] = {} +log_chl['PRATEsfc_coarse']['min'] = np.log(precip['min'] - precip['min'] + 1e-14) +log_chl['PRATEsfc_coarse']['max'] = np.log(precip['max'] - precip['min'] + 1e-14) + +# save the chl dictionary as pickle +with open(precip_folder / 'log_chl.pkl', 'wb') as f: + pickle.dump(log_chl, f) \ No newline at end of file diff --git a/projects/super_res/data/ensemblec48trainstats.py b/projects/super_res/data/ensemblec48trainstats.py new file mode 100644 index 0000000000..4b8b42e981 --- /dev/null +++ b/projects/super_res/data/ensemblec48trainstats.py @@ -0,0 +1,58 @@ +import pickle +import numpy as np +import xarray as xr +from tqdm import tqdm +from pathlib import Path + +precip_folder = Path('./ensemble_c48_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +ENSEMBLE = 10 + +channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] +atm_channels = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + +chl, atm_chl = {}, {} + +for channel in channels: + + chl[channel] = {} + chl[channel]['min'] = np.PINF + chl[channel]['max'] = np.NINF + +for channel in atm_channels: + + atm_chl[channel] = {} + atm_chl[channel]['min'] = np.PINF + atm_chl[channel]['max'] = np.NINF + +for member in tqdm(range(1, ENSEMBLE + 1)): + + c48 = xr.open_zarr(f"/extra/ucibdl0/shared/data/fv3gfs/c48_precip_plus_more_ave/{member:04d}/sfc_8xdaily_ave_coarse.zarr") + c48_atm = xr.open_zarr(f"/extra/ucibdl0/shared/data/fv3gfs/c48_atmos_ave/{member:04d}/atmos_8xdaily_ave_coarse.zarr") + + for channel in tqdm(channels): + channel_48 = c48[channel] + channel_48_min = channel_48.min().values + channel_48_max = channel_48.max().values + if channel_48_min < chl[channel]['min']: + chl[channel]['min'] = channel_48_min + if channel_48_max > chl[channel]['max']: + chl[channel]['max'] = channel_48_max + + for channel in tqdm(atm_channels): + channel_48 = c48_atm[channel] + channel_48_min = channel_48.min().values + channel_48_max = channel_48.max().values + if channel_48_min < atm_chl[channel]['min']: + atm_chl[channel]['min'] = channel_48_min + if channel_48_max > atm_chl[channel]['max']: + atm_chl[channel]['max'] = channel_48_max + +# save the chl dictionary as pickle +with open(precip_folder / 'chl.pkl', 'wb') as f: + pickle.dump(chl, f) + +# save the atm_chl dictionary as pickle +with open(precip_folder / 'atm_chl.pkl', 'wb') as f: + pickle.dump(atm_chl, f) \ No newline at end of file diff --git a/projects/super_res/data/load_data.py b/projects/super_res/data/load_data.py index dc6c6bc21f..312d872367 100644 --- a/projects/super_res/data/load_data.py +++ b/projects/super_res/data/load_data.py @@ -13,15 +13,15 @@ def load_data(data_config, batch_size, num_workers = 4, pin_memory = True): train, batch_size = batch_size, shuffle = True, - num_workers = num_workers, + num_workers = 2, pin_memory = pin_memory, ) val = DataLoader( val, - batch_size = batch_size, + batch_size = 5, shuffle = False, - num_workers = num_workers, + num_workers = 2, pin_memory = pin_memory, ) diff --git a/projects/super_res/data/load_dataset.py b/projects/super_res/data/load_dataset.py index 6678d27bdd..c4c74cb10f 100644 --- a/projects/super_res/data/load_dataset.py +++ b/projects/super_res/data/load_dataset.py @@ -1,15 +1,16 @@ -from .vsrdata import VSRDataset +# from .vsrdata import VSRDataset +# from .vsrdata_new import VSRDataset +from .vsrdata_ensemble import VSRDataset def load_dataset(data_config): - channels = data_config["channels"] length = data_config["length"] logscale = data_config["logscale"] - quick = data_config["quick"] + multi = data_config["multi"] train, val = None, None - train = VSRDataset(channels, 'train', length, logscale, quick) - val = VSRDataset(channels, 'val', length, logscale, quick) + train = VSRDataset('train', length, logscale, multi) + val = VSRDataset('val', length, logscale, multi) return train, val \ No newline at end of file diff --git a/projects/super_res/data/precip_data_gen.py b/projects/super_res/data/precip_data_gen.py new file mode 100644 index 0000000000..2d8c1709a3 --- /dev/null +++ b/projects/super_res/data/precip_data_gen.py @@ -0,0 +1,50 @@ +import xarray as xr +import numpy as np +from pathlib import Path + +precip_folder = Path('./only_precip') +precip_folder.mkdir(exist_ok = True, parents = True) + +c384 = xr.open_zarr("gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt_coarse": "x", "grid_yt_coarse": "y"}) +c48 = xr.open_zarr("gs://vcm-ml-intermediate/2021-10-12-PIRE-c48-post-spinup-verification/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt": "x", "grid_yt": "y"}) + +channels = ["PRATEsfc_coarse"] +c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +np.save('only_precip/c384_gmin.npy', c384_np.min()) +np.save('only_precip/c48_gmin.npy', c48_np.min()) + +# calculate split (80/20) +split = int(c384_np.shape[1] * 0.8) + +# compute statistics on training set +c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + +# normalize +c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +np.save('only_precip/c384_min.npy', c384_min) +np.save('only_precip/c384_max.npy', c384_max) +np.save('only_precip/c48_min.npy', c48_min) +np.save('only_precip/c48_max.npy', c48_max) +np.save('only_precip/c48_norm.npy', c48_norm) +np.save('only_precip/c384_norm.npy', c384_norm) + +c384_lnp = np.log(c384_np - c384_np.min() + 1e-14) +c48_lnp = np.log(c48_np - c48_np.min() + 1e-14) + +# compute statistics on training set +c384_lmin, c384_lmax, c48_lmin, c48_lmax = c384_lnp[:, :split, :, :, :].min(), c384_lnp[:, :split, :, :, :].max(), c48_lnp[:, :split, :, :, :].min(), c48_lnp[:, :split, :, :, :].max() + +# normalize +c384_lnorm= (c384_lnp - c384_lmin) / (c384_lmax - c384_lmin) +c48_lnorm = (c48_lnp - c48_lmin) / (c48_lmax - c48_lmin) + +np.save('only_precip/c384_lgmin.npy', c384_lmin) +np.save('only_precip/c384_lgmax.npy', c384_lmax) +np.save('only_precip/c48_lgmin.npy', c48_lmin) +np.save('only_precip/c48_lgmax.npy', c48_lmax) +np.save('only_precip/c48_lgnorm.npy', c48_lnorm) +np.save('only_precip/c384_lgnorm.npy', c384_lnorm) \ No newline at end of file diff --git a/projects/super_res/data/topo_data_gen.py b/projects/super_res/data/topo_data_gen.py new file mode 100644 index 0000000000..531fd3fbfe --- /dev/null +++ b/projects/super_res/data/topo_data_gen.py @@ -0,0 +1,112 @@ +import xarray as xr +import numpy as np +from typing import TypeVar, Union, Tuple, Hashable, Any, Callable +from pathlib import Path + +topo_folder = Path('./topography') +topo_folder.mkdir(exist_ok = True, parents = True) + +topo384 = xr.open_zarr('gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_static_coarse.zarr') + +wts = xr.open_zarr('gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/grid_spec_coarse.zarr') + +T_DataArray_or_Dataset = TypeVar("T_DataArray_or_Dataset", xr.DataArray, xr.Dataset) +CoordFunc = Callable[[Any, Union[int, Tuple[int]]], Any] + +def coarsen_coords_coord_func( + coordinate: np.ndarray, axis: Union[int, Tuple[int]] = -1 +) -> np.ndarray: + """xarray coarsen coord_func version of coarsen_coords. + + Note that xarray requires an axis argument for this to work, but it is not + used by this function. To coarsen dimension coordinates, xarray reshapes + the 1D coordinate into a 2D array, with the rows representing groups of + values to aggregate together in some way. The length of the rows + corresponds to the coarsening factor. The value of the coordinate sampled + every coarsening factor is just the first value in each row. + + Args: + coordinate: 2D array of coordinate values + axis: Axes to reduce along (not used) + + Returns: + np.array + """ + return ( + ((coordinate[:, 0] - 1) // coordinate.shape[1] + 1) + .astype(int) + .astype(np.float32) + ) + +def _propagate_attrs( + reference_obj: T_DataArray_or_Dataset, obj: T_DataArray_or_Dataset +) -> T_DataArray_or_Dataset: + """Propagate attributes from the reference object to another. + + Args: + reference_obj: input object + obj: output object + + Returns: + xr.DataArray or xr.Dataset + """ + if isinstance(reference_obj, xr.Dataset): + for variable in reference_obj: + obj[variable].attrs = reference_obj[variable].attrs + obj.attrs = reference_obj.attrs + return obj + + +def weighted_block_average( + obj: T_DataArray_or_Dataset, + weights: xr.DataArray, + coarsening_factor: int, + x_dim: Hashable = "xaxis_1", + y_dim: Hashable = "yaxis_2", + coord_func: Union[str, CoordFunc] = coarsen_coords_coord_func, +) -> T_DataArray_or_Dataset: + """Coarsen a DataArray or Dataset through weighted block averaging. + + Note that this function assumes that the x and y dimension names of the + input DataArray and weights are the same. + + Args: + obj: Input Dataset or DataArray. + weights: Weights (e.g. area or pressure thickness). + coarsening_factor: Integer coarsening factor to use. + x_dim: x dimension name (default 'xaxis_1'). + y_dim: y dimension name (default 'yaxis_1'). + coord_func: function that is applied to the coordinates, or a + mapping from coordinate name to function. See `xarray's coarsen + method for details + `_. + + Returns: + xr.Dataset or xr.DataArray. + """ + coarsen_kwargs = {x_dim: coarsening_factor, y_dim: coarsening_factor} + numerator = (obj * weights).coarsen(coarsen_kwargs, coord_func=coord_func).sum() # type: ignore # noqa + denominator = weights.coarsen(coarsen_kwargs, coord_func=coord_func).sum() # type: ignore # noqa + result = numerator / denominator + + if isinstance(obj, xr.DataArray): + result = result.rename(obj.name) + + return _propagate_attrs(obj, result) + +topo48 = weighted_block_average(topo384, wts['area_coarse'], 8, 'grid_xt_coarse', 'grid_yt_coarse') + +topo384 = topo384['zsurf_coarse'].values +topo48 = topo48['zsurf_coarse'].values + +topo384_min, topo384_max, topo48_min, topo48_max = topo384.min(), topo384.max(), topo48.min(), topo48.max() + +topo384_norm = (topo384 - topo384_min) / (topo384_max - topo384_min) +topo48_norm = (topo48 - topo48_min) / (topo48_max - topo48_min) + +np.save('topography/topo384_norm.npy', topo384_norm) +np.save('topography/topo48_norm.npy', topo48_norm) +np.save('topography/topo384_min.npy', topo384_min) +np.save('topography/topo384_max.npy', topo384_max) +np.save('topography/topo48_min.npy', topo48_min) +np.save('topography/topo48_max.npy', topo48_max) \ No newline at end of file diff --git a/projects/super_res/data/vsrdata.py b/projects/super_res/data/vsrdata.py index 1dacbfc851..9a90a4153d 100644 --- a/projects/super_res/data/vsrdata.py +++ b/projects/super_res/data/vsrdata.py @@ -1,10 +1,9 @@ -import xarray as xr import numpy as np from torch.utils.data import Dataset class VSRDataset(Dataset): - def __init__(self, channels, mode, length, logscale = False, quick = True): + def __init__(self, mode, length, logscale = False, multi = False): ''' Args: channels (list): list of channels to use @@ -20,58 +19,43 @@ def __init__(self, channels, mode, length, logscale = False, quick = True): # mode self.mode = mode - if not quick: - # load data from bucket - # shape : (tile, time, y, x) - c384 = xr.open_zarr("gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt_coarse": "x", "grid_yt_coarse": "y"}) - c48 = xr.open_zarr("gs://vcm-ml-intermediate/2021-10-12-PIRE-c48-post-spinup-verification/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt": "x", "grid_yt": "y"}) - - # convert to numpy - # shape : (tile, time, channel, y, x) - c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) - c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) - - if logscale: - c384_np = np.log(c384_np - c384_np.min() + 1e-14) - c48_np = np.log(c48_np - c48_np.min() + 1e-14) - - # calculate split (80/20) - split = int(c384_np.shape[1] * 0.8) - - # compute statistics on training set - c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + # data shape : (num_tiles, num_frames, num_channels, height, width) + # num_tiles = 6; num_frames = 2920, num_channels = 1 + if logscale: - # normalize - c384_norm= (c384_np - c384_min) / (c384_max - c384_min) - c48_norm = (c48_np - c48_min) / (c48_max - c48_min) - - if mode == 'train': - - self.X = c48_norm[:, :split, :, :, :] - self.y = c384_norm[:, :split, :, :, :] - - elif mode == 'val': - - self.X = c48_norm[:, split:, :, :, :] - self.y = c384_norm[:, split:, :, :, :] + c384_norm= np.load("data/only_precip/c384_lgnorm.npy") + c48_norm = np.load("data/only_precip/c48_lgnorm.npy") else: c384_norm= np.load("data/only_precip/c384_norm.npy") c48_norm = np.load("data/only_precip/c48_norm.npy") + + t, f, c, h, w = c384_norm.shape + + if multi: - # calculate split (80/20) - split = int(c384_norm.shape[1] * 0.8) + # load more channels, order : ("UGRD10m_coarse", "VGRD10m_coarse", "tsfc_coarse", "CPRATEsfc_coarse") + c48_norm_more = np.load("data/more_channels/c48_norm.npy") + c48_norm = np.concatenate((c48_norm, c48_norm_more), axis = 2) - if mode == 'train': - - self.X = c48_norm[:, :split, :, :, :] - self.y = c384_norm[:, :split, :, :, :] - - elif mode == 'val': - - self.X = c48_norm[:, split:, :, :, :] - self.y = c384_norm[:, split:, :, :, :] + # load topography, shape : (num_tiles, height, width) + # reshaping to match data shape + topo384 = np.repeat(np.load("data/topography/topo384_norm.npy").reshape((t, 1, c, 384, 384)), f, axis = 1) + c384_norm = np.concatenate((c384_norm, topo384), axis = 2) + + # calculate split (80/20) + split = int(c384_norm.shape[1] * 0.8) + + if mode == 'train': + + self.X = c48_norm[:, :split, :, :, :] + self.y = c384_norm[:, :split, :, :, :] + + elif mode == 'val': + + self.X = c48_norm[:, split:, :, :, :] + self.y = c384_norm[:, split:, :, :, :] def __len__(self): @@ -80,13 +64,13 @@ def __len__(self): def __getitem__(self, idx): # load a random tile index - if self.mode == 'train': tile = np.random.randint(0, self.X.shape[0]) elif self.mode == 'val': tile = 0 + # tensor shape : (length, num_channels, height, width) lowres = self.X[tile, idx:idx+self.length, :, :, :] highres = self.y[tile, idx:idx+self.length, :, :, :] diff --git a/projects/super_res/data/vsrdata_ensemble.py b/projects/super_res/data/vsrdata_ensemble.py new file mode 100644 index 0000000000..95083c9305 --- /dev/null +++ b/projects/super_res/data/vsrdata_ensemble.py @@ -0,0 +1,146 @@ +import torch +import pickle +import numpy as np +import xarray as xr +from torch.utils.data import Dataset + +class VSRDataset(Dataset): + + def __init__(self, mode, length, logscale = False, multi = False): + ''' + Args: + channels (list): list of channels to use + mode (str): train or val + length (int): length of sequence + logscale (bool): whether to logscale the data + multi (bool): whether to use multi-channel data + ''' + + ENSEMBLE = 11 + + # load data + self.X, self.X_, self.y, self.topo = {}, {}, {}, {} + + PATH = "/extra/ucibdl0/shared/data/fv3gfs" + + for member in range(1, ENSEMBLE + 1): + + self.X[member] = xr.open_zarr(f"{PATH}/c48_precip_plus_more_ave/{member:04d}/sfc_8xdaily_ave_coarse.zarr") + self.X_[member] = xr.open_zarr(f"{PATH}/c48_atmos_ave/{member:04d}/atmos_8xdaily_ave_coarse.zarr") + self.y[member] = xr.open_zarr(f"{PATH}/c384_precip_ave/{member:04d}/sfc_8xdaily_ave.zarr") + self.topo[member] = xr.open_zarr(f"{PATH}/c384_topo/{member:04d}/atmos_static.zarr") + + # expected sequence length + self.length = length + + self.mode = mode + self.logscale = logscale + self.multi = multi + + self.time_steps = self.X[1].time.shape[0] + self.tiles = self.X[1].tile.shape[0] + + # load statistics + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + self.c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + self.c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + self.c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + self.c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + self.c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + self.c384_topo = pickle.load(f) + + if multi: + + self.c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + self.c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + self.c384_channels = ["PRATEsfc"] + + else: + + self.c48_channels = ["PRATEsfc_coarse"] + self.c384_channels = ["PRATEsfc"] + + self.indices = list(range(self.time_steps - self.length + 1)) + + def __len__(self): + + return len(self.indices) + + def __getitem__(self, idx): + + time_idx = self.indices[idx] + + if self.mode == 'train': + + np.random.seed() + tile = np.random.randint(self.tiles) + member = np.random.randint(10) + 1 + + else: + + tile = idx % self.tiles + member = 11 + + X = self.X[member].isel(time = slice(time_idx, time_idx + self.length), tile = tile) + X_ = self.X_[member].isel(time = slice(time_idx, time_idx + self.length), tile = tile) + y = self.y[member].isel(time = slice(time_idx, time_idx + self.length), tile = tile) + + if self.multi: + + X = np.stack([X[channel].values for channel in self.c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in self.c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in self.c384_channels], axis = 1) + topo = self.topo[member].isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), self.length, axis = 0) + + else: + + X = np.stack([X[channel].values for channel in self.c48_channels], axis = 1) + y = np.stack([y[channel].values for channel in self.c384_channels], axis = 1) + + + if self.logscale: + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - self.c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - self.c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - self.c48_log_chl["PRATEsfc_coarse"]['min']) / (self.c48_log_chl["PRATEsfc_coarse"]['max'] - self.c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - self.c384_log_chl["PRATEsfc"]['min']) / (self.c384_log_chl["PRATEsfc"]['max'] - self.c384_log_chl["PRATEsfc"]['min']) + + else: + + X[:,0:1,:,:] = (X[:,0:1,:,:] - self.c48_chl["PRATEsfc_coarse"]['min']) / (self.c48_chl["PRATEsfc_coarse"]['max'] - self.c48_chl["PRATEsfc_coarse"]['min']) + y = (y - self.c384_chl["PRATEsfc"]['min']) / (self.c384_chl["PRATEsfc"]['max'] - self.c384_chl["PRATEsfc"]['min']) + + if self.multi: + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - self.c48_chl[self.c48_channels[i]]['min']) / (self.c48_chl[self.c48_channels[i]]['max'] - self.c48_chl[self.c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - self.c48_atm_chl[self.c48_channels_atmos[i]]['min']) / (self.c48_atm_chl[self.c48_channels_atmos[i]]['max'] - self.c48_atm_chl[self.c48_channels_atmos[i]]['min']) + + topo = (topo - self.c384_topo["zsurf"]['min']) / (self.c384_topo["zsurf"]['max'] - self.c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + return {'LR' : X, 'HR' : y} \ No newline at end of file diff --git a/projects/super_res/data/vsrdata_new.py b/projects/super_res/data/vsrdata_new.py new file mode 100644 index 0000000000..869b87c39e --- /dev/null +++ b/projects/super_res/data/vsrdata_new.py @@ -0,0 +1,107 @@ +import numpy as np +import xarray as xr +from torch.utils.data import Dataset + +class VSRDataset(Dataset): + + def __init__(self, mode, length, logscale = False, multi = False): + ''' + Args: + channels (list): list of channels to use + mode (str): train or val + length (int): length of sequence + logscale (bool): whether to logscale the data + quick (bool): whether to load data from bucket or from local (local only supports single precipitation channel) + ''' + + # load data + self.y = xr.open_zarr("/data/prakhars/pire_atmos_phys_3h_c384.zarr") + self.X = xr.open_zarr('/data/prakhars/pire_atmos_phys_3h_c48.zarr') + + # expected sequence length + self.length = length + + # mode + self.mode = mode + + self.logscale = logscale + + if logscale: + + self.c384_gmin = np.load('data/only_precip/c384_gmin.npy') + self.c48_gmin = np.load('data/only_precip/c48_gmin.npy') + self.c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + self.c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + self.c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + self.c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + + else: + + self.c384_min = np.load('data/only_precip/c384_min.npy') + self.c384_max = np.load('data/only_precip/c384_max.npy') + self.c48_min = np.load('data/only_precip/c48_min.npy') + self.c48_max = np.load('data/only_precip/c48_max.npy') + + self.time_steps = self.X.time.shape[0] + self.tiles = self.X.tile.shape[0] + + self.multi = multi + + if multi: + + self.channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "tsfc_coarse", "CPRATEsfc_coarse"] + self.topo384 = np.load("data/topography/topo384_norm.npy") + self.c384_multimin = np.load('data/more_channels/c384_min.npy') + self.c384_multimax = np.load('data/more_channels/c384_max.npy') + self.c48_multimin = np.load('data/more_channels/c48_min.npy') + self.c48_multimax = np.load('data/more_channels/c48_max.npy') + + else: + + self.channels = ["PRATEsfc_coarse"] + + if mode == 'train': + + self.indices = list(range(int(self.time_steps * 0.8) - self.length + 1)) + + elif mode == 'val': + + self.indices = list(range(int(self.time_steps * 0.8), self.time_steps - self.length + 1)) + + def __len__(self): + + return len(self.indices) + + def __getitem__(self, idx): + + time_idx = self.indices[idx] + if self.mode == 'train': + tile = idx % self.tiles + else: + tile = 0 + + lowres = self.X.isel(time = slice(time_idx, time_idx + self.length), tile = tile) + lowres = np.stack([lowres[channel].values for channel in self.channels], axis = 1) + highres = self.y.isel(time = slice(time_idx, time_idx + self.length), tile = tile) + highres = np.stack([highres[channel].values for channel in self.channels[0:1]], axis = 1) + + if self.logscale: + + lowres[:,0:1,:,:] = np.log(lowres[:,0:1,:,:] - self.c48_gmin + 1e-14) + highres = np.log(highres - self.c384_gmin + 1e-14) + lowres[:,0:1,:,:] = (lowres[:,0:1,:,:] - self.c48_lgmin) / (self.c48_lgmax - self.c48_lgmin) + highres = (highres - self.c384_lgmin) / (self.c384_lgmax - self.c384_lgmin) + + else: + + lowres[:,0:1,:,:] = (lowres[:,0:1,:,:] - self.c48_min) / (self.c48_max - self.c48_min) + highres = (highres - self.c384_min) / (self.c384_max - self.c384_min) + + if self.multi: + + lowres[:,1:,:,:] = (lowres[:,1:,:,:] - self.c48_multimin) / (self.c48_multimax - self.c48_multimin) + topo = self.topo384[tile,:,:] + topo = np.repeat(topo.reshape((1,1,384,384)), self.length, axis = 0) + highres = np.concatenate((highres, topo), axis = 1) + + return {'LR' : lowres, 'HR' : highres} \ No newline at end of file diff --git a/projects/super_res/model/autoreg_diffusion_mod.py b/projects/super_res/model/autoreg_diffusion_mod.py index 7090bf50be..a5e926f2c3 100644 --- a/projects/super_res/model/autoreg_diffusion_mod.py +++ b/projects/super_res/model/autoreg_diffusion_mod.py @@ -1,10 +1,10 @@ import os import math from pathlib import Path -from random import random +from random import random, randint from functools import partial from collections import namedtuple -from joblib import Parallel, delayed +import xarray as xr import numpy as np @@ -13,10 +13,16 @@ import torch.nn.functional as F import wandb +from torchvision.transforms.functional import crop + import piq +import pickle +import cv2 +from scipy.stats import wasserstein_distance from kornia import filters from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau from einops import rearrange, reduce from einops.layers.torch import Rearrange @@ -42,6 +48,14 @@ # helpers functions +def get_random_idx_with_difference(min_tx, max_tx, number_tx, diff): + times = [] + while len(times) < number_tx: + new_time = randint(min_tx, max_tx) + if all(abs(new_time - time) >= diff for time in times): + times.append(new_time) + return times + def calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size): truth_cdf = np.zeros((256, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') for i in range(256): @@ -101,7 +115,7 @@ def normalize_to_neg_one_to_one(img): def unnormalize_to_zero_to_one(t): return (t + 1) * 0.5 -# ssf modules +# flow modules def gaussian_pyramids(input, base_sigma = 1, m = 5): @@ -165,6 +179,41 @@ def scale_space_warp(input, flow): return warped +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='border', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + + Returns: + Tensor: Warped image or feature map. + """ + n, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), + torch.arange(0, w, dtype=x.dtype, device=x.device)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + return output + # small helper modules class Residual(nn.Module): @@ -641,6 +690,7 @@ def __init__( flow, *, image_size, + in_ch, timesteps = 1200, sampling_timesteps = None, loss_type = 'l1', @@ -653,16 +703,13 @@ def __init__( auto_normalize = True ): super().__init__() - #assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) - #assert not model.random_or_learned_sinusoidal_cond self.model = model - self.umodel = context_net(upscale=8, in_chans=1, img_size=48, window_size=8, - img_range=1., depths=[6, 6, 6, 6, 6, 6, 6], embed_dim=200, - num_heads=[8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, upsampler='pixelshuffle', resi_connection='3conv') - + self.umodel = context_net(upscale = 8, in_chans = in_ch, out_chans = 1, img_size = 48, window_size = 8, + img_range = 1., depths = [6, 6, 6, 6, 6, 6, 6], embed_dim = 200, + num_heads = [8, 8, 8, 8, 8, 8, 8], + mlp_ratio = 2, upsampler = 'pixelshuffle', resi_connection = '3conv') self.flow = flow self.upsample = nn.UpsamplingBilinear2d(scale_factor=8) @@ -766,6 +813,7 @@ def predict_start_from_v(self, x_t, t, v): ) def q_posterior(self, x_start, x_t, t): + posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t @@ -774,11 +822,8 @@ def q_posterior(self, x_start, x_t, t): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - #def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False): def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_start = False): - #model_output = self.model(x, t, x_self_cond) - #print(x.shape, l_cond.shape) model_output = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity @@ -801,10 +846,8 @@ def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_st return ModelPrediction(pred_noise, x_start) - #def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = True): - #preds = self.model_predictions(x, t, x_self_cond) preds = self.model_predictions(x, t, context, x_self_cond) x_start = preds.pred_x_start @@ -815,22 +858,18 @@ def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = Tru return model_mean, posterior_variance, posterior_log_variance, x_start @torch.no_grad() - #def p_sample(self, x, t: int, x_self_cond = None): def p_sample(self, x, t: int, context, x_self_cond = None): - b, *_, device = *x.shape, x.device batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) - #model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, context = context, x_self_cond = x_self_cond, clip_denoised = True) noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 pred_img = model_mean + (0.5 * model_log_variance).exp() * noise return pred_img, x_start @torch.no_grad() - #def p_sample_loop(self, shape, return_all_timesteps = False): def p_sample_loop(self, shape, context, return_all_timesteps = False): - batch, device = shape[0], self.betas.device + device = self.betas.device img = torch.randn(shape, device = device) imgs = [img] @@ -839,19 +878,16 @@ def p_sample_loop(self, shape, context, return_all_timesteps = False): for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): self_cond = x_start if self.self_condition else None - #img, x_start = self.p_sample(img, t, self_cond) img, x_start = self.p_sample(img, t, context, self_cond) imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - #ret = self.unnormalize(ret) return ret @torch.no_grad() - #def ddim_sample(self, shape, return_all_timesteps = False): def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): - print('here!!!') + batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps @@ -866,7 +902,6 @@ def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) self_cond = x_start if self.self_condition else None - #pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True) pred_noise, x_start, *_ = self.model_predictions(img, time_cond, l_cond, context, self_cond, clip_x_start = True) imgs.append(img) @@ -890,33 +925,58 @@ def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - #ret = self.unnormalize(ret) return ret @torch.no_grad() - def sample(self, lres, return_all_timesteps = False): + def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): b, f, c, h, w = lres.shape - ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + if multi: + + topo = hres[:, :, 1:2, :, :] + low_chans = lres[:, :, 1:, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + high_chans = rearrange(F.interpolate(rearrange(low_chans, 'b t c h w -> (b t) c h w'), size=(8*h, 8*w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + + if multi: + + ures = self.umodel(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + + else: + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) lres = self.normalize(lres) ures = self.normalize(ures) + + if multi: + + topo = self.normalize(topo) + sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample l = ures.clone() + + if multi: + + l = torch.cat((l, high_chans, topo), dim = 2) + r = torch.roll(l, -1, 1) ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) - #l_cond = rearrange(ures[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') m = lres.clone() m1 = rearrange(m, 'b t c h w -> (b t) c h w') m1 = self.upsample(m1) m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) + + if multi: + + m1 = torch.cat((m1, topo), dim = 2) + m1 = torch.roll(m1, -2, 1) - #m1 = torch.roll(l, -2, 1) stack = torch.cat((l, r, m1), 2) stack = stack[:, :(f-2), :, :, :] @@ -924,9 +984,26 @@ def sample(self, lres, return_all_timesteps = False): flow, context = self.flow(stack) - warped = scale_space_warp(ures_flow, flow) + if flow_mode == '3d': + + warped = scale_space_warp(ures_flow, flow) + + elif flow_mode == '2d': + + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) + + if multi: + + # l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + l_cond = torch.cat((warped, self.upsample(rearrange(lres[:, 2:, 1:, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) - res = sample_fn((b*(f-2),c,8*h,8*w), l_cond, context, return_all_timesteps = return_all_timesteps) + else: + + # l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l_cond = warped + + res = sample_fn((b * (f - 2), 1, 8 * h, 8 * w), l_cond, context, return_all_timesteps = return_all_timesteps) sres = warped + res sres = rearrange(sres, '(b t) c h w -> b t c h w', b = b) @@ -934,10 +1011,17 @@ def sample(self, lres, return_all_timesteps = False): res = rearrange(res, '(b t) c h w -> b t c h w', b = b) flow = rearrange(flow, '(b t) c h w -> b t c h w', b = b) - return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), self.unnormalize(flow) + if flow_mode == '2d': + + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), flow + + elif flow_mode == '3d': + + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), self.unnormalize(flow) @torch.no_grad() def interpolate(self, x1, x2, t = None, lam = 0.5): + b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) @@ -953,6 +1037,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5): return img def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) return ( @@ -962,6 +1047,7 @@ def q_sample(self, x_start, t, noise=None): @property def loss_fn(self): + if self.loss_type == 'l1': return F.l1_loss elif self.loss_type == 'l2': @@ -969,24 +1055,36 @@ def loss_fn(self): else: raise ValueError(f'invalid loss type {self.loss_type}') - def p_losses(self, stack, hres, lres, ures, t, noise = None): + def p_losses(self, stack, hres, lres, ures, t, multi, flow_mode, topo = None, noise = None): - b, f, c, h, w = hres.shape + f = hres.shape[1] stack = rearrange(stack, 'b t c h w -> (b t) c h w') - ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') + ures_flow = rearrange(ures[:, 1:(f - 1), :, :, :], 'b t c h w -> (b t) c h w') flow, context = self.flow(stack) - warped = scale_space_warp(ures_flow, flow) + if flow_mode == '3d': + + warped = scale_space_warp(ures_flow, flow) + + elif flow_mode == '2d': + + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) x_start = rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') x_start = x_start - warped - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) - #l_cond = rearrange(ures[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') + if multi: + + # l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + l_cond = torch.cat((warped, self.upsample(rearrange(lres[:, 2:, 1:, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) - b, c, h, w = x_start.shape + else: + + # l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l_cond = warped del f @@ -1031,37 +1129,70 @@ def p_losses(self, stack, hres, lres, ures, t, noise = None): loss2 = self.loss_fn(x_start, warped, reduction = 'none') loss2 = reduce(loss2, 'b ... -> b (...)', 'mean') - return loss.mean() + loss1.mean() + loss2.mean() + return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*0.3 - def forward(self, lres, hres, *args, **kwargs): - - b, f, c, h, w, device, img_size = *hres.shape, hres.device, self.image_size + def forward(self, lres, hres, multi, flow_mode, *args, **kwargs): - assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + b, f, c, h, w, device = *hres.shape, hres.device t = torch.randint(0, self.num_timesteps, (b*(f-2),), device=device).long() - ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + if multi: + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + low_chans = lres[:, :, 1:, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h//8, w//8), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + high_chans = rearrange(F.interpolate(rearrange(low_chans, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + + if multi: + + ures = self.umodel(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + + else: + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) lres = self.normalize(lres) hres = self.normalize(hres) ures = self.normalize(ures) + + if multi: + + topo = self.normalize(topo) l = ures.clone() + + if multi: + + l = torch.cat((l, high_chans, topo), dim = 2) + r = torch.roll(l, -1, 1) m = lres.clone() m1 = rearrange(m, 'b t c h w -> (b t) c h w') m1 = self.upsample(m1) m1 = rearrange(m1, '(b t) c h w -> b t c h w', b = b) + + if multi: + + m1 = torch.cat((m1, topo), dim = 2) + m1 = torch.roll(m1, -2, 1) - #m1 = torch.roll(l, -2, 1) stack = torch.cat((l, r, m1), 2) stack = stack[:, :(f-2), :, :, :] - return self.p_losses(stack, hres, lres, ures, t, *args, **kwargs) + if multi: + + + return self.p_losses(stack, hres, lres, ures, t, multi, flow_mode, topo, *args, **kwargs) + + else: + + return self.p_losses(stack, hres, lres, ures, t, multi, flow_mode, None, *args, **kwargs) # trainer class @@ -1075,24 +1206,18 @@ def __init__( *, train_batch_size = 16, gradient_accumulate_every = 1, - #augment_horizontal_flip = True, train_lr = 1e-4, train_num_steps = 100000, ema_update_every = 1, ema_decay = 0.995, adam_betas = (0.9, 0.99), save_and_sample_every = 1, - #num_samples = 25, eval_folder = './evaluate', results_folder = './results', - #tensorboard_dir = './tensorboard', val_num_of_batch = 2, amp = False, fp16 = False, - #fp16 = True, - split_batches = True, - #split_batches = False, - convert_image_to = None + split_batches = True ): super().__init__() @@ -1101,17 +1226,21 @@ def __init__( mixed_precision = 'fp16' if fp16 else 'no', log_with = 'wandb', ) - self.accelerator.init_trackers("vsr-orig-autoreg-hres", + self.accelerator.init_trackers("climate", init_kwargs={ "wandb": { - "notes": "Use VSR to improve precipitation forecasting.", - # Change "name" to set the name of the run. "name": None, } }, ) self.config = config self.accelerator.native_amp = amp + self.multi = config.data_config["multi"] + self.rollout = config.rollout + self.rollout_batch = config.rollout_batch + self.flow = config.data_config["flow"] + self.minipatch = config.data_config["minipatch"] + self.logscale = config.data_config["logscale"] self.model = diffusion_model @@ -1128,6 +1257,8 @@ def __init__( # optimizer self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) + self.sched = CosineAnnealingLR(self.opt, train_num_steps, 5e-7) + #self.sched = ReduceLROnPlateau(self.opt, 'min', factor = 0.5, patience = 5, min_lr = 1e-6, verbose = False) # for logging results in a folder periodically @@ -1138,7 +1269,9 @@ def __init__( self.results_folder.mkdir(exist_ok=True, parents=True) - self.eval_folder = eval_folder + self.eval_folder = Path(eval_folder) + + self.eval_folder.mkdir(exist_ok=True, parents=True) # step counter state @@ -1146,9 +1279,9 @@ def __init__( # prepare model, dataloader, optimizer with accelerator - self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) + self.model, self.opt, self.sched, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, self.sched, train_dl, val_dl) self.train_dl = cycle(train_dl) - self.val_dl = cycle(val_dl) + self.val_dl = val_dl def save(self, milestone): if not self.accelerator.is_local_main_process: @@ -1192,13 +1325,44 @@ def train(self): cmap = mpl.colormaps['RdBu_r'] fcmap = mpl.colormaps['gray_r'] - c384_min = np.load('data/only_precip/c384_min.npy') - c384_max = np.load('data/only_precip/c384_max.npy') - c384_logmin = np.load('data/only_precip/c384_logmin.npy') + # c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + # c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + # c384_gmin = np.load('data/only_precip/c384_gmin.npy') + + # c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + # c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + # c48_gmin = np.load('data/only_precip/c48_gmin.npy') + + # c384_min = np.load('data/only_precip/c384_min.npy') + # c384_max = np.load('data/only_precip/c384_max.npy') + + # c48_min = np.load('data/only_precip/c48_min.npy') + # c48_max = np.load('data/only_precip/c48_max.npy') + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + c384_log_chl = pickle.load(f) - c48_min = np.load('data/only_precip/c48_min.npy') - c48_max = np.load('data/only_precip/c48_max.npy') - c48_logmin = np.load('data/only_precip/c48_logmin.npy') + c384_lgmin = c384_log_chl["PRATEsfc"]['min'] + c384_lgmax = c384_log_chl["PRATEsfc"]['max'] + c48_lgmin = c48_log_chl["PRATEsfc_coarse"]['min'] + c48_lgmax = c48_log_chl["PRATEsfc_coarse"]['max'] + + c384_min = c384_chl["PRATEsfc"]['min'] + c384_max = c384_chl["PRATEsfc"]['max'] + c48_min = c48_chl["PRATEsfc_coarse"]['min'] + c48_max = c48_chl["PRATEsfc_coarse"]['max'] + + c384_gmin = c384_min + c48_gmin = c48_min with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: @@ -1208,15 +1372,20 @@ def train(self): for _ in range(self.gradient_accumulate_every): - #data = next(self.dl).to(device) data = next(self.train_dl) lres = data['LR'].to(device) hres = data['HR'].to(device) + if self.minipatch: + + x_st = randint(0, 36) + y_st = randint(0, 36) + lres = crop(lres, x_st, y_st, 12, 12) + hres = crop(hres, 8 * x_st, 8 * y_st, 96, 96) + with self.accelerator.autocast(): - #loss = self.model(data) - loss = self.model(lres, hres) + loss = self.model(lres, hres, self.multi, self.flow) loss = loss / self.gradient_accumulate_every total_loss += loss.item() @@ -1225,13 +1394,13 @@ def train(self): accelerator.clip_grad_norm_(self.model.parameters(), 1.0) pbar.set_description(f'loss: {total_loss:.4f}') - #self.writer.add_scalar("loss", total_loss, self.step) accelerator.log({"loss": total_loss}, step = self.step) accelerator.wait_for_everyone() self.opt.step() self.opt.zero_grad() + self.sched.step() accelerator.wait_for_everyone() @@ -1244,6 +1413,14 @@ def train(self): self.ema.ema_model.eval() with torch.no_grad(): + + vlosses = [] + vids = [] + hr = [] + lr = [] + bases, ress, flowss = [], [], [] + num_frames = 5 + img_size = 384 for i, batch in enumerate(self.val_dl): @@ -1253,99 +1430,169 @@ def train(self): if i >= self.val_num_of_batch: break - num_samples = 5 - num_videos_per_batch = 1 - num_frames = 5 - img_size = 384 - img_channels = 1 - - truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - truth[0,:,:,:,:,:] = (hres[:,2:,:,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) - for k in range(num_samples): - videos, base, res, flows = self.ema.ema_model.sample(lres) - pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) - - crps_index = calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size) - psnr_index = piq.psnr(hres[:,2:,0:1,:,:], videos.clamp(0.0, 1.0)[:,:,0:1,:,:], data_range=1., reduction='none') - - videos_time_mean = videos.mean(dim = 1) - hres_time_mean = hres[:,2:,:,:,:].mean(dim = 1) - bias = videos_time_mean - hres_time_mean - norm = mpl.colors.Normalize(vmin = bias.min(), vmax = bias.max()) - sm = smap(norm, cmap) - b_c = [] - for l in range(num_videos_per_batch): - b_c.append(sm.to_rgba(bias[l,0,:,:].cpu().numpy())) - bias_color = np.stack(b_c, axis = 0) + # num_samples = 5 + # num_videos_per_batch = 1 + # num_frames = 5 + # img_size = 384 + # img_channels = 1 + # truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + # pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + # truth[0,:,:,:,:,:] = (hres[:,2:,0:1,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) + + # for k in range(num_samples): + # videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + # pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + loss = self.model(lres, hres, self.multi, self.flow) + + vids.append(videos) + vlosses.append(loss) + hr.append(hres) + lr.append(lres) + bases.append(base) + ress.append(res) + flowss.append(flows) + + videos = torch.cat(vids, dim = 0) + vloss = torch.stack(vlosses, dim = 0).mean() + #self.sched.step(vloss) + hres = torch.cat(hr, dim = 0) + lres = torch.cat(lr, dim = 0) + base = torch.cat(bases, dim = 0) + res = torch.cat(ress, dim = 0) + flows = torch.cat(flowss, dim = 0) + del vids, vlosses, hr, lr, bases, ress, flowss + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] + + if not self.logscale: target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min - target = np.exp(target) + c384_logmin - 1e-14 - output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min - output = np.exp(output) + c384_logmin - 1e-14 - coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min - coarse = np.exp(coarse) + c48_logmin - 1e-14 - - nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) - diff_output = (output - nn_upscale).flatten() - diff_target = (target - nn_upscale).flatten() - vmin = min(diff_output.min(), diff_target.min()) - vmax = max(diff_output.max(), diff_target.max()) - bins = np.linspace(vmin, vmax, 100 + 1) - - fig, ax = plt.subplots(1, 1, figsize=(6, 4)) - ax.hist( - diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True - ) - ax.hist( - diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True - ) - ax.set_xlim(vmin, vmax) - ax.legend() - ax.set_ylabel("Density") - ax.set_yscale("log") - - output1 = output.flatten() - target1 = target.flatten() - vmin1 = min(output1.min(), target1.min()) - vmax1 = max(output1.max(), target1.max()) - bins1 = np.linspace(vmin1, vmax1, 100 + 1) - - fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) - ax1.hist( - output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True - ) - ax1.hist( - target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True - ) - ax1.set_xlim(vmin1, vmax1) - ax1.legend() - ax1.set_ylabel("Density") - ax1.set_yscale("log") - - flow_d = np.zeros((1, num_samples, 3, img_size, img_size)) - for m in range(num_samples): - flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:2,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) + + else: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 + + ssim_index = piq.ssim(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + gmsd_index = piq.gmsd(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + rmse = np.sqrt(np.mean((output1 - target1)**2)) + pscore = np.abs(np.percentile(output1, 99.999) - np.percentile(target1, 99.999)) + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + #histo = np.histogram(output1, bins=bins1, density=True)[0].ravel().astype('float32') + #histt = np.histogram(target1, bins=bins1, density=True)[0].ravel().astype('float32') + count_o, bin_o = np.histogram(output1, bins=bins1, density=True) + count_t, bin_t = np.histogram(target1, bins=bins1, density=True) + histo = count_o.ravel().astype('float32') + histt = count_t.ravel().astype('float32') + distchisqr = cv2.compareHist(histo, histt, cv2.HISTCMP_CHISQR) + distinter = cv2.compareHist(histo, histt, cv2.HISTCMP_INTERSECT) + distkl = cv2.compareHist(histo, histt, cv2.HISTCMP_KL_DIV) + distemd = wasserstein_distance(output1, target1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + #output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + bin_o[:-1], bins=bin_o, weights = count_o, alpha=0.5, label="Output", histtype="step"#, density=True + ) + ax1.hist( + #target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + bin_t[:-1], bins=bin_t, weights = count_t, alpha=0.5, label="Target", histtype="step"#, density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) - flow_s = np.zeros((1, num_samples, 3, img_size, img_size)) + for m in range(num_frames): + + flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:2,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) + + if self.flow == '3d': + + flow_s = np.zeros((1, num_frames, 3, img_size, img_size)) sm = smap(None, fcmap) - for m in range(num_samples): + + for m in range(num_frames): + flow_s[0,m,:,:,:] = np.transpose(sm.to_rgba(flows.clamp(0, 1)[0,m,2,:,:].cpu().numpy())[:,:,:3], (2,0,1)) + + + + if self.logscale: - accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) - accelerator.log({"flow_s": wandb.Video((flow_s*255).astype(np.uint8))}, step=self.step) - accelerator.log({"pattern_bias": wandb.Image((bias_color*255).astype(np.uint8), mode = 'RGBA')}, step=self.step) - accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) - accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) - accelerator.log({"psnr": psnr_index.mean()}, step=self.step) - accelerator.log({"crps": crps_index}, step=self.step) + if self.flow == '3d': + accelerator.log({"flow_s": wandb.Video((flow_s*255).astype(np.uint8))}, step=self.step) + + else: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"ssim": ssim_index.mean()}, step=self.step) + accelerator.log({"gmsd": gmsd_index.mean()}, step=self.step) + accelerator.log({"rmse": rmse}, step=self.step) + accelerator.log({"pscore": pscore}, step=self.step) + accelerator.log({"distchisqr": distchisqr}, step=self.step) + accelerator.log({"distinter": distinter}, step=self.step) + accelerator.log({"distkl": distkl}, step=self.step) + accelerator.log({"distemd": distemd}, step=self.step) + accelerator.log({"vloss": vloss}, step=self.step) + accelerator.log({"lr": self.opt.param_groups[0]['lr']}, step=self.step) milestone = self.step // self.save_and_sample_every @@ -1362,159 +1609,184 @@ def sample(self): self.ema.ema_model.eval() - cmap = mpl.colormaps['viridis'] - sm = smap(None, cmap) + PATH = "/extra/ucibdl0/shared/data/fv3gfs" + XX = xr.open_zarr(f"{PATH}/c48_precip_plus_more_ave/0011/sfc_8xdaily_ave_coarse.zarr") + XX_ = xr.open_zarr(f"{PATH}/c48_atmos_ave/0011/atmos_8xdaily_ave_coarse.zarr") + yy = xr.open_zarr(f"{PATH}/c384_precip_ave/0011/sfc_8xdaily_ave.zarr") + topot = xr.open_zarr(f"{PATH}/c384_topo/0011/atmos_static.zarr") + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + c384_topo = pickle.load(f) + + if self.multi: + + c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + c384_channels = ["PRATEsfc"] + + else: + + c48_channels = ["PRATEsfc_coarse"] + c384_channels = ["PRATEsfc"] with torch.no_grad(): - for k, batch in enumerate(self.val_dl): + for tile in range(6): + + if self.rollout == 'full': + + seq_len = self.rollout_batch + st = 0 + en = seq_len + 2 + count = 0 + + while en < 3176: + + print(tile, st) + + X = XX.isel(time = slice(st, en), tile = tile) + X_ = XX_.isel(time = slice(st, en), tile = tile) + y = yy.isel(time = slice(st, en), tile = tile) + + + if self.multi: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + else: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) - lres = batch['LR'].to(device) - hres = batch['HR'].to(device) - - if k >= self.val_num_of_batch: - break - - limit = lres.shape[1] - if limit < 8: - - #videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres, True) - videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres) - - torch.save(videos, os.path.join(self.eval_folder) + "/gen.pt") - torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr.pt") - torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr.pt") - - for i, b in enumerate(videos.clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - os.makedirs(os.path.join(self.eval_folder, "generated")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - - #videos = torch.log(videos.clamp(0.0, 1.0) + 1) - #hres = torch.log(hres + 1) - - #for i, b in enumerate(videos.clamp(0, 1)): - # for i, b in enumerate(videos): - # if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - # os.makedirs(os.path.join(self.eval_folder, "generated")) - # Parallel(n_jobs=4)( - # delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - # for j, f in enumerate(b.cpu()) - # ) - -# for i, b in enumerate(nsteps.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# os.makedirs(os.path.join(self.eval_folder, "residual")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(base.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# os.makedirs(os.path.join(self.eval_folder, "warped")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# os.makedirs(os.path.join(self.eval_folder, "flows")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - - for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "truth")): - os.makedirs(os.path.join(self.eval_folder, "truth")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - -# else: - -# videos, base, nsteps, flows = self.ema.ema_model.sample(lres[:,:7,:,:], hres[:,:7,:,:], True) - -# st = 5 -# ed = st + 7 - -# while ed < limit: - -# vi, ba, ns, fl = self.ema.ema_model.sample(lres[:,st:ed,:,:], hres[:,st:ed,:,:], True) -# st += 5 -# ed += 5 -# videos = torch.cat((videos, vi), 1) -# #base = torch.cat((base, ba), 1) -# #nsteps = torch.cat((nsteps, ns), 1) -# #flows = torch.cat((flows, fl), 1) - -# for i, b in enumerate(videos.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "generated")): -# os.makedirs(os.path.join(self.eval_folder, "generated")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "truth")): -# os.makedirs(os.path.join(self.eval_folder, "truth")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# # for i, b in enumerate(nsteps.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# # os.makedirs(os.path.join(self.eval_folder, "residual")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# # for i, b in enumerate(base.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# # os.makedirs(os.path.join(self.eval_folder, "warped")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) -# # for i, b in enumerate(flows.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# # os.makedirs(os.path.join(self.eval_folder, "flows")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) + if self.logscale: + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + else: + + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min']) / (c48_chl["PRATEsfc_coarse"]['max'] - c48_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_chl["PRATEsfc"]['min']) / (c384_chl["PRATEsfc"]['max'] - c384_chl["PRATEsfc"]['min']) + + if self.multi: + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + count += 1 + + st += seq_len + en += seq_len + + if self.rollout == 'partial': + + seq_len = self.rollout_batch + #indices = get_random_idx_with_difference(0, 3176 - (seq_len + 2), 75 // seq_len, seq_len + 2) # 75 samples per tile + indices = list(range(0, 3176 - (seq_len + 2), 250)) # deterministic, 325 samples per tile for seq_len of 25 + + for count, st in enumerate(indices): + + print(tile, count) + + X = XX.isel(time = slice(st, st+(seq_len+2)), tile = tile) + X_ = XX_.isel(time = slice(st, st+(seq_len+2)), tile = tile) + y = yy.isel(time = slice(st, st+(seq_len+2)), tile = tile) + + + if self.multi: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + else: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_d")): -# os.makedirs(os.path.join(self.eval_folder, "flows_d")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_d") + f"/{k}-{i}-{j}.png", flow_vis.flow_to_color(f.permute(1,2,0).cpu().numpy()[:,:,:2], convert_to_bgr = False)) -# for j, f in enumerate(b.cpu()) -# ) -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_s")): -# os.makedirs(os.path.join(self.eval_folder, "flows_s")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_s") + f"/{k}-{i}-{j}.png", f.permute(1,2,0).cpu().numpy()[:,:,2], cmap = 'gray_r') -# for j, f in enumerate(b.cpu()) -# ) \ No newline at end of file + if self.logscale: + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + else: + + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min']) / (c48_chl["PRATEsfc_coarse"]['max'] - c48_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_chl["PRATEsfc"]['min']) / (c384_chl["PRATEsfc"]['max'] - c384_chl["PRATEsfc"]['min']) + + if self.multi: + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) \ No newline at end of file diff --git a/projects/super_res/model/autoreg_diffusion.py b/projects/super_res/model/autoreg_diffusion_mod_focal.py similarity index 53% rename from projects/super_res/model/autoreg_diffusion.py rename to projects/super_res/model/autoreg_diffusion_mod_focal.py index 3d8eabc370..3b35a73eaa 100644 --- a/projects/super_res/model/autoreg_diffusion.py +++ b/projects/super_res/model/autoreg_diffusion_mod_focal.py @@ -1,10 +1,10 @@ import os import math from pathlib import Path -from random import random +from random import random, randint from functools import partial from collections import namedtuple -from joblib import Parallel, delayed +import xarray as xr import numpy as np @@ -13,10 +13,16 @@ import torch.nn.functional as F import wandb +from torchvision.transforms.functional import crop + import piq +import pickle +import cv2 +from scipy.stats import wasserstein_distance from kornia import filters from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau from einops import rearrange, reduce from einops.layers.torch import Rearrange @@ -24,6 +30,7 @@ from PIL import Image import matplotlib as mpl +import matplotlib.pyplot as plt from matplotlib.cm import ScalarMappable as smap from tqdm.auto import tqdm @@ -33,12 +40,42 @@ from accelerate import Accelerator +from .network_swinir import SwinIR as context_net + # constants ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) # helpers functions +def focal_mse_loss(input, target, reduction = None): + weight = torch.abs(input - target) + norm = (weight - weight.min()) / (weight.max() - weight.min()) + focal = torch.pow((norm + 1e-10), 5) + return weight * (input - target) ** 2 + +def get_random_idx_with_difference(min_tx, max_tx, number_tx, diff): + times = [] + while len(times) < number_tx: + new_time = randint(min_tx, max_tx) + if all(abs(new_time - time) >= diff for time in times): + times.append(new_time) + return times + +def calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size): + truth_cdf = np.zeros((256, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + for i in range(256): + truth_cdf[i, :, :, :, :, :, :] = (truth <= i).astype('uint8') + pred_cdf = np.zeros((256, num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + for j in range(256): + pred_cdf[j, :, :, :, :, :, :, :] = (pred <= j).astype('uint8') + red_pred_cdf = pred_cdf.mean(1) + temp = np.square(red_pred_cdf - truth_cdf) + temp_dz = temp.sum(0) + temp_dz_dd = temp_dz.mean(axis = (3, 4, 5)) + temp_dz_dd_dt = temp_dz_dd.mean(2) + return temp_dz_dd_dt.mean() + def save_image(tensor, path): im = Image.fromarray((tensor[:,:,:3] * 255).astype(np.uint8)) im.save(path) @@ -84,13 +121,14 @@ def normalize_to_neg_one_to_one(img): def unnormalize_to_zero_to_one(t): return (t + 1) * 0.5 -# ssf modules +# flow modules def gaussian_pyramids(input, base_sigma = 1, m = 5): output = [input] N, C, H, W = input.shape - kernel = filters.get_gaussian_kernel2d((5, 5), (base_sigma, base_sigma)) + + kernel = filters.get_gaussian_kernel2d((5, 5), (base_sigma, base_sigma))#.unsqueeze(0) for i in range(m): @@ -147,6 +185,41 @@ def scale_space_warp(input, flow): return warped +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='border', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + + Returns: + Tensor: Warped image or feature map. + """ + n, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), + torch.arange(0, w, dtype=x.dtype, device=x.device)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + return output + # small helper modules class Residual(nn.Module): @@ -623,6 +696,7 @@ def __init__( flow, *, image_size, + in_ch, timesteps = 1200, sampling_timesteps = None, loss_type = 'l1', @@ -635,11 +709,13 @@ def __init__( auto_normalize = True ): super().__init__() - #assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) - #assert not model.random_or_learned_sinusoidal_cond self.model = model + self.umodel = context_net(upscale = 8, in_chans = in_ch, out_chans = 1, img_size = 48, window_size = 8, + img_range = 1., depths = [6, 6, 6, 6, 6, 6, 6], embed_dim = 200, + num_heads = [8, 8, 8, 8, 8, 8, 8], + mlp_ratio = 2, upsampler = 'pixelshuffle', resi_connection = '3conv') self.flow = flow self.upsample = nn.UpsamplingBilinear2d(scale_factor=8) @@ -743,6 +819,7 @@ def predict_start_from_v(self, x_t, t, v): ) def q_posterior(self, x_start, x_t, t): + posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t @@ -751,11 +828,8 @@ def q_posterior(self, x_start, x_t, t): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - #def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False): def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_start = False): - #model_output = self.model(x, t, x_self_cond) - #print(x.shape, l_cond.shape) model_output = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity @@ -778,10 +852,8 @@ def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_st return ModelPrediction(pred_noise, x_start) - #def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = True): - #preds = self.model_predictions(x, t, x_self_cond) preds = self.model_predictions(x, t, context, x_self_cond) x_start = preds.pred_x_start @@ -792,22 +864,18 @@ def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = Tru return model_mean, posterior_variance, posterior_log_variance, x_start @torch.no_grad() - #def p_sample(self, x, t: int, x_self_cond = None): def p_sample(self, x, t: int, context, x_self_cond = None): - b, *_, device = *x.shape, x.device batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) - #model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, context = context, x_self_cond = x_self_cond, clip_denoised = True) noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 pred_img = model_mean + (0.5 * model_log_variance).exp() * noise return pred_img, x_start @torch.no_grad() - #def p_sample_loop(self, shape, return_all_timesteps = False): def p_sample_loop(self, shape, context, return_all_timesteps = False): - batch, device = shape[0], self.betas.device + device = self.betas.device img = torch.randn(shape, device = device) imgs = [img] @@ -816,17 +884,14 @@ def p_sample_loop(self, shape, context, return_all_timesteps = False): for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): self_cond = x_start if self.self_condition else None - #img, x_start = self.p_sample(img, t, self_cond) img, x_start = self.p_sample(img, t, context, self_cond) imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - #ret = self.unnormalize(ret) return ret @torch.no_grad() - #def ddim_sample(self, shape, return_all_timesteps = False): def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective @@ -843,7 +908,6 @@ def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) self_cond = x_start if self.self_condition else None - #pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True) pred_noise, x_start, *_ = self.model_predictions(img, time_cond, l_cond, context, self_cond, clip_x_start = True) imgs.append(img) @@ -867,67 +931,103 @@ def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - #ret = self.unnormalize(ret) return ret @torch.no_grad() - #def sample(self, batch_size = 16, return_all_timesteps = False): - def sample(self, lres, hres, return_all_timesteps = False): + def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): + + b, f, c, h, w = lres.shape + + if multi: + + topo = hres[:, :, 1:2, :, :] + low_chans = lres[:, :, 1:, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + high_chans = rearrange(F.interpolate(rearrange(low_chans, 'b t c h w -> (b t) c h w'), size=(8*h, 8*w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + + if multi: + + ures = self.umodel(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + + else: + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) - b, f, c, h, w, image_size, channels = *hres.shape, self.image_size, self.channels - print(b,f,c,h,w) lres = self.normalize(lres) - hres = self.normalize(hres) + ures = self.normalize(ures) + + if multi: + + topo = self.normalize(topo) + sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample - l = hres.clone()[:, :1, :, :, :] - r = hres.clone()[:, 1:2, :, :, :] - hres_flow = rearrange(hres[:, 1:2, :, :, :], 'b t c h w -> (b t) c h w') - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l = ures.clone() + + if multi: + + l = torch.cat((l, high_chans, topo), dim = 2) + + r = torch.roll(l, -1, 1) + ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') m = lres.clone() m1 = rearrange(m, 'b t c h w -> (b t) c h w') m1 = self.upsample(m1) m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) + + if multi: + + m1 = torch.cat((m1, topo), dim = 2) + m1 = torch.roll(m1, -2, 1) - m1 = m1[:, :(f-2), :, :, :] - ans = [] - base = [] - nsteps = [] - flows = [] + stack = torch.cat((l, r, m1), 2) + stack = stack[:, :(f-2), :, :, :] + stack = rearrange(stack, 'b t c h w -> (b t) c h w') - for i in range(f-2): - - stack = torch.cat((l, r, m1[:, i:i+1, :, :, :]), 2) + flow, context = self.flow(stack) - stack = rearrange(stack, 'b t c h w -> (b t) c h w') + if flow_mode == '3d': + + warped = scale_space_warp(ures_flow, flow) - flow, context = self.flow(stack) - - warped = scale_space_warp(hres_flow, flow) - batch_size = b - #res = sample_fn((batch_size, c, image_size, image_size), l_cond[i::(f-2), :, :, :], context, return_all_timesteps = return_all_timesteps) - - res = sample_fn((batch_size, c, h, w), l_cond[i::(f-2), :, :, :], context, return_all_timesteps = return_all_timesteps) - hres_flow = warped + res - - #hres_flow = warped + res[:, -1, :, :, :] + elif flow_mode == '2d': + + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) + + if multi: - l = r - r = rearrange(hres_flow, '(b t) c h w -> b t c h w', t = 1) + # l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + l_cond = torch.cat((warped, self.upsample(rearrange(lres[:, 2:, 1:, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + + else: - ans.append(hres_flow) - base.append(warped) - #nsteps.append(torch.cat(torch.unbind(res, 1), 3)) - nsteps.append(res) - flows.append(flow) + # l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l_cond = warped + + res = sample_fn((b * (f - 2), 1, 8 * h, 8 * w), l_cond, context, return_all_timesteps = return_all_timesteps) + sres = warped + res + sres = rearrange(sres, '(b t) c h w -> b t c h w', b = b) + + warped = rearrange(warped, '(b t) c h w -> b t c h w', b = b) + res = rearrange(res, '(b t) c h w -> b t c h w', b = b) + flow = rearrange(flow, '(b t) c h w -> b t c h w', b = b) + + if flow_mode == '2d': - return self.unnormalize(torch.stack(ans, 1)), self.unnormalize(torch.stack(base, 1)), self.unnormalize(torch.stack(nsteps, 1)), self.unnormalize(torch.stack(flows, 1)) - #return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps) + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), flow + + elif flow_mode == '3d': + + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), self.unnormalize(flow) @torch.no_grad() def interpolate(self, x1, x2, t = None, lam = 0.5): + b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) @@ -943,6 +1043,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5): return img def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) return ( @@ -952,31 +1053,46 @@ def q_sample(self, x_start, t, noise=None): @property def loss_fn(self): + if self.loss_type == 'l1': return F.l1_loss elif self.loss_type == 'l2': return F.mse_loss + elif self.loss_type == 'focal': + return focal_mse_loss else: raise ValueError(f'invalid loss type {self.loss_type}') - #def p_losses(self, x_start, t, noise = None): - def p_losses(self, stack, hres, lres, t, noise = None): + def p_losses(self, stack, hres, lres, ures, t, multi, flow_mode, topo = None, noise = None): - b, f, c, h, w = hres.shape + f = hres.shape[1] stack = rearrange(stack, 'b t c h w -> (b t) c h w') - hres_flow = rearrange(hres[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') + ures_flow = rearrange(ures[:, 1:(f - 1), :, :, :], 'b t c h w -> (b t) c h w') flow, context = self.flow(stack) - #print(flow.shape, hres_flow.shape) - warped = scale_space_warp(hres_flow, flow) + + if flow_mode == '3d': + + warped = scale_space_warp(ures_flow, flow) + + elif flow_mode == '2d': + + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) x_start = rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') x_start = x_start - warped - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + if multi: + + # l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + l_cond = torch.cat((warped, self.upsample(rearrange(lres[:, 2:, 1:, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) - b, c, h, w = x_start.shape + else: + + # l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l_cond = warped del f @@ -1014,37 +1130,77 @@ def p_losses(self, stack, hres, lres, t, noise = None): loss = reduce(loss, 'b ... -> b (...)', 'mean') loss = loss * extract(self.p2_loss_weight, t, loss.shape) - return loss.mean() - #def forward(self, data, *args, **kwargs): - def forward(self, lres, hres, *args, **kwargs): - - #b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size - b, f, c, h, w, device, img_size = *hres.shape, hres.device, self.image_size + loss1 = self.loss_fn(ures, hres, reduction = 'none') + loss1 = reduce(loss1, 'b ... -> b (...)', 'mean') + + loss2 = self.loss_fn(warped, rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w'), reduction = 'none') + loss2 = reduce(loss2, 'b ... -> b (...)', 'mean') + + return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*0.3 + + def forward(self, lres, hres, multi, flow_mode, *args, **kwargs): - assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + b, f, c, h, w, device = *hres.shape, hres.device - #t = torch.randint(0, self.num_timesteps, (b,), device=device).long() t = torch.randint(0, self.num_timesteps, (b*(f-2),), device=device).long() - #img = self.normalize(img) + if multi: + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + low_chans = lres[:, :, 1:, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h//8, w//8), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + high_chans = rearrange(F.interpolate(rearrange(low_chans, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + + if multi: + + ures = self.umodel(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + + else: + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) + lres = self.normalize(lres) hres = self.normalize(hres) + ures = self.normalize(ures) + + if multi: + + topo = self.normalize(topo) + + l = ures.clone() + + if multi: + + l = torch.cat((l, high_chans, topo), dim = 2) - l = hres.clone() r = torch.roll(l, -1, 1) m = lres.clone() m1 = rearrange(m, 'b t c h w -> (b t) c h w') m1 = self.upsample(m1) - m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) + m1 = rearrange(m1, '(b t) c h w -> b t c h w', b = b) + + if multi: + + m1 = torch.cat((m1, topo), dim = 2) + m1 = torch.roll(m1, -2, 1) stack = torch.cat((l, r, m1), 2) stack = stack[:, :(f-2), :, :, :] - #return self.p_losses(img, t, *args, **kwargs) - return self.p_losses(stack, hres, lres, t, *args, **kwargs) + if multi: + + + return self.p_losses(stack, hres, lres, ures, t, multi, flow_mode, topo, *args, **kwargs) + + else: + + return self.p_losses(stack, hres, lres, ures, t, multi, flow_mode, None, *args, **kwargs) # trainer class @@ -1058,24 +1214,18 @@ def __init__( *, train_batch_size = 16, gradient_accumulate_every = 1, - #augment_horizontal_flip = True, train_lr = 1e-4, train_num_steps = 100000, ema_update_every = 1, ema_decay = 0.995, adam_betas = (0.9, 0.99), - save_and_sample_every = 10, - #num_samples = 25, + save_and_sample_every = 1, eval_folder = './evaluate', results_folder = './results', - #tensorboard_dir = './tensorboard', val_num_of_batch = 2, amp = False, fp16 = False, - #fp16 = True, - split_batches = True, - #split_batches = False, - convert_image_to = None + split_batches = True ): super().__init__() @@ -1084,17 +1234,21 @@ def __init__( mixed_precision = 'fp16' if fp16 else 'no', log_with = 'wandb', ) - self.accelerator.init_trackers("vsr-orig-autoreg-hres", + self.accelerator.init_trackers("climate", init_kwargs={ "wandb": { - "notes": "Use VSR to improve precipitation forecasting.", - # Change "name" to set the name of the run. "name": None, } }, ) self.config = config self.accelerator.native_amp = amp + self.multi = config.data_config["multi"] + self.rollout = config.rollout + self.rollout_batch = config.rollout_batch + self.flow = config.data_config["flow"] + self.minipatch = config.data_config["minipatch"] + self.logscale = config.data_config["logscale"] self.model = diffusion_model @@ -1111,6 +1265,8 @@ def __init__( # optimizer self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) + self.sched = CosineAnnealingLR(self.opt, train_num_steps, 5e-7) + #self.sched = ReduceLROnPlateau(self.opt, 'min', factor = 0.5, patience = 5, min_lr = 1e-6, verbose = False) # for logging results in a folder periodically @@ -1121,7 +1277,9 @@ def __init__( self.results_folder.mkdir(exist_ok=True, parents=True) - self.eval_folder = eval_folder + self.eval_folder = Path(eval_folder) + + self.eval_folder.mkdir(exist_ok=True, parents=True) # step counter state @@ -1129,9 +1287,9 @@ def __init__( # prepare model, dataloader, optimizer with accelerator - self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) + self.model, self.opt, self.sched, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, self.sched, train_dl, val_dl) self.train_dl = cycle(train_dl) - self.val_dl = cycle(val_dl) + self.val_dl = val_dl def save(self, milestone): if not self.accelerator.is_local_main_process: @@ -1172,6 +1330,48 @@ def train(self): accelerator = self.accelerator device = accelerator.device + cmap = mpl.colormaps['RdBu_r'] + fcmap = mpl.colormaps['gray_r'] + + # c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + # c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + # c384_gmin = np.load('data/only_precip/c384_gmin.npy') + + # c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + # c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + # c48_gmin = np.load('data/only_precip/c48_gmin.npy') + + # c384_min = np.load('data/only_precip/c384_min.npy') + # c384_max = np.load('data/only_precip/c384_max.npy') + + # c48_min = np.load('data/only_precip/c48_min.npy') + # c48_max = np.load('data/only_precip/c48_max.npy') + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + c384_log_chl = pickle.load(f) + + c384_lgmin = c384_log_chl["PRATEsfc"]['min'] + c384_lgmax = c384_log_chl["PRATEsfc"]['max'] + c48_lgmin = c48_log_chl["PRATEsfc_coarse"]['min'] + c48_lgmax = c48_log_chl["PRATEsfc_coarse"]['max'] + + c384_min = c384_chl["PRATEsfc"]['min'] + c384_max = c384_chl["PRATEsfc"]['max'] + c48_min = c48_chl["PRATEsfc_coarse"]['min'] + c48_max = c48_chl["PRATEsfc_coarse"]['max'] + + c384_gmin = c384_min + c48_gmin = c48_min + with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: while self.step < self.train_num_steps: @@ -1180,15 +1380,20 @@ def train(self): for _ in range(self.gradient_accumulate_every): - #data = next(self.dl).to(device) data = next(self.train_dl) lres = data['LR'].to(device) hres = data['HR'].to(device) + if self.minipatch: + + x_st = randint(0, 36) + y_st = randint(0, 36) + lres = crop(lres, x_st, y_st, 12, 12) + hres = crop(hres, 8 * x_st, 8 * y_st, 96, 96) + with self.accelerator.autocast(): - #loss = self.model(data) - loss = self.model(lres, hres) + loss = self.model(lres, hres, self.multi, self.flow) loss = loss / self.gradient_accumulate_every total_loss += loss.item() @@ -1197,13 +1402,13 @@ def train(self): accelerator.clip_grad_norm_(self.model.parameters(), 1.0) pbar.set_description(f'loss: {total_loss:.4f}') - #self.writer.add_scalar("loss", total_loss, self.step) accelerator.log({"loss": total_loss}, step = self.step) accelerator.wait_for_everyone() self.opt.step() self.opt.zero_grad() + self.sched.step() accelerator.wait_for_everyone() @@ -1216,6 +1421,14 @@ def train(self): self.ema.ema_model.eval() with torch.no_grad(): + + vlosses = [] + vids = [] + hr = [] + lr = [] + bases, ress, flowss = [], [], [] + num_frames = 5 + img_size = 384 for i, batch in enumerate(self.val_dl): @@ -1224,17 +1437,170 @@ def train(self): if i >= self.val_num_of_batch: break - - videos, base, res, flows = self.ema.ema_model.sample(lres, hres) - psnr_index = piq.psnr(hres[:,2:,0:1,:,:], videos.clamp(0.0, 1.0)[:,:,0:1,:,:], data_range=1., reduction='none') + + # num_samples = 5 + # num_videos_per_batch = 1 + # num_frames = 5 + # img_size = 384 + # img_channels = 1 + + # truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + # pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + # truth[0,:,:,:,:,:] = (hres[:,2:,0:1,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) - accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"flows": wandb.Video((flows.clamp(0.0, 1.0).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"psnr": psnr_index.mean()}, step=self.step) + # for k in range(num_samples): + # videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + # pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + loss = self.model(lres, hres, self.multi, self.flow) + + vids.append(videos) + vlosses.append(loss) + hr.append(hres) + lr.append(lres) + bases.append(base) + ress.append(res) + flowss.append(flows) + + videos = torch.cat(vids, dim = 0) + vloss = torch.stack(vlosses, dim = 0).mean() + #self.sched.step(vloss) + hres = torch.cat(hr, dim = 0) + lres = torch.cat(lr, dim = 0) + base = torch.cat(bases, dim = 0) + res = torch.cat(ress, dim = 0) + flows = torch.cat(flowss, dim = 0) + del vids, vlosses, hr, lr, bases, ress, flowss + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] + + if not self.logscale: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 + + ssim_index = piq.ssim(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + gmsd_index = piq.gmsd(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + rmse = np.sqrt(np.mean((output1 - target1)**2)) + pscore = np.abs(np.percentile(output1, 99.999) - np.percentile(target1, 99.999)) + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + #histo = np.histogram(output1, bins=bins1, density=True)[0].ravel().astype('float32') + #histt = np.histogram(target1, bins=bins1, density=True)[0].ravel().astype('float32') + count_o, bin_o = np.histogram(output1, bins=bins1, density=True) + count_t, bin_t = np.histogram(target1, bins=bins1, density=True) + histo = count_o.ravel().astype('float32') + histt = count_t.ravel().astype('float32') + distchisqr = cv2.compareHist(histo, histt, cv2.HISTCMP_CHISQR) + distinter = cv2.compareHist(histo, histt, cv2.HISTCMP_INTERSECT) + distkl = cv2.compareHist(histo, histt, cv2.HISTCMP_KL_DIV) + distemd = wasserstein_distance(output1, target1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + #output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + bin_o[:-1], bins=bin_o, weights = count_o, alpha=0.5, label="Output", histtype="step"#, density=True + ) + ax1.hist( + #target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + bin_t[:-1], bins=bin_t, weights = count_t, alpha=0.5, label="Target", histtype="step"#, density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) + + for m in range(num_frames): + + flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:2,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) + + if self.flow == '3d': + + flow_s = np.zeros((1, num_frames, 3, img_size, img_size)) + sm = smap(None, fcmap) + + for m in range(num_frames): + + flow_s[0,m,:,:,:] = np.transpose(sm.to_rgba(flows.clamp(0, 1)[0,m,2,:,:].cpu().numpy())[:,:,:3], (2,0,1)) + + + + if self.logscale: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) + if self.flow == '3d': + accelerator.log({"flow_s": wandb.Video((flow_s*255).astype(np.uint8))}, step=self.step) + + else: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"ssim": ssim_index.mean()}, step=self.step) + accelerator.log({"gmsd": gmsd_index.mean()}, step=self.step) + accelerator.log({"rmse": rmse}, step=self.step) + accelerator.log({"pscore": pscore}, step=self.step) + accelerator.log({"distchisqr": distchisqr}, step=self.step) + accelerator.log({"distinter": distinter}, step=self.step) + accelerator.log({"distkl": distkl}, step=self.step) + accelerator.log({"distemd": distemd}, step=self.step) + accelerator.log({"vloss": vloss}, step=self.step) + accelerator.log({"lr": self.opt.param_groups[0]['lr']}, step=self.step) milestone = self.step // self.save_and_sample_every @@ -1251,159 +1617,183 @@ def sample(self): self.ema.ema_model.eval() - cmap = mpl.colormaps['viridis'] - sm = smap(None, cmap) + PATH = "/extra/ucibdl0/shared/data/fv3gfs" + XX = xr.open_zarr(f"{PATH}/c48_precip_plus_more_ave/0011/sfc_8xdaily_ave_coarse.zarr") + XX_ = xr.open_zarr(f"{PATH}/c48_atmos_ave/0011/atmos_8xdaily_ave_coarse.zarr") + yy = xr.open_zarr(f"{PATH}/c384_precip_ave/0011/sfc_8xdaily_ave.zarr") + topot = xr.open_zarr(f"{PATH}/c384_topo/0011/atmos_static.zarr") + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + c384_topo = pickle.load(f) + + if self.multi: + + c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + c384_channels = ["PRATEsfc"] + + else: + + c48_channels = ["PRATEsfc_coarse"] + c384_channels = ["PRATEsfc"] with torch.no_grad(): - for k, batch in enumerate(self.val_dl): + for tile in range(6): + + if self.rollout == 'full': + + seq_len = self.rollout_batch + st = 0 + en = seq_len + 2 + count = 0 + + while en < 3176: + + print(tile, st) + + X = XX.isel(time = slice(st, en), tile = tile) + X_ = XX_.isel(time = slice(st, en), tile = tile) + y = yy.isel(time = slice(st, en), tile = tile) + + + if self.multi: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + else: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) - lres = batch['LR'].to(device) - hres = batch['HR'].to(device) - - if k >= self.val_num_of_batch: - break - - limit = lres.shape[1] - if limit < 8: - - #videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres, True) - videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres) - - torch.save(videos, os.path.join(self.eval_folder) + "/gen.pt") - torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr.pt") - torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr.pt") - - for i, b in enumerate(videos.clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - os.makedirs(os.path.join(self.eval_folder, "generated")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - - #videos = torch.log(videos.clamp(0.0, 1.0) + 1) - #hres = torch.log(hres + 1) - - #for i, b in enumerate(videos.clamp(0, 1)): - # for i, b in enumerate(videos): - # if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - # os.makedirs(os.path.join(self.eval_folder, "generated")) - # Parallel(n_jobs=4)( - # delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - # for j, f in enumerate(b.cpu()) - # ) - -# for i, b in enumerate(nsteps.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# os.makedirs(os.path.join(self.eval_folder, "residual")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(base.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# os.makedirs(os.path.join(self.eval_folder, "warped")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# os.makedirs(os.path.join(self.eval_folder, "flows")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - - for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "truth")): - os.makedirs(os.path.join(self.eval_folder, "truth")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - -# else: - -# videos, base, nsteps, flows = self.ema.ema_model.sample(lres[:,:7,:,:], hres[:,:7,:,:], True) - -# st = 5 -# ed = st + 7 - -# while ed < limit: - -# vi, ba, ns, fl = self.ema.ema_model.sample(lres[:,st:ed,:,:], hres[:,st:ed,:,:], True) -# st += 5 -# ed += 5 -# videos = torch.cat((videos, vi), 1) -# #base = torch.cat((base, ba), 1) -# #nsteps = torch.cat((nsteps, ns), 1) -# #flows = torch.cat((flows, fl), 1) - -# for i, b in enumerate(videos.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "generated")): -# os.makedirs(os.path.join(self.eval_folder, "generated")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "truth")): -# os.makedirs(os.path.join(self.eval_folder, "truth")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# # for i, b in enumerate(nsteps.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# # os.makedirs(os.path.join(self.eval_folder, "residual")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# # for i, b in enumerate(base.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# # os.makedirs(os.path.join(self.eval_folder, "warped")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) -# # for i, b in enumerate(flows.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# # os.makedirs(os.path.join(self.eval_folder, "flows")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) + if self.logscale: + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + else: + + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min']) / (c48_chl["PRATEsfc_coarse"]['max'] - c48_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_chl["PRATEsfc"]['min']) / (c384_chl["PRATEsfc"]['max'] - c384_chl["PRATEsfc"]['min']) + + if self.multi: + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + count += 1 + + st += seq_len + en += seq_len + + if self.rollout == 'partial': + + seq_len = self.rollout_batch + indices = get_random_idx_with_difference(0, 3176 - (seq_len + 2), 75 // seq_len, seq_len + 2) # 250 samples per tile + + for count, st in enumerate(indices): + + print(tile, count) + + X = XX.isel(time = slice(st, st+(seq_len+2)), tile = tile) + X_ = XX_.isel(time = slice(st, st+(seq_len+2)), tile = tile) + y = yy.isel(time = slice(st, st+(seq_len+2)), tile = tile) + + + if self.multi: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + else: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_d")): -# os.makedirs(os.path.join(self.eval_folder, "flows_d")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_d") + f"/{k}-{i}-{j}.png", flow_vis.flow_to_color(f.permute(1,2,0).cpu().numpy()[:,:,:2], convert_to_bgr = False)) -# for j, f in enumerate(b.cpu()) -# ) -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_s")): -# os.makedirs(os.path.join(self.eval_folder, "flows_s")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_s") + f"/{k}-{i}-{j}.png", f.permute(1,2,0).cpu().numpy()[:,:,2], cmap = 'gray_r') -# for j, f in enumerate(b.cpu()) -# ) \ No newline at end of file + if self.logscale: + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + else: + + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min']) / (c48_chl["PRATEsfc_coarse"]['max'] - c48_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_chl["PRATEsfc"]['min']) / (c384_chl["PRATEsfc"]['max'] - c384_chl["PRATEsfc"]['min']) + + if self.multi: + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) \ No newline at end of file diff --git a/projects/super_res/model/denoising_diffusion_rvrt_full.py b/projects/super_res/model/denoising_diffusion_rvrt_full.py new file mode 100644 index 0000000000..3df61b513a --- /dev/null +++ b/projects/super_res/model/denoising_diffusion_rvrt_full.py @@ -0,0 +1,1611 @@ +import os +import math +from pathlib import Path +from random import random, randint +from functools import partial, reduce, lru_cache +from collections import namedtuple +from operator import mul + +import numpy as np +import cv2 +from scipy.stats import wasserstein_distance + +import xarray as xr + +import torch +from torch import nn +import torch.nn.functional as F +import wandb + +import piq +import pickle + +from torchvision.transforms.functional import crop + +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.cm import ScalarMappable as smap + +from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR + +from einops import rearrange +import einops +from einops.layers.torch import Rearrange + +from PIL import Image + +from tqdm.auto import tqdm +from ema_pytorch import EMA + +from accelerate import Accelerator +from distutils.version import LooseVersion +from .op.deform_attn import deform_attn, DeformAttnPack + +# constants + +ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) + +# helpers functions + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, eps=1e-9): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y, reduction = None): + diff = x - y + loss = torch.sqrt((diff * diff) + self.eps) + return loss + +def save_image(tensor, path): + im = Image.fromarray((tensor[:,:,:3] * 255).astype(np.uint8)) + im.save(path) + return None + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def identity(t, *args, **kwargs): + return t + +def cycle(dl): + while True: + for data in dl: + yield data + +def has_int_squareroot(num): + return (math.sqrt(num) ** 2) == num + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +def convert_image_to_fn(img_type, image): + if image.mode != img_type: + return image.convert(img_type) + return image + +def make_layer(block, num_blocks, **kwarg): + """Make layers by stacking the same blocks. + + Args: + block (nn.module): nn.module class for basic block. + num_blocks (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_blocks): + layers.append(block(**kwarg)) + return nn.Sequential(*layers) + +# normalization functions + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +# model helpers + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + + Returns: + Tensor: Warped image or feature map. + """ + n, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), + torch.arange(0, w, dtype=x.dtype, device=x.device)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + return output + + +def make_layer(block, num_blocks, **kwarg): + """Make layers by stacking the same blocks. + + Args: + block (nn.module): nn.module class for basic block. + num_blocks (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_blocks): + layers.append(block(**kwarg)) + return nn.Sequential(*layers) + + +class BasicModule(nn.Module): + """Basic Module for SpyNet. + """ + + def __init__(self): + super(BasicModule, self).__init__() + + self.basic_module = nn.Sequential( + nn.Conv2d(in_channels=26, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, tensor_input): + return self.basic_module(tensor_input) + + +class SpyNet(nn.Module): + """SpyNet architecture. + + Args: + load_path (str): path for pretrained SpyNet. Default: None. + return_levels (list[int]): return flows of different levels. Default: [5]. + """ + + def __init__(self, load_path=None, return_levels=[5]): + super(SpyNet, self).__init__() + self.return_levels = return_levels + self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) + + + def process(self, ref, supp, w, h, w_floor, h_floor): + flow_list = [] + + ref = [ref] + supp = [supp] + + for level in range(5): + ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) + supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) + + flow = ref[0].new_zeros( + [ref[0].size(0), 2, + int(math.floor(ref[0].size(2) / 2.0)), + int(math.floor(ref[0].size(3) / 2.0))]) + + for level in range(len(ref)): + upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + + if upsampled_flow.size(2) != ref[level].size(2): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') + if upsampled_flow.size(3) != ref[level].size(3): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') + + flow = self.basic_module[level](torch.cat([ + ref[level], + flow_warp( + supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), + upsampled_flow + ], 1)) + upsampled_flow + + if level in self.return_levels: + scale = 2 ** (5 - level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8) + flow_out = F.interpolate(input=flow, size=(h // scale, w // scale), mode='bilinear', + align_corners=False) + flow_out[:, 0, :, :] *= float(w // scale) / float(w_floor // scale) + flow_out[:, 1, :, :] *= float(h // scale) / float(h_floor // scale) + flow_list.insert(0, flow_out) + + return flow_list + + def forward(self, ref, supp): + assert ref.size() == supp.size() + + h, w = ref.size(2), ref.size(3) + w_floor = math.floor(math.ceil(w / 32.0) * 32.0) + h_floor = math.floor(math.ceil(h / 32.0) * 32.0) + + ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + + flow_list = self.process(ref, supp, w, h, w_floor, h_floor) + + return flow_list[0] if len(flow_list) == 1 else flow_list + + +class GuidedDeformAttnPack(DeformAttnPack): + """Guided deformable attention module. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + attention_window (int or tuple[int]): Attention window size. Default: [3, 3]. + attention_heads (int): Attention head number. Default: 12. + deformable_groups (int): Deformable offset groups. Default: 12. + clip_size (int): clip size. Default: 2. + max_residue_magnitude (int): The maximum magnitude of the offset residue. Default: 10. + Ref: + Recurrent Video Restoration Transformer with Guided Deformable Attention + + """ + + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + + super(GuidedDeformAttnPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv3d(self.in_channels * (1 + self.clip_size) + self.clip_size * 2, 64, kernel_size=(1, 1, 1), + padding=(0, 0, 0)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, self.clip_size * self.deformable_groups * self.attn_size * 2, kernel_size=(1, 1, 1), + padding=(0, 0, 0)), + ) + self.init_offset() + + # proj to a higher dimension can slightly improve the performance + self.proj_channels = int(self.in_channels * 2) + self.proj_q = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_k = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_v = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.proj_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.mlp = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + Mlp(self.in_channels, self.in_channels * 2, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + + def init_offset(self): + if hasattr(self, 'conv_offset'): + self.conv_offset[-1].weight.data.zero_() + self.conv_offset[-1].bias.data.zero_() + + def forward(self, q, k, v, v_prop_warped, flows, return_updateflow): + offset1, offset2 = torch.chunk(self.max_residue_magnitude * torch.tanh( + self.conv_offset(torch.cat([q] + v_prop_warped + flows, 2).transpose(1, 2)).transpose(1, 2)), 2, dim=2) + offset1 = offset1 + flows[0].flip(2).repeat(1, 1, offset1.size(2) // 2, 1, 1) + offset2 = offset2 + flows[1].flip(2).repeat(1, 1, offset2.size(2) // 2, 1, 1) + offset = torch.cat([offset1, offset2], dim=2).flatten(0, 1) + + b, t, c, h, w = offset1.shape + q = self.proj_q(q).view(b * t, 1, self.proj_channels, h, w) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size).view(b, t, self.proj_channels, h, + w) + v = self.proj(v) + v = v + self.mlp(v) + + if return_updateflow: + return v, offset1.view(b, t, c // 2, 2, h, w).mean(2).flip(2), offset2.view(b, t, c // 2, 2, h, w).mean( + 2).flip(2) + else: + return v + +def window_partition(x, window_size): + """ Partition the input into windows. Attention will be conducted within the windows. + + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], + window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """ Reverse windows back to the original input. Attention was conducted within the windows. + + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + + return x + + +def get_window_size(x_size, window_size, shift_size=None): + """ Get the window size and the shift size """ + + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + """ Compute attnetion mask for input of size (D, H, W). @lru_cache caches each stage results. """ + + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + +class Mlp(nn.Module): + """ Multilayer perceptron. + + Args: + x: (B, D, H, W, C) + + Returns: + x: (B, D, H, W, C) + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + +class WindowAttention(nn.Module): + """ Window based multi-head self attention. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None): + super().__init__() + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + self.register_buffer("relative_position_index", self.get_position_index(window_size)) + self.qkv_self = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + + # self attention + B_, N, C = x.shape + qkv = self.qkv_self(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + x_out = self.attention(q, k, v, mask, (B_, N, C)) + + # projection + x = self.proj(x_out) + + return x + + def attention(self, q, k, v, mask, x_shape): + B_, N, C = x_shape + attn = (q * self.scale) @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:N, :N].reshape(-1)].reshape(N, N, -1) # Wd*Wh*Ww, Wd*Wh*Ww,nH + attn = attn + relative_position_bias.permute(2, 0, 1).unsqueeze(0) # B_, nH, N, N + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask[:, :N, :N].unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + attn = F.softmax(attn, -1, dtype=q.dtype) # Don't use attn.dtype after addition! + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + + return x + + def get_position_index(self, window_size): + ''' Get pair-wise relative position index for each token inside the window. ''' + + coords_d = torch.arange(window_size[0]) + coords_h = torch.arange(window_size[1]) + coords_w = torch.arange(window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 2] += window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * window_size[1] - 1) * (2 * window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + + return relative_position_index + +class STL(nn.Module): + """ Swin Transformer Layer (STL). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for mutual and self attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True. + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm. + use_checkpoint_attn (bool): If True, use torch.checkpoint for attention modules. Default: False. + use_checkpoint_ffn (bool): If True, use torch.checkpoint for feed-forward modules. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=(2, 8, 8), + shift_size=(0, 0, 0), + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_checkpoint_attn=False, + use_checkpoint_ffn=False + ): + super().__init__() + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.use_checkpoint_attn = use_checkpoint_attn + self.use_checkpoint_ffn = use_checkpoint_ffn + + assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, + qk_scale=qk_scale) + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + + x = self.norm1(x) + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1), mode='constant') + + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C + + # attention / shifted attention + attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C + + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :] + + return x + + def forward_part2(self, x): + return self.mlp(self.norm2(x)) + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + # attention + x = x + self.forward_part1(x, mask_matrix) + + # feed-forward + x = x + self.forward_part2(x) + + return x + + +class STG(nn.Module): + """ Swin Transformer Group (STG). + + Args: + dim (int): Number of feature channels + input_resolution (tuple[int]): Input resolution. + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (6,8,8). + shift_size (tuple[int]): Shift size for mutual and self attention. Default: None. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 2. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + use_checkpoint_attn (bool): If True, use torch.checkpoint for attention modules. Default: False. + use_checkpoint_ffn (bool): If True, use torch.checkpoint for feed-forward modules. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=[2, 8, 8], + shift_size=None, + mlp_ratio=2., + qkv_bias=False, + qk_scale=None, + norm_layer=nn.LayerNorm, + use_checkpoint_attn=False, + use_checkpoint_ffn=False, + ): + super().__init__() + self.input_resolution = input_resolution + self.window_size = window_size + self.shift_size = list(i // 2 for i in window_size) if shift_size is None else shift_size + + # build blocks + self.blocks = nn.ModuleList([ + STL( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=[0, 0, 0] if i % 2 == 0 else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint_attn=use_checkpoint_attn, + use_checkpoint_ffn=use_checkpoint_ffn + ) + for i in range(depth)]) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for attention + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + + for blk in self.blocks: + x = blk(x, attn_mask) + + x = x.view(B, D, H, W, -1) + x = rearrange(x, 'b d h w c -> b c d h w') + + return x + +class RSTB(nn.Module): + """ Residual Swin Transformer Block (RSTB). + + Args: + kwargs: Args for RSTB. + """ + + def __init__(self, groups = 8, **kwargs): + super(RSTB, self).__init__() + self.input_resolution = kwargs['input_resolution'] + + self.residual_group = STG(**kwargs) + self.linear = nn.Linear(kwargs['dim'], kwargs['dim']) + self.proj = nn.Conv3d(kwargs['dim'], + kwargs['dim'], + kernel_size=(1,3,3), + padding=(0,1,1), + groups=groups) + self.norm = nn.GroupNorm(groups, kwargs['dim']) + self.act = nn.SiLU() + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + + x = self.act(x) + + return x + self.linear(self.residual_group(x).transpose(1, 4)).transpose(1, 4) + +class RSTBWithInputConv(nn.Module): + """RSTB with a convolution in front. + + Args: + in_channels (int): Number of input channels of the first conv. + kernel_size (int): Size of kernel of the first conv. + stride (int): Stride of the first conv. + group (int): Group of the first conv. + num_blocks (int): Number of residual blocks. Default: 2. + **kwarg: Args for RSTB. + """ + + def __init__(self, in_channels=3, kernel_size=(1, 3, 3), stride=1, groups=1, num_blocks=2, **kwargs): + super(RSTBWithInputConv, self).__init__() + + self.in_channels = in_channels + self.init_conv = nn.Conv3d(in_channels, + kwargs['dim'], + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2), + groups=groups) + + self.init_norm = nn.LayerNorm(kwargs['dim']) + + # RSTB blocks + #kwargs['use_checkpoint_attn'] = kwargs.pop('use_checkpoint_attn')[0] + #kwargs['use_checkpoint_ffn'] = kwargs.pop('use_checkpoint_ffn')[0] + + #main.append(make_layer(RSTB, num_blocks, **kwargs)) + self.main1 = [] + for _ in range(num_blocks): + self.main1.append(RSTB(**kwargs).cuda()) + + main2 = [] + main2 += [Rearrange('n c d h w -> n d h w c'), + nn.LayerNorm(kwargs['dim']), + Rearrange('n d h w c -> n d c h w')] + + self.main2 = nn.Sequential(*main2) + + def forward(self, x): + """ + Forward function for RSTBWithInputConv. + + Args: + feat (Tensor): Input feature with shape (n, t, in_channels, h, w) + + Returns: + Tensor: Output feature with shape (n, t, out_channels, h, w) + """ + + + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.init_conv(x) + + x = rearrange(x, 'n c d h w -> n d h w c') + x = self.init_norm(x) + x = rearrange(x, 'n d h w c -> n c d h w') + + for i in range(len(self.main1)): + x = self.main1[i](x) + x = self.main2(x) + + return x + +class Upsample(nn.Module): + '''Upsample module for video SR. + + Args: + scale (int): Scale factor. Supported scales: 4. + num_feat (int): Channel number of intermediate features. + ''' + + def __init__(self, scale, num_feat, **kwargs): + super(Upsample, self).__init__() + + assert LooseVersion(torch.__version__) >= LooseVersion('1.8.1'), \ + 'PyTorch version >= 1.8.1 to support 5D PixelShuffle.' + + self.feat1 = nn.Conv3d(num_feat, 4 * num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + self.feat2 = nn.Conv3d(num_feat, 4 * num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + self.feat3 = nn.Conv3d(num_feat, 4 * num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + + self.upsample1 = nn.PixelShuffle(2) + self.upsample2 = nn.PixelShuffle(2) + self.upsample3 = nn.PixelShuffle(2) + + self.lrelu1 = nn.LeakyReLU(negative_slope=0.1) + self.lrelu2 = nn.LeakyReLU(negative_slope=0.1) + self.lrelu3 = nn.LeakyReLU(negative_slope=0.1) + + self.final = nn.Conv3d(num_feat, 1, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + + def forward(self, x): + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.feat1(x) + x = rearrange(x, 'n c d h w -> n d c h w') + x = self.upsample1(x) + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.lrelu1(x) + x = self.feat2(x) + x = rearrange(x, 'n c d h w -> n d c h w') + x = self.upsample2(x) + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.lrelu2(x) + x = self.feat3(x) + x = rearrange(x, 'n c d h w -> n d c h w') + x = self.upsample3(x) + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.lrelu3(x) + + x = self.final(x) + x = rearrange(x, 'n c d h w -> n d c h w') + + return x + +class GaussianDiffusion(nn.Module): + def __init__( + self, + feat_ext, + feat_up, + backbone, + deform_align, + recon, + spynet, + *, + image_size, + timesteps = 1200, + sampling_timesteps = None, + loss_type = 'l1', + objective = 'pred_noise', + beta_schedule = 'sigmoid', + schedule_fn_kwargs = dict(), + p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended + p2_loss_weight_k = 1, + ddim_sampling_eta = 0., + auto_normalize = True + ): + super(GaussianDiffusion, self).__init__() + self.clip_size = 2 + self.feat_ext = feat_ext + self.feat_up = feat_up + + self.backbone = backbone + + self.deform_align = deform_align + + self.recon = recon + + self.spynet = spynet + + self.channels = self.feat_ext.in_channels + + self.image_size = image_size + + self.loss_type = loss_type + + self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity + self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity + + @property + def loss_fn(self): + if self.loss_type == 'l1': + return F.l1_loss + elif self.loss_type == 'l2': + return F.mse_loss + elif self.loss_type == 'charbonnier': + return CharbonnierLoss() + else: + raise ValueError(f'invalid loss type {self.loss_type}') + + def compute_flow(self, lqs): + """Compute optical flow using SPyNet for feature alignment. + + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lqs.size() + lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w) + lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w) + + flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w) + + return flows_forward, flows_backward + + def propagate(self, feats, flows, module_name, updated_flows=None): + """Propagate the latent clip features throughout the sequence. + + Args: + feats dict(list[tensor]): Features from previous branches. Each + component is a list of tensors with shape (n, clip_size, c, h, w). + flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). + module_name (str): The name of the propgation branches. Can either + be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. + updated_flows dict(list[tensor]): Each component is a list of updated + optical flows with shape (n, clip_size, 2, h, w). + + Return: + dict(list[tensor]): A dictionary containing all the propagated + features. Each key in the dictionary corresponds to a + propagation branch, which is represented by a list of tensors. + """ + + n, t, _, h, w = flows.size() + if 'backward' in module_name: + flow_idx = range(0, t + 1)[::-1] + clip_idx = range(0, (t + 1) // self.clip_size)[::-1] + else: + flow_idx = range(-1, t) + clip_idx = range(0, (t + 1) // self.clip_size) + + if '_1' in module_name: + updated_flows[f'{module_name}_n1'] = [] + updated_flows[f'{module_name}_n2'] = [] + + feat_prop = torch.zeros_like(feats['shallow'][0])#.cuda() + + last_key = list(feats)[-2] + + for i in range(0, len(clip_idx)): + idx_c = clip_idx[i] + if i > 0: + if '_1' in module_name: + flow_n01 = flows[:, flow_idx[self.clip_size * i - 1], :, :, :] + flow_n12 = flows[:, flow_idx[self.clip_size * i], :, :, :] + flow_n23 = flows[:, flow_idx[self.clip_size * i + 1], :, :, :] + flow_n02 = flow_n12 + flow_warp(flow_n01, flow_n12.permute(0, 2, 3, 1)) + flow_n13 = flow_n23 + flow_warp(flow_n12, flow_n23.permute(0, 2, 3, 1)) + flow_n03 = flow_n23 + flow_warp(flow_n02, flow_n23.permute(0, 2, 3, 1)) + flow_n1 = torch.stack([flow_n02, flow_n13], 1) + flow_n2 = torch.stack([flow_n12, flow_n03], 1) + else: + module_name_old = module_name.replace('_2', '_1') + flow_n1 = updated_flows[f'{module_name_old}_n1'][i - 1] + flow_n2 = updated_flows[f'{module_name_old}_n2'][i - 1] + + + if 'backward' in module_name: + feat_q = feats[last_key][idx_c].flip(1) + feat_k = feats[last_key][clip_idx[i - 1]].flip(1) + else: + feat_q = feats[last_key][idx_c] + feat_k = feats[last_key][clip_idx[i - 1]] + + feat_prop_warped1 = flow_warp(feat_prop.flatten(0, 1), + flow_n1.permute(0, 1, 3, 4, 2).flatten(0, 1))\ + .view(n, feat_prop.shape[1], feat_prop.shape[2], h, w) + feat_prop_warped2 = flow_warp(feat_prop.flip(1).flatten(0, 1), + flow_n2.permute(0, 1, 3, 4, 2).flatten(0, 1))\ + .view(n, feat_prop.shape[1], feat_prop.shape[2], h, w) + + if '_1' in module_name: + feat_prop, flow_n1, flow_n2 = self.deform_align[module_name](feat_q, feat_k, feat_prop, + [feat_prop_warped1, feat_prop_warped2], + [flow_n1, flow_n2], + True) + updated_flows[f'{module_name}_n1'].append(flow_n1) + updated_flows[f'{module_name}_n2'].append(flow_n2) + else: + feat_prop = self.deform_align[module_name](feat_q, feat_k, feat_prop, + [feat_prop_warped1, feat_prop_warped2], + [flow_n1, flow_n2], + False) + + if 'backward' in module_name: + feat = [feats[k][idx_c].flip(1) for k in feats if k not in [module_name]] + [feat_prop] + else: + feat = [feats[k][idx_c] for k in feats if k not in [module_name]] + [feat_prop] + + #print(len(feat), feat[0].shape, feat[1].shape) + fp = self.backbone[module_name](torch.cat(feat, dim=2)) + #fp = self.backbone[module_name](torch.cat(feat, dim=1)) + feat_prop = feat_prop + fp + + feats[module_name].append(feat_prop) + + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + feats[module_name] = [f.flip(1) for f in feats[module_name]] + + return feats + + def forward(self, lres, hres, *args, **kwargs): + + b, f, c, h, w, device, img_size = *hres.shape, hres.device, self.image_size + + assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h//8, w//8), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + lres = torch.cat([lres, topo_low], dim = 2) + + lres = self.normalize(lres) + hres = self.normalize(hres) + + flows_forward, flows_backward = self.compute_flow(lres) + + feats = {} + ff = self.feat_ext(lres) + + feats['shallow'] = list(torch.chunk(ff, f // self.clip_size, dim = 1)) + + updated_flows = {} + for iter_ in [1, 2]: + for direction in ['backward', 'forward']: + if direction == 'backward': + flows = flows_backward + else: + flows = flows_forward if flows_forward is not None else flows_backward.flip(1) + + module_name = f'{direction}_{iter_}' + feats[module_name] = [] + + feats = self.propagate(feats, flows, module_name, updated_flows) + + feats['shallow'] = torch.cat(feats['shallow'], 1) + feats['backward_1'] = torch.cat(feats['backward_1'], 1) + feats['forward_1'] = torch.cat(feats['forward_1'], 1) + feats['backward_2'] = torch.cat(feats['backward_2'], 1) + feats['forward_2'] = torch.cat(feats['forward_2'], 1) + upsampled = torch.cat([feats[k] for k in feats], dim=2) + upsampled = self.recon(upsampled) + upsampled = self.feat_up(upsampled) + upsampled = upsampled + F.interpolate(lres[:,:,0:1,:,:], size = (1, h, w), mode = 'trilinear', align_corners = False) + + loss = self.loss_fn(upsampled, hres, reduction = 'none') + loss = einops.reduce(loss, 'b ... -> b (...)', 'mean') + + return loss.mean(), upsampled + +class Trainer(object): + def __init__( + self, + diffusion_model, + train_dl, + val_dl, + config, + *, + train_batch_size = 16, + gradient_accumulate_every = 1, + #augment_horizontal_flip = True, + train_lr = 1e-4, + train_num_steps = 100000, + ema_update_every = 1, + ema_decay = 0.995, + adam_betas = (0.9, 0.99), + save_and_sample_every = 1, + #num_samples = 25, + eval_folder = './evaluate', + results_folder = './results', + #tensorboard_dir = './tensorboard', + val_num_of_batch = 2, + amp = False, + fp16 = False, + #fp16 = True, + split_batches = True, + #split_batches = False, + convert_image_to = None + ): + super().__init__() + + self.accelerator = Accelerator( + split_batches = split_batches, + mixed_precision = 'fp16' if fp16 else 'no', + log_with = 'wandb', + ) + self.accelerator.init_trackers("climate", + init_kwargs={ + "wandb": { + "name": None, + } + }, + ) + self.config = config + self.accelerator.native_amp = amp + self.multi = config.data_config["multi"] + self.rollout = config.rollout + self.rollout_batch = config.rollout_batch + #self.flow = config.data_config["flow"] + self.minipatch = config.data_config["minipatch"] + self.logscale = config.data_config["logscale"] + + self.model = diffusion_model + + self.save_and_sample_every = save_and_sample_every + + self.batch_size = train_batch_size + self.gradient_accumulate_every = gradient_accumulate_every + + self.train_num_steps = train_num_steps + self.image_size = diffusion_model.image_size + + self.val_num_of_batch = val_num_of_batch + + # optimizer + + self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) + + self.sched = CosineAnnealingLR(self.opt, train_num_steps, 5e-7) + + # for logging results in a folder periodically + + if self.accelerator.is_main_process: + self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) + + self.results_folder = Path(results_folder) + + self.results_folder.mkdir(exist_ok=True, parents=True) + + self.eval_folder = Path(eval_folder) + + self.eval_folder.mkdir(exist_ok=True, parents=True) + + # step counter state + + self.step = 0 + + # prepare model, dataloader, optimizer with accelerator + + self.model, self.opt, self.sched, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, self.sched, train_dl, val_dl) + self.train_dl = cycle(train_dl) + self.val_dl = val_dl + + def save(self, milestone): + if not self.accelerator.is_local_main_process: + return + + data = { + 'step': self.step, + 'model': self.accelerator.get_state_dict(self.model), + 'opt': self.opt.state_dict(), + 'ema': self.ema.state_dict(), + 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, + #'version': __version__ + } + + torch.save(data, str(self.results_folder / f'qmodel-{milestone%3}.pt')) + + def load(self, milestone): + accelerator = self.accelerator + device = accelerator.device + + data = torch.load(str(self.results_folder / f'qmodel-{milestone}.pt'), map_location=device) + + model = self.accelerator.unwrap_model(self.model) + model.load_state_dict(data['model']) + + self.step = data['step'] + #self.opt.load_state_dict(data['opt']) + self.ema.load_state_dict(data['ema']) + print('loaded') + + #if 'version' in data: + # print(f"loading from version {data['version']}") + + if exists(self.accelerator.scaler) and exists(data['scaler']): + self.accelerator.scaler.load_state_dict(data['scaler']) + + def train(self): + + accelerator = self.accelerator + device = accelerator.device + + cmap = mpl.colormaps['RdBu_r'] + fcmap = mpl.colormaps['gray_r'] + + # c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + # c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + # c384_gmin = np.load('data/only_precip/c384_gmin.npy') + + # c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + # c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + # c48_gmin = np.load('data/only_precip/c48_gmin.npy') + + # c384_min = np.load('data/only_precip/c384_min.npy') + # c384_max = np.load('data/only_precip/c384_max.npy') + + # c48_min = np.load('data/only_precip/c48_min.npy') + # c48_max = np.load('data/only_precip/c48_max.npy') + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + c384_log_chl = pickle.load(f) + + c384_lgmin = c384_log_chl["PRATEsfc"]['min'] + c384_lgmax = c384_log_chl["PRATEsfc"]['max'] + c48_lgmin = c48_log_chl["PRATEsfc_coarse"]['min'] + c48_lgmax = c48_log_chl["PRATEsfc_coarse"]['max'] + + c384_min = c384_chl["PRATEsfc"]['min'] + c384_max = c384_chl["PRATEsfc"]['max'] + c48_min = c48_chl["PRATEsfc_coarse"]['min'] + c48_max = c48_chl["PRATEsfc_coarse"]['max'] + + c384_gmin = c384_min + c48_gmin = c48_min + + with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: + + while self.step < self.train_num_steps: + + total_loss = 0. + + for _ in range(self.gradient_accumulate_every): + + data = next(self.train_dl) + lres = data['LR'].to(device) + hres = data['HR'].to(device) + + if self.minipatch: + + x_st = randint(0, 36) + y_st = randint(0, 36) + lres = crop(lres, x_st, y_st, 12, 12) + hres = crop(hres, 8 * x_st, 8 * y_st, 96, 96) + + with self.accelerator.autocast(): + + loss, _ = self.model(lres, hres) + loss = loss / self.gradient_accumulate_every + total_loss += loss.item() + + self.accelerator.backward(loss) + + accelerator.clip_grad_norm_(self.model.parameters(), 1.0) + pbar.set_description(f'loss: {total_loss:.4f}') + + accelerator.log({"loss": total_loss}, step = self.step) + + accelerator.wait_for_everyone() + + self.opt.step() + self.opt.zero_grad() + self.sched.step() + + accelerator.wait_for_everyone() + + self.step += 1 + if accelerator.is_main_process: + self.ema.to(device) + self.ema.update() + + if self.step != 0 and self.step % self.save_and_sample_every == 0: + self.ema.ema_model.eval() + + with torch.no_grad(): + + vlosses = [] + vids = [] + hr = [] + lr = [] + bases, ress, flowss = [], [], [] + num_frames = 5 + img_size = 384 + + for i, batch in enumerate(self.val_dl): + + lres = batch['LR'].to(device) + hres = batch['HR'].to(device) + + if i >= self.val_num_of_batch: + break + + loss, videos = self.model(lres, hres) + + + vids.append(videos) + vlosses.append(loss) + hr.append(hres) + lr.append(lres) + + videos = torch.cat(vids, dim = 0) + vloss = torch.stack(vlosses, dim = 0).mean() + #self.sched.step(vloss) + hres = torch.cat(hr, dim = 0) + lres = torch.cat(lr, dim = 0) + del vids, vlosses, hr, lr + + + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] + + if not self.logscale: + target = hres[:,:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 + + ssim_index = piq.ssim(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + gmsd_index = piq.gmsd(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + rmse = np.sqrt(np.mean((output1 - target1)**2)) + pscore = np.abs(np.percentile(output1, 99.999) - np.percentile(target1, 99.999)) + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + #histo = np.histogram(output1, bins=bins1, density=True)[0].ravel().astype('float32') + #histt = np.histogram(target1, bins=bins1, density=True)[0].ravel().astype('float32') + count_o, bin_o = np.histogram(output1, bins=bins1, density=True) + count_t, bin_t = np.histogram(target1, bins=bins1, density=True) + histo = count_o.ravel().astype('float32') + histt = count_t.ravel().astype('float32') + distchisqr = cv2.compareHist(histo, histt, cv2.HISTCMP_CHISQR) + distinter = cv2.compareHist(histo, histt, cv2.HISTCMP_INTERSECT) + distkl = cv2.compareHist(histo, histt, cv2.HISTCMP_KL_DIV) + distemd = wasserstein_distance(output1, target1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + #output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + bin_o[:-1], bins=bin_o, weights = count_o, alpha=0.5, label="Output", histtype="step"#, density=True + ) + ax1.hist( + #target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + bin_t[:-1], bins=bin_t, weights = count_t, alpha=0.5, label="Target", histtype="step"#, density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + if self.logscale: + + accelerator.log({"true_high": wandb.Video((hres[0:1,:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + + else: + + accelerator.log({"true_high": wandb.Video((hres[0:1,:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"ssim": ssim_index.mean()}, step=self.step) + accelerator.log({"gmsd": gmsd_index.mean()}, step=self.step) + accelerator.log({"rmse": rmse}, step=self.step) + accelerator.log({"pscore": pscore}, step=self.step) + accelerator.log({"distchisqr": distchisqr}, step=self.step) + accelerator.log({"distinter": distinter}, step=self.step) + accelerator.log({"distkl": distkl}, step=self.step) + accelerator.log({"distemd": distemd}, step=self.step) + accelerator.log({"vloss": vloss}, step=self.step) + accelerator.log({"lr": self.opt.param_groups[0]['lr']}, step=self.step) + + milestone = self.step // self.save_and_sample_every + + self.save(milestone) + + pbar.update(1) + + accelerator.print('training complete') + + def sample(self): + + accelerator = self.accelerator + device = accelerator.device + + self.ema.ema_model.eval() + + PATH = "/extra/ucibdl0/shared/data/fv3gfs" + XX = xr.open_zarr(f"{PATH}/c48_precip_plus_more_ave/0011/sfc_8xdaily_ave_coarse.zarr") + XX_ = xr.open_zarr(f"{PATH}/c48_atmos_ave/0011/atmos_8xdaily_ave_coarse.zarr") + yy = xr.open_zarr(f"{PATH}/c384_precip_ave/0011/sfc_8xdaily_ave.zarr") + topot = xr.open_zarr(f"{PATH}/c384_topo/0011/atmos_static.zarr") + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + c384_topo = pickle.load(f) + + if self.multi: + + c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + c384_channels = ["PRATEsfc"] + + else: + + c48_channels = ["PRATEsfc_coarse"] + c384_channels = ["PRATEsfc"] + + with torch.no_grad(): + + for tile in range(6): + + if self.rollout == 'full': + + seq_len = self.rollout_batch + st = 0 + en = seq_len + 2 + count = 0 + + while en < 3176: + + print(tile, st) + + X = XX.isel(time = slice(st, en), tile = tile) + X_ = XX_.isel(time = slice(st, en), tile = tile) + y = yy.isel(time = slice(st, en), tile = tile) + + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + loss, videos = self.model(lres, hres) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + count += 1 + + st += seq_len + en += seq_len + + if self.rollout == 'partial': + + seq_len = self.rollout_batch + #indices = get_random_idx_with_difference(0, 3176 - (seq_len + 2), 75 // seq_len, seq_len + 2) # 75 samples per tile + indices = list(range(0, 3176 - (seq_len + 2), 250)) # deterministic, 325 samples per tile for seq_len of 25 + + for count, st in enumerate(indices): + + print(tile, count) + + X = XX.isel(time = slice(st, st+(seq_len+2)), tile = tile) + X_ = XX_.isel(time = slice(st, st+(seq_len+2)), tile = tile) + y = yy.isel(time = slice(st, st+(seq_len+2)), tile = tile) + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + loss, videos = self.model(lres, hres) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) \ No newline at end of file diff --git a/projects/super_res/model/isr_baseline.py b/projects/super_res/model/isr_baseline.py new file mode 100644 index 0000000000..ed59f3c013 --- /dev/null +++ b/projects/super_res/model/isr_baseline.py @@ -0,0 +1,568 @@ +from pathlib import Path +import os + +import numpy as np +import xarray as xr + +import torch +import wandb + +import piq +import pickle +import cv2 +from scipy.stats import wasserstein_distance + +from torch.optim import Adam +import torch.nn.functional as F + +from random import randint +from torch.optim.lr_scheduler import CosineAnnealingLR + +from tqdm.auto import tqdm +from ema_pytorch import EMA +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.cm import ScalarMappable as smap + +from accelerate import Accelerator +from einops import rearrange, reduce + +def get_random_idx_with_difference(min_tx, max_tx, number_tx, diff): + times = [] + while len(times) < number_tx: + new_time = randint(min_tx, max_tx) + if all(abs(new_time - time) >= diff for time in times): + times.append(new_time) + return times + +def cycle(dl): + while True: + for data in dl: + yield data + +def exists(x): + return x is not None + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +# trainer class + +class Trainer(object): + def __init__( + self, + model, + train_dl, + val_dl, + config, + *, + train_batch_size = 16, + gradient_accumulate_every = 1, + #augment_horizontal_flip = True, + train_lr = 1e-4, + train_num_steps = 100000, + ema_update_every = 1, + ema_decay = 0.995, + adam_betas = (0.9, 0.99), + save_and_sample_every = 10, + #num_samples = 25, + eval_folder = './evaluate', + results_folder = './results', + #tensorboard_dir = './tensorboard', + val_num_of_batch = 2, + amp = False, + fp16 = False, + #fp16 = True, + split_batches = True, + #split_batches = False, + convert_image_to = None + ): + super().__init__() + + self.accelerator = Accelerator( + split_batches = split_batches, + mixed_precision = 'fp16' if fp16 else 'no', + log_with = 'wandb', + ) + self.accelerator.init_trackers("climate", + init_kwargs={ + "wandb": { + "name": None, + } + }, + ) + self.config = config + self.accelerator.native_amp = amp + self.multi = config.data_config["multi"] + self.rollout = config.rollout + self.rollout_batch = config.rollout_batch + self.flow = config.data_config["flow"] + self.minipatch = config.data_config["minipatch"] + self.logscale = config.data_config["logscale"] + + self.model = model + + self.save_and_sample_every = save_and_sample_every + + self.batch_size = train_batch_size + self.gradient_accumulate_every = gradient_accumulate_every + + self.train_num_steps = train_num_steps + + self.val_num_of_batch = val_num_of_batch + + # optimizer + + self.opt = Adam(model.parameters(), lr = train_lr, betas = adam_betas) + self.sched = CosineAnnealingLR(self.opt, train_num_steps, 5e-7) + + # for logging results in a folder periodically + + if self.accelerator.is_main_process: + self.ema = EMA(model, beta = ema_decay, update_every = ema_update_every) + + self.results_folder = Path(results_folder) + + self.results_folder.mkdir(exist_ok=True, parents=True) + + self.eval_folder = Path(eval_folder) + + self.eval_folder.mkdir(exist_ok=True, parents=True) + + # step counter state + + self.step = 0 + + # prepare model, dataloader, optimizer with accelerator + + self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) + self.train_dl = cycle(train_dl) + self.val_dl = val_dl + + def save(self, milestone): + if not self.accelerator.is_local_main_process: + return + + data = { + 'step': self.step, + 'model': self.accelerator.get_state_dict(self.model), + 'opt': self.opt.state_dict(), + 'ema': self.ema.state_dict(), + 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, + #'version': __version__ + } + + torch.save(data, str(self.results_folder / f'qmodel-{milestone%3}.pt')) + + def load(self, milestone): + accelerator = self.accelerator + device = accelerator.device + + data = torch.load(str(self.results_folder / f'qmodel-{milestone}.pt'), map_location=device) + + model = self.accelerator.unwrap_model(self.model) + model.load_state_dict(data['model']) + + self.step = data['step'] + #self.opt.load_state_dict(data['opt']) + self.ema.load_state_dict(data['ema']) + + #if 'version' in data: + # print(f"loading from version {data['version']}") + + if exists(self.accelerator.scaler) and exists(data['scaler']): + self.accelerator.scaler.load_state_dict(data['scaler']) + + def train(self): + + accelerator = self.accelerator + device = accelerator.device + + cmap = mpl.colormaps['RdBu_r'] + fcmap = mpl.colormaps['gray_r'] + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + c384_log_chl = pickle.load(f) + + c384_lgmin = c384_log_chl["PRATEsfc"]['min'] + c384_lgmax = c384_log_chl["PRATEsfc"]['max'] + c48_lgmin = c48_log_chl["PRATEsfc_coarse"]['min'] + c48_lgmax = c48_log_chl["PRATEsfc_coarse"]['max'] + + c384_min = c384_chl["PRATEsfc"]['min'] + c384_max = c384_chl["PRATEsfc"]['max'] + c48_min = c48_chl["PRATEsfc_coarse"]['min'] + c48_max = c48_chl["PRATEsfc_coarse"]['max'] + + c384_gmin = c384_min + c48_gmin = c48_min + + with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: + + while self.step < self.train_num_steps: + + total_loss = 0. + + for _ in range(self.gradient_accumulate_every): + + data = next(self.train_dl) + lres = data['LR'].to(device) + hres = data['HR'].to(device) + + with self.accelerator.autocast(): + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(48, 48), mode='bilinear'), '(b t) c h w -> b t c h w', t = 7) + + ures = self.model(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + loss = F.mse_loss(ures, rearrange(hres, 'b t c h w -> (b t) c h w'), reduction = 'none') + loss = reduce(loss, 'b ... -> b (...)', 'mean') + loss = loss.mean() + + loss = loss / self.gradient_accumulate_every + total_loss += loss.item() + + self.accelerator.backward(loss) + + accelerator.clip_grad_norm_(self.model.parameters(), 1.0) + pbar.set_description(f'loss: {total_loss:.4f}') + + accelerator.log({"loss": total_loss}, step = self.step) + + accelerator.wait_for_everyone() + + self.opt.step() + self.opt.zero_grad() + + accelerator.wait_for_everyone() + + self.step += 1 + if accelerator.is_main_process: + self.ema.to(device) + self.ema.update() + + if self.step != 0 and self.step % self.save_and_sample_every == 0: + self.ema.ema_model.eval() + + with torch.no_grad(): + + vlosses = [] + vids = [] + hr = [] + lr = [] + num_frames = 5 + img_size = 384 + + for i, batch in enumerate(self.val_dl): + + lres = batch['LR'].to(device) + hres = batch['HR'].to(device) + + if i >= self.val_num_of_batch: + break + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(48, 48), mode='bilinear'), '(b t) c h w -> b t c h w', t = 7) + + ures = self.model(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + loss = F.mse_loss(ures, rearrange(hres, 'b t c h w -> (b t) c h w'), reduction = 'none') + + videos = rearrange(ures, '(b t) c h w -> b t c h w', t = 7) + + vids.append(videos) + vlosses.append(loss) + hr.append(hres) + lr.append(lres) + + videos = torch.cat(vids, dim = 0) + vloss = torch.stack(vlosses, dim = 0).mean() + #self.sched.step(vloss) + hres = torch.cat(hr, dim = 0) + lres = torch.cat(lr, dim = 0) + del vids, vlosses, hr, lr + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] + + if not self.logscale: + target = hres[:,:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 + + ssim_index = piq.ssim(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + gmsd_index = piq.gmsd(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + rmse = np.sqrt(np.mean((output1 - target1)**2)) + pscore = np.abs(np.percentile(output1, 99.999) - np.percentile(target1, 99.999)) + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + #histo = np.histogram(output1, bins=bins1, density=True)[0].ravel().astype('float32') + #histt = np.histogram(target1, bins=bins1, density=True)[0].ravel().astype('float32') + count_o, bin_o = np.histogram(output1, bins=bins1, density=True) + count_t, bin_t = np.histogram(target1, bins=bins1, density=True) + histo = count_o.ravel().astype('float32') + histt = count_t.ravel().astype('float32') + distchisqr = cv2.compareHist(histo, histt, cv2.HISTCMP_CHISQR) + distinter = cv2.compareHist(histo, histt, cv2.HISTCMP_INTERSECT) + distkl = cv2.compareHist(histo, histt, cv2.HISTCMP_KL_DIV) + distemd = wasserstein_distance(output1, target1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + #output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + bin_o[:-1], bins=bin_o, weights = count_o, alpha=0.5, label="Output", histtype="step"#, density=True + ) + ax1.hist( + #target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + bin_t[:-1], bins=bin_t, weights = count_t, alpha=0.5, label="Target", histtype="step"#, density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + if self.logscale: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + + else: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"ssim": ssim_index.mean()}, step=self.step) + accelerator.log({"gmsd": gmsd_index.mean()}, step=self.step) + accelerator.log({"rmse": rmse}, step=self.step) + accelerator.log({"pscore": pscore}, step=self.step) + accelerator.log({"distchisqr": distchisqr}, step=self.step) + accelerator.log({"distinter": distinter}, step=self.step) + accelerator.log({"distkl": distkl}, step=self.step) + accelerator.log({"distemd": distemd}, step=self.step) + accelerator.log({"vloss": vloss}, step=self.step) + accelerator.log({"lr": self.opt.param_groups[0]['lr']}, step=self.step) + + milestone = self.step // self.save_and_sample_every + + self.save(milestone) + + pbar.update(1) + + accelerator.print('training complete') + + def sample(self): + + accelerator = self.accelerator + device = accelerator.device + + self.ema.ema_model.eval() + + PATH = "/extra/ucibdl0/shared/data/fv3gfs" + XX = xr.open_zarr(f"{PATH}/c48_precip_plus_more_ave/0011/sfc_8xdaily_ave_coarse.zarr") + XX_ = xr.open_zarr(f"{PATH}/c48_atmos_ave/0011/atmos_8xdaily_ave_coarse.zarr") + yy = xr.open_zarr(f"{PATH}/c384_precip_ave/0011/sfc_8xdaily_ave.zarr") + topot = xr.open_zarr(f"{PATH}/c384_topo/0011/atmos_static.zarr") + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + c384_topo = pickle.load(f) + + if self.multi: + + c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + c384_channels = ["PRATEsfc"] + + else: + + c48_channels = ["PRATEsfc_coarse"] + c384_channels = ["PRATEsfc"] + + with torch.no_grad(): + + for tile in range(6): + + if self.rollout == 'full': + + seq_len = self.rollout_batch + st = 0 + en = seq_len + 2 + count = 0 + + while en < 3176: + + print(tile, st) + + X = XX.isel(time = slice(st, en), tile = tile) + X_ = XX_.isel(time = slice(st, en), tile = tile) + y = yy.isel(time = slice(st, en), tile = tile) + + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + topo = hres[:, :, 1:2, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(48, 48), mode='bilinear'), '(b t) c h w -> b t c h w', t = seq_len + 2) + + videos = self.model(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')).unsqueeze(0) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + count += 1 + + st += seq_len + en += seq_len + + if self.rollout == 'partial': + + seq_len = self.rollout_batch + #indices = get_random_idx_with_difference(0, 3176 - (seq_len + 2), 75 // seq_len, seq_len + 2) # 75 samples per tile + indices = list(range(0, 3176 - (seq_len + 2), 250)) # deterministic, 325 samples per tile for seq_len of 25 + + for count, st in enumerate(indices): + + print(tile, count) + + X = XX.isel(time = slice(st, st+(seq_len+2)), tile = tile) + X_ = XX_.isel(time = slice(st, st+(seq_len+2)), tile = tile) + y = yy.isel(time = slice(st, st+(seq_len+2)), tile = tile) + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + topo = hres[:, :, 1:2, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(48, 48), mode='bilinear'), '(b t) c h w -> b t c h w', t = seq_len + 2) + + videos = self.model(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')).unsqueeze(0) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) \ No newline at end of file diff --git a/projects/super_res/model/network_swinir.py b/projects/super_res/model/network_swinir.py index 461fb354ce..8c5f8537c0 100644 --- a/projects/super_res/model/network_swinir.py +++ b/projects/super_res/model/network_swinir.py @@ -643,7 +643,7 @@ class SwinIR(nn.Module): resi_connection: The convolutional block before residual connection. '1conv'/'3conv' """ - def __init__(self, img_size=64, patch_size=1, in_chans=3, + def __init__(self, img_size=64, patch_size=1, in_chans=3, out_chans=3, embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, @@ -652,7 +652,7 @@ def __init__(self, img_size=64, patch_size=1, in_chans=3, **kwargs): super(SwinIR, self).__init__() num_in_ch = in_chans - num_out_ch = in_chans + num_out_ch = out_chans num_feat = 64 self.img_range = img_range if in_chans == 3: @@ -666,6 +666,7 @@ def __init__(self, img_size=64, patch_size=1, in_chans=3, ##################################################################################################### ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) ##################################################################################################### diff --git a/projects/super_res/model/op/deform_attn.py b/projects/super_res/model/op/deform_attn.py new file mode 100644 index 0000000000..55da954230 --- /dev/null +++ b/projects/super_res/model/op/deform_attn.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +import os +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from einops.layers.torch import Rearrange +from distutils.version import LooseVersion +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +deform_attn_ext = load( + 'deform_attn', + sources=[ + os.path.join(module_path, 'deform_attn_ext.cpp'), + os.path.join(module_path, 'deform_attn_cuda_pt110.cpp' if LooseVersion(torch.__version__) >= LooseVersion( + '1.10.0') else 'deform_attn_cuda_pt109.cpp'), + os.path.join(module_path, 'deform_attn_cuda_kernel.cu'), +], +) + + +class Mlp(nn.Module): + """ Multilayer perceptron. + + Args: + x: (B, D, H, W, C) + + Returns: + x: (B, D, H, W, C) + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class DeformAttnFunction(Function): + + @staticmethod + def forward(ctx, + q, + kv, + offset, + kernel_h, + kernel_w, + stride=1, + padding=0, + dilation=1, + attention_heads=1, + deformable_groups=1, + clip_size=1): + ctx.kernel_h = kernel_h + ctx.kernel_w = kernel_w + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.attention_heads = attention_heads + ctx.deformable_groups = deformable_groups + ctx.clip_size = clip_size + if q.requires_grad or kv.requires_grad or offset.requires_grad: + ctx.save_for_backward(q, kv, offset) + output = q.new_empty(q.shape) + ctx._bufs = [q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0)] + deform_attn_ext.deform_attn_forward(q, kv, offset, output, + ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx.kernel_h, ctx.kernel_w, ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.attention_heads, ctx.deformable_groups, ctx.clip_size) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + q, kv, offset = ctx.saved_tensors + grad_q = torch.zeros_like(q) + grad_kv = torch.zeros_like(kv) + grad_offset = torch.zeros_like(offset) + deform_attn_ext.deform_attn_backward(q, kv, offset, ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx._bufs[3], ctx._bufs[4], + grad_q, grad_kv, grad_offset, + grad_output, ctx.kernel_h, ctx.kernel_w, ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.attention_heads, ctx.deformable_groups, ctx.clip_size) + + return (grad_q, grad_kv, grad_offset, None, None, None, None, None, None, None, None) + + +deform_attn = DeformAttnFunction.apply + + +class DeformAttn(nn.Module): + + def __init__(self, + in_channels, + out_channels, + attention_window=[3, 3], + deformable_groups=12, + attention_heads=12, + clip_size=1): + super(DeformAttn, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_h = attention_window[0] + self.kernel_w = attention_window[1] + self.attn_size = self.kernel_h * self.kernel_w + self.deformable_groups = deformable_groups + self.attention_heads = attention_heads + self.clip_size = clip_size + self.stride = 1 + self.padding = self.kernel_h//2 + self.dilation = 1 + + self.proj_q = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_k = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_v = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.mlp = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + Mlp(self.in_channels, self.in_channels * 2), + Rearrange('n d h w c -> n d c h w')) + + def forward(self, q, k, v, offset): + q = self.proj_q(q) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size) + v = v + self.mlp(v) + return v + + +class DeformAttnPack(DeformAttn): + """A Deformable Attention Encapsulation that acts as normal attention layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + attention_window (int or tuple[int]): Attention window size. Default: [3, 3]. + attention_heads (int): Attention head number. Default: 12. + deformable_groups (int): Deformable offset groups. Default: 12. + clip_size (int): clip size. Default: 2. + """ + + def __init__(self, *args, **kwargs): + super(DeformAttnPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels * (1 + self.clip_size), + self.clip_size * self.deformable_groups * self.attn_size * 2, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + bias=True) + self.init_weight() + + def init_weight(self): + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, q, k, v): + out = self.conv_offset(torch.cat([q.flatten(1, 2), k.flatten(1, 2)], 1)) + o1, o2 = torch.chunk(out, 2, dim=1) + offset = torch.cat((o1, o2), dim=1) + + q = self.proj_q(q) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size) + v = v + self.mlp(v) + return v \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn_cuda_kernel.cu b/projects/super_res/model/op/deform_attn_cuda_kernel.cu new file mode 100644 index 0000000000..6f1ccc2c91 --- /dev/null +++ b/projects/super_res/model/op/deform_attn_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn_cuda_pt109.cpp b/projects/super_res/model/op/deform_attn_cuda_pt109.cpp new file mode 100644 index 0000000000..46ef081a8f --- /dev/null +++ b/projects/super_res/model/op/deform_attn_cuda_pt109.cpp @@ -0,0 +1,219 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deform_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deform_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ){ + TORCH_CHECK(kv.is_contiguous(), "input tensor has to be contiguous"); + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + output = output.view({batch, attn_head, attn_dim, height, width}).zero_(); + + // resize temporary columns and attns + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // batch x clip_size x (deform_group*attn_size) x (heightxwidth) + + for (int b = 0; b < batch; b++) { // 0->2,1->2, or, 1->3,0->3 // todo: refer to deformable_im2col_cuda and use `im2col_step` to speed up + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + // do attention + output[b] = at::matmul(attns, columns[1].transpose(2, 3)) // (attn_head x (height*width) x 1 x attn_dim) + .transpose(1, 3).view({attn_head, attn_dim, height, width}); // (attn_head x attn_dim x height x width) + + // resize columns back for next batch + columns = columns.view({2, attn_head, area, attn_dim, clip_size , attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 3); // clip_size x (attn_head*attn_dim*attn_size) x (height*width) + } + + output = output.view({batch, channels, height, width}); +} + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ){ + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); +// // for PyTorch 1.10.1 +// const at::ScalarType dtype = kv.scalar_type(); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + grad_q = grad_q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + grad_offset = grad_offset.view({batch, clip_size, grad_offset.size(1) / clip_size, area}); + grad_output = grad_output.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + + // resize temporary columns, attns and grad_attns (we further need grad_attns in backward propagation because attn@V are interdependent. + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // (deform_group*attn_size) x (heightxwidth) + grad_attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + grad_mask_ones = at::zeros({deform_group * attn_size, area}, q.options()); // not returned + + + for (int b = 0; b < batch; b++) { + // recalculate columns and attns + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. attns, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + grad_attns = at::matmul(grad_output[b], columns[1]); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. sampled_v, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[1] = at::matmul(grad_output[b].transpose(2, 3), attns); // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + // gradient w.r.t. attns_before_softmax +// for PyTorch 1.9.1 + grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, grad_attns); // todo: it seems pt191 has different interface as pt110 +// // for PyTorch 1.10.1 +// grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, dtype); + + // gradient w.r.t. q, (attn_head x (height*width) x 1 x (clip_size*attn_size)) @ (attn_head x (height*width) x (clip_size*attn_size) x attn_dim) + grad_q[b] = at::matmul(grad_attns, columns[0].transpose(2, 3)) * attn_scale; // (attn_head x (height*width) x 1 x attn_dim) + + // gradient w.r.t. sampled_k, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[0] = at::matmul(q[b].transpose(2, 3), grad_attns) * attn_scale; // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + columns = columns.view({2, attn_head, area, attn_dim, clip_size, attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x 2 x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 4); // clip_size x (2*attn_head*attn_dim*attn_size) x (height*width) + + for (int n = 0; n < clip_size; n++) { + // gradient w.r.t. input coordinate data (grad_offset and grad_mask_ones) + modulated_deformable_col2im_coord_cuda( + columns[n], kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, + height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deform_group, grad_offset[b][n], + grad_mask_ones); + + // gradient w.r.t. kv + modulated_deformable_col2im_cuda( + columns[n], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, grad_kv[b/clip_size][(n+b)%clip_size]); // the grad is accumulated + } + } + + // resize gradidents back + grad_q = grad_q.transpose(2, 4).view({batch, channels, height, width}); // batch x (attn_headxattn_dim) x height x width + grad_offset = grad_offset.flatten(1, 2); + grad_output = grad_output.permute({0, 1, 4, 3, 2}).view({batch, channels, height, width}); +} \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn_cuda_pt110.cpp b/projects/super_res/model/op/deform_attn_cuda_pt110.cpp new file mode 100644 index 0000000000..0dd7816d80 --- /dev/null +++ b/projects/super_res/model/op/deform_attn_cuda_pt110.cpp @@ -0,0 +1,219 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deform_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deform_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ){ + TORCH_CHECK(kv.is_contiguous(), "input tensor has to be contiguous"); + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + output = output.view({batch, attn_head, attn_dim, height, width}).zero_(); + + // resize temporary columns and attns + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // batch x clip_size x (deform_group*attn_size) x (heightxwidth) + + for (int b = 0; b < batch; b++) { // 0->2,1->2, or, 1->3,0->3 // todo: refer to deformable_im2col_cuda and use `im2col_step` to speed up + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + // do attention + output[b] = at::matmul(attns, columns[1].transpose(2, 3)) // (attn_head x (height*width) x 1 x attn_dim) + .transpose(1, 3).view({attn_head, attn_dim, height, width}); // (attn_head x attn_dim x height x width) + + // resize columns back for next batch + columns = columns.view({2, attn_head, area, attn_dim, clip_size , attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 3); // clip_size x (attn_head*attn_dim*attn_size) x (height*width) + } + + output = output.view({batch, channels, height, width}); +} + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ){ + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + // for PyTorch 1.10.1 + const at::ScalarType dtype = kv.scalar_type(); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + grad_q = grad_q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + grad_offset = grad_offset.view({batch, clip_size, grad_offset.size(1) / clip_size, area}); + grad_output = grad_output.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + + // resize temporary columns, attns and grad_attns (we further need grad_attns in backward propagation because attn@V are interdependent. + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // (deform_group*attn_size) x (heightxwidth) + grad_attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + grad_mask_ones = at::zeros({deform_group * attn_size, area}, q.options()); // not returned + + + for (int b = 0; b < batch; b++) { + // recalculate columns and attns + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. attns, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + grad_attns = at::matmul(grad_output[b], columns[1]); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. sampled_v, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[1] = at::matmul(grad_output[b].transpose(2, 3), attns); // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + // gradient w.r.t. attns_before_softmax +// // for PyTorch 1.9.1 +// grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, grad_attns); // todo: it seems pt191 has different interface as pt110 + // for PyTorch 1.10.1 + grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, dtype); + + // gradient w.r.t. q, (attn_head x (height*width) x 1 x (clip_size*attn_size)) @ (attn_head x (height*width) x (clip_size*attn_size) x attn_dim) + grad_q[b] = at::matmul(grad_attns, columns[0].transpose(2, 3)) * attn_scale; // (attn_head x (height*width) x 1 x attn_dim) + + // gradient w.r.t. sampled_k, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[0] = at::matmul(q[b].transpose(2, 3), grad_attns) * attn_scale; // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + columns = columns.view({2, attn_head, area, attn_dim, clip_size, attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x 2 x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 4); // clip_size x (2*attn_head*attn_dim*attn_size) x (height*width) + + for (int n = 0; n < clip_size; n++) { + // gradient w.r.t. input coordinate data (grad_offset and grad_mask_ones) + modulated_deformable_col2im_coord_cuda( + columns[n], kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, + height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deform_group, grad_offset[b][n], + grad_mask_ones); + + // gradient w.r.t. kv + modulated_deformable_col2im_cuda( + columns[n], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, grad_kv[b/clip_size][(n+b)%clip_size]); // the grad is accumulated + } + } + + // resize gradidents back + grad_q = grad_q.transpose(2, 4).view({batch, channels, height, width}); // batch x (attn_headxattn_dim) x height x width + grad_offset = grad_offset.flatten(1, 2); + grad_output = grad_output.permute({0, 1, 4, 3, 2}).view({batch, channels, height, width}); +} \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn_ext.cpp b/projects/super_res/model/op/deform_attn_ext.cpp new file mode 100644 index 0000000000..a09d85851a --- /dev/null +++ b/projects/super_res/model/op/deform_attn_ext.cpp @@ -0,0 +1,75 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ); + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ); +#endif + +void deform_attn_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ) { + if (q.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_attn_cuda_forward(q, kv, + offset, output, columns, attns, mask_ones, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, attn_head, deform_group, clip_size); +#else + AT_ERROR("modulated deform attn is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform attn is not implemented on CPU"); +} + +void deform_attn_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor columns, + at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ) { + if (q.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_attn_cuda_backward(q, kv, + offset, columns, attns, mask_ones, grad_attns, grad_mask_ones, grad_q, grad_kv, grad_offset, + grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, attn_head, deform_group, clip_size); +#else + AT_ERROR("modulated deform attn is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform attn is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_attn_forward", + &deform_attn_forward, + "deform attn forward"); + m.def("deform_attn_backward", + &deform_attn_backward, + "deform attn backward"); +} \ No newline at end of file diff --git a/projects/super_res/sampler.py b/projects/super_res/sampler.py index 874a9e4fcb..69b3185d5a 100644 --- a/projects/super_res/sampler.py +++ b/projects/super_res/sampler.py @@ -1,63 +1,83 @@ import os -from model.autoreg_diffusion import Unet, Flow, GaussianDiffusion, Trainer -from data.load_data import load_data +from model.autoreg_diffusion_mod import Unet, Flow, GaussianDiffusion, Trainer from config_infer import config -model = Unet( - dim = config.dim, - channels = 2 * config.data_config["img_channel"], - out_dim = config.data_config["img_channel"], - dim_mults = config.dim_mults, - learned_sinusoidal_cond = config.learned_sinusoidal_cond, - random_fourier_features = config.random_fourier_features, - learned_sinusoidal_dim = config.learned_sinusoidal_dim -).cuda() - -flow = Flow( - dim = config.dim, - channels = 3 * config.data_config["img_channel"], - out_dim = 3, - dim_mults = config.dim_mults -).cuda() - -diffusion = GaussianDiffusion( - model, - flow, - image_size = config.data_config["img_size"], - timesteps = config.diffusion_steps, - sampling_timesteps = config.sampling_steps, - loss_type = config.loss, - objective = config.objective -).cuda() - -train_dl, val_dl = load_data( - config.data_config, - config.batch_size, - pin_memory = True, - num_workers = 4, +def main(): + + if config.data_config["multi"]: + + # in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + # in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + # in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + in_ch_model = 2 * config.data_config["img_channel"] + 10 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + in_ch_flow = 3 * (config.data_config["img_channel"] + 10 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + in_ch_isr = config.data_config["img_channel"] + 10 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + + else: + + in_ch_model = 2 * config.data_config["img_channel"] + in_ch_flow = 3 * config.data_config["img_channel"] + in_ch_isr = config.data_config["img_channel"] + + if config.data_config["flow"] == "3d": + + out_ch_flow = 3 + + elif config.data_config["flow"] == "2d": + + out_ch_flow = 2 + + model = Unet( + dim = config.dim, + channels = in_ch_model, + out_dim = config.data_config["img_channel"], + dim_mults = config.dim_mults, + learned_sinusoidal_cond = config.learned_sinusoidal_cond, + random_fourier_features = config.random_fourier_features, + learned_sinusoidal_dim = config.learned_sinusoidal_dim + ).cuda() + + flow = Flow( + dim = config.dim, + channels = in_ch_flow, + out_dim = out_ch_flow, + dim_mults = config.dim_mults + ).cuda() + + diffusion = GaussianDiffusion( + model, + flow, + image_size = config.data_config["img_size"], + in_ch = in_ch_isr, + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective + ).cuda() + + trainer = Trainer( + diffusion, + None, + None, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.data_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config ) -trainer = Trainer( - diffusion, - train_dl, - val_dl, - train_batch_size = config.batch_size, - train_lr = config.lr, - train_num_steps = config.steps, - gradient_accumulate_every = config.grad_acc, - val_num_of_batch = config.val_num_of_batch, - save_and_sample_every = config.save_and_sample_every, - ema_decay = config.ema_decay, - amp = config.amp, - split_batches = config.split_batches, - #eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), - eval_folder = os.path.join(config.eval_folder, f"{config.data_name}/"), - results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), - config = config - #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), -) - -trainer.load(config.milestone) - -trainer.sample() \ No newline at end of file + trainer.load(config.milestone) + + trainer.sample() + +if __name__ == "__main__": + print(config) + main() \ No newline at end of file diff --git a/projects/super_res/sampler_isr.py b/projects/super_res/sampler_isr.py new file mode 100644 index 0000000000..20c2a71992 --- /dev/null +++ b/projects/super_res/sampler_isr.py @@ -0,0 +1,38 @@ +import os + +from model.isr_baseline import Trainer +from model.network_swinir import SwinIR +from config_isr_infer import config + +def main(): + model = SwinIR(upscale=8, in_chans=12, out_chans=1, img_size=48, window_size=8, + img_range=1., depths=[6, 6, 6, 6, 6, 6, 6], embed_dim=200, + num_heads=[8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, upsampler='pixelshuffle', resi_connection='3conv').cuda() + + trainer = Trainer( + model, + None, + None, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config + #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), + ) + + trainer.load(config.milestone) + + trainer.sample() + +if __name__ == "__main__": + print(config) + main() \ No newline at end of file diff --git a/projects/super_res/sampler_rvrt_full.py b/projects/super_res/sampler_rvrt_full.py new file mode 100644 index 0000000000..f5ec3f2a4c --- /dev/null +++ b/projects/super_res/sampler_rvrt_full.py @@ -0,0 +1,103 @@ +import os +from torch import nn +from model.denoising_diffusion_rvrt_full import RSTBWithInputConv, Upsample, GuidedDeformAttnPack, GaussianDiffusion, SpyNet, Trainer +from config_rvrt_full_infer import config + +recon = RSTBWithInputConv( + in_channels = 5 * config.dim, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 1, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (1,8,8) +).cuda() + +feat_ext = RSTBWithInputConv( + in_channels = config.data_config["img_channel"]+11, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 1, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (1,8,8) +).cuda() + +feat_up = Upsample( + scale = 8, + num_feat = config.dim, + in_channels = config.data_config["img_channel"] +).cuda() + +spynet = SpyNet('./spynet').cuda() + +backbone = nn.ModuleDict() +deform_align = nn.ModuleDict()\ + +modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + +for i, module in enumerate(modules): + # deformable attention + deform_align[module] = GuidedDeformAttnPack(config.dim, + config.dim, + attention_window=[3, 3], + attention_heads=6, + deformable_groups=6, + clip_size=2, + max_residue_magnitude=10).cuda() + + # feature propagation + backbone[module] = RSTBWithInputConv( + in_channels = (2 + i) * config.dim, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 2, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (2,8,8) + ).cuda() + +diffusion = GaussianDiffusion( + feat_ext = feat_ext, + feat_up = feat_up, + backbone = backbone, + deform_align = deform_align, + recon = recon, + spynet = spynet, + image_size = config.data_config["img_size"], + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective +).cuda() + +trainer = Trainer( + diffusion, + None, + None, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config +) + +trainer.load(config.milestone) + +trainer.sample() \ No newline at end of file diff --git a/projects/super_res/trainer.py b/projects/super_res/trainer.py index 617c257d95..6a02be3e35 100755 --- a/projects/super_res/trainer.py +++ b/projects/super_res/trainer.py @@ -5,9 +5,33 @@ from config import config def main(): + + if config.data_config["multi"]: + + # in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + # in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + # in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + in_ch_model = 2 * config.data_config["img_channel"] + 10 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + in_ch_flow = 3 * (config.data_config["img_channel"] + 10 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + in_ch_isr = config.data_config["img_channel"] + 10 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + + else: + + in_ch_model = 2 * config.data_config["img_channel"] + in_ch_flow = 3 * config.data_config["img_channel"] + in_ch_isr = config.data_config["img_channel"] + + if config.data_config["flow"] == "3d": + + out_ch_flow = 3 + + elif config.data_config["flow"] == "2d": + + out_ch_flow = 2 + model = Unet( dim = config.dim, - channels = 2 * config.data_config["img_channel"], + channels = in_ch_model, out_dim = config.data_config["img_channel"], dim_mults = config.dim_mults, learned_sinusoidal_cond = config.learned_sinusoidal_cond, @@ -17,15 +41,16 @@ def main(): flow = Flow( dim = config.dim, - channels = 3 * config.data_config["img_channel"], - out_dim = 3, + channels = in_ch_flow, + out_dim = out_ch_flow, dim_mults = config.dim_mults ).cuda() - + diffusion = GaussianDiffusion( model, flow, image_size = config.data_config["img_size"], + in_ch = in_ch_isr, timesteps = config.diffusion_steps, sampling_timesteps = config.sampling_steps, loss_type = config.loss, @@ -55,7 +80,6 @@ def main(): eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), config = config - #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), ) trainer.train() @@ -63,4 +87,4 @@ def main(): if __name__ == "__main__": print(config) - main() + main() \ No newline at end of file diff --git a/projects/super_res/trainer_focal.py b/projects/super_res/trainer_focal.py new file mode 100755 index 0000000000..82cb803442 --- /dev/null +++ b/projects/super_res/trainer_focal.py @@ -0,0 +1,90 @@ +import os + +from model.autoreg_diffusion_mod_focal import Unet, Flow, GaussianDiffusion, Trainer +from data.load_data import load_data +from config_focal import config + +def main(): + + if config.data_config["multi"]: + + # in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + # in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + # in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + in_ch_model = 2 * config.data_config["img_channel"] + 10 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + in_ch_flow = 3 * (config.data_config["img_channel"] + 10 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + in_ch_isr = config.data_config["img_channel"] + 10 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + + else: + + in_ch_model = 2 * config.data_config["img_channel"] + in_ch_flow = 3 * config.data_config["img_channel"] + in_ch_isr = config.data_config["img_channel"] + + if config.data_config["flow"] == "3d": + + out_ch_flow = 3 + + elif config.data_config["flow"] == "2d": + + out_ch_flow = 2 + + model = Unet( + dim = config.dim, + channels = in_ch_model, + out_dim = config.data_config["img_channel"], + dim_mults = config.dim_mults, + learned_sinusoidal_cond = config.learned_sinusoidal_cond, + random_fourier_features = config.random_fourier_features, + learned_sinusoidal_dim = config.learned_sinusoidal_dim + ).cuda() + + flow = Flow( + dim = config.dim, + channels = in_ch_flow, + out_dim = out_ch_flow, + dim_mults = config.dim_mults + ).cuda() + + diffusion = GaussianDiffusion( + model, + flow, + image_size = config.data_config["img_size"], + in_ch = in_ch_isr, + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective + ).cuda() + + train_dl, val_dl = load_data( + config.data_config, + config.batch_size, + pin_memory = True, + num_workers = 4, + ) + + trainer = Trainer( + diffusion, + train_dl, + val_dl, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config + ) + + trainer.train() + + +if __name__ == "__main__": + print(config) + main() \ No newline at end of file diff --git a/projects/super_res/trainer_isr.py b/projects/super_res/trainer_isr.py new file mode 100644 index 0000000000..18afce64bd --- /dev/null +++ b/projects/super_res/trainer_isr.py @@ -0,0 +1,45 @@ +import os + +from model.isr_baseline import Trainer +from model.network_swinir import SwinIR +from data.load_data import load_data +from config_isr import config + +def main(): + model = SwinIR(upscale=8, in_chans=12, out_chans=1, img_size=48, window_size=8, + img_range=1., depths=[6, 6, 6, 6, 6, 6, 6], embed_dim=200, + num_heads=[8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, upsampler='pixelshuffle', resi_connection='3conv').cuda() + + train_dl, val_dl = load_data( + config.data_config, + config.batch_size, + pin_memory = True, + num_workers = 4, + ) + + trainer = Trainer( + model, + train_dl, + val_dl, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config + #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), + ) + + trainer.train() + + +if __name__ == "__main__": + print(config) + main() diff --git a/projects/super_res/trainer_rvrt_full.py b/projects/super_res/trainer_rvrt_full.py new file mode 100644 index 0000000000..8f626eb9aa --- /dev/null +++ b/projects/super_res/trainer_rvrt_full.py @@ -0,0 +1,109 @@ +import os +from torch import nn +from model.denoising_diffusion_rvrt_full import RSTBWithInputConv, Upsample, GuidedDeformAttnPack, GaussianDiffusion, SpyNet, Trainer +from data.load_data import load_data +from config_rvrt_full import config + +recon = RSTBWithInputConv( + in_channels = 5 * config.dim, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 1, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (1,8,8) +).cuda() + +feat_ext = RSTBWithInputConv( + in_channels = config.data_config["img_channel"]+11, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 1, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (1,8,8) +).cuda() + +feat_up = Upsample( + scale = 8, + num_feat = config.dim, + in_channels = config.data_config["img_channel"] +).cuda() + +spynet = SpyNet('./spynet').cuda() + +backbone = nn.ModuleDict() +deform_align = nn.ModuleDict()\ + +modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + +for i, module in enumerate(modules): + # deformable attention + deform_align[module] = GuidedDeformAttnPack(config.dim, + config.dim, + attention_window=[3, 3], + attention_heads=6, + deformable_groups=6, + clip_size=2, + max_residue_magnitude=10).cuda() + + # feature propagation + backbone[module] = RSTBWithInputConv( + in_channels = (2 + i) * config.dim, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 2, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (2,8,8) + ).cuda() + +diffusion = GaussianDiffusion( + feat_ext = feat_ext, + feat_up = feat_up, + backbone = backbone, + deform_align = deform_align, + recon = recon, + spynet = spynet, + image_size = config.data_config["img_size"], + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective +).cuda() + +train_dl, val_dl = load_data( + config.data_config, + config.batch_size, + pin_memory = True, + num_workers = 2, + ) + +trainer = Trainer( + diffusion, + train_dl, + val_dl, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config +) + +trainer.train() \ No newline at end of file