Skip to content

Commit c739562

Browse files
Add Support for Excluded Regions in Fitting and Visualization (#58)
* Very quick and dirty implementation of excluded regions * Adds parent handling and item addition callback to Collection * Refactors excluded points handling in diffraction minimization Moves the logic for updating excluded points from minimization to ExcludedRegions, ensuring better modularity. * Extends unit test with excluded regions * Refactors excluded region logic in experiments Optimizes the process of updating the pattern's excluded points by avoiding unnecessary resets with each new excluded region. Initializes pattern's excluded points with default values on data load. * Excludes specified points from calculations for improve accuracy * Refines excluded region handling in experiments * Refines data plotting and analysis settings * Formats function signatures for readability * Removes unused assignment in test_fit_with_params * Adds 'excluded' flag to experiment mock data * Improves handling of excluded data points during minimization * Enhances exclusion handling in Plotter class * Fixes empty background and excluded regions check * Adds CIF display for experiment data * Updates the multiphase tutorial * Simplifies exclusion handling in data processing * Renames tutorial files and updates documentation * Updates tutorials and adds display for excluded regions * Refines tutorial markdown headers for consistency * Adjusts data preprocessing for stability in analysis Rounds x-values to 4 decimal places to align data size Replaces small uncertainty values with 1.0 to prevent failures in minimization algorithms
1 parent a98989d commit c739562

29 files changed

+10597
-236
lines changed

docs/mkdocs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ nav:
6767
- LBCO quick: tutorials/quick_single-fit_pd-neut-cwl_LBCO-HRPT.ipynb
6868
- LBCO basic: tutorials/basic_single-fit_pd-neut-cwl_LBCO-HRPT.ipynb
6969
- PbSO4 advanced: tutorials/advanced_joint-fit_pd-neut-xray-cwl_PbSO4.ipynb
70-
- Structure Refinement:
70+
- Standard Diffraction:
7171
- Co2SiO4 pd-neut-cwl: tutorials/cryst-struct_pd-neut-cwl_CoSiO4-D20.ipynb
7272
- HS pd-neut-cwl: tutorials/cryst-struct_pd-neut-cwl_HS-HRPT.ipynb
7373
- Si pd-neut-tof: tutorials/cryst-struct_pd-neut-tof_Si-SEPD.ipynb
74-
- NCAF pd-neut-tof: tutorials/cryst-struct_pd-neut-tof_NCAF-WISH.ipynb
74+
- NCAF pd-neut-tof: tutorials/cryst-struct_pd-neut-tof_multidata_NCAF-WISH.ipynb
7575
- LBCO+Si McStas: tutorials/cryst-struct_pd-neut-tof_multphase-LBCO-Si_McStas.ipynb
7676
- Pair Distribution Function:
7777
- Ni pd-neut-cwl: tutorials/pdf_pd-neut-cwl_Ni.ipynb

docs/tutorials/index.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ as self-contained, step-by-step **guides** to help users grasp the workflow of d
1010
analysis using EasyDiffraction.
1111

1212
Instructions on how to run the tutorials are provided in the
13-
[:material-cog-box: Installation & Setup](../installation-and-setup/index.md#running-tutorials)
13+
[:material-cog-box: Installation & Setup](../installation-and-setup/index.md#how-to-run-tutorials)
1414
section of the documentation.
1515

1616
The tutorials are organized into the following categories.
@@ -46,9 +46,14 @@ The tutorials are organized into the following categories.
4646
- [Si `pd-neut-tof`](cryst-struct_pd-neut-tof_Si-SEPD.ipynb)
4747
Demonstrates a Rietveld refinement of the Si crystal structure using
4848
time-of-flight neutron powder diffraction data from SEPD at Argonne.
49-
- [NCAF `pd-neut-tof`](cryst-struct_pd-neut-tof_NCAF-WISH.ipynb)
49+
- [NCAF `pd-neut-tof`](cryst-struct_pd-neut-tof_multidata_NCAF-WISH.ipynb)
5050
Demonstrates a Rietveld refinement of the Na2Ca3Al2F14 crystal structure
51-
using time-of-flight neutron powder diffraction data from WISH at ISIS.
51+
using two time-of-flight neutron powder diffraction datasets (from two
52+
detector banks) of the WISH instrument at ISIS.
53+
- [LBCO+Si McStas](cryst-struct_pd-neut-tof_multiphase-LBCO-Si_McStas.ipynb)
54+
Demonstrates a Rietveld refinement of the La0.5Ba0.5CoO3 crystal structure
55+
with a small amount of Si impurity as a secondary phase using time-of-flight
56+
neutron powder diffraction data simulated with McStas.
5257

5358
## Pair Distribution Function (PDF)
5459

src/easydiffraction/analysis/calculation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def set_calculator(self, engine: str) -> None:
3131
"""
3232
self._calculator = self.calculator_factory.create_calculator(engine)
3333

34-
def calculate_structure_factors(self, sample_models: SampleModels, experiments: Experiments) -> Optional[List[Any]]:
34+
def calculate_structure_factors(self,
35+
sample_models: SampleModels,
36+
experiments: Experiments) -> Optional[List[Any]]:
3537
"""
3638
Calculate HKL intensities (structure factors) for sample models and experiments.
3739
@@ -44,15 +46,17 @@ def calculate_structure_factors(self, sample_models: SampleModels, experiments:
4446
"""
4547
return self._calculator.calculate_structure_factors(sample_models, experiments)
4648

47-
def calculate_pattern(self, sample_models: SampleModels, experiment: Experiment) -> np.ndarray:
49+
def calculate_pattern(self,
50+
sample_models: SampleModels,
51+
experiment: Experiment) -> np.ndarray:
4852
"""
49-
Generate diffraction pattern based on sample models and experiment.
53+
Calculate diffraction pattern based on sample models and experiment.
5054
5155
Args:
5256
sample_models: Collection of sample models.
5357
experiment: A single experiment object.
5458
5559
Returns:
56-
Diffraction pattern generated by the backend calculator.
60+
Diffraction pattern calculated by the backend calculator.
5761
"""
5862
return self._calculator.calculate_pattern(sample_models, experiment)

src/easydiffraction/analysis/calculators/calculator_cryspy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _recreate_cryspy_dict(self,
127127

128128
cryspy_model_id = f'crystal_{sample_model.name}'
129129
cryspy_model_dict = cryspy_dict[cryspy_model_id]
130+
130131
# Cell
131132
cryspy_cell = cryspy_model_dict['unit_cell_parameters']
132133
cryspy_cell[0] = sample_model.cell.length_a.value
@@ -135,6 +136,7 @@ def _recreate_cryspy_dict(self,
135136
cryspy_cell[3] = np.deg2rad(sample_model.cell.angle_alpha.value)
136137
cryspy_cell[4] = np.deg2rad(sample_model.cell.angle_beta.value)
137138
cryspy_cell[5] = np.deg2rad(sample_model.cell.angle_gamma.value)
139+
138140
# Atomic coordinates
139141
cryspy_xyz = cryspy_model_dict['atom_fract_xyz']
140142
for idx, atom_site in enumerate(sample_model.atom_sites):

src/easydiffraction/analysis/minimization.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ def _process_fit_results(self,
7878
f_obs, f_calc = None, None
7979

8080
if self.results:
81-
self.results.display_results(y_obs=y_obs, y_calc=y_calc, y_err=y_err, f_obs=f_obs, f_calc=f_calc)
81+
self.results.display_results(y_obs=y_obs,
82+
y_calc=y_calc,
83+
y_err=y_err,
84+
f_obs=f_obs,
85+
f_calc=f_calc)
8286

8387
def _collect_free_parameters(self,
8488
sample_models: SampleModels,
@@ -140,13 +144,19 @@ def _residual_function(self,
140144
residuals: List[float] = []
141145

142146
for (expt_id, experiment), weight in zip(experiments._items.items(), _weights):
147+
148+
# Calculate the difference between measured and calculated patterns
143149
y_calc: np.ndarray = calculator.calculate_pattern(sample_models,
144150
experiment,
145-
called_by_minimizer=True) # True False
151+
called_by_minimizer=True)
146152
y_meas: np.ndarray = experiment.datastore.pattern.meas
147153
y_meas_su: np.ndarray = experiment.datastore.pattern.meas_su
148-
diff: np.ndarray = (y_meas - y_calc) / y_meas_su
149-
diff *= np.sqrt(weight) # Residuals are squared before going into reduced chi-squared
154+
diff = ((y_meas - y_calc) / y_meas_su)
155+
156+
# Residuals are squared before going into reduced chi-squared
157+
diff *= np.sqrt(weight)
158+
159+
# Append the residuals for this experiment
150160
residuals.extend(diff)
151161

152162
return self.minimizer.tracker.track(np.array(residuals), parameters)

src/easydiffraction/analysis/reliability_factors.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from easydiffraction.experiments.experiments import Experiments
55
from easydiffraction.analysis.calculators.calculator_base import CalculatorBase
66

7-
def calculate_r_factor(y_obs: np.ndarray, y_calc: np.ndarray) -> float:
7+
def calculate_r_factor(y_obs: np.ndarray,
8+
y_calc: np.ndarray) -> float:
89
"""
910
Calculate the R-factor (reliability factor) between observed and calculated data.
1011
@@ -22,7 +23,9 @@ def calculate_r_factor(y_obs: np.ndarray, y_calc: np.ndarray) -> float:
2223
return numerator / denominator if denominator != 0 else np.nan
2324

2425

25-
def calculate_weighted_r_factor(y_obs: np.ndarray, y_calc: np.ndarray, weights: np.ndarray) -> float:
26+
def calculate_weighted_r_factor(y_obs: np.ndarray,
27+
y_calc: np.ndarray,
28+
weights: np.ndarray) -> float:
2629
"""
2730
Calculate the weighted R-factor between observed and calculated data.
2831
@@ -42,7 +45,8 @@ def calculate_weighted_r_factor(y_obs: np.ndarray, y_calc: np.ndarray, weights:
4245
return np.sqrt(numerator / denominator) if denominator != 0 else np.nan
4346

4447

45-
def calculate_rb_factor(y_obs: np.ndarray, y_calc: np.ndarray) -> float:
48+
def calculate_rb_factor(y_obs: np.ndarray,
49+
y_calc: np.ndarray) -> float:
4650
"""
4751
Calculate the Bragg R-factor between observed and calculated data.
4852
@@ -60,7 +64,8 @@ def calculate_rb_factor(y_obs: np.ndarray, y_calc: np.ndarray) -> float:
6064
return numerator / denominator if denominator != 0 else np.nan
6165

6266

63-
def calculate_r_factor_squared(y_obs: np.ndarray, y_calc: np.ndarray) -> float:
67+
def calculate_r_factor_squared(y_obs: np.ndarray,
68+
y_calc: np.ndarray) -> float:
6469
"""
6570
Calculate the R-factor squared between observed and calculated data.
6671
@@ -78,7 +83,8 @@ def calculate_r_factor_squared(y_obs: np.ndarray, y_calc: np.ndarray) -> float:
7883
return np.sqrt(numerator / denominator) if denominator != 0 else np.nan
7984

8085

81-
def calculate_reduced_chi_square(residuals: np.ndarray, num_parameters: int) -> float:
86+
def calculate_reduced_chi_square(residuals: np.ndarray,
87+
num_parameters: int) -> float:
8288
"""
8389
Calculate the reduced chi-square statistic.
8490
@@ -99,7 +105,9 @@ def calculate_reduced_chi_square(residuals: np.ndarray, num_parameters: int) ->
99105
return np.nan
100106

101107

102-
def get_reliability_inputs(sample_models: SampleModels, experiments: Experiments, calculator: CalculatorBase) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
108+
def get_reliability_inputs(sample_models: SampleModels,
109+
experiments: Experiments,
110+
calculator: CalculatorBase) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
103111
"""
104112
Collect observed and calculated data points for reliability calculations.
105113
@@ -120,9 +128,13 @@ def get_reliability_inputs(sample_models: SampleModels, experiments: Experiments
120128
y_meas_su = experiment.datastore.pattern.meas_su
121129

122130
if y_meas is not None and y_calc is not None:
131+
# If standard uncertainty is not provided, use ones
132+
if y_meas_su is None:
133+
y_meas_su = np.ones_like(y_meas)
134+
123135
y_obs_all.extend(y_meas)
124136
y_calc_all.extend(y_calc)
125-
y_err_all.extend(y_meas_su if y_meas_su is not None else np.ones_like(y_meas))
137+
y_err_all.extend(y_meas_su)
126138

127139
return (
128140
np.array(y_obs_all),

src/easydiffraction/core/objects.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,8 @@ class Collection(ABC):
341341
def _child_class(self):
342342
return None
343343

344-
def __init__(self):
344+
def __init__(self, parent=None):
345+
self._parent = parent # Parent datablock
345346
self._datablock_id = None # Parent datablock name to be set by the parent
346347
self._items = {}
347348

@@ -373,6 +374,10 @@ def add(self, *args, **kwargs):
373374
child_obj.entry_id = child_obj.entry_id # Forcing the entry_id to be reset to update its child parameters
374375
self._items[child_obj._entry_id] = child_obj
375376

377+
# Call on_item_added if it exists, i.e. defined in the derived class
378+
if hasattr(self, "on_item_added"):
379+
self.on_item_added(child_obj)
380+
376381
def get_all_params(self):
377382
params = []
378383
for item in self._items.values():

src/easydiffraction/experiments/collections/datastore.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self, experiment: Experiment) -> None:
1616
self.meas: Optional[np.ndarray] = None
1717
self.meas_su: Optional[np.ndarray] = None
1818
self.bkg: Optional[np.ndarray] = None
19+
self.excluded: Optional[np.ndarray] = None # Flags for excluded points
1920
self._calc: Optional[np.ndarray] = None # Cached calculated intensities
2021

2122
@property
@@ -33,6 +34,7 @@ class PowderPattern(Pattern):
3334
"""
3435
Specialized pattern for powder diffraction (can be extended in the future).
3536
"""
37+
# TODO: Check if this class is needed or if it can be merged with Pattern
3638
def __init__(self, experiment: Experiment) -> None:
3739
super().__init__(experiment)
3840
# Additional powder-specific initialization if needed
@@ -49,12 +51,14 @@ def __init__(self, sample_form: str, experiment: Experiment) -> None:
4951
if sample_form == "powder":
5052
self.pattern: Pattern = PowderPattern(experiment)
5153
elif sample_form == "single_crystal":
52-
self.pattern: Pattern = Pattern(experiment)
54+
self.pattern: Pattern = Pattern(experiment) # TODO: Find better name for single crystal pattern
5355
else:
5456
raise ValueError(f"Unknown sample form '{sample_form}'")
5557

5658
def load_measured_data(self, file_path: str) -> None:
5759
"""Load measured data from an ASCII file."""
60+
# TODO: Check if this method is used...
61+
# Looks like _load_ascii_data_to_experiment from experiments.py is used instead
5862
print(f"Loading measured data for {self.sample_form} diffraction from {file_path}")
5963

6064
try:
@@ -73,6 +77,7 @@ def load_measured_data(self, file_path: str) -> None:
7377
self.pattern.x = x
7478
self.pattern.meas = y
7579
self.pattern.meas_su = sy
80+
self.pattern.excluded = np.full(x.shape, fill_value=False, dtype=bool) # No excluded points by default
7681

7782
print(f"Loaded {len(x)} points for experiment '{self.pattern.experiment.name}'.")
7883

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import List, Type
2+
3+
from easydiffraction.utils.utils import render_table
4+
from easydiffraction.utils.formatting import paragraph
5+
from easydiffraction.core.objects import (
6+
Parameter,
7+
Descriptor,
8+
Component,
9+
Collection
10+
)
11+
12+
13+
class ExcludedRegion(Component):
14+
@property
15+
def category_key(self) -> str:
16+
return "excluded_region"
17+
18+
@property
19+
def cif_category_key(self) -> str:
20+
return "excluded_region"
21+
22+
def __init__(self,
23+
minimum: float,
24+
maximum: float):
25+
super().__init__()
26+
27+
self.minimum = Descriptor(
28+
value=minimum,
29+
name="minimum",
30+
cif_name="minimum"
31+
)
32+
self.maximum = Parameter(
33+
value=maximum,
34+
name="maximum",
35+
cif_name="maximum"
36+
)
37+
38+
# Select which of the input parameters is used for the
39+
# as ID for the whole object
40+
self._entry_id = f'{minimum}-{maximum}'
41+
42+
# Lock further attribute additions to prevent
43+
# accidental modifications by users
44+
self._locked = True
45+
46+
47+
class ExcludedRegions(Collection):
48+
"""
49+
Collection of ExcludedRegion instances.
50+
"""
51+
@property
52+
def _type(self) -> str:
53+
return "category" # datablock or category
54+
55+
@property
56+
def _child_class(self) -> Type[ExcludedRegion]:
57+
return ExcludedRegion
58+
59+
def on_item_added(self, item: ExcludedRegion) -> None:
60+
"""
61+
Mark excluded points in the experiment pattern when a new region is added.
62+
"""
63+
experiment = self._parent
64+
pattern = experiment.datastore.pattern
65+
66+
# Boolean mask for points within the new excluded region
67+
in_region = ((pattern.full_x >= item.minimum.value) &
68+
(pattern.full_x <= item.maximum.value))
69+
70+
# Update the exclusion mask
71+
pattern.excluded[in_region] = True
72+
73+
# Update the excluded points in the datastore
74+
pattern.x = pattern.full_x[~pattern.excluded]
75+
pattern.meas = pattern.full_meas[~pattern.excluded]
76+
pattern.meas_su = pattern.full_meas_su[~pattern.excluded]
77+
78+
def show(self) -> None:
79+
# TODO: Consider moving this to the base class
80+
# to avoid code duplication with implementations in Background, etc.
81+
# Consider using parameter names as column headers
82+
columns_headers: List[str] = ["minimum", "maximum"]
83+
columns_alignment = ["left", "left"]
84+
columns_data: List[List[float]] = []
85+
for region in self._items.values():
86+
minimum = region.minimum.value
87+
maximum = region.maximum.value
88+
columns_data.append([minimum, maximum])
89+
90+
print(paragraph("Excluded regions"))
91+
render_table(columns_headers=columns_headers,
92+
columns_alignment=columns_alignment,
93+
columns_data=columns_data)

0 commit comments

Comments
 (0)