diff --git a/api/Pipfile b/api/Pipfile index a1c6c74..d3195bd 100644 --- a/api/Pipfile +++ b/api/Pipfile @@ -25,6 +25,7 @@ requests = "*" alembic = "*" sqlalchemy = "*" mysql-connector-python = "*" +pytest = "*" [dev-packages] diff --git a/api/Pipfile.lock b/api/Pipfile.lock index 930c2d4..6fc3b52 100644 --- a/api/Pipfile.lock +++ b/api/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "2b8bbf71251550eb71f393b52f3351e7a35e9ebe9bdb5d621d7f0bb2ec9bc145" + "sha256": "61d634f858b70f3e04ed9c267d51cde6e5c2f172526441a7e7f3b53a8413ff79" }, "pipfile-spec": 6, "requires": { @@ -357,6 +357,14 @@ "markers": "python_version >= '3.5'", "version": "==3.4" }, + "iniconfig": { + "hashes": [ + "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", + "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374" + ], + "markers": "python_version >= '3.7'", + "version": "==2.0.0" + }, "itsdangerous": { "hashes": [ "sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19", @@ -377,11 +385,11 @@ }, "loguru": { "hashes": [ - "sha256:1612053ced6ae84d7959dd7d5e431a0532642237ec21f7fd83ac73fe539e03e1", - "sha256:b93aa30099fa6860d4727f1b81f8718e965bb96253fa190fab2077aaad6d15d3" + "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb", + "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac" ], "markers": "python_version >= '3.5'", - "version": "==0.7.0" + "version": "==0.7.2" }, "mako": { "hashes": [ @@ -602,14 +610,30 @@ "markers": "python_full_version >= '3.7.1'", "version": "==0.27.2" }, + "packaging": { + "hashes": [ + "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61", + "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f" + ], + "markers": "python_version >= '3.7'", + "version": "==23.1" + }, "pinecone-client": { "hashes": [ - "sha256:21fddb752668efee4d3c6b706346d9580e36a8b06b8d97afd60bd33ef2536e7e", - "sha256:391fe413754efd4e0ef00154b44271d63c4cdd4bedf088d23111a5725d863210" + "sha256:2c1cc1d6648b2be66e944db2ffa59166a37b9164d1135ad525d9cd8b1e298168", + "sha256:5bf496c01c2f82f4e5c2dc977cc5062ecd7168b8ed90743b09afcc8c7eb242ec" ], "index": "pypi", "markers": "python_version >= '3.8'", - "version": "==2.2.2" + "version": "==2.2.4" + }, + "pluggy": { + "hashes": [ + "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12", + "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7" + ], + "markers": "python_version >= '3.8'", + "version": "==1.3.0" }, "protobuf": { "hashes": [ @@ -631,6 +655,15 @@ "markers": "python_version >= '3.7'", "version": "==4.21.12" }, + "pytest": { + "hashes": [ + "sha256:1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002", + "sha256:a766259cfab564a2ad52cb1aae1b881a75c3eb7e34ca3779697c23ed47c47069" + ], + "index": "pypi", + "markers": "python_version >= '3.7'", + "version": "==7.4.2" + }, "python-dateutil": { "hashes": [ "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86", @@ -809,11 +842,11 @@ }, "setuptools": { "hashes": [ - "sha256:3d4dfa6d95f1b101d695a6160a7626e15583af71a5f52176efa5d39a054d475d", - "sha256:3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b" + "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87", + "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a" ], "markers": "python_version >= '3.8'", - "version": "==68.1.2" + "version": "==68.2.2" }, "six": { "hashes": [ @@ -825,51 +858,51 @@ }, "sqlalchemy": { "hashes": [ - "sha256:1506e988ebeaaf316f183da601f24eedd7452e163010ea63dbe52dc91c7fc70e", - "sha256:1a58052b5a93425f656675673ef1f7e005a3b72e3f2c91b8acca1b27ccadf5f4", - "sha256:1b74eeafaa11372627ce94e4dc88a6751b2b4d263015b3523e2b1e57291102f0", - "sha256:1be86ccea0c965a1e8cd6ccf6884b924c319fcc85765f16c69f1ae7148eba64b", - "sha256:1d35d49a972649b5080557c603110620a86aa11db350d7a7cb0f0a3f611948a0", - "sha256:243d0fb261f80a26774829bc2cee71df3222587ac789b7eaf6555c5b15651eed", - "sha256:26a3399eaf65e9ab2690c07bd5cf898b639e76903e0abad096cd609233ce5208", - "sha256:27d554ef5d12501898d88d255c54eef8414576f34672e02fe96d75908993cf53", - "sha256:3364b7066b3c7f4437dd345d47271f1251e0cfb0aba67e785343cdbdb0fff08c", - "sha256:3423dc2a3b94125094897118b52bdf4d37daf142cbcf26d48af284b763ab90e9", - "sha256:3c6aceebbc47db04f2d779db03afeaa2c73ea3f8dcd3987eb9efdb987ffa09a3", - "sha256:3ce5e81b800a8afc870bb8e0a275d81957e16f8c4b62415a7b386f29a0cb9763", - "sha256:411e7f140200c02c4b953b3dbd08351c9f9818d2bd591b56d0fa0716bd014f1e", - "sha256:4cde2e1096cbb3e62002efdb7050113aa5f01718035ba9f29f9d89c3758e7e4e", - "sha256:5768c268df78bacbde166b48be788b83dddaa2a5974b8810af422ddfe68a9bc8", - "sha256:599ccd23a7146e126be1c7632d1d47847fa9f333104d03325c4e15440fc7d927", - "sha256:5ed61e3463021763b853628aef8bc5d469fe12d95f82c74ef605049d810f3267", - "sha256:63a368231c53c93e2b67d0c5556a9836fdcd383f7e3026a39602aad775b14acf", - "sha256:63e73da7fb030ae0a46a9ffbeef7e892f5def4baf8064786d040d45c1d6d1dc5", - "sha256:6eb6d77c31e1bf4268b4d61b549c341cbff9842f8e115ba6904249c20cb78a61", - "sha256:6f8a934f9dfdf762c844e5164046a9cea25fabbc9ec865c023fe7f300f11ca4a", - "sha256:6fe7d61dc71119e21ddb0094ee994418c12f68c61b3d263ebaae50ea8399c4d4", - "sha256:759b51346aa388c2e606ee206c0bc6f15a5299f6174d1e10cadbe4530d3c7a98", - "sha256:76fdfc0f6f5341987474ff48e7a66c3cd2b8a71ddda01fa82fedb180b961630a", - "sha256:77d37c1b4e64c926fa3de23e8244b964aab92963d0f74d98cbc0783a9e04f501", - "sha256:79543f945be7a5ada9943d555cf9b1531cfea49241809dd1183701f94a748624", - "sha256:79fde625a0a55220d3624e64101ed68a059c1c1f126c74f08a42097a72ff66a9", - "sha256:7d3f175410a6db0ad96b10bfbb0a5530ecd4fcf1e2b5d83d968dd64791f810ed", - "sha256:8dd77fd6648b677d7742d2c3cc105a66e2681cc5e5fb247b88c7a7b78351cf74", - "sha256:a3f0dd6d15b6dc8b28a838a5c48ced7455c3e1fb47b89da9c79cc2090b072a50", - "sha256:bcb04441f370cbe6e37c2b8d79e4af9e4789f626c595899d94abebe8b38f9a4d", - "sha256:c3d99ba99007dab8233f635c32b5cd24fb1df8d64e17bc7df136cedbea427897", - "sha256:ca8a5ff2aa7f3ade6c498aaafce25b1eaeabe4e42b73e25519183e4566a16fc6", - "sha256:cb0d3e94c2a84215532d9bcf10229476ffd3b08f481c53754113b794afb62d14", - "sha256:d1b09ba72e4e6d341bb5bdd3564f1cea6095d4c3632e45dc69375a1dbe4e26ec", - "sha256:d32b5ffef6c5bcb452723a496bad2d4c52b346240c59b3e6dba279f6dcc06c14", - "sha256:d3793dcf5bc4d74ae1e9db15121250c2da476e1af8e45a1d9a52b1513a393459", - "sha256:dd81466bdbc82b060c3c110b2937ab65ace41dfa7b18681fdfad2f37f27acdd7", - "sha256:e4e571af672e1bb710b3cc1a9794b55bce1eae5aed41a608c0401885e3491179", - "sha256:ea8186be85da6587456c9ddc7bf480ebad1a0e6dcbad3967c4821233a4d4df57", - "sha256:eefebcc5c555803065128401a1e224a64607259b5eb907021bf9b175f315d2a6" + "sha256:014794b60d2021cc8ae0f91d4d0331fe92691ae5467a00841f7130fe877b678e", + "sha256:0268256a34806e5d1c8f7ee93277d7ea8cc8ae391f487213139018b6805aeaf6", + "sha256:05b971ab1ac2994a14c56b35eaaa91f86ba080e9ad481b20d99d77f381bb6258", + "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce", + "sha256:1e7dc99b23e33c71d720c4ae37ebb095bebebbd31a24b7d99dfc4753d2803ede", + "sha256:2e617727fe4091cedb3e4409b39368f424934c7faa78171749f704b49b4bb4ce", + "sha256:3cf229704074bce31f7f47d12883afee3b0a02bb233a0ba45ddbfe542939cca4", + "sha256:3eb7c03fe1cd3255811cd4e74db1ab8dca22074d50cd8937edf4ef62d758cdf4", + "sha256:3f7d57a7e140efe69ce2d7b057c3f9a595f98d0bbdfc23fd055efdfbaa46e3a5", + "sha256:419b1276b55925b5ac9b4c7044e999f1787c69761a3c9756dec6e5c225ceca01", + "sha256:44ac5c89b6896f4740e7091f4a0ff2e62881da80c239dd9408f84f75a293dae9", + "sha256:4615623a490e46be85fbaa6335f35cf80e61df0783240afe7d4f544778c315a9", + "sha256:50a69067af86ec7f11a8e50ba85544657b1477aabf64fa447fd3736b5a0a4f67", + "sha256:513fd5b6513d37e985eb5b7ed89da5fd9e72354e3523980ef00d439bc549c9e9", + "sha256:6ff3dc2f60dbf82c9e599c2915db1526d65415be323464f84de8db3e361ba5b9", + "sha256:73c079e21d10ff2be54a4699f55865d4b275fd6c8bd5d90c5b1ef78ae0197301", + "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8", + "sha256:785e2f2c1cb50d0a44e2cdeea5fd36b5bf2d79c481c10f3a88a8be4cfa2c4615", + "sha256:7ca38746eac23dd7c20bec9278d2058c7ad662b2f1576e4c3dbfcd7c00cc48fa", + "sha256:7f0c4ee579acfe6c994637527c386d1c22eb60bc1c1d36d940d8477e482095d4", + "sha256:87bf91ebf15258c4701d71dcdd9c4ba39521fb6a37379ea68088ce8cd869b446", + "sha256:89e274604abb1a7fd5c14867a412c9d49c08ccf6ce3e1e04fffc068b5b6499d4", + "sha256:8c323813963b2503e54d0944813cd479c10c636e3ee223bcbd7bd478bf53c178", + "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af", + "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b", + "sha256:b19ae41ef26c01a987e49e37c77b9ad060c59f94d3b3efdfdbf4f3daaca7b5fe", + "sha256:b4eae01faee9f2b17f08885e3f047153ae0416648f8e8c8bd9bc677c5ce64be9", + "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09", + "sha256:b977bfce15afa53d9cf6a632482d7968477625f030d86a109f7bdfe8ce3c064a", + "sha256:bf8eebccc66829010f06fbd2b80095d7872991bfe8415098b9fe47deaaa58063", + "sha256:c111cd40910ffcb615b33605fc8f8e22146aeb7933d06569ac90f219818345ef", + "sha256:c2d494b6a2a2d05fb99f01b84cc9af9f5f93bf3e1e5dbdafe4bed0c2823584c1", + "sha256:c9cba4e7369de663611ce7460a34be48e999e0bbb1feb9130070f0685e9a6b66", + "sha256:cca720d05389ab1a5877ff05af96551e58ba65e8dc65582d849ac83ddde3e231", + "sha256:ccb99c3138c9bde118b51a289d90096a3791658da9aea1754667302ed6564f6e", + "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec", + "sha256:e36339a68126ffb708dc6d1948161cea2a9e85d7d7b0c54f6999853d70d44430", + "sha256:ea7da25ee458d8f404b93eb073116156fd7d8c2a776d8311534851f28277b4ce", + "sha256:f9fefd6298433b6e9188252f3bff53b9ff0443c8fde27298b8a2b19f6617eeb9", + "sha256:fb87f763b5d04a82ae84ccff25554ffd903baafba6698e18ebaf32561f2fe4aa", + "sha256:fc6b15465fabccc94bf7e38777d665b6a4f95efd1725049d6184b3a39fd54880" ], "index": "pypi", "markers": "python_version >= '3.7'", - "version": "==2.0.20" + "version": "==2.0.21" }, "stampy-chat": { "editable": true, @@ -886,39 +919,39 @@ }, "tiktoken": { "hashes": [ - "sha256:00d662de1e7986d129139faf15e6a6ee7665ee103440769b8dedf3e7ba6ac37f", - "sha256:08efa59468dbe23ed038c28893e2a7158d8c211c3dd07f2bbc9a30e012512f1d", - "sha256:176cad7f053d2cc82ce7e2a7c883ccc6971840a4b5276740d0b732a2b2011f8a", - "sha256:1b6bce7c68aa765f666474c7c11a7aebda3816b58ecafb209afa59c799b0dd2d", - "sha256:1e8fa13cf9889d2c928b9e258e9dbbbf88ab02016e4236aae76e3b4f82dd8288", - "sha256:2ca30367ad750ee7d42fe80079d3092bd35bb266be7882b79c3bd159b39a17b0", - "sha256:329f548a821a2f339adc9fbcfd9fc12602e4b3f8598df5593cfc09839e9ae5e4", - "sha256:3dc3df19ddec79435bb2a94ee46f4b9560d0299c23520803d851008445671197", - "sha256:450d504892b3ac80207700266ee87c932df8efea54e05cefe8613edc963c1285", - "sha256:4d980fa066e962ef0f4dad0222e63a484c0c993c7a47c7dafda844ca5aded1f3", - "sha256:55e251b1da3c293432179cf7c452cfa35562da286786be5a8b1ee3405c2b0dd2", - "sha256:5727d852ead18b7927b8adf558a6f913a15c7766725b23dbe21d22e243041b28", - "sha256:59b20a819969735b48161ced9b92f05dc4519c17be4015cfb73b65270a243620", - "sha256:5a73286c35899ca51d8d764bc0b4d60838627ce193acb60cc88aea60bddec4fd", - "sha256:64e1091c7103100d5e2c6ea706f0ec9cd6dc313e6fe7775ef777f40d8c20811e", - "sha256:8d1d97f83697ff44466c6bef5d35b6bcdb51e0125829a9c0ed1e6e39fb9a08fb", - "sha256:9c15d9955cc18d0d7ffcc9c03dc51167aedae98542238b54a2e659bd25fe77ed", - "sha256:9c6dd439e878172dc163fced3bc7b19b9ab549c271b257599f55afc3a6a5edef", - "sha256:9ec161e40ed44e4210d3b31e2ff426b4a55e8254f1023e5d2595cb60044f8ea6", - "sha256:b1a038cee487931a5caaef0a2e8520e645508cde21717eacc9af3fbda097d8bb", - "sha256:ba16698c42aad8190e746cd82f6a06769ac7edd415d62ba027ea1d99d958ed93", - "sha256:bb2341836b725c60d0ab3c84970b9b5f68d4b733a7bcb80fb25967e5addb9920", - "sha256:c06cd92b09eb0404cedce3702fa866bf0d00e399439dad3f10288ddc31045422", - "sha256:c835d0ee1f84a5aa04921717754eadbc0f0a56cf613f78dfc1cf9ad35f6c3fea", - "sha256:d0394967d2236a60fd0aacef26646b53636423cc9c70c32f7c5124ebe86f3093", - "sha256:dae2af6f03ecba5f679449fa66ed96585b2fa6accb7fd57d9649e9e398a94f44", - "sha256:e063b988b8ba8b66d6cc2026d937557437e79258095f52eaecfafb18a0a10c03", - "sha256:e87751b54eb7bca580126353a9cf17a8a8eaadd44edaac0e01123e1513a33281", - "sha256:f3020350685e009053829c1168703c346fb32c70c57d828ca3742558e94827a9" + "sha256:1f2b3b253e22322b7f53a111e1f6d7ecfa199b4f08f3efdeb0480f4033b5cdc6", + "sha256:1fe99953b63aabc0c9536fbc91c3c9000d78e4755edc28cc2e10825372046a2d", + "sha256:27e773564232004f4f810fd1f85236673ec3a56ed7f1206fc9ed8670ebedb97a", + "sha256:2b0bae3fd56de1c0a5874fb6577667a3c75bf231a6cef599338820210c16e40a", + "sha256:2b756a65d98b7cf760617a6b68762a23ab8b6ef79922be5afdb00f5e8a9f4e76", + "sha256:323cec0031358bc09aa965c2c5c1f9f59baf76e5b17e62dcc06d1bb9bc3a3c7c", + "sha256:426e7def5f3f23645dada816be119fa61e587dfb4755de250e136b47a045c365", + "sha256:43ce0199f315776dec3ea7bf86f35df86d24b6fcde1babd3e53c38f17352442f", + "sha256:46b8554b9f351561b1989157c6bb54462056f3d44e43aa4e671367c5d62535fc", + "sha256:5abd9436f02e2c8eda5cce2ff8015ce91f33e782a7423de2a1859f772928f714", + "sha256:5d5a187ff9c786fae6aadd49f47f019ff19e99071dc5b0fe91bfecc94d37c686", + "sha256:709a5220891f2b56caad8327fab86281787704931ed484d9548f65598dea9ce4", + "sha256:714efb2f4a082635d9f5afe0bf7e62989b72b65ac52f004eb7ac939f506c03a4", + "sha256:74c90d2be0b4c1a2b3f7dde95cd976757817d4df080d6af0ee8d461568c2e2ad", + "sha256:779c4dea5edd1d3178734d144d32231e0b814976bec1ec09636d1003ffe4725f", + "sha256:7ef730db4097f5b13df8d960f7fdda2744fe21d203ea2bb80c120bb58661b155", + "sha256:8079ac065572fe0e7c696dbd63e1fdc12ce4cdca9933935d038689d4732451df", + "sha256:92ed3bbf71a175a6a4e5fbfcdb2c422bdd72d9b20407e00f435cf22a68b4ea9b", + "sha256:9b180a22db0bbcc447f691ffc3cf7a580e9e0587d87379e35e58b826ebf5bc7b", + "sha256:a10488d1d1a5f9c9d2b2052fdb4cf807bba545818cb1ef724a7f5d44d9f7c3d4", + "sha256:a84657c083d458593c0235926b5c993eec0b586a2508d6a2020556e5347c2f0d", + "sha256:b5dcfcf9bfb798e86fbce76d40a1d5d9e3f92131aecfa3d1e5c9ea1a20f1ef1a", + "sha256:ba9873c253ca1f670e662192a0afcb72b41e0ba3e730f16c665099e12f4dac2d", + "sha256:c008375c0f3d97c36e81725308699116cd5804fdac0f9b7afc732056329d2790", + "sha256:dcdc630461927718b317e6f8be7707bd0fc768cee1fdc78ddaa1e93f4dc6b2b1", + "sha256:e21840043dbe2e280e99ad41951c00eff8ee3b63daf57cd4c1508a3fd8583ea2", + "sha256:e4c73d47bdc1a3f1f66ffa019af0386c48effdc6e8797e5e76875f6388ff72e9", + "sha256:e529578d017045e2f0ed12d2e00e7e99f780f477234da4aae799ec4afca89f37", + "sha256:edd2ffbb789712d83fee19ab009949f998a35c51ad9f9beb39109357416344ff" ], "index": "pypi", "markers": "python_version >= '3.8'", - "version": "==0.4.0" + "version": "==0.5.1" }, "tqdm": { "hashes": [ @@ -930,19 +963,19 @@ }, "typing-extensions": { "hashes": [ - "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36", - "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2" + "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0", + "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef" ], - "markers": "python_version >= '3.7'", - "version": "==4.7.1" + "markers": "python_version >= '3.8'", + "version": "==4.8.0" }, "urllib3": { "hashes": [ - "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11", - "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4" + "sha256:13abf37382ea2ce6fb744d4dad67838eec857c9f4f57009891805e0b5e123594", + "sha256:ef16afa8ba34a1f989db38e1dbbe0c302e4289a47856990d0682e374563ce35e" ], "markers": "python_version >= '3.7'", - "version": "==2.0.4" + "version": "==2.0.5" }, "werkzeug": { "hashes": [ diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index 59a4845..07b7501 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -29,41 +29,60 @@ ENCODER = tiktoken.get_encoding("cl100k_base") +SOURCE_PROMPT = ( + "You are a helpful assistant knowledgeable about AI Alignment and Safety. " + "Please give a clear and coherent answer to the user's questions.(written after \"Q:\") " + "using the following sources. Each source is labeled with a letter. Feel free to " + "use the sources in any order, and try to use multiple sources in your answers.\n\n" +) +SOURCE_PROMPT_SUFFIX = ( + "\n\n" + "Before the question (\"Q: \"), there will be a history of previous questions and answers. " + "These sources only apply to the last question. Any sources used in previous answers " + "are invalid." +) + +QUESTION_PROMPT = ( + "In your answer, please cite any claims you make back to each source " + "using the format: [a], [b], etc. If you use multiple sources to make a claim " + "cite all of them. For example: \"AGI is concerning [c, d, e].\"\n\n" +) +PROMPT_MODES = { + 'default': "", + "concise": ( + "Answer very concisely, getting to the crux of the matter in as " + "few words as possible. Limit your answer to 1-2 sentences.\n\n" + ), + "rookie": ( + "This user is new to the field of AI Alignment and Safety - don't " + "assume they know any technical terms or jargon. Still give a complete answer " + "without patronizing the user, but take any extra time needed to " + "explain new concepts or to illustrate your answer with examples. " + "Put extra effort into explaining the intuition behind concepts " + "rather than just giving a formal definition.\n\n" + ), +} + # --------------------------------- prompt code -------------------------------- # limit a string to a certain number of tokens def cap(text: str, max_tokens: int) -> str: - - if max_tokens <= 0: return "..." + if max_tokens <= 0: + return "..." encoded_text = ENCODER.encode(text) - if len(encoded_text) <= max_tokens: return text - else: return ENCODER.decode(encoded_text[:max_tokens]) + " ..." + if len(encoded_text) <= max_tokens: + return text + return ENCODER.decode(encoded_text[:max_tokens]) + " ..." Prompt = List[Dict[str, str]] -def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block]) -> Prompt: - - prompt = [] - - # History takes the format: history=[ - # {"role": "user", "content": "Die monster. You don’t belong in this world!"}, - # {"role": "assistant", "content": "It was not by my hand I am once again given flesh. I was called here by humans who wished to pay me tribute."}, - # {"role": "user", "content": "Tribute!?! You steal men's souls and make them your slaves!"}, - # {"role": "assistant", "content": "Perhaps the same could be said of all religions..."}, - # {"role": "user", "content": "Your words are as empty as your soul! Mankind ill needs a savior such as you!"}, - # {"role": "assistant", "content": "What is a man? A miserable little pile of secrets. But enough talk... Have at you!"}, - # ] - - source_prompt = "You are a helpful assistant knowledgeable about AI Alignment and Safety. " \ - "Please give a clear and coherent answer to the user's questions.(written after \"Q:\") " \ - "using the following sources. Each source is labeled with a letter. Feel free to " \ - "use the sources in any order, and try to use multiple sources in your answers.\n\n" +def prompt_context(source_prompt: str, context: List[Block], max_tokens: int) -> str: token_count = len(ENCODER.encode(source_prompt)) # Context from top-k blocks @@ -71,75 +90,70 @@ def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block block_str = f"[{chr(ord('a') + i)}] {block.title} - {','.join(block.authors)} - {block.date}\n{block.text}\n\n" block_tc = len(ENCODER.encode(block_str)) - if token_count + block_tc > int(NUM_TOKENS * CONTEXT_FRACTION): - source_prompt += cap(block_str, int(NUM_TOKENS * CONTEXT_FRACTION) - token_count) + if token_count + block_tc > max_tokens: + source_prompt += cap(block_str, max_tokens - token_count) break else: source_prompt += block_str token_count += block_tc + return source_prompt.strip() - source_prompt = source_prompt.strip(); - if len(history) > 0: - source_prompt += "\n\n"\ - "Before the question (\"Q: \"), there will be a history of previous questions and answers. " \ - "These sources only apply to the last question. Any sources used in previous answers " \ - "are invalid." - prompt.append({"role": "system", "content": source_prompt.strip()}) - - - # Write a version of the last 10 messages into history, cutting things off when we hit the token limit. +def prompt_history(history: Prompt, max_tokens: int, n_items=10) -> Prompt: token_count = 0 - history_trnc = [] - for message in history[:-10:-1]: + prompt = [] + + # Get the n_items last messages, starting from the last one. This is because it's assumed + # that more recent messages are more important. The `-1` is because of how slicing works + messages = history[:-n_items - 1:-1] + for message in messages: if message["role"] == "user": - history_trnc.append({"role": "user", "content": "Q: " + message["content"]}) + prompt.append({"role": "user", "content": "Q: " + message["content"]}) token_count += len(ENCODER.encode("Q: " + message["content"])) else: - content = cap(message["content"], int(NUM_TOKENS * HISTORY_FRACTION) - token_count) - + content = message["content"] # censor all source letters into [x] content = re.sub(r"\[[0-9]+\]", "[x]", content) + content = cap(content, max_tokens - token_count) - history_trnc.append({"role": "assistant", "content": content}) + prompt.append({"role": "assistant", "content": content}) token_count += len(ENCODER.encode(content)) - if token_count > int(NUM_TOKENS * HISTORY_FRACTION): + if token_count > max_tokens: break + return prompt[::-1] - prompt.extend(history_trnc[::-1]) - - - question_prompt = "In your answer, please cite any claims you make back to each source " \ - "using the format: [a], [b], etc. If you use multiple sources to make a claim " \ - "cite all of them. For example: \"AGI is concerning [c, d, e].\"\n\n" - - if mode == "concise": - question_prompt += "Answer very concisely, getting to the crux of the matter in as " \ - "few words as possible. Limit your answer to 1-2 sentences.\n\n" - - elif mode == "rookie": - question_prompt += "This user is new to the field of AI Alignment and Safety - don't " \ - "assume they know any technical terms or jargon. Still give a complete answer " \ - "without patronizing the user, but take any extra time needed to " \ - "explain new concepts or to illustrate your answer with examples. "\ - "Put extra effort into explaining the intuition behind concepts " \ - "rather than just giving a formal definition.\n\n" - elif mode != "default": raise ValueError("Invalid mode: " + mode) +def construct_prompt(query: str, mode: str, history: Prompt, context: List[Block]) -> Prompt: + if mode not in PROMPT_MODES: + raise ValueError("Invalid mode: " + mode) + # History takes the format: history=[ + # {"role": "user", "content": "Die monster. You don’t belong in this world!"}, + # {"role": "assistant", "content": "It was not by my hand I am once again given flesh. I was called here by humans who wished to pay me tribute."}, + # {"role": "user", "content": "Tribute!?! You steal men's souls and make them your slaves!"}, + # {"role": "assistant", "content": "Perhaps the same could be said of all religions..."}, + # {"role": "user", "content": "Your words are as empty as your soul! Mankind ill needs a savior such as you!"}, + # {"role": "assistant", "content": "What is a man? A miserable little pile of secrets. But enough talk... Have at you!"}, + # ] - question_prompt += "Q: " + query + # Context from top-k blocks + source_prompt = prompt_context(SOURCE_PROMPT, context, int(NUM_TOKENS * CONTEXT_FRACTION)) + if history: + source_prompt += SOURCE_PROMPT_SUFFIX + source_prompt = [{"role": "system", "content": source_prompt.strip()}] - prompt.append({"role": "user", "content": question_prompt}) + # Write a version of the last 10 messages into history, cutting things off when we hit the token limit. + history_prompt = prompt_history(history, int(NUM_TOKENS * HISTORY_FRACTION)) + question_prompt = [{"role": "user", "content": QUESTION_PROMPT + PROMPT_MODES[mode] + "Q: " + query}] - return prompt + return source_prompt + history_prompt + question_prompt # ------------------------------- completion code ------------------------------- def check_openai_moderation(prompt: Prompt, query: str): prompt_string = '\n\n'.join([message["content"] for message in prompt]) - mod_res = openai.Moderation.create( input = [ query, prompt_string ]) + mod_res = openai.Moderation.create(input=[query, prompt_string]) if any(map(lambda x: x["flagged"], mod_res["results"])): logger.moderation_issue(query, prompt_string, mod_res) @@ -153,7 +167,7 @@ def remaining_tokens(prompt: Prompt): len(ENCODER.encode(message["content"]) + ENCODER.encode(message["role"])) for message in prompt ]) - return NUM_TOKENS - used_tokens - TOKENS_BUFFER + return max(0, NUM_TOKENS - used_tokens - TOKENS_BUFFER) def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, k: int = STANDARD_K): @@ -188,11 +202,10 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, k: int temperature=0, # may or may not be a good idea ): res = chunk["choices"][0]["delta"] - if res is not None and res.get("content") is not None: + if res and res.get("content"): response += res["content"] yield {"state": "streaming", "content": res["content"]} - t2 = time.time() logger.debug(f'Time to get response: {time.time() - t1:.2f}s') if logger.is_debug(): @@ -218,10 +231,12 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, k: int logger.error(e) yield {'state': 'error', 'error': str(e)} + # convert talk_to_robot_internal from dict generator into json generator def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K): yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k)) + # wayyy simplified api def talk_to_robot_simple(index, query: str): res = {'response': ''} diff --git a/api/tests/stampy_chat/test_chat.py b/api/tests/stampy_chat/test_chat.py new file mode 100644 index 0000000..38f0a32 --- /dev/null +++ b/api/tests/stampy_chat/test_chat.py @@ -0,0 +1,309 @@ +import pytest +from unittest.mock import patch, MagicMock + +from stampy_chat.followups import Followup +from stampy_chat.get_blocks import Block +from stampy_chat.chat import ( + cap, construct_prompt, check_openai_moderation, remaining_tokens, talk_to_robot_internal, + talk_to_robot, talk_to_robot_simple, prompt_context, prompt_history, ENCODER, logger +) + + +@pytest.fixture +def history(): + return [ + {"role": "user", "content": "Die monster. You don’t belong in this world!"}, + {"role": "assistant", "content": "It was not by my hand[1] I am once again given flesh. I was called here by humans who wished to pay me tribute."}, + {"role": "user", "content": "Tribute!?! You steal men's souls and make them your slaves!"}, + {"role": "assistant", "content": "Perhaps the same could be said[321] of all religions..."}, + {"role": "user", "content": "Your words are as empty as your soul! Mankind ill needs a savior such as you!"}, + {"role": "assistant", "content": "What is a man? A[4234] miserable little pile of secrets. But enough talk... Have at you!"}, + ] + + +@pytest.fixture +def context(): + return [ + Block( + id=i, + url=f"http://bla.bla/{i}", + tags=[], + title=f"Block{i}", + authors=[f"Author{i}"], + date=f"2021-01-0{i + 1}", + text=f"Block text {i}" + ) for i in range(5) + ] + + +@pytest.mark.parametrize( + "text, max_tokens, expected", + [ + ("", 10, ""), # case when input text is empty + ("Hello, world!", -1, "..."), # case when max_tokens are negative + ("Hello, world!", 0, "..."), # case when max_tokens is zero + ("Hello, world!", 10, "Hello, world!"), # case when input text is less than max_tokens + ("Hello, world! This is a long text string that exceeds the token limit.", 5, + ENCODER.decode(ENCODER.encode("Hello, world! This is a long text string that exceeds the token limit.")[:5]) + " ..."), # case when input text is more than max_tokens + ], +) +def test_cap(text, max_tokens, expected): + assert cap(text, max_tokens) == expected + + +EXPECTED_CONTEXT = """bla bla: [a] Block0 - Author0 - 2021-01-01 +Block text 0 + +[b] Block1 - Author1 - 2021-01-02 +Block text 1 + +[c] Block2 - Author2 - 2021-01-03 +Block text 2 + +[d] Block3 - Author3 - 2021-01-04 +Block text 3 + +[e] Block4 - Author4 - 2021-01-05 +Block text 4""" + + +def test_prompt_context(context): + assert prompt_context("bla bla: ", context, 1000) == EXPECTED_CONTEXT + + +def test_prompt_context_cutoff(context): + formatted = prompt_context("bla bla: ", context, 50) + + assert len(ENCODER.encode(formatted)) == 50 + 1 # the "..." is a single token + assert prompt_context("bla bla: ", context, 50) == EXPECTED_CONTEXT[:116] + '...' + + +def test_prompt_history(history): + assert prompt_history(history, 1000) == [ + {'content': 'Q: Die monster. You don’t belong in this world!', 'role': 'user'}, + {'content': 'It was not by my hand[x] I am once again given flesh. I was called here by humans who wished to pay me tribute.', 'role': 'assistant'}, + {'content': "Q: Tribute!?! You steal men's souls and make them your slaves!", 'role': 'user'}, + {'content': 'Perhaps the same could be said[x] of all religions...', 'role': 'assistant'}, + {'content': 'Q: Your words are as empty as your soul! Mankind ill needs a savior such as you!', 'role': 'user'}, + {'content': 'What is a man? A[x] miserable little pile of secrets. But enough talk... Have at you!', 'role': 'assistant'}, + ] + + +def test_prompt_history_cutoffs(history): + assert prompt_history(history, 50) == [ + {'content': 'Perhaps the same could be said ...', 'role': 'assistant'}, + {'content': 'Q: Your words are as empty as your soul! Mankind ill needs a savior such as you!', 'role': 'user'}, + {'content': 'What is a man? A[x] miserable little pile of secrets. But enough talk... Have at you!', 'role': 'assistant'}, + ] + + +def test_prompt_history_limit_items(): + history = [{'content': f'content {i}', 'role': 'assistant'} for i in range(30)] + + assert len(prompt_history(history, 1000)) == 10 + assert prompt_history(history, 1000) == history[-10:] + + +def test_construct_prompt(history, context): + assert construct_prompt("to be or not to be?", "default", history, context) == [ + { + 'content': ( + 'You are a helpful assistant knowledgeable about AI Alignment and ' + "Safety. Please give a clear and coherent answer to the user's " + 'questions.(written after "Q:") using the following sources. Each ' + 'source is labeled with a letter. Feel free to use the sources in ' + 'any order, and try to use multiple sources in your answers.\n' + '\n' + '[a] Block0 - Author0 - 2021-01-01\n' + 'Block text 0\n' + '\n' + '[b] Block1 - Author1 - 2021-01-02\n' + 'Block text 1\n' + '\n' + '[c] Block2 - Author2 - 2021-01-03\n' + 'Block text 2\n' + '\n' + '[d] Block3 - Author3 - 2021-01-04\n' + 'Block text 3\n' + '\n' + '[e] Block4 - Author4 - 2021-01-05\n' + 'Block text 4\n' + '\n' + 'Before the question ("Q: "), there will be a history of previous ' + 'questions and answers. These sources only apply to the last ' + 'question. Any sources used in previous answers are invalid.' + ), + 'role': 'system' + }, { + 'content': 'Q: Die monster. You don’t belong in this world!', 'role': 'user' + }, { + 'content': 'It was not by my hand[x] I am once again given flesh. I was called' + ' here by humans who wished to pay me tribute.', + 'role': 'assistant' + }, + {'content': "Q: Tribute!?! You steal men's souls and make them your slaves!", 'role': 'user'}, + {'content': 'Perhaps the same could be said[x] of all religions...', 'role': 'assistant'}, + {'content': 'Q: Your words are as empty as your soul! Mankind ill needs a savior such as you!', 'role': 'user'}, + { + 'content': 'What is a man? A[x] miserable little pile of secrets. But enough ' + 'talk... Have at you!', + 'role': 'assistant' + }, + { + 'content': ( + 'In your answer, please cite any claims you make back to each ' + 'source using the format: [a], [b], etc. If you use multiple ' + 'sources to make a claim cite all of them. For example: "AGI is ' + 'concerning [c, d, e]."\n' + '\n' + 'Q: to be or not to be?' + ), + 'role': 'user' + }, + ] + + +def test_construct_prompt_no_history_or_context(): + assert construct_prompt("to be or not to be?", "default", [], []) == [ + { + 'content': ( + 'You are a helpful assistant knowledgeable about AI Alignment and ' + "Safety. Please give a clear and coherent answer to the user's " + 'questions.(written after "Q:") using the following sources. Each ' + 'source is labeled with a letter. Feel free to use the sources in ' + 'any order, and try to use multiple sources in your answers.' + ), + 'role': 'system' + }, + { + 'content': ( + 'In your answer, please cite any claims you make back to each ' + 'source using the format: [a], [b], etc. If you use multiple ' + 'sources to make a claim cite all of them. For example: "AGI is ' + 'concerning [c, d, e]."\n' + '\n' + 'Q: to be or not to be?' + ), + 'role': 'user' + }, + ] + + + +def test_check_openai_moderation_flagged(): + prompt = [{"content": "message 1"}, {"content": "message 2"}] + query = "test query" + + # Create a mock for openai.Moderation.create return value + results = { + 'results': [ + {'flagged': False, 'text': 'bla bla 1'}, + {'flagged': True, 'text': 'bla bla 2'}, + {'flagged': False, 'text': 'bla bla 3'}, + ] + } + + # Patch openai.Moderation.create and logger.moderation_issue + with patch('openai.Moderation.create', return_value=results), patch.object(logger, 'moderation_issue'): + with pytest.raises(ValueError): + check_openai_moderation(prompt, query) + + +def test_check_openai_moderation_not_flagged(): + prompt = [{"content": "message 1"}, {"content": "message 2"}] + query = "test query" + + results = { + 'results': [ + {'flagged': False, 'text': 'bla bla 1'}, + {'flagged': False, 'text': 'bla bla 2'}, + {'flagged': False, 'text': 'bla bla 3'}, + ] + } + + # Patch openai.Moderation.create and logger.moderation_issue + with patch('openai.Moderation.create', return_value=results), patch.object(logger, 'moderation_issue'): + assert check_openai_moderation(prompt, query) is None + + +@pytest.mark.parametrize('prompt, remaining', ( + ([{'role': 'system', 'content': 'bla'}], 4043), + ( + [ + {'role': 'system', 'content': 'bla'}, + {'role': 'user', 'content': 'message 1'}, + {'role': 'assistant', 'content': 'response 1'}, + ], + 4035 + ), + ( + [ + {'role': 'system', 'content': 'bla'}, + {'role': 'user', 'content': 'message 1'}, + {'role': 'assistant', 'content': 'response 1'}, + ] * 1999, + 0 + ), +)) +def test_remaining_tokens(prompt, remaining): + assert remaining_tokens(prompt) == remaining + + +@patch('stampy_chat.chat.check_openai_moderation') +@patch('stampy_chat.chat.logger') +def test_talk_to_robot_internal(history, context): + chunks = [ + {'choices': [{'delta': {'content': f"response 1"}}]}, + {'choices': [{'delta': {'content': f"response 2"}}]}, + {'choices': [{'delta': {'content': f"response 3"}}]}, + {'choices': [{'delta': {}}]}, + {'choices': [{'delta': {'content': None}}]}, + {'choices': [{'delta': {'content': f"response 4"}}]}, + ] + followups = [ + Followup('followup 1', '1', 0.231), + Followup('followup 2', '2', 0.231), + Followup('followup 3', '3', 0.231), + ] + with patch('stampy_chat.chat.get_top_k_blocks', return_value=context): + with patch('stampy_chat.chat.multisearch_authored', return_value=followups): + with patch('openai.ChatCompletion.create', return_value=chunks): + assert list(talk_to_robot_internal("index", "what is this about?", "default", history)) == [ + {'phase': 'semantic', 'state': 'loading'}, + {'citations': [], 'phase': 'semantic', 'state': 'loading'}, + {'phase': 'prompt', 'state': 'loading'}, + {'phase': 'llm', 'state': 'loading'}, + {'content': 'response 1', 'state': 'streaming'}, + {'content': 'response 2', 'state': 'streaming'}, + {'content': 'response 3', 'state': 'streaming'}, + {'content': 'response 4', 'state': 'streaming'}, + { + 'followup_0': {'pageid': '1', 'score': 0.231, 'text': 'followup 1'}, + 'followup_1': {'pageid': '2', 'score': 0.231, 'text': 'followup 2'}, + 'followup_2': {'pageid': '3', 'score': 0.231, 'text': 'followup 3'}, + 'state': 'done' + }, + ] + + +@patch('stampy_chat.chat.check_openai_moderation') +@patch('stampy_chat.chat.logger') +def test_talk_to_robot_internal_error(history, context): + chunks = [ + {'choices': [{'delta': {'content': f"response 1"}}]}, + {'choices': [{'delta': {'content': f"response 2"}}]}, + {'choices': [{'delta': {'content': f"response 3"}}]}, + {'choices': []}, + ] + with patch('stampy_chat.chat.get_top_k_blocks', return_value=context): + with patch('openai.ChatCompletion.create', return_value=chunks): + assert list(talk_to_robot_internal("index", "what is this about?", "default", history)) == [ + {'phase': 'semantic', 'state': 'loading'}, + {'citations': [], 'phase': 'semantic', 'state': 'loading'}, + {'phase': 'prompt', 'state': 'loading'}, + {'phase': 'llm', 'state': 'loading'}, + {'content': 'response 1', 'state': 'streaming'}, + {'content': 'response 2', 'state': 'streaming'}, + {'content': 'response 3', 'state': 'streaming'}, + {'error': 'list index out of range', 'state': 'error'}, + ]