*Serket is the goddess of magic in Egyptian mythology
Installation |Description |Documentation |Quick Example
Install development version
pip install git+https://github.com/ASEM000/serket
serket
aims to be the most intuitive and easy-to-use machine learning library injax
.serket
is fully transparent tojax
transformation (e.g.vmap
,grad
,jit
,...).
- Full documentation
- Train MNIST, UNet, ConvLSTM, PINN
- Model surgery, Parallelism, Mixed precision
- Optimizers, Augmentation composition
- Interoperability with keras, tensorflow
import jax, jax.numpy as jnp
import serket as sk
x_train, y_train = ..., ...
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
net = sk.tree_mask(sk.Sequential(
jnp.ravel,
sk.nn.Linear(28 * 28, 64, key=k1),
jax.nn.relu,
sk.nn.Linear(64, 10, key=k2),
))
@ft.partial(jax.grad, has_aux=True)
def loss_func(net, x, y):
logits = jax.vmap(sk.tree_unmask(net))(x)
onehot = jax.nn.one_hot(y, 10)
loss = jnp.mean(softmax_cross_entropy(logits, onehot))
return loss, (loss, logits)
@jax.jit
def train_step(net, x, y):
grads, (loss, logits) = loss_func(net, x, y)
net = jax.tree_map(lambda p, g: p - g * 1e-3, net, grads)
return net, (loss, logits)
for j, (xb, yb) in enumerate(zip(x_train, y_train)):
net, (loss, logits) = train_step(net, xb, yb)
accuracy = accuracy_func(logits, y_train)
net = sk.tree_unmask(net)
📚 Layers catalog
Group | Layers |
---|---|
Containers | - Sequential , Random{Choice} |
Group | Layers |
---|---|
Attention | - MultiHeadAttention |
Convolution | - {FFT,_}Conv{1D,2D,3D} - {FFT,_}Conv{1D,2D,3D}Transpose - Depthwise{FFT,_}Conv{1D,2D,3D} - Separable{FFT,_}Conv{1D,2D,3D} - Conv{1D,2D,3D}Local - SpectralConv{1D,2D,3D} |
Dropout | - Dropout - Dropout{1D,2D,3D} - RandomCutout{1D,2D,3D} |
Linear | - Linear , MLP , Identity |
Normalization | - {Layer,Instance,Group,Batch}Norm |
Pooling | - {Avg,Max,LP}Pool{1D,2D,3D} - Global{Avg,Max}Pool{1D,2D,3D} - Adaptive{Avg,Max}Pool{1D,2D,3D} |
Reshaping | - Upsample{1D,2D,3D} - {Random,Center}Crop{1D,2D,3D} ` |
Recurrent cells | - {SimpleRNN,LSTM,GRU,Dense}Cell - {Conv,FFTConv}{LSTM,GRU}{1D,2D,3D}Cell |
Activations | - Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh} ,- CeLU ,ELU ,GELU ,GLU - Hard{SILU,Shrink,Sigmoid,Swish,Tanh} , - Soft{Plus,Sign,Shrink} - LeakyReLU ,LogSigmoid ,LogSoftmax ,Mish ,PReLU ,- ReLU ,ReLU6 ,SeLU ,Sigmoid - Swish ,Tanh ,TanhShrink , ThresholdedReLU , Snake |
Group | Layers |
---|---|
Filter | - {FFT,_}{Avg,Box,Gaussian,Motion}Blur2D - {JointBilateral,Bilateral,Median}Blur2D - {FFT,_}{UnsharpMask}2D - {FFT,_}{Sobel,Laplacian}2D - {FFT,_}BlurPool2D |
Augment | - Adjust{Sigmoid,Log}2D - {Adjust,Random}{Brightness,Contrast,Hue,Saturation}2D , - RandomJigSaw2D ,PixelShuffle2D , - Pixelate2D ,Posterize2D ,Solarize2D - FourierDomainAdapt2D |
Geometric | - {Random,_}{Horizontal,Vertical}{Translate,Flip,Shear}2D - {Random,_}{Rotate}2D - RandomPerspective2D - {FFT,_}ElasticTransform2D |
Color | - RGBToGrayscale2D , GrayscaleToRGB2D - RGBToHSV2D , HSVToRGB2D |