Skip to content

Commit

Permalink
Merge pull request #26 from jmborr/fitting
Browse files Browse the repository at this point in the history
🎆
  • Loading branch information
jmborr authored Jan 17, 2018
2 parents 03fb82c + 2bee884 commit 17866ba
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ install:
conda install --yes -n testenv flake8;
fi
- source activate testenv
- pip install codecov lmfit sphinx sphinx_rtd_theme
- pip install codecov lmfit matplotlib sphinx sphinx_rtd_theme

script:
- py.test --cov=qef qef tests
Expand Down
1 change: 1 addition & 0 deletions docs/qef/models/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Models

deltadirac
strexpft
tabulatedmodel
8 changes: 8 additions & 0 deletions docs/qef/models/tabulatedmodel.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
===============
TabulatedModel
===============

.. automodule:: qef.models.tabulatedmodel
:members:
:undoc-members:
:show-inheritance:
5 changes: 0 additions & 5 deletions qef/models/deltadirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ def delta_dirac(x, amplitude=1.0, center=0.0):
Integrated intensity of the curve
center : float
position of the peak
Returns
-------
values: :class:`~numpy:numpy.ndarray`
function values over the domain
"""
dx = (x[-1] - x[0]) / (len(x) - 1) # domain spacing
y = np.zeros(len(x))
Expand Down
63 changes: 63 additions & 0 deletions qef/models/tabulatedmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from scipy.interpolate import interp1d
from lmfit import Model, models


class TabulatedModel (Model):
"""fitting the tabulated Model to some arbitrary points
Parameters
----------
xs: :class:`~numpy:numpy.ndarray`
given domain of the function, energy
ys: :class:`~numpy:numpy.ndarray`
given domain of the function, intensity
amplitude : float
peak intensity of the curve
center : float
position of the peak
"""

def __init__(self, xs, ys, *args, **kwargs):
self._interp = interp1d(xs, ys, fill_value='extrapolate', kind='cubic')

def interpolator(x, amplitude, center):
return amplitude * self._interp(x - center)

super(TabulatedModel, self).__init__(interpolator, *args, **kwargs)

def guess(self, data, x, **kwargs):

"""Guess starting values for the parameters of a model.
Parameters
----------
data: :class:`~numpy:numpy.ndarray`
data to be fitted
x: :class:`~numpy:numpy.ndarray`
energy domain where the interpolation required
kwargs : dict
additional optional arguments, passed to model function.
Returns
-------
:class:`~lmfit.parameter.Parameters`
parameters with guessed values
"""
params = self.make_params()

def pset(param, value, min):
params["%s%s" % (self.prefix, param)].set(value=value, min=min)

x_at_max = x[models.index_of(data, max(data))]
ysim = self.eval(x=x_at_max, amplitude=1, center=x_at_max)
amplitude = max(data) / ysim
pset("amplitude", amplitude, min=0.0)
pset("center", x_at_max, min=-1000)
return models.update_param_vals(params, self.prefix, **kwargs)
28 changes: 28 additions & 0 deletions tests/models/test_tabulatedmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import (absolute_import, division, print_function)

import os
import pytest
import numpy as np
from lmfit.lineshapes import lorentzian
from qef.models.tabulatedmodel import TabulatedModel


def test_tabulatedmodel():
x_sim = np.arange(-1.0, 1.0, 0.0003) # energy domain, in meV
y_sim = lorentzian(x_sim, amplitude=1, center=0, sigma=0.042)
intensity = 42.0
peak_center = 0.0002
x_exp = np.arange(-0.1, 0.5, 0.0004)
y_exp = lorentzian(x_exp, amplitude=intensity, center=peak_center,
sigma=0.042)

model = TabulatedModel(x_sim, y_sim)
params = model.guess(x_exp, y_exp)
fit = model.fit(y_exp, params, x=x_exp, fit_kws={'nan_policy': 'omit'})

assert abs(fit.best_values['amplitude'] - intensity) < 0.0001
assert abs(fit.best_values['center'] - peak_center) < 0.0001


if __name__ == '__main__':
pytest.main([os.path.abspath(__file__)])

0 comments on commit 17866ba

Please sign in to comment.