Skip to content

Commit

Permalink
Merge pull request #1068 from rhayes777/feature/sensitivity_updates
Browse files Browse the repository at this point in the history
Feature/sensitivity updates
  • Loading branch information
Jammy2211 authored Nov 1, 2024
2 parents 7c3c84c + ab3517c commit 80d8624
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
42 changes: 31 additions & 11 deletions autofit/non_linear/grid/sensitivity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
base_fit_cls: Callable,
perturb_fit_cls: Callable,
job_cls: ClassVar = Job,
visualizer_cls: Optional[Callable] = None,
perturb_model_prior_func: Optional[Callable] = None,
number_of_steps: Union[Tuple[int, ...], int] = 4,
mask: Optional[List[bool]] = None,
Expand Down Expand Up @@ -62,6 +63,9 @@ def __init__(
The class which fits the base model to each simulated dataset of the sensitivity map.
perturb_fit_cls
The class which fits the perturb model to each simulated dataset of the sensitivity map.
visualizer_cls
A class which can be used to visualize the results of the sensitivity mapping after each fit is performed,
therefore providing visualization on the fly.
number_of_steps
The number of steps for each dimension of the sensitivity grid. If input as a float the dimensions are
all that value. If input as a tuple of length the number of dimensions, each tuple value is the number of
Expand Down Expand Up @@ -97,6 +101,7 @@ def __init__(
self.perturb_model_prior_func = perturb_model_prior_func

self.job_cls = job_cls
self.visualizer_cls = visualizer_cls

self.number_of_steps = number_of_steps
self.mask = None
Expand Down Expand Up @@ -142,16 +147,16 @@ def run(self) -> SensitivityResult:
jobs = []

for number in range(len(self._perturb_instances)):
if self._should_bypass(number=number):
model = self.model.copy()
model.perturb = self._perturb_models[number]
results.append(
MaskedJobResult(
number=number,
model=model,
)
model = self.model.copy()
model.perturb = self._perturb_models[number]
results.append(
MaskedJobResult(
number=number,
model=model,
)
else:
)

if not self._should_bypass(number=number):
jobs.append(self._make_job(number))

for result in process_class.run_jobs(
Expand All @@ -160,8 +165,21 @@ def run(self) -> SensitivityResult:
if isinstance(result, Exception):
raise result

results.append(result)
results = sorted(results)
results[result.number] = result

sensitivity_result = SensitivityResult(
samples=[result.result.samples_summary for result in results],
perturb_samples=[
result.perturb_result.samples_summary for result in results
],
shape=self.shape,
path_values=self.path_values,
)

if self.visualizer_cls is not None:
self.visualizer_cls(
sensitivity_result=sensitivity_result, paths=self.paths
)

os.makedirs(self.paths.output_path, exist_ok=True)

Expand All @@ -179,6 +197,8 @@ def run(self) -> SensitivityResult:
filename=self.results_path,
)

# TODO : Had to repeat this code block to get certain unit tests to pass which presumably bypass run_jobs.

sensitivity_result = SensitivityResult(
samples=[result.result.samples_summary for result in results],
perturb_samples=[
Expand Down
1 change: 1 addition & 0 deletions autofit/non_linear/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _generate_unit_parameter_list(self, model: AbstractPriorModel) -> List[float

unit_parameter_list = []
for prior in model.priors_ordered_by_id:

try:
lower, upper = map(prior.unit_value_for, self.parameter_dict[prior])
value = random.uniform(lower, upper)
Expand Down
2 changes: 1 addition & 1 deletion autofit/non_linear/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def model_relative(self, r: float) -> AbstractPriorModel:

def model_bounded(self, b: float) -> AbstractPriorModel:
"""
Returns a model where every free parameter is a `UniformPrior` with `lower_limit` and `upper_limit the previous
Returns a model where every free parameter is a `UniformPrior` with `lower_limit` and `upper_limit` the previous
result's inferred maximum log likelihood parameter values minus and plus the bound `b`.
For example, a previous result may infer a parameter to have a maximum log likelihood value of 2.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ def test_perturbed_physical_centres_list_from(masked_result):
0.75,
0.75,
]


def test_visualise(sensitivity):
def visualiser(sensitivity_result, **_):
assert len(sensitivity_result.samples) == 8

sensitivity.visualizer_cls = visualiser
sensitivity.run()

0 comments on commit 80d8624

Please sign in to comment.