diff --git a/didactic/scripts/experiments.sh b/didactic/scripts/experiments.sh index 1209696..44a3ba3 100644 --- a/didactic/scripts/experiments.sh +++ b/didactic/scripts/experiments.sh @@ -277,7 +277,7 @@ for task in scratch finetune xtab-finetune; do for tau_mode in learn_sigm learn_fn; do job_path=$task/contrastive=$contrastive/$data/$time_series_tokenizer/$target/ordinal_mode=$ordinal_mode,distribution=$distribution,tau_mode=$tau_mode echo "Plotting variability of predicted continuum w.r.t. position along the continuum between $job_path models" >>$HOME/data/didactic/results/plot_models_variability.log 2>&1 - python ~/remote/didactic/didactic/scripts/plot_models_variability.py $(find ~/data/didactic/results/cardiac-multimodal-representation/$job_path -name *.ckpt | sort | tr "\n" " ") --data_roots $HOME/dataset/cardinal/v1.0/data --views A4C A2C --hue_attr=ht_severity --output_file=$HOME/data/didactic/results/cardiac-multimodal-representation/$job_path/unimodal_param_variability.svg >>$HOME/data/didactic/results/plot_models_variability.log 2>&1 + python ~/remote/didactic/didactic/scripts/plot_models_variability.py $(find ~/data/didactic/results/cardiac-multimodal-representation/$job_path -name *.ckpt | sort | tr "\n" " ") --data_roots $HOME/dataset/cardinal/v1.0/data --views A4C A2C --hue_attr=ht_severity --output_dir=$HOME/data/didactic/results/cardiac-multimodal-representation/$job_path/unimodal_param_variability >>$HOME/data/didactic/results/plot_models_variability.log 2>&1 done done # end w/ ordinal constraint diff --git a/didactic/scripts/plot_models_variability.py b/didactic/scripts/plot_models_variability.py index 33c8e0f..8be1a51 100644 --- a/didactic/scripts/plot_models_variability.py +++ b/didactic/scripts/plot_models_variability.py @@ -7,7 +7,11 @@ def plot_embeddings_variability( - embeddings: Dict[str, np.ndarray], hue: Sequence = None, hue_name: str = None, hue_order: Sequence[str] = None + embeddings: Dict[str, np.ndarray], + ref_embedding: str = None, + hue: Sequence = None, + hue_name: str = None, + hue_order: Sequence[str] = None, ) -> JointGrid: """Generates a scatter plot of the variability w.r.t. the position on the continuum, for each item in the embedding. @@ -16,6 +20,8 @@ def plot_embeddings_variability( Args: embeddings: Mapping between the name of each embedding and the corresponding (N, [1]) embedding vector. + ref_embedding: Key to a "reference" embedding (in `embeddings`) to use as the reference position on the + continuum. If not provided, the mean of the embeddings will be used. hue: (N), Optional sequence of values to use as the hue for the plot. hue_name: Name to use for the hue in the plot's legend. hue_order: Sequence of hue values to use for ordering the hue values in the plot's legend. @@ -30,11 +36,14 @@ def plot_embeddings_variability( embeddings_arr = np.stack([emb.squeeze() for emb in embeddings.values()], axis=1) # (N, M) # Compute the statistics over each set of embedding vectors - mean = np.mean(embeddings_arr, axis=1) # (N) + if ref_embedding: + val = embeddings[ref_embedding].squeeze() # (N) + else: + val = np.mean(embeddings_arr, axis=1) # (N) std = np.std(embeddings_arr, axis=1) # (N) # Scatter plot of the position along the continuum (i.e. mean) w.r.t. variability (i.e. std) and user-defined hue - data = {"mean": mean, "std": std} + data = {"val": val, "std": std} grid_kwargs = {} if hue is not None: data[hue_name] = hue @@ -42,9 +51,9 @@ def plot_embeddings_variability( data = pd.DataFrame(data=data) with sns.axes_style("darkgrid"): - grid = sns.JointGrid(data=data, x="mean", y="std", **grid_kwargs) + grid = sns.JointGrid(data=data, x="val", y="std", **grid_kwargs) grid.plot_joint(sns.scatterplot, size=std) - sns.histplot(data=data, x="mean", ax=grid.ax_marg_x, **grid_kwargs, legend=False) + sns.histplot(data=data, x="val", ax=grid.ax_marg_x, **grid_kwargs, legend=False) sns.kdeplot(data=data, y="std", ax=grid.ax_marg_y, **grid_kwargs, legend=False, clip=(0, 1)) return grid @@ -97,9 +106,9 @@ def main(): help="Name of the tabular attribute to use as the hue for the plot", ) parser.add_argument( - "--output_file", + "--output_dir", type=Path, - help="Path to the image file where to save the variability plot", + help="Path to the root directory where to save the variability plots", ) args = parser.parse_args() @@ -107,8 +116,8 @@ def main(): raise ValueError("At least 2 models are required to compute a meaningful variability") kwargs = vars(args) - pretrained_encoders, mask_tag, encoding_task, hue_attr, output_file = list( - map(kwargs.pop, ["pretrained_encoder", "mask_tag", "encoding_task", "hue_attr", "output_file"]) + pretrained_encoders, mask_tag, encoding_task, hue_attr, output_dir = list( + map(kwargs.pop, ["pretrained_encoder", "mask_tag", "encoding_task", "hue_attr", "output_dir"]) ) # Compute the embeddings of the patients for each model @@ -136,17 +145,26 @@ def main(): hue_order=TABULAR_CAT_ATTR_LABELS[hue_attr], ) - # Generate the plot of the variability - plot = plot_embeddings_variability(embeddings, **plot_kwargs) - plot.set_axis_labels(f"{encoding_task} mean", f"{encoding_task} std") - # Move the title above the plot, to avoid overlapping with the x-axis marginal plot - plot.figure.suptitle(f"{encoding_task} std w.r.t. mean and {hue_attr}", y=1.02) - - # Save the variability plot to disk - output_file.parent.mkdir(parents=True, exist_ok=True) - - plt.savefig(output_file, bbox_inches="tight") - plt.close() # Close the figure to avoid contamination between plots + output_dir.mkdir(parents=True, exist_ok=True) + # Generate a plot of the variability w.r.t. the mean embedding as well as each encoder's embedding + for ref_embedding in tqdm( + [None, *list(embeddings.keys())], + desc=f"Generating variability plots for {encoding_task} embeddings", + unit="embedding", + ): + # Generate the plot of the variability + plot = plot_embeddings_variability(embeddings, ref_embedding=ref_embedding, **plot_kwargs) + xlabel = f"{encoding_task}" + if ref_embedding is None: + xlabel += " mean" + ylabel = f"{encoding_task} std" + plot.set_axis_labels(xlabel, ylabel) + # Move the title above the plot, to avoid overlapping with the x-axis marginal plot + plot.figure.suptitle(f"{xlabel} w.r.t. std and {hue_attr}", y=1.02) + + filename = ref_embedding if ref_embedding else "mean" + plt.savefig(output_dir / f"{filename}.svg", bbox_inches="tight") + plt.close() # Close the figure to avoid contamination between plots if __name__ == "__main__":