Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] cleaned tests #20

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 3 additions & 21 deletions clar/data/artificial.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def get_S_star(
if noise_type == "Gaussian_iid":
S_star = np.eye(n_channels)
elif noise_type == "Gaussian_multivariate":
vect = rho_noise ** np.arange(n_channels)
S_star = toeplitz(vect, vect)
S_star = toeplitz(rho_noise ** np.arange(n_channels))
else:
raise ValueError("Unknown noise type %s" % noise_type)
return S_star
Expand Down Expand Up @@ -146,7 +145,7 @@ def get_dictionary(
elif dictionary_type == 'Gaussian':
X = rng.randn(n_channels, n_sources)
else:
raise NotImplementedError("No dictionary '{}' in maxsparse"
raise NotImplementedError("No dictionary '{}' in clar"
.format(dictionary_type))
normalize(X)
return X
Expand Down Expand Up @@ -181,23 +180,6 @@ def get_toeplitz_dictionary(
The dictionary.
"""
rng = check_random_state(seed)
vect = rho ** np.arange(n_sources)
covar = toeplitz(vect, vect)
covar = toeplitz(rho ** np.arange(n_sources))
X = rng.multivariate_normal(np.zeros(n_sources), covar, n_channels)
return X


def decimate(M, n_channels, axis, seed):
if n_channels in (M.shape[0], -1):
return M

n_channels_max = M.shape[0]
rng = check_random_state(seed)

to_choose = rng.choice(np.arange(n_channels_max), n_channels)
to_choose.sort()
if axis.__contains__(1):
M = M[:, to_choose]
if axis.__contains__(0):
M = M[to_choose, :]
return M
13 changes: 0 additions & 13 deletions clar/duality_gap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
from numpy.linalg import norm, slogdet
from numba import njit

from clar.utils import l_2_inf
from clar.utils import l_2_1
Expand All @@ -24,31 +23,27 @@ def get_p_obj_mrce(
return p_obj


@njit
def get_p_obj_mtl(R, B, alpha):
n_sensors, n_times = R.shape
p_obj = (R ** 2).sum() / (2 * n_times * n_sensors) \
+ alpha * l_2_1(B)
return p_obj


@njit
def get_d_obj_mtl(Y, Theta, alpha):
n_sensors, n_times = Y.shape
d_obj = alpha * (Theta * Y).sum() - \
alpha ** 2 * n_times * n_sensors * (Theta ** 2).sum() / 2
return d_obj


@njit
def get_feasible_theta_mtl(R, X, alpha):
n_sensors, n_times = R.shape
scaling_factor = l_2_inf(X.T @ R)
scaling_factor = max(scaling_factor, alpha * n_sensors * n_times)
return R / scaling_factor


@njit
def get_p_obj_me(R_all_epochs, B, S_inv_R, S_trace, alpha):
n_epochs, n_channels, n_times = R_all_epochs.shape
p_obj = (R_all_epochs * S_inv_R).sum()
Expand All @@ -58,7 +53,6 @@ def get_p_obj_me(R_all_epochs, B, S_inv_R, S_trace, alpha):
return p_obj


@njit
def get_d_obj_me(all_epochs, Theta, sigma_min, alpha):
n_epochs, n_channels, n_times = all_epochs.shape
d_obj = alpha * (all_epochs * Theta).sum() / n_epochs
Expand All @@ -69,7 +63,6 @@ def get_d_obj_me(all_epochs, Theta, sigma_min, alpha):
return d_obj


@njit
def get_d_obj(Y, Theta, sigma_min, alpha):
n_channels, n_times = Y.shape
d_obj = alpha * (Y * Theta).sum()
Expand All @@ -78,7 +71,6 @@ def get_d_obj(Y, Theta, sigma_min, alpha):
return d_obj


@njit
def get_p_obj(R, B, S_trace, alpha, S_inv_R):
n_channels, n_times = R.shape
p_obj = (R * S_inv_R).sum() / (2. * n_channels * n_times)
Expand All @@ -87,7 +79,6 @@ def get_p_obj(R, B, S_trace, alpha, S_inv_R):
return p_obj


@njit
def get_feasible_theta(X, alpha, S_inv_R):
n_channels, n_times = S_inv_R.shape
scaling_factor = max(
Expand All @@ -97,7 +88,6 @@ def get_feasible_theta(X, alpha, S_inv_R):
return S_inv_R / scaling_factor


@njit
def get_feasible_theta_me(X, alpha, S_inv_R):
n_epochs, n_channels, n_times = S_inv_R.shape
S_inv_R_mean = np.zeros((n_channels, n_times), dtype=np.float64)
Expand All @@ -118,7 +108,6 @@ def get_feasible_theta_me(X, alpha, S_inv_R):
return S_inv_R / scaling_factor


@njit
def get_duality_gap_mtl(X, Y, B, alpha):
R = Y - X @ B
p_obj = get_p_obj_mtl(R, B, alpha)
Expand All @@ -127,7 +116,6 @@ def get_duality_gap_mtl(X, Y, B, alpha):
return p_obj, d_obj


@njit
def get_duality_gap(
R, X, Y, B, S_trace, S_inv_R, sigma_min, alpha):
p_obj = get_p_obj(R, B, S_trace, alpha, S_inv_R)
Expand All @@ -136,7 +124,6 @@ def get_duality_gap(
return p_obj, d_obj


@njit
def get_duality_gap_me(
X, all_epochs, B, S_trace, S_inv,
sigma_min, alpha):
Expand Down
43 changes: 10 additions & 33 deletions clar/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,8 @@ def get_path(
# save the results
mask = np.abs(B_hat).sum(axis=1) != 0
str_pourcentage_alpha = '%0.10f' % pourcentage_alpha
if pb_name == "MTLME":
n_sources = X.shape[1]
n_epochs, _, n_times = measurement.shape
B_reshaped = B_hat.reshape((n_sources, n_epochs, n_times))
B_reshaped = B_reshaped.mean(axis=1)
dict_masks[str_pourcentage_alpha] = mask
dict_dense_Bs[str_pourcentage_alpha] = B_reshaped[mask, :]
else:
dict_masks[str_pourcentage_alpha] = mask
dict_dense_Bs[str_pourcentage_alpha] = B_hat[mask, :]
dict_masks[str_pourcentage_alpha] = mask
dict_dense_Bs[str_pourcentage_alpha] = B_hat[mask, :]
assert len(dict_dense_Bs.keys()) == len(list_pourcentage_alpha)
return dict_masks, dict_dense_Bs

Expand Down Expand Up @@ -89,7 +81,7 @@ def solver(
S is updated every S times.
pb_name: str
choose the problem you want to solve between
"MTL", "MTLME", "SGCL", "CLAR" and "mrce"
"MTL", "SGCL", "CLAR" and "mrce"
use_accel: bool
States if you want to use accelratio while computing the dual.
heur_stop: bool
Expand All @@ -110,11 +102,7 @@ def solver(
print("--------- %s -----------------" % pb_name)

if B0 is None:
if pb_name != "MTLME":
B = np.zeros((n_sources, n_times), dtype=float)
else:
n_epochs, _, n_times = all_epochs.shape
B = np.zeros((n_sources, n_times * n_epochs), dtype=float)
B = np.zeros((n_sources, n_times), dtype=float)
else:
B = B0.copy().astype(np.float64)

Expand All @@ -125,15 +113,6 @@ def solver(
observations = all_epochs[None, :, :]
elif pb_name in ("CLAR", "mrce"):
observations = all_epochs
elif pb_name == "MTLME":
if all_epochs.ndim != 3:
raise ValueError(
"Wrong number of dimensions, expected 2, "
"got %d " % all_epochs.ndim)
observations = all_epochs.transpose((1, 0, 2))
observations = observations.reshape(observations.shape[0], -1)
observations = observations.reshape((1, *observations.shape))
n_epochs, _, n_times = all_epochs.shape
else:
raise ValueError("Unknown solver %s" % pb_name)

Expand Down Expand Up @@ -166,11 +145,9 @@ def solver_(
Y += all_epochs[l, :, :]
Y2 /= n_epochs
Y /= n_epochs
elif pb_name in ("MTL", "SGCL", "MTLME"):
elif pb_name in ("MTL", "SGCL"):
Y = all_epochs[0]
Y2 = None
elif pb_name == "MTLME":
Y = all_epochs

if use_accel:
K = 6
Expand All @@ -192,7 +169,7 @@ def solver_(
primal_first, _ = get_duality_gap(
Y, X, Y, B_first, S_trace_first,
S_inv_R, sigma_min, alpha)
elif pb_name in("MTL", "MTLME"):
elif pb_name in "MTL":
primal_first, _ = get_duality_gap_mtl(
X, Y, B_first, alpha)
elif pb_name == "mrce":
Expand Down Expand Up @@ -231,7 +208,7 @@ def solver_(
S_trace, S_inv = clp_sqrt(ZZT, sigma_min)
S_inv_R = np.asfortranarray(S_inv @ R)
S_inv_X = S_inv @ X
elif pb_name in ("MTL", "MTLME"):
elif pb_name in "MTL":
# this else case is for MTL
# dummy variables for njit to work:
S_trace = n_sensors
Expand Down Expand Up @@ -289,7 +266,7 @@ def solver_(
if verbose:
print("gap_acc: %.2e" % (p_obj - d_obj_acc))
gaps_acc.append(p_obj - d_obj_acc)
elif pb_name in ("MTL", "MTLME"):
elif pb_name in "MTL":
p_obj, d_obj = get_duality_gap_mtl(
X, Y, B, alpha)
elif pb_name == "mrce":
Expand Down Expand Up @@ -333,7 +310,7 @@ def update_S(Y, X, B, Y2, sigma_min, pb_name):
Z = Y - X @ B
ZZT = Z @ Z.T / n_times
S_trace, S_inv = clp_sqrt(ZZT, sigma_min)
elif pb_name in ("MTL", "MTLME"):
elif pb_name in "MTL":
# this else case is for MTL
# dummy variables for njit to work:
S_trace = n_sensors
Expand All @@ -349,7 +326,7 @@ def update_B(
n_sensors, n_times = Y.shape
n_sources = X.shape[1]

is_not_MTL = pb_name not in ("MTL", "MTLME")
is_not_MTL = pb_name not in "MTL"

active_set = np.ones(n_sources)

Expand Down
44 changes: 0 additions & 44 deletions clar/tests/test_clar.py

This file was deleted.

58 changes: 0 additions & 58 deletions clar/tests/test_mrce.py

This file was deleted.

Loading