diff --git a/trax/fastmath/ops.py b/trax/fastmath/ops.py index 144729038..e6a001d45 100644 --- a/trax/fastmath/ops.py +++ b/trax/fastmath/ops.py @@ -18,14 +18,12 @@ Import these operations directly from fastmath and import fastmath.numpy as np: -``` -from trax import fastmath -from trax.fastmath import numpy as np - -x = np.array([1.0, 2.0]) # Use like numpy. -y = np.exp(x) # Common numpy ops are available and accelerated. -z = fastmath.logsumexp(y) # Special operations (below) available from fastmath. -``` +>>> from trax import fastmath +>>> from trax.fastmath import numpy as np +>>> +>>> x = np.array([1.0, 2.0]) # Use like numpy. +>>> y = np.exp(x) # Common numpy ops are available and accelerated. +>>> z = fastmath.logsumexp(y) # Special operations available from fastmath. Trax uses either TensorFlow 2 or JAX as backend for accelerating operations. You can select which one to use (e.g., for debugging) with `use_backend`. diff --git a/trax/layers/combinators.py b/trax/layers/combinators.py index a8ac36062..d205f1d50 100644 --- a/trax/layers/combinators.py +++ b/trax/layers/combinators.py @@ -27,14 +27,13 @@ class Serial(base.Layer): """Combinator that applies layers serially (by function composition). This combinator is commonly used to construct deep networks, e.g., like this: - ``` - mlp = tl.Serial( - tl.Dense(128), - tl.Relu(), - tl.Dense(10), - tl.LogSoftmax() - ) - ``` + + >>> mlp = tl.Serial( + >>> tl.Dense(128), + >>> tl.Relu(), + >>> tl.Dense(10), + >>> tl.LogSoftmax() + >>> ) A Serial combinator uses stack semantics to manage data for its sublayers. Each sublayer sees only the inputs it needs and returns only the outputs it