Skip to content

Commit

Permalink
Merge pull request #8 from openmrslab/mrs_modifications
Browse files Browse the repository at this point in the history
Massive overhaul of Amares style singlet fitting and 100% test coverage of the module.
  • Loading branch information
bennyrowland authored Sep 20, 2016
2 parents e138008 + 6afc09b commit ca720a3
Show file tree
Hide file tree
Showing 2 changed files with 371 additions and 113 deletions.
183 changes: 77 additions & 106 deletions suspect/fitting/singlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy
import scipy.optimize
import numbers
import copy

import suspect.basis

Expand Down Expand Up @@ -50,18 +51,15 @@ def fit(fid, model, baseline_points=16):
:param fid: MRSData object of FID to be fit
:param model: dictionary model of fit parameters
:param baseline_points: the number of points at the start of the FID to ignore
:return: ["model": optimized model, "fit": fitting data, "err": dictionary of standard errors]
:return: Dictionary containing ["model": optimized model, "fit": fitting data, "err": dictionary of standard errors]
"""

# List of metabolite names
metabolite_name_list = []

# Get list of metabolite names.
def get_metabolites(model_input):
metabolites = []
for name, value in model_input.items():
if type(value) is dict:
metabolites.append(name)
for fid_property_name, fid_property_value in model_input.items():
if type(fid_property_value) is dict:
metabolites.append(fid_property_name)
return metabolites

# Get standard errors from lmfit MinimizerResult object.
Expand All @@ -78,7 +76,7 @@ def phase_fid(fid_in, phase0, phase1):
:param fid_in: FID to be fitted.
:param phase1: phase1 value.
:param phase0: phase0 value.
:return:
:return: FID that has been shifted into phase by FFT
"""
spectrum = numpy.fft.fftshift(numpy.fft.fft(fid_in))
np = fid_in.np
Expand All @@ -94,23 +92,24 @@ def make_basis(params, time_axis):
:param time_axis: the time axis.
:return: a matrix containing the generated basis set.
"""
# metabolite_name_list = []
# for param in params.keys():
# split = param.split('_')
# if len(split) == 2:
# if split[0] not in metabolite_name_list:
# metabolite_name_list.append(split[0])

basis_matrix = numpy.matrix(numpy.zeros((len(metabolite_name_list), len(time_axis) * 2)))
for i, metabolite_name in enumerate(metabolite_name_list):
gaussian = suspect.basis.gaussian(time_axis,
params["{}_frequency".format(metabolite_name)],
params["{}_phase".format(metabolite_name)].value,
params["{}_width".format(metabolite_name)])
params["{}_fwhm".format(metabolite_name)])
real_gaussian = complex_to_real(gaussian)
basis_matrix[i, :] = real_gaussian
return basis_matrix

def unphase(data, params):

unphased_data = phase_fid(data, -params['phase0'], -params['phase1'])
real_unphased_data = complex_to_real(unphased_data)

return real_unphased_data

def do_fit(params, time_axis, real_unphased_data):
"""
This function performs the fitting.
Expand All @@ -137,11 +136,9 @@ def residual(params, time_axis, data):
:param data: FID to be fitted.
:return: residual values of baseline points.
"""
# unphase the data to make it pure absorptive
unphased_data = phase_fid(data, -params['phase0'], -params['phase1'])
real_unphased_data = complex_to_real(unphased_data)

fitted_data, _ = do_fit(params, time_axis, real_unphased_data)
real_unphased_data = unphase(data, params)
fitted_data, weights = do_fit(params, time_axis, real_unphased_data)
res = fitted_data - real_unphased_data

return res[baseline_points:-baseline_points]
Expand All @@ -160,24 +157,13 @@ def fit_data(data, initial_params):
args=(data.time_axis(), data),
xtol=5e-3)

unphased_data = phase_fid(data,
-fitting_result.params['phase0'],
-fitting_result.params['phase1'])
real_unphased_data = complex_to_real(unphased_data)
real_fitted_data, fitting_weights = do_fit(fitting_result.params, data.time_axis(), real_unphased_data)
real_fitted_data, fitting_weights = do_fit(fitting_result.params, data.time_axis(), unphase(data, fitting_result.params))
fitted_data = real_to_complex(real_fitted_data)

return fitting_weights, fitted_data, fitting_result

# Convert lmfit parameters to model format
def parameters_to_model(parameters_obj, param_weights):
# metabolite_name_list = []
# for param in parameters_obj.keys():
# split = param.split('_')
# if len(split) == 2:
# if split[0] not in metabolite_name_list:
# metabolite_name_list.append(split[0])
# Create dictionary for new model.
new_model = {}
for param_name, param in parameters_obj.items():
name = param_name.split("_")
Expand All @@ -191,7 +177,6 @@ def parameters_to_model(parameters_obj, param_weights):
else:
new_model[name1][name2] = param.value

nonlocal metabolite_name_list
for i, metabolite_name in enumerate(metabolite_name_list):
new_model[metabolite_name]["amplitude"] = param_weights[i]

Expand All @@ -205,35 +190,34 @@ def model_to_parameters(model_dict):
# Calculate dependencies/references for each parameter.
depend_dict = calculate_dependencies(model_dict)

model_dict_copy = copy.deepcopy(model_dict)
params.append(("phase0", model_dict_copy.pop("phase0")))
params.append(("phase1", model_dict_copy.pop("phase1")))

# Construct lmfit Parameter input for each parameter.
for name1, value1 in model_dict.items():
if type(value1) is int: # (e.g. phase0)
params.append((name1, value1))
if type(value1) is dict:
for peak_name, peak_properties in model_dict_copy.items():
# Fix phase value to 0 by default.
if "phase" not in value1:
params.append(("{}_{}".format(name1, "phase"), None, None, None, None, "0"))
for name2, value2 in value1.items():
if "phase" not in peak_properties:
params.append(("{0}_{1}".format(peak_name, "phase"), None, None, None, None, "0"))
for property_name, property_value in peak_properties.items():
# Initialize lmfit parameter arguments.
name = "{}_{}".format(name1, name2)
name = "{0}_{1}".format(peak_name, property_name)
value = None
vary = True
lmfit_min = None
lmfit_max = None
expr = None
if type(value2) is int:
value = value2
elif type(value2) is str:
expr = value2
if type(value2) is dict:
if "value" in value2:
value = value2["value"]
# if "vary" in value2:
# vary = value2["vary"]
if "min" in value2:
lmfit_min = value2["min"]
if "max" in value2:
lmfit_max = value2["max"]
if isinstance(property_value, numbers.Number):
value = property_value
elif isinstance(property_value, str):
expr = property_value
elif isinstance(property_value, dict):
if "value" in property_value:
value = property_value["value"]
if "min" in property_value:
lmfit_min = property_value["min"]
if "max" in property_value:
lmfit_max = property_value["max"]
# Add parameter object with defined parameters.
params.append((name, value, vary, lmfit_min, lmfit_max, expr)) # (lmfit Parameter input format)

Expand Down Expand Up @@ -267,87 +251,74 @@ def model_to_parameters(model_dict):

# Check if all model input types are correct.
def check_errors(check_model):
# Allowed names and keys in the model.
allowed_names = ["pcr", "atpc", "atpb", "atpa", "pi", "pme", "pde", "phase0", "phase1"]
# Allowed keys in the model.
allowed_keys = ["min", "max", "value", "phase", "amplitude"]

# Scan model.
for name1, value1 in check_model.items():
if type(value1) is not int and type(value1) is not float and type(value1) is not dict:
raise TypeError("Value of {} must be a number (for phases), or a dictionary.".format(name1))
elif name1 not in allowed_names:
raise NameError("{} is not an allowed name.".format(name1))
elif type(value1) is dict: # i.e. type(value) is not int
for name2, value2 in value1.items():
if type(value2) is not int and type(value2) is not float and type(value2) is not dict and \
type(value2) is not str:
raise TypeError("Value of {}_{} must be a value, an expression, or a dictionary."
.format(name1, name2))
if type(value2) is dict:
for key in value2:
for model_property, model_values in check_model.items():
if not isinstance(model_values, (numbers.Number, dict)):
raise TypeError("Value of {0} must be a number (for phases), or a dictionary.".format(model_property))
elif type(model_values) is dict: # i.e. type(value) is not int
for peak_property, peak_value in model_values.items():
if not isinstance(peak_value,(numbers.Number,dict,str)):
raise TypeError("Value of {0}_{1} must be a value, an expression, or a dictionary."
.format(model_property, peak_property))
if type(peak_value) is dict:
for width_param in peak_value:
# Dictionary must have 'value' key.
if "value" not in value2:
raise KeyError("Dictionary {}_{} is missing 'value' key.".format(name1, name2))
if "value" not in peak_value:
raise KeyError("Dictionary {0}_{1} is missing 'value' key."
.format(model_property, peak_property))
# Dictionary can only have 'min,' 'max,' and 'value'.
if key not in allowed_keys:
raise KeyError("In {}_{}, '{}' is not an allowed key.".format(name1, name2, key))

return
if width_param not in allowed_keys:
raise KeyError("In {0}_{1}, '{2}' is not an allowed key."
.format(model_property, peak_property, width_param))

# Calculate references to determine order for Parameters.
def calculate_dependencies(unordered_model):
dependencies = {} # (name, [dependencies])

# Compile dictionary of effective names.
for name1, value1 in unordered_model.items():
if type(value1) is dict: # i.e. not phase
for name2 in value1:
dependencies["{}_{}".format(name1, name2)] = None
for model_property, model_values in unordered_model.items():
if type(model_values) is dict: # i.e. pcr, not phase
for peak_property in model_values:
dependencies["{0}_{1}".format(model_property, peak_property)] = None

# Find dependencies for each effective name.
for name1, value1 in unordered_model.items():
if type(value1) is dict: # i.e. not phase
for name2, value2 in value1.items():
if type(value2) is str:
lmfit_name = "{}_{}".format(name1, name2)
for model_property, model_values in unordered_model.items():
if type(model_values) is dict: # i.e. not phase
for peak_property, peak_value in model_values.items():
if type(peak_value) is str:
lmfit_name = "{0}_{1}".format(model_property, peak_property)
dependencies[lmfit_name] = []
for depend in dependencies:
if depend in value2:
if depend in peak_value:
dependencies[lmfit_name].append(depend)

# Check for circular dependencies.
for name, dependents in dependencies.items():
if type(dependents) is list:
for dependent in dependents:
if name in dependencies[dependent]:
raise ReferenceError("{} and {} reference each other, creating a circular reference."
if dependencies[dependent] is not None and name in dependencies[dependent]:
raise ReferenceError("{0} and {1} reference each other, creating a circular reference."
.format(name, dependent))

return dependencies

# Do singlet fitting
def main():
# Minimize and fit 31P data.
# Check for errors in model formatting.
check_errors(model)
# Minimize and fit 31P data.

check_errors(model) # Check for errors in model formatting.

metabolite_name_list = get_metabolites(model) # Set list of metabolite names.

# Set list of metabolite names.
nonlocal metabolite_name_list
metabolite_name_list = get_metabolites(model)
parameters = model_to_parameters(model) # Convert model to lmfit Parameters object.

# Convert model to lmfit Parameters object.
parameters = model_to_parameters(model)
fitted_weights, fitted_data, fitted_results = fit_data(fid, parameters) # Fit data.

# Fit data.
fitted_weights, fitted_data, fitted_results = fit_data(fid, parameters)
final_model = parameters_to_model(fitted_results.params, fitted_weights) # Convert fit parameters to model format.

# Convert fit parameters to model format.
final_model = parameters_to_model(fitted_results.params, fitted_weights)
# Get stderr values for each parameter.
stderr = get_errors(fitted_results)
stderr = get_errors(fitted_results) # Get stderr values for each parameter.

# Compile output into a dictionary.
return_dict = {"model": final_model, "fit": fitted_data, "errors": stderr}
return return_dict
return_dict = {"model": final_model, "fit": fitted_data, "errors": stderr} # Compile output into a dictionary.
return return_dict

return main()
Loading

0 comments on commit ca720a3

Please sign in to comment.