forked from commaai/speedchallenge
-
Notifications
You must be signed in to change notification settings - Fork 1
/
keras_predictor.py
108 lines (83 loc) · 5.09 KB
/
keras_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import tensorflow as tf
from tensorflow.keras import layers
import loader
import keras_resnet
import keras_regnet
import time
import depth_and_motion_net
BATCH_SIZE = 32
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
# physical_devices = tf.config.list_physical_devices('GPU')
# print(physical_devices)
# try:
# tf.config.experimental.set_memory_growth(physical_devices[0], True)
# except:
# # Invalid device or cannot modify virtual devices once initialized.
# pass
flags.DEFINE_string("checkpoint_dir", None, "Directory to load model state from to resume training.")
flags.DEFINE_string("experiment_name", None, "Name of the experiment being run.")
# training_dataset = loader.load_tfrecord("/mnt/Datasets/2k19_train_augmented/*.tfrecord", BATCH_SIZE, True)
# validation_dataset = loader.load_tfrecord("/mnt/Datasets/2k19_val_augmented/*.tfrecord", BATCH_SIZE, False)
# training_dataset = loader.load_tfrecord("/mnt/e/commaai/2k19_train_augmented/*.tfrecord", BATCH_SIZE, True)
# training_dataset = loader.load_tfrecord("/mnt/e/commaai/2k19_train_augmented_dedupped/*.tfrecord", BATCH_SIZE, True)
# training_dataset = loader.load_tfrecord("/mnt/e/commaai/2k19_train_more_augmented/*.tfrecord", BATCH_SIZE, True)
# validation_dataset = loader.load_tfrecord("/mnt/e/commaai/2k19_val_augmented/*.tfrecord", BATCH_SIZE, False)
# training_dataset = loader.load_tfrecord("/mnt/Bulk Storage/commaai/monolithic_train.tfrecord", BATCH_SIZE, True)
# test_dataset = loader.load_tfrecord("/mnt/Datasets/monolithic_test.tfrecord", BATCH_SIZE, False)
# test_dataset = loader.load_tfrecord("/mnt/Datasets/actual_test.tfrecord", BATCH_SIZE, False)
test_dataset = loader.load_tfrecord("/mnt/Datasets/calib_0.tfrecord", BATCH_SIZE, False)
# training_dataset = dali_loader.load_tfrecord("/mnt/Datasets/2k19_train/", BATCH_SIZE, True)
# validation_dataset = dali_loader.load_tfrecord("/mnt/Datasets/2k19_val/", BATCH_SIZE, False)
inputs = tf.keras.Input(shape=(256, 640, 6), batch_size=BATCH_SIZE, name='frames')
# inputs = tf.keras.Input(shape=(312, 416, 6), batch_size=BATCH_SIZE, name='frames')
# inputs = tf.keras.Input(shape=(128, 416, 6), batch_size=BATCH_SIZE, name='frames')
# encoder
conv5 = keras_resnet.resnet18_encoder(inputs)
# conv5 = keras_resnet.resnet34_encoder(inputs)
# conv5 = keras_regnet.regnety_400mf(inputs)
# conv5 = keras_regnet.regnetx_600mf(inputs)
# conv5 = keras_regnet.regnetx_600mf(inputs)
# more downsampling
conv6 = keras_resnet.res_block_first(conv5, 512, stride=2)
conv6 = keras_resnet.res_block(conv6, 512)
conv7 = keras_resnet.res_block_first(conv6, 512, stride=2)
conv7 = keras_resnet.res_block(conv7, 512)
bottleneck = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, [1, 2]))(conv7)
# bottleneck = tf.keras.layers.Dropout(0.2)(bottleneck)
rotation = tf.keras.layers.Dense(3, name='rot_fc')(bottleneck)
translation = tf.keras.layers.Dense(3, name='trans_fc')(bottleneck)
rotation = depth_and_motion_net.Scale(0.001)(rotation)
translation = depth_and_motion_net.Scale(0.001)(translation)
# translation, rotation = depth_and_motion_net.depth_and_motion_net_fc_no_mean(inputs)
pose = tf.keras.layers.Concatenate(axis=1, name='pose')([translation, rotation])
# speed prediction
speed = tf.keras.layers.Lambda(lambda x: tf.expand_dims(20 * tf.norm(x[:, :3], axis=1), -1), name='speed')(pose)
pitch = tf.keras.layers.Lambda(lambda x: tf.expand_dims(20 * tf.math.atan2(tf.cast(x[:, 1], tf.float32), tf.cast(x[:, 0], tf.float32)), -1), name='pitch')(pose)
# yaw = tf.keras.layers.Lambda(lambda x: tf.expand_dims(20 * tf.norm(x[:, :3], axis=1), -1), name='speed')(pose)
model = tf.keras.Model(inputs=inputs, outputs=[pose, speed, pitch])
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss={'pose': 'mse', 'speed': 'mse', 'pitch': 'mse'}, loss_weights={'pose': 1.0, 'speed': 0.0, 'pitch': 0.0})
if FLAGS.checkpoint_dir:
checkpoint_dir = FLAGS.checkpoint_dir
# print('attempting to load checkpoint from {}'.format(checkpoint_dir))
latest = tf.train.latest_checkpoint(checkpoint_dir)
model.load_weights("checkpoints/resnet18 large/cp-0019.ckpt")
elif FLAGS.experiment_name:
checkpoint_dir = 'checkpoints/{}'.format(FLAGS.experiment_name)
else:
checkpoint_dir = 'checkpoints/{}'.format(time.strftime("%m_%d_%y-%H_%M"))
checkpoint_path = checkpoint_dir + "/cp-{epoch:04d}.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
tb_callback = tf.keras.callbacks.TensorBoard(checkpoint_dir, update_freq=1)
buar_callback = tf.keras.callbacks.experimental.BackupAndRestore(checkpoint_dir)
# model.fit(training_dataset, epochs=30, validation_data=validation_dataset, callbacks=[cp_callback, tb_callback, buar_callback])
model.evaluate(test_dataset)
# velocities = model.predict(test_dataset)[1]
# for v in velocities:
# print(v[0])
# model.evaluate(validation_dataset, callbacks=[cp_callback, tb_callback])
# model.fit(training_dataset, epochs=30, callbacks=[cp_callback])