Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions src/diffpy/labpdfproc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def get_path_length(self, grid_point, angle):
return total_distance, primary_distance, secondary_distance


def _cve_brute_force(mud):
def _cve_brute_force(diffraction_data, mud):
"""
compute cve for the given mud on a global grid using the brute-force method
assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1
Expand All @@ -190,10 +190,21 @@ def _cve_brute_force(mud):
distances = np.array(distances) / abs_correction.total_points_in_grid
muls = np.array(muls) / abs_correction.total_points_in_grid
cve = 1 / muls
return cve

abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
abdo.insert_scattering_quantity(
TTH_GRID,
cve,
"tth",
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
)
return abdo


def _cve_polynomial_interpolation(mud):
def _cve_polynomial_interpolation(diffraction_data, mud):
"""
compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6)
"""
Expand All @@ -208,32 +219,43 @@ def _cve_polynomial_interpolation(mud):
]
muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e)
cve = 1 / muls
return cve

abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
abdo.insert_scattering_quantity(
TTH_GRID,
cve,
"tth",
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
)
return abdo


def _compute_cve(method, mud):
def _cve_method(method):
"""
compute cve for the given mud on a global grid using the specified method
retrieve the cve computation function for the given method
"""
methods = {
"brute_force": _cve_brute_force,
"polynomial_interpolation": _cve_polynomial_interpolation,
}
return methods[method](mud)
if method not in CVE_METHODS:
raise ValueError(f"Unknown method: {method}. Allowed methods are {*CVE_METHODS, }.")
return methods[method]


def interpolate_cve(diffraction_data, mud, wavelength, method="polynomial_interpolation"):
def compute_cve(diffraction_data, mud, method="polynomial_interpolation"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good practice here is to make the default as method=None then in the code do

If method is None:
     method = "polynomial_interpolation"

"""
compute and interpolate the cve for the given diffraction data, mud, and wavelength, using the selected method
compute and interpolate the cve for the given diffraction data and mud using the selected method

Parameters
----------
diffraction_data Diffraction_object
the diffraction pattern
mud float
the mu*D of the diffraction object, where D is the diameter of the circle
wavelength float
the wavelength of the diffraction object
method str
the method used to calculate cve

Expand All @@ -243,10 +265,13 @@ def interpolate_cve(diffraction_data, mud, wavelength, method="polynomial_interp

"""

cve = _compute_cve(method, mud)
cve_function = _cve_method(method)
abdo_on_global_tth = cve_function(diffraction_data, mud)
global_tth = abdo_on_global_tth.on_tth[0]
cve_on_global_tth = abdo_on_global_tth.on_tth[1]
orig_grid = diffraction_data.on_tth[0]
newcve = np.interp(orig_grid, TTH_GRID, cve)
abdo = Diffraction_object(wavelength=wavelength)
newcve = np.interp(orig_grid, global_tth, cve_on_global_tth)
abdo = Diffraction_object(wavelength=diffraction_data.wavelength)
abdo.insert_scattering_quantity(
orig_grid,
newcve,
Expand Down
8 changes: 4 additions & 4 deletions src/diffpy/labpdfproc/labpdfprocapp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from argparse import ArgumentParser

from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, interpolate_cve
from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, compute_cve
from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args
from diffpy.utils.parsers.loaddata import loadData
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
Expand Down Expand Up @@ -115,8 +115,8 @@ def main():
for filepath in args.input_paths:
outfilestem = filepath.stem + "_corrected"
corrfilestem = filepath.stem + "_cve"
outfile = args.output_directory / (outfilestem + ".chi")
corrfile = args.output_directory / (corrfilestem + ".chi")
outfile = args.output_directory / (outfilestem + ".xy")
corrfile = args.output_directory / (corrfilestem + ".xy")

if outfile.exists() and not args.force_overwrite:
sys.exit(
Expand All @@ -140,7 +140,7 @@ def main():
metadata=load_metadata(args, filepath),
)

absorption_correction = interpolate_cve(input_pattern, args.mud, args.wavelength, args.method)
absorption_correction = compute_cve(input_pattern, args.mud, args.method)
corrected_data = apply_corr(input_pattern, absorption_correction)
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
corrected_data.dump(f"{outfile}", xtype="tth")
Expand Down
6 changes: 3 additions & 3 deletions src/diffpy/labpdfproc/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from diffpy.labpdfproc.functions import Gridded_circle, apply_corr, interpolate_cve
from diffpy.labpdfproc.functions import Gridded_circle, apply_corr, compute_cve
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object

params1 = [
Expand Down Expand Up @@ -69,13 +69,13 @@ def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"):
return test_do


def test_interpolate_cve(mocker):
def test_compute_cve(mocker):
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
expected_cve = np.array([0.5, 0.5, 0.5])
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
mocker.patch("numpy.interp", return_value=expected_cve)
input_pattern = _instantiate_test_do(xarray, yarray)
actual_abdo = interpolate_cve(input_pattern, mud=1, wavelength=1.54)
actual_abdo = compute_cve(input_pattern, mud=1)
expected_abdo = _instantiate_test_do(
xarray,
expected_cve,
Expand Down