-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun_cnn.py
47 lines (37 loc) · 1.22 KB
/
run_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
""" Script to train the BrainAge model directly. It won't perform
any of the federated learning (Vantage6) calls or tasks (e.g.,
running the models at different organizations, merging the results)
"""
import os
import random
import tensorflow as tf
from federated_brain_age.brain_age import BrainAge
from federated_brain_age.data_loader import DataLoader
data_path = "/mnt"
parameters = {
"USE_MASK": True,
"EPOCHS": 2,
"PATIENTS_PER_EPOCH": 2,
"ROUNDS": 1,
}
seed = 1
model = BrainAge(parameters, "test", data_path + "/data/", "CSV", data_path + "/dataset.csv", seed, 0.7)
result = model.train()
# img_size = model.initialize()
# batch_size = len(model.train_loader.participants)
# img_scale = model.get_parameter("IMG_SCALE")
# predictions = model.model.predict(
# model.train_loader.data_generator(
# img_size, 2, img_scale, mask=model.mask, augment=False, mode=[], shuffle=False, crop=model.crop
# ),
# max_queue_size = 1,
# batch_size=2,
# steps=1
# )
# print("Predictions")
# print(predictions)
# dict_keys(['loss', 'mae', 'mse', 'val_loss', 'val_mae', 'val_mse'])
output = {}
for metric in result.history.keys():
output[metric] = result.history[metric][-1]
print(output)