-
Notifications
You must be signed in to change notification settings - Fork 1
/
baseline_convolutional_model.py
54 lines (44 loc) · 1.51 KB
/
baseline_convolutional_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Default simple convolutional model."""
import tensorflow as tf
def simple_convolutional_model(
time_window,
filters=1,
kernel_size=16,
channels=64
):
"""Construct the simple convolutional model.
Parameters
----------
time_window: int
Time window of input data in samples
filters: int
Number of filters for the convolutional layer
kernel_size: int
Kernel size for the convolutional layer
channels: int
Number of channels in the EEG
Returns
-------
tf.keras.model.Model
Simple convolutional model
"""
# If different inputs are required
eeg = tf.keras.layers.Input([time_window, channels])
env1 = tf.keras.layers.Input([time_window, 1])
env2 = tf.keras.layers.Input([time_window, 1])
eeg_proj = tf.keras.layers.Conv1D(filters, kernel_size=kernel_size)(eeg)
cut_layer = tf.keras.layers.Lambda(lambda t: t[:, :-(kernel_size-1), :])
dot_layer = tf.keras.layers.Dot(1, normalize=True)
cos1 = dot_layer([eeg_proj, cut_layer(env1)])
cos2 = dot_layer([eeg_proj, cut_layer(env2)])
all_cos = tf.keras.layers.Concatenate()([cos1, cos2])
flat = tf.keras.layers.Flatten()(all_cos)
out = tf.keras.layers.Dense(1, activation="sigmoid")(flat)
model = tf.keras.Model(inputs=[eeg, env1, env2], outputs=[out])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
metrics=["acc"],
loss=["binary_crossentropy"]
)
print(model.summary())
return model