From 4a2f2a3b79e1abd4a5d7ee0bf58dff88cdd4ebd4 Mon Sep 17 00:00:00 2001 From: NicoNeureiter Date: Thu, 3 Nov 2022 14:52:05 +0100 Subject: [PATCH] added feature_plot to the plotting options --- sbayes/config/config.py | 3 +- sbayes/config/default_config_plot.json | 30 ++- sbayes/plot.py | 262 +++++++++++++++++-------- sbayes/results.py | 20 +- 4 files changed, 218 insertions(+), 97 deletions(-) diff --git a/sbayes/config/config.py b/sbayes/config/config.py index 51d6c50e..42efe5dc 100644 --- a/sbayes/config/config.py +++ b/sbayes/config/config.py @@ -395,7 +395,8 @@ def from_config_file( with open(path, "r") as f: path_str = str(path).lower() if path_str.endswith(".yaml"): - config_dict = yaml.load(f, Loader=yaml.Loader) + yaml_loader = yaml.YAML(typ='safe') + config_dict = yaml_loader.load(f) else: config_dict = json.load(f) diff --git a/sbayes/config/default_config_plot.json b/sbayes/config/default_config_plot.json index 606cb683..387a5ced 100644 --- a/sbayes/config/default_config_plot.json +++ b/sbayes/config/default_config_plot.json @@ -127,8 +127,8 @@ "labels": { "add": true, "names": [ - "U", "C", + "U", "I" ], "font_size": 6 @@ -178,6 +178,34 @@ "n_columns": 5 } }, + "feature_plot": { + "legend": { + "labels": { + "add": true, + "font_size": 8, + "names": [ + "contact", + "universal", + "inheritance" + ] + }, + "title": { + "add": true, + "font_size": 6, + "position": [ + 0, + 1 + ] + } + }, + "output": { + "width_subplot": 3, + "height_subplot": 3, + "format": "pdf", + "resolution": 300, + "n_columns": 5 + } + }, "dic_plot": { "content": { "model": [], diff --git a/sbayes/plot.py b/sbayes/plot.py index 7ba29b4e..f9f126ca 100644 --- a/sbayes/plot.py +++ b/sbayes/plot.py @@ -1,20 +1,22 @@ from __future__ import annotations + import json import logging import math import os +from argparse import Namespace +from enum import Enum +from functools import lru_cache from itertools import compress from pathlib import Path -from os.path import basename import typing as typ -import pandas as pd - try: import importlib.resources as pkg_resources # PYTHON >= 3.7 except ImportError: import importlib_resources as pkg_resources # PYTHON < 3.7 +import pandas as pd import geopandas as gpd import matplotlib as mpl import matplotlib.pyplot as plt @@ -53,12 +55,17 @@ class Plot: + # Attributes config: dict[str, ...] config_file: Path base_directory: Path all_cluster_paths: list[Path] all_stats_paths: list[Path] + # Constant class attributes + pref_color_map = sns.cubehelix_palette(light=1, start=.5, rot=-.75, as_cmap=True) + # pref_color_map = sns.color_palette("rocket_r", as_cmap=True) + def __init__(self): # Config variables @@ -947,6 +954,7 @@ def posterior_map(self, results: Results, file_name='mst_posterior'): # Probability simplex, grid plot #################################### @staticmethod + @lru_cache(maxsize=128) def get_corner_points(n, offset=0.5 * np.pi): """Generate corner points of a equal sided ´n-eck´.""" angles = np.linspace(0, 2 * np.pi, n, endpoint=False) + offset @@ -990,8 +998,8 @@ def fill_outside(polygon, color, ax=None): top_y.append(polygon[i, 1]) ymin, ymax = ax.get_ylim() - plt.fill_between(bot_x, ymin, bot_y, color=color) - plt.fill_between(top_x, ymax, top_y, color=color) + ax.fill_between(bot_x, ymin, bot_y, color=color) + ax.fill_between(top_x, ymax, top_y, color=color) # Transform weights into needed format # def transform_weights(self, feature, b_in): @@ -1071,26 +1079,37 @@ def fill_outside(polygon, color, ax=None): # ordering = sorted(sort_by, key=sort_by.get, reverse=True) # return ordering - # Probability simplex (for one feature) + @staticmethod - def plot_weight(samples, feature, cfg_legend, ax=None, mean_weights=False, plot_samples=False): + def plot_weight( + samples: NDArray[float], + feature: str, + cfg_legend: dict, + ax: plt.Axes | None = None, + mean_weights: bool = False, + plot_samples: bool = False, + lw: float | None = None, + ): """Plot a set of weight vectors in a 2D representation of the probability simplex. Args: - samples (np.array): Sampled weight vectors to plot. - feature (str): Name of the feature for which weights are being plotted - ax (plt.Axis): The pyplot axis. - cfg_legend(dict): legend info from the config plot file - mean_weights (bool): Plot the mean of the weights? - plot_samples (bool): Add a scatter plot overlay of the actual samples? + samples: Sampled weight vectors to plot. + feature: Name of the feature for which weights are being plotted + ax: The pyplot axis. + cfg_legend: legend info from the config plot file + mean_weights: Plot the mean of the weights? + plot_samples: Add a scatter plot overlay of the actual samples? + lw: Line width of the triangular border delineating the probability simplex. """ if ax is None: ax = plt.gca() + n_samples, n_weights = samples.shape # Compute corners corners = Plot.get_corner_points(n_weights) + # Bounding box xmin, ymin = np.min(corners, axis=0) xmax, ymax = np.max(corners, axis=0) @@ -1098,48 +1117,47 @@ def plot_weight(samples, feature, cfg_legend, ax=None, mean_weights=False, plot_ # Project the samples samples_projected = samples.dot(corners) - # color map - cmap = sns.cubehelix_palette(light=1, start=.5, rot=-.75, as_cmap=True) - # Density and scatter plot title = cfg_legend['title'] if title['add']: - plt.text(title['position'][0], - title['position'][1], - str(feature), fontdict={'fontweight': 'bold', - 'fontsize': title['font_size']}, transform=ax.transAxes) + # ax.set_title(str(feature), pad=15, + # fontdict={'fontweight': 'bold', 'fontsize': title['font_size']}) + ax.text( + title['position'][0], title['position'][1], str(feature), + fontdict={'fontweight': 'bold', 'fontsize': title['font_size']}, + transform=ax.transAxes + ) x = samples_projected.T[0] y = samples_projected.T[1] - sns.kdeplot(x=x, y=y, shade=True, cut=30, n_levels=100, - clip=([xmin, xmax], [ymin, ymax]), cmap=cmap) + sns.kdeplot(x=x, y=y, shade=True, cut=30, n_levels=20, + clip=([xmin, xmax], [ymin, ymax]), cmap=Plot.pref_color_map, ax=ax) if plot_samples: - plt.scatter(x, y, color='k', lw=0, s=1, alpha=0.2) + ax.scatter(x, y, color='k', lw=0, s=1, alpha=0.2) # Draw simplex and crop outside - plt.fill(*corners.T, edgecolor='k', fill=False) + ax.fill(*corners.T, edgecolor='k', fill=False, lw=lw) Plot.fill_outside(corners, color='w', ax=ax) if mean_weights: mean_projected = np.mean(samples, axis=0).dot(corners) - plt.scatter(*mean_projected.T, color="#ed1696", lw=0, s=50, marker="o") + ax.scatter(*mean_projected.T, color="#ed1696", lw=0, s=50, marker="o") labels = cfg_legend['labels'] if labels['add']: for xy, label in zip(corners, labels['names']): - xy *= 1.08 # Stretch, s.t. labels don't overlap with corners - plt.text(*xy, label, ha='center', va='center', fontdict={'fontsize': labels['font_size']}) + xy = xy*1.15 - 0.05 # Stretch, s.t. labels don't overlap with corners + ax.text(*xy, label, ha='center', va='center', fontdict={'fontsize': labels['font_size']}) - plt.xlim(xmin - 0.1, xmax + 0.1) - plt.ylim(ymin - 0.1, ymax + 0.1) - plt.axis('off') - plt.plot() + ax.set_xlim([xmin - 0.1, xmax + 0.1]) + ax.set_ylim([ymin - 0.1, ymax + 0.1]) + ax.axis('off') @staticmethod - def plot_preference(samples, feature, cfg_legend, label_names, ax=None, plot_samples=False): + def plot_preference(samples, feature, cfg_legend, label_names, ax=None, plot_samples=False, color=None): """Plot a set of weight vectors in a 2D representation of the probability simplex. Args: @@ -1152,16 +1170,19 @@ def plot_preference(samples, feature, cfg_legend, label_names, ax=None, plot_sam """ if ax is None: ax = plt.gca() + if color is None: + color = 'g' + n_samples, n_p = samples.shape # color map - cmap = sns.cubehelix_palette(light=1, start=.5, rot=-.75, as_cmap=True) title = cfg_legend['title'] labels = cfg_legend['labels'] if n_p == 2: x = samples.T[1] - sns.distplot(x, rug=True, hist=False, kde_kws={"shade": True, "lw": 0, "clip": (0, 1)}, color="g", - rug_kws={"color": "k", "alpha": 0.01, "height": 0.03}) + sns.kdeplot(x, color=color, ax=ax, fill=True, lw=1, clip=(0, 1)) + # if cfg_legend['rug']: # TODO Make the rug plot option? + # sns.rugplot(x, color="k", alpha=0.02, height=-0.03, ax=ax, clip_on=False) ax.axes.get_yaxis().set_visible(False) @@ -1171,18 +1192,18 @@ def plot_preference(samples, feature, cfg_legend, label_names, ax=None, plot_sam x = 0.05 if x == 1: x = 0.95 - plt.text(x, -0.05, label, ha='center', va='top', + ax.text(x, -0.05, label, ha='center', va='top', fontdict={'fontsize': labels['font_size']}, transform=ax.transAxes) if title['add']: - plt.text(title['position'][0], title['position'][1], + ax.text(title['position'][0], title['position'][1], str(feature), fontsize=title['font_size'], fontweight='bold', transform=ax.transAxes) - plt.plot([0, 1], [0, 0], c="k", lw=0.5) + ax.plot([0, 1], [0, 0], lw=1, color=color, clip_on=False) - ax.axes.set_ylim([-0.2, 5]) - ax.axes.set_xlim([0, 1]) + ax.set_ylim([0, None]) + ax.set_xlim([-0.01, 1.01]) - plt.axis('off') + ax.axis('off') elif n_p > 2: # Compute corners @@ -1197,20 +1218,20 @@ def plot_preference(samples, feature, cfg_legend, label_names, ax=None, plot_sam # Density and scatter plot if title['add']: - plt.text(title['position'][0], title['position'][1], + ax.text(title['position'][0], title['position'][1], str(feature), fontsize=title['font_size'], fontweight='bold', transform=ax.transAxes) x = samples_projected.T[0] y = samples_projected.T[1] - sns.kdeplot(x=x, y=y, shade=True, thresh=0, cut=30, n_levels=100, - clip=([xmin, xmax], [ymin, ymax]), cmap=cmap) + sns.kdeplot(x=x, y=y, shade=True, thresh=0, cut=30, n_levels=100, ax=ax, + clip=([xmin, xmax], [ymin, ymax]), cmap=Plot.pref_color_map) if plot_samples: - plt.scatter(x, y, color='k', lw=0, s=1, alpha=0.05) + ax.scatter(x, y, color='k', lw=0, s=1, alpha=0.05) # Draw simplex and crop outside - plt.fill(*corners.T, edgecolor='k', fill=False) + ax.fill(*corners.T, edgecolor='k', fill=False) Plot.fill_outside(corners, color='w', ax=ax) if labels['add']: @@ -1223,36 +1244,21 @@ def plot_preference(samples, feature, cfg_legend, label_names, ax=None, plot_sam break_label = min(white_or_dash, key=lambda x: abs(x - mid_point)) label = label[:break_label] + "\n" + label[break_label:] - plt.text(*xy, label, ha='center', va='center', + ax.text(*xy, label, ha='center', va='center', fontdict={'fontsize': labels['font_size']}) - plt.xlim(xmin - 0.1, xmax + 0.1) - plt.ylim(ymin - 0.1, ymax + 0.1) - plt.axis('off') - - plt.plot() - - @staticmethod - def filter_weights( - weights: typ.Dict[str, NDArray[float]], - features_subset: typ.Optional[list] = None, - ): - """Return the subset of weights specificied by the features in `features_subset`. - If no features_subset is specified, return all weights - """ - if not features_subset: - return weights - else: - return {f: weights[f] for f in features_subset} + ax.set_xlim([xmin - 0.1, xmax + 0.1]) + ax.set_ylim([ymin - 0.1, ymax + 0.1]) + ax.axis('off') def plot_weights(self, results: Results, file_name: PathLike): print('Plotting weights...') cfg_weights = self.config['weight_plot'] - weights = self.filter_weights( - weights=results.weights, - features_subset=[results.feature_names[i-1] for i in cfg_weights['content']['features']], - ) + feature_subset = cfg_weights['content']['features'] + weights = results.weights + if feature_subset: + weights = {f: weights[f] for f in feature_subset} features = weights.keys() n_plots = len(features) @@ -1287,8 +1293,70 @@ def plot_weights(self, results: Results, file_name: PathLike): dpi=resolution, format=file_format) plt.close(fig) + def plot_weights_and_prefs( + self, + results: Results, + feature_name: str, + # file_name: PathLike, + ): + n_components = 1 + results.n_confounders + max_groups = max(len(groups) for groups in results.groups_by_confounders.values()) + fig, axes = plt.subplots(nrows=n_components, ncols=2 + max_groups, + figsize=(4 + max_groups, 2 + n_components), + gridspec_kw={'width_ratios': [1.8, .8] + [1]*max_groups}) + + plt.tight_layout() + axes[0, 0].text( + 0, 0, f"feature {feature_name}", + fontsize=10, fontweight='bold', ha='center', + bbox=dict(boxstyle='round, pad=1.0, rounding_size=0.5', facecolor='#eeeeee', lw=0.8, edgecolor='k') + ) + axes[0, 0].set_xlim([-1, 1]) + axes[0, 0].set_ylim([-2, 3]) + + for ax in axes.flatten(): + ax.axis('off') + + self.plot_weight( + samples=results.weights[feature_name], + # feature='weights', + feature='', + cfg_legend=self.config['feature_plot']['legend'], + mean_weights=True, + ax=axes[n_components // 2, 0], + lw=.8, + ) + + preferences = { + **results.confounding_effects, + 'cluster': results.areal_effect, + } + + for i, (component, prefs_by_group) in enumerate(preferences.items()): + axes[i, 1].text(0, 0, + component.replace('cluster', 'contact').replace('family', 'inheritance'), + fontsize=8) + axes[i, 1].set_ylim([-2, 3]) + for j, (group, pref_by_feat) in enumerate(prefs_by_group.items()): + axes[i, j + 2].get_shared_y_axes().join(axes[i, 2], axes[i, j + 2]) + + if component == "cluster": + group = "Area " + group[1:] + self.plot_preference( + pref_by_feat[feature_name], + feature='' if group == '' else group, + label_names=results.get_states_for_feature_name(feature_name), + cfg_legend=self.config['feature_plot']['legend'], + ax=axes[i, j + 2], + color='#005570', + ) + + + + plt.show() + # This is not changed yet - def plot_preferences(self, results: Results, file_name): + def plot_preferences(self, results: Results, file_name: str): """Creates preference plots for universal, clusters and families Args: @@ -1351,11 +1419,11 @@ def plot_preferences(self, results: Results, file_name): bbox_inches='tight', dpi=resolution, format=file_format) plt.close(fig) - def plot_dic(self, models, file_name): + def plot_dic(self, models: dict, file_name: str): """This function plots the dics. What did you think? Args: - file_name (str): name of the output file - models(dict): A dict of different models for which the DIC is evaluated + models: A dict of different models for which the DIC is evaluated + file_name: name of the output file """ print('Plotting DIC...') cfg_dic = self.config['dic_plot'] @@ -1716,47 +1784,66 @@ def plot_pies(self, results: Results, file_name: PathLike): plt.close(fig) -ALL_PLOT_TYPES = ['map', 'weights_plot', 'preference_plot', 'pie_plot'] +class PlotType(Enum): + map = 'map' + weights_plot = 'weights_plot' + preference_plot = 'preference_plot' + pie_plot = 'pie_plot' + feature_plot = 'feature_plot' + dic_plot = 'dic_plot' + @classmethod + def values(cls) -> list[str]: + return [str(e.value) for e in cls] -def main(config, plot_types=None): + +def main(config, plot_types: list[PlotType] = None, args: Namespace = None): # TODO adapt paths according to experiment_name (if provided) # If no plot type is specified, plot everything in the config file + if plot_types is None: - plot_types = ALL_PLOT_TYPES + plot_types = list(PlotType) plot = Plot() plot.load_config(config_file=config) plot.read_data() - def should_be_plotted(plot_type): + def should_be_plotted(plot_type: PlotType): """A plot type should only be generated if it 1) is specified in the config file and 2) is in the requested list of plot types.""" - return (plot_type in plot.config) and (plot_type in plot_types) + return (plot_type.value in plot.config) and (plot_type in plot_types) for m, results in plot.iterate_over_models(): print('Plotting model', m) # Plot the reconstructed clusters on a map - if should_be_plotted('map'): + if should_be_plotted(PlotType.map): plot_map(plot, results, m) # Plot the reconstructed mixture weights in simplex plots - if should_be_plotted('weights_plot'): + if should_be_plotted(PlotType.weights_plot): plot.plot_weights(results, file_name='weights_grid_' + m) # Plot the reconstructed probability vectors in simplex plots - if should_be_plotted('preference_plot'): + if should_be_plotted(PlotType.preference_plot): plot.plot_preferences(results, file_name=f'prob_grid_{m}') # Plot the reconstructed clusters in pie-charts # (one per language, showing how likely the language is to be in each cluster) - if should_be_plotted('pie_plot'): + if should_be_plotted(PlotType.pie_plot): plot.plot_pies(results, file_name= 'plot_pies_' + m) + # if should_be_plotted(PlotType.feature_plot): + if should_be_plotted(PlotType.feature_plot): + if args.feature_name is None: + logging.warning("Skipping 'feature_plot', since not feature_name was provided.") + # TODO: If feature_name is None, iterate over all features and save to file. + else: + plot.plot_weights_and_prefs(results, args.feature_name) + # Plot DIC over all models - if should_be_plotted('dic_plot'): + if should_be_plotted(PlotType.dic_plot): plot.plot_dic(plot.results, file_name='dic') @@ -1785,12 +1872,13 @@ def plot_map(plot: Plot, results: Results, m: str): parser = argparse.ArgumentParser(description='Plot the results of a sBayes run.') parser.add_argument('config', type=Path, help='The JSON configuration file') parser.add_argument('type', nargs='?', type=str, help='The type of plot to generate') + parser.add_argument('feature_name', nargs='?', type=str, help='The feature to show in a `feature_plot`') args = parser.parse_args() plot_types = None if args.type is not None: - if args.type not in ALL_PLOT_TYPES: - raise ValueError('Unknown plot type: ' + args.type) - plot_types = [args.type] + if args.type not in PlotType.values(): + raise ValueError(f"Unknown plot type: '{args.type}'. Choose from {PlotType.values()}.") + plot_types = [PlotType(args.type)] - main(args.config, plot_types=plot_types) + main(args.config, plot_types=plot_types, args=args) diff --git a/sbayes/results.py b/sbayes/results.py index 56f9ca80..a6b3284f 100644 --- a/sbayes/results.py +++ b/sbayes/results.py @@ -21,6 +21,7 @@ class Results: shape: (n_clusters, n_samples, n_sites) parameters (pd.DataFrame): Data-frame containing sample information about parameters and likelihood, prior and posterior probabilities. + groups_by_confounders (dict[str, list[str]): A list of groups for each confounder. """ @staticmethod @@ -82,30 +83,30 @@ def __init__( self.prior_single_clusters = Results.read_dictionary(self.parameters, "prior_") @property - def n_features(self): + def n_features(self) -> int: return len(self.feature_names) @property - def n_clusters(self): + def n_clusters(self) -> int: return self.clusters.shape[0] @property - def n_samples(self): + def n_samples(self) -> int: return self.clusters.shape[1] @property - def n_objects(self): + def n_objects(self) -> int: return self.clusters.shape[2] @property - def confounders(self): + def confounders(self) -> list[str]: return list(self.groups_by_confounders.keys()) @property - def n_confounders(self): + def n_confounders(self) -> int: return len(self.groups_by_confounders) - def __getitem__(self, item): + def __getitem__(self, item: str): if item in [ "feature_names", "sample_id", @@ -127,7 +128,10 @@ def __getitem__(self, item): @classmethod def from_csv_files( - cls, clusters_path: PathLike, parameters_path: PathLike, burn_in: float = 0.1 + cls: type[TResults], + clusters_path: PathLike, + parameters_path: PathLike, + burn_in: float = 0.1 ) -> TResults: clusters = cls.read_clusters(clusters_path) parameters = cls.read_stats(parameters_path)