diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index aa1e200..d2b20cc 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -282,7 +282,6 @@ class Conv1D(ConvND): groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``jax.random.PRNGKey(0)``. dtype: dtype of the weights. defaults to ``jax.numpy.float32`` - dtype: dtype to use for the weights and bias. defaults to ``jnp.float32``. Example: >>> import jax.numpy as jnp