Skip to content

Commit

Permalink
Merge pull request #156 from robelgeda/depreciate
Browse files Browse the repository at this point in the history
Depreciate `PSFModel`
  • Loading branch information
robelgeda authored Aug 6, 2023
2 parents 65a5a8e + 2651600 commit d2d5c18
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 261 deletions.
185 changes: 0 additions & 185 deletions petrofit/modeling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,191 +62,6 @@ def make_grid(size, origin=(0, 0), factor=1):
return np.meshgrid(x_arange, y_arange)


class PSFModel(FittableModel):
"""
PSFModel is deprecated as of v0.4.0 and will be
removed in the next release, use `PSFConvolvedModel2D` instead.
"""
oversample = None

_cache_grid = True
_cached_grid_size = 0
_cached_grid_factor = 0
_cached_grid = None

def clear_cached_grid(self):
self._cached_grid_size = 0
self._cached_grid_factor = 0
if self._cached_grid is not None:
del self._cached_grid
self._cached_grid = None

@property
def cache_grid(self):
return self._cache_grid

@cache_grid.setter
def cache_grid(self, value):
if value == False:
self._cache_grid = False
self.clear_cached_grid()
elif value == True:
self._cache_grid = True
else:
raise ValueError("{} is not a bool, use True or False".format(value))

def evaluate(self, *args, **kwargs):
psf_p = args[-1]
args = args[:-1]

x = args[0]
y = args[1]

assert not np.any(x < 0), 'negative pixel values not supported at this time'
assert not np.any(y < 0), 'negative pixel values not supported at this time'

grid_size = max([i.max() + 1 for i in [x, y]])
grid_factor = self.oversample if isinstance(self.oversample, int) else 1

if grid_size == self._cached_grid_size and self._cached_grid_factor == grid_factor:
main_grid = self._cached_grid
else:
main_grid = make_grid(grid_size, factor=grid_factor)

if self.cache_grid:
self._cached_grid = main_grid
self._cached_grid_size = grid_size
self._cached_grid_factor = grid_factor

x_grid, y_grid = main_grid

model_image = self._model.evaluate(x_grid, y_grid, *args[self.n_inputs:])

if isinstance(self.oversample, int):
model_image = block_reduce(model_image, grid_factor) / grid_factor ** 2

elif isinstance(self.oversample, tuple):
sub_grid_x0, sub_grid_y0, sub_grid_size, sub_grid_factor = self.oversample

if isinstance(sub_grid_x0, str):
assert sub_grid_x0 in self._model.param_names, "oversample param '{}' is not in the wrapped model param list".format(
sub_grid_x0)

idx = self._model.param_names.index(sub_grid_x0)
sub_grid_x0 = args[self.n_inputs:][idx][0]

if isinstance(sub_grid_y0, str):
assert sub_grid_y0 in self._model.param_names, "oversample param '{}' is not in the wrapped model param list".format(
sub_grid_y0)

idx = self._model.param_names.index(sub_grid_y0)
sub_grid_y0 = args[self.n_inputs:][idx][0]

x_sub_grid, y_sub_grid = make_grid(sub_grid_size, factor=sub_grid_factor)

x_sub_grid += int(sub_grid_x0) - sub_grid_size // 2
y_sub_grid += int(sub_grid_y0) - sub_grid_size // 2

sub_model_oversampled_image = self._model.evaluate(x_sub_grid, y_sub_grid, *args[self.n_inputs:])

# Experimental
# over_sampled_sub_model_x0 = np.argmin(
# np.abs(x_sub_grid[0, :] - 1 / (2 * sub_grid_factor) - (sub_grid_x0 * sub_grid_factor)))
# over_sampled_sub_model_y0 = np.argmin(
# np.abs(y_sub_grid[:, 0] - 1 / (2 * sub_grid_factor) - (sub_grid_y0 * sub_grid_factor)))
#
# sub_model_oversampled_image[
# over_sampled_sub_model_y0,
# over_sampled_sub_model_x0
# ] = self._model.evaluate(sub_grid_x0, sub_grid_y0, *args[self.n_inputs:])

sub_model_image = block_reduce(sub_model_oversampled_image, sub_grid_factor) / sub_grid_factor ** 2

x_sub_min = int(x_sub_grid[0][0] - 1 / (2 * sub_grid_factor)) + 1
y_sub_min = int(y_sub_grid[0][0] - 1 / (2 * sub_grid_factor)) + 1

model_image[
y_sub_min: y_sub_min + sub_grid_size,
x_sub_min: x_sub_min + sub_grid_size
] = sub_model_image

if self.psf is None:
return model_image[y.astype(int), x.astype(int)]

else:
psf = self.psf
if psf_p[0] != 0:
psf = rotate(psf, psf_p[0], reshape=False)
return convolve(model_image, psf, mode='same')[y.astype(int), x.astype(int)]

@property
def model(self):

model = self._model.copy()
for param in model.param_names:
setattr(model, param, getattr(self, param).value)

fixed = self.fixed
del fixed['psf_pa']

bounds = self.bounds
del bounds['psf_pa']

model.fixed.update(fixed)
model.bounds.update(bounds)

return model

@staticmethod
def wrap(model, psf=None, oversample=None):
warnings.warn('''PSFModel is deprecated as of v0.4.0 and will be
removed in the next release, use `PSFConvolvedModel2D` instead.''', DeprecationWarning, stacklevel=2)

if isinstance(model, PSFModel):
raise TypeError("Can not wrap a PSFModel, try: PSFModel.wrap(psf_model.model)")

# Extract model params
params = OrderedDict(
(param_name, Parameter(param_name, default=param_val)) for param_name, param_val in
zip(model.param_names, model.parameters)
)

# Prepare class attributes
members = OrderedDict([
('__module__', '__main__'),
('__name__', 'PSFModel'),
('__doc__', 'PSF Wrapped Model\n{}'.format(model.__doc__)),
('n_inputs', model.n_inputs),
('n_outputs', model.n_outputs),
('psf', psf),
('_model', model),
])

# Add params to class attributes
members.update(params)
members.update({'psf_pa': Parameter('psf_pa', default=0)})

# Construct new model class
model_class = type('PSFModel', (PSFModel,), members)

# Init new model from new model class
new_model = model_class()

# Sync model states
new_model.fixed.update(model.fixed)
new_model.bounds.update(model.bounds)

# add oversample regions
if oversample is not None:
if isinstance(oversample, int) or isinstance(oversample, tuple):
new_model.oversample = oversample
else:
raise ValueError("oversample should be a single int factor or a tuple (x, y, size, factor).")

# Return new model
return new_model


class PSFConvolvedModel2D(FittableModel):
"""
Fittable model for converting `FittableModel` and `CompoundModel` into 2D images.
Expand Down
78 changes: 2 additions & 76 deletions petrofit/modeling/tests/test_fitting.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,10 @@
import numpy as np

from astropy.convolution import convolve
from astropy.modeling import FittableModel, Parameter, custom_model, models
from astropy.modeling import models

from petrofit.modeling.models import PSFModel, make_grid, PSFConvolvedModel2D
from petrofit.modeling.models import make_grid, PSFConvolvedModel2D
from petrofit.modeling import model_to_image, fit_model
from petrofit.segmentation import masked_segm_image

import matplotlib.pyplot as plt


def test_psfmodel():
"""Test fitting and PSF convolution"""

# Make model:

imsize = 300

sersic_model = models.Sersic2D(
amplitude=1,
r_eff=25,
n=2,
x_0=imsize / 2,
y_0=imsize / 2,
ellip=0,
theta=0,
bounds={
'amplitude': (0., None),
'r_eff': (0, None),
'n': (0, 10),
'ellip': (0, 1),
'theta': (-2 * np.pi, 2 * np.pi),
},
)

# Make model image
image = model_to_image(sersic_model, imsize)

# Make a PSF
x_grid, y_grid = make_grid(51, factor=1)
PSF = models.Moffat2D(x_0=25.0, y_0=25.0)(x_grid, y_grid)
PSF /= PSF.sum()

# Make a PSF image using model image and PSF
psf_sersic_image = convolve(image, PSF)

# Make a PSFModel
psf_sersic_model = PSFModel.wrap(sersic_model, psf=PSF, oversample=None)
psf_sersic_model.fixed['psf_pa'] = True

# Make a PSFModel image
psf_sersic_model_image = model_to_image(psf_sersic_model, imsize)

# Compare the PSF image to PSFModel image
error_arr = abs(psf_sersic_model_image - psf_sersic_image) / psf_sersic_image
assert np.max(error_arr) < 0.01 # max error less than 1% error

# Test fitting

# Change params with offset
psf_sersic_model.r_eff = 23
psf_sersic_model.x_0 = 0.6 + imsize / 2
psf_sersic_model.y_0 = 0.1 + imsize / 2
psf_sersic_model.n = 3


# Fit
fitted_model, fit_info = fit_model(
psf_sersic_image, psf_sersic_model,
maxiter=10000,
epsilon=1.4901161193847656e-10,
acc=1e-9,
)

# Generate a model image from the fitted model
fitted_model_image = model_to_image(fitted_model, imsize)

# Check if fit is close to actual
error_arr = abs(fitted_model_image - psf_sersic_image) / psf_sersic_image
assert np.max(error_arr) < 0.01 # max error less than 1% error


def test_psf_convolved_image_model():
Expand Down

0 comments on commit d2d5c18

Please sign in to comment.