Skip to content

Commit

Permalink
owtsne: Ensure tsne preprocessors are applied
Browse files Browse the repository at this point in the history
  • Loading branch information
pavlin-policar committed Sep 22, 2023
1 parent 360d6e4 commit ad26a7d
Showing 1 changed file with 54 additions and 8 deletions.
62 changes: 54 additions & 8 deletions Orange/widgets/unsupervised/owtsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Task(namespace):
data = None # type: Optional[Table]
distance_matrix = None # type: Optional[DistMatrix]

preprocessed_data = None # type: Optional[Table]

normalize = None # type: Optional[bool]
normalized_data = None # type: Optional[Table]

Expand Down Expand Up @@ -99,6 +101,10 @@ def error(msg):
return self


def apply_tsne_preprocessing(tsne, data):
return tsne.preprocess(data)


def data_normalization(data):
normalization = preprocess.Normalize()
return normalization(data)
Expand Down Expand Up @@ -132,6 +138,13 @@ def prepare_tsne_obj(n_samples: int, initialization_method: str,


class TSNERunner:
@staticmethod
def compute_tsne_preprocessing(task: Task, state: TaskState, **_) -> None:
state.set_status("Preprocessing data...")
task.preprocessed_data = apply_tsne_preprocessing(task.tsne, task.effective_data)
task.effective_data = task.preprocessed_data
state.set_partial_result(("preprocessed_data", task))

@staticmethod
def compute_normalization(task: Task, state: TaskState, **_) -> None:
state.set_status("Normalizing data...")
Expand Down Expand Up @@ -248,7 +261,7 @@ def run(cls, task: Task, state: TaskState) -> Task:
task.validate()

# Assign weights to each job indicating how much time will be spent on each
weights = {"normalization": 1, "pca": 1, "init": 1, "aff": 25, "tsne": 50}
weights = {"preprocessing": 1, "normalization": 1, "pca": 1, "init": 1, "aff": 25, "tsne": 50}
total_weight = sum(weights.values())

# Prepare the tsne object and add it to the spec
Expand All @@ -271,6 +284,8 @@ def run(cls, task: Task, state: TaskState) -> Task:
# Add the tasks that still need to be run to the job queue
if task.distance_metric != "precomputed":
task.effective_data = task.data
if task.preprocessed_data is None:
job_queue.append((cls.compute_tsne_preprocessing, weights["preprocessing"]))

if task.normalize and task.normalized_data is None:
job_queue.append((cls.compute_normalization, weights["normalization"]))
Expand All @@ -293,6 +308,8 @@ def run(cls, task: Task, state: TaskState) -> Task:
# Ensure the effective data is set to the appropriate, potentially
# precomputed matrix
task.effective_data = task.data
if task.preprocessed_data is not None:
task.effective_data = task.preprocessed_data
if task.normalize and task.normalized_data is not None:
task.effective_data = task.normalized_data
if task.use_pca_preprocessing and task.pca_projection is not None:
Expand Down Expand Up @@ -330,10 +347,12 @@ def update_coordinates(self):

class invalidated:
# pylint: disable=invalid-name
normalized_data = pca_projection = initialization = affinities = tsne_embedding = False
preprocessed_data = normalized_data = pca_projection = initialization = \
affinities = tsne_embedding = False

def __set__(self, instance, value):
# `self._invalidate = True` should invalidate everything
self.preprocessed_data = value
self.normalized_data = value
self.pca_projection = value
self.initialization = value
Expand All @@ -343,15 +362,15 @@ def __set__(self, instance, value):
def __bool__(self):
# If any of the values are invalidated, this should return true
return (
self.normalized_data or self.pca_projection or self.initialization or
self.affinities or self.tsne_embedding
self.preprocessed_data or self.normalized_data or self.pca_projection or
self.initialization or self.affinities or self.tsne_embedding
)

def __str__(self):
return "%s(%s)" % (self.__class__.__name__, ", ".join(
"=".join([k, str(getattr(self, k))])
for k in ["normalized_data", "pca_projection", "initialization",
"affinities", "tsne_embedding"]
for k in ["preprocessed_data", "normalized_data", "pca_projection",
"initialization", "affinities", "tsne_embedding"]
))


Expand Down Expand Up @@ -418,6 +437,7 @@ def __init__(self):
self.signal_data = None # type: Optional[Table]

# Intermediate results
self.preprocessed_data = None # type: Optional[Table]
self.normalized_data = None # type: Optional[Table]
self.pca_projection = None # type: Optional[Table]
self.initialization = None # type: Optional[np.ndarray]
Expand Down Expand Up @@ -538,6 +558,10 @@ def _multiscale_changed(self):
self._invalidate_affinities()

# Invalidation cascade
def _invalidate_preprocessed_data(self):
self._invalidated.preprocessed_data = True
self._invalidate_normalized_data()

def _invalidate_normalized_data(self):
self._invalidated.normalized_data = True
self._invalidate_pca_projection()
Expand Down Expand Up @@ -823,7 +847,7 @@ def enable_controls(self):
)
form.labelForField(self.distance_metric_combo).setDisabled(True)

# PCA doesn't support normalization on sparse data, as this would
# Normalization isn't supported on sparse data, as this would
# require centering and normalizing the matrix
if not has_distance_matrix and has_data and self.data.is_sparse():
self.normalize_cbx.setDisabled(True)
Expand All @@ -840,6 +864,8 @@ def enable_controls(self):

def run(self):
# Reset invalidated values as indicated by the flags
if self._invalidated.preprocessed_data:
self.preprocessed_data = None
if self._invalidated.normalized_data:
self.normalized_data = None
if self._invalidated.pca_projection:
Expand Down Expand Up @@ -875,6 +901,8 @@ def run(self):
task = Task(
data=self.data,
distance_matrix=self.distance_matrix,
# Preprocessed data
preprocessed_data=self.preprocessed_data,
# Normalization
normalize=self.normalize,
normalized_data=self.normalized_data,
Expand All @@ -896,6 +924,12 @@ def run(self):
)
return self.start(TSNERunner.run, task)

def __ensure_task_same_for_preprocessing(self, task: Task):
if task.distance_metric != "precomputed":
assert task.data is self.data
assert isinstance(task.preprocessed_data, Table) and \
len(task.preprocessed_data) == len(self.data)

def __ensure_task_same_for_normalization(self, task: Task):
assert task.normalize == self.normalize
if task.normalize and task.distance_metric != "precomputed":
Expand Down Expand Up @@ -946,24 +980,32 @@ def on_partial_result(self, value):
# type: (Tuple[str, Task]) -> None
which, task = value

if which == "normalized_data":
if which == "preprocessed_data":
self.__ensure_task_same_for_preprocessing(task)
self.preprocessed_data = task.preprocessed_data
elif which == "normalized_data":
self.__ensure_task_same_for_preprocessing(task)
self.__ensure_task_same_for_normalization(task)
self.normalized_data = task.normalized_data
elif which == "pca_projection":
self.__ensure_task_same_for_preprocessing(task)
self.__ensure_task_same_for_normalization(task)
self.__ensure_task_same_for_pca(task)
self.pca_projection = task.pca_projection
elif which == "initialization":
self.__ensure_task_same_for_preprocessing(task)
self.__ensure_task_same_for_normalization(task)
self.__ensure_task_same_for_pca(task)
self.__ensure_task_same_for_initialization(task)
self.initialization = task.initialization
elif which == "affinities":
self.__ensure_task_same_for_preprocessing(task)
self.__ensure_task_same_for_normalization(task)
self.__ensure_task_same_for_pca(task)
self.__ensure_task_same_for_affinities(task)
self.affinities = task.affinities
elif which == "tsne_embedding":
self.__ensure_task_same_for_preprocessing(task)
self.__ensure_task_same_for_normalization(task)
self.__ensure_task_same_for_pca(task)
self.__ensure_task_same_for_initialization(task)
Expand All @@ -990,6 +1032,9 @@ def on_done(self, task):
self.run_button.setText("Start")
# NOTE: All of these have already been set by on_partial_result,
# we double-check that they are aliases
if task.preprocessed_data is not None:
self.__ensure_task_same_for_preprocessing(task)
assert task.preprocessed_data is self.preprocessed_data
if task.normalized_data is not None:
self.__ensure_task_same_for_normalization(task)
assert task.normalized_data is self.normalized_data
Expand All @@ -1015,6 +1060,7 @@ def clear(self):
"""Clear widget state. Note that this doesn't clear the data."""
super().clear()
self.cancel()
self.preprocessed_data = None
self.normalized_data = None
self.pca_projection = None
self.initialization = None
Expand Down

0 comments on commit ad26a7d

Please sign in to comment.