Skip to content

Commit

Permalink
Submitted version
Browse files Browse the repository at this point in the history
  • Loading branch information
tobifinn committed Sep 21, 2023
1 parent 24c0004 commit fe49fbf
Show file tree
Hide file tree
Showing 23 changed files with 132 additions and 142 deletions.
49 changes: 30 additions & 19 deletions assimilate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main_assimilate(cfg: DictConfig):
curr_state = integrator.integrate(curr_state)
if (burn_time % cfg.obs_every) == 0:
# Assimilate
curr_state, curr_ens = assimilation.assimilate(
curr_state, _, curr_ens = assimilation.assimilate(
curr_state, obs[:, [burn_time]]
)
# Estimate statistics
Expand Down Expand Up @@ -142,10 +142,12 @@ def main_assimilate(cfg: DictConfig):
# Estimate statistics
total_steps = cfg.obs_every * cfg.n_cycles
n_stat_steps = 0
mse = torch.zeros(cfg.obs_every+1, 3)
spread = torch.zeros(cfg.obs_every+1, 3)
mse = torch.zeros(curr_state.size(0), cfg.obs_every+1, 3)
spread = torch.zeros(curr_state.size(0), cfg.obs_every+1, 3)
ana_mse = 0
ana_spread = 0
bg_mse = 0
bg_spread = 0
cov_ana = torch.zeros(3, 3)
cov_bg = torch.zeros(3, 3)
curr_traj = [curr_state.clone()]
Expand All @@ -164,12 +166,12 @@ def main_assimilate(cfg: DictConfig):
# Update MSE and spread
curr_mse = (
curr_traj.mean(dim=-2)-truth[:, t-cfg.obs_every:t+1]
).pow(2).mean(dim=0)
).pow(2)
mse = mse * old_weight + curr_mse / n_stat_steps

if cfg.n_ens > 1:
spread = spread * old_weight \
+ curr_traj.var(dim=-2).mean(dim=0) / n_stat_steps
+ curr_traj.var(dim=-2) / n_stat_steps

# Update bg cov
bg_mean = curr_traj[:, -1].mean(dim=-2, keepdims=True)
Expand All @@ -179,10 +181,16 @@ def main_assimilate(cfg: DictConfig):
cov_bg = cov_bg * old_weight + curr_cov / n_stat_steps

# Assimilate
analysis, ana_ens = assimilation.assimilate(
analysis, bg_ens, ana_ens = assimilation.assimilate(
curr_traj[:, -1], obs[:, [t]]
)

# Update background scores
curr_bg_mse = (bg_ens.mean(dim=-2)-truth[:, t]).pow(2).mean(dim=0)
bg_mse = bg_mse * old_weight + curr_bg_mse / n_stat_steps
curr_bg_spread = bg_ens.var(dim=-2).mean(dim=0)
bg_spread = bg_spread * old_weight + curr_bg_spread / n_stat_steps

# Update analysis scores
curr_ana_mse = (analysis.mean(dim=-2)-truth[:, t]).pow(2).mean(dim=0)
ana_mse = ana_mse * old_weight + curr_ana_mse / n_stat_steps
Expand All @@ -203,12 +211,16 @@ def main_assimilate(cfg: DictConfig):
curr_nspread = (
curr_ana_spread/clim_scaling.pow(2)
).mean().sqrt().item()
bg_nrmse = (bg_mse/clim_scaling.pow(2)).mean().sqrt().item()
bg_nspread = (bg_spread/clim_scaling.pow(2)).mean().sqrt().item()
ana_nrmse = (ana_mse/clim_scaling.pow(2)).mean().sqrt().item()
ana_nspread = (ana_spread/clim_scaling.pow(2)).mean().sqrt().item()

wandb.log({
"assim/curr_rmse": curr_nrmse,
"assim/curr_spread": curr_nspread,
"assim/bg_rmse": bg_nrmse,
"assim/bg_spread": bg_nspread,
"assim/ana_rmse": ana_nrmse,
"assim/ana_spread": ana_nspread,
},)
Expand All @@ -220,25 +232,24 @@ def main_assimilate(cfg: DictConfig):
)

wandb.define_metric("lead_time")
wandb.define_metric("assim/mse*", step_metric="lead_time")
wandb.define_metric("assim/spread*", step_metric="lead_time")
all_scores = zip(mse, spread)
wandb.define_metric("assim/rmse_mean", step_metric="lead_time")
wandb.define_metric("assim/rmse_std", step_metric="lead_time")
wandb.define_metric("assim/spread_mean", step_metric="lead_time")
rmse = (mse / clim_scaling.pow(2)).mean(dim=-1).sqrt()
spread = (spread / clim_scaling.pow(2)).mean(dim=-1).sqrt()
rmse_mean = rmse.mean(dim=0)
rmse_std = rmse.std(dim=0)
spread_mean = spread.mean(dim=0)
all_scores = zip(rmse_mean, rmse_std, spread_mean)
for ld, scores in enumerate(all_scores):
score_dict = {
"assim/mse_norm": (scores[0]/clim_scaling.pow(2)).mean().sqrt(),
"assim/mse_x": scores[0][0],
"assim/mse_y": scores[0][1],
"assim/mse_z": scores[0][2],
"assim/spread_norm": (scores[1]/clim_scaling.pow(2)).mean().sqrt(),
"assim/spread_x": scores[1][0],
"assim/spread_y": scores[1][1],
"assim/spread_z": scores[1][2],
"assim/rmse_mean": scores[0],
"assim/rmse_std": scores[1],
"assim/spread_mean": scores[2],
"lead_time": ld
}
wandb.log(score_dict)

wandb.log({"assim/ana_mse": ana_mse, "assim/ana_spread": ana_spread})

wandb.run.summary["cov_bg"] = cov_bg
wandb.run.summary["cov_ana"] = cov_ana
wandb.finish()
Expand Down
12 changes: 0 additions & 12 deletions configs/assimilation/etkf_ddim.yaml

This file was deleted.

6 changes: 6 additions & 0 deletions configs/experiment/assimilation/ddim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ assimilation:
n_steps: 13
sampler:
timesteps: 100

logger:
tags:
- assimilation
- ddim
- scores
20 changes: 20 additions & 0 deletions configs/experiment/assimilation/ddim_10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# @package _global_

defaults:
- override /assimilation: enoi_ddim

exp_name: ddim
n_ens: 1

assimilation:
sampler:
n_ens: 10
n_steps: 15
sampler:
timesteps: 100

logger:
tags:
- assimilation
- ddim
- scores
18 changes: 0 additions & 18 deletions configs/experiment/assimilation/ddim_noise.yaml

This file was deleted.

8 changes: 7 additions & 1 deletion configs/experiment/assimilation/ddim_scaling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ hydra:
mode: "MULTIRUN"
sweeper:
params:
+combined: "{dt:1,steps:3},{dt:5,steps:7},{dt:10,steps:12},{dt:15,steps:17},{dt:20,steps:23},{dt:25,steps:31}"
+combined: "{dt:1,steps:3},{dt:5,steps:7},{dt:10,steps:13},{dt:15,steps:17},{dt:20,steps:23},{dt:25,steps:31}"

exp_name: ddim_${combined.dt}
obs_every: ${combined.dt}
n_ens: 1
assimilation:
sampler:
n_steps: ${combined.steps}

logger:
tags:
- assimilation
- ddim
- scaling
7 changes: 6 additions & 1 deletion configs/experiment/assimilation/ddim_spread.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ hydra:
assimilation.sampler.n_steps: range(5, 101, 5)

exp_name: ddim_${assimilation.sampler.n_steps}
output_path: "data/predictions/assimilation_ddim/${exp_name}.pt"

n_ens: 1

logger:
tags:
- assimilation
- ddim
- spread
18 changes: 0 additions & 18 deletions configs/experiment/assimilation/enoi_cov_noise.yaml

This file was deleted.

7 changes: 7 additions & 0 deletions configs/experiment/assimilation/enoi_diag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,10 @@ n_ens: 1
assimilation:
sampler:
cov_inf: 2.22

logger:
tags:
- assimilation
- enoi
- diag
- scores
18 changes: 0 additions & 18 deletions configs/experiment/assimilation/enoi_diag_noise.yaml

This file was deleted.

7 changes: 7 additions & 0 deletions configs/experiment/assimilation/enoi_diag_scaling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ n_ens: 1
assimilation:
sampler:
cov_inf: ${combined.factor}

logger:
tags:
- assimilation
- enoi
- diag
- scaling
7 changes: 7 additions & 0 deletions configs/experiment/assimilation/enoi_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,10 @@ n_ens: 1
assimilation:
sampler:
cov_inf: 1.82

logger:
tags:
- assimilation
- enoi
- full
- scores
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

defaults:
- override /assimilation: enoi_cov
- override /assimilation: enoi_full

hydra:
mode: "MULTIRUN"
Expand All @@ -15,3 +15,10 @@ n_ens: 1
assimilation:
sampler:
cov_inf: ${combined.factor}

logger:
tags:
- assimilation
- enoi
- full
- scaling
17 changes: 0 additions & 17 deletions configs/experiment/assimilation/enoi_sqrt_scaling.yaml

This file was deleted.

7 changes: 7 additions & 0 deletions configs/experiment/assimilation/etkf_12.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ n_ens: 12

assimilation:
inf_factor: 1.07

logger:
tags:
- assimilation
- etkf
- ens_12
- scores
7 changes: 7 additions & 0 deletions configs/experiment/assimilation/etkf_3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ n_ens: 3

assimilation:
inf_factor: 1.34

logger:
tags:
- assimilation
- etkf
- ens_3
- scores
14 changes: 0 additions & 14 deletions configs/experiment/assimilation/etkf_ddim.yaml

This file was deleted.

17 changes: 0 additions & 17 deletions configs/experiment/assimilation/etkf_noise.yaml

This file was deleted.

Loading

0 comments on commit fe49fbf

Please sign in to comment.