From 88498127fbff8d7fbb9089be2f5b39447a5d73de Mon Sep 17 00:00:00 2001 From: phisn Date: Fri, 3 May 2024 02:31:04 +0200 Subject: [PATCH] Fix implementation of SAC --- packages/editor/models/export-model.ts | 2 +- packages/learning/actor/model.json | 1 + packages/learning/actor/weights.bin | Bin 0 -> 5128 bytes packages/learning/critic/model.json | 1 + packages/learning/critic/weights.bin | Bin 0 -> 4996 bytes packages/learning/index.html | 2 - packages/learning/package.json | 20 +- packages/learning/src/main.ts | 656 +++++++++++++++--- packages/learning/src/old-main.ts | 283 ++++++++ packages/learning/src/ppo/ppo.ts | 204 +++--- packages/learning/src/soft-actor-critic.zip | Bin 0 -> 5275 bytes .../learning/src/soft-actor-critic/actor.ts | 112 +-- .../learning/src/soft-actor-critic/critic.ts | 61 +- .../soft-actor-critic/gaussian-likelihood.ts | 25 - .../learning/src/soft-actor-critic/mlp.ts | 4 +- .../src/soft-actor-critic/replay-buffer.ts | 47 +- .../soft-actor-critic/soft-actor-critic.ts | 281 ++++++-- .../runtime-framework/src/message-store.ts | 2 - .../rocket/systems/rocket-death-system.ts | 2 - packages/server/package.json | 2 +- packages/web-game/package.json | 2 +- .../game/modules/module-input/module-input.ts | 200 +++++- packages/web-game/src/main.ts | 10 +- 23 files changed, 1510 insertions(+), 407 deletions(-) create mode 100644 packages/learning/actor/model.json create mode 100644 packages/learning/actor/weights.bin create mode 100644 packages/learning/critic/model.json create mode 100644 packages/learning/critic/weights.bin create mode 100644 packages/learning/src/old-main.ts create mode 100644 packages/learning/src/soft-actor-critic.zip delete mode 100644 packages/learning/src/soft-actor-critic/gaussian-likelihood.ts diff --git a/packages/editor/models/export-model.ts b/packages/editor/models/export-model.ts index a134313c..4035ea24 100644 --- a/packages/editor/models/export-model.ts +++ b/packages/editor/models/export-model.ts @@ -1,4 +1,4 @@ -import { EntityType, ShapeModel, WorldModel } from "runtime/proto/world" +import { EntityType, ShapeModel } from "runtime/proto/world" import { verticesToBytes } from "runtime/src/model/world/shape-model" import { ShapeState } from "../entities/shape/shape-state" import { WorldState } from "./world-state" diff --git a/packages/learning/actor/model.json b/packages/learning/actor/model.json new file mode 100644 index 00000000..9b76bea6 --- /dev/null +++ b/packages/learning/actor/model.json @@ -0,0 +1 @@ +{"modelTopology":{"class_name":"Sequential","config":{"name":"sequential_3","layers":[{"class_name":"Sequential","config":{"name":"sequential_1","layers":[{"class_name":"Dense","config":{"units":32,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense1","trainable":true,"batch_input_shape":[null,4],"dtype":"float32"}},{"class_name":"Dense","config":{"units":32,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense2","trainable":true}}]}},{"class_name":"Dense","config":{"units":2,"activation":"linear","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense5","trainable":true}}]},"keras_version":"tfjs-layers 4.19.0-rc.0","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"dense_Dense1/kernel","shape":[4,32],"dtype":"float32"},{"name":"dense_Dense1/bias","shape":[32],"dtype":"float32"},{"name":"dense_Dense2/kernel","shape":[32,32],"dtype":"float32"},{"name":"dense_Dense2/bias","shape":[32],"dtype":"float32"},{"name":"dense_Dense5/kernel","shape":[32,2],"dtype":"float32"},{"name":"dense_Dense5/bias","shape":[2],"dtype":"float32"}]}],"format":"layers-model","generatedBy":"TensorFlow.js tfjs-layers v4.19.0-rc.0","convertedBy":null} \ No newline at end of file diff --git a/packages/learning/actor/weights.bin b/packages/learning/actor/weights.bin new file mode 100644 index 0000000000000000000000000000000000000000..839240a709a301c6387fcfd518ee1c1352f710d0 GIT binary patch literal 5128 zcmWO9ha;A27=Uq+s7S~N8L1>ACmQS zfbw4{&3AGp++R5&nLbQMXD&v5>oEsf%(Pyb`j^!BA1^ z?)!;ot<6S{#ad|hE*_jBtI^}e4tmh!3NfB!Pt)N%u~t+;QK@Jg$o@r){Wg#!Lse#m z_&DtTbpxFhG*H3k3`jX-;hA$D_}jUTitE&1_>>jcwR|60xjT#muQ>*(yQKNPe-7}y zey``$&weAH`ZGaK=_qIReGqCr3^|_3BF;=#n=eR7h4j(|e3^`1teIiL9d!`mpG_*9 zbNgLrte(pMI%yf-+n@lB+j?*=tF^hh7v<hR(v9se3&A_n8`GZ$ zf-rE9>^!v?qtrLRKTSh$W&aJ<-ZBf@dHGm%K$08r+RQ50xM0YqAYwRANFrX0{b;(;tDV${_TXICZ)E<{mGHhlm&T^7$0K@+Aqp=`*fge0s_BB5z;U1zfK#$ybx_N4<W~DYiS^5Y0KRhZ^AK~-eCTDHI(Qqhgsk6 zVYR9f$2hPUR~HN!cVg-O+X+lb(I#v@<^XPLAK1OC)3LIs4U30+sZ4h(m%ZKu*5YevsQga2W%ZQEcfCi85fE{VHH5P= zZ(seYEU(PH~lH=5z z3&V-&l8OZFF6;@5x5{ET?GjJ6d+WWit< z`^rKN&Ud-P?WYS^^f$%1^~PX3WjcQMRwWbO=;Qr(BbXoJ57nm*V}^JpOnRz^=i@4< zUUf3rFfSAnUJAh9I0;xe39{%uWfjog@rih@mBE;v7P#pn2K_N{AnRR*Ejm*n zYKv#p)R|+by>|jvy3HSMPaCM}hyrHK&ZUqX|BU?Vm*8gl-h*bLHSS#y1U|YBaN?s4 ziY*$TinHF(A5v=Eo^vg%y5mFa<(J?tzdSfmGRUZmOoObU?Km|=l3rfJ(QENFRnm(T zV9VqExbE00-qW3rP((xkKa!J))a*03|KF|rnC^J&>E)60hb1sHU6jA%^b7obw1Y~2 zDaWy0i*eHOR~V=~fJJfvDBoO;e;pJ#b^Yah$LZJTlRH*?dv|&M>kq12T<<(seqb>Q zbRB8h@c^1kWVoSl1^QmQ6TdwSMYnPpOqL(Vo%<%m$$hMNAFZ9}@C4h$;vJm#fZ z0pxyl#FCsJOyoQdY+2^YT)%&w?AxWzeI8XMMpsI4RI&uNe`tdyDHjwK#?is*AhcLJ z2lQUAVoo?D;YVWwI8auF*2iAr;D|Wq`wnTATL&+vU6lLV9gHXT+fnC}^H*Fj2a&o|-|A z`lW>KvAKpPL^i`5xDKPsc9WLmLENesK`az=8JUYQXtnqSmGx32pF{(oy5I>F-_pz2 z<<6i#TiS^3X)9bX7Rz?zEPe51Eu|OgO%o*tRKgVzPZCV|z2X z{VACY?1}_`P7yR`SAxyByL8}6HtY+vuL}642jgcaVq>clR4zOMd+OY1%=Hv9eKili z%r$1BV-KVa1nKy@7;^GdxoL4XsKLP;c3Z(GDw`NhXjc@p z&Jx9bs=zr6{l(S3iuh+n0oft=#9kWeVd71KAV;N$&C6Vcq9tauYj!!Kzd;t#74vCq z-y9Sf8^ZR!P8u0rNPUD$F?~$&e@09Ndx045RK_ z5$vvwH_5`5FTCY%imC9$PxkJSXkvEYEHo%bV3gnyVTNbosZ?3c@2m-yOr1gvD&1kV zjvczXX5ktdfhAXJseVNh35X8Fsk>SFJmo48m-t09tA)gWt}&=Ygh8?21n@}d2Gdy= zXv?-y`gqp`rfGC4j`Nxf2jVtCU?-q|^fc_u&;gyC793aq8j|%A>9s@BVEi^6JnHR< z8xxFEr0ik;LP@IJWI$}cnZY=peD-vACK$AfVpw<(wyUqfm~L_Sr740_)xF_yQx48g z{6zALUf~+eD`2`i2n?Q<;VTJ481DUo2c>Spj?dL3tH>M6oTMp#=s3QL45ZmVpMY&r z35te?k;ITboOAamD#d3oUjIyiH-~BKzlC;nF%fK ztcG<2DSyk+hSdw%^s94-R+}sij+0>TkL^O8CM9sL*b5SJm&qYP3XHu;07wtPeZz8? ztG)#nI?m!2LmYMJI?2k1j1dhfHRelxEieCW5;^@ch;C2_VViUFc{XXQX^BA$TRHD6 zEU+b_US%<5!Fm=O*wx_*q~q5`Jnh%CzFg%g(_Q6xOqkv9391FtbYlwA+U zO#8LKnEXrBTu;*duj5D#i3Dx!YEs!XLS5}-*pBt-F&4SIJX{#H9jyhmp#BW zI1E0%`^`wJB|t>agIkjG3(2JdRXW<=Pf|}Uf)TeuVs=vsn*9Uuw%K-=b>D>-^u3k@ zT^prq?mQ)v^j9-e)NhgjLpQi$u?3V{S?0mHbZX1~gHgfvN$n2Xt2309zYIzQ-dneM$F&5i2He&@|WA1+c zNqGZdz>7%)wN-}X<3JXqIoIP_s}|~+aE%rk43T-~^_Z85;;8ZY0{SP%vmfU^BF=ja z@l{L_T09(}22F;nLyr>MXc~*exN0n#p@eT&6Qs%f?TD7=I(w_`ri0 z-IFMK+7a5!Phf&{9u*g4!meU{d@Uh{nOT0k70q^jX|})}wGER8P4_{wZDnEBo#-o4tcb-EL*7JZMey|7owfqqUB38%v0F zn;zB_7t=0FQD|;62czf-n6^6-WW_JA@3mi$v%hN@e(DRF{v;OKe;2cR!UJINSO+VV zG$G?d%z3>{MmTkO5_$e46qY&1z@LCrddX`R{g#&kU3CI@)KEi4-@35R?`7iiML_pI zUJc*bWfuo-#6PA3%z0Xb0mGy$9cB$Ad-r!Z+(9y*4cMGbW|lpQv}g;pQ2*+d!yha+Lk zJfCQK&!&QwQEEq2sJF%nke5otQ^f|Lr}&g-Z5@ac>Iq!W`Oo8RTyyG0#>6RsqCURyo+0tt9-lf zVQ{uQ*`ZWV+zSg?iTr6;TP+GlH*l4c3I+7KT^?)a6M%o417KZKT7^%X5)}JN;~|J= zFJ;U{w{)ed&*~GgZpl$pEBA+C3LST;tZ=-HV!W= ztp%A4e`uL!2A#NkB5GC2Q|S8#_id1&)v=CX^X&*(xAr1du2F)V8y+xVR7Y!`YtVD< zS#aZK1WEsyOpecb#Eg}#0sYyAFp?ch8b8m&uNRGB>2^twj;*0KV*Bae8JZ~me1u+! zUIa0#=F&*pa#r|h8l>OwLJOr+@Vjvflm#Auv?Ckw_?_ceP_zK2DTh;|(|5=z%k$KA zh9wH$dJ(7J1cDp?lJ2|LiK2}YRX_WHYFl3-OYQT?1m9iI@>WPx+>@bHuAUm1>OiAT z6eM|0hI6gD;91&2Rx)Ra@gF@}scJ;}a`o`NmJ_%q)RNOdq2x=T4us+Ym>$0btWv6} zv&I^bF4_vh6>=aG^NJqWWCq!rMYzpv2{3nC0o!i33PzJyEdJlP7eXUxWW5GqXYRrF zfHLen)F z_YehUB3H4_l0SP^E>x%-#^)w8x%b--{r~CnK;gkX)<%ra+21}3pUuDEeGNX`DzoNX zo~ZD34<^yZT7M`!vkDe96yb}TdR(VgE6H>z!_fB>l-l}XPxAl`^2vt}2S)I!LMOSV H&`JLTl-1^w literal 0 HcmV?d00001 diff --git a/packages/learning/critic/model.json b/packages/learning/critic/model.json new file mode 100644 index 00000000..0cd31e71 --- /dev/null +++ b/packages/learning/critic/model.json @@ -0,0 +1 @@ +{"modelTopology":{"class_name":"Sequential","config":{"name":"sequential_4","layers":[{"class_name":"Sequential","config":{"name":"sequential_2","layers":[{"class_name":"Dense","config":{"units":32,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense3","trainable":true,"batch_input_shape":[null,4],"dtype":"float32"}},{"class_name":"Dense","config":{"units":32,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense4","trainable":true}}]}},{"class_name":"Dense","config":{"units":1,"activation":"linear","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense6","trainable":true}}]},"keras_version":"tfjs-layers 4.19.0-rc.0","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"dense_Dense3/kernel","shape":[4,32],"dtype":"float32"},{"name":"dense_Dense3/bias","shape":[32],"dtype":"float32"},{"name":"dense_Dense4/kernel","shape":[32,32],"dtype":"float32"},{"name":"dense_Dense4/bias","shape":[32],"dtype":"float32"},{"name":"dense_Dense6/kernel","shape":[32,1],"dtype":"float32"},{"name":"dense_Dense6/bias","shape":[1],"dtype":"float32"}]}],"format":"layers-model","generatedBy":"TensorFlow.js tfjs-layers v4.19.0-rc.0","convertedBy":null} \ No newline at end of file diff --git a/packages/learning/critic/weights.bin b/packages/learning/critic/weights.bin new file mode 100644 index 0000000000000000000000000000000000000000..3a9ae3a6e97f19a945d7e0a76bd13613265f2b14 GIT binary patch literal 4996 zcmWMp_gl|x8Cuuw>MUkkGy01^h z8xa{%JXuME%F2${`3IcGasF_w(QAj|6FEyz^7sJ85hLJylNfc^RKbah57|Tg(=e?g zl=`IXg!XOyl#{$0Ciom??^FAklU*ilsT_eN0XBk7T^t%W4G|J_5oQb+2}}PwM=P$6 zgBrCGnA>%GswZq&g^K=$r+$Ngs|-vt$T%WqUJh9K$Svi|F`pVDjdAQ-KjG%Z11OHoX{xN2=cCT`!7sbJXe?hKD%2azu*Yd+Y^Y>|CfO~HN8pO-%EJ9 z$danMLeTU}5P$vH3KsU&(tzD_?y2iWYv+2x)D?lKk-n1Rrreh*7wW! zs0bf^cExV&F^~}I?nctVlys`HuVU>UPubY{cJ#FRHuM`xVR5t)4i32i6E)K5@M|l) zxmpt=QWv3l)NZs&kE4yNTG@8FCD^>i3S~;|P+xZ^9<^6vs~iu|7%iamIVw;z8Skw&@4%<=QL4E}m?eQhHqh~vvu$YV&O|!+_BUaP+2s7|V>4(j}t|WtZ$fo`wXWdR@JF=J- z1kMGSZ$&I;>2|SXmIS1qD~I*Hu9WshhOD%W=-|S(+O=a^pkTcmCFE_Tv$4I*KsJm= zP8~w(WiH~7fQ>x9@-RHk8H&ez@A8up#jH_t5MJH#i5)6B4Hfmu5WHKFGJE1-)lqBK zvGtfZb?G>26z?GEcqj5-J{(^Ro{G8Q*<^6li^@H9&@D`dd&M{KxYSVk>X*ij8jZ)^ z9~oRwjUbIOAN1>1p}|57Es!@9>$DD{#a;L5LR=%1@p2yhuL_x5al@}0WC=y{@KV7a zekrp@Y`K0Vm* z8HOZ3U@fA#+*Gz7PFW@MZ%N~6%JS84cF8V?2%20MQ|L*mp=;>JcPDx`>>U`r@MdvI z$D!ta0IE4?l4D*tRESQp>RU0iOM58uci#XNmiGXxRih-c_R7gHOeUJ#u(%&=-jNx`rtS z=pa?}uq)%dxmIiU6XmyQMX#a#5i0X}w* zFFI@je!bAYuEOmjUniPDLAG=GK3xx-s5BWZL|f`++)^gl94VA9KFs$Q9Om<3GO5b; zvW`h@u(+?A9daB^-gb|`<7ObOo18+kG{4rAB^c5FhBq+v^cpCZwdY^WuF~0uC*k$l zv%LJ_NPODZ&LwUpP-~VXh9=G>)omw9z1p5`9Bk!gA_Ee~cY(f77{+c_gV7UCQfs^^ zmRD}$_NWObjE{16RHwDIJ8_Ct0vl?+iJqT{!^M3$lvGnexh=}ja3a>&zN zp0d5=>4(fC>i$uLm+X7Q`yRcf@HeHnYoH9KUFEnhF%^>Yin;#2I7~6q2dPJL^sO_N zmhrjBP#$LSIh`ugUsQF17ZR;P2a_4UIw~cRM@O^VUr_#aFhs?V-*So#!dLb<7(@Z;fkmMEu&d2P`Y7u*dmj|pIKQxg4Q4G8wvA+;_6Zdsp! z+9_eMAUPi-vvWyy(s!6yngbT@`H&cKoH_@|qwF$?IG#dT)CWb|CDnn`&i7*|HUh$2FTh?if+k_q3&%PC}D{h6WU_Y zLiz#EIc-NF&%^Ln{1m?QbR@O>dy;KE+gm#=J{h}TMswHp?@WC{5dOJ3nAG-trPH-K zkgb+Vi!7e-VM{Xcs7MvQJz~r{CIbxw8Tcp!;NGdLSe~Al$ZClfx*4y4g&2mq^i&8J*Rz{? z`ItQ+nXkWajGB zH4jis#NOUy?rCNND(d;D>=sC?|N8FX2Se_?dm;P$JCh!cF2!8uJf_n(g6%(?iFf}K zcywzn!qR-Yv~@05s4K^mQSaE$E9xZqNR{`Dj24>|dD9h}0^DM{44UI(NbliTC_drK z6n%j-`uu5j=Q(b{KQZNVSs+Ps=xB~%olQjsM5|n3vb`!;IpYxP?#pL_a}Thl80hjV zqAUBAF|6q|n~>`bD_Sl=_^O{wd*vYBc;o;(BTfO!fEv&!pG?JfwnBoOA^)|ol(`=1 zC)ZBKx10~dgjWlMuHxC~>|9B%nWw37+{L#d{`x2!-)$yDtEllD#~?hO8zzLM4a2*kkwVpvML6&MRh}wc!7SBn(6Y-BhdC{P zPRkwGCSgyKKHek|7%5n|P2~Pjr_gWLB4Ld5-#!kS%7e{EL*?@Qe3wn8U^zLIj~x1f zJ<1LhN_}$i@!UID^`U`%vvtS1&Gl&iB3C$jE}qiT!mx8_pzzsn7#1ly2s$5DVEpz` zyz7CRpc1NrD@vpBQOglhA5w%n+I2vA`9H#1&EatV%sOUPG7`;)jK#}^!?4<|nF+Sm ze715EcL`_(tx9$3bFris&(h(6&pS9SVT{MFsnTMP!MOIGDoUu0p^&>lJUv7S+w2d) z`j;%o>+J<6bhBO|r8o{nKso-6`L$oAowOCHu1ZS=a=F7`e zK)_=6R#p+HE&^;kEDV_2?vuBWVg%NCVF(H3 ziC5<#JvtTSF8%a^sHKZ^A9fg-L~ z@W6l1YSRVPAUshL&Uaoj;CFWAqG-o2^YVbx^z6oId|5Pzg2z`fpC^VieCAWWU}qhQ zbZpqulIgUeWvpI41Xn5i6h`$piVCe8fnrXS2exDqR#j8D7 zv${Kf{N0;3TsNRc#}zR^^Az2kG?%6s7^1DXoSJfvi8_k%`LD~zx%|gRwLN}T_>a32 z_(#byQ(Gl;*ElI&Hhv8{{g7Z^>pR7|FO%6pmjKo~Hl4a6s>JeNC)A3(YWS0GD<7a68D0ThLK>VDnn+>b06V?Mwl~ z?iDa)`b73Npa7Dtd*InsQe5+?9y@FriHnO~!nvkN^h9evy?yLNL;O9l@aIvIF-fFf zo6Ru!MJk*f2m;ekE|w1do6pq~=)%Z#c*J%m(`>E;$;I=re*Qvm>eI!$-D}~{CL0#t zxg8h3h=v}&;V}M;DV_Nm2*(Z_pit`o)c@+jV{$U;bpCGtn{OD|j5-Bv-(Rv_1+m=V z)@>GaA&1?FuxI1`+~&2$iY(#V2Pks&Lj^;3%Jq`K3yeWjS!g@1Nobm;j~VHl8yJ_mP3kymYn-!RMA5M+Y@}*5 z?z-3q9zN;Vxz&)(sn><0v88xKN*(H51Qww&g7lL&liWFH_A_=HdURxgQg#YW{5%Gy zy|`Z2aOfXQS!T{MJ{%SwU1v&1q$D7-Jb>N>s({kOLGUZY3OgE?V{xw*J$$c)=FaP} zq*W3sX1bE0+h7ctmPGHL4?&~6oZ64lk~plhg1%lU0=?;N5Mdz3>(l)y?%HhhkQa$% zHRgieuZ>{hD+`zIRIzp2)t#9rk{8vj&4A9HSP@5xMhny z&q=Yy$S(FF-=76$o}x_;b@+p0S~%!0-mEkx$zclUpJ_n_<7GI{&Ef~%F<$FA43}CJ z^N#qB++xdkycZ^-*7{iTt9w(Yp1GY&Tn^EMcuTbSA<3r?*pPD{V#E0X{%Er$URk3; z4wcPZCs7_xIm(dHu08brw>3{w*h_mBH1a*o#<=X-UUVFq1{+H~1l2?X%+-EE{N^OK z=2jU@aXk!X8vF3b5J%Qws0s0hqR?0^QYhZ!BSb0;7aJBj3qCP1f?laF8tX(u)v^E~ zDKHrCN$U%z;#T44o3es?LNGKbTcdKdqfqFLLUVW^B;7R;X8v;u{&`c3Ps=XA{{al( BMT`Id literal 0 HcmV?d00001 diff --git a/packages/learning/index.html b/packages/learning/index.html index a53dce46..842a931c 100644 --- a/packages/learning/index.html +++ b/packages/learning/index.html @@ -2,12 +2,10 @@ - Polyburn - diff --git a/packages/learning/package.json b/packages/learning/package.json index 36737a35..e1a070a9 100644 --- a/packages/learning/package.json +++ b/packages/learning/package.json @@ -2,8 +2,6 @@ "name": "learning", "private": true, "version": "1.0.0", - "type": "commonjs", - "moduleResolution": "node", "scripts": { "rl-dev": "vite", "rl-local": "npx tsx ./src/main.ts", @@ -12,33 +10,35 @@ }, "devDependencies": { "@types/node": "^20.12.7", - "@types/three": "^0.162.0", "autoprefixer": "^10.4.19", "postcss": "^8.4.38", - "runtime": "*", - "runtime-framework": "*", - "shared": "*", "tailwindcss": "^3.4.3", "tslib": "^2.6.2", "tsx": "^4.7.3", "typescript": "^5.4.2", - "vite": "^5.1.6" + "vite": "^5.2.10" }, "dependencies": { + "@dimforge/rapier2d": "^0.12.0", + "@tensorflow/tfjs": "^4.19.0", "@tensorflow/tfjs-backend-webgl": "^4.19.0-rc.0", "@tensorflow/tfjs-backend-webgpu": "^4.19.0-rc.0", - "@tensorflow/tfjs-node": "^4.19.0-rc.0", - "@tensorflow/tfjs-node-gpu": "^4.19.0-rc.0", "@types/prompts": "^2.4.9", "@types/sat": "^0.0.35", + "@types/three": "^0.164.0", "eslint-config-custom": "*", "lil-gui": "^0.19.2", "poly-decomp-es": "^0.4.2", "ppo-tfjs": "^0.0.2", "prompts": "^2.4.2", + "protobufjs": "^7.2.6", + "runtime": "*", + "runtime-framework": "*", "sat": "^0.9.0", - "three": "^0.162.0", + "shared": "*", + "three": "^0.164.1", "tsconfig": "*", + "vite-node": "^1.5.3", "vite-plugin-top-level-await": "^1.4.1", "vite-plugin-wasm": "^3.3.0" }, diff --git a/packages/learning/src/main.ts b/packages/learning/src/main.ts index 990ead65..6682083e 100644 --- a/packages/learning/src/main.ts +++ b/packages/learning/src/main.ts @@ -1,54 +1,179 @@ -function getReward(got: number, expected: number) { - function f() { - const gotRounded = Math.round(got) +import * as tf from "@tensorflow/tfjs" +import { Buffer } from "buffer" +import { EntityWith, MessageCollector } from "runtime-framework" +import { WorldModel } from "runtime/proto/world" +import { LevelCapturedMessage } from "runtime/src/core/level-capture/level-captured-message" +import { RocketDeathMessage } from "runtime/src/core/rocket/rocket-death-message" +import { RuntimeComponents } from "runtime/src/core/runtime-components" +import { Runtime, newRuntime } from "runtime/src/runtime" +import { Environment, PPO } from "./ppo/ppo" + +export class PolyburnEnvironment implements Environment { + private runtime: Runtime + private currentRotation: number + private nearestLevel: EntityWith + + private captureMessages: MessageCollector + private deathMessages: MessageCollector + + private bestDistance: number + private maxTime = 60 * 30 + private remainingTime = 60 * 30 + + private worldModel: any + + private touchedFlag = false + + constructor() { + const worldStr = + "CqAJCgZOb3JtYWwSlQkKDw0fhZ3BFR+FB0Id2w/JQBItDR+FtsEVgZUDQh3bD8lAJQAAEMItpHBhQjWuR9lBPR+Fm0FFAAAAQE0AAABAEi0Nrkc/QRVt5wZCHdsPyUAlAAD4QC2kcBZCNezRjUI94KMwP0UAAABATQAAAEASLQ2k8B5CFX9qWEEd2w/JQCUAAP5BLaRwFkI17NG9Qj3gozA/RQAAAEBNAAAAQBItDeyRm0IVPzWGQR3bD8lAJQCAjUItSOHsQTX26AVDPYTr6cBFAAAAQE0AAABAEi0Nw0XwQhUcd4lAHTMeejwlAIDnQi2kcA5CNfboMkM9EK6nv0UAAABATQAAAEASLQ2PYhxDFT813EEd2w/JQCUAAM9CLaRwbEI1AMAmQz0fhbFBRQAAAEBNAAAAQBItDcM15UIVYxBJQh3bD8lAJQAAeUItUrijQjXs0fpCPZDCM0JFAAAAQE0AAABAEi0N9WiFQhXVeIhCHdsPyUAlw7WBQi3sUY9CNcO1kUI9AACBQkUAAABATQAAAEAaTgpMpHA9wXE9ukHAwP8AAEAAPYCA/wAAtIBDAAD/AIDFAEBAQP8AgMgAAICA/wBAxgC+oKD/AABGAMf///8AV0dxQry8+QBSQPHA////ABpOCkyuR3FBSOHKQf/++ABAxgAA//3wAAA/QMT/++AAQEoAQv/3wAAAPkBF/++AAADHAD//3gAAgMYAAP/vgAAAAIDD////AKxGCq////8AGpcCCpQC9qjBQpqZJEL///8AMNEAOv///wDqy9pH////AOzHNML///8AAMIAx////wAAQkDE////AABFAL3///8AAELAx////wCARgBF////AEBGgMb///8AwEYAv////wAgSQBF////AOBIgMP///8A4EjAR////wAARYDE////AAC+oMj///8AAD8AAP///wAAAODK////AGBJAEf///8AwMTASP///wAgSQAA////AEBEwMb///8AAEOAQ////wBASQC/////AAA+wEj///8AwEqAw////wAAvMBL////AODIAAD///8AQMoAQP///wAAPgBI////ACDIAAD///8AgMCARv///wCAyQAA////AEBFgMb///8AGqcCCqQCpHAZQqRwOcH///8AmFgAwP///wCAxwhU////AGDK4E3///8AwM1gyf///wAAv+DI////AKBLAMP///8AADpgyf///wCARgAA////AAA6YMv///8AQMgAAP///wAAvuDJ////AIBFYMj///8AQMyAwf///wAAtMDG////AGDLAL3///8AOMAMSP///wAkxgCu////AADC4Mj///8AAMNARv///wBgyQAA////AEDHgMP///8AwMeAQf///wAAAEBM////ACDJAAD///8AgMMAx////wAAyoBC////AAC9AMb///8AgMTARf///wCAwIDB////AABFAML///8AAMgANP///wBAxEBG////AADHAAD///8AAMFAyP///wBgyEDE////ABomCiSPQopCcT2DQv/AjQAAxAAA/+R0AAAAAMT/kwAAAEQAAP+bAAASEgoGTm9ybWFsEggKBk5vcm1hbA==" + this.worldModel = WorldModel.decode(Buffer.from(worldStr, "base64")) + + this.runtime = newRuntime(this.worldModel, "Normal") + + this.currentRotation = 0 + + const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] + const rocketPosition = rocket.components.rigidBody.translation() + + this.captureMessages = this.runtime.factoryContext.messageStore.collect("levelCaptured") + this.deathMessages = this.runtime.factoryContext.messageStore.collect("rocketDeath") + + this.nearestLevel = this.runtime.factoryContext.store + .find("level") + .filter(level => level.components.level.captured === false) + .sort( + (a, b) => + Math.abs(a.components.level.flag.x - rocketPosition.x) - + Math.abs(b.components.level.flag.y - rocketPosition.x), + )[0] + + const { distance } = this.state() + this.bestDistance = distance + } + + inputFromAction(action: number[]) { + const input = { + rotation: this.currentRotation + action[0], + thrust: action[1] > 0 ? true : false, + } + + return input + } + + step(action: number | number[]): [number[], number, boolean] { + if (typeof action === "number") { + throw new Error("Wrong action type") + } + + this.remainingTime-- + + const input = this.inputFromAction(action) + this.currentRotation += action[0] + + this.runtime.step(input) - if (gotRounded === expected) { - return 0 + const { distance, observation, velMag, angDiff } = this.state() + + let newTouch = false + + if (this.nearestLevel.components.level.inCapture) { + if (!this.touchedFlag) { + newTouch = true + } + + this.touchedFlag = true } - if (gotRounded === 0) { - return expected === -1 ? 1 : -1 + const captureMessage = [...this.captureMessages].at(-1) + + if (captureMessage) { + const reward = 10000 + (this.maxTime - this.remainingTime) * 100 + return [observation, reward, true] } - if (gotRounded === 1) { - return expected === 0 ? 1 : -1 + const deathMessage = [...this.deathMessages].at(-1) + + if (deathMessage) { + const reward = -(velMag + angDiff) + return [observation, reward, true] } - return expected === 1 ? 1 : -1 + const reward = Math.max(0, this.bestDistance - distance) + this.bestDistance = Math.min(this.bestDistance, distance) + + const done = this.remainingTime <= 0 + + return [observation, reward * 10 + (newTouch ? 100 : 0), done] } - return (f() + 1) / 2 -} + state() { + const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] + + const rocketPosition = rocket.components.rigidBody.translation() + const rocketRotation = rocket.components.rigidBody.rotation() + const rocketVelocity = rocket.components.rigidBody.linvel() + + const dx = this.nearestLevel.components.level.flag.x - rocketPosition.x + const dy = this.nearestLevel.components.level.flag.y - rocketPosition.y + + const distanceToLevel = Math.sqrt(dx * dx + dy * dy) -const observationSize = 8 -const actionSize = 1 - -const observations = [ - [[-1, -1, -1, -1, -1, -1, -1, -1], [-1]], - [[0, 0, 0, 0, 0, 0, 0, 0], [0]], - [[1, 1, 1, 1, 1, 1, 1, 1], [1]], - [[-1, 0, 1, 0, -1, 0, 1, 0], [-1]], - [[0, 1, 0, -1, 0, 1, 0, -1], [0]], - [[1, 0, -1, 0, 1, 0, -1, 0], [1]], - [[-1, 1, -1, 1, -1, 1, -1, 1], [-1]], - [[1, -1, 1, -1, 1, -1, 1, -1], [1]], -] - -export class CartPole { - actionSpace = { - class: "Box", - shape: [1], - dtype: "float32", - low: [-1], - high: [1], + const angDiff = + (this.nearestLevel.components.level.flagRotation - + rocket.components.rigidBody.rotation()) % + (Math.PI * 2) + + const velMag = Math.sqrt( + rocketVelocity.x * rocketVelocity.x + rocketVelocity.y * rocketVelocity.y, + ) + + return { + distance: distanceToLevel, + observation: [ + this.nearestLevel.components.level.flag.x - rocketPosition.x, + this.nearestLevel.components.level.flag.y - rocketPosition.y, + rocketRotation, + rocketVelocity.x, + rocketVelocity.y, + ], + touched: this.touchedFlag, + angDiff, + velMag, + } } - observationSpace = { - class: "Box", - shape: [4], - dtype: "float32", + reset(): number[] { + this.runtime = newRuntime(this.worldModel, "Normal") + + this.currentRotation = 0 + + const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] + const rocketPosition = rocket.components.rigidBody.translation() + + this.captureMessages = this.runtime.factoryContext.messageStore.collect("levelCaptured") + this.deathMessages = this.runtime.factoryContext.messageStore.collect("rocketDeath") + + this.nearestLevel = this.runtime.factoryContext.store + .find("level") + .filter(level => level.components.level.captured === false) + .sort( + (a, b) => + Math.abs(a.components.level.flag.x - rocketPosition.x) - + Math.abs(b.components.level.flag.y - rocketPosition.x), + )[0] + + const { distance, observation } = this.state() + + this.bestDistance = distance + this.remainingTime = this.maxTime + this.touchedFlag = false + + return observation } +} +export class CartPole implements Environment { private gravity: number private massCart: number private massPole: number @@ -98,14 +223,21 @@ export class CartPole { return [this.x, this.xDot, this.theta, this.thetaDot] } + private i = 0 + private max = 0 + /** * Update the cart-pole system using an action. * @param {number} action Only the sign of `action` matters. * A value > 0 leads to a rightward force of a fixed magnitude. * A value <= 0 leads to a leftward force of the same fixed magnitude. */ - step(action: number) { - let force = action * this.forceMag + step(action: number | number[]): [number[], number, boolean] { + if (Array.isArray(action)) { + action = action[0] + } + + const force = action * this.forceMag const cosTheta = Math.cos(this.theta) const sinTheta = Math.sin(this.theta) @@ -131,6 +263,7 @@ export class CartPole { * Set the state of the cart-pole system randomly. */ reset() { + this.i = 0 // The control-theory state variables of the cart-pole system. // Cart position, meters. this.x = Math.random() - 0.5 @@ -162,96 +295,399 @@ export class CartPole { } } -const tf = require("@tensorflow/tfjs-node-gpu") - -tf.setBackend("cpu").then(() => { - const env = new CartPole() - - const PPO = require("./ppo/base-ppo.js") - - const ppo = new PPO(env, { - nSteps: 2048, - nEpochs: 25, - policyLearningRate: 2e-3, - valueLearningRate: 2e-3, - clipRatio: 0.2, - targetKL: 0.01, - gamma: 0.99, - lam: 0.95, - netArch: { - pi: [32, 32], - vf: [32, 32], - }, - activation: "relu", - verbose: 1, +import "@tensorflow/tfjs-backend-webgl" +import "@tensorflow/tfjs-backend-webgpu" +import { SoftActorCritic } from "./soft-actor-critic/soft-actor-critic" + +if (true) { + tf.setBackend("cpu").then(() => { + const sac = new SoftActorCritic({ + mlpSpec: { + sizes: [32, 32], + activation: "relu", + outputActivation: "relu", + }, + + actionSize: 1, + observationSize: 4, + + maxEpisodeLength: 1000, + bufferSize: 10000, + batchSize: 100, + updateAfter: 10000, + updateEvery: 50, + + learningRate: 0.01, + alpha: 0.2, + gamma: 0.99, + polyak: 0.995, + }) + + sac.test() + + /* + const actor = new Actor(4, 2, { + sizes: [32, 32], + activation: "relu", + outputActivation: "relu", + }) + + actor.trainableWeights.forEach(w => { + w.write(tf.zeros(w.shape, w.dtype)) + }) + + /* + x = torch.tensor([[0.1, 0.2, 0.3, 0.4]], dtype=torch.float32) + x = actor(x, True) + + const x = tf.tensor2d([[0.1, 0.2, 0.3, 0.4]]) + const r = actor.apply(x, { deterministic: true }) as tf.Tensor[] + + console.log(r[0].dataSync()) + console.log(r[1].dataSync()) + */ }) +} + +if (false) { + tf.setBackend("webgpu").then(() => { + const env = new CartPole() + + const sac = new SoftActorCritic({ + mlpSpec: { + sizes: [32, 32], + activation: "relu", + outputActivation: "relu", + }, + + actionSize: 1, + observationSize: 4, + + maxEpisodeLength: 1000, + bufferSize: 100000, + batchSize: 8096, + updateAfter: 8096, + updateEvery: 25, - function possibleLifetime() { - let acc = [] + learningRate: 1e-3, + alpha: 0.2, + gamma: 0.99, + polyak: 0.995, + }) - for (let j = 0; j < 100; ++j) { - env.reset() + function currentReward() { + const acc = [] - let t = 0 + for (let j = 0; j < 25; ++j) { + env.reset() - while (!env.isDone() && t < 1000) { - const [, action] = ppo.sampleAction(tf.tensor([env.getStateTensor()]), true) - env.step(action.arraySync()) + let t = 0 + + while (!env.isDone() && t < 1000) { + env.step(sac.act(env.getStateTensor(), true)) + t++ + } + + acc.push(t) + } + + // average of top 10% lifetimes + acc.sort((a, b) => b - a) + + const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 + const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 + const avg = acc.reduce((a, b) => a + b, 0) / acc.length + + return { avg, best10avg, worst10avg } + } + + let t = 0 + let k = 8 + + function iteration() { + for (let i = 0; i < 128; ++i) { t++ + + const observation = env.getStateTensor() + + let action: number[] + + if (t < 300) { + action = [Math.random()] + } else { + action = sac.act(observation, false) + } + + const [nextObservation, reward, done] = env.step(action) + + sac.observe({ + observation, + action, + reward, + nextObservation, + done, + }) + + if (done) { + env.reset() + } } - acc.push(t) + --k + console.log(k) + + if (k < 0) { + k = 10 + + const { avg, best10avg, worst10avg } = currentReward() + + console.log(`Leaks: ${tf.memory().numTensors}`) + console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) + } + + requestAnimationFrame(iteration) } - // average of top 10% lifetimes - acc.sort((a, b) => b - a) + console.log("Start") + requestAnimationFrame(iteration) + + /* + const ppo = new PPO( + { + steps: 512, + epochs: 15, + policyLearningRate: 1e-3, + valueLearningRate: 1e-3, + clipRatio: 0.1, + targetKL: 0.01, + gamma: 0.99, + lambda: 0.95, + observationDimension: 4, + actionSpace: { + class: "Discrete", + len: 2, + }, + }, + env, + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: 4, + units: 32, + activation: "relu", + }), + tf.layers.dense({ + units: 32, + activation: "relu", + }), + ], + }), + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: 4, + units: 32, + activation: "relu", + }), + tf.layers.dense({ + units: 32, + activation: "relu", + }), + ], + }), + ) + + function possibleLifetime() { + const acc = [] - const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 - const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 - const avg = acc.reduce((a, b) => a + b, 0) / acc.length + for (let j = 0; j < 25; ++j) { + env.reset() - return `10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}` - } + let t = 0 - console.log(possibleLifetime()) - ;(async () => { - for (let i = 0; i < 500; ++i) { - await ppo.learn({ - totalTimesteps: 1000 * i, - }) + while (!env.isDone() && t < 1000) { + env.step(ppo.act(env.getStateTensor()) as number[]) + t++ + } - console.log(possibleLifetime()) + acc.push(t) + } + + // average of top 10% lifetimes + acc.sort((a, b) => b - a) + + const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 + const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 + const avg = acc.reduce((a, b) => a + b, 0) / acc.length + + return { avg, best10avg, worst10avg } } - })().then(() => { - console.log(possibleLifetime()) - }) - /* -import { WorldModel } from "runtime/proto/world" -import { Game } from "./game/game" -import { GameLoop } from "./game/game-loop" -import { GameInstanceType, GameSettings } from "./game/game-settings" -import * as tf from '@tensorflow/tfjs'; + let currentAverage = 0 + let i = 0 -function base64ToBytes(base64: string) { - return Uint8Array.from(atob(base64), c => c.charCodeAt(0)) -} + function iteration() { + ppo.learn(512 * i) + + const { avg, best10avg, worst10avg } = possibleLifetime() + + console.log(`Leaks: ${tf.memory().numTensors}`) + console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) + + if (avg > currentAverage) { + // await ppo.save() + currentAverage = avg + console.log("Saved") + } + + i++ + + requestAnimationFrame(iteration) + } + + console.log("Initial: ", possibleLifetime()) -const world = - "CscCCgZOb3JtYWwSvAIKCg2F65XBFTXTGkISKA2kcLrBFZfjFkIlAAAAwi1SuIlCNa5H+UE9H4X/QUUAAABATQAAAEASKA1SuMFBFZmRGkIlhetRQS3NzFJCNSlcp0I9zcxEQUUAAABATQAAAEASKA0AgEVCFfIboEElAAAoQi0K189BNaRw4UI9rkdZwUUAAABATQAAAEASKA171MBCFcubHcElmpm5Qi0K189BNY/CI0M9rkdZwUUAAABATQAAAEASLQ1syOFCFToytkEdVGuzOiWamblCLSlcZUI1XI8jQz3NzIhBRQAAAEBNAAAAQBItDR/lAUMVk9VNQh2fUDa1JaRw9UItexRsQjWF60FDPQAAlEFFAAAAQE0AAABAEigNw1UzQxVpqkFCJdejJEMtBW94QjXXo0JDPQVvAEJFAAAAQE0AAABACu4KCg1Ob3JtYWwgU2hhcGVzEtwKGt8GCtwGP4UAws3MNEGgEEAAZjYAAP///wB1PAAU////AF5PABT///8AyUtPxP///wAzSg3L////AMBJAcj///8AE0Umzf///wCMVAo5////AJNRpDr///8AVE0WVP///wD0vlZLAAD/AEPI7Bn///8AhcPlOAAA/wAFQZrF////ADS9F8f///8AJMIuwf///wC5xvvF////AOrJ1rf///8Ac8ikQP///wBAxfRF////AGkxi0n///8Aj0LxQgAA/wB1xWY9////AJ/HZAlQUP4AzcUBvQAA/wDwQFzE////ADDGR73///8As8eZPoiI8QBxxWQ3rKz/AFw3LMQAAP8AwkNRtP///wC2RKO4////AEhBe8EAAP8AS0WPPP///wAdSaSx////AMw/Ucj///8A7MBNxv///wDmxnG9////AELCFLr///8Aw8UOof///wAKxCg4AAD/ALg8OMDZ2fsA4j9NwP///wCkxB+/AADwAHGwrr54ePgAVERcwv///wAPwXbA////APW0H0EAAPgASLtnv////wALM67DJSX/AFJApL////8AZj4uwP///wBcu+HATU3/AIU7+8H///8AXMK8Lf///wB7wjM/AAD4AHDCx8D///8AFEH7wP///wAAvnvE////AOTGChL///8A6bncRP///wCAQddAAAD4AB/AxLH///8AIL9RPQAA+ACZwqvG////AOLCLkQAAPgAIcTrwP///wDtwQPH////AOLJbqz///8ALsR6QwAA+AD+x8zA////APtF90kyMv8AH7mZQCcn/wCNxHo8tbX/AIDAiETKyv8AXEAgSgAA+AClyAqS////AH9EG0n///8AS0ypRP///wAxSIK7MDToANjBdUf///8A58yjxP///wCByD1EMDToAIzCYMv///8AnMq3MzA06AC+QenF////ANzGT0T///8AtMFSR////wBzRb85lpj/AFJALEQwNOgAqMIpPjA06AAgyiCF////AAPEE77///8AzT4FSnN1/wAzxWFCMDToAA23PcKXl/8AGcLmQDA06ADMPUnJu77/AFrGxsL///8A1TRGSjA06ACKwik8MDToAE3Apcn///8Ar8SawP///wBsygqP////ABHI8z0wNOgAAABTzv///wAa9wMK9APNzJNCj8JlQP///wBmtly8////ABa2jsg2Nv8AO0SENwAA+ACkvrtEvLz/AG0uOEX///8A4UaHPv///wA+QlXFAAD4AApB2L4AAPgAeDLVRP///wATSHHAAAD4ADhA3EP///8As0MKvAAA8ADOPxM4AAD4AEjBTUD///8Arj5TP3B0+ACyKw9DaGz4ALm6eDz///8AKT4MSP///wDhPy5CAAD/APS/XEL///8A+EV6PwAA/wAdsXtBp6f/AGzEpEEAAP8AisfEuf///wDXwVJI////AJpEaUf///8AhUfxQP///wB7RA3FAAD/ANdBTzUAAP8AC8C9Rv///wBGQoVE////APRMpDz///8A7kS3yAAA/wDLR9HB////AFLHNscAAP8AR0HNwf///wDsvtLGAAD/AABE5kD///8AD0JIRv///wD0RNJA////AEVFqcD///8A3ESpwwAA/wAuwgtJ////AARBqEj///8ALUdbSf///wA01Hks////AHjCAL3///8AF8s5x////wC4vlPP////AME1O8f///8AhsIAPgAA+ABcxZXC7e3/AIrEpUMAAPgAjcbDxcvL/wBdQFzF////AEjI+8EAAOAAQ0GZvf///wAGN77AFRX/APlFXDz///8AikEzwkhI+ADcQmoy////AArNAgoHUmV2ZXJzZRLBAgoPDRydLkMVk5lFQh2z7Zk2EigNpHC6wRWX4xZCJQAAAMItAABMQjUAAEDBPR+F/0FFAAAAQE0AAABAEigNUrjBQRWZkRpCJR+FAMItZuaJQjUAAPpBPQAAAEJFAAAAQE0AAABAEigNAIBFQhXyG6BBJQAAUEEthetRQjWkcKdCPVK4TkFFAAAAQE0AAABAEigNe9TAQhXLmx3BJTQzKEItCtfPQTUeBeJCPa5HWcFFAAAAQE0AAABAEi0NbMjhQhU6MrZBHVRrszolmpm5Qi1SuNRBNVyPI0M9ZmZawUUAAABATQAAAEASLQ0f5QFDFZPVTUIdn1A2tSWk8LlCLXsUZUI1hSskQz0AAIZBRQAAAEBNAAAAQBIoDcNVM0MVaapBQiUAgPVCLQAAbEI1AABCQz0AAJRBRQAAAEBNAAAAQBIhCgZOb3JtYWwSFwoNTm9ybWFsIFNoYXBlcwoGTm9ybWFsEiMKB1JldmVyc2USGAoNTm9ybWFsIFNoYXBlcwoHUmV2ZXJzZQ==" -const worldModel = WorldModel.decode(base64ToBytes(world)) + console.log("Start") + requestAnimationFrame(iteration) -const settings: GameSettings = { - instanceType: GameInstanceType.Play, - world: worldModel, - gamemode: "Normal", + */ + }) } -try { - const loop = new GameLoop(new Game(settings)) - loop.start() -} catch (e) { - console.error(e) +if (false) { + tf.setBackend("cpu").then(() => { + const env = new PolyburnEnvironment() + + const inputDim = 5 + + const ppo = new PPO( + { + steps: 512, + epochs: 15, + policyLearningRate: 1e-3, + valueLearningRate: 1e-3, + clipRatio: 0.2, + targetKL: 0.01, + gamma: 0.99, + lambda: 0.95, + observationDimension: inputDim, + actionSpace: { + class: "Box", + len: 2, + low: -1, + high: 1, + }, + }, + env, + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: inputDim, + units: 64, + activation: "relu", + }), + tf.layers.dense({ + units: 64, + activation: "relu", + }), + ], + }), + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: inputDim, + units: 64, + activation: "relu", + }), + tf.layers.dense({ + units: 64, + activation: "relu", + }), + ], + }), + ) + + function possibleLifetime() { + let observation = env.reset() + + let totalReward = 0 + const inputs = [] + + while (true) { + const action = ppo.act(observation) + inputs.push(env.inputFromAction(action as number[])) + + const [nextObservation, reward, done] = env.step(action) + + totalReward += reward + observation = nextObservation + + if (done) { + break + } + } + + return { + totalReward, + touched: env.state().touched, + distance: env.state().distance, + inputs, + } + } + + let currentAverage = 0 + let i = 0 + + const previousTwenty: number[] = [] + + function iteration() { + ppo.learn(512 * i) + const info = possibleLifetime() + + console.log( + `Reward ${i}: reward(${info.totalReward}), distance(${info.distance}), touched(${info.touched})`, + ) + + if (info.totalReward > currentAverage && previousTwenty.length === 20) { + currentAverage = info.totalReward + console.log("Saved") + ppo.save() + } + + if (previousTwenty.length === 20) { + previousTwenty.shift() + } + + previousTwenty.push(info.totalReward) + + const avgPreviousTwenty = + previousTwenty.reduce((a, b) => a + b, 0) / previousTwenty.length + + ++i + + if ( + avgPreviousTwenty < 50 && + avgPreviousTwenty < Math.max(currentAverage, 10) * 0.5 && + previousTwenty.length === 20 + ) { + console.log("Restoring") + + ppo.restore().finally(() => { + requestAnimationFrame(iteration) + }) + } else { + requestAnimationFrame(iteration) + } + } + + ppo.restore().finally(() => { + const { totalReward, inputs } = possibleLifetime() + currentAverage = totalReward + + console.log(JSON.stringify(inputs)) + + console.log("Start with: ", currentAverage) + requestAnimationFrame(iteration) + }) + }) } -*/ -}) diff --git a/packages/learning/src/old-main.ts b/packages/learning/src/old-main.ts new file mode 100644 index 00000000..0409c6cc --- /dev/null +++ b/packages/learning/src/old-main.ts @@ -0,0 +1,283 @@ +import { Environment, PPO } from "./ppo/ppo" + +function getReward(got: number, expected: number) { + function f() { + const gotRounded = Math.round(got) + + if (gotRounded === expected) { + return 0 + } + + if (gotRounded === 0) { + return expected === -1 ? 1 : -1 + } + + if (gotRounded === 1) { + return expected === 0 ? 1 : -1 + } + + return expected === 1 ? 1 : -1 + } + + return (f() + 1) / 2 +} + +const observationSize = 8 +const actionSize = 1 + +const observations = [ + [[-1, -1, -1, -1, -1, -1, -1, -1], [-1]], + [[0, 0, 0, 0, 0, 0, 0, 0], [0]], + [[1, 1, 1, 1, 1, 1, 1, 1], [1]], + [[-1, 0, 1, 0, -1, 0, 1, 0], [-1]], + [[0, 1, 0, -1, 0, 1, 0, -1], [0]], + [[1, 0, -1, 0, 1, 0, -1, 0], [1]], + [[-1, 1, -1, 1, -1, 1, -1, 1], [-1]], + [[1, -1, 1, -1, 1, -1, 1, -1], [1]], +] + +export class CartPole implements Environment { + private gravity: number + private massCart: number + private massPole: number + private totalMass: number + private cartWidth: number + private cartHeight: number + private length: number + private poleMoment: number + private forceMag: number + private tau: number + + private xThreshold: number + private thetaThreshold: number + + private x: number = 0 + private xDot: number = 0 + private theta: number = 0 + private thetaDot: number = 0 + + /** + * Constructor of CartPole. + */ + constructor() { + // Constants that characterize the system. + this.gravity = 9.8 + this.massCart = 1.0 + this.massPole = 0.1 + this.totalMass = this.massCart + this.massPole + this.cartWidth = 0.2 + this.cartHeight = 0.1 + this.length = 0.5 + this.poleMoment = this.massPole * this.length + this.forceMag = 10.0 + this.tau = 0.02 // Seconds between state updates. + + // Threshold values, beyond which a simulation will be marked as failed. + this.xThreshold = 2.4 + this.thetaThreshold = (12 / 360) * 2 * Math.PI + + this.reset() + } + + /** + * Get current state as a tf.Tensor of shape [1, 4]. + */ + getStateTensor() { + return [this.x, this.xDot, this.theta, this.thetaDot] + } + + /** + * Update the cart-pole system using an action. + * @param {number} action Only the sign of `action` matters. + * A value > 0 leads to a rightward force of a fixed magnitude. + * A value <= 0 leads to a leftward force of the same fixed magnitude. + */ + step(action: number | number[]): [number[], number, boolean] { + if (Array.isArray(action)) { + action = action[0] + } + + const force = action === 0 ? this.forceMag : -this.forceMag + + const cosTheta = Math.cos(this.theta) + const sinTheta = Math.sin(this.theta) + + const temp = + (force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) / this.totalMass + const thetaAcc = + (this.gravity * sinTheta - cosTheta * temp) / + (this.length * (4 / 3 - (this.massPole * cosTheta * cosTheta) / this.totalMass)) + const xAcc = temp - (this.poleMoment * thetaAcc * cosTheta) / this.totalMass + + // Update the four state variables, using Euler's method. + this.x += this.tau * this.xDot + this.xDot += this.tau * xAcc + this.theta += this.tau * this.thetaDot + this.thetaDot += this.tau * thetaAcc + + const reward = this.isDone() ? -100 : 1 + return [this.getStateTensor(), reward, this.isDone()] + } + + /** + * Set the state of the cart-pole system randomly. + */ + reset() { + // The control-theory state variables of the cart-pole system. + // Cart position, meters. + this.x = Math.random() - 0.5 + // Cart velocity. + this.xDot = (Math.random() - 0.5) * 1 + // Pole angle, radians. + this.theta = (Math.random() - 0.5) * 2 * ((6 / 360) * 2 * Math.PI) + // Pole angle velocity. + this.thetaDot = (Math.random() - 0.5) * 0.5 + + return this.getStateTensor() + } + + /** + * Determine whether this simulation is done. + * + * A simulation is done when `x` (position of the cart) goes out of bound + * or when `theta` (angle of the pole) goes out of bound. + * + * @returns {bool} Whether the simulation is done. + */ + isDone() { + return ( + this.x < -this.xThreshold || + this.x > this.xThreshold || + this.theta < -this.thetaThreshold || + this.theta > this.thetaThreshold + ) + } +} + +import * as tf from "@tensorflow/tfjs" + +tf.setBackend("cpu").then(() => { + const env = new CartPole() + + const ppo = new PPO( + { + steps: 2048, + epochs: 15, + policyLearningRate: 1e-3, + valueLearningRate: 1e-3, + clipRatio: 0.1, + targetKL: 0.01, + gamma: 0.99, + lambda: 0.95, + observationDimension: 4, + actionSpace: { + class: "Discrete", + len: 2, + }, + }, + env, + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: 4, + units: 32, + activation: "relu", + }), + tf.layers.dense({ + units: 32, + activation: "relu", + }), + ], + }), + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: 4, + units: 32, + activation: "relu", + }), + tf.layers.dense({ + units: 32, + activation: "relu", + }), + ], + }), + ) + + function possibleLifetime() { + const acc = [] + + for (let j = 0; j < 100; ++j) { + env.reset() + + let t = 0 + + while (!env.isDone() && t < 1000) { + env.step(ppo.act(env.getStateTensor())) + t++ + } + + acc.push(t) + } + + // average of top 10% lifetimes + acc.sort((a, b) => b - a) + + const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 + const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 + const avg = acc.reduce((a, b) => a + b, 0) / acc.length + + return { avg, best10avg, worst10avg } + } + + ;(async () => { + // await ppo.restore() + let currentAverage = possibleLifetime().avg + + for (let i = 0; i < 500; ++i) { + ppo.learn(1000 * i) + + const { avg, best10avg, worst10avg } = possibleLifetime() + + console.log(`Leaks: ${tf.memory().numTensors}`) + console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) + + if (avg > currentAverage) { + // await ppo.save() + currentAverage = avg + console.log("Saved") + } + } + })().then(() => { + console.log(possibleLifetime()) + }) + + /* +import { WorldModel } from "runtime/proto/world" +import { Game } from "./game/game" +import { GameLoop } from "./game/game-loop" +import { GameInstanceType, GameSettings } from "./game/game-settings" +import * as tf from '@tensorflow/tfjs'; + +function base64ToBytes(base64: string) { + return Uint8Array.from(atob(base64), c => c.charCodeAt(0)) +} + +const world = + "CscCCgZOb3JtYWwSvAIKCg2F65XBFTXTGkISKA2kcLrBFZfjFkIlAAAAwi1SuIlCNa5H+UE9H4X/QUUAAABATQAAAEASKA1SuMFBFZmRGkIlhetRQS3NzFJCNSlcp0I9zcxEQUUAAABATQAAAEASKA0AgEVCFfIboEElAAAoQi0K189BNaRw4UI9rkdZwUUAAABATQAAAEASKA171MBCFcubHcElmpm5Qi0K189BNY/CI0M9rkdZwUUAAABATQAAAEASLQ1syOFCFToytkEdVGuzOiWamblCLSlcZUI1XI8jQz3NzIhBRQAAAEBNAAAAQBItDR/lAUMVk9VNQh2fUDa1JaRw9UItexRsQjWF60FDPQAAlEFFAAAAQE0AAABAEigNw1UzQxVpqkFCJdejJEMtBW94QjXXo0JDPQVvAEJFAAAAQE0AAABACu4KCg1Ob3JtYWwgU2hhcGVzEtwKGt8GCtwGP4UAws3MNEGgEEAAZjYAAP///wB1PAAU////AF5PABT///8AyUtPxP///wAzSg3L////AMBJAcj///8AE0Umzf///wCMVAo5////AJNRpDr///8AVE0WVP///wD0vlZLAAD/AEPI7Bn///8AhcPlOAAA/wAFQZrF////ADS9F8f///8AJMIuwf///wC5xvvF////AOrJ1rf///8Ac8ikQP///wBAxfRF////AGkxi0n///8Aj0LxQgAA/wB1xWY9////AJ/HZAlQUP4AzcUBvQAA/wDwQFzE////ADDGR73///8As8eZPoiI8QBxxWQ3rKz/AFw3LMQAAP8AwkNRtP///wC2RKO4////AEhBe8EAAP8AS0WPPP///wAdSaSx////AMw/Ucj///8A7MBNxv///wDmxnG9////AELCFLr///8Aw8UOof///wAKxCg4AAD/ALg8OMDZ2fsA4j9NwP///wCkxB+/AADwAHGwrr54ePgAVERcwv///wAPwXbA////APW0H0EAAPgASLtnv////wALM67DJSX/AFJApL////8AZj4uwP///wBcu+HATU3/AIU7+8H///8AXMK8Lf///wB7wjM/AAD4AHDCx8D///8AFEH7wP///wAAvnvE////AOTGChL///8A6bncRP///wCAQddAAAD4AB/AxLH///8AIL9RPQAA+ACZwqvG////AOLCLkQAAPgAIcTrwP///wDtwQPH////AOLJbqz///8ALsR6QwAA+AD+x8zA////APtF90kyMv8AH7mZQCcn/wCNxHo8tbX/AIDAiETKyv8AXEAgSgAA+AClyAqS////AH9EG0n///8AS0ypRP///wAxSIK7MDToANjBdUf///8A58yjxP///wCByD1EMDToAIzCYMv///8AnMq3MzA06AC+QenF////ANzGT0T///8AtMFSR////wBzRb85lpj/AFJALEQwNOgAqMIpPjA06AAgyiCF////AAPEE77///8AzT4FSnN1/wAzxWFCMDToAA23PcKXl/8AGcLmQDA06ADMPUnJu77/AFrGxsL///8A1TRGSjA06ACKwik8MDToAE3Apcn///8Ar8SawP///wBsygqP////ABHI8z0wNOgAAABTzv///wAa9wMK9APNzJNCj8JlQP///wBmtly8////ABa2jsg2Nv8AO0SENwAA+ACkvrtEvLz/AG0uOEX///8A4UaHPv///wA+QlXFAAD4AApB2L4AAPgAeDLVRP///wATSHHAAAD4ADhA3EP///8As0MKvAAA8ADOPxM4AAD4AEjBTUD///8Arj5TP3B0+ACyKw9DaGz4ALm6eDz///8AKT4MSP///wDhPy5CAAD/APS/XEL///8A+EV6PwAA/wAdsXtBp6f/AGzEpEEAAP8AisfEuf///wDXwVJI////AJpEaUf///8AhUfxQP///wB7RA3FAAD/ANdBTzUAAP8AC8C9Rv///wBGQoVE////APRMpDz///8A7kS3yAAA/wDLR9HB////AFLHNscAAP8AR0HNwf///wDsvtLGAAD/AABE5kD///8AD0JIRv///wD0RNJA////AEVFqcD///8A3ESpwwAA/wAuwgtJ////AARBqEj///8ALUdbSf///wA01Hks////AHjCAL3///8AF8s5x////wC4vlPP////AME1O8f///8AhsIAPgAA+ABcxZXC7e3/AIrEpUMAAPgAjcbDxcvL/wBdQFzF////AEjI+8EAAOAAQ0GZvf///wAGN77AFRX/APlFXDz///8AikEzwkhI+ADcQmoy////AArNAgoHUmV2ZXJzZRLBAgoPDRydLkMVk5lFQh2z7Zk2EigNpHC6wRWX4xZCJQAAAMItAABMQjUAAEDBPR+F/0FFAAAAQE0AAABAEigNUrjBQRWZkRpCJR+FAMItZuaJQjUAAPpBPQAAAEJFAAAAQE0AAABAEigNAIBFQhXyG6BBJQAAUEEthetRQjWkcKdCPVK4TkFFAAAAQE0AAABAEigNe9TAQhXLmx3BJTQzKEItCtfPQTUeBeJCPa5HWcFFAAAAQE0AAABAEi0NbMjhQhU6MrZBHVRrszolmpm5Qi1SuNRBNVyPI0M9ZmZawUUAAABATQAAAEASLQ0f5QFDFZPVTUIdn1A2tSWk8LlCLXsUZUI1hSskQz0AAIZBRQAAAEBNAAAAQBIoDcNVM0MVaapBQiUAgPVCLQAAbEI1AABCQz0AAJRBRQAAAEBNAAAAQBIhCgZOb3JtYWwSFwoNTm9ybWFsIFNoYXBlcwoGTm9ybWFsEiMKB1JldmVyc2USGAoNTm9ybWFsIFNoYXBlcwoHUmV2ZXJzZQ==" +const worldModel = WorldModel.decode(base64ToBytes(world)) + +const settings: GameSettings = { + instanceType: GameInstanceType.Play, + world: worldModel, + gamemode: "Normal", +} + +try { + const loop = new GameLoop(new Game(settings)) + loop.start() +} catch (e) { + console.error(e) +} +*/ +}) diff --git a/packages/learning/src/ppo/ppo.ts b/packages/learning/src/ppo/ppo.ts index 71188139..525594cc 100644 --- a/packages/learning/src/ppo/ppo.ts +++ b/packages/learning/src/ppo/ppo.ts @@ -5,7 +5,7 @@ class ReplayBuffer { private lambda: number private observationBuffer: number[][] = [] - private actionBuffer: number[][] = [] + private actionBuffer: (number | number[])[] = [] private advantageBuffer: number[] = [] private rewardBuffer: number[] = [] private returnBuffer: number[] = [] @@ -24,7 +24,7 @@ class ReplayBuffer { add( observation: number[], - action: number[], + action: number | number[], reward: number, criticPrediction: number, logProbability: number, @@ -83,13 +83,13 @@ class ReplayBuffer { advantage => (advantage - advantageMean) / advantageStd, ) - return [ - this.observationBuffer, - this.actionBuffer, - this.advantageBuffer, - this.returnBuffer, - this.logProbabilityBuffer, - ] + return { + observationBuffer: this.observationBuffer, + actionBuffer: this.actionBuffer, + advantageBuffer: this.advantageBuffer, + returnBuffer: this.returnBuffer, + logProbabilityBuffer: this.logProbabilityBuffer, + } } reset() { @@ -108,14 +108,11 @@ class ReplayBuffer { interface DiscreteSpace { class: "Discrete" - dtype?: "int32" - len: number } interface BoxSpace { class: "Box" - dtype?: "float32" low: number high: number @@ -141,33 +138,12 @@ interface PPOConfig { actionSpace: Space } -interface Environment { +export interface Environment { reset(): number[] step(action: number | number[]): [number[], number, boolean] } -const ppo = new PPO( - {} as PPOConfig, - {} as Space, - [ - { - class: "Box", - len: 2, - low: [0, 0], - high: [1, 1], - }, - { - class: "Discrete", - len: 2, - }, - ], - {} as tf.LayersModel, - {} as tf.LayersModel, -) - -ppo.act([1, 2, 3]) - -class PPO { +export class PPO { private numTimeSteps: number private lastObservation: number[] @@ -186,39 +162,26 @@ class PPO { private env: Environment, - private actorModel: tf.LayersModel, - private criticModel: tf.LayersModel, + actorModel: tf.LayersModel, + criticModel: tf.LayersModel, ) { this.numTimeSteps = 0 this.lastObservation = [] this.buffer = new ReplayBuffer(config.gamma, config.lambda) - if (config.actionSpace.class === "Discrete") { - this.actor = tf.sequential({ - layers: [ - actorModel, - tf.layers.dense({ - units: config.actionSpace.len, - }), - ], - }) - } else if (config.actionSpace.class === "Box") { - this.actor = tf.sequential({ - layers: [ - actorModel, - tf.layers.dense({ - units: config.actionSpace.len, - }), - ], - }) - } else { - throw new Error("Unsupported action space") - } + this.actor = tf.sequential({ + layers: [ + actorModel, + tf.layers.dense({ + units: config.actionSpace.len, + }), + ], + }) this.critic = tf.sequential({ layers: [ - actorModel, + criticModel, tf.layers.dense({ units: 1, activation: "linear", @@ -234,29 +197,112 @@ class PPO { this.optimizerValue = tf.train.adam(config.valueLearningRate) } - act(observation: number[]): GetPPOSpaceType {} + async save() { + await this.actor.save("localstorage://actor") + await this.critic.save("localstorage://critic") + } + + async restore() { + this.actor = await tf.loadLayersModel("localstorage://actor") + this.critic = await tf.loadLayersModel("localstorage://critic") + } + + act(observation: number[]): number | number[] { + return tf.tidy(() => { + const [, , actionSynced] = this.sampleAction(tf.tensor([observation])) + return actionSynced + }) + } + + learn(upToTimesteps: number) { + while (this.numTimeSteps < upToTimesteps) { + this.collectRollouts() + this.train() + } + } + + private train() { + const batch = this.buffer.get() + + tf.tidy(() => { + const observationBuffer = tf.tensor2d(batch.observationBuffer) + + const actionBuffer = tf.tensor( + batch.actionBuffer, + undefined, + this.config.actionSpace.class === "Discrete" ? "int32" : "float32", + ) + + const advantageBuffer = tf.tensor1d(batch.advantageBuffer) + const returnBuffer = tf.tensor1d(batch.returnBuffer).reshape([-1, 1]) as tf.Tensor1D + const logProbabilityBuffer = tf.tensor1d(batch.logProbabilityBuffer) + + for (let epoch = 0; epoch < this.config.epochs; ++epoch) { + const kl = this.trainPolicy( + observationBuffer, + actionBuffer, + logProbabilityBuffer, + advantageBuffer, + ) + + if (kl > 1.5 * this.config.targetKL) { + break + } + } + + for (let epoch = 0; epoch < this.config.epochs; ++epoch) { + this.trainValue(observationBuffer, returnBuffer) + } + }) + } private collectRollouts() { - this.buffer.reset() + if (this.lastObservation.length === 0) { + this.lastObservation = this.env.reset() + } - let sumReturn = 0 - let sumReward = 0 - let numEpisodes = 0 + this.buffer.reset() for (let step = 0; step < this.config.steps; ++step) { tf.tidy(() => { - const observation = tf.tensor2d(this.lastObservation) + const observation = tf.tensor([this.lastObservation]) as tf.Tensor2D - const [predictions, action, actionSynced] = this.sampleAction(observation) - const value = this.critic.predict(observation) as tf.Tensor1D + const [predictions, action, actionClipped] = this.sampleAction(observation) + const value = this.critic.predict(observation) as tf.Tensor2D + const valueSynced = value.arraySync()[0][0] + const actionSynced = action.arraySync() // TODO verify types - const logProbability = this.logProb(predictions as any, action as any) + const logProbability = this.logProb(predictions, action) + const logProbabilitySynced = logProbability.arraySync() + + const [nextObservation, reward, done] = this.env.step(actionClipped) + this.numTimeSteps++ + + this.buffer.add( + this.lastObservation, + actionSynced, + reward, + valueSynced, + logProbabilitySynced, + ) + + this.lastObservation = nextObservation - const [nextObservation, reward, done] = this.env.step(actionSynced) + if (done || step === this.config.steps - 1) { + let lastValue = 0 - sumReturn += reward - sumReward += reward + if (!done) { + const lastValueTensor = this.critic.predict( + tf.tensor([nextObservation]), + ) as tf.Tensor2D + + lastValue = lastValueTensor.arraySync()[0][0] + } + + this.buffer.finishTrajectory(lastValue) + this.lastObservation = this.env.reset() + } }) } } @@ -275,7 +321,7 @@ class PPO { private trainPolicy( observationBuffer: tf.Tensor2D, - actionBuffer: tf.Tensor2D, + actionBuffer: tf.Tensor, logProbabilityBuffer: tf.Tensor1D, advantageBuffer: tf.Tensor1D, ) { @@ -314,13 +360,13 @@ class PPO { actionBuffer, ), ), - ) + ) as tf.Scalar return kl.arraySync() }) } - private logProb(predictions: tf.Tensor2D, actions: tf.Tensor2D) { + private logProb(predictions: tf.Tensor2D, actions: tf.Tensor) { if (this.config.actionSpace.class === "Discrete") { return this.logProbCategorical(predictions, actions) } else if (this.config.actionSpace.class === "Box") { @@ -330,7 +376,7 @@ class PPO { } } - private logProbCategorical(predictions: tf.Tensor2D, actions: tf.Tensor2D) { + private logProbCategorical(predictions: tf.Tensor2D, actions: tf.Tensor) { return tf.tidy(() => { const numActions = predictions.shape[predictions.shape.length - 1] const logprobabilitiesAll = tf.logSoftmax(predictions) @@ -338,11 +384,11 @@ class PPO { return tf.sum( tf.mul(tf.oneHot(actions, numActions), logprobabilitiesAll), logprobabilitiesAll.shape.length - 1, - ) + ) as tf.Scalar }) } - private logProbNormal(predictions: tf.Tensor2D, actions: tf.Tensor2D) { + private logProbNormal(predictions: tf.Tensor2D, actions: tf.Tensor) { return tf.tidy(() => { if (this.logStd === undefined) { throw new Error("logStd is not initialized") @@ -360,7 +406,7 @@ class PPO { return tf.sum( tf.sub(logUnnormalized, logNormalization), logUnnormalized.shape.length - 1, - ) + ) as tf.Scalar }) } @@ -368,7 +414,7 @@ class PPO { return tf.tidy(() => { const predictions = tf.squeeze( this.actor.predict(observation) as tf.Tensor2D, - ) as tf.Tensor1D + ) as tf.Tensor2D const actionSpace = this.config.actionSpace diff --git a/packages/learning/src/soft-actor-critic.zip b/packages/learning/src/soft-actor-critic.zip new file mode 100644 index 0000000000000000000000000000000000000000..6b6bd02102d9216518e4c90bb166c04381a86541 GIT binary patch literal 5275 zcmai&1yEIOyN2o7beBkNO1fmzB@I#=HXz;I(p}O>ce82f7NjMm6a^$jx;y3Pi~oF_ zBWJ#M&8k^z%{6Q8=bra|US)ZB1bmo325nS6ac#-$#LQi_C|R=;}`zz{9}cqrt$CKKlNjAG5kRH)!a*oMXrLm0l5n_|z72i=Jsj zTcXnjBcWGXqz7YjMkjdkIUXQtW8e9FjG6ncea1e@X1lWQoip}j`2{dZ>?!t_0TTwb zTG0Zw&?N0Lg@JXUgxl$Y`7D307mL}YLp+{ya8CfOCUW~luQG({Dj**(616KqEMu4O zc;2k0t#lFRN7W$0?fJ&tkP;|C#$nW0$#QL!GF4eJND$WB&74l&cKq^us%`UhaK3K_ z$S?WB#IM@%1i*`H-bJWqL{~DsVPRG%ZUgIE#7gV9tL(k!JHEY#83<p}aOyZ7*L|wCiSF&qpDYpAE zU}^Sznf~HDC>bk<=46S?6_+&G25$a)<8cejJQ&Vc8*CI#)H@oIXE+z$T8TzhkwJTt zR4`0wBK*@sET9@p|oHtd&T^C%_ zzZh00Oi?9JhoPQ!))q_HmENEfry*Yc?5!Xb9G5(6HulNXw`ld237qfSdukS(*_r~Z zX+)%WS`YuZctbd-OK9%`iQbL5JY`fHl*+gMejbs>*^M|e&VnF3y*+ebM7`UJ9R1^< z)eTNleqOiOzAhP_H);RkxD$nm0C8Uw7}$zRzjfs^w{2s~kccp()ttg=+2x={g)u*% z4=FHM_+Tq;);Rj6+s^Z8e3PZMPau?Mr0)Gn!#e324wx@s3^gaBoa;z0#aviZ&l+!U z7RSqB68|;(Q9kWfQfEO?nWeorQU&8!t%Ev1%n_Ej@v4Bo)DWYnZcY#{p4Tptjh!O~ z=-r>_u=w>LQDixH>WGRx9s~SC6gceoMm%=|0d*&DN5H7Ch0WWvIfxT6hy12bzqZB* z^p!m@MmOA0U)JCV`5#|e`sYKI3T;x{VPRm9?!8I&s4xGr^Sv=`Roj(%*s&L$ojy@Y z!$pzOr4gmnFvzt5Bjd65E%okQs!nS7xY<0nk{D)D+Sy;fxm35s;CSWl(n+9hVkNUj zLS|kvxwf8O&Xy7-Eo286n^j^-E>kTsF-<48vNin5)z@%$d~k@;Cd1I+f=MA;!vrh9 zc7qcYoNP@?pWO{NcZlMxrTT0TWtf^0+H6x}2y94+TraXDSJSD@`}|&@5WHhh#tRXh zn%RDO+$Z&l{3XkgfI#@APtVD}O30~hbEke}UgH>Uy2yV1Ya7hF z=vL=(ws^*l9*v&$p;_d(=DHp!9`s7F7M{EFDmS=ZqiTv3Fy->juMpcwUeTZLZlJlV ze%SCePg;&%e+_z2U8irno=NoP;g=ZILEqkZco8`Scv=y0D?B{>V8Bn}O^!E3UNPr4 zi2n!{=$sEJkp{@B-3!(Z2Lr?Ks9@%XuFlSuhITAAme!^=mKOH*CV#3o_=kG!(nBIq zPPOIOazIuTY+U4EbwU+86-=n3wwgf9`|;XafyRXKc=ibHxs2!XBa$>p0jSs*&_Fp^ z*ukW{hUBspCoh{a&8%J(@(Ep0R?q~g-J37zZ#_TpCe}410>`t;nS|u~2L-_V8D5Av zjtEN0s@&7#$kv(K6|xdIVX7RA~>K`q2*)r`d=+PgSJFMHpD>1kuSS$PiScJ(mIg7&++yaVs z!GbY~reS2Xq%AL$l6FR)+3G8&;ui~Vvc}AY*8oNw#st>(q|#2w_%EJtY>#+YASFzD zyzi5Uka+rWVx!ZLE?_ok-tlm$c^11TxjKKot(iWM`erOgt{;o;@+X>02LU+B;cS`D zG>fs-Xja*EH+-UbVe|<8A7N09@*#xh#r?kb!uZ`giTF`rY;7Fwg-Ci}qc>DVad5`D zfpW~~3Q;Hl;bVPE-?};uG>`UPWlaXxqkqTUyt%w~c2VJfS&d8$!>GJKQ0=T_5p+D% z+=NPSfhTgSD4NjstQ_V^f`Nt{`hbOh!9)xn`}Q zvTn4X@UNOom)dFy)fePixc(ZrM|WRab?N9F0P*QwqWMKgsf zVR`@DR%X}dW)CNk1YBS*!vS=&=7*=-;jkM<_fXQ8vkD%a#X33lUpU7yNll6&hcB;? z{?QE4B_Hz6Pd7~U-bh)U}NaPV&rOOX6kgWUY2^7UBTZrqOL?}8<;yG zZkPcVQFuZs6Zcq#x;mBL19+`7=xjgxRdpedqbRDYmzoG3ueEzbY3ZSqE18`$#n zF!18-Qx!?$)aRFD-GK~}!U_1y?3Idy z4v5{>PKN4BVrjaPl1&RkPHcKvA}qvcp!%5a5PA~hwitX&s`#>hTshDsNYTjpvb1?X z;PmE*l37OSO1)f|`lgBNLcAEGyJvL4h-FN}@B4&8=drW^M&? z`0f8wkzem~GGU&U=yyaIn6&$lO!sI-{+6`wtFli^$A0Rs%$zDmi7cMGTsa;diOoj& zEwopn0r|Nrzl0qhtB&FAK3`+P1+F;>x(bzR&DiCq+j#M#vO;u0gc2lXL`D{8cCSGP zl@#vLp-6h>TkWyE1r?VMo#nwv_~iJeCkQ-a{9R;n#IoEGlH?n-Piwd8yHRCj1Bn>{ z(R^b}Sx393A&gxw0ws3$$Hd#QGnbeouD`Z%_qEEYlZ)s@I>)>|$Lzp&FEmJnGIT8| zY89FQeCONYFX&UQn92M2swhQek6q+^R}dLqUy`bM=)JR6@S{9eUCtkFA}Afk@QsaR z5qO?f zExMy6L_cA%zNQUY%&sU+UZb&<85KZc;)~IZcZWhYVt6Vv0PwQ_l*+4WXH{uWn=V}Z zi*6M5)zw?+yca%`VL|8Am$zTj^L(ij(Z3@Peb)0!(~vvX>~@Jwr1%C>t)!{0{xAl7 zCzqjv5Mz(^El{K7R~5JB!JU#{!BBHz@t%*q5&XGAwpOu{r6W^vnP6zS$A;`ESLMhn z-RR_LBhCmJoIH#nou-Z3xxwHJC;4BEv%~1pS}h@U(4gTHzY0%2#tWh$z7sQOAw|6` zikEzwS>wu@9s5&7koh~%09=JUD;$B;qRizM4A%&~)uTZoKI|3G<4RATM>U*bR)j&6=WF;L{W)2E54=5IJ23n0%P@fCb7SB)_K{su^{}I6tS%@sI|S1of@)&_YN-oe zw-hNf=z1NZzPsgYfyf+@Ptk*>5k#EMx|<_|mOFm4WmS#Hl2S0_hM69uo-lt&V(V(3L}?Yme|4(x6`51C|}e2GRKF7>Gy9-K!ZX=0V*NbGyn6(u=u zGF3q*KG7pv%~CPo#Tr@3N~Wejs6GIW0Ra&Q3nf15*@3Z|#aKjN?%`3CB zJ>u=;C!oT~8n>c}xvrs$Jm$DB2%^nD#D!199b*nyF?3NM^2D#^^rpTSp3)>h1@$yu zNrp;SB*Z-@iEGb{Ab(k28gU~AbUd+aP=6~@M%lsaP0bWIx)Qg4w1zQ!OAMU;R-viJ z7Y^(Q1s>3Q>hRjuaL2ZxZ?Wj!oQz8eq9mDiEhYylG-Z4aT0OuIX>E8`@bc5dk@EKf z9;bl8v&=72sz=WY*^Y>jXaqgWv;=8h<+KqA{)o1zRquPMb3=h~MbI-kwQ$CeATZSa za}s|_Usw0?(pbb#FvL&B_oHR42QQYk-kbB8CPj%buAj)c2F|MMih$G>>%mj6xXMB| z*jk|`D^g1VCqaV!>7tt&HgXt1w@Mty9bL*Fp?8(n2=S0tG5l@;|7q{QG>QB4r+_~5 z<(vL_sW~6AEKW{<_n`2Cw~ZXgrdh`&XsX^_Irk?+k-HzCBQ}$d;QYGZX3xjuk!RKn zO}=OX(eN1!Ya%V=@@AjP=}cpZ0lIiCc#a`e}-kES?TD-XFPb>H3ICl**Z5t#pP<%RZsj`(9>|8;s~9R>}C^3P|m z_o=j36#)gBIeVW{|M&QZ@z`q3e;JhbAFzs*2w|6RLz$l8!GmM;1i*6$tYL)O53 zlK!3bKmF)K*sowEue5)G{oXx2gn8Ugy1&C7@1-7+g-;b(qW*>adwcMZjC&tGe<%Od zEIcGRu6p)){0r&#Q2UV7kMbDl|3=)0%!+g5&VT#6-($-|CIs*p^SA%K4>1q9css0h Xj~5Rj(jTRZe7|AbmxnvfA9w!&C92p& literal 0 HcmV?d00001 diff --git a/packages/learning/src/soft-actor-critic/actor.ts b/packages/learning/src/soft-actor-critic/actor.ts index 044939d1..df17e8b9 100644 --- a/packages/learning/src/soft-actor-critic/actor.ts +++ b/packages/learning/src/soft-actor-critic/actor.ts @@ -1,73 +1,99 @@ -import * as tf from "@tensorflow/tfjs-node-gpu" -import { GaussianLikelihood } from "./gaussian-likelihood" +import * as tf from "@tensorflow/tfjs" +import { Kwargs } from "@tensorflow/tfjs-layers/dist/types" import { MlpSpecification, mlp } from "./mlp" const LOG_STD_MIN = -20 const LOG_STD_MAX = 2 -export class Actor extends tf.layers.Layer { - private gaussianLikelihood: tf.layers.Layer +export function actor(observationSize: number, actionSize: number, mlpSpec: MlpSpecification) { + const net = mlp({ + ...mlpSpec, + sizes: [observationSize, ...mlpSpec.sizes], + }) - private net: tf.Sequential - private meanLayer: tf.layers.Layer - private stdevLayer: tf.layers.Layer + const meanLayer = tf.layers.dense({ + inputDim: mlpSpec.sizes[mlpSpec.sizes.length - 1], + units: actionSize, + activation: "linear", + }) - constructor(observationSize: number, actionSize: number, mlpSpec: MlpSpecification) { - super() + const stdevLayer = tf.layers.dense({ + inputDim: mlpSpec.sizes[mlpSpec.sizes.length - 1], + units: actionSize, + activation: "linear", + }) - this.net = mlp({ - ...mlpSpec, - sizes: [observationSize, ...mlpSpec.sizes], - }) + const innerActorLayer = new InnerActorLayer() - this.net.predict + const input = tf.input({ shape: [observationSize] }) + const netResult = net.apply(input) as tf.SymbolicTensor - this.meanLayer = tf.layers.dense({ - units: actionSize, - }) + const mu = meanLayer.apply(netResult) as tf.SymbolicTensor + const logstd = stdevLayer.apply(netResult) as tf.SymbolicTensor - this.stdevLayer = tf.layers.dense({ - units: actionSize, - }) + const [action, logprob] = innerActorLayer.apply([mu, logstd]) as tf.SymbolicTensor[] - this.gaussianLikelihood = new GaussianLikelihood() + return tf.model({ inputs: [input], outputs: [action, logprob] }) +} + +class InnerActorLayer extends tf.layers.Layer { + computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return [inputShape[0], inputShape[0]] } - call(x: tf.Tensor): tf.Tensor[] { - x = this.net.apply(x) as tf.Tensor - const mu = this.meanLayer.apply(x) as tf.Tensor + call([mu, logstd]: tf.Tensor[], kwargs: Kwargs): tf.Tensor[] { + const logstdClipped = tf.clipByValue(logstd, LOG_STD_MIN, LOG_STD_MAX) + const sigma = tf.exp(logstdClipped) - let logSigma = this.stdevLayer.apply(x) as tf.Tensor - logSigma = tf.clipByValue(logSigma, LOG_STD_MIN, LOG_STD_MAX) - const sigma = tf.exp(logSigma) + let action: tf.Tensor - let action = tf.mul(tf.randomNormal(mu.shape), sigma) - action = tf.tanh(action) + if (kwargs.deterministic) { + action = mu + } else { + action = tf.add( + mu, + tf.mul(sigma, tf.randomNormal(mu.shape, 0, 1, "float32") as tf.Tensor), + ) + } - let logpPi = this.gaussianLikelihood.apply([action, mu, logSigma]) as tf.Tensor + let actionLogProb = this.logProb(action, mu, sigma).sum(-1) - logpPi = tf.sub( - logpPi, + actionLogProb = actionLogProb.sub( tf.sum( - tf.mul(2, tf.sub(tf.sub(Math.log(2), action), tf.softplus(tf.mul(-2, action)))), + tf.mul(2, tf.sub(Math.log(2), tf.add(action, tf.softplus(tf.mul(-2, action))))), 1, ), ) - return [action, logpPi] + action = tf.tanh(action) + + return [action, actionLogProb] } - get trainableWeights(): tf.LayerVariable[] { - return [ - ...this.net.trainableWeights, - ...this.meanLayer.trainableWeights, - ...this.stdevLayer.trainableWeights, - ] + private logProb(x: tf.Tensor, mu: tf.Tensor, sigma: tf.Tensor) { + /* + var = self.scale**2 + log_scale = ( + math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() + ) + return ( + -((value - self.loc) ** 2) / (2 * var) + - log_scale + - math.log(math.sqrt(2 * math.pi)) + ) + */ + + const variance = tf.square(sigma) + const logScale = tf.log(sigma) + + const s1 = tf.div(tf.square(tf.sub(x, mu)), tf.mul(2, variance)) + const s2 = logScale + const s3 = Math.log(Math.sqrt(2 * Math.PI)) + + return tf.neg(s1).sub(s2).sub(s3) } static get className() { - return "Actor" + return "InnerActorLayer" } } - -tf.serialization.registerClass(Actor) diff --git a/packages/learning/src/soft-actor-critic/critic.ts b/packages/learning/src/soft-actor-critic/critic.ts index 3ec9f83a..fe366d09 100644 --- a/packages/learning/src/soft-actor-critic/critic.ts +++ b/packages/learning/src/soft-actor-critic/critic.ts @@ -1,32 +1,59 @@ -import * as tf from "@tensorflow/tfjs-node-gpu" +import * as tf from "@tensorflow/tfjs" import { MlpSpecification, mlp } from "./mlp" -export class Critic extends tf.layers.Layer { - private q: tf.Sequential +export function critic(observationSize: number, actionSize: number, mlpSpec: MlpSpecification) { + const q = mlp({ + ...mlpSpec, + sizes: [observationSize + actionSize, ...mlpSpec.sizes, 1], + outputActivation: "linear", + }) - constructor(observationSize: number, actionSize: number, mlpSpec: MlpSpecification) { + return q +} + +class BeforeCriticLayer extends tf.layers.Layer { + constructor() { super() + } - this.q = mlp({ - ...mlpSpec, - sizes: [observationSize + actionSize, ...mlpSpec.sizes, 1], - outputActivation: undefined, - }) + computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + console.log("inputShape: ", inputShape) + return [inputShape[0][0], inputShape[0][1] + inputShape[1][1]] } call([obs, act]: tf.Tensor[]): tf.Tensor { - let x = tf.concat([obs, act], 1) - x = this.q.apply(x) as tf.Tensor - return tf.squeeze(x, [1]) - } + console.log("obs: ", obs.dataSync()) + console.log("act: ", act.dataSync()) + + const x = tf.concat([obs, act], -1) - get trainableWeights(): tf.LayerVariable[] { - return this.q.trainableWeights + console.log("x: ", x.dataSync()) + + return x } static get className() { - return "Critic" + return "BeforeCriticLayer" } } -tf.serialization.registerClass(Critic) +class AfterCriticLayer extends tf.layers.Layer { + constructor() { + super() + } + + computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + console.log("inputShape: ", inputShape) + return [inputShape[0]] + } + + call(x: tf.Tensor): tf.Tensor { + const y = tf.squeeze(x, [-1]) + + return y + } + + static get className() { + return "AfterCriticLayer" + } +} diff --git a/packages/learning/src/soft-actor-critic/gaussian-likelihood.ts b/packages/learning/src/soft-actor-critic/gaussian-likelihood.ts deleted file mode 100644 index 1074b28c..00000000 --- a/packages/learning/src/soft-actor-critic/gaussian-likelihood.ts +++ /dev/null @@ -1,25 +0,0 @@ -import * as tf from "@tensorflow/tfjs-node-gpu" - -export class GaussianLikelihood extends tf.layers.Layer { - computeOutputShape(inputShape: tf.Shape[]): tf.Shape | tf.Shape[] { - return [inputShape[0][0], 1] - } - - call([x, mu, logstd]: tf.Tensor[]): tf.Tensor { - const preSum = tf.mul( - -0.5, - tf.add( - tf.pow(tf.div(tf.sub(x, mu), tf.exp(logstd)), 2), - tf.add(tf.mul(2, logstd), Math.log(2 * Math.PI)), - ), - ) - - return tf.sum(preSum, 1) - } - - static get className() { - return "GaussianLikelihood" - } -} - -tf.serialization.registerClass(GaussianLikelihood) diff --git a/packages/learning/src/soft-actor-critic/mlp.ts b/packages/learning/src/soft-actor-critic/mlp.ts index 20680742..fde36256 100644 --- a/packages/learning/src/soft-actor-critic/mlp.ts +++ b/packages/learning/src/soft-actor-critic/mlp.ts @@ -1,5 +1,5 @@ +import * as tf from "@tensorflow/tfjs" import { ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config" -import * as tf from "@tensorflow/tfjs-node-gpu" export interface MlpSpecification { sizes: number[] @@ -20,6 +20,8 @@ export function mlp(spec: MlpSpecification) { activation: nextActivation, }), ) + + console.log(`Added layer with inputDim: ${spec.sizes[i]} and units: ${spec.sizes[i + 1]}`) } return model diff --git a/packages/learning/src/soft-actor-critic/replay-buffer.ts b/packages/learning/src/soft-actor-critic/replay-buffer.ts index 3550e338..5067c235 100644 --- a/packages/learning/src/soft-actor-critic/replay-buffer.ts +++ b/packages/learning/src/soft-actor-critic/replay-buffer.ts @@ -1,4 +1,4 @@ -import * as tf from "@tensorflow/tfjs-node-gpu" +import * as tf from "@tensorflow/tfjs" export interface Experience { observation: number[] @@ -20,12 +20,6 @@ export class ReplayBuffer { private buffer: Experience[] = [] private bufferIndex = 0 - private tensorObservation: tf.TensorBuffer - private tensorAction: tf.TensorBuffer - private tensorReward: tf.TensorBuffer - private tensorNextObservation: tf.TensorBuffer - private tensorDone: tf.TensorBuffer - constructor( private capacity: number, private batchSize: number, @@ -35,12 +29,6 @@ export class ReplayBuffer { if (batchSize > capacity) { throw new Error("Batch size must be less than or equal to capacity") } - - this.tensorObservation = tf.buffer([batchSize, observationSize], "float32") - this.tensorAction = tf.buffer([batchSize, actionSize], "float32") - this.tensorReward = tf.buffer([batchSize], "float32") - this.tensorNextObservation = tf.buffer([batchSize, observationSize], "float32") - this.tensorDone = tf.buffer([batchSize], "bool") } push(experience: Experience) { @@ -48,7 +36,7 @@ export class ReplayBuffer { this.buffer[this.bufferIndex] = experience this.bufferIndex = (this.bufferIndex + 1) % this.capacity } else { - this.buffer.push(experience) + this.buffer.push({ ...experience }) } } @@ -57,29 +45,16 @@ export class ReplayBuffer { throw new Error("Buffer does not have enough experiences") } - const indices = tf.util.createShuffledIndices(this.buffer.length) - - for (let i = 0; i < this.batchSize; i++) { - const experience = this.buffer[indices[i]] - - for (let j = 0; j < this.observationSize; j++) { - this.tensorObservation.set(experience.observation[j], i, j) - this.tensorNextObservation.set(experience.nextObservation[j], i, j) - } + const indices = [ + ...tf.util.createShuffledIndices(Math.min(this.buffer.length, this.batchSize)), + ] - for (let j = 0; j < this.actionSize; j++) { - this.tensorAction.set(experience.action[j], i, j) - } + const observation = tf.tensor2d(indices.map(x => this.buffer[x].observation)) + const action = tf.tensor2d(indices.map(x => this.buffer[x].action)) + const reward = tf.tensor1d(indices.map(x => this.buffer[x].reward)) + const nextObservation = tf.tensor2d(indices.map(x => this.buffer[x].nextObservation)) + const done = tf.tensor1d(indices.map(x => (this.buffer[x].done ? 1 : 0))) - this.tensorReward.set(experience.reward, i) - } - - return { - observation: this.tensorObservation.toTensor(), - action: this.tensorAction.toTensor(), - reward: this.tensorReward.toTensor(), - nextObservation: this.tensorNextObservation.toTensor(), - done: this.tensorDone.toTensor(), - } + return { observation, action, reward, nextObservation, done } } } diff --git a/packages/learning/src/soft-actor-critic/soft-actor-critic.ts b/packages/learning/src/soft-actor-critic/soft-actor-critic.ts index 17fa010d..e414c350 100644 --- a/packages/learning/src/soft-actor-critic/soft-actor-critic.ts +++ b/packages/learning/src/soft-actor-critic/soft-actor-critic.ts @@ -1,6 +1,6 @@ -import * as tf from "@tensorflow/tfjs-node-gpu" -import { Actor } from "./actor" -import { Critic } from "./critic" +import * as tf from "@tensorflow/tfjs" +import { actor } from "./actor" +import { critic } from "./critic" import { MlpSpecification } from "./mlp" import { Experience, ExperienceTensor, ReplayBuffer } from "./replay-buffer" @@ -25,13 +25,13 @@ export interface Config { export class SoftActorCritic { private replayBuffer: ReplayBuffer - private policy: Actor + private policy: tf.LayersModel private policyOptimizer: tf.Optimizer - private q1: Critic - private q2: Critic - private targetQ1: Critic - private targetQ2: Critic + private q1: tf.LayersModel + private q2: tf.LayersModel + private targetQ1: tf.LayersModel + private targetQ2: tf.LayersModel private qOptimizer: tf.Optimizer private episodeReturn: number @@ -39,8 +39,6 @@ export class SoftActorCritic { private t: number - private actTensor: tf.TensorBuffer - constructor(private config: Config) { this.replayBuffer = new ReplayBuffer( config.bufferSize, @@ -49,33 +47,44 @@ export class SoftActorCritic { config.actionSize, ) - this.policy = new Actor(config.observationSize, config.actionSize, config.mlpSpec) - this.q1 = new Critic(config.observationSize, config.actionSize, config.mlpSpec) - this.q2 = new Critic(config.observationSize, config.actionSize, config.mlpSpec) - this.targetQ1 = new Critic(config.observationSize, config.actionSize, config.mlpSpec) - this.targetQ2 = new Critic(config.observationSize, config.actionSize, config.mlpSpec) + this.policy = actor(config.observationSize, config.actionSize, config.mlpSpec) + this.q1 = critic(config.observationSize, config.actionSize, config.mlpSpec) + this.q2 = critic(config.observationSize, config.actionSize, config.mlpSpec) + + this.targetQ1 = critic(config.observationSize, config.actionSize, config.mlpSpec) + this.targetQ2 = critic(config.observationSize, config.actionSize, config.mlpSpec) this.episodeReturn = 0 this.episodeLength = 0 this.t = 0 - this.actTensor = tf.buffer([1, config.observationSize], "float32") + this.observationBuffer = tf.buffer([1, config.observationSize], "float32") this.policyOptimizer = tf.train.adam(config.learningRate) this.qOptimizer = tf.train.adam(config.learningRate) - } - act(observation: number[]) { - for (let i = 0; i < this.config.observationSize; i++) { - this.actTensor.set(observation[i], 0, i) + for (let i = 0; i < this.targetQ1.trainableWeights.length; i++) { + const targetWeight = this.targetQ1.trainableWeights[i] + const weight = this.q1.trainableWeights[i] + + targetWeight.write(weight.read()) + } + + for (let i = 0; i < this.targetQ2.trainableWeights.length; i++) { + const targetWeight = this.targetQ2.trainableWeights[i] + const weight = this.q2.trainableWeights[i] + + targetWeight.write(weight.read()) } + } + act(observation: number[], deterministic: boolean) { return tf.tidy(() => { - const [action] = this.policy.apply(this.actTensor.toTensor(), { - training: false, + const [action] = this.policy.apply(tf.tensor2d([observation]), { + deterministic, }) as tf.Tensor[] - return action + return action.squeeze([1]).arraySync() as number[] }) } @@ -97,78 +106,208 @@ export class SoftActorCritic { } if (this.t > this.config.updateAfter && this.t % this.config.updateEvery === 0) { + console.log("update") for (let i = 0; i < this.config.updateEvery; i++) { - this.update() + tf.tidy(() => { + this.update(this.replayBuffer.sample()) + }) } } + } + + predictQ1(observation: tf.Tensor2D, action: tf.Tensor2D) { + return tf.squeeze( + this.q1.apply(tf.concat([observation, action], 1)) as tf.Tensor, + [-1], + ) as tf.Tensor + } - console.log(this.t) + predictQ2(observation: tf.Tensor2D, action: tf.Tensor2D) { + return tf.squeeze( + this.q2.apply(tf.concat([observation, action], 1)) as tf.Tensor, + [-1], + ) as tf.Tensor } - private update() { - tf.tidy(() => { - const batch = this.replayBuffer.sample() - const backup = this.computeBackup(batch) + test() { + this.deterministic = true + + this.q1.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) + this.q2.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) + this.policy.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) + this.targetQ1.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) + this.targetQ2.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) + let seedi = 9 + + function randomNumber(seed: number, min: number, max: number) { + const a = 1103515245 + const c = 721847 + + seed = (a * seed + c) % 2 ** 31 + return min + (seed % (max - min)) + } + + /* + seedi = 9 + + for i in range(10): + seedi += 4 + observation = torch.tensor([[ + randomNumber(seedi, -10, 10), + randomNumber(seedi + 1, -10, 10), + randomNumber(seedi + 2, -10, 10), + randomNumber(seedi + 3, -10, 10)] + ]) + + print("R", i, ": ", ac.pi(observation, True)) + */ + + /* + seedi = 9 + + for (let i = 0; i < 10; i++) { + seedi += 4 + const observation = tf.tensor2d([ + [ + randomNumber(seedi, -10, 10), + randomNumber(seedi + 1, -10, 10), + randomNumber(seedi + 2, -10, 10), + randomNumber(seedi + 3, -10, 10), + ], + ]) + + console.log(observation.dataSync()) + const [a, b] = this.policy.apply(observation, { + deterministic: true, + }) + console.log("R", i, ": ", a.dataSync(), b.dataSync()) + } + */ + function randomData() { + /* + def randomNumber(seed, min, max): + a = 1103515245 + c = 721847 + + seed = (a * seed + c) % 2**31 + return (float) (min + (seed % (max - min))) + + seedi = 9 + + def randomData(): + global seedi + + seedi += 5 + data = { + 'obs': torch.tensor([[randomNumber(seedi, -10, 10), randomNumber(2, -10, 10), randomNumber(3, -10, 10), randomNumber(4, -10, 10)]]), + 'act': torch.tensor([[randomNumber(seedi + 1, -1, 1)]]), + 'rew': torch.tensor([[randomNumber(seedi + 2, -100, 100)]]), + 'obs2': torch.tensor([[randomNumber(seedi + 3, -10, 10), randomNumber(8, -10, 10), randomNumber(9, -10, 10), randomNumber(10, -10, 10)]]), + 'done': torch.tensor([[randomNumber(seedi + 4, 0, 1)]]) + } + + return data + */ + + seedi += 5 + + return { + observation: tf.tensor2d([ + [ + randomNumber(seedi, -10, 10), + randomNumber(2, -10, 10), + randomNumber(3, -10, 10), + randomNumber(4, -10, 10), + ], + ]), + action: tf.tensor2d([[randomNumber(seedi + 1, -1, 1)]]), + reward: tf.tensor1d([randomNumber(seedi + 2, -100, 100)]), + nextObservation: tf.tensor2d([ + [ + randomNumber(seedi + 3, -10, 10), + randomNumber(8, -10, 10), + randomNumber(9, -10, 10), + randomNumber(10, -10, 10), + ], + ]), + done: tf.tensor1d([randomNumber(seedi + 4, 0, 1)]), + } + } + + let data = randomData() + for (let i = 0; i < 1000; i++) { + console.log("Action: ", this.act(data.observation.arraySync()[0], true)[0]) + this.update(data) + data = randomData() + } + + this.q1.trainableWeights.forEach(w => { + console.log(w.read().dataSync()) + }) + + console.log("Verify: ", randomNumber(seedi, 0, 1000)) + } + + private deterministic = false + + private update(batch: ExperienceTensor) { + tf.tidy(() => { const lossQ = () => { - const errorQ1 = tf.losses.meanSquaredError( - backup, - this.q1.apply([batch.observation, batch.action], { - training: true, - }) as tf.Tensor, - ) + const q1 = this.predictQ1(batch.observation, batch.action) + const q2 = this.predictQ2(batch.observation, batch.action) - const errorQ2 = tf.losses.meanSquaredError( - backup, - this.q2.apply([batch.observation, batch.action], { - training: true, - }) as tf.Tensor, - ) + const backup = tf.tensor1d(this.computeBackup(batch).arraySync()) + + const errorQ1 = tf.mean(tf.square(tf.sub(q1, backup))) as tf.Scalar + const errorQ2 = tf.mean(tf.square(tf.sub(q2, backup))) as tf.Scalar return tf.add(errorQ1, errorQ2) as tf.Scalar } + console.log("lossQ: ", lossQ().arraySync()) + const gradsQ = tf.variableGrads(lossQ) this.qOptimizer.applyGradients(gradsQ.grads) const lossPolicy = () => { const [pi, logpPi] = this.policy.apply(batch.observation, { - training: true, + deterministic: this.deterministic, }) as tf.Tensor[] - const piQ1 = this.q1.apply([batch.observation, pi], { - training: false, - }) as tf.Tensor - const piQ2 = this.q2.apply([batch.observation, pi], { - training: false, - }) as tf.Tensor + const piQ1 = this.predictQ1(batch.observation, pi) + const piQ2 = this.predictQ2(batch.observation, pi) const minPiQ = tf.minimum(piQ1, piQ2) - return tf.mean(tf.sub(tf.mul(this.config.alpha, logpPi), minPiQ)) as tf.Scalar + return tf.mean(logpPi.mul(this.config.alpha).sub(minPiQ)) as tf.Scalar } - const gradsPolicy = tf.variableGrads(lossPolicy) + console.log("lossPolicy: ", lossPolicy().arraySync()) + + const gradsPolicy = tf.variableGrads(lossPolicy, this.policy.getWeights()) this.policyOptimizer.applyGradients(gradsPolicy.grads) - for (let i = 0; i < this.q1.trainableWeights.length; ++i) { - const targetQ1 = this.targetQ1.trainableWeights[i] - const q1 = this.q1.trainableWeights[i] + for (let i = 0; i < this.targetQ1.trainableWeights.length; i++) { + const targetWeight = this.targetQ1.trainableWeights[i] + const weight = this.q1.trainableWeights[i] - targetQ1.write( + targetWeight.write( tf.add( - tf.mul(this.config.polyak, targetQ1.read()), - tf.mul(1 - this.config.polyak, q1.read()), + tf.mul(this.config.polyak, targetWeight.read()), + tf.mul(1 - this.config.polyak, weight.read()), ), ) + } - const targetQ2 = this.targetQ2.trainableWeights[i] - const q2 = this.q2.trainableWeights[i] + for (let i = 0; i < this.targetQ2.trainableWeights.length; i++) { + const targetWeight = this.targetQ2.trainableWeights[i] + const weight = this.q2.trainableWeights[i] - targetQ2.write( + targetWeight.write( tf.add( - tf.mul(this.config.polyak, targetQ2.read()), - tf.mul(1 - this.config.polyak, q2.read()), + tf.mul(this.config.polyak, targetWeight.read()), + tf.mul(1 - this.config.polyak, weight.read()), ), ) } @@ -176,17 +315,19 @@ export class SoftActorCritic { } private computeBackup(batch: ExperienceTensor) { - const [action, logpPi] = this.policy.apply(batch.nextObservation) as tf.Tensor[] + const [action, logpPi] = this.policy.apply(batch.nextObservation, { + deterministic: this.deterministic, + }) as tf.Tensor[] - const targetQ1 = this.targetQ1.apply([ - batch.nextObservation, - action, - ]) as tf.Tensor + const targetQ1 = tf.squeeze( + this.targetQ1.apply(tf.concat([batch.nextObservation, action], 1)), + [-1], + ) as tf.Tensor - const targetQ2 = this.targetQ2.apply([ - batch.nextObservation, - action, - ]) as tf.Tensor + const targetQ2 = tf.squeeze( + this.targetQ2.apply(tf.concat([batch.nextObservation, action], 1)), + [-1], + ) as tf.Tensor const minTargetQ = tf.minimum(targetQ1, targetQ2) const softQTarget = tf.sub(minTargetQ, tf.mul(this.config.alpha, logpPi)) diff --git a/packages/runtime-framework/src/message-store.ts b/packages/runtime-framework/src/message-store.ts index 4493f184..1f3c2c8b 100644 --- a/packages/runtime-framework/src/message-store.ts +++ b/packages/runtime-framework/src/message-store.ts @@ -55,8 +55,6 @@ export const createMessageStore = < messageName: T, message: WithoutTarget[T], ) { - console.log("publish", messageName, message) - for (const callback of listenerMap.get(messageName.toString()) ?? []) { callback(message) } diff --git a/packages/runtime/src/core/rocket/systems/rocket-death-system.ts b/packages/runtime/src/core/rocket/systems/rocket-death-system.ts index 2b637ada..481bae22 100644 --- a/packages/runtime/src/core/rocket/systems/rocket-death-system.ts +++ b/packages/runtime/src/core/rocket/systems/rocket-death-system.ts @@ -125,8 +125,6 @@ export const newRocketDeathSystem: RuntimeSystemFactory = ({ ? contact.solverContactPoint(0) : contact.solverContactPoint(0) - console.log("rocket death", contactPoint) - messageStore.publish("rocketDeath", { position: rocket.components.rigidBody.translation(), rotation: rocket.components.rigidBody.rotation(), diff --git a/packages/server/package.json b/packages/server/package.json index 59cda5e8..d443ed7b 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -4,7 +4,7 @@ "description": "", "main": "index.js", "scripts": { - "_postinstall": "wrangler d1 migrations apply polyburn --local && node ./tools/postprocess-rapier2d.mjs", + "postinstall": "wrangler d1 migrations apply polyburn --local && node ./tools/postprocess-rapier2d.mjs", "generate": "drizzle-kit generate:sqlite", "dev": "wrangler dev src/index.ts --live-reload", "lint": "eslint \"src/**/*.{tsx,ts}\"" diff --git a/packages/web-game/package.json b/packages/web-game/package.json index aa0d16a4..fb838820 100644 --- a/packages/web-game/package.json +++ b/packages/web-game/package.json @@ -9,6 +9,7 @@ "web-game-preview": "vite preview" }, "devDependencies": { + "@types/sat": "^0.0.35", "@types/three": "^0.162.0", "autoprefixer": "^10.4.19", "postcss": "^8.4.38", @@ -20,7 +21,6 @@ "vite": "^5.1.6" }, "dependencies": { - "@types/sat": "^0.0.35", "eslint-config-custom": "*", "lil-gui": "^0.19.2", "poly-decomp-es": "^0.4.2", diff --git a/packages/web-game/src/game/modules/module-input/module-input.ts b/packages/web-game/src/game/modules/module-input/module-input.ts index 645c8479..6e7f794e 100644 --- a/packages/web-game/src/game/modules/module-input/module-input.ts +++ b/packages/web-game/src/game/modules/module-input/module-input.ts @@ -5,6 +5,197 @@ import { Mouse } from "./mouse" const CHARCODE_ONE = "1".charCodeAt(0) const CHARCODE_NINE = "9".charCodeAt(0) +const inputs = [ + { rotation: 0.928602397441864, thrust: true }, + { rotation: 0.13935142755508423, thrust: true }, + { rotation: -0.8606485724449158, thrust: true }, + { rotation: -1.8606485724449158, thrust: true }, + { rotation: -0.8606485724449158, thrust: true }, + { rotation: -1.8606485724449158, thrust: true }, + { rotation: -0.8606485724449158, thrust: true }, + { rotation: -1.8606485724449158, thrust: true }, + { rotation: -0.8606485724449158, thrust: true }, + { rotation: 0.13935142755508423, thrust: true }, + { rotation: -0.8606485724449158, thrust: true }, + { rotation: 0.13935142755508423, thrust: true }, + { rotation: -0.8606485724449158, thrust: true }, + { rotation: -0.523856908082962, thrust: true }, + { rotation: -1.523856908082962, thrust: true }, + { rotation: -2.523856908082962, thrust: true }, + { rotation: -3.523856908082962, thrust: true }, + { rotation: -4.523856908082962, thrust: true }, + { rotation: -5.523856908082962, thrust: true }, + { rotation: -4.523856908082962, thrust: true }, + { rotation: -5.523856908082962, thrust: true }, + { rotation: -5.0757246017456055, thrust: true }, + { rotation: -6.0757246017456055, thrust: true }, + { rotation: -5.0757246017456055, thrust: true }, + { rotation: -6.0757246017456055, thrust: true }, + { rotation: -5.0757246017456055, thrust: true }, + { rotation: -5.493886172771454, thrust: true }, + { rotation: -4.493886172771454, thrust: true }, + { rotation: -5.4650625586509705, thrust: true }, + { rotation: -4.4650625586509705, thrust: true }, + { rotation: -5.4650625586509705, thrust: true }, + { rotation: -5.493640601634979, thrust: true }, + { rotation: -4.493640601634979, thrust: true }, + { rotation: -5.493640601634979, thrust: true }, + { rotation: -4.493640601634979, thrust: true }, + { rotation: -5.493640601634979, thrust: true }, + { rotation: -5.574220359325409, thrust: true }, + { rotation: -4.574220359325409, thrust: true }, + { rotation: -5.574220359325409, thrust: true }, + { rotation: -5.371117532253265, thrust: true }, + { rotation: -5.393860101699829, thrust: true }, + { rotation: -5.179872035980225, thrust: true }, + { rotation: -4.758336663246155, thrust: true }, + { rotation: -4.888206958770752, thrust: true }, + { rotation: -5.888206958770752, thrust: true }, + { rotation: -4.888206958770752, thrust: true }, + { rotation: -5.888206958770752, thrust: true }, + { rotation: -4.888206958770752, thrust: true }, + { rotation: -3.888206958770752, thrust: true }, + { rotation: -4.888206958770752, thrust: true }, + { rotation: -5.8682825565338135, thrust: true }, + { rotation: -4.8682825565338135, thrust: true }, + { rotation: -5.8682825565338135, thrust: true }, + { rotation: -4.8682825565338135, thrust: true }, + { rotation: -5.779716372489929, thrust: true }, + { rotation: -4.779716372489929, thrust: true }, + { rotation: -5.779716372489929, thrust: true }, + { rotation: -5.314650177955627, thrust: true }, + { rotation: -5.377287566661835, thrust: true }, + { rotation: -5.50549703836441, thrust: true }, + { rotation: -4.565630733966827, thrust: true }, + { rotation: -5.565630733966827, thrust: true }, + { rotation: -5.236084401607513, thrust: true }, + { rotation: -4.80553811788559, thrust: true }, + { rotation: -5.80553811788559, thrust: true }, + { rotation: -4.80553811788559, thrust: true }, + { rotation: -5.507085978984833, thrust: true }, + { rotation: -4.507085978984833, thrust: true }, + { rotation: -5.507085978984833, thrust: true }, + { rotation: -5.149120390415192, thrust: true }, + { rotation: -6.149120390415192, thrust: true }, + { rotation: -5.149120390415192, thrust: true }, + { rotation: -6.036867916584015, thrust: true }, + { rotation: -5.036867916584015, thrust: true }, + { rotation: -6.036867916584015, thrust: true }, + { rotation: -5.036867916584015, thrust: true }, + { rotation: -5.05073469877243, thrust: true }, + { rotation: -5.590804040431976, thrust: true }, + { rotation: -6.352537512779236, thrust: true }, + { rotation: -5.352537512779236, thrust: true }, + { rotation: -4.628996133804321, thrust: true }, + { rotation: -5.4285489320755005, thrust: true }, + { rotation: -4.843723893165588, thrust: true }, + { rotation: -5.4131468534469604, thrust: true }, + { rotation: -4.4131468534469604, thrust: true }, + { rotation: -5.4131468534469604, thrust: true }, + { rotation: -4.4131468534469604, thrust: true }, + { rotation: -5.4131468534469604, thrust: true }, + { rotation: -6.4131468534469604, thrust: true }, + { rotation: -5.716241121292114, thrust: true }, + { rotation: -6.401003062725067, thrust: true }, + { rotation: -5.401003062725067, thrust: true }, + { rotation: -5.035071134567261, thrust: true }, + { rotation: -5.827503323554993, thrust: true }, + { rotation: -4.827503323554993, thrust: true }, + { rotation: -5.827503323554993, thrust: true }, + { rotation: -4.827503323554993, thrust: true }, + { rotation: -5.827503323554993, thrust: true }, + { rotation: -4.827503323554993, thrust: true }, + { rotation: -5.827503323554993, thrust: true }, + { rotation: -4.827503323554993, thrust: true }, + { rotation: -5.827503323554993, thrust: true }, + { rotation: -5.135900259017944, thrust: true }, + { rotation: -6.135900259017944, thrust: true }, + { rotation: -5.149014234542847, thrust: true }, + { rotation: -4.686952769756317, thrust: true }, + { rotation: -5.686952769756317, thrust: false }, + { rotation: -4.686952769756317, thrust: true }, + { rotation: -5.686952769756317, thrust: true }, + { rotation: -5.333105087280273, thrust: true }, + { rotation: -5.489454597234726, thrust: true }, + { rotation: -4.489454597234726, thrust: true }, + { rotation: -5.489454597234726, thrust: true }, + { rotation: -6.159207135438919, thrust: true }, + { rotation: -5.159207135438919, thrust: true }, + { rotation: -6.159207135438919, thrust: true }, + { rotation: -5.159207135438919, thrust: true }, + { rotation: -5.264431670308113, thrust: true }, + { rotation: -5.5160354524850845, thrust: true }, + { rotation: -4.5160354524850845, thrust: true }, + { rotation: -5.5160354524850845, thrust: true }, + { rotation: -4.5160354524850845, thrust: true }, + { rotation: -5.5160354524850845, thrust: true }, + { rotation: -4.5160354524850845, thrust: true }, + { rotation: -5.5160354524850845, thrust: true }, + { rotation: -4.5160354524850845, thrust: true }, + { rotation: -5.210378661751747, thrust: true }, + { rotation: -5.789670959115028, thrust: true }, + { rotation: -5.821384325623512, thrust: true }, + { rotation: -5.038382187485695, thrust: true }, + { rotation: -6.038382187485695, thrust: true }, + { rotation: -5.038382187485695, thrust: true }, + { rotation: -6.038382187485695, thrust: true }, + { rotation: -5.038382187485695, thrust: true }, + { rotation: -5.197831615805626, thrust: true }, + { rotation: -4.203378781676292, thrust: true }, + { rotation: -5.203378781676292, thrust: true }, + { rotation: -4.203378781676292, thrust: true }, + { rotation: -5.203378781676292, thrust: true }, + { rotation: -4.203378781676292, thrust: true }, + { rotation: -5.203378781676292, thrust: true }, + { rotation: -5.559658095240593, thrust: true }, + { rotation: -4.559658095240593, thrust: true }, + { rotation: -5.559658095240593, thrust: true }, + { rotation: -4.742914721369743, thrust: true }, + { rotation: -5.742914721369743, thrust: true }, + { rotation: -4.742914721369743, thrust: true }, + { rotation: -5.742914721369743, thrust: true }, + { rotation: -4.742914721369743, thrust: true }, + { rotation: -5.742914721369743, thrust: true }, + { rotation: -4.749851331114769, thrust: true }, + { rotation: -5.749851331114769, thrust: true }, + { rotation: -4.749851331114769, thrust: true }, + { rotation: -5.749851331114769, thrust: true }, + { rotation: -4.749851331114769, thrust: true }, + { rotation: -5.749851331114769, thrust: true }, + { rotation: -4.749851331114769, thrust: true }, + { rotation: -5.749851331114769, thrust: true }, + { rotation: -5.938547268509865, thrust: true }, + { rotation: -6.0961340218782425, thrust: true }, + { rotation: -5.0961340218782425, thrust: true }, + { rotation: -5.992851212620735, thrust: true }, + { rotation: -5.3073021322488785, thrust: true }, + { rotation: -6.3073021322488785, thrust: true }, + { rotation: -5.3073021322488785, thrust: true }, + { rotation: -6.3073021322488785, thrust: true }, + { rotation: -6.077244356274605, thrust: true }, + { rotation: -6.318330153822899, thrust: true }, + { rotation: -5.318330153822899, thrust: true }, + { rotation: -6.318330153822899, thrust: true }, + { rotation: -5.438394114375114, thrust: true }, + { rotation: -5.514149233698845, thrust: true }, + { rotation: -6.514149233698845, thrust: true }, + { rotation: -6.433003231883049, thrust: true }, + { rotation: -5.433003231883049, thrust: true }, + { rotation: -6.433003231883049, thrust: true }, + { rotation: -5.61981026828289, thrust: true }, + { rotation: -6.61981026828289, thrust: true }, + { rotation: -6.839142426848412, thrust: true }, + { rotation: -5.839142426848412, thrust: true }, + { rotation: -6.839142426848412, thrust: true }, + { rotation: -6.279265388846397, thrust: true }, + { rotation: -7.279265388846397, thrust: true }, + { rotation: -6.315085455775261, thrust: true }, + { rotation: -7.315085455775261, thrust: true }, + { rotation: -6.7484916895627975, thrust: true }, + { rotation: -7.7484916895627975, thrust: true }, + { rotation: -6.7484916895627975, thrust: true }, +] + export class ModuleInput { private keyboard: Keyboard private mouse: Mouse @@ -12,6 +203,9 @@ export class ModuleInput { private rotationSpeed = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.5, 3.0] private rotationSpeedIndex = 2 + private c: { rotation: number; thrust: boolean } = { rotation: 0, thrust: false } + private i = 0 + constructor(runtime: ExtendedRuntime) { this.keyboard = new Keyboard() this.mouse = new Mouse(runtime) @@ -32,15 +226,17 @@ export class ModuleInput { } rotation() { - return this.keyboard.rotation() + this.mouse.rotation() + return this.c.rotation } thrust() { - return this.keyboard.thrust() || this.mouse.thrust() + return this.c.thrust } onPreFixedUpdate(delta: number) { this.keyboard.onPreFixedUpdate(delta) + this.c = inputs[this.i % inputs.length] + ++this.i } onKeyboardDown(event: KeyboardEvent) { diff --git a/packages/web-game/src/main.ts b/packages/web-game/src/main.ts index d8c16cea..44c26762 100644 --- a/packages/web-game/src/main.ts +++ b/packages/web-game/src/main.ts @@ -1,4 +1,3 @@ -/* import { WorldModel } from "runtime/proto/world" import { Game } from "./game/game" import { GameLoop } from "./game/game-loop" @@ -8,14 +7,16 @@ function base64ToBytes(base64: string) { return Uint8Array.from(atob(base64), c => c.charCodeAt(0)) } -const world = - "CscCCgZOb3JtYWwSvAIKCg2F65XBFTXTGkISKA2kcLrBFZfjFkIlAAAAwi1SuIlCNa5H+UE9H4X/QUUAAABATQAAAEASKA1SuMFBFZmRGkIlhetRQS3NzFJCNSlcp0I9zcxEQUUAAABATQAAAEASKA0AgEVCFfIboEElAAAoQi0K189BNaRw4UI9rkdZwUUAAABATQAAAEASKA171MBCFcubHcElmpm5Qi0K189BNY/CI0M9rkdZwUUAAABATQAAAEASLQ1syOFCFToytkEdVGuzOiWamblCLSlcZUI1XI8jQz3NzIhBRQAAAEBNAAAAQBItDR/lAUMVk9VNQh2fUDa1JaRw9UItexRsQjWF60FDPQAAlEFFAAAAQE0AAABAEigNw1UzQxVpqkFCJdejJEMtBW94QjXXo0JDPQVvAEJFAAAAQE0AAABACu4KCg1Ob3JtYWwgU2hhcGVzEtwKGt8GCtwGP4UAws3MNEGgEEAAZjYAAP///wB1PAAU////AF5PABT///8AyUtPxP///wAzSg3L////AMBJAcj///8AE0Umzf///wCMVAo5////AJNRpDr///8AVE0WVP///wD0vlZLAAD/AEPI7Bn///8AhcPlOAAA/wAFQZrF////ADS9F8f///8AJMIuwf///wC5xvvF////AOrJ1rf///8Ac8ikQP///wBAxfRF////AGkxi0n///8Aj0LxQgAA/wB1xWY9////AJ/HZAlQUP4AzcUBvQAA/wDwQFzE////ADDGR73///8As8eZPoiI8QBxxWQ3rKz/AFw3LMQAAP8AwkNRtP///wC2RKO4////AEhBe8EAAP8AS0WPPP///wAdSaSx////AMw/Ucj///8A7MBNxv///wDmxnG9////AELCFLr///8Aw8UOof///wAKxCg4AAD/ALg8OMDZ2fsA4j9NwP///wCkxB+/AADwAHGwrr54ePgAVERcwv///wAPwXbA////APW0H0EAAPgASLtnv////wALM67DJSX/AFJApL////8AZj4uwP///wBcu+HATU3/AIU7+8H///8AXMK8Lf///wB7wjM/AAD4AHDCx8D///8AFEH7wP///wAAvnvE////AOTGChL///8A6bncRP///wCAQddAAAD4AB/AxLH///8AIL9RPQAA+ACZwqvG////AOLCLkQAAPgAIcTrwP///wDtwQPH////AOLJbqz///8ALsR6QwAA+AD+x8zA////APtF90kyMv8AH7mZQCcn/wCNxHo8tbX/AIDAiETKyv8AXEAgSgAA+AClyAqS////AH9EG0n///8AS0ypRP///wAxSIK7MDToANjBdUf///8A58yjxP///wCByD1EMDToAIzCYMv///8AnMq3MzA06AC+QenF////ANzGT0T///8AtMFSR////wBzRb85lpj/AFJALEQwNOgAqMIpPjA06AAgyiCF////AAPEE77///8AzT4FSnN1/wAzxWFCMDToAA23PcKXl/8AGcLmQDA06ADMPUnJu77/AFrGxsL///8A1TRGSjA06ACKwik8MDToAE3Apcn///8Ar8SawP///wBsygqP////ABHI8z0wNOgAAABTzv///wAa9wMK9APNzJNCj8JlQP///wBmtly8////ABa2jsg2Nv8AO0SENwAA+ACkvrtEvLz/AG0uOEX///8A4UaHPv///wA+QlXFAAD4AApB2L4AAPgAeDLVRP///wATSHHAAAD4ADhA3EP///8As0MKvAAA8ADOPxM4AAD4AEjBTUD///8Arj5TP3B0+ACyKw9DaGz4ALm6eDz///8AKT4MSP///wDhPy5CAAD/APS/XEL///8A+EV6PwAA/wAdsXtBp6f/AGzEpEEAAP8AisfEuf///wDXwVJI////AJpEaUf///8AhUfxQP///wB7RA3FAAD/ANdBTzUAAP8AC8C9Rv///wBGQoVE////APRMpDz///8A7kS3yAAA/wDLR9HB////AFLHNscAAP8AR0HNwf///wDsvtLGAAD/AABE5kD///8AD0JIRv///wD0RNJA////AEVFqcD///8A3ESpwwAA/wAuwgtJ////AARBqEj///8ALUdbSf///wA01Hks////AHjCAL3///8AF8s5x////wC4vlPP////AME1O8f///8AhsIAPgAA+ABcxZXC7e3/AIrEpUMAAPgAjcbDxcvL/wBdQFzF////AEjI+8EAAOAAQ0GZvf///wAGN77AFRX/APlFXDz///8AikEzwkhI+ADcQmoy////AArNAgoHUmV2ZXJzZRLBAgoPDRydLkMVk5lFQh2z7Zk2EigNpHC6wRWX4xZCJQAAAMItAABMQjUAAEDBPR+F/0FFAAAAQE0AAABAEigNUrjBQRWZkRpCJR+FAMItZuaJQjUAAPpBPQAAAEJFAAAAQE0AAABAEigNAIBFQhXyG6BBJQAAUEEthetRQjWkcKdCPVK4TkFFAAAAQE0AAABAEigNe9TAQhXLmx3BJTQzKEItCtfPQTUeBeJCPa5HWcFFAAAAQE0AAABAEi0NbMjhQhU6MrZBHVRrszolmpm5Qi1SuNRBNVyPI0M9ZmZawUUAAABATQAAAEASLQ0f5QFDFZPVTUIdn1A2tSWk8LlCLXsUZUI1hSskQz0AAIZBRQAAAEBNAAAAQBIoDcNVM0MVaapBQiUAgPVCLQAAbEI1AABCQz0AAJRBRQAAAEBNAAAAQBIhCgZOb3JtYWwSFwoNTm9ybWFsIFNoYXBlcwoGTm9ybWFsEiMKB1JldmVyc2USGAoNTm9ybWFsIFNoYXBlcwoHUmV2ZXJzZQ==" -const worldModel = WorldModel.decode(base64ToBytes(world)) +const worldStr = + "CqAJCgZOb3JtYWwSlQkKDw0fhZ3BFR+FB0Id2w/JQBItDR+FtsEVgZUDQh3bD8lAJQAAEMItpHBhQjWuR9lBPR+Fm0FFAAAAQE0AAABAEi0Nrkc/QRVt5wZCHdsPyUAlAAD4QC2kcBZCNezRjUI94KMwP0UAAABATQAAAEASLQ2k8B5CFX9qWEEd2w/JQCUAAP5BLaRwFkI17NG9Qj3gozA/RQAAAEBNAAAAQBItDeyRm0IVPzWGQR3bD8lAJQCAjUItSOHsQTX26AVDPYTr6cBFAAAAQE0AAABAEi0Nw0XwQhUcd4lAHTMeejwlAIDnQi2kcA5CNfboMkM9EK6nv0UAAABATQAAAEASLQ2PYhxDFT813EEd2w/JQCUAAM9CLaRwbEI1AMAmQz0fhbFBRQAAAEBNAAAAQBItDcM15UIVYxBJQh3bD8lAJQAAeUItUrijQjXs0fpCPZDCM0JFAAAAQE0AAABAEi0N9WiFQhXVeIhCHdsPyUAlw7WBQi3sUY9CNcO1kUI9AACBQkUAAABATQAAAEAaTgpMpHA9wXE9ukHAwP8AAEAAPYCA/wAAtIBDAAD/AIDFAEBAQP8AgMgAAICA/wBAxgC+oKD/AABGAMf///8AV0dxQry8+QBSQPHA////ABpOCkyuR3FBSOHKQf/++ABAxgAA//3wAAA/QMT/++AAQEoAQv/3wAAAPkBF/++AAADHAD//3gAAgMYAAP/vgAAAAIDD////AKxGCq////8AGpcCCpQC9qjBQpqZJEL///8AMNEAOv///wDqy9pH////AOzHNML///8AAMIAx////wAAQkDE////AABFAL3///8AAELAx////wCARgBF////AEBGgMb///8AwEYAv////wAgSQBF////AOBIgMP///8A4EjAR////wAARYDE////AAC+oMj///8AAD8AAP///wAAAODK////AGBJAEf///8AwMTASP///wAgSQAA////AEBEwMb///8AAEOAQ////wBASQC/////AAA+wEj///8AwEqAw////wAAvMBL////AODIAAD///8AQMoAQP///wAAPgBI////ACDIAAD///8AgMCARv///wCAyQAA////AEBFgMb///8AGqcCCqQCpHAZQqRwOcH///8AmFgAwP///wCAxwhU////AGDK4E3///8AwM1gyf///wAAv+DI////AKBLAMP///8AADpgyf///wCARgAA////AAA6YMv///8AQMgAAP///wAAvuDJ////AIBFYMj///8AQMyAwf///wAAtMDG////AGDLAL3///8AOMAMSP///wAkxgCu////AADC4Mj///8AAMNARv///wBgyQAA////AEDHgMP///8AwMeAQf///wAAAEBM////ACDJAAD///8AgMMAx////wAAyoBC////AAC9AMb///8AgMTARf///wCAwIDB////AABFAML///8AAMgANP///wBAxEBG////AADHAAD///8AAMFAyP///wBgyEDE////ABomCiSPQopCcT2DQv/AjQAAxAAA/+R0AAAAAMT/kwAAAEQAAP+bAAASEgoGTm9ybWFsEggKBk5vcm1hbA==" +const worldModel = WorldModel.decode(base64ToBytes(worldStr)) const settings: GameSettings = { instanceType: GameInstanceType.Play, world: worldModel, gamemode: "Normal", + canvas: document.getElementById("scene") as HTMLCanvasElement, + worldname: "Test", } try { @@ -24,4 +25,3 @@ try { } catch (e) { console.error(e) } -*/