Skip to content

Commit

Permalink
add regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
Eh2406 committed Apr 5, 2018
1 parent a5b3838 commit cd9a255
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
14 changes: 12 additions & 2 deletions urbansim/models/dcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ class MNLDiscreteChoiceModel(DiscreteChoiceModel):
in output.
normalize : bool, optional default False
subtract the mean and divide by the standard deviation before fitting the Coefficients
l1 : float, optional default 0.0
the amount of l1 (Lasso) regularization when fitting the Coefficients
l2 : float, optional default 0.0
the amount of l2 (Ridge) regularization when fitting the Coefficients
"""
def __init__(
Expand All @@ -254,7 +258,7 @@ def __init__(
estimation_sample_size=None,
prediction_sample_size=None,
choice_column=None, name=None,
normalize=False):
normalize=False, l1=0.0, l2=0.0):
self._check_prob_choice_mode_compat(probability_mode, choice_mode)
self._check_prob_mode_interaction_compat(
probability_mode, interaction_predict_filters)
Expand All @@ -274,6 +278,8 @@ def __init__(
self.name = name if name is not None else 'MNLDiscreteChoiceModel'
self.sim_pdf = None
self.normalize = normalize
self.l1 = l1
self.l2 = l2

self.log_likelihoods = None
self.fit_parameters = None
Expand Down Expand Up @@ -314,6 +320,8 @@ def from_yaml(cls, yaml_str=None, str_or_buffer=None):
choice_column=cfg.get('choice_column', None),
name=cfg.get('name', None),
normalize=cfg.get('normalize', False),
l1=cfg.get('l1', 0.0),
l2=cfg.get('l2', 0.0),
)

if cfg.get('log_likelihoods', None):
Expand Down Expand Up @@ -425,7 +433,7 @@ def fit(self, choosers, alternatives, current_choice):
'the input columns.')

self.log_likelihoods, self.fit_parameters = mnl.mnl_estimate(
model_design.as_matrix(), chosen, self.sample_size, self.normalize)
model_design.as_matrix(), chosen, self.sample_size, self.normalize, self.l1, self.l2)
self.fit_parameters.index = model_design.columns

logger.debug('finish: fit LCM model {}'.format(self.name))
Expand Down Expand Up @@ -702,6 +710,8 @@ def to_dict(self):
'fit_parameters': (yamlio.frame_to_yaml_safe(self.fit_parameters)
if self.fitted else None),
'normalize': self.normalize,
'l1': self.l1,
'l2': self.l2,
}

def to_yaml(self, str_or_buffer=None):
Expand Down
12 changes: 9 additions & 3 deletions urbansim/urbanchoice/mnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_standard_error(hessian):


def mnl_loglik(beta, data, chosen, numalts, weights=None, lcgrad=False,
stderr=0):
stderr=0, l1=0.0, l2=0.0):
logger.debug('start: calculate MNL log-likelihood')
numvars = beta.size
numobs = data.size() // numvars // numalts
Expand Down Expand Up @@ -114,6 +114,12 @@ def mnl_loglik(beta, data, chosen, numalts, weights=None, lcgrad=False,
loglik = loglik.get_mat()[0, 0]
gradarr = np.reshape(gradarr.get_mat(), (1, gradarr.size()))[0]

loglik -= l1 * np.abs(beta.get_mat()).sum()
gradarr -= l1 * np.sign(beta.get_mat())

loglik -= l2 * np.square(beta.get_mat()).sum()
gradarr -= l1 * beta.get_mat()

logger.debug('finish: calculate MNL log-likelihood')
return -1 * loglik, -1 * gradarr

Expand Down Expand Up @@ -178,7 +184,7 @@ def mnl_simulate(data, coeff, numalts, normalization_mean=0.0, normalization_std


def mnl_estimate(data, chosen, numalts, GPU=False, coeffrange=(-3, 3),
weights=None, lcgrad=False, beta=None, normalize=False):
weights=None, lcgrad=False, beta=None, normalize=False, l1=0.0, l2=0.0):
"""
Calculate coefficients of the MNL model.
Expand Down Expand Up @@ -249,7 +255,7 @@ def mnl_estimate(data, chosen, numalts, GPU=False, coeffrange=(-3, 3),
bounds = [coeffrange] * numvars

with log_start_finish('scipy optimization for MNL fit', logger):
args = (data, chosen, numalts, weights, lcgrad)
args = (data, chosen, numalts, weights, lcgrad, l1, l2)
bfgs_result = scipy.optimize.fmin_l_bfgs_b(mnl_loglik,
beta,
args=args,
Expand Down

0 comments on commit cd9a255

Please sign in to comment.