Skip to content

Commit

Permalink
added basic inference scripts in inference/ and animate_using_ffmpeg_…
Browse files Browse the repository at this point in the history
…scripts
  • Loading branch information
Rishabh Gupta committed Sep 29, 2023
1 parent 93360c1 commit 886f26d
Show file tree
Hide file tree
Showing 9 changed files with 974 additions and 422 deletions.
111 changes: 111 additions & 0 deletions animate_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import h5py
import numpy as np
import sys, os


import random
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import matplotlib.animation as animation
import logging

# from matplotlib import animation
import matplotlib.animation
import matplotlib.pyplot as plt
import numpy as np


plt.rcParams['figure.dpi'] = 400
plt.ioff()
plt.rcParams['animation.ffmpeg_path'] = '/scratch/gilbreth/gupt1075/FourCastNet/'



path = "/depot/gdsp/data/gupt1075/fourcastnet/data/FCN_ERA5_data_v0/train/"
filename2 = "/scratch/gilbreth/gupt1075/ERA5_expts_gtc/autoregressive_predictions_z500.h5"
filename = os.path.join(path,"1979.h5")


list1 = []
with h5py.File(filename2, "r") as hf:
logging.warning( f" {hf.keys()} " )
ndarray = np.array(hf["fields"][:1400, 0])
list1.append(ndarray)

data = list1[0]
logging.warning(f"data_shape {data.shape}")



fig = plt.figure( figsize=(12,12) )

a = data[0]
im = plt.imshow(a, interpolation='none', aspect='auto')

def animate_func(i):
if i % fps == 0:
print( '.', end ='' )
im.set_array(data[i])
return [im]

anim = animation.FuncAnimation(fig, animate_func, frames = nSeconds * fps,interval = 1000 / fps, repeat=True )
writergif = animation.PillowWriter(fps=30)

# extra_args=['-vcodec', 'libx264']
anim.save( "input.gif" , writer=writergif)
# anim.save('input_gif.mp4', fps=fps )

logging.warning('Done!')








# with h5py.File(filename, "r") as f:
# # Print all root level object names (aka keys)
# # these can be group or dataset names
# print(f"Keys: {f.keys()} ")
# # get first object name/key; may or may NOT be a group
# a_group_key = list(f.keys())[0]

# # get the object type for a_group_key: usually group or dataset
# print(type(f[a_group_key]))

# # If a_group_key is a group name,
# # this gets the object names in the group and returns as a list
# data = list(f[a_group_key])

# # If a_group_key is a dataset name,
# # this gets the dataset values and returns as a list
# data = list(f[a_group_key])
# # preferred methods to get dataset values:
# ds_obj = f[a_group_key] # returns as a h5py dataset object
# ds_arr = f[a_group_key][()] # returns as a numpy array




# dset = f['key'][:]
# img = Image.fromarray(dset.astype("uint8"), "RGB")
# img.save("./test.png")

# print(f" ds_obj: {ds_obj} ds_arr: {ds_arr} ")



fps = 30
nSeconds = 5

# hf = h5py.File(filename, 'r')
# ndarray = np.array(hf["fields"][:])
# print(ndarray.shape)
# First set up the figure, the axis, and the plot element we want to animate


# plt.show() # Not required, it seems!
76 changes: 76 additions & 0 deletions bash_scripts/inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/bin/bash
# change the directory path of model run-time output and error messages to your own
#SBATCH --output=/scratch/gilbreth/gupt1075/run_infer_fourcastnet.out
#SBATCH --error=/scratch/gilbreth/gupt1075/run_infer_fourcastnet.err
# The file name of this submission file, so it's easier to track jobs
# filename: submit_run_model_example.sub
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=32
#SBATCH --gres=gpu:1
#SBATCH --time=23:00:00
# partner queue has a 24-hour limit
#SBATCH -A gdsp-k
#SBATCH -C "v100|a100"
# Job name, it will show up when you track this job
#SBATCH -J fourcastnet_job
# Use your email address so that you will receive email notifications about the job begin, end, or fail status
# To submit the job via command line:$ sbatch submit_run_model_example.sub
# To check status of the submitted job:$ squeue -u yourUserID

module --force purge
unset PYTHONPATH
module load anaconda/5.3.1-py37
module load cuda/11.7.0
module load cudnn/cuda-11.7_8.6
module use /depot/gdsp/etc/modules
module load utilities monitor
module load rcac

module list
export PRECXX11ABI=1
export CUDA="11.7"

echo $PYTHONPATH



# track per-code GPU load
monitor gpu percent --all-cores >gpu-percent.log &
GPU_PID=$!


# track memory usage
monitor gpu memory >gpu-memory.log &
MEM_PID=$!


# track per-code CPU load
monitor cpu percent --all-cores >cpu-percent.log &
CPU_PID=$!

# track memory usage
monitor cpu memory >cpu-memory.log &
MEM_PID=$!




# Loading anaconda environment
source /apps/spack/gilbreth/apps/anaconda/5.3.1-py37-gcc-4.8.5-7vvmykn/etc/profile.d/conda.sh
conda activate pytorch


# Change this directory to where you save the model-related files such as run_model.py
# cd /scratch/gilbreth/wwtung/FourCastNet/

python /scratch/gilbreth/gupt1075/FourCastNet/inference/inference.py \
--config=AFNO.yaml
--run_num=0 \
--weights '/scratch/gilbreth/gupt1075/model_weights/FCN_weights_v0/backbone.ckpt' \
--override_dir '/scratch/gilbreth/gupt1075/ERA5_expts_gtc/'






74 changes: 74 additions & 0 deletions bash_scripts/inference_precip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/bin/bash
# change the directory path of model run-time output and error messages to your own
#SBATCH --output=/scratch/gilbreth/gupt1075/run_infer_fourcastnet.out
#SBATCH --error=/scratch/gilbreth/gupt1075/run_infer_fourcastnet.err
# The file name of this submission file, so it's easier to track jobs
# filename: submit_run_model_example.sub
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=32
#SBATCH --gres=gpu:1
#SBATCH --time=23:00:00
# partner queue has a 24-hour limit
#SBATCH -A gdsp-k
#SBATCH -C "v100|a100"
# Job name, it will show up when you track this job
#SBATCH -J fourcastnet_job
# Use your email address so that you will receive email notifications about the job begin, end, or fail status
# To submit the job via command line:$ sbatch submit_run_model_example.sub
# To check status of the submitted job:$ squeue -u yourUserID

module --force purge
unset PYTHONPATH
module load anaconda/5.3.1-py37
module load cuda/11.7.0
module load cudnn/cuda-11.7_8.6
module use /depot/gdsp/etc/modules
module load utilities monitor
module load rcac

module list
export PRECXX11ABI=1
export CUDA="11.7"

echo $PYTHONPATH



# track per-code GPU load
monitor gpu percent --all-cores >gpu-percent.log &
GPU_PID=$!


# track memory usage
monitor gpu memory >gpu-memory.log &
MEM_PID=$!


# track per-code CPU load
monitor cpu percent --all-cores >cpu-percent.log &
CPU_PID=$!

# track memory usage
monitor cpu memory >cpu-memory.log &
MEM_PID=$!




# Loading anaconda environment
source /apps/spack/gilbreth/apps/anaconda/5.3.1-py37-gcc-4.8.5-7vvmykn/etc/profile.d/conda.sh
conda activate pytorch


# Change this directory to where you save the model-related files such as run_model.py
cd /scratch/gilbreth/wwtung/FourCastNet/



python /scratch/gilbreth/gupt1075/FourCastNet/inference/inference_precip.py \
--config=precip \
--run_num=0 \
-weights '/scratch/gilbreth/gupt1075/model_weights/FCN_weights_v0/precip.ckpt' \
--override_dir '/scratch/gilbreth/gupt1075/ERA5_expts_gtc/precip/'


27 changes: 17 additions & 10 deletions config/AFNO.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,27 @@ afno_backbone: &backbone
scheduler: 'CosineAnnealingLR'
in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
orography: !!bool False
orography_path: None
exp_dir: '/pscratch/sd/s/shas1693/results/era5_wind'
train_data_path: '/pscratch/sd/s/shas1693/data/era5/train'
valid_data_path: '/pscratch/sd/s/shas1693/data/era5/test'
inf_data_path: '/pscratch/sd/s/shas1693/data/era5/out_of_sample'
time_means_path: '/pscratch/sd/s/shas1693/data/era5/time_means.npy'
global_means_path: '/pscratch/sd/s/shas1693/data/era5/global_means.npy'
global_stds_path: '/pscratch/sd/s/shas1693/data/era5/global_stds.npy'

orography: !!bool True
orography_path: '/scratch/gilbreth/wwtung/FourCastNet/data/static/orography.h5'

exp_dir: '/scratch/gilbreth/gupt1075/ERA5_expts_gtc/'
train_data_path: '/scratch/gilbreth/wwtung/FourCastNet/data/train'
valid_data_path: '/scratch/gilbreth/wwtung/FourCastNet/data/test'
inf_data_path: '/scratch/gilbreth/wwtung/FourCastNet/data/out_of_sample' # test set path for inference
time_means_path: '/scratch/gilbreth/wwtung/FourCastNet/additional/stats_v0/time_means.npy'
global_means_path: '/scratch/gilbreth/wwtung/FourCastNet/additional/stats_v0/global_means.npy'
global_stds_path: '/scratch/gilbreth/wwtung/FourCastNet/additional/stats_v0/global_stds.npy'






afno_backbone_orography: &backbone_orography
<<: *backbone
orography: !!bool True
orography_path: '/pscratch/sd/s/shas1693/data/era5/static/orography.h5'
orography_path: '/scratch/gilbreth/wwtung/FourCastNet/data/static/orography.h5'

afno_backbone_finetune:
<<: *backbone
Expand Down
Loading

0 comments on commit 886f26d

Please sign in to comment.