From 69a4d8e55e5429d5998ed2a2c61703be51d3e636 Mon Sep 17 00:00:00 2001 From: Mahmoud Asem <48389287+ASEM000@users.noreply.github.com> Date: Sun, 31 Mar 2024 02:16:42 +0900 Subject: [PATCH] v0.12 (#16) * Shard info (#10) * add shard info in `tree_summary` for jax arrays (if exists) * add export xla flag in jax ci * shorter syntax * Update test_pprint.py * Simplify tree pprint (#11) * Update CHANGELOG.md * remove tree_mermaid tree_graph * simplify the masking API. (#12) * simplify the masking API. - remove freeze/unfreeze - rename is_frozen to is_masked. - restrict the cond in tree_mask to callable only * fix failing tests * remove partial is_tree_equal from public API * Update CHANGELOG.md * add __format__ * Update CHANGELOG.md * tree mask edits * add broadcast_to to bcmap (#13) * Remove distinction between regex and string match key (#14) * changelog edit * revert `__format__` * [AtIndexer] make string key points to regex by default, remove BaseKey * tree_*** -> *** * print tracer type in tree repr/str * fix `is_leaf` typing * revert #14 plus some mods - in favor of more explicit - fails if dicts haves keys similar to re.Pattern * remove __format__ * Add `fill_value` for `at[...].get(fill_value=...)` * bump version * changelog * Add `jax.tree_util.{SequenceKey,GetAttrKey,DictKey}` as valid path keys in `at[...]`. * define arraylib.array_equal * fix numpy test failing * AtIndexer -> at * tuple -> tuple[type1, ... typen] in tree_summary type row * add def_rule for at indexer * remove is_nondiff * docs organization * list tree summary pp rule * Update tree_mask.py * nits * fix at docstring * fix no leaf match error * fix doctest errors * docs * nits * typing --- .github/workflows/test_jax.yml | 1 + CHANGELOG.md | 97 ++++- docs/API/constructor.rst | 10 + docs/API/core.rst | 31 -- docs/API/masking.rst | 5 +- docs/API/module.rst | 10 + docs/API/pretty_print.rst | 2 - docs/API/sepes.rst | 4 +- docs/API/tree.rst | 17 + docs/_static/tree_graph.svg | 67 ---- docs/_static/tree_graph_stylized.svg | 67 ---- docs/_static/tree_mermaid.jpg | Bin 100963 -> 0 bytes sepes/__init__.py | 64 +--- sepes/_src/backend/arraylib/__init__.py | 42 ++- sepes/_src/backend/arraylib/jax.py | 10 +- sepes/_src/backend/arraylib/numpy.py | 4 +- sepes/_src/backend/arraylib/torch.py | 4 +- sepes/_src/backend/treelib/__init__.py | 10 +- sepes/_src/backend/treelib/jax.py | 10 +- sepes/_src/backend/treelib/optree.py | 10 +- sepes/_src/code_build.py | 2 + sepes/_src/tree_base.py | 111 ++---- sepes/_src/tree_index.py | 475 +++++++++--------------- sepes/_src/tree_mask.py | 276 +++++--------- sepes/_src/tree_pprint.py | 220 +++-------- sepes/_src/tree_util.py | 134 +++---- tests/test_index.py | 68 ++-- tests/test_mask.py | 145 ++++---- tests/test_operator.py | 20 +- tests/test_pprint.py | 51 +-- tests/test_treeclass.py | 23 +- 31 files changed, 769 insertions(+), 1221 deletions(-) create mode 100644 docs/API/constructor.rst delete mode 100644 docs/API/core.rst create mode 100644 docs/API/module.rst create mode 100644 docs/API/tree.rst delete mode 100644 docs/_static/tree_graph.svg delete mode 100644 docs/_static/tree_graph_stylized.svg delete mode 100644 docs/_static/tree_mermaid.jpg diff --git a/.github/workflows/test_jax.yml b/.github/workflows/test_jax.yml index 644a14b..2b04bd8 100644 --- a/.github/workflows/test_jax.yml +++ b/.github/workflows/test_jax.yml @@ -29,6 +29,7 @@ jobs: run: | export SEPES_TEST_ARRAYLIB=jax export SEPES_BACKEND=jax + export XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m pip install . coverage run -m pytest tests diff --git a/CHANGELOG.md b/CHANGELOG.md index a6fd830..82fa907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,51 @@ # Changelog +## V0.12 + +### Deprecations + +- Reduce the core API size by removing: + 1) `tree_graph` (for graphviz) + 2) `tree_mermaid` (mermaidjs) + 3) `Partial/partial` -> Use `jax.tree_util.Partial` instead. + 4) `is_tree_equal` -> Use `bcmap(numpy.testing.*)(pytree1, pytree2)` instead. + 5) `freeze` -> Use `ft.partial(tree_mask, lambda _: True)` instead. + 6) `unfreeze` -> Use `tree_unmask` instead. + 7) `is_nondiff` + 8) `BaseKey` + + +### Changes + +- `tree_{mask,unmask}` now accepts only callable `cond` argument. + + For masking using pytree boolean mask use the following pattern: + + ```python + import jax + import sepes as sp + import functools as ft + tree = [[1, 2], 3] # the nested tree + where = [[True, False], True] # mask tree[0][1] and tree[1] + mask = ft.partial(sp.tree_mask, cond=lambda _: True) + sp.at(tree)[where].apply(mask) # apply using `at` + # [[#1, 2], #3] + # or simply apply to the node directly + tree = [[mask(1), 2], mask(3)] + # [[#1, 2], #3] + ``` + +- Rename `is_frozen` to `is_masked` + - frozen could mean non-trainable array, however the masking is not only for arrays but also for other types that will be hidden across jax transformations. + +- Rename `AtIndexer` to `at` for shorter syntax. + +### Additions + +- Add `fill_value` in `at[...].get(fill_value=...)` to add default value for non + selected leaves. Useful for arrays under `jax.jit` to avoid variable size related errors. +- Add `jax.tree_util.{SequenceKey,GetAttrKey,DictKey}` as valid path keys in `at[...]`. + ## V0.11.3 - Raise error if `autoinit` is used with `__init__` method defined. @@ -7,20 +53,43 @@ - Add `at` as an alias for `AtIndexer` for shorter syntax. - Deprecate `AtIndexer.__call__` in favor of `value_and_tree` to apply function in a functional manner by copying the input argument. -```python -import sepes as sp -class Counter(sp.TreeClass): - def __init__(self, count: int): - self.count = count - def increment(self, value): - self.count += value - return self.count -counter = Counter(0) -# the function follow jax.value_and_grad semantics where the tree is the -# copied mutated input argument, if the function mutates the input arguments -sp.value_and_tree(lambda C: C.increment(1))(counter) -# (1, Counter(count=1)) -``` + ```python + import sepes as sp + class Counter(sp.TreeClass): + def __init__(self, count: int): + self.count = count + def increment(self, value): + self.count += value + return self.count + counter = Counter(0) + # the function follow jax.value_and_grad semantics where the tree is the + # copied mutated input argument, if the function mutates the input arguments + sp.value_and_tree(lambda C: C.increment(1))(counter) + # (1, Counter(count=1)) + ``` + +- Add sharding info in `tree_summary`, `G` for global, `S` for sharded shape. + + ```python + import jax + import sepes as sp + from jax.sharding import Mesh, NamedSharding as N, PartitionSpec as P + import numpy as np + import os + os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + x = jax.numpy.ones([4 * 4, 2 * 2]) + mesh = Mesh(devices=np.array(jax.devices()).reshape(4, 2), axis_names=["i", "j"]) + sharding = N(mesh=mesh, spec=P("i", "j")) + x = jax.device_put(x, device=sharding) + + print(sp.tree_summary(x)) + ┌────┬───────────┬─────┬───────┐ + │Name│Type │Count│Size │ + ├────┼───────────┼─────┼───────┤ + │Σ │G:f32[16,4]│64 │256.00B│ + │ │S:f32[4,2] │ │ │ + └────┴───────────┴─────┴───────┘ + ``` - Updated docstrings. e.g. How to construct flops counter in `tree_summary` using `jax.jit` diff --git a/docs/API/constructor.rst b/docs/API/constructor.rst new file mode 100644 index 0000000..3afe172 --- /dev/null +++ b/docs/API/constructor.rst @@ -0,0 +1,10 @@ +🏗️ Constructor utils API +============================= + + +.. currentmodule:: sepes + +.. autofunction:: field +.. autofunction:: fields +.. autofunction:: autoinit +.. autofunction:: leafwise \ No newline at end of file diff --git a/docs/API/core.rst b/docs/API/core.rst deleted file mode 100644 index e23f8c7..0000000 --- a/docs/API/core.rst +++ /dev/null @@ -1,31 +0,0 @@ -🎯 Core API -============================= - - -.. currentmodule:: sepes - -.. autoclass:: TreeClass - :members: - at -.. autoclass:: Partial -.. autoclass:: partial -.. autoclass:: AtIndexer - :members: - get, - set, - apply, - scan, - reduce, - pluck, - -.. autoclass:: at -.. autoclass:: BaseKey - :members: - __eq__ -.. autofunction:: autoinit -.. autofunction:: leafwise -.. autofunction:: field -.. autofunction:: fields -.. autofunction:: bcmap -.. autofunction:: is_tree_equal -.. autofunction:: value_and_tree \ No newline at end of file diff --git a/docs/API/masking.rst b/docs/API/masking.rst index c1be176..3414824 100644 --- a/docs/API/masking.rst +++ b/docs/API/masking.rst @@ -3,9 +3,6 @@ .. currentmodule:: sepes -.. autofunction:: is_nondiff -.. autofunction:: freeze -.. autofunction:: unfreeze -.. autofunction:: is_frozen +.. autofunction:: is_masked .. autofunction:: tree_mask .. autofunction:: tree_unmask diff --git a/docs/API/module.rst b/docs/API/module.rst new file mode 100644 index 0000000..4b56db8 --- /dev/null +++ b/docs/API/module.rst @@ -0,0 +1,10 @@ +📍 Module API +============================= + + +.. currentmodule:: sepes + +.. autoclass:: TreeClass + :members: + at + diff --git a/docs/API/pretty_print.rst b/docs/API/pretty_print.rst index d863ea9..627fdf4 100644 --- a/docs/API/pretty_print.rst +++ b/docs/API/pretty_print.rst @@ -4,8 +4,6 @@ .. currentmodule:: sepes .. autofunction:: tree_diagram -.. autofunction:: tree_graph -.. autofunction:: tree_mermaid .. autofunction:: tree_repr .. autofunction:: tree_str .. autofunction:: tree_summary \ No newline at end of file diff --git a/docs/API/sepes.rst b/docs/API/sepes.rst index a857ed8..3e396bd 100644 --- a/docs/API/sepes.rst +++ b/docs/API/sepes.rst @@ -5,7 +5,9 @@ :maxdepth: 2 :caption: API Documentation - core + module masking + tree + constructor pretty_print backend diff --git a/docs/API/tree.rst b/docs/API/tree.rst new file mode 100644 index 0000000..0cb4f9a --- /dev/null +++ b/docs/API/tree.rst @@ -0,0 +1,17 @@ +🌲 Tree utils API +============================= + + +.. currentmodule:: sepes + +.. autoclass:: at + :members: + get, + set, + apply, + scan, + reduce, + pluck, + +.. autofunction:: value_and_tree +.. autofunction:: bcmap \ No newline at end of file diff --git a/docs/_static/tree_graph.svg b/docs/_static/tree_graph.svg deleted file mode 100644 index 380a167..0000000 --- a/docs/_static/tree_graph.svg +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - -G - - - -5353602176 - -list - - - -5353602432 - -[0]=1 - - - -5353602176->5353602432 - - - - - -5353602496 - -[1]=2 - - - -5353602176->5353602496 - - - - - -5353602816 - -[2]:dict - - - -5353602176->5353602816 - - - - - -5353602560 - -['a']=3 - - - -5353602816->5353602560 - - - - - diff --git a/docs/_static/tree_graph_stylized.svg b/docs/_static/tree_graph_stylized.svg deleted file mode 100644 index f6a8d7b..0000000 --- a/docs/_static/tree_graph_stylized.svg +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - -G - - - -5345369024 - -list - - - -5353442880 - -[0]=1 - - - -5345369024->5353442880 - - - - - -5353442496 - -[1]=2 - - - -5345369024->5353442496 - - - - - -5353171392 - -[2]:dict - - - -5345369024->5353171392 - - - - - -5353173184 - -['a']=3 - - - -5353171392->5353173184 - - - - - diff --git a/docs/_static/tree_mermaid.jpg b/docs/_static/tree_mermaid.jpg deleted file mode 100644 index 07f1d829aa9395c551fe5b86609916d74a3a7336..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 100963 zcmeEv2|QHm-}o7WvF}^L$P%R}OUN=Iq>?3(C6nxgkTn@4YnZ4g6@{pzscgwI)+Cke zYuPHYHDeuS`#-w3dw;jz`}wW+ec#`G|M%3H^F7abp65K@XZh~WLGPoFKs@`6OpG8H z41z9!e-M2XN;3@gbb}yMQ)nv$K`am>OcY`Q5Xb`Q!Z!W_8^Pou__uTh2#WKB7=O-l z2wc|>P}{dMKV9Kzu@C~J@PRY*HpBPa%+PK4cbEy(13`L+4NOeH^{{h*tE+FIhu>-X zETmx$u72R=l%$H{4oE}a#FTOUec<>8%YFlSy2O+EvCvb-FHx-5W=Q&x^fB&tCI$v* zb4xQLlY@re(B!UlI(6zK0}llG_?`~1G}f0qYHKHnm;iVQLY$ByB z->={5|K(@7=No2F=l1oweo6kv0A3f@Kxfc3<{44j(kTEyKfUWikoS@G^!57ZJzPzG zt{=`7?qOvE(%B#gE)(o&y}n+!0IV43eR#b+pbc>Ca1Vp^^ly0rf-KkDF9XsoU44!9 z0Ss_~2fGCrt>eC)KhY-;K;LkO-}dk{SuYFv0si!~ht)m+gZ_axIy-%9Kfne4&ed1% z=RSc4y9ORzZyU(ZQ0C&g*ATz}cle~M?~!$U*XtpjI<3F{E>M<{?bOL{YXhV+O1Yji zUQd?-uxeoN;h)PgTAvQET2BYSh{Y>PLH?5BQ;vOra5g(NBHo3i|gi`hsyYs8fwmpHTPHZUHofWr9D(_HMfuME%vYrQm7A(H;4VcjDU(=$8AxH;Lr_(Ec zP1}!$peO2J`B(g!Ca(cOyzCJ4vcfqiAo!>H7}qaMfJ$(Km9_zrfTSULfR8e?3(|u2 zK!(tM$PBWAY#~Q5Up>G)4}e0T2B0xYfPzQE>TYYcD(4h8`RafZzdN(>qddJG2`tQd|l zxHI@Ogfm=VNMX3fki+nhp^Tx9p_QSVVT56tVF?bw+2KO)P4Mk-4Y&c^3~mQ^g9pGP z;fe5ccnv19RN zInR>L@_?m=rGsUHh0MysD$BZ?)r{4JHH;O@n#cN_wT<-?E18X#O^!{6&5F&7?Ht=R zwnuDDY(s1d?Ck7P?3(Nr?4ImV?AO>Iv%h2?Wnbms-|{9KA$CR{FD7_J*!Wn3LxbKI=lvfO&yj@;qg>D~pU2W1(RpLy#-|$x#2Cc3idl-C5z7^OEw&^sA#N;wO8mNb zgZPYuki;GdFNrG>H4bw+A>b*3Zw1%{s^cCrPX@U$=##AOkrch=;mQ_|m))T~6?p|<1Xj=UX% zJNb6*-+6xL^PTg`Ta}%aZzy-EaH<%oU{s!|5LFdaT~)JG`_%Z=Ox5Dl>eZ;b)OQ8! zdbI0{x{Ug9^_%Lw8Uh;T8c7;2HJLQ^H8Gl1nyb5a?LNKx>F%#u+qJy49%xN!%W1o6 z=V_1YNa;A~+|e1)mDD}1ds}w|wF%{f%0Z3ok>2C7Cx6d`-WEMCy@z^p`aARk^h@S8+3XKTHYQ|@bYxgnjJFqWxU$2R{iL=Rl z6T*J={b%<#9^g1|_(0}?v4dL=1{{2D%3!+R^s?#DA=yK|hsw<$GZVARX7A1A%>B(P zEtoCLEpAwRw%lnMVcC3`|1kRSgTu>K`c_!0!6WiVf{xT%b6Y!D-?v_|*=v(#Gjdev z=-H#KwxYJ4w$JPkcGhOj8|{QHKF)O9 z=J~Xo`^2JrhHO+OxO~)Q4wb5Re}T3p^J1?6d$7%ads3crZpe4E_zPh;MV){i&l5^6l@p7qi!Hd*1H5V|=G9XLC+^E^}^h z?&mz~yym+ackkzm=3mUG7Wfv7+_Svbcwhbg{Rd(Xu!RhTL4{u)Iy`(=w72NlqwSCG z;Dm4&A489W9#1_v{-o#Wfv5Gw8pV%FWJ_+A@|Px-!OFtQ2<0B-BhPG~wO8z~sDG~Y zytGoWvY={H)s1SwYHST#&4n5|J{-SPd#ZM(&b@BD9$i1!aI~ST(W3Ef(}AX!FN|I^ zH0w3vU+TQ9YT4aV@k--WdF!s$ve#;_OW&xyDSf-^ZCRUo+p~9?@1D17x7T!_I_f(2 zcE0E`>1yph)ZN~5q^Gafp?9RurEjwTWIu5rcz`?@HN-TO_@3u|`UkNOIm26qaU;7% zsz>*Zz8*U~HaPA)KK(K9BjwYD&zzsrCy*2OCzU6wz8HOZH)TKdY5LSOc_wz2clPF- z{9MV`JzrlFjuJi*1Bmqbq=gL&1&gYS^-JbUAC`TV$t#JZ4WxUk8mrA~Hfs~)5DJ2F zjk<;UoOXaVME9Z7PdNoReUk;&HE{+vptsK#KoF-b1o6KE@`L?P`ro%4u%C2R0KgRTVAahO#veJYYR8)RBzZC&G^*6n8 z{e~In5VuXJCwxx){OC8ApL9-;^u66bUI8{LDyr*cesiVQLcAJ-vPX1EXW(A3uGbnEWzDoL^X6T3#WouC3z* z19*Nve!}brym$d#42+C$M#MT^FoqCt!g(2)Htl5Q)3ZQ0`SVLDpJNfwPrdV`j#XO4 zk|1~@pq))fMs-w{xQ^O4%>FUN&i_l8{esx9cnt#8g5lc%XJCLc!QpTwW+rejvjEA9 znVE%+<=esb^TF}$;95Vpe>v!&5*Vn1k&zJr{`0W2vh)1Y2mL)TJsqbHKpb!wm`re9 z2nA6oIf?SnpXdKtAHeDF{)5J#@cprYzQOAEGcA6M&f8^@cm>W}j#79X#GQg>DQP*V z?zr`dnnFHoa*YB=c7yLJSXe1i_f=h8$s zjYdjU}njA%nfh zl{r6}6n4hJowAxiW1vGXb#7Klv>AprQ&iRKve!f<=@32;LlnR?7({HKLw9nx$t&t~ zh}U>oQH2)KN{3*Wj}mW>Vcm18oPXQ&zru_x3@Xrdr9)Bpd2V;b=Hhn8oj*4I2i&L^ zlju;kG99W#uOf;wockwnB7fERAF=zj%cA>RpUf=UB>h$6e?;!YIBMcGo*+!qoB7nY z9R7EQ1@*)~d*L&h#7>7K{~VUzYgi7M4e*b2KDHRUc*^X-e#B?XALnJOVL%O%*cKgO z6-@H8FYe0Tu!V7Qp8Em(S)|nS2P1|1Pel}4N4k^qXGJI!uBV$SDvm^|xhN+ze8+89etl?7tvtROME?47`Cmqr|zv>1!nVG>8!n8x{fO6 zc}Jl8X77pK39 zy+5=1rm@J0NDgJIu`&~->AAhFf=b1Ahp+T4}LiC;|_LI6#(4MD@IG zpheQyDbdp^a}j=Yh$8_|I|J=(5(|nnizK{EM=z?h0QwhAhf;d`1uSS~Zpeid0{SD; zHal<}LvlHww6@_Rpmv}mn>_%b+xrxd+$=K+Ba#wa^D+m!RJ;SwL;QI_cGg~#vp3SA zzihPeZxAY7HWa`VU5}$sDZ;CW9`_;A@NXK7`YkrNr}DYfldn+|Uq|pSb&GW;UeKZ6 zwz&NVgwCcKB7f;3y`}3bx_~kKrpegfV&pxO2w0E33im3Rq%_a{u&v_nje+r6|20h8 za2@8-**`}BO8MtUAd?O~6~AN2-^8`?$*$wr);EI7N5HJ8Sf3TrwDmzl=~Nk#cZSlT zJxD*or=i{W%L_2x9S#gtpN^!cEJmzZH#cAg?3>ZtrD}w)jjGvJog;pVN1idF?Va1R znVwL4C)#7qjyswVZAKO2h;sYSzD$4qdTZpHV1%YP-@$;Fn4MiAh#D+!XOoiC^}78^ zMGt$kDQY!tsf}opZ2i!X#&D&Js@^`Bh82@A=itiRRFt`~c`ZTjqXef<@2B2cr36uY zy&2MEB%6OO`T+hJS*^?Rba}82cj7qG5L{Qfva?A*D&uc)a;dErB5X` zlhu(0jp*m~qo>;Y4N*5}*&qB8+Ko0D3fR09J7*|h4*NGxM16K?$c53TJtb`-cDK>9 z0^VnmLehBHcCN_#FIu2|DT2gscyI ztp8=hXRPIg=%qw+?;2|A-7sC=txsO2&_3DpMZ~RMMjyBpNtU^5X>P?5>b6VTI9vv2?I<^`w*tDNu8GB zMHR+%k%~{ENW3))w=E0>=Mj>>!}5hKpl1)$A(??eTKx{pWQBJr^%5QWn3hduallM& zaj^MYPtm4WI+V&rhc@&hN0AoohJvg2-va6a{Si(Kz3Y!W>}VER55BZF9_6mY?_tN4 zJW4+1q~@(&5%TQnxV*LgjXiq;G~PsH&>?OkVjGoZ&5y?Th9=miZ6!6+Hxt=zfxd_* zEc%H76&ssRhei$#kD~K1#i$pi#^d<s;egB_3Ns_aAf z1u!$ux^L^FzEl;7YOEwZ!-<#Eq4PNuUGG?X9FtdSU6#1qxgl(j)*d=kQ;0}7-!cAV zEx|CNnh>{KPuf6CUmo5|pgmsmEe}A26iiuAvqJplUi5q9`5Bt9JKwbFQhVr@`TPW% zJnqf#fb;S=5mlR=n1SX#jpMMLN?%^?SxEStwKu^1Sfi8CjxARAa&VGQ)(j0J6@j4G zUKL1(^mW0SJEww)LH^JwN6!Xl`}a|Ea;iE3eEq zugwr9owD>LsN+_5BbfO+*yMAfemS?KwTkVm>eNWlsB?4EU_;2Hg~QOIp4sJ?+FO2O zY8eg3IG4orS1h_tbZ(y;El=thME(=3*CIO2* z>nBo;){UAsEyB-8BWM@XE7ASjG?OiK=y6Ol?fliI$iMM$v5yX|^5BTKQfa(P;($h3@8iiMHx(;=q4bSUlO2%gMjN{4daM0$}h23--RIM14m9Xtmxg$@m#zDack zrdsjugtQ^sIJtA^yJ&)y*{XaEFsVGY8eY@xts7> zYV-rXWS9zjh$lD!HB1m)M2DWxq57}F6bX`o>j+LvPrR&<^V87l)cwWsCTN8PMdSTz zv5AXcoN^JT4>+M?@a5dIuJnI2L~9C8`-@jUG~Je*Qs8y|6hbBo`glpp8ZqMs2 z#V|OYO_BZZ`NM%T6@tc=C(6@JAaS$S+K;`{?las3vxB5SLO^q~aM`N}!*PwCdE09Y z#D3JScIhCtmV~1rZWCHI$e=FE+JcQO(W#Wg+$!IvMoqT+B%{x>OzPp@bN5Lv-|V7o znan`4SA_&EXY-OGnbLcvj9!FB#cf&mLX=6iN)iH7lp$cLy?EdDmn^^OJW@DpWX2xZtmIW z5Z!~fpTyY}$*Q^ScVBQKsxRx2h6xwz)pa$9kFNXpMBvrlI{|l4-115YH=0i39W;D4dgK2xg8x7mMeq5) zMN=ch3C6So#!;A=;~jV~0m#iXwgfsPD_nOT!rag&>5e3IYg+WP51(1w#J$Jqt;bxqU(K};$Xsl2D zl;Wlj#^WEcukb?ND@Yn!SUF7q7hAl7ACST68Pl4Ufz{F^>lgd|sUL|KfOzyqONQ zlv8xpSsMCvZ(f8AA^*-eE>>bY!TPd{`LOrEk=Z0}P1gk1ikMvePe-oFZ&TL4oD#K& zJEZ>6dd6fGm=C#4dohz)_Lz|tAkRo(N#VYgbZGDVARS^h1FLaq0!f(h!=e`@2)#5W zM>;ew{#OZMl)m#1vBq>LE`J)oXto_saG^sKtT9DL{TLm}2?K~l%K>F5d;v=eR&U0w zafbpOmjQ=a<@&3Hl#_dGk%uD~srVp2V5Sv1jf~B3hI?K}HzKO}{+p{$8uBlC%6+{V#PB&tBPl%3rHp&nk_Fc{u(d#6IYagQ z#7G-O;BbB927+-To`vLBFX~#8nr{~xxl_Ele#L~iKGFSbUfo&{p zl`u?uVG244$RSVgBL(<*+bR`0besk7oJV6dn$#@R$C?OW|9o=UgbsB)n{IuG54S_z zdmx`vox9^_R|NbsexQJ6Qd&;EaGVac2lRTAH-QNvi(2M+2J{olXSl!lxaJ0u4oL~p zRGy;hb#rqP6&4QP{7%;Y)35a2|6%G|6*q=x9tMVT_G5R5D@i5qr9SCxem}t-c#aMo zlzy)SO#H&6aT*(zpsbTTi4W#<=5hq#sf3+A%jvud2^(BJ1eUipN4?VdAmHZe|~4aUVY< zLbM_m>V3Q$DP`lBdmNVhJc*NXtqViw#uU3#1*+k|WHkyrzlE!Za-vX$NRk)6znGTj z2s+~&Apb-wOv4e{o9so)MQ#dr91uWopLQog)P@eVTMCWnl9S=fF(Kl#Tlc(iI@OKa zBs5Q?S-zZXIRB7KPjKxpIt5+KO^JyV#|&oB&MPeM!XF-c^BCybMx;mQezEc)#tJBc z{IvZ_U`%EN)oEBYLP#BQb-M%%ESl*3-xkeRBx$Q0L*k59L_DoArtv=e!txyE8b%CLABRcqmStE4)+D&0TV)l?jwD0gL?>I^i9j!Jt zF71_6f^kR>u|V7Do$(DZJatiu_q~G)@@048>iF_u(m8LU^ayu!X7QHnDpdt5mFz?0!avl+;Z7(&KN!6L2Mqb%Lo_F4L*C^DhpCRYh!P&Z3LwM$)c}k<9n9rgYsd&VuVKq{of+l32<5h2Z zfNoSoIM_9j2ct=QQ@ooSYrK-BdYTtY2@OUW4R}s%+0oTrm*;M&PF@Y|cgpulF9;5` z8=l$16JY~*J?O;C$BEq3vz~Nl?T9!=dx^?{sRM?HkA2z(bO>#SpDH{fg6_OH{2gciv^+$;aC_Qdf`p@N)tddW7am)fCN*gGeL($bN?y@r!{?5OYl|X z7^WeT)|$Pj3V7xD={A~*F-86B3=lmGfQCv^tOAZdHBty2VuO$4sr(o|Fl)0G=uk3H zgZ?UM>D8_1E7TqS0|{ zSJC1v+yHYV(rSGOj{f zVFbk(D)jCsNvUC3RPdR3+EV&Z}~DWfTK^R=Qbduq;X@06&94ax{*NBs^zewbEE<4-}I+6w1XTJd{Io$;ZqRRWRU>b_MoO z8ACE_WJ^1TQp0p0ZR!0qM)S=9;FjM>29EW1ol+1wa05@%LED0?K=)|ih5`^XCQKbRC!Frl_sgRxH<)YG}hitYN%$v4Y?Kzq*2_a-MsXK9R zV*wAz1+wasL_exn_nH#=KGCfYLy(TJ*#qR!k6hqa9!@;`cQw%!?BA6Ini?^J>DZ|N zKvg`V(~0;Cs(wh4!BK)uyz|28S+5>@+bD9Ki-Ygs{wt$bmL)}L=*TS*szg-1lGs!h zoakDqccx&UzNW~5coSU*Vpc$!l600Pf`bLweJoV2*{{EKw0Rbu)9f_9(K_0)*jmW$ zY6W|sul-FI!laqiDP#ZG)xr$HxCtBkWEBE-If!~MB8t=b#IvKOnDW%CQp(VGP6~Jw#5yKLy1w#c*8;_tV z{Mh;$&MqH{2@zu+*JN`EYRPZdsF;<;HLo%{0e^bc5Y_ll1J^0DH(5xdOG*;#8HU6m z!SVUxxw$8bkfOehsW;B}&83O?^M3C4v z1GInHg;6{;DbtL0wHf^gS*AV{LzG*NA=0p8c&4dv)2olN^S56~xM%7w7HgPgtDGY* z%j7RPrS~=z2PYi?3;HXma|Dy{fsdGvD+-OY7e`n>!QYassPwc`@a(W=5pB16!hS9t zuCoufwTSh3#=eraelMWvS+4WdKY8nadpO&(He`f8!Mz@+cke@6n=>Eo)=hO!UN{)J zfP5aGyKQc`{sF5_Or#(U*n|n@KtiNgsz-8>4!RR?F-NNc9Af+Lp^c0>E3VK!??Tqb z`6I)gT)oZx4g|U^(}1!4kUf?yR85+QxtO!+D3aH7H5bzvmZ5bsBw2Ijbu6AUawi2t zDkKgc6)kV{@jcKt=JU?B8CZ2Acarbfd_sr`X*$184t*0DANdwdFheGwg-?_cjT?n8 z*VPDLLzpRRPYGM3UOEx;EZM%eWr?LeN|%e!91~$OZt6SOkj2-jsCL$pc)rQUd*cCV zKK~UfG1yQR!qOjoA53@DNcxA(-L$r%h*LcodrN&DoqnV7u-D~OaEm6he4BYNt8Nbd zXbm@onPg8;+DmI9`R5oPcM-|@LWjaaPKAHAz46&MjBQ6Q$;#%jIRZ=BN;^FR#ce$#h80*dOR<0>f*X zg^P~(-5Y6#QpRa7w%`{7zlDD^-Eab|DVcZ-tPR{H&2KBD$?OMZZ=rH@BECp_6A(Q) zNKa|%nZ1TxGdS0MQOc^pop}$1{o(%X=hB$Z^iC;J1VToe8&MlW7wXlsI0!ZOae99F ztzu6-mk+;}@pY3?5J^3MGcX!`Gn;Umwi~cUO`;r)O$kNn4t;ion(HZq?E&I{*hC#_ zO^9Lu!lHdptD9?DN!y9h4R|pME25E`ec+hw(IHb+2CXK;YpUBz6L>2harUUA)qu%ysDY z205db;*oabt$0q1FCB_O3RKyX^}R&fVj{#JPVbIL-?hJ7$nI>@YuPmI9o(YIA0=|N zZFD#1#*5+Tfr5rMj>+dux2TIfPi4we2Omo6DZlLB`F1u<82)J* zogp}l{;|K1KPH5!$!b=X9K%j3Et>*qz>fqpl;ijC*Kv2I>oIOR($?=Hp-^Gt-fj7v zxLMV{H19L<16y)VnZ24EQO-W0;yCO09u>3UaZcja+P`JA`v2EY@!uE*hReo#C=Hm6 zHqT6q1{Q_Bnr`cwyNWMfd?vE7umy2@!sR%1)&7%XbfBR1!tlW0as{D$rgQX;rqNYISe!QM=1Q0g6;H0@o@77P#R!{Wt#wlz%!`uj?XfiaoG`+Yq5 zZNJ;n3uvT9f}&0=Z$k4AkWjU|4-VAaKOLH~X+CEqOl;Iy@BOtX9q&%y1M4psQ<_DN z1-Dmif=WAPUfrNt2BF+a3-4al^$E(PM;yb2H7Eqjt+Q zbI$IMwE^?bW>7JRSGl)aWTGPOlNB*_C~bbixsw!E;tqo`tjN~FhbJp%w#iJB^PI^2 z2+E>=9r*#b+~SG9zX=FF|7V=j|6q(1=D#bv4VQ%R#to1gP{ko}n5T6%X`Scf**Vlo#223&fTqd=u$e<}aXSe6>Ky|M1XzjXv zPK@s9bxd!0_sM|3Zk0$yqF;h`*w``S7ZOTcvfXO8?A)?Gf3{TLH#-9zG2ijv0mFl+ z2MkdV1F5o!wwd@fAqXi#YPQHY(i3ERuvcA0tg>R{)DA%_=zC6uX|sO*<=sTECW?~% zY*jpHjGZ3`Ln3ao`W_40wOn})SXQe(3MK{Uo!rt#O+h(}hb&*{&LxIiv^=I%E2KPT z`@-^={$7--P$MwuXqzY*ezMtq3*U~4)7HEbYOMc+&=$MxM5k)eF4@=*15B)9{%5)U=BkeD5`{}jOYl~y9T))1L3o=A9(!dUcWw|Rc;+3du_fk%e&}jNA%w! z{_D>ud48+AwXkH7e89fnpmUa1-wn^74nZ#O1Z*ca6K^=Zq`Pg9o5tvuI0=H0BF?@1 zJ%aHfddXJ6nzpM9xa^XUYl?QW7#Gy4OaZN47`et~BltI;b~*!Nsy_ZjurbNF>$HVo zq*q|f-w>SsQi%HBhLQgrL4r6YY7PxWI=mcA3iBUF@ud2n*}3CO5A1wPSYT3&c)gM> z5-GnVtM|2LbA$;2TW`kX<)}lfo_yM|_;O|E)O^LLR01Z^7_}9LbKpp&;3MDRB_e?2 zL)lwNEKhLpD6-ksf)M-xqnu1GSzJG6* zna60-*_S5{-b%eaYD1KHB11JMi$=cO-8{=brZ=#XWY8*!DGKT=2D#yMrA z-Hh}%TZw-JU$&xNk1!^ikd{tgQ|}LTKd?oy=pL-mg1^yXB3kv?CS#{`O*mrBo!}Qu zQ6J&v@ch6@NKW(|FuKau05%22SaeMU}?#?7lw0))uHA>bfp{gHB zkSs5c(Ki;bNZR@!1!5@uCxRzXTYeCM+C)ItBVjej zqY)@VgiFIo?baI^kZDLjZP;a#lo?}q&OJ5_!ReBpB0v98f{0AVl*DF)jbKx_#|~Hx zlqdV#AIusYuYcTzyS_p09WOfEBeOakpJ(eG9{9DaCT25`GoOxY=#JkJPX;P$38wyj zjSOkWF%M!`=7&T(Z&&!|>muCBo91d6@UxIs?<;ii@Cy-b<%!y??KfUNPHpRVyBt%V zihZ;3@`?pk5*Q%dwZ*VqWB8J%&x5vjTc#%mm18O*-S>{SG4@Q7UnCm3265!^80m?# z+{InesQ=kMw5|dFYp2578kq=_cQqR)vG=AdS(Rigw-@QHevTXpHji8gUKis!fFPO+J`xE&S9ss!pY{OaF6r)$eElx;Z5YLV zL{B`wbd!?*$wnD<rQ0R8YbekJCb=^nq?QbJQ9(-on^d6b^L4s63 zv>Ev{Op)H*T@uUgtJ`o-n&Bb{^~jE8|8oBDJDc(^!=>f1ALGUC^()ISi^WT5W}vL+ z+rFSrd1r}%5Qu~#-MX>uyM`^xXEJ!V-R2+iMIeR(9aZi{zD6>A{iMRlS!7$GYx?>X zex)S*@{i+YpNWj&*zM=3pwO2dSHW-ieQsHwL!d5dP*x~w@cjUvHXeJ<1UJ18)^^0QE!Fy$>-B@)S|f6KC6KOeQqdA8)-}P%WtG&&cO9ZF0cv|Ce24 z{6pPKSduF;8okI39QU>NQ7?YHojn-0Do2fISUGn6^1xF3tNT)AwAYAj*|Fk=fX3nMDA~! zh<{Ukcr!&ATZdai41q`ltL$;GrS~_t%lQVB%d{FN@}t-Dm>;GB*ungps&Pj_w_1?Y zz<86hDQuT19zWjGx?W<^Fuku)e@?WYCh`B_iI!hJ>pF70!0qb2^cl8##Kh;!CERaU z)m4nEQuO}XTf0Z2W_<59PZ}csf_8jH^3!^DzZo*~8`Y80VrXDE6Sr{l%mtVg^1joo zL)$qT8|4tG`ynnt+hW^t=IM_c{Jz45rMTW1`9BL5fgx~SR5)R{&M&5h|EjO58JAt& zmshPj)cLk7SzoI-Jja4Fe^^Mo#i1>x{7Sosdfy9BnYSMdmY-m2Q=IE)-0!qmZ`+gn z8iMNw-VU90G)9HJ$0yvVpfE$wX|TmJySs2p%Uj>sxz`&(JbLjj!TkS*g7ZJIWf=B* zMn&Q#;%ICzvW5B!lvsA0E_bCHAtWZ>bz9uW8oWc{n`SHQ73`DM8>Cp20czw9vZR(| z0QN-)w#6uxg_bwHdoHAJ(EnvZVOzhr#35%ajj`dXAceTK`0-xGVlhZlCKFmjUl?Ym zm=UqfYsL+2EcLGVZdW#k^Oh`e%BGvd`P5oamp-0D0~ba;GA2@lq(oTRM+zF%Snl`q zclxxIF@s-S*=+Of$}*Eh5qOlsCsd#M@^S2ld&4N;nF97FiMYf&CW;2sX)S}e#J+HK zqYUapB?24y!>?P+{t3@RkCw@r3hwh!*9YSq3*6L825j1fykcN;gD#bt?-EJl2gqjczkPgBkH+}bB{NY&)- zPl4}0p0LH5cOO7V+m1nGM{pxApK8L_Aiid z)bH_O4BU%s(w2!?E-PDFTb5Ae^rsKgMH_?3JIxZ54*$Sg1HOk9sY$el99SRsosb>heo zM18V!?e0($rPRpw`wq982mD;&Zz976cJS)(8b|bg1|eC(%OUpDt+)aHCZWMQ#h%+1 zRf8B+pDFkUs8H*!U0TbSV~Xs4_rGlt{iDABclaj8D*tj9iKYvS5UuV~y-&gOhenmg zUV$~pq~NQbf%vv8`4BjDKDru;w7}FE;ywI{hsREld`4SSJMincG__5#w!uTF0m#%!I$Bhm4XRKE9uj7Xfj(0z{=X5!-M zdSdQlv5FWl1EqyuM!yT8@ZVSt@rwa!J$ek2u)PY;OX`tHlMUK5yEt6UcYIMrZ2yNr zgt-U&D-IEQK%an&IgMiuJuzhGX{S~dTt-Z35vi3Srt2JCk;z)>X#Zn;;o|pj{Qcf_ z{JUS>V9<0;WXTG9@0_Cc>xU<}^6}TugG%Fzcw~+8@YLGrrYiM2bB{Z0NOmlJU$avd zbkRD&vxvKq#UOsz2;(75s2ug8s1uf2E;c%<8^O13!jRE~kwrRXhu{$-L90^`JPN{H!8@a)cq1T^rgNSbFB5TdIVfxs zIjx0jLFn8uyDYBmD_wNf`P(-h{!vT+pZR|48A1L#UZ{r{9`6C;n3C-VjnApCY`;8r z;By#yJ;5$!vBlkak)cds{s4{bEOBu%68IB1M_X@|XZG2@cgPJ$j~}kMFhYk~Z7g4N z6YkA)!B)+HN7lJJQj22St1C67m3w~9Z0ieVX{y94rF0R}xK^aw(uwBkv%r$cI6;DR zw?qKDR7@OhFf2=)U?5w|1v6O^`ustg_a#=sJDewq(7j5|`(l?;{h)s`Q>XIS@D#tX zpC21XN!qJAUSSaVqp=r36ILhvjsdIxd`too_DL=xK`vdzTvaUQTey9`lL_T zm3QprEgyNC7f?zrcXT4S zj{?r?KAIIx7}ta9at9xn$U__EHh#`@s}-wE`_KR{4L|#b`1sTB&DMU;`Y`{h`Rfs4 zEThqseH7I&8!{#rKkSndcXHwB1hdT-u>-WOlq14Jy=T$!c!HG@R{BV<>Tpz+M8%u3 zG~hC^ksnfPedIH861nOY_xkvm=PPy)rHnX{tkxF)LFK2JI&0?e9L&@iQ3EhjwJ`%V^q~aNJ?*9Chwc-`QReQp8?emGoVX?_9 zh%glCmiy?vv5#=h97?6!?5co}-*NA?W|IGi4)IaO2~Z6mg|vqa(OZ0=G(u99@n zYF@yx>XzGFJ%1V@{zc#Tuf`Paf6)d8f32$$Vcs6eK~e6dNl)jAoIVpAIwk|xX^41S zUm=}w=>eG|tDBM?THGYsn_Ol3AQayi!aRMr$Jh1wvrCs3OcaDpx@&+xMD(~8*M%Z& zMKt@7_G#PZQ#jr!h8!iG@tM2-q{gQDDkg}xu5Lcpr;&FhyW4x=%&1ud*snL-Xo?u^ zjk|K=jqj&T!%6YE&2REO0z? zN$POQTdm%ipnwkuLU0^`TrY|`udU9Vn(vYR#ICk_+349&bLGy~opUB?U4JTzzk@u0 z_*bl6I0aHa(ae~%oq%sd@~Ndg%rvc3-N^^b`b66&C5FXsg2r#R%sSpj$6yGDzorZm z(n4Dgk&w3?<~k{AWnT~X4R6bsFPIzG(9k_-nsrm}XW(UNEruPFNU zaQB?-^NOn}2U=Bxo!cr#g2ilRhADhsN6eOeFFuF}@0}(3Kk(~ZOdh{lZbXOnT@R$L zFzhgP@qK}p(dCU$?V;_6*!NCX%!4hmyU5|8>1I}*dR@=ECovrUvCfOQY5qEzK)W{k z=rOsv8nGyQ&&=&5)@74B8~x?Y5A96XXqEVQd5^g4A0+PaKV7Ln{tL!0_^X;-6srzh z_6S8HqA~e+q2guVH;f1UCx-AXj0Tq?!&nev=PnQ6FMs7F#n}-3;?MWFL3T;v;oks6DR3V^(Jw{ z$N-8o$)_m7k(7EnLi3#!wb1_gz5^qv90s4$t(ytWB5tVt5H%bSuwxW9QRBEEvy`=^ zGR37^Ymt|$!|aa^u7Gc}QA$$v8MycQzCk|{d>d*||6J)@)%HA+R>LKB9~Hy0#EuU; zQ=}90lg2kVCp`glto=_q_B+tAb;I31Z3M$9k~)dT#v~O&!#W+yDH6h-Y!lcU<^Ies zO(YI*$$Q`ESS~si>@!zN$sVlP+CgLXP}h1`!dX6$|6=s3D|)E04>kX4^N^mn%aQbp}MtB%jT%v;s1iM`oF`U_kW>I z)IS+t4u4JCNjcPB!xfZUOV&Q{ zAhfd4W?c+ZbgitoH1RpzZCl^lQ_BC3y)S`>vR(flA(SOaA{s?el7z@MNw!K85t7Cd z(S%CMI!}>g-$D^n6qT4{U&k&aWGnkv%5KIw%*^9|dw=Kr&Z+nPopauE-t+sN^FQzB zlZJU_o_Xee?)$p0>w8_-_j_$^*1Uwu4b+vPt8P&zsC-X*zP4(0Lupr?gr!@TXCg`HnbDrFE@XbT;BN>?L~PZK+xlJhi!AGFPhhqR7_FspbnA4O(B)-7{g?ABZzbyuNTl=zXH)#6z--?@f?!|mWfr%>~C!S0A1m|um3Kk@PXwKj;~kN+2M;a?kvUjv2_l&A$12`x|n(T|~q5O^0Y zHn)G}2S_`ys<44YaOE@lqT9q^G`VDQ4_(uPm5QSb1N$^=n_ZCxv2bm z?=_~!JwT&MDg_G6Aqcyro}+|4_@=0wnB^kfTjj4>C$&}SGK%M-^Ch|QR7NbHck;vXzY!~9ui2QFc4W=NmuLF?FJAca zzfoNCTYd)9uK>cBo8Wz3$6IiGbk)w%U2bXo=ieM>HW0yg~|>cD@`#GToF4>ovRDmgAlJ zCWGh8Gv}oPdfc1r6yoNhFjZ}Y(_u&w2kZ7A0M0Fjo;3LJ}%|;>It{-;j;(r59 zQCbd8kv$dTd;^>P4+rCaU!nV-g8Bd7-p})2z*jj%CSYEMZL|{<$p*-5uCMuQ`y4ks zZ_EsC96~QG_4L`uVB>Cw_gXrsfI$TLHJ_x!~^2gU2`bHY73`t3=3+ zGAlk~3!F2iUZ*S#@~qs-h@E?0Jl*DW!gcJyh+~l)6OFxVwWPr`yjMWa~;5UfdRNS|vYY4i*d>jHOeuLFM&bcnef1O9wb=;$b{> z6RIL?L6H^kgtgF!G|NQ;*k-{W%n3su9?dvsHoxc(- z*@sRN(k~KglN_rfG+;L4T2O-O;^Pu0$g%M+Di7Q`KNi@iu6Ev1^+V{$EuJftn7c?J zx*i;#R#x4g(3yATBQDwHE&UpH+JVe>?`(~=n0{$Iu@NVrxVV*?5Z0kT#vo~fB}u58 zq1vGWF>^Jopcb8tYVgCNZvB3OFs#7R-{Op(*+2&ra^El9zc8r>W z9+KG=mQ{XOESfYqFQZ^*UFnnD+TGziq;Rmgojr(|7eN-SB64|bkzY|}l7u=tWVTF& z*ka4!7dBWP@i+J@*@)Y=>vjF<>?SS!pqmmKBz?|48y#@1HuWRbhWwaRBg-l z<`5nm6VtJ}s%NBI+n0#Z7K>3G@Jd#Lvl$u1wu=7nj&ACW#zpkt-VaVXY?EZGVdn8_ zrbTfRF4_U=9ubL(chvf?sQE_Q$GfbLnA^uYvC*bE)SpV@qAKQh`jnrPi!1IP@gE0j zevxy^duf^9)Wk^m@)=IABkw|Wt<8Ha2{X=@FnEcV!NQXmtGa1@kKQdNeu33S>57JL zGQKqIJv?0zD4B8N*x!(v<7AlVwSqRvn~l~TDh+1$cN{S~gfCl~SJU$1pu`X4bXIyk zpxX?5*)4S9ed>L6*Cz$b?APlqFgnJ$nU0jBp1JCrZM~Y4p!`j*Es3wIx2EMpNK5p6 zb^>Aj!JE)aREhS9r zO79Ws`u$@&@$1#qG*ZwzX{xvV(N7YYyL!TTyHk%EqF)j^_RZzjN~Zc%y=sAlI}b5T z7Fyyud`?f;fkdxY#Dm``j%50V--6re(9c;c;yxHe7LP{VcH~f+%Mca`m>-Byd&q~0 z!)+qtgHiU#?G_VpjPB*_m33>$|8tsR(Uc9Pp-hHv_D2egH_B`P5HoH=tI0`nZ{+#ovL$yuL>@$X=%{9{z}tjP&_IAH zdSCsG8}UV#x8H8}+hTA=H0iH%zx+?lk4)^G8{WPLQz4xawdmEF)1XJ!8Nk|Rx7)A| z=L9U0!_8z4c#&0a$pzrKBsL((g^ zwDz5vz7KHtm!M^EBg8)0b9-ohA8|kAAj1w-hz-SK_(#jSYv0&;rEWh!5-nP%ytO{B zQ(;Ax++VF(1GBcM8eac;W$XJ-cbiWU#h-ezo#St>aob=RqFV2heZ{HhB+N~8T}e=g zbkKKIxgwcV%KmTE(Az?bU2H9kpj(h8Hbb}hgCm({#GfM>2$MD>clxeOGMm4&^){o{ zbZE1+A$m|xVU{qJyowO(1Feyk{fzY8-q?Y#<<0XvG`6l)#OMn`Yi8Fa>ha>n35Z() zEDe+TCwG4Z1kd&jHKpQ|i`i1z&Q`a(I-$5KwCLeG=4DYrJC`~sV^TT;Kcm0-$jUok zOe``7O@A2gOk>jkWGA}`tqdOcXd%S#tOivu8&O`3pV+=OdCO$l-wGz-MSFieT?W0usZ(X zl0M`?Tzf}E!vL2$Ijm3PWew9R!ger9Y>s)BaDow;@jRc3a^CnPJ$>W@c5)*dd=-0# z@fsSAH-|oJ^DqRYZxP=i>5*rFy2-{c2CLLq7S|1a9O$}~ONahIQG9CMZ!}c@?DO!f zexf-qf)?F~8BIV(;|zC1mgbOk=m8#V4Bshe-mtR{lN&}4^FT6up~dXz1%eN1_+&F` z5((5Bl46MH52#^-d&}`1)LpJ?=KR^OV#vX{t*L7nIua*E*x+Drwt3X@y0KNn!i)fX zM{17wC!v@qN3c>xQS+B_R-k8q4PO8h`4QZR^Gp~`FjHq`$~{oclds6f+t$J$4yeuY8_g!?88%I_S~oYHO5zy4GtK^?_)*<9n{xg!g<15Z z6i|su8*pB7MKY8wbau42>-X(_XIACbI*6&9@js~HGStLtYQOh}uK!dl$Ho&T8P(Q; znGsKzLl>TrP(oDA+BnQ&zfVq@{*B7(nhH$I=ba$Z>?c9;AeOXuox9py+IMF#el5ln*~oC3P+;5 zFx63+k?e7%PeQdv!Vx^%I4f1d-Nk~7m#-R!^>;}o6$Sh^rXe-t9Jq?eLzh8*or)f+ z9{2iocchCXMYihePY4H|h>pB1wW*xt;5j>yj$}p@>WaS2VoaN=!|rLn^?o*_p;ShU z(ul)|)$PR)`0Bk~Pu!(KUgp<*bbC{h?Ot%U?Us5YSj7f^oNdU&zqm)y%ht^D^KR&~ zG7m#g8iyZ+W?2C$G0`Uj_TOd-L=Pfg44YTk=f$h=+wz3`oD@lso*+`o2tEsqekHVc zFwjAVK@R2a_Q5~wT5I13sqS@&k*dz`#om6Cv3-!Wu}whviS%{$ygY}yA?8-%q5f)V zHGn>05R+v>Uqy^#sOuIe{XdXcFZ&AL&KJSTTP%Uv!ry81P zEdj`z1Bx2Qpn&=b@6Q;A{w!RA<4f~8?7=pf(7H3pmUZQoh%4mJk58Wqm-o*3iCol} z;Cywy`$!I#w%e4WW5n|qGgI;FNt{`7hu6Injy%n3+>w(mRvBm9QAFsgOm*mx&CcpM zJ*bqVHU6RT4l5we{fx`Da%V)o9>h4G_zT0w)K;wp;XRz1G-r_E6=O*{6EiBg_7R z)!rJIyZz)dXr75K|EZlXdp(Fi=t=E^Y`~j$7GLVyg1|Jy?M+|JxhL z4-=rDq`H$CHqOG-*wf9Y${uwrj!W)1syu08*%)St9pLFcx!BUCE;KHH;=u9i1`RT8 zf}7jkAt%)>-{0pScBkp?uf4~D*f@(1@+i%EyQ#rV%8I5 zv!GCo5lWg8IiU+j*)TGQb_JKQuGG6J{%3v65=_X}&(H(d&&AS1L>B7TFBx8q$Z|r0c-C87mz|ub73pv{Ayxfnv1ktFh9{;@6U)m#!jDL|H|{4 z`|L_b67dMp8tPsT7)O}%DrEu;#oRJFw#IxSvj$(;RJ+t3aV_Z*Qnm|W%QC<(4hMH8 zpH7M_6bB3Y?V5E`$NnWBE$}ZL9LofL#7@kRiWpsNTp75!3cs1F``>r$pX3knthW6V z_DwgoDL_!Z=Np9wE?Yvaa6yNR8HeF^%&ZdxcvvQRTtI}-Ik)@bIn0avk5@+(?nMp5fuHd+Muu65?$4_#BQ;Sa? zJ}luF_qL@>EJ{+%Dcue+FDY0ZS?>~xwlriKuw_4aY4=^(G z%d81{KE;srn|R%Pjj5D0^DA))?ZaY~OQ$wFBx>~>a6GDHt?jPbDZT&rZiVUb!(lxq zq~y+b$oKQuFLlA0%3Yrw-kx;1F6I51U7tv|(6}NV*59y-*zD$cBC4)k;iUNyrDeMW zF;$u@Kd#}-+cxXDdD^3l%d>GNsQ@(*Xg#P7P{Nunjf;Y+`;0<}h1DD4^g0BFz&M;WQ(a9VvVJhMxFvq(>xH z=X6I-2=W75M@($^2==&LCSkXFhE#Mv9=eV{KkHyTf|tDt=h_p7zrlRyZ9u2&)+z#I z7eJD4Fl{A@dR;uqOO;&QDtp^|U)|$J!;zN9^dw#>jFTwppH?=td;wzFpVkl)=DEth z5;8pciw?!g4lp+&a2s1-e_OQAo);kCO!&>D*V#tIzE>dszf!umLwF2Q`!ZdgfJ_`0 zW~H8Ebv^A$*ZcU%t>IR@AGY6qY<1giI4=eB3wn(6APV>eDXU={_|PD9iuk>} zD@GOfoo?e4ZVzQg6zE>#zLF@_`Qo5Mp~9Qa*h8$B_j0mZyfjBQILCz%y! zd07B9b|ywFanz;&EIrr$RmApWcn33025PPaR@P^N2rHhQBK7R|PX60otdr+|Epsu1 ze~ztkrkiI#rkY1!3}z+Y%2Vshd@abrs}@~#VogJ;p^$akahtkDLhAUC{K6@n;nXRB z1oe&Qb){CLea3e8jtsx+yP~X;>ApKsTV{W5a8=}os>}z8?w>t&n=`@y_L0t1jKWf6 zhC(AdV>>H9WVU3qy}71^k1>BVt5_6z-0ARZ{Nt0t5)-Wa{_R$cA5;=uA+j$?JvBuH z)FPA@s&;6daNkqWZ^=e`D$y{h%|3;>3vwP|cZ6tb6 zX)IWqc7VhgY9~%wfLz_DwDV+Qw+wvRrM7>8hJ>T4F9IBk1ymD9t|HQ)&qY8*=Eb?3 zk+Yc6<>+RHA|uDH&5{vY?6zWeyr!$yoDEsvnSag+KmQ>YyPQv$P^`uYep0sw^cc#-28 zea~Ru`|Sr&gFz35@_Q%-!OT40hp=J;q5GPJo=HhQo1*@M*jWN|LJ`I&)l6^%Rt#L9FDQj`CP zXOM;Tr9JM5&HCc+wZ}L*WruTlSvbwLgG_BjP zwd?TPA@R8JV*a5p4!Rb3gb?J#_o^~&OfgxuiR1nI@i(cQp(TQ*>)Cu*U0p{d)6Abq z%mz^SPQUx4F6xHYYqj2V-t0>3Tg`1BH!FQ})e@{LLu{bU!5sZ$h?H6lQB9HLOjVR~ z27r#iKkwM4UK#;4g|q8@L-e%%HilvMItOlk@(_~7IgVU8v#@d<%!fCth&ZtEUBfU+ z2uRMT*HO~=8&D+@ZbZ(Dg6gJa0U(=3NZNnZQNBUS`*6ihuAJu%TMP)tLF2oTCqnkz zOM>;mlb4PH5u~GQU6TL*;itHLppqIjoJ5s@%K?@$hhg*L1;f}ffTV4cCjiCqQec{W zlAt;*GKS%k0%+U6mU2_YfG?U{hFU!6fLSpSamx=IeAmJ_4Mu5d6MvX1iLrrZ@~(5{ z=7eVmT3^qwLSkS-9Cv?fBNQrSO>WS>h#0IAs&o)LJj(7o`<<5sCT(n0aCipko;iBJNh#fOh^;? z3mVZEwNXO{B<34_N#)WN|C7pPkeM^#Pb!zue^j}=hG2;*%h35zGsWVQyo0#q>!*== zWeC?jB}!7|GP}?tO|T$ssFT6xym?KJa_PVHC^;FF2{<4x03xG3fevQciZ|T(3Hr9k zL!%OhbQVja0JQWL>Ib;H(9p^~&#*_ZC6RSyDT#}nH^VlL#`cMI0*)vB)}5X~0mDkBgCQs7kWZ0kyK;HwYNxCyv#kTj zitcpl4kLB(_Pp~Bn{qkiBIZt+6dp^xK52gCi4#A)tLFmxwrZ@2p?vG9>fW9hJG&$M zt7i37Be8ub2V`g*v34dPwp6CDe|_|N$Lu0))VTDfgaYQULz;8%-G_%%jIX068_7~< zsSJvGLq&wYzjk(YN7{{KhH=utdf#Tfy~}((G36Vu)dYjKga`*yhA^1ij$SCDSO71lF|bfJY;XU_x%?glD2QW`(Oj|wnIw&a|SF|8#I3RK+cCB2W?{VBQ505EaEeA^nP^)9HhZ?lm#O` zyO+D`R1ei@aq-hlHKRs3i=DMgCmPEaSZ*mSsd;R%7znsLhoOO;24Lkhm}mn$4!5*k z)WBonEM~!(CIBRx=3jRd%Pay+A?gWbiGizKMLegN6hh^4&@gelfAK&s0B4NT-LJul zA2R%>N37j~DlhJ!+=@`YJD6~||6Hrn{oQ$m0k#1phu*NipY-s3pPqEg=fgFi%1AG3 zfd*cv^ZE{&Jdo$vIwlQtJD!*jxF5eZWfo67e);la)l>b-1+0}iK3n`g*TyYeF$%`* zx_XWZ+#i;9a_-wa<<4?PgY1))u|l?U@*P*9BH^9}0s6)dNR8InjWz1r6q!4!`Y+v3 z)r&GI<13dEjjJ|WMk}#wG77J+*OXePxsF|22oAZ5!p?k2Gk7Z$yaUx5xl7UUQ5`2S z*l<^=cW&ok=Ts`iw;U2GmBaJ7ZM<1%9ExYXQ{dJmlcS2>A0ny4R55te_Zo}eM`e?Q zpj3US>W86VH(#cBnvi8`;?e8$r1YM-HMnNKjae2eG?r$mFHziFYPN9hn~D~WVKS#i z6Q|xbq$o~_3~X1EH7!X=<(l_%+sG9=GrZq(-Af%cJ?0YVltRlL=Ws(*G%~P%*2oLo zC94EYpsBlqs1FT9zY~obyeUqfl(%F4g6ra$<6r(l)36t~xQ#k8)Vr0Un&*GM)OAOu zzT8PZs4U_ktT2JfY?|Ut#Wg{^2WIFeIkTv@FXow>H|^Ba+bDKF@LKA2A6s9^Bm1g? z!F3(@q@-k0b+if5_E=7IM2^4S7c+|c9xw&25pFUJU>~ghREb+IQ`}t>( z%2qYv`tu5SJLbH0VYo+#Cj!!GvUiB0_e@3(9dJ|J_Va~WQQt} ze1d?u60Y;AZYV)tXzF;iRyGT>XwT7ax!6%9BCJJ~(OD6RR%geD5XqVFPJjiqyoLrq z@d@1mcwFB?9_gFVkI%T$WI@3GWKCyG^DB7So#r)?>#-((;=KA5Je7-$+RQeDqVIN%kDvn?tG^P9c@F_7^*D&&fNdacc4kjPsI5QWy*DAwQcgs6X2(IiLRYiYT1Yi%=GTFHPE6EKxa};z^^BM z_T1bYSA+rTIFKp@$vY+L99+B~_(0gs$zI$cs2sQ`!`CU_7;YdbG1^J^QrHbt!}frf z*_EA2fH$y#>{j_8a%H7;&2LhM-6vK+*1-n=kacg}{NYm2iNe?yN=)>a1^a^FbyN!G ziz2K~1Y3(w5^Pt9?2?0Z*{Z^505q z*{+zq>pOL{S3!fCPxdxS)Z=rauH?OJqI}h}{Pa zu$R{Yu_m}F_C`V2NPQrj43QGybs*v*cfuT#AfdGuazRWe5OQIGzqJU_uY^r5-{!U< z$1j2}74-^uoFKeg(YTIjsdg{*B|DP@{BrCSW;m&0nx7O@Ec*1@Q{o+2C^Lp3?{tSa z3(oDk=FSp&zhqa@!L|!xb|q%9Fs2WLOsZ_$U9#=@o~aO-PE|KmR-dp zQzf4HI-wH%1wR;D39wzvCX+#)kooQSO&1(m@ROsKZY{eKs1A7OoH0p_qqj{O7Y0?f z6f>_Pj-5^g?gB`6`q|ujC^n-vei)Z*fQdzs=baJ2;fMlW!oYD931f}>h|k^{Lerw2 zHK1G!c;4f{>(H4}mz;m|mL2c&_`9cN{cA}V?@3jhu8Xo6-6xrpyG9-V*R-unrp^L^ zj;?OF;vh6yfX!iqg9#mpr?FpvM)zX$0sbBV>iz}4AW29Ako{VI^>Gak`giyNj9mYO zfXqc*o?H?eLPut}$r;+6w76CCwdNeA_>~La@egWcSHnDD8(7SXkD5RxRGC~9rr_?n zI`1w{kx$rlJ2xL$L9TzAmd35kt#gcd9ovGmmP=>lYu{6JMTW2X7P$EqW^+2-B)Qg; zl9qP-r}80rq$T7~Pd5%?$j$!L@xp0AjnLNsbGwk2`d zXkWJ${}X=+qYD&dR;h9Qe9Ri6RshxsC4G4bLCWYx`9q%Yr*osYoER z*H76xErr=Yl`~A(y^4qpc#6fIQ9CtvHP~D6(&BpJ%M(%wwi9VcS<2{D z^}tI@{(Nfif)f2zYS5nTrW`EgOqLvVh@$X~mI#C}BUOvPmt8V#7rK$yZL_G{4j6)=C5<`B{}_+L^+=!Zy6cPF0eR1x9v0;CyQ7^Hg!j8rw*rn7}E&w8g&d1{qEyy6Zc^=hx^ar0Yd$DOm8 z01)^h1sYvf>K_CaHpXWj zdiLe)K)@?dSoKXu>9=Z$Sl0bP&T{SnZHEA8ciqKKO1^t%WIxE@#Oy(!teyWsF z+w|s+k4@&Sc5x1%B&=HG!`0~zZX2F%)*F?ou7dbu!p4&VDG=3u284%jAj+RezU&B8 zjHEt!DGU-h-I`A5->y*3$;e67{Sf4G`4F1pl9HtehalU_M$Zj;lkXKrlLDV^fUn>r z=p=0SYPmRGPwTJSlIfu6Ec{@J+c+X-m!PSz4>$X;$$Yw&)kU9#4Ed%y zh`I+bA3zUq>nSE72fQ3f-e)Vxy=LddWnP-0?cW`GKYUMa*R-f>?Foi`QedIcvRM6H zzbz*5)c7HR3W#+){TY>5b6z{>Wl~yB&dB`1Q~b+@20aH^ZVtuee(KTr2>k-4izy)H zvDG-=!^tudavGT{aB?jR2`4<`x1PkRj9z~1b3sg&J=F9LlSC*hPS7MG)S>BE3*!(d zmzV|P!UIk63(P6Zt9em=z=WF@##^ZdYnklKwYdDJ*!(UgXXef-!X6!f*+D<~nrFS6 zuj!{ZduTfz@6L^Nm42J`z^nA^^Q&(JgJeaYWTG_Y6zOtHwS<*DePA`0tIYstmoTmz zNJ?Zq+->WEYzX_=VW8Ob19EJ#AfhERNnPT$S-uWHZya@D6>+?~ib$0Jg@r@LG+yAB zjFB*0G}Y-Nw6bu)p8-UMvUaYjED6Lh|He z*`HR~gSA!mpDWef{7={||5iWYKOnyp+s-&oO(Wqe@gf&A+=rYuI^aYT+c~f6y)@Ij zNqKf;U-~{-MB+86XfUtcsUU7R-d_n(_ksXn$oN}Fd4XR2$R6lGG=l{yR{@s1>t;WV z9aM3Q*01oug5L&^3_gepc#0_?&6j z*|u(T%5C49j?M)$m_d`&}y(CO!@;~_u(qy^RXe&>3?F7 z_?M5bT!rWlOsF5<0~mTG6tvtop1ua^F#e}@xPPVRnO9qWU(UbCN%ecX%^!F0|BXjL z`S|-h{Cyt&!;O^x=v@6n+ZP+mm(B^=Clmpdfga{ORoHicz!R~iDT!um$(bWgr9V3f z1l*bek}uS#(vI*Xi@eiDA@Q>WQ!0{BJ&3$7v}N>5?FrOraU{Tn?kYOkvnv0fOuwi6Lhi2V}a0zL`sUWqM*?)xur5oe+c8wgF+#W5_5`uI0+DHwN zs6=k*b30poH(j1q=k==UiFeGnT$T%WB^A@^+iWsk;uTHPz%^|gYLcr_=QUvpt+Ecx01x|PIID2j+rZzA6`Ga@b<m0F2s5&V%opqiLhy zcObSpqn8yr><-bgY*ilzu+MD>v+In?C?wJ9wZL8ZgSc#<=NGZ(+q@)K2EINt7-VM`MlOvz!yQDr4N#7l%A8_^~&L8T6`gQ1%aAOtJRN(zV?Mcy& zcrEwX?+*TWSzq$QmX~A-{Igurru}?)72$sW$apyGCYd=XZCqNwF+tS#e;zTTXQ-xh<5f{*?fZT*9fOq8Xm>*NlI z*UhQB*5lMg+diK+vHO&Bl7&I4Dp$XW_NF%HI+O)fV=c~}@F&4}hZ~Po>txoLR67Hi=?Y#v-)yw2Aj?15x+?yQN zpISVn?$~m7PTq2seK~4W4wFRgyR6B7AO29?5S?x%)4I~{RC`?Lp}V+uo|c!)08baV zj??$m6l-%sAR$E3Z2!8e+$gtuelF7+@Ye2xq?M(OwLFBVZMp|uTN#v_cq~oi}BB~82^sLoIJM9BpgR`-luOjul>L z8k)+i>(_bIET<@(yZJ$LHGVPi=yv3Bg*lOqXB6~eAZ5RY>%;YKq$9dE_2)pRjl4W}Ss`x5?5R!0{R*2M%8TZJ?q2{2cc+i7*A-p(N$ zcwZh}tA#m!At8Ajy{F_xsb}EaXY28r1lB8O&V?;UpG)$P4hLn?v-fxYAC_ePno>}< zKX5KH`%zLW(ady>JDP)yd!#y#8t)mA(11U@)MQG${=_42Ir%&z$ZM~Q6sTt0Uc~Bj z488fn-N_A)Le+^0pmoag2(-!OT3U?Xxmd@}p6Q?vJ?Y0}G+mYtYW|8MV^$HpfOV1= z#o)F9FdNAIw_rRCE)rH209Mu1Pme(jV>%>ys1oqD008Z|AgMB|i0MES=>Z=0mj!hL zAd4|dD0!L8@Mwg-&A=y-vpjFm1wVZNaB}{~Ptf_Z4?vpYT?X(}YYE`we1HastB68L z06Ff0R)lNejjIS+Qt4t+*usPY*zAeR`=R&hnd;xlQRFH#a2p;0Fb`Tq+%y36a-i5^ z{zw>IU=^{Xk345zUhQ<0L#o=MTo9c z(eNxFUGnE+qXWk6PhX3FDQ62rFeA_zpm@Rt)>sHwQs?;o%JZQB%)4_%19&EEsSJ5B zYZbx458AhOe{vY?o@~69$@MW!B7VArJ9(GK}IMau?l7TyR+wqK*K66Frm|zvmSy znl!^V)v6$jk6?MD`Z$b@Vf08Gv~T34|7=)0Wxl(2`x^5m^{&2Sb{}GkR3B8?q`9S5 zSoV#9tu%i$bO!LJ)PuENwBHizWUf)YPG%$d zhuKCh8GV9lAsK1cOMS=Rpbg8`1$%$dBQUk0tQ`;NopiXkKdw)`K_Tfw|I(*AlT#ID0c3oofnui-_1_s_mkJ8ugwJdvlU_T`;z-_bM&j1T-bNB zQEKK2FeVqXFv~Phid9SZo#31P8*YaITw+I@1GPiD!>+wal6f=bnD?B3xa|=avcyxlp5D7g}^Hs^bpo zD$&Y06hkhA*E{rgMJ<=^BQBz`AQw&P};D#C>7UlaQn6_g1n^5Mh`-M!}f9i98ptk?pHao&S;=ofp zGm7Ly=Pl@*_k~<7%mj-^@i(>~ljvN(cY7ggLSaMWo5!VjY;(Ws{8tlV! zMhgK6DKHM>Dv82G(rqu!Rwo5LdjBRW*6H!9OZ4d$$yqy9o`uYTClC?4(bLexpwJTR7Vl=v&>K|eJ?)r)?tq~u!(;$ zChJsb&$Wp6Q*emv5LH}7ZtN+isrhbi!{_pZywqGgN$UFd1B`=o|spcrw zJ(SZSj%$`?hb?`(iH!6!<|zZtb?Emjcer23_7R_4(>=n^h^YgJ4f(^}p=-v`_chpj z*wr)$BMpnCRHQ^16w=jne?Lif^n3u^{3#wO^F_ki>+QAmR@Yh>TXxJd(b8yqUN?}! zxb1i_m`2Re;^N*&7{*~U3{Iz0?ggdB#B-A=U=yrv^ zHL6BPAuoQAhk2$n=*q~*J5le~2bd%_u-x%@CNeRz5^8UqA25)xRrk5!n_#AMh6mTC z&$uKrwe8#`%TZn1BP`c6_!%f%4N}--D79=#!n92S@cn5q>$fwZRM6`}vOp75X&Clf z%2e4+LOG#pXy6`^sr2x~#VxshQO8T1^TtkTlxO9Ozh=H0#kU?nxJ8KzL_23&X=Zd@{D4eNiF3uK;|)o@irv|% zqNFz&9Fk85E>)E$t4|~X)AWrdOuvo}?Ps2cHepWh?wk*opEBZ`+x}>awc#Q3OnBP2 zltD1-UP~=k04&6dehU~UEy`;iF07)uiNC7j$ma{wS9SP1pccC$qMRz*>eTq_rZ6p2 z{}$!ye<`W?uW}4pZDr7Zt-DWX!c!#)Wq9g&Xnl80*Y2OJlXJ0tP;5;39h#BXy3@@p z$*owM8Iw^39r=SjmAgx-u=@>PnP@Xb1FCpYvWCc-gY4KN1cT?7fu1|RFgR{B50oW)P+HlwMSpHiPDYE%uJ{750h6(X#@{>-w z)I4K8wQ@H6+SPBsMBU*C@YRNTn>^-cJ~I#H(f7`T#B}Y+yguUQQiHpga66y0D_!H){{49u?zsim!){XJUy4I)jRU!#=x=PQu_I`9UbLF*urqu18>>hsdTD#@_8$2MQRP@MY+ zdF7tvX51NZgzMhWg{A64x}S{fUz`~oxNmp}_}Xkaz#LfsKU@*6!ieb9%G_OX?0W@V zX&G+ERs~gjLf`HkM&Itjg{qr9oVis@NVQWzU0Zts_rah1@`YE!g;CIAWhM|}|-WnQuV z^DpxI@4skX{%9`#r^jCd1pRmpW)jp*8$i(ypnPJL`$7OP6jq*S_sk*Nb|h~j9F3U1 zE!xL+>!?cSoHOQ&JJa1YV2BG)zU4NRMUiEOZoRNc(HL>NU8RyEEyFybJ7V+oC@4dE z6}6AeA|4H_O+3RFR^f8l-nzIyFmWQ1{<+qI9GFQSFyw@g9K2Bh>jwt7w)SY*S zH|dEioWaeaQg4$xir>s_HgE>=RwD#A(T4xU`1zM-%7x%qP69`Z+q&Z(|L zvs*Cz=Ak+(!`ql?6RBI9(e*`?X3SAEsc3h1bsnwP~=pcY{Q`AhVFWDQwV2lZnsVySBHxW8-i)>V1&*R zhIu-UjQ&Cx(tbu37=#0cXx5E3+ZlKRKWHQ;xQZio+YzZs0_}t(^Gbfw%Q26H`#CA+ zvC2jhuC@^8?M)~ArNj+i-gH$De$H$1^2ZhO``>?ihK%`FqD4QBE%P^?uCxK$)LPJV z97pu!--JCpG!`WgLV{Apr9`n2OE`xhZKg5t-F zqy4Sf%lR2modb&WEm;G+WQi&_^@H|XyKkZeWd-7O8n{z@51>)Az3VA(Gjs*9k_B_b zQmAf4WlI>AV$B=d&yin0o*x``iamX6gzbi9L?%4(Qdk<7(m9JF+YH9!f&ZL*DM$_R z#^)EL8a9b@Om&aMZ7mm8-mDf1RA1X??x~tQqfS(GnFJRrTsx3mV%MdT=Dz+&sBsSy zvyxQ{xIQ&W=Hx#vo!=jS^?nC>LH>SIwKm>_J=7s`ycV^WTx3DLo_;6jRT-{HR`BqW z0loQ@lElb0xzMzTz7o4Y1guhQObWr4U)aYO%Mi0e*RG${00}x{An}?Y>|NH(wOW zt+=plQsM0#gm_IiHVP&)T7_LX#<8;|6&Lq*(0K>FWomY8&p#ZO66PY(Ui=!}Abx7w zZ{O?Q#u>Boa)6Yy<429^ME%V?aZ#%fBZtATsCT@K?-y@;|OLD)G(Qfu+6>h7m<|s zm${dfbQKOBuJOJ%@t}#Op*FF?LKrgS_?jE167L|iH_38z^u)V&Pkh*iH@UNz%+Jw; z;N~-MGm33;!SUD`uanOv-?q0FoeG*oH#xL9&G+VWKV({n{tZN>{*}P$|E+5Q{Tu(7 z4O_EgFAtsiw5Zr`~sDg_SM&4(*W|A1HMO*U!+T_OF-+yt;jdcr(RRu2Mda9vo+Hl@HxWb`YLS} zRDyBj(6kAmbkNVjT5RB2bdmhr{Ed+I_6+%+0mjVgpTzsbGrxMj_;FeOC;t8S z!c#vRD?-@}b+5o=j~DZhF|~!HqMwMQ7!^4KIvmHGT^r>R%bgI?FxB~t)KX18CII45 zJwh=rD{QJ*ic=-ONWB2KqKw@H_`XX2Aye1 zH8Z-ralf~C0K@S3?(LsfIu~rnS%c{N@Bu2bHr1b`cdrnaxvpM7G~19{+>g=%@`fd~ zFQ9$>TT8dOMGbtblDi_ zP)RVCTgFTSc`0y&%X&7G!|~N7exq9(uOT$y&uBw?=B19`0M%LdA9jwv8^!tsXP9}; z8RTzhhmnEm2uxriU-q0Tn#BG%)wJ#Lf%%r0n`Pt8Pu>$JXGC-f-}(e! z!x{QKO=|CZ(NGPxXN1Z0Xp0j{tfrQ}??pT3ZqwXwjh6>Zs=iX#?oSmH*{GSr`C07T z4d%gLv4$Q9j^1rBsf?C#2W5e>mgzbz)5T1>d8=!xA}?4~KC}DE6;tPg>kl&S%s=E* zOuj#<<(-VF1Fhv0O3_6s6Gi7F^=@)3K;36uHyVp>Oxw#oXMJQyb^S39rS-&Dpn%gY z?HHx@{>>%ll)}`6!@i<=JaYIo;`)seN)x#x29uEd+;_`9EwJ}!+4boXC>7Vrl&Oa&* zk($pd4Igoy!^fl6lqF4C|2a|)!r7qwWzgKysmcl z)UtYtY8c4?^By zPQl;btfa8%gQA5_u+FVyLQ}Bb=)gQ@&f>`5m`KFP;#%ea2@&0NG<9qM@~d}r)r(s)-Y2-< zk^E_3{t(}6je4s14JJYmU6Xdf>r7{cKG82deg2ZmgAU2kcF^3o?m0t@KY*M(aEn0M zqd7RU?Ly=FSe4kms+{Vf%U^fKD;ate)Tx!|^z;dlFZ&fEuk5EQrf6EO3y#60P2D!rG-$Q(_ z{*L&NJ_&BVKkW?EWZ}GuK!EEB0mbPXc@_Z+Y+_+QjL2K#nQ2i5WvON_K{2_w`Uof; zkEW-7yZrt2gi%eQ95(tw$0~x%N~E^q0$Oxd8g<~iJC|fYl|5ku_5av=^Khu!_J4dN zNw$!(k3!0lk}YJCY)PUd#8i?HDp^7^MaUjP5hH8artC$=E@aKVG#L9jV;N(XcfZT? z`QD%7c<%f8JfHh_Ki|*s`+dL1^T!->+*h}mYu?v&o!5E3&e!=mHGV&vE9lxf#(Ts} z0r-xmplGl$cl}~0qR$Dj&c^*2{ql)D`sfxA9{Sy6(tr0{kZF%hp&~Hp)a_>%upql- z(%$jA$-saAU>JOWqn{YX*rBtGoP=8w0hGOxuKY71?+iK?V6< z-@SK#!LAXzyS>$vXSdD3m-?;jYIv#28?=%Yy8g71F$xReD3A^Fo)R*ACBO7$tJHN)^K?<9_jr@obN2~~7i(0e%xAX> z<=lXUIfiv-l|h@WFeYc+{ItfU!rnaj4PzDi&&3rH_Mul!S?CL1W$g^GkX>`s&A#If zf@p(($>&X6j!rj}VA0(fa}mb;;czlC30G{=WUSw)UX17Rk`{ z3~zU`qcD7THV6nSlFl;+}ghxB29?9hgYU0~_3P+&woz$D({(`1@?LrHY1J$f_|lFmWaU}&t)m2m9TfZ+JS5-G zDAg7>crJLc>{Gb!>nvM6gLy9VpeNM~E{ZY`AW>jWP@3ZQKE)e1yFox5(wPxzWlFbhJPPif&XTT z(Thpfl?m40sTloIdw%!n+^tH+mtU|udK6BQ2qmCSWmxmT#2YbZ@u!Q$m=z)!v-_=vJI|VVMy<6p2PB(_q9{AnN;pM zuFz1ycb|ha!XEFx2BGSx%zjZ4FBtuk-9sz%T{2Ele+rp#KZaV0e9@#oL9vL( zPgfDrlBNQp$x@Yjt!hN`i=Fo!e}7I_z}#E_#?;Bwoe4C~OWTkCgJu73G%cvG!u2WxAN3`@BPr()!FaTr6UhP5>Vrnuf*Cf7&CLp zlO$w3y0Ev0s?a$5C>2lIV42D#lt)~^wI4)&ctu%Z5jyq!txBR9ZIf=^9)8Y69F&q0 zF6{Uhpg!*1jC9S%3Qr+H+YUSFg!fK)axMe#Hy2V99zcsXKl+~FUoBfKBa+!aA47@H zN3-rEEaoiZ3Z-ac_2hPWJ~ z9TDtwRTths-od8(4@L_RJ|@$>R8MY3tYlblgLws27W;$ba1JAXT0JG=M+*E-|l^)>&HrrweEp9=q|?Gmt$Qj6QqqR z7o$Y-c0MG0WOWYcchz@CrYmm|Xl8TGH%W)_mjb4o7X<8YscGGICCxxv=Gaf;M{^3Y zc4gP7;c&Q(LD6d#o)c6|&XubQK^v^Sl^U7zg-;fSuN_8M?mYo{lpy0Taw$G|!f`hA z$cKNGvG9*}HuHqrumjnql<^J)2i>`5dE`(^#7VbfZz-^mrp=VDj`R&sOTbRu)W)*%7Tw0bNO^-jiFFpNTo$S$?ojm^B;{D{y$Y><-Qpey9lf5-wq_y$5r!IPf z9EnS>WYTc|;J9$3?y9B9vdnyE@PFI?frlYvI z5AN~M@NS|`ZM)V*)3B_J81$q?KYD?L zSigq#1~9|q3rYZE0l|<7t%Q>efzlY#26UgT2hDD7;S8H`lR6f3G(Dv27wp)Zjbxzt zrSxF?HUhE5|HuoT0xw#jzcmkLc!ACd=X$Bx$T-aJMr`pf@GC$)s2X=lv>X!xrhtMjiY*=gw73^V5Zv^&|K2a|F1>b{?R%8yD{Vrs1LoW4#HD1H=L?+ zo;F{_!pH*>2amwIEQ%r4@5OES9e{qp;QCvlYRGT~50RVxA_a6~`{;s!LaY{?V}3ud z_SS)v4S-cesd6Bw@!jxWC3F0votRaa`$3-xknJ`7In_bCjU9~2UogAJsC`alhyzt3 zR`EpFY!x>bJNPR5=T~HA$Y#=RHRVOq&FZ7bw?4aY`Q8@Rg(P^eg8h#YW0fl@LlgUt zN}3&PDU60pbdUpDMd^;Vhq~RAL_NDT7#f#C#>&gy=rqZhapr|*gXF6+{U>S=Q>c|s zJ4MNGM!^l;G+Fa6=teZRmSkkhlMRPy)jRC$HhbU`SF2eURTtf8*Bs%+!2wAoePXf- zBC`(^$Z@H)A%t<=N%o&AlB~=O>zCXuUGX$6ToitX_eukrQ`0ir8($;zJkho&vw2!u z)0+^kSok(q;)6^}b?#AdK+iP57kVwZj4KRT*XfPYze+TIxqf~qKA|G$dLYfbV8W>S z`xyzTv|Bu@9B$&Q{jdHBdlK}t2dJl{em#5z^aPo^Vm|f@R%k*GrLHhg(oh)ahXr89 zX!u_AR|nu4n_2`w&4^Xm1>6P&LzT0bJEcm0;q(g@1rE46V(N!h5x6I(Bmb9&ur+Nm z)I`1?q!()PLdfb1Z^q+be>Fs%VY7TSRVZ|87!v9yyMU7r0K9H|BWP;wL{6ZYLLG^H znU3jmD2*SkZf;qTyAN4%%}K@ir$r^?o`7AU8EmqPZR+kbe`>1xu&hdhsF(CmZUeq| zuCI(W*>8r86r0=IE#ocj0*juWgtiRAevqvg%_lQ#M@Z$0Y8-*AtWNY}1}r7%(2-q}zR#{1Gdj1^m!7H_|< zlVk3P;MHTgbLE!Gw`{d0tzOL*a=BQnF2CN|A3yd|H};fcAI#>LWvCpuyz483VM-|_P;T&m63bF2Ex>#H)d3!aBLPz7~QN+EMxfR;nvSue8>M&Gm5*|i2jzu zwniWP>NZ}#bQk!jm;cs>=SYEFqnHvA`&8*@r7_|3MNtW(XVGF4?j`&gH*(l!KB(~% zaKt`bksD*5r~(A3{(|wDO~)^Cr`^RSPQCN4LV$b+2Q9zF(;`*Ge$(# zki3o73WLfyyRwJhqnFvs{aFRYFJ&JnVi$V{_T@jW;X`VTLcQps{D_`D3zXK6Ln)&} zr}ln(SP;y?-8v8&@W-QYI>R?G3QPsFh<1t^B$$0I`;q78n`P|oaY&a}F8J8;*_^0^ z57PYWEM_sj+uvFJcan^GPb}(f;hh$!&mV^7TJ+rA70;enSc^7)ggZbny8cop&Dm4^ zdBits-`AOvzB}#Ck3G6#s(buD;}Xh&Z9SUkANs)7dxJ~6&;oBxR9FHwKq3pAup1bRK&NrTDrq*)Km22IpPc>;2x0LIp1%6J>ho)$hQPU1j1q~TD}%^f4yb$ zzk!4G54Toz9Q4kl4~ zKw0jA`jMZ?LVm#<2==#Wa5 z2^$3XiXQuJk7OBB<&y8xT#>uDuLoV-nJc>^FC%wHMkw^vPkWrf*We?y5=1w&-=CQ| ze?m!h|Jr-+sffls9}$yc`pWbTs2}%nCInP0+EVb<_rVvl41XUvDb<)L0J>wBFL(Lh zFQnek;+A8gCeS3IAuV3UyY+9H$`j07A%umyMW^GJKH2URbj)kzHK-ox;q&D-x0AcVe5I5d_F|)tHM|lJsu9K ziwAr!fM%fY7uNM5w%BnKc1ohci<$bw(`m1t_6XM1m_q7im(3E+qxa@5@@b)MEny5I|FJ@9ioOoaQb^8#XUhJ1m9Ys9?p2@WhNcOd?JGya+1uYj+T6ya4M=ESELn?kp)!yci&SFVJ}%$IGpG>iTEjB$s)~on%-#< z$Dujz`#pC6@}wgs+;}J^myE4zshNGzzJe?c#c#%5+A-dLQ>K2RNuz##na;+}0re>@ z)e))gd9G5jReNWQT>IH3(a!5fuS(y21=DV%`&;@E$OJ}or7+P;=Og6}v0199+qIw} zc^V@bp;*0|tNHNkO2TeW{u|#1q&!J45lT63>SiBZ(Kd0@aADkJ7=;EkFf4w-3?=h0 zgM^07DoqM}48uSF@Nrq$kPU7~Qd(*{#P5R$t>~cif-UZWo12%LRDO@M5}&MO^p_xg zJpq|&qPEChpZ+4!A@0dLp->)T+?;+K2#VsIpwXo~S5pJU3FgVQKz@{6xO)xLi&{0=k>YkEGT5fEX+$`~&jvjwvcWiH0azz|=;khO zZm8-QZavZmik8JP5fP>aN76orxe0qKZeCHBO8J`I@xFkmG;;p%TZioBy8R54YWMFi`a|+ zcFLg%LXeR4EBtuxj>d#lC!u6o{2fmzJ2S&;X=WF{{ZKO(VM6xS$$WWrf6vDnK=8x1 zAVM(23%T-c4PI6qb?AM%X>%5Z#aWH>gPHq>o8>-!IX)7grfa{Gi(RawKRolORLgmL zEak9U1NVOKU_%n3F2Zx(dgDgFKSON=txP*YC61F|qk5}5S8b1pr zX^gpecd{GkUsvjVYNT2kaKuLl>^ke$|+eiGBca$eI zgj+pHaf2$S!PmyEchtFUN!tPj#7FZ4=*9W1q<`vn%?)&H@!ixyC#mMs)spi;JS_8Y zJ06Dfp=&_QFgdW%L8XKFuD^5MTExzpKzq2slWr*Jv^Mv!yCgZJ-i0RwCEjoEn5_1`8 zc(;m+7!@=fGgJ3PDl$PS_tVe5p(BC`r)rL*MyEWmzErE!vwXNc@Rj;2&0PY7fHw3I zi>f1}sG@N(CkVT+b%BDhXS@NOkrF*mkZo)sIiGvw9?+{h z!r~z`N5C7zCzyYsu)m)dDjDgIv2pId7o`+&aQ8>+ET&uDBY9{c8dP-zoVO>}7_sD= zoySf@m5$EyJ$dkugkFpzfAtuLc5XLQyeGcx4P)!El%uxU6QaTdli=we%CbpD&dED? z3I$V?*Lde=9@M@(>0d%~MuuaESCSM!NdR-t1E{KYRSsNfM@@0$%!dR-P-E0y#&Zue zpEJ66B`~eB*L7Jwf1*b!=9F}F*VCdmyR`3#oMhg>&dT6+0$O-H!kLj~U*3pnNc-8g z{{1Du$cy?{o?K)d6N#E!Yl5$;Qv#sc18AypZaj9r9T^1Ybcmb5)cTx@F)%u*9QgT; zU9;q?X373Z2l~FIT1b%CjvG@ZPDeXrvWGdZ9Vsr0giYCsR#cW;CrAs$bFTPqlng0{md~?X&2DgO82JmQPJ0*W}Rb@$b zfO3s=>u%VcanPY0=t>ji5A3Qw%C9n&*bvFjxD)9DZ`ibPlj1xD6?=rHnPto`A zgIsHb;xGE)elYo^N^>XDkd5GQOXx&fQm`kCPUMSz7JpV>@=(9lVZ{fVyP|42P0V{X za_2lw-&u90HMieRoY0H6(8^mjl%98rDUx74YyN7Ir)H-YKZDs%j`0+=lOk*w=x3dc z5m?Y=O?LNgf9ZPq+|}6{e)k_JGSK3GYr#&_fvSN(VM)|ZC$P6`BaTzN>V591ZF@&X_AX$FFa7|s1{>1|X z#5q6aU$8^c!H7j+U_{fdei@qFZU*^xBlHDP>KetsPfP^d=Sis3j;k=cA=EXU*?vH_4S&!~G5wG3T zd|N$EbbKBWpeas&oG6%|_xa@b(Br^i}bZf|Dkqc4mK*A6@gho&Byf z`pbL3>oNxP_ea@QjAaUvtnmo*Syyl@CQ7U|)^!D+E3eSDaQfY;ftNPcHYDdEeX}+z-yc2RQjD3?eJECm{2uE_X1H-jaoX90`|M5c1P~ zhw|yT_xZNkeAhB7Eex5X#YTBH7WIOv?hG;a%PWrpCc{*IM(z-8u-Iok=OLUoe{tK#q%eJQzGE zBfa28=PhHH)Z;gc-KHKR3zWA+fhEuX1-ndAWh@DA*;*__JQ_y?J>QNxpK)AE*Yi5J zr(5I`Gs_Es=8Z8X6W}cP#ADJO`>m>S@5jeqJ)ybQc5o;fbXE1b<--Dtxy!uf3+-_s zgc&Q0y9Lu(gae)w%p;uaI}Qysb4}>qYd&Om`u-C&`dMr$+7TP!g%s10n@1_ybj@9w z=(rGAe$;_YgH1H?td*W5nQ;Tumtc@XlFhGm7j+EW!IjJG~Fod6BkMuMU}4vU@# z40iko)=8$No&}A-5KT?CP8qGxb`1bk-d(shP5hu`N8E8RGi!mG*#do|Z&~@@hY<;f# zX*#h;Uv{|W(|%Z8n!)uqu*pE?E zVVfJbk**ipWr}z~KkhJi_xt-}OW8v<{7Eo_yr-#>JJ`80EC}+!fOU;G>pO6!gn}*z z>${P3`^QYk0clue%%L;|+J%}Y3rT`(?h)`3%tvIS0uSAaqAX{7{EIUKS*IFtqu=br zV^{IjZUft^nne-cpKl!0mi%2uhG2Lw79tnuDq84gI5l~@uEzi z2xcE;>B$5R&K;1*KWM{Vry;g=iC!iYR$@6xP zt1#=SF}G?mbEb!_2qUXXi(+}~;15lY9AlR6#vCtyw!6JMGWKDI2KL02mp<=Vf8xwg zg6JBhD#2Om^#kyI9!9vmPVK3!{23Vg!iXsI>)gtu9}nz1Z2Uj?-2|@CefsgFItyw? zov7Z*nV!5`F&;tlIv$QkB8>KF1SdWx1aS;Nb4zgZoVX@;k`~SbG4>k2s_n5t6LiI` zhqmMPENote*dhP~eE=-Nr+6S|Y>+p1qw2UrE55X>2_OEGnUT7I<@DBQpaw-}4i8%@ zd6HKbBdm-b-4VC@ESb24fP^Z@k>FMe%thltp-~BH&=06yqqzOh0pK9u+@ z@CzWgGj4CmPZuSlJ8x6f6q{=scw?&_C1nCeLRacqHysNtCx5|;^_O(UxCaUPSNoD^ zm+Ms>t2zvy)wan@K9h_H;?aGnx5_5DH>X|%rA9YkRG{y=!EM@=814{{jmm}RCEc&i zwtNjK7jL%~HSiFeJc4~gM-%XTbYl`KlMp3flhkH$iLWnq-r5!)aJ~%|)#r*!n z$m39^S-HG;Rf+!)hjPGFS5K&wrv?X}7}p!J>+g}b`~8VK8R_iP|| zXW27<-gVb?ksAGsu&apf`Mla2a{H06(FeI*Q{rpSD08$kba%?v0t&tXXo40uL|`;ht1+_{nvlIQ&$)P-3&MAN{z+pbylPe~z6_w8PIl zyw9d#!QQ!t!bh%PL~C(VEXIQ)Aw@?Xr1XpMg(nH0N4b+9`wAp%+M|{w?{Sc9_~+Vw zSW!dZ;rvUsa7H>x-TSg`F>wXRetnpu+5;*`S?DcZC6PeW0gKT7MUZqf8L%5OsaFU{ z;Uks3nCQV=?z?9;G1lPqN%hmIp?7NhHD}*uF#1&dmS_i2oj74XXPU0JPMw%6DLy^N z(W1j!x-`Y6>l<&omrb4M8*Q-(oJa0tR1ofyNew(uOXY_jdR)zerMEct=+h3*+u?$@ zysc%69j4qPefs-07H*CoY3!B?+-GY2x2~deA_Ie-oXA5gRHcJ+br{>LPS5+p5SgwA zy3}MtEL-1}>gfRL2&FK;{>`!fEARBBf1VFrQ(ejh_%kBt7i^*3jUkAd>$5MuwdfZ5o}26?Wt{cD5B zP7Hk?95Rl&4idC&2oRR;$ECFV*4iFBXapHIjn6Rr<-sAxrsmGze?QWbO-#E68Usb3 z$KE2i>COas-$qp}xt^sv{n%yfPl1W&87>~O=S>X-YvZds)UaJSN0QZAC(eGJMzU%N zIAV5H3D4p=A31+^%U8O2+%RiT;i;6z(zj~7Y27>xhNR$nt)s+*+Kzpc+?q>u5AyF) z{WjxCBNB$f0{gyI=O}*CU>agti2N$s?IaxLeQar}QiNm+?>j-aao;!$UhuFlzAQ4`npQ}Ri=*|f{vb(SaL zTX#L!Bk(Eo{zVJ$(8k{2cCUZPDknvL{@^^7{QZG$Y0JK|gO`T=ngw6oaQ&IO>KS=r z&H_U;K@Ubw_b(FfCd9|wh*Fn7JBI7$J(-0bQ=Vd6&Fw8*MrnHslfN?3{FpuV4S9vH zcpQ;9r81P1M6p4C@nf{7Ga4_?>Y`eq7EDP>08Q61LFno`r$p7T_40V5T;+QWnzxru z4##VhGCyq@JYpdHdav`UKDm$)hs8OFx|w|b;fA*$HkMUdjhxRZ(pB;+Rt(%S9MX|} z`tI|V-cLU)*@Lu!uQQ3L;wE_%r(D19#G~UL`?Rl2y7?`fc4tw~XD_!?qDS^czrtMw z=u76SD@zhId2eTOj2pKWE^l;>bA0@fgvtluE&qj<_N6ff{)nAeR#jTsS4Uk#%a?v0 zG)+_^#KTCY-$WDS*=@h}7P#e!RcE&7<$S9A)N<{fVEe5dRSjvO2U+-UGa)14#483x5UrqKY2-~|3~nzuDn zTZ*8bmZg^GDKcJ9K17`$&GMKvS?orCit6{pubStMx>}h`z zXkTy<)qz+yeT1NMh5dqsVqf9LveB#OG}X}ax4^<1EawutbE@QPuYcH~zfJ<*c8aD6 z?m68~P;{cM0LesA96JVQpwMF=uz1#mCYjUgx~_p3!|3Cd$s7R^awk6q<|FJK*h7zy zbaO8)b_#~LGI63(jF}`FRQ2c=Okrt9%cIu~)60W6N43$CPbuZf^lXa-G@Z8s(n3iA z1v2ntG`{o&ooGKV&z8PH1}^;S*~%sYYdxORPm7{-@{@>O}(KzLzed+yyWiL zLsudxgjXU7M(vWfuh7`F&WA1!B&40&qzr zve`+PG%etHM)TES$ebV?>you`d1(9l5tMs4q{Ki@pDP2F^K%H|FDl>t!_3@&5#BS0 zlkjyN%yi>Udty^mv9*i)g`_;yqZN+O>hU@WuJ0K<3*fjE!|8knkX@V|AyrJHrA4<= zEADQ(?q5F6B&mI(wnc-*C4YGucmI7VHzDbIr*a2}Q|xSt#E)gAuV;HhTix9GbYyr2 zqpJ@UR~{sw1-zNBMInh7i@e^=pRnVzy3qNVxyVK0AvNHoZzzOCN&1OW>`1=GPzxn$ zyw&)j)K0xM{g8=vbnk_V+fg5Ad9tgU6Te_|-oozj>0Yhf-U*AJTRBy)wH|)Zt?~ZV zN3mxFH?LIg0~9-Op7kiW@mKzJ6IxOB6SQl!xfe0Bsu^5ta?jH2@C>ubO>AdW5D3~3 zgB48^vNlM7fl;5o-X(7(AMLsl_9CqQ!04b>9B@Y!$X6)ful{{qUv6-l-i?Zcc*3>E@h^m?-m-TeL~ESWimuoPkAtBtUsf}$R*bG3SG2H0JP^M z&Egr@c*-D*FLwc#t}vi^1cyVt`1O-kz$BCc=s2R2bs7232`F1(Gj&{{-kR&!AK3pgoe1v#|3nwW{|C_p7L*>k<{ZvZ zG?l)C=+ehFc&tLnWjHwWheUsfBq+NpKvA&e2q$FJ7c_ACh%GZnJwrd853VKM^25{k zhZb)veO`~lPm-mbZtv%(56vfLz005X@@dE3-}LHVtpvEl$f}X-elNR`Pm+majs1m_ zx{}(CjviBiA+-s@klQ7 zDb5WDG1wBnDr&~Bp|2)q6a2Hmj}3$~i;TCwU}YulBtk=JMe&l1UYKq~6Q51NRin3J zd+IhXt`tvsM3ec1Zi@QE46 zN{7b=>j!U|V)P_FV_TgoEb|Jo3P0P#3|Ht8x?YY&eYaQ$rt|)=DPbeA)*Y*HHT_CFctE;QBn z*%z(f!o_gvLuU^lHvkQ0T>{z1KuZcm(Vri~EdYmiR4*0(g&H;_XCUo~k&>SsKc zN337A06-YaF%tjJIHH14=MwVPg3ouJcEnc>9wQaR%EUXm>~nDafik3T2?Q){w-tIe zD<8-^Kp06Jo1bisOibwz{*}S3=7OU^2ZNS_tmXpA3bZMrXy(~sl?a`m==FS&aBXPV z*xwMK-lA=_A=zlkzJhw+8y2fe$X-kxr_YZDI450pLSq-+?Zcd>#qK8=f!3^@;bW0a zp{)v|&m*Tn%Z*Dffn(wF3$~?<1;AXtFoGfPUzPvIv7&t9iS(^%>|6pFctEkP6kotI zd5?R8op!=yAqJ0vRlS1`Zn&sVzHUe2sD?^rZfG680w^+Rp&JIj;gSeVAa$~(mdj0D zc@G|4Nbbf(9J7`L{TwN(7PEvArLbOtTKkn2682tyY-;k+SOF}qug57WhHK$&-Rt8~S;HO-=fC4TNl>q+;SxJ2-U-``7@bFOiX6bz2ukR> zdYm(YDhy~2HTHleBZz%^Dr)XxDCCS7pFoyi7@VExvBD;h69XUvH}5vra@I8ftYSQZ ze&;D7Q(blo3Zu)-UUXc=E(5qSObP*&bMFZQe zbCeZgmqSz9;6#!Fa0h~B5kyOH=o;@bQ!oFyX@ZJ{e^2uNv?r;nK;MTTm&PcD5RRlt zQxYzG1QwzyRboAryT@(V0@6c+=%D{Co~l^&h62u?|}Da`%*%YiK&An$M;L#f7fX(~@zrm8G-+5N=v<@DZA| zj(1Dvxo_+1K1V6742eS=j1d(rQoCB{NOXC(Ix-ASG{b}lF-FJ;wnr4r6Il*Cqc(a8 z7K^NUGQ+7CHf>_;XpL1D`VdYgL$cT7vc6J6^7T4;tVX}4LF}$}%tZcHx*)}9!XyH> z_t8o)Byg0!Vca3{F`{>PxaoV$$Fbz}w8(2OHTswI1JED{C~68WUJt16E2}zMIlp#& zwDn#FRnjwme=k|CiIKB*Q>et_vy}wtKm%gG)?vzwL4Y<18|opCYV32 zcSXQ?%U}+Dt2Ls*DG?rI9Wl($Y@yP8FiH4KMJi=ETa=>vGa*ath!9Qu6%7|!QE<9DOKWv8M0Bg&vq ztJz$<-py@(#h(sV+W^~L2%%gh46v4SfWR4e554YVvGD-eg7`y*m;-m`6bt{>7BGLp zB$E)j6w&Prj#%*Nak`BWR)??E{()3wgkZ_%TY>5JQV*GG2;xF?VAcYA0a_T`hW-WH zm4Ku_EXfbMlpy~+nqks~Y-Ob%!-iwlwVZ#!C`urwL3o9(ou>ewL1!8V1SgHK&?!sw z-D7oAh!0fRts8&%Fq|2BEKy69gb6>Ubx0x1gP~)+KO|E?ccZ6F>*|_Eg+pRnnP=Bq zkLr%|-SHd$*(VbddxRaX)vuV+!VBC z_maOZ_9;{QDQueph_#8v^60uD$)S~6o5z(pQaMG<*|vHgN7RpWEZ^4ogl7a@>A2Iu zN)8Hb%o5`p<(gTQmvrMRHv6D|R_CYykO&X>i36QTR zR1qTZ#=b~zH}NyVyY91!$mBbTv^zLW#Jwe9qfVS-hv>lUY>4aTrR)4>_|KJH?J`ki z4u2yWvbmYhJ%Bb%GG@H5!AKnE3nz?$G&|beVa=MY#2@S?4ayCMS%fnMXjpe#|cs;#^75#ny0|6K8#@Smh|;Y z{eECZt%G3F3*9!_?%ek}_OHri{}*BAKMhH-tN`FxtRpf{_^mu9AekNe%O1`=`EHtI zl+fXsIM5`4l6zoF0#qJ1Z!E;-OVeDio+2(y7mW{^j9zm!o8Z~7elgm?p!2VlRsb(F zzPJ018wXv_|8A5_S)%*7x`ASMHH1z29J1ufCDyT$emZX~!Hkcru;1Iqboje(`7c;p zMc+QKGmN;04l8W<^ao}Af_bhkqGzw(!lgPLVcr7Qx)FZpg2aotzhMi;pp*_M?B}y102HxUI%^bl4W`(!(K*jbirF>F^*Kt*hYn82cz2J)lSkH&q30| z$>~`SQe^jp673I<4!)zcjYkdR$fiWE;Nq1%l!wM=TQ78{8u@yg)exf_kF#En9f@r| zB^_&xBa(o8F@!ErpfjOK*qPPPdFpyY3)&D5;x1=g08&WPs%j5Ww|DYOQjJu#<&u@; z;C@xV`i<$T6l^tuB7Iz%{IaacMTWK|g@P>J1?9{L7h3(f*sDX7$R&+I)#Fzm*)oGL zB@D}uNWo>Chroq@%WX?t0YE_g4gFAb2eMo0BJ&vPoy3&M5RRNs$aw1~?D*K>P9o;T zMpKNu`jFXG@k6j?&5FlI#ioR{S;Anu1#fQY0ISTx#mLO}c=ut{DPA?c;#HEM+`((d zxy6-k-kHg!)*-C>Yfl!@;d;*Bn(2P1jEK$N5`YiXE!^80*Zj-tzSEA)L4~ zhTqD`^I~;00|9m5gddqi;W?KyA7`Aom%S#|NA4_LbMe~C^22l2luXO~eZ@RQrf*`y z$P^6GxPxV(N{Axy-g)S|qP5P1vGHUv+$n$=n6?cU?)C!i1zHwb{nVq`#i_wpSfPh4g~tTAr5 zqSX_wWEPoq^tdY1H}T``$I&q82%VFXNp!m$dyzzb_P|c)t-fcDebH+NSd;n`7y66A zP0SnZiOM`(!L!#}Z^a^fwM3yoOi@SuX+_~;l07zfc~C#eG51@k;<xxRQ=)#7>I-m0*as;}jihdkj8vs3K}7 zU7565@B85%o%@}QGq-Ez^j(R=N6hsM<1fD;2#8t+-u?8$sjuhMMX`Sv&ssoik-I5u z1u-WuyrVoTP_TokKcxnfwn2aCgjn4R?%x8y!r5kSENqGQJQIrOl-POx7mUi2zs6m7 z1XBp?8G4JC4xoUW>?VwTy=46*Q)1x^fwdaX>#cLan3`5!7B3+v@9tEqtGy!5DR8>% zPLo|A71CjfMuB~f0c>X96=*pzYd}~M)jMjD+o3OOh;cp*tsX@9h3J^DCF=F$ZI6jC| z?!kgrxIM1hCjQ8B#_2oH1_BMQtWSM9rRG5|NnL^zfK}gD8 z{xP@!&P~A|$+SM zt|VVQ0VuA2urNCo{Cz3ngi+XQqLSDG_9oaG>cX${5?7#IBj{CmIB@|zqyhQcf)A`7j)HWt zv+CP9RtO|>`v85wi?Op(17u#x;Pfzkx;lQF1d*B$j3%lO03zi~sz`!XUxFzrb z+6r92VRQl5!=Y`r;dk?c0swA4(*%jS(e%-sMs+xf@B`q;6{G9Gg?)F0%H|EQCT3T} z)_zd&-^vhp`1@8*V8(+hz_|yb_sKd3BfbYlpM=29Z*d`b*z%5tIOS?hsA`=L#P@ zd6m+}iLFNmKai}F=^k4>CeOx3Ua*&nwc|a|79JZhFp@3U@?NvN#r0C>UHU8|SWAwg zKhE8>vrp4St>}mH6RjsoqvhW{yNx8xG9rBf56cHuxHi}N>#sb=@Kp)xW*@^UgekrH zc(eSvGJVIP6K{hR%Cf(r<*cTk6;SLhC+U(7CqKw^+CRf?+=;%Z_ufInMo;)ImX#zr8!GhF_WwwmR=cS$GvMBoct(5PgnzX!H)Gys zf2Mo7^VOIE{9uKy65qQ>gHdmpmI3i79m}f^-4cx+q*Jpfjvs3B{nS7d?=)fHR-7x? z6@0evOgQ(Xb14>R6r`ZEND1LuY{$YTj1{jHgU07Aa>Yx!?fX|niUfL15VQ2dphbFDDXb^8*a#* zyhzgbuTmKVq)hXARUy^q8Qvx5?hA?PO_j&<4#?&17S=h+`hDc$89|19?>*$@?iHnh zW&1+s^4l*=nAnx{bY;X8wVC=)nQ2cx!p>w9$`&n1VdMMLz1|O8GoI++a+4QJw-=1j zz1%@P$Nt4&z{~YABad{Q^4VB_R0v^SC4~A3FWQWH^ewZ`BYIr^^4*hGPG@p*A88ze zttPv7#^gtyk8CDMlSvFrC3}X+gEVKIhQ!pJIu6;g$L}oKdYbZP`Mmr-Fq@C;kd>9# zci;TgSU}BAPo*IOmA-WA^F?fv*2#XYgA|+v2^V6|JMYt5H`0E^LE9X5o-Y@a~+pm9DKT0jNSdoub;%E zp+FX;My^19JZ$1dqIr0H*Qd`qgM{hL7_>U1FWU`qpqpg)TW?^>qbd_nkbaVif|WdLeG7jY%$b9&;d zOWW?}iAWZWw7Ck1{DQrYx%wjNRylbAX=Z8@wBX`&1TlW*AHj z#^s-p7Mi*z^6pL0gtr_>L7p*>duo5f$De!Wl~S=8ahta4%5tJ>nA}3Q{*U}O zP4wMOiA#BT_lk$_E*39%FS$$z7)+dbk=`lul*KA7W?MKii|IGrp(`bvCyjgaBT{l5 z`Ijml+9fJSbDcJS#_t&lTg###h;cKt%M@uH$0$;=%0Su7)O3ZY!nv+TYFe#n*R1D= zzhDo4N`{~%VOKNv#hd{b#c`q8n1kDm%8zRc?i8!HV0u7|9XFrq7(PZ9OzXIBsAV_T z_kHX1MS`I6tdg12*v-4p0b+k}8Ty8!O5z<-5KWG!E_&sYu4|XB#@`kgEHcy9I86(n z*$vmw{ZI6N!Oj^Yb3C>z|2VcpiGXM(boRCA;SkD8Abvt*J^tH9`3vt%Tf;N}@qX(V z$5Gdk0J+EF2Z9qfnGDEYHCK_;gR;~AU5+Q)4l!x`d<(kwvy3?5Sq!QPKl}}?^8;WE z3oFR5+VvwA)=U8NKM*@ph_HnGFSwkG)8byTJWRwrrE^bko;_`COI%L=fYgnM9U0P5%8lG5+VhdV zkeiltM#s&KCF7`zvrrw!8JR<{oYY4!F4%|n|Dm@glQb+PWi@eJ-?S1Pfl>#iz;wnR zy_c@FdfQTEXkKtgyU`YYp594nqDNMVgS=EA9*T6D%5BDmPlrW>87yfD|B0-|J=~q6 zph(j8*}>m-Ie2H4%J*iiu9A$NCLIV?8Z|T2TC=gNJLUGl<5bohoV$``d@F{aP3(J< z)e0Z=bymA+Pu+J$*JIsm>J0xHzG?Kl{m1TgT>zclUo{qQ`gtO7_sO}n6B7NbB~RFp zN^}$uVEKB00rDJ)SJe`hPo&Qvis~3Ow$0<*{_2=;Kb5M3pT79>9E`RL)xGyAJi3)* zjQ50@xBFOd7~ME+*f~CSu!dzQUN=#nyG(88b3GA;`vTIEu1fq0k|vc>&Z+Ycb>-_7 zWKW*8jOCgNVwXt??yhXVW4YB#_C}Y>{($x(JF>hdTG;JMWygpu>FLL6>qs0-^d_Gr zVX`}^+)Q^mrJl>*@*!F+_c_JGO)9(IMzQbSRUQcOaQBLZ$EV6?dpV{n9$Z=aNvF)qbj2sO^I+?HU1Q7u?n=iI9fv{{9)Fd7^ zk0B7T{jk*lsuw8-CPG!OYpOrrD?Bjl6BwXYH}$rlUtvC6uE9HtkQ=yPpWhTM&u)PFC;m&IoNNWo;8hLECQP{e<+&1)`oZ(So+ z`yJQ+Tld61>G==R*dJB9qiDi(5aF@&!A`ZVyNs656YyNv4{~il|LhG)i30GB+OA33 zOaGDjhyhSCx0ogIMO8Imbg4eXg~%$wQnE{|00H;NGLk`=WI6_ta>X#r#J!i+rp2Vsg zzg)UnH+Mn+7S{hMZfNL&$DMX(7O3jj4PPUc#2If>o1ndIy18CLw%umS3*(Ohk#Z4 zw=41|QD$VX$mu56$y)#g^Za=A-a#h~w@_{lF&@L*K!?AxEPgF3t$X^2DDh?anM$M? z!987B$FavyQ*3rg?Y_n`%}D0+xgUbqRQ$Tat5nc!^)}j75e}<x=qXdm;JA@rAf zVqAvGF!dD=I=5|o;yc%jzkfapFpbG-xHW! zMsRIHu(9d%=-(EZ?fy^nA%3Ca_??e`kPqt1ZY6`pW9AJ(u2u*mnv zWP};#tI6genF8#gz+t`&7Ns@a6?n^_6nUN720H>WB1JUg#Bc4FZldhEAIHg4V*9#Q=y@ox@bY*42;=8SIM4g=vsgEF-?k`0_Q~ax066ij@=vpfuG1?Fi@2AA^y&~zu1YZ)^r8fR&2L-FO!L<66!V~#p;n- zM}ML=8zIHXkj2dGQ6)`cAI#UEwjp0IA{*~bU6XfgR| z!%EbbE7(Un2WNmLpXB{la5H|ViS#ctg?{JF3(SkJ1fc)*o0wgY!rq`BBt%Y`&&`T2 zF)on26#F2;$%%KYA{_6{4fFjZwxAGpw z5Gp3ZIgsNSuDF#R5b#}@bU;X7Fv9}(K@^~*l>j0oN{`S_I$Sw|$W3p_pI$xlp=E53 zOY}480JbevPpLT+>&7Zx%GG4GSZVpI`AtNvOEKqsIy`2Q!_Xtd%=ZZfbXv=iVD*nDhbxo@v*;c`qT3Vq*Ix9jq>+f6nT3G+{p7`3J5mi*0)^1 z63bD{eZdHHb??G*MPlK*xy+){V_b&HDjqDt!d8$)`K(@0*E44MN-*(=VtZWSZL9r7 z2W236vo7Ry!c*!2(n4dkkh_PL4~K2nV_(3?9jh!9)984W#{lZ^&fhwNZbUI`Qry+B zt18@eCf(pv1I6~m=sVcsb`9?`P4;<8lz$xTzN)y=>9_OA&6dT9oudXPyweP7r>=b< zuBv9cjlu*_w#|M{&rlb1WLu5Ji}Iqhby+21GyRhq%p-P-Q3$uM63g=wVZ1i?6=fr= z_6+Eq`6S4dYM?Xz;Pu!4 zTbX>pYdfIMJEb9J^-i%P+d|P6E!gyFSW-*-Cr0BYB+0lE zIrM^IuKT{oD*Ac~f>hJGO~-3j!#VEKVK%Q+FO<_vkS1Xn7dBvc*L^o~PV`}}=}QUu zjXqj_b!(LcU~hY^GZx|goRF%n7TE?*41sX#Xg1r0mQGacLIg8XFqYZFq6=JhCcU;c z4KiU~Z=yUnTecx@#2uUK$_+F`i7-1x7FDlcSH`i+KUnu$N+SPNM26oXOrcF6KC3Nv z(SJ05WuP)_LNBe@dwbGTQ3{VTvyC zUWVmPjnZ9`T>v(IqchSinOo-bb$)ZRcUP~w38`wUIh7Wd>wG2C&2N#+v+gYFeW_6w zDVy=PoS}c}{MlmpjX-&`^1cI|9~4M_8HfmxL^RNpam zQS(E#&?j)+JwnZEKGX8m@-?Qr>{J{Gi_>jc+%66+pyiCIXtdkdk?9GaRe86HY?!^{ zZtJD^YKxh1^M1imdev@nAR%f+X!RmGB&lVqx8e_-y>p)q$qce;&&-r}#duNh4^0zX zGvO83#+Mj9yaY&K9RI>U{#Th2w|vKW7;wehqvcTQbpxcRCjGt&oB>UoeQYVnau&;4 zj4sCA8Wt~&D6}uObr!F;{iFcjm_(Y)FtyQJP>m0kkU_^_MvjwB(#QEBm`y$Z5bSMGw|s}L*aAVuU}ytRoaAHeRFmX<+R7l zn;~X(NH;-M`v@HU)zAD&&GK8|%n6)u_4x7EASWmvJJpDXP;~%1l$0L$9*>>M!lyA0 z44B^_Dju`VTjGs0fablsX9}Z6l1c^pO>4q0WU}qY zoiwu-8bd{)r^aZ6j;f=S0O|R+J&d5+V^2gbp{OVvPxWMR zHG{@)K04X>7#F_o)cKe@IqBIa&!avjr^1R39OPTKehL!GOi|cwHYA+`N!N#<=VRST z=69+NU!L`l3bQVA@GkT?7)&8faE9t`cXa5fJU&KRYrwNi$D;#)8kpWW5X@zic|Hk{ zjy~_!X&KL!*NOFg5yl80<{%`y^>;tjkp4?h1`T_p2p zJM!bOSSWLvA96x^TZ{PdBKN2gt=-xhPq`~dkhk<3cTL>_yVb7~Gi=>7zPPf+iTYX**iY+A5&G$^!%yCOL6PF zT5+R@20}NVwKEiMy~9Rzs=}CGSYO}Ic-3AaDn=(0-zBwi0OJ`j>9DL8Ne2vDe?AUde+UBx<5WQM6o|l4%aQrl zz&zNOd9a*jT5R`#{8AfJ`0iL1otDL9YuZ?Lm~Y^!`ak|weK z8vBs3<5{&<_2|Q@BiOEFEb%Iqei2S?M~NdRJCI~GB=bP*tQb3ktsX?P#>4)2oIl>q zAB*ju**h<_iPw7+HBF;s+Mo)?k>`}-GrC-n5CsfFhmT(u5fQeDZAnv7mAqR5Vi_0^2SA9|%!V=e5ej8m*? z-dkp42nKhvEEop?vo#B?9#K1fJ?Q}!NBPN`>?swf(-<|zzQetwSVUm>pn{9OogAe44 zr%z&Xa)X+2JBw(ht|hOIJUK?sFzPm@=H$JlYU0aRA0dv8M4LOvlm}qSgXFvW>Yfyv zH4pPyPbPg4(}!wa0vY{Pvi}v$g0&|ir_|u#AReXkhl2P+-~Inr-$7^Mrp1kFtw6TR z(-_ig9~>At9d=_CR8ofY0EA2l$=EQVHUFb)A^TobFvIpxv5UH^hnNmi3)EAfl?Ll1 zfZi_~(q#cBTqy7sMvCxHuz$}sT5&r75cQTII6W>v2#Dh48{f=LXYz3;0b9#&6S6R1 z-HH6QS@maD+In#kw%hTGA*UUMeYh@X9z38W>y%{kh}Tb{C75OYq@Q3^E@(b@M3|p$ zu`kQN?4!GLL2F~BaL9*Ln&vWFnGwJUu+i2rQsu8oqOp~V1In=*lU@w@_gYdLf~|W^ zW5wNaCbC&`XZhlXhkb(`$hs@(8&dny76=X52Y_K2JTk?3Utkeo(XE2J19)7ECRG=%aa6TpLbDZtS%+lb{t=+ykoRSr5a%Dt+n8`Y{ zZZ`+HBzUX|1WG*_mkAnDC$iR+v>wM-t2( zF(8bxosb6JrSlk*k!cJ1*w7V4*(%;G>;r_Ii@ks+r#O&bAI;`082(x{jcQeMIeEv= z1GOpbY^H}DL?6W%BhT0t0qXC60~^qvoSfh49Q~)v!{Y#1XV+=+>OH4tpOuSWmuDw+ z+(Td53E>R;=;e|A1m?faR%OZnJ9KR07ywdkz&m82pER8k{X{MgW!&pDJtUnlf70+` z%p>wfQ*uX?P2wYvf9s~PSNhb!nw{K-=j8x{kzDm1?DCc}dN{S49Y^R!LbZHyZP3p? z)|DJCTK_Ta6Qpow(|&&8TC>tY1CGy~Ze60Io(+3XH|vzu2#@L@<8lB5oQ|&A=la$p zQ+#in^1V=w=tYi|r?77SZ~9y!|DC10|0(bPw>}3x9@t-4Dzh6olhV9ovffj5nvpr@ zJo3~o?=2)=DF<~IguaCny=WW(Vc}eJ2zAi77o7};L zn$0tI)^^!sn;y~v$speTbQ}9)`P9}>HDP1~bQ+XU2Y&+~vTXj}r6d2>NtzdxF8V#M zb+Xzr-PGRPa63r5w+VSD6=(PZhg3nXi7d1}lY!G>k&gPjT6#Z~-+3$S`Gm}pDx7|} zE4X?xbAiD`F(oVCM?%PWstb+<+hsZeY7^-C{z^CL+eUG$`f{!URc8}6Auk?$cpbw+ i6wrn&3&CSmKm59f;_uyZ_`mKo{#L>8%{=I4|Nj62mm9AD diff --git a/sepes/__init__.py b/sepes/__init__.py index 75f69f1..d4ff6a3 100644 --- a/sepes/__init__.py +++ b/sepes/__init__.py @@ -15,69 +15,37 @@ from sepes._src.backend import backend_context from sepes._src.code_build import autoinit, field, fields from sepes._src.tree_base import TreeClass -from sepes._src.tree_index import AtIndexer, BaseKey, at -from sepes._src.tree_mask import ( - freeze, - is_frozen, - is_nondiff, - tree_mask, - tree_unmask, - unfreeze, -) -from sepes._src.tree_pprint import ( - tree_diagram, - tree_graph, - tree_mermaid, - tree_repr, - tree_str, - tree_summary, -) -from sepes._src.tree_util import ( - Partial, - bcmap, - is_tree_equal, - leafwise, - partial, - value_and_tree, -) +from sepes._src.tree_index import at +from sepes._src.tree_mask import is_masked, tree_mask, tree_unmask +from sepes._src.tree_pprint import tree_diagram, tree_repr, tree_str, tree_summary +from sepes._src.tree_util import bcmap, leafwise, value_and_tree -__all__ = ( - # general utils +__all__ = [ + # module utils "TreeClass", - "is_tree_equal", - "field", - "fields", - "autoinit", # pprint utils "tree_diagram", - "tree_graph", - "tree_mermaid", "tree_repr", "tree_str", "tree_summary", # masking utils - "is_nondiff", - "is_frozen", - "freeze", - "unfreeze", + "is_masked", "tree_unmask", "tree_mask", - # indexing utils - "AtIndexer", - "at", - "BaseKey", # tree utils + "at", "bcmap", - "Partial", - "partial", - "leafwise", "value_and_tree", + # construction utils + "field", + "fields", + "autoinit", + "leafwise", # backend utils "backend_context", -) +] -__version__ = "0.11.3" +__version__ = "0.12.0" -AtIndexer.__module__ = "sepes" +at.__module__ = "sepes" TreeClass.__module__ = "sepes" -Partial.__module__ = "sepes" diff --git a/sepes/_src/backend/arraylib/__init__.py b/sepes/_src/backend/arraylib/__init__.py index bbdd925..2bfeeaa 100644 --- a/sepes/_src/backend/arraylib/__init__.py +++ b/sepes/_src/backend/arraylib/__init__.py @@ -14,20 +14,32 @@ """Backend tools for sepes.""" from __future__ import annotations + import functools as ft +from typing import Callable, NamedTuple + + +class NoImplError(NamedTuple): + op: Callable + + def __call__(self, *args, **kwargs): + raise NotImplementedError(f"No implementation for {self.op}" + f" with {args=} {kwargs=}") + -tobytes = ft.singledispatch(lambda array: ...) -where = ft.singledispatch(lambda condition, x, y: ...) -nbytes = ft.singledispatch(lambda array: ...) -shape = ft.singledispatch(lambda array: ...) -dtype = ft.singledispatch(lambda array: ...) -min = ft.singledispatch(lambda array: ...) -max = ft.singledispatch(lambda array: ...) -mean = ft.singledispatch(lambda array: ...) -std = ft.singledispatch(lambda array: ...) -all = ft.singledispatch(lambda array: ...) -is_floating = ft.singledispatch(lambda array: ...) -is_integer = ft.singledispatch(lambda array: ...) -is_inexact = ft.singledispatch(lambda array: ...) -is_bool = ft.singledispatch(lambda array: ...) -ndarrays: tuple[type, ...] = () +tobytes = ft.singledispatch(NoImplError("tobytes")) +where = ft.singledispatch(NoImplError("where")) +nbytes = ft.singledispatch(NoImplError("nbytes")) +shape = ft.singledispatch(NoImplError("shape")) +dtype = ft.singledispatch(NoImplError("dtype")) +min = ft.singledispatch(NoImplError("min")) +max = ft.singledispatch(NoImplError("max")) +mean = ft.singledispatch(NoImplError("mean")) +std = ft.singledispatch(NoImplError("std")) +all = ft.singledispatch(NoImplError("all")) +array_equal = ft.singledispatch(NoImplError("array_equal")) +is_floating = ft.singledispatch(NoImplError("is_floating")) +is_integer = ft.singledispatch(NoImplError("is_integer")) +is_inexact = ft.singledispatch(NoImplError("is_inexact")) +is_bool = ft.singledispatch(NoImplError("is_bool")) +ndarrays: list[type] = [] diff --git a/sepes/_src/backend/arraylib/jax.py b/sepes/_src/backend/arraylib/jax.py index 1022f8e..c494b78 100644 --- a/sepes/_src/backend/arraylib/jax.py +++ b/sepes/_src/backend/arraylib/jax.py @@ -14,12 +14,13 @@ from __future__ import annotations - -from jax import Array import jax.numpy as jnp +import numpy as np +from jax import Array + import sepes._src.backend.arraylib as arraylib -arraylib.tobytes.register(Array, lambda x: jnp.array(x).tobytes()) +arraylib.tobytes.register(Array, lambda x: np.array(x).tobytes()) arraylib.where.register(Array, jnp.where) arraylib.nbytes.register(Array, lambda x: x.nbytes) arraylib.shape.register(Array, jnp.shape) @@ -29,8 +30,9 @@ arraylib.mean.register(Array, jnp.mean) arraylib.std.register(Array, jnp.std) arraylib.all.register(Array, jnp.all) +arraylib.array_equal.register(Array, np.array_equal) # NOTE: not traceable arraylib.is_floating.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.floating)) arraylib.is_integer.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.integer)) arraylib.is_inexact.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.inexact)) arraylib.is_bool.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.bool_)) -arraylib.ndarrays += (Array,) +arraylib.ndarrays.append(Array) diff --git a/sepes/_src/backend/arraylib/numpy.py b/sepes/_src/backend/arraylib/numpy.py index 285c916..1bf8d55 100644 --- a/sepes/_src/backend/arraylib/numpy.py +++ b/sepes/_src/backend/arraylib/numpy.py @@ -16,6 +16,7 @@ import numpy as np from numpy import ndarray + import sepes._src.backend.arraylib as arraylib arraylib.tobytes.register(ndarray, lambda x: np.array(x).tobytes()) @@ -28,8 +29,9 @@ arraylib.mean.register(ndarray, np.mean) arraylib.std.register(ndarray, np.std) arraylib.all.register(ndarray, np.all) +arraylib.array_equal.register(ndarray, np.array_equal) arraylib.is_floating.register(ndarray, lambda x: np.issubdtype(x.dtype, np.floating)) arraylib.is_integer.register(ndarray, lambda x: np.issubdtype(x.dtype, np.integer)) arraylib.is_inexact.register(ndarray, lambda x: np.issubdtype(x.dtype, np.inexact)) arraylib.is_bool.register(ndarray, lambda x: np.issubdtype(x.dtype, np.bool_)) -arraylib.ndarrays += (ndarray,) +arraylib.ndarrays.append(ndarray) diff --git a/sepes/_src/backend/arraylib/torch.py b/sepes/_src/backend/arraylib/torch.py index 696a309..6ddac57 100644 --- a/sepes/_src/backend/arraylib/torch.py +++ b/sepes/_src/backend/arraylib/torch.py @@ -17,6 +17,7 @@ import numpy as np import torch from torch import Tensor + import sepes._src.backend.arraylib as arraylib floatings = [torch.float16, torch.float32, torch.float64] @@ -33,8 +34,9 @@ arraylib.mean.register(Tensor, torch.mean) arraylib.std.register(Tensor, torch.std) arraylib.all.register(Tensor, torch.all) +arraylib.array_equal.register(Tensor, torch.equal) arraylib.is_floating.register(Tensor, lambda x: x.dtype in floatings) arraylib.is_integer.register(Tensor, lambda x: x.dtype in integers) arraylib.is_inexact.register(Tensor, lambda x: x.dtype in floatings + complexes) arraylib.is_bool.register(Tensor, lambda x: x.dtype == torch.bool) -arraylib.ndarrays += (Tensor,) +arraylib.ndarrays.append(Tensor) diff --git a/sepes/_src/backend/treelib/__init__.py b/sepes/_src/backend/treelib/__init__.py index 90655b7..a6cb835 100644 --- a/sepes/_src/backend/treelib/__init__.py +++ b/sepes/_src/backend/treelib/__init__.py @@ -61,7 +61,7 @@ class AbstractTreeLib(abc.ABC): @staticmethod @abc.abstractmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -72,7 +72,7 @@ def tree_map( @staticmethod @abc.abstractmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -83,7 +83,7 @@ def tree_path_map( @staticmethod @abc.abstractmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -92,7 +92,7 @@ def tree_flatten( @staticmethod @abc.abstractmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -101,7 +101,7 @@ def tree_path_flatten( @staticmethod @abc.abstractmethod - def tree_unflatten(treedef: Any, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: Any, leaves: Iterable[Any]) -> Any: ... @staticmethod diff --git a/sepes/_src/backend/treelib/jax.py b/sepes/_src/backend/treelib/jax.py index cf49f8a..43bc9d7 100644 --- a/sepes/_src/backend/treelib/jax.py +++ b/sepes/_src/backend/treelib/jax.py @@ -36,7 +36,7 @@ def __str__(self): class JaxTreeLib(AbstractTreeLib): @staticmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -51,7 +51,7 @@ def tree_map( return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -66,7 +66,7 @@ def tree_path_map( return jtu.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -74,7 +74,7 @@ def tree_flatten( return jtu.tree_flatten(tree, is_leaf=is_leaf) @staticmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -82,7 +82,7 @@ def tree_path_flatten( return jtu.tree_flatten_with_path(tree, is_leaf=is_leaf) @staticmethod - def tree_unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: jtu.PyTreeDef, leaves: Iterable[Any]) -> Any: return jtu.tree_unflatten(treedef, leaves) @staticmethod diff --git a/sepes/_src/backend/treelib/optree.py b/sepes/_src/backend/treelib/optree.py index 78015ad..4747494 100644 --- a/sepes/_src/backend/treelib/optree.py +++ b/sepes/_src/backend/treelib/optree.py @@ -61,7 +61,7 @@ def __str__(self) -> str: class OpTreeTreeLib(AbstractTreeLib): @staticmethod - def tree_map( + def map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -76,7 +76,7 @@ def tree_map( return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_path_map( + def path_map( func: Callable[..., Any], tree: Any, *rest: Any, @@ -92,7 +92,7 @@ def tree_path_map( return ot.tree_unflatten(treedef, concurrent_map(func, flat, **config)) @staticmethod - def tree_flatten( + def flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -101,7 +101,7 @@ def tree_flatten( return (leaves, treedef) @staticmethod - def tree_path_flatten( + def path_flatten( tree: Any, *, is_leaf: Callable[[Any], bool] | None = None, @@ -110,7 +110,7 @@ def tree_path_flatten( return (list(zip(ot.treespec_paths(treedef), leaves)), treedef) @staticmethod - def tree_unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any: + def unflatten(treedef: ot.PyTreeDef, leaves: Iterable[Any]) -> Any: return ot.tree_unflatten(treedef, leaves) @staticmethod diff --git a/sepes/_src/code_build.py b/sepes/_src/code_build.py index 4889a8c..e2e44bd 100644 --- a/sepes/_src/code_build.py +++ b/sepes/_src/code_build.py @@ -294,6 +294,7 @@ def field( Buffer creation using :attr:`on_getattr`: >>> import sepes as sp + >>> import jax >>> import jax.numpy as jnp >>> @sp.autoinit ... class Tree(sp.TreeClass): @@ -308,6 +309,7 @@ def field( Parameterization using :attr:`on_getattr`: >>> import sepes as sp + >>> import jax >>> import jax.numpy as jnp >>> def symmetric(array: jax.Array) -> jax.Array: ... triangle = jnp.triu(array) # upper triangle diff --git a/sepes/_src/tree_base.py b/sepes/_src/tree_base.py index 978ada2..1f5a0b3 100644 --- a/sepes/_src/tree_base.py +++ b/sepes/_src/tree_base.py @@ -19,14 +19,13 @@ import abc from typing import Any, Hashable, TypeVar -from typing_extensions import Unpack +from typing_extensions import Self, Unpack import sepes from sepes._src.code_build import fields +from sepes._src.tree_index import at from sepes._src.tree_pprint import PPSpec, tree_repr, tree_str from sepes._src.tree_util import is_tree_equal, tree_copy, tree_hash, value_and_tree -from typing_extensions import Self -from sepes._src.tree_index import AtIndexer T = TypeVar("T", bound=Hashable) S = TypeVar("S") @@ -148,11 +147,11 @@ class TreeClass(metaclass=TreeClassMeta): the tree. for example: >>> @sp.leafwise - ... @sp.autoinit ... class Tree(sp.TreeClass): - ... a:int = 1 - ... b:float = 2.0 - >>> tree = Tree() + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b + >>> tree = Tree(a=1, b=2.0) >>> tree + 1 # will add 1 to each leaf Tree(a=2, b=3.0) @@ -161,45 +160,16 @@ class TreeClass(metaclass=TreeClassMeta): used to ``get``, ``set``, or ``apply`` a function to a leaf or a group of leaves using ``leaf`` name, index or by a boolean mask. - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a:int = 1 - ... b:float = 2.0 - >>> tree = Tree() + >>> class Tree(sp.TreeClass): + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b + >>> tree = Tree(a=1, b=2.0) >>> tree.at["a"].get() Tree(a=1, b=None) >>> tree.at[0].get() Tree(a=1, b=None) - Note: - - Under ``jax.tree_util.***`` or ``optree`` all :class:`.TreeClass` - attributes are treated as leaves. - - To hide/ignore a specific attribute from the tree leaves, during - ``jax.tree_util.***`` operations, freeze the leaf using :func:`.freeze` - or :func:`.tree_mask`. - - >>> # freeze(exclude) a leaf from the tree leaves: - >>> import jax - >>> import sepes as sp - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a:int = 1 - ... b:float = 2.0 - >>> tree = Tree() - >>> tree = tree.at["a"].apply(sp.freeze) - >>> jax.tree_util.tree_leaves(tree) - [2.0] - - >>> # undo the freeze - >>> tree = tree.at["a"].apply(sp.unfreeze, is_leaf=sp.is_frozen) - >>> jax.tree_util.tree_leaves(tree) - [1, 2.0] - - >>> # using `tree_mask` to exclude a leaf from the tree leaves - >>> freeze_mask = Tree(a=True, b=False) - >>> jax.tree_util.tree_leaves(sp.tree_mask(tree, freeze_mask)) - [2.0] - Note: ``AttributeError`` is raised, If a method that mutates the instance is called directly. Instead use :func:`.value_and_tree` to call @@ -236,23 +206,23 @@ def __init_subclass__(klass: type[T], **k): if "__delattr__" in vars(klass): raise TypeError(f"Reserved method `__delattr__` defined in `{klass}`.") super().__init_subclass__(**k) - # register the class with the proper tree backend. - # the registration envolves defining two rules: how to flatten the nested - # structure of the class and how to unflatten the flattened structure. - # The flatten rule for `TreeClass` is equivalent to vars(self). and the - # unflatten rule is equivalent to `klass(**flat_tree)`. The flatten/unflatten - # rule is exactly same as the flatten rule for normal dictionaries. + # - register the class with the proper tree backend. + # - the registration envolves defining two rules: how to flatten the nested + # structure of the class and how to unflatten the flattened structure. + # The flatten rule for `TreeClass` is equivalent to vars(self). and the + # unflatten rule is equivalent to `klass(**flat_tree)`. The flatten/unflatten + # rule is exactly same as the flatten rule for normal dictionaries. treelib = sepes._src.backend.treelib treelib.register_treeclass(klass) def __setattr__(self, key: str, value: Any) -> None: - # implements the controlled mutability behavior. - # In essence, setattr is allowed to set attributes during initialization - # and during functional call using .at["method"](*, **) by marking the - # instnace as mutable. Otherwise, setattr is disallowed. - # recall that during the functional call using .at["method"](*, **) - # the tree is always copied and the copy is marked as mutable, thus - # setattr is allowed to set attributes on the copy not the original. + # - implements the controlled mutability behavior. + # - In essence, setattr is allowed to set attributes during initialization + # and during functional call using `value_and_tree(method)(*, **)` by marking the + # instnace as mutable. Otherwise, setattr is disallowed. + # - recall that during the functional call using `value_and_tree(method)(*, **)` + # the tree is always copied and the copy is marked as mutable, thus + # setattr is allowed to set attributes on the copy not the original. if id(self) not in _mutable_instance_registry: raise AttributeError( f"Cannot set attribute {value=} to `{key=}` " @@ -262,13 +232,13 @@ def __setattr__(self, key: str, value: Any) -> None: getattr(object, "__setattr__")(self, key, value) def __delattr__(self, key: str) -> None: - # same as __setattr__ but for delattr. - # both __setattr__ and __delattr__ are used to implement the - # controlled mutability behavior during initialization and - # during functional call using .at["method"](*, **). - # recall that during the functional call using .at["method"](*, **) - # the tree is always copied and the copy is marked as mutable, thus - # setattr is allowed to set attributes on the copy not the original. + # - same as __setattr__ but for delattr. + # - both __setattr__ and __delattr__ are used to implement the + # - controlled mutability behavior during initialization and + # during functional call using `value_and_tree(method)(*, **)`. + # - recall that during the functional call using `value_and_tree(method)(*, **)` + # the tree is always copied and the copy is marked as mutable, thus + # setattr is allowed to set attributes on the copy not the original. if id(self) not in _mutable_instance_registry: raise AttributeError( f"Cannot delete attribute `{key}` " @@ -277,7 +247,7 @@ def __delattr__(self, key: str) -> None: getattr(object, "__delattr__")(self, key) @property - def at(self) -> AtIndexer[Self]: + def at(self) -> at[Self]: """Immutable out-of-place indexing. - ``.at[***].get()``: @@ -292,20 +262,18 @@ def at(self) -> AtIndexer[Self]: - ``int`` for positional indexing for sequences. - ``...`` to select all leaves. - a boolean mask of the same structure as the tree - - ``re.Pattern`` to index all keys matching a regex pattern. - - an instance of ``BaseKey`` with custom logic to index a pytree. - a tuple of the above types to index multiple keys at same level. Example: >>> import sepes as sp - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a: int = 1 - ... b: float = 2.0 + >>> class Tree(sp.TreeClass): + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b ... def add(self, x: int) -> int: ... self.a += x ... return self.a - >>> tree = Tree() + >>> tree = Tree(a=1, b=2.0) >>> tree.at["a"].get() Tree(a=1, b=None) >>> tree.at["a"].set(100) @@ -317,7 +285,10 @@ def at(self) -> AtIndexer[Self]: - ``pytree.at[*][**]`` is equivalent to selecting pytree.*.** . - ``pytree.at[*, **]`` is equivalent selecting pytree.* and pytree.** """ - return AtIndexer(self) + # NOTE: use `at` as a property to enable chaining syntax. + # instead of at(at(tree)[...].apply(...))[...].set(...) + # chaining syntax is tree.at[...].apply(...).at[...].set(...) + return at(self) def __repr__(self) -> str: return tree_repr(self) diff --git a/sepes/_src/tree_index.py b/sepes/_src/tree_index.py index 721c388..4bb98f4 100644 --- a/sepes/_src/tree_index.py +++ b/sepes/_src/tree_index.py @@ -12,31 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Define lens-like indexing/masking for pytrees.""" - -# enable get/set/apply/scan/reduce operations on selected parts of a nested -# structure -pytree- in out-of-place manner. this process invovles defining two -# parts: 1) *where* to select the parts of the pytree and 2) *what* to do with -# the selected parts. the *where* part is defined either by a path or a boolean -# mask. the *what* part is defined by a set value, or a function to apply to -# the selected parts. once we have a *final* boolean mask that encompasses all -# path and the boolean mask, we can use `tree_map` to apply the *what* part to -# the *where* part. for example, for a tree = [[1, 2], 3, 4] and boolean mask -# [[True, False], False, True] and path mask [0][1], then we select only leaf -# 1 that is at the intersection of the boolean mask and the path mask. then we -# apply the *what* part to the *where* part. +"""Define lens-like indexing for pytrees + +This module provides a way to index and mask pytrees (e.g. TreeClass) in an +out-of-place manner.Out-of-place means that the original pytree is not modified, +instead a new pytree with the selected leaves are modified. + +The indexing is done through two concepts: + +1) Selection (Where): Determines parts of the pytree for manipulation via a path or a boolean mask. +2) Operation (What): Defines actions on selected parts, such as setting values or applying functions. + +For example, the following code defines a dict pytree with where of same structure +as the tree. The where (Selection) defines which parts of the tree to select and +the set (Operation) operation sets the selected parts to 100. + +>>> import sepes as sp +>>> tree = {"a": 1, "b": [1, 2, 3]} +>>> where = {"a": True, "b": [False, True, False]} +>>> sp.at(tree)[where].set(100) +{'a': 100, 'b': [1, 100, 3]} +""" from __future__ import annotations import abc import functools as ft import re -from typing import Any, Callable, Hashable, Tuple, TypeVar, Generic +from typing import Any, Callable, Generic, Hashable, TypeVar, Sequence from typing_extensions import Self import sepes import sepes._src.backend.arraylib as arraylib +from sepes._src.backend import is_package_avaiable from sepes._src.backend.treelib import ParallelConfig from sepes._src.tree_pprint import tree_repr @@ -44,241 +53,21 @@ S = TypeVar("S") PyTree = Any EllipsisType = TypeVar("EllipsisType") -KeyEntry = TypeVar("KeyEntry", bound=Hashable) -KeyPath = Tuple[KeyEntry, ...] +PathKeyEntry = TypeVar("PathKeyEntry", bound=Hashable) _no_initializer = object() +_no_fill_value = object() class BaseKey(abc.ABC): - """Parent class for all match classes. - - - Subclass this class to create custom match keys by implementing - the `__eq__` method. The ``__eq__`` method should return True if the - key matches the given path entry and False otherwise. The path entry - refers to the entry defined in the ``tree_flatten_with_keys`` method of - the pytree class. - - - Typical path entries in ``jax`` are: - - - ``jax.tree_util.GetAttrKey`` for attributes - - ``jax.tree_util.DictKey`` for mapping keys - - ``jax.tree_util.SequenceKey`` for sequence indices - - - When implementing the ``__eq__`` method you can use the ``singledispatchmethod`` - to unpack the path entry for example: - - - ``jax.tree_util.GetAttrKey`` -> `key.name` - - ``jax.tree_util.DictKey`` -> `key.key` - - ``jax.tree_util.SequenceKey`` -> `key.index` - - - See Examples for more details. - - Example: - >>> # define an match strategy to match a leaf with a given name and type - >>> import sepes as sp - >>> from typing import NamedTuple - >>> import jax - >>> class NameTypeContainer(NamedTuple): - ... name: str - ... type: type - >>> @jax.tree_util.register_pytree_with_keys_class - ... class Tree: - ... def __init__(self, a, b) -> None: - ... self.a = a - ... self.b = b - ... def tree_flatten_with_keys(self): - ... ak = (NameTypeContainer("a", type(self.a)), self.a) - ... bk = (NameTypeContainer("b", type(self.b)), self.b) - ... return (ak, bk), None - ... @classmethod - ... def tree_unflatten(cls, aux_data, children): - ... return cls(*children) - ... @property - ... def at(self): - ... return sp.at(self) - >>> tree = Tree(1, 2) - >>> class MatchNameType(sp.BaseKey): - ... def __init__(self, name, type): - ... self.name = name - ... self.type = type - ... def __eq__(self, other): - ... if isinstance(other, NameTypeContainer): - ... return other == (self.name, self.type) - ... return False - >>> tree = tree.at[MatchNameType("a", int)].get() - >>> assert jax.tree_util.tree_leaves(tree) == [1] - - Note: - - use ``BaseKey.def_alias(type, func)`` to define an index type alias - for `BaseKey` subclasses. This is useful for convience when - creating new match strategies. - - >>> import sepes as sp - >>> import functools as ft - >>> from types import FunctionType - >>> import jax.tree_util as jtu - >>> # lets define a new match strategy called `FuncKey` that applies - >>> # a function to the path entry and returns True if the function - >>> # returns True and False otherwise. - >>> # for example `FuncKey(lambda x: x.startswith("a"))` will match - >>> # all leaves that start with "a". - >>> class FuncKey(sp.BaseKey): - ... def __init__(self, func): - ... self.func = func - ... @ft.singledispatchmethod - ... def __eq__(self, key): - ... return self.func(key) - ... @__eq__.register(jtu.GetAttrKey) - ... def _(self, key: jtu.GetAttrKey): - ... # unpack the GetAttrKey - ... return self.func(key.name) - ... @__eq__.register(jtu.DictKey) - ... def _(self, key: jtu.DictKey): - ... # unpack the DictKey - ... return self.func(key.key) - ... @__eq__.register(jtu.SequenceKey) - ... def _(self, key: jtu.SequenceKey): - ... return self.func(key.index) - >>> # instead of using ``FuncKey(function)`` we can define an alias - >>> # for `FuncKey`, for this example we will define any FunctionType - >>> # as a `FuncKey` by default. - >>> @sp.BaseKey.def_alias(FunctionType) - ... def _(func): - ... return FuncKey(func) - >>> # create a simple pytree - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a: int - ... b: str - >>> tree = Tree(1, "string") - >>> # now we can use the `FuncKey` alias to match all leaves that - >>> # are strings and start with "a" - >>> tree.at[lambda x: isinstance(x, str) and x.startswith("a")].get() - Tree(a=1, b=None) - """ + """Parent class for all match classes.""" @abc.abstractmethod - def __eq__(self, entry: KeyEntry) -> bool: + def __eq__(self, entry: PathKeyEntry) -> bool: pass - broadcastable: bool = False - - -class IndexKey(BaseKey): - """Match a leaf with a given index.""" - - def __init__(self, idx: int) -> None: - self.idx = idx - - def __eq__(self, key: KeyEntry) -> bool: - if isinstance(key, int): - return self.idx == key - treelib = sepes._src.backend.treelib - if isinstance(key, type(treelib.sequence_key(0))): - return self.idx == key.idx - return False - - def __repr__(self) -> str: - return f"{self.idx}" - - -class NameKey(BaseKey): - """Match a leaf with a given key.""" - - def __init__(self, name: str) -> None: - self.name = name - - def __eq__(self, key: KeyEntry) -> bool: - if isinstance(key, str): - return self.name == key - treelib = sepes._src.backend.treelib - if isinstance(key, type(treelib.attribute_key(""))): - return self.name == key.name - if isinstance(key, type(treelib.dict_key(""))): - return self.name == key.key - return False - - def __repr__(self) -> str: - return f"{self.name}" - - -class EllipsisKey(BaseKey): - """Match all leaves.""" - - broadcastable = True - - def __init__(self, _): - del _ - - def __eq__(self, _: KeyEntry) -> bool: - return True - - def __repr__(self) -> str: - return "..." - - -class MultiKey(BaseKey): - """Match a leaf with multiple keys at the same level.""" - - def __init__(self, *keys: tuple[BaseKey, ...]): - self.keys = tuple(keys) - - def __eq__(self, entry) -> bool: - return any(entry == key for key in self.keys) - - def __repr__(self) -> str: - return f"({', '.join(map(repr, self.keys))})" - - -class RegexKey(BaseKey): - """Match a leaf with a regex pattern inside 'at' property. - - Args: - pattern: regex pattern to match. - - Example: - >>> import sepes as sp - >>> import re - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... weight_1: float = 1.0 - ... weight_2: float = 2.0 - ... weight_3: float = 3.0 - ... bias: float = 0.0 - >>> tree = Tree() - >>> tree.at[re.compile(r"weight_.*")].set(100.0) # set all weights to 100.0 - Tree(weight_1=100.0, weight_2=100.0, weight_3=100.0, bias=0.0) - """ - - def __init__(self, pattern: str) -> None: - self.pattern = pattern - - def __eq__(self, key: KeyEntry) -> bool: - if isinstance(key, str): - return re.fullmatch(self.pattern, key) is not None - treelib = sepes._src.backend.treelib - if isinstance(key, type(treelib.attribute_key(""))): - return re.fullmatch(self.pattern, key.name) is not None - if isinstance(key, type(treelib.dict_key(""))): - return re.fullmatch(self.pattern, key.key) is not None - return False - - def __repr__(self) -> str: - return f"{self.pattern}" - - -# dispatch on type of indexer to convert input item to at indexer -# `__getitem__` to the appropriate key -# avoid using container pytree types to avoid conflict between -# matching as a mask or as an instance of `BaseKey` -indexer_dispatcher = ft.singledispatch(lambda x: x) -indexer_dispatcher.register(type(...), EllipsisKey) -indexer_dispatcher.register(int, IndexKey) -indexer_dispatcher.register(str, NameKey) -indexer_dispatcher.register(re.Pattern, RegexKey) - -BaseKey.def_alias = indexer_dispatcher.register + @property + @abc.abstractmethod + def broadcast(self): ... _INVALID_INDEXER = """\ @@ -286,14 +75,13 @@ def __repr__(self) -> str: - `str` for mapping keys or class attributes. - `int` for positional indexing for sequences. - `...` to select all leaves. + - ``re.Pattern`` to match a leaf level path with a regex pattern. - Boolean mask of a compatible structure as the pytree. - - `re.Pattern` to index all keys matching a regex pattern. - - Instance of `BaseKey` with custom logic to index a pytree. - `tuple` of the above types to match multiple leaves at the same level. """ _NO_LEAF_MATCH = """\ -No leaf match is found for where={where}. Available keys are {names}. +No leaf match is found for where={where}, Available keys are {names} Check the following: - If where is `str` then check if the key exists as a key or attribute. - If where is `int` then check if the index is in range. @@ -321,9 +109,9 @@ def is_leaf_func(node) -> bool: return False return True - return treelib.tree_path_map(func, tree, is_leaf=is_leaf_func) + return treelib.path_map(func, tree, is_leaf=is_leaf_func) - if any(mask.broadcastable for mask in where): + if any(where_i.broadcast for where_i in where): # should the selected subtree be broadcasted to the full tree # e.g. tree = [[1, 2], 3, 4] and where = [0], then # broadcast with True will be [[True, True], False, False] @@ -334,8 +122,8 @@ def is_leaf_func(node) -> bool: # and without broadcast the result will be [100, 3, 4] def bool_tree(value: bool, tree: Any): - leaves, treedef = treelib.tree_flatten(tree, is_leaf=is_leaf) - return treelib.tree_unflatten(treedef, [value] * len(leaves)) + leaves, treedef = treelib.flatten(tree, is_leaf=is_leaf) + return treelib.unflatten(treedef, [value] * len(leaves)) true_tree = ft.partial(bool_tree, True) false_tree = ft.partial(bool_tree, False) @@ -380,9 +168,10 @@ def path_map_func(path, leaf): mask = one_level_tree_path_map(path_map_func, tree) if not match: - path_leaf, _ = treelib.tree_path_flatten(tree, is_leaf=is_leaf) + path_leaf, _ = treelib.path_flatten(tree, is_leaf=is_leaf) + path = "/".join(str(where_i.input) for where_i in where) names = "".join("\n - " + treelib.keystr(path) for path, _ in path_leaf) - raise LookupError(_NO_LEAF_MATCH.format(where=where, names=names)) + raise LookupError(_NO_LEAF_MATCH.format(where=path, names=names)) return mask @@ -390,9 +179,10 @@ def path_map_func(path, leaf): def resolve_where( where: list[Any], tree: T, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, ): treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) def combine_bool_leaves(*leaves): # given a list of boolean leaves, combine them using `and` @@ -404,7 +194,7 @@ def combine_bool_leaves(*leaves): return verdict def is_bool_leaf(leaf: Any) -> bool: - if isinstance(leaf, arraylib.ndarrays): + if isinstance(leaf, ndarrays): return arraylib.is_bool(leaf) return isinstance(leaf, bool) @@ -423,7 +213,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: nonlocal seen_tuple, level_paths, bool_masks # used to check if a pytree is a valid indexing pytree # used with `is_leaf` argument of any `tree_*` function - leaves, _ = treelib.tree_flatten(node) + leaves, _ = treelib.flatten(node) if all(map(is_bool_leaf, leaves)): # if all leaves are boolean then this is maybe a boolean mask. @@ -442,7 +232,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: bool_masks += [node] return True - if isinstance(resolved_key := indexer_dispatcher(node), BaseKey): + if isinstance(resolved_key := at.dispatcher(node), BaseKey): # valid resolution of `BaseKey` is a valid indexing leaf # makes it possible to dispatch on multi-leaf pytree level_paths += [resolved_key] @@ -463,7 +253,7 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: # each for loop iteration is a level in the where path # this means that if where = ("a", "b", "c") then this means # we are travering the tree at level "a" then level "b" then level "c" - treelib.tree_flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf) + treelib.flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf) # if len(level_paths) > 1 then this means that we have multiple keys # at the same level, for example where = ("a", ("b", "c")) then this # means that for a parent "a", select "b" and "c". @@ -476,17 +266,14 @@ def verify_and_aggregate_is_leaf(node: Any) -> bool: if bool_masks: all_masks = [mask, *bool_masks] if mask else bool_masks - mask = treelib.tree_map(combine_bool_leaves, *all_masks) + mask = treelib.map(combine_bool_leaves, *all_masks) return mask -class AtIndexer(Generic[T]): +class at(Generic[T]): """Operate on a pytree at a given path using a path or mask in out-of-place manner. - Note: - Use :class:`.at` as a shorter alias for this class. - Args: tree: pytree to operate on. where: one of the following: @@ -495,13 +282,9 @@ class AtIndexer(Generic[T]): - ``int`` for positional indexing for sequences. - ``...`` to select all leaves. - a boolean mask of the same structure as the tree - - ``re.Pattern`` to index all keys matching a regex pattern. - - an instance of ``BaseKey`` with custom logic to index a pytree. + - ``re.Pattern`` to match a leaf level path with a regex pattern. - a tuple of the above to match multiple keys at the same level. - Note: - Alternatively, use ``at(tree)[where]`` to index a pytree. - Example: >>> import jax >>> import sepes as sp @@ -514,26 +297,23 @@ class AtIndexer(Generic[T]): >>> sp.at(tree)[mask].set(100) {'a': 1, 'b': [1, 100, 100]} """ - def __init__(self, tree: T, where: list[Any] | None = None) -> None: - vars(self)["tree"] = tree - vars(self)["where"] = [] if where is None else where - - def __setattr__(self, key: str, _: Any) -> None: - raise AttributeError(f"Cannot set {key=} on {type(self).__name__} instance") + self.tree = tree + self.where = [] if where is None else where def __getitem__(self, where: Any) -> Self: """Index a pytree at a given path using a path or mask.""" return type(self)(self.tree, [*self.where, where]) def __repr__(self) -> str: - return f"{type(self).__name__}(tree={tree_repr(self.tree)}, where={self.where})" + return f"{type(self).__name__}({tree_repr(self.tree)}, where={self.where})" def get( self, *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, + fill_value: Any = _no_fill_value, ): """Get the leaf values at the specified location. @@ -547,6 +327,10 @@ def get( - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. + fill_value: the value to fill the non-selected leaves with. + Useful to use with ``jax.jit`` to avoid variable size arrays + leaves related errors. + Returns: A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array. @@ -558,19 +342,25 @@ def get( {'a': None, 'b': [1, None, None]} """ treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) def leaf_get(where: Any, leaf: Any): # support both array and non-array leaves # for array boolean mask we select **parts** of the array that # matches the mask, for example if the mask is Array([True, False, False]) # and the leaf is Array([1, 2, 3]) then the result is Array([1]) - if isinstance(where, arraylib.ndarrays) and len(arraylib.shape(where)): + # because of the variable resultant size of the output + if isinstance(where, ndarrays) and len(arraylib.shape(where)): + if fill_value is not _no_fill_value: + return arraylib.where(where, leaf, fill_value) return leaf[where] # non-array boolean mask we select the leaf if the mask is True # and `None` otherwise + if fill_value is not _no_fill_value: + return leaf if where else fill_value return leaf if where else None - return treelib.tree_map( + return treelib.map( leaf_get, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -582,7 +372,7 @@ def set( self, set_value: Any, *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ): """Set the leaf values at the specified location. @@ -609,6 +399,7 @@ def set( {'a': 1, 'b': [100, 2, 3]} """ treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) def leaf_set(where: Any, leaf: Any, set_value: Any): # support both array and non-array leaves @@ -616,12 +407,12 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): # matches the mask, for example if the mask is Array([True, False, False]) # and the leaf is Array([1, 2, 3]) then the result is Array([1, 100, 100]) # with set_value = 100 - if isinstance(where, arraylib.ndarrays): + if isinstance(where, ndarrays): return arraylib.where(where, set_value, leaf) return set_value if where else leaf - _, lhsdef = treelib.tree_flatten(self.tree, is_leaf=is_leaf) - _, rhsdef = treelib.tree_flatten(set_value, is_leaf=is_leaf) + _, lhsdef = treelib.flatten(self.tree, is_leaf=is_leaf) + _, rhsdef = treelib.flatten(set_value, is_leaf=is_leaf) if lhsdef == rhsdef: # do not broadcast set_value if it is a pytree of same structure @@ -629,7 +420,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): # to tree2 leaves if tree2 is a pytree of same structure as tree # instead of making each leaf of tree a copy of tree2 # is design is similar to ``numpy`` design `np.at[...].set(Array)` - return treelib.tree_map( + return treelib.map( leaf_set, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -638,7 +429,7 @@ def leaf_set(where: Any, leaf: Any, set_value: Any): is_parallel=is_parallel, ) - return treelib.tree_map( + return treelib.map( ft.partial(leaf_set, set_value=set_value), resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -650,7 +441,7 @@ def apply( self, func: Callable[[Any], Any], *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ): """Apply a function to the leaf values at the specified location. @@ -685,19 +476,19 @@ def apply( >>> is_parallel = dict(max_workers=2) >>> images = sp.at(path)[...].apply(imread, is_parallel=is_parallel) # doctest: +SKIP """ - treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) def leaf_apply(where: Any, leaf: Any): # same as `leaf_set` but with `func` applied to the leaf # one thing to note is that, the where mask select an array # then the function needs work properly when applied to the selected # array elements - if isinstance(where, arraylib.ndarrays): + if isinstance(where, ndarrays): return arraylib.where(where, func(leaf), leaf) return func(leaf) if where else leaf - return treelib.tree_map( + return treelib.map( leaf_apply, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -710,7 +501,7 @@ def scan( func: Callable[[Any, S], tuple[Any, S]], state: S, *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, ) -> tuple[Any, S]: """Apply a function while carrying a state. @@ -746,6 +537,7 @@ def scan( leaf values while carrying a state and returning a single value. """ treelib = sepes._src.backend.treelib + ndarrays = tuple(arraylib.ndarrays) running_state = state def stateless_func(leaf): @@ -754,11 +546,11 @@ def stateless_func(leaf): return leaf def leaf_apply(where: Any, leaf: Any): - if isinstance(where, arraylib.ndarrays): + if isinstance(where, ndarrays): return arraylib.where(where, stateless_func(leaf), leaf) return stateless_func(leaf) if where else leaf - out_tree = treelib.tree_map( + out_tree = treelib.map( leaf_apply, resolve_where(self.where, self.tree, is_leaf), self.tree, @@ -771,7 +563,7 @@ def reduce( func: Callable[[Any, Any], Any], *, initializer: Any = _no_initializer, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, ) -> Any: """Reduce the leaf values at the specified location. @@ -799,7 +591,7 @@ def reduce( """ treelib = sepes._src.backend.treelib tree = self.get(is_leaf=is_leaf) # type: ignore - leaves, _ = treelib.tree_flatten(tree, is_leaf=is_leaf) + leaves, _ = treelib.flatten(tree, is_leaf=is_leaf) if initializer is _no_initializer: return ft.reduce(func, leaves) return ft.reduce(func, leaves, initializer) @@ -808,7 +600,7 @@ def pluck( self, count: int | None = None, *, - is_leaf: Callable[[Any], None] | None = None, + is_leaf: Callable[[Any], bool] | None = None, is_parallel: bool | ParallelConfig = False, ) -> list[Any]: """Extract subtrees at the specified location. @@ -880,7 +672,7 @@ def aggregate_subtrees(node: Any) -> bool: # for example if tree = dict(a=1) and mask is dict(a=True) # then returns [1] and not [dict(a=1)] return False - leaves, _ = treelib.tree_flatten(node, is_leaf=lambda x: x is None) + leaves, _ = treelib.flatten(node, is_leaf=lambda x: x is None) # in essence if the subtree does not contain any None leaves # then it is a valid subtree to be plucked # this because `get` sets the non-selected leaves to None @@ -890,9 +682,102 @@ def aggregate_subtrees(node: Any) -> bool: count -= 1 return True - treelib.tree_flatten(tree, is_leaf=aggregate_subtrees) + treelib.flatten(tree, is_leaf=aggregate_subtrees) return subtrees -# shorter alias -at = AtIndexer +# pass through for boolean pytrees masks and tuple of keys +at.dispatcher = ft.singledispatch(lambda x: x) + + +def def_rule( + user_type: type[T], + path_compare_func: Callable[[T, PathKeyEntry], bool], + *, + broadcastable: bool = False, +) -> None: + # remove the BaseKey abstraction from the user-facing function + class UserKey(BaseKey): + broadcast: bool = broadcastable + + def __init__(self, input: T): + self.input = input + + def __eq__(self, key: PathKeyEntry) -> bool: + return path_compare_func(self.input, key) + + at.dispatcher.register(user_type, UserKey) + + +at.def_rule = def_rule + + +# key rules to match user input to with the path entry + + +def str_compare(name: str, key: PathKeyEntry): + """Match a leaf with a given name.""" + if isinstance(key, str): + return name == key + treelib = sepes._src.backend.treelib + if isinstance(key, type(treelib.attribute_key(""))): + return name == key.name + if isinstance(key, type(treelib.dict_key(""))): + return name == key.key + return False + + +def int_compare(idx: int, key: PathKeyEntry) -> bool: + """Match a leaf with a given index.""" + if isinstance(key, int): + return idx == key + treelib = sepes._src.backend.treelib + if isinstance(key, type(treelib.sequence_key(0))): + return idx == key.idx + return False + + +def regex_compare(pattern: re.Pattern, key: PathKeyEntry) -> bool: + """Match a path with a regex pattern inside 'at' property.""" + if isinstance(key, str): + return re.fullmatch(pattern, key) is not None + treelib = sepes._src.backend.treelib + if isinstance(key, type(treelib.attribute_key(""))): + return re.fullmatch(pattern, key.name) is not None + if isinstance(key, type(treelib.dict_key(""))): + return re.fullmatch(pattern, key.key) is not None + return False + + +def ellipsis_compare(_, __): + return True + + +at.def_rule(str, str_compare, broadcastable=False) +at.def_rule(int, int_compare, broadcastable=False) +at.def_rule(re.Pattern, regex_compare, broadcastable=False) +at.def_rule(type(...), ellipsis_compare, broadcastable=True) + + +class MultiKey(BaseKey): + """Match a leaf with multiple keys at the same level.""" + + def __init__(self, *keys): + self.keys = tuple(keys) + + def __eq__(self, entry: PathKeyEntry) -> bool: + return any(entry == key for key in self.keys) + + broadcast: bool = False + + +if is_package_avaiable("jax"): + import jax.tree_util as jtu + + def jax_key_compare(input, key: PathKeyEntry) -> bool: + """Enable indexing with jax keys directly in `at`.""" + return input == key + + at.def_rule(jtu.SequenceKey, jax_key_compare, broadcastable=False) + at.def_rule(jtu.GetAttrKey, jax_key_compare, broadcastable=False) + at.def_rule(jtu.DictKey, jax_key_compare, broadcastable=False) diff --git a/sepes/_src/tree_mask.py b/sepes/_src/tree_mask.py index e113e98..5f14370 100644 --- a/sepes/_src/tree_mask.py +++ b/sepes/_src/tree_mask.py @@ -30,23 +30,19 @@ MaskType = Union[T, Callable[[Any], bool]] -class _FrozenError(NamedTuple): +class _MaskedError(NamedTuple): opname: str def __call__(self, *a, **k): raise NotImplementedError( - f"Cannot apply `{self.opname}` operation to a frozen object " + f"Cannot apply `{self.opname}` operation on a masked object " f"{', '.join(map(str, a))} " f"{', '.join(k + '=' + str(v) for k, v in k.items())}.\n" - "Unfreeze the object first by unmasking the frozen mask:\n" - "Example:\n" - ">>> import jax\n" - ">>> import sepes as sp\n" - ">>> tree = sp.tree_unmask(tree)" + "Unmask the object first using `tree_unmask`" ) -class _FrozenBase(Static): +class _MaskBase(Static[T]): # the objective of this class is to wrap a pytree node with a custom wrapper # that yields no leaves when flattened. This is useful to avoid updating # the node by effectivly *hiding it* from function transformations that operates @@ -69,43 +65,44 @@ def __repr__(self) -> str: def __str__(self) -> str: return "#" + tree_str(self.__wrapped__) - def __copy__(self) -> _FrozenBase[T]: + def __copy__(self) -> _MaskBase[T]: return type(self)(tree_copy(self.__wrapped__)) # raise helpful error message when trying to interact with frozen object - __add__ = __radd__ = __iadd__ = _FrozenError("+") - __sub__ = __rsub__ = __isub__ = _FrozenError("-") - __mul__ = __rmul__ = __imul__ = _FrozenError("*") - __matmul__ = __rmatmul__ = __imatmul__ = _FrozenError("@") - __truediv__ = __rtruediv__ = __itruediv__ = _FrozenError("/") - __floordiv__ = __rfloordiv__ = __ifloordiv__ = _FrozenError("//") - __mod__ = __rmod__ = __imod__ = _FrozenError("%") - __pow__ = __rpow__ = __ipow__ = _FrozenError("**") - __lshift__ = __rlshift__ = __ilshift__ = _FrozenError("<<") - __rshift__ = __rrshift__ = __irshift__ = _FrozenError(">>") - __and__ = __rand__ = __iand__ = _FrozenError("and") - __xor__ = __rxor__ = __ixor__ = _FrozenError("") - __or__ = __ror__ = __ior__ = _FrozenError("or") - __neg__ = __pos__ = __abs__ = __invert__ = _FrozenError("unary operation") - __call__ = _FrozenError("__call__") - - -@tree_summary.def_type(_FrozenBase) + __add__ = __radd__ = __iadd__ = _MaskedError("+") + __sub__ = __rsub__ = __isub__ = _MaskedError("-") + __mul__ = __rmul__ = __imul__ = _MaskedError("*") + __matmul__ = __rmatmul__ = __imatmul__ = _MaskedError("@") + __truediv__ = __rtruediv__ = __itruediv__ = _MaskedError("/") + __floordiv__ = __rfloordiv__ = __ifloordiv__ = _MaskedError("//") + __mod__ = __rmod__ = __imod__ = _MaskedError("%") + __pow__ = __rpow__ = __ipow__ = _MaskedError("**") + __lshift__ = __rlshift__ = __ilshift__ = _MaskedError("<<") + __rshift__ = __rrshift__ = __irshift__ = _MaskedError(">>") + __and__ = __rand__ = __iand__ = _MaskedError("and") + __xor__ = __rxor__ = __ixor__ = _MaskedError("") + __or__ = __ror__ = __ior__ = _MaskedError("or") + __neg__ = __pos__ = __abs__ = __invert__ = _MaskedError("unary") + __lt__ = __le__ = __gt__ = __ge__ = _MaskedError("comparison") + __call__ = _MaskedError("__call__") + + +@tree_summary.def_type(_MaskBase) def _(node) -> str: return f"#{tree_summary.type_dispatcher(node.__wrapped__)}" -class _FrozenHashable(_FrozenBase): +class _MaskedHashable(_MaskBase): def __hash__(self) -> int: return tree_hash(self.__wrapped__) def __eq__(self, rhs: Any) -> bool: - if not isinstance(rhs, _FrozenHashable): + if not isinstance(rhs, _MaskedHashable): return False return is_tree_equal(self.__wrapped__, rhs.__wrapped__) -class _FrozenArray(_FrozenBase): +class _MaskedArray(_MaskBase): # wrap arrays with a custom wrapper that implements hash and equality # using the wrapped array's bytes representation and sha256 hash function # this is useful to select some array to hold without updating in the process @@ -115,7 +112,7 @@ def __hash__(self) -> int: return int(hashlib.sha256(bytes).hexdigest(), 16) def __eq__(self, other) -> bool: - if not isinstance(other, _FrozenArray): + if not isinstance(other, _MaskedArray): return False lhs, rhs = self.__wrapped__, other.__wrapped__ # fast path to avoid calling `all` on large arrays @@ -123,136 +120,62 @@ def __eq__(self, other) -> bool: return False if arraylib.dtype(lhs) != arraylib.dtype(rhs): return False - return arraylib.all(lhs == rhs) + return arraylib.array_equal(lhs, rhs) -def freeze(value: T) -> _FrozenBase[T]: - """Freeze a value to avoid updating it by through function transformations. - - Args: - value: A value to freeze. - - Note: - - :func:`.freeze` is idempotent, i.e. ``freeze(freeze(x)) == freeze(x)``. - - Example: - >>> import jax - >>> import sepes as sp - >>> import jax.tree_util as jtu - >>> # Usage with `jax.tree_util.tree_leaves` - >>> # no leaves for a wrapped value - >>> jtu.tree_leaves(sp.freeze(2.)) - [] - - >>> # retrieve the frozen wrapper value using `is_leaf=sp.is_frozen` - >>> jtu.tree_leaves(sp.freeze(2.), is_leaf=sp.is_frozen) - [#2.0] - - >>> # Usage with `jax.tree_util.tree_map` - >>> a= [1,2,3] - >>> a[1] = sp.freeze(a[1]) - >>> jtu.tree_map(lambda x:x+100, a) - [101, #2, 103] - """ +def mask(value: T) -> _MaskBase[T]: # dispatching is used to customize the type of the wrapper based on the type # of the value. For instance, hashable values dont need custom hash and # equality implementations, so they are wrapped with a simpler wrapper. # this approach avoids type logic in the wrapper equality and hash methods, # thus effectively improving performance of the wrapper. - return freeze.type_dispatcher(value) + return mask.type_dispatcher(value) -freeze.type_dispatcher = ft.singledispatch(_FrozenHashable) -freeze.def_type = freeze.type_dispatcher.register +mask.type_dispatcher = ft.singledispatch(_MaskedHashable) +mask.def_type = mask.type_dispatcher.register for ndarray in arraylib.ndarrays: - @freeze.def_type(ndarray) - def freeze_array(value: T) -> _FrozenArray[T]: + @mask.def_type(ndarray) + def mask_array(value: T) -> _MaskedArray[T]: # wrap arrays with a custom wrapper that implements hash and equality # arrays can be hashed by converting them to bytes and hashing the bytes - return _FrozenArray(value) + return _MaskedArray(value) -@freeze.def_type(_FrozenBase) -def _(value: _FrozenBase[T]) -> _FrozenBase[T]: - # idempotent freeze operation, meaning that freeze(freeze(x)) == freeze(x) +@mask.def_type(_MaskBase) +def _(value: _MaskBase[T]) -> _MaskBase[T]: + # idempotent mask operation, meaning that mask(mask(x)) == mask(x) # this is useful to avoid recursive unwrapping of frozen values, plus its - # meaningless to freeze a frozen value. + # meaningless to mask a frozen value. return value -def is_frozen(value: Any) -> bool: +def is_masked(value: Any) -> bool: """Returns True if the value is a frozen wrapper.""" - return isinstance(value, _FrozenBase) - - -def unfreeze(value: T) -> T: - """Unfreeze :func:`.freeze` value, otherwise return the value itself. + return isinstance(value, _MaskBase) - Args: - value: A value to unfreeze. - - Note: - - use ``is_leaf=sp.is_frozen`` with ``tree_map`` to unfreeze a tree.** - Example: - >>> import sepes as sp - >>> import jax - >>> frozen_value = sp.freeze(1) - >>> sp.unfreeze(frozen_value) - 1 - >>> # usage with `jax.tree_map` - >>> frozen_tree = jax.tree_map(sp.freeze, {"a": 1, "b": 2}) - >>> unfrozen_tree = jax.tree_map(sp.unfreeze, frozen_tree, is_leaf=sp.is_frozen) - >>> unfrozen_tree - {'a': 1, 'b': 2} - """ - return unfreeze.type_dispatcher(value) +def unmask(value: T) -> T: + return unmask.type_dispatcher(value) -unfreeze.type_dispatcher = ft.singledispatch(lambda x: x) -unfreeze.def_type = unfreeze.type_dispatcher.register +unmask.type_dispatcher = ft.singledispatch(lambda x: x) +unmask.def_type = unmask.type_dispatcher.register -@unfreeze.def_type(_FrozenBase) -def _(value: _FrozenBase[T]) -> T: +@unmask.def_type(_MaskBase) +def _(value: _MaskBase[T]) -> T: return getattr(value, "__wrapped__") def is_nondiff(value: Any) -> bool: - """Returns True for non-inexact types, False otherwise. - - Args: - value: A value to check. - - Note: - - :func:`.is_nondiff` uses single dispatch to support custom types. To define - a custom behavior for a certain type, use ``is_nondiff.def_type(type, func)``. - - Example: - >>> import sepes as sp - >>> import jax.numpy as jnp - >>> sp.is_nondiff(jnp.array(1)) # int array is non-diff type - True - >>> sp.is_nondiff(jnp.array(1.)) # float array is diff type - False - >>> sp.is_nondiff(1) # int is non-diff type - True - >>> sp.is_nondiff(1.) # float is diff type - False - - Note: - This function is meant to be used with ``jax.tree_map`` to - create a mask for non-differentiable nodes in a tree, that can be used - to freeze the non-differentiable nodes before passing the tree to a - ``jax`` transformation. - """ return is_nondiff.type_dispatcher(value) -is_nondiff.type_dispatcher = ft.singledispatch(lambda x: True) +is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True) is_nondiff.def_type = is_nondiff.type_dispatcher.register @@ -274,78 +197,53 @@ def _(_: float | complex) -> bool: def _tree_mask_map( tree: T, - mask: MaskType, + cond: Callable[[Any], bool], func: type | Callable[[Any], Any], *, is_leaf: Callable[[Any], None] | None = None, ): - treelib = sepes._src.backend.treelib - # apply func to leaves satisfying mask pytree/condtion - _, lhsdef = treelib.tree_flatten(tree, is_leaf=is_leaf) - _, rhsdef = treelib.tree_flatten(mask, is_leaf=is_leaf) - - if (lhsdef == rhsdef) and (type(mask) is type(tree)): - # a tree with the same structure as tree with boolean values - # and also a callable. - def map_func(x, y): - return func(x) if y else x - return treelib.tree_map(map_func, tree, mask, is_leaf=is_leaf) - - if isinstance(mask, Callable): + if not isinstance(cond, Callable): # a callable that accepts a leaf and returns a boolean # but *not* a tree with the same structure as tree with boolean values. - def map_func(x): - return func(x) if mask(x) else x + raise TypeError( + f"`cond` must be a callable that accepts a leaf and returns a boolean " + f" Got {cond=} and {tree=}." + ) - return treelib.tree_map(map_func, tree, is_leaf=is_leaf) + treelib = sepes._src.backend.treelib - raise ValueError( - f"`mask` must be a callable that accepts a leaf and returns a boolean " - f"or a tree with the same structure as tree with boolean values." - f" Got {mask=} and {tree=}." - ) + def map_func(x): + return func(x) if cond(x) else x + + return treelib.map(map_func, tree, is_leaf=is_leaf) def tree_mask( tree: T, - mask: MaskType = is_nondiff, + cond: Callable[[Any], bool] = is_nondiff, *, is_leaf: Callable[[Any], None] | None = None, ): """Mask leaves of a pytree based on ``mask`` boolean pytree or callable. + Masked leaves are wrapped with a wrapper that yields no leaves when + ``tree_flatten`` is called on it. + Args: tree: A pytree of values. - mask: A pytree of boolean values or a callable that accepts a leaf and - returns a boolean. If a leaf is ``True`` either in the mask or the - callable, the leaf is wrapped by with a wrapper that yields no - leaves when ``tree_flatten`` is called on it, otherwise - it is unchanged. defaults to :func:`.is_nondiff` which returns true for - non-differentiable nodes. + cond: A callable that accepts a leaf and returns a boolean to mark the leaf + for masking. Defaults to masking non-differentiable leaf nodes that + are not instances of of python float, python complex, or inexact + array types. is_leaf: A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example, ``is_leaf=lambda x: isinstance(x, list)`` will treat lists as leaves and will not recurse into them. - Note: - - Masked leaves are wrapped with a wrapper that yields no leaves when - ``tree_flatten`` is called on it. - - Masking is equivalent to applying :func:`.freeze` to the masked leaves. - - >>> import sepes as sp - >>> import jax - >>> tree = [1, 2, {"a": 3, "b": 4.}] - >>> # mask all non-differentiable nodes by default - >>> def mask_if_nondiff(x): - ... return sp.freeze(x) if sp.is_nondiff(x) else x - >>> masked_tree = jax.tree_map(mask_if_nondiff, tree) - - - Use masking on tree containing non-differentiable nodes before passing - the tree to a ``jax`` transformation. - Example: >>> import sepes as sp + >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) @@ -357,32 +255,32 @@ def tree_mask( [1, 2, {'a': 3, 'b': 4.0}] Example: - >>> # pass non-differentiable values to `jax.grad` + Pass non-differentiable values to ``jax.grad`` + >>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) - ... return tree[0]**2 + ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) """ - return _tree_mask_map(tree, mask=mask, func=freeze, is_leaf=is_leaf) + return _tree_mask_map(tree, cond=cond, func=mask, is_leaf=is_leaf) -def tree_unmask(tree: T, mask: MaskType = lambda _: True): - """Undo the masking of tree leaves according to ``mask``. defaults to unmasking all leaves. +def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True): + """Undo the masking of tree leaves according to ``cond``. defaults to unmasking all leaves. Args: tree: A pytree of values. - mask: A pytree of boolean values or a callable that accepts a leaf and - returns a boolean. If a leaf is True either in the mask or the - callable, the leaf is unfrozen, otherwise it is unchanged. defaults - unmasking all nodes. + cond: A callable that accepts a leaf and returns a boolean to mark the + leaf to be unmasked. Defaults to always unmask. Example: >>> import sepes as sp + >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) @@ -394,27 +292,19 @@ def tree_unmask(tree: T, mask: MaskType = lambda _: True): [1, 2, {'a': 3, 'b': 4.0}] Example: - >>> # pass non-differentiable values to `jax.grad` + Pass non-differentiable values to ``jax.grad`` + >>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) - ... return tree[0]**2 + ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) - - Note: - - Unmasking is equivalent to applying :func:`.unfreeze` on the masked leaves. - - >>> import sepes as sp - >>> import jax - >>> tree = [1, 2, {"a": 3, "b": 4.}] - >>> # unmask all nodes - >>> tree = jax.tree_map(sp.unfreeze, tree, is_leaf=sp.is_frozen) """ - return _tree_mask_map(tree, mask=mask, func=unfreeze, is_leaf=is_frozen) + return _tree_mask_map(tree, cond=cond, func=unmask, is_leaf=is_masked) if is_package_avaiable("jax"): @@ -424,6 +314,6 @@ def tree_unmask(tree: T, mask: MaskType = lambda _: True): # otherwise calling `freeze` inside a jax transformation on # a tracer will hide the tracer from jax and will cause leaked tracer # error. - @freeze.def_type(jax.core.Tracer) + @mask.def_type(jax.core.Tracer) def _(value: jax.core.Tracer) -> jax.core.Tracer: return value diff --git a/sepes/_src/tree_pprint.py b/sepes/_src/tree_pprint.py index 27ea5f2..fb24e65 100644 --- a/sepes/_src/tree_pprint.py +++ b/sepes/_src/tree_pprint.py @@ -31,7 +31,6 @@ from sepes._src.backend import is_package_avaiable from sepes._src.tree_util import ( Node, - Partial, construct_tree, is_path_leaf_depth_factory, tree_type_path_leaves, @@ -178,13 +177,12 @@ def _(func: Callable, **spec: Unpack[PPSpec]) -> str: return f"{name}({', '.join(header)})" -@tree_str.def_type(Partial) @tree_str.def_type(ft.partial) def _(node: ft.partial, **spec: Unpack[PPSpec]) -> str: func = tree_str.pp(node.func, **spec) args = tree_str.pps(tree_str.pp, node.args, **spec) keywords = tree_str.pps(tree_str.kv_pp, node.keywords, **spec) - return f"Partial(" + ",".join([func, args, keywords]) + ")" + return "partial(" + ",".join([func, args, keywords]) + ")" @tree_str.def_type(list) @@ -242,7 +240,6 @@ def array_pp(node, **spec: Unpack[PPSpec]) -> str: return f"{base}(μ={mean}, σ={std}, ∈{interval})" -@tree_repr.def_type(Partial) @tree_repr.def_type(ft.partial) def _(node: ft.partial, **spec: Unpack[PPSpec]) -> str: func = tree_repr.pp(node.func, **spec) @@ -363,139 +360,6 @@ def step( return (text if tabwidth is None else text.expandtabs(tabwidth)).rstrip() -def tree_mermaid( - tree: PyTree, - depth: int | float = float("inf"), - is_leaf: Callable[[Any], None] | None = None, - tabwidth: int | None = 4, -) -> str: - """Generate a mermaid diagram syntax for arbitrary pytrees. - - Args: - tree: PyTree - depth: depth of the tree to print. default is max depth - is_leaf: function to determine if a node is a leaf. default is None - tabwidth: tab width of the repr string. default is 4. - - Example: - >>> import sepes as sp - >>> tree = [1, 2, dict(a=3)] - >>> # as rendered by mermaid - >>> print(sp.tree_mermaid(tree)) # doctest: +SKIP - - .. image:: ../_static/tree_mermaid.jpg - :width: 300px - :align: center - - Note: - - Copy the output and paste it in the mermaid live editor to interact with - the diagram. https://mermaid.live - """ - - def step(node: Node, depth: int = 0) -> str: - if len(node.children) == 0: - (key, _), value = node.data - ppstr = f"{key}=" if key is not None else "" - ppstr += tree_repr(value, depth=0) - ppstr = "" + ppstr + "" - return f'\tid{id(node.parent)} --- id{id(node)}("{ppstr}")\n' - - (key, type), _ = node.data - ppstr = f"{key}:" if key is not None else "" - ppstr += f"{type.__name__}" - ppstr = "" + ppstr + "" - - if node.parent is None: - text = f'\tid{id(node)}("{ppstr}")\n' - else: - text = f'\tid{id(node.parent)} --- id{id(node)}("{ppstr}")\n' - - for child in node.children.values(): - text += step(child, depth=depth + 1) - return text - - is_path_leaf = is_path_leaf_depth_factory(depth) - root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf) - text = "flowchart LR\n" + step(root) - return (text.expandtabs(tabwidth) if tabwidth is not None else text).rstrip() - - -# dispatcher for dot nodestyles -dot_dispatcher = ft.singledispatch(lambda _: dict(shape="box")) - - -def tree_graph( - tree: PyTree, - depth: int | float = float("inf"), - is_leaf: Callable[[Any], None] | None = None, - tabwidth: int | None = 4, -) -> str: - """Generate a dot diagram syntax for arbitrary pytrees. - - Args: - tree: pytree - depth: depth of the tree to print. default is max depth - is_leaf: function to determine if a node is a leaf. default is None - tabwidth: tab width of the repr string. default is 4. - - Returns: - str: dot diagram syntax - - Example: - >>> import sepes as sp - >>> tree = [1, 2, dict(a=3)] - >>> # as rendered by graphviz - - .. image:: ../_static/tree_graph.svg - - Example: - >>> # define custom style for a node by dispatching on the value - >>> # the defined function should return a dict of attributes - >>> # that will be passed to graphviz. - >>> import sepes as sp - >>> tree = [1, 2, dict(a=3)] - >>> @sp.tree_graph.def_nodestyle(list) - ... def _(_) -> dict[str, str]: - ... return dict(shape="circle", style="filled", fillcolor="lightblue") - - .. image:: ../_static/tree_graph_stylized.svg - """ - - def step(node: Node, depth: int = 0) -> str: - (key, type), value = node.data - - # dispatch node style - style = ", ".join(f"{k}={v}" for k, v in dot_dispatcher(value).items()) - - if len(node.children) == 0: - ppstr = f"{key}=" if key is not None else "" - ppstr += tree_repr(value, depth=0) - text = f'\t{id(node)} [label="{ppstr}", {style}];\n' - text += f"\t{id(node.parent)} -> {id(node)};\n" - return text - - ppstr = f"{key}:" if key is not None else "" - ppstr += f"{type.__name__}" - - if node.parent is None: - text = f'\t{id(node)} [label="{ppstr}", {style}];\n' - else: - text = f'\t{id(node)} [label="{ppstr}", {style}];\n' - text += f"\t{id(node.parent)} -> {id(node)};\n" - - for child in node.children.values(): - text += step(child, depth=depth + 1) - return text - - is_path_leaf = is_path_leaf_depth_factory(depth) - root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf) - text = "digraph G {\n" + step(root) + "}" - return (text.expandtabs(tabwidth) if tabwidth is not None else text).rstrip() - - -tree_graph.def_nodestyle = dot_dispatcher.register - - def format_width(string, width=60): """Strip newline/tab characters if less than max width.""" children_length = len(string) - string.count("\n") - string.count("\t") @@ -570,39 +434,19 @@ def tree_summary( >>> import sepes as sp >>> import jax.numpy as jnp >>> print(sp.tree_summary([1, [2, [3]], jnp.array([1, 2, 3])])) - ┌─────────┬──────┬─────┬──────┐ - │Name │Type │Count│Size │ - ├─────────┼──────┼─────┼──────┤ - │[0] │int │1 │ │ - ├─────────┼──────┼─────┼──────┤ - │[1][0] │int │1 │ │ - ├─────────┼──────┼─────┼──────┤ - │[1][1][0]│int │1 │ │ - ├─────────┼──────┼─────┼──────┤ - │[2] │i32[3]│3 │12.00B│ - ├─────────┼──────┼─────┼──────┤ - │Σ │list │6 │12.00B│ - └─────────┴──────┴─────┴──────┘ - - Example: - Set custom type display for ``jax`` jaxprs - - >>> import jax - >>> import sepes as sp - >>> ClosedJaxprType = type(jax.make_jaxpr(lambda x: x)(1)) - >>> @sp.tree_summary.def_type(ClosedJaxprType) - ... def _(expr: ClosedJaxprType) -> str: - ... jaxpr = expr.jaxpr - ... return f"Jaxpr({jaxpr.invars}, {jaxpr.outvars})" - >>> def func(x, y): - ... return x - >>> jaxpr = jax.make_jaxpr(func)(1, 2) - >>> print(sp.tree_summary(jaxpr)) - ┌────┬──────────────────┬─────┬────┐ - │Name│Type │Count│Size│ - ├────┼──────────────────┼─────┼────┤ - │Σ │Jaxpr([a, b], [a])│1 │ │ - └────┴──────────────────┴─────┴────┘ + ┌─────────┬────────────────────────────────────┬─────┬──────┐ + │Name │Type │Count│Size │ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │[0] │int │1 │ │ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │[1][0] │int │1 │ │ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │[1][1][0]│int │1 │ │ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │[2] │i32[3] │3 │12.00B│ + ├─────────┼────────────────────────────────────┼─────┼──────┤ + │Σ │list[int,list[int,list[int]],i32[3]]│6 │12.00B│ + └─────────┴────────────────────────────────────┴─────┴──────┘ Example: Display flops of a function in tree summary @@ -662,14 +506,14 @@ def tree_size(tree: PyTree) -> int: def reduce_func(acc, node): return acc + tree_summary.size_dispatcher(node) - leaves, _ = treelib.tree_flatten(tree) + leaves, _ = treelib.flatten(tree) return ft.reduce(reduce_func, leaves, 0) def tree_count(tree: PyTree) -> int: def reduce_func(acc, node): return acc + tree_summary.count_dispatcher(node) - leaves, _ = treelib.tree_flatten(tree) + leaves, _ = treelib.flatten(tree) return ft.reduce(reduce_func, leaves, 0) traces_leaves = tree_type_path_leaves( @@ -725,6 +569,21 @@ def _(node: Any) -> str: dtype = arraylib.dtype(node) return tree_repr(ShapeDTypePP(shape, dtype)) +@tree_summary.def_type(list) +@tree_summary.def_type(tuple) +def _(node: tuple) -> str: + # - output Container[types,...] instead of just container type in the type col. + # - usually this encounterd if the tree_summary depth is not inf + # so the tree leaves could contain non-atomic types. + treelib = sepes._src.backend.treelib + + one_level_types = treelib.map( + tree_summary.type_dispatcher, + node, + is_leaf=lambda x: False if id(x) == id(node) else True, + ) + return f"{type(node).__name__}[{','.join(one_level_types)}]" + if is_package_avaiable("jax"): # jax pretty printing extra handlers @@ -764,4 +623,19 @@ def _(node, **spec: Unpack[PPSpec]) -> str: shape = node.aval.shape dtype = node.aval.dtype string = tree_repr.dispatch(ShapeDTypePP(shape, dtype), **spec) - return f"Tracer({string})" + return f"{type(node).__name__}({string})" + + # handle the sharding info if it is sharded + @tree_summary.def_type(jax.Array) + def _(node: Any) -> str: + """Return the type repr of the node.""" + # global shape + global_shape = arraylib.shape(node) + shard_shape = node.sharding.shard_shape(global_shape) + dtype = arraylib.dtype(node) + global_info = tree_repr(ShapeDTypePP(global_shape, dtype)) + + if global_shape == shard_shape: + return global_info + shard_info = tree_repr(ShapeDTypePP(shard_shape, dtype)) + return f"G:{global_info}\nS:{shard_info}" diff --git a/sepes/_src/tree_util.py b/sepes/_src/tree_util.py index 88fb24d..429b1b6 100644 --- a/sepes/_src/tree_util.py +++ b/sepes/_src/tree_util.py @@ -16,15 +16,17 @@ from __future__ import annotations +import copy import functools as ft import operator as op -import copy from math import ceil, floor, trunc -from typing import Any, Callable, Hashable, Iterator, Sequence, Tuple, TypeVar, Generic -import sepes._src.backend.arraylib as arraylib -from sepes._src.backend import is_package_avaiable +from typing import Any, Callable, Generic, Hashable, Iterator, Sequence, Tuple, TypeVar + from typing_extensions import ParamSpec + import sepes +import sepes._src.backend.arraylib as arraylib +from sepes._src.backend import is_package_avaiable T = TypeVar("T") T1 = TypeVar("T1") @@ -42,7 +44,7 @@ def tree_hash(*trees: PyTree) -> int: treelib = sepes._src.backend.treelib - leaves, treedef = treelib.tree_flatten(trees) + leaves, treedef = treelib.flatten(trees) return hash((*leaves, treedef)) @@ -57,7 +59,7 @@ def tree_copy(tree: T) -> T: def is_leaf(node) -> bool: return isinstance(node, types) - return treelib.tree_map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf) + return treelib.map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf) # default behavior is to copy the tree elements except for registered types @@ -66,18 +68,31 @@ def is_leaf(node) -> bool: tree_copy.def_type = tree_copy.copy_dispatcher.register +@tree_copy.def_type(int) +@tree_copy.def_type(float) +@tree_copy.def_type(complex) +@tree_copy.def_type(str) +@tree_copy.def_type(bytes) +def _(x: T) -> T: + # skip applying `copy.copy` on immutable atom types + return x + + def is_array_like(node) -> bool: return hasattr(node, "shape") and hasattr(node, "dtype") -def _is_leaf_rhs_equal(leaf, rhs) -> bool: +def _is_leaf_rhs_equal(leaf, rhs): if is_array_like(leaf): if is_array_like(rhs): if leaf.shape != rhs.shape: return False if leaf.dtype != rhs.dtype: return False - verdict = arraylib.all(leaf == rhs) + try: + verdict = arraylib.all(leaf == rhs) + except NotImplementedError: + verdict = leaf == rhs try: return bool(verdict) except Exception: @@ -94,11 +109,11 @@ def is_tree_equal(*trees: Any) -> bool: """ treelib = sepes._src.backend.treelib tree0, *rest = trees - leaves0, treedef0 = treelib.tree_flatten(tree0) + leaves0, treedef0 = treelib.flatten(tree0) verdict = True for tree in rest: - leaves, treedef = treelib.tree_flatten(tree) + leaves, treedef = treelib.flatten(tree) if (treedef != treedef0) or verdict is False: return False verdict = ft.reduce(op.and_, map(_is_leaf_rhs_equal, leaves0, leaves), verdict) @@ -116,73 +131,28 @@ def __init_subclass__(klass, **k) -> None: treelib.register_static(klass) -class Partial(Static): - """``Partial`` function with support for positional partial application. - - Args: - func: The function to be partially applied. - args: Positional arguments to be partially applied. use ``...`` as a - placeholder for positional arguments. - kwargs: Keyword arguments to be partially applied. - - Example: - >>> import sepes as sp - >>> def f(a, b, c): - ... print(f"a: {a}, b: {b}, c: {c}") - ... return a + b + c - - >>> # positional arguments using `...` placeholder - >>> f_a = sp.Partial(f, ..., 2, 3) - >>> f_a(1) - a: 1, b: 2, c: 3 - 6 - - >>> # keyword arguments - >>> f_b = sp.Partial(f, b=2, c=3) - >>> f_a(1) - a: 1, b: 2, c: 3 - 6 - - Note: - - The ``...`` is used to indicate a placeholder for positional arguments. - - https://stackoverflow.com/a/7811270 - """ - - __slots__ = ["func", "args", "keywords"] # type: ignore - - def __init__(self, func: Callable[..., Any], *args: Any, **kwargs: Any): - self.func = func - self.args = args - self.keywords = kwargs - +class partial(ft.partial): def __call__(self, *args: Any, **kwargs: Any) -> Any: iargs = iter(args) args = (next(iargs) if arg is ... else arg for arg in self.args) # type: ignore return self.func(*args, *iargs, **{**self.keywords, **kwargs}) - def __repr__(self) -> str: - return f"Partial({self.func}, {self.args}, {self.keywords})" - - def __hash__(self) -> int: - return tree_hash(self) - - def __eq__(self, other: Any) -> bool: - return is_tree_equal(self, other) - - -# to match python -partial = Partial - def bcmap( func: Callable[P, T], + broadcast_to: int | str | None = None, *, is_leaf: Callable[[Any], bool] | None = None, ) -> Callable[P, T]: """Map a function over pytree leaves with automatic broadcasting for scalar arguments. Args: - func: the function to be mapped over the pytree + func: the function to be mapped over the pytree. + broadcast_to: Accepts integer for broadcasting to a specific argument + or string for broadcasting to a specific keyword argument. + If ``None``, then the function is broadcasted to the first argument + or the first keyword argument if no positional arguments are provided. + Defaults to ``None``. is_leaf: a predicate function that returns True if the node is a leaf. Example: @@ -199,7 +169,6 @@ def bcmap( >>> print(sp.tree_str(tree_add(tree_of_arrays, 1))) dict(a=[2 3 4], b=[5 6 7]) """ - # add broadcasting argnum/argname to the function later treelib = sepes._src.backend.treelib @ft.wraps(func) @@ -209,23 +178,29 @@ def wrapper(*args, **kwargs): leaves = [] kwargs_keys: list[str] = [] + bdcst_to = ( + (0 if len(args) else next(iter(kwargs))) + if broadcast_to is None + else broadcast_to + ) + treedef0 = ( # reference treedef is the first positional argument - treelib.tree_flatten(args[0], is_leaf=is_leaf)[1] + treelib.flatten(args[bdcst_to], is_leaf=is_leaf)[1] if len(args) # reference treedef is the first keyword argument - else treelib.tree_flatten(kwargs[next(iter(kwargs))], is_leaf=is_leaf)[1] + else treelib.flatten(kwargs[bdcst_to], is_leaf=is_leaf)[1] ) for arg in args: - if treedef0 == treelib.tree_flatten(arg, is_leaf=is_leaf)[1]: + if treedef0 == treelib.flatten(arg, is_leaf=is_leaf)[1]: cargs += [...] leaves += [treedef0.flatten_up_to(arg)] else: cargs += [arg] for key in kwargs: - if treedef0 == treelib.tree_flatten(kwargs[key], is_leaf=is_leaf)[1]: + if treedef0 == treelib.flatten(kwargs[key], is_leaf=is_leaf)[1]: ckwargs[key] = ... leaves += [treedef0.flatten_up_to(kwargs[key])] kwargs_keys += [key] @@ -239,7 +214,7 @@ def wrapper(*args, **kwargs): args = args_kwargs_values[:split_index] kwargs = dict(zip(kwargs_keys, args_kwargs_values[split_index:])) all_leaves += [bfunc(*args, **kwargs)] - return treelib.tree_unflatten(treedef0, all_leaves) + return treelib.unflatten(treedef0, all_leaves) return wrapper @@ -266,7 +241,8 @@ def leafwise(klass: type[T]) -> type[T]: The decorated class. Example: - >>> # use ``numpy`` functions on :class:`TreeClass`` classes decorated with ``leafwise`` + Use ``numpy`` functions on :class:`TreeClass`` classes decorated with :func:`leafwise` + >>> import sepes as sp >>> import jax.numpy as jnp >>> @sp.leafwise @@ -321,15 +297,15 @@ def leafwise(klass: type[T]) -> type[T]: def uop(func): def wrapper(self): - return treelib.tree_map(func, self) + return treelib.map(func, self) return ft.wraps(func)(wrapper) def bop(func): def wrapper(leaf, rhs=None): if isinstance(rhs, type(leaf)): - return treelib.tree_map(func, leaf, rhs) - return treelib.tree_map(lambda x: func(x, rhs), leaf) + return treelib.map(func, leaf, rhs) + return treelib.map(lambda x: func(x, rhs), leaf) return ft.wraps(func)(wrapper) @@ -391,7 +367,7 @@ def tree_type_path_leaves( is_path_leaf: Callable[[KeyTypePath], bool] | None = None, ) -> Sequence[tuple[KeyTypePath, Any]]: treelib = sepes._src.backend.treelib - _, atomicdef = treelib.tree_flatten(1) + _, atomicdef = treelib.flatten(1) # mainly used for visualization def flatten_one_level(type_path: KeyTypePath, tree: PyTree): @@ -407,7 +383,7 @@ def one_level_is_leaf(node) -> bool: return False return True - path_leaf, treedef = treelib.tree_path_flatten(tree, is_leaf=one_level_is_leaf) + path_leaf, treedef = treelib.path_flatten(tree, is_leaf=one_level_is_leaf) if treedef == atomicdef: yield type_path, tree @@ -501,7 +477,7 @@ def construct_tree( return root -def value_and_tree(func, argnums: int | Sequence[int] = 0): +def value_and_tree(func: Callable[..., T], argnums: int | Sequence[int] = 0): """Call a function on copied input argument and return the value and the tree. Input arguments are copied before calling the function, and the argument @@ -614,15 +590,15 @@ def immutate_is_leaf(node): return False @ft.wraps(func) - def stateless_func(*args, **kwargs) -> tuple[Any, PyTree | tuple[PyTree, ...]]: + def stateless_func(*args, **kwargs) -> tuple[T, PyTree | tuple[PyTree, ...]]: # copy the incoming inputs (args, kwargs) = tree_copy((args, kwargs)) # and edit the node/record to make it mutable (if there is a rule for it) - treelib.tree_map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf) + treelib.map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf) output = func(*args, **kwargs) # traverse each node in the tree depth-first manner # to undo the mutation (if there is a rule for it) - treelib.tree_map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf) + treelib.map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf) out_args = tuple(a for i, a in enumerate(args) if i in argnums) out_args = out_args[0] if is_int_argnum else out_args return output, out_args diff --git a/tests/test_index.py b/tests/test_index.py index 097543e..0728380 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -23,7 +23,7 @@ from sepes._src.backend import arraylib, backend, treelib from sepes._src.code_build import autoinit from sepes._src.tree_base import TreeClass, _mutable_instance_registry -from sepes._src.tree_index import AtIndexer, BaseKey +from sepes._src.tree_index import at, BaseKey from sepes._src.tree_util import is_tree_equal, leafwise, value_and_tree test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") @@ -117,7 +117,7 @@ def __init__(self, c: int, d: int): ], ) def test_indexer_get(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.get(), expected) assert is_tree_equal(indexer.get(is_parallel=True), expected) @@ -150,11 +150,33 @@ def test_indexer_get(tree, expected, where): ], ) def test_array_indexer_get(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.get(), expected) assert is_tree_equal(indexer.get(is_parallel=True), expected) +@pytest.mark.skipif(backend != "jax", reason="test jax jit with get") +def test_get_fill_value(): + import jax + import jax.numpy as jnp + + tree = dict(a=jnp.array([1, 2, 3]), b=jnp.array([4, 5, 6])) + mask = dict( + a=jnp.array([False, True, False]), + b=jnp.array([False, True, False]), + ) + + @jax.jit + def jit_func(tree): + return at(tree)[mask].get(fill_value=0) + + out = jit_func(tree) + a = out["a"] + b = out["b"] + assert jnp.all(a == jnp.array([0, 2, 0])) + assert jnp.all(b == jnp.array([0, 5, 0])) + + @pytest.mark.parametrize( ["tree", "expected", "where", "set_value"], [ @@ -191,7 +213,7 @@ def test_array_indexer_get(tree, expected, where): ], ) def test_indexer_set(tree, expected, where, set_value): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.set(set_value), expected) assert is_tree_equal(indexer.set(set_value, is_parallel=True), expected) @@ -233,7 +255,7 @@ def test_indexer_set(tree, expected, where, set_value): ], ) def test_array_indexer_set(tree, expected, where, set_value): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.set(set_value), expected) assert is_tree_equal(indexer.set(set_value, is_parallel=True), expected) @@ -268,7 +290,7 @@ def test_array_indexer_set(tree, expected, where, set_value): ], ) def test_indexer_apply(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.apply(lambda _: _X), expected) assert is_tree_equal( indexer.apply(lambda _: _X, is_parallel=True), @@ -307,7 +329,7 @@ def test_indexer_apply(tree, expected, where): ], ) def test_array_indexer_apply(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal(indexer.apply(lambda _: _X), expected) assert is_tree_equal( indexer.apply(lambda _: _X, is_parallel=True), @@ -343,7 +365,7 @@ def test_array_indexer_apply(tree, expected, where): ], ) def test_indexer_reduce(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal( indexer.reduce(lambda x, y: x + y, initializer=0), expected, @@ -378,7 +400,7 @@ def test_indexer_reduce(tree, expected, where): ], ) def test_array_indexer_reduce(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal( indexer.reduce(lambda x, y: x + y, initializer=0), expected, @@ -405,7 +427,7 @@ def test_array_indexer_reduce(tree, expected, where): ], ) def test_indexer_scan(tree, expected, where): - indexer = AtIndexer(tree, where=where) + indexer = at(tree, where=where) assert is_tree_equal( indexer.scan(lambda x, s: (x + s, x), state=0), expected, @@ -451,8 +473,8 @@ def __call__(self, x): a = A(1) _, b = value_and_tree(lambda A: A(2))(a) - assert treelib.tree_flatten(a)[0] == [1] - assert treelib.tree_flatten(b)[0] == [3] + assert treelib.flatten(a)[0] == [1] + assert treelib.flatten(b)[0] == [3] with pytest.raises(TypeError): a.at[0](1) @@ -480,7 +502,7 @@ def delete(self, name): def test_unsupported_where(where): t = namedtuple("a", ["x", "y"])(1, 2) with pytest.raises(NotImplementedError): - AtIndexer(t, where=where).get() + at(t, where=where).get() @pytest.mark.skipif(backend != "jax", reason="jax backend needed") @@ -496,7 +518,7 @@ def __init__(self, a, b) -> None: @property def at(self): - return AtIndexer(self) + return at(self) if backend == "jax": import jax.tree_util as jtu @@ -533,7 +555,7 @@ def __init__(self, a, b) -> None: @property def at(self): - return AtIndexer(self) + return at(self) import optree as ot @@ -575,26 +597,26 @@ class Tree(TreeClass): t = Tree() - assert repr(t.at["a"]) == "AtIndexer(tree=Tree(a=1, b=2), where=['a'])" - assert str(t.at["a"]) == "AtIndexer(tree=Tree(a=1, b=2), where=['a'])" - assert repr(t.at[...]) == "AtIndexer(tree=Tree(a=1, b=2), where=[Ellipsis])" + assert repr(t.at["a"]) == "at(Tree(a=1, b=2), where=['a'])" + assert str(t.at["a"]) == "at(Tree(a=1, b=2), where=['a'])" + assert repr(t.at[...]) == "at(Tree(a=1, b=2), where=[Ellipsis])" def test_compat_mask(): tree = [1, 2, [3, 4]] - tree_ = AtIndexer(tree)[[False, False, True]].set(10) + tree_ = at(tree)[[False, False, True]].set(10) assert tree_ == [1, 2, 10] def test_pluck(): tree = [1, 2, [3, 4]] - subtrees = AtIndexer(tree)[2].pluck() + subtrees = at(tree)[2].pluck() assert subtrees[0] == [3, 4] - assert AtIndexer(tree)[0, 1].pluck(1) == [1] - assert AtIndexer(tree)[0, 1].pluck(2) == [1, 2] + assert at(tree)[0, 1].pluck(1) == [1] + assert at(tree)[0, 1].pluck(2) == [1, 2] tree = dict(a=1, b=2) - assert AtIndexer(tree)[...].pluck() == [1, 2] + assert at(tree)[...].pluck() == [1, 2] @pytest.mark.skipif(backend != "jax", reason="jax backend needed") diff --git a/tests/test_mask.py b/tests/test_mask.py index d03cf8b..9b8de81 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -13,6 +13,8 @@ # limitations under the License. import copy +import functools as ft +import os from typing import Any import pytest @@ -20,17 +22,12 @@ from sepes._src.backend import backend, treelib from sepes._src.code_build import autoinit from sepes._src.tree_base import TreeClass -from sepes._src.tree_mask import ( - freeze, - is_frozen, - tree_mask, - tree_unmask, - unfreeze, -) -import os +from sepes._src.tree_mask import is_masked, tree_mask, tree_unmask from sepes._src.tree_util import is_tree_equal, leafwise, tree_hash test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") +freeze = ft.partial(tree_mask, cond=lambda _: True) +unfreeze = ft.partial(tree_unmask, cond=lambda _: True) if test_arraylib == "jax": import jax.numpy as arraylib @@ -54,14 +51,14 @@ class A(TreeClass): b = a.at[...].apply(freeze) c = ( a.at["a"] - .apply(unfreeze, is_leaf=is_frozen) + .apply(unfreeze, is_leaf=is_masked) .at["b"] - .apply(unfreeze, is_leaf=is_frozen) + .apply(unfreeze, is_leaf=is_masked) ) - assert treelib.tree_flatten(a)[0] == [1, 2] - assert treelib.tree_flatten(b)[0] == [] - assert treelib.tree_flatten(c)[0] == [1, 2] + assert treelib.flatten(a)[0] == [1, 2] + assert treelib.flatten(b)[0] == [] + assert treelib.flatten(c)[0] == [1, 2] assert unfreeze(freeze(1.0)) == 1.0 @autoinit @@ -80,17 +77,17 @@ class A(TreeClass): b: int a = A(1, 2) - b = treelib.tree_map(freeze, a) + b = treelib.map(freeze, a) c = ( a.at["a"] - .apply(unfreeze, is_leaf=is_frozen) + .apply(unfreeze, is_leaf=is_masked) .at["b"] - .apply(unfreeze, is_leaf=is_frozen) + .apply(unfreeze, is_leaf=is_masked) ) - assert treelib.tree_flatten(a)[0] == [1, 2] - assert treelib.tree_flatten(b)[0] == [] - assert treelib.tree_flatten(c)[0] == [1, 2] + assert treelib.flatten(a)[0] == [1, 2] + assert treelib.flatten(b)[0] == [] + assert treelib.flatten(c)[0] == [1, 2] @autoinit class L0(TreeClass): @@ -104,11 +101,11 @@ class L1(TreeClass): class L2(TreeClass): c: L1 = L1() - t = treelib.tree_map(freeze, L2()) + t = treelib.map(freeze, L2()) - assert treelib.tree_flatten(t)[0] == [] - assert treelib.tree_flatten(t.c)[0] == [] - assert treelib.tree_flatten(t.c.b)[0] == [] + assert treelib.flatten(t)[0] == [] + assert treelib.flatten(t.c)[0] == [] + assert treelib.flatten(t.c.b)[0] == [] class L1(TreeClass): def __init__(self): @@ -118,9 +115,9 @@ class L2(TreeClass): def __init__(self): self.c = L1() - t = treelib.tree_map(freeze, L2()) - assert treelib.tree_flatten(t.c)[0] == [] - assert treelib.tree_flatten(t.c.b)[0] == [] + t = treelib.map(freeze, L2()) + assert treelib.flatten(t.c)[0] == [] + assert treelib.flatten(t.c.b)[0] == [] def test_freeze_errors(): @@ -160,25 +157,25 @@ class Test(TreeClass): c: str = freeze("test") t = Test() - assert treelib.tree_flatten(t)[0] == [1] + assert treelib.flatten(t)[0] == [1] with pytest.raises(AttributeError): - treelib.tree_map(freeze, t).a = 1 + treelib.map(freeze, t).a = 1 with pytest.raises(AttributeError): - treelib.tree_map(unfreeze, t).a = 1 + treelib.map(unfreeze, t).a = 1 hash(t) t = Test() - treelib.tree_map(unfreeze, t, is_leaf=is_frozen) - treelib.tree_map(freeze, t) + treelib.map(unfreeze, t, is_leaf=is_masked) + treelib.map(freeze, t) @autoinit class Test(TreeClass): a: int - t = treelib.tree_map(freeze, (Test(100))) + t = treelib.map(freeze, (Test(100))) class Test(TreeClass): def __init__(self, x): @@ -223,7 +220,7 @@ class Test(TreeClass): t = Test() - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] def test_freeze_nondiff(): @@ -234,10 +231,10 @@ class Test(TreeClass): t = Test() - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] - assert treelib.tree_flatten( - (treelib.tree_map(freeze, t)).at["b"].apply(unfreeze, is_leaf=is_frozen) + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] + assert treelib.flatten( + (treelib.map(freeze, t)).at["b"].apply(unfreeze, is_leaf=is_masked) )[0] == ["a"] @autoinit @@ -246,11 +243,11 @@ class T0(TreeClass): t = T0() - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] - assert treelib.tree_flatten(t)[0] == ["a"] - assert treelib.tree_flatten(treelib.tree_map(freeze, t))[0] == [] + assert treelib.flatten(t)[0] == ["a"] + assert treelib.flatten(treelib.map(freeze, t))[0] == [] def test_freeze_nondiff_with_mask(): @@ -278,11 +275,11 @@ class L2(TreeClass): t = t.at["d"]["d"]["a"].apply(freeze) t = t.at["d"]["d"]["b"].apply(freeze) - assert treelib.tree_flatten(t)[0] == [10, 20, 30, 1, 2, 3, 3] + assert treelib.flatten(t)[0] == [10, 20, 30, 1, 2, 3, 3] def test_non_dataclass_input_to_freeze(): - assert treelib.tree_flatten(freeze(1))[0] == [] + assert treelib.flatten(freeze(1))[0] == [] def test_tree_mask(): @@ -299,18 +296,18 @@ class L1(TreeClass): tree = L1() - assert treelib.tree_flatten(tree)[0] == [1, 2, 3] - assert treelib.tree_flatten(treelib.tree_map(freeze, tree))[0] == [] - assert treelib.tree_flatten(treelib.tree_map(freeze, tree))[0] == [] - assert treelib.tree_flatten(tree.at[...].apply(freeze))[0] == [] - assert treelib.tree_flatten(tree.at[tree > 1].apply(freeze))[0] == [1] - assert treelib.tree_flatten(tree.at[tree == 1].apply(freeze))[0] == [2, 3] - assert treelib.tree_flatten(tree.at[tree < 1].apply(freeze))[0] == [1, 2, 3] + assert treelib.flatten(tree)[0] == [1, 2, 3] + assert treelib.flatten(treelib.map(freeze, tree))[0] == [] + assert treelib.flatten(treelib.map(freeze, tree))[0] == [] + assert treelib.flatten(tree.at[...].apply(freeze))[0] == [] + assert treelib.flatten(tree.at[tree > 1].apply(freeze))[0] == [1] + assert treelib.flatten(tree.at[tree == 1].apply(freeze))[0] == [2, 3] + assert treelib.flatten(tree.at[tree < 1].apply(freeze))[0] == [1, 2, 3] - assert treelib.tree_flatten(tree.at["a"].apply(freeze))[0] == [2, 3] - assert treelib.tree_flatten(tree.at["b"].apply(freeze))[0] == [1] - assert treelib.tree_flatten(tree.at["b"]["x"].apply(freeze))[0] == [1, 3] - assert treelib.tree_flatten(tree.at["b"]["y"].apply(freeze))[0] == [1, 2] + assert treelib.flatten(tree.at["a"].apply(freeze))[0] == [2, 3] + assert treelib.flatten(tree.at["b"].apply(freeze))[0] == [1] + assert treelib.flatten(tree.at["b"]["x"].apply(freeze))[0] == [1, 3] + assert treelib.flatten(tree.at["b"]["y"].apply(freeze))[0] == [1, 2] def test_tree_unmask(): @@ -328,21 +325,21 @@ class L1(TreeClass): tree = L1() frozen_tree = tree.at[...].apply(freeze) - assert treelib.tree_flatten(frozen_tree)[0] == [] + assert treelib.flatten(frozen_tree)[0] == [] mask = tree == tree - unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen) - assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) + assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] mask = tree > 1 - unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen) - assert treelib.tree_flatten(unfrozen_tree)[0] == [2, 3] + unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) + assert treelib.flatten(unfrozen_tree)[0] == [2, 3] - unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_frozen) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [1] + unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_masked) + # assert treelib.flatten(unfrozen_tree)[0] == [1] - # unfrozen_tree = frozen_tree.at["b"].apply(unfreeze, is_leaf=is_frozen) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [2, 3] + # unfrozen_tree = frozen_tree.at["b"].apply(unfreeze, is_leaf=is_masked) + # assert treelib.flatten(unfrozen_tree)[0] == [2, 3] def test_tree_mask_unfreeze(): @@ -361,12 +358,12 @@ class L1(TreeClass): mask = tree == tree frozen_tree = tree.at[...].apply(freeze) - unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_frozen) - assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + unfrozen_tree = frozen_tree.at[mask].apply(unfreeze, is_leaf=is_masked) + assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] # frozen_tree = tree.at["a"].apply(freeze) - # unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_frozen) - # assert treelib.tree_flatten(unfrozen_tree)[0] == [1, 2, 3] + # unfrozen_tree = frozen_tree.at["a"].apply(unfreeze, is_leaf=is_masked) + # assert treelib.flatten(unfrozen_tree)[0] == [1, 2, 3] def test_wrapper(): @@ -403,18 +400,16 @@ def test_wrapper(): @pytest.mark.skipif(backend == "default", reason="no array backend installed") def test_tree_mask_tree_unmask(): tree = [1, 2, 3.0] - assert treelib.tree_flatten(tree_mask(tree))[0] == [3.0] - assert treelib.tree_flatten(tree_unmask(tree_mask(tree)))[0] == [1, 2, 3.0] + assert treelib.flatten(tree_mask(tree))[0] == [3.0] + assert treelib.flatten(tree_unmask(tree_mask(tree)))[0] == [1, 2, 3.0] mask_func = lambda x: x < 2 - assert treelib.tree_flatten(tree_mask(tree, mask_func))[0] == [2, 3.0] + assert treelib.flatten(tree_mask(tree, mask_func))[0] == [2, 3.0] assert freeze(freeze(1)) == freeze(1) - assert tree_mask({"a": 1}, mask={"a": True}) == {"a": freeze(1)} - - with pytest.raises(ValueError): - tree_mask({"a": 1}, mask=1.0) + with pytest.raises(TypeError): + tree_mask({"a": 1}, cond=1.0) assert copy.copy(freeze(1)) == freeze(1) @@ -424,7 +419,7 @@ def test_tree_mask_tree_unmask(): @pytest.mark.skipif(backend == "default", reason="no array backend installed") def test_array_tree_mask_tree_unmask(): - frozen_array = tree_mask(arraylib.ones((5, 5)), mask=lambda _: True) + frozen_array = tree_mask(arraylib.ones((5, 5)), cond=lambda _: True) assert frozen_array == frozen_array assert not (frozen_array == freeze(arraylib.ones((5, 6)))) diff --git a/tests/test_operator.py b/tests/test_operator.py index 8a89205..b8908c1 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -15,6 +15,7 @@ from __future__ import annotations import math +import os from typing import Any import pytest @@ -22,9 +23,10 @@ from sepes._src.backend import backend from sepes._src.code_build import autoinit, field from sepes._src.tree_base import TreeClass -from sepes._src.tree_mask import freeze +from sepes._src.tree_mask import tree_mask from sepes._src.tree_util import bcmap, is_tree_equal, leafwise -import os + +freeze = lambda x: tree_mask(x, cond=lambda _: True) test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") if test_arraylib == "jax": @@ -171,3 +173,17 @@ def test_bcmap(tree, expected): def test_math_operations_errors(): with pytest.raises(TypeError): tree1 + "s" + + +def test_bcmap_int_argnum_broadcast_to(): + def func(x, y): + return x + y + + assert bcmap(func, broadcast_to=1)(1, [2, 3, 4]) == [3, 4, 5] + + +def test_bcmap_key_argnum_broadcast_to(): + def func(x, y): + return x + y + + assert bcmap(func, broadcast_to="y")(x=1, y=[2, 3, 4]) == [3, 4, 5] diff --git a/tests/test_pprint.py b/tests/test_pprint.py index 622f654..8eb8003 100644 --- a/tests/test_pprint.py +++ b/tests/test_pprint.py @@ -15,12 +15,11 @@ from __future__ import annotations import dataclasses as dc -import re +import os from collections import namedtuple from typing import Any import pytest -import os test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") backend = os.environ.get("SEPES_BACKEND", "jax") @@ -29,8 +28,6 @@ from sepes._src.tree_pprint import ( _table, tree_diagram, - tree_graph, - tree_mermaid, tree_repr, tree_str, tree_summary, @@ -144,7 +141,7 @@ def test_tree_summary(): assert ( tree_summary(r1, depth=1) # trunk-ignore(flake8/E501) - == "┌────┬────────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼────────────┼─────┼───────┤\n│.a │int │1 │ │\n├────┼────────────┼─────┼───────┤\n│.b │str │1 │ │\n├────┼────────────┼─────┼───────┤\n│.c │float │1 │ │\n├────┼────────────┼─────┼───────┤\n│.d │str │1 │ │\n├────┼────────────┼─────┼───────┤\n│.e │list │5 │ │\n├────┼────────────┼─────┼───────┤\n│.f │set │1 │ │\n├────┼────────────┼─────┼───────┤\n│.g │dict │27 │100.00B│\n├────┼────────────┼─────┼───────┤\n│.h │f32[5,1] │5 │20.00B │\n├────┼────────────┼─────┼───────┤\n│.i │f32[1,6] │6 │24.00B │\n├────┼────────────┼─────┼───────┤\n│.j │f32[1,1,4,5]│20 │80.00B │\n├────┼────────────┼─────┼───────┤\n│.k │tuple │3 │ │\n├────┼────────────┼─────┼───────┤\n│.l │a │2 │ │\n├────┼────────────┼─────┼───────┤\n│.m │f32[5,5] │25 │100.00B│\n├────┼────────────┼─────┼───────┤\n│.n │bool[] │1 │1.00B │\n├────┼────────────┼─────┼───────┤\n│.o │c64[2] │2 │16.00B │\n├────┼────────────┼─────┼───────┤\n│Σ │Repr1 │101 │341.00B│\n└────┴────────────┴─────┴───────┘" + == "┌────┬─────────────────────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼─────────────────────────┼─────┼───────┤\n│.a │int │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.b │str │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.c │float │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.d │str │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.e │list[int,int,int,int,int]│5 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.f │set │1 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.g │dict │27 │100.00B│\n├────┼─────────────────────────┼─────┼───────┤\n│.h │f32[5,1] │5 │20.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.i │f32[1,6] │6 │24.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.j │f32[1,1,4,5] │20 │80.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.k │tuple[int,int,int] │3 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.l │a[int,int] │2 │ │\n├────┼─────────────────────────┼─────┼───────┤\n│.m │f32[5,5] │25 │100.00B│\n├────┼─────────────────────────┼─────┼───────┤\n│.n │bool[] │1 │1.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│.o │c64[2] │2 │16.00B │\n├────┼─────────────────────────┼─────┼───────┤\n│Σ │Repr1 │101 │341.00B│\n└────┴─────────────────────────┴─────┴───────┘" ) assert ( @@ -165,20 +162,6 @@ def test_tree_diagram(): assert tree_diagram(r1, depth=1) == out -@pytest.mark.skipif(backend != "jax", reason="jax is not installed") -def test_tree_mermaid(): - assert ( - re.sub(r"id\d*", "***", tree_mermaid(r1, depth=1)) - # trunk-ignore(flake8/E501) - == 'flowchart LR\n ***("Repr1")\n *** --- ***(".a=1")\n *** --- ***(".b=string")\n *** --- ***(".c=1.0")\n *** --- ***(".d=aaaaa")\n *** --- ***(".e=[...]")\n *** --- ***(".f={...}")\n *** --- ***(".g=dict(...)")\n *** --- ***(".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".k=(...)")\n *** --- ***(".l=a(...)")\n *** --- ***(".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".n=bool[]")\n *** --- ***(".o=c64[2]")' - ) - assert ( - re.sub(r"id\d*", "***", tree_mermaid(r1, depth=2)) - # trunk-ignore(flake8/E501) - == 'flowchart LR\n ***("Repr1")\n *** --- ***(".a=1")\n *** --- ***(".b=string")\n *** --- ***(".c=1.0")\n *** --- ***(".d=aaaaa")\n *** --- ***(".e:list")\n *** --- ***("[0]=10")\n *** --- ***("[1]=10")\n *** --- ***("[2]=10")\n *** --- ***("[3]=10")\n *** --- ***("[4]=10")\n *** --- ***(".f={...}")\n *** --- ***(".g:dict")\n *** --- ***("[\'a\']=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")\n *** --- ***("[\'b\']=bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")\n *** --- ***("[\'c\']=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".k:tuple")\n *** --- ***("[0]=1")\n *** --- ***("[1]=2")\n *** --- ***("[2]=3")\n *** --- ***(".l:a")\n *** --- ***(".b=1")\n *** --- ***(".c=2")\n *** --- ***(".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])")\n *** --- ***(".n=bool[]")\n *** --- ***(".o=c64[2]")' - ) - - @pytest.mark.skipif(backend != "jax", reason="jax is not installed") def test_misc(): x = (1, 2, 3) @@ -251,16 +234,6 @@ def test_invalid_depth(): tree_diagram(1, depth="a") with pytest.raises(TypeError): tree_summary(1, depth="a") - with pytest.raises(TypeError): - tree_mermaid(1, depth="a") - - -@pytest.mark.skipif(backend != "jax", reason="jax is not installed") -def test_tree_graph(): - assert ( - re.sub(r"\b\d{10,}", "***", tree_graph(r1)) - == 'digraph G {\n *** [label="Repr1", shape=box];\n *** [label=".a=1", shape=box];\n *** -> ***;\n *** [label=".b=string", shape=box];\n *** -> ***;\n *** [label=".c=1.0", shape=box];\n *** -> ***;\n *** [label=".d=aaaaa", shape=box];\n *** -> ***;\n *** [label=".e:list", shape=box];\n *** -> ***;\n *** [label="[0]=10", shape=box];\n *** -> ***;\n *** [label="[1]=10", shape=box];\n *** -> ***;\n *** [label="[2]=10", shape=box];\n *** -> ***;\n *** [label="[3]=10", shape=box];\n *** -> ***;\n *** [label="[4]=10", shape=box];\n *** -> ***;\n *** [label=".f={...}", shape=box];\n *** -> ***;\n *** [label=".g:dict", shape=box];\n *** -> ***;\n *** [label="[\'a\']=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", shape=box];\n *** -> ***;\n *** [label="[\'b\']=bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", shape=box];\n *** -> ***;\n *** [label="[\'c\']=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".h=f32[5,1](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".i=f32[1,6](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".j=f32[1,1,4,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".k:tuple", shape=box];\n *** -> ***;\n *** [label="[0]=1", shape=box];\n *** -> ***;\n *** [label="[1]=2", shape=box];\n *** -> ***;\n *** [label="[2]=3", shape=box];\n *** -> ***;\n *** [label=".l:a", shape=box];\n *** -> ***;\n *** [label=".b=1", shape=box];\n *** -> ***;\n *** [label=".c=2", shape=box];\n *** -> ***;\n *** [label=".m=f32[5,5](μ=1.00, σ=0.00, ∈[1.00,1.00])", shape=box];\n *** -> ***;\n *** [label=".n=bool[]", shape=box];\n *** -> ***;\n *** [label=".o=c64[2]", shape=box];\n *** -> ***;\n}' - ) @pytest.mark.skipif(backend != "jax", reason="jax is not installed") @@ -270,9 +243,25 @@ def test_tracer_repr(): @jax.jit def f(x): out = tree_repr(x) - assert out == "Tracer(f32[10,10])" + assert out == "DynamicJaxprTracer(f32[10,10])" out = tree_str(x) - assert out == "Tracer(f32[10,10])" + assert out == "DynamicJaxprTracer(f32[10,10])" return x f(jax.numpy.ones((10, 10))) + + +@pytest.mark.skipif(backend != "jax", reason="testing jax specific sharding info") +def test_jax_sharding_tree_summary(): + import jax + import numpy as np + from jax.sharding import Mesh, NamedSharding, PartitionSpec + + x = jax.numpy.ones([4 * 4, 2 * 2]) + mesh = Mesh(devices=np.array(jax.devices()).reshape(4, 2), axis_names=["i", "j"]) + sharding = NamedSharding(mesh=mesh, spec=PartitionSpec("i", "j")) + x = jax.device_put(x, device=sharding) + assert ( + tree_summary(x) + == "┌────┬───────────┬─────┬───────┐\n│Name│Type │Count│Size │\n├────┼───────────┼─────┼───────┤\n│Σ │G:f32[16,4]│64 │256.00B│\n│ │S:f32[4,2] │ │ │\n└────┴───────────┴─────┴───────┘" + ) diff --git a/tests/test_treeclass.py b/tests/test_treeclass.py index 67da1f3..61887e3 100644 --- a/tests/test_treeclass.py +++ b/tests/test_treeclass.py @@ -15,11 +15,12 @@ import copy import dataclasses as dc import inspect +import os from typing import Any import numpy.testing as npt import pytest -import os + from sepes._src.backend import backend, treelib from sepes._src.code_build import ( autoinit, @@ -29,8 +30,10 @@ fields, ) from sepes._src.tree_base import TreeClass -from sepes._src.tree_mask import freeze -from sepes._src.tree_util import Partial, is_tree_equal, value_and_tree +from sepes._src.tree_mask import tree_mask +from sepes._src.tree_util import is_tree_equal, partial, value_and_tree + +freeze = lambda x: tree_mask(x, cond=lambda _: True) test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy") if test_arraylib == "jax": @@ -147,7 +150,7 @@ def __init__( test = Test() - assert treelib.tree_flatten(test)[0] == [] + assert treelib.flatten(test)[0] == [] class Test(TreeClass): def __init__(self, a=arraylib.array([1, 2, 3]), b=arraylib.array([4, 5, 6])): @@ -155,7 +158,7 @@ def __init__(self, a=arraylib.array([1, 2, 3]), b=arraylib.array([4, 5, 6])): self.b = b test = Test() - npt.assert_allclose(treelib.tree_flatten(test)[0][0], arraylib.array([4, 5, 6])) + npt.assert_allclose(treelib.flatten(test)[0][0], arraylib.array([4, 5, 6])) def test_post_init(): @@ -200,7 +203,7 @@ def inc(self, x): l1 = L1() - assert treelib.tree_flatten(l1)[0] == [2, 4, 5, 5] + assert treelib.flatten(l1)[0] == [2, 4, 5, 5] assert l1.inc(10) == 20 assert l1.sub(10) == 0 assert l1.d == 5 @@ -212,7 +215,7 @@ class L1(L0): l1 = L1() - assert treelib.tree_flatten(l1)[0] == [2, 4, 5] + assert treelib.flatten(l1)[0] == [2, 4, 5] def test_registering_state(): @@ -414,7 +417,7 @@ class Test(TreeClass): t = Test(1) assert t.a == freeze(1) - assert treelib.tree_flatten(t)[0] == [] + assert treelib.flatten(t)[0] == [] def test_super(): @@ -522,10 +525,10 @@ def test_partial(): def f(a, b, c): return a + b + c - f_a = Partial(f, ..., 2, 3) + f_a = partial(f, ..., 2, 3) assert f_a(1) == 6 - f_b = Partial(f, 1, ..., 3) + f_b = partial(f, 1, ..., 3) assert f_b(2) == 6 assert f_b == f_b