Skip to content

Commit

Permalink
Improve plots of unimodal param variability to produce multiple w.r.t…
Browse files Browse the repository at this point in the history
…. each individual embedding
  • Loading branch information
nathanpainchaud committed Nov 21, 2023
1 parent fd1ffdf commit 4a082ef
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
2 changes: 1 addition & 1 deletion didactic/scripts/experiments.sh
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ for task in scratch finetune xtab-finetune; do
for tau_mode in learn_sigm learn_fn; do
job_path=$task/contrastive=$contrastive/$data/$time_series_tokenizer/$target/ordinal_mode=$ordinal_mode,distribution=$distribution,tau_mode=$tau_mode
echo "Plotting variability of predicted continuum w.r.t. position along the continuum between $job_path models" >>$HOME/data/didactic/results/plot_models_variability.log 2>&1
python ~/remote/didactic/didactic/scripts/plot_models_variability.py $(find ~/data/didactic/results/cardiac-multimodal-representation/$job_path -name *.ckpt | sort | tr "\n" " ") --data_roots $HOME/dataset/cardinal/v1.0/data --views A4C A2C --hue_attr=ht_severity --output_file=$HOME/data/didactic/results/cardiac-multimodal-representation/$job_path/unimodal_param_variability.svg >>$HOME/data/didactic/results/plot_models_variability.log 2>&1
python ~/remote/didactic/didactic/scripts/plot_models_variability.py $(find ~/data/didactic/results/cardiac-multimodal-representation/$job_path -name *.ckpt | sort | tr "\n" " ") --data_roots $HOME/dataset/cardinal/v1.0/data --views A4C A2C --hue_attr=ht_severity --output_dir=$HOME/data/didactic/results/cardiac-multimodal-representation/$job_path/unimodal_param_variability >>$HOME/data/didactic/results/plot_models_variability.log 2>&1
done
done
# end w/ ordinal constraint
Expand Down
58 changes: 38 additions & 20 deletions didactic/scripts/plot_models_variability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@


def plot_embeddings_variability(
embeddings: Dict[str, np.ndarray], hue: Sequence = None, hue_name: str = None, hue_order: Sequence[str] = None
embeddings: Dict[str, np.ndarray],
ref_embedding: str = None,
hue: Sequence = None,
hue_name: str = None,
hue_order: Sequence[str] = None,
) -> JointGrid:
"""Generates a scatter plot of the variability w.r.t. the position on the continuum, for each item in the embedding.
Expand All @@ -16,6 +20,8 @@ def plot_embeddings_variability(
Args:
embeddings: Mapping between the name of each embedding and the corresponding (N, [1]) embedding vector.
ref_embedding: Key to a "reference" embedding (in `embeddings`) to use as the reference position on the
continuum. If not provided, the mean of the embeddings will be used.
hue: (N), Optional sequence of values to use as the hue for the plot.
hue_name: Name to use for the hue in the plot's legend.
hue_order: Sequence of hue values to use for ordering the hue values in the plot's legend.
Expand All @@ -30,21 +36,24 @@ def plot_embeddings_variability(
embeddings_arr = np.stack([emb.squeeze() for emb in embeddings.values()], axis=1) # (N, M)

# Compute the statistics over each set of embedding vectors
mean = np.mean(embeddings_arr, axis=1) # (N)
if ref_embedding:
val = embeddings[ref_embedding].squeeze() # (N)
else:
val = np.mean(embeddings_arr, axis=1) # (N)
std = np.std(embeddings_arr, axis=1) # (N)

# Scatter plot of the position along the continuum (i.e. mean) w.r.t. variability (i.e. std) and user-defined hue
data = {"mean": mean, "std": std}
data = {"val": val, "std": std}
grid_kwargs = {}
if hue is not None:
data[hue_name] = hue
grid_kwargs.update(hue=hue_name, hue_order=hue_order)
data = pd.DataFrame(data=data)

with sns.axes_style("darkgrid"):
grid = sns.JointGrid(data=data, x="mean", y="std", **grid_kwargs)
grid = sns.JointGrid(data=data, x="val", y="std", **grid_kwargs)
grid.plot_joint(sns.scatterplot, size=std)
sns.histplot(data=data, x="mean", ax=grid.ax_marg_x, **grid_kwargs, legend=False)
sns.histplot(data=data, x="val", ax=grid.ax_marg_x, **grid_kwargs, legend=False)
sns.kdeplot(data=data, y="std", ax=grid.ax_marg_y, **grid_kwargs, legend=False, clip=(0, 1))

return grid
Expand Down Expand Up @@ -97,18 +106,18 @@ def main():
help="Name of the tabular attribute to use as the hue for the plot",
)
parser.add_argument(
"--output_file",
"--output_dir",
type=Path,
help="Path to the image file where to save the variability plot",
help="Path to the root directory where to save the variability plots",
)
args = parser.parse_args()

if len(args.pretrained_encoder) < 2:
raise ValueError("At least 2 models are required to compute a meaningful variability")

kwargs = vars(args)
pretrained_encoders, mask_tag, encoding_task, hue_attr, output_file = list(
map(kwargs.pop, ["pretrained_encoder", "mask_tag", "encoding_task", "hue_attr", "output_file"])
pretrained_encoders, mask_tag, encoding_task, hue_attr, output_dir = list(
map(kwargs.pop, ["pretrained_encoder", "mask_tag", "encoding_task", "hue_attr", "output_dir"])
)

# Compute the embeddings of the patients for each model
Expand Down Expand Up @@ -136,17 +145,26 @@ def main():
hue_order=TABULAR_CAT_ATTR_LABELS[hue_attr],
)

# Generate the plot of the variability
plot = plot_embeddings_variability(embeddings, **plot_kwargs)
plot.set_axis_labels(f"{encoding_task} mean", f"{encoding_task} std")
# Move the title above the plot, to avoid overlapping with the x-axis marginal plot
plot.figure.suptitle(f"{encoding_task} std w.r.t. mean and {hue_attr}", y=1.02)

# Save the variability plot to disk
output_file.parent.mkdir(parents=True, exist_ok=True)

plt.savefig(output_file, bbox_inches="tight")
plt.close() # Close the figure to avoid contamination between plots
output_dir.mkdir(parents=True, exist_ok=True)
# Generate a plot of the variability w.r.t. the mean embedding as well as each encoder's embedding
for ref_embedding in tqdm(
[None, *list(embeddings.keys())],
desc=f"Generating variability plots for {encoding_task} embeddings",
unit="embedding",
):
# Generate the plot of the variability
plot = plot_embeddings_variability(embeddings, ref_embedding=ref_embedding, **plot_kwargs)
xlabel = f"{encoding_task}"
if ref_embedding is None:
xlabel += " mean"
ylabel = f"{encoding_task} std"
plot.set_axis_labels(xlabel, ylabel)
# Move the title above the plot, to avoid overlapping with the x-axis marginal plot
plot.figure.suptitle(f"{xlabel} w.r.t. std and {hue_attr}", y=1.02)

filename = ref_embedding if ref_embedding else "mean"
plt.savefig(output_dir / f"{filename}.svg", bbox_inches="tight")
plt.close() # Close the figure to avoid contamination between plots


if __name__ == "__main__":
Expand Down

0 comments on commit 4a082ef

Please sign in to comment.