Skip to content

Commit

Permalink
Merge pull request #1 from alonfnt/randomized
Browse files Browse the repository at this point in the history
add Halko randomized algo for large datasets
  • Loading branch information
alonfnt authored May 29, 2023
2 parents f4d6d63 + ab57c9e commit bbcd4a3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
45 changes: 38 additions & 7 deletions pcax/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,29 @@
import jax
import jax.numpy as jnp


class PCAState(NamedTuple):
components: jax.Array
means: jax.Array
explained_variance: jax.Array


def fit(x, n_components, solver="full"):
def transform(state, x):
x = x - state.means
return jnp.dot(x, jnp.transpose(state.components))


def recover(state, x):
return jnp.dot(x, state.components) + state.means


def fit(x, n_components, solver="full", rng=None):
if solver == "full":
return _fit_full(x, n_components)
elif solver == "randomized":
if rng is None:
rng = jax.random.PRNGKey(n_components)
return _fit_randomized(x, n_components, rng)
else:
raise ValueError("solver parameter is not correct")

Expand All @@ -29,17 +43,34 @@ def _fit_full(x, n_components):
U, S, Vt = jax.scipy.linalg.svd(x, full_matrices=False)

# Compute the explained variance
explained_variance = (S**2) / (n_samples - 1)
explained_variance = (S[:n_components] ** 2) / (n_samples - 1)

# Return the transformation matrix
A = Vt[:n_components]
return PCAState(components=A, means=means, explained_variance=explained_variance)


def transform(state, x):
x = x - state.means
return jnp.dot(x, jnp.transpose(state.components))
def _fit_randomized(x, n_components, rng, n_iter=5):
"""Randomized PCA based on Halko et al [https://doi.org/10.48550/arXiv.1007.5510]."""
n_samples, n_features = x.shape
means = jnp.mean(x, axis=0, keepdims=True)
x = x - means

# Generate n_features normal vectors of the given size
size = jnp.minimum(2 * n_components, n_features)
Q = jax.random.normal(rng, shape=(n_features, size))

def recover(state, x):
return jnp.dot(x, state.components) + state.means
def step_fn(q, _):
q, _ = jax.scipy.linalg.lu(x @ q, permute_l=True)
q, _ = jax.scipy.linalg.lu(x.T @ q, permute_l=True)
return q, None

Q, _ = jax.lax.scan(step_fn, init=Q, xs=None, length=n_iter)
Q, _ = jax.scipy.linalg.qr(x @ Q, mode="economic")
B = Q.T @ x

_, S, Vt = jax.scipy.linalg.svd(B, full_matrices=False)

explained_variance = (S[:n_components] ** 2) / (n_samples - 1)
A = Vt[:n_components]
return PCAState(components=A, means=means, explained_variance=explained_variance)
16 changes: 9 additions & 7 deletions tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ def test_fit_invalid_solver():

@pytest.mark.parametrize("n_components", [1, 2, 5, 10])
@pytest.mark.parametrize("n_entries", [100, 200, 300])
@pytest.mark.parametrize("solver", ["full"])
@pytest.mark.parametrize("solver", ["full", "randomized"])
def test_fit_output_shapes(n_entries, n_components, solver):
x = jax.random.normal(KEY, shape=(n_entries, 50))
rng, _ = jax.random.split(KEY)

state = fit(x, n_components=n_components, solver=solver)
state = fit(x, n_components=n_components, solver=solver, rng=rng)

assert state.components.shape == (n_components, x.shape[1])
assert state.means.shape == (1, x.shape[1])
assert state.explained_variance.shape == (x.shape[1],)
assert state.explained_variance.shape == (n_components,)


def test_fit_zero_mean():
Expand All @@ -37,11 +38,12 @@ def test_fit_zero_mean():


@pytest.mark.parametrize("n_components", [50])
@pytest.mark.parametrize("n_entries", [100, 200])
def test_reconstruction(n_entries, n_components):
@pytest.mark.parametrize("n_entries", [300, 500])
@pytest.mark.parametrize("solver", ['full', 'randomized'])
def test_reconstruction(n_entries, n_components, solver):
x = jax.random.normal(KEY, shape=(n_entries, 50))
state = fit(x, n_components=n_components, solver="full")
state = fit(x, n_components=n_components, solver=solver)
x_pca = transform(state, x)
x_recovered = recover(state, x_pca)
assert x_recovered.shape == x.shape
assert jnp.allclose(x, x_recovered, atol=1e-2)
assert jnp.allclose(x, x_recovered, atol=1e-1)

0 comments on commit bbcd4a3

Please sign in to comment.