diff --git a/.bumpversion.cfg b/.bumpversion.cfg index eef065f..11269f0 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 2.0.5 +current_version = 2.0.4 [bumpversion:file:setup.py] search = version="{current_version}" diff --git a/.gitignore b/.gitignore index 863634c..bd6404c 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,4 @@ dmypy.json # Pyre type checker .pyre/ +*.h5 diff --git a/MetNet blog post, Valter Fallenius.pdf b/MetNet blog post, Valter Fallenius.pdf new file mode 100644 index 0000000..45d1a4d Binary files /dev/null and b/MetNet blog post, Valter Fallenius.pdf differ diff --git a/MetNet_lightning.py b/MetNet_lightning.py new file mode 100644 index 0000000..a64143e --- /dev/null +++ b/MetNet_lightning.py @@ -0,0 +1,64 @@ +from metnet.models.metnet_pylight import MetNetPylight +import torch +import torch.nn.functional as F +from data_prep.prepare_data_MetNet import load_data +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +import wandb +from pytorch_lightning.callbacks import DeviceStatsMonitor +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks import ModelCheckpoint +import time + +wandb.login() + +'''model = MetNetPylight( + hidden_dim=256, #384 original paper + forecast_steps=60, #240 original paper + input_channels=15, #46 original paper, hour/day/month = 3, lat/long/elevation = 3, GOES+MRMS = 40 + output_channels=128, #512 + input_size=112, # 112 + n_samples = None, #None = All ~ 23000 + num_workers = 4, + batch_size = 8, + learning_rate = 1e-2, + num_att_layers = 4, + plot_every = None, #Plot every global_step + rain_step = 0.2, + momentum = 0.9, + att_heads=16, + keep_biggest = 1, + leadtime_spacing = 1, #1: 5 minutes, 3: 15 minutes + )''' +#PATH_cp = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/8leads with agg.ckpt" +#PATH_cp = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/full_run_continue.ckpt" + +PATH_cp = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/lit-wandb/2yap8c0s/checkpoints/epoch=754-step=52094_0.2231.ckpt" +model = MetNetPylight.load_from_checkpoint(PATH_cp) +model.learning_rate = 1e-3 +print(model) +print(model.forecast_steps) +print(model.input_channels) +print(model.output_channels) +print(model.n_samples) +print(model.file_name) +print(model.keep_biggest) + +#model.printer = True +#model.plot_every = None +#MetNetPylight expects already preprocessed data. Can be change by uncommenting the preprocessing step. +#print(model) + +#wandb.restore("/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/wandb/run-20220331_162434-21o2u2sj" +#wandb.init(run_id = "21o2u2sj",resume="must") +wandb_logger = WandbLogger(project="lit-wandb", log_model="all") +checkpoint_callback = ModelCheckpoint(monitor="validation/loss_epoch", mode="min", save_top_k=5) + + +trainer = pl.Trainer(num_sanity_val_steps=2, track_grad_norm = 2, max_epochs=2000, gpus=-1,log_every_n_steps=50, logger = wandb_logger,strategy="ddp", callbacks=[checkpoint_callback]) +start_time = time.time() + +trainer.fit(model) +print("--- %s seconds ---" % (time.time() - start_time)) +wandb.finish() diff --git a/Precipitation Nowcasting using Deep Neural Networks - Valter Fallenius.pdf b/Precipitation Nowcasting using Deep Neural Networks - Valter Fallenius.pdf new file mode 100644 index 0000000..b047a66 Binary files /dev/null and b/Precipitation Nowcasting using Deep Neural Networks - Valter Fallenius.pdf differ diff --git a/README.md b/README.md index f1955b2..4343ece 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,8 @@ +Pytorch lightning: + +Swedish dataset (the ugly and small one, I might publish the nice and big one too at some point): +https://we.tl/t-OfkV7ZnGWC + # MetNet and MetNet-2 PyTorch Implementation of Google Research's MetNet for short term weather forecasting (https://arxiv.org/abs/2003.12140), inspired from https://github.com/tcapelle/metnet_pytorch/tree/master/metnet_pytorch @@ -14,8 +19,6 @@ pip install -e . Alternatively, you can also install a usually older version through ```pip install metnet``` -Please ensure that you're using Python version 3.9 or above. - ## Data While the exact training data used for both MetNet and MetNet-2 haven't been released, the papers do go into some detail as to the inputs, which were GOES-16 and MRMS precipitation data, as well as the time period covered. We will be making those splits available, as well as a larger dataset that covers a longer time period, with [HuggingFace Datasets](https://huggingface.co/datasets/openclimatefix/goes-mrms)! diff --git a/Scripts/Activate.ps1 b/Scripts/Activate.ps1 new file mode 100644 index 0000000..b547581 --- /dev/null +++ b/Scripts/Activate.ps1 @@ -0,0 +1,247 @@ +<# +.Synopsis +Activate a Python virtual environment for the current PowerShell session. + +.Description +Pushes the python executable for a virtual environment to the front of the +$Env:PATH environment variable and sets the prompt to signify that you are +in a Python virtual environment. Makes use of the command line switches as +well as the `pyvenv.cfg` file values present in the virtual environment. + +.Parameter VenvDir +Path to the directory that contains the virtual environment to activate. The +default value for this is the parent of the directory that the Activate.ps1 +script is located within. + +.Parameter Prompt +The prompt prefix to display when this virtual environment is activated. By +default, this prompt is the name of the virtual environment folder (VenvDir) +surrounded by parentheses and followed by a single space (ie. '(.venv) '). + +.Example +Activate.ps1 +Activates the Python virtual environment that contains the Activate.ps1 script. + +.Example +Activate.ps1 -Verbose +Activates the Python virtual environment that contains the Activate.ps1 script, +and shows extra information about the activation as it executes. + +.Example +Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv +Activates the Python virtual environment located in the specified location. + +.Example +Activate.ps1 -Prompt "MyPython" +Activates the Python virtual environment that contains the Activate.ps1 script, +and prefixes the current prompt with the specified string (surrounded in +parentheses) while the virtual environment is active. + +.Notes +On Windows, it may be required to enable this Activate.ps1 script by setting the +execution policy for the user. You can do this by issuing the following PowerShell +command: + +PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser + +For more information on Execution Policies: +https://go.microsoft.com/fwlink/?LinkID=135170 + +#> +Param( + [Parameter(Mandatory = $false)] + [String] + $VenvDir, + [Parameter(Mandatory = $false)] + [String] + $Prompt +) + +<# Function declarations --------------------------------------------------- #> + +<# +.Synopsis +Remove all shell session elements added by the Activate script, including the +addition of the virtual environment's Python executable from the beginning of +the PATH variable. + +.Parameter NonDestructive +If present, do not remove this function from the global namespace for the +session. + +#> +function global:deactivate ([switch]$NonDestructive) { + # Revert to original values + + # The prior prompt: + if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) { + Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt + Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT + } + + # The prior PYTHONHOME: + if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) { + Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME + Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME + } + + # The prior PATH: + if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) { + Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH + Remove-Item -Path Env:_OLD_VIRTUAL_PATH + } + + # Just remove the VIRTUAL_ENV altogether: + if (Test-Path -Path Env:VIRTUAL_ENV) { + Remove-Item -Path env:VIRTUAL_ENV + } + + # Just remove VIRTUAL_ENV_PROMPT altogether. + if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) { + Remove-Item -Path env:VIRTUAL_ENV_PROMPT + } + + # Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether: + if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) { + Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force + } + + # Leave deactivate function in the global namespace if requested: + if (-not $NonDestructive) { + Remove-Item -Path function:deactivate + } +} + +<# +.Description +Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the +given folder, and returns them in a map. + +For each line in the pyvenv.cfg file, if that line can be parsed into exactly +two strings separated by `=` (with any amount of whitespace surrounding the =) +then it is considered a `key = value` line. The left hand string is the key, +the right hand is the value. + +If the value starts with a `'` or a `"` then the first and last character is +stripped from the value before being captured. + +.Parameter ConfigDir +Path to the directory that contains the `pyvenv.cfg` file. +#> +function Get-PyVenvConfig( + [String] + $ConfigDir +) { + Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg" + + # Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue). + $pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue + + # An empty map will be returned if no config file is found. + $pyvenvConfig = @{ } + + if ($pyvenvConfigPath) { + + Write-Verbose "File exists, parse `key = value` lines" + $pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath + + $pyvenvConfigContent | ForEach-Object { + $keyval = $PSItem -split "\s*=\s*", 2 + if ($keyval[0] -and $keyval[1]) { + $val = $keyval[1] + + # Remove extraneous quotations around a string value. + if ("'""".Contains($val.Substring(0, 1))) { + $val = $val.Substring(1, $val.Length - 2) + } + + $pyvenvConfig[$keyval[0]] = $val + Write-Verbose "Adding Key: '$($keyval[0])'='$val'" + } + } + } + return $pyvenvConfig +} + + +<# Begin Activate script --------------------------------------------------- #> + +# Determine the containing directory of this script +$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition +$VenvExecDir = Get-Item -Path $VenvExecPath + +Write-Verbose "Activation script is located in path: '$VenvExecPath'" +Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)" +Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)" + +# Set values required in priority: CmdLine, ConfigFile, Default +# First, get the location of the virtual environment, it might not be +# VenvExecDir if specified on the command line. +if ($VenvDir) { + Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values" +} +else { + Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir." + $VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/") + Write-Verbose "VenvDir=$VenvDir" +} + +# Next, read the `pyvenv.cfg` file to determine any required value such +# as `prompt`. +$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir + +# Next, set the prompt from the command line, or the config file, or +# just use the name of the virtual environment folder. +if ($Prompt) { + Write-Verbose "Prompt specified as argument, using '$Prompt'" +} +else { + Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value" + if ($pyvenvCfg -and $pyvenvCfg['prompt']) { + Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'" + $Prompt = $pyvenvCfg['prompt']; + } + else { + Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virutal environment)" + Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'" + $Prompt = Split-Path -Path $venvDir -Leaf + } +} + +Write-Verbose "Prompt = '$Prompt'" +Write-Verbose "VenvDir='$VenvDir'" + +# Deactivate any currently active virtual environment, but leave the +# deactivate function in place. +deactivate -nondestructive + +# Now set the environment variable VIRTUAL_ENV, used by many tools to determine +# that there is an activated venv. +$env:VIRTUAL_ENV = $VenvDir + +if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) { + + Write-Verbose "Setting prompt to '$Prompt'" + + # Set the prompt to include the env name + # Make sure _OLD_VIRTUAL_PROMPT is global + function global:_OLD_VIRTUAL_PROMPT { "" } + Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT + New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt + + function global:prompt { + Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) " + _OLD_VIRTUAL_PROMPT + } + $env:VIRTUAL_ENV_PROMPT = $Prompt +} + +# Clear PYTHONHOME +if (Test-Path -Path Env:PYTHONHOME) { + Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME + Remove-Item -Path Env:PYTHONHOME +} + +# Add the venv to the PATH +Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH +$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH" diff --git a/Scripts/activate b/Scripts/activate new file mode 100644 index 0000000..41bfb5c --- /dev/null +++ b/Scripts/activate @@ -0,0 +1,69 @@ +# This file must be used with "source bin/activate" *from bash* +# you cannot run it directly + +deactivate () { + # reset old environment variables + if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then + PATH="${_OLD_VIRTUAL_PATH:-}" + export PATH + unset _OLD_VIRTUAL_PATH + fi + if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then + PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}" + export PYTHONHOME + unset _OLD_VIRTUAL_PYTHONHOME + fi + + # This should detect bash and zsh, which have a hash command that must + # be called to get it to forget past commands. Without forgetting + # past commands the $PATH changes we made may not be respected + if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then + hash -r 2> /dev/null + fi + + if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then + PS1="${_OLD_VIRTUAL_PS1:-}" + export PS1 + unset _OLD_VIRTUAL_PS1 + fi + + unset VIRTUAL_ENV + unset VIRTUAL_ENV_PROMPT + if [ ! "${1:-}" = "nondestructive" ] ; then + # Self destruct! + unset -f deactivate + fi +} + +# unset irrelevant variables +deactivate nondestructive + +VIRTUAL_ENV="C:\Users\valte\Desktop\Teoretisk Fysik\SMHI master\Network\metnet-main" +export VIRTUAL_ENV + +_OLD_VIRTUAL_PATH="$PATH" +PATH="$VIRTUAL_ENV/Scripts:$PATH" +export PATH + +# unset PYTHONHOME if set +# this will fail if PYTHONHOME is set to the empty string (which is bad anyway) +# could use `if (set -u; : $PYTHONHOME) ;` in bash +if [ -n "${PYTHONHOME:-}" ] ; then + _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}" + unset PYTHONHOME +fi + +if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then + _OLD_VIRTUAL_PS1="${PS1:-}" + PS1="(metnet-main) ${PS1:-}" + export PS1 + VIRTUAL_ENV_PROMPT="(metnet-main) " + export VIRTUAL_ENV_PROMPT +fi + +# This should detect bash and zsh, which have a hash command that must +# be called to get it to forget past commands. Without forgetting +# past commands the $PATH changes we made may not be respected +if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then + hash -r 2> /dev/null +fi diff --git a/Scripts/activate.bat b/Scripts/activate.bat new file mode 100644 index 0000000..8e6d34e --- /dev/null +++ b/Scripts/activate.bat @@ -0,0 +1,34 @@ +@echo off + +rem This file is UTF-8 encoded, so we need to update the current code page while executing it +for /f "tokens=2 delims=:." %%a in ('"%SystemRoot%\System32\chcp.com"') do ( + set _OLD_CODEPAGE=%%a +) +if defined _OLD_CODEPAGE ( + "%SystemRoot%\System32\chcp.com" 65001 > nul +) + +set VIRTUAL_ENV=C:\Users\valte\Desktop\Teoretisk Fysik\SMHI master\Network\metnet-main + +if not defined PROMPT set PROMPT=$P$G + +if defined _OLD_VIRTUAL_PROMPT set PROMPT=%_OLD_VIRTUAL_PROMPT% +if defined _OLD_VIRTUAL_PYTHONHOME set PYTHONHOME=%_OLD_VIRTUAL_PYTHONHOME% + +set _OLD_VIRTUAL_PROMPT=%PROMPT% +set PROMPT=(metnet-main) %PROMPT% + +if defined PYTHONHOME set _OLD_VIRTUAL_PYTHONHOME=%PYTHONHOME% +set PYTHONHOME= + +if defined _OLD_VIRTUAL_PATH set PATH=%_OLD_VIRTUAL_PATH% +if not defined _OLD_VIRTUAL_PATH set _OLD_VIRTUAL_PATH=%PATH% + +set PATH=%VIRTUAL_ENV%\Scripts;%PATH% +set VIRTUAL_ENV_PROMPT=(metnet-main) + +:END +if defined _OLD_CODEPAGE ( + "%SystemRoot%\System32\chcp.com" %_OLD_CODEPAGE% > nul + set _OLD_CODEPAGE= +) diff --git a/Scripts/deactivate.bat b/Scripts/deactivate.bat new file mode 100644 index 0000000..62a39a7 --- /dev/null +++ b/Scripts/deactivate.bat @@ -0,0 +1,22 @@ +@echo off + +if defined _OLD_VIRTUAL_PROMPT ( + set "PROMPT=%_OLD_VIRTUAL_PROMPT%" +) +set _OLD_VIRTUAL_PROMPT= + +if defined _OLD_VIRTUAL_PYTHONHOME ( + set "PYTHONHOME=%_OLD_VIRTUAL_PYTHONHOME%" + set _OLD_VIRTUAL_PYTHONHOME= +) + +if defined _OLD_VIRTUAL_PATH ( + set "PATH=%_OLD_VIRTUAL_PATH%" +) + +set _OLD_VIRTUAL_PATH= + +set VIRTUAL_ENV= +set VIRTUAL_ENV_PROMPT= + +:END diff --git a/Scripts/pip.exe b/Scripts/pip.exe new file mode 100644 index 0000000..a21b20d Binary files /dev/null and b/Scripts/pip.exe differ diff --git a/Scripts/pip3.10.exe b/Scripts/pip3.10.exe new file mode 100644 index 0000000..a21b20d Binary files /dev/null and b/Scripts/pip3.10.exe differ diff --git a/Scripts/pip3.exe b/Scripts/pip3.exe new file mode 100644 index 0000000..a21b20d Binary files /dev/null and b/Scripts/pip3.exe differ diff --git a/Scripts/python.exe b/Scripts/python.exe new file mode 100644 index 0000000..bbc056e Binary files /dev/null and b/Scripts/python.exe differ diff --git a/Scripts/pythonw.exe b/Scripts/pythonw.exe new file mode 100644 index 0000000..3b5c0bd Binary files /dev/null and b/Scripts/pythonw.exe differ diff --git a/a.txt b/a.txt new file mode 100644 index 0000000..e69de29 diff --git a/data_prep/__init__.py b/data_prep/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_prep/metnet_dataloader.py b/data_prep/metnet_dataloader.py new file mode 100644 index 0000000..4481de8 --- /dev/null +++ b/data_prep/metnet_dataloader.py @@ -0,0 +1,176 @@ +import torch +#from .prepare_data_MetNet import load_data +from torch.utils.data import Dataset, DataLoader +import math +import numpy as np +import os +import matplotlib.pyplot as plt + +class MetNetDataset(Dataset): + def __init__(self,ID, N=None, keep_biggest = 0.5, leadtime_spacing = 1, lead_times = 60): + + data_path = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/bin_sorted_data/"+ID+"/" + self.leadtime_spacing = leadtime_spacing + + self.file_names = [] + self.ID = ID + self.means = [] + self.weights = None + + file_list = os.listdir(data_path) + if not N: + N = len(file_list) + + for file_name in file_list: + + if file_name[-5] == "X": + self.file_names.append(data_path+file_name) + mean = float(file_name.split("_")[0]) + + self.means.append(mean) + + + self.means = np.array(self.means) + idx_sorted = np.argsort(-self.means) + to_keep = int(N*keep_biggest) + + + + idx_to_keep = idx_sorted[:to_keep] + self.file_names = [self.file_names[idx] for idx in idx_sorted[:to_keep]] + + if ID == "train": + minimum_rain = 5 + skipped = 0 + try: + self.rainy_leadtimes = np.load(f"leadtime_sampling_N_{N}_leads_{lead_times}_spacing_{leadtime_spacing}_minimum_{minimum_rain}.npy") + with open(f"leadtime_sampling_N_{N}_leads_{lead_times}_spacing_{leadtime_spacing}_minimum_{minimum_rain}.txt", 'r') as f: + self.file_names = [a.replace("\n","") for a in list(f.readlines())] + + + except FileNotFoundError: + self.rainy_leadtimes = [] + copy_to_it = self.file_names[:] + for j, file_name in enumerate(copy_to_it): + if j%100==0: + print(f"Progress {j} / {len(self.file_names)}") + + y = np.load(file_name.replace("X","Y")) + rainy_leads = [] + + for i, y_here in enumerate(y[self.leadtime_spacing-1::self.leadtime_spacing]): + if i>=lead_times: + break + if np.sum(y_here[1:])>minimum_rain: + rainy_leads.append(i) + else: + #print("skipping one") + skipped += 1 + + if not rainy_leads: + len_before = len(self.file_names) + self.file_names.remove(file_name) + assert len_before-len(self.file_names) == 1 + + print("SKIPPING ", file_name) + + else: + for i in np.random.choice(rainy_leads,(lead_times-len(rainy_leads))): + rainy_leads.append(i) + rainy_leads = np.array(rainy_leads) + assert len(rainy_leads) == lead_times + self.rainy_leadtimes.append(rainy_leads) + + assert len(self.rainy_leadtimes) == len(self.file_names) + np.save(f"leadtime_sampling_N_{N}_leads_{lead_times}_spacing_{leadtime_spacing}_minimum_{minimum_rain}.npy",np.array(self.rainy_leadtimes)) + with open(f"leadtime_sampling_N_{N}_leads_{lead_times}_spacing_{leadtime_spacing}_minimum_{minimum_rain}.txt", 'w') as f: + for item in self.file_names: + f.write(f"{item}\n") + '''n_uniques = [] + for leads in self.rainy_leadtimes: + unique = np.unique(leads) + n_uniques.append(len(unique)) + plt.hist(n_uniques, bins = lead_times) + plt.title("Number of unique leadtimes") + plt.show() + a = {} + for lead in range(lead_times): + a[lead] = len(np.where(self.rainy_leadtimes==lead)[0]) + print(a) + plt.hist(self.rainy_leadtimes.reshape(-1), bins = lead_times) + plt.title("resampling of leadtimes") + plt.show()''' + + + self.n_samples = len(self.file_names) + + def __getitem__(self, index): + # allows indexing dataset[0] + name_x = self.file_names[index] + name_y = name_x.replace("X","Y") + x = np.load(name_x) + + y = np.load(name_y) + persistence = y[0] + y = y[self.leadtime_spacing-1::self.leadtime_spacing] + + x = torch.from_numpy(x) + y = torch.from_numpy(y) + if self.ID == "train": + + return x, y, self.rainy_leadtimes[index] + if self.ID == "test": + return x, y, persistence + return x, y + + def __len__(self): + # Will allow len(data) + return self.n_samples + +if __name__=="__main__": + test_data = MetNetDataset("train", N = None, keep_biggest = 1) + n_to_plot = 128 + '''BINS = np.zeros((n_to_plot,)) + + for j,(x,y) in enumerate(test_data): + if j%100==0: print(f"Progress {j}/{len(test_data)}") + for i in range(n_to_plot): + BINS[i] += np.sum(y[0,i].numpy()) + np.save(f"bin_count_no_keep_biggest.npy",BINS)''' + + BINS1 = np.load("bin_count.npy") + BINS2 = np.load("bin_count_no_keep_biggest.npy") + N_1 = np.sum(BINS1) + N_2 = np.sum(BINS2) + BINS1 /= N_1 + BINS2 /= N_2 + for i,a in enumerate(BINS1): + if a==0: + BINS1[i] = BINS1[i-1] + for i,a in enumerate(BINS2): + if a==0: + BINS2[i] = BINS2[i-1] + + + + #BINS = np.load("bin_count_no_keep_biggest.npy") + + rain_mm = np.arange(128)*0.2 + plt.plot(rain_mm,100*BINS1, label =r"$Y_{15\%}$") + plt.plot(rain_mm,100*BINS2, label =r"$Y_{100\%}$") + plt.legend() + plt.yscale("log") + plt.xlabel("Rain rate [mm/h]") + plt.ylabel("Percentage [%]") + plt.title("Percentage of pixels per rain rate") + plt.show() + + plt.plot(rain_mm, 100*BINS1/BINS2, label =r"$\frac{bin_{15}}{bin_{100}}$") + plt.legend() + #plt.yscale("log") + plt.xlabel("Rain rate [mm/h]") + plt.ylabel("Percentage difference [%]") + plt.title("Impact of keeping 15% of data") + plt.show() + + diff --git a/data_prep/prepare_data_MetNet.py b/data_prep/prepare_data_MetNet.py new file mode 100644 index 0000000..4b7c528 --- /dev/null +++ b/data_prep/prepare_data_MetNet.py @@ -0,0 +1,597 @@ +import numpy as np +import h5py as h5 +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from pathlib import Path +import sys, os +from datetime import datetime, timedelta,date + +''' +Output: 5D tensor of shape (n_samples, time_dim, channels, width, height): + +''' + + + + +def space_to_depth(x, block_size): + x = np.asarray(x) + batch, height, width, depth = x.shape + reduced_height = height // block_size + reduced_width = width // block_size + y = x.reshape(batch, reduced_height, block_size, + reduced_width, block_size, depth) + z = np.swapaxes(y, 2, 3).reshape(batch, reduced_height, reduced_width, -1) + return z + +def date_assertion(dates,expected_delta = 5): + for date1,date2 in zip(dates[0:-1],dates[1:]): + list1 = date1.split("_") + #print(list1) + + y1, m1, d1, hour1, minute1 = [int(a) for a in list1] + + datetime1 = datetime(y1, m1, d1, hour=hour1, minute=minute1) + list2 = date2.split("_") + + y2, m2, d2, hour2, minute2 = [int(a) for a in list2] + + datetime2 = datetime(y2, m2, d2, hour=hour2, minute=minute2) + delta = datetime2-datetime1 + minutes = delta.total_seconds()/60 + + #print(datetime1) + #print(datetime2) + #print("DELTA ", minutes, "delta ", expected_delta) + assert int(minutes) == expected_delta + +def h5_iterator(h5_file,maxN = 100,spaced = 1, starter = 0): + """Iterates through the desired datafile and returns index, array and datetime""" + + with h5.File(h5_file,"r") as f: + + keys = list(f.keys()) + for i,name in enumerate(keys): + if i=maxN: break + #print(name) + date = name + y, m, d, hh, mm = date.split("_") + #title = f"Year: {y} month: {months[m]} day: {d} time: {t}" + array = np.array(obj) #flip array + + + #print(array.shape) + yield j, array, date + +def down_sampler(data, rate = 2): + """Spatial downsampling with vertical and horizontal downsampling rate = rate.""" + + down_sampled=data[:,0::rate,0::rate] + + return down_sampled + + + +def temporal_concatenation(data,dates,targets,target_dates,concat = 7, overlap = 0,spaced = 3,lead_times = 60): + """Takes the spatial 2D arrays and concatenates temporal aspect to 3D-vector (T-120min, T-105min, ..., T-0min) + concat = number of frames to encode in temporal dimension + overlap = how many of the spatial arrays are allowed to overlap in another datasample""" + n,x_size,y_size,channels = data.shape + n_y,x_y,y_y = targets.shape + + seq_length = spaced*concat + lead_times #5 minute increments + x_limit = n - seq_length//spaced + #concecutive time + X = [] + X_dates=[] + Y = [] + Y_dates = [] + for i,j in zip(range(0,x_limit,concat-overlap),range(concat*spaced-2,n_y,(concat-overlap)*spaced)): + + if (i+1)%1000==0: + print(f"\nTemporal concatenated samples: ",i+1) + temp_input = data[i:i+concat,:,:] + temp_target = targets[j:j+lead_times,:,:] + temp_dates = dates[i:i+concat] + temp_dates_target = target_dates[j:j+lead_times] + try: + date_assertion(temp_dates,expected_delta = 5*spaced) + date_assertion(temp_dates_target,expected_delta = 5) + fiver = [temp_dates[-1],temp_dates_target[0]] #final X date and first Y date should be 5 spaced minutes + date_assertion(fiver,expected_delta = 5) + except AssertionError: + print(f"Warning, dates are not alligned! Skipping: {i}:{i+seq_length}") + #print(temp_dates) + #print(temp_dates_target) + continue + X.append(temp_input) + X_dates.append(temp_dates) + Y.append(temp_target) + Y_dates.append(temp_dates_target) + X = np.array(X) + Y = np.array(Y) + + return X,Y,X_dates,Y_dates + +def extract_centercrop(data,factor_smaller=2): + + x0 = 0 + y0 = 0 + x1 = data.shape[2] + y1 = data.shape[1] + + try: + assert x1 == y1 + except AssertionError: + print(f"\nWarning: centercrop shapes ({x1}, {y1}) are not the same.") + centercrop_x_lim = slice(x0+x1//(2*factor_smaller),x1-x1//(2*factor_smaller)) + centercrop_y_lim = slice(y0+y1//(2*factor_smaller),y1-y1//(2*factor_smaller)) + + return data[:,centercrop_y_lim,centercrop_x_lim] + +def datetime_encoder(data,dates,plotter = False): + data_shape = data.shape + data_type = data.dtype + year_days = [] + day_minutes=[] + for i,date_string in enumerate(dates): + if (i+1)%1000==0: + print("Dates loaded for encoding: ",i+1) + list1 = date_string.split("_") + + + year,month,day, hour, minute = [int(a) for a in list1] + date_object = date(year,month,day) + day_of_the_year = date_object.timetuple().tm_yday + minute_of_the_day = hour*60 + minute + year_days.append(day_of_the_year) + day_minutes.append(minute_of_the_day) + year_days = np.array(year_days) + day_minutes = np.array(day_minutes) + year_days = np.repeat(year_days[:,np.newaxis],data_shape[1],axis=1) + year_days = np.repeat(year_days[:,:,np.newaxis],data_shape[2],axis=2) + day_minutes = np.repeat(day_minutes[:,np.newaxis],data_shape[1],axis=1) + day_minutes = np.repeat(day_minutes[:,:,np.newaxis],data_shape[2],axis=2) + + + + periodicals = [np.sin(2*np.pi*year_days/365,dtype =data_type), + np.cos(2*np.pi*year_days/365,dtype =data_type), + np.sin(2*np.pi*day_minutes/(60*24),dtype =data_type), + np.cos(2*np.pi*day_minutes/(60*24),dtype =data_type)] + periodicals = [np.expand_dims(a, axis=3) for a in periodicals] + date_array = np.concatenate(periodicals,axis=3) + + try: + assert date_array.shape[0:2] == data_shape[0:2] + except AssertionError: + print("Datetime dimensions seem wrong!") + raise + + if plotter: + fig,ax = plt.subplots(1,2) + ax[0].scatter(date_array[:,0,0,0],date_array[:,0,0,1]) + fig.suptitle(f"Periodical year and days from {dates[0]} to {dates[-1]}") + ax[1].scatter(date_array[:,0,0,2],date_array[:,0,0,3]) + ax[0].set_title("Year") + ax[1].set_title("Day") + + plt.show() + data = np.concatenate((data,date_array),axis=3) + return data + +def longlatencoding(data): + print(f"\nExtracting longitude, latitude and elevation data ...", end="") + + with h5.File("lonlatelev.h5","r") as FF: + lonlatelev = FF["lonlatelev"] + lonlatelev = np.array(lonlatelev)[:112,:112,:] + + lon = lonlatelev[:,:,0] + lat = lonlatelev[:,:,1] + + elev = lonlatelev[:,:,2] + + + lon_mean, lon_std = np.mean(lon), np.std(lon) + lat_mean, lat_std = np.mean(lat), np.std(lat) + elev /= np.max(np.abs(elev)) + elev_mean, elev_std = np.mean(elev), np.std(elev) + try: + assert lon_std != 0 + assert lat_std != 0 + assert elev_std != 0 + except AssertionError: + print("WARNING: LON LAT OR ELEV STD == 0") + lon = (lon-lon_mean)/lon_std + lat = (lat-lat_mean)/lat_std + elev = (elev-elev_mean)/elev_std + elev = np.log(elev-np.min(elev)+0.1) + + #lon = np.tanh(lon) + #lat = np.tanh(lat) + #elev = np.tanh(elev) + print("MINMAX lon: ", np.min(lon), np.max(lon)) + print("MINMAX lat: ", np.min(lat), np.max(lat)) + print("MINMAX elev: ", np.min(elev), np.max(elev)) + + lonlatelev[:,:,0] = lon + lonlatelev[:,:,1] = lat + lonlatelev[:,:,2] = elev + + lonlatelev = np.expand_dims(lonlatelev, axis=0) + lonlatelev = np.repeat(lonlatelev,data.shape[0],axis=0) + + print(f"\ndone! it has shape {lonlatelev.shape}") + + return np.concatenate((data,lonlatelev), axis=3, dtype = np.float32) + + + +def load_data(h5_path,N = 3000,lead_times = 60, concat = 7, square = (0,448,881-448,881), downsampling_rate = 2, overlap = 0, spaced=3,downsample = True, spacedepth =True,centercrop=True,box=2,printer=True, rain_step = 0.2, n_bins=512, keep_biggest = 0.8): + #15 minutes between datapoints is default --> spaced = 3 + snapshots = [] + dates = [] + all_snapshots = [] + Y_dates = [] + array_mean = 0 + means = [] + n = 0 + if not printer: + sys.stdout = open(os.devnull, 'w') + for i, array,date in h5_iterator(h5_path, N): + + if (i+1)%1000==0: + print("Loaded samples: ",n) + + if i%spaced==0: + snapshots.append(array) + dates.append(date) + means.append(np.mean(array)) + n+=1 + all_snapshots.append(array) + array_mean += np.mean(array) + + + Y_dates.append(date) + + '''print("MEAN", array_mean/n) + print("Done loading samples! \n") + n_snap = len(snapshots) + n_all_snaps = len(all_snapshots) + means = np.array(means)[0:(n_snap//concat)*concat].reshape(n_snap//concat, concat) + running_means = np.mean(means,axis=1) + n_runs = len(running_means) + n_to_keep = int(n_runs*keep_biggest) + idx_to_keep = np.argsort(running_means)[-n_to_keep:] + print(idx_to_keep) + temp_s = [] + temp_s_all = [] + temp_s_dates = [] + temp_s_all_dates = [] + print("LEN BEFORE", len(snapshots), len(all_snapshots)) + for i in idx_to_keep: + temp_s.extend(snapshots[i*concat:i*concat+concat]) + temp_s_all.extend(all_snapshots[i*concat*spaced:i*concat*spaced+concat*spaced]) + temp_s_dates.extend(dates[i*concat:i*concat+concat]) + temp_s_all_dates.extend(Y_dates[i*concat*spaced:i*concat*spaced+concat*spaced]) + snapshots = temp_s + all_snapshots = temp_s_all + dates = temp_s_dates + Y_dates = temp_s_all_dates + print("LEN AFTER", len(snapshots), len(all_snapshots)) + print(dates) + print(Y_dates) + input()''' + + + + data = np.array(snapshots) + del(snapshots) # MANAGE MEMORY + all_data = np.array(all_snapshots) + + + del(all_snapshots) # MANAGE MEMORY + print("\nDatatype data: ", data.dtype) + print("\nInput data shape: ", data.shape, " size: ", sys.getsizeof(data)) + + + x0,x1,y0,y1 = square + print(f"\nInput patch by index: xmin = {x0}, xmax = {x1}, ymin = {y0}, ymax = {y1}") + x_lim = slice(x0,x1) + y_lim = slice(y0,y1) + + center_x = (x0+x1)//2 + center_y = (y0+y1)//2 + length_x = (x1-x0)//16 #size of Y is 16 times smaller + length_y = (y1-y0)//16 #size of Y is 16 times smaller + Y_lim_x = slice(center_x-length_x//2,center_x+length_x//2) + Y_lim_y = slice(center_y-length_y//2,center_y+length_y//2) + print("SLICED x: ",Y_lim_x) + print("SLICED y: ",Y_lim_y) + Y = all_data[:,Y_lim_y,Y_lim_x] + del(all_data) #MANAGE MEMORY + print(f"\nY shape here (not ready): {Y.shape}") + + data = data[:,y_lim,x_lim] + print(f"\nSliced data to dimensions {data.shape}") + + if centercrop: #extract centercrop before downsampling, since it's high resolution + + center = extract_centercrop(data) + print(f"\nCopying centercrop with shape {center.shape}") + if downsample == True: + print("\nDownsampling with rate: ", downsampling_rate) + data = down_sampler(data) + print("\nDone downsampling!") + + + print("\nDatatype downsampled: ", data.dtype) + print("\nDownsampled data shape: ",data.shape) + if len(data.shape)<4: + data = np.expand_dims(data, axis=3) + print(f"\nAdding channel dimension to data, new shape: {data.shape}") + if centercrop: + if len(center.shape)<4: + center = np.expand_dims(center, axis=3) + print(f"\nAdding channel dimension to centercrop, new shape: {center.shape}") + if spacedepth==True: + data = space_to_depth(data,box) + + print(f"\nSpace-to-depth done! Data shape: {data.shape}") + if centercrop: + center = space_to_depth(center,box) + print(f"\nSpace-to-depth done! Centercrop shape: {center.shape}") + + if centercrop: + data = np.concatenate((data,center), axis=3) + print(f"\nConcatenating data and centercrop to dimenison: {data.shape} with shape [:,:,:,downsampled + centercrop]") + + + + + data = longlatencoding(data) + print(f"\nConcatenating data with long, lat and elevation. New shape: {data.shape}, dtype: {data.dtype}") + + data = datetime_encoder(data,dates,plotter=False) + print(f"\nEncoding datetime periodical variables (seasonally,hourly) and concatenating with data. New shape: {data.shape}, dtype: {data.dtype}") + + + + + + data = np.swapaxes(np.swapaxes(data,3,1),2,3) + print(f"\nData swapping axes to get channel first, now shape: {data.shape}") + X,Y, X_dates,Y_dates = temporal_concatenation(data,dates,Y,Y_dates,concat = concat, overlap = overlap, spaced = spaced,lead_times = lead_times) + + print(f"\nDone with temporal concatenation and target_split! Data shape: {X.shape}, target shape: {Y.shape}") + + GAIN = 0.4 + OFFSET = -30 + X[:,:,0:8] = X[:,:,0:8]*GAIN + OFFSET + + + maxx = np.max(X[:,:,0:8]) + print("\nMAX DBZ data(should be 72): ", maxx) + data_new = np.empty(X[:,:,0:8].shape) + N = data_new.shape[0] + runs = N//5000 + for run in range(0,N,5000): + data_new[run:run+5000,:,0:8] = np.log(X[run:run+5000,:,0:8]+0.01, dtype = np.float32)/4 + data_new[run:run+5000,:,0:8] = np.nan_to_num(data_new[run:run+5000,:,0:8]) + data_new[run:run+5000,:,0:8] = np.tanh(data_new[run:run+5000,:,0:8], dtype = np.float32) + + data_new[runs*5000:,:,0:8] = np.log(X[runs*5000:,:,0:8]+0.01, dtype = np.float32)/4 + data_new[runs*5000:,:,0:8] = np.nan_to_num(data_new[runs*5000:,:,0:8]) + data_new[runs*5000:,:,0:8] = np.tanh(data_new[runs*5000:,:,0:8], dtype = np.float32) + #data[np.where(data<0)] = 0 + '''data_new = np.log(data+0.01)/4 + data_new = np.nan_to_num(data_new) + data_new = np.tanh(data_new)''' + + + + for i in range(8): + try: + assert np.std(data_new[:,:,:,i]) != 0 + except AssertionError: + print(f"WARNING: CHANNEL {i} STD == 0") + data_new[:,:,i] = (data_new[:,:,i] - np.mean(data_new[:,:,i] ))/np.std(data_new[:,:,i] ) + + + print(f"\nScaling data with log(x+0.01)/4, replace NaN with 0 and apply tanh(x) and convert to data type: {data.dtype}, nbytes: {data.nbytes}, size: {data.size}") + + + Y = Y*GAIN + OFFSET + '''for i in range(0,5): + fig, ax = plt.subplots(1,2) + ax[0].imshow(X[i,0,0,:,:]) + #ax[0].imshow(np.mean(data_after_gained[i*7,42:70,42:70,4:8],axis=2)) + ax[0].set_title(X_dates[i][6]) + ax[1].imshow(Y[i,0,:,:]) + ax[1].set_title(Y_dates[i][0]) + plt.show()''' + + #print("comparing X and Y after gain:", np.mean(data_after_gained[:,:,4:8]), np.mean(Y)) + + Y_gained = np.copy(Y) + + print("MINMAX Y AFTER GAIN + OFFSET", np.min(Y), np.max(Y)) + + passer = np.mean(X[:,6,4:8,:,:],axis=1) + + Y = rain_binned(Y, n_bins = n_bins, increment = rain_step, x = passer) + + print(f"\nDone with binning targets into bins, target shape: {Y.shape}") + + + + #Remove low-rainfall data: + meaned = np.mean(X[:,:,0:4,:,:], axis=(1,2,3,4)) + idx_sorted = np.argsort(meaned) + + N = len(meaned) + '''for j in range(N-1,0,-1): + fig, ax = plt.subplots(7,1) + for k in range(7): + i = idx_sorted[j] + + im = ax[k].imshow(X[i,k,0,:,:]) + ax[k].set_title(f"MEAN: {meaned[i]:.2f}") + + + fig.subplots_adjust(right=0.8) + cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fig.colorbar(im, cax=cbar_ax) + plt.show()''' + to_keep = int(N*keep_biggest) + idx_to_keep = idx_sorted[-to_keep:] + #print(meaned) + #print(meaned[idx_to_keep]) + + X = X[idx_to_keep] + '''print(meaned[idx_sorted]) + print(meaned[idx_to_keep]) + print(np.mean(X[:,:,0:4,:,:], axis=(1,2,3,4))) + input()''' + Y = Y[idx_to_keep] + X_dates = [X_dates[i] for i in idx_to_keep] + Y_dates = [Y_dates[i] for i in idx_to_keep] + print(f"\nOnly keeping {to_keep} out of {N} samples to reduce low rainfall events. New X shape: {X.shape}") + N = X.shape[0] + '''for j in range(N-1,0,-1): + fig, ax = plt.subplots(7,2) + for k in range(7): + + + im = ax[k,0].imshow(X[j,k,0,:,:]) + ax[k,0].set_title(f"MEAN: {np.mean(X[j,k,0,:,:]):.2f}") + ax[k,1].imshow(Y[j,k,0,:,:]) + + + fig.subplots_adjust(right=0.8) + cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fig.colorbar(im, cax=cbar_ax) + plt.show()''' + + #plots all channels seperately: + '''channels = X.shape[2] + for i in range(N): + fig, axs = plt.subplots(4,4) + axs = axs.reshape(-1) + for c in range(channels): + axs[c].imshow(X[i,0,c]) + plt.show()''' + + #print(f"\nOnly keeping {to_keep} out of {N} samples to reduce low rainfall events. New X shape: {X.shape}") + + Y_thresh = np.ones(Y[:,0,:,:].shape)*-1 + + Y_thresh[np.where(Y[:,0,:,:]==1)] = 1 + + + #rain_check(X, Y_gained[idx_to_keep], Y_thresh,X_dates,Y_dates,X_dict,Y_dict, meaned) + if not printer: + sys.stdout = sys.__stdout__ + return X,Y, X_dates,Y_dates + +def rain_binned(Y, n_bins = 51,increment = 2, x = None): + SHAPE = Y.shape + n,leads,w,h = SHAPE + max_fall = n_bins*increment + min_dbz = np.min(Y) + max_dbz = np.max(Y) + Y[np.where(Y>70)] = 0 + + rain = (10**(Y / 10.0) / 200.0)**(1.0 / 1.6) + + print("RAIN MINMAX: ", np.min(rain), np.max(rain)) + + '''for i in range(n): + fig,ax = plt.subplots(1,2) + rain_x = (10**(x[i] / 10.0) / 200.0)**(1.0 / 1.6) + ax[0].imshow(x[i][(112//2)-14:(112//2)+14, (112//2)-14:(112//2)+14]) + ax[0].set_title("x zoomed") + ax[1].imshow(rain[i,0,:,:]) + ax[1].set_title("y") + plt.show()''' + rain_bins = np.zeros((n,leads,n_bins,w,h)) + counter = [] + for i in range(n_bins-1): + bin_min = i*increment + bin_max = (i+1)*increment + rain_bin = np.zeros((n,leads,w,h)) #Y.shape = (None,lead_times, bin_channel, width/4,heigth/4) + idx = np.where(np.logical_and(rain>=bin_min, rain=n_bins*increment) + rain_bin[idx] = 1 + rain_bins[:,:,n_bins-1,:,:] = rain_bin + counter.append(len(idx)) + print("RAINBINS: ", rain_bins.size, " counter size: ", sum(counter)) + + print(counter) + + + return rain_bins +''' +def rain_check(X, Y,Y_thresh,x_dates,y_dates,X_dict,Y_dict,meaned): + X_mid = np.mean(X[:,:,4:8], axis = 2) #,42:70, 42:70 + N = min(X.shape[0], Y.shape[0]) + temps = X_mid.shape[1] + leads = min(Y.shape[1],temps) + + for n in range(N): + fig, axs = plt.subplots(5,temps) + axs = axs.reshape(-1) + #print(X_mid[n,0,:,:].reshape(-1)) + for i in range(temps): + #print("i: ", i, "\n", X_mid[n,i,:,:].reshape(-1)) + + print("MINMAX", np.min(X_mid[n,i,:,:]), np.max(X_mid[n,i,:,:])) + #print(X_mid[np.where(np.isnan(X_mid))] + np.random.random(X_mid[np.where(np.isnan(X_mid))].shape)) + im = axs[i].imshow(X_mid[n,i,:,:]) + + #axs[i].set_title(x_dates[n][i][-5:]) + for i in range(temps): + im = axs[i+temps].imshow(X_dict[x_dates[n][i]]) + axs[i+temps].set_title(x_dates[n][i][-5:]) + for j in range(leads): + im = axs[j+2*temps].imshow(Y[n,j,:,:]) + axs[j+2*temps].set_title(y_dates[n][j][-5:]) + for j in range(leads): + im = axs[j+3*temps].imshow(Y_thresh[n,j,:,:]) + axs[j+3*temps].set_title(y_dates[n][j][-5:]) + for j in range(leads): + im = axs[j+4*temps].imshow(Y_dict[y_dates[n][i]]) + axs[j+4*temps].set_title(x_dates[n][i][-5:]) + fig.suptitle("mean : " + str(meaned[n])) + fig.subplots_adjust(right=0.8) + cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fig.colorbar(im, cax=cbar_ax) + plt.show()''' + + +if __name__=="__main__": + + + data,dates = load_data("combination_all_pn157.h5",N =500,downsample=True,spacedepth=True,printer=True,) + print(data.nbytes) + + + print(np.max(data)) diff --git a/metnet/__init__.py b/metnet/__init__.py index 7eaaddc..e69de29 100644 --- a/metnet/__init__.py +++ b/metnet/__init__.py @@ -1,4 +0,0 @@ -from metnet.models.metnet import MetNet -from metnet.models.metnet2 import MetNet2 - -from .layers import * diff --git a/metnet/layers/ConditionTime.py b/metnet/layers/ConditionTime.py index 97695ae..e752274 100644 --- a/metnet/layers/ConditionTime.py +++ b/metnet/layers/ConditionTime.py @@ -7,11 +7,12 @@ def condition_time(x, i=0, size=(12, 16), seq_len=15): assert i < seq_len times = (torch.eye(seq_len, dtype=x.dtype, device=x.device)[i]).unsqueeze(-1).unsqueeze(-1) ones = torch.ones(1, *size, dtype=x.dtype, device=x.device) + # print((times*ones).shape) return times * ones class ConditionTime(nn.Module): - "Condition Time on a stack of images, adds `horizon` channels to image" + """Condition Time on a stack of images, adds `horizon` channels to image""" def __init__(self, horizon, ch_dim=2, num_dims=5): super().__init__() @@ -20,7 +21,7 @@ def __init__(self, horizon, ch_dim=2, num_dims=5): self.num_dims = num_dims def forward(self, x, fstep=0): - "x stack of images, fsteps" + """x stack of images, fsteps""" if self.num_dims == 5: bs, seq_len, ch, h, w = x.shape ct = condition_time(x, fstep, (h, w), seq_len=self.horizon).repeat(bs, seq_len, 1, 1, 1) diff --git a/metnet/layers/ConvGRU.py b/metnet/layers/ConvGRU.py index 9699c9f..c829d2a 100644 --- a/metnet/layers/ConvGRU.py +++ b/metnet/layers/ConvGRU.py @@ -10,7 +10,7 @@ def __init__( hidden_dim, kernel_size=(3, 3), bias=True, - activation=F.tanh, + activation=torch.tanh, batchnorm=False, ): """ @@ -71,7 +71,7 @@ def forward(self, input, h_prev=None): combined = torch.cat((input, h_prev), dim=1) # concatenate along channel axis - combined_conv = F.sigmoid(self.conv_zr(combined)) + combined_conv = torch.sigmoid(self.conv_zr(combined)) z, r = torch.split(combined_conv, self.hidden_dim, dim=1) @@ -131,7 +131,7 @@ def __init__( n_layers, batch_first=True, bias=True, - activation=F.tanh, + activation=torch.tanh, input_p=0.2, hidden_p=0.1, batchnorm=False, diff --git a/metnet/layers/DownSampler.py b/metnet/layers/DownSampler.py index d7c349d..1efa199 100644 --- a/metnet/layers/DownSampler.py +++ b/metnet/layers/DownSampler.py @@ -9,24 +9,21 @@ def __init__(self, in_channels, output_channels: int = 256, conv_type: str = "st super().__init__() conv2d = get_conv_layer(conv_type=conv_type) self.output_channels = output_channels - if conv_type == "antialiased": - antialiased = True - else: - antialiased = False self.module = nn.Sequential( conv2d(in_channels, 160, 3, padding=1), - nn.MaxPool2d((2, 2), stride=1 if antialiased else 2), - antialiased_cnns.BlurPool(160, stride=2) if antialiased else nn.Identity(), + nn.MaxPool2d((2, 2), stride=2), + # antialiased_cnns.BlurPool(160, stride=2) if antialiased else nn.Identity(), nn.BatchNorm2d(160), conv2d(160, output_channels, 3, padding=1), nn.BatchNorm2d(output_channels), conv2d(output_channels, output_channels, 3, padding=1), nn.BatchNorm2d(output_channels), conv2d(output_channels, output_channels, 3, padding=1), - nn.MaxPool2d((2, 2), stride=1 if antialiased else 2), - antialiased_cnns.BlurPool(output_channels, stride=2) if antialiased else nn.Identity(), + nn.MaxPool2d((2, 2), stride=2), + # antialiased_cnns.BlurPool(output_channels, stride=2) if antialiased else nn.Identity(), ) def forward(self, x): + return self.module.forward(x) diff --git a/metnet/layers/__init__.py b/metnet/layers/__init__.py index aec8617..b93b5e3 100644 --- a/metnet/layers/__init__.py +++ b/metnet/layers/__init__.py @@ -1,5 +1,6 @@ from .ConditionTime import ConditionTime from .ConvGRU import ConvGRU +from .ConvLSTM import ConvLSTM from .DownSampler import DownSampler from .Preprocessor import MetNetPreprocessor from .TimeDistributed import TimeDistributed diff --git a/metnet/models/__init__.py b/metnet/models/__init__.py index 04d0750..e69de29 100644 --- a/metnet/models/__init__.py +++ b/metnet/models/__init__.py @@ -1,2 +0,0 @@ -from .metnet import MetNet -from .metnet2 import MetNet2 diff --git a/metnet/models/metnet.py b/metnet/models/metnet.py index a1d33d0..8475f31 100644 --- a/metnet/models/metnet.py +++ b/metnet/models/metnet.py @@ -9,7 +9,7 @@ class MetNet(torch.nn.Module, PyTorchModelHubMixin): def __init__( self, - image_encoder: str = "downsampler", + image_encoder: str = "downsampler", # 4 CNN layers input_channels: int = 12, sat_channels: int = 12, input_size: int = 256, @@ -42,15 +42,15 @@ def __init__( self.forecast_steps = forecast_steps self.input_channels = input_channels self.output_channels = output_channels - + """ self.preprocessor = MetNetPreprocessor( sat_channels=sat_channels, crop_size=input_size, use_space2depth=True, split_input=True ) # Update number of input_channels with output from MetNetPreprocessor new_channels = sat_channels * 4 # Space2Depth new_channels *= 2 # Concatenate two of them together - input_channels = input_channels - sat_channels + new_channels - self.drop = nn.Dropout(temporal_dropout) + input_channels = input_channels - sat_channels + new_channels""" + # self.drop = nn.Dropout(temporal_dropout) if image_encoder in ["downsampler", "default"]: image_encoder = DownSampler(input_channels + forecast_steps) else: @@ -70,19 +70,26 @@ def __init__( self.head = nn.Conv2d(hidden_dim, output_channels, kernel_size=(1, 1)) # Reduces to mask def encode_timestep(self, x, fstep=1): - + print("\n shape before preprocess: ", x.shape) # Preprocess Tensor - x = self.preprocessor(x) - + # x = self.preprocessor(x) + print("\n shape after preprocess: ", x.shape) # Condition Time + x = self.ct(x, fstep) + print("\n shape after ct: ", x.shape) ##CNN x = self.image_encoder(x) + print("\n shape after image_encoder: ", x.shape) # Temporal Encoder - _, state = self.temporal_enc(self.drop(x)) - return self.temporal_agg(state) + # _, state = self.temporal_enc(self.drop(x)) + _, state = self.temporal_enc(x) + print("\n shape after temp enc: ", state.shape) + dummy = self.temporal_agg(state) + print("\n shape after temporal_agg: ", dummy.shape) + return dummy def forward(self, imgs): """It takes a rank 5 tensor diff --git a/metnet/models/metnet_pylight.py b/metnet/models/metnet_pylight.py new file mode 100644 index 0000000..ea528d3 --- /dev/null +++ b/metnet/models/metnet_pylight.py @@ -0,0 +1,781 @@ +import torch +import torch.nn as nn +from axial_attention import AxialAttention, AxialPositionalEmbedding +from huggingface_hub import PyTorchModelHubMixin +import pytorch_lightning as pl +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from torch import optim +from data_prep import metnet_dataloader, prepare_data_MetNet +from metnet.layers import ConditionTime, ConvGRU, DownSampler, MetNetPreprocessor, TimeDistributed, ConvLSTM +from torch.utils.data import DataLoader, random_split +import numpy as np +import math +import matplotlib.pyplot as plt +import wandb +from sklearn.metrics import f1_score +import PIL +import matplotlib as mpl +from matplotlib import colors +from metnet.layers.utils import get_conv_layer + +class MetNetPylight(pl.LightningModule, PyTorchModelHubMixin): + def __init__( + self, + image_encoder: str = "downsampler", #4 CNN layers + file_name: str = "data3_500k.h5", + input_channels: int = 12, # radar channels + longitude channels + time encoding channels = 15 (excluding lead time encoding) + n_samples: int = 1000, # number of radar snapshots to preprocess + sat_channels: int = 0, # ignore + input_size: int = 256, # height = width = input_size = 112 + output_channels: int = 512, # number of rain bins + rain_step: int = 0.2, # size of each rain bin in millimeters + hidden_dim: int = 384, # hidden dimensions in RNN layer + kernel_size: int = 3, # Kernel sizes in Downsampler + num_layers: int = 1, # ignore + num_att_layers: int = 4, #Number of attention layers, 8 original paper. + forecast_steps: int = 240, # Number of lead times + temporal_dropout: float = 0.2, # Dropout + num_workers: int = 32, + batch_size: int = 8, + momentum: float = 0.9, + att_heads: int = 8, + plot_every: int = 10, + keep_biggest: float = 0.8, + + learning_rate: int = 1e-2, + leadtime_spacing: int = 1, + **kwargs, + ): + super(MetNetPylight, self).__init__() + pl.seed_everything(42, workers = True) + config = locals() + config.pop("self") + config.pop("__class__") + self.config = kwargs.pop("config", config) + sat_channels = self.config["sat_channels"] + input_size = self.config["input_size"] + input_channels = self.config["input_channels"] + temporal_dropout = self.config["temporal_dropout"] + image_encoder = self.config["image_encoder"] + forecast_steps = self.config["forecast_steps"] + hidden_dim = self.config["hidden_dim"] + kernel_size = self.config["kernel_size"] + num_layers = self.config["num_layers"] + num_att_layers = self.config["num_att_layers"] + output_channels = self.config["output_channels"] + + self.forecast_steps = forecast_steps + self.input_channels = input_channels + self.output_channels = output_channels + self.n_samples = n_samples + self.file_name = file_name + self.workers = num_workers + self.rain_step = rain_step + self.learning_rate = learning_rate + self.batch_size = batch_size + self.plot_every = plot_every + self.momentum = momentum + self.att_heads = att_heads + lead_time_keys = list(range(forecast_steps)) + lead_time_counts = [0 for i in range(forecast_steps)] + self.lead_time_histogram = dict(zip(lead_time_keys,lead_time_counts)) + self.keep_biggest = keep_biggest + self.weights = None + self.leadtime_spacing = leadtime_spacing + self.testing = False + if str(self.device)== "cuda:0": + self.printer = True + else: + self.printer = False + ''' + self.preprocessor = MetNetPreprocessor( + sat_channels=sat_channels, crop_size=input_size, use_space2depth=True, split_input=True + ) + # Update number of input_channels with output from MetNetPreprocessor + new_channels = sat_channels * 4 # Space2Depth + new_channels *= 2 # Concatenate two of them together + input_channels = input_channels - sat_channels + new_channels''' + self.drop = nn.Dropout(temporal_dropout) + '''if image_encoder in ["downsampler", "default"]: + image_encoder = DownSampler(input_channels + forecast_steps) + else: + raise ValueError(f"Image_encoder {image_encoder} is not recognized")''' + image_encoder = DownSampler(input_channels + forecast_steps) + self.image_encoder = TimeDistributed(image_encoder) + self.ct = ConditionTime(forecast_steps) + self.temporal_enc = TemporalEncoder( + image_encoder.output_channels, hidden_dim, ks=kernel_size, n_layers=num_layers + ) + self.position_embedding = AxialPositionalEmbedding(dim=self.temporal_enc.out_channels, shape = (input_size // 4, input_size // 4)) + self.temporal_agg = nn.Sequential( + *[ + AxialAttention(dim=hidden_dim, dim_index=1, heads=self.att_heads, num_dimensions=2) + for _ in range(num_att_layers) + ] + ) + '''conv2d = get_conv_layer(conv_type="standard") + self.conv_agg = nn.Sequential( + conv2d(hidden_dim, hidden_dim, kernel_size=(28,1), padding="same"), + conv2d(hidden_dim, hidden_dim, kernel_size=(1,28), padding="same"), + conv2d(hidden_dim, hidden_dim, kernel_size=(3,3), padding=1), + #nn.MaxPool2d((2, 2), stride=2), + # antialiased_cnns.BlurPool(160, stride=2) if antialiased else nn.Identity(), + nn.BatchNorm2d(hidden_dim), + conv2d(hidden_dim, hidden_dim, kernel_size=(28,1)), + conv2d(hidden_dim, hidden_dim, kernel_size=(1,28)), + conv2d(hidden_dim, hidden_dim, kernel_size=(3,3), padding=1), + #nn.MaxPool2d((2, 2), stride=2), + # antialiased_cnns.BlurPool(output_channels, stride=2) if antialiased else nn.Identity(), + )''' + self.head = nn.Conv2d(hidden_dim, output_channels, kernel_size=(1, 1)) # Reduces to mask + self.double() + + self.save_hyperparameters() + + def encode_timestep(self, x, fstep=1, lead_times = []): + #print("\n shape before preprocess: ", x.shape) + # Preprocess Tensor + #x = self.preprocessor(x) + #print("\n shape after preprocess: ", x.shape) + # Condition Time + #plot_channels(x, 1, tit_add = "INPUT") + if lead_times: + bs, t, c, w, h = x.shape + x_temp = torch.empty((bs, t, c+self.forecast_steps, w, h), device = self.device) + for i,lead_time in enumerate(lead_times): + + x_temp[i] = self.ct(x[i:i+1], lead_time) + else: + x_temp = self.ct(x, fstep) + + x = x_temp.double() + + + + #print("\n shape after ct: ", x.shape) + + ##CNN + x = self.image_encoder(x) + #plot_channels(x, 1, tit_add = "after image_encoder") + #print("\n shape after image_encoder: ", x.shape) + + # Temporal Encoder + #_, state = self.temporal_enc(self.drop(x)) + if not self.testing: + x = self.drop(x) + _, state = self.temporal_enc(x) + embedded = self.position_embedding(state) + #plot_channels(state, 1, tit_add = "after temporal_enc") + #print("\n shape after temp enc: ", state.shape) + #return state #REMOVEMOMROEMREOMREOMREORMERRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRR + agg = self.temporal_agg(state) + #agg = self.conv_agg(state) + #plot_channels(x, 1, tit_add = "after agg") + #print("\n shape after temporal_agg: ", agg.shape) + return agg + + def forward(self, imgs,lead_time = 0, lead_times = []): + """It takes a rank 5 tensor + - imgs [bs, seq_len, channels, h, w] + - lead_time #random int between 0,self.forecast_steps + """ + + # Compute all timesteps, probably can be parallelized + #print("in forward") + #print_channels(imgs[0,0]) + #plot_channels(imgs, 2, tit_add = "input") + #print(imgs.shape) + x = self.encode_timestep(imgs, lead_time, lead_times) + #plot_channels(x, 2, tit_add = "after encode") + #print("shape before head: ", x.shape) + out = self.head(x) + #print("shape after head: ", out.shape) + #plot_channels(out, 1, tit_add = "after head") + #soft = torch.softmax(out,dim=1) + #print("shape after softmax: ", soft.shape) + #plot_channels(out, 1, tit_add = "after softmax") + #plot_bins(out[0], " after head",soft[0], "after softmax") + return out + + + def training_step(self, batch, batch_idx): + + x, y, rainy_leads = batch + #print("training_step len: ", x.shape[0]) + + bs = x.shape[0] + lead_times = [int(np.random.choice(leads.cpu().detach().numpy())) for leads in rainy_leads] + + #w = torch.tensor(self.weights,device = self.device) + L = CrossEntropyLoss() + y_hat = self(x.float(),lead_times=lead_times) + + #y_leads = torch.empty(y[:,lead_time].shape, device = self.device) + #y_leads = torch.tensor([y[i,lead_times[i]] for i in range(self.batch_size)], device = self.device) + loss = L(y_hat, y[torch.arange(bs), lead_times], ) + + + self.log("train/loss", loss, on_step=False, on_epoch=True) + '''if batch_idx == 0: + + self.train_batch = (x[0:1],y[0:1])''' + + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + + x, y = batch + #print("validation_step len: ", x.shape[0]) + #log_img = (batch_idx == 0 and str(self.device)== "cuda:0") + lead_times = list(range(self.forecast_steps)) + + loss = 0 + f1_val = 0 + f1_count = 0 + L = CrossEntropyLoss() + for lead_time in lead_times: + y_hat = self(x.float(),lead_time) + + loss += L(y_hat, y[:,lead_time]) + ''' + # F1 score: + softed = torch.softmax(y_hat, dim=1) + n_soft = softed.shape[0] + rainy = torch.sum(y_hat[:,1:], dim=1) + idx_rain = [torch.where(rainy[j]>y_hat[j,0]) for j in range(n_soft)] + for i, soft in enumerate(softed): + if len(idx_rain[i])==0: + continue + else: + f1_count +=1 + truth = torch.sum(y[i,lead_time,1:],dim = 0) + pred = torch.zeros(truth.shape) + idx_here = idx_rain[i] + pred[idx_here] = 1 + + truth_flat = torch.flatten(truth) + pred_flat = torch.flatten(pred) + print("SUM FLAT: ", torch.sum(truth_flat), torch.sum(pred_flat)) + + f1_here = f1_score(truth_flat.cpu().detach().numpy(),pred_flat.cpu().detach().numpy(), zero_divison=0) + + f1_val += f1_here''' + + '''y_img, y_hat_img = thresh_imgs(y[0,lead_time], y_hat[0], thresh_bin = 1) + if log_img: + self.logger.experiment.log({f"val_{lead_time}":[wandb.Image(y_img.cpu(), caption=f"y leadtime {lead_time}"), wandb.Image(y_hat_img.cpu(), caption=f"y_hat leadtime {lead_time}")]})''' + + + loss /= self.forecast_steps + #f1_val /= f1_count + self.log("validation/loss_epoch", loss, on_step=False, on_epoch=True) + #self.log("validation/f1_score", f1_val, on_step=False, on_epoch=True) + return {"loss": loss} #, "f1_val": f1_val + + def test_step(self, batch, batch_idx): + x,y, persistence = batch + + + + # ---------- calculate test_loss ---------- + loss = 0 + L = CrossEntropyLoss() + y_hat = torch.empty(y.shape) + for lead_time in range(self.forecast_steps): + y_hat_here= self(x,lead_time) + loss += L(y_hat_here, y[:,lead_time]) + y_hat[:,lead_time] = y_hat_here + loss /= self.forecast_steps + self.log("test/loss_epoch", loss, on_step=True, on_epoch=True) + # ---------- calculate test_loss ---------- + + + + + # ---------- calculate f1_score ---------- + self.probabillity_thresh = 0 + thresh = self.thresh + rain_y = torch.sum(y[:,:,thresh:], dim=2) + softed = torch.softmax(y_hat,dim=2) + no_rain_y_hat = torch.sum(softed[:,:, 0:thresh], dim=2) + rain_y_hat = torch.sum(softed[:,:,thresh:], dim=2) + threshed_y_hat = torch.zeros(rain_y_hat.shape) + idx_above = torch.where(rain_y_hat>self.probabillity_thresh) + threshed_y_hat[idx_above] = 1 + + + + + for sample in range(y.shape[0]): + + after_five = persistence[sample] + after_five = torch.sum(after_five[thresh:],dim=0) + after_five = torch.flatten(after_five).cpu().detach().numpy() + + f1_here = [] + for lead_time in range(self.forecast_steps): + if torch.sum(rain_y[sample,lead_time])<5: + self.skipped += 1 + continue + truth = torch.flatten(rain_y[sample,lead_time]).cpu().detach().numpy() + pred = torch.flatten(threshed_y_hat[sample,lead_time]).cpu().detach().numpy() + + f1 = f1_score(truth, pred) + f1_after_five = f1_score(truth, after_five) + f1_here.append((f1,f1_after_five)) + self.f1s[lead_time].append(f1) + self.f1s_control[lead_time].append(f1_after_five) + + #plot_probabillity(y[sample],softed[sample],[kk for kk in range(0,self.forecast_steps,self.forecast_steps//3)],increment = 0.2, spacing = self.leadtime_spacing, f1_scores = f1_here) + + + + + + # ---------- calculate f1_score ---------- + + + #for i in range(y.shape[0]): #REWRITE + '''for i in range(1): + temp3 = x[i,-1,0:4].cpu().numpy() + temp3 = np.mean(temp3, axis = 0) + temp3 = (temp3 + np.min(temp3))/(np.max(temp3)-np.min(temp3)) + temp3 = temp3*255 + pil_im_x = PIL.Image.fromarray(np.uint8(temp3)) + self.logger.log_image(key=f"val_{lead_time}", images=[pil_im_x], caption = ["x"]) + + self.f1_count += 1 + imgs = [] + capts = [] + for lead_time in range(0,self.forecast_steps,10): + if torch.sum(y[i,lead_time,0])>28*28-5: + continue + y_img, y_hat_img, f1,f1_control = thresh_imgs(y[i,lead_time], y_hat[i,lead_time], no_rain, thresh_bin = thresh) + self.avg_y_img[lead_time]+=(torch.mean(y_img.cpu())) + self.avg_y_hat_img[lead_time]+=(torch.mean(y_hat_img.cpu())) + self.f1s[lead_time] += f1 + self.f1s_control[lead_time] += f1_control + + if lead_time in list(range(0,self.forecast_steps,self.forecast_steps//3)): + + temp1 = y_img.cpu().numpy() + pil_im_y = PIL.Image.fromarray(np.uint8(temp1)*255) + imgs.append(pil_im_y) + mean_y = str(np.mean(temp1)) + + capts.append(f"y {lead_time}, mean={mean_y[0:4]}") + temp2 = y_hat_img.cpu().numpy() + pil_im_y = PIL.Image.fromarray(np.uint8(temp2)*255) + imgs.append(pil_im_y) + capts.append(f"y_hat {lead_time}, f1={round(f1,4)}") + + + + fig,axs = plt.subplots(1,2) + im0 = axs[0].imshow(y_img, vmin=0, vmax=1) + axs[0].set_title("y") + im0 = axs[1].imshow(y_hat_img, vmin=0, vmax=1) + axs[1].set_title("y_hat") + fig.suptitle(f"Threshhold: {thresh}, lead time: {lead_time}, with f1-score: {f1}") + fig.subplots_adjust(right=0.8) + + fig.colorbar(im0, ax=axs.ravel().tolist()) + plt.show() + self.logger.log_image(key=f"val_{lead_time}", images=imgs, caption = capts) + ''' + '''for i in range(x.shape[0]): + softed = torch.softmax(y_hat[i],dim=1) + plot_probabillity(y[i],softed,[kk for kk in range(0,self.forecast_steps,self.forecast_steps//3)],increment = 0.2, spacing = self.leadtime_spacing)''' + #plot_categories(y[i,0],softed,increment = 0.2) + + #plot bins: + '''for i in range(x.shape[0]): + softed = torch.softmax(y_hat[i,0],dim=0) + plot_bins(x[i], softed[0:9], "x" ,y[i,0,0:9], "y") + #plot category: + for i in range(x.shape[0]): + plot_category(y[i,0],y_hat[i,0],self.output_channels,self.rain_step,self.device)''' + def on_train_epoch_end(self): + imgs = [] + capts = [] + thresh = 1 + + '''for lead_time in range(0,self.forecast_steps,10): + y_hat = self(self.train_batch[0],lead_time) + y = self.train_batch[1] + no_rain = torch.zeros(y[0,0].shape, device = self.device) + no_rain[0] = 1 + y_img, y_hat_img, f1,f1_control = thresh_imgs(y[0,lead_time], y_hat[0], no_rain, thresh_bin = thresh) + + + + temp1 = y_img.cpu().numpy() + pil_im_y = PIL.Image.fromarray(np.uint8(temp1)*255) + imgs.append(pil_im_y) + capts.append(f"y {lead_time}") + temp2 = y_hat_img.cpu().numpy() + pil_im_y = PIL.Image.fromarray(np.uint8(temp2)*255) + imgs.append(pil_im_y) + capts.append(f"y_hat {lead_time}") + self.logger.log_image(key=f"train_end_{lead_time}", images=imgs, caption = capts)''' + + def validation_epoch_end(self, val_step_outputs): + '''wombo = self.logger.experiment + hist_scores = [[s] for s in self.lead_time_histogram.values()] + table = wandb.Table(data = hist_scores, columns = ["lead_times"]) + wombo.log({"leadtimes_histogram": wombo.plot.histogram(table, "Lead times", title="Lead times histogram")})''' + outs = tuple([x["loss"] for x in val_step_outputs]) + #outs_f1 = tuple([x["f1_val"] for x in val_step_outputs]) + avg_val_loss = torch.tensor(outs, device = self.device).mean() + #avg_val_f1 = torch.tensor(outs, device = self.device).mean() + return {"val_loss":avg_val_loss} #, "val_f1_avg": avg_val_f1} + def test_epoch_end(self,test_step_outputs): + f1_mean = [ np.mean(f1s) for f1s in self.f1s] + lens = [len(f1s) for f1s in self.f1s] + print(f"Skipped {self.skipped} / {self.skipped + sum(lens)}") + f1_mean = np.array(f1_mean) + f1_control_mean = np.array([ np.mean(f1s) for f1s in self.f1s_control]) + np.save(f"f1_threshed_{self.probabillity_thresh}_N_{sum(lens)}_thresh_{self.thresh}.npy",f1_mean) + np.save(f"f1_control_N_{sum(lens)}_thresh_{self.thresh}.npy",f1_control_mean) + plt.plot(f1_mean,"b", label="f1 meaned") + plt.plot(f1_control_mean,"--g", label="persistence") + plt.legend() + plt.title(f"f1-scores, P(rate>0.2)>{self.probabillity_thresh}") + plt.xlabel("lead_time") + plt.ylabel("f1") + plt.show() + + '''y_means = np.array(self.avg_y_img)/self.f1_count + y_hat_means = np.array(self.avg_y_hat_img)/self.f1_count + plt.plot(y_means,"b", label="y means") + plt.plot(y_hat_means,"g", label="y_hat means") + plt.legend() + plt.title(f"y and yhat above threshhold at different leadtimes") + plt.xlabel("lead_time") + plt.ylabel("mean") + plt.show()''' + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(),lr=self.learning_rate, momentum = self.momentum) + return optimizer + + def setup(self, stage = None): + self.train_data = metnet_dataloader.MetNetDataset("train", N = self.n_samples , keep_biggest = self.keep_biggest, leadtime_spacing = self.leadtime_spacing, lead_times = self.forecast_steps) + self.val_data = metnet_dataloader.MetNetDataset("val", N = None, keep_biggest = 1, leadtime_spacing = self.leadtime_spacing) + self.test_data = metnet_dataloader.MetNetDataset("test", N = self.n_samples, keep_biggest = 0.1, leadtime_spacing = self.leadtime_spacing) + '''if self.train_data.weights is not None: + + self.weights = self.train_data.weights''' + + + + print(f"Training data samples = {len(self.train_data)}") + + print(f"Validation data samples = {len(self.val_data)}") + print(f"Test data samples = {len(self.test_data)}") + + + def train_dataloader(self): + train_loader = DataLoader(dataset=self.train_data,batch_size=self.batch_size, num_workers = self.workers, shuffle = True) + return train_loader + def val_dataloader(self): + val_loader = DataLoader(self.val_data,batch_size=self.batch_size, num_workers = self.workers) + return val_loader + def test_dataloader(self): + test_loader = DataLoader(self.test_data,batch_size=self.batch_size, num_workers = self.workers) + return test_loader + + '''def on_train_epoch_start(self): + print("INSIDE NOW") + self.rain_bins = torch.zeros((self.output_channels,)) + for i, batch in enumerate(self.train_data): + x,y = batch + self.rain_bins += torch.sum(y.cpu(),dim = [0,1,3,4]) + if i%50==0: + print(f"{i}/{len(self.train_data)}") + plt.plot(self.rain_bins) + plt.yscale("log") + plt.show()''' + + +def thresh_imgs(y, y_hat, after_five, thresh_bin = 1): + bins, w, h = y_hat.shape + + y_below = torch.sum(y[0:thresh_bin], dim=0) + y_above = torch.sum(y[thresh_bin:], dim=0) + y_outcome = torch.zeros((w, h)) + y_idx_less_rain = torch.where(y_belowpart_2_b) + a1[idx_1_below] = 0 + a1[idx_1_above] = 1 + a2[idx_2_below] = 0 + a2[idx_2_above] = 1 + y = a1.cpu().detach().numpy() + y_hat = a2.cpu().detach().numpy() + + fig, ax = plt.subplots(1,2) + fig.suptitle(str(title)) + ax[0].imshow(y) + ax[1].imshow(y_hat) + plt.show() + +def plot_categories(y,y_hat,increment = 0.2): + #accepts y.shape = (bins,w,h) + + _, w, h = y.shape + categories = [(0,1), (1,5), (5,10), (10, y.shape[0]-1)] + labels = {} + bounds = [i for i in range(len(categories)+1)] + rain_img_y = torch.zeros((w,h)) + rain_img_y_hat = torch.zeros((w,h)) + part_y_hat_probs = rain_img_y_hat[:] + for i, (low,high) in enumerate(categories): + low_str = str(increment*low) + high_str = str(increment*high) + labels[i] = low_str + "-" + high_str + + + + part_y = torch.sum(y[low:high], dim = 0) + + part_y_hat = torch.sum(y_hat[low:high], dim = 0) + + idx_y = torch.where(part_y==1) + idx_y_hat = torch.where(part_y_hat>part_y_hat_probs) + part_y_hat_probs[idx_y_hat] = part_y_hat[idx_y_hat] + rain_img_y[idx_y] = i + rain_img_y_hat[idx_y_hat] = i + + fig, ax = plt.subplots(1,2) + rain_img_y = rain_img_y/torch.max(rain_img_y) + rain_img_y_hat = rain_img_y_hat/torch.max(rain_img_y_hat) + ax[0].imshow(rain_img_y, cmap = "hot") + ax[0].set_title("Ground truth rain") + ax[1].imshow(rain_img_y_hat, cmap = "hot") + ax[1].set_title("Predicted rain") + plt.show() + +def plot_probabillity(y,y_hat,lead_times ,increment = 0.2,spacing = 1, f1_scores = []): + #accepts y.shape = (leads, bins,w,h) + + _, _, w, h = y.shape + fig, ax = plt.subplots(len(lead_times),2) + + + bounds = [0, 0.2, 1,3] + prob_bounds = [0, 0.25, 0.5, 0.75, 1] + + + ii = np.arange(0,28) + jj = np.arange(0,28) + xx, yy = np.meshgrid(ii, jj) + + divnorm_bounds = colors.TwoSlopeNorm(vmin=0, vcenter=0.1, vmax=3) + + for j, lead_time in enumerate(lead_times): + rain_img_y = torch.zeros((w,h)) + + for i in range(y.shape[1]): + idx_y_rain = torch.where( y[lead_time,i] == 1) + rain_img_y[idx_y_rain] = i*increment + + #zz = rain_img_y[xx,yy] + + rain_img_y_hat = torch.sum(y_hat[lead_time,1:], dim = 0) + + #np.save("att_visualisera.npy", rain_img_y.cpu().detach().numpy()) + zz = rain_img_y.cpu().detach().numpy() + + zz_hat = rain_img_y_hat.cpu().detach().numpy() + im1 = ax[j,0].contourf(xx,yy,zz,bounds,cmap = "Reds", extend="both", norm = divnorm_bounds) + + ax[j,0].set_title(f"Lead time:{(lead_time*5+5)*spacing} min") + im2 = ax[j,1].contourf(xx,yy,zz_hat,prob_bounds,cmap = "Greens") + ax[j,1].set_title(f"Lead time:{(lead_time*5+5)*spacing} min, f1: {round(f1_scores[lead_time][0],3)}") + ax[j,0].get_xaxis().set_visible(False) + ax[j,0].get_yaxis().set_visible(False) + ax[j,1].get_xaxis().set_visible(False) + ax[j,1].get_yaxis().set_visible(False) + + #fig.subplots_adjust(left=0.2) + + + cb1 = fig.colorbar(im1, ax=ax[:,0]) + cb1.set_label("Rain [mm/h]") + cb2 = fig.colorbar(im2, ax=ax[:,1]) + cb2.set_label("Probabillity of rain>0.2mm/h") + fig.suptitle("Ground truth rainfall (left) vs. Prediction probabillity (right)") + plt.show() +def plot_bins(x, y1, title1="", y2 = 0, title2=""): + + N = y1.shape[0] + side = int(N**0.5) + if side**2x.shape[2]: maxN = x.shape[2] + channels = np.random.choice(x.shape[2],maxN, replace = False) + for channel in channels: + + plt.imshow(x[0,0,0,:,:].cpu().detach().numpy()) + plt.colorbar() + plt.title(str(channel)+" " + tit_add) + plt.show() + else: + if not maxN or maxN>x.shape[1]: maxN = x.shape[1] + channels = np.random.choice(x.shape[1],maxN, replace = False) + for channel in channels: + + plt.imshow(x[0,0,:,:].cpu().detach().numpy()) + plt.colorbar() + plt.title(str(channel)+" " + tit_add) + plt.show() + + + +def print_channels(x): + x = x.cpu().detach().numpy() + for channel, array in enumerate(x): + mean = np.mean(array) + std = np.std(array) + print(f"Mean: {mean} Std: {std} for channel {channel}") +class RainfieldCallback(pl.Callback): + def __init__(self, val_samples, num_samples=1): + super().__init__() + self.val_imgs, self.val_Y = val_samples + self.val_imgs = self.val_imgs[:num_samples] + self.val_Y = self.val_Y[:num_samples] + + + def on_validation_epoch_end(self, trainer, pl_module, lead_times = 60): + val_imgs = self.val_imgs.to(device=pl_module.device) + for lead_time in range(lead_times): + y_hat = pl_module(val_imgs) + preds = torch.argmax(logits, 1) + + trainer.logger.experiment.log({ + "examples": [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") + for x, pred, y in zip(val_imgs, preds, self.val_labels)], + "global_step": trainer.global_step + }) + +''' +def feat2image(x, target_size=(128, 128)): + "This idea comes from MetNet" + x = x.transpose(1, 2) + return x.unsqueeze(-1).unsqueeze(-1) * x.new_ones(1, 1, 1, *target_size)''' diff --git a/prepare_files.py b/prepare_files.py new file mode 100644 index 0000000..e9701ec --- /dev/null +++ b/prepare_files.py @@ -0,0 +1,1062 @@ +import numpy as np +import h5py as h5 +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from pathlib import Path +import sys, os +from datetime import datetime, timedelta,date + +''' +Output: 5D tensor of shape (n_samples, time_dim, channels, width, height): + +''' + + + + +def space_to_depth(x, block_size): + x = np.asarray(x) + batch, height, width, depth = x.shape + reduced_height = height // block_size + reduced_width = width // block_size + y = x.reshape(batch, reduced_height, block_size, + reduced_width, block_size, depth) + z = np.swapaxes(y, 2, 3).reshape(batch, reduced_height, reduced_width, -1) + return z + +def date_assertion(dates,expected_delta = 5): + for date1,date2 in zip(dates[0:-1],dates[1:]): + list1 = date1.split("_") + #print(list1) + + y1, m1, d1, hour1, minute1 = [int(a) for a in list1] + + datetime1 = datetime(y1, m1, d1, hour=hour1, minute=minute1) + list2 = date2.split("_") + + y2, m2, d2, hour2, minute2 = [int(a) for a in list2] + + datetime2 = datetime(y2, m2, d2, hour=hour2, minute=minute2) + delta = datetime2-datetime1 + minutes = delta.total_seconds()/60 + + #print(datetime1) + #print(datetime2) + #print("DELTA ", minutes, "delta ", expected_delta) + assert int(minutes) == expected_delta + +def h5_iterator(h5_file,maxN = 100,spaced = 1, starter = 0): + """Iterates through the desired datafile and returns index, array and datetime""" + + with h5.File(h5_file,"r") as f: + + keys = list(f.keys())[starter:] + for i,name in enumerate(keys): + + if i%spaced!=0: + continue + j = i//spaced + obj = f[name] + + #print(name, obj) + if maxN: + if i//spaced>=maxN: break + #print(name) + date = name + y, m, d, hh, mm = date.split("_") + #title = f"Year: {y} month: {months[m]} day: {d} time: {t}" + array = np.array(obj) #flip array + + + #print(array.shape) + yield j, array, date + +def down_sampler(data, rate = 2): + """Spatial downsampling with vertical and horizontal downsampling rate = rate.""" + + down_sampled=data[:,0::rate,0::rate] + + return down_sampled + + + +def temporal_concatenation(data,dates,targets,target_dates,concat = 7, overlap = 0,spaced = 3,lead_times = 60): + """Takes the spatial 2D arrays and concatenates temporal aspect to 3D-vector (T-120min, T-105min, ..., T-0min) + concat = number of frames to encode in temporal dimension + overlap = how many of the spatial arrays are allowed to overlap in another datasample""" + n,x_size,y_size,channels = data.shape + n_y,x_y,y_y = targets.shape + + seq_length = spaced*concat + lead_times #5 minute increments + x_limit = n - seq_length//spaced + #concecutive time + X = [] + X_dates=[] + Y = [] + Y_dates = [] + for i,j in zip(range(0,x_limit,concat-overlap),range(concat*spaced-2,n_y,(concat-overlap)*spaced)): + + if (i+1)%1000==0: + print(f"\nTemporal concatenated samples: ",i+1) + temp_input = data[i:i+concat,:,:] + temp_target = targets[j:j+lead_times,:,:] + temp_dates = dates[i:i+concat] + temp_dates_target = target_dates[j:j+lead_times] + try: + date_assertion(temp_dates,expected_delta = 5*spaced) + date_assertion(temp_dates_target,expected_delta = 5) + fiver = [temp_dates[-1],temp_dates_target[0]] #final X date and first Y date should be 5 spaced minutes + date_assertion(fiver,expected_delta = 5) + except AssertionError: + print(f"Warning, dates are not alligned! Skipping: {i}:{i+seq_length}") + #print(temp_dates) + #print(temp_dates_target) + continue + X.append(temp_input) + X_dates.append(temp_dates) + Y.append(temp_target) + Y_dates.append(temp_dates_target) + X = np.array(X) + Y = np.array(Y) + + return X,Y,X_dates,Y_dates + +def extract_centercrop(data,factor_smaller=2): + + x0 = 0 + y0 = 0 + x1 = data.shape[2] + y1 = data.shape[1] + + try: + assert x1 == y1 + except AssertionError: + print(f"\nWarning: centercrop shapes ({x1}, {y1}) are not the same.") + centercrop_x_lim = slice(x0+x1//(2*factor_smaller),x1-x1//(2*factor_smaller)) + centercrop_y_lim = slice(y0+y1//(2*factor_smaller),y1-y1//(2*factor_smaller)) + + return data[:,centercrop_y_lim,centercrop_x_lim] + +def datetime_encoder(data,dates,plotter = False): + data_shape = data.shape + data_type = data.dtype + year_days = [] + day_minutes=[] + for i,date_string in enumerate(dates): + if (i+1)%1000==0: + print("Dates loaded for encoding: ",i+1) + list1 = date_string.split("_") + + + year,month,day, hour, minute = [int(a) for a in list1] + date_object = date(year,month,day) + day_of_the_year = date_object.timetuple().tm_yday + minute_of_the_day = hour*60 + minute + year_days.append(day_of_the_year) + day_minutes.append(minute_of_the_day) + year_days = np.array(year_days) + day_minutes = np.array(day_minutes) + year_days = np.repeat(year_days[:,np.newaxis],data_shape[1],axis=1) + year_days = np.repeat(year_days[:,:,np.newaxis],data_shape[2],axis=2) + day_minutes = np.repeat(day_minutes[:,np.newaxis],data_shape[1],axis=1) + day_minutes = np.repeat(day_minutes[:,:,np.newaxis],data_shape[2],axis=2) + + + + periodicals = [np.sin(2*np.pi*year_days/365,dtype =data_type), + np.cos(2*np.pi*year_days/365,dtype =data_type), + np.sin(2*np.pi*day_minutes/(60*24),dtype =data_type), + np.cos(2*np.pi*day_minutes/(60*24),dtype =data_type)] + periodicals = [np.expand_dims(a, axis=3) for a in periodicals] + date_array = np.concatenate(periodicals,axis=3) + + try: + assert date_array.shape[0:2] == data_shape[0:2] + except AssertionError: + print("Datetime dimensions seem wrong!") + raise + + if plotter: + fig,ax = plt.subplots(1,2) + ax[0].scatter(date_array[:,0,0,0],date_array[:,0,0,1]) + fig.suptitle(f"Periodical year and days from {dates[0]} to {dates[-1]}") + ax[1].scatter(date_array[:,0,0,2],date_array[:,0,0,3]) + ax[0].set_title("Year") + ax[1].set_title("Day") + + plt.show() + data = np.concatenate((data,date_array),axis=3) + return data + +def longlatencoding(data): + print(f"\nExtracting longitude, latitude and elevation data ...", end="") + + with h5.File("lonlatelev_full.h5","r") as FF: + lonlatelev = FF["lonlatelev"] + lonlatelev = np.array(lonlatelev) + print("LONLAT SHAPE:", lonlatelev.shape) + + lon = lonlatelev[:,:,0] + lat = lonlatelev[:,:,1] + + elev = lonlatelev[:,:,2] + + + lon_mean, lon_std = np.mean(lon), np.std(lon) + lat_mean, lat_std = np.mean(lat), np.std(lat) + elev /= np.max(np.abs(elev)) + elev_mean, elev_std = np.mean(elev), np.std(elev) + try: + assert lon_std != 0 + assert lat_std != 0 + assert elev_std != 0 + except AssertionError: + print("WARNING: LON LAT OR ELEV STD == 0") + lon = (lon-lon_mean)/lon_std + lat = (lat-lat_mean)/lat_std + elev = (elev-elev_mean)/elev_std + elev = np.log(elev-np.min(elev)+0.1) + + #lon = np.tanh(lon) + #lat = np.tanh(lat) + #elev = np.tanh(elev) + print("MINMAX lon: ", np.min(lon), np.max(lon)) + print("MINMAX lat: ", np.min(lat), np.max(lat)) + print("MINMAX elev: ", np.min(elev), np.max(elev)) + + lonlatelev[:,:,0] = lon + lonlatelev[:,:,1] = lat + lonlatelev[:,:,2] = elev + + lonlatelev = np.expand_dims(lonlatelev, axis=0) + lonlatelev = np.repeat(lonlatelev,data.shape[0],axis=0) + + print(f"\ndone! it has shape {lonlatelev.shape}") + + return np.concatenate((data,lonlatelev), axis=3, dtype = np.float32) + + + +def load_data(h5_path,N = 3000,lead_times = 60, concat = 7, square = (0,448,881-448,881), downsampling_rate = 2, overlap = 0, spaced=3,downsample = True, spacedepth =True,centercrop=True,box=2,printer=True, rain_step = 0.2, n_bins=512, keep_biggest = 0.8): + #15 minutes between datapoints is default --> spaced = 3 + snapshots = [] + dates = [] + all_snapshots = [] + Y_dates = [] + array_mean = 0 + means = [] + n = 0 + if not printer: + sys.stdout = open(os.devnull, 'w') + for i, array,date in h5_iterator(h5_path, N): + + if (i+1)%1000==0: + print("Loaded samples: ",n) + + if i%spaced==0: + snapshots.append(array) + dates.append(date) + means.append(np.mean(array)) + n+=1 + all_snapshots.append(array) + array_mean += np.mean(array) + + + Y_dates.append(date) + + '''print("MEAN", array_mean/n) + print("Done loading samples! \n") + n_snap = len(snapshots) + n_all_snaps = len(all_snapshots) + means = np.array(means)[0:(n_snap//concat)*concat].reshape(n_snap//concat, concat) + running_means = np.mean(means,axis=1) + n_runs = len(running_means) + n_to_keep = int(n_runs*keep_biggest) + idx_to_keep = np.argsort(running_means)[-n_to_keep:] + print(idx_to_keep) + temp_s = [] + temp_s_all = [] + temp_s_dates = [] + temp_s_all_dates = [] + print("LEN BEFORE", len(snapshots), len(all_snapshots)) + for i in idx_to_keep: + temp_s.extend(snapshots[i*concat:i*concat+concat]) + temp_s_all.extend(all_snapshots[i*concat*spaced:i*concat*spaced+concat*spaced]) + temp_s_dates.extend(dates[i*concat:i*concat+concat]) + temp_s_all_dates.extend(Y_dates[i*concat*spaced:i*concat*spaced+concat*spaced]) + snapshots = temp_s + all_snapshots = temp_s_all + dates = temp_s_dates + Y_dates = temp_s_all_dates + print("LEN AFTER", len(snapshots), len(all_snapshots)) + print(dates) + print(Y_dates) + input()''' + + + + data = np.array(snapshots) + del(snapshots) # MANAGE MEMORY + all_data = np.array(all_snapshots) + + + del(all_snapshots) # MANAGE MEMORY + print("\nDatatype data: ", data.dtype) + print("\nInput data shape: ", data.shape, " size: ", sys.getsizeof(data)) + + + x0,x1,y0,y1 = square + print(f"\nInput patch by index: xmin = {x0}, xmax = {x1}, ymin = {y0}, ymax = {y1}") + x_lim = slice(x0,x1) + y_lim = slice(y0,y1) + + center_x = (x0+x1)//2 + center_y = (y0+y1)//2 + length_x = (x1-x0)//16 #size of Y is 16 times smaller + length_y = (y1-y0)//16 #size of Y is 16 times smaller + Y_lim_x = slice(center_x-length_x//2,center_x+length_x//2) + Y_lim_y = slice(center_y-length_y//2,center_y+length_y//2) + print("SLICED x: ",Y_lim_x) + print("SLICED y: ",Y_lim_y) + Y = all_data[:,Y_lim_y,Y_lim_x] + del(all_data) #MANAGE MEMORY + print(f"\nY shape here (not ready): {Y.shape}") + + data = data[:,y_lim,x_lim] + print(f"\nSliced data to dimensions {data.shape}") + + if centercrop: #extract centercrop before downsampling, since it's high resolution + + center = extract_centercrop(data) + print(f"\nCopying centercrop with shape {center.shape}") + if downsample == True: + print("\nDownsampling with rate: ", downsampling_rate) + data = down_sampler(data) + print("\nDone downsampling!") + + + print("\nDatatype downsampled: ", data.dtype) + print("\nDownsampled data shape: ",data.shape) + if len(data.shape)<4: + data = np.expand_dims(data, axis=3) + print(f"\nAdding channel dimension to data, new shape: {data.shape}") + if centercrop: + if len(center.shape)<4: + center = np.expand_dims(center, axis=3) + print(f"\nAdding channel dimension to centercrop, new shape: {center.shape}") + if spacedepth==True: + data = space_to_depth(data,box) + + print(f"\nSpace-to-depth done! Data shape: {data.shape}") + if centercrop: + center = space_to_depth(center,box) + print(f"\nSpace-to-depth done! Centercrop shape: {center.shape}") + + if centercrop: + data = np.concatenate((data,center), axis=3) + print(f"\nConcatenating data and centercrop to dimenison: {data.shape} with shape [:,:,:,downsampled + centercrop]") + + + + + data = longlatencoding(data) + print(f"\nConcatenating data with long, lat and elevation. New shape: {data.shape}, dtype: {data.dtype}") + + data = datetime_encoder(data,dates,plotter=False) + print(f"\nEncoding datetime periodical variables (seasonally,hourly) and concatenating with data. New shape: {data.shape}, dtype: {data.dtype}") + + + + + + data = np.swapaxes(np.swapaxes(data,3,1),2,3) + print(f"\nData swapping axes to get channel first, now shape: {data.shape}") + X,Y, X_dates,Y_dates = temporal_concatenation(data,dates,Y,Y_dates,concat = concat, overlap = overlap, spaced = spaced,lead_times = lead_times) + + print(f"\nDone with temporal concatenation and target_split! Data shape: {X.shape}, target shape: {Y.shape}") + + GAIN = 0.4 + OFFSET = -30 + X[:,:,0:8] = X[:,:,0:8]*GAIN + OFFSET + + + maxx = np.max(X[:,:,0:8]) + print("\nMAX DBZ data(should be 72): ", maxx) + data_new = np.empty(X[:,:,0:8].shape) + N = data_new.shape[0] + runs = N//5000 + for run in range(0,N,5000): + data_new[run:run+5000,:,0:8] = np.log(X[run:run+5000,:,0:8]+0.01, dtype = np.float32)/4 + data_new[run:run+5000,:,0:8] = np.nan_to_num(data_new[run:run+5000,:,0:8]) + data_new[run:run+5000,:,0:8] = np.tanh(data_new[run:run+5000,:,0:8], dtype = np.float32) + + data_new[runs*5000:,:,0:8] = np.log(X[runs*5000:,:,0:8]+0.01, dtype = np.float32)/4 + data_new[runs*5000:,:,0:8] = np.nan_to_num(data_new[runs*5000:,:,0:8]) + data_new[runs*5000:,:,0:8] = np.tanh(data_new[runs*5000:,:,0:8], dtype = np.float32) + + '''data_new = np.log(data+0.01)/4 + data_new = np.nan_to_num(data_new) + data_new = np.tanh(data_new)''' + + + + for i in range(8): + try: + assert np.std(data_new[:,:,:,i]) != 0 + except AssertionError: + print(f"WARNING: CHANNEL {i} STD == 0") + data_new[:,:,i] = (data_new[:,:,i] - np.mean(data_new[:,:,i] ))/np.std(data_new[:,:,i] ) + + + print(f"\nScaling data with log(x+0.01)/4, replace NaN with 0 and apply tanh(x) and convert to data type: {data.dtype}, nbytes: {data.nbytes}, size: {data.size}") + + + Y = Y*GAIN + OFFSET + '''for i in range(0,5): + fig, ax = plt.subplots(1,2) + ax[0].imshow(X[i,0,0,:,:]) + #ax[0].imshow(np.mean(data_after_gained[i*7,42:70,42:70,4:8],axis=2)) + ax[0].set_title(X_dates[i][6]) + ax[1].imshow(Y[i,0,:,:]) + ax[1].set_title(Y_dates[i][0]) + plt.show()''' + + #print("comparing X and Y after gain:", np.mean(data_after_gained[:,:,4:8]), np.mean(Y)) + + Y_gained = np.copy(Y) + + print("MINMAX Y AFTER GAIN + OFFSET", np.min(Y), np.max(Y)) + + passer = np.mean(X[:,6,4:8,:,:],axis=1) + + Y = rain_binned(Y, n_bins = n_bins, increment = rain_step, x = passer) + + print(f"\nDone with binning targets into bins, target shape: {Y.shape}") + + + + #Remove low-rainfall data: + meaned = np.mean(X[:,:,0:4,:,:], axis=(1,2,3,4)) + idx_sorted = np.argsort(meaned) + + N = len(meaned) + '''for j in range(N-1,0,-1): + fig, ax = plt.subplots(7,1) + for k in range(7): + i = idx_sorted[j] + + im = ax[k].imshow(X[i,k,0,:,:]) + ax[k].set_title(f"MEAN: {meaned[i]:.2f}") + + + fig.subplots_adjust(right=0.8) + cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fig.colorbar(im, cax=cbar_ax) + plt.show()''' + to_keep = int(N*keep_biggest) + idx_to_keep = idx_sorted[-to_keep:] + #print(meaned) + #print(meaned[idx_to_keep]) + + X = X[idx_to_keep] + '''print(meaned[idx_sorted]) + print(meaned[idx_to_keep]) + print(np.mean(X[:,:,0:4,:,:], axis=(1,2,3,4))) + input()''' + Y = Y[idx_to_keep] + X_dates = [X_dates[i] for i in idx_to_keep] + Y_dates = [Y_dates[i] for i in idx_to_keep] + print(f"\nOnly keeping {to_keep} out of {N} samples to reduce low rainfall events. New X shape: {X.shape}") + N = X.shape[0] + '''for j in range(N-1,0,-1): + fig, ax = plt.subplots(7,2) + for k in range(7): + + + im = ax[k,0].imshow(X[j,k,0,:,:]) + ax[k,0].set_title(f"MEAN: {np.mean(X[j,k,0,:,:]):.2f}") + ax[k,1].imshow(Y[j,k,0,:,:]) + + + fig.subplots_adjust(right=0.8) + cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fig.colorbar(im, cax=cbar_ax) + plt.show()''' + + #plots all channels seperately: + '''channels = X.shape[2] + for i in range(N): + fig, axs = plt.subplots(4,4) + axs = axs.reshape(-1) + for c in range(channels): + axs[c].imshow(X[i,0,c]) + plt.show()''' + + #print(f"\nOnly keeping {to_keep} out of {N} samples to reduce low rainfall events. New X shape: {X.shape}") + + Y_thresh = np.ones(Y[:,0,:,:].shape)*-1 + + Y_thresh[np.where(Y[:,0,:,:]==1)] = 1 + + + #rain_check(X, Y_gained[idx_to_keep], Y_thresh,X_dates,Y_dates,X_dict,Y_dict, meaned) + if not printer: + sys.stdout = sys.__stdout__ + return X,Y, X_dates,Y_dates + +def rain_binned(Y, n_bins = 51,increment = 0.2, x = None): + SHAPE = Y.shape + n,leads,w,h = SHAPE + max_fall = n_bins*increment + min_dbz = np.min(Y) + max_dbz = np.max(Y) + Y[np.where(Y>70)] = 0 + + rain = (10**(Y / 10.0) / 200.0)**(1.0 / 1.6) + + print("RAIN MINMAX: ", np.min(rain), np.max(rain)) + + '''for i in range(n): + fig,ax = plt.subplots(1,2) + rain_x = (10**(x[i] / 10.0) / 200.0)**(1.0 / 1.6) + ax[0].imshow(x[i][(112//2)-14:(112//2)+14, (112//2)-14:(112//2)+14]) + ax[0].set_title("x zoomed") + ax[1].imshow(rain[i,0,:,:]) + ax[1].set_title("y") + plt.show()''' + rain_bins = np.zeros((n,leads,n_bins,w,h), dtype = np.int8) + counter = [] + for i in range(n_bins-1): + if i%100==0: print(f"Bin progress:",i) + bin_min = i*increment + bin_max = (i+1)*increment + rain_bin = np.zeros((n,leads,w,h)) #Y.shape = (None,lead_times, bin_channel, width/4,heigth/4) + idx = np.where(np.logical_and(rain>=bin_min, rain=n_bins*increment) + rain_bin[idx] = 1 + rain_bins[:,:,n_bins-1,:,:] = rain_bin + counter.append(len(idx)) + print("RAINBINS: ", rain_bins.size, " counter size: ", sum(counter)) + + print(counter) + + + return rain_bins +''' +def rain_check(X, Y,Y_thresh,x_dates,y_dates,X_dict,Y_dict,meaned): + X_mid = np.mean(X[:,:,4:8], axis = 2) #,42:70, 42:70 + N = min(X.shape[0], Y.shape[0]) + temps = X_mid.shape[1] + leads = min(Y.shape[1],temps) + + for n in range(N): + fig, axs = plt.subplots(5,temps) + axs = axs.reshape(-1) + #print(X_mid[n,0,:,:].reshape(-1)) + for i in range(temps): + #print("i: ", i, "\n", X_mid[n,i,:,:].reshape(-1)) + + print("MINMAX", np.min(X_mid[n,i,:,:]), np.max(X_mid[n,i,:,:])) + #print(X_mid[np.where(np.isnan(X_mid))] + np.random.random(X_mid[np.where(np.isnan(X_mid))].shape)) + im = axs[i].imshow(X_mid[n,i,:,:]) + + #axs[i].set_title(x_dates[n][i][-5:]) + for i in range(temps): + im = axs[i+temps].imshow(X_dict[x_dates[n][i]]) + axs[i+temps].set_title(x_dates[n][i][-5:]) + for j in range(leads): + im = axs[j+2*temps].imshow(Y[n,j,:,:]) + axs[j+2*temps].set_title(y_dates[n][j][-5:]) + for j in range(leads): + im = axs[j+3*temps].imshow(Y_thresh[n,j,:,:]) + axs[j+3*temps].set_title(y_dates[n][j][-5:]) + for j in range(leads): + im = axs[j+4*temps].imshow(Y_dict[y_dates[n][i]]) + axs[j+4*temps].set_title(x_dates[n][i][-5:]) + fig.suptitle("mean : " + str(meaned[n])) + fig.subplots_adjust(right=0.8) + cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fig.colorbar(im, cax=cbar_ax) + plt.show()''' + + +def prepare_files(h5_path = "data3_500k.h5",lead_times = 60, concat = 7, square = (0,448,881-448,881), downsampling_rate = 2, overlap = 0, spaced=3,downsample = True, spacedepth =True,centercrop=True,box=2,printer=True, rain_step = 0.2, n_bins=512): + #15 minutes between datapoints is default --> spaced = 3 + + + + batch_size = 5000 + N = 500000 + partial = 400000 + for start in range(partial,partial+100000,batch_size): + print(f"-------------------- STARTING AT {start} --------------------") + snapshots = [] + dates = [] + all_snapshots = [] + all_dates = [] + for i, array,date in h5_iterator(h5_path, batch_size, starter=start): + + if (i+1)%1000==0: + print("Loaded samples: ", i+1) + + if i%spaced==0: + snapshots.append(array) + dates.append(date) + + all_snapshots.append(array) + all_dates.append(date) + + data = np.array(snapshots) + del(snapshots) # MANAGE MEMORY + all_data = np.array(all_snapshots) + + + del(all_snapshots) # MANAGE MEMORY + print("\nDatatype data: ", data.dtype) + print("\nInput data shape: ", data.shape, " size: ", sys.getsizeof(data)) + + + x0,x1,y0,y1 = square + print(f"\nInput patch by index: xmin = {x0}, xmax = {x1}, ymin = {y0}, ymax = {y1}") + x_lim = slice(x0,x1) + y_lim = slice(y0,y1) + + center_x = (x0+x1)//2 + center_y = (y0+y1)//2 + length_x = (x1-x0)//16 #size of Y is 16 times smaller + length_y = (y1-y0)//16 #size of Y is 16 times smaller + if length_x==length_y: + Y_lim_x = slice(center_x-length_x//2,center_x+length_x//2) + Y_lim_y = slice(center_y-length_y//2,center_y+length_y//2) + else: + Y_lim_x = slice(center_x-length_x//2,center_x+length_x//2) + Y_lim_y = slice(center_x-length_x//2,y1-(center_x-length_x//2)) + print("SLICED Y in x: ",Y_lim_x) + print("SLICED Y in y: ",Y_lim_y) + Y = all_data[:,Y_lim_y,Y_lim_x] + del(all_data) #MANAGE MEMORY + print(f"\nY shape here (not ready): {Y.shape}") + + data = data[:,y_lim,x_lim] + print(f"\nSliced data to dimensions {data.shape}") + + if centercrop: #extract centercrop before downsampling, since it's high resolution + + center = extract_centercrop(data) + print(f"\nCopying centercrop with shape {center.shape}") + if downsample == True: + print("\nDownsampling with rate: ", downsampling_rate) + data = down_sampler(data) + print("\nDone downsampling!") + + + print("\nDatatype downsampled: ", data.dtype) + print("\nDownsampled data shape: ",data.shape) + if len(data.shape)<4: + data = np.expand_dims(data, axis=3) + print(f"\nAdding channel dimension to data, new shape: {data.shape}") + if centercrop: + if len(center.shape)<4: + center = np.expand_dims(center, axis=3) + print(f"\nAdding channel dimension to centercrop, new shape: {center.shape}") + if spacedepth==True: + data = space_to_depth(data,box) + + print(f"\nSpace-to-depth done! Data shape: {data.shape}") + if centercrop: + center = space_to_depth(center,box) + print(f"\nSpace-to-depth done! Centercrop shape: {center.shape}") + + if centercrop: + data = np.concatenate((data,center), axis=3) + print(f"\nConcatenating data and centercrop to dimenison: {data.shape} with shape [:,:,:,downsampled + centercrop]") + + + + + data = longlatencoding(data) + print(f"\nConcatenating data with long, lat and elevation. New shape: {data.shape}, dtype: {data.dtype}") + + data = datetime_encoder(data,dates,plotter=False) + print(f"\nEncoding datetime periodical variables (seasonally,hourly) and concatenating with data. New shape: {data.shape}, dtype: {data.dtype}") + + + + + + data = np.swapaxes(np.swapaxes(data,3,1),2,3) + print(f"\nData swapping axes to get channel first, now shape: {data.shape}") + X,Y, X_dates,Y_dates = temporal_concatenation(data,dates,Y,all_dates,concat = concat, overlap = overlap, spaced = spaced,lead_times = lead_times) + + print(f"\nDone with temporal concatenation and target_split! Data shape: {X.shape}, target shape: {Y.shape}") + + GAIN = 0.4 + OFFSET = -30 + X[:,:,0:8] = X[:,:,0:8]*GAIN + OFFSET + + + maxx = np.max(X[:,:,0:8]) + print("\nMAX DBZ data(should be 72): ", maxx) + data_new = np.empty(X[:,:,0:8].shape) + N = data_new.shape[0] + runs = N//5000 + for run in range(0,N,5000): + data_new[run:run+5000,:,0:8] = np.log(X[run:run+5000,:,0:8]+0.01, dtype = np.float32)/4 + data_new[run:run+5000,:,0:8] = np.nan_to_num(data_new[run:run+5000,:,0:8]) + data_new[run:run+5000,:,0:8] = np.tanh(data_new[run:run+5000,:,0:8], dtype = np.float32) + + data_new[runs*5000:,:,0:8] = np.log(X[runs*5000:,:,0:8]+0.01, dtype = np.float32)/4 + data_new[runs*5000:,:,0:8] = np.nan_to_num(data_new[runs*5000:,:,0:8]) + data_new[runs*5000:,:,0:8] = np.tanh(data_new[runs*5000:,:,0:8], dtype = np.float32) + #data[np.where(data<0)] = 0 + '''data_new = np.log(data+0.01)/4 + data_new = np.nan_to_num(data_new) + data_new = np.tanh(data_new)''' + + + + for i in range(8): + try: + assert np.std(data_new[:,:,:,i]) != 0 + except AssertionError: + print(f"WARNING: CHANNEL {i} STD == 0") + data_new[:,:,i] = (data_new[:,:,i] - np.mean(data_new[:,:,i] ))/np.std(data_new[:,:,i] ) + + + print(f"\nScaling data with log(x+0.01)/4, replace NaN with 0 and apply tanh(x) and convert to data type: {data.dtype}, nbytes: {data.nbytes}, size: {data.size}") + + Y = Y*GAIN + OFFSET + '''for i in range(0,5): + fig, ax = plt.subplots(1,2) + ax[0].imshow(X[i,0,0,:,:]) + #ax[0].imshow(np.mean(data_after_gained[i*7,42:70,42:70,4:8],axis=2)) + ax[0].set_title(X_dates[i][6]) + ax[1].imshow(Y[i,0,:,:]) + ax[1].set_title(Y_dates[i][0]) + plt.show()''' + + #print("comparing X and Y after gain:", np.mean(data_after_gained[:,:,4:8]), np.mean(Y)) + + Y_gained = np.copy(Y) + + print("MINMAX Y AFTER GAIN + OFFSET", np.min(Y), np.max(Y)) + + passer = np.mean(X[:,6,4:8,:,:],axis=1) + + Y = rain_binned(Y, n_bins = n_bins, increment = rain_step, x = passer) + + print(f"\nDone with binning targets into bins, target shape: {Y.shape}") + + + + #Remove low-rainfall data: + to_mean = X[:,:,0:4,:,:] + + idx_to_remove = np.where(to_mean > 70) + to_mean[idx_to_remove] = 0 + means = np.mean(to_mean, axis=(1,2,3,4)) + + + '''for j in range(N-1,0,-1): + fig, ax = plt.subplots(7,1) + for k in range(7): + i = idx_sorted[j] + + im = ax[k].imshow(X[i,k,0,:,:]) + ax[k].set_title(f"MEAN: {meaned[i]:.2f}") + + + fig.subplots_adjust(right=0.8) + cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fig.colorbar(im, cax=cbar_ax) + plt.show()''' + + '''print(meaned[idx_sorted]) + print(meaned[idx_to_keep]) + print(np.mean(X[:,:,0:4,:,:], axis=(1,2,3,4))) + input()''' + + + '''for j in range(N-1,0,-1): + fig, ax = plt.subplots(7,2) + for k in range(7): + + + im = ax[k,0].imshow(X[j,k,0,:,:]) + ax[k,0].set_title(f"MEAN: {np.mean(X[j,k,0,:,:]):.2f}") + ax[k,1].imshow(Y[j,k,0,:,:]) + + + fig.subplots_adjust(right=0.8) + cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fig.colorbar(im, cax=cbar_ax) + plt.show()''' + + #plots all channels seperately: + '''channels = X.shape[2] + for i in range(N): + fig, axs = plt.subplots(4,4) + axs = axs.reshape(-1) + for c in range(channels): + axs[c].imshow(X[i,0,c]) + plt.show()''' + + #print(f"\nOnly keeping {to_keep} out of {N} samples to reduce low rainfall events. New X shape: {X.shape}") + path_loc = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/data_moving_frame/" + for i in range(X.shape[0]): + mean_str = str(means[i]) + mean_str = mean_str[0:max(5,len(mean_str))] + start_date = X_dates[i][0] + full_path_X = path_loc + mean_str +"_" + start_date +"_"+"X"+ ".npy" + full_path_Y = path_loc + mean_str +"_" + start_date +"_"+"Y"+ ".npy" + np.save(full_path_X,X[i]) + np.save(full_path_Y,Y[i]) + + + #rain_check(X, Y_gained[idx_to_keep], Y_thresh,X_dates,Y_dates,X_dict,Y_dict, meaned) + if not printer: + sys.stdout = sys.__stdout__ + +def prepare_batches(batch_size = 8): + data_path = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/data_exclusive/" + file_list = os.listdir(data_path) + means = [] + dates = [] + names = [] + for name in file_list: + if not name[-4:] == ".npy": continue + if name[-5:] == "Y.npy": continue + names.append(name) + name = name[:-4] #remove .npy + name = name.split("/")[-1] #remove directories + #print(name) + mean = name.split("_")[0] + date = "_".join(name.split("_")[1:-1]) + print(mean, date) + means.append(float(mean)) + dates.append(date) + + means = np.array(means) + idx_sorted = np.argsort(-means) + names_sorted = [names[idx] for idx in idx_sorted] + means_sorted = means[idx_sorted] + dates_sorted = [dates[idx] for idx in idx_sorted] + print("MEANS: ", means[idx_sorted]) + print(len(means)) + N = len(means) + + batches_path = f"/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/batch_{batch_size}/" + paths = {} + for ID in ["train", "val", "test"]: + #os.makedirs(batches_path+ID + "/") + paths[ID] = batches_path + ID + "/" + train, val, test = (10, 2, 2) + leng = sum([train, val, test]) + + for i in range(0,N//batch_size): + if i%10==0: + print(f"batch {i} / {N//batch_size}") + X = np.empty((batch_size, 7, 15, 112,112)) + Y = np.empty((batch_size, 60, 128, 28,28)) + + if 0 <= i%leng < train: ID = "train" + elif train <= i%leng < train+val: ID = "val" + elif train+val <= i%leng: ID = "test" + + batch_mean = np.mean(means_sorted[i*batch_size:(i+1)*batch_size]) + batch_names_X = [data_path + file_name for file_name in names_sorted[i*batch_size:(i+1)*batch_size]] + batch_names_Y = [file_name.replace("X","Y") for file_name in batch_names_X] + + for k,(x,y) in enumerate(zip(batch_names_X,batch_names_Y)): + X[k] = np.load(x) + Y[k] = np.load(y) + + np.save(paths[ID]+str(batch_mean)[0:5]+ "_" +dates_sorted[i*batch_size]+"_X.npy",X) + np.save(paths[ID]+str(batch_mean)[0:5]+ "_" +dates_sorted[i*batch_size]+"_Y.npy",Y) + +def prepare_new_mean_split(): + data_path = "/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/data_exclusive/" + file_list = os.listdir(data_path) + non_zeros = np.zeros((128,)) + dates = [] + names = [] + for k,name in enumerate(file_list): + if k%100 == 0: print(f"{k} / {len(file_list)} samples meaned!") + if not name[-4:] == ".npy": continue + if name[-5:] == "Y.npy": continue + Y_name = name.replace("X","Y") + Y_here = np.load(data_path+Y_name) #Y.shape = (60, 128, 28, 28) + + non_zero = np.sum(Y_here[:,1:,:,:], axis = (0,1,2,3)) + + names.append(name) + name = name[:-4] #remove .npy + + + date = "_".join(name.split("_")[1:-1]) + + non_zeros.append(non_zero) + dates.append(date) + + non_zeros = np.array(non_zeros) + idx_sorted = np.argsort(-non_zeros) + np.save("/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/no_zeros.npy",non_zeros) + + names_sorted = [names[idx] for idx in idx_sorted] + non_zeros_sorted = non_zeros[idx_sorted] + dates_sorted = [dates[idx] for idx in idx_sorted] + + print("MEANS: ", non_zeros[idx_sorted]) + print(len(non_zeros)) + N = len(non_zeros) + + y_sort_path = f"/proj/berzelius-2022-18/users/sm_valfa/metnet_pylight/metnet/bin_sorted_data/" + paths = {} + for ID in ["train", "val", "test"]: + #os.makedirs(y_sort_path+ID + "/") + paths[ID] = y_sort_path + ID + "/" + train, val, test = (10, 2, 2) + leng = sum([train, val, test]) + + for i in range(0,N): + if i%100==0: + print(f" {i} / {N} samples done!") + + + if 0 <= i%leng < train: ID = "train" + elif train <= i%leng < train+val: ID = "val" + elif train+val <= i%leng: ID = "test" + + + X_name = data_path + names_sorted[i] + Y_name = X_name.replace("X","Y") + + X = np.load(X_name) + Y = np.load(Y_name) + + + + np.save(paths[ID]+str(non_zeros_sorted[i])+ "_" +dates_sorted[i]+"_X.npy",X) + np.save(paths[ID]+str(non_zeros_sorted[i])+ "_" +dates_sorted[i]+"_Y.npy",Y) + +def plot_rain_distribution(h5_path = "data3_500k.h5", maxN = 1000): + mean_rain_month = [[] for i in range(12)] + + mean_rain_hour = [[] for i in range(24)] + + for i, array,date in h5_iterator(h5_path, maxN = maxN): + + if (i+1)%1000==0: + print("Counted samples: ", i+1) + monthly_mean = np.array([np.mean(month) for month in mean_rain_month]) + hourly_mean = np.array([np.mean(hour) for hour in mean_rain_hour]) + print(monthly_mean) + print(hourly_mean) + #np.save("monthly_mean.npy", monthly_mean) + #np.save("hourly_mean.npy", hourly_mean) + + fig, axs = plt.subplots(1,2, subplot_kw={'projection': 'polar'}) + equals = np.linspace(0, 360, 24, endpoint=False) + + axs[0].scatter(np.deg2rad(equals), hourly_mean) + + # Set the circumference labels + axs[0].set_xticks(np.linspace(0, 2*np.pi, 24, endpoint=False)) + axs[0].set_xticklabels(range(24)) + + # Make the labels go clockwise + axs[0].set_theta_direction(-1) + + # Place 0 at the top + axs[0].set_theta_offset(np.pi/2.0) + + #plt.show() + + + + equals = np.linspace(0, 360, 12, endpoint=False) #np.arange(24) + #ones = np.ones(24) + axs[1].scatter(np.deg2rad(equals), monthly_mean) + + # Set the circumference labels + axs[1].set_xticks(np.linspace(0, 2*np.pi, 12, endpoint=False)) + axs[1].set_xticklabels(["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]) + + # Make the labels go clockwise + axs[1].set_theta_direction(-1) + + # Place 0 at the top + axs[1].set_theta_offset(np.pi/2.0) + + axs[0].set_title("Hourly rain rates") + axs[1].set_title("Monthly rain rates") + plt.savefig("yearly_rates_plots") + plt.show() + listed = date.split("_") + y, m, d, hour, minute = [int(a) for a in listed] + + array = 0.4*array - 30 + rain = (10**(array / 10.0) / 200.0)**(1.0 / 1.6) + + idx = np.where(array<70) + + if len(idx)==0: continue + mean_rain = np.mean(rain[idx]) + mean_rain_month[m].append(mean_rain) + mean_rain_hour[hour].append(mean_rain) + + monthly_mean = [np.mean(month) for month in mean_rain_month] + hourly_mean = [np.mean(hour) for hour in mean_rain_hour] + print(hourly_mean) + np.save("monthly_mean.npy", monthly_mean) + np.save("hourly_mean.npy", hourly_mean) + ax = plt.subplot(111, polar=True) + equals = np.linspace(0, 360, 24, endpoint=False) #np.arange(24) + ones = np.ones(24) + ax.scatter(np.deg2rad(equals), hourly_mean) + + # Set the circumference labels + ax.set_xticks(np.linspace(0, 2*np.pi, 24, endpoint=False)) + ax.set_xticklabels(range(24)) + + # Make the labels go clockwise + ax.set_theta_direction(-1) + + # Place 0 at the top + ax.set_theta_offset(np.pi/2.0) + + plt.show() + + + ax = plt.subplot(111, polar=True) + equals = np.linspace(0, 360, 12, endpoint=False) #np.arange(24) + #ones = np.ones(24) + ax.scatter(np.deg2rad(equals), monthly) + + # Set the circumference labels + ax.set_xticks(np.linspace(0, 2*np.pi, 12, endpoint=False)) + ax.set_xticklabels(["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]) + + # Make the labels go clockwise + ax.set_theta_direction(-1) + + # Place 0 at the top + ax.set_theta_offset(np.pi/2.0) + + plt.show() + +if __name__=="__main__": + plot_rain_distribution(maxN = None) + square = (0,448,0,880) + #prepare_files(square=square) + #prepare_new_mean_split() + + diff --git a/pyvenv.cfg b/pyvenv.cfg new file mode 100644 index 0000000..479ee9f --- /dev/null +++ b/pyvenv.cfg @@ -0,0 +1,3 @@ +home = C:\Users\valte\.conda\envs\MetNet +include-system-site-packages = false +version = 3.10.0 diff --git a/setup.py b/setup.py index 4888920..28393c0 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setup( name="metnet", - version="2.0.5", + version="2.0.4", packages=find_packages(), url="https://github.com/openclimatefix/metnet", license="MIT License", diff --git a/test_metnet_lightning.py b/test_metnet_lightning.py new file mode 100644 index 0000000..26c6ae1 --- /dev/null +++ b/test_metnet_lightning.py @@ -0,0 +1,124 @@ +from metnet.models.metnet_pylight import MetNetPylight +import torch +import torch.nn.functional as F +from data_prep.prepare_data_MetNet import load_data +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +import wandb +from pytorch_lightning.callbacks import DeviceStatsMonitor +from pytorch_lightning import seed_everything +import time +import numpy as np +import matplotlib.pyplot as plt + +if False: + + thresh = 1 + control = np.load(f"f1_control_thresh_{thresh}.npy") + fig, ax = plt.subplots(1,1) + ax.plot(control, label="persistence") + print(control) + f1s = np.load(f"f1_threshed_0.5_thresh_{thresh}.npy") + ax.plot(f1s, label=f"No aggregation model") + print(f1s) + fig.suptitle(f"F1-score for rainfall threshed at {round(thresh*0.2,3)} mm/h") + + ax.set_xlabel("Lead time") + ax.set_ylabel("F1") + ax.legend() + plt.show() +if False: + print("hej") + N = 3606 + prob_threshes = [0, 0.1,0.2,0.3, 0.4, 0.5] + control = np.load(f"f1_control_N_{N}.npy") + fig, ax = plt.subplots(1,1) + ax.plot(control, label="persistence") + for p in prob_threshes: + f1s = np.load(f"f1_threshed_{p}_N_{N}.npy") + ax.plot(f1s, label=f"P(rate>=0.2)>{p}") + + fig.suptitle("Different probabillity thresholds for precipitation") + ax.set_xlabel("Lead time") + ax.set_ylabel("F1") + ax.legend() + plt.show() + + +wandb.login() + +'''model = MetNetPylight( + hidden_dim=256, #384 original paper + forecast_steps=60, #240 original paper + input_channels=15, #46 original paper, hour/day/month = 3, lat/long/elevation = 3, GOES+MRMS = 40 + output_channels=128, #512 + input_size=112, # 112 + n_samples = None, #None = All ~ 23000 + num_workers = 4, + batch_size = None, #None = 8 + learning_rate = 1e-2, + num_att_layers = 8, + plot_every = None, #Plot every global_step + rain_step = 0.2, + momentum = 0.9, + att_heads=16, + keep_biggest = 0.3, + )''' +#PATH_cp = "epoch=653-step=90251.ckpt" +#PATH_cp = "epoch=429-step=59339.ckpt" +#PATH_cp = "epoch=276-step=14680.ckpt" +#PATH_cp = "epoch=61-step=3285.ckpt" +#PATH_cp = "epoch=464-step=24644.ckpt" + +#PATH_cp = "epoch=471-step=25015.ckpt" +#PATH_cp = "epoch=33-step=3569.ckpt" +#PATH_cp = "epoch=430-step=22842.ckpt" +# +#PATH_cp = "8leadtimessecond8h.ckpt" +#PATH_cp = "epoch=242-step=16766.ckpt" +PATH_cp = "fullrun_1.ckpt" +#PATH_cp = "best_60_leadtime.ckpt" +#PATH_cp = "best_single_leadtime.ckpt" +#PATH_cp = "no agg network.ckpt" + + +model = MetNetPylight.load_from_checkpoint(PATH_cp) +#model.forecast_steps=8 +#model.leadtime_spacing = 3 +#model.keep_biggest = 0.15 +model.thresh = 1 +model.batch_size = 8 +model.n_samples = None +model.testing = True +print(model.forecast_steps) +#model.plot_every = None +#MetNetPylight expects already preprocessed data. Can be change by uncommenting the preprocessing step. +#print(model) +model.TPs = np.zeros(model.forecast_steps) +model.FNs = np.zeros(model.forecast_steps) +model.FPs = np.zeros(model.forecast_steps) +model.TPs_control = np.zeros(model.forecast_steps) +model.FNs_control = np.zeros(model.forecast_steps) +model.FPs_control = np.zeros(model.forecast_steps) + +model.f1s = [[] for _ in range(model.forecast_steps)] +model.f1s_control = [[] for _ in range(model.forecast_steps)] +model.f1_count = [0 for _ in range(model.forecast_steps)] +model.avg_y_img = [0 for _ in range(model.forecast_steps)] +model.avg_y_hat_img = [0 for _ in range(model.forecast_steps)] +model.skipped = 0 +model.not_skipped = 0 + + +wandb_logger = WandbLogger(project="lit-wandb") + + + +trainer = pl.Trainer(track_grad_norm = 2, max_epochs=1000, gpus=-1,log_every_n_steps=10, logger = wandb_logger,strategy="ddp") + + + +start_time = time.time() +trainer.test(model) +print("--- %s seconds ---" % (time.time() - start_time)) +wandb.finish()