-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from ECRL/ecabc-update
Update to ECabc-based hyper-parameter tuning functions
- Loading branch information
Showing
18 changed files
with
79 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from ecnet.server import Server | ||
__version__ = '3.3.1' | ||
__version__ = '3.3.2' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/models/mlp.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "MultilayerPerceptron" (feed-forward neural network) class | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/server.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "Server" class, which handles ECNet project creation, neural | ||
|
@@ -151,7 +151,7 @@ def limit_inputs(self, limit_num: int, num_estimators: int = None, | |
def tune_hyperparameters(self, num_employers: int, num_iterations: int, | ||
shuffle: bool = None, split: list = None, | ||
validate: bool = True, eval_set: str = None, | ||
eval_fn: str = 'rmse', epochs: int = 300): | ||
eval_fn: str = 'rmse', epochs: int = 500): | ||
'''Tunes neural network learning hyperparameters using an artificial | ||
bee colony algorithm; tuned hyperparameters are saved to Server's | ||
model configuration file | ||
|
@@ -167,7 +167,7 @@ def tune_hyperparameters(self, num_employers: int, num_iterations: int, | |
`train`, `test`, None (all sets) | ||
eval_fn (str): error function used to evaluate bee fitness; | ||
`rmse`, `mean_abs_error`, `med_abs_error` | ||
epochs (int): number of training epochs per bee ANN (def: 300) | ||
epochs (int): number of training epochs per bee ANN (def: 500) | ||
''' | ||
|
||
self._vars = tune_hyperparameters( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/limit_inputs.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for selecting influential input parameters | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/training.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains function for project training (multiprocessed training) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tasks/tuning.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/fitness functions for tuning hyperparameters | ||
|
@@ -13,7 +13,7 @@ | |
from os import name | ||
|
||
# 3rd party imports | ||
from ecabc.abc import ABC | ||
from ecabc import ABC | ||
|
||
# ECNet imports | ||
from ecnet.utils.data_utils import DataFrame | ||
|
@@ -25,7 +25,7 @@ def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int, | |
num_iterations: int, num_processes: int = 1, | ||
shuffle: str = None, split: list = None, | ||
validate: bool = True, eval_set: str = None, | ||
eval_fn: str = 'rmse', epochs: int = 300) -> dict: | ||
eval_fn: str = 'rmse', epochs: int = 500) -> dict: | ||
'''Tunes neural network learning/architecture hyperparameters | ||
Args: | ||
|
@@ -41,7 +41,7 @@ def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int, | |
`train`, `test`, None (all sets) | ||
eval_fn (str): error function used to evaluate bee performance; `rmse`, | ||
`mean_abs_error`, `med_abs_error` | ||
epochs (int): number of training epochs per bee ANN (def: 300) | ||
epochs (int): number of training epochs per bee ANN (def: 500) | ||
Returns: | ||
dict: tuned hyperparameters | ||
|
@@ -72,57 +72,48 @@ def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int, | |
'epochs': epochs | ||
} | ||
|
||
value_ranges = [ | ||
('float', (1e-9, 1e-4)), # Learning rate decay | ||
('float', (1e-5, 0.1)), # Learning rate | ||
('int', (1, len(df.learn_set))), # Batch size | ||
('int', (64, 1024)) # Patience | ||
to_tune = [ | ||
(1e-9, 1e-4, 'decay'), | ||
(1e-5, 0.1, 'learning_rate'), | ||
(1, len(df.learn_set), 'batch_size'), | ||
(64, 1024, 'patience') | ||
] | ||
for hl in range(len(vars['hidden_layers'])): | ||
to_tune.append((1, 2 * len(df._input_names), 'hl{}'.format(hl))) | ||
|
||
for _ in range(len(vars['hidden_layers'])): | ||
value_ranges.append(('int', (1, len(df._input_names)))) | ||
|
||
abc = ABC( | ||
tune_fitness_function, | ||
num_employers=num_employers, | ||
value_ranges=value_ranges, | ||
args=fit_fn_args, | ||
processes=num_processes | ||
) | ||
|
||
abc._logger.stream_level = logger.stream_level | ||
if logger.file_level != 'disable': | ||
abc._logger.log_dir = logger.log_dir | ||
abc._logger.file_level = logger.file_level | ||
abc._logger.default_call_loc('TUNE') | ||
abc.create_employers() | ||
abc = ABC(num_employers, tune_fitness_function, fit_fn_args, num_processes) | ||
for param in to_tune: | ||
abc.add_param(param[0], param[1], name=param[2]) | ||
abc.initialize() | ||
|
||
best_ret_val = abc.best_ret_val | ||
best_params = abc.best_params | ||
for i in range(num_iterations): | ||
logger.log('info', 'Iteration {}'.format(i + 1), call_loc='TUNE') | ||
abc.run_iteration() | ||
abc.search() | ||
new_best_ret = abc.best_ret_val | ||
new_best_params = abc.best_params | ||
logger.log('info', 'Best Performer: {}, {}'.format( | ||
abc.best_performer[2], { | ||
'decay': abc.best_performer[1][0], | ||
'learning_rate': abc.best_performer[1][1], | ||
'batch_size': abc.best_performer[1][2], | ||
'patience': abc.best_performer[1][3], | ||
'hidden_layers': abc.best_performer[1][4:] | ||
} | ||
new_best_ret, new_best_ret | ||
), call_loc='TUNE') | ||
params = abc.best_performer[1] | ||
vars['decay'] = params[0] | ||
vars['learning_rate'] = params[1] | ||
vars['batch_size'] = params[2] | ||
vars['patience'] = params[3] | ||
if new_best_ret < best_ret_val: | ||
best_ret_val = new_best_ret | ||
best_params = new_best_params | ||
|
||
vars['decay'] = best_params['decay'] | ||
vars['learning_rate'] = best_params['learning_rate'] | ||
vars['batch_size'] = best_params['batch_size'] | ||
vars['patience'] = best_params['patience'] | ||
for l_idx in range(len(vars['hidden_layers'])): | ||
vars['hidden_layers'][l_idx][0] = params[4 + l_idx] | ||
vars['hidden_layers'][l_idx][0] = best_params['hl{}'.format(l_idx)] | ||
return vars | ||
|
||
|
||
def tune_fitness_function(params: dict, **kwargs): | ||
def tune_fitness_function(params: list, **kwargs) -> float: | ||
'''Fitness function used by ABC | ||
Args: | ||
params (dict): bee hyperparams | ||
params (list): bee hyperparams | ||
kwargs (dict): additional arguments | ||
Returns: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/database.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for creating ECNet-formatted databases | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/plotting.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/classes for creating various plots | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/tools/project.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for predicting data using pre-existing .prj files | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/data_utils.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions/classes for loading data, saving data, saving results | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/error_utils.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for error calculations | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/logging.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains logger used by ECNet | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/utils/server_utils.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions used by ecnet.Server | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/workflows/ecrl_workflow.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# General workflow used by the UMass Lowell Energy and Combustion Research | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/workflows/workflow_utils.py | ||
# v.3.3.1 | ||
# v.3.3.2 | ||
# Developed in 2020 by Travis Kessler <[email protected]> | ||
# | ||
# Functions used by the ECRL workflow | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters