Skip to content

Commit

Permalink
Merge pull request #14 from lee-group-cmu/issue/12/black-formatting
Browse files Browse the repository at this point in the history
Applying black code style formatting.
  • Loading branch information
biprateep authored Mar 24, 2023
2 parents 253b5f2 + e26ccdb commit 2a4a083
Show file tree
Hide file tree
Showing 12 changed files with 214 additions and 198 deletions.
25 changes: 14 additions & 11 deletions src/flexcode/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,21 @@ def evaluate_basis(responses, n_basis, basis_system):
If the basis system isn't recognized.
"""
systems = {'cosine' : cosine_basis,
'Fourier' : fourier_basis,
'db4' : wavelet_basis}
systems = {"cosine": cosine_basis, "Fourier": fourier_basis, "db4": wavelet_basis}
try:
basis_fn = systems[basis_system]
except KeyError:
raise ValueError("Basis system {} not recognized".format(basis_system))

n_dim = responses.shape[1]
if n_dim == 1:
if n_dim == 1:
return basis_fn(responses, n_basis)
else:
if len(n_basis) == 1:
n_basis = [n_basis] * n_dim
return tensor_basis(responses, n_basis, basis_fn)


def tensor_basis(responses, n_basis, basis_fn):
"""Evaluates tensor basis.
Expand Down Expand Up @@ -118,6 +117,7 @@ def cosine_basis(responses, n_basis):
basis[:, col] = np.sqrt(2) * np.cos(np.pi * col * responses)
return basis


def fourier_basis(responses, n_basis):
"""Evaluates Fourier basis.
Expand Down Expand Up @@ -149,7 +149,8 @@ def fourier_basis(responses, n_basis):
basis[:, -1] = np.sqrt(2) * np.sin(np.pi * n_basis * responses)
return basis

def wavelet_basis(responses, n_basis, family='db4'):

def wavelet_basis(responses, n_basis, family="db4"):
"""Evaluates Daubechies basis.
Arguments
Expand Down Expand Up @@ -179,6 +180,7 @@ def wavelet_basis(responses, n_basis, family='db4'):
_, wavelet, x_grid = rez
wavelet *= np.sqrt(max(x_grid) - min(x_grid))
x_grid = (x_grid - min(x_grid)) / (max(x_grid) - min(x_grid))

def _wave_fun(val):
if val < 0 or val > 1:
return 0.0
Expand All @@ -191,16 +193,16 @@ def _wave_fun(val):
loc = 0
level = 0
for col in range(1, n_basis):
basis[:, col] = [2 ** (level / 2) * _wave_fun(a * 2 ** level - loc) for a in responses]
basis[:, col] = [2 ** (level / 2) * _wave_fun(a * 2**level - loc) for a in responses]
loc += 1
if loc == 2 ** level:
if loc == 2**level:
loc = 0
level += 1
return basis


class BasisCoefs(object):
def __init__(self, coefs, basis_system, z_min, z_max, bump_threshold=None,
sharpen_alpha=None):
def __init__(self, coefs, basis_system, z_min, z_max, bump_threshold=None, sharpen_alpha=None):
self.coefs = coefs
self.basis_system = basis_system
self.z_min = z_min
Expand All @@ -209,8 +211,9 @@ def __init__(self, coefs, basis_system, z_min, z_max, bump_threshold=None,
self.sharpen_alpha = sharpen_alpha

def evaluate(self, z_grid):
basis = evaluate_basis(box_transform(z_grid, self.z_min, self.z_max),
self.coefs.shape[1], self.basis_system)
basis = evaluate_basis(
box_transform(z_grid, self.z_min, self.z_max), self.coefs.shape[1], self.basis_system
)
cdes = np.matmul(self.coefs, basis.T)

normalize(cdes)
Expand Down
57 changes: 30 additions & 27 deletions src/flexcode/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@


class FlexCodeModel(object):
def __init__(self, model, max_basis, basis_system="cosine",
z_min=None, z_max=None, regression_params={},
custom_model=None):
def __init__(
self,
model,
max_basis,
basis_system="cosine",
z_min=None,
z_max=None,
regression_params={},
custom_model=None,
):
"""Initialize FlexCodeModel object
:param model: A FlexCodeRegression object
Expand Down Expand Up @@ -55,13 +62,13 @@ def fit(self, x_train, z_train, weight=None):
if self.z_max is None:
self.z_max = max(z_train)

z_basis = evaluate_basis(box_transform(z_train, self.z_min, self.z_max),
self.max_basis, self.basis_system)
z_basis = evaluate_basis(
box_transform(z_train, self.z_min, self.z_max), self.max_basis, self.basis_system
)

self.model.fit(x_train, z_basis, weight)

def tune(self, x_validation, z_validation, bump_threshold_grid=None,
sharpen_grid=None, n_grid=1000):
def tune(self, x_validation, z_validation, bump_threshold_grid=None, sharpen_grid=None, n_grid=1000):
"""Set tuning parameters to minimize CDE loss
Sets best_basis, bump_delta, and sharpen_alpha values attributes
Expand All @@ -85,59 +92,55 @@ def tune(self, x_validation, z_validation, bump_threshold_grid=None,

coefs = self.model.predict(x_validation)

term1 = np.mean(coefs ** 2, 0)
term1 = np.mean(coefs**2, 0)
term2 = np.mean(coefs * z_basis, 0)
# losses = np.cumsum(term1 - 2 * term2)
self.best_basis = np.where(term1 - 2 * term2 < 0.0)[0]

if bump_threshold_grid is not None or sharpen_grid is not None:
coefs = coefs[:, self.best_basis]
z_grid = make_grid(n_grid, self.z_min, self.z_max)
z_basis = evaluate_basis(box_transform(z_grid, self.z_min, self.z_max),
max(self.best_basis) + 1, self.basis_system)
z_basis = evaluate_basis(
box_transform(z_grid, self.z_min, self.z_max), max(self.best_basis) + 1, self.basis_system
)
z_basis = z_basis[:, self.best_basis]
cdes = np.matmul(coefs, z_basis.T)
normalize(cdes)

if bump_threshold_grid is not None:
self.bump_threshold = choose_bump_threshold(cdes, z_grid,
z_validation,
bump_threshold_grid)
self.bump_threshold = choose_bump_threshold(cdes, z_grid, z_validation, bump_threshold_grid)

remove_bumps(cdes, self.bump_threshold)
normalize(cdes)

if sharpen_grid is not None:
self.sharpen_alpha = choose_sharpen(cdes, z_grid, z_validation,
sharpen_grid)


self.sharpen_alpha = choose_sharpen(cdes, z_grid, z_validation, sharpen_grid)

def predict_coefs(self, x_new):
if len(x_new.shape) == 1:
x_new = x_new.reshape(-1, 1)

coefs = self.model.predict(x_new)[:, self.best_basis]
return BasisCoefs(coefs, self.basis_system, self.z_min,
self.z_max, self.bump_threshold, self.sharpen_alpha)
return BasisCoefs(
coefs, self.basis_system, self.z_min, self.z_max, self.bump_threshold, self.sharpen_alpha
)

def predict(self, x_new, n_grid):
"""Predict conditional density estimates on new data
n :param x_new: A numpy matrix of covariates at which to predict
:param n_grid: int, the number of grid points at which to
predict the conditional density
:returns: A numpy matrix where each row is a conditional
density estimate at the grid points
:rtype: numpy matrix
n :param x_new: A numpy matrix of covariates at which to predict
:param n_grid: int, the number of grid points at which to
predict the conditional density
:returns: A numpy matrix where each row is a conditional
density estimate at the grid points
:rtype: numpy matrix
"""
if len(x_new.shape) == 1:
x_new = x_new.reshape(-1, 1)

z_grid = make_grid(n_grid, 0.0, 1.0)
z_basis = evaluate_basis(z_grid, max(self.best_basis) + 1,
self.basis_system)
z_basis = evaluate_basis(z_grid, max(self.best_basis) + 1, self.basis_system)
z_basis = z_basis[:, self.best_basis]
coefs = self.model.predict(x_new)[:, self.best_basis]
cdes = np.matmul(coefs, z_basis.T)
Expand Down
22 changes: 13 additions & 9 deletions src/flexcode/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def box_transform(z, z_min, z_max):

return (z - z_min) / (z_max - z_min)


def make_grid(n_grid, z_min, z_max):
"""Create grid of equally spaced points
Expand All @@ -24,8 +25,9 @@ def make_grid(n_grid, z_min, z_max):
"""
return np.linspace(z_min, z_max, n_grid).reshape((n_grid, 1))


def params_dict_optim_decision(params, multi_output=False):
'''
"""
Ingest parameter dictionary and determines whether to do CV optimization.
If one of the parameter has a list of length above 1 as values
then automatically format the dictionary for GridSearchCV.
Expand All @@ -36,14 +38,16 @@ def params_dict_optim_decision(params, multi_output=False):
:returns: a dictionary of parameters and a boolean flag of whether CV-opt
is going to be performed. If CV-optimization is set to happen then
the paramater dictionary is correctly format.
'''
"""

# Determines whether there are any list in the items of the dictionary
opt_flag = False
for k, value in params.items():
if type(value) == tuple:
raise ValueError("Parameter values need to be lists or np.array, not tuple."
"Current issues with parameter %s" % (k))
raise ValueError(
"Parameter values need to be lists or np.array, not tuple."
"Current issues with parameter %s" % (k)
)
if type(value) == list or type(value) == np.ndarray:
opt_flag = True
break
Expand All @@ -55,14 +59,14 @@ def params_dict_optim_decision(params, multi_output=False):
for k, value in params.items():
out_value = value.tolist() if type(value) == np.ndarray else value
out_value = [out_value] if type(out_value) != list else out_value
out_key = 'estimator__' + k if multi_output else k
out_key = "estimator__" + k if multi_output else k
out_param_dict[out_key] = out_value

return out_param_dict, opt_flag


def params_name_format(params, str_rem):
'''
"""
Changes all the key in dictionaries to remove a specific word from each key (``estimator__``).
This is because in order to GridsearchCV on MultiOutputRegressor one needs to
use ``estimator__`` in all parameters - but once the best parameters are fetched
Expand All @@ -71,9 +75,9 @@ def params_name_format(params, str_rem):
:param params: dictionary of model parameters
:param str_rem: word to be removed
:returns: dictionary of parameters in which the word has been removed in keys
'''
"""
out_dict = {}
for k, v in params.items():
new_key = k.replace(str_rem, '') if str_rem in k else k
new_key = k.replace(str_rem, "") if str_rem in k else k
out_dict[new_key] = v
return out_dict
return out_dict
2 changes: 1 addition & 1 deletion src/flexcode/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def cde_loss(cde_estimates, z_grid, true_z):

n_obs, n_grid = cde_estimates.shape

term1 = np.mean(np.trapz(cde_estimates ** 2, z_grid.flatten()))
term1 = np.mean(np.trapz(cde_estimates**2, z_grid.flatten()))

nns = [np.argmin(np.abs(z_grid - true_z[ii])) for ii in range(n_obs)]
term2 = np.mean(cde_estimates[range(n_obs), nns])
Expand Down
15 changes: 10 additions & 5 deletions src/flexcode/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def normalize(cde_estimates, tol=1e-6, max_iter=200):
if cde_estimates.ndim == 1:
_normalize(cde_estimates, tol, max_iter)
else:
np.apply_along_axis(_normalize, 1, cde_estimates, tol=tol,
max_iter=max_iter)
np.apply_along_axis(_normalize, 1, cde_estimates, tol=tol, max_iter=max_iter)


def _normalize(density, tol=1e-6, max_iter=500):
"""Normalizes a density estimate to be non-negative and integrate to
Expand Down Expand Up @@ -62,6 +62,7 @@ def _normalize(density, tol=1e-6, max_iter=500):
density -= mid
density[density < 0.0] = 0.0


def sharpen(cde_estimates, alpha):
"""Sharpens conditional density estimates.
Expand All @@ -76,6 +77,7 @@ def sharpen(cde_estimates, alpha):
cde_estimates **= alpha
normalize(cde_estimates)


def choose_sharpen(cde_estimates, z_grid, true_z, alpha_grid):
"""Chooses the sharpen parameter by minimizing cde loss.
Expand All @@ -97,6 +99,7 @@ def choose_sharpen(cde_estimates, z_grid, true_z, alpha_grid):
best_alpha = alpha
return best_alpha


def remove_bumps(cde_estimates, delta):
"""Removes bumps in conditional density estimates
Expand All @@ -111,7 +114,8 @@ def remove_bumps(cde_estimates, delta):
if cde_estimates.ndim == 1:
_remove_bumps(cde_estimates, delta)
else:
np.apply_along_axis(_remove_bumps, 1, cde_estimates, delta = delta)
np.apply_along_axis(_remove_bumps, 1, cde_estimates, delta=delta)


def _remove_bumps(density, delta):
"""Removes bumps in conditional density estimates.
Expand All @@ -131,17 +135,18 @@ def _remove_bumps(density, delta):
for right_idx, val in enumerate(density):
if val <= 0.0:
if area < delta:
density[left_idx:(right_idx + 1)] = 0.0
density[left_idx : (right_idx + 1)] = 0.0
removed_area += area
left_idx = right_idx + 1
area = 0.0
else:
area += val * bin_size
if area < delta: # final check at end
if area < delta: # final check at end
density[left_idx:] = 0.0
removed_area += area
_normalize(density)


def choose_bump_threshold(cde_estimates, z_grid, true_z, delta_grid):
"""Chooses the bump threshold which minimizes cde loss.
Expand Down
Loading

0 comments on commit 2a4a083

Please sign in to comment.