Skip to content

Commit

Permalink
Add option to perturbation test script to force the perturbation of s…
Browse files Browse the repository at this point in the history
…pecific attributes first
  • Loading branch information
nathanpainchaud committed Jul 22, 2024
1 parent 5906865 commit 2797424
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions didactic/apps/perturbation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def run_perturbation_test(
patients: Patients,
target: TabularAttribute,
perturbation_mode: Literal["negative", "positive"] = "negative",
manual_perturbations: Sequence[TabularAttribute] = None,
mask_tag: str = CardinalTag.mask,
progress_bar: bool = False,
save_dir: Path = None,
Expand All @@ -107,6 +108,7 @@ def run_perturbation_test(
target: Target attribute w.r.t. which to compute the model's AUROC score.
perturbation_mode: Type of perturbation test to perform. If ``negative``, the test will remove the least
relevant attributes first. If ``positive``, the test will remove the most relevant attributes first.
manual_perturbations: Attributes to manually perturb before automatically selecting following attributes.
mask_tag: Tag of the segmentation mask for which to extract the time-series attributes.
progress_bar: If ``True``, enables progress bars detailing the progress of how many attributes are left to
perturb.
Expand All @@ -115,6 +117,9 @@ def run_perturbation_test(
Returns:
A series containing the AUROC score of the model for each further attribute removed.
"""
if manual_perturbations is None:
manual_perturbations = []

tab_attrs, time_series_attrs = model.hparams.tabular_attrs, model.hparams.time_series_attrs
n_attrs = len(tab_attrs) + (len(time_series_attrs) * len(model.hparams.views))

Expand Down Expand Up @@ -190,8 +195,13 @@ def run_perturbation_test(
remaining_attrs = attrs_relevancy[~attrs_relevancy.index.isin(attrs_to_remove)]
if remaining_attrs.empty:
break

if manual_perturbations:
next_attr = manual_perturbations.pop(0)
else:
attrs_to_remove.append(remaining_attrs.index[0])
next_attr = remaining_attrs.index[0]

attrs_to_remove.append(next_attr)

return pd.Series(attrs_perturbation_scores, name="AUROC")

Expand Down Expand Up @@ -229,6 +239,12 @@ def main():
help="Type of perturbation test to perform. 'negative' removes the least relevant attributes first, "
"'positive' removes the most relevant attributes first",
)
parser.add_argument(
"--manual_perturbations",
type=TabularAttribute,
nargs="*",
help="Attributes to manually perturb before automatically selecting following attributes",
)
parser.add_argument(
"--mask_tag",
type=str,
Expand All @@ -239,10 +255,11 @@ def main():
args = parser.parse_args()
kwargs = vars(args)

encoder_ckpt, relevancy_target, perturbation_mode, mask_tag, output_dir = (
encoder_ckpt, relevancy_target, perturbation_mode, manual_perturbations, mask_tag, output_dir = (
kwargs.pop("pretrained_encoder"),
kwargs.pop("relevancy_target"),
kwargs.pop("perturbation_mode"),
kwargs.pop("manual_perturbations"),
kwargs.pop("mask_tag"),
kwargs.pop("output_dir"),
)
Expand All @@ -256,6 +273,7 @@ def main():
patients,
relevancy_target,
perturbation_mode=perturbation_mode,
manual_perturbations=manual_perturbations,
mask_tag=mask_tag,
progress_bar=True,
save_dir=output_dir,
Expand Down

0 comments on commit 2797424

Please sign in to comment.