Skip to content

pytorch backend lstm very +10x slow (maybe batch size with pytorch backend has different semantics than the traditional Keras semantics?) #20902

Open
@mw66

Description

@mw66

https://stackoverflow.com/questions/78717341/keras-training-speed-with-pytorch-backend-is-a-lot-slower-than-with-tensorflow

"""
I am on native Windows and I used old Keras with TensorFlow 2.10 (GPU accelerated) before. I wanted to try Keras 3 with PyTorch backend. Can someone please help me why this model trains 10x slower with Keras 3.4.1 and PyTorch 2.3.1 backend? With my GPU a single epoch takes a little more than 2 minutes with TF, and over 20 minutes with PyTorch.

import os
os.environ["KERAS_BACKEND"] = "torch"
import torch
torch.cuda.is_available() # <-- returns True

import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras import optimizers
from keras.regularizers import l2

x_train, y_train = np.float32(x_train), np.float32(y_train)
x_val, y_val = np.float32(x_val), np.float32(y_val)

model=Sequential()
reg=0.00001
model.add(LSTM( 80, return_sequences=True , dropout=0.0, kernel_regularizer=l2(reg), recurrent_regularizer=l2(reg), input_shape=(x_train.shape[1], x_train.shape[2]) ))
model.add(LSTM( 80, return_sequences=False, dropout=0.0, kernel_regularizer=l2(reg), recurrent_regularizer=l2(reg) ))
model.add(Dense(40))
model.add(Dense(40))
model.add(Dense(1))
opt = optimizers.Adam(learning_rate=lrate)
model.compile(optimizer=opt, loss='mean_squared_error')

from keras.callbacks import ModelCheckpoint
from keras.callbacks import BackupAndRestore
savecallback = ModelCheckpoint(basefolder+"/"+modelfile, save_best_only=False, monitor='val_loss', mode='min', verbose=1)
backupcallback = BackupAndRestore(basefolder+"/tmp/backup_"+modelfile)

hist=model.fit(x_train, y_train, validation_data=(x_val, y_val), batch_size=batchsize, epochs=20, callbacks=[savecallback, backupcallback])

I verified GPU acceleration with both backends.
"""

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions