Skip to content

Commit

Permalink
Update reshape.py (#93)
Browse files Browse the repository at this point in the history
- remove Crop/Pad/Flatten/Zoom
- add functional forms of crop
- add extract patches
  • Loading branch information
ASEM000 authored Dec 10, 2023
1 parent 0176520 commit dd90661
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 799 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ net = sk.tree_unmask(net)
| Densely connected | - `FNN` , <br> - `MLP` _compile time_ optimized |
| Normalization | - `{Layer,Instance,Group,Batch}Norm` |
| Pooling | - `{Avg,Max,LP}Pool{1D,2D,3D}` <br> - `Global{Avg,Max}Pool{1D,2D,3D}` <br> - `Adaptive{Avg,Max}Pool{1D,2D,3D}` |
| Reshaping | - `Flatten`, `Unflatten`, <br> - `Resize{1D,2D,3D}` <br> - `Upsample{1D,2D,3D}` <br> - `Pad{1D,2D,3D}` <br> - `{Random,Center,_}Crop{1D,2D,3D}` <br> - `{Random,_}Zoom{1D,2D,3D}` |
| Reshaping | - `Upsample{1D,2D,3D}` <br> - `{Random,Center}Crop{1D,2D,3D}` ` |
| Recurrent cells | - `{SimpleRNN,LSTM,GRU,Dense}Cell` <br> - `{Conv,FFTConv}{LSTM,GRU}{1D,2D,3D}Cell` |
| Activations | - `Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh}`,<br> - `CeLU`,`ELU`,`GELU`,`GLU`<br>- `Hard{SILU,Shrink,Sigmoid,Swish,Tanh}`, <br> - `Soft{Plus,Sign,Shrink}` <br> - `LeakyReLU`,`LogSigmoid`,`LogSoftmax`,`Mish`,`PReLU`,<br> - `ReLU`,`ReLU6`,`SeLU`,`Sigmoid` <br> - `Swish`,`Tanh`,`TanhShrink`, `ThresholdedReLU`, `Snake` |

Expand Down
25 changes: 5 additions & 20 deletions docs/API/reshaping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,16 @@ Reshaping
.. autoclass:: CenterCrop1D
.. autoclass:: CenterCrop2D
.. autoclass:: CenterCrop3D
.. autoclass:: Crop1D
.. autoclass:: Crop2D
.. autoclass:: Crop3D

.. autoclass:: RandomCrop1D
.. autoclass:: RandomCrop2D
.. autoclass:: RandomCrop3D

.. autoclass:: Resize1D
.. autoclass:: Resize2D
.. autoclass:: Resize3D

.. autoclass:: Upsample1D
.. autoclass:: Upsample2D
.. autoclass:: Upsample3D

.. autoclass:: Pad1D
.. autoclass:: Pad2D
.. autoclass:: Pad3D

.. autoclass:: Flatten
.. autoclass:: Unflatten

.. autoclass:: Zoom1D
.. autoclass:: Zoom2D
.. autoclass:: Zoom3D
.. autoclass:: RandomZoom1D
.. autoclass:: RandomZoom2D
.. autoclass:: RandomZoom3D
.. autofunction:: center_crop_nd
.. autofunction:: extract_patches
.. autofunction:: random_crop_nd
.. autofunction:: upsample_nd
8 changes: 4 additions & 4 deletions serket/_src/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def tree_state(tree: T, **kwargs) -> T:

types = tuple(set(tree_state.state_dispatcher.registry) - {object})

def is_leaf(x: Any) -> bool:
return isinstance(x, types)
def is_leaf(node: Any) -> bool:
return isinstance(node, types)

def dispatch_func(leaf):
try:
Expand Down Expand Up @@ -199,8 +199,8 @@ def tree_eval(tree):

types = tuple(set(tree_eval.eval_dispatcher.registry) - {object})

def is_leaf(x: Any) -> bool:
return isinstance(x, types)
def is_leaf(node: Any) -> bool:
return isinstance(node, types)

return jax.tree_map(tree_eval.eval_dispatcher, tree, is_leaf=is_leaf)

Expand Down
Loading

0 comments on commit dd90661

Please sign in to comment.