From dd5d0a93827e7178ab617e22c916227bfcc35f49 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Thu, 27 Jun 2024 01:05:02 -0400 Subject: [PATCH 01/15] try to switch from numpy/scipy to jax in some part of basis interp --- piff/basis_interp.py | 112 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 92 insertions(+), 20 deletions(-) diff --git a/piff/basis_interp.py b/piff/basis_interp.py index 81581e94..740ce49a 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -24,6 +24,42 @@ from .interp import Interp from .star import Star +import jax +# If uncommented, the following line will make the code run in double precision +# and have identical results as running the original code in double precision. +# jax.config.update("jax_enable_x64", True) +from jax import jit +from jax import numpy as jnp +from jax import vmap +import time + +@jit +def jax_solve(ATA, ATb): + # Original code: + # dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False) + # New code: + dq = jax.scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False) + return dq + +@jit +def build_ATA_ATb(alpha, beta, K): + ATb = (beta[:, jnp.newaxis] * K).flatten() + tmp1 = alpha[:, :, jnp.newaxis] * K + tmp2 = K[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] * tmp1[:, jnp.newaxis, :, :] + return tmp2, ATb + +@jit +def vmap_build_ATA_ATb(Ks, alphas, betas): + # Use vmap to vectorize build_ATA_ATb across the first dimension of Ks, alphas, and betas + vmapped_build_ATA_ATb = vmap(build_ATA_ATb, in_axes=(0, 0, 0)) + # Get the vectorized results + tmp2s, ATbs = vmapped_build_ATA_ATb(alphas, betas, Ks) + # Sum the results along the first axis + ATb = jnp.sum(ATbs, axis=0) + tmp2 = jnp.sum(tmp2s, axis=0) + return tmp2, ATb + + class BasisInterp(Interp): r"""An Interp class that works whenever the interpolating functions are linear sums of basis functions. Does things the "slow way" to be stable to @@ -55,6 +91,7 @@ def __init__(self): self.use_qr = False # The default. May be overridden by subclasses. self.q = None self.set_num(None) + self._use_jax = False # The default. May be overridden by subclasses. def initialize(self, stars, logger=None): """Initialize both the interpolator to some state prefatory to any solve iterations and @@ -249,26 +286,49 @@ def _solve_direct(self, stars, logger): # Build ATA and ATb by accumulating the chunks for each star as we go. nq = np.prod(self.q.shape) + logger.info(f'nq = {nq}') ATA = np.zeros((nq, nq), dtype=float) ATb = np.zeros(nq, dtype=float) - for s in stars: - # Get the basis function values at this star - K = self.basis(s) - # Sum contributions into ATA, ATb - if True: - ATb += (s.fit.beta[:,np.newaxis] * K).flatten() - tmp1 = s.fit.alpha[:,:,np.newaxis] * K - tmp2 = K[np.newaxis,:,np.newaxis,np.newaxis] * tmp1[:,np.newaxis,:,:] - ATA += tmp2.reshape(nq,nq) - else: # pragma: no cover - # This is equivalent, but slower. - # It is here to make more explicit the connection between this calculation - # and the corresponding part of the QR code above. - A1 = (s.fit.A[:,:,np.newaxis] * K[np.newaxis,:]).reshape( - s.fit.A.shape[0], s.fit.A.shape[1] * len(K)) - ATb += A1.T.dot(s.fit.b) - ATA += A1.T.dot(A1) + start = time.time() + if self.use_jax: + Ks = [] + alphas = [] + betas = [] + for s in stars: + # Get the basis function values at this star + K = self.basis(s) + Ks.append(K) + alphas.append(s.fit.alpha) + betas.append(s.fit.beta) + alphas = np.array(alphas).reshape((len(alphas), alphas[0].shape[0], alphas[0].shape[1])) + betas = np.array(betas).reshape((len(betas), betas[0].shape[0])) + Ks = np.array(Ks).reshape((len(Ks), Ks[0].shape[0])) + tmp2, ATb = vmap_build_ATA_ATb(Ks, alphas, betas) + ATA = tmp2.reshape(nq,nq) + else: + for s in stars: + # Get the basis function values at this star + K = self.basis(s) + # Sum contributions into ATA, ATb + + if True: + alpha = s.fit.alpha + beta = s.fit.beta + ATb += (beta[:,np.newaxis] * K).flatten() + tmp1 = alpha[:,:,np.newaxis] * K + tmp2 = K[np.newaxis,:,np.newaxis,np.newaxis] * tmp1[:,np.newaxis,:,:] + ATA += tmp2.reshape(nq,nq) + else: # pragma: no cover + # This is equivalent, but slower. + # It is here to make more explicit the connection between this calculation + # and the corresponding part of the QR code above. + A1 = (s.fit.A[:,:,np.newaxis] * K[np.newaxis,:]).reshape( + s.fit.A.shape[0], s.fit.A.shape[1] * len(K)) + ATb += A1.T.dot(s.fit.b) + ATA += A1.T.dot(A1) + end = time.time() + logger.info('PF time to compute ATb and ATA: %f | use jax: %s', end-start, str(self.use_jax)) logger.info('Beginning solution of matrix size %s',ATA.shape) try: @@ -278,7 +338,16 @@ def _solve_direct(self, stars, logger): # assuming just 'sym' instead (which does an LDL decomposition rather than # Cholesky) would help. If this fails, the matrix is usually high enough # condition that it is functionally singular, and switching to SVD is warranted. - dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False) + if self.use_jax: + start = time.time() + dq = jax_solve(ATA, ATb) + end = time.time() + logger.info('PF: Use JAX to solve the linear system | Time: %f', end-start) + else: + start = time.time() + dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False) + end = time.time() + logger.info('PF: Not using JAX to solve the linear system | Time: %f', end-start) if len(w) > 0: # scipy likes to warn about high condition. They aren't actually a problem @@ -358,7 +427,7 @@ class BasisPolynomial(BasisInterp): """ _type_name = 'BasisPolynomial' - def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, logger=None): + def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, use_jax=False, logger=None): super(BasisPolynomial, self).__init__() self._keys = keys @@ -375,6 +444,8 @@ def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, logger=N self._max_order = max_order self.use_qr = use_qr + self.use_jax = use_jax + if self._max_order<0 or np.any(np.array(self._orders) < 0): # Exception if we have any requests for negative orders raise ValueError('Negative polynomial order specified') @@ -382,7 +453,8 @@ def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, logger=N self.kwargs = { 'order' : order, 'keys' : keys, - 'use_qr' : use_qr + 'use_qr' : use_qr, + 'use_jax' : use_jax, } # Now build a mask that picks the desired polynomial products From 0b8dde044eb852b36330da7aa4d37bd6ac313b20 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Fri, 28 Jun 2024 14:21:51 -0400 Subject: [PATCH 02/15] add jax in requirments --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index d03b7b19..8567f464 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ matplotlib>=3.3 galsim>=2.3 treegp>=0.6 threadpoolctl>=3.1 +jax>=0.4.28 From 8bd41683e9dea8808203e0a7d66428f2217e8b57 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Mon, 1 Jul 2024 19:31:43 -0400 Subject: [PATCH 03/15] fix requirment for jax --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8567f464..f975d670 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ matplotlib>=3.3 galsim>=2.3 treegp>=0.6 threadpoolctl>=3.1 -jax>=0.4.28 +jax>=0.3 From 0c1b1a63a15673cf81a4a3a5f7c8143095ed97c4 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Tue, 2 Jul 2024 00:20:28 -0400 Subject: [PATCH 04/15] deprecate test in python 3.7 and python 3.8 --- .github/workflows/ci.yml | 2 +- requirements.txt | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b6b3c276..ab84de76 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: matrix: # First all python versions in basic linux os: [ ubuntu-latest ] - py: [ 3.7, 3.8, 3.9, "3.10", 3.11, 3.12 ] + py: [ 3.9, "3.10", 3.11, 3.12 ] CC: [ gcc ] CXX: [ g++ ] FFTW_DIR: [ "/usr/local/lib/" ] diff --git a/requirements.txt b/requirements.txt index f975d670..72258d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ matplotlib>=3.3 galsim>=2.3 treegp>=0.6 threadpoolctl>=3.1 -jax>=0.3 +jax>=0.4 \ No newline at end of file diff --git a/setup.py b/setup.py index 6b60f6b0..1406b3fb 100644 --- a/setup.py +++ b/setup.py @@ -274,7 +274,7 @@ def run(self): install_scripts.run(self) self.distribution.script_install_dir = self.install_dir -dependencies = ['galsim>=2.3', 'numpy>=1.17', 'scipy>=1.2', 'pyyaml>=5.1', 'treecorr>=4.3.1', 'fitsio>=1.0', 'matplotlib>=3.3', 'LSSTDESC.Coord>=1.0', 'treegp>=0.6', 'threadpoolctl>=3.1'] +dependencies = ['galsim>=2.3', 'numpy>=1.17', 'scipy>=1.2', 'pyyaml>=5.1', 'treecorr>=4.3.1', 'fitsio>=1.0', 'matplotlib>=3.3', 'LSSTDESC.Coord>=1.0', 'treegp>=0.6', 'jax>=0.4', 'threadpoolctl>=3.1'] with open('README.rst') as file: long_description = file.read() From 008362b33559c17a9852d28aa235d44308098b34 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Tue, 2 Jul 2024 00:32:17 -0400 Subject: [PATCH 05/15] add some comment and remove some logging --- piff/basis_interp.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/piff/basis_interp.py b/piff/basis_interp.py index 740ce49a..37d4c30f 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -31,7 +31,12 @@ from jax import jit from jax import numpy as jnp from jax import vmap -import time + + +# Bellow are implementations of _solve_direct using JAX. +# if jax.config.update("jax_enable_x64", True) it will give the +# same results as the original code in double precision, but will run +# slower, but still faster than the numpy/scipy version. @jit def jax_solve(ATA, ATb): @@ -290,7 +295,6 @@ def _solve_direct(self, stars, logger): ATA = np.zeros((nq, nq), dtype=float) ATb = np.zeros(nq, dtype=float) - start = time.time() if self.use_jax: Ks = [] alphas = [] @@ -327,8 +331,6 @@ def _solve_direct(self, stars, logger): s.fit.A.shape[0], s.fit.A.shape[1] * len(K)) ATb += A1.T.dot(s.fit.b) ATA += A1.T.dot(A1) - end = time.time() - logger.info('PF time to compute ATb and ATA: %f | use jax: %s', end-start, str(self.use_jax)) logger.info('Beginning solution of matrix size %s',ATA.shape) try: @@ -339,15 +341,9 @@ def _solve_direct(self, stars, logger): # Cholesky) would help. If this fails, the matrix is usually high enough # condition that it is functionally singular, and switching to SVD is warranted. if self.use_jax: - start = time.time() dq = jax_solve(ATA, ATb) - end = time.time() - logger.info('PF: Use JAX to solve the linear system | Time: %f', end-start) else: - start = time.time() dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False) - end = time.time() - logger.info('PF: Not using JAX to solve the linear system | Time: %f', end-start) if len(w) > 0: # scipy likes to warn about high condition. They aren't actually a problem From fd95179b10b063ffeac5ce076ead902f513d05a7 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Tue, 2 Jul 2024 00:47:19 -0400 Subject: [PATCH 06/15] add doc for use_jax and test for use_jax --- piff/basis_interp.py | 1 + tests/test_pixel.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/piff/basis_interp.py b/piff/basis_interp.py index 37d4c30f..ee69e122 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -419,6 +419,7 @@ class BasisPolynomial(BasisInterp): and is somewhat slower (nearly a factor of 2); however, it is significantly less susceptible to numerical errors from high condition matrices. Therefore, it may be preferred for some use cases. [default: False] + :param use_jax: Use JAX in _solve_direct compared to classic numpy/scipy. [default: False] :param logger: A logger object for logging debug info. [default: None] """ _type_name = 'BasisPolynomial' diff --git a/tests/test_pixel.py b/tests/test_pixel.py index 87842b54..aa2655ab 100644 --- a/tests/test_pixel.py +++ b/tests/test_pixel.py @@ -1025,7 +1025,8 @@ def test_single_image(): }, 'interp' : { 'type' : 'BasisPolynomial', - 'order' : order + 'order' : order, + 'use_jax' : True, }, }, } From 5ea4f5edc9b9a8127a2c36cd70fa0369d0442038 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Tue, 2 Jul 2024 15:13:32 -0400 Subject: [PATCH 07/15] add a better coverage and assert jax and non jax computation give same result --- tests/test_pixel.py | 81 ++++++++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/tests/test_pixel.py b/tests/test_pixel.py index aa2655ab..f5ef4cf4 100644 --- a/tests/test_pixel.py +++ b/tests/test_pixel.py @@ -21,6 +21,7 @@ import time import os import fitsio +import copy from piff_test_helper import get_script_name, timer, CaptureLog @@ -1026,53 +1027,65 @@ def test_single_image(): 'interp' : { 'type' : 'BasisPolynomial', 'order' : order, - 'use_jax' : True, + 'use_jax' : False, }, }, } + if __name__ == '__main__': config['verbose'] = 2 else: config['verbose'] = 0 - print("Running piffify function") - piff.piffify(config) - psf = piff.read(psf_file) - test_star = psf.drawStar(target_star) - print("Max abs diff = ",np.max(np.abs(test_star.image.array - test_im.array))) - np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) + configJax = copy.deepcopy(config) + configJax['psf']['interp']['use_jax'] = True + import jax + jax.config.update("jax_enable_x64", True) + test_star_jax = [] - # Test using the piffify executable - with open('pixel_moffat.yaml','w') as f: - f.write(yaml.dump(config, default_flow_style=False)) - if __name__ == '__main__': - print("Running piffify executable") - if os.path.exists(psf_file): - os.remove(psf_file) - piffify_exe = get_script_name('piffify') - p = subprocess.Popen( [piffify_exe, 'pixel_moffat.yaml'] ) - p.communicate() + for conf in [config, configJax]: + print("Running piffify function") + piff.piffify(conf) psf = piff.read(psf_file) test_star = psf.drawStar(target_star) + test_star_jax.append(test_star) + print("Max abs diff = ",np.max(np.abs(test_star.image.array - test_im.array))) np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) - # test copy_image property of draw - target_star_copy = psf.interp.interpolate(target_star) - test_star_copy = psf.model.draw(target_star_copy, copy_image=True) - test_star_nocopy = psf.model.draw(target_star_copy, copy_image=False) - # if we modify target_star_copy, then test_star_nocopy should be modified, but not test_star_copy - target_star_copy.image.array[0,0] = 23456 - assert test_star_nocopy.image.array[0,0] == target_star_copy.image.array[0,0] - assert test_star_copy.image.array[0,0] != target_star_copy.image.array[0,0] - # however the other pixels SHOULD still be all the same value - assert test_star_nocopy.image.array[1,1] == target_star_copy.image.array[1,1] - assert test_star_copy.image.array[1,1] == target_star_copy.image.array[1,1] - - # check that drawing onto an image does not return a copy - image = psf.draw(x=x0, y=y0) - image_reference = psf.draw(x=x0, y=y0, image=image) - image_reference.array[0,0] = 123456 - assert image.array[0,0] == image_reference.array[0,0] + # Test using the piffify executable + with open('pixel_moffat.yaml','w') as f: + f.write(yaml.dump(conf, default_flow_style=False)) + if __name__ == '__main__': + print("Running piffify executable") + if os.path.exists(psf_file): + os.remove(psf_file) + piffify_exe = get_script_name('piffify') + p = subprocess.Popen( [piffify_exe, 'pixel_moffat.yaml'] ) + p.communicate() + psf = piff.read(psf_file) + test_star = psf.drawStar(target_star) + np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) + + # test copy_image property of draw + target_star_copy = psf.interp.interpolate(target_star) + test_star_copy = psf.model.draw(target_star_copy, copy_image=True) + test_star_nocopy = psf.model.draw(target_star_copy, copy_image=False) + # if we modify target_star_copy, then test_star_nocopy should be modified, but not test_star_copy + target_star_copy.image.array[0,0] = 23456 + assert test_star_nocopy.image.array[0,0] == target_star_copy.image.array[0,0] + assert test_star_copy.image.array[0,0] != target_star_copy.image.array[0,0] + # however the other pixels SHOULD still be all the same value + assert test_star_nocopy.image.array[1,1] == target_star_copy.image.array[1,1] + assert test_star_copy.image.array[1,1] == target_star_copy.image.array[1,1] + + # check that drawing onto an image does not return a copy + image = psf.draw(x=x0, y=y0) + image_reference = psf.draw(x=x0, y=y0, image=image) + image_reference.array[0,0] = 123456 + assert image.array[0,0] == image_reference.array[0,0] + + # check that the two versions (jax vs numpy/scipy) of the test star are the same + np.testing.assert_allclose(test_star_jax[0].image.array, test_star_jax[1].image.array) @timer def test_des_image(): From c76e3ac494ed45ad0adb12a9593959096ae51d88 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Mon, 15 Jul 2024 14:11:40 -0400 Subject: [PATCH 08/15] replace with cholesky the solver --- piff/basis_interp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/piff/basis_interp.py b/piff/basis_interp.py index ee69e122..9e10d1ca 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -43,7 +43,8 @@ def jax_solve(ATA, ATb): # Original code: # dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False) # New code: - dq = jax.scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False) + factor = (jax.scipy.linalg.cholesky(ATA, overwrite_a=True, lower=False), False) + dq = jax.scipy.linalg.cho_solve(factor, ATb, overwrite_b=False) return dq @jit From 49453fa905b55c997a628b332b09be0941c74c1b Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Mon, 29 Jul 2024 22:35:37 -0400 Subject: [PATCH 09/15] add request changes on basis interp --- piff/basis_interp.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/piff/basis_interp.py b/piff/basis_interp.py index 9e10d1ca..89876485 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -24,14 +24,17 @@ from .interp import Interp from .star import Star -import jax -# If uncommented, the following line will make the code run in double precision -# and have identical results as running the original code in double precision. -# jax.config.update("jax_enable_x64", True) -from jax import jit -from jax import numpy as jnp -from jax import vmap - +try: + import jax + from jax import jit + from jax import numpy as jnp + from jax import vmap + +except ImportError: + CAN_USE_JAX = False +else: + CAN_USE_JAX = True + jax.config.update("jax_enable_x64", True) # Bellow are implementations of _solve_direct using JAX. # if jax.config.update("jax_enable_x64", True) it will give the @@ -51,19 +54,19 @@ def jax_solve(ATA, ATb): def build_ATA_ATb(alpha, beta, K): ATb = (beta[:, jnp.newaxis] * K).flatten() tmp1 = alpha[:, :, jnp.newaxis] * K - tmp2 = K[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] * tmp1[:, jnp.newaxis, :, :] - return tmp2, ATb + ATA = K[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] * tmp1[:, jnp.newaxis, :, :] + return ATA, ATb @jit def vmap_build_ATA_ATb(Ks, alphas, betas): # Use vmap to vectorize build_ATA_ATb across the first dimension of Ks, alphas, and betas vmapped_build_ATA_ATb = vmap(build_ATA_ATb, in_axes=(0, 0, 0)) # Get the vectorized results - tmp2s, ATbs = vmapped_build_ATA_ATb(alphas, betas, Ks) + ATAs, ATbs = vmapped_build_ATA_ATb(alphas, betas, Ks) # Sum the results along the first axis ATb = jnp.sum(ATbs, axis=0) - tmp2 = jnp.sum(tmp2s, axis=0) - return tmp2, ATb + ATA = jnp.sum(ATAs, axis=0) + return ATA, ATb class BasisInterp(Interp): @@ -292,7 +295,6 @@ def _solve_direct(self, stars, logger): # Build ATA and ATb by accumulating the chunks for each star as we go. nq = np.prod(self.q.shape) - logger.info(f'nq = {nq}') ATA = np.zeros((nq, nq), dtype=float) ATb = np.zeros(nq, dtype=float) @@ -309,8 +311,8 @@ def _solve_direct(self, stars, logger): alphas = np.array(alphas).reshape((len(alphas), alphas[0].shape[0], alphas[0].shape[1])) betas = np.array(betas).reshape((len(betas), betas[0].shape[0])) Ks = np.array(Ks).reshape((len(Ks), Ks[0].shape[0])) - tmp2, ATb = vmap_build_ATA_ATb(Ks, alphas, betas) - ATA = tmp2.reshape(nq,nq) + ATA, ATb = vmap_build_ATA_ATb(Ks, alphas, betas) + ATA = ATA.reshape(nq,nq) else: for s in stars: # Get the basis function values at this star @@ -418,9 +420,8 @@ class BasisPolynomial(BasisInterp): :param use_qr: Use QR decomposition for the solution rather than the more direct least squares solution. QR decomposition requires more memory than the default and is somewhat slower (nearly a factor of 2); however, it is significantly - less susceptible to numerical errors from high condition matrices. + :param use_jax: Use JAX for solving the linear algebra equations rather than numpy/scipy. Therefore, it may be preferred for some use cases. [default: False] - :param use_jax: Use JAX in _solve_direct compared to classic numpy/scipy. [default: False] :param logger: A logger object for logging debug info. [default: None] """ _type_name = 'BasisPolynomial' @@ -444,6 +445,10 @@ def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, use_jax= self.use_jax = use_jax + if not CAN_USE_JAX and self.use_jax: + logger.warning("JAX is not installed. Switching to numpy/scipy.") + self.use_jax = False + if self._max_order<0 or np.any(np.array(self._orders) < 0): # Exception if we have any requests for negative orders raise ValueError('Negative polynomial order specified') From d6475e712dcb2d28326c8f52963e6fbcfea4a031 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Mon, 29 Jul 2024 23:23:18 -0400 Subject: [PATCH 10/15] make sure can work with/without jax --- .github/workflows/ci.yml | 7 ++- piff/basis_interp.py | 7 ++- requirements.txt | 3 +- tests/test_pixel.py | 94 +++++++++++++++++++++------------------- 4 files changed, 61 insertions(+), 50 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ab84de76..06798c19 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: matrix: # First all python versions in basic linux os: [ ubuntu-latest ] - py: [ 3.9, "3.10", 3.11, 3.12 ] + py: [ 3.7, 3.8, 3.9, "3.10", 3.11, 3.12 ] CC: [ gcc ] CXX: [ g++ ] FFTW_DIR: [ "/usr/local/lib/" ] @@ -115,6 +115,11 @@ jobs: # https://github.com/python-pillow/Pillow/issues/7259 pip install -U nose coverage "pytest<8" nbval ipykernel "pillow<10" + - name: Install Jax + if: matrix.py == 3.9 || matrix.py == 3.10 || matrix.py == 3.11 || matrix.py == 3.12 + run: | + pip install -U jax + - name: Install Pixmappy (not on pip) run: | git clone https://github.com/gbernstein/pixmappy.git diff --git a/piff/basis_interp.py b/piff/basis_interp.py index 89876485..bd95abc6 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -32,6 +32,9 @@ except ImportError: CAN_USE_JAX = False + # define dummy functions for jax + def jit(f): + return f else: CAN_USE_JAX = True jax.config.update("jax_enable_x64", True) @@ -46,8 +49,8 @@ def jax_solve(ATA, ATb): # Original code: # dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False) # New code: - factor = (jax.scipy.linalg.cholesky(ATA, overwrite_a=True, lower=False), False) - dq = jax.scipy.linalg.cho_solve(factor, ATb, overwrite_b=False) + (factor, lower) = (jax.scipy.linalg.cholesky(ATA, overwrite_a=True, lower=False), False) + dq = jax.scipy.linalg.cho_solve((factor, lower), ATb, overwrite_b=False) return dq @jit diff --git a/requirements.txt b/requirements.txt index 72258d68..0d6a5a87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,4 @@ fitsio>=1.0 matplotlib>=3.3 galsim>=2.3 treegp>=0.6 -threadpoolctl>=3.1 -jax>=0.4 \ No newline at end of file +threadpoolctl>=3.1 \ No newline at end of file diff --git a/tests/test_pixel.py b/tests/test_pixel.py index f5ef4cf4..0ad91a50 100644 --- a/tests/test_pixel.py +++ b/tests/test_pixel.py @@ -1037,55 +1037,59 @@ def test_single_image(): else: config['verbose'] = 0 - configJax = copy.deepcopy(config) - configJax['psf']['interp']['use_jax'] = True - import jax - jax.config.update("jax_enable_x64", True) - test_star_jax = [] - - for conf in [config, configJax]: - print("Running piffify function") - piff.piffify(conf) - psf = piff.read(psf_file) - test_star = psf.drawStar(target_star) - test_star_jax.append(test_star) - print("Max abs diff = ",np.max(np.abs(test_star.image.array - test_im.array))) - np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) - - # Test using the piffify executable - with open('pixel_moffat.yaml','w') as f: - f.write(yaml.dump(conf, default_flow_style=False)) - if __name__ == '__main__': - print("Running piffify executable") - if os.path.exists(psf_file): - os.remove(psf_file) - piffify_exe = get_script_name('piffify') - p = subprocess.Popen( [piffify_exe, 'pixel_moffat.yaml'] ) - p.communicate() + try: + import jax + except ImportError: + print("JAX not installed. Skipping JAX tests.") + else: + configJax = copy.deepcopy(config) + configJax['psf']['interp']['use_jax'] = True + jax.config.update("jax_enable_x64", True) + test_star_jax = [] + + for conf in [config, configJax]: + print("Running piffify function") + piff.piffify(conf) psf = piff.read(psf_file) test_star = psf.drawStar(target_star) + test_star_jax.append(test_star) + print("Max abs diff = ",np.max(np.abs(test_star.image.array - test_im.array))) np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) - # test copy_image property of draw - target_star_copy = psf.interp.interpolate(target_star) - test_star_copy = psf.model.draw(target_star_copy, copy_image=True) - test_star_nocopy = psf.model.draw(target_star_copy, copy_image=False) - # if we modify target_star_copy, then test_star_nocopy should be modified, but not test_star_copy - target_star_copy.image.array[0,0] = 23456 - assert test_star_nocopy.image.array[0,0] == target_star_copy.image.array[0,0] - assert test_star_copy.image.array[0,0] != target_star_copy.image.array[0,0] - # however the other pixels SHOULD still be all the same value - assert test_star_nocopy.image.array[1,1] == target_star_copy.image.array[1,1] - assert test_star_copy.image.array[1,1] == target_star_copy.image.array[1,1] - - # check that drawing onto an image does not return a copy - image = psf.draw(x=x0, y=y0) - image_reference = psf.draw(x=x0, y=y0, image=image) - image_reference.array[0,0] = 123456 - assert image.array[0,0] == image_reference.array[0,0] - - # check that the two versions (jax vs numpy/scipy) of the test star are the same - np.testing.assert_allclose(test_star_jax[0].image.array, test_star_jax[1].image.array) + # Test using the piffify executable + with open('pixel_moffat.yaml','w') as f: + f.write(yaml.dump(conf, default_flow_style=False)) + if __name__ == '__main__': + print("Running piffify executable") + if os.path.exists(psf_file): + os.remove(psf_file) + piffify_exe = get_script_name('piffify') + p = subprocess.Popen( [piffify_exe, 'pixel_moffat.yaml'] ) + p.communicate() + psf = piff.read(psf_file) + test_star = psf.drawStar(target_star) + np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) + + # test copy_image property of draw + target_star_copy = psf.interp.interpolate(target_star) + test_star_copy = psf.model.draw(target_star_copy, copy_image=True) + test_star_nocopy = psf.model.draw(target_star_copy, copy_image=False) + # if we modify target_star_copy, then test_star_nocopy should be modified, but not test_star_copy + target_star_copy.image.array[0,0] = 23456 + assert test_star_nocopy.image.array[0,0] == target_star_copy.image.array[0,0] + assert test_star_copy.image.array[0,0] != target_star_copy.image.array[0,0] + # however the other pixels SHOULD still be all the same value + assert test_star_nocopy.image.array[1,1] == target_star_copy.image.array[1,1] + assert test_star_copy.image.array[1,1] == target_star_copy.image.array[1,1] + + # check that drawing onto an image does not return a copy + image = psf.draw(x=x0, y=y0) + image_reference = psf.draw(x=x0, y=y0, image=image) + image_reference.array[0,0] = 123456 + assert image.array[0,0] == image_reference.array[0,0] + + # check that the two versions (jax vs numpy/scipy) of the test star are the same + np.testing.assert_allclose(test_star_jax[0].image.array, test_star_jax[1].image.array) @timer def test_des_image(): From b915ccf25d06a9b329c45bebd571b7fa395c8759 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Mon, 29 Jul 2024 23:26:22 -0400 Subject: [PATCH 11/15] fix typo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0d6a5a87..d03b7b19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ fitsio>=1.0 matplotlib>=3.3 galsim>=2.3 treegp>=0.6 -threadpoolctl>=3.1 \ No newline at end of file +threadpoolctl>=3.1 From 9f629ca2e36ce262dadd5e34ed9e5f35875b8a74 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Mon, 29 Jul 2024 23:31:58 -0400 Subject: [PATCH 12/15] remove jax from requirment in setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1406b3fb..6b60f6b0 100644 --- a/setup.py +++ b/setup.py @@ -274,7 +274,7 @@ def run(self): install_scripts.run(self) self.distribution.script_install_dir = self.install_dir -dependencies = ['galsim>=2.3', 'numpy>=1.17', 'scipy>=1.2', 'pyyaml>=5.1', 'treecorr>=4.3.1', 'fitsio>=1.0', 'matplotlib>=3.3', 'LSSTDESC.Coord>=1.0', 'treegp>=0.6', 'jax>=0.4', 'threadpoolctl>=3.1'] +dependencies = ['galsim>=2.3', 'numpy>=1.17', 'scipy>=1.2', 'pyyaml>=5.1', 'treecorr>=4.3.1', 'fitsio>=1.0', 'matplotlib>=3.3', 'LSSTDESC.Coord>=1.0', 'treegp>=0.6', 'threadpoolctl>=3.1'] with open('README.rst') as file: long_description = file.read() From 6147229959f09bac9768af8afe0274a8e14272c6 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Mon, 29 Jul 2024 23:46:23 -0400 Subject: [PATCH 13/15] test jax not installed but use jax --- tests/test_pixel.py | 90 ++++++++++++++++++++++----------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/tests/test_pixel.py b/tests/test_pixel.py index 0ad91a50..dd2374c7 100644 --- a/tests/test_pixel.py +++ b/tests/test_pixel.py @@ -1039,57 +1039,57 @@ def test_single_image(): try: import jax - except ImportError: - print("JAX not installed. Skipping JAX tests.") - else: - configJax = copy.deepcopy(config) - configJax['psf']['interp']['use_jax'] = True jax.config.update("jax_enable_x64", True) - test_star_jax = [] + except ImportError: + print("JAX not installed.") + + configJax = copy.deepcopy(config) + configJax['psf']['interp']['use_jax'] = True + test_star_jax = [] - for conf in [config, configJax]: - print("Running piffify function") - piff.piffify(conf) + for conf in [config, configJax]: + print("Running piffify function") + piff.piffify(conf) + psf = piff.read(psf_file) + test_star = psf.drawStar(target_star) + test_star_jax.append(test_star) + print("Max abs diff = ",np.max(np.abs(test_star.image.array - test_im.array))) + np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) + + # Test using the piffify executable + with open('pixel_moffat.yaml','w') as f: + f.write(yaml.dump(conf, default_flow_style=False)) + if __name__ == '__main__': + print("Running piffify executable") + if os.path.exists(psf_file): + os.remove(psf_file) + piffify_exe = get_script_name('piffify') + p = subprocess.Popen( [piffify_exe, 'pixel_moffat.yaml'] ) + p.communicate() psf = piff.read(psf_file) test_star = psf.drawStar(target_star) - test_star_jax.append(test_star) - print("Max abs diff = ",np.max(np.abs(test_star.image.array - test_im.array))) np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) - # Test using the piffify executable - with open('pixel_moffat.yaml','w') as f: - f.write(yaml.dump(conf, default_flow_style=False)) - if __name__ == '__main__': - print("Running piffify executable") - if os.path.exists(psf_file): - os.remove(psf_file) - piffify_exe = get_script_name('piffify') - p = subprocess.Popen( [piffify_exe, 'pixel_moffat.yaml'] ) - p.communicate() - psf = piff.read(psf_file) - test_star = psf.drawStar(target_star) - np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3) - - # test copy_image property of draw - target_star_copy = psf.interp.interpolate(target_star) - test_star_copy = psf.model.draw(target_star_copy, copy_image=True) - test_star_nocopy = psf.model.draw(target_star_copy, copy_image=False) - # if we modify target_star_copy, then test_star_nocopy should be modified, but not test_star_copy - target_star_copy.image.array[0,0] = 23456 - assert test_star_nocopy.image.array[0,0] == target_star_copy.image.array[0,0] - assert test_star_copy.image.array[0,0] != target_star_copy.image.array[0,0] - # however the other pixels SHOULD still be all the same value - assert test_star_nocopy.image.array[1,1] == target_star_copy.image.array[1,1] - assert test_star_copy.image.array[1,1] == target_star_copy.image.array[1,1] - - # check that drawing onto an image does not return a copy - image = psf.draw(x=x0, y=y0) - image_reference = psf.draw(x=x0, y=y0, image=image) - image_reference.array[0,0] = 123456 - assert image.array[0,0] == image_reference.array[0,0] - - # check that the two versions (jax vs numpy/scipy) of the test star are the same - np.testing.assert_allclose(test_star_jax[0].image.array, test_star_jax[1].image.array) + # test copy_image property of draw + target_star_copy = psf.interp.interpolate(target_star) + test_star_copy = psf.model.draw(target_star_copy, copy_image=True) + test_star_nocopy = psf.model.draw(target_star_copy, copy_image=False) + # if we modify target_star_copy, then test_star_nocopy should be modified, but not test_star_copy + target_star_copy.image.array[0,0] = 23456 + assert test_star_nocopy.image.array[0,0] == target_star_copy.image.array[0,0] + assert test_star_copy.image.array[0,0] != target_star_copy.image.array[0,0] + # however the other pixels SHOULD still be all the same value + assert test_star_nocopy.image.array[1,1] == target_star_copy.image.array[1,1] + assert test_star_copy.image.array[1,1] == target_star_copy.image.array[1,1] + + # check that drawing onto an image does not return a copy + image = psf.draw(x=x0, y=y0) + image_reference = psf.draw(x=x0, y=y0, image=image) + image_reference.array[0,0] = 123456 + assert image.array[0,0] == image_reference.array[0,0] + + # check that the two versions (jax vs numpy/scipy) of the test star are the same + np.testing.assert_allclose(test_star_jax[0].image.array, test_star_jax[1].image.array) @timer def test_des_image(): From 55e753bbfc0407383c3e029a916fc0e6c4dfaa0a Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Mon, 29 Jul 2024 23:54:49 -0400 Subject: [PATCH 14/15] remove logger --- piff/basis_interp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/piff/basis_interp.py b/piff/basis_interp.py index bd95abc6..cfe0cfdf 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -449,7 +449,6 @@ def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, use_jax= self.use_jax = use_jax if not CAN_USE_JAX and self.use_jax: - logger.warning("JAX is not installed. Switching to numpy/scipy.") self.use_jax = False if self._max_order<0 or np.any(np.array(self._orders) < 0): From 11c49b1d02a7663f15bb0dc8cf545e1fec507b75 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Tue, 30 Jul 2024 03:46:52 -0400 Subject: [PATCH 15/15] add logger properly --- piff/basis_interp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/piff/basis_interp.py b/piff/basis_interp.py index cfe0cfdf..3782acc5 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -432,6 +432,8 @@ class BasisPolynomial(BasisInterp): def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, use_jax=False, logger=None): super(BasisPolynomial, self).__init__() + logger = galsim.config.LoggerWrapper(logger) + self._keys = keys if hasattr(order,'__len__'): if not len(order)==len(keys): @@ -449,6 +451,7 @@ def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, use_jax= self.use_jax = use_jax if not CAN_USE_JAX and self.use_jax: + logger.warning("JAX not installed. Reverting to numpy/scipy.") self.use_jax = False if self._max_order<0 or np.any(np.array(self._orders) < 0):