Skip to content

Commit

Permalink
Merge pull request #24 from ncats/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
jorgeso authored Dec 9, 2020
2 parents f330e4f + 0319b3d commit 4810ad0
Showing 1 changed file with 49 additions and 15 deletions.
64 changes: 49 additions & 15 deletions server/predictors/cyp450/cyp450_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tqdm import tqdm
import multiprocessing as mp
from copy import deepcopy
from multiprocessing import Process, Pipe

class CYP450Predictor:
"""
Expand Down Expand Up @@ -129,27 +130,31 @@ def get_predictions(self):

start = time.time()

processes_dict = {}
conns_dict = {}

for model_name in tqdm(cyp450_models_dict.keys()):

model_has_error = False
parent_conn, child_conn = Pipe()

conns_dict[model_name] = parent_conn

# create a masked array to calculate mean probabilities event when values are missing
probs_matrix = np.ma.empty((64, features.shape[0]))
probs_matrix.mask = True
params_dict = {
"model_name": model_name,
"features": features,
"error_threshold_length": len(self.predictions_df.index)
}

for model_number in tqdm(range(0, 64)):
probs = cyp450_models_dict[model_name][f'model_{model_number}'].predict_proba(features)
probs_matrix[model_number, :probs.shape[0]] = probs.T[1]
if model_has_error == False and len(self.predictions_df.index) > len(probs):
model_has_error = True
processes_dict[model_name] = Process(target=self._get_model_predictions, args=(child_conn,))
parent_conn.send(params_dict)

# pool = mp.Pool()
# probs_matrix = np.ma.array([pool.apply(self._predict_rf, args=(deepcopy(cyp450_models_dict[model_name][f'model_{model_number}']), features.copy())) for model_number in range(0, 64)])
# pool.close()
# pool.terminate()
# probs_matrix.mask = True
for model_name in processes_dict:
processes_dict[model_name].start()

mean_probs = probs_matrix.mean(axis=0)
for model_name in processes_dict:
response_dict = conns_dict[model_name].recv()
model_has_error = response_dict["model_has_error"]
mean_probs = response_dict["mean_probs"]

if model_has_error:
self.model_errors.append(self._columns_dict[model_name]['description'])
Expand All @@ -160,6 +165,9 @@ def get_predictions(self):
+pd.Series(mean_probs).round(2).astype(str)
+')'
)
conns_dict[model_name].close()
processes_dict[model_name].join()
processes_dict[model_name].close()

end = time.time()
print(f'{end - start} seconds to CYP450 predict {len(self.predictions_df.index)} molecules')
Expand All @@ -168,6 +176,32 @@ def get_predictions(self):

return self.df.merge(self.predictions_df, on=self._smi_column_name, how='left')

def _get_model_predictions(self, con):

params_dict = con.recv()
model_name = params_dict['model_name']
features = params_dict['features']
error_threshold_length = params_dict['error_threshold_length']
models = cyp450_models_dict[model_name]
model_has_error = False
probs_matrix = np.ma.empty((64, features.shape[0]))
probs_matrix.mask = True

for model_number in tqdm(range(0, 64)):
probs = models[f'model_{model_number}'].predict_proba(features)
probs_matrix[model_number, :probs.shape[0]] = probs.T[1]
if model_has_error == False and error_threshold_length > len(probs):
model_has_error = True

mean_probs = probs_matrix.mean(axis=0)
response_dict = {
"mean_probs": mean_probs,
"model_has_error": model_has_error
}
con.send(response_dict)
con.close()
return

def get_errors(self):
return {
'has_smi_errors': self.has_smi_errors,
Expand Down

0 comments on commit 4810ad0

Please sign in to comment.