From 1f45245600c960ecc3c4e4e35f9cb52b25aa1bea Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Thu, 3 Oct 2024 14:22:38 -0400 Subject: [PATCH 01/23] randomized svd draft --- python/tskit/trees.py | 152 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 7d3bbca7b9..09025de8dc 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -40,6 +40,7 @@ from typing import NamedTuple import numpy as np +import scipy.sparse import _tskit import tskit @@ -8592,6 +8593,156 @@ def genetic_relatedness_vector( ) return out + def pca( + self, + n_components: int = 10, + iterated_power: int = 3, + n_oversamples: int = 10, + indices: np.ndarray = None, + centre: bool = True, + windows = None, + random_state: np.random.Generator = None, + ): + """ + Run randomized singular value decomposition (rSVD) to obtain principal components. + API partially adopted from `scikit-learn`: + https://scikit-learn.org/dev/modules/generated/sklearn.decomposition.PCA.html + + :param int n_components: Number of principal components + :param int iterated_power: Number of power iteration of range finder + :param int n_oversamples: Number of additional test vectors + :param np.ndarray indices: Indcies of individuals to perform rSVD + :param bool centre: Centre the genetic relatedness matrix + :param windows: ??? + :param np.random.Generator random_state: Random number generator + """ + + def _rand_pow_range_finder( + operator: Callable, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + ) -> np.ndarray: + """ + Algorithm 9 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0 + test_vectors = rng.normal(size=(operator_dim, num_vectors)) + Q = test_vectors + for i in range(depth): + Q = np.linalg.qr(Q).Q + Q = operator(Q) + Q = np.linalg.qr(Q).Q + return Q[:, :rank] + + def _rand_svd( + operator: Callable, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + ) -> (np.ndarray, np.ndarray, np.ndarray): + """ + Algorithm 8 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0 + Q = _rand_pow_range_finder( + operator, + operator_dim, + num_vectors, + depth, + num_vectors, + rng + ) + C = operator(Q).T + U_hat, D, V = np.linalg.svd(C, full_matrices=False) + U = Q @ U_hat + return U[:,:rank], D[:rank], V[:rank] + + def _genetic_relatedness_vector( + ts: tskit.Treesequence, + arr: np.ndarray, + rows: np.ndarray, + cols: np.ndarray, + centre: bool = False, + windows = None, + ) -> np.ndarray: + """ + Wrapper around `tskit.TreeSequence.genetic_relatedness_vector` to support centering in respect to individuals. + Multiplies an array to the genetic relatedness matrix of :class:`tskit.TreeSequence`. + + :param tskit.TreeSequence ts: A tree sequence. + :param numpy.ndarray arr: The array to multiply. Either a vector or a matrix. + :param numpy.ndarray rows: Index of rows of the genetic relatedness matrix to be selected. + :param numpy.ndarray cols: Index of cols of the genetic relatedness matrix to be selected. The size should match the row length of `arr`. + :param bool centre: Centre the genetic relatedness matrix. Centering happens respect to the `rows` and `cols`. + :param windows: An increasing list of breakpoints between the windows to compute the genetic relatedness matrix in. + :return: An array that is the matrix-array product of the genetic relatedness matrix and the array. + :rtype: `np.ndarray` + """ + + # maps samples to individuals + def sample_individual_sparray(ts: tskit.TreeSequence) -> scipy.sparse.sparray: + samples_individual = ts.nodes_individual[ts.samples()] + return scipy.sparse.csr_array( + ( + np.ones(ts.num_samples), + (np.arange(ts.num_samples), samples_individual) + ), + shape=(ts.num_samples, ts.num_individuals) + ) + + # maps values in idx to num_individuals + def individual_idx_sparray(n: int, idx: np.ndarray) -> scipy.sparse.sparray: + return scipy.sparse.csr_array( + ( + np.ones(idx.size), + (idx, np.arange(idx.size)) + ), + shape=(n, idx.size) + ) + + assert cols.size == arr.shape[0], "Dimension mismatch" + # centering + x = arr - arr.mean(axis=0) if centre else arr # centering within index in rows + x = individual_idx_sparray(ts.num_individuals, cols).dot(x) + x = sample_individual_sparray(ts).dot(x) + x = ts.genetic_relatedness_vector(W=x, windows=windows, mode="branch", centre=False) + x = sample_individual_sparray(ts).T.dot(x) + x = individual_idx_sparray(ts.num_individuals, rows).T.dot(x) + x = x - x.mean(axis=0) if centre else x # centering within index in cols + + return x + + + if indices is None: indices = np.array([i.id for i in self.individuals()]) + if random_state is None: random_state = np.random.default_rng() + + def _G(x): + return _genetic_relatedness_vector( + self.ts, + x, + indices, + indices, + centre, + windows + ) + + U, D, _ = _rand_svd( + operator=_G, + operator_dim=indices.size, + rank=n_components, + depth=iterated_power, + num_vectors=n_components+n_oversamples, + rng=random_state + ) + + return U, D + + def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W`` @@ -10171,3 +10322,4 @@ def write_ms( ) else: print(file=output) + From e408ab3cffc8ee4ad9544fb957ea220b5bdac6cd Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Fri, 4 Oct 2024 12:46:37 -0400 Subject: [PATCH 02/23] modified api remove scipy --- python/tskit/trees.py | 65 ++++++++++++++----------------------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 09025de8dc..54f9767620 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8598,11 +8598,12 @@ def pca( n_components: int = 10, iterated_power: int = 3, n_oversamples: int = 10, - indices: np.ndarray = None, + samples: np.ndarray = None, + individuals: np.ndarray = None, centre: bool = True, windows = None, random_state: np.random.Generator = None, - ): + ) -> (np.ndarray, np.ndarray): """ Run randomized singular value decomposition (rSVD) to obtain principal components. API partially adopted from `scikit-learn`: @@ -8611,7 +8612,8 @@ def pca( :param int n_components: Number of principal components :param int iterated_power: Number of power iteration of range finder :param int n_oversamples: Number of additional test vectors - :param np.ndarray indices: Indcies of individuals to perform rSVD + :param np.ndarray samples: Samples to perform PCA + :param np.ndarray individuals: Individuals to perform PCA :param bool centre: Centre the genetic relatedness matrix :param windows: ??? :param np.random.Generator random_state: Random number generator @@ -8663,18 +8665,16 @@ def _rand_svd( return U[:,:rank], D[:rank], V[:rank] def _genetic_relatedness_vector( - ts: tskit.Treesequence, arr: np.ndarray, rows: np.ndarray, cols: np.ndarray, - centre: bool = False, + centre: bool = True, windows = None, ) -> np.ndarray: """ Wrapper around `tskit.TreeSequence.genetic_relatedness_vector` to support centering in respect to individuals. Multiplies an array to the genetic relatedness matrix of :class:`tskit.TreeSequence`. - :param tskit.TreeSequence ts: A tree sequence. :param numpy.ndarray arr: The array to multiply. Either a vector or a matrix. :param numpy.ndarray rows: Index of rows of the genetic relatedness matrix to be selected. :param numpy.ndarray cols: Index of cols of the genetic relatedness matrix to be selected. The size should match the row length of `arr`. @@ -8684,56 +8684,33 @@ def _genetic_relatedness_vector( :rtype: `np.ndarray` """ - # maps samples to individuals - def sample_individual_sparray(ts: tskit.TreeSequence) -> scipy.sparse.sparray: - samples_individual = ts.nodes_individual[ts.samples()] - return scipy.sparse.csr_array( - ( - np.ones(ts.num_samples), - (np.arange(ts.num_samples), samples_individual) - ), - shape=(ts.num_samples, ts.num_individuals) - ) - - # maps values in idx to num_individuals - def individual_idx_sparray(n: int, idx: np.ndarray) -> scipy.sparse.sparray: - return scipy.sparse.csr_array( - ( - np.ones(idx.size), - (idx, np.arange(idx.size)) - ), - shape=(n, idx.size) - ) - assert cols.size == arr.shape[0], "Dimension mismatch" - # centering + ij = np.vstack([[n,k] for k, i in enumerate(individuals) for n in self.individual(i).nodes]) + samples, sample_individuals = ij[:,0], ij[:,1] # sample node index, individual of those nodes x = arr - arr.mean(axis=0) if centre else arr # centering within index in rows - x = individual_idx_sparray(ts.num_individuals, cols).dot(x) - x = sample_individual_sparray(ts).dot(x) - x = ts.genetic_relatedness_vector(W=x, windows=windows, mode="branch", centre=False) - x = sample_individual_sparray(ts).T.dot(x) - x = individual_idx_sparray(ts.num_individuals, rows).T.dot(x) + x = self.genetic_relatedness_vector(W=x[sample_individuals], windows=windows, mode="branch", centre=False, nodes=samples) + bincount_fn = lambda w: np.bincount(sample_individuals, w) + x = np.apply_along_axis(bincount_fn, axis=0, arr=x) # I think it should be axis=1, but axis=0 gives the correct values why? x = x - x.mean(axis=0) if centre else x # centering within index in cols return x - if indices is None: indices = np.array([i.id for i in self.individuals()]) if random_state is None: random_state = np.random.default_rng() + if samples is None and individuals is None: samples = self.samples() - def _G(x): - return _genetic_relatedness_vector( - self.ts, - x, - indices, - indices, - centre, - windows - ) + if samples is not None and individuals is not None: + raise ValueError("samples and individuals cannot be used at the same time") + elif samples is not None: + _G = lambda x: self.genetic_relatedness_vector(x, windows=windows, mode="branch", centre=centre, nodes=samples) + dim = samples.size + elif individuals is not None: + _G = lambda x: _genetic_relatedness_vector(x, individuals, individuals, centre=centre, windows=windows) + dim = individuals.size U, D, _ = _rand_svd( operator=_G, - operator_dim=indices.size, + operator_dim=dim, rank=n_components, depth=iterated_power, num_vectors=n_components+n_oversamples, From a1761322721b411c0da15e19abe8b8155f40086b Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Mon, 7 Oct 2024 10:55:41 -0400 Subject: [PATCH 03/23] remove scipy --- python/tskit/trees.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 54f9767620..2269697658 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -40,7 +40,6 @@ from typing import NamedTuple import numpy as np -import scipy.sparse import _tskit import tskit From 8c662c881c2a8a2c9c1463552c1304d7ddbce1f0 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Mon, 7 Oct 2024 11:02:22 -0400 Subject: [PATCH 04/23] correct docstring and comments --- python/tskit/trees.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 2269697658..530ac9f071 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8600,7 +8600,7 @@ def pca( samples: np.ndarray = None, individuals: np.ndarray = None, centre: bool = True, - windows = None, + windows: list = None, random_state: np.random.Generator = None, ) -> (np.ndarray, np.ndarray): """ @@ -8608,14 +8608,14 @@ def pca( API partially adopted from `scikit-learn`: https://scikit-learn.org/dev/modules/generated/sklearn.decomposition.PCA.html - :param int n_components: Number of principal components - :param int iterated_power: Number of power iteration of range finder - :param int n_oversamples: Number of additional test vectors - :param np.ndarray samples: Samples to perform PCA - :param np.ndarray individuals: Individuals to perform PCA - :param bool centre: Centre the genetic relatedness matrix - :param windows: ??? - :param np.random.Generator random_state: Random number generator + :param int n_components: Number of principal components. + :param int iterated_power: Number of power iteration of range finder. + :param int n_oversamples: Number of additional test vectors. + :param np.ndarray samples: Samples to perform PCA. + :param np.ndarray individuals: Individuals to perform PCA. + :param bool centre: Centre the genetic relatedness matrix. + :param list windows: An increasing list of breakpoints between the windows to compute the principal components in. Currently not working. + :param np.random.Generator random_state: Random number generator. """ def _rand_pow_range_finder( @@ -8689,7 +8689,7 @@ def _genetic_relatedness_vector( x = arr - arr.mean(axis=0) if centre else arr # centering within index in rows x = self.genetic_relatedness_vector(W=x[sample_individuals], windows=windows, mode="branch", centre=False, nodes=samples) bincount_fn = lambda w: np.bincount(sample_individuals, w) - x = np.apply_along_axis(bincount_fn, axis=0, arr=x) # I think it should be axis=1, but axis=0 gives the correct values why? + x = np.apply_along_axis(bincount_fn, axis=0, arr=x) x = x - x.mean(axis=0) if centre else x # centering within index in cols return x From aa13613560792143377fbedb779a965cf01b5ed6 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Mon, 7 Oct 2024 11:32:47 -0400 Subject: [PATCH 05/23] space remove --- python/tskit/trees.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 530ac9f071..745b932e6d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8718,7 +8718,6 @@ def _genetic_relatedness_vector( return U, D - def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W`` @@ -10298,4 +10297,3 @@ def write_ms( ) else: print(file=output) - From 5bf405a4887349c4d9305b4b4e314df0b411db8e Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Mon, 7 Oct 2024 23:16:13 -0400 Subject: [PATCH 06/23] rng to random seed --- python/tskit/trees.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 745b932e6d..174eec8a60 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8601,7 +8601,7 @@ def pca( individuals: np.ndarray = None, centre: bool = True, windows: list = None, - random_state: np.random.Generator = None, + seed: int = None, ) -> (np.ndarray, np.ndarray): """ Run randomized singular value decomposition (rSVD) to obtain principal components. @@ -8614,8 +8614,11 @@ def pca( :param np.ndarray samples: Samples to perform PCA. :param np.ndarray individuals: Individuals to perform PCA. :param bool centre: Centre the genetic relatedness matrix. - :param list windows: An increasing list of breakpoints between the windows to compute the principal components in. Currently not working. - :param np.random.Generator random_state: Random number generator. + :param list windows: An increasing list of breakpoints between the windows + to compute the principal components in. + :param int random_seed: The random seed. If this is None, a random seed will + be automatically generated. Valid random seeds must be between 1 and + :math:`2^32 − 1`. """ def _rand_pow_range_finder( @@ -8694,8 +8697,7 @@ def _genetic_relatedness_vector( return x - - if random_state is None: random_state = np.random.default_rng() + random_state = np.random.default_rng(random_seed) if samples is None and individuals is None: samples = self.samples() if samples is not None and individuals is not None: From 6e415e5fda598674249ceb8fce2f4091b7e2777d Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Mon, 7 Oct 2024 23:54:33 -0400 Subject: [PATCH 07/23] add windows feature --- python/tskit/trees.py | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 174eec8a60..e5af8d8733 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8601,7 +8601,7 @@ def pca( individuals: np.ndarray = None, centre: bool = True, windows: list = None, - seed: int = None, + random_seed: int = None, ) -> (np.ndarray, np.ndarray): """ Run randomized singular value decomposition (rSVD) to obtain principal components. @@ -8690,7 +8690,7 @@ def _genetic_relatedness_vector( ij = np.vstack([[n,k] for k, i in enumerate(individuals) for n in self.individual(i).nodes]) samples, sample_individuals = ij[:,0], ij[:,1] # sample node index, individual of those nodes x = arr - arr.mean(axis=0) if centre else arr # centering within index in rows - x = self.genetic_relatedness_vector(W=x[sample_individuals], windows=windows, mode="branch", centre=False, nodes=samples) + x = self.genetic_relatedness_vector(W=x[sample_individuals], windows=windows, mode="branch", centre=False, nodes=samples)[0] bincount_fn = lambda w: np.bincount(sample_individuals, w) x = np.apply_along_axis(bincount_fn, axis=0, arr=x) x = x - x.mean(axis=0) if centre else x # centering within index in cols @@ -8703,20 +8703,36 @@ def _genetic_relatedness_vector( if samples is not None and individuals is not None: raise ValueError("samples and individuals cannot be used at the same time") elif samples is not None: - _G = lambda x: self.genetic_relatedness_vector(x, windows=windows, mode="branch", centre=centre, nodes=samples) + mode = 'node' dim = samples.size elif individuals is not None: - _G = lambda x: _genetic_relatedness_vector(x, individuals, individuals, centre=centre, windows=windows) + mode = 'individual' dim = individuals.size + + drop_windows = windows is None + if drop_windows: + windows = [0, self.sequence_length] + + U = np.empty((len(windows)-1, dim, n_components)) + D = np.empty((len(windows)-1, n_components)) + for i in range(len(windows)-1): + if mode == 'node': + _G = lambda x: self.genetic_relatedness_vector( + x, windows=windows[i:i+2], mode="branch", centre=centre, nodes=samples)[0] + elif mode == 'individual': + _G = lambda x: _genetic_relatedness_vector( + x, individuals, individuals, centre=centre, windows=windows[i:i+2]) + U[i], D[i], _ = _rand_svd( + operator=_G, + operator_dim=dim, + rank=n_components, + depth=iterated_power, + num_vectors=n_components+n_oversamples, + rng=random_state + ) - U, D, _ = _rand_svd( - operator=_G, - operator_dim=dim, - rank=n_components, - depth=iterated_power, - num_vectors=n_components+n_oversamples, - rng=random_state - ) + if drop_windows or len(windows) == 2: + U, D = U[0], D[0] return U, D From 45ac61eac4778a15148ce8b7078fecdc08685640 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Fri, 11 Oct 2024 14:03:23 -0400 Subject: [PATCH 08/23] output shape change when windows=wholegnome --- python/tskit/trees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e5af8d8733..f0fe415114 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8731,7 +8731,7 @@ def _genetic_relatedness_vector( rng=random_state ) - if drop_windows or len(windows) == 2: + if drop_windows: U, D = U[0], D[0] return U, D From 2cdb9dd70cf57feed53ca2c315ad0fc54491eccd Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Fri, 11 Oct 2024 14:28:44 -0400 Subject: [PATCH 09/23] make centre work with nodes --- python/tskit/trees.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index f0fe415114..f00b3946a4 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8666,26 +8666,13 @@ def _rand_svd( U = Q @ U_hat return U[:,:rank], D[:rank], V[:rank] - def _genetic_relatedness_vector( + def _genetic_relatedness_vector_individual( arr: np.ndarray, rows: np.ndarray, cols: np.ndarray, centre: bool = True, windows = None, ) -> np.ndarray: - """ - Wrapper around `tskit.TreeSequence.genetic_relatedness_vector` to support centering in respect to individuals. - Multiplies an array to the genetic relatedness matrix of :class:`tskit.TreeSequence`. - - :param numpy.ndarray arr: The array to multiply. Either a vector or a matrix. - :param numpy.ndarray rows: Index of rows of the genetic relatedness matrix to be selected. - :param numpy.ndarray cols: Index of cols of the genetic relatedness matrix to be selected. The size should match the row length of `arr`. - :param bool centre: Centre the genetic relatedness matrix. Centering happens respect to the `rows` and `cols`. - :param windows: An increasing list of breakpoints between the windows to compute the genetic relatedness matrix in. - :return: An array that is the matrix-array product of the genetic relatedness matrix and the array. - :rtype: `np.ndarray` - """ - assert cols.size == arr.shape[0], "Dimension mismatch" ij = np.vstack([[n,k] for k, i in enumerate(individuals) for n in self.individual(i).nodes]) samples, sample_individuals = ij[:,0], ij[:,1] # sample node index, individual of those nodes @@ -8697,6 +8684,20 @@ def _genetic_relatedness_vector( return x + def _genetic_relatedness_vector_node( + arr: np.ndarray, + rows: np.ndarray, + cols: np.ndarray, + centre: bool = True, + windows = None, + ) -> np.ndarray: + assert cols.size == arr.shape[0], "Dimension mismatch" + x = arr - arr.mean(axis=0) if centre else arr + x = self.genetic_relatedness_vector(W=x, windows=windows, mode="branch", centre=False, nodes=cols)[0] + x = x - x.mean(axis=0) if centre else x + + return x + random_state = np.random.default_rng(random_seed) if samples is None and individuals is None: samples = self.samples() @@ -8717,10 +8718,10 @@ def _genetic_relatedness_vector( D = np.empty((len(windows)-1, n_components)) for i in range(len(windows)-1): if mode == 'node': - _G = lambda x: self.genetic_relatedness_vector( - x, windows=windows[i:i+2], mode="branch", centre=centre, nodes=samples)[0] + _G = lambda x: _genetic_relatedness_vector_node( + x, samples, samples, centre=centre, windows=windows[i:i+2]) elif mode == 'individual': - _G = lambda x: _genetic_relatedness_vector( + _G = lambda x: _genetic_relatedness_vector_individual( x, individuals, individuals, centre=centre, windows=windows[i:i+2]) U[i], D[i], _ = _rand_svd( operator=_G, From 1f194aa398e864634da4ee80b8b32a0d6ad62cc8 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Fri, 11 Oct 2024 14:40:11 -0400 Subject: [PATCH 10/23] remove redundant options from internal functions --- python/tskit/trees.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index f00b3946a4..32031412e5 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8668,12 +8668,9 @@ def _rand_svd( def _genetic_relatedness_vector_individual( arr: np.ndarray, - rows: np.ndarray, - cols: np.ndarray, centre: bool = True, windows = None, ) -> np.ndarray: - assert cols.size == arr.shape[0], "Dimension mismatch" ij = np.vstack([[n,k] for k, i in enumerate(individuals) for n in self.individual(i).nodes]) samples, sample_individuals = ij[:,0], ij[:,1] # sample node index, individual of those nodes x = arr - arr.mean(axis=0) if centre else arr # centering within index in rows @@ -8686,14 +8683,11 @@ def _genetic_relatedness_vector_individual( def _genetic_relatedness_vector_node( arr: np.ndarray, - rows: np.ndarray, - cols: np.ndarray, centre: bool = True, windows = None, ) -> np.ndarray: - assert cols.size == arr.shape[0], "Dimension mismatch" x = arr - arr.mean(axis=0) if centre else arr - x = self.genetic_relatedness_vector(W=x, windows=windows, mode="branch", centre=False, nodes=cols)[0] + x = self.genetic_relatedness_vector(W=x, windows=windows, mode="branch", centre=False, nodes=samples)[0] x = x - x.mean(axis=0) if centre else x return x @@ -8719,10 +8713,10 @@ def _genetic_relatedness_vector_node( for i in range(len(windows)-1): if mode == 'node': _G = lambda x: _genetic_relatedness_vector_node( - x, samples, samples, centre=centre, windows=windows[i:i+2]) + x, centre=centre, windows=windows[i:i+2]) elif mode == 'individual': _G = lambda x: _genetic_relatedness_vector_individual( - x, individuals, individuals, centre=centre, windows=windows[i:i+2]) + x, centre=centre, windows=windows[i:i+2]) U[i], D[i], _ = _rand_svd( operator=_G, operator_dim=dim, From fdc5842adaa81af320a0449a69efc7365c6e4a6e Mon Sep 17 00:00:00 2001 From: peter Date: Sun, 13 Oct 2024 08:53:27 -0700 Subject: [PATCH 11/23] start at testing --- python/tests/test_relatedness_vector.py | 146 ++++++++++++++- python/tskit/trees.py | 226 +++++++++++++++--------- 2 files changed, 283 insertions(+), 89 deletions(-) diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py index f765c75c9f..6488fa7b24 100644 --- a/python/tests/test_relatedness_vector.py +++ b/python/tests/test_relatedness_vector.py @@ -460,7 +460,7 @@ def check_relatedness_vector( return R -class TestExamples: +class TestRelatednessVector: def test_bad_weights(self): n = 5 @@ -737,3 +737,147 @@ def test_disconnected_non_sample_topology(self, centre): ts2, internal_checks=True, centre=centre, do_nodes=False ) np.testing.assert_array_almost_equal(D1, D2) + + +def pca(ts, windows, centre): + drop_dimension = windows is None + if drop_dimension: + windows = [0, ts.sequence_length] + Sigma = relatedness_matrix(ts=ts, windows=windows, centre=centre) + U, S, _ = np.linalg.svd(Sigma, hermitian=True) + if drop_dimension: + U = U[0] + S = S[0] + return U, S + + +def allclose_up_to_sign(x, y, **kwargs): + # check if two vectors are the same up to sign + x_const = np.isclose(np.std(x), 0) + y_const = np.isclose(np.std(y), 0) + if x_const or y_const: + if np.allclose(x, 0): + r = 1.0 + else: + r = np.mean(x / y) + else: + r = np.sign(np.corrcoef(x, y)[0, 1]) + return np.allclose(x, r * y, **kwargs) + + +def assert_pcs_equal(U, D, U_full, D_full, rtol=1e-05, atol=1e-08): + # check that the PCs in U, D occur in U_full, D_full + # accounting for sign and ordering + assert len(D) <= len(D_full) + assert U.shape[0] == U_full.shape[0] + assert U.shape[1] == len(D) + for k in range(len(D)): + u = U[:, k] + d = D[k] + (ii,) = np.where(np.isclose(D_full, d, rtol=rtol, atol=atol)) + assert len(ii) > 0, f"{k}th singular value {d} not found in {D_full}." + found_it = False + for i in ii: + if allclose_up_to_sign(u, U_full[:, i], rtol=rtol, atol=atol): + found_it = True + break + assert found_it, f"{k}th singular vector {u} not found in {U_full}." + + +class TestPCA: + + def verify_pca(self, ts, num_windows, n_components, centre): + if num_windows == 0: + windows = None + elif num_windows % 2 == 0: + windows = np.linspace( + 0.2 * ts.sequence_length, 0.8 * ts.sequence_length, num_windows + 1 + ) + else: + windows = np.linspace(0, ts.sequence_length, num_windows + 1) + ts_U, ts_D = ts.pca( + windows=windows, n_components=n_components, centre=centre, random_seed=123 + ) + num_rows = ts.num_samples + if windows is None: + assert ts_U.shape == (num_rows, n_components) + assert ts_D.shape == (n_components,) + else: + assert ts_U.shape == (num_windows, num_rows, n_components) + assert ts_D.shape == (num_windows, n_components) + U, D = pca(ts=ts, windows=windows, centre=centre) + if windows is None: + np.testing.assert_allclose(ts_D, D[:n_components], atol=1e-8) + assert_pcs_equal(ts_U, ts_D, U, D) + else: + for w in range(num_windows): + np.testing.assert_allclose(ts_D[w], D[w, :n_components], atol=1e-8) + assert_pcs_equal(ts_U[w], ts_D[w], U[w], D[w]) + + def test_bad_windows(self): + ts = msprime.sim_ancestry( + 3, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + for bad_w in ([], [1]): + with pytest.raises(ValueError, match="Number of windows"): + ts.pca(n_components=2, windows=bad_w) + for bad_w in ([1, 0], [-3, 10]): + with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_WINDOWS"): + ts.pca(n_components=2, windows=bad_w) + + def test_bad_num_components(self): + ts = msprime.sim_ancestry( + 3, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + with pytest.raises(ValueError, match="Number of components"): + ts.pca(n_components=ts.num_samples + 1) + with pytest.raises(ValueError, match="Number of components"): + ts.pca(n_components=4, samples=[0, 1, 2]) + with pytest.raises(ValueError, match="Number of components"): + ts.pca(n_components=4, individuals=[0, 1]) + + def test_indivs_and_samples(self): + ts = msprime.sim_ancestry( + 3, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + with pytest.raises(ValueError, match="Samples and individuals"): + ts.pca(n_components=2, samples=[0, 1, 2, 3], individuals=[0, 1, 2]) + + def test_modes(self): + ts = msprime.sim_ancestry( + 3, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + for bad_mode in ("site", "node"): + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE" + ): + ts.pca(n_components=2, mode=bad_mode) + + @pytest.mark.parametrize("n", [2, 3, 5, 15]) + @pytest.mark.parametrize("centre", (True, False)) + @pytest.mark.parametrize("num_windows", (0, 1, 2, 3)) + @pytest.mark.parametrize("n_components", (1, 3)) + def test_simple_sims(self, n, centre, num_windows, n_components): + ploidy = 1 + nc = min(n_components, n * ploidy) + ts = msprime.sim_ancestry( + n, + ploidy=ploidy, + population_size=20, + sequence_length=100, + recombination_rate=0.01, + random_seed=12345, + ) + self.verify_pca(ts, num_windows=num_windows, n_components=nc, centre=centre) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 32031412e5..cbbe0811b3 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8593,138 +8593,188 @@ def genetic_relatedness_vector( return out def pca( - self, - n_components: int = 10, - iterated_power: int = 3, - n_oversamples: int = 10, - samples: np.ndarray = None, - individuals: np.ndarray = None, - centre: bool = True, - windows: list = None, - random_seed: int = None, - ) -> (np.ndarray, np.ndarray): - """ - Run randomized singular value decomposition (rSVD) to obtain principal components. - API partially adopted from `scikit-learn`: - https://scikit-learn.org/dev/modules/generated/sklearn.decomposition.PCA.html + self, + n_components: int = 10, + windows: list = None, + samples: np.ndarray = None, + individuals: np.ndarray = None, + mode: str = "branch", + centre: bool = True, + iterated_power: int = 3, + n_oversamples: int = 10, + random_seed: int = None, + ) -> (np.ndarray, np.ndarray): + """ + Run randomized singular value decomposition (rSVD) to obtain principal + components. + API partially adopted from `scikit-learn`'s + `sklearn.decomposition.PCA.html` + + By default, performs PCA for the samples, so output has one coordinate + for each sample), but alternatively either a list of sample IDs or a + list of individual IDs can be provided (but not both). + + TODO: say exactly what is returned (and relationship to + :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>`). + + TODO: say what algorithms are used. :param int n_components: Number of principal components. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param np.ndarray samples: Samples to perform PCA with. + :param np.ndarray individuals: Individuals to perform PCA with. Cannot specify + both `samples` and `individuals`. + :param str mode: A string giving the "type" of relatedness to be computed + (defaults to "branch"; see + :meth:`genetic_relatedness_vector + <.TreeSequence.genetic_relatedness_vector>`) + :param bool centre: Centre the genetic relatedness matrix. :param int iterated_power: Number of power iteration of range finder. :param int n_oversamples: Number of additional test vectors. - :param np.ndarray samples: Samples to perform PCA. - :param np.ndarray individuals: Individuals to perform PCA. - :param bool centre: Centre the genetic relatedness matrix. - :param list windows: An increasing list of breakpoints between the windows - to compute the principal components in. :param int random_seed: The random seed. If this is None, a random seed will be automatically generated. Valid random seeds must be between 1 and :math:`2^32 − 1`. + :return: A tuple (U, D) of ndarrays, with the principal component loadings in U + and the principal values in D. """ + if samples is None and individuals is None: + samples = self.samples() + + if samples is not None and individuals is not None: + raise ValueError("Samples and individuals cannot be used at the same time") + elif samples is not None: + output_type = "node" + dim = len(samples) + else: + assert individuals is not None + output_type = "individual" + dim = len(individuals) + + if n_components > dim: + raise ValueError( + "Number of components must be less than or equal to " + "the number of samples (or individuals, if specified)." + ) + + random_state = np.random.default_rng(random_seed) + def _rand_pow_range_finder( - operator: Callable, - operator_dim: int, - rank: int, - depth: int, - num_vectors: int, - rng: np.random.Generator, - ) -> np.ndarray: + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + ) -> np.ndarray: """ Algorithm 9 in https://arxiv.org/pdf/2002.01387 """ assert num_vectors >= rank > 0 test_vectors = rng.normal(size=(operator_dim, num_vectors)) Q = test_vectors - for i in range(depth): + for _ in range(depth): Q = np.linalg.qr(Q).Q Q = operator(Q) Q = np.linalg.qr(Q).Q return Q[:, :rank] def _rand_svd( - operator: Callable, - operator_dim: int, - rank: int, - depth: int, - num_vectors: int, - rng: np.random.Generator, - ) -> (np.ndarray, np.ndarray, np.ndarray): + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + ) -> (np.ndarray, np.ndarray, np.ndarray): """ Algorithm 8 in https://arxiv.org/pdf/2002.01387 """ assert num_vectors >= rank > 0 Q = _rand_pow_range_finder( - operator, - operator_dim, - num_vectors, - depth, - num_vectors, - rng - ) + operator, operator_dim, num_vectors, depth, num_vectors, rng + ) C = operator(Q).T U_hat, D, V = np.linalg.svd(C, full_matrices=False) U = Q @ U_hat - return U[:,:rank], D[:rank], V[:rank] + return U[:, :rank], D[:rank], V[:rank] def _genetic_relatedness_vector_individual( - arr: np.ndarray, - centre: bool = True, - windows = None, - ) -> np.ndarray: - ij = np.vstack([[n,k] for k, i in enumerate(individuals) for n in self.individual(i).nodes]) - samples, sample_individuals = ij[:,0], ij[:,1] # sample node index, individual of those nodes - x = arr - arr.mean(axis=0) if centre else arr # centering within index in rows - x = self.genetic_relatedness_vector(W=x[sample_individuals], windows=windows, mode="branch", centre=False, nodes=samples)[0] - bincount_fn = lambda w: np.bincount(sample_individuals, w) + arr: np.ndarray, + centre: bool = True, + windows=None, + ) -> np.ndarray: + ij = np.vstack( + [ + [n, k] + for k, i in enumerate(individuals) + for n in self.individual(i).nodes + ] + ) + samples, sample_individuals = ( + ij[:, 0], + ij[:, 1], + ) # sample node index, individual of those nodes + x = ( + arr - arr.mean(axis=0) if centre else arr + ) # centering within index in rows + x = self.genetic_relatedness_vector( + W=x[sample_individuals], + windows=windows, + mode=mode, + centre=False, + nodes=samples, + )[0] + + def bincount_fn(w): + np.bincount(sample_individuals, w) + x = np.apply_along_axis(bincount_fn, axis=0, arr=x) - x = x - x.mean(axis=0) if centre else x # centering within index in cols + x = x - x.mean(axis=0) if centre else x # centering within index in cols return x def _genetic_relatedness_vector_node( - arr: np.ndarray, - centre: bool = True, - windows = None, - ) -> np.ndarray: + arr: np.ndarray, + centre: bool = True, + windows=None, + ) -> np.ndarray: x = arr - arr.mean(axis=0) if centre else arr - x = self.genetic_relatedness_vector(W=x, windows=windows, mode="branch", centre=False, nodes=samples)[0] + x = self.genetic_relatedness_vector( + W=x, windows=windows, mode=mode, centre=False, nodes=samples + )[0] x = x - x.mean(axis=0) if centre else x return x - random_state = np.random.default_rng(random_seed) - if samples is None and individuals is None: samples = self.samples() - - if samples is not None and individuals is not None: - raise ValueError("samples and individuals cannot be used at the same time") - elif samples is not None: - mode = 'node' - dim = samples.size - elif individuals is not None: - mode = 'individual' - dim = individuals.size - drop_windows = windows is None - if drop_windows: - windows = [0, self.sequence_length] - - U = np.empty((len(windows)-1, dim, n_components)) - D = np.empty((len(windows)-1, n_components)) - for i in range(len(windows)-1): - if mode == 'node': - _G = lambda x: _genetic_relatedness_vector_node( - x, centre=centre, windows=windows[i:i+2]) - elif mode == 'individual': - _G = lambda x: _genetic_relatedness_vector_individual( - x, centre=centre, windows=windows[i:i+2]) + windows = self.parse_windows(windows) + num_windows = len(windows) - 1 + if num_windows < 1: + raise ValueError("Number of windows must be at least 1.") + + U = np.empty((num_windows, dim, n_components)) + D = np.empty((num_windows, n_components)) + for i in range(num_windows): + this_window = windows[i : i + 2] + _f = ( + _genetic_relatedness_vector_node + if output_type == "node" + else _genetic_relatedness_vector_individual + ) + + def _G(x): + _f(x, centre=centre, windows=this_window) # NOQA: B023 + U[i], D[i], _ = _rand_svd( - operator=_G, - operator_dim=dim, - rank=n_components, - depth=iterated_power, - num_vectors=n_components+n_oversamples, - rng=random_state - ) + operator=_G, + operator_dim=dim, + rank=n_components, + depth=iterated_power, + num_vectors=n_components + n_oversamples, + rng=random_state, + ) if drop_windows: U, D = U[0], D[0] From c0a285463f18f5e0d518390ec21adff1bbb8cdfd Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Sun, 13 Oct 2024 12:42:29 -0400 Subject: [PATCH 12/23] change variable name to n_* to num_* --- python/tskit/trees.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index cbbe0811b3..d686cd7522 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8594,14 +8594,14 @@ def genetic_relatedness_vector( def pca( self, - n_components: int = 10, + num_components: int, windows: list = None, samples: np.ndarray = None, individuals: np.ndarray = None, mode: str = "branch", centre: bool = True, - iterated_power: int = 3, - n_oversamples: int = 10, + iterated_power: int = 5, + num_oversamples: int = 10, random_seed: int = None, ) -> (np.ndarray, np.ndarray): """ @@ -8619,7 +8619,7 @@ def pca( TODO: say what algorithms are used. - :param int n_components: Number of principal components. + :param int num_components: Number of principal components. :param list windows: An increasing list of breakpoints between the windows to compute the statistic in. :param np.ndarray samples: Samples to perform PCA with. @@ -8631,7 +8631,7 @@ def pca( <.TreeSequence.genetic_relatedness_vector>`) :param bool centre: Centre the genetic relatedness matrix. :param int iterated_power: Number of power iteration of range finder. - :param int n_oversamples: Number of additional test vectors. + :param int num_oversamples: Number of additional test vectors. :param int random_seed: The random seed. If this is None, a random seed will be automatically generated. Valid random seeds must be between 1 and :math:`2^32 − 1`. @@ -8652,7 +8652,7 @@ def pca( output_type = "individual" dim = len(individuals) - if n_components > dim: + if num_components > dim: raise ValueError( "Number of components must be less than or equal to " "the number of samples (or individuals, if specified)." @@ -8754,8 +8754,8 @@ def _genetic_relatedness_vector_node( if num_windows < 1: raise ValueError("Number of windows must be at least 1.") - U = np.empty((num_windows, dim, n_components)) - D = np.empty((num_windows, n_components)) + U = np.empty((num_windows, dim, num_components)) + D = np.empty((num_windows, num_components)) for i in range(num_windows): this_window = windows[i : i + 2] _f = ( @@ -8770,9 +8770,9 @@ def _G(x): U[i], D[i], _ = _rand_svd( operator=_G, operator_dim=dim, - rank=n_components, + rank=num_components, depth=iterated_power, - num_vectors=n_components + n_oversamples, + num_vectors=num_components + num_oversamples, rng=random_state, ) From 667d60ffe8ec17f5f758424a7099a272aede4297 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Sun, 13 Oct 2024 13:15:20 -0400 Subject: [PATCH 13/23] return range sketch matrix Q --- python/tskit/trees.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index d686cd7522..39a3d5e1d0 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8603,7 +8603,8 @@ def pca( iterated_power: int = 5, num_oversamples: int = 10, random_seed: int = None, - ) -> (np.ndarray, np.ndarray): + range_sketch: list = None, + ) -> (np.ndarray, np.ndarray, np.ndarray): """ Run randomized singular value decomposition (rSVD) to obtain principal components. @@ -8635,6 +8636,7 @@ def pca( :param int random_seed: The random seed. If this is None, a random seed will be automatically generated. Valid random seeds must be between 1 and :math:`2^32 − 1`. + :param list range_sketch: Sketch matrix for each window. Default is None. :return: A tuple (U, D) of ndarrays, with the principal component loadings in U and the principal values in D. """ @@ -8651,6 +8653,8 @@ def pca( assert individuals is not None output_type = "individual" dim = len(individuals) + if range_sketch is not None: + assert len(range_sketch) == len(windows) if num_components > dim: raise ValueError( @@ -8667,13 +8671,17 @@ def _rand_pow_range_finder( depth: int, num_vectors: int, rng: np.random.Generator, + range_sketch: np.ndarray = None, ) -> np.ndarray: """ Algorithm 9 in https://arxiv.org/pdf/2002.01387 """ assert num_vectors >= rank > 0 - test_vectors = rng.normal(size=(operator_dim, num_vectors)) - Q = test_vectors + if range_sketch is None: + test_vectors = rng.normal(size=(operator_dim, num_vectors)) + Q = test_vectors + else: + Q = range_sketch for _ in range(depth): Q = np.linalg.qr(Q).Q Q = operator(Q) @@ -8687,18 +8695,19 @@ def _rand_svd( depth: int, num_vectors: int, rng: np.random.Generator, + range_sketch: np.ndarray = None, ) -> (np.ndarray, np.ndarray, np.ndarray): """ Algorithm 8 in https://arxiv.org/pdf/2002.01387 """ assert num_vectors >= rank > 0 Q = _rand_pow_range_finder( - operator, operator_dim, num_vectors, depth, num_vectors, rng + operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch ) C = operator(Q).T U_hat, D, V = np.linalg.svd(C, full_matrices=False) U = Q @ U_hat - return U[:, :rank], D[:rank], V[:rank] + return U[:, :rank], D[:rank], V[:rank], Q def _genetic_relatedness_vector_individual( arr: np.ndarray, @@ -8756,6 +8765,7 @@ def _genetic_relatedness_vector_node( U = np.empty((num_windows, dim, num_components)) D = np.empty((num_windows, num_components)) + Q = np.empty((num_windows, num_components + num_oversamples)) for i in range(num_windows): this_window = windows[i : i + 2] _f = ( @@ -8767,19 +8777,20 @@ def _genetic_relatedness_vector_node( def _G(x): _f(x, centre=centre, windows=this_window) # NOQA: B023 - U[i], D[i], _ = _rand_svd( + U[i], D[i], _, Q[i] = _rand_svd( operator=_G, operator_dim=dim, rank=num_components, depth=iterated_power, num_vectors=num_components + num_oversamples, rng=random_state, + range_sketch=range_sketch[i], ) if drop_windows: - U, D = U[0], D[0] + U, D, Q = U[0], D[0], Q[0] - return U, D + return U, D, Q def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ From 36c10e9a0ba424ed0edc2413b2e74bf78dff2933 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Sun, 13 Oct 2024 13:17:35 -0400 Subject: [PATCH 14/23] fix random sketch option to handle None --- python/tskit/trees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 39a3d5e1d0..13e639cc64 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8784,7 +8784,7 @@ def _G(x): depth=iterated_power, num_vectors=num_components + num_oversamples, rng=random_state, - range_sketch=range_sketch[i], + range_sketch=None if random_sketch is None else range_sketch[i], ) if drop_windows: From 0ac88b8289d44db315d4a511f483d7e2b84f2213 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Sun, 13 Oct 2024 13:32:21 -0400 Subject: [PATCH 15/23] random_sketch needs windows specified --- python/tskit/trees.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 13e639cc64..f2f000457f 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8603,7 +8603,7 @@ def pca( iterated_power: int = 5, num_oversamples: int = 10, random_seed: int = None, - range_sketch: list = None, + range_sketch: np.ndarray = None, ) -> (np.ndarray, np.ndarray, np.ndarray): """ Run randomized singular value decomposition (rSVD) to obtain principal @@ -8636,7 +8636,7 @@ def pca( :param int random_seed: The random seed. If this is None, a random seed will be automatically generated. Valid random seeds must be between 1 and :math:`2^32 − 1`. - :param list range_sketch: Sketch matrix for each window. Default is None. + :param np.ndarray range_sketch: Sketch matrix for each window. Default is None. :return: A tuple (U, D) of ndarrays, with the principal component loadings in U and the principal values in D. """ @@ -8653,8 +8653,11 @@ def pca( assert individuals is not None output_type = "individual" dim = len(individuals) + + if windows is None and range_sketch is not None: + raise ValueError("Windows should be given to supply range_sketch") if range_sketch is not None: - assert len(range_sketch) == len(windows) + assert range_sketch.shape == len(windows) - 1 if num_components > dim: raise ValueError( From a6fa8ab64dae83bfa697fa62d9a049905599dc23 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Sun, 13 Oct 2024 13:47:05 -0400 Subject: [PATCH 16/23] input checking for range_sketch to align with windows --- python/tskit/trees.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index f2f000457f..aab8687c50 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8653,11 +8653,12 @@ def pca( assert individuals is not None output_type = "individual" dim = len(individuals) - - if windows is None and range_sketch is not None: - raise ValueError("Windows should be given to supply range_sketch") + if range_sketch is not None: - assert range_sketch.shape == len(windows) - 1 + if windows is not None: + assert range_sketch.shape[0] == len(windows) - 1 + elif windows is None: + range_sketch = np.expand_dims(range_sketch, 0) if num_components > dim: raise ValueError( From 791c7dacaef38636d82614b1dcdaa1525c329ea7 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Sun, 13 Oct 2024 13:51:56 -0400 Subject: [PATCH 17/23] docstring change to reflect range_sketch --- python/tskit/trees.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index aab8687c50..2a5cbe4d96 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8637,8 +8637,8 @@ def pca( be automatically generated. Valid random seeds must be between 1 and :math:`2^32 − 1`. :param np.ndarray range_sketch: Sketch matrix for each window. Default is None. - :return: A tuple (U, D) of ndarrays, with the principal component loadings in U - and the principal values in D. + :return: A tuple (U, D, Q) of ndarrays, with the principal component loadings in U + and the principal values in D. Q is the range sketch array for each window. """ if samples is None and individuals is None: From 2f4ce2bac92cf5356d908d2a34e90f750138e1cc Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Sun, 13 Oct 2024 20:22:17 -0400 Subject: [PATCH 18/23] linting has a bug; when converting lambda to ordinary function definition, it omits return in the end of the function --- python/tskit/trees.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 2a5cbe4d96..6064e7556f 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8680,7 +8680,7 @@ def _rand_pow_range_finder( """ Algorithm 9 in https://arxiv.org/pdf/2002.01387 """ - assert num_vectors >= rank > 0 + assert num_vectors >= rank > 0, "num_vectors should be larger than rank" if range_sketch is None: test_vectors = rng.normal(size=(operator_dim, num_vectors)) Q = test_vectors @@ -8700,7 +8700,7 @@ def _rand_svd( num_vectors: int, rng: np.random.Generator, range_sketch: np.ndarray = None, - ) -> (np.ndarray, np.ndarray, np.ndarray): + ) -> (np.ndarray, np.ndarray, np.ndarray, float): """ Algorithm 8 in https://arxiv.org/pdf/2002.01387 """ @@ -8711,7 +8711,12 @@ def _rand_svd( C = operator(Q).T U_hat, D, V = np.linalg.svd(C, full_matrices=False) U = Q @ U_hat - return U[:, :rank], D[:rank], V[:rank], Q + + error_factor = np.power( + 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), + 1 / (2 * depth + 1)) + error_bound = D[-1] * (2 + error_factor) + return U[:, :rank], D[:rank], V[:rank], Q, error_bound def _genetic_relatedness_vector_individual( arr: np.ndarray, @@ -8741,7 +8746,7 @@ def _genetic_relatedness_vector_individual( )[0] def bincount_fn(w): - np.bincount(sample_individuals, w) + return np.bincount(sample_individuals, w) x = np.apply_along_axis(bincount_fn, axis=0, arr=x) x = x - x.mean(axis=0) if centre else x # centering within index in cols @@ -8769,7 +8774,8 @@ def _genetic_relatedness_vector_node( U = np.empty((num_windows, dim, num_components)) D = np.empty((num_windows, num_components)) - Q = np.empty((num_windows, num_components + num_oversamples)) + Q = np.empty((num_windows, dim, num_components + num_oversamples)) + E = np.empty(num_windows) for i in range(num_windows): this_window = windows[i : i + 2] _f = ( @@ -8779,22 +8785,22 @@ def _genetic_relatedness_vector_node( ) def _G(x): - _f(x, centre=centre, windows=this_window) # NOQA: B023 + return _f(x, centre=centre, windows=this_window) # NOQA: B023 - U[i], D[i], _, Q[i] = _rand_svd( + U[i], D[i], _, Q[i], E[i] = _rand_svd( operator=_G, operator_dim=dim, rank=num_components, depth=iterated_power, num_vectors=num_components + num_oversamples, rng=random_state, - range_sketch=None if random_sketch is None else range_sketch[i], + range_sketch=None if range_sketch is None else range_sketch[i], ) if drop_windows: U, D, Q = U[0], D[0], Q[0] - return U, D, Q + return U, D, Q, E def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ From be2f7360ede40b15041a7a675dc103f73aa86369 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Thu, 17 Oct 2024 14:32:59 -0400 Subject: [PATCH 19/23] now output is a dataclass --- python/tskit/trees.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 6064e7556f..0636cab5f7 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8714,8 +8714,9 @@ def _rand_svd( error_factor = np.power( 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), - 1 / (2 * depth + 1)) - error_bound = D[-1] * (2 + error_factor) + 1 / (2 * depth + 1) + ) + error_bound = D[-1] * (1 + error_factor) return U[:, :rank], D[:rank], V[:rank], Q, error_bound def _genetic_relatedness_vector_individual( @@ -8765,7 +8766,14 @@ def _genetic_relatedness_vector_node( x = x - x.mean(axis=0) if centre else x return x - + + @dataclass + class PCAResult: + U: np.ndarray + D: np.ndarray + Q: np.ndarray + E: np.ndarray + drop_windows = windows is None windows = self.parse_windows(windows) num_windows = len(windows) - 1 @@ -8800,7 +8808,9 @@ def _G(x): if drop_windows: U, D, Q = U[0], D[0], Q[0] - return U, D, Q, E + pca_result = PCAResult(U, D, Q, E) + + return pca_result def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ From bcbbcf63055f7580b78150efe865e6103315bbe7 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Thu, 17 Oct 2024 15:33:32 -0400 Subject: [PATCH 20/23] docstring change --- python/tskit/trees.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 0636cab5f7..6cace1b057 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8637,8 +8637,11 @@ def pca( be automatically generated. Valid random seeds must be between 1 and :math:`2^32 − 1`. :param np.ndarray range_sketch: Sketch matrix for each window. Default is None. - :return: A tuple (U, D, Q) of ndarrays, with the principal component loadings in U - and the principal values in D. Q is the range sketch array for each window. + :return: A class object with attributes U, D, Q and E. + The principal component loadings are in U + and the principal values are in D. + Q is the range sketch array for each window. + E is the error bound of the singular values.. """ if samples is None and individuals is None: From 1a8ff7a777d09e2cf8973d379bd2d55088798abd Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Thu, 17 Oct 2024 23:56:06 -0400 Subject: [PATCH 21/23] change variable name of PCAResult class --- python/tskit/trees.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 6cace1b057..fe5f4d9ab5 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8769,14 +8769,7 @@ def _genetic_relatedness_vector_node( x = x - x.mean(axis=0) if centre else x return x - - @dataclass - class PCAResult: - U: np.ndarray - D: np.ndarray - Q: np.ndarray - E: np.ndarray - + drop_windows = windows is None windows = self.parse_windows(windows) num_windows = len(windows) - 1 @@ -8811,7 +8804,7 @@ def _G(x): if drop_windows: U, D, Q = U[0], D[0], Q[0] - pca_result = PCAResult(U, D, Q, E) + pca_result = PCAResult(loadings=U, eigen_values=D, range_sketch=Q, error_bound=E) return pca_result From af163400e8248db3eea2dd1c5feb540b4b0ce99f Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Fri, 18 Oct 2024 13:40:03 -0400 Subject: [PATCH 22/23] move internal function of PCA out --- python/tskit/trees.py | 239 ++++++++++++++++++++++++------------------ 1 file changed, 137 insertions(+), 102 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index fe5f4d9ab5..c5435900ba 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8670,106 +8670,6 @@ def pca( ) random_state = np.random.default_rng(random_seed) - - def _rand_pow_range_finder( - operator, - operator_dim: int, - rank: int, - depth: int, - num_vectors: int, - rng: np.random.Generator, - range_sketch: np.ndarray = None, - ) -> np.ndarray: - """ - Algorithm 9 in https://arxiv.org/pdf/2002.01387 - """ - assert num_vectors >= rank > 0, "num_vectors should be larger than rank" - if range_sketch is None: - test_vectors = rng.normal(size=(operator_dim, num_vectors)) - Q = test_vectors - else: - Q = range_sketch - for _ in range(depth): - Q = np.linalg.qr(Q).Q - Q = operator(Q) - Q = np.linalg.qr(Q).Q - return Q[:, :rank] - - def _rand_svd( - operator, - operator_dim: int, - rank: int, - depth: int, - num_vectors: int, - rng: np.random.Generator, - range_sketch: np.ndarray = None, - ) -> (np.ndarray, np.ndarray, np.ndarray, float): - """ - Algorithm 8 in https://arxiv.org/pdf/2002.01387 - """ - assert num_vectors >= rank > 0 - Q = _rand_pow_range_finder( - operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch - ) - C = operator(Q).T - U_hat, D, V = np.linalg.svd(C, full_matrices=False) - U = Q @ U_hat - - error_factor = np.power( - 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), - 1 / (2 * depth + 1) - ) - error_bound = D[-1] * (1 + error_factor) - return U[:, :rank], D[:rank], V[:rank], Q, error_bound - - def _genetic_relatedness_vector_individual( - arr: np.ndarray, - centre: bool = True, - windows=None, - ) -> np.ndarray: - ij = np.vstack( - [ - [n, k] - for k, i in enumerate(individuals) - for n in self.individual(i).nodes - ] - ) - samples, sample_individuals = ( - ij[:, 0], - ij[:, 1], - ) # sample node index, individual of those nodes - x = ( - arr - arr.mean(axis=0) if centre else arr - ) # centering within index in rows - x = self.genetic_relatedness_vector( - W=x[sample_individuals], - windows=windows, - mode=mode, - centre=False, - nodes=samples, - )[0] - - def bincount_fn(w): - return np.bincount(sample_individuals, w) - - x = np.apply_along_axis(bincount_fn, axis=0, arr=x) - x = x - x.mean(axis=0) if centre else x # centering within index in cols - - return x - - def _genetic_relatedness_vector_node( - arr: np.ndarray, - centre: bool = True, - windows=None, - ) -> np.ndarray: - x = arr - arr.mean(axis=0) if centre else arr - x = self.genetic_relatedness_vector( - W=x, windows=windows, mode=mode, centre=False, nodes=samples - )[0] - x = x - x.mean(axis=0) if centre else x - - return x - drop_windows = windows is None windows = self.parse_windows(windows) num_windows = len(windows) - 1 @@ -8787,9 +8687,13 @@ def _genetic_relatedness_vector_node( if output_type == "node" else _genetic_relatedness_vector_individual ) - + indices = ( + samples + if output_type == "node" + else individuals + ) def _G(x): - return _f(x, centre=centre, windows=this_window) # NOQA: B023 + return _f(tree_sequence=self, arr=x, indices=indices, mode=mode, centre=centre, windows=this_window) # NOQA: B023 U[i], D[i], _, Q[i], E[i] = _rand_svd( operator=_G, @@ -10387,3 +10291,134 @@ def write_ms( ) else: print(file=output) + +def _rand_pow_range_finder( + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + range_sketch: np.ndarray = None, + ) -> np.ndarray: + """ + Algorithm 9 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0, "num_vectors should be larger than rank" + if range_sketch is None: + test_vectors = rng.normal(size=(operator_dim, num_vectors)) + Q = test_vectors + else: + Q = range_sketch + for _ in range(depth): + Q = np.linalg.qr(Q).Q + Q = operator(Q) + Q = np.linalg.qr(Q).Q + return Q[:, :rank] + +def _rand_svd( + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + range_sketch: np.ndarray = None, + ) -> (np.ndarray, np.ndarray, np.ndarray, float): + """ + Algorithm 8 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0 + Q = _rand_pow_range_finder( + operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch + ) + C = operator(Q).T + U_hat, D, V = np.linalg.svd(C, full_matrices=False) + U = Q @ U_hat + + error_factor = np.power( + 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), + 1 / (2 * depth + 1) + ) + error_bound = D[-1] * (1 + error_factor) + return U[:, :rank], D[:rank], V[:rank], Q, error_bound + +def _genetic_relatedness_vector_individual( + tree_sequence: tskit.TreeSequence, + arr: np.ndarray, + indices: np.ndarray, + mode: str, + centre: bool = True, + windows = None, + ) -> np.ndarray: + ij = np.vstack( + [ + [n, k] + for k, i in enumerate(indices) + for n in tree_sequence.individual(i).nodes + ] + ) + samples, sample_individuals = ( + ij[:, 0], + ij[:, 1], + ) # sample node index, individual of those nodes + x = ( + arr - arr.mean(axis=0) if centre else arr + ) # centering within index in rows + x = tree_sequence.genetic_relatedness_vector( + W=x[sample_individuals], + windows=windows, + mode=mode, + centre=False, + nodes=samples, + )[0] + + def bincount_fn(w): + return np.bincount(sample_individuals, w) + + x = np.apply_along_axis(bincount_fn, axis=0, arr=x) + x = x - x.mean(axis=0) if centre else x # centering within index in cols + + return x + +def _genetic_relatedness_vector_node( + tree_sequence: tskit.TreeSequence, + arr: np.ndarray, + indices: np.ndarray, + mode: str, + centre: bool = True, + windows = None, + ) -> np.ndarray: + x = arr - arr.mean(axis=0) if centre else arr + x = tree_sequence.genetic_relatedness_vector( + W=x, windows=windows, mode=mode, centre=False, nodes=indices, + )[0] + x = x - x.mean(axis=0) if centre else x + + return x + +@dataclass +class PCAResult: + """ + The result of a call to TreeSequence.pca() capturing the output values + and algorithm convergence details. + + + """ + loadings: np.ndarray + """ + The principal component loadings. It is an orthogonal matrix. + """ + eigen_values: np.ndarray + """ + Eigenvalues of the genetic relatedness matrix. + """ + range_sketch: np.ndarray + """ + Range sketch matrix. Can be used as an input for .pca() call with range_sketch option + to further improve precision.. + """ + error_bound: np.ndarray + """ + Error bounds for the eigenvalues. + """ From e81db153403340a97d3b847e3577f27e451ee0c7 Mon Sep 17 00:00:00 2001 From: Hanbin Lee Date: Tue, 22 Oct 2024 14:22:45 -0400 Subject: [PATCH 23/23] function rearrangement --- python/tskit/trees.py | 217 +++++++++++++++++++++--------------------- 1 file changed, 109 insertions(+), 108 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index c5435900ba..08c081a6a5 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8592,6 +8592,60 @@ def genetic_relatedness_vector( ) return out + def _genetic_relatedness_vector_node( + self, + arr: np.ndarray, + indices: np.ndarray, + mode: str, + centre: bool = True, + windows = None, + ) -> np.ndarray: + x = arr - arr.mean(axis=0) if centre else arr + x = self.genetic_relatedness_vector( + W=x, windows=windows, mode=mode, centre=False, nodes=indices, + )[0] + x = x - x.mean(axis=0) if centre else x + + return x + + def _genetic_relatedness_vector_individual( + self, + arr: np.ndarray, + indices: np.ndarray, + mode: str, + centre: bool = True, + windows = None, + ) -> np.ndarray: + ij = np.vstack( + [ + [n, k] + for k, i in enumerate(indices) + for n in self.individual(i).nodes + ] + ) + samples, sample_individuals = ( + ij[:, 0], + ij[:, 1], + ) # sample node index, individual of those nodes + x = ( + arr - arr.mean(axis=0) if centre else arr + ) # centering within index in rows + x = self.genetic_relatedness_vector( + W=x[sample_individuals], + windows=windows, + mode=mode, + centre=False, + nodes=samples, + )[0] + + def bincount_fn(w): + return np.bincount(sample_individuals, w) + + x = np.apply_along_axis(bincount_fn, axis=0, arr=x) + x = x - x.mean(axis=0) if centre else x # centering within index in cols + + return x + def pca( self, num_components: int, @@ -8669,6 +8723,58 @@ def pca( "the number of samples (or individuals, if specified)." ) + def _rand_pow_range_finder( + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + range_sketch: np.ndarray = None, + ) -> np.ndarray: + """ + Algorithm 9 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0, "num_vectors should be larger than rank" + if range_sketch is None: + test_vectors = rng.normal(size=(operator_dim, num_vectors)) + Q = test_vectors + else: + Q = range_sketch + for _ in range(depth): + Q = np.linalg.qr(Q).Q + Q = operator(Q) + Q = np.linalg.qr(Q).Q + return Q[:, :rank] + + def _rand_svd( + operator, + operator_dim: int, + rank: int, + depth: int, + num_vectors: int, + rng: np.random.Generator, + range_sketch: np.ndarray = None, + ) -> (np.ndarray, np.ndarray, np.ndarray, float): + """ + Algorithm 8 in https://arxiv.org/pdf/2002.01387 + """ + assert num_vectors >= rank > 0 + Q = _rand_pow_range_finder( + operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch + ) + C = operator(Q).T + U_hat, D, V = np.linalg.svd(C, full_matrices=False) + U = Q @ U_hat + + error_factor = np.power( + 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), + 1 / (2 * depth + 1) + ) + error_bound = D[-1] * (1 + error_factor) + return U[:, :rank], D[:rank], V[:rank], Q, error_bound + + random_state = np.random.default_rng(random_seed) drop_windows = windows is None windows = self.parse_windows(windows) @@ -8683,9 +8789,9 @@ def pca( for i in range(num_windows): this_window = windows[i : i + 2] _f = ( - _genetic_relatedness_vector_node + self._genetic_relatedness_vector_node if output_type == "node" - else _genetic_relatedness_vector_individual + else self._genetic_relatedness_vector_individual ) indices = ( samples @@ -8693,7 +8799,7 @@ def pca( else individuals ) def _G(x): - return _f(tree_sequence=self, arr=x, indices=indices, mode=mode, centre=centre, windows=this_window) # NOQA: B023 + return _f(arr=x, indices=indices, mode=mode, centre=centre, windows=this_window) # NOQA: B023 U[i], D[i], _, Q[i], E[i] = _rand_svd( operator=_G, @@ -10292,111 +10398,6 @@ def write_ms( else: print(file=output) -def _rand_pow_range_finder( - operator, - operator_dim: int, - rank: int, - depth: int, - num_vectors: int, - rng: np.random.Generator, - range_sketch: np.ndarray = None, - ) -> np.ndarray: - """ - Algorithm 9 in https://arxiv.org/pdf/2002.01387 - """ - assert num_vectors >= rank > 0, "num_vectors should be larger than rank" - if range_sketch is None: - test_vectors = rng.normal(size=(operator_dim, num_vectors)) - Q = test_vectors - else: - Q = range_sketch - for _ in range(depth): - Q = np.linalg.qr(Q).Q - Q = operator(Q) - Q = np.linalg.qr(Q).Q - return Q[:, :rank] - -def _rand_svd( - operator, - operator_dim: int, - rank: int, - depth: int, - num_vectors: int, - rng: np.random.Generator, - range_sketch: np.ndarray = None, - ) -> (np.ndarray, np.ndarray, np.ndarray, float): - """ - Algorithm 8 in https://arxiv.org/pdf/2002.01387 - """ - assert num_vectors >= rank > 0 - Q = _rand_pow_range_finder( - operator, operator_dim, num_vectors, depth, num_vectors, rng, range_sketch - ) - C = operator(Q).T - U_hat, D, V = np.linalg.svd(C, full_matrices=False) - U = Q @ U_hat - - error_factor = np.power( - 1 + 4 * np.sqrt(2 * operator_dim / (rank - 1)), - 1 / (2 * depth + 1) - ) - error_bound = D[-1] * (1 + error_factor) - return U[:, :rank], D[:rank], V[:rank], Q, error_bound - -def _genetic_relatedness_vector_individual( - tree_sequence: tskit.TreeSequence, - arr: np.ndarray, - indices: np.ndarray, - mode: str, - centre: bool = True, - windows = None, - ) -> np.ndarray: - ij = np.vstack( - [ - [n, k] - for k, i in enumerate(indices) - for n in tree_sequence.individual(i).nodes - ] - ) - samples, sample_individuals = ( - ij[:, 0], - ij[:, 1], - ) # sample node index, individual of those nodes - x = ( - arr - arr.mean(axis=0) if centre else arr - ) # centering within index in rows - x = tree_sequence.genetic_relatedness_vector( - W=x[sample_individuals], - windows=windows, - mode=mode, - centre=False, - nodes=samples, - )[0] - - def bincount_fn(w): - return np.bincount(sample_individuals, w) - - x = np.apply_along_axis(bincount_fn, axis=0, arr=x) - x = x - x.mean(axis=0) if centre else x # centering within index in cols - - return x - -def _genetic_relatedness_vector_node( - tree_sequence: tskit.TreeSequence, - arr: np.ndarray, - indices: np.ndarray, - mode: str, - centre: bool = True, - windows = None, - ) -> np.ndarray: - x = arr - arr.mean(axis=0) if centre else arr - x = tree_sequence.genetic_relatedness_vector( - W=x, windows=windows, mode=mode, centre=False, nodes=indices, - )[0] - x = x - x.mean(axis=0) if centre else x - - return x - @dataclass class PCAResult: """