From 4be294c885d53fba7b8a56079d95f9fb89539a62 Mon Sep 17 00:00:00 2001 From: xaviernogueira Date: Thu, 14 Dec 2023 15:21:35 -0500 Subject: [PATCH 1/6] Removing numba -> actual slows things down substantially, poor use case for JIT without pre-defined iterations --- src/clearwater_modules/tsm/processes.py | 34 +++++++------------------ 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/src/clearwater_modules/tsm/processes.py b/src/clearwater_modules/tsm/processes.py index f8093c2..cdc6e4d 100644 --- a/src/clearwater_modules/tsm/processes.py +++ b/src/clearwater_modules/tsm/processes.py @@ -1,6 +1,5 @@ """JIT compiled processes for the heat model.""" import warnings -import numba import numpy as np import xarray as xr from clearwater_modules.shared.processes import ( @@ -8,7 +7,6 @@ ) -@numba.njit def air_temp_k( air_temp_c: xr.DataArray, ) -> xr.DataArray: @@ -20,7 +18,6 @@ def air_temp_k( return celsius_to_kelvin(air_temp_c) -@numba.njit def water_temp_k( water_temp_c: xr.DataArray, ) -> xr.DataArray: @@ -32,7 +29,6 @@ def water_temp_k( return celsius_to_kelvin(water_temp_c) -@numba.njit def mixing_ratio_air( eair_mb: xr.DataArray, pressure_mb: xr.DataArray, @@ -46,7 +42,6 @@ def mixing_ratio_air( return 0.622 * eair_mb / (pressure_mb - eair_mb) -@numba.njit def density_air( pressure_mb: xr.DataArray, air_temp_k: xr.DataArray, @@ -67,7 +62,6 @@ def density_air( ) -@numba.njit def emissivity_air( air_temp_k: xr.DataArray, ) -> xr.DataArray: @@ -79,7 +73,6 @@ def emissivity_air( return 0.00000937 * air_temp_k**2.0 -@numba.njit def wind_function( ri_function: xr.DataArray, wind_a: xr.DataArray, @@ -105,7 +98,6 @@ def wind_function( ) -@numba.njit def q_latent( pressure_mb: xr.DataArray, density_water: xr.DataArray, @@ -133,7 +125,6 @@ def q_latent( ) -@numba.njit def q_sensible( wind_kh_kw: xr.DataArray, cp_air: xr.DataArray, @@ -163,7 +154,6 @@ def q_sensible( ) -@numba.njit def q_sediment( use_sed_temp: xr.DataArray, pb: xr.DataArray, @@ -172,7 +162,7 @@ def q_sediment( h2: xr.DataArray, sed_temp_c: xr.DataArray, water_temp_c: xr.DataArray, -) -> xr.DataArray: +) -> np.ndarray: """Sediment heat flux (W/m^2). Args: @@ -195,14 +185,13 @@ def q_sediment( ) -@numba.njit def dTdt_sediment_c( use_sed_temp: xr.DataArray, alphas: xr.DataArray, h2: xr.DataArray, water_temp_c: xr.DataArray, sed_temp_c: xr.DataArray, -) -> xr.DataArray: +) -> np.ndarray: """Sediments temperature change (C). Args: @@ -219,7 +208,6 @@ def dTdt_sediment_c( ) -@numba.njit def mf_d_esat_dT( water_temp_k: xr.DataArray, a1: xr.DataArray, @@ -250,7 +238,7 @@ def mf_d_esat_dT( # Define functions to be used in the latent heat formulation # ----------------------------------------------------------------------------------- -@numba.njit + def mf_q_longwave_down( air_temp_k: xr.DataArray, emissivity_air: xr.DataArray, @@ -272,7 +260,7 @@ def mf_q_longwave_down( return (1.0 + 0.17 * cloudiness**2) * emissivity_air * stefan_boltzmann * air_temp_k**4.0 -@numba.njit + def mf_q_longwave_up( water_temp_k: xr.DataArray, emissivity_water: xr.DataArray, @@ -287,7 +275,7 @@ def mf_q_longwave_up( return emissivity_water * stefan_boltzmann * water_temp_k**4.0 -@numba.njit + def mf_esat_mb( water_temp_k: xr.DataArray, a0: xr.DataArray, @@ -325,7 +313,7 @@ def mf_esat_mb( # Temperature conversion functions -@numba.njit + def ri_number( gravity: xr.DataArray, density_air: xr.DataArray, @@ -395,7 +383,6 @@ def ri_function(ri_number: xr.DataArray) -> np.ndarray: # ))))) -@numba.njit def mf_latent_heat_vaporization(water_temp_k: xr.DataArray) -> xr.DataArray: """ Compute the latent heat of vaporization (W/m2) as a function of water temperature (Kelvin) @@ -404,7 +391,7 @@ def mf_latent_heat_vaporization(water_temp_k: xr.DataArray) -> xr.DataArray: return 2499999 - 2385.74 * water_temp_k -@numba.njit + def mf_density_water(water_temp_c: xr.DataArray) -> xr.DataArray: """ Compute density of water (kg/m3) as a function of water temperature (Celsius) @@ -428,7 +415,6 @@ def mf_density_water(water_temp_c: xr.DataArray) -> xr.DataArray: ) -@numba.njit def mf_density_air_sat(water_temp_k: xr.DataArray, esat_mb: float, pressure_mb: float) -> xr.DataArray: """ Compute the density of saturated air at water surface temperature. @@ -448,7 +434,7 @@ def mf_density_air_sat(water_temp_k: xr.DataArray, esat_mb: float, pressure_mb: return 0.348 * (pressure_mb / water_temp_k) * (1.0 + mixing_ratio_sat) / (1.0 + 1.61 * mixing_ratio_sat) -def mf_cp_water(water_temp_c: xr.DataArray) -> xr.DataArray: +def mf_cp_water(water_temp_c: xr.DataArray) -> np.ndarray: """ Compute the specific heat of water (J/kg/K) as a function of water temperature (Celsius). This is used in computing the source/sink term. @@ -483,7 +469,7 @@ def mf_cp_water(water_temp_c: xr.DataArray) -> xr.DataArray: ) -@numba.njit + def q_net( q_sensible: xr.DataArray, q_latent: xr.DataArray, @@ -512,7 +498,6 @@ def q_net( ) -@numba.njit def dTdt_water_c( q_net: xr.DataArray, surface_area: xr.DataArray, @@ -536,7 +521,6 @@ def dTdt_water_c( ) -@numba.njit def t_water_c( water_temp_c: xr.DataArray, dTdt_water_c: xr.DataArray, From 5b3f1695a757da5e638d2e88cf1539acc148d44e Mon Sep 17 00:00:00 2001 From: xaviernogueira Date: Thu, 14 Dec 2023 15:53:53 -0500 Subject: [PATCH 2/6] Revert "Removing numba -> actual slows things down substantially, poor use case for JIT without pre-defined iterations" This reverts commit 4be294c885d53fba7b8a56079d95f9fb89539a62. --- src/clearwater_modules/tsm/processes.py | 34 ++++++++++++++++++------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/clearwater_modules/tsm/processes.py b/src/clearwater_modules/tsm/processes.py index cdc6e4d..f8093c2 100644 --- a/src/clearwater_modules/tsm/processes.py +++ b/src/clearwater_modules/tsm/processes.py @@ -1,5 +1,6 @@ """JIT compiled processes for the heat model.""" import warnings +import numba import numpy as np import xarray as xr from clearwater_modules.shared.processes import ( @@ -7,6 +8,7 @@ ) +@numba.njit def air_temp_k( air_temp_c: xr.DataArray, ) -> xr.DataArray: @@ -18,6 +20,7 @@ def air_temp_k( return celsius_to_kelvin(air_temp_c) +@numba.njit def water_temp_k( water_temp_c: xr.DataArray, ) -> xr.DataArray: @@ -29,6 +32,7 @@ def water_temp_k( return celsius_to_kelvin(water_temp_c) +@numba.njit def mixing_ratio_air( eair_mb: xr.DataArray, pressure_mb: xr.DataArray, @@ -42,6 +46,7 @@ def mixing_ratio_air( return 0.622 * eair_mb / (pressure_mb - eair_mb) +@numba.njit def density_air( pressure_mb: xr.DataArray, air_temp_k: xr.DataArray, @@ -62,6 +67,7 @@ def density_air( ) +@numba.njit def emissivity_air( air_temp_k: xr.DataArray, ) -> xr.DataArray: @@ -73,6 +79,7 @@ def emissivity_air( return 0.00000937 * air_temp_k**2.0 +@numba.njit def wind_function( ri_function: xr.DataArray, wind_a: xr.DataArray, @@ -98,6 +105,7 @@ def wind_function( ) +@numba.njit def q_latent( pressure_mb: xr.DataArray, density_water: xr.DataArray, @@ -125,6 +133,7 @@ def q_latent( ) +@numba.njit def q_sensible( wind_kh_kw: xr.DataArray, cp_air: xr.DataArray, @@ -154,6 +163,7 @@ def q_sensible( ) +@numba.njit def q_sediment( use_sed_temp: xr.DataArray, pb: xr.DataArray, @@ -162,7 +172,7 @@ def q_sediment( h2: xr.DataArray, sed_temp_c: xr.DataArray, water_temp_c: xr.DataArray, -) -> np.ndarray: +) -> xr.DataArray: """Sediment heat flux (W/m^2). Args: @@ -185,13 +195,14 @@ def q_sediment( ) +@numba.njit def dTdt_sediment_c( use_sed_temp: xr.DataArray, alphas: xr.DataArray, h2: xr.DataArray, water_temp_c: xr.DataArray, sed_temp_c: xr.DataArray, -) -> np.ndarray: +) -> xr.DataArray: """Sediments temperature change (C). Args: @@ -208,6 +219,7 @@ def dTdt_sediment_c( ) +@numba.njit def mf_d_esat_dT( water_temp_k: xr.DataArray, a1: xr.DataArray, @@ -238,7 +250,7 @@ def mf_d_esat_dT( # Define functions to be used in the latent heat formulation # ----------------------------------------------------------------------------------- - +@numba.njit def mf_q_longwave_down( air_temp_k: xr.DataArray, emissivity_air: xr.DataArray, @@ -260,7 +272,7 @@ def mf_q_longwave_down( return (1.0 + 0.17 * cloudiness**2) * emissivity_air * stefan_boltzmann * air_temp_k**4.0 - +@numba.njit def mf_q_longwave_up( water_temp_k: xr.DataArray, emissivity_water: xr.DataArray, @@ -275,7 +287,7 @@ def mf_q_longwave_up( return emissivity_water * stefan_boltzmann * water_temp_k**4.0 - +@numba.njit def mf_esat_mb( water_temp_k: xr.DataArray, a0: xr.DataArray, @@ -313,7 +325,7 @@ def mf_esat_mb( # Temperature conversion functions - +@numba.njit def ri_number( gravity: xr.DataArray, density_air: xr.DataArray, @@ -383,6 +395,7 @@ def ri_function(ri_number: xr.DataArray) -> np.ndarray: # ))))) +@numba.njit def mf_latent_heat_vaporization(water_temp_k: xr.DataArray) -> xr.DataArray: """ Compute the latent heat of vaporization (W/m2) as a function of water temperature (Kelvin) @@ -391,7 +404,7 @@ def mf_latent_heat_vaporization(water_temp_k: xr.DataArray) -> xr.DataArray: return 2499999 - 2385.74 * water_temp_k - +@numba.njit def mf_density_water(water_temp_c: xr.DataArray) -> xr.DataArray: """ Compute density of water (kg/m3) as a function of water temperature (Celsius) @@ -415,6 +428,7 @@ def mf_density_water(water_temp_c: xr.DataArray) -> xr.DataArray: ) +@numba.njit def mf_density_air_sat(water_temp_k: xr.DataArray, esat_mb: float, pressure_mb: float) -> xr.DataArray: """ Compute the density of saturated air at water surface temperature. @@ -434,7 +448,7 @@ def mf_density_air_sat(water_temp_k: xr.DataArray, esat_mb: float, pressure_mb: return 0.348 * (pressure_mb / water_temp_k) * (1.0 + mixing_ratio_sat) / (1.0 + 1.61 * mixing_ratio_sat) -def mf_cp_water(water_temp_c: xr.DataArray) -> np.ndarray: +def mf_cp_water(water_temp_c: xr.DataArray) -> xr.DataArray: """ Compute the specific heat of water (J/kg/K) as a function of water temperature (Celsius). This is used in computing the source/sink term. @@ -469,7 +483,7 @@ def mf_cp_water(water_temp_c: xr.DataArray) -> np.ndarray: ) - +@numba.njit def q_net( q_sensible: xr.DataArray, q_latent: xr.DataArray, @@ -498,6 +512,7 @@ def q_net( ) +@numba.njit def dTdt_water_c( q_net: xr.DataArray, surface_area: xr.DataArray, @@ -521,6 +536,7 @@ def dTdt_water_c( ) +@numba.njit def t_water_c( water_temp_c: xr.DataArray, dTdt_water_c: xr.DataArray, From 4f1a5e43154021b952a362165e38a82b83cc7c60 Mon Sep 17 00:00:00 2001 From: xaviernogueira Date: Thu, 14 Dec 2023 16:02:04 -0500 Subject: [PATCH 3/6] Create prof.py --- examples/dev_sandbox/prof.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 examples/dev_sandbox/prof.py diff --git a/examples/dev_sandbox/prof.py b/examples/dev_sandbox/prof.py new file mode 100644 index 0000000..069ec89 --- /dev/null +++ b/examples/dev_sandbox/prof.py @@ -0,0 +1,26 @@ +"""A script to allow for debugging of the TSM module.""" +import clearwater_modules +import time + +def main(): + ti = time.time() + # define starting state values + state_i = { + 'water_temp_c': 40.0, + 'surface_area': 1.0, + 'volume': 1.0, + } + + # instantiate the TSM module + tsm = clearwater_modules.tsm.EnergyBudget( + initial_state_values=state_i, + meteo_parameters={'wind_c': 1.0}, + ) + print(tsm.static_variable_values) + t2 = time.time() + for _ in range(100): + tsm.increment_timestep() + print(f'Increment timestep speed (average of 100): {(time.time() - t2) / 100}') + print(f'Run time: {time.time() - ti}') +if __name__ == '__main__': + main() From 44592b370722259524bee82fcb4f644db411a974 Mon Sep 17 00:00:00 2001 From: xaviernogueira Date: Thu, 14 Dec 2023 16:02:08 -0500 Subject: [PATCH 4/6] Create improve_performance_v2.ipynb --- .../dev_sandbox/improve_performance_v2.ipynb | 233 ++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 examples/dev_sandbox/improve_performance_v2.ipynb diff --git a/examples/dev_sandbox/improve_performance_v2.ipynb b/examples/dev_sandbox/improve_performance_v2.ipynb new file mode 100644 index 0000000..49d285c --- /dev/null +++ b/examples/dev_sandbox/improve_performance_v2.ipynb @@ -0,0 +1,233 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "52b6fe1e-6174-4855-8230-e44c637b754e", + "metadata": {}, + "source": [ + "# Numba Performance exploration notebook\n", + "\n", + "**Overview:** After realizing that removing numba speed up performance 8 fold, I wanted to make sure this persists for longer runtimes where JIT is designed to be effective." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a01ac108-0980-4d68-89ac-3b49013fb939", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import numba\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "33e20862-ac39-46be-aa1d-0e1ee51a8994", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "@numba.njit\n", + "def jitadd(x, y) -> np.ndarray:\n", + " return (x * 0.2) + (y * 0.3)\n", + "\n", + "def add(x, y) -> np.ndarray:\n", + " return (x * 0.2) + (y * 0.3)\n", + "\n", + "@numba.njit\n", + "def jitsub(x) -> np.ndarray:\n", + " return x - (0.15 * x)\n", + "\n", + "def sub(x) -> np.ndarray:\n", + " return x - (0.15 * x)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "3bf02ded-b09c-4791-9d3a-3f299c29bad5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.25289281, 0.3012049 , 0.81652952, ..., 0.08600542, 0.59860995,\n", + " 0.7160782 ],\n", + " [0.23049757, 0.7084206 , 0.008605 , ..., 0.12755473, 0.76453943,\n", + " 0.70370063],\n", + " [0.34352444, 0.21580305, 0.24908259, ..., 0.54471826, 0.42611405,\n", + " 0.9940653 ],\n", + " ...,\n", + " [0.73328428, 0.46605497, 0.29684926, ..., 0.38692169, 0.1791475 ,\n", + " 0.24509559],\n", + " [0.22376018, 0.713524 , 0.47561131, ..., 0.0983168 , 0.99624008,\n", + " 0.96001501],\n", + " [0.17163253, 0.21930795, 0.68329088, ..., 0.34515248, 0.92341599,\n", + " 0.65903902]])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input = np.ones(shape=(100, 100)) * np.random.random_sample((100, 100))\n", + "input" + ] + }, + { + "cell_type": "markdown", + "id": "46b28de0-95a5-49b9-8a85-3f468b34993f", + "metadata": {}, + "source": [ + "**No JIT**" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "8760934b-879b-480f-b8f1-6acd8438ef4f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "61.3 ms ± 5.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "TIMESTEPS = 1000\n", + "out_dict = {}\n", + "for i in range(TIMESTEPS):\n", + " val = add(input, -input)\n", + " out_dict[i] = sub(val)" + ] + }, + { + "cell_type": "markdown", + "id": "00898ada-0ca7-4b5a-baf1-ccad07219545", + "metadata": { + "tags": [] + }, + "source": [ + "**Just the function being JIT**" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "cb8e6fd5-bfc7-4c08-bc3f-a1e4ce0eac06", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "45.2 ms ± 1.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "TIMESTEPS = 1000\n", + "out_dict = {}\n", + "for i in range(TIMESTEPS):\n", + " val = jitadd(input, -input)\n", + " out_dict[i] = jitsub(val)" + ] + }, + { + "cell_type": "markdown", + "id": "dde04c99-8b06-47fb-ba8f-838f83501755", + "metadata": { + "tags": [] + }, + "source": [ + "**The functions and loop being JIT**" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "d61c90a2-476d-4769-84ee-40f8c8e57482", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "477 ms ± 24.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "TIMESTEPS = 1000\n", + "\n", + "@numba.njit\n", + "def loop(input_var, time_steps):\n", + " out_dict = {}\n", + " for i in range(time_steps):\n", + " val = jitadd(input_var, -input_var)\n", + " out_dict[i] = jitsub(val)\n", + " return out_dict\n", + "\n", + "out_dict = loop(input, TIMESTEPS)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a86bda76-3aec-4855-8101-9dca521bb5e3", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa488e31-835f-424c-b050-a2b9e7d68683", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From c94a2e4396cebec7ed2610e3f3073f9b3ecbb7e0 Mon Sep 17 00:00:00 2001 From: xaviernogueira Date: Thu, 14 Dec 2023 16:06:38 -0500 Subject: [PATCH 5/6] 1000 iters for prof --- examples/dev_sandbox/prof.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dev_sandbox/prof.py b/examples/dev_sandbox/prof.py index 069ec89..e0eb859 100644 --- a/examples/dev_sandbox/prof.py +++ b/examples/dev_sandbox/prof.py @@ -18,7 +18,7 @@ def main(): ) print(tsm.static_variable_values) t2 = time.time() - for _ in range(100): + for _ in range(1000): tsm.increment_timestep() print(f'Increment timestep speed (average of 100): {(time.time() - t2) / 100}') print(f'Run time: {time.time() - ti}') From 9b9788224d24c5c85dcad87717007df07c717c80 Mon Sep 17 00:00:00 2001 From: xaviernogueira Date: Thu, 14 Dec 2023 16:14:07 -0500 Subject: [PATCH 6/6] CLI control over number of iterations to check --- examples/dev_sandbox/prof.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/dev_sandbox/prof.py b/examples/dev_sandbox/prof.py index e0eb859..f4cb2a5 100644 --- a/examples/dev_sandbox/prof.py +++ b/examples/dev_sandbox/prof.py @@ -1,8 +1,9 @@ """A script to allow for debugging of the TSM module.""" import clearwater_modules import time +import sys -def main(): +def main(iters: int): ti = time.time() # define starting state values state_i = { @@ -18,9 +19,20 @@ def main(): ) print(tsm.static_variable_values) t2 = time.time() - for _ in range(1000): + for _ in range(iters): tsm.increment_timestep() - print(f'Increment timestep speed (average of 100): {(time.time() - t2) / 100}') + print(f'Increment timestep speed (average of {iters}): {(time.time() - t2) / 100}') print(f'Run time: {time.time() - ti}') + if __name__ == '__main__': - main() + if len(sys.argv) > 1: + try: + iters = int(sys.argv[1]) + print(f'Running {iters} iterations.') + except ValueError: + raise ValueError('Argument must be an integer # of iterations.') + else: + print('No argument given, defaulting to 100 iteration.') + iters = 100 + + main(iters=iters)