From e3fd1efea8ac1d74c503d680f420104e32efa5ac Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Fri, 5 Apr 2024 13:34:20 -0400 Subject: [PATCH] refactor Visualizer to internally use ClassConfig --- .../dataset/visualizer/visualizer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py index 70c4adf6a..423a6d68a 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py @@ -10,13 +10,13 @@ import matplotlib.pyplot as plt from rastervision.pipeline.file_system import make_dir +from rastervision.core.data import ClassConfig from rastervision.pytorch_learner.utils import ( deserialize_albumentation_transform, validate_albumentation_transform, MinMaxNormalize) from rastervision.pytorch_learner.learner_config import ( RGBTuple, ChannelInds, - ensure_class_colors, validate_channel_display_groups, get_default_channel_display_groups, ) @@ -60,14 +60,21 @@ def __init__(self, title is a string that will be used as the title of the subplot for that group. """ - self.class_names = class_names - self.class_colors = ensure_class_colors(self.class_names, class_colors) + self.class_config = ClassConfig(names=class_names, colors=class_colors) if transform is None: transform = A.to_dict(MinMaxNormalize()) self.transform = validate_albumentation_transform(transform) self._channel_display_groups = validate_channel_display_groups( channel_display_groups) + @property + def class_names(self): + return self.class_config.names + + @property + def class_colors(self): + return self.class_config.colors + @abstractmethod def plot_xyz(self, axs,