Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try to switch from numpy/scipy to jax in some part of basis interp #166

Merged
merged 15 commits into from
Jul 30, 2024
Merged
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ jobs:
# https://github.com/python-pillow/Pillow/issues/7259
pip install -U nose coverage "pytest<8" nbval ipykernel "pillow<10"

- name: Install Jax
if: matrix.py == 3.9 || matrix.py == 3.10 || matrix.py == 3.11 || matrix.py == 3.12
run: |
pip install -U jax

- name: Install Pixmappy (not on pip)
run: |
git clone https://github.com/gbernstein/pixmappy.git
Expand Down
122 changes: 101 additions & 21 deletions piff/basis_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,54 @@
from .interp import Interp
from .star import Star

try:
import jax
from jax import jit
from jax import numpy as jnp
from jax import vmap

except ImportError:
CAN_USE_JAX = False
# define dummy functions for jax
def jit(f):
return f
else:
CAN_USE_JAX = True
jax.config.update("jax_enable_x64", True)

# Bellow are implementations of _solve_direct using JAX.
# if jax.config.update("jax_enable_x64", True) it will give the
# same results as the original code in double precision, but will run
# slower, but still faster than the numpy/scipy version.

@jit
def jax_solve(ATA, ATb):
# Original code:
# dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False)
# New code:
(factor, lower) = (jax.scipy.linalg.cholesky(ATA, overwrite_a=True, lower=False), False)
dq = jax.scipy.linalg.cho_solve((factor, lower), ATb, overwrite_b=False)
return dq

@jit
def build_ATA_ATb(alpha, beta, K):
ATb = (beta[:, jnp.newaxis] * K).flatten()
tmp1 = alpha[:, :, jnp.newaxis] * K
ATA = K[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] * tmp1[:, jnp.newaxis, :, :]
return ATA, ATb

@jit
def vmap_build_ATA_ATb(Ks, alphas, betas):
# Use vmap to vectorize build_ATA_ATb across the first dimension of Ks, alphas, and betas
vmapped_build_ATA_ATb = vmap(build_ATA_ATb, in_axes=(0, 0, 0))
# Get the vectorized results
ATAs, ATbs = vmapped_build_ATA_ATb(alphas, betas, Ks)
# Sum the results along the first axis
ATb = jnp.sum(ATbs, axis=0)
ATA = jnp.sum(ATAs, axis=0)
return ATA, ATb


class BasisInterp(Interp):
r"""An Interp class that works whenever the interpolating functions are
linear sums of basis functions. Does things the "slow way" to be stable to
Expand Down Expand Up @@ -55,6 +103,7 @@ def __init__(self):
self.use_qr = False # The default. May be overridden by subclasses.
self.q = None
self.set_num(None)
self._use_jax = False # The default. May be overridden by subclasses.

def initialize(self, stars, logger=None):
"""Initialize both the interpolator to some state prefatory to any solve iterations and
Expand Down Expand Up @@ -252,23 +301,42 @@ def _solve_direct(self, stars, logger):
ATA = np.zeros((nq, nq), dtype=float)
ATb = np.zeros(nq, dtype=float)

for s in stars:
# Get the basis function values at this star
K = self.basis(s)
# Sum contributions into ATA, ATb
if True:
ATb += (s.fit.beta[:,np.newaxis] * K).flatten()
tmp1 = s.fit.alpha[:,:,np.newaxis] * K
tmp2 = K[np.newaxis,:,np.newaxis,np.newaxis] * tmp1[:,np.newaxis,:,:]
ATA += tmp2.reshape(nq,nq)
else: # pragma: no cover
# This is equivalent, but slower.
# It is here to make more explicit the connection between this calculation
# and the corresponding part of the QR code above.
A1 = (s.fit.A[:,:,np.newaxis] * K[np.newaxis,:]).reshape(
s.fit.A.shape[0], s.fit.A.shape[1] * len(K))
ATb += A1.T.dot(s.fit.b)
ATA += A1.T.dot(A1)
if self.use_jax:
Ks = []
alphas = []
betas = []
for s in stars:
# Get the basis function values at this star
K = self.basis(s)
Ks.append(K)
alphas.append(s.fit.alpha)
betas.append(s.fit.beta)
alphas = np.array(alphas).reshape((len(alphas), alphas[0].shape[0], alphas[0].shape[1]))
betas = np.array(betas).reshape((len(betas), betas[0].shape[0]))
Ks = np.array(Ks).reshape((len(Ks), Ks[0].shape[0]))
ATA, ATb = vmap_build_ATA_ATb(Ks, alphas, betas)
ATA = ATA.reshape(nq,nq)
else:
for s in stars:
# Get the basis function values at this star
K = self.basis(s)
# Sum contributions into ATA, ATb

if True:
alpha = s.fit.alpha
beta = s.fit.beta
ATb += (beta[:,np.newaxis] * K).flatten()
tmp1 = alpha[:,:,np.newaxis] * K
tmp2 = K[np.newaxis,:,np.newaxis,np.newaxis] * tmp1[:,np.newaxis,:,:]
ATA += tmp2.reshape(nq,nq)
else: # pragma: no cover
# This is equivalent, but slower.
# It is here to make more explicit the connection between this calculation
# and the corresponding part of the QR code above.
A1 = (s.fit.A[:,:,np.newaxis] * K[np.newaxis,:]).reshape(
s.fit.A.shape[0], s.fit.A.shape[1] * len(K))
ATb += A1.T.dot(s.fit.b)
ATA += A1.T.dot(A1)

logger.info('Beginning solution of matrix size %s',ATA.shape)
try:
Expand All @@ -278,7 +346,10 @@ def _solve_direct(self, stars, logger):
# assuming just 'sym' instead (which does an LDL decomposition rather than
# Cholesky) would help. If this fails, the matrix is usually high enough
# condition that it is functionally singular, and switching to SVD is warranted.
dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False)
if self.use_jax:
dq = jax_solve(ATA, ATb)
else:
dq = scipy.linalg.solve(ATA, ATb, assume_a='pos', check_finite=False)

if len(w) > 0:
# scipy likes to warn about high condition. They aren't actually a problem
Expand Down Expand Up @@ -352,15 +423,17 @@ class BasisPolynomial(BasisInterp):
:param use_qr: Use QR decomposition for the solution rather than the more direct least
squares solution. QR decomposition requires more memory than the default
and is somewhat slower (nearly a factor of 2); however, it is significantly
less susceptible to numerical errors from high condition matrices.
:param use_jax: Use JAX for solving the linear algebra equations rather than numpy/scipy.
Therefore, it may be preferred for some use cases. [default: False]
:param logger: A logger object for logging debug info. [default: None]
"""
_type_name = 'BasisPolynomial'

def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, logger=None):
def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, use_jax=False, logger=None):
super(BasisPolynomial, self).__init__()

logger = galsim.config.LoggerWrapper(logger)

self._keys = keys
if hasattr(order,'__len__'):
if not len(order)==len(keys):
Expand All @@ -375,14 +448,21 @@ def __init__(self, order, keys=('u','v'), max_order=None, use_qr=False, logger=N
self._max_order = max_order
self.use_qr = use_qr

self.use_jax = use_jax

if not CAN_USE_JAX and self.use_jax:
logger.warning("JAX not installed. Reverting to numpy/scipy.")
self.use_jax = False

if self._max_order<0 or np.any(np.array(self._orders) < 0):
# Exception if we have any requests for negative orders
raise ValueError('Negative polynomial order specified')

self.kwargs = {
'order' : order,
'keys' : keys,
'use_qr' : use_qr
'use_qr' : use_qr,
'use_jax' : use_jax,
}

# Now build a mask that picks the desired polynomial products
Expand Down
86 changes: 52 additions & 34 deletions tests/test_pixel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
import os
import fitsio
import copy

from piff_test_helper import get_script_name, timer, CaptureLog

Expand Down Expand Up @@ -1025,53 +1026,70 @@ def test_single_image():
},
'interp' : {
'type' : 'BasisPolynomial',
'order' : order
'order' : order,
'use_jax' : False,
},
},
}

if __name__ == '__main__':
config['verbose'] = 2
else:
config['verbose'] = 0

print("Running piffify function")
piff.piffify(config)
psf = piff.read(psf_file)
test_star = psf.drawStar(target_star)
print("Max abs diff = ",np.max(np.abs(test_star.image.array - test_im.array)))
np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3)
try:
import jax
jax.config.update("jax_enable_x64", True)
except ImportError:
print("JAX not installed.")

# Test using the piffify executable
with open('pixel_moffat.yaml','w') as f:
f.write(yaml.dump(config, default_flow_style=False))
if __name__ == '__main__':
print("Running piffify executable")
if os.path.exists(psf_file):
os.remove(psf_file)
piffify_exe = get_script_name('piffify')
p = subprocess.Popen( [piffify_exe, 'pixel_moffat.yaml'] )
p.communicate()
configJax = copy.deepcopy(config)
configJax['psf']['interp']['use_jax'] = True
test_star_jax = []

for conf in [config, configJax]:
print("Running piffify function")
piff.piffify(conf)
psf = piff.read(psf_file)
test_star = psf.drawStar(target_star)
test_star_jax.append(test_star)
print("Max abs diff = ",np.max(np.abs(test_star.image.array - test_im.array)))
np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3)

# test copy_image property of draw
target_star_copy = psf.interp.interpolate(target_star)
test_star_copy = psf.model.draw(target_star_copy, copy_image=True)
test_star_nocopy = psf.model.draw(target_star_copy, copy_image=False)
# if we modify target_star_copy, then test_star_nocopy should be modified, but not test_star_copy
target_star_copy.image.array[0,0] = 23456
assert test_star_nocopy.image.array[0,0] == target_star_copy.image.array[0,0]
assert test_star_copy.image.array[0,0] != target_star_copy.image.array[0,0]
# however the other pixels SHOULD still be all the same value
assert test_star_nocopy.image.array[1,1] == target_star_copy.image.array[1,1]
assert test_star_copy.image.array[1,1] == target_star_copy.image.array[1,1]

# check that drawing onto an image does not return a copy
image = psf.draw(x=x0, y=y0)
image_reference = psf.draw(x=x0, y=y0, image=image)
image_reference.array[0,0] = 123456
assert image.array[0,0] == image_reference.array[0,0]
# Test using the piffify executable
with open('pixel_moffat.yaml','w') as f:
f.write(yaml.dump(conf, default_flow_style=False))
if __name__ == '__main__':
print("Running piffify executable")
if os.path.exists(psf_file):
os.remove(psf_file)
piffify_exe = get_script_name('piffify')
p = subprocess.Popen( [piffify_exe, 'pixel_moffat.yaml'] )
p.communicate()
psf = piff.read(psf_file)
test_star = psf.drawStar(target_star)
np.testing.assert_almost_equal(test_star.image.array/2., test_im.array/2., decimal=3)

# test copy_image property of draw
target_star_copy = psf.interp.interpolate(target_star)
test_star_copy = psf.model.draw(target_star_copy, copy_image=True)
test_star_nocopy = psf.model.draw(target_star_copy, copy_image=False)
# if we modify target_star_copy, then test_star_nocopy should be modified, but not test_star_copy
target_star_copy.image.array[0,0] = 23456
assert test_star_nocopy.image.array[0,0] == target_star_copy.image.array[0,0]
assert test_star_copy.image.array[0,0] != target_star_copy.image.array[0,0]
# however the other pixels SHOULD still be all the same value
assert test_star_nocopy.image.array[1,1] == target_star_copy.image.array[1,1]
assert test_star_copy.image.array[1,1] == target_star_copy.image.array[1,1]

# check that drawing onto an image does not return a copy
image = psf.draw(x=x0, y=y0)
image_reference = psf.draw(x=x0, y=y0, image=image)
image_reference.array[0,0] = 123456
assert image.array[0,0] == image_reference.array[0,0]

# check that the two versions (jax vs numpy/scipy) of the test star are the same
np.testing.assert_allclose(test_star_jax[0].image.array, test_star_jax[1].image.array)

@timer
def test_des_image():
Expand Down