Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with TimeDistributed + LSTM layer #18941

Closed
innat opened this issue Dec 14, 2023 · 2 comments · Fixed by #18968
Closed

Issue with TimeDistributed + LSTM layer #18941

innat opened this issue Dec 14, 2023 · 2 comments · Fixed by #18968
Assignees
Labels
keras-team-review-pending Pending review by a Keras team member. To investigate Looks like a bug. It needs someone to investigate.

Comments

@innat
Copy link

innat commented Dec 14, 2023

From here.

from numpy import array
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import TimeDistributed
from keras.layers import LSTM

# prepare sequence
length = 5
seq = array([i/float(length) for i in range(length)])
X = seq.reshape(1, length, 1)
y = seq.reshape(1, length, 1)

# define LSTM configuration
n_neurons = length
n_batch = 1
n_epoch = 10

# create LSTM
model = Sequential()
model.add(keras.layers.InputLayer((length, 1)))
model.add(LSTM(n_neurons, return_sequences=True))
model.add(TimeDistributed(Dense(1)))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(X, y, epochs=n_epoch, batch_size=n_batch, verbose=2)
Epoch 1/10
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 1
----> 1 model.fit(X, y, epochs=n_epoch, batch_size=n_batch, verbose=2)

File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    120     filtered_tb = _process_traceback_frames(e.__traceback__)
    121     # To get the full stack trace, call:
    122     # `keras.config.disable_traceback_filtering()`
--> 123     raise e.with_traceback(filtered_tb) from None
    124 finally:
    125     del filtered_tb

File /opt/conda/lib/python3.10/site-packages/keras/src/backend/common/variables.py:394, in standardize_dtype(dtype)
    391     dtype = str(dtype).split(".")[-1]
    393 if dtype not in ALLOWED_DTYPES:
--> 394     raise ValueError(f"Invalid dtype: {dtype}")
    395 return dtype

ValueError: Exception encountered when calling TimeDistributed.call().

Invalid dtype: <class 'NoneType'>

Arguments received by TimeDistributed.call():
  • inputs=tf.Tensor(shape=(1, None, 5), dtype=float32)
  • training=True
  • mask=None

However, compiling the model with eager mode runs properly.

@SuryanarayanaY
Copy link
Contributor

Hi @innat ,

I have replicated the reported error and attached gist here.

The reported error occurs when dtype not in allowed_types in keras3. The inputs X and y dtypes can be is "float64" after applying standardize_dtype() which is OK. It seems the issue with TimeDistributed layer output. Needs investigation.

Thanks!

@SuryanarayanaY SuryanarayanaY added To investigate Looks like a bug. It needs someone to investigate. keras-team-review-pending Pending review by a Keras team member. labels Dec 15, 2023
hertschuh added a commit to hertschuh/keras that referenced this issue Dec 19, 2023
This allows the static shape to propagate in RNNs in the case when the number of time steps is fixed and known at build time.
fchollet pushed a commit that referenced this issue Dec 19, 2023
This allows the static shape to propagate in RNNs in the case when the number of time steps is fixed and known at build time.
@hertschuh
Copy link
Collaborator

hertschuh commented Dec 19, 2023

@innat ,

The fix will be available in tomorrow's nightly build, i.e. you have to pip install keras-nightly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keras-team-review-pending Pending review by a Keras team member. To investigate Looks like a bug. It needs someone to investigate.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants