Skip to content

Commit

Permalink
try something
Browse files Browse the repository at this point in the history
  • Loading branch information
philipperemy committed Aug 13, 2024
1 parent f8f5750 commit de9c2d6
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions tcn/tcn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import List # noqa
from typing import List # noqa

import tensorflow as tf
# pylint: disable=E0611,E0401
Expand Down Expand Up @@ -270,6 +270,12 @@ def __init__(self,
def receptive_field(self):
return 1 + 2 * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations)

def tolist(self, shape):
try:
return shape.as_list()
except AttributeError:
return shape

def build(self, input_shape):

# member to hold current output shape of the layer for building purposes
Expand Down Expand Up @@ -305,20 +311,17 @@ def build(self, input_shape):

self.output_slice_index = None
if self.padding == 'same':
time = self.build_output_shape.as_list()[1]
time = self.tolist(self.build_output_shape)[1]
if time is not None: # if time dimension is defined. e.g. shape = (bs, 500, input_dim).
self.output_slice_index = int(self.build_output_shape.as_list()[1] / 2)
self.output_slice_index = int(self.tolist(self.build_output_shape)[1] / 2)
else:
# It will known at call time. c.f. self.call.
self.padding_same_and_time_dim_unknown = True

else:
self.output_slice_index = -1 # causal case.
self.slicer_layer = Lambda(lambda tt: tt[:, self.output_slice_index, :], name='Slice_Output')
try:
self.slicer_layer.build(self.build_output_shape.as_list())
except AttributeError:
self.slicer_layer.build(self.build_output_shape)
self.slicer_layer.build(self.tolist(self.build_output_shape))

def compute_output_shape(self, input_shape):
"""
Expand Down

0 comments on commit de9c2d6

Please sign in to comment.