Skip to content

Commit

Permalink
sklearn tuning 5 trial
Browse files Browse the repository at this point in the history
  • Loading branch information
nicdemon committed Sep 15, 2023
1 parent 3988b07 commit 890cc2b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
1 change: 1 addition & 0 deletions data/subset_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,6 @@ def verify_extract_ids(file):
# Merge and save the classes subset file
################################################################################
cls_out = pd.merge(cls,ids, on = 'id', how = 'inner')
cls_out = cls_out.drop_duplicates(subset = 'id', keep = 'first')

cls_out.to_csv(opt['output'], index = False)
34 changes: 29 additions & 5 deletions src/supplement/sklearn_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from ray import tune
from ray.tune import Tuner, TuneConfig
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.skopt import SkOptSearch
from ray.air.config import RunConfig, ScalingConfig


warnings.simplefilter(action='ignore')

# Functions
Expand Down Expand Up @@ -302,15 +304,37 @@ def split_val_test_ds(ds, data):
),
)

# Ray + SkOpt
# print('tuner')
# tuner = Tuner(
# trainer,
# param_space = tune_params,
# tune_config = TuneConfig(
# num_samples = 5,
# max_concurrent_trials = int((os.cpu_count() * 0.8)),
# search_alg = SkOptSearch(
# metric = 'test/test_score', # mean accuracy according to scikit-learn's doc
# mode = 'max'
# ),
# scheduler = ASHAScheduler(
# metric = 'test/test_score', # mean accuracy according to scikit-learn's doc
# mode = 'max'
# )

# )
# )

# Basic Ray tuner using GridSearch algo
print('tuner')
tuner = Tuner(
trainer,
param_space=tune_params,
tune_config=TuneConfig(
max_concurrent_trials=int((os.cpu_count() * 0.8)),
scheduler=ASHAScheduler(
param_space = tune_params,
tune_config = TuneConfig(
num_samples = 5,
max_concurrent_trials = int((os.cpu_count() * 0.8)),
scheduler = ASHAScheduler(
metric = 'test/test_score', # mean accuracy according to scikit-learn's doc
mode='max'
mode = 'max'
)
)
)
Expand Down

0 comments on commit 890cc2b

Please sign in to comment.