From 8782988e6a3ffcd10b96a9d17f13f9269fbc66a2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 13 Feb 2024 16:29:48 -0500 Subject: [PATCH] continuing --- WORKSPACE | 39 ++++++++++++-------- src/enzyme_ad/jax/.primitives.py.swp | Bin 0 -> 57344 bytes src/enzyme_ad/jax/compile_with_xla.cc | 8 +++-- src/enzyme_ad/jax/enzyme_call.cc | 5 +++ src/enzyme_ad/jax/primitives.py | 50 ++++++++++++++++++-------- 5 files changed, 72 insertions(+), 30 deletions(-) create mode 100644 src/enzyme_ad/jax/.primitives.py.swp diff --git a/WORKSPACE b/WORKSPACE index 99086d318..08c33578b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -33,13 +33,18 @@ llvm_configure(name = "llvm-project", targets = LLVM_TARGETS) XLA_COMMIT = "c5163ff997d8be8fd32136e25050fa32c67c989f" XLA_SHA256 = "" -http_archive( - name = "xla", - sha256 = XLA_SHA256, - strip_prefix = "xla-" + XLA_COMMIT, - urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], - patch_args = ["-p1"], - patches = ["//:patches/xla.patch"], +# http_archive( +# name = "xla", +# sha256 = XLA_SHA256, +# strip_prefix = "xla-" + XLA_COMMIT, +# urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], +# patch_args = ["-p1"], +# patches = ["//:patches/xla.patch"], +# ) + +local_repository( + name = "xlae", + path = "./xla" ) PYRULES_COMMIT = "fe33a4582c37499f3caeb49a07a78fc7948a8949" @@ -60,16 +65,22 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() -ENZYME_COMMIT = "ed28cb68ccf47b5ff2594421ad62f878be562b03" -ENZYME_SHA256 = "" - -http_archive( +# ENZYME_COMMIT = "ed28cb68ccf47b5ff2594421ad62f878be562b03" +# ENZYME_SHA256 = "" +# +# http_archive( +# name = "enzyme", +# sha256 = ENZYME_SHA256, +# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", +# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], +# ) + +local_repository( name = "enzyme", - sha256 = ENZYME_SHA256, - strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", - urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], + path = "../Enzyme/enzyme", ) + JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248" JAX_SHA256 = "" diff --git a/src/enzyme_ad/jax/.primitives.py.swp b/src/enzyme_ad/jax/.primitives.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..59871542f050b4e4f1a65ffde2de3c7fe8c21245 GIT binary patch literal 57344 zcmeI537lj_dBB@+DYt}Eh&~#K-E4Qy&T?jSz=d6qW!YW!0CZg1H{Gvio9&B!-Lo^h zA|zai1`|Rwm(x$m)Nu=~@fif1 zcu6FXNMLFSbUT&&F}tofd`|yt@py3hAnOnJTru@MiEAQ(L;{Hf5(y*{NFLl}jo7n&D%V(zas*LOAl*Tmd+G1sf+|5sz~JDKadng375 z+`VwU`xhC38gHp>&3%LB|5-72b2<5!NFb3wB7sB#i3Ab}Boas@kVqhrKq7%e0*M5E zClavBsnlVFJS>VT>;Inq|IOV~se9oTxEPvn7A%9Ma1tC2+rtwm#dpKia6X&|BT#{b za0M+yIHSPVPDV<`2P!||{S+<_9_fkWYb6#DnVTj0k` z4qOLk!cK4-lL;5VQSby44y_9x5AZh8C(dfcT8E~PrB1=<;%@AjloX4>ztdbbQ`6D z{K=G>jZvrF$(Ne-R?Qj9cba*(V>ilnyPVCHTql+2m5EMD~)bF z-*(otad8Lp{n3!T2oHBF6{lTYt8SHO*PVLXsf3a?RU^`3w`;X#$?iC16WCz?x!F*F`jPpvWBd{4Y*aqlhk6iFEUjKAp)23WY^{^8 zHuA&clB)jMbB>6nE0#m%S#8Jd);fO7wHQNDYPFm7X2*#pX(kymGs*YN?}+%2?6{V) zT?@;~%cfejVb?}P^0j8OB_)6|6)Odi08Pddrj{FbsxRTdN1alqN#3iB9S^%SY;1R8 z(ufrOu;+@zt-7|_7|DC3$qBnl0gA*;gcdKHJw-H=9B^AssamNzZ6%!AXg!pzF#%1? ztVpz+daE;@?>G&&*>-a-g~TOb3bNEy>aw3Ybwzt{sOHVKJu;=RLNPM=NCl_1O6|=z z>^c=i<*8G{YN|oSiHd6`QkY>A&{N7&2q9?T!om?rn^L8O6>pCs0jiV*;TQpw9W7^R zzwD6_>ZdCuuHGy=x^2zoCXq(BAw`<|%J~vjYi`b~B04}-3RGnYNj>zmyj!Zo#2YEt zp}R~}D_z&gJ2j{7G&*ald{-%IJ3O~kt;O=Hr+{aYfx3mtm#cMBw8^cz+bM?<;z`I| zWw)fRH{4F!t~NTR8R5~*8p2o3<$PxxPmwG%{gkN+;i|1eq&Vm=wdNZ7lEw%H!t{r} zhe&|RAFX(qo>DR~@Gn7CM;bNCLnr|{lt>V&*IiYkJ#hS`rDYRvm8J?Qvc10zY*L#2ITz%R~ab(F84FI3)2D3C%P;Qcu)F z<1(}vFRZFNooXo*A8EpUK}Uj9`JkfWyOZ~wcC(hRkj>$!M1Aq_XvF;chuclNtm<|| zz`>JL)9jEDGPY60O-UsZFF!QWNfOr?LCN^6>2jBhlVlJV9XttQ6)~iI7CP+Ir90_tyBx)WK3RH5X#egR)wn(yO(v(c>=4RaCMxGT2+|Lo}RI_D| zs19)YF-WJ&SU8)D6drrH8Y)++NT@XFCc!J5L3N14=-iyiMv=1$-KDa3s2dVj$<)C* z4O@6%uAWe-Neyg7%bVR8&r};U8P2-0Q?VHz=3CVky-vd!4A;$g*Xj*8+fJw3ZdhrO zS1h%7Pcx2t8R=mnY1rFe?fvR&7|mT9(BLCg-MQW8qL+vpI0HE1oT zWTJY_U360M)VD2M%oSr_EhVq8&~A)$8FLiGP5WB7@bptB^CfQD*9l9OOy&tE^gV{b zR>49oX|_AzN2XHd8phoPhHkBK%XY1^8!XnUd*m|Bly|2RNv(<~d}mw>^;)%EC|B(o z<4)J}%K)Fonhc%b)$>tp`%VRVEvHR`%wInQG+;jxz1-!oD!TSqQNH`Wy$96nu~d*g zKdtV#>s=P*s?%)NTzxNHB+ZU@``pI3zDy4xNt3pvU=No(_Y=r7Nn_4h=A6?N)lPd? zWmW?9JTl_!Ae~eyolY;9d#2@gV6IhZwk>UOZ_&Iq^ivA>ONst}DY~)ft)l;X>(ie^ zufG|F!G=BIA@uiuhL6HWU~^X(^Q~cqvl%8~i{hBG1%=2Brd`EIvIMhE53L??-q(%jua|Er4SM6Z>`V44wfJaD(!Z8H(h zFx{k6hD_Os{thDoZjYcnu`uYZRH{W8JZB;gSsD2yj~9WHQBGl*OqP`Aw%hi2rdd$w zK+c7-Opj#Mc&uqv8%pxZQ9V8qTr|EOWOriV~vy zWkhB7Mg*UeJmQWPs~-M0 zog3mT5$oU>ikJjX@<;@^rXJ|uhi@JEjO~?d#L7w(zjXoECES)Ii42QD-UlT>bds@% z5A_s_51L^CF}azsf=x;E9%6o(I#jJ#WKq>Ec9{rd5y2@JtU0>Mi>wIF_b@bDU)MqtaSlx$w}}bYH-QtsR341m4q^)RCYJssd#4k7D4*G$<0Pkhh$Io2}=W>ITc8Lj(E#l{w1@O<&U` zXH-0Mk12UUq47$(DV?hFc_FB(s0+Imkjm#s=n|{8U2+NtOQvinG0hT_#j50}9p}1% z4+Za64|QLqTdNhB{dS#jI!H!VRUHvC%dS$pr6w~+uB%Rv<{8AmOQT7g2{UJM-H3!* z;tdU^OOY3z%DtR)xikk_LRqE%uSQ>f7)1Z~?ECLP-@gcshJD~a(eKBh01u(d{}mh$ z&!W5E0A+Y9JcFKoExa9G3%8@A*I^dOGp@jo=>EfSGVB09!CxH~!oKhm{K`82De!&n zxgBnX>*RiT7n}(D!Y@RZhuh&sSPwVA{D8h-b5Bw5c|3UZ^Tnc{)?}IM90~Wzt z$iVZ=i+mgIfGeO04XDEz@D`X22g6SAU2Fs~lDRl$H(9*2+$OV!wd%;K&KgJlGTq1;CNovd3T`r%vN}!68IuCn zu{x`qlfe@?#zt(oJ2E1+J>3TV*x*QcxW_aB{?L&sdDh%y#eVR}-hL!3j1DTz{Av`eL2aK` zLrsZ!Ct?ipb+n_~I=oDbXzNXR zg0_3au8Bu;Pml1{RU4&Rx9sGHEB#0HuqQDg=!Z|bff)vQ%Z;}08m(9+`vQps?ukzsaB3kdJ4mF zwsAG5=G!_(?V+M^j#_VtS!66$zcEz&f>?o6M{s5}+x3UKkWx7X5ug3LQ3f4N}A8 zU(9T!%5Y(^2NcG1FmrhHsBk?We62 z3)VbC{DzB-76xr9^CO|@Io-U$P+7rX{uBHtc`Pr%>6xo{?^{L~tM@;8ydOe7#<#VtQG)N=_j znJ!{dTg|v_z;w|@N%Hr44r!w#l1wbo4!n9L_p;44y?RCVvQ0s~dPSDZOQ~6|jh>g% zD;D1DRJw0&3#Ho3-}IeXVjoulrMvG0_JYVNYTrC()TuSGkndZL$#ZEHX8Nx7H&3tR z7RFHJj7wkFZTFekme}N6ni)KdvDtCR-WT~&%2zSeX8kMfaRZ{X1k))N_3Limsdj1c zvNH=Ub0HM%W%Ew)o!#>4%@IZB6I;wq6wYo|8yV&a)p(&;#<~T&mb91tn7va-055yO z`%1HB^^Nw&B|mA7EYv%hLQ*}US+BE!X(97*QZC%Sn3#J8-EKBJAqO+Vnv%lEEUlM! z1zjvMS!;06O3$;~YpRVjc0tOU-r>;H<_=6w%YFy%MrCbO!r*Ge%_tiSxud5oP$UD( z4Kl4zNc(?FG6Ih z_WjEoz+>Czz!hi0Ne$4!nJS>G@$`+h5@kPx6A>24laVXK^~q) z-@g(*3@hQCAm;^0`aT2=D8gB=5{`z~!S?Vxa(EbSg1>^ZVL2>=rEoax3fsW5==Bf4 zJ#YhD0%Ajuh6CXMcs;xpeuz&0MX16;I0Y8K02~BAK#zX{9tY9+#eU#o5F3IzEQdp2 zJ9rA+{@ZXroDb*064(QFhhJa^@N@V!d>UFX0KaCA=TZ0=tbpxdJ2)IvnO5c6vk{|+ z4XGt+WxSKADia$qT*NHLVFQnx72t?nrC!J8dc~@W<7%hX(FRKFqhfQ;V3?=~`(QNj&Jh(L_lqB2s<7L+5#cv!Ky#cCsX zq<6*KF{zO*GS)V_rH(d#iv|<3Z6!3F%#_58Op!K38Zh4k*K5HV$|ev0+EP%PC}d5O z;pHpNGS!`_+pSE^t`C=OYixkI=+!L2wJ^?8#;n?+NClA9=?gTdZ~*2uW2uT6tc!OQ zPw_871qH?5FjY`pedaKUftYTpJzyfcVy8*Y=-YBuARNNNWKxGK_7_21V#XHiB|F`& zb)6a7S_I*Q%uae5ylf6*3@>szEK;XzR&?qYnqAUMvC&iGqp?oZWW$V2Rnep0$e0Rl zx}bIG%Hea;golkKTWsyl;dAI?Dn}g=zIN0R=MHrk7PWO&Um3%&4%I~Hg5ibZ#UZUCAbp&Ee!AT3QorOvRYj6f4Dl<{U2{c6 zQLY?y1Xm`DpxJ7#N=7KQQXe|wqyf<`<)9I#yj&eVu$*!y@mI<5N_HqcL5=N|*`-Bhir~w*ef2t(w!(oc0~R!Mi%+ zP^A6OS*&|`yMr@vom*_+BfE+xc(79PSTT8C#<>~P>3U1PP?y5PPH`_X7BAvllw)>i zcjP|0^ijO3e;7ohS)kUj=dg-4_1gQX6&Epf_r}ipehI@b)C9CS#wA<%P(nsoy1!0| zbmg-N;|M(jl$@dTdo8FP+5zC_-b$ULL9w6WrofZAi(T8J#o#Kp20wJT{&R5ilQ{Ec zCtmJTB9~X&LWInIPTmML*yS-zzpO_Dj0Hk{vs}}iuj!vB>wDGCj8La5LxfPF2z9HG z-mj_p%jkY6y=OhNDQ1otT*{y>-282=IW3d6Q{{B1c1S@!SQ=V(JTM57fi}};9fvI{ zO8>tW9rQ=&siOZc;GcY6M6bUOu7JOR3*gV-Phkj-ge)8g+rqu*_&0;B`^!Fo5)6Y4 z$HRfJ2W$&ZqT|asfS-eZfcs-mn*J3y-4b-v%W(5$3}@I06oUN6`5n2C*CX z0{lIk4;>hRMeqjr4RZiLf^UHANw^IhD8rHP9QyyW@NJN@1!d2{hhYW05gtY7zY8vb z0eBL5DVb?;C4UnMBoas@kVqhrz$;V&G76ASc$mC7lqtfKF&KmHWRf?TzJN|I`o7rwS70f4w*If< zyXgE6n)`0&|0QrL$a((Qh)(|id>-Vz0oT9To#BfgNE7_&IYaAA@ligGKN<5W9Xk zBjAUu;eQiufltCmLDu`PhKoS#{O7?jFbnnsvG4yPTn8TnIWyo`7=-6p+y52FTK}iu z@8Lsm9-IaH!uMI@zXtvSPKN(8X94K(pc)Y-BRHqMH7(PpB=3Vpl9s|>=%2VwKQv&; zkwoHM%^s@ba8`P&?EYoHxk*;2Aw_+VMXwBA}h6C^B%pYYh76~=zc>HY5Sz>gdbUX|#*!5V{8|;c=c{Z4)_gOn$ zcM8l$>-}h8m9}|waHm}4UL>e3J8*Fwtb{mjMDKibGebcbEA{<>Ktij! zx}-_U)Pk={j{FKQ(9S*noPKld*%u-BB zz6Wtti^Nj+qHDw@-o4Oqyfaqh)a%ul;5r;GBDYIixIz7UrK}Wh%H)O|7shOhvWW|F zZ7W}U4lK2*oc}d2$CPTx|4mqfHp7?AY0IV~m(|rhs09?zDSJvuQRpVU+U8A-B^~1? zbd0j#xhWlEzqV!T+4TQTdq!=Kx7j<#ph?-T1qWnKPVAo@S*xcYn?{r)Csf$Z-;2=<0;;3xQf1Vrz@3H}`9ep%;# z4eSWNAfSJQbK%`^8oUXfM1OxA9)mBydU!vqgS8;%0u2HTC;X-J` zIUsuf?jUCkz6S0`xBm>>3NEzabXWoXFdHn`0iHnbzY<2^FxVZwj;?<`%!b|I3+VKp z201TK_V)L~L9jP$1G2aO5s-chUPFOK(Um^Yhr;9~y(n&GeA| zUJk7|%2hFGW*-PKj_d(pY7vda%nJm*y_lhUrsQt<1fA@V-xaITFdFKFL$Bn0i%8BZ zjdD_KN&<;@-*PPoxJM#W;v*=rVnwKp&%C_@zS=jHcFS!W)q%$!r;-$nflS7t>jqKW z+PZVu;JLXK-KFdoZF88D%6%yTufXowUoIb26Y?8n86A`wNhAfE`l&V_P*v5h*dD8R zd5mthN=YE{iF=baOzq%GDQEupQw2C}4l9C#B}Ajf`GWD2fL2hmZ1Ygx0^Ki^DT%iK ziy}HvXAei%)7S$^ZK9^Go*#ZfP=S<_7%}_Ni$trR^qEi=Q3Z6@S5r0bW5`9lh_Jv!3RP10<3~XFbDczSCBB?0dfw&4Ip~~E(h5QuoUDyfro(D2Z(LJrgV+G<3qL^Le*zwd``{yRHCzSq zE`hUPCA<@6!S?VV`u|7aG*|)|kp2F$*Z(oN2G+n^;aBMSzl1Nt)i4TghNsZ;pM)>M z`5@>2p9}AXli);{53<((0(!r^1K@gSf}HJm7(5eN^GAo0gN8G*mZv8t%?mj!_V2Nx zQ+HV6YpB;(@NyOPst(!TE7iECBTa?84OPPI$O~E29UPn?1JMrsoc}Tm-3U*7Ecf$j zrB?i2^~w+vL4B+Fff{P!Cy&PXbZE9LCFq^#3=sIPV%unY>BFBaVT)h=9v9`V4?GRI z9LkCfyX9xA-M}*89DXZGM!jgmybCGz9u@x$qD8k|>iL?{a7gkx(UwCk)2Z^;Q0XDr z5gA^zlhsRECzEy8VczVkURh*Lml;gsl2)&NEvdJqqSueGU7sb(kREJSYcm{Y;jJ_J z-u&L!TauG5x|NH(PssDbHqpp)v#H2U@pAD2IXsN`Qt?duNoP%)p6bI*qICcq2F@e!q`gH>i9HvQM)BdaQd-V3YbN7}rpMH~Y)(DxEt3VWpC9`(J*kFX=J|MMPD zX@niLaf0ute}+A%m%?ae#!Ce$^cn}%0qEDMr6jh}d*H06C(L`Tqs*f`tAWdV+VGeuU`_J?&MZXbKyzq}G%0#afkc zZS@*L@vi1eMEk*nw80No-l6sm9;@G&#>JEbBO7-^F|s+VK;5k;c=aK%T+uIccP!qJ zQys0Uso_eK_v7$tAbS2)b|VNxm2K1LQotscSFB@=wOwLm^s=^Hs!Rkur2L$|?a~El z@*Yy$CuzHaTO)0kT#mL~+T(O>m-01r+l7}NorW%JFR$!%9Kv0PxiWnYc(+c`ffz57 zlyVaXCz^GyR1yzi7=NLmC5Yd|-BLvk(<@ntNkoXjh?xRS1zbM+)5a+iEIGK%`< zjlP}coPF*lg=%Gg!kM13B;iG!UI%-Wm85di;|h_Wu`yr{jy=|8wZ`KLxP^xCd?pIpbg67kCj| z2p52y{dYWw{x5a`yTdQh`R{~(g3DkG94N!CunXLU{(mQ23>QHaR>303!w=B;KMpR) zI|26tm6s|Px2iv0ixDh5BnwGf$%fRFevL(7loW@=od0BmKEii}YIvR18$Fq7IIte>2A?YF`_;5+oY5H)hMwMv-}* z=Hm5B>YBdpP4yV3WRJsEw8vqKE*mAgBV;yX3!U$q?2hO*I3%1`FE;|SjAVC&oZ`O4 z4*mUo?`LJUXH)vwNo+1v-)wqghFYP;CeN}|cj>`2g+KP1q0NPv?@$j41<%7mHVGTP zpD>7!n`|bO!Ij?szm&D%J<%IQ|F0UAF|PlwM*lwmocft_-2^A>$a((VAvMEL*D!|6hYSc4*`{5Ca;ozn^OW(FT$Uo(TH62n!2H6qeQY% zqSXnuJ|r6@Qc_Viu-g{P6Fv2=Cu!*FCtTX%29X(Bb85zx8kf{40#YAjmrZH__?u zg?EFj^NW4|w(uxA{a4@yxBz6$e-7ke8~8c;{LkPSkoA6f_u=huBaDO00jz-v91Tao zk+27+53Mx+y>`@*!k}d-$1|rI$Q;!^Dl~8fPa9u!x?Zo ztN;t%2ycLA(DVNXJ^>fN+h9A`7F2$1PPO;d9+uSSWMhLhCnbVNHa4tr+Vbu@&x}j@ zz>vv#xJyhnHmoVig8U?gkHJ=gVzRM81{lFE8-Gycl`;`ysFa1x940Si>ZU2|Qm4q9 z;MRyms$7nmqiT=S*`g|6427oHK7*HZmaak?8DlBAwQg*PYYqN3gSbOlZ$=Z5iT3CF zlON&D2r>fIW~;uJ;8NJ5=n}pc`4M(ZHbbOD((*9{G}#QH-cJ8J*e=?NicZIPnPf~52_Vw-7?)`8<9na31M)mh-3Mu3em4}% zy=~-9gUD6^-(r8mwhIJS_A}Tapyz_rZ0`*BQk0wlCO+EQ*fQu8W;XQ{+6+)?wNzov Jm@Yh<`hV?-D^UOd literal 0 HcmV?d00001 diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index 224ef5314..0052863db 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -30,13 +30,17 @@ #include "compile_with_xla.h" +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" + // Compile an MHLO module given as a string to LLVM IR using XLA. std::unique_ptr compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, bool xla_runtime, const std::string &pass_pipeline) { // Parse MLIR. - mlir::MLIRContext context; + mlir::DialectRegistry registry; + mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry); + mlir::MLIRContext context(registry); context.loadDialect(); context.loadDialect(); context.loadDialect(); @@ -132,7 +136,7 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, xla_computation.proto(), std::move(module_config_or_error.value()), local_client->mutable_backend(), executor.value(), {build_options.device_allocator(), build_options.compile_thread_pool(), - build_options.layout_canonicalization_callback()}, + build_options.layout_canonicalization_callback(), ®istry}, build_options.run_backend_only()); if (!executable.ok()) { throw pybind11::value_error(executable.status().ToString()); diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 13eb423cf..64a0cdfef 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -57,6 +57,7 @@ #include "xla/service/cpu/cpu_executable.h" #include "Enzyme/FunctionUtils.h" +#include "Enzyme/MLIR/Passes/Passes.h" enum class ABI { Primal, Forward, Augmented, Reverse, Tape }; @@ -204,6 +205,7 @@ class CpuKernel { nullptr); } } + llvm::errs() << "linkMod: " << *linkMod << "\n"; } if (xla_runtime) { ss << " extern \"C\" void " << fn << "(void* exec"; @@ -831,10 +833,12 @@ class CpuKernel { #endif } + llvm::errs() << " str: " << ss.str() << "\n"; auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/ true, pyargv_strs, llvm_ctx.get(), std::move(linkMod)); if (!mod) throw pybind11::value_error("failed to compile C++"); + llvm::errs() << " postmod: " << *mod << "\n"; return std::make_tuple(std::move(mod), std::move(llvm_ctx), out_off, tmpBuf); } @@ -1022,6 +1026,7 @@ PYBIND11_MODULE(enzyme_call, m) { mlir::registerAsyncPasses(); mlir::arith::registerArithPasses(); mlir::memref::registerMemRefPasses(); + mlir::registerenzymePasses(); pybind11::enum_(m, "Language") .value("CPP", Language::CPP) diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 1a335522f..66f1ab754 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -39,7 +39,7 @@ def mlir_ad(self): class OldXLAPipeline: def xla_runtime(self): - raise False + return False def pass_pipeline(self): return "" @@ -195,11 +195,12 @@ def __init__(self, passes=None, mlirad=False): test-convergence=false top-down=true}, cse""" + assert len(passes) != 0 self.passes = passes self.mlirad = mlirad def xla_runtime(self): - raise False + return True def pass_pipeline(self): return self.passes @@ -207,7 +208,7 @@ def pass_pipeline(self): def mlir_ad(self): return self.mlirad -DefaultPipeline = NewXLAPipeline("", True) +DefaultPipeline = NewXLAPipeline(None, True) def pass_pipeline(options): if type(options) == type(""): @@ -369,7 +370,7 @@ def _enzyme_aug_abstract_eval( in_shapes = [absmaketup(a) for a in in_shapes] if lang == LANG_MHLO: - (in_tree, func) = source + (in_tree, _, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") @@ -453,15 +454,31 @@ def _enzyme_primal_lowering( in_args = (*args_flat,) if lang == LANG_MHLO: - (in_tree, func) = source - avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) + (in_tree, in_idx_map, func) = source + in_idxs = sorted(set(v for _, v in in_idx_map.items())) + avals = [ctx.avals_in[i] for i in in_idxs] + avals_in = jax.tree_util.tree_unflatten(in_tree, avals) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") source = str(mhlo) + print(in_idx_map) + print("source", source) kept = lowered_func.compile()._executable._kept_var_idx - in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) - in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] - + in_args = tuple(arg for (i, arg) in enumerate(in_args) if in_idx_map[i] in kept) + orig_shapes = [] + seen = [] + for (i, shape) in enumerate(in_shapes): + if in_idx_map[i] in seen: + continue + seen.append(in_idx_map[i]) + orig_shapes.append(shape) + in_shapes = [shape for (i, shape) in enumerate(orig_shapes) if i in kept] + print("in args", in_args) + print("in shapes", in_shapes) + + print(pipeline_options) + print(pipeline_options.xla_runtime()) + print(pipeline_options.pass_pipeline()) argv = argv + ("-resource-dir", resource_dir()) + cflags() identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( source, @@ -486,6 +503,7 @@ def _enzyme_primal_lowering( custom_call = stablehlo.CustomCallOp( out_types, mlir_args, call_target_name="jaxzyme.primal" ) + print(custom_call) results = custom_call.results @@ -516,7 +534,7 @@ def _enzyme_fwd_lowering( in_args = (*args_flat,) if lang == LANG_MHLO: - (in_tree, func) = source + (in_tree, _, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2]) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") @@ -578,7 +596,7 @@ def _enzyme_aug_lowering( in_args = (*args_flat,) if lang == LANG_MHLO: - (in_tree, func) = source + (in_tree, _, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") @@ -644,7 +662,7 @@ def _enzyme_rev_lowering( kept = None if lang == LANG_MHLO: - (in_tree, func) = source + (in_tree, _, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") @@ -785,7 +803,10 @@ def make_zero(tan, prim): for o in kwargs["out_shapes"]: outshapes2.append(o) outshapes2.append(o) - shadconv = ffi_call(*args, out_shapes=outshapes2, source=kwargs["source"], fn=kwargs["fn"], argv=kwargs["argv"], lang=kwargs["lang"], pipeline_options=pipeline_options) + (in_tree, in_idx_map, func) = kwargs["source"] + avals = {2*k:v for k, v in in_idx_map.items()} | {2*k+1:v for k, v in in_idx_map.items()} + source = (in_tree, avals, func) + shadconv = ffi_call(*args, out_shapes=outshapes2, source=source, fn=kwargs["fn"], argv=kwargs["argv"], lang=kwargs["lang"], pipeline_options=pipeline_options) else: shadconv = _enzyme_fwd_p.bind( *args, @@ -891,13 +912,14 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: def wrapped(*args: Any): args_flat, in_tree = jax.tree_util.tree_flatten(args) out_shape = jax.eval_shape(func, *args) + in_idxs = {i:i for i in range(len(args_flat))} out_shape_flat, out_tree = jax.tree_util.tree_flatten(out_shape) out_shape_flat = [ jax.core.ShapedArray(o.shape, o.dtype) for o in out_shape_flat ] out_flat = ffi_call( *args_flat, - source=(in_tree, func), + source=(in_tree, in_idxs, func), fn="", out_shapes=out_shape_flat, argv=argv,