Skip to content

Commit

Permalink
Update linear.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 3, 2023
1 parent 4663bb5 commit 12b3a96
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions serket/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ class Multilinear(sk.TreeClass):
Args:
in_features: number of input features for each input
out_features: number of output features
weight_init: function to initialize the weights
bias_init: function to initialize the bias
key: key for the random number generator
weight_init: function to initialize the weights. Defaults to ``glorot_uniform``.
bias_init: function to initialize the bias. Defaults to ``zeros``.
key: key for the random number generator. Defaults to ``jax.random.PRNGKey(0)``.
dtype: dtype of the weights and bias. defaults to ``jnp.float32``.
Example:
Expand Down Expand Up @@ -161,8 +161,8 @@ def __init__(
in_features: tuple[int, ...] | None,
out_features: int,
*,
weight_init: InitType = "he_normal",
bias_init: InitType = "ones",
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
dtype: DType = jnp.float32,
):
Expand Down Expand Up @@ -192,9 +192,9 @@ class Linear(Multilinear):
Args:
in_features: number of input features
out_features: number of output features
weight_init: function to initialize the weights
bias_init: function to initialize the bias
key: key for the random number generator
weight_init: function to initialize the weights. Defaults to ``glorot_uniform``.
bias_init: function to initialize the bias. Defaults to ``zeros``.
key: key for the random number generator. Defaults to ``jax.random.PRNGKey(0)``.
dtype: data type of the weights and biases. defaults to ``jnp.float32``.
Example:
Expand Down Expand Up @@ -236,8 +236,8 @@ def __init__(
in_features: int | None,
out_features: int,
*,
weight_init: InitType = "he_normal",
bias_init: InitType = "ones",
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
dtype: DType = jnp.float32,
):
Expand Down Expand Up @@ -266,8 +266,8 @@ class GeneralLinear(sk.TreeClass):
in_features: number of input features corresponding to in_axes
out_features: number of output features
in_axes: axes to apply the linear layer to
weight_init: weight initialization function
bias_init: bias initialization function
weight_init: weight initialization function. Defaults to ``glorot_uniform``.
bias_init: bias initialization function. Defaults to ``zeros``.
key: key to use for initializing the weights. defaults to ``jax.random.PRNGKey(0)``.
dtype: dtype of the weights and biases. defaults to ``jnp.float32``.
Expand Down Expand Up @@ -310,8 +310,8 @@ def __init__(
out_features: int,
*,
in_axes: tuple[int, ...],
weight_init: InitType = "he_normal",
bias_init: InitType = "ones",
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
dtype: DType = jnp.float32,
):
Expand Down Expand Up @@ -402,11 +402,10 @@ class FNN(sk.TreeClass):
layers: Sequence of layer sizes
act: a single Activation function to be applied between layers or
``len(layers)-2`` Sequence of activation functions applied between
layers.
weight_init: Weight initializer function.
bias_init: Bias initializer function. Defaults to lambda key,
shape: jnp.ones(shape).
key: Random key for weight and bias initialization.
layers. Defaults to ``tanh``.
weight_init: Weight initializer function. Defaults to ``glorot_uniform``.
bias_init: Bias initializer function. Defaults to ``zeros``.
key: Random key for weight and bias initialization. Defaults to ``jax.random.PRNGKey(0)``.
dtype: dtype of the weights and biases. defaults to ``jnp.float32``.
Example:
Expand Down Expand Up @@ -543,10 +542,10 @@ class MLP(sk.TreeClass):
out_features: Number of output features.
hidden_size: Number of hidden units in each hidden layer.
num_hidden_layers: Number of hidden layers including the output layer.
act: Activation function.
weight_init: Weight initialization function.
bias_init: Bias initialization function.
key: Random number generator key.
act: Activation function. Defaults to ``tanh``.
weight_init: Weight initialization function. Defaults to ``glorot_uniform``.
bias_init: Bias initialization function. Defaults to ``zeros``.
key: Random number generator key. Defaults to ``jax.random.PRNGKey(0)``.
dtype: dtype of the weights and biases. defaults to ``jnp.float32``.
Example:
Expand Down

0 comments on commit 12b3a96

Please sign in to comment.