diff --git a/.circleci/config.yml b/.circleci/config.yml index c4d216cac..7e6cd0a25 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -12,7 +12,7 @@ jobs: build: docker: - image: zokrates/env:latest - resource_class: large + resource_class: xlarge steps: - checkout - run: @@ -68,7 +68,7 @@ jobs: docker: - image: zokrates/env:latest - image: trufflesuite/ganache-cli:next - resource_class: large + resource_class: xlarge steps: - checkout - run: diff --git a/.gitignore b/.gitignore index a11fa3038..3e0371edf 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ verifier.sol proof.json universal_setup.dat witness +witness.json # ZoKrates source files at the root of the repository /*.zok diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ea84c10a..0c24b138c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. ## [Unreleased] https://github.com/Zokrates/ZoKrates/compare/latest...develop +## [0.8.6] - 2023-04-13 + +### Release +- https://github.com/Zokrates/ZoKrates/releases/tag/0.8.6 + +### Changes +- Make ZoKrates build on stable rust (#1288, @schaeff) +- Introduce sourcemaps, introduce `inspect` command to identify costly parts of the source (#1285, @schaeff) +- Change witness format to binary, optimize backend integration code to improve proving time (#1289, @dark64) +- Fixed precedence issue on Sudoku example. (#1287, @Turupawn) +- Reduce compiled program size by deduplicating assembly solvers (#1268, @dark64) + ## [0.8.5] - 2023-03-28 ### Release diff --git a/Cargo.lock b/Cargo.lock index e780f0273..e74fd1a9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.17.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" +checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97" dependencies = [ "gimli", ] @@ -39,9 +39,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "0.7.18" +version = "0.7.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" dependencies = [ "memchr", ] @@ -171,8 +171,8 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db02d390bf6643fb404d3d22d31aee1c4bc4459600aef9113833d17e786c6e44" dependencies = [ - "quote 1.0.20", - "syn 1.0.107", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -183,8 +183,8 @@ checksum = "db2fd794a08ccb318058009eefdf15bcaaaaf6f8161eb3345f907222bac38b20" dependencies = [ "num-bigint 0.4.3", "num-traits 0.2.15", - "quote 1.0.20", - "syn 1.0.107", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -334,9 +334,9 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8dd4e5f0bf8285d5ed538d27fab7411f3e297908fd93c62195de8bee3f199e82" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -409,9 +409,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.65" +version = "0.3.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11a17d453482a265fd5f8479f2a3f405566e6ca627837aaddb85af8b1ab8ef61" +checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca" dependencies = [ "addr2line", "cc", @@ -464,9 +464,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitvec" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1489fcb93a5bb47da0462ca93ad252ad6af2145cce58d10d46a83931ba9f016b" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" dependencies = [ "funty", "radium", @@ -537,16 +537,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" dependencies = [ "block-padding 0.2.1", - "generic-array 0.14.6", + "generic-array 0.14.7", ] [[package]] name = "block-buffer" -version = "0.10.2" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ - "generic-array 0.14.6", + "generic-array 0.14.7", ] [[package]] @@ -564,29 +564,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" -[[package]] -name = "bstr" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" -dependencies = [ - "lazy_static", - "memchr", - "regex-automata", - "serde", -] - [[package]] name = "bumpalo" -version = "3.10.0" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37ccbd214614c6783386c1af30caf03192f17891059cecc394b4fb119e363de3" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" [[package]] name = "byte-slice-cast" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87c5fdd0166095e1d463fc6cc01aa8ce547ad77a4e84d42eb6762b084e28067e" +checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" [[package]] name = "byte-tools" @@ -608,15 +596,15 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfb24e866b15a1af2a1b663f10c6b6b8f397a84aadb828f12e5b289ec23a3a3c" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "camino" -version = "1.0.9" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "869119e97797867fd90f5e22af7d0bd274bd4635ebb9eb68c04f3f513ae6c412" +checksum = "c530edf18f37068ac2d977409ed5cd50d53d73bc653c7647b48eb78976ac9ae2" dependencies = [ "serde", ] @@ -638,16 +626,16 @@ checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa" dependencies = [ "camino", "cargo-platform", - "semver 1.0.16", + "semver 1.0.17", "serde", "serde_json", ] [[package]] name = "cc" -version = "1.0.73" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" [[package]] name = "cfg-if" @@ -702,9 +690,9 @@ checksum = "c6dd675567eb3e35787bd2583d129e85fabc7503b0a093d08c51198a307e2091" dependencies = [ "heck", "proc-macro-error", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -736,9 +724,9 @@ checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" [[package]] name = "cpufeatures" -version = "0.2.2" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" +checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181" dependencies = [ "libc", ] @@ -769,12 +757,12 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.6" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" dependencies = [ "cfg-if 1.0.0", - "crossbeam-utils 0.8.14", + "crossbeam-utils 0.8.15", ] [[package]] @@ -790,13 +778,13 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" dependencies = [ "cfg-if 1.0.0", - "crossbeam-epoch 0.9.13", - "crossbeam-utils 0.8.14", + "crossbeam-epoch 0.9.14", + "crossbeam-utils 0.8.15", ] [[package]] @@ -816,14 +804,14 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.13" +version = "0.9.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a" +checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" dependencies = [ "autocfg", "cfg-if 1.0.0", - "crossbeam-utils 0.8.14", - "memoffset 0.7.1", + "crossbeam-utils 0.8.15", + "memoffset 0.8.0", "scopeguard", ] @@ -851,9 +839,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.14" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" +checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" dependencies = [ "cfg-if 1.0.0", ] @@ -866,11 +854,11 @@ checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] name = "crypto-common" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5999502d32b9c48d492abe66392408144895020ec4709e549e840799f3bb74c0" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ - "generic-array 0.14.6", + "generic-array 0.14.7", "typenum", ] @@ -890,19 +878,18 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" dependencies = [ - "generic-array 0.14.6", - "subtle 2.4.1", + "generic-array 0.14.7", + "subtle 2.5.0", ] [[package]] name = "csv" -version = "1.1.6" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +checksum = "0b015497079b9a9d69c02ad25de6c0a6edef051ea6360a327d0bd05802ef64ad" dependencies = [ - "bstr", "csv-core", - "itoa 0.4.8", + "itoa", "ryu", "serde", ] @@ -918,12 +905,12 @@ dependencies = [ [[package]] name = "ctor" -version = "0.1.22" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f877be4f7c9f246b183111634f75baa039715e3f46ce860677d3b19a69fb229c" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" dependencies = [ - "quote 1.0.20", - "syn 1.0.107", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -932,9 +919,9 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -970,7 +957,7 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" dependencies = [ - "generic-array 0.14.6", + "generic-array 0.14.7", ] [[package]] @@ -979,7 +966,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ - "block-buffer 0.10.2", + "block-buffer 0.10.4", "crypto-common", ] @@ -1005,20 +992,20 @@ dependencies = [ [[package]] name = "either" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f107b87b6afc2a64fd13cac55fe06d6c8859f12d4b14cbcdd2c67d0976781be" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" [[package]] name = "env_logger" -version = "0.9.0" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b2cf0344971ee6c64c31be0d530793fba457d322dfec2810c453d0ef228f9c3" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" dependencies = [ "atty", "humantime", "log", - "regex 1.7.1", + "regex 1.7.3", "termcolor", ] @@ -1028,6 +1015,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f4b14e20978669064c33b4c1e0fb4083412e40fe56cbea2eae80fd7591503ee" +[[package]] +name = "errno" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +dependencies = [ + "errno-dragonfly", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "error-chain" version = "0.11.0" @@ -1048,17 +1056,17 @@ dependencies = [ [[package]] name = "ethabi" -version = "17.1.0" +version = "17.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f186de076b3e77b8e6d73c99d1b52edc2a229e604f4b5eb6992c06c11d79d537" +checksum = "e4966fba78396ff92db3b817ee71143eccd98acf0f876b8d600e585a670c5d1b" dependencies = [ "ethereum-types", "hex 0.4.3", "once_cell", - "regex 1.7.1", + "regex 1.7.3", "serde", "serde_json", - "sha3 0.10.1", + "sha3 0.10.7", "thiserror", "uint", ] @@ -1112,9 +1120,9 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "synstructure", ] @@ -1126,9 +1134,9 @@ checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" [[package]] name = "fastrand" -version = "1.7.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3fcf0cee53519c866c09b5de1f6c56ff9d647101f81c1964fa632e148896cdf" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" dependencies = [ "instant", ] @@ -1154,9 +1162,9 @@ dependencies = [ "num-bigint 0.2.6", "num-integer", "num-traits 0.2.15", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -1183,9 +1191,9 @@ dependencies = [ [[package]] name = "fs_extra" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2022715d62ab30faffd124d40b76f4134a550a87792276512b18d63272333394" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fuchsia-cprng" @@ -1201,9 +1209,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" dependencies = [ "futures-channel", "futures-core", @@ -1216,9 +1224,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" dependencies = [ "futures-core", "futures-sink", @@ -1226,15 +1234,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" [[package]] name = "futures-executor" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" dependencies = [ "futures-core", "futures-task", @@ -1244,27 +1252,27 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" [[package]] name = "futures-sink" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" [[package]] name = "futures-task" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" [[package]] name = "futures-util" -version = "0.3.21" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-channel", "futures-core", @@ -1288,9 +1296,9 @@ dependencies = [ [[package]] name = "generic-array" -version = "0.14.6" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", @@ -1298,9 +1306,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if 1.0.0", "js-sys", @@ -1311,9 +1319,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.26.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" +checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4" [[package]] name = "glob" @@ -1323,9 +1331,9 @@ checksum = "8be18de09a56b60ed0edf84bc9df007e30040691af7acd1c41874faac5895bfb" [[package]] name = "glob" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "half" @@ -1375,6 +1383,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + [[package]] name = "hex" version = "0.3.2" @@ -1445,9 +1459,9 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -1469,6 +1483,17 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "io-lifetimes" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" +dependencies = [ + "hermit-abi 0.3.1", + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "itertools" version = "0.7.11" @@ -1489,15 +1514,9 @@ dependencies = [ [[package]] name = "itoa" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" - -[[package]] -name = "itoa" -version = "1.0.2" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112c678d4050afce233f4f2852bb2eb519230b3cf12f33585275537d7e41578d" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "js-sys" @@ -1516,9 +1535,12 @@ checksum = "078e285eafdfb6c4b434e0d31e8cfcb5115b651496faca5749b88fafd4f23bfd" [[package]] name = "keccak" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9b7d56ba4a8344d6be9729995e6b06f928af29998cdf79fe390cbf6b1fee838" +checksum = "3afef3b6eff9ce9d8ff9b3601125eec7f0c8cbac7abd14f355d053fa56c98768" +dependencies = [ + "cpufeatures", +] [[package]] name = "lazy_static" @@ -1528,9 +1550,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.126" +version = "0.2.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" + +[[package]] +name = "linux-raw-sys" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" [[package]] name = "log" @@ -1570,18 +1598,18 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" dependencies = [ "autocfg", ] [[package]] name = "miniz_oxide" -version = "0.5.3" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f5c75688da582b8ffc1f1799e9db273f32133c49e048f614d22ec3256773ccc" +checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" dependencies = [ "adler", ] @@ -1677,18 +1705,18 @@ dependencies = [ [[package]] name = "object" -version = "0.28.4" +version = "0.30.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e42c982f2d955fac81dd7e1d0e1426a7d702acd9c98d19ab01083a6a0328c424" +checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439" dependencies = [ "memchr", ] [[package]] name = "once_cell" -version = "1.17.0" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" +checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" [[package]] name = "opaque-debug" @@ -1724,9 +1752,9 @@ dependencies = [ [[package]] name = "parity-scale-codec" -version = "3.1.5" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9182e4a71cae089267ab03e67c99368db7cd877baf50f931e5d6d4b71e195ac0" +checksum = "637935964ff85a605d114591d4d2c13c5d1ba2806dae97cea6bf180238a749ac" dependencies = [ "arrayvec 0.7.2", "bitvec", @@ -1738,21 +1766,21 @@ dependencies = [ [[package]] name = "parity-scale-codec-derive" -version = "3.1.3" +version = "3.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9299338969a3d2f491d65f140b00ddec470858402f888af98e8642fb5e8965cd" +checksum = "86b26a931f824dd4eca30b3e43bb4f31cd5f0d3a403c5f5ff27106b805bfde7b" dependencies = [ "proc-macro-crate", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] name = "paste" -version = "1.0.7" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c520e05135d6e763148b6426a837e239041653ba7becd2e538c076c738025fc" +checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" [[package]] name = "pest" @@ -1794,9 +1822,9 @@ checksum = "99b8db626e31e5b81787b9783425769681b347011cc59471e33ea46d2ea0cf55" dependencies = [ "pest", "pest_meta", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -1840,9 +1868,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "ppv-lite86" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "pretty_assertions" @@ -1858,14 +1886,14 @@ dependencies = [ [[package]] name = "pretty_assertions" -version = "1.2.1" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89f989ac94207d048d92db058e4f6ec7342b0971fc58d1271ca148b799b3563" +checksum = "a25e9bcb20aa780fd0bb16b72403a9064d6b3f22f026946029acb941a50af755" dependencies = [ - "ansi_term 0.12.1", "ctor", "diff", "output_vt100", + "yansi", ] [[package]] @@ -1883,10 +1911,11 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "1.1.3" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e17d47ce914bf4de440332250b0edd23ce48c005f59fab39d3335866b114f11a" +checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9" dependencies = [ + "once_cell", "thiserror", "toml", ] @@ -1898,9 +1927,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7" dependencies = [ "proc-macro-error-attr", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "version_check", ] @@ -1910,18 +1939,18 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "syn-mid", "version_check", ] [[package]] name = "proc-macro-hack" -version = "0.5.19" +version = "0.5.20+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" +checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" @@ -1934,18 +1963,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.50" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ef7d57beacfaf2d8aee5937dab7b7f28de3cb8b1828479bb5de2a7106f2bae2" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] [[package]] name = "pulldown-cmark" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34f197a544b0c9ab3ae46c359a7ec9cbbb5c7bf97054266fecb7ead794a181d6" +checksum = "2d9cc634bc78768157b5cbfe988ffcd1dcba95cd2b2f03a88316c08c6d00ed63" dependencies = [ "bitflags", "memchr", @@ -1963,11 +1992,11 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.20" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bcdf212e9776fbcb2d23ab029360416bb1706b1aea2d1a5ba002727cbcab804" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ - "proc-macro2 1.0.50", + "proc-macro2 1.0.56", ] [[package]] @@ -1997,7 +2026,7 @@ checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", "rand_chacha", - "rand_core 0.6.3", + "rand_core 0.6.4", ] [[package]] @@ -2007,7 +2036,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.3", + "rand_core 0.6.4", ] [[package]] @@ -2027,18 +2056,18 @@ checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc" [[package]] name = "rand_core" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ "getrandom", ] [[package]] name = "rayon" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" dependencies = [ "either", "rayon-core", @@ -2046,13 +2075,13 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac410af5d00ab6884528b4ab69d1e8e146e8d471201800fa1b4524126de6ad3" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ - "crossbeam-channel 0.5.6", - "crossbeam-deque 0.8.2", - "crossbeam-utils 0.8.14", + "crossbeam-channel 0.5.8", + "crossbeam-deque 0.8.3", + "crossbeam-utils 0.8.15", "num_cpus", ] @@ -2067,9 +2096,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.13" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25bc4c7e55e0b0b7a1d43fb893f4fa1361d0abe38b9ce4f323c2adfe6ef42" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ "bitflags", ] @@ -2081,15 +2119,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ "getrandom", - "redox_syscall", + "redox_syscall 0.2.16", "thiserror", ] [[package]] name = "reduce" -version = "0.1.4" +version = "0.1.5+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16d2dc47b68ac15ea328cd7ebe01d7d512ed29787f7d534ad2a3c341328b35d7" +checksum = "feff7c275fbc4a96ccdc240a5a180487a61a31baffaff6cdd4fb2c8e9e0a2ecd" [[package]] name = "regex" @@ -2106,21 +2144,15 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" dependencies = [ - "aho-corasick 0.7.18", + "aho-corasick 0.7.20", "memchr", - "regex-syntax 0.6.27", + "regex-syntax 0.6.29", ] -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" - [[package]] name = "regex-syntax" version = "0.5.6" @@ -2132,9 +2164,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.27" +version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "remove_dir_all" @@ -2157,9 +2189,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.21" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" +checksum = "d4a36c42d1873f9a77c53bde094f9664d9891bc604a45b4798fd2c389ed12e5b" [[package]] name = "rustc-hex" @@ -2176,11 +2208,25 @@ dependencies = [ "semver 0.11.0", ] +[[package]] +name = "rustix" +version = "0.37.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85597d61f83914ddeba6a47b3b8ffe7365107221c2e557ed94426489fefb5f77" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.48.0", +] + [[package]] name = "ryu" -version = "1.0.10" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3f6f92acf49d1b98f7a81226834412ada05458b7364277387724a237f062695" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "same-file" @@ -2210,9 +2256,9 @@ dependencies = [ [[package]] name = "scoped-tls" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" [[package]] name = "scopeguard" @@ -2231,9 +2277,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" +checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" dependencies = [ "serde", ] @@ -2249,9 +2295,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.138" +version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1578c6245786b9d168c5447eeacfb96856573ca56c9d68fdcf394be134882a47" +checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" dependencies = [ "serde_derive", ] @@ -2268,23 +2314,23 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.138" +version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "023e9b1467aef8a10fb88f25611870ada9800ef7e22afce356bb0d2387b6f27c" +checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 2.0.14", ] [[package]] name = "serde_json" -version = "1.0.82" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82c2c1fdcd807d1098552c5b9a36e425e42e9fbd7c6a37a8425f390f781f7fa7" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" dependencies = [ "indexmap", - "itoa 1.0.2", + "itoa", "ryu", "serde", ] @@ -2338,9 +2384,9 @@ dependencies = [ [[package]] name = "sha3" -version = "0.10.1" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "881bf8156c87b6301fc5ca6b27f11eeb2761224c7081e69b409d5a1951a70c86" +checksum = "54c2bb1a323307527314a36bfb73f24febb08ce2b8a554bf4ffd6f51ad15198c" dependencies = [ "digest 0.10.6", "keccak", @@ -2348,9 +2394,9 @@ dependencies = [ [[package]] name = "single" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5add732a1ab689845591a1b50339cf5310b563e08dc5813c65991f30369ea2" +checksum = "9db45bb685b486eec37e0271dcc0dac76eae5e893125f8a4f0511d0a1d29543b" dependencies = [ "failure", ] @@ -2364,7 +2410,7 @@ dependencies = [ "bytecount", "cargo_metadata", "error-chain 0.12.4", - "glob 0.3.0", + "glob 0.3.1", "pulldown-cmark", "tempfile", "walkdir", @@ -2372,9 +2418,12 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +dependencies = [ + "autocfg", +] [[package]] name = "static_assertions" @@ -2402,9 +2451,9 @@ checksum = "2d67a5a62ba6e01cb2192ff309324cb4875d0c451d55fe2319433abe7a05a8ee" [[package]] name = "subtle" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" @@ -2419,12 +2468,23 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.107" +version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", + "proc-macro2 1.0.56", + "quote 1.0.26", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf316d5356ed6847742d036f8a39c3b8435cac10bd528a4bd461928a6ab34d5" +dependencies = [ + "proc-macro2 1.0.56", + "quote 1.0.26", "unicode-ident", ] @@ -2434,9 +2494,9 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baa8e7560a164edb1621a55d18a0c59abf49d360f47aa7b821061dd7eea7fac9" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -2445,9 +2505,9 @@ version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "unicode-xid 0.2.4", ] @@ -2469,23 +2529,22 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.3.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" +checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" dependencies = [ "cfg-if 1.0.0", "fastrand", - "libc", - "redox_syscall", - "remove_dir_all", - "winapi", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.45.0", ] [[package]] name = "termcolor" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" dependencies = [ "winapi-util", ] @@ -2501,22 +2560,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 2.0.14", ] [[package]] @@ -2539,9 +2598,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" dependencies = [ "serde", ] @@ -2564,9 +2623,9 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", ] [[package]] @@ -2596,27 +2655,27 @@ checksum = "a9b2228007eba4120145f785df0f6c92ea538f5a3635a612ecf4e334c8c1446d" [[package]] name = "typenum" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "ucd-trie" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89570599c4fe5585de2b388aab47e99f7fa4e9238a1399f707a02e356058141c" +checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" [[package]] name = "ucd-util" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bfcbf611b122f2c10eb1bb6172fbc4c2e25df9970330e4d75ce2b5201c9bfc" +checksum = "abd2fc5d32b590614af8b0a20d837f32eca055edd0bbead59a9cfe80858be003" [[package]] name = "uint" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12f03af7ccf01dd611cc450a0d10dbc9b745770d096473e2faf0ca6e2d66d1e0" +checksum = "76f64bba2c53b04fcab63c01a7d7427eadc821e3bc48c34dc9ba29c501164b52" dependencies = [ "byteorder", "crunchy", @@ -2635,21 +2694,21 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.1" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bd2fe26506023ed7b5e1e315add59d6f584c621d037f9368fea9cfb988f368c" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" [[package]] name = "unicode-segmentation" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8820f5d777f6224dc4be3632222971ac30164d4a258d595640799554ebfd99" +checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" [[package]] name = "unicode-width" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" [[package]] name = "unicode-xid" @@ -2695,12 +2754,11 @@ checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" [[package]] name = "walkdir" -version = "2.3.2" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" dependencies = [ "same-file", - "winapi", "winapi-util", ] @@ -2731,9 +2789,9 @@ dependencies = [ "bumpalo", "lazy_static", "log", - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "wasm-bindgen-shared", ] @@ -2755,7 +2813,7 @@ version = "0.2.81" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c441e177922bc58f1e12c022624b6216378e5febc2f0533e41ba443d505b80aa" dependencies = [ - "quote 1.0.20", + "quote 1.0.26", "wasm-bindgen-macro-support", ] @@ -2765,9 +2823,9 @@ version = "0.2.81" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d94ac45fcf608c1f45ef53e748d35660f168490c10b23704c7779ab8f5c3048" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 1.0.109", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2798,8 +2856,8 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88ad594bf33e73cafcac2ae9062fc119d4f75f9c77e25022f91c9a64bd5b6463" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", + "proc-macro2 1.0.56", + "quote 1.0.26", ] [[package]] @@ -2843,34 +2901,171 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + [[package]] name = "wyz" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30b31594f29d27036c383b53b59ed3476874d518f0efb151b27a4c275141390e" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" dependencies = [ "tap", ] +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + [[package]] name = "zeroize" -version = "1.5.6" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20b578acffd8516a6c3f2a1bdefc1ec37e547bb4e0fb8b6b01a4cafc886b4442" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" dependencies = [ "zeroize_derive", ] [[package]] name = "zeroize_derive" -version = "1.3.2" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f8f187641dad4f680d25c4bfc4225b418165984179f26ca76ec4fb6441d3a17" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ - "proc-macro2 1.0.50", - "quote 1.0.20", - "syn 1.0.107", - "synstructure", + "proc-macro2 1.0.56", + "quote 1.0.26", + "syn 2.0.14", ] [[package]] @@ -2895,7 +3090,7 @@ dependencies = [ [[package]] name = "zokrates_abi" -version = "0.1.7" +version = "0.1.8" dependencies = [ "serde", "serde_derive", @@ -2906,7 +3101,7 @@ dependencies = [ [[package]] name = "zokrates_analysis" -version = "0.1.1" +version = "0.1.2" dependencies = [ "cfg-if 0.1.10", "csv", @@ -2929,7 +3124,7 @@ dependencies = [ [[package]] name = "zokrates_ark" -version = "0.1.2" +version = "0.1.3" dependencies = [ "ark-bls12-377", "ark-bn254", @@ -2956,12 +3151,13 @@ dependencies = [ [[package]] name = "zokrates_ast" -version = "0.1.5" +version = "0.1.6" dependencies = [ "ark-bls12-377", + "byteorder", "cfg-if 0.1.10", - "csv", "derivative", + "log", "num-bigint 0.2.6", "pairing_ce", "serde", @@ -2974,7 +3170,7 @@ dependencies = [ [[package]] name = "zokrates_bellman" -version = "0.1.1" +version = "0.1.2" dependencies = [ "bellman_ce", "getrandom", @@ -2991,11 +3187,11 @@ dependencies = [ [[package]] name = "zokrates_circom" -version = "0.1.2" +version = "0.1.3" dependencies = [ "bellman_ce", "byteorder", - "pretty_assertions 1.2.1", + "pretty_assertions 1.3.0", "zkutil", "zokrates_ast", "zokrates_core", @@ -3004,7 +3200,7 @@ dependencies = [ [[package]] name = "zokrates_cli" -version = "0.8.5" +version = "0.8.6" dependencies = [ "assert_cli", "blake2 0.8.1", @@ -3019,7 +3215,7 @@ dependencies = [ "hex 0.3.2", "lazy_static", "log", - "pretty_assertions 1.2.1", + "pretty_assertions 1.3.0", "primitive-types", "rand 0.4.6", "rand 0.8.5", @@ -3041,12 +3237,13 @@ dependencies = [ "zokrates_field", "zokrates_fs_resolver", "zokrates_interpreter", + "zokrates_profiler", "zokrates_proof_systems", ] [[package]] name = "zokrates_codegen" -version = "0.1.1" +version = "0.1.2" dependencies = [ "zokrates_ast", "zokrates_common", @@ -3064,7 +3261,7 @@ dependencies = [ [[package]] name = "zokrates_core" -version = "0.7.4" +version = "0.7.5" dependencies = [ "cfg-if 0.1.10", "csv", @@ -3089,7 +3286,7 @@ dependencies = [ [[package]] name = "zokrates_core_test" -version = "0.2.10" +version = "0.2.11" dependencies = [ "zokrates_test", "zokrates_test_derive", @@ -3097,7 +3294,7 @@ dependencies = [ [[package]] name = "zokrates_embed" -version = "0.1.9" +version = "0.1.10" dependencies = [ "ark-bls12-377", "ark-bw6-761", @@ -3115,7 +3312,7 @@ dependencies = [ [[package]] name = "zokrates_field" -version = "0.5.3" +version = "0.5.4" dependencies = [ "ark-bls12-377", "ark-bls12-381", @@ -3147,7 +3344,7 @@ dependencies = [ [[package]] name = "zokrates_interpreter" -version = "0.1.3" +version = "0.1.4" dependencies = [ "ark-bls12-377", "num", @@ -3163,7 +3360,7 @@ dependencies = [ [[package]] name = "zokrates_js" -version = "1.1.6" +version = "1.1.7" dependencies = [ "console_error_panic_hook", "getrandom", @@ -3193,7 +3390,7 @@ dependencies = [ [[package]] name = "zokrates_parser" -version = "0.3.3" +version = "0.3.4" dependencies = [ "glob 0.2.11", "pest", @@ -3202,7 +3399,7 @@ dependencies = [ [[package]] name = "zokrates_pest_ast" -version = "0.3.1" +version = "0.3.2" dependencies = [ "from-pest", "glob 0.2.11", @@ -3212,9 +3409,16 @@ dependencies = [ "zokrates_parser", ] +[[package]] +name = "zokrates_profiler" +version = "0.1.0" +dependencies = [ + "zokrates_ast", +] + [[package]] name = "zokrates_proof_systems" -version = "0.1.1" +version = "0.1.2" dependencies = [ "blake2 0.8.1", "byteorder", @@ -3240,7 +3444,7 @@ dependencies = [ [[package]] name = "zokrates_test" -version = "0.2.1" +version = "0.2.2" dependencies = [ "getrandom", "rand 0.8.5", @@ -3262,7 +3466,7 @@ dependencies = [ [[package]] name = "zokrates_test_derive" -version = "0.0.1" +version = "0.0.2" dependencies = [ - "glob 0.3.0", + "glob 0.3.1", ] diff --git a/Cargo.toml b/Cargo.toml index d521322d6..fd10dd820 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,9 @@ members = [ "zokrates_bellman", "zokrates_proof_systems", "zokrates_js", - "zokrates_circom" + "zokrates_circom", + "zokrates_profiler", ] -exclude = [] \ No newline at end of file +[profile.dev] +opt-level = 1 \ No newline at end of file diff --git a/clippy.toml b/clippy.toml deleted file mode 100644 index 3f707a567..000000000 --- a/clippy.toml +++ /dev/null @@ -1 +0,0 @@ -blacklisted-names = [] \ No newline at end of file diff --git a/dev.Dockerfile b/dev.Dockerfile index 2bf8fed0e..bcb73b0d8 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -1,9 +1,9 @@ -FROM rustlang/rust:nightly +FROM rust:latest RUN useradd -u 1000 -m zokrates -COPY ./scripts/install_foundry_deb.sh /tmp/ -RUN /tmp/install_foundry_deb.sh +COPY ./scripts/install_foundry.sh /tmp/ +RUN /tmp/install_foundry.sh USER zokrates diff --git a/rust-toolchain.toml b/rust-toolchain.toml index cbfe2704f..292fe499e 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2022-07-01" +channel = "stable" diff --git a/zokrates_abi/Cargo.toml b/zokrates_abi/Cargo.toml index 953aa53f9..7ee7e15ab 100644 --- a/zokrates_abi/Cargo.toml +++ b/zokrates_abi/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_abi" -version = "0.1.7" +version = "0.1.8" authors = ["Thibaut Schaeffer "] edition = "2018" diff --git a/zokrates_abi/src/lib.rs b/zokrates_abi/src/lib.rs index a85adb219..d12f65d8f 100644 --- a/zokrates_abi/src/lib.rs +++ b/zokrates_abi/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - pub enum Inputs { Raw(Vec), Abi(Values), diff --git a/zokrates_analysis/Cargo.toml b/zokrates_analysis/Cargo.toml index f93dc7d8d..911790400 100644 --- a/zokrates_analysis/Cargo.toml +++ b/zokrates_analysis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_analysis" -version = "0.1.1" +version = "0.1.2" edition = "2021" [features] diff --git a/zokrates_analysis/src/assembly_transformer.rs b/zokrates_analysis/src/assembly_transformer.rs index 7c8e856e0..446f4c8f6 100644 --- a/zokrates_analysis/src/assembly_transformer.rs +++ b/zokrates_analysis/src/assembly_transformer.rs @@ -3,9 +3,13 @@ use crate::ZirPropagator; use std::fmt; +use std::ops::*; use zokrates_ast::zir::lqc::LinQuadComb; use zokrates_ast::zir::result_folder::ResultFolder; -use zokrates_ast::zir::{FieldElementExpression, Id, Identifier, ZirAssemblyStatement, ZirProgram}; +use zokrates_ast::zir::AssemblyConstraint; +use zokrates_ast::zir::{ + Expr, FieldElementExpression, Id, Identifier, ZirAssemblyStatement, ZirProgram, +}; use zokrates_field::Field; #[derive(Debug)] @@ -28,143 +32,132 @@ impl AssemblyTransformer { impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer { type Error = Error; - fn fold_assembly_statement( + fn fold_assembly_constraint( &mut self, - s: ZirAssemblyStatement<'ast, T>, + s: AssemblyConstraint<'ast, T>, ) -> Result>, Self::Error> { - match s { - ZirAssemblyStatement::Assignment(_, _) => Ok(vec![s]), - ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = self.fold_field_expression(lhs)?; - let rhs = self.fold_field_expression(rhs)?; - - let (is_quadratic, lhs, rhs) = match (lhs, rhs) { - ( - lhs @ FieldElementExpression::Identifier(..), - rhs @ FieldElementExpression::Identifier(..), - ) => (true, lhs, rhs), - (FieldElementExpression::Mult(x, y), other) - | (other, FieldElementExpression::Mult(x, y)) - if other.is_linear() => - { - ( - x.is_linear() && y.is_linear(), - other, - FieldElementExpression::Mult(x, y), + let lhs = self.fold_field_expression(s.left)?; + let rhs = self.fold_field_expression(s.right)?; + + let (is_quadratic, lhs, rhs) = match (lhs, rhs) { + ( + lhs @ FieldElementExpression::Identifier(..), + rhs @ FieldElementExpression::Identifier(..), + ) => (true, lhs, rhs), + (FieldElementExpression::Mult(e), other) | (other, FieldElementExpression::Mult(e)) + if other.is_linear() => + { + ( + e.left.is_linear() && e.right.is_linear(), + other, + FieldElementExpression::Mult(e), + ) + } + (lhs, rhs) => (false, lhs, rhs), + }; + + match is_quadratic { + true => Ok(vec![ZirAssemblyStatement::constraint(lhs, rhs, s.metadata)]), + false => { + let sub = FieldElementExpression::sub(lhs, rhs); + let mut lqc = LinQuadComb::try_from(sub.clone()) + .map_err(|_| Error("Non-quadratic constraints are not allowed".to_string()))?; + + let linear = lqc + .linear + .into_iter() + .map(|(c, i)| { + FieldElementExpression::mul( + FieldElementExpression::value(c), + FieldElementExpression::identifier(i), ) - } - (lhs, rhs) => (false, lhs, rhs), - }; - - match is_quadratic { - true => Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]), - false => { - let sub = FieldElementExpression::Sub(box lhs, box rhs); - let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| { - Error("Non-quadratic constraints are not allowed".to_string()) - })?; - - let linear = lqc - .linear - .into_iter() - .map(|(c, i)| { - FieldElementExpression::Mult( - box FieldElementExpression::Number(c), - box FieldElementExpression::identifier(i), - ) - }) - .fold(FieldElementExpression::Number(T::from(0)), |acc, e| { - FieldElementExpression::Add(box acc, box e) - }); - - let lhs = FieldElementExpression::Add( - box FieldElementExpression::Number(lqc.constant), - box linear, - ); - - let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 { - let common_factor = lqc - .quadratic - .iter() - .scan(None, |state: &mut Option>, (_, a, b)| { - // short circuit if we do not have any common factors anymore - if *state == Some(vec![]) { - None - } else { - match state { - // only keep factors found in this term - Some(factors) => { - factors.retain(|&x| x == a || x == b); - } - // initialisation step, start with the two factors in the first term - None => { - *state = Some(vec![a, b]); - } - }; - state.clone() + }) + .fold(FieldElementExpression::value(T::from(0)), |acc, e| { + FieldElementExpression::add(acc, e) + }); + + let lhs = FieldElementExpression::add( + FieldElementExpression::value(lqc.constant), + linear, + ); + + let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 { + let common_factor = lqc + .quadratic + .iter() + .scan(None, |state: &mut Option>, (_, a, b)| { + // short circuit if we do not have any common factors anymore + if *state == Some(vec![]) { + None + } else { + match state { + // only keep factors found in this term + Some(factors) => { + factors.retain(|&x| x == a || x == b); } - }) - .last() - .and_then(|mut v| v.pop().cloned()); - - match common_factor { - Some(factor) => Ok(FieldElementExpression::Mult( - box lqc - .quadratic - .into_iter() - .map(|(c, i0, i1)| { - let c = T::zero() - c; - let e = match (i0, i1) { - (i0, i1) if factor.eq(&i0) => { - FieldElementExpression::identifier(i1) - } - (i0, i1) if factor.eq(&i1) => { - FieldElementExpression::identifier(i0) - } - _ => unreachable!(), - }; - FieldElementExpression::Mult( - box FieldElementExpression::Number(c), - box e, - ) - }) - .fold( - FieldElementExpression::Number(T::from(0)), - |acc, e| FieldElementExpression::Add(box acc, box e), - ), - box FieldElementExpression::identifier(factor), - )), - None => Err(Error( - "Non-quadratic constraints are not allowed".to_string(), - )), - }? - } else { + // initialisation step, start with the two factors in the first term + None => { + *state = Some(vec![a, b]); + } + }; + state.clone() + } + }) + .last() + .and_then(|mut v| v.pop().cloned()); + + match common_factor { + Some(factor) => Ok(FieldElementExpression::mul( lqc.quadratic - .pop() + .into_iter() .map(|(c, i0, i1)| { - FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(T::zero() - c), - box FieldElementExpression::identifier(i0), - ), - box FieldElementExpression::identifier(i1), - ) + let c = T::zero() - c; + let e = match (i0, i1) { + (i0, i1) if factor.eq(&i0) => { + FieldElementExpression::identifier(i1) + } + (i0, i1) if factor.eq(&i1) => { + FieldElementExpression::identifier(i0) + } + _ => unreachable!(), + }; + FieldElementExpression::mul(FieldElementExpression::value(c), e) }) - .unwrap_or_else(|| FieldElementExpression::Number(T::from(0))) - }; + .fold( + FieldElementExpression::value(T::from(0)), + FieldElementExpression::add, + ), + FieldElementExpression::identifier(factor), + )), + None => Err(Error( + "Non-quadratic constraints are not allowed".to_string(), + )), + }? + } else { + lqc.quadratic + .pop() + .map(|(c, i0, i1)| { + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(T::zero() - c), + FieldElementExpression::identifier(i0), + ), + FieldElementExpression::identifier(i1), + ) + }) + .unwrap_or_else(|| FieldElementExpression::value(T::from(0))) + }; - let mut propagator = ZirPropagator::default(); - let lhs = propagator - .fold_field_expression(lhs) - .map_err(|e| Error(e.to_string()))?; + let mut propagator = ZirPropagator::default(); + let lhs = propagator + .fold_field_expression(lhs) + .map_err(|e| Error(e.to_string()))?; - let rhs = propagator - .fold_field_expression(rhs) - .map_err(|e| Error(e.to_string()))?; + let rhs = propagator + .fold_field_expression(rhs) + .map_err(|e| Error(e.to_string()))?; - Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)]) - } - } + Ok(vec![ZirAssemblyStatement::constraint(lhs, rhs, s.metadata)]) } } } @@ -180,21 +173,21 @@ mod tests { fn quadratic() { // x === a * b; let lhs = FieldElementExpression::::identifier("x".into()); - let rhs = FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ); - let expected = vec![ZirAssemblyStatement::Constraint( + let expected = vec![ZirAssemblyStatement::constraint( FieldElementExpression::identifier("x".into()), - FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), SourceMetadata::default(), )]; let result = AssemblyTransformer - .fold_assembly_statement(ZirAssemblyStatement::Constraint( + .fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), @@ -208,15 +201,15 @@ mod tests { fn non_quadratic() { // x === ((a * b) * c); let lhs = FieldElementExpression::::identifier("x".into()); - let rhs = FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::identifier("c".into()), ); - let result = AssemblyTransformer.fold_assembly_statement(ZirAssemblyStatement::Constraint( + let result = AssemblyTransformer.fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), @@ -229,31 +222,31 @@ mod tests { fn transform() { // x === 1 - a * b; --> (-1) + x === (((-1) * a) * b); let lhs = FieldElementExpression::identifier("x".into()); - let rhs = FieldElementExpression::Sub( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::sub( + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), ); - let expected = vec![ZirAssemblyStatement::Constraint( - FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("x".into()), + let expected = vec![ZirAssemblyStatement::constraint( + FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("x".into()), ), - FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), SourceMetadata::default(), )]; let result = AssemblyTransformer - .fold_assembly_statement(ZirAssemblyStatement::Constraint( + .fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), @@ -267,30 +260,30 @@ mod tests { fn factorize() { // x === (a * b) + (b * c); --> x === ((a + c) * b); let lhs = FieldElementExpression::::identifier("x".into()); - let rhs = FieldElementExpression::Add( - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::add( + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("b".into()), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::mul( + FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("c".into()), ), ); - let expected = vec![ZirAssemblyStatement::Constraint( + let expected = vec![ZirAssemblyStatement::constraint( FieldElementExpression::identifier("x".into()), - FieldElementExpression::Mult( - box FieldElementExpression::Add( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::mul( + FieldElementExpression::add( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("c".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), SourceMetadata::default(), )]; let result = AssemblyTransformer - .fold_assembly_statement(ZirAssemblyStatement::Constraint( + .fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), @@ -307,100 +300,100 @@ mod tests { // --> // ((((x + ((-1)*a)) + ((-1)*b)) + ((-1)*c)) + (2*mid)) === (((((-2)*b) + ((-2)*c)) + (4*mid)) * a); let lhs = FieldElementExpression::::identifier("x".into()); - let rhs = FieldElementExpression::Add( - box FieldElementExpression::Sub( - box FieldElementExpression::Sub( - box FieldElementExpression::Sub( - box FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let rhs = FieldElementExpression::add( + FieldElementExpression::sub( + FieldElementExpression::sub( + FieldElementExpression::sub( + FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::identifier("c".into()), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::identifier("c".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("mid".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("mid".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("mid".into()), + FieldElementExpression::identifier("mid".into()), ), ); - let lhs_expected = FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("a".into()), + let lhs_expected = FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("a".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("b".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-1)), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-1)), + FieldElementExpression::identifier("c".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("mid".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("mid".into()), ), ); - let rhs_expected = FieldElementExpression::Mult( - box FieldElementExpression::Add( - box FieldElementExpression::Add( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-2)), - box FieldElementExpression::identifier("b".into()), + let rhs_expected = FieldElementExpression::mul( + FieldElementExpression::add( + FieldElementExpression::add( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-2)), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(-2)), - box FieldElementExpression::identifier("c".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(-2)), + FieldElementExpression::identifier("c".into()), ), ), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::identifier("mid".into()), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::identifier("mid".into()), ), ), - box FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("a".into()), ); - let expected = vec![ZirAssemblyStatement::Constraint( + let expected = vec![ZirAssemblyStatement::constraint( lhs_expected, rhs_expected, SourceMetadata::default(), )]; let result = AssemblyTransformer - .fold_assembly_statement(ZirAssemblyStatement::Constraint( + .fold_assembly_statement(ZirAssemblyStatement::constraint( lhs, rhs, SourceMetadata::default(), diff --git a/zokrates_analysis/src/boolean_array_comparator.rs b/zokrates_analysis/src/boolean_array_comparator.rs index cb016c034..72c158193 100644 --- a/zokrates_analysis/src/boolean_array_comparator.rs +++ b/zokrates_analysis/src/boolean_array_comparator.rs @@ -1,10 +1,22 @@ -use zokrates_ast::typed::{ - folder::*, ArrayExpressionInner, ArrayValue, BooleanExpression, ConditionalExpression, - ConditionalKind, EqExpression, FieldElementExpression, SelectExpression, Type, TypedExpression, - TypedProgram, UExpressionInner, +use zokrates_ast::{ + common::WithSpan, + typed::{ + folder::*, ArrayExpression, ArrayType, BooleanExpression, Conditional, ConditionalKind, + Expr, FieldElementExpression, Select, Type, TypedExpression, TypedProgram, UExpression, + UExpressionInner, + }, }; + use zokrates_field::Field; +fn sum_rec + Clone>(a: &[T], default: &T) -> T { + match a.len() { + 0 => default.clone(), + 1 => a[0].clone(), + n => sum_rec(&a[..n / 2], default) + sum_rec(&a[n / 2..], default), + } +} + #[derive(Default)] pub struct BooleanArrayComparator; @@ -15,73 +27,57 @@ impl BooleanArrayComparator { } impl<'ast, T: Field> Folder<'ast, T> for BooleanArrayComparator { - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { - BooleanExpression::ArrayEq(e) => match e.left.inner_type() { + BooleanExpression::ArrayEq(e) => match *e.left.inner_type() { Type::Boolean => { + let span = e.get_span(); + let len = e.left.size(); let len = match len.as_inner() { - UExpressionInner::Value(v) => *v as usize, + UExpressionInner::Value(v) => v.value as usize, _ => unreachable!("array size should be known"), }; - let chunk_size = T::get_required_bits() as usize - 1; + let chunk_size = T::get_required_bits() - 1; let left_elements: Vec<_> = (0..len) - .map(|i| { - BooleanExpression::Select(SelectExpression::new( - *e.left.clone(), - (i as u32).into(), - )) - }) + .map(|i| BooleanExpression::select(*e.left.clone(), i as u32).span(span)) .collect(); let right_elements: Vec<_> = (0..len) - .map(|i| { - BooleanExpression::Select(SelectExpression::new( - *e.right.clone(), - (i as u32).into(), - )) - }) + .map(|i| BooleanExpression::select(*e.right.clone(), i as u32).span(span)) .collect(); let process = |elements: &[BooleanExpression<'ast, T>]| { elements .chunks(chunk_size) .map(|chunk| { - TypedExpression::from( - chunk + TypedExpression::from(sum_rec( + &chunk .iter() .rev() .enumerate() .rev() .map(|(index, c)| { - FieldElementExpression::Conditional( - ConditionalExpression::new( - c.clone(), - FieldElementExpression::Pow( - box FieldElementExpression::Number( - T::from(2), - ), - box (index as u32).into(), - ), - T::zero().into(), - ConditionalKind::Ternary, + FieldElementExpression::conditional( + c.clone().span(span), + FieldElementExpression::pow( + FieldElementExpression::value(T::from(2)) + .span(span), + UExpression::from(index as u32).span(span), ), + FieldElementExpression::from(T::zero()).span(span), + ConditionalKind::Ternary, ) + .span(span) }) - .fold(None, |acc, e| match acc { - Some(acc) => { - Some(FieldElementExpression::Add(box acc, box e)) - } - None => Some(e), - }) - .unwrap_or_else(|| { - FieldElementExpression::Number(T::zero()) - }), - ) + .collect::>(), + &FieldElementExpression::value(T::from(0)).span(span), + )) + .span(span) .into() }) .collect() @@ -93,23 +89,28 @@ impl<'ast, T: Field> Folder<'ast, T> for BooleanArrayComparator { let chunk_count = left.len(); - BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value(ArrayValue(left)) - .annotate(Type::FieldElement, chunk_count as u32), - ArrayExpressionInner::Value(ArrayValue(right)) - .annotate(Type::FieldElement, chunk_count as u32), - )) + BooleanExpression::array_eq( + ArrayExpression::value(left) + .annotate(ArrayType::new(Type::FieldElement, chunk_count as u32)) + .span(span), + ArrayExpression::value(right) + .annotate(ArrayType::new(Type::FieldElement, chunk_count as u32)) + .span(span), + ) } - _ => fold_boolean_expression(self, BooleanExpression::ArrayEq(e)), + _ => fold_boolean_expression_cases(self, BooleanExpression::ArrayEq(e)), }, - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } } #[cfg(test)] mod tests { - use zokrates_ast::typed::{BooleanExpression, EqExpression, Type}; + use zokrates_ast::{ + common::expressions::BinaryExpression, + typed::{BooleanExpression, Type}, + }; use zokrates_field::DummyCurveField; use zokrates_ast::typed::utils::{a, a_id, conditional, f, select, u_32}; @@ -123,13 +124,13 @@ mod tests { // [x[0] ? 2**1 : 0 + x[1] ? 2**0 : 0] == [y[0] ? 2**1 : 0 + y[1] ? 2**0 : 0] // a single field is sufficient, as the prime we're working with is 3 bits long, so we can pack up to 2 bits - let x = a_id("x").annotate(Type::Boolean, 2u32); - let y = a_id("y").annotate(Type::Boolean, 2u32); + let x = a_id("x").annotate(ArrayType::new(Type::Boolean, 2u32)); + let y = a_id("y").annotate(ArrayType::new(Type::Boolean, 2u32)); let e: BooleanExpression = - BooleanExpression::ArrayEq(EqExpression::new(x.clone(), y.clone())); + BooleanExpression::ArrayEq(BinaryExpression::new(x.clone(), y.clone())); - let expected = BooleanExpression::ArrayEq(EqExpression::new( + let expected = BooleanExpression::ArrayEq(BinaryExpression::new( a([ conditional(select(x.clone(), 0u32), f(2).pow(u_32(1)), f(0)) + conditional(select(x.clone(), 1u32), f(2).pow(u_32(0)), f(0)), @@ -151,13 +152,13 @@ mod tests { // should become // [x[0] ? 2**2 : 0 + x[1] ? 2**1 : 0, x[2] ? 2**0 : 0] == [y[0] ? 2**2 : 0 + y[1] ? 2**1 : 0 y[2] ? 2**0 : 0] - let x = a_id("x").annotate(Type::Boolean, 3u32); - let y = a_id("y").annotate(Type::Boolean, 3u32); + let x = a_id("x").annotate(ArrayType::new(Type::Boolean, 3u32)); + let y = a_id("y").annotate(ArrayType::new(Type::Boolean, 3u32)); let e: BooleanExpression = - BooleanExpression::ArrayEq(EqExpression::new(x.clone(), y.clone())); + BooleanExpression::ArrayEq(BinaryExpression::new(x.clone(), y.clone())); - let expected = BooleanExpression::ArrayEq(EqExpression::new( + let expected = BooleanExpression::ArrayEq(BinaryExpression::new( a([ conditional(select(x.clone(), 0u32), f(2).pow(u_32(1)), f(0)) + conditional(select(x.clone(), 1u32), f(2).pow(u_32(0)), f(0)), diff --git a/zokrates_analysis/src/branch_isolator.rs b/zokrates_analysis/src/branch_isolator.rs index ee41233a4..6b6789ec6 100644 --- a/zokrates_analysis/src/branch_isolator.rs +++ b/zokrates_analysis/src/branch_isolator.rs @@ -3,6 +3,7 @@ // `if c then a else b fi` becomes `if c then { a } else { b } fi`, and down the line any statements resulting from trating `a` and `b` can be safely kept inside the respective blocks. +use zokrates_ast::common::{Fold, WithSpan}; use zokrates_ast::typed::folder::*; use zokrates_ast::typed::*; use zokrates_field::Field; @@ -18,17 +19,25 @@ impl Isolator { impl<'ast, T: Field> Folder<'ast, T> for Isolator { fn fold_conditional_expression< - E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + Block<'ast, T> + Fold + Conditional<'ast, T>, >( &mut self, _: &E::Ty, e: ConditionalExpression<'ast, T, E>, ) -> ConditionalOrExpression<'ast, T, E> { - ConditionalOrExpression::Conditional(ConditionalExpression::new( - self.fold_boolean_expression(*e.condition), - E::block(vec![], e.consequence.fold(self)), - E::block(vec![], e.alternative.fold(self)), - e.kind, - )) + let span = e.get_span(); + + let consequence_span = e.consequence.get_span(); + let alternative_span = e.alternative.get_span(); + + ConditionalOrExpression::Conditional( + ConditionalExpression::new( + self.fold_boolean_expression(*e.condition), + E::block(vec![], e.consequence.fold(self)).span(consequence_span), + E::block(vec![], e.alternative.fold(self)).span(alternative_span), + e.kind, + ) + .span(span), + ) } } diff --git a/zokrates_analysis/src/condition_redefiner.rs b/zokrates_analysis/src/condition_redefiner.rs index 775271b1d..535a51375 100644 --- a/zokrates_analysis/src/condition_redefiner.rs +++ b/zokrates_analysis/src/condition_redefiner.rs @@ -1,7 +1,10 @@ -use zokrates_ast::typed::{ - folder::*, BlockExpression, BooleanExpression, Conditional, ConditionalExpression, - ConditionalOrExpression, CoreIdentifier, Expr, Id, Identifier, Type, TypedExpression, - TypedProgram, TypedStatement, Variable, +use zokrates_ast::{ + common::{Fold, WithSpan}, + typed::{ + folder::*, BlockExpression, BooleanExpression, Conditional, ConditionalExpression, + ConditionalOrExpression, CoreIdentifier, Expr, Id, Identifier, Type, TypedExpression, + TypedProgram, TypedStatement, Variable, + }, }; use zokrates_field::Field; @@ -18,14 +21,14 @@ impl<'ast, T: Field> ConditionRedefiner<'ast, T> { } impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> { - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { + fn fold_statement_cases(&mut self, s: TypedStatement<'ast, T>) -> Vec> { assert!(self.buffer.is_empty()); - let s = fold_statement(self, s); + let s = fold_statement_cases(self, s); let buffer = std::mem::take(&mut self.buffer); buffer.into_iter().chain(s).collect() } - fn fold_block_expression>( + fn fold_block_expression>( &mut self, b: BlockExpression<'ast, T, E>, ) -> BlockExpression<'ast, T, E> { @@ -54,25 +57,32 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> { b } - fn fold_conditional_expression + Conditional<'ast, T> + Fold<'ast, T>>( + fn fold_conditional_expression + Conditional<'ast, T> + Fold>( &mut self, _: &E::Ty, e: ConditionalExpression<'ast, T, E>, ) -> ConditionalOrExpression<'ast, T, E> { let condition = self.fold_boolean_expression(*e.condition); + let condition_span = condition.get_span(); let condition = match condition { condition @ BooleanExpression::Value(_) | condition @ BooleanExpression::Identifier(_) => condition, condition => { let condition_id = Identifier::from(CoreIdentifier::Condition(self.index)); - self.buffer.push(TypedStatement::definition( - Variable::immutable(condition_id.clone(), Type::Boolean).into(), - TypedExpression::from(condition), - )); + self.buffer.push( + TypedStatement::definition( + Variable::new(condition_id.clone(), Type::Boolean) + .span(condition_span) + .into(), + TypedExpression::from(condition).span(condition_span), + ) + .span(condition_span), + ); self.index += 1; - BooleanExpression::identifier(condition_id) + BooleanExpression::identifier(condition_id).span(condition_span) } - }; + } + .span(condition_span); let consequence = e.consequence.fold(self); let alternative = e.alternative.fold(self); @@ -89,6 +99,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> { #[cfg(test)] mod tests { use super::*; + use std::ops::*; use zokrates_ast::typed::{ Block, BooleanExpression, Conditional, ConditionalKind, FieldElementExpression, Type, }; @@ -102,9 +113,9 @@ mod tests { let s = TypedStatement::definition( Variable::field_element("foo").into(), FieldElementExpression::conditional( - BooleanExpression::Value(true), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + BooleanExpression::value(true), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ) .into(), @@ -124,8 +135,8 @@ mod tests { Variable::field_element("foo").into(), FieldElementExpression::conditional( BooleanExpression::identifier("c".into()), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ) .into(), @@ -143,17 +154,17 @@ mod tests { // bool #CONDITION_0 = c && d; // field foo = if #CONDITION_0 { 1 } else { 2 }; - let condition = BooleanExpression::And( - box BooleanExpression::identifier("c".into()), - box BooleanExpression::identifier("d".into()), + let condition = BooleanExpression::bitand( + BooleanExpression::identifier("c".into()), + BooleanExpression::identifier("d".into()), ); let s = TypedStatement::definition( Variable::field_element("foo").into(), FieldElementExpression::conditional( condition.clone(), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ) .into(), @@ -164,7 +175,7 @@ mod tests { let expected = vec![ // define condition TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(), + Variable::new(CoreIdentifier::Condition(0), Type::Boolean).into(), condition.into(), ), // rewrite statement @@ -172,8 +183,8 @@ mod tests { Variable::field_element("foo").into(), FieldElementExpression::conditional( BooleanExpression::identifier(CoreIdentifier::Condition(0).into()), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ) .into(), @@ -202,14 +213,14 @@ mod tests { // 3 // }; - let condition_0 = BooleanExpression::And( - box BooleanExpression::identifier("c".into()), - box BooleanExpression::identifier("d".into()), + let condition_0 = BooleanExpression::bitand( + BooleanExpression::identifier("c".into()), + BooleanExpression::identifier("d".into()), ); - let condition_1 = BooleanExpression::And( - box BooleanExpression::identifier("e".into()), - box BooleanExpression::identifier("f".into()), + let condition_1 = BooleanExpression::bitand( + BooleanExpression::identifier("e".into()), + BooleanExpression::identifier("f".into()), ); let s = TypedStatement::definition( @@ -218,11 +229,11 @@ mod tests { condition_0.clone(), FieldElementExpression::conditional( condition_1.clone(), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ) .into(), @@ -233,11 +244,11 @@ mod tests { let expected = vec![ // define conditions TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(), + Variable::new(CoreIdentifier::Condition(0), Type::Boolean).into(), condition_0.into(), ), TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean).into(), + Variable::new(CoreIdentifier::Condition(1), Type::Boolean).into(), condition_1.into(), ), // rewrite statement @@ -247,11 +258,11 @@ mod tests { BooleanExpression::identifier(CoreIdentifier::Condition(0).into()), FieldElementExpression::conditional( BooleanExpression::identifier(CoreIdentifier::Condition(1).into()), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ConditionalKind::IfElse, ), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ) .into(), @@ -284,19 +295,19 @@ mod tests { // if #CONDITION_2 { 2 } : { 3 } // }; - let condition_0 = BooleanExpression::And( - box BooleanExpression::identifier("c".into()), - box BooleanExpression::identifier("d".into()), + let condition_0 = BooleanExpression::bitand( + BooleanExpression::identifier("c".into()), + BooleanExpression::identifier("d".into()), ); - let condition_1 = BooleanExpression::And( - box BooleanExpression::identifier("e".into()), - box BooleanExpression::identifier("f".into()), + let condition_1 = BooleanExpression::bitand( + BooleanExpression::identifier("e".into()), + BooleanExpression::identifier("f".into()), ); - let condition_2 = BooleanExpression::And( - box BooleanExpression::identifier("e".into()), - box BooleanExpression::identifier("f".into()), + let condition_2 = BooleanExpression::bitand( + BooleanExpression::identifier("e".into()), + BooleanExpression::identifier("f".into()), ); let condition_id_0 = BooleanExpression::identifier(CoreIdentifier::Condition(0).into()); @@ -310,24 +321,24 @@ mod tests { FieldElementExpression::block( vec![TypedStatement::definition( Variable::field_element("a").into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), )], FieldElementExpression::conditional( condition_1.clone(), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ), ), FieldElementExpression::block( vec![TypedStatement::definition( Variable::field_element("b").into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), )], FieldElementExpression::conditional( condition_2.clone(), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ), ), @@ -341,7 +352,7 @@ mod tests { let expected = vec![ // define conditions TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(0), Type::Boolean).into(), + Variable::new(CoreIdentifier::Condition(0), Type::Boolean).into(), condition_0.into(), ), // rewrite statement @@ -353,18 +364,17 @@ mod tests { vec![ TypedStatement::definition( Variable::field_element("a").into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), ), TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(1), Type::Boolean) - .into(), + Variable::new(CoreIdentifier::Condition(1), Type::Boolean).into(), condition_1.into(), ), ], FieldElementExpression::conditional( condition_id_1, - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ), ), @@ -372,18 +382,17 @@ mod tests { vec![ TypedStatement::definition( Variable::field_element("b").into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), TypedStatement::definition( - Variable::immutable(CoreIdentifier::Condition(2), Type::Boolean) - .into(), + Variable::new(CoreIdentifier::Condition(2), Type::Boolean).into(), condition_2.into(), ), ], FieldElementExpression::conditional( condition_id_2, - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ), ), diff --git a/zokrates_analysis/src/constant_argument_checker.rs b/zokrates_analysis/src/constant_argument_checker.rs index 485cf9181..a8b772f51 100644 --- a/zokrates_analysis/src/constant_argument_checker.rs +++ b/zokrates_analysis/src/constant_argument_checker.rs @@ -1,9 +1,10 @@ use std::fmt; use zokrates_ast::common::FlatEmbed; +use zokrates_ast::typed::result_folder::*; +use zokrates_ast::typed::{result_folder::ResultFolder, Constant, EmbedCall, TypedStatement}; use zokrates_ast::typed::{ - result_folder::fold_statement, result_folder::ResultFolder, Constant, EmbedCall, TypedStatement, + DefinitionRhs, DefinitionStatement, TypedProgram, UBitwidth, UExpression, UExpressionInner, }; -use zokrates_ast::typed::{DefinitionRhs, TypedProgram}; use zokrates_field::Field; pub struct ConstantArgumentChecker; @@ -26,47 +27,83 @@ impl fmt::Display for Error { impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { type Error = Error; - fn fold_statement( + fn fold_definition_statement( &mut self, - s: TypedStatement<'ast, T>, + s: DefinitionStatement<'ast, T>, ) -> Result>, Self::Error> { - match s { - TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => { - match embed_call { - EmbedCall { - embed: FlatEmbed::BitArrayLe, - .. - } => { - let arguments = embed_call - .arguments - .into_iter() - .map(|a| self.fold_expression(a)) - .collect::, _>>()?; + match s.rhs { + DefinitionRhs::EmbedCall(embed_call) => match embed_call { + EmbedCall { + embed: FlatEmbed::BitArrayLe, + .. + } => { + let arguments = embed_call + .arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::, _>>()?; - if arguments[1].is_constant() { - Ok(vec![TypedStatement::Definition( - assignee, - EmbedCall { - embed: FlatEmbed::BitArrayLe, - generics: embed_call.generics, - arguments, - } - .into(), - )]) - } else { - Err(Error(format!( - "Cannot compare to a variable value, found `{}`", - arguments[1] - ))) - } + if arguments[1].is_constant() { + Ok(vec![TypedStatement::embed_call_definition( + s.assignee, + EmbedCall { + embed: FlatEmbed::BitArrayLe, + generics: embed_call.generics, + arguments, + }, + )]) + } else { + Err(Error(format!( + "Cannot compare to a variable value, found `{}`", + arguments[1] + ))) } - embed_call => Ok(vec![TypedStatement::Definition( - assignee, - embed_call.into(), - )]), + } + embed_call => Ok(vec![TypedStatement::embed_call_definition( + s.assignee, embed_call, + )]), + }, + _ => fold_definition_statement(self, s), + } + } + + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Error> { + match e { + UExpressionInner::LeftShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + + match right.as_inner() { + UExpressionInner::Value(_) => { + Ok(UExpression::left_shift(left, right).into_inner()) + } + by => Err(Error(format!( + "Cannot shift by a variable value, found `{} << {}`", + left, + by.clone().annotate(UBitwidth::B32) + ))), + } + } + UExpressionInner::RightShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + + match right.as_inner() { + UExpressionInner::Value(_) => { + Ok(UExpression::right_shift(left, right).into_inner()) + } + by => Err(Error(format!( + "Cannot shift by a variable value, found `{} >> {}`", + left, + by.clone().annotate(UBitwidth::B32) + ))), } } - s => fold_statement(self, s), + e => fold_uint_expression_cases(self, bitwidth, e), } } } diff --git a/zokrates_analysis/src/constant_resolver.rs b/zokrates_analysis/src/constant_resolver.rs index bf816248f..5ff42f4c2 100644 --- a/zokrates_analysis/src/constant_resolver.rs +++ b/zokrates_analysis/src/constant_resolver.rs @@ -9,18 +9,18 @@ use zokrates_field::Field; // a map of the canonical constants in this program. with all imported constants reduced to their canonical value type ProgramConstants<'ast, T> = - HashMap, TypedConstant<'ast, T>>>; + HashMap, TypedConstant<'ast, T>>>; pub struct ConstantResolver<'ast, T> { modules: TypedModules<'ast, T>, - location: OwnedTypedModuleId, + location: OwnedModuleId, constants: ProgramConstants<'ast, T>, } impl<'ast, T: Field> ConstantResolver<'ast, T> { pub fn new( modules: TypedModules<'ast, T>, - location: OwnedTypedModuleId, + location: OwnedModuleId, constants: ProgramConstants<'ast, T>, ) -> Self { ConstantResolver { @@ -35,14 +35,14 @@ impl<'ast, T: Field> ConstantResolver<'ast, T> { inliner.fold_program(p) } - fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn change_location(&mut self, location: OwnedModuleId) -> OwnedModuleId { let prev = self.location.clone(); self.location = location; self.constants.entry(self.location.clone()).or_default(); prev } - fn treated(&self, id: &TypedModuleId) -> bool { + fn treated(&self, id: &ModuleId) -> bool { self.constants.contains_key(id) } @@ -67,7 +67,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantResolver<'ast, T> { } } - fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn fold_module_id(&mut self, id: OwnedModuleId) -> OwnedModuleId { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); @@ -109,10 +109,11 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantResolver<'ast, T> { #[cfg(test)] mod tests { use super::*; + use std::ops::*; use zokrates_ast::typed::types::{DeclarationSignature, GTupleType}; use zokrates_ast::typed::{ DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression, - GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, + Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedStatement, }; use zokrates_field::Bn128Field; @@ -130,7 +131,7 @@ mod tests { let const_id = "a"; let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from(const_id)).into(), )], signature: DeclarationSignature::new() @@ -139,6 +140,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -147,7 +149,7 @@ mod tests { TypedConstantSymbolDeclaration::new( CanonicalConstantIdentifier::new(const_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( + TypedExpression::FieldElement(FieldElementExpression::value( Bn128Field::from(1), )), DeclarationType::FieldElement, @@ -190,7 +192,7 @@ mod tests { let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( BooleanExpression::identifier(Identifier::from(const_id.clone())).into(), )], signature: DeclarationSignature::new() @@ -199,6 +201,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -207,7 +210,7 @@ mod tests { TypedConstantSymbolDeclaration::new( const_id, TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::Boolean(BooleanExpression::Value(true)), + TypedExpression::Boolean(BooleanExpression::value(true)), DeclarationType::Boolean, )), ) @@ -248,7 +251,7 @@ mod tests { let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( UExpression::identifier(Identifier::from(const_id.clone())) .annotate(UBitwidth::B32) .into(), @@ -259,6 +262,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -267,9 +271,7 @@ mod tests { TypedConstantSymbolDeclaration::new( const_id, TypedConstantSymbol::Here(TypedConstant::new( - UExpressionInner::Value(1u128) - .annotate(UBitwidth::B32) - .into(), + UExpression::value(1u128).annotate(UBitwidth::B32).into(), DeclarationType::Uint(UBitwidth::B32), )), ) @@ -310,20 +312,18 @@ mod tests { let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( - FieldElementExpression::Add( + statements: vec![TypedStatement::ret( + FieldElementExpression::add( FieldElementExpression::select( ArrayExpression::identifier(Identifier::from(const_id.clone())) - .annotate(GType::FieldElement, 2u32), - UExpressionInner::Value(0u128).annotate(UBitwidth::B32), - ) - .into(), + .annotate(GArrayType::new(Type::FieldElement, 2u32)), + UExpression::value(0u128).annotate(UBitwidth::B32), + ), FieldElementExpression::select( ArrayExpression::identifier(Identifier::from(const_id.clone())) - .annotate(GType::FieldElement, 2u32), - UExpressionInner::Value(1u128).annotate(UBitwidth::B32), - ) - .into(), + .annotate(GArrayType::new(Type::FieldElement, 2u32)), + UExpression::value(1u128).annotate(UBitwidth::B32), + ), ) .into(), )], @@ -333,6 +333,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -342,16 +343,11 @@ mod tests { const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::Array( - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)) - .into(), - FieldElementExpression::Number(Bn128Field::from(2)) - .into(), - ] - .into(), - ) - .annotate(GType::FieldElement, 2u32), + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), + ]) + .annotate(GArrayType::new(Type::FieldElement, 2u32)), ), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, @@ -397,7 +393,7 @@ mod tests { let main: TypedFunction = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from(const_b_id.clone())).into(), )], signature: DeclarationSignature::new() @@ -406,6 +402,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -414,7 +411,7 @@ mod tests { TypedConstantSymbolDeclaration::new( const_a_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( + TypedExpression::FieldElement(FieldElementExpression::value( Bn128Field::from(1), )), DeclarationType::FieldElement, @@ -424,11 +421,11 @@ mod tests { TypedConstantSymbolDeclaration::new( const_b_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Add( - box FieldElementExpression::identifier(Identifier::from( + TypedExpression::FieldElement(FieldElementExpression::add( + FieldElementExpression::identifier(Identifier::from( const_a_id.clone(), )), - box FieldElementExpression::Number(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(1)), )), DeclarationType::FieldElement, )), @@ -505,7 +502,7 @@ mod tests { TypedConstantSymbolDeclaration::new( foo_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( + TypedExpression::FieldElement(FieldElementExpression::value( Bn128Field::from(42), )), DeclarationType::FieldElement, @@ -556,7 +553,7 @@ mod tests { ), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from( main_const_id.clone(), )) @@ -572,6 +569,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![ ("main".into(), main_module), @@ -602,7 +600,7 @@ mod tests { ), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from( main_const_id.clone(), )) @@ -618,6 +616,7 @@ mod tests { }; let expected_program: TypedProgram = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![ ("main".into(), expected_main_module), @@ -683,7 +682,7 @@ mod tests { TypedConstantSymbolDeclaration::new( foo_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( + TypedExpression::FieldElement(FieldElementExpression::value( Bn128Field::from(2), )), DeclarationType::FieldElement, @@ -693,13 +692,10 @@ mod tests { TypedConstantSymbolDeclaration::new( bar_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::Array( - ArrayExpressionInner::Repeat( - box FieldElementExpression::Number(Bn128Field::from(1)).into(), - box UExpression::from(foo_const_id.clone()), - ) - .annotate(Type::FieldElement, foo_const_id.clone()), - ), + TypedExpression::Array(ArrayExpression::repeat( + FieldElementExpression::value(Bn128Field::from(1)).into(), + UExpression::from(foo_const_id.clone()), + )), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, DeclarationConstant::Constant(foo_const_id.clone()), @@ -745,8 +741,9 @@ mod tests { main_baz_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::Array( - ArrayExpression::identifier(main_bar_const_id.clone().into()) - .annotate(Type::FieldElement, main_foo_const_id.clone()), + ArrayExpression::identifier(main_bar_const_id.clone().into()).annotate( + ArrayType::new(Type::FieldElement, main_foo_const_id.clone()), + ), ), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, @@ -763,7 +760,7 @@ mod tests { ), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from( main_foo_const_id.clone(), )) @@ -779,6 +776,7 @@ mod tests { }; let program = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![ ("main".into(), main_module), @@ -794,7 +792,7 @@ mod tests { TypedConstantSymbolDeclaration::new( main_foo_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), DeclarationType::FieldElement, )), ) @@ -802,13 +800,10 @@ mod tests { TypedConstantSymbolDeclaration::new( main_bar_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::Array( - ArrayExpressionInner::Repeat( - box FieldElementExpression::Number(Bn128Field::from(1)).into(), - box UExpression::from(foo_const_id.clone()), - ) - .annotate(Type::FieldElement, foo_const_id.clone()), - ), + TypedExpression::Array(ArrayExpression::repeat( + FieldElementExpression::value(Bn128Field::from(1)).into(), + UExpression::from(foo_const_id.clone()), + )), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, DeclarationConstant::Constant(foo_const_id.clone()), @@ -820,8 +815,9 @@ mod tests { main_baz_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::Array( - ArrayExpression::identifier(main_bar_const_id.into()) - .annotate(Type::FieldElement, main_foo_const_id.clone()), + ArrayExpression::identifier(main_bar_const_id.into()).annotate( + ArrayType::new(Type::FieldElement, main_foo_const_id.clone()), + ), ), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, @@ -838,7 +834,7 @@ mod tests { ), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier(Identifier::from( main_foo_const_id.clone(), )) @@ -854,6 +850,7 @@ mod tests { }; let expected_program: TypedProgram = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![ ("main".into(), expected_main_module), diff --git a/zokrates_analysis/src/dead_code.rs b/zokrates_analysis/src/dead_code.rs index 4b67b08b2..d1baba976 100644 --- a/zokrates_analysis/src/dead_code.rs +++ b/zokrates_analysis/src/dead_code.rs @@ -26,38 +26,56 @@ impl<'ast, T: Field> Folder<'ast, T> for DeadCodeEliminator<'ast> { ZirFunction { statements, ..f } } - fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { - match s { - ZirStatement::Definition(v, e) => { - // if the lhs is used later in the program - if self.used.remove(&v.id) { - // include this statement - fold_statement(self, ZirStatement::Definition(v, e)) - } else { - // otherwise remove it - vec![] - } - } - ZirStatement::IfElse(condition, consequence, alternative) => { - let condition = self.fold_boolean_expression(condition); + fn fold_if_else_statement( + &mut self, + s: zokrates_ast::zir::IfElseStatement<'ast, T>, + ) -> Vec> { + let condition = self.fold_boolean_expression(s.condition); - let mut consequence: Vec<_> = consequence - .into_iter() - .rev() - .flat_map(|e| self.fold_statement(e)) - .collect(); - consequence.reverse(); + let mut consequence: Vec<_> = s + .consequence + .into_iter() + .rev() + .flat_map(|e| self.fold_statement(e)) + .collect(); + consequence.reverse(); - let mut alternative: Vec<_> = alternative - .into_iter() - .rev() - .flat_map(|e| self.fold_statement(e)) - .collect(); - alternative.reverse(); + let mut alternative: Vec<_> = s + .alternative + .into_iter() + .rev() + .flat_map(|e| self.fold_statement(e)) + .collect(); + alternative.reverse(); + + vec![ZirStatement::if_else(condition, consequence, alternative)] + } + + fn fold_multiple_definition_statement( + &mut self, + s: zokrates_ast::zir::MultipleDefinitionStatement<'ast, T>, + ) -> Vec> { + // if the lhs is used later in the program + if s.assignees.iter().any(|a| self.used.remove(&a.id)) { + // include this statement + fold_multiple_definition_statement(self, s) + } else { + // otherwise remove it + vec![] + } + } - vec![ZirStatement::IfElse(condition, consequence, alternative)] - } - s => fold_statement(self, s), + fn fold_definition_statement( + &mut self, + s: zokrates_ast::zir::DefinitionStatement<'ast, T>, + ) -> Vec> { + // if the lhs is used later in the program + if self.used.remove(&s.assignee.id) { + // include this statement + fold_definition_statement(self, s) + } else { + // otherwise remove it + vec![] } } diff --git a/zokrates_analysis/src/expression_validator.rs b/zokrates_analysis/src/expression_validator.rs index d3a0a96ef..9dcd63413 100644 --- a/zokrates_analysis/src/expression_validator.rs +++ b/zokrates_analysis/src/expression_validator.rs @@ -1,7 +1,5 @@ use std::fmt; -use zokrates_ast::typed::result_folder::{ - fold_assembly_statement, fold_field_expression, fold_uint_expression_inner, ResultFolder, -}; +use zokrates_ast::typed::{result_folder::*, AssemblyAssignment, UExpression}; use zokrates_ast::typed::{ FieldElementExpression, TypedAssemblyStatement, TypedProgram, UBitwidth, UExpressionInner, }; @@ -27,81 +25,80 @@ impl ExpressionValidator { impl<'ast, T: Field> ResultFolder<'ast, T> for ExpressionValidator { type Error = Error; - fn fold_assembly_statement( + // we allow more dynamic expressions in witness generation + fn fold_assembly_assignment( &mut self, - s: TypedAssemblyStatement<'ast, T>, + s: AssemblyAssignment<'ast, T>, ) -> Result>, Self::Error> { - match s { - // we allow more dynamic expressions in witness generation - TypedAssemblyStatement::Assignment(_, _) => Ok(vec![s]), - s => fold_assembly_statement(self, s), - } + Ok(vec![TypedAssemblyStatement::Assignment(s)]) } - fn fold_field_expression( + fn fold_field_expression_cases( &mut self, e: FieldElementExpression<'ast, T>, ) -> Result, Self::Error> { match e { // these should have been propagated away - FieldElementExpression::And(_, _) - | FieldElementExpression::Or(_, _) - | FieldElementExpression::Xor(_, _) - | FieldElementExpression::LeftShift(_, _) - | FieldElementExpression::RightShift(_, _) => Err(Error(format!( + FieldElementExpression::And(_) + | FieldElementExpression::Or(_) + | FieldElementExpression::Xor(_) + | FieldElementExpression::LeftShift(_) + | FieldElementExpression::RightShift(_) => Err(Error(format!( "Found non-constant bitwise operation in field element expression `{}`", e ))), - FieldElementExpression::Pow(box e, box exp) => { - let e = self.fold_field_expression(e)?; - let exp = self.fold_uint_expression(exp)?; + FieldElementExpression::Pow(e) => { + let base = self.fold_field_expression(*e.left)?; + let exp = self.fold_uint_expression(*e.right)?; match exp.as_inner() { - UExpressionInner::Value(_) => Ok(FieldElementExpression::Pow(box e, box exp)), + UExpressionInner::Value(_) => Ok(FieldElementExpression::pow(base, exp)), exp => Err(Error(format!( "Found non-constant exponent in power expression `{}**{}`", - e, + base, exp.clone().annotate(UBitwidth::B32) ))), } } - e => fold_field_expression(self, e), + e => fold_field_expression_cases(self, e), } } - fn fold_uint_expression_inner( + fn fold_uint_expression_cases( &mut self, bitwidth: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, Error> { match e { - UExpressionInner::LeftShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; + UExpressionInner::LeftShift(e) => { + let expr = self.fold_uint_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), - by => Err(Error(format!( - "Cannot shift by a variable value, found `{} << {}`", - e, - by.clone().annotate(UBitwidth::B32) + UExpressionInner::Value(_) => { + Ok(UExpression::left_shift(expr, by).into_inner()) + } + _ => Err(Error(format!( + "Cannot shift by a variable value, found `{}`", + UExpression::left_shift(expr, by) ))), } } - UExpressionInner::RightShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; + UExpressionInner::RightShift(e) => { + let expr = self.fold_uint_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; match by.as_inner() { - UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), - by => Err(Error(format!( - "Cannot shift by a variable value, found `{} >> {}`", - e, - by.clone().annotate(UBitwidth::B32) + UExpressionInner::Value(_) => { + Ok(UExpression::right_shift(expr, by).into_inner()) + } + _ => Err(Error(format!( + "Cannot shift by a variable value, found `{}`", + UExpression::right_shift(expr, by) ))), } } - e => fold_uint_expression_inner(self, bitwidth, e), + e => fold_uint_expression_cases(self, bitwidth, e), } } } diff --git a/zokrates_analysis/src/flat_propagation.rs b/zokrates_analysis/src/flat_propagation.rs index 155d803c8..547712ff3 100644 --- a/zokrates_analysis/src/flat_propagation.rs +++ b/zokrates_analysis/src/flat_propagation.rs @@ -5,6 +5,9 @@ //! @date 2018 use std::collections::HashMap; +use std::ops::*; +use zokrates_ast::common::expressions::IdentifierOrExpression; +use zokrates_ast::common::WithSpan; use zokrates_ast::flat::folder::*; use zokrates_ast::flat::*; use zokrates_field::Field; @@ -17,49 +20,62 @@ struct Propagator { impl<'ast, T: Field> Folder<'ast, T> for Propagator { fn fold_statement(&mut self, s: FlatStatement<'ast, T>) -> Vec> { match s { - FlatStatement::Definition(var, expr) => match self.fold_expression(expr) { - FlatExpression::Number(n) => { - self.constants.insert(var, n); + FlatStatement::Definition(s) => match self.fold_expression(s.rhs) { + FlatExpression::Value(n) => { + self.constants.insert(s.assignee, n.value); vec![] } - e => vec![FlatStatement::Definition(var, e)], + e => vec![FlatStatement::definition(s.assignee, e)], }, s => fold_statement(self, s), } } + fn fold_identifier_expression( + &mut self, + e: zokrates_ast::common::expressions::IdentifierExpression>, + ) -> IdentifierOrExpression, FlatExpression> { + match self.constants.get(&e.id) { + Some(c) => IdentifierOrExpression::Expression(FlatExpression::value(*c)), + None => IdentifierOrExpression::Identifier(e), + } + } + fn fold_expression(&mut self, e: FlatExpression) -> FlatExpression { + let span = e.get_span(); + match e { - FlatExpression::Number(n) => FlatExpression::Number(n), - FlatExpression::Identifier(id) => match self.constants.get(&id) { - Some(c) => FlatExpression::Number(c.clone()), - None => FlatExpression::Identifier(id), - }, - FlatExpression::Add(box e1, box e2) => { - match (self.fold_expression(e1), self.fold_expression(e2)) { - (FlatExpression::Number(n1), FlatExpression::Number(n2)) => { - FlatExpression::Number(n1 + n2) - } - (e1, e2) => FlatExpression::Add(box e1, box e2), + FlatExpression::Value(n) => FlatExpression::Value(n), + FlatExpression::Add(e) => match ( + self.fold_expression(*e.left), + self.fold_expression(*e.right), + ) { + (FlatExpression::Value(n1), FlatExpression::Value(n2)) => { + FlatExpression::value(n1.value + n2.value) } - } - FlatExpression::Sub(box e1, box e2) => { - match (self.fold_expression(e1), self.fold_expression(e2)) { - (FlatExpression::Number(n1), FlatExpression::Number(n2)) => { - FlatExpression::Number(n1 - n2) - } - (e1, e2) => FlatExpression::Sub(box e1, box e2), + (e1, e2) => FlatExpression::add(e1, e2), + }, + FlatExpression::Sub(e) => match ( + self.fold_expression(*e.left), + self.fold_expression(*e.right), + ) { + (FlatExpression::Value(n1), FlatExpression::Value(n2)) => { + FlatExpression::value(n1.value - n2.value) } - } - FlatExpression::Mult(box e1, box e2) => { - match (self.fold_expression(e1), self.fold_expression(e2)) { - (FlatExpression::Number(n1), FlatExpression::Number(n2)) => { - FlatExpression::Number(n1 * n2) - } - (e1, e2) => FlatExpression::Mult(box e1, box e2), + (e1, e2) => FlatExpression::sub(e1, e2), + }, + FlatExpression::Mult(e) => match ( + self.fold_expression(*e.left), + self.fold_expression(*e.right), + ) { + (FlatExpression::Value(n1), FlatExpression::Value(n2)) => { + FlatExpression::value(n1.value * n2.value) } - } + (e1, e2) => FlatExpression::mul(e1, e2), + }, + e => fold_expression(self, e), } + .span(span) } } @@ -80,14 +96,14 @@ mod tests { fn add() { let mut propagator = Propagator::default(); - let e = FlatExpression::Add( - box FlatExpression::Number(Bn128Field::from(2)), - box FlatExpression::Number(Bn128Field::from(3)), + let e = FlatExpression::add( + FlatExpression::value(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(3)), ); assert_eq!( propagator.fold_expression(e), - FlatExpression::Number(Bn128Field::from(5)) + FlatExpression::value(Bn128Field::from(5)) ); } @@ -95,14 +111,14 @@ mod tests { fn sub() { let mut propagator = Propagator::default(); - let e = FlatExpression::Sub( - box FlatExpression::Number(Bn128Field::from(3)), - box FlatExpression::Number(Bn128Field::from(2)), + let e = FlatExpression::sub( + FlatExpression::value(Bn128Field::from(3)), + FlatExpression::value(Bn128Field::from(2)), ); assert_eq!( propagator.fold_expression(e), - FlatExpression::Number(Bn128Field::from(1)) + FlatExpression::value(Bn128Field::from(1)) ); } @@ -110,14 +126,14 @@ mod tests { fn mult() { let mut propagator = Propagator::default(); - let e = FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(3)), - box FlatExpression::Number(Bn128Field::from(2)), + let e = FlatExpression::mul( + FlatExpression::value(Bn128Field::from(3)), + FlatExpression::value(Bn128Field::from(2)), ); assert_eq!( propagator.fold_expression(e), - FlatExpression::Number(Bn128Field::from(6)) + FlatExpression::value(Bn128Field::from(6)) ); } } diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index 0b834ef8a..64705d1ad 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -1,10 +1,15 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; +use std::ops::*; +use zokrates_ast::common::expressions::{BinaryExpression, ValueExpression}; +use zokrates_ast::common::operators::OpEq; +use zokrates_ast::common::statements::LogStatement; +use zokrates_ast::common::{Span, WithSpan}; use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth}; -use zokrates_ast::typed::{self, Expr, Typed}; -use zokrates_ast::zir::IntoType as ZirIntoType; -use zokrates_ast::zir::{self, Folder, Id, Select}; +use zokrates_ast::typed::{self, Basic, Expr, Typed}; +use zokrates_ast::zir::{self, Expr as ZirExpr, Folder, Id, MultipleDefinitionStatement, Select}; +use zokrates_ast::zir::{IntoType as ZirIntoType, SourceIdentifier}; use zokrates_field::Field; #[derive(Default)] @@ -18,31 +23,30 @@ fn flatten_identifier_rec<'ast>( ) -> Vec> { match ty { typed::ConcreteType::Int => unreachable!(), - typed::ConcreteType::FieldElement => vec![zir::Variable { - id: zir::Identifier::Source(id), - _type: zir::Type::FieldElement, - }], - typed::types::ConcreteType::Boolean => vec![zir::Variable { - id: zir::Identifier::Source(id), - _type: zir::Type::Boolean, - }], - typed::types::ConcreteType::Uint(bitwidth) => vec![zir::Variable { - id: zir::Identifier::Source(id), - _type: zir::Type::uint(bitwidth.to_usize()), - }], + typed::ConcreteType::FieldElement => vec![zir::Variable::new( + zir::Identifier::Source(id), + zir::Type::FieldElement, + )], + typed::types::ConcreteType::Boolean => vec![zir::Variable::new( + zir::Identifier::Source(id), + zir::Type::Boolean, + )], + typed::types::ConcreteType::Uint(bitwidth) => { + vec![zir::Variable::new( + zir::Identifier::Source(id), + zir::Type::uint(bitwidth.to_usize()), + )] + } typed::types::ConcreteType::Array(array_type) => (0..*array_type.size) .flat_map(|i| { - flatten_identifier_rec( - zir::SourceIdentifier::Select(box id.clone(), i), - &array_type.ty, - ) + flatten_identifier_rec(SourceIdentifier::select(id.clone(), i), &array_type.ty) }) .collect(), typed::types::ConcreteType::Struct(members) => members .iter() .flat_map(|struct_member| { flatten_identifier_rec( - zir::SourceIdentifier::Member(box id.clone(), struct_member.id.clone()), + SourceIdentifier::member(id.clone(), struct_member.id.clone()), &struct_member.ty, ) }) @@ -52,7 +56,7 @@ fn flatten_identifier_rec<'ast>( .iter() .enumerate() .flat_map(|(i, ty)| { - flatten_identifier_rec(zir::SourceIdentifier::Element(box id.clone(), i as u32), ty) + flatten_identifier_rec(SourceIdentifier::element(id.clone(), i as u32), ty) }) .collect(), } @@ -78,7 +82,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( typed::ConcreteType::Array(array_type) => (0..*array_type.size) .flat_map(|i| { flatten_identifier_to_expression_rec( - zir::SourceIdentifier::Select(box id.clone(), i), + SourceIdentifier::select(id.clone(), i), &array_type.ty, ) }) @@ -87,7 +91,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( .iter() .flat_map(|struct_member| { flatten_identifier_to_expression_rec( - zir::SourceIdentifier::Member(box id.clone(), struct_member.id.clone()), + SourceIdentifier::member(id.clone(), struct_member.id.clone()), &struct_member.ty, ) }) @@ -98,7 +102,7 @@ fn flatten_identifier_to_expression_rec<'ast, T: Field>( .enumerate() .flat_map(|(i, ty)| { flatten_identifier_to_expression_rec( - zir::SourceIdentifier::Element(box id.clone(), i as u32), + SourceIdentifier::element(id.clone(), i as u32), ty, ) }) @@ -192,44 +196,50 @@ impl<'ast, T: Field> Flattener { &mut self, p: typed::DeclarationParameter<'ast, T>, ) -> Vec> { + let span = p.get_span(); + let private = p.private; self.fold_variable(zokrates_ast::typed::variable::try_from_g_variable(p.id).unwrap()) .into_iter() - .map(|v| zir::Parameter { id: v, private }) + .map(|v| zir::Parameter::new(v, private).span(span)) .collect() } fn fold_name(&mut self, n: typed::Identifier<'ast>) -> zir::SourceIdentifier<'ast> { - zir::SourceIdentifier::Basic(n) + SourceIdentifier::Basic(n) } fn fold_variable(&mut self, v: typed::Variable<'ast, T>) -> Vec> { + let span = v.get_span(); let ty = v.get_type(); let id = self.fold_name(v.id); let ty = typed::types::ConcreteType::try_from(ty).unwrap(); flatten_identifier_rec(id, &ty) + .into_iter() + .map(|v| v.span(span)) + .collect() } fn fold_assignee(&mut self, a: typed::TypedAssignee<'ast, T>) -> Vec> { match a { typed::TypedAssignee::Identifier(v) => self.fold_variable(v), - typed::TypedAssignee::Select(box a, box i) => { + typed::TypedAssignee::Select(a, i) => { let count = match typed::ConcreteType::try_from(a.get_type()).unwrap() { typed::ConcreteType::Array(array_ty) => array_ty.ty.get_primitive_count(), _ => unreachable!(), }; - let a = self.fold_assignee(a); + let a = self.fold_assignee(*a); match i.as_inner() { typed::UExpressionInner::Value(index) => { - a[*index as usize * count..(*index as usize + 1) * count].to_vec() + a[index.value as usize * count..(index.value as usize + 1) * count].to_vec() } i => unreachable!("index {:?} not allowed, should be a constant", i), } } - typed::TypedAssignee::Member(box a, m) => { + typed::TypedAssignee::Member(a, m) => { let (offset, size) = match typed::ConcreteType::try_from(a.get_type()).unwrap() { typed::ConcreteType::Struct(struct_type) => struct_type @@ -247,11 +257,11 @@ impl<'ast, T: Field> Flattener { let size = size.unwrap(); - let a = self.fold_assignee(a); + let a = self.fold_assignee(*a); a[offset..offset + size].to_vec() } - typed::TypedAssignee::Element(box a, index) => { + typed::TypedAssignee::Element(a, index) => { let tuple_ty = typed::ConcreteTupleType::try_from( typed::ConcreteType::try_from(a.get_type()).unwrap(), ) @@ -266,7 +276,7 @@ impl<'ast, T: Field> Flattener { let size = &tuple_ty.elements[index as usize].get_primitive_count(); - let a = self.fold_assignee(a); + let a = self.fold_assignee(*a); a[offset..offset + size].to_vec() } @@ -366,6 +376,20 @@ impl<'ast, T: Field> Flattener { fold_conditional_expression(self, statements_buffer, c) } + fn fold_binary_expression< + L: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + R: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + E: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + Op, + >( + &mut self, + statements_buffer: &mut Vec>, + e: BinaryExpression, + ) -> BinaryExpression + { + fold_binary_expression(self, statements_buffer, e) + } + fn fold_member_expression( &mut self, statements_buffer: &mut Vec>, @@ -393,7 +417,7 @@ impl<'ast, T: Field> Flattener { fn fold_eq_expression>( &mut self, statements_buffer: &mut Vec>, - eq: typed::EqExpression, + eq: BinaryExpression>, ) -> zir::BooleanExpression<'ast, T> { fold_eq_expression(self, statements_buffer, eq) } @@ -481,29 +505,30 @@ impl<'ast, T: Field> Flattener { // } #[derive(Default)] pub struct ArgumentFinder<'ast, T> { - pub identifiers: HashMap, zir::Type>, + pub identifiers: BTreeMap, zir::Type>, _phantom: PhantomData, } impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> { fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec> { match s { - zir::ZirStatement::Definition(assignee, expr) => { - let assignee = self.fold_assignee(assignee); - let expr = self.fold_expression(expr); + zir::ZirStatement::Definition(s) => { + let assignee = self.fold_assignee(s.assignee); + let expr = self.fold_expression(s.rhs); self.identifiers.remove(&assignee.id); - vec![zir::ZirStatement::Definition(assignee, expr)] + vec![zir::ZirStatement::definition(assignee, expr)] } - zir::ZirStatement::MultipleDefinition(assignees, list) => { - let assignees: Vec> = assignees + zir::ZirStatement::MultipleDefinition(s) => { + let assignees: Vec> = s + .assignees .into_iter() .map(|v| self.fold_assignee(v)) .collect(); - let list = self.fold_expression_list(list); + let list = self.fold_expression_list(s.rhs); for a in &assignees { self.identifiers.remove(&a.id); } - vec![zir::ZirStatement::MultipleDefinition(assignees, list)] + vec![zir::ZirStatement::multiple_definition(assignees, list)] } s => zir::folder::fold_statement(self, s), } @@ -525,12 +550,14 @@ fn fold_assembly_statement<'ast, T: Field>( statements_buffer: &mut Vec>, s: typed::TypedAssemblyStatement<'ast, T>, ) -> zir::ZirAssemblyStatement<'ast, T> { + let span = s.get_span(); + match s { - typed::TypedAssemblyStatement::Assignment(a, e) => { + typed::TypedAssemblyStatement::Assignment(s) => { let mut statements_buffer: Vec> = vec![]; - let a = f.fold_assignee(a); - let e = f.fold_expression(&mut statements_buffer, e); - statements_buffer.push(zir::ZirStatement::Return(e)); + let a = f.fold_assignee(s.assignee); + let e = f.fold_expression(&mut statements_buffer, s.expression); + statements_buffer.push(zir::ZirStatement::ret(e)); let mut finder = ArgumentFinder::default(); let mut statements_buffer: Vec> = statements_buffer @@ -547,22 +574,22 @@ fn fold_assembly_statement<'ast, T: Field>( arguments: finder .identifiers .into_iter() - .map(|(id, ty)| zir::Parameter { - id: zir::Variable::with_id_and_type(id, ty), - private: true, + .map(|(id, ty)| { + zir::Parameter::private(zir::Variable::with_id_and_type(id, ty)) }) .collect(), statements: statements_buffer, }; - zir::ZirAssemblyStatement::Assignment(a, function) + zir::ZirAssemblyStatement::assignment(a, function) } - typed::TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = f.fold_field_expression(statements_buffer, lhs); - let rhs = f.fold_field_expression(statements_buffer, rhs); - zir::ZirAssemblyStatement::Constraint(lhs, rhs, metadata) + typed::TypedAssemblyStatement::Constraint(s) => { + let lhs = f.fold_field_expression(statements_buffer, s.left); + let rhs = f.fold_field_expression(statements_buffer, s.right); + zir::ZirAssemblyStatement::constraint(lhs, rhs, s.metadata) } } + .span(span) } fn fold_statement<'ast, T: Field>( @@ -570,57 +597,68 @@ fn fold_statement<'ast, T: Field>( statements_buffer: &mut Vec>, s: typed::TypedStatement<'ast, T>, ) { + let span = s.get_span(); + let res = match s { - typed::TypedStatement::Assembly(statements) => { - let statements = statements + typed::TypedStatement::Return(s) => vec![zir::ZirStatement::ret( + f.fold_expression(statements_buffer, s.inner), + )], + typed::TypedStatement::Assembly(s) => { + let statements = s + .inner .into_iter() .map(|s| f.fold_assembly_statement(statements_buffer, s)) .collect(); - vec![zir::ZirStatement::Assembly(statements)] + vec![zir::ZirStatement::assembly(statements)] } - typed::TypedStatement::Return(expression) => vec![zir::ZirStatement::Return( - f.fold_expression(statements_buffer, expression), - )], - typed::TypedStatement::Definition(a, typed::DefinitionRhs::Expression(e)) => { + typed::TypedStatement::Definition(typed::DefinitionStatement { + assignee: a, + rhs: typed::DefinitionRhs::Expression(e), + .. + }) => { let a = f.fold_assignee(a); let e = f.fold_expression(statements_buffer, e); assert_eq!(a.len(), e.len()); a.into_iter() .zip(e.into_iter()) - .map(|(a, e)| zir::ZirStatement::Definition(a, e)) + .map(|(a, e)| zir::ZirStatement::definition(a, e)) .collect() } - typed::TypedStatement::Assertion(e, error) => { - let e = f.fold_boolean_expression(statements_buffer, e); - let error = match error { + typed::TypedStatement::Assertion(s) => { + let e = f.fold_boolean_expression(statements_buffer, s.expression); + let error = match s.error { typed::RuntimeError::SourceAssertion(metadata) => { zir::RuntimeError::SourceAssertion(metadata) } typed::RuntimeError::SelectRangeCheck => zir::RuntimeError::SelectRangeCheck, typed::RuntimeError::DivisionByZero => zir::RuntimeError::DivisionByZero, }; - vec![zir::ZirStatement::Assertion(e, error)] + vec![zir::ZirStatement::assertion(e, error)] } - typed::TypedStatement::Definition( - assignee, - typed::DefinitionRhs::EmbedCall(embed_call), - ) => { + typed::TypedStatement::Definition(typed::DefinitionStatement { + assignee: a, + rhs: typed::DefinitionRhs::EmbedCall(embed_call), + .. + }) => { vec![zir::ZirStatement::MultipleDefinition( - f.fold_assignee(assignee), - zir::ZirExpressionList::EmbedCall( - embed_call.embed, - embed_call.generics, - embed_call - .arguments - .into_iter() - .flat_map(|a| f.fold_expression(statements_buffer, a)) - .collect(), + MultipleDefinitionStatement::new( + f.fold_assignee(a), + zir::ZirExpressionList::EmbedCall( + embed_call.embed, + embed_call.generics, + embed_call + .arguments + .into_iter() + .flat_map(|a| f.fold_expression(statements_buffer, a)) + .collect(), + ), ), )] } - typed::TypedStatement::Log(l, e) => vec![zir::ZirStatement::Log( - l, - e.into_iter() + typed::TypedStatement::Log(e) => vec![zir::ZirStatement::Log(LogStatement::new( + e.format_string, + e.expressions + .into_iter() .map(|e| { ( e.get_type().try_into().unwrap(), @@ -628,11 +666,11 @@ fn fold_statement<'ast, T: Field>( ) }) .collect(), - )], + ))], typed::TypedStatement::For(..) => unreachable!(), }; - statements_buffer.extend(res); + statements_buffer.extend(res.into_iter().map(|s| s.span(span))); } fn fold_array_expression_inner<'ast, T: Field>( @@ -671,31 +709,32 @@ fn fold_array_expression_inner<'ast, T: Field>( typed::ArrayExpressionInner::Select(select) => { f.fold_select_expression(statements_buffer, select) } - typed::ArrayExpressionInner::Slice(box array, box from, box to) => { - let array = f.fold_array_expression(statements_buffer, array); - let from = f.fold_uint_expression(statements_buffer, from); - let to = f.fold_uint_expression(statements_buffer, to); + typed::ArrayExpressionInner::Slice(e) => { + let array = f.fold_array_expression(statements_buffer, *e.array); + let from = f.fold_uint_expression(statements_buffer, *e.from); + let to = f.fold_uint_expression(statements_buffer, *e.to); match (from.into_inner(), to.into_inner()) { (zir::UExpressionInner::Value(from), zir::UExpressionInner::Value(to)) => { - assert_eq!(size, to.saturating_sub(from) as u32); + assert_eq!(size, to.value.saturating_sub(from.value) as u32); let element_size = ty.get_primitive_count(); - let start = from as usize * element_size; - let end = to as usize * element_size; + let start = from.value as usize * element_size; + let end = to.value as usize * element_size; array[start..end].to_vec() } _ => unreachable!(), } } - typed::ArrayExpressionInner::Repeat(box e, box count) => { - let e = f.fold_expression(statements_buffer, e); - let count = f.fold_uint_expression(statements_buffer, count); + typed::ArrayExpressionInner::Repeat(r) => { + let e = f.fold_expression(statements_buffer, *r.e); + let count = f.fold_uint_expression(statements_buffer, *r.count); match count.into_inner() { - zir::UExpressionInner::Value(count) => { - vec![e; count as usize].into_iter().flatten().collect() - } + zir::UExpressionInner::Value(count) => vec![e; count.value as usize] + .into_iter() + .flatten() + .collect(), _ => unreachable!(), } } @@ -860,7 +899,7 @@ fn fold_select_expression<'ast, T: Field, E>( match index.as_inner() { zir::UExpressionInner::Value(v) => { - let v = *v as usize; + let v = v.value as usize; array[v * size..(v + 1) * size].to_vec() } @@ -925,6 +964,9 @@ fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>( statements_buffer: &mut Vec>, c: typed::ConditionalExpression<'ast, T, E>, ) -> Vec> { + let span = c.get_span(); + let condition_span = c.condition.get_span(); + let mut consequence_statements = vec![]; let mut alternative_statements = vec![]; @@ -935,11 +977,14 @@ fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>( assert_eq!(consequence.len(), alternative.len()); if !consequence_statements.is_empty() || !alternative_statements.is_empty() { - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); + statements_buffer.push( + zir::ZirStatement::if_else( + condition.clone().span(condition_span), + consequence_statements, + alternative_statements, + ) + .span(span), + ); } use zokrates_ast::zir::Conditional; @@ -949,19 +994,46 @@ fn fold_conditional_expression<'ast, T: Field, E: Flatten<'ast, T>>( .zip(alternative.into_iter()) .map(|(c, a)| match (c, a) { (zir::ZirExpression::FieldElement(c), zir::ZirExpression::FieldElement(a)) => { - zir::FieldElementExpression::conditional(condition.clone(), c, a).into() + zir::FieldElementExpression::conditional(condition.clone(), c, a) + .span(span) + .into() } (zir::ZirExpression::Boolean(c), zir::ZirExpression::Boolean(a)) => { - zir::BooleanExpression::conditional(condition.clone(), c, a).into() + zir::BooleanExpression::conditional(condition.clone(), c, a) + .span(span) + .into() } (zir::ZirExpression::Uint(c), zir::ZirExpression::Uint(a)) => { - zir::UExpression::conditional(condition.clone(), c, a).into() + zir::UExpression::conditional(condition.clone(), c, a) + .span(span) + .into() } _ => unreachable!(), }) .collect() } +fn fold_binary_expression< + 'ast, + T: Field, + L: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + R: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + E: Flatten<'ast, T> + typed::Expr<'ast, T> + Basic<'ast, T>, + Op, +>( + f: &mut Flattener, + statements_buffer: &mut Vec>, + e: BinaryExpression, +) -> BinaryExpression { + let left_span = e.left.get_span(); + let right_span = e.left.get_span(); + + let left: L::ZirExpressionType = e.left.flatten(f, statements_buffer).pop().unwrap().into(); + let right: R::ZirExpressionType = e.right.flatten(f, statements_buffer).pop().unwrap().into(); + + BinaryExpression::new(left.span(left_span), right.span(right_span)).span(e.span) +} + fn fold_identifier_expression<'ast, T: Field, E: Expr<'ast, T>>( f: &mut Flattener, ty: E::ConcreteTy, @@ -975,77 +1047,56 @@ fn fold_field_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::FieldElementExpression<'ast, T>, ) -> zir::FieldElementExpression<'ast, T> { + let span = e.get_span(); + match e { - typed::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n), + typed::FieldElementExpression::Value(n) => zir::FieldElementExpression::Value(n), typed::FieldElementExpression::Identifier(id) => f .fold_identifier_expression(typed::ConcreteType::FieldElement, id) .pop() .unwrap() .try_into() .unwrap(), - typed::FieldElementExpression::Add(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::FieldElementExpression::Add(box e1, box e2) - } - typed::FieldElementExpression::Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::FieldElementExpression::Sub(box e1, box e2) - } - typed::FieldElementExpression::Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::FieldElementExpression::Mult(box e1, box e2) - } - typed::FieldElementExpression::Div(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::FieldElementExpression::Div(box e1, box e2) - } - typed::FieldElementExpression::Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::FieldElementExpression::Pow(box e1, box e2) - } - typed::FieldElementExpression::Neg(box e) => { - let e = f.fold_field_expression(statements_buffer, e); - - zir::FieldElementExpression::Sub( - box zir::FieldElementExpression::Number(T::zero()), - box e, - ) + typed::FieldElementExpression::Add(e) => { + zir::FieldElementExpression::Add(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::Pos(box e) => f.fold_field_expression(statements_buffer, e), - typed::FieldElementExpression::Xor(box left, box right) => { - let left = f.fold_field_expression(statements_buffer, left); - let right = f.fold_field_expression(statements_buffer, right); - - zir::FieldElementExpression::Xor(box left, box right) + typed::FieldElementExpression::Sub(e) => { + zir::FieldElementExpression::Sub(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::And(box left, box right) => { - let left = f.fold_field_expression(statements_buffer, left); - let right = f.fold_field_expression(statements_buffer, right); - - zir::FieldElementExpression::And(box left, box right) + typed::FieldElementExpression::Mult(e) => { + zir::FieldElementExpression::Mult(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::Or(box left, box right) => { - let left = f.fold_field_expression(statements_buffer, left); - let right = f.fold_field_expression(statements_buffer, right); - - zir::FieldElementExpression::Or(box left, box right) + typed::FieldElementExpression::Div(e) => { + zir::FieldElementExpression::Div(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::LeftShift(box e, box by) => { - let e = f.fold_field_expression(statements_buffer, e); - let by = f.fold_uint_expression(statements_buffer, by); - - zir::FieldElementExpression::LeftShift(box e, box by) + typed::FieldElementExpression::Pow(e) => { + zir::FieldElementExpression::Pow(f.fold_binary_expression(statements_buffer, e)) } - typed::FieldElementExpression::RightShift(box e, box by) => { - let e = f.fold_field_expression(statements_buffer, e); - let by = f.fold_uint_expression(statements_buffer, by); + typed::FieldElementExpression::And(e) => { + zir::FieldElementExpression::And(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::Or(e) => { + zir::FieldElementExpression::Or(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::Xor(e) => { + zir::FieldElementExpression::Xor(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::LeftShift(e) => { + zir::FieldElementExpression::LeftShift(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::RightShift(e) => { + zir::FieldElementExpression::RightShift(f.fold_binary_expression(statements_buffer, e)) + } + typed::FieldElementExpression::Neg(e) => { + let e = f.fold_field_expression(statements_buffer, *e.inner); - zir::FieldElementExpression::RightShift(box e, box by) + zir::FieldElementExpression::sub( + zir::FieldElementExpression::Value(ValueExpression::new(T::zero())), + e, + ) + } + typed::FieldElementExpression::Pos(e) => { + f.fold_field_expression(statements_buffer, *e.inner) } typed::FieldElementExpression::Conditional(c) => f .fold_conditional_expression(statements_buffer, c) @@ -1080,6 +1131,7 @@ fn fold_field_expression<'ast, T: Field>( f.fold_field_expression(statements_buffer, *block.value) } } + .span(span) } // util function to output a boolean expression representing the equality of two lists of ZirExpression. @@ -1088,27 +1140,32 @@ fn fold_field_expression<'ast, T: Field>( fn conjunction_tree<'ast, T: Field>( v: &[zir::ZirExpression<'ast, T>], w: &[zir::ZirExpression<'ast, T>], + span: Option, ) -> zir::BooleanExpression<'ast, T> { assert_eq!(v.len(), w.len()); match v.len() { - 0 => zir::BooleanExpression::Value(true), + 0 => zir::BooleanExpression::value(true), 1 => match (v[0].clone(), w[0].clone()) { (zir::ZirExpression::Boolean(v), zir::ZirExpression::Boolean(w)) => { - zir::BooleanExpression::BoolEq(box v, box w) + zir::BooleanExpression::bool_eq(v, w).span(span) } (zir::ZirExpression::FieldElement(v), zir::ZirExpression::FieldElement(w)) => { - zir::BooleanExpression::FieldEq(box v, box w) + zir::BooleanExpression::field_eq(v, w).span(span) } (zir::ZirExpression::Uint(v), zir::ZirExpression::Uint(w)) => { - zir::BooleanExpression::UintEq(box v, box w) + zir::BooleanExpression::uint_eq(v, w).span(span) } _ => unreachable!(), }, n => { let (x0, y0) = v.split_at(n / 2); let (x1, y1) = w.split_at(n / 2); - zir::BooleanExpression::And(box conjunction_tree(x0, x1), box conjunction_tree(y0, y1)) + zir::BooleanExpression::bitand( + conjunction_tree(x0, x1, span), + conjunction_tree(y0, y1, span), + ) + .span(span) } } } @@ -1116,11 +1173,18 @@ fn conjunction_tree<'ast, T: Field>( fn fold_eq_expression<'ast, T: Field, E: Flatten<'ast, T>>( f: &mut Flattener, statements_buffer: &mut Vec>, - e: typed::EqExpression, + e: zokrates_ast::common::expressions::BinaryExpression< + OpEq, + E, + E, + typed::BooleanExpression<'ast, T>, + >, ) -> zir::BooleanExpression<'ast, T> { + let span = e.get_span(); + let left = e.left.flatten(f, statements_buffer); let right = e.right.flatten(f, statements_buffer); - conjunction_tree(&left, &right) + conjunction_tree(&left, &right, span) } fn fold_boolean_expression<'ast, T: Field>( @@ -1128,6 +1192,8 @@ fn fold_boolean_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::BooleanExpression<'ast, T>, ) -> zir::BooleanExpression<'ast, T> { + let span = e.get_span(); + match e { typed::BooleanExpression::Block(block) => { block @@ -1149,59 +1215,26 @@ fn fold_boolean_expression<'ast, T: Field>( typed::BooleanExpression::StructEq(e) => f.fold_eq_expression(statements_buffer, e), typed::BooleanExpression::TupleEq(e) => f.fold_eq_expression(statements_buffer, e), typed::BooleanExpression::UintEq(e) => f.fold_eq_expression(statements_buffer, e), - typed::BooleanExpression::FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::BooleanExpression::FieldLt(box e1, box e2) - } - typed::BooleanExpression::FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::BooleanExpression::FieldLe(box e1, box e2) - } - typed::BooleanExpression::FieldGt(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::BooleanExpression::FieldLt(box e2, box e1) - } - typed::BooleanExpression::FieldGe(box e1, box e2) => { - let e1 = f.fold_field_expression(statements_buffer, e1); - let e2 = f.fold_field_expression(statements_buffer, e2); - zir::BooleanExpression::FieldLe(box e2, box e1) - } - typed::BooleanExpression::UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::BooleanExpression::UintLt(box e1, box e2) - } - typed::BooleanExpression::UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::BooleanExpression::UintLe(box e1, box e2) - } - typed::BooleanExpression::UintGt(box e1, box e2) => { - let e1 = f.fold_uint_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::BooleanExpression::UintLt(box e2, box e1) - } - typed::BooleanExpression::UintGe(box e1, box e2) => { - let e1 = f.fold_uint_expression(statements_buffer, e1); - let e2 = f.fold_uint_expression(statements_buffer, e2); - zir::BooleanExpression::UintLe(box e2, box e1) - } - typed::BooleanExpression::Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(statements_buffer, e1); - let e2 = f.fold_boolean_expression(statements_buffer, e2); - zir::BooleanExpression::Or(box e1, box e2) - } - typed::BooleanExpression::And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(statements_buffer, e1); - let e2 = f.fold_boolean_expression(statements_buffer, e2); - zir::BooleanExpression::And(box e1, box e2) - } - typed::BooleanExpression::Not(box e) => { - let e = f.fold_boolean_expression(statements_buffer, e); - zir::BooleanExpression::Not(box e) + typed::BooleanExpression::FieldLt(e) => { + zir::BooleanExpression::FieldLt(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::FieldLe(e) => { + zir::BooleanExpression::FieldLe(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::UintLt(e) => { + zir::BooleanExpression::UintLt(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::UintLe(e) => { + zir::BooleanExpression::UintLe(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::Or(e) => { + zir::BooleanExpression::Or(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::And(e) => { + zir::BooleanExpression::And(f.fold_binary_expression(statements_buffer, e)) + } + typed::BooleanExpression::Not(e) => { + zir::BooleanExpression::not(f.fold_boolean_expression(statements_buffer, *e.inner)) } typed::BooleanExpression::Conditional(c) => f .fold_conditional_expression(statements_buffer, c) @@ -1229,6 +1262,7 @@ fn fold_boolean_expression<'ast, T: Field>( .try_into() .unwrap(), } + .span(span) } fn fold_uint_expression<'ast, T: Field>( @@ -1236,8 +1270,10 @@ fn fold_uint_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::UExpression<'ast, T>, ) -> zir::UExpression<'ast, T> { + let span = e.get_span(); f.fold_uint_expression_inner(statements_buffer, e.bitwidth, e.inner) .annotate(e.bitwidth.to_usize()) + .span(span) } fn fold_uint_expression_inner<'ast, T: Field>( @@ -1246,6 +1282,8 @@ fn fold_uint_expression_inner<'ast, T: Field>( bitwidth: UBitwidth, e: typed::UExpressionInner<'ast, T>, ) -> zir::UExpressionInner<'ast, T> { + let span = e.get_span(); + match e { typed::UExpressionInner::Block(block) => { block @@ -1261,95 +1299,83 @@ fn fold_uint_expression_inner<'ast, T: Field>( .unwrap() .into_inner() } - typed::UExpressionInner::Add(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Add(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Add(box left, box right) + zir::UExpression::add(left, right).into_inner() } - typed::UExpressionInner::Sub(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Sub(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Sub(box left, box right) + zir::UExpression::sub(left, right).into_inner() } typed::UExpressionInner::FloorSub(..) => unreachable!(), - typed::UExpressionInner::Mult(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Mult(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Mult(box left, box right) + zir::UExpression::mult(left, right).into_inner() } - typed::UExpressionInner::Div(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Div(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Div(box left, box right) + zir::UExpression::div(left, right).into_inner() } - typed::UExpressionInner::Rem(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Rem(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Rem(box left, box right) + zir::UExpression::rem(left, right).into_inner() } - typed::UExpressionInner::Xor(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Xor(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Xor(box left, box right) + zir::UExpression::xor(left, right).into_inner() } - typed::UExpressionInner::And(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::And(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::And(box left, box right) + zir::UExpression::and(left, right).into_inner() } - typed::UExpressionInner::Or(box left, box right) => { - let left = f.fold_uint_expression(statements_buffer, left); - let right = f.fold_uint_expression(statements_buffer, right); + typed::UExpressionInner::Or(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::Or(box left, box right) + zir::UExpression::or(left, right).into_inner() } - typed::UExpressionInner::LeftShift(box e, box by) => { - let e = f.fold_uint_expression(statements_buffer, e); - - let by = match by.as_inner() { - typed::UExpressionInner::Value(by) => by, - _ => unreachable!("static analysis should have made sure that this is constant"), - }; + typed::UExpressionInner::LeftShift(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::LeftShift(box e, *by as u32) + zir::UExpression::left_shift(left, right).into_inner() } - typed::UExpressionInner::RightShift(box e, box by) => { - let e = f.fold_uint_expression(statements_buffer, e); - - let by = match by.as_inner() { - typed::UExpressionInner::Value(by) => by, - _ => unreachable!("static analysis should have made sure that this is constant"), - }; + typed::UExpressionInner::RightShift(e) => { + let left = f.fold_uint_expression(statements_buffer, *e.left); + let right = f.fold_uint_expression(statements_buffer, *e.right); - zir::UExpressionInner::RightShift(box e, *by as u32) + zir::UExpression::right_shift(left, right).into_inner() } - typed::UExpressionInner::Not(box e) => { - let e = f.fold_uint_expression(statements_buffer, e); - - zir::UExpressionInner::Not(box e) + typed::UExpressionInner::Not(e) => { + zir::UExpression::not(f.fold_uint_expression(statements_buffer, *e.inner)).into_inner() } - typed::UExpressionInner::Neg(box e) => { - let bitwidth = e.bitwidth(); + typed::UExpressionInner::Neg(e) => { + let bitwidth = e.inner.bitwidth(); f.fold_uint_expression( statements_buffer, - typed::UExpressionInner::Value(0).annotate(bitwidth) - e, + typed::UExpression::value(0).annotate(bitwidth) - *e.inner, ) .into_inner() } - typed::UExpressionInner::Pos(box e) => { - let e = f.fold_uint_expression(statements_buffer, e); - - e.into_inner() - } + typed::UExpressionInner::Pos(e) => f + .fold_uint_expression(statements_buffer, *e.inner) + .into_inner(), typed::UExpressionInner::FunctionCall(..) => { unreachable!("function calls should have been removed") } @@ -1382,6 +1408,7 @@ fn fold_uint_expression_inner<'ast, T: Field>( .unwrap() .into_inner(), } + .span(span) } fn fold_function<'ast, T: Field>( @@ -1418,6 +1445,7 @@ fn fold_array_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::ArrayExpression<'ast, T>, ) -> Vec> { + let span = e.get_span(); let size: u32 = e.size().try_into().unwrap(); f.fold_array_expression_inner( statements_buffer, @@ -1425,6 +1453,9 @@ fn fold_array_expression<'ast, T: Field>( size, e.into_inner(), ) + .into_iter() + .map(|e| e.span(span)) + .collect() } fn fold_struct_expression<'ast, T: Field>( @@ -1432,11 +1463,15 @@ fn fold_struct_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::StructExpression<'ast, T>, ) -> Vec> { + let span = e.get_span(); f.fold_struct_expression_inner( statements_buffer, typed::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(), e.into_inner(), ) + .into_iter() + .map(|e| e.span(span)) + .collect() } fn fold_tuple_expression<'ast, T: Field>( @@ -1444,11 +1479,15 @@ fn fold_tuple_expression<'ast, T: Field>( statements_buffer: &mut Vec>, e: typed::TupleExpression<'ast, T>, ) -> Vec> { + let span = e.get_span(); f.fold_tuple_expression_inner( statements_buffer, typed::types::ConcreteTupleType::try_from(e.ty().clone()).unwrap(), e.into_inner(), ) + .into_iter() + .map(|e| e.span(span)) + .collect() } fn fold_program<'ast, T: Field>( @@ -1469,5 +1508,6 @@ fn fold_program<'ast, T: Field>( zir::ZirProgram { main: f.fold_function(main_function), + module_map: p.module_map, } } diff --git a/zokrates_analysis/src/lib.rs b/zokrates_analysis/src/lib.rs index 539fe86cb..2cf0c24fc 100644 --- a/zokrates_analysis/src/lib.rs +++ b/zokrates_analysis/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - //! Module containing static analysis //! //! @file mod.rs diff --git a/zokrates_analysis/src/log_ignorer.rs b/zokrates_analysis/src/log_ignorer.rs index a5a8cd54e..2c8aa5a8c 100644 --- a/zokrates_analysis/src/log_ignorer.rs +++ b/zokrates_analysis/src/log_ignorer.rs @@ -1,4 +1,4 @@ -use zokrates_ast::typed::{folder::*, TypedProgram, TypedStatement}; +use zokrates_ast::typed::{folder::*, LogStatement, TypedProgram, TypedStatement}; use zokrates_field::Field; #[derive(Default)] @@ -11,10 +11,7 @@ impl LogIgnorer { } impl<'ast, T: Field> Folder<'ast, T> for LogIgnorer { - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { - match s { - TypedStatement::Log(..) => vec![], - s => fold_statement(self, s), - } + fn fold_log_statement(&mut self, _: LogStatement<'ast, T>) -> Vec> { + vec![] } } diff --git a/zokrates_analysis/src/out_of_bounds.rs b/zokrates_analysis/src/out_of_bounds.rs index 7cdc88b6c..7ed8613a7 100644 --- a/zokrates_analysis/src/out_of_bounds.rs +++ b/zokrates_analysis/src/out_of_bounds.rs @@ -46,29 +46,29 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for OutOfBoundsChecker { a: TypedAssignee<'ast, T>, ) -> Result, Error> { match a { - TypedAssignee::Select(box array, box index) => { + TypedAssignee::Select(array, index) => { use zokrates_ast::typed::Typed; - let array = self.fold_assignee(array)?; + let array = self.fold_assignee(*array)?; let size = match array.get_type() { Type::Array(array_ty) => match array_ty.size.as_inner() { - UExpressionInner::Value(size) => *size, + UExpressionInner::Value(size) => size.value, _ => unreachable!(), }, _ => unreachable!(), }; match index.as_inner() { - UExpressionInner::Value(i) if i >= &size => Err(Error(format!( + UExpressionInner::Value(i) if i.value >= size => Err(Error(format!( "Out of bounds write to `{}` because `{}` has size {}", - TypedAssignee::Select(box array.clone(), box index), + TypedAssignee::select(array.clone(), *index), array, size ))), - _ => Ok(TypedAssignee::Select( - box self.fold_assignee(array)?, - box self.fold_uint_expression(index)?, + _ => Ok(TypedAssignee::select( + array, + self.fold_uint_expression(*index)?, )), } } diff --git a/zokrates_analysis/src/panic_extractor.rs b/zokrates_analysis/src/panic_extractor.rs index d17675a83..dd2628cd5 100644 --- a/zokrates_analysis/src/panic_extractor.rs +++ b/zokrates_analysis/src/panic_extractor.rs @@ -1,6 +1,11 @@ -use zokrates_ast::zir::{ - folder::*, BooleanExpression, Conditional, ConditionalExpression, ConditionalOrExpression, - FieldElementExpression, RuntimeError, UBitwidth, UExpressionInner, ZirProgram, ZirStatement, +use std::ops::*; +use zokrates_ast::{ + common::{Fold, WithSpan}, + zir::{ + folder::*, BooleanExpression, Conditional, ConditionalExpression, ConditionalOrExpression, + Expr, FieldElementExpression, IfElseStatement, RuntimeError, UBitwidth, UExpression, + UExpressionInner, ZirProgram, ZirStatement, + }, }; use zokrates_field::Field; @@ -18,67 +23,85 @@ impl<'ast, T: Field> PanicExtractor<'ast, T> { } impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { - fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { + fn fold_if_else_statement( + &mut self, + s: IfElseStatement<'ast, T>, + ) -> Vec> { + let condition = self.fold_boolean_expression(s.condition); + let mut consequence_extractor = Self::default(); + let consequence = s + .consequence + .into_iter() + .flat_map(|s| consequence_extractor.fold_statement(s)) + .collect(); + assert!(consequence_extractor.panic_buffer.is_empty()); + let mut alternative_extractor = Self::default(); + let alternative = s + .alternative + .into_iter() + .flat_map(|s| alternative_extractor.fold_statement(s)) + .collect(); + assert!(alternative_extractor.panic_buffer.is_empty()); + + self.panic_buffer + .drain(..) + .chain(std::iter::once(ZirStatement::if_else( + condition, + consequence, + alternative, + ))) + .collect() + } + + fn fold_statement_cases(&mut self, s: ZirStatement<'ast, T>) -> Vec> { match s { - ZirStatement::IfElse(condition, consequence, alternative) => { - let condition = self.fold_boolean_expression(condition); - let mut consequence_extractor = Self::default(); - let consequence = consequence - .into_iter() - .flat_map(|s| consequence_extractor.fold_statement(s)) - .collect(); - assert!(consequence_extractor.panic_buffer.is_empty()); - let mut alternative_extractor = Self::default(); - let alternative = alternative - .into_iter() - .flat_map(|s| alternative_extractor.fold_statement(s)) - .collect(); - assert!(alternative_extractor.panic_buffer.is_empty()); - - self.panic_buffer - .drain(..) - .chain(std::iter::once(ZirStatement::IfElse( - condition, - consequence, - alternative, - ))) - .collect() - } + ZirStatement::IfElse(s) => self.fold_if_else_statement(s), s => { - let s = fold_statement(self, s); + let s = fold_statement_cases(self, s); self.panic_buffer.drain(..).chain(s).collect() } } } - fn fold_field_expression( + fn fold_field_expression_cases( &mut self, e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { + let span = e.get_span(); + match e { - FieldElementExpression::Div(box n, box d) => { - let n = self.fold_field_expression(n); - let d = self.fold_field_expression(d); - self.panic_buffer.push(ZirStatement::Assertion( - BooleanExpression::Not(box BooleanExpression::FieldEq( - box d.clone(), - box FieldElementExpression::Number(T::zero()), - )), - RuntimeError::DivisionByZero, - )); - FieldElementExpression::Div(box n, box d) + FieldElementExpression::Div(e) => { + let n = self.fold_field_expression(*e.left); + let d = self.fold_field_expression(*e.right); + self.panic_buffer.push( + ZirStatement::assertion( + BooleanExpression::not( + BooleanExpression::field_eq( + d.clone().span(span), + FieldElementExpression::value(T::zero()).span(span), + ) + .span(span), + ) + .span(span), + RuntimeError::DivisionByZero, + ) + .span(span), + ); + FieldElementExpression::div(n, d) } - e => fold_field_expression(self, e), + e => fold_field_expression_cases(self, e), } } fn fold_conditional_expression< - E: zokrates_ast::zir::Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>, + E: zokrates_ast::zir::Expr<'ast, T> + Fold + Conditional<'ast, T>, >( &mut self, _: &E::Ty, e: ConditionalExpression<'ast, T, E>, ) -> ConditionalOrExpression<'ast, T, E> { + let span = e.get_span(); + let condition = self.fold_boolean_expression(*e.condition); let mut consequence_extractor = Self::default(); let consequence = e.consequence.fold(&mut consequence_extractor); @@ -89,62 +112,72 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { let alternative_panics: Vec<_> = alternative_extractor.panic_buffer.drain(..).collect(); if !(consequence_panics.is_empty() && alternative_panics.is_empty()) { - self.panic_buffer.push(ZirStatement::IfElse( - condition.clone(), - consequence_panics, - alternative_panics, - )); + self.panic_buffer.push( + ZirStatement::if_else(condition.clone(), consequence_panics, alternative_panics) + .span(span), + ); } - ConditionalOrExpression::Conditional(ConditionalExpression::new( - condition, - consequence, - alternative, - )) + ConditionalOrExpression::Conditional( + ConditionalExpression::new(condition, consequence, alternative).span(span), + ) } - fn fold_uint_expression_inner( + fn fold_uint_expression_cases( &mut self, b: UBitwidth, e: UExpressionInner<'ast, T>, ) -> UExpressionInner<'ast, T> { + let span = e.get_span(); + match e { - UExpressionInner::Div(box n, box d) => { - let n = self.fold_uint_expression(n); - let d = self.fold_uint_expression(d); - self.panic_buffer.push(ZirStatement::Assertion( - BooleanExpression::Not(box BooleanExpression::UintEq( - box d.clone(), - box UExpressionInner::Value(0).annotate(b), - )), - RuntimeError::DivisionByZero, - )); - UExpressionInner::Div(box n, box d) + UExpressionInner::Div(e) => { + let n = self.fold_uint_expression(*e.left); + let d = self.fold_uint_expression(*e.right); + self.panic_buffer.push( + ZirStatement::assertion( + BooleanExpression::not( + BooleanExpression::uint_eq( + d.clone().span(span), + UExpression::value(0).annotate(b).span(span), + ) + .span(span), + ) + .span(span), + RuntimeError::DivisionByZero, + ) + .span(span), + ); + UExpression::div(n, d).into_inner() } - e => fold_uint_expression_inner(self, b, e), + e => fold_uint_expression_cases(self, b, e), } } - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { // constant range checks are complete, so no panic needs to be extracted - e @ BooleanExpression::FieldLt(box FieldElementExpression::Number(_), _) - | e @ BooleanExpression::FieldLt(_, box FieldElementExpression::Number(_)) => { - fold_boolean_expression(self, e) + BooleanExpression::FieldLt(b) + if matches!(b.left.as_ref(), FieldElementExpression::Value(_)) + || matches!(b.right.as_ref(), FieldElementExpression::Value(_)) => + { + fold_boolean_expression_cases(self, BooleanExpression::FieldLt(b)) } - BooleanExpression::FieldLt(box left, box right) => { - let left = self.fold_field_expression(left); - let right = self.fold_field_expression(right); + BooleanExpression::FieldLt(e) => { + let span = e.get_span(); + + let left = self.fold_field_expression(*e.left); + let right = self.fold_field_expression(*e.right); let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to field elements whose difference is strictly smaller than 2**(bitwidth - 2) - let offset = FieldElementExpression::Number(T::from(2).pow(safe_width)); - let max = FieldElementExpression::Number(T::from(2).pow(safe_width + 1)); + let offset = FieldElementExpression::number(T::from(2).pow(safe_width)); + let max = FieldElementExpression::number(T::from(2).pow(safe_width + 1)); // `|left - right|` must be of bitwidth at most `safe_bitwidth` // this means we need to guarantee the following: `-2**(safe_width) < left - right < 2**(safe_width)` @@ -153,26 +186,40 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { // we split this check in two: // `2**(safe_width) + left - right < 2**(safe_width + 1)` - self.panic_buffer.push(ZirStatement::Assertion( - BooleanExpression::FieldLt( - box FieldElementExpression::Add( - box offset.clone(), - box FieldElementExpression::Sub(box left.clone(), box right.clone()), - ), - box max, - ), - RuntimeError::IncompleteDynamicRange, - )); + self.panic_buffer.push( + ZirStatement::assertion( + BooleanExpression::field_lt( + offset.clone().span(span) + + FieldElementExpression::sub(left.clone(), right.clone()) + .span(span), + max, + ) + .span(span), + RuntimeError::IncompleteDynamicRange, + ) + .span(span), + ); // and // `2**(safe_width) + left - right != 0` - self.panic_buffer.push(ZirStatement::Assertion( - BooleanExpression::Not(box BooleanExpression::FieldEq( - box FieldElementExpression::Sub(box right.clone(), box left.clone()), - box offset, - )), - RuntimeError::IncompleteDynamicRange, - )); + self.panic_buffer.push( + ZirStatement::assertion( + BooleanExpression::not( + BooleanExpression::field_eq( + FieldElementExpression::sub( + right.clone().span(span), + left.clone().span(span), + ) + .span(span), + offset.span(span), + ) + .span(span), + ) + .span(span), + RuntimeError::IncompleteDynamicRange, + ) + .span(span), + ); // NOTE: // instead of splitting the check in two, we could have used a single `Lt` here, by simply subtracting 1 from all sides: @@ -182,9 +229,9 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> { // if we use `x - 1` here, we end up having to calculate the bits of both `x` and `x - 1`, which is expensive // by splitting, we can reuse the bits of `x` needed for this completeness check when computing the result - BooleanExpression::FieldLt(box left, box right) + BooleanExpression::field_lt(left, right) } - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } } diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index 7d77e86db..4fc9e6651 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -12,8 +12,13 @@ use num_bigint::BigUint; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::fmt; +use std::ops::*; use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub}; -use zokrates_ast::common::FlatEmbed; +use zokrates_ast::common::expressions::{ + BinaryExpression, BinaryOrExpression, EqExpression, ValueExpression, +}; +use zokrates_ast::common::operators::OpEq; +use zokrates_ast::common::{FlatEmbed, ResultFold, WithSpan}; use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::types::Type; use zokrates_ast::typed::*; @@ -68,30 +73,28 @@ impl<'ast, T: Field> Propagator<'ast, T> { .get_mut(&var.id) .map(|c| Ok((var, c))) .unwrap_or(Err(var)), - TypedAssignee::Select(box assignee, box index) => { - match self.try_get_constant_mut(assignee) { - Ok((variable, constant)) => match index.as_inner() { - UExpressionInner::Value(n) => match constant { - TypedExpression::Array(a) => match a.as_inner_mut() { - ArrayExpressionInner::Value(value) => { - match value.0.get_mut(*n as usize) { - Some(TypedExpressionOrSpread::Expression(ref mut e)) => { - Ok((variable, e)) - } - None => Err(variable), - _ => unreachable!(), + TypedAssignee::Select(assignee, index) => match self.try_get_constant_mut(assignee) { + Ok((variable, constant)) => match index.as_inner() { + UExpressionInner::Value(n) => match constant { + TypedExpression::Array(a) => match a.as_inner_mut() { + ArrayExpressionInner::Value(value) => { + match value.value.get_mut(n.value as usize) { + Some(TypedExpressionOrSpread::Expression(ref mut e)) => { + Ok((variable, e)) } + None => Err(variable), + _ => unreachable!(), } - _ => unreachable!("should be an array value"), - }, - _ => unreachable!("should be an array expression"), + } + _ => unreachable!("should be an array value"), }, - _ => Err(variable), + _ => unreachable!("should be an array expression"), }, - e => e, - } - } - TypedAssignee::Member(box assignee, m) => match self.try_get_constant_mut(assignee) { + _ => Err(variable), + }, + e => e, + }, + TypedAssignee::Member(assignee, m) => match self.try_get_constant_mut(assignee) { Ok((v, c)) => { let ty = assignee.get_type(); @@ -106,7 +109,7 @@ impl<'ast, T: Field> Propagator<'ast, T> { match c { TypedExpression::Struct(a) => match a.as_inner_mut() { - StructExpressionInner::Value(value) => Ok((v, &mut value[index])), + StructExpressionInner::Value(value) => Ok((v, &mut value.value[index])), _ => unreachable!("should be a struct value"), }, _ => unreachable!("should be a struct expression"), @@ -114,20 +117,18 @@ impl<'ast, T: Field> Propagator<'ast, T> { } e => e, }, - TypedAssignee::Element(box assignee, index) => { - match self.try_get_constant_mut(assignee) { - Ok((v, c)) => match c { - TypedExpression::Tuple(a) => match a.as_inner_mut() { - TupleExpressionInner::Value(value) => { - Ok((v, &mut value[*index as usize])) - } - _ => unreachable!("should be a tuple value"), - }, - _ => unreachable!("should be a tuple expression"), + TypedAssignee::Element(assignee, index) => match self.try_get_constant_mut(assignee) { + Ok((v, c)) => match c { + TypedExpression::Tuple(a) => match a.as_inner_mut() { + TupleExpressionInner::Value(value) => { + Ok((v, &mut value.value[*index as usize])) + } + _ => unreachable!("should be a tuple value"), }, - e => e, - } - } + _ => unreachable!("should be a tuple expression"), + }, + e => e, + }, } } } @@ -150,7 +151,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } }) .collect::>()?, - main: p.main, + ..p }) } @@ -168,7 +169,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } fn fold_conditional_expression< - E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq + ResultFold<'ast, T>, + E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq + ResultFold, >( &mut self, _: &E::Ty, @@ -180,10 +181,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { e.consequence.fold(self)?, e.alternative.fold(self)?, ) { - (BooleanExpression::Value(true), consequence, _) => { + (BooleanExpression::Value(v), consequence, _) if v.value => { ConditionalOrExpression::Expression(consequence.into_inner()) } - (BooleanExpression::Value(false), _, alternative) => { + (BooleanExpression::Value(v), _, alternative) if !v.value => { ConditionalOrExpression::Expression(alternative.into_inner()) } (_, consequence, alternative) if consequence == alternative => { @@ -196,105 +197,93 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { ) } - fn fold_assembly_statement( + fn fold_assembly_assignment( &mut self, - s: TypedAssemblyStatement<'ast, T>, + s: AssemblyAssignment<'ast, T>, ) -> Result>, Self::Error> { - match s { - TypedAssemblyStatement::Assignment(assignee, expr) => { - let assignee = self.fold_assignee(assignee)?; - let expr = self.fold_expression(expr)?; + let assignee = self.fold_assignee(s.assignee)?; + let expr = self.fold_expression(s.expression)?; - if expr.is_constant() { - match assignee { - TypedAssignee::Identifier(var) => { - let expr = expr.into_canonical_constant(); + if expr.is_constant() { + match assignee { + TypedAssignee::Identifier(var) => { + let expr = expr.into_canonical_constant(); - assert!(self.constants.insert(var.id, expr).is_none()); + assert!(self.constants.insert(var.id, expr).is_none()); - Ok(vec![]) - } - assignee => match self.try_get_constant_mut(&assignee) { - Ok((_, c)) => { - *c = expr.into_canonical_constant(); - Ok(vec![]) - } - Err(v) => match self.constants.remove(&v.id) { - // invalidate the cache for this identifier, and define the latest - // version of the constant in the program, if any - Some(c) => Ok(vec![ - TypedAssemblyStatement::Assignment(v.clone().into(), c), - TypedAssemblyStatement::Assignment(assignee, expr), - ]), - None => { - Ok(vec![TypedAssemblyStatement::Assignment(assignee, expr)]) - } - }, - }, + Ok(vec![]) + } + assignee => match self.try_get_constant_mut(&assignee) { + Ok((_, c)) => { + *c = expr.into_canonical_constant(); + Ok(vec![]) } - } else { - // the expression being assigned is not constant, invalidate the cache - let v = self - .try_get_constant_mut(&assignee) - .map(|(v, _)| v) - .unwrap_or_else(|v| v); - - match self.constants.remove(&v.id) { + Err(v) => match self.constants.remove(&v.id) { + // invalidate the cache for this identifier, and define the latest + // version of the constant in the program, if any Some(c) => Ok(vec![ - TypedAssemblyStatement::Assignment(v.clone().into(), c), - TypedAssemblyStatement::Assignment(assignee, expr), + TypedAssemblyStatement::assignment(v.clone().into(), c), + TypedAssemblyStatement::assignment(assignee, expr), ]), - None => Ok(vec![TypedAssemblyStatement::Assignment(assignee, expr)]), - } - } + None => Ok(vec![TypedAssemblyStatement::assignment(assignee, expr)]), + }, + }, } - TypedAssemblyStatement::Constraint(left, right, metadata) => { - let left = self.fold_field_expression(left)?; - let right = self.fold_field_expression(right)?; - - // a bit hacky, but we use a fake boolean expression to check this - let is_equal = - BooleanExpression::FieldEq(EqExpression::new(left.clone(), right.clone())); - let is_equal = self.fold_boolean_expression(is_equal)?; - - match is_equal { - BooleanExpression::Value(true) => Ok(vec![]), - BooleanExpression::Value(false) => { - Err(Error::AssertionFailed(RuntimeError::SourceAssertion( - metadata - .message(Some(format!("In asm block: `{} !== {}`", left, right))), - ))) - } - _ => Ok(vec![TypedAssemblyStatement::Constraint( - left, right, metadata, - )]), - } + } else { + // the expression being assigned is not constant, invalidate the cache + let v = self + .try_get_constant_mut(&assignee) + .map(|(v, _)| v) + .unwrap_or_else(|v| v); + + match self.constants.remove(&v.id) { + Some(c) => Ok(vec![ + TypedAssemblyStatement::assignment(v.clone().into(), c), + TypedAssemblyStatement::assignment(assignee, expr), + ]), + None => Ok(vec![TypedAssemblyStatement::assignment(assignee, expr)]), } } } - fn fold_statement( + fn fold_assembly_constraint( &mut self, - s: TypedStatement<'ast, T>, - ) -> Result>, Error> { - match s { - TypedStatement::Assembly(statements) => { - let statements: Vec<_> = statements - .into_iter() - .map(|s| self.fold_assembly_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(); - match statements.len() { - 0 => Ok(vec![]), - _ => Ok(vec![TypedStatement::Assembly(statements)]), - } + s: AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + let span = s.get_span(); + + let left = self.fold_field_expression(s.left)?; + let right = self.fold_field_expression(s.right)?; + + // a bit hacky, but we use a fake boolean expression to check this + let is_equal = BooleanExpression::field_eq(left.clone(), right.clone()).span(span); + let is_equal = self.fold_boolean_expression(is_equal)?; + + match is_equal { + BooleanExpression::Value(v) if v.value => Ok(vec![]), + BooleanExpression::Value(v) if !v.value => { + Err(Error::AssertionFailed(RuntimeError::SourceAssertion( + s.metadata + .message(Some(format!("In asm block: `{} !== {}`", left, right))), + ))) } + _ => Ok(vec![TypedAssemblyStatement::constraint( + left, right, s.metadata, + )]), + } + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + let span = s.get_span(); + + match s.rhs { // propagation to the defined variable if rhs is a constant - TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => { - let assignee = self.fold_assignee(assignee)?; - let expr = self.fold_expression(expr)?; + DefinitionRhs::Expression(e) => { + let assignee = self.fold_assignee(s.assignee)?; + let expr = self.fold_expression(e)?; if let (Ok(a), Ok(e)) = ( ConcreteType::try_from(assignee.get_type()), @@ -326,10 +315,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { // invalidate the cache for this identifier, and define the latest // version of the constant in the program, if any Some(c) => Ok(vec![ - TypedStatement::Definition(v.clone().into(), c.into()), - TypedStatement::Definition(assignee, expr.into()), + TypedStatement::definition(v.clone().into(), c), + TypedStatement::definition(assignee, expr), ]), - None => Ok(vec![TypedStatement::Definition(assignee, expr.into())]), + None => Ok(vec![TypedStatement::definition(assignee, expr)]), }, }, } @@ -342,23 +331,16 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { match self.constants.remove(&v.id) { Some(c) => Ok(vec![ - TypedStatement::Definition(v.clone().into(), c.into()), - TypedStatement::Definition(assignee, expr.into()), + TypedStatement::definition(v.clone().into(), c), + TypedStatement::definition(assignee, expr), ]), - None => Ok(vec![TypedStatement::Definition(assignee, expr.into())]), + None => Ok(vec![TypedStatement::definition(assignee, expr)]), } } } - // we do not visit the for-loop statements - TypedStatement::For(v, from, to, statements) => { - let from = self.fold_uint_expression(from)?; - let to = self.fold_uint_expression(to)?; - - Ok(vec![TypedStatement::For(v, from, to, statements)]) - } - TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => { - let assignee = self.fold_assignee(assignee)?; - let embed_call = self.fold_embed_call(embed_call)?; + DefinitionRhs::EmbedCall(e) => { + let assignee = self.fold_assignee(s.assignee)?; + let embed_call = self.fold_embed_call(e)?; fn process_u_from_bits<'ast, T: Field>( arguments: &[TypedExpression<'ast, T>], @@ -374,7 +356,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { .into_inner() { ArrayExpressionInner::Value(v) => - UExpressionInner::Value( + UExpression::value( v.into_iter() .map(|v| match v { TypedExpressionOrSpread::Expression( @@ -386,7 +368,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { }) .enumerate() .fold(0, |acc, (i, v)| { - if v { + if v.value { acc + 2u128.pow( (bitwidth.to_usize() - i - 1) .try_into() @@ -414,7 +396,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { .into_inner() { UExpressionInner::Value(v) => { - let mut num = v; + let mut num = v.value; let mut res = vec![]; for i in (0..bitwidth as u32).rev() { @@ -427,13 +409,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } assert_eq!(num, 0); - ArrayExpressionInner::Value( + ArrayExpression::value( res.into_iter() - .map(|v| BooleanExpression::Value(v).into()) - .collect::>() - .into(), + .map(|v| BooleanExpression::value(v).into()) + .collect::>(), ) - .annotate(Type::Boolean, bitwidth.to_usize() as u32) + .annotate(ArrayType::new(Type::Boolean, bitwidth.to_usize() as u32)) .into() } _ => unreachable!("should be a uint value"), @@ -448,13 +429,17 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { match FieldElementExpression::try_from_typed( embed_call.arguments[0].clone(), ) { - Ok(FieldElementExpression::Number(n)) if n == T::from(0) => { - Ok(Some(BooleanExpression::Value(false).into())) + Ok(FieldElementExpression::Value(n)) + if n.value == T::from(0) => + { + Ok(Some(BooleanExpression::value(false).span(span).into())) } - Ok(FieldElementExpression::Number(n)) if n == T::from(1) => { - Ok(Some(BooleanExpression::Value(true).into())) + Ok(FieldElementExpression::Value(n)) + if n.value == T::from(1) => + { + Ok(Some(BooleanExpression::value(true).span(span).into())) } - Ok(FieldElementExpression::Number(n)) => { + Ok(FieldElementExpression::Value(n)) => { Err(Error::InvalidValue(format!( "Cannot call `{}` with value `{}`: should be 0 or 1", embed_call.embed.id(), @@ -507,8 +492,8 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { ) .unwrap() { - FieldElementExpression::Number(num) => { - let mut acc = num.clone(); + FieldElementExpression::Value(num) => { + let mut acc = num.value; let mut res = vec![]; for i in (0..bit_width as usize).rev() { @@ -528,13 +513,17 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { ))) } else { Ok(Some( - ArrayExpressionInner::Value( + ArrayExpression::value( res.into_iter() - .map(|v| BooleanExpression::Value(v).into()) - .collect::>() - .into(), + .map(|v| { + BooleanExpression::value(v) + .span(span) + .into() + }) + .collect::>(), ) - .annotate(Type::Boolean, bit_width) + .annotate(ArrayType::new(Type::Boolean, bit_width)) + .span(span) .into(), )) } @@ -562,11 +551,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } Err(v) => match self.constants.remove(&v.id) { Some(c) => vec![ - TypedStatement::Definition(v.clone().into(), c.into()), - TypedStatement::Definition(assignee, expr.into()), + TypedStatement::definition(v.clone().into(), c), + TypedStatement::definition(assignee, expr), ], None => { - vec![TypedStatement::Definition(assignee, expr.into())] + vec![TypedStatement::definition(assignee, expr)] } }, }, @@ -582,12 +571,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { match self.constants.remove(&v.id) { Some(c) => vec![ - TypedStatement::Definition(v.clone().into(), c.into()), - TypedStatement::Definition(assignee, embed_call.into()), + TypedStatement::definition(v.clone().into(), c), + TypedStatement::embed_call_definition(assignee, embed_call), ], - None => vec![TypedStatement::Definition( - assignee, - embed_call.into(), + None => vec![TypedStatement::embed_call_definition( + assignee, embed_call, )], } } @@ -596,7 +584,8 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { false => { // if the function arguments are not constant, invalidate the cache // for the return assignees - let def = TypedStatement::Definition(assignee.clone(), embed_call.into()); + let def = + TypedStatement::embed_call_definition(assignee.clone(), embed_call); let v = self .try_get_constant_mut(&assignee) @@ -605,401 +594,447 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { Ok(match self.constants.remove(&v.id) { Some(c) => { - vec![TypedStatement::Definition(v.clone().into(), c.into()), def] + vec![TypedStatement::definition(v.clone().into(), c), def] } None => vec![def], }) } } } - TypedStatement::Assertion(e, err) => { - let expr = self.fold_boolean_expression(e)?; - match expr { - BooleanExpression::Value(false) => Err(Error::AssertionFailed(err)), - BooleanExpression::Value(true) => Ok(vec![]), - _ => Ok(vec![TypedStatement::Assertion(expr, err)]), - } - } - s => fold_statement(self, s), } } - fn fold_uint_expression_inner( + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Result>, Self::Error> { + Ok(fold_assembly_block(self, s)? + .into_iter() + .filter(|s| match s { + TypedStatement::Assembly(s) => !s.inner.is_empty(), + _ => true, + }) + .collect()) + } + + fn fold_for_statement( + &mut self, + s: ForStatement<'ast, T>, + ) -> Result>, Self::Error> { + // we do not visit the for-loop statements + let from = self.fold_uint_expression(s.from)?; + let to = self.fold_uint_expression(s.to)?; + + Ok(vec![TypedStatement::for_(s.var, from, to, s.statements)]) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Result>, Self::Error> { + let _e_str = s.expression.to_string(); + let expr = self.fold_boolean_expression(s.expression)?; + + match expr { + BooleanExpression::Value(v) if !v.value => Err(Error::AssertionFailed(s.error)), + BooleanExpression::Value(v) if v.value => Ok(vec![]), + _ => Ok(vec![TypedStatement::assertion(expr, s.error)]), + } + } + + fn fold_uint_expression_cases( &mut self, bitwidth: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, Error> { match e { - UExpressionInner::Add(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Add(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1 + v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value + v2.value) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v { - 0 => Ok(e), - _ => Ok(UExpressionInner::Add( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), - }, - (e1, e2) => Ok(UExpressionInner::Add( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => { + match v.value { + 0 => Ok(e), + _ => Ok(UExpression::add( + e.annotate(bitwidth), + UExpression::value(v.value).annotate(bitwidth), + ) + .into_inner()), + } + } + (e1, e2) => { + Ok(UExpression::add(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::Sub(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Sub(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1.wrapping_sub(v2)) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value.wrapping_sub(v2.value)) + % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) => match v { + (e, UExpressionInner::Value(v)) => match v.value { 0 => Ok(e), - _ => Ok(UExpressionInner::Sub( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), + _ => Ok(UExpression::sub( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), }, - (e1, e2) => Ok(UExpressionInner::Sub( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => { + Ok(UExpression::sub(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::FloorSub(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::FloorSub(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - v1.saturating_sub(v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + v1.value.saturating_sub(v2.value) + % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) => match v { + (e, UExpressionInner::Value(v)) => match v.value { 0 => Ok(e), - _ => Ok(UExpressionInner::FloorSub( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), + _ => Ok(UExpression::floor_sub( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), }, - (e1, e2) => Ok(UExpressionInner::Sub( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => { + Ok(UExpression::sub(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::Mult(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Mult(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1 * v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value * v2.value) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v { - 0 => Ok(UExpressionInner::Value(0)), - 1 => Ok(e), - _ => Ok(UExpressionInner::Mult( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), - }, - (e1, e2) => Ok(UExpressionInner::Mult( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => { + match v.value { + 0 => Ok(UExpression::value(0)), + 1 => Ok(e), + _ => Ok(UExpression::mul( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), + } + } + (e1, e2) => { + Ok(UExpression::mul(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::Div(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Div(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1 / v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value / v2.value) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) => match v { + (e, UExpressionInner::Value(v)) => match v.value { 1 => Ok(e), - _ => Ok(UExpressionInner::Div( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), + _ => Ok(UExpression::div( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), }, - (e1, e2) => Ok(UExpressionInner::Div( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => { + Ok(UExpression::div(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::Rem(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Rem(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value( - (v1 % v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + Ok(UExpression::value( + (v1.value % v2.value) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )) } - (e, UExpressionInner::Value(v)) => match v { - 1 => Ok(UExpressionInner::Value(0)), - _ => Ok(UExpressionInner::Rem( - box e.annotate(bitwidth), - box UExpressionInner::Value(v).annotate(bitwidth), - )), + (e, UExpressionInner::Value(v)) => match v.value { + 1 => Ok(UExpression::value(0)), + _ => Ok(UExpression::rem( + e.annotate(bitwidth), + UExpressionInner::Value(v).annotate(bitwidth), + ) + .into_inner()), }, - (e1, e2) => Ok(UExpressionInner::Rem( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => { + Ok(UExpression::rem(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) + } }, - UExpressionInner::RightShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e.into_inner(), by.into_inner()) { + UExpressionInner::RightShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + match (left.into_inner(), right.into_inner()) { (UExpressionInner::Value(v), UExpressionInner::Value(by)) => { - Ok(UExpressionInner::Value(v >> by)) + Ok(UExpression::value(v.value >> by.value)) } - (e, by) => Ok(UExpressionInner::RightShift( - box e.annotate(bitwidth), - box by.annotate(UBitwidth::B32), - )), + (e, by) => Ok(UExpression::right_shift( + e.annotate(bitwidth), + by.annotate(UBitwidth::B32), + ) + .into_inner()), } } - UExpressionInner::LeftShift(box e, box by) => { - let e = self.fold_uint_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e.into_inner(), by.into_inner()) { - (UExpressionInner::Value(v), UExpressionInner::Value(by)) => Ok( - UExpressionInner::Value((v << by) & (2_u128.pow(bitwidth as u32) - 1)), - ), - (e, by) => Ok(UExpressionInner::LeftShift( - box e.annotate(bitwidth), - box by.annotate(UBitwidth::B32), - )), + UExpressionInner::LeftShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + match (left.into_inner(), right.into_inner()) { + (UExpressionInner::Value(v), UExpressionInner::Value(by)) => { + Ok(UExpression::value( + (v.value << by.value) & (2_u128.pow(bitwidth as u32) - 1), + )) + } + (e, by) => Ok(UExpression::left_shift( + e.annotate(bitwidth), + by.annotate(UBitwidth::B32), + ) + .into_inner()), } } - UExpressionInner::Xor(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::Xor(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value(v1 ^ v2)) + Ok(UExpression::value(v1.value ^ v2.value)) + } + (UExpressionInner::Value(v), e2) | (e2, UExpressionInner::Value(v)) + if v.value == 0 => + { + Ok(e2) } - (UExpressionInner::Value(0), e2) => Ok(e2), - (e1, UExpressionInner::Value(0)) => Ok(e1), (e1, e2) => { if e1 == e2 { - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) } else { - Ok(UExpressionInner::Xor( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )) + Ok( + UExpression::xor(e1.annotate(bitwidth), e2.annotate(bitwidth)) + .into_inner(), + ) } } }, - UExpressionInner::And(box e1, box e2) => match ( - self.fold_uint_expression(e1)?.into_inner(), - self.fold_uint_expression(e2)?.into_inner(), + UExpressionInner::And(e) => match ( + self.fold_uint_expression(*e.left)?.into_inner(), + self.fold_uint_expression(*e.right)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(UExpressionInner::Value(v1 & v2)) + Ok(UExpression::value(v1.value & v2.value)) + } + (UExpressionInner::Value(v), _) | (_, UExpressionInner::Value(v)) + if v.value == 0 => + { + Ok(UExpression::value(0)) } - (UExpressionInner::Value(0), _) | (_, UExpressionInner::Value(0)) => { - Ok(UExpressionInner::Value(0)) + (e1, e2) => { + Ok(UExpression::and(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner()) } - (e1, e2) => Ok(UExpressionInner::And( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), }, - UExpressionInner::Not(box e) => { - let e = self.fold_uint_expression(e)?.into_inner(); + UExpressionInner::Not(e) => { + let e = self.fold_uint_expression(*e.inner)?.into_inner(); match e { - UExpressionInner::Value(v) => Ok(UExpressionInner::Value( - (!v) & (2_u128.pow(bitwidth as u32) - 1), + UExpressionInner::Value(v) => Ok(UExpression::value( + (!v.value) & (2_u128.pow(bitwidth as u32) - 1), )), - e => Ok(UExpressionInner::Not(box e.annotate(bitwidth))), + e => Ok(UExpression::not(e.annotate(bitwidth)).into_inner()), } } - UExpressionInner::Neg(box e) => { - let e = self.fold_uint_expression(e)?.into_inner(); + UExpressionInner::Neg(e) => { + let e = self.fold_uint_expression(*e.inner)?.into_inner(); match e { - UExpressionInner::Value(v) => Ok(UExpressionInner::Value( - (0u128.wrapping_sub(v)) + UExpressionInner::Value(v) => Ok(UExpression::value( + (0u128.wrapping_sub(v.value)) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), )), - e => Ok(UExpressionInner::Neg(box e.annotate(bitwidth))), + e => Ok(UExpression::neg(e.annotate(bitwidth)).into_inner()), } } - UExpressionInner::Pos(box e) => { - let e = self.fold_uint_expression(e)?.into_inner(); + UExpressionInner::Pos(e) => { + let e = self.fold_uint_expression(*e.inner)?.into_inner(); match e { - UExpressionInner::Value(v) => Ok(UExpressionInner::Value(v)), - e => Ok(UExpressionInner::Pos(box e.annotate(bitwidth))), + UExpressionInner::Value(v) => Ok(UExpression::value(v.value)), + e => Ok(UExpression::pos(e.annotate(bitwidth)).into_inner()), } } - e => fold_uint_expression_inner(self, bitwidth, e), + e => fold_uint_expression_cases(self, bitwidth, e), } } - fn fold_field_expression( + fn fold_field_expression_cases( &mut self, e: FieldElementExpression<'ast, T>, ) -> Result, Error> { match e { - FieldElementExpression::Add(box e1, box e2) => match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 + n2)) - } - (e1, e2) => Ok(FieldElementExpression::Add(box e1, box e2)), - }, - FieldElementExpression::Sub(box e1, box e2) => match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 - n2)) - } - (e1, e2) => Ok(FieldElementExpression::Sub(box e1, box e2)), - }, - FieldElementExpression::Mult(box e1, box e2) => match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 * n2)) - } - (e1, e2) => Ok(FieldElementExpression::Mult(box e1, box e2)), - }, - FieldElementExpression::Div(box e1, box e2) => match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 / n2)) - } - (e1, e2) => Ok(FieldElementExpression::Div(box e1, box e2)), - }, - FieldElementExpression::Neg(box e) => match self.fold_field_expression(e)? { - FieldElementExpression::Number(n) => { - Ok(FieldElementExpression::Number(T::zero() - n)) + FieldElementExpression::Add(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::Value(ValueExpression::new(n1.value + n2.value)) + } + (e1, e2) => e1 + e2, + }) + } + FieldElementExpression::Sub(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::Value(ValueExpression::new(n1.value - n2.value)) + } + (e1, e2) => e1 - e2, + }) + } + FieldElementExpression::Mult(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::Value(ValueExpression::new(n1.value * n2.value)) + } + (e1, e2) => e1 * e2, + }) + } + FieldElementExpression::Div(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::Value(ValueExpression::new(n1.value / n2.value)) + } + (e1, e2) => e1 / e2, + }) + } + FieldElementExpression::Neg(e) => match self.fold_field_expression(*e.inner)? { + FieldElementExpression::Value(n) => { + Ok(FieldElementExpression::value(T::zero() - n.value)) } - e => Ok(FieldElementExpression::Neg(box e)), + e => Ok(FieldElementExpression::neg(e)), }, - FieldElementExpression::Pos(box e) => match self.fold_field_expression(e)? { - FieldElementExpression::Number(n) => Ok(FieldElementExpression::Number(n)), - e => Ok(FieldElementExpression::Pos(box e)), + FieldElementExpression::Pos(e) => match self.fold_field_expression(*e.inner)? { + FieldElementExpression::Value(n) => Ok(FieldElementExpression::Value(n)), + e => Ok(FieldElementExpression::pos(e)), }, - FieldElementExpression::Pow(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + FieldElementExpression::Pow(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1, e2.into_inner()) { - (_, UExpressionInner::Value(ref n2)) if *n2 == 0 => { - Ok(FieldElementExpression::Number(T::from(1))) + (_, UExpressionInner::Value(ref n2)) if n2.value == 0 => { + Ok(FieldElementExpression::value(T::from(1))) } - (FieldElementExpression::Number(n1), UExpressionInner::Value(n2)) => { - Ok(FieldElementExpression::Number(n1.pow(n2 as usize))) - } - (e1, UExpressionInner::Value(n2)) => Ok(FieldElementExpression::Pow( - box e1, - box UExpressionInner::Value(n2).annotate(UBitwidth::B32), - )), - (e1, e2) => Ok(FieldElementExpression::Pow( - box e1, - box e2.annotate(UBitwidth::B32), - )), + (FieldElementExpression::Value(n1), UExpressionInner::Value(n2)) => Ok( + FieldElementExpression::value(n1.value.pow(n2.value as usize)), + ), + (e1, e2) => Ok(FieldElementExpression::pow(e1, e2.annotate(UBitwidth::B32))), } } - FieldElementExpression::Xor(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::Xor(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitxor(n2.value.to_biguint())) + .unwrap(), )) } - (FieldElementExpression::Number(n), e) - | (e, FieldElementExpression::Number(n)) - if n == T::from(0) => + (FieldElementExpression::Value(n), e) + | (e, FieldElementExpression::Value(n)) + if n.value == T::from(0) => { Ok(e) } - (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))), - (e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)), + (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::value(T::from(0))), + (e1, e2) => Ok(FieldElementExpression::bitxor(e1, e2)), } } - - FieldElementExpression::And(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::And(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (_, FieldElementExpression::Number(n)) - | (FieldElementExpression::Number(n), _) - if n == T::from(0) => + (_, FieldElementExpression::Value(n)) + | (FieldElementExpression::Value(n), _) + if n.value == T::from(0) => { - Ok(FieldElementExpression::Number(n)) + Ok(FieldElementExpression::Value(n)) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitand(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitand(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) => Ok(FieldElementExpression::And(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::bitand(e1, e2)), } } - FieldElementExpression::Or(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::Or(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (e, FieldElementExpression::Number(n)) - | (FieldElementExpression::Number(n), e) - if n == T::from(0) => + (e, FieldElementExpression::Value(n)) + | (FieldElementExpression::Value(n), e) + if n.value == T::from(0) => { Ok(e) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitor(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitor(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::bitor(e1, e2)), } } - FieldElementExpression::LeftShift(box e, box by) => { - let e = self.fold_field_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e, by) { + FieldElementExpression::LeftShift(e) => { + let expr = self.fold_field_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; + match (expr, by) { ( e, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by == 0 => Ok(e), + ) if by.value == 0 => Ok(e), ( _, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by as usize >= T::get_required_bits() => { - Ok(FieldElementExpression::Number(T::from(0))) + ) if by.value as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::value(T::from(0))) } ( - FieldElementExpression::Number(n), + FieldElementExpression::Value(n), UExpression { inner: UExpressionInner::Value(by), .. @@ -1008,46 +1043,47 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { let two = BigUint::from(2usize); let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); - Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(), + Ok(FieldElementExpression::value( + T::try_from(n.value.to_biguint().shl(by.value as usize).bitand(mask)) + .unwrap(), )) } - (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), + (e, by) => Ok(FieldElementExpression::left_shift(e, by)), } } - FieldElementExpression::RightShift(box e, box by) => { - let e = self.fold_field_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e, by) { + FieldElementExpression::RightShift(e) => { + let expr = self.fold_field_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; + match (expr, by) { ( e, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by == 0 => Ok(e), + ) if by.value == 0 => Ok(e), ( _, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by as usize >= T::get_required_bits() => { - Ok(FieldElementExpression::Number(T::from(0))) + ) if by.value as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::value(T::from(0))) } ( - FieldElementExpression::Number(n), + FieldElementExpression::Value(n), UExpression { inner: UExpressionInner::Value(by), .. }, - ) => Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shr(by as usize)).unwrap(), + ) => Ok(FieldElementExpression::value( + T::try_from(n.value.to_biguint().shr(by.value as usize)).unwrap(), )), - (e, by) => Ok(FieldElementExpression::RightShift(box e, box by)), + (e, by) => Ok(FieldElementExpression::right_shift(e, by)), } } - e => fold_field_expression(self, e), + e => fold_field_expression_cases(self, e), } } @@ -1128,10 +1164,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { if n < size { Ok(SelectOrExpression::Expression( - v.expression_at::(n as usize).unwrap().into_inner(), + v.expression_at::(n.value as usize).unwrap().into_inner(), )) } else { - Err(Error::OutOfBounds(n, size)) + Err(Error::OutOfBounds(n.value, size.value)) } } (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { @@ -1140,7 +1176,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { TypedExpression::Array(a) => match a.as_inner() { ArrayExpressionInner::Value(v) => { Ok(SelectOrExpression::Expression( - v.expression_at::(n as usize).unwrap().into_inner(), + v.expression_at::(n.value as usize) + .unwrap() + .into_inner(), )) } _ => unreachable!("should be an array value"), @@ -1150,7 +1188,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { None => Ok(SelectOrExpression::Expression( E::select( ArrayExpressionInner::Identifier(id) - .annotate(inner_type, size as u32), + .annotate(ArrayType::new(inner_type, size.value as u32)), UExpressionInner::Value(n).annotate(UBitwidth::B32), ) .into_inner(), @@ -1158,7 +1196,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } } (a, i) => Ok(SelectOrExpression::Select(SelectExpression::new( - a.annotate(inner_type, size as u32), + a.annotate(ArrayType::new(inner_type, size.value as u32)), i.annotate(UBitwidth::B32), ))), }, @@ -1168,7 +1206,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } } - fn fold_array_expression_inner( + fn fold_array_expression_cases( &mut self, ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, @@ -1190,7 +1228,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { inner: ArrayExpressionInner::Value(v), .. }, - }) => v.0, + }) => v.value, e => vec![e], } }) @@ -1214,11 +1252,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { .collect(), )) } - e => fold_array_expression_inner(self, ty, e), + e => fold_array_expression_cases(self, ty, e), } } - fn fold_struct_expression_inner( + fn fold_struct_expression_cases( &mut self, ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, @@ -1240,11 +1278,13 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { Ok(StructExpressionInner::Value(v)) } - e => fold_struct_expression_inner(self, ty, e), + e => fold_struct_expression_cases(self, ty, e), } } - fn fold_identifier_expression + Id<'ast, T> + ResultFold<'ast, T>>( + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + ResultFold, + >( &mut self, _: &E::Ty, id: IdentifierExpression<'ast, E>, @@ -1255,7 +1295,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } } - fn fold_tuple_expression_inner( + fn fold_tuple_expression_cases( &mut self, ty: &TupleType<'ast, T>, e: TupleExpressionInner<'ast, T>, @@ -1277,16 +1317,19 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { Ok(TupleExpressionInner::Value(v)) } - e => fold_tuple_expression_inner(self, ty, e), + e => fold_tuple_expression_cases(self, ty, e), } } fn fold_eq_expression< - E: Expr<'ast, T> + PartialEq + Constant + Typed<'ast, T> + ResultFold<'ast, T>, + E: Expr<'ast, T> + PartialEq + Constant + Typed<'ast, T> + ResultFold, >( &mut self, - e: EqExpression, - ) -> Result, Self::Error> { + e: EqExpression>, + ) -> Result< + BinaryOrExpression, BooleanExpression<'ast, T>>, + Self::Error, + > { let left = e.left.fold(self)?; let right = e.right.fold(self)?; @@ -1305,22 +1348,26 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { // if the two expressions are the same, we can reduce to `true`. // Note that if they are different we cannot reduce to `false`: `a == 1` may still be `true` even though `a` and `1` are different expressions if left == right { - return Ok(EqOrBoolean::Boolean(BooleanExpression::Value(true))); + return Ok(BinaryOrExpression::Expression(BooleanExpression::value( + true, + ))); } // if both expressions are constant, we can reduce the equality check after we put them in canonical form if left.is_constant() && right.is_constant() { let left = left.into_canonical_constant(); let right = right.into_canonical_constant(); - Ok(EqOrBoolean::Boolean(BooleanExpression::Value( + Ok(BinaryOrExpression::Expression(BooleanExpression::value( left == right, ))) } else { - Ok(EqOrBoolean::Eq(EqExpression::new(left, right))) + Ok(BinaryOrExpression::Binary(BinaryExpression::new( + left, right, + ))) } } - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> Result, Error> { @@ -1330,142 +1377,106 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { // These kind of reduction rules are easier to apply later in the process, when we have canonical representations // of expressions, ie `a + a` would always be written `2 * a` match e { - BooleanExpression::FieldLt(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; - - match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 < n2)) - } - (e1, e2) => Ok(BooleanExpression::FieldLt(box e1, box e2)), - } - } - BooleanExpression::FieldLe(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; - - match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 <= n2)) - } - (e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)), - } - } - BooleanExpression::FieldGt(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + BooleanExpression::FieldLt(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 > n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(BooleanExpression::value(n1.value < n2.value)) } - (e1, e2) => Ok(BooleanExpression::FieldGt(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::field_lt(e1, e2)), } } - BooleanExpression::FieldGe(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + BooleanExpression::FieldLe(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 >= n2)) - } - (e1, e2) => Ok(BooleanExpression::FieldGe(box e1, box e2)), - } - } - BooleanExpression::UintLt(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; - - match (e1.as_inner(), e2.as_inner()) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(BooleanExpression::Value(n1 < n2)) - } - _ => Ok(BooleanExpression::UintLt(box e1, box e2)), - } - } - BooleanExpression::UintLe(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; - - match (e1.as_inner(), e2.as_inner()) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(BooleanExpression::Value(n1 <= n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(BooleanExpression::value(n1.value <= n2.value)) } - _ => Ok(BooleanExpression::UintLe(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::field_le(e1, e2)), } } - BooleanExpression::UintGt(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintLt(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(BooleanExpression::Value(n1 > n2)) + Ok(BooleanExpression::value(n1.value < n2.value)) } - _ => Ok(BooleanExpression::UintGt(box e1, box e2)), + _ => Ok(BooleanExpression::uint_lt(e1, e2)), } } - BooleanExpression::UintGe(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintLe(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(BooleanExpression::Value(n1 >= n2)) + Ok(BooleanExpression::value(n1.value <= n2.value)) } - _ => Ok(BooleanExpression::UintGe(box e1, box e2)), + _ => Ok(BooleanExpression::uint_le(e1, e2)), } } - BooleanExpression::Or(box e1, box e2) => { - let e1 = self.fold_boolean_expression(e1)?; - let e2 = self.fold_boolean_expression(e2)?; + BooleanExpression::Or(e) => { + let e1 = self.fold_boolean_expression(*e.left)?; + let e2 = self.fold_boolean_expression(*e.right)?; match (e1, e2) { // reduction of constants (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - Ok(BooleanExpression::Value(v1 || v2)) + Ok(BooleanExpression::value(v1.value || v2.value)) } // x || true == true - (_, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), _) => { - Ok(BooleanExpression::Value(true)) + (_, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), _) + if v.value => + { + Ok(BooleanExpression::value(true)) } // x || false == x - (e, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), e) => { + (e, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), e) + if !v.value => + { Ok(e) } - (e1, e2) => Ok(BooleanExpression::Or(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::bitor(e1, e2)), } } - BooleanExpression::And(box e1, box e2) => { - let e1 = self.fold_boolean_expression(e1)?; - let e2 = self.fold_boolean_expression(e2)?; + BooleanExpression::And(e) => { + let e1 = self.fold_boolean_expression(*e.left)?; + let e2 = self.fold_boolean_expression(*e.right)?; match (e1, e2) { // reduction of constants (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - Ok(BooleanExpression::Value(v1 && v2)) + Ok(BooleanExpression::value(v1.value && v2.value)) } // x && true == x - (e, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), e) => { + (e, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), e) + if v.value => + { Ok(e) } // x && false == false - (_, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), _) => { - Ok(BooleanExpression::Value(false)) + (_, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), _) + if !v.value => + { + Ok(BooleanExpression::value(false)) } - (e1, e2) => Ok(BooleanExpression::And(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::bitand(e1, e2)), } } - BooleanExpression::Not(box e) => { - let e = self.fold_boolean_expression(e)?; + BooleanExpression::Not(e) => { + let e = self.fold_boolean_expression(*e.inner)?; match e { - BooleanExpression::Value(v) => Ok(BooleanExpression::Value(!v)), - e => Ok(BooleanExpression::Not(box e)), + BooleanExpression::Value(v) => Ok(BooleanExpression::value(!v.value)), + e => Ok(BooleanExpression::not(e)), } } - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } } @@ -1485,66 +1496,66 @@ mod tests { #[test] fn add() { - let e = FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + let e = FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(5))) + Ok(FieldElementExpression::value(Bn128Field::from(5))) ); } #[test] fn sub() { - let e = FieldElementExpression::Sub( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e = FieldElementExpression::sub( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); } #[test] fn mult() { - let e = FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e = FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(6))) + Ok(FieldElementExpression::value(Bn128Field::from(6))) ); } #[test] fn div() { - let e = FieldElementExpression::Div( - box FieldElementExpression::Number(Bn128Field::from(6)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e = FieldElementExpression::div( + FieldElementExpression::value(Bn128Field::from(6)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(3))) + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); } #[test] fn pow() { - let e = FieldElementExpression::Pow( - box FieldElementExpression::Number(Bn128Field::from(2)), - box 3u32.into(), + let e = FieldElementExpression::pow( + FieldElementExpression::value(Bn128Field::from(2)), + 3u32.into(), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(8))) + Ok(FieldElementExpression::value(Bn128Field::from(8))) ); } @@ -1553,43 +1564,43 @@ mod tests { let mut propagator = Propagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::identifier("a".into()), - box 0u32.into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::identifier("a".into()), + 0u32.into(), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box 2u32.into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(2)), + 2u32.into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(8))) + Ok(FieldElementExpression::value(Bn128Field::from(8))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(1)), - box ((Bn128Field::get_required_bits() - 1) as u32).into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(1)), + ((Bn128Field::get_required_bits() - 1) as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) + Ok(FieldElementExpression::value(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(3)), - box ((Bn128Field::get_required_bits() - 3) as u32).into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(3)), + ((Bn128Field::get_required_bits() - 3) as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) + Ok(FieldElementExpression::value(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(1)), - box (Bn128Field::get_required_bits() as u32).into(), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(1)), + (Bn128Field::get_required_bits() as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); } @@ -1598,110 +1609,106 @@ mod tests { let mut propagator = Propagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::identifier("a".into()), - box 0u32.into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::identifier("a".into()), + 0u32.into(), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::identifier("a".into()), - box (Bn128Field::get_required_bits() as u32).into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::identifier("a".into()), + (Bn128Field::get_required_bits() as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(3)), - box 1u32.into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(3)), + 1u32.into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box 2u32.into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(2)), + 2u32.into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box 4u32.into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(2)), + 4u32.into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::max_value()), - box ((Bn128Field::get_required_bits() - 1) as u32).into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::max_value()), + ((Bn128Field::get_required_bits() - 1) as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::max_value()), - box (Bn128Field::get_required_bits() as u32).into(), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::max_value()), + (Bn128Field::get_required_bits() as u32).into(), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); } #[test] fn if_else_true() { let e = FieldElementExpression::conditional( - BooleanExpression::Value(true), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + BooleanExpression::value(true), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(2))) + Ok(FieldElementExpression::value(Bn128Field::from(2))) ); } #[test] fn if_else_false() { let e = FieldElementExpression::conditional( - BooleanExpression::Value(false), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(3)), + BooleanExpression::value(false), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), ConditionalKind::IfElse, ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(3))) + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); } #[test] fn select() { let e = FieldElementExpression::select( - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 3u32), - UExpressionInner::Add(box 1u32.into(), box 1u32.into()) - .annotate(UBitwidth::B32), + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(3)).into(), + ]) + .annotate(ArrayType::new(Type::FieldElement, 3u32)), + UExpression::add(1u32.into(), 1u32.into()), ); assert_eq!( Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::Number(Bn128Field::from(3))) + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); } } @@ -1713,21 +1720,21 @@ mod tests { #[test] fn not() { let e_true: BooleanExpression = - BooleanExpression::Not(box BooleanExpression::Value(false)); + BooleanExpression::not(BooleanExpression::value(false)); let e_false: BooleanExpression = - BooleanExpression::Not(box BooleanExpression::Value(true)); + BooleanExpression::not(BooleanExpression::value(true)); let e_default: BooleanExpression = - BooleanExpression::Not(box BooleanExpression::identifier("a".into())); + BooleanExpression::not(BooleanExpression::identifier("a".into())); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_default.clone()), @@ -1737,39 +1744,39 @@ mod tests { #[test] fn field_eq() { - let e_constant_true = BooleanExpression::FieldEq(EqExpression::new( - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(2)), + let e_constant_true = BooleanExpression::FieldEq(BinaryExpression::new( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), )); - let e_constant_false = BooleanExpression::FieldEq(EqExpression::new( - FieldElementExpression::Number(Bn128Field::from(4)), - FieldElementExpression::Number(Bn128Field::from(2)), + let e_constant_false = BooleanExpression::FieldEq(BinaryExpression::new( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(2)), )); let e_identifier_true: BooleanExpression = - BooleanExpression::FieldEq(EqExpression::new( + BooleanExpression::FieldEq(BinaryExpression::new( FieldElementExpression::identifier("a".into()), FieldElementExpression::identifier("a".into()), )); let e_identifier_unchanged: BooleanExpression = - BooleanExpression::FieldEq(EqExpression::new( + BooleanExpression::FieldEq(BinaryExpression::new( FieldElementExpression::identifier("a".into()), FieldElementExpression::identifier("b".into()), )); assert_eq!( Propagator::default().fold_boolean_expression(e_constant_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_constant_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_identifier_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), @@ -1782,151 +1789,125 @@ mod tests { assert_eq!( Propagator::::default().fold_boolean_expression( BooleanExpression::BoolEq(EqExpression::new( - BooleanExpression::Value(false), - BooleanExpression::Value(false) + BooleanExpression::value(false), + BooleanExpression::value(false) )) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( BooleanExpression::BoolEq(EqExpression::new( - BooleanExpression::Value(true), - BooleanExpression::Value(true) + BooleanExpression::value(true), + BooleanExpression::value(true) )) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( BooleanExpression::BoolEq(EqExpression::new( - BooleanExpression::Value(true), - BooleanExpression::Value(false) + BooleanExpression::value(true), + BooleanExpression::value(false) )) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( BooleanExpression::BoolEq(EqExpression::new( - BooleanExpression::Value(false), - BooleanExpression::Value(true) + BooleanExpression::value(false), + BooleanExpression::value(true) )) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn array_eq() { - let e_constant_true = BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32), - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32), + let e_constant_true = BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); - let e_constant_false = BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32), - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(4usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32), + let e_constant_false = BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(4usize)).into(), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); let e_identifier_true: BooleanExpression = - BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), - ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), + BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::identifier("a".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), + ArrayExpression::identifier("a".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); let e_identifier_unchanged: BooleanExpression = - BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpression::identifier("a".into()).annotate(Type::FieldElement, 1u32), - ArrayExpression::identifier("b".into()).annotate(Type::FieldElement, 1u32), + BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::identifier("a".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), + ArrayExpression::identifier("b".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); - let e_non_canonical_true = BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Spread( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32) - .into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32), - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] + let e_non_canonical_true = BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::value(vec![TypedExpressionOrSpread::Spread( + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), - ) - .annotate(Type::FieldElement, 1u32), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); - let e_non_canonical_false = BooleanExpression::ArrayEq(EqExpression::new( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Spread( - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(2usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32) - .into(), - )] + let e_non_canonical_false = BooleanExpression::ArrayEq(BinaryExpression::new( + ArrayExpression::value(vec![TypedExpressionOrSpread::Spread( + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(2usize)).into(), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), - ) - .annotate(Type::FieldElement, 1u32), - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression( - FieldElementExpression::Number(Bn128Field::from(4usize)).into(), - )] - .into(), - ) - .annotate(Type::FieldElement, 1u32), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::value(Bn128Field::from(4usize)).into(), + )]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)), )); assert_eq!( Propagator::default().fold_boolean_expression(e_constant_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_constant_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_identifier_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), @@ -1934,99 +1915,99 @@ mod tests { ); assert_eq!( Propagator::default().fold_boolean_expression(e_non_canonical_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_non_canonical_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn lt() { - let e_true = BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(4)), + let e_true = BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(4)), ); - let e_false = BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e_false = BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn le() { - let e_true = BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e_true = BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), ); - let e_false = BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(2)), + let e_false = BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(2)), ); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn gt() { - let e_true = BooleanExpression::FieldGt( - box FieldElementExpression::Number(Bn128Field::from(5)), - box FieldElementExpression::Number(Bn128Field::from(4)), + let e_true = BooleanExpression::field_gt( + FieldElementExpression::value(Bn128Field::from(5)), + FieldElementExpression::value(Bn128Field::from(4)), ); - let e_false = BooleanExpression::FieldGt( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(5)), + let e_false = BooleanExpression::field_gt( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(5)), ); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } #[test] fn ge() { - let e_true = BooleanExpression::FieldGe( - box FieldElementExpression::Number(Bn128Field::from(5)), - box FieldElementExpression::Number(Bn128Field::from(5)), + let e_true = BooleanExpression::field_ge( + FieldElementExpression::value(Bn128Field::from(5)), + FieldElementExpression::value(Bn128Field::from(5)), ); - let e_false = BooleanExpression::FieldGe( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(5)), + let e_false = BooleanExpression::field_ge( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(5)), ); assert_eq!( Propagator::default().fold_boolean_expression(e_true), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::default().fold_boolean_expression(e_false), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -2036,75 +2017,75 @@ mod tests { assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::identifier(a_bool.clone()) + BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::identifier(a_bool.clone()) ) ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::identifier(a_bool.clone()), - box BooleanExpression::Value(true), + BooleanExpression::bitand( + BooleanExpression::identifier(a_bool.clone()), + BooleanExpression::value(true), ) ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(false), - box BooleanExpression::identifier(a_bool.clone()) + BooleanExpression::bitand( + BooleanExpression::value(false), + BooleanExpression::identifier(a_bool.clone()) ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::identifier(a_bool.clone()), - box BooleanExpression::Value(false), + BooleanExpression::bitand( + BooleanExpression::identifier(a_bool.clone()), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(false), - box BooleanExpression::Value(true), + BooleanExpression::bitand( + BooleanExpression::value(false), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::And( - box BooleanExpression::Value(false), - box BooleanExpression::Value(false), + BooleanExpression::bitand( + BooleanExpression::value(false), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -2114,75 +2095,75 @@ mod tests { assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::identifier(a_bool.clone()) + BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::identifier(a_bool.clone()) ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::identifier(a_bool.clone()), - box BooleanExpression::Value(true), + BooleanExpression::bitor( + BooleanExpression::identifier(a_bool.clone()), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::identifier(a_bool.clone()) + BooleanExpression::bitor( + BooleanExpression::value(false), + BooleanExpression::identifier(a_bool.clone()) ) ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::identifier(a_bool.clone()), - box BooleanExpression::Value(false), + BooleanExpression::bitor( + BooleanExpression::identifier(a_bool.clone()), + BooleanExpression::value(false), ) ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::Value(true), + BooleanExpression::bitor( + BooleanExpression::value(false), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::value(true), ) ), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( Propagator::::default().fold_boolean_expression( - BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::Value(false), + BooleanExpression::bitor( + BooleanExpression::value(false), + BooleanExpression::value(false), ) ), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } } diff --git a/zokrates_analysis/src/reducer/constants_reader.rs b/zokrates_analysis/src/reducer/constants_reader.rs index f991f77fd..ca882dfdd 100644 --- a/zokrates_analysis/src/reducer/constants_reader.rs +++ b/zokrates_analysis/src/reducer/constants_reader.rs @@ -2,15 +2,13 @@ use crate::reducer::ConstantDefinitions; use zokrates_ast::typed::{ - folder::*, identifier::FrameIdentifier, ArrayExpression, ArrayExpressionInner, ArrayType, - BooleanExpression, CoreIdentifier, DeclarationConstant, Expr, FieldElementExpression, Id, - Identifier, IdentifierExpression, StructExpression, StructExpressionInner, StructType, - TupleExpression, TupleExpressionInner, TupleType, TypedProgram, TypedSymbolDeclaration, - UBitwidth, UExpression, UExpressionInner, + folder::*, identifier::FrameIdentifier, CoreIdentifier, DeclarationConstant, Expr, Id, + Identifier, IdentifierExpression, IdentifierOrExpression, TypedExpression, TypedProgram, + TypedSymbolDeclaration, UExpression, UExpressionInner, }; use zokrates_field::Field; -use std::convert::{TryFrom, TryInto}; +use std::convert::TryFrom; pub struct ConstantsReader<'a, 'ast, T> { constants: &'a ConstantDefinitions<'ast, T>, @@ -44,7 +42,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { match self.constants.get(&c).cloned() { Some(e) => match UExpression::try_from(e).unwrap().into_inner() { - UExpressionInner::Value(v) => DeclarationConstant::Concrete(v as u32), + UExpressionInner::Value(v) => DeclarationConstant::Concrete(v.value as u32), _ => unreachable!(), }, None => DeclarationConstant::Constant(c), @@ -54,179 +52,31 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { } } - fn fold_field_expression( + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + From>, + >( &mut self, - e: FieldElementExpression<'ast, T>, - ) -> FieldElementExpression<'ast, T> { - match e { - FieldElementExpression::Identifier(IdentifierExpression { + ty: &E::Ty, + e: IdentifierExpression<'ast, E>, + ) -> IdentifierOrExpression<'ast, T, E> { + match e.id { + Identifier { id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame: _, }, - .. - }) => { + version, + } => { assert_eq!(version, 0); match self.constants.get(&c).cloned() { - Some(v) => v.try_into().unwrap(), - None => FieldElementExpression::identifier(Identifier::from( - CoreIdentifier::Constant(c), + Some(v) => IdentifierOrExpression::Expression(E::from(v).into_inner()), + None => IdentifierOrExpression::Identifier(IdentifierExpression::new( + CoreIdentifier::Constant(c).into(), )), } } - e => fold_field_expression(self, e), - } - } - - fn fold_boolean_expression( - &mut self, - e: BooleanExpression<'ast, T>, - ) -> BooleanExpression<'ast, T> { - match e { - BooleanExpression::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => v.try_into().unwrap(), - None => { - BooleanExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - } - } - } - e => fold_boolean_expression(self, e), - } - } - - fn fold_uint_expression_inner( - &mut self, - ty: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> UExpressionInner<'ast, T> { - match e { - UExpressionInner::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => UExpression::try_from(v).unwrap().into_inner(), - None => UExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))), - } - } - e => fold_uint_expression_inner(self, ty, e), - } - } - - fn fold_array_expression_inner( - &mut self, - ty: &ArrayType<'ast, T>, - e: ArrayExpressionInner<'ast, T>, - ) -> ArrayExpressionInner<'ast, T> { - match e { - ArrayExpressionInner::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => ArrayExpression::try_from(v).unwrap().into_inner(), - None => { - ArrayExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - } - } - } - e => fold_array_expression_inner(self, ty, e), - } - } - - fn fold_tuple_expression_inner( - &mut self, - ty: &TupleType<'ast, T>, - e: TupleExpressionInner<'ast, T>, - ) -> TupleExpressionInner<'ast, T> { - match e { - TupleExpressionInner::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => TupleExpression::try_from(v).unwrap().into_inner(), - None => { - TupleExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - } - } - } - e => fold_tuple_expression_inner(self, ty, e), - } - } - - fn fold_struct_expression_inner( - &mut self, - ty: &StructType<'ast, T>, - e: StructExpressionInner<'ast, T>, - ) -> StructExpressionInner<'ast, T> { - match e { - StructExpressionInner::Identifier(IdentifierExpression { - id: - Identifier { - id: - FrameIdentifier { - id: CoreIdentifier::Constant(c), - frame: _, - }, - version, - }, - .. - }) => { - assert_eq!(version, 0); - match self.constants.get(&c).cloned() { - Some(v) => StructExpression::try_from(v).unwrap().into_inner(), - None => { - StructExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - } - } - } - e => fold_struct_expression_inner(self, ty, e), + _ => fold_identifier_expression(self, ty, e), } } } diff --git a/zokrates_analysis/src/reducer/constants_writer.rs b/zokrates_analysis/src/reducer/constants_writer.rs index 50a6d25ad..a96d0c8d5 100644 --- a/zokrates_analysis/src/reducer/constants_writer.rs +++ b/zokrates_analysis/src/reducer/constants_writer.rs @@ -3,18 +3,17 @@ use crate::reducer::{ constants_reader::ConstantsReader, reduce_function, ConstantDefinitions, Error, }; -use std::collections::{BTreeMap, HashSet}; +use std::collections::HashSet; use zokrates_ast::typed::{ - result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol, - TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration, - UExpression, + result_folder::*, Constant, ModuleId, OwnedModuleId, Typed, TypedConstant, TypedConstantSymbol, + TypedConstantSymbolDeclaration, TypedProgram, TypedSymbolDeclaration, UExpression, }; use zokrates_field::Field; pub struct ConstantsWriter<'ast, T> { - treated: HashSet, + treated: HashSet, constants: ConstantDefinitions<'ast, T>, - location: OwnedTypedModuleId, + location: OwnedModuleId, program: TypedProgram<'ast, T>, } @@ -28,22 +27,19 @@ impl<'ast, T: Field> ConstantsWriter<'ast, T> { } } - fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn change_location(&mut self, location: OwnedModuleId) -> OwnedModuleId { let prev = self.location.clone(); self.location = location; self.treated.insert(self.location.clone()); prev } - fn treated(&self, id: &TypedModuleId) -> bool { + fn treated(&self, id: &ModuleId) -> bool { self.treated.contains(id) } fn update_program(&mut self) { - let mut p = TypedProgram { - main: "".into(), - modules: BTreeMap::default(), - }; + let mut p = TypedProgram::default(); std::mem::swap(&mut self.program, &mut p); self.program = ConstantsReader::with_constants(&self.constants).read_into_program(p); } @@ -59,10 +55,7 @@ impl<'ast, T: Field> ConstantsWriter<'ast, T> { impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> { type Error = Error; - fn fold_module_id( - &mut self, - id: OwnedTypedModuleId, - ) -> Result { + fn fold_module_id(&mut self, id: OwnedModuleId) -> Result { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); @@ -114,15 +107,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> { // wrap this expression in a function let wrapper = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return(c.expression)], + statements: vec![TypedStatement::ret(c.expression)], signature: DeclarationSignature::new().output(c.ty.clone()), }; let mut inlined_wrapper = reduce_function(wrapper, &self.program)?; - if let TypedStatement::Return(expression) = - inlined_wrapper.statements.pop().unwrap() - { + if let TypedStatement::Return(ret) = inlined_wrapper.statements.pop().unwrap() { + let expression = ret.inner; + if !expression.is_constant() { return Err(Error::ConstantReduction(id.id.to_string(), id.module)); }; diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index 8d3727d3d..254952498 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -91,7 +91,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .map(|g| { g.as_ref() .map(|g| match g.as_inner() { - UExpressionInner::Value(v) => Ok(*v as u32), + UExpressionInner::Value(v) => Ok(v.value as u32), _ => Err(()), }) .transpose() @@ -143,12 +143,12 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(f.arguments.len(), arguments.len()); let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| { - TypedStatement::Definition( + TypedStatement::definition( TypedAssignee::Identifier(Variable::uint( CoreIdentifier::from(identifier), UBitwidth::B32, )), - TypedExpression::from(UExpression::from(value)).into(), + TypedExpression::from(UExpression::from(value)), ) }); @@ -156,7 +156,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .arguments .into_iter() .zip(inferred_signature.inputs.clone()) - .map(|(p, t)| ConcreteVariable::new(p.id.id, t, false)) + .map(|(p, t)| ConcreteVariable::new(p.id.id, t)) .map(Variable::from) .collect(); @@ -167,7 +167,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(returns.len(), 1); let return_value = match returns.pop().unwrap() { - TypedStatement::Return(e) => e, + TypedStatement::Return(e) => e.inner, _ => unreachable!(), }; diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index 826cd6663..c248cccf2 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -20,18 +20,24 @@ mod shallow_ssa; use self::inline::InlineValue; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; -use zokrates_ast::typed::result_folder::*; +use zokrates_ast::common::ResultFold; +use zokrates_ast::common::WithSpan; use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; +use zokrates_ast::typed::SliceExpression; +use zokrates_ast::typed::SliceOrExpression; +use zokrates_ast::typed::{result_folder::*, ArrayExpression}; +use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; + use zokrates_ast::typed::TypedAssignee; use zokrates_ast::typed::{ - ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, - FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression, - TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, - TypedStatement, UExpression, UExpressionInner, + BlockExpression, CoreIdentifier, Expr, FunctionCall, FunctionCallExpression, + FunctionCallOrExpression, Id, TypedExpression, TypedFunction, TypedFunctionSymbol, + TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, TypedStatement, UExpression, + UExpressionInner, }; -use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; +use zokrates_ast::untyped::OwnedModuleId; use zokrates_field::Field; use self::constants_writer::ConstantsWriter; @@ -53,7 +59,7 @@ pub enum Error { Incompatible(String), GenericsInMain, LoopTooLarge(u128), - ConstantReduction(String, OwnedTypedModuleId), + ConstantReduction(String, OwnedModuleId), NonConstant(String), Type(String), Propagation(propagation::Error), @@ -119,8 +125,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ty: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> Result, Self::Error> { - // generics are already in ssa form + let span = e.get_span(); + // generics are already in ssa form let generics: Vec<_> = e .generics .into_iter() @@ -166,7 +173,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let input_bindings: Vec<_> = input_variables .into_iter() .zip(arguments) - .map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a)) + .map(|(v, a)| { + TypedStatement::definition(self.ssa.fold_assignee(v.into()), a).span(span) + }) .collect(); let input_bindings = input_bindings @@ -174,7 +183,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .map(|s| self.propagator.fold_statement(s)) .collect::, _>>()? .into_iter() - .flatten(); + .flatten() + .collect::>(); self.statement_buffer.extend(input_bindings); @@ -183,7 +193,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .map(|s| self.fold_statement(s)) .collect::, _>>()? .into_iter() - .flatten(); + .flatten() + .collect::>(); self.statement_buffer.extend(statements); @@ -208,18 +219,20 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { Err(InlineError::Flat(embed, generics, output_type)) => { let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0)); - let var = Variable::immutable(identifier.clone(), output_type); + let var = Variable::new(identifier.clone(), output_type); let v: TypedAssignee<'ast, T> = var.clone().into(); - self.statement_buffer - .push(TypedStatement::embed_call_definition( + self.statement_buffer.push( + TypedStatement::embed_call_definition( v, EmbedCall::new(embed, generics, arguments), - )); - Ok(FunctionCallOrExpression::Expression(E::identifier( - identifier, - ))) + ) + .span(span), + ); + Ok(FunctionCallOrExpression::Expression( + E::identifier(identifier).span(span), + )) } }; @@ -228,7 +241,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { res } - fn fold_block_expression>( + fn fold_block_expression>( &mut self, b: BlockExpression<'ast, T, E>, ) -> Result, Self::Error> { @@ -258,36 +271,39 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { unreachable!("canonical constant identifiers should not be folded, they should be inlined") } + // here we implement fold_statement and not fold_statement_cases because we do not want the span of the input + // to be applied to all the outputs: a statement which contains a call which gets inline should not hide the + // inlined statements behind its own span fn fold_statement( &mut self, s: TypedStatement<'ast, T>, ) -> Result>, Self::Error> { let res = match s { - TypedStatement::For(v, from, to, statements) => { - let from = self.ssa.fold_uint_expression(from); + TypedStatement::For(s) => { + let from = self.ssa.fold_uint_expression(s.from); let from = self.propagator.fold_uint_expression(from)?; let from = self.fold_uint_expression(from)?; let from = self.propagator.fold_uint_expression(from)?; - let to = self.ssa.fold_uint_expression(to); + let to = self.ssa.fold_uint_expression(s.to); let to = self.propagator.fold_uint_expression(to)?; let to = self.fold_uint_expression(to)?; let to = self.propagator.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(from), UExpressionInner::Value(to)) - if to - from > MAX_FOR_LOOP_SIZE => + if to.value - from.value > MAX_FOR_LOOP_SIZE => { - Err(Error::LoopTooLarge(to.saturating_sub(*from))) + Err(Error::LoopTooLarge(to.value.saturating_sub(from.value))) } - (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from - ..*to) + (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((from.value + ..to.value) .flat_map(|index| { std::iter::once(TypedStatement::definition( - v.clone().into(), + s.var.clone().into(), UExpression::from(index as u32).into(), )) - .chain(statements.clone()) + .chain(s.statements.clone()) .map(|s| self.fold_statement(s)) .collect::>() }) @@ -330,39 +346,33 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { Ok(self.statement_buffer.drain(..).chain(res).collect()) } - fn fold_array_expression_inner( + fn fold_slice_expression( &mut self, - array_ty: &ArrayType<'ast, T>, - e: ArrayExpressionInner<'ast, T>, - ) -> Result, Self::Error> { - match e { - ArrayExpressionInner::Slice(box array, box from, box to) => { - let array = self.ssa.fold_array_expression(array); - let array = self.propagator.fold_array_expression(array)?; - let array = self.fold_array_expression(array)?; - let array = self.propagator.fold_array_expression(array)?; - - let from = self.ssa.fold_uint_expression(from); - let from = self.propagator.fold_uint_expression(from)?; - let from = self.fold_uint_expression(from)?; - let from = self.propagator.fold_uint_expression(from)?; - - let to = self.ssa.fold_uint_expression(to); - let to = self.propagator.fold_uint_expression(to)?; - let to = self.fold_uint_expression(to)?; - let to = self.propagator.fold_uint_expression(to)?; - - match (from.as_inner(), to.as_inner()) { - (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { - Ok(ArrayExpressionInner::Slice(box array, box from, box to)) - } - _ => Err(Error::NonConstant(format!( - "Slice bounds must be compile time constants, found {}", - ArrayExpressionInner::Slice(box array, box from, box to) - ))), - } - } - _ => fold_array_expression_inner(self, array_ty, e), + e: zokrates_ast::typed::SliceExpression<'ast, T>, + ) -> Result, Self::Error> { + let array = self.ssa.fold_array_expression(*e.array); + let array = self.propagator.fold_array_expression(array)?; + let array = self.fold_array_expression(array)?; + let array = self.propagator.fold_array_expression(array)?; + + let from = self.ssa.fold_uint_expression(*e.from); + let from = self.propagator.fold_uint_expression(from)?; + let from = self.fold_uint_expression(from)?; + let from = self.propagator.fold_uint_expression(from)?; + + let to = self.ssa.fold_uint_expression(*e.to); + let to = self.propagator.fold_uint_expression(to)?; + let to = self.fold_uint_expression(to)?; + let to = self.propagator.fold_uint_expression(to)?; + + match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(..), UExpressionInner::Value(..)) => Ok( + SliceOrExpression::Slice(SliceExpression::new(array, from, to)), + ), + _ => Err(Error::NonConstant(format!( + "Slice bounds must be compile time constants, found {}", + ArrayExpression::slice(array, from, to) + ))), } } } @@ -405,6 +415,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E )] .into_iter() .collect(), + ..p }) } _ => Err(Error::GenericsInMain), @@ -426,17 +437,16 @@ mod tests { use zokrates_ast::typed::types::DeclarationSignature; use zokrates_ast::typed::types::{DeclarationConstant, GTupleType}; use zokrates_ast::typed::{ - ArrayExpression, ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, - DeclarationVariable, FieldElementExpression, GenericIdentifier, Identifier, - OwnedTypedModuleId, Select, TupleExpressionInner, TupleType, Type, TypedExpression, - TypedExpressionOrSpread, UBitwidth, UExpressionInner, Variable, + ArrayExpression, ArrayType, DeclarationFunctionKey, DeclarationType, DeclarationVariable, + FieldElementExpression, GenericIdentifier, Identifier, Select, TupleExpression, TupleType, + Type, TypedExpression, TypedExpressionOrSpread, UBitwidth, Variable, }; use zokrates_field::Bn128Field; use lazy_static::lazy_static; lazy_static! { - static ref MAIN_MODULE_ID: OwnedTypedModuleId = OwnedTypedModuleId::from("main"); + static ref MAIN_MODULE_ID: OwnedModuleId = OwnedModuleId::from("main"); } #[test] @@ -463,7 +473,7 @@ mod tests { let foo: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( FieldElementExpression::identifier("a".into()).into(), )], signature: DeclarationSignature::new() @@ -507,7 +517,7 @@ mod tests { .annotate(UBitwidth::B32) .into(), ), - TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), + TypedStatement::ret(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) @@ -515,6 +525,7 @@ mod tests { }; let p = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -562,7 +573,7 @@ mod tests { Variable::field_element(Identifier::from("a").version(2)).into(), FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(), ), - TypedStatement::Return( + TypedStatement::ret( FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), ), ], @@ -572,6 +583,7 @@ mod tests { }; let expected: TypedProgram = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -636,9 +648,9 @@ mod tests { GenericIdentifier::with_name("K").with_index(0), ) .into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), )], signature: foo_signature.clone(), @@ -659,10 +671,10 @@ mod tests { ), TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), - ) - .annotate(Type::FieldElement, 1u32) + ArrayExpression::value(vec![ + FieldElementExpression::identifier("a".into()).into() + ]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( @@ -672,10 +684,10 @@ mod tests { .signature(foo_signature.clone()), vec![None], vec![ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into()], ) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( @@ -684,11 +696,11 @@ mod tests { .annotate(UBitwidth::B32) .into(), ), - TypedStatement::Return( + TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), 0u32, )) .into(), @@ -700,6 +712,7 @@ mod tests { }; let p = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -734,31 +747,31 @@ mod tests { statements: vec![ TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), - ) - .annotate(Type::FieldElement, 1u32) + ArrayExpression::value(vec![ + FieldElementExpression::identifier("a".into()).into() + ]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier(Identifier::from("a").in_frame(1)) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), - TypedStatement::Return( + TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier(Identifier::from("b").version(1)) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), 0u32, )) .into(), @@ -770,6 +783,7 @@ mod tests { }; let expected = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -834,14 +848,16 @@ mod tests { GenericIdentifier::with_name("K").with_index(0), ) .into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), )], signature: foo_signature.clone(), }; + use std::ops::Sub; + let main: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], statements: vec![ @@ -859,17 +875,16 @@ mod tests { Variable::array( "b", Type::FieldElement, - UExpressionInner::Sub( - box UExpression::identifier("n".into()).annotate(UBitwidth::B32), - box 1u32.into(), - ) - .annotate(UBitwidth::B32), + UExpression::sub( + UExpression::identifier("n".into()).annotate(UBitwidth::B32), + 1u32.into(), + ), ) .into(), - ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), - ) - .annotate(Type::FieldElement, 1u32) + ArrayExpression::value(vec![ + FieldElementExpression::identifier("a".into()).into() + ]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( @@ -879,10 +894,10 @@ mod tests { .signature(foo_signature.clone()), vec![None], vec![ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into()], ) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( @@ -891,11 +906,11 @@ mod tests { .annotate(UBitwidth::B32) .into(), ), - TypedStatement::Return( + TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), 0u32, )) .into(), @@ -907,6 +922,7 @@ mod tests { }; let p = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -941,32 +957,31 @@ mod tests { statements: vec![ TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), - ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier(Identifier::from("a")).into()] - .into(), - ) - .annotate(Type::FieldElement, 1u32) + ArrayExpression::value(vec![ + FieldElementExpression::identifier("a".into()).into() + ]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier(Identifier::from("a").in_frame(1)) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), - TypedStatement::Return( + TypedStatement::ret( (FieldElementExpression::identifier("a".into()) + FieldElementExpression::select( ArrayExpression::identifier(Identifier::from("b").version(1)) - .annotate(Type::FieldElement, 1u32), + .annotate(ArrayType::new(Type::FieldElement, 1u32)), 0u32, )) .into(), @@ -978,6 +993,7 @@ mod tests { }; let expected = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -1021,10 +1037,6 @@ mod tests { // expected: // def main() { - // # PUSH CALL to foo::<1> - // # PUSH CALL to bar::<2> - // # POP CALL - // # POP CALL // return; // } @@ -1056,49 +1068,42 @@ mod tests { UExpression::identifier("K".into()).annotate(UBitwidth::B32), ) .into(), - ArrayExpressionInner::Slice( - box ArrayExpression::function_call( + ArrayExpression::slice( + ArrayExpression::function_call( DeclarationFunctionKey::with_location("main", "bar") .signature(foo_signature.clone()), vec![None], - vec![ArrayExpressionInner::Value( - vec![ - TypedExpressionOrSpread::Spread( - ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - FieldElementExpression::Number(Bn128Field::from(0)).into(), - ] - .into(), - ) - .annotate( + vec![ArrayExpression::value(vec![ + TypedExpressionOrSpread::Spread( + ArrayExpression::identifier("a".into()) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) + .into(), + ), + FieldElementExpression::value(Bn128Field::from(0)).into(), + ]) + .annotate(ArrayType::new( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32) + 1u32.into(), - ) + )) .into()], ) - .annotate( + .annotate(ArrayType::new( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32) + 1u32.into(), - ), - box 0u32.into(), - box UExpression::identifier("K".into()).annotate(UBitwidth::B32), - ) - .annotate( - Type::FieldElement, + )), + 0u32.into(), UExpression::identifier("K".into()).annotate(UBitwidth::B32), ) .into(), ), - TypedStatement::Return( + TypedStatement::ret( ArrayExpression::identifier("ret".into()) - .annotate( + .annotate(ArrayType::new( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32), - ) + )) .into(), ), ], @@ -1114,12 +1119,12 @@ mod tests { DeclarationConstant::Generic(GenericIdentifier::with_name("K").with_index(0)), ) .into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) - .annotate( + .annotate(ArrayType::new( Type::FieldElement, UExpression::identifier("K".into()).annotate(UBitwidth::B32), - ) + )) .into(), )], signature: bar_signature.clone(), @@ -1134,17 +1139,18 @@ mod tests { DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), vec![None], - vec![ArrayExpressionInner::Value( - vec![FieldElementExpression::Number(Bn128Field::from(1)).into()].into(), + vec![ArrayExpression::value(vec![FieldElementExpression::value( + Bn128Field::from(1), ) - .annotate(Type::FieldElement, 1u32) + .into()]) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into()], ) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -1180,14 +1186,15 @@ mod tests { )] .into_iter() .collect(), + module_map: Default::default(), }; let reduced = reduce_program(p); let expected_main = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + statements: vec![TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), )], @@ -1195,6 +1202,7 @@ mod tests { }; let expected = TypedProgram { + module_map: Default::default(), main: "main".into(), modules: vec![( "main".into(), @@ -1247,9 +1255,9 @@ mod tests { GenericIdentifier::with_name("K").with_index(0), ) .into()], - statements: vec![TypedStatement::Return( + statements: vec![TypedStatement::ret( ArrayExpression::identifier("a".into()) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), )], signature: foo_signature.clone(), @@ -1264,15 +1272,15 @@ mod tests { DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), vec![None], - vec![ArrayExpressionInner::Value(vec![].into()) - .annotate(Type::FieldElement, 0u32) + vec![ArrayExpression::value(vec![]) + .annotate(ArrayType::new(Type::FieldElement, 0u32)) .into()], ) - .annotate(Type::FieldElement, 1u32) + .annotate(ArrayType::new(Type::FieldElement, 1u32)) .into(), ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -1308,6 +1316,7 @@ mod tests { )] .into_iter() .collect(), + module_map: Default::default(), }; let reduced = reduce_program(p); diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index aaec4fd56..46f9cd962 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -103,15 +103,15 @@ impl<'ast> ShallowTransformer<'ast> { ) -> TypedAssignee<'ast, T> { match a { TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)), - TypedAssignee::Select(box a, box index) => TypedAssignee::Select( - box self.fold_assignee_no_ssa_increase(a), - box self.fold_uint_expression(index), + TypedAssignee::Select(a, index) => TypedAssignee::select( + self.fold_assignee_no_ssa_increase(*a), + self.fold_uint_expression(*index), ), - TypedAssignee::Member(box s, m) => { - TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m) + TypedAssignee::Member(s, m) => { + TypedAssignee::member(self.fold_assignee_no_ssa_increase(*s), m) } - TypedAssignee::Element(box s, index) => { - TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index) + TypedAssignee::Element(s, index) => { + TypedAssignee::element(self.fold_assignee_no_ssa_increase(*s), index) } } } @@ -155,16 +155,11 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { } } - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { - match s { - // only fold bounds of for loop statements - TypedStatement::For(v, from, to, stats) => { - let from = self.fold_uint_expression(from); - let to = self.fold_uint_expression(to); - vec![TypedStatement::For(v, from, to, stats)] - } - s => fold_statement(self, s), - } + // only fold bounds of for loop statements + fn fold_for_statement(&mut self, s: ForStatement<'ast, T>) -> Vec> { + let from = self.fold_uint_expression(s.from); + let to = self.fold_uint_expression(s.to); + vec![TypedStatement::for_(s.var, from, to, s.statements)] } // retrieve the latest version @@ -185,6 +180,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { #[cfg(test)] mod tests { use super::*; + use std::ops::*; use zokrates_ast::typed::types::DeclarationSignature; use zokrates_field::Bn128Field; mod normal { @@ -205,7 +201,7 @@ mod tests { statements: vec![ TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))), - FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(4)).into(), ), TypedStatement::definition( TypedAssignee::Identifier(Variable::uint( @@ -214,19 +210,19 @@ mod tests { )), UExpression::from(0u32).into(), ), - TypedStatement::For( - Variable::new("i", Type::Uint(UBitwidth::B32), false), + TypedStatement::for_( + Variable::new("i", Type::Uint(UBitwidth::B32)), UExpression::identifier("i".into()).annotate(UBitwidth::B32), 2u32.into(), vec![TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element(Identifier::from( "foo", ))), - FieldElementExpression::Number(Bn128Field::from(5)).into(), + FieldElementExpression::value(Bn128Field::from(5)).into(), )], ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -254,7 +250,7 @@ mod tests { let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(5)).into(), + FieldElementExpression::value(Bn128Field::from(5)).into(), ); assert_eq!( u.fold_statement(s), @@ -262,13 +258,13 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(0) )), - FieldElementExpression::Number(Bn128Field::from(5)).into() + FieldElementExpression::value(Bn128Field::from(5)).into() )] ); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(6)).into(), + FieldElementExpression::value(Bn128Field::from(6)).into(), ); assert_eq!( u.fold_statement(s), @@ -276,7 +272,7 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(1) )), - FieldElementExpression::Number(Bn128Field::from(6)).into() + FieldElementExpression::value(Bn128Field::from(6)).into() )] ); @@ -301,7 +297,7 @@ mod tests { let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(5)).into(), + FieldElementExpression::value(Bn128Field::from(5)).into(), ); assert_eq!( u.fold_statement(s), @@ -309,15 +305,15 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(0) )), - FieldElementExpression::Number(Bn128Field::from(5)).into() + FieldElementExpression::value(Bn128Field::from(5)).into() )] ); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Add( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + FieldElementExpression::add( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(1)), ) .into(), ); @@ -327,9 +323,9 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(1) )), - FieldElementExpression::Add( - box FieldElementExpression::identifier(Identifier::from("a").version(0)), - box FieldElementExpression::Number(Bn128Field::from(1)) + FieldElementExpression::add( + FieldElementExpression::identifier(Identifier::from("a").version(0)), + FieldElementExpression::value(Bn128Field::from(1)) ) .into() )] @@ -350,7 +346,7 @@ mod tests { let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ); assert_eq!( u.fold_statement(s), @@ -358,7 +354,7 @@ mod tests { TypedAssignee::Identifier(Variable::field_element( Identifier::from("a").version(0) )), - FieldElementExpression::Number(Bn128Field::from(2)).into() + FieldElementExpression::value(Bn128Field::from(2)).into() )] ); @@ -409,14 +405,11 @@ mod tests { let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 2u32) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), + ]) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), ); @@ -428,24 +421,25 @@ mod tests { Type::FieldElement, 2u32 )), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into() - ] - .into() - ) - .annotate(Type::FieldElement, 2u32) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into() + ]) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into() )] ); let s: TypedStatement = TypedStatement::definition( TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), - box UExpression::from(1u32), + Box::new(TypedAssignee::Identifier(Variable::array( + "a", + Type::FieldElement, + 2u32, + ))), + Box::new(UExpression::from(1u32)), ), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ); assert_eq!(u.fold_statement(s.clone()), vec![s]); @@ -465,31 +459,25 @@ mod tests { let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32)); let s = TypedStatement::definition( - TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone(), true)), - ArrayExpressionInner::Value( - vec![ - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(0)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ] + TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone())), + ArrayExpression::value(vec![ + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(0)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), + ]) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), - ) - .annotate(Type::array((Type::FieldElement, 2u32)), 2u32) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(3)).into(), + ]) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) + .into(), + ]) + .annotate(ArrayType::new( + Type::array((Type::FieldElement, 2u32)), + 2u32, + )) .into(), ); @@ -499,53 +487,39 @@ mod tests { TypedAssignee::Identifier(Variable::new( Identifier::from("a").version(0), array_of_array_ty.clone(), - true, )), - ArrayExpressionInner::Value( - vec![ - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(0)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ] - .into() - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ] - .into() - ) - .annotate(Type::FieldElement, 2u32) - .into(), - ] - .into() - ) - .annotate(Type::array((Type::FieldElement, 2u32)), 2u32) + ArrayExpression::value(vec![ + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(0)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), + ]) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) + .into(), + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(3)).into(), + ]) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) + .into(), + ]) + .annotate(ArrayType::new( + Type::array((Type::FieldElement, 2u32)), + 2u32 + )) .into(), )] ); let s: TypedStatement = TypedStatement::definition( - TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::new( - "a", - array_of_array_ty.clone(), - true, - )), - box UExpression::from(1u32), + TypedAssignee::select( + TypedAssignee::Identifier(Variable::new("a", array_of_array_ty.clone())), + UExpression::from(1u32), ), - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(4)).into(), - FieldElementExpression::Number(Bn128Field::from(5)).into(), - ] - .into(), - ) - .annotate(Type::FieldElement, 2u32) + ArrayExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(5)).into(), + ]) + .annotate(ArrayType::new(Type::FieldElement, 2u32)) .into(), ); @@ -606,7 +580,7 @@ mod tests { Variable::field_element("a").into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::For( + TypedStatement::for_( Variable::uint("i", UBitwidth::B32), UExpression::identifier("n".into()).annotate(UBitwidth::B32), UExpression::identifier("n".into()).annotate(UBitwidth::B32) @@ -620,7 +594,7 @@ mod tests { Variable::field_element("a").into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::For( + TypedStatement::for_( Variable::uint("i", UBitwidth::B32), UExpression::identifier("n".into()).annotate(UBitwidth::B32), UExpression::identifier("n".into()).annotate(UBitwidth::B32) @@ -634,7 +608,7 @@ mod tests { Variable::field_element("a").into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), + TypedStatement::ret(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) @@ -660,7 +634,7 @@ mod tests { Variable::field_element(Identifier::from("a").version(1)).into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::For( + TypedStatement::for_( Variable::uint("i", UBitwidth::B32), UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), @@ -677,7 +651,7 @@ mod tests { Variable::field_element(Identifier::from("a").version(2)).into(), FieldElementExpression::identifier(Identifier::from("a").version(1)).into(), ), - TypedStatement::For( + TypedStatement::for_( Variable::uint("i", UBitwidth::B32), UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), @@ -694,7 +668,7 @@ mod tests { Variable::field_element(Identifier::from("a").version(3)).into(), FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), ), - TypedStatement::Return( + TypedStatement::ret( FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), ), ], @@ -720,6 +694,7 @@ mod tests { } mod shadowing { + use super::*; #[test] @@ -751,10 +726,10 @@ mod tests { 1, ))) .into(), - BooleanExpression::Value(true).into(), + BooleanExpression::value(true).into(), ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -775,10 +750,10 @@ mod tests { 1, ))) .into(), - BooleanExpression::Value(true).into(), + BooleanExpression::value(true).into(), ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) + TypedStatement::ret( + TupleExpression::value(vec![]) .annotate(TupleType::new(vec![])) .into(), ), @@ -860,7 +835,7 @@ mod tests { )) .into(), ), - TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), + TypedStatement::ret(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() .generics(vec![Some( @@ -924,7 +899,7 @@ mod tests { )) .into(), ), - TypedStatement::Return( + TypedStatement::ret( FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), ), ], diff --git a/zokrates_analysis/src/struct_concretizer.rs b/zokrates_analysis/src/struct_concretizer.rs index 932a92670..6e57c0484 100644 --- a/zokrates_analysis/src/struct_concretizer.rs +++ b/zokrates_analysis/src/struct_concretizer.rs @@ -70,7 +70,7 @@ impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast, T> { .collect(), generics: concrete_generics .into_iter() - .map(|g| Some(DeclarationConstant::Concrete(g as u32))) + .map(|g| Some(DeclarationConstant::Concrete(g))) .collect(), ..ty } @@ -82,9 +82,9 @@ impl<'ast, T: Field> Folder<'ast, T> for StructConcretizer<'ast, T> { ) -> DeclarationArrayType<'ast, T> { let size = ty.size.map_concrete(&self.generics).unwrap(); - DeclarationArrayType { - size: box DeclarationConstant::Concrete(size), - ty: box self.fold_declaration_type(*ty.ty), - } + DeclarationArrayType::new( + self.fold_declaration_type(*ty.ty), + DeclarationConstant::Concrete(size), + ) } } diff --git a/zokrates_analysis/src/uint_optimizer.rs b/zokrates_analysis/src/uint_optimizer.rs index ac96dbfe1..c006cdd67 100644 --- a/zokrates_analysis/src/uint_optimizer.rs +++ b/zokrates_analysis/src/uint_optimizer.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::ops::{BitAnd, Shl, Shr}; -use zokrates_ast::common::FlatEmbed; +use zokrates_ast::common::{FlatEmbed, Fold, WithSpan}; use zokrates_ast::zir::folder::*; use zokrates_ast::zir::*; use zokrates_field::Field; @@ -55,7 +55,7 @@ fn force_no_reduce(e: UExpression) -> UExpression { } impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { - fn fold_select_expression + Fold<'ast, T> + Select<'ast, T>>( + fn fold_select_expression + Fold + Select<'ast, T>>( &mut self, _: &E::Ty, e: SelectExpression<'ast, T, E>, @@ -66,43 +66,45 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { SelectOrExpression::Select(SelectExpression::new(array, force_reduce(index))) } - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { - BooleanExpression::UintEq(box left, box right) => { - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + BooleanExpression::UintEq(e) => { + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left = force_reduce(left); let right = force_reduce(right); - BooleanExpression::UintEq(box left, box right) + BooleanExpression::uint_eq(left, right) } - BooleanExpression::UintLt(box left, box right) => { - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + BooleanExpression::UintLt(e) => { + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left = force_reduce(left); let right = force_reduce(right); - BooleanExpression::UintLt(box left, box right) + BooleanExpression::uint_lt(left, right) } - BooleanExpression::UintLe(box left, box right) => { - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + BooleanExpression::UintLe(e) => { + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left = force_reduce(left); let right = force_reduce(right); - BooleanExpression::UintLe(box left, box right) + BooleanExpression::uint_le(left, right) } - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } fn fold_uint_expression(&mut self, e: UExpression<'ast, T>) -> UExpression<'ast, T> { + let span = e.get_span(); + if e.metadata.is_some() { return e; } @@ -120,7 +122,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { use self::UExpressionInner::*; let res = match inner { - Value(v) => Value(v).annotate(range).with_max(v), + Value(v) => Value(v.clone()).annotate(range).with_max(v.value), Identifier(id) => Identifier(id.clone()).annotate(range).metadata( self.ids .get(&Variable::uint(id.id.clone(), range)) @@ -151,10 +153,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { UExpression::select(values, index).with_max(max_value) } - Add(box left, box right) => { + Add(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left_max = left.metadata.clone().unwrap().max; let right_max = right.metadata.clone().unwrap().max; @@ -170,7 +172,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_add(&range_max.clone()) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() + range_max)) + .unwrap_or_else(|| (true, true, range_max + range_max)) }) }); @@ -187,7 +189,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { UExpression::add(left, right).with_max(max) } - Sub(box left, box right) => { + Sub(e) => { // let `target` the target bitwidth of `left` and `right` // `0 <= left <= max_left` // `0 <= right <= max_right` @@ -205,8 +207,8 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { // smaller or equal to N for target in {8, 16, 32} // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left_max = left.metadata.clone().unwrap().max; let right_bitwidth = right.metadata.clone().unwrap().bitwidth(); @@ -223,7 +225,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_add(&target_offset) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() + target_offset)) + .unwrap_or_else(|| (true, true, range_max + target_offset)) } else { left_max .checked_add(&offset) @@ -254,31 +256,31 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { UExpression::sub(left, right).with_max(max) } - Xor(box left, box right) => { + Xor(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::xor(force_reduce(left), force_reduce(right)).with_max(range_max) } - And(box left, box right) => { + And(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::and(force_reduce(left), force_reduce(right)).with_max(range_max) } - Or(box left, box right) => { + Or(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::or(force_reduce(left), force_reduce(right)).with_max(range_max) } - Mult(box left, box right) => { + Mult(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); let left_max = left.metadata.clone().unwrap().max; let right_max = right.metadata.clone().unwrap().max; @@ -294,7 +296,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { left_max .checked_mul(&range_max.clone()) .map(|max| (false, true, max)) - .unwrap_or_else(|| (true, true, range_max.clone() * range_max)) + .unwrap_or_else(|| (true, true, range_max * range_max)) }) }); @@ -311,52 +313,72 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { UExpression::mult(left, right).with_max(max) } - Div(box left, box right) => { + Div(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::div(force_reduce(left), force_reduce(right)).with_max(range_max) } - Rem(box left, box right) => { + Rem(e) => { // reduce the two terms - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); UExpression::rem(force_reduce(left), force_reduce(right)).with_max(range_max) } - Not(box e) => { - let e = self.fold_uint_expression(e); + Not(e) => { + let inner = self.fold_uint_expression(*e.inner); - UExpressionInner::Not(box force_reduce(e)) - .annotate(range) - .with_max(range_max) + UExpression::not(force_reduce(inner)).with_max(range_max) } - LeftShift(box e, by) => { + LeftShift(e) => { // reduce both terms - let e = self.fold_uint_expression(e); - - let e_max: num_bigint::BigUint = e.metadata.as_ref().unwrap().max.to_biguint(); - let max = e_max - .shl(by as usize) - .bitand(&(2_u128.pow(range as u32) - 1).into()); - - let max = T::try_from(max).unwrap(); - - UExpression::left_shift(force_reduce(e), by).with_max(max) + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); + + match right.into_inner() { + UExpressionInner::Value(by) => { + let e_max: num_bigint::BigUint = + left.metadata.as_ref().unwrap().max.to_biguint(); + let max = e_max + .shl(by.value as usize) + .bitand(&(2_u128.pow(range as u32) - 1).into()); + + let max = T::try_from(max).unwrap(); + + UExpression::left_shift( + force_reduce(left), + UExpression::value(by.value).annotate(UBitwidth::B32), + ) + .with_max(max) + } + _ => unreachable!(), + } } - RightShift(box e, by) => { + RightShift(e) => { // reduce both terms - let e = self.fold_uint_expression(e); - - let e_max: num_bigint::BigUint = e.metadata.as_ref().unwrap().max.to_biguint(); - let max = e_max - .bitand(&(2_u128.pow(range as u32) - 1).into()) - .shr(by as usize); - - let max = T::try_from(max).unwrap(); - - UExpression::right_shift(force_reduce(e), by).with_max(max) + let left = self.fold_uint_expression(*e.left); + let right = self.fold_uint_expression(*e.right); + + match right.into_inner() { + UExpressionInner::Value(by) => { + let e_max: num_bigint::BigUint = + left.metadata.as_ref().unwrap().max.to_biguint(); + let max = e_max + .bitand(&(2_u128.pow(range as u32) - 1).into()) + .shr(by.value as usize); + + let max = T::try_from(max).unwrap(); + + UExpression::right_shift( + force_reduce(left), + UExpression::value(by.value).annotate(UBitwidth::B32), + ) + .with_max(max) + } + _ => unreachable!(), + } } Conditional(e) => { let condition = self.fold_boolean_expression(*e.condition); @@ -379,44 +401,52 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { assert!(res.metadata.is_some()); - res + res.span(span) } - fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { - match s { - ZirStatement::Definition(a, e) => { - let e = self.fold_expression(e); + fn fold_return_statement(&mut self, s: ReturnStatement<'ast, T>) -> Vec> { + // we need to put back in range to return + vec![ZirStatement::ret( + s.inner + .into_iter() + .map(|e| match e { + ZirExpression::Uint(e) => { + let e = self.fold_uint_expression(e); + + let e = force_reduce(e); - let e = match e { - ZirExpression::Uint(i) => { - let i = force_no_reduce(i); - self.register(a.clone(), i.metadata.clone().unwrap()); - ZirExpression::Uint(i) + ZirExpression::Uint(e) } - e => e, - }; - vec![ZirStatement::Definition(a, e)] + e => self.fold_expression(e), + }) + .collect(), + )] + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Vec> { + let e = self.fold_expression(s.rhs); + + let e = match e { + ZirExpression::Uint(i) => { + let i = force_no_reduce(i); + self.register(s.assignee.clone(), i.metadata.clone().unwrap()); + ZirExpression::Uint(i) } - // we need to put back in range to return - ZirStatement::Return(expressions) => vec![ZirStatement::Return( - expressions - .into_iter() - .map(|e| match e { - ZirExpression::Uint(e) => { - let e = self.fold_uint_expression(e); - - let e = force_reduce(e); - - ZirExpression::Uint(e) - } - e => self.fold_expression(e), - }) - .collect(), - )], - ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall(embed, generics, arguments), - ) => { + e => e, + }; + vec![ZirStatement::definition(s.assignee, e)] + } + + fn fold_multiple_definition_statement( + &mut self, + s: MultipleDefinitionStatement<'ast, T>, + ) -> Vec> { + let lhs = s.assignees; + match s.rhs { + ZirExpressionList::EmbedCall(embed, generics, arguments) => { match embed { FlatEmbed::U64FromBits => { assert_eq!(lhs.len(), 1); @@ -467,74 +497,68 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { | FlatEmbed::U32ToBits | FlatEmbed::U64ToBits => { vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall( - embed, - generics, - arguments - .into_iter() - .map(|e| match e { - ZirExpression::Uint(e) => { - let e = self.fold_uint_expression(e); - let e = force_reduce(e); - ZirExpression::Uint(e) - } - e => self.fold_expression(e), - }) - .collect(), + MultipleDefinitionStatement::new( + lhs, + ZirExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|e| match e { + ZirExpression::Uint(e) => { + let e = self.fold_uint_expression(e); + let e = force_reduce(e); + ZirExpression::Uint(e) + } + e => self.fold_expression(e), + }) + .collect(), + ), ), )] } _ => { vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall( - embed, - generics, - arguments - .into_iter() - .map(|e| self.fold_expression(e)) - .collect(), + MultipleDefinitionStatement::new( + lhs, + ZirExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|e| self.fold_expression(e)) + .collect(), + ), ), )] } } } - ZirStatement::Assertion(BooleanExpression::UintEq(box left, box right), metadata) => { - let left = self.fold_uint_expression(left); - let right = self.fold_uint_expression(right); - - // we can only compare two unsigned integers if they are in range - let left = force_reduce(left); - let right = force_reduce(right); - - vec![ZirStatement::Assertion( - BooleanExpression::UintEq(box left, box right), - metadata, - )] - } - ZirStatement::Log(l, e) => vec![ZirStatement::Log( - l, - e.into_iter() - .map(|(t, e)| { - ( - t, - e.into_iter() - .map(|e| match e { - ZirExpression::Uint(e) => { - force_reduce(self.fold_uint_expression(e)).into() - } - e => self.fold_expression(e), - }) - .collect(), - ) - }) - .collect(), - )], - s => fold_statement(self, s), } } + fn fold_log_statement(&mut self, s: LogStatement<'ast, T>) -> Vec> { + vec![ZirStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() + .map(|(t, e)| { + ( + t, + e.into_iter() + .map(|e| match e { + ZirExpression::Uint(e) => { + force_reduce(self.fold_uint_expression(e)).into() + } + e => self.fold_expression(e), + }) + .collect(), + ) + }) + .collect(), + ))] + } + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { let id = match p.id.get_type() { Type::Uint(bitwidth) => { @@ -730,8 +754,8 @@ mod tests { assert_eq!( UintOptimizer::new() - .fold_uint_expression(UExpression::right_shift(left.clone(), right)), - UExpression::right_shift(left_expected, right_expected).with_max(output_max) + .fold_uint_expression(UExpression::right_shift(left.clone(), right.into())), + UExpression::right_shift(left_expected, right_expected.into()).with_max(output_max) ); } @@ -753,8 +777,8 @@ mod tests { assert_eq!( UintOptimizer::new() - .fold_uint_expression(UExpression::left_shift(left.clone(), right)), - UExpression::left_shift(left_expected, right_expected).with_max(output_max) + .fold_uint_expression(UExpression::left_shift(left.clone(), right.into())), + UExpression::left_shift(left_expected, right_expected.into()).with_max(output_max) ); } @@ -777,7 +801,7 @@ mod tests { assert_eq!( UintOptimizer::new() .fold_uint_expression(UExpression::conditional( - BooleanExpression::Value(true), + BooleanExpression::value(true), consequence, alternative )) diff --git a/zokrates_analysis/src/variable_write_remover.rs b/zokrates_analysis/src/variable_write_remover.rs index b218acc33..2d9ce8685 100644 --- a/zokrates_analysis/src/variable_write_remover.rs +++ b/zokrates_analysis/src/variable_write_remover.rs @@ -6,6 +6,7 @@ use std::collections::HashSet; use std::fmt; +use zokrates_ast::common::{Span, WithSpan}; use zokrates_ast::typed::result_folder::ResultFolder; use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::types::{MemberId, Type}; @@ -38,6 +39,7 @@ impl<'ast> VariableWriteRemover { indices: Vec>, new_expression: TypedExpression<'ast, T>, statements: &mut HashSet>, + span: Option, ) -> TypedExpression<'ast, T> { let mut indices = indices; @@ -54,26 +56,37 @@ impl<'ast> VariableWriteRemover { let tail = indices; match head { - Access::Select(box head) => { - statements.insert(TypedStatement::Assertion( - BooleanExpression::UintLt(box head.clone(), box size.into()), + Access::Select(head) => { + statements.insert(TypedStatement::assertion( + BooleanExpression::uint_lt( + head.clone(), + UExpression::from(size).span(span), + ) + .span(span), RuntimeError::SelectRangeCheck, )); - ArrayExpressionInner::Value( + ArrayExpression::value( (0..size) .map(|i| match inner_ty { Type::Int => unreachable!(), Type::Array(..) => ArrayExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - ArrayExpression::select(base.clone(), i).into(), + ArrayExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Array(e) => e, e => unreachable!( @@ -81,20 +94,32 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - ArrayExpression::select(base.clone(), i), + ArrayExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::Struct(..) => StructExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - StructExpression::select(base.clone(), i).into(), + StructExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Struct(e) => e, e => unreachable!( @@ -102,20 +127,31 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - StructExpression::select(base.clone(), i), + StructExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::Tuple(..) => TupleExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - TupleExpression::select(base.clone(), i).into(), + TupleExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Tuple(e) => e, e => unreachable!( @@ -123,21 +159,31 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - TupleExpression::select(base.clone(), i), + TupleExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::FieldElement => FieldElementExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - FieldElementExpression::select(base.clone(), i) - .into(), + FieldElementExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::FieldElement(e) => e, e => unreachable!( @@ -145,20 +191,31 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - FieldElementExpression::select(base.clone(), i), + FieldElementExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::Boolean => BooleanExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - BooleanExpression::select(base.clone(), i).into(), + BooleanExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Boolean(e) => e, e => unreachable!( @@ -166,20 +223,31 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - BooleanExpression::select(base.clone(), i), + BooleanExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), Type::Uint(..) => UExpression::conditional( - BooleanExpression::UintEq(EqExpression::new( - i.into(), + BooleanExpression::uint_eq( + UExpression::from(i).span(span), head.clone(), - )), + ) + .span(span), match Self::choose_many( - UExpression::select(base.clone(), i).into(), + UExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span) + .into(), tail.clone(), new_expression.clone(), statements, + span, ) { TypedExpression::Uint(e) => e, e => unreachable!( @@ -187,15 +255,18 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - UExpression::select(base.clone(), i), + UExpression::select( + base.clone(), + UExpression::from(i).span(span), + ) + .span(span), ConditionalKind::IfElse, ) .into(), }) - .collect::>() - .into(), + .collect::>(), ) - .annotate(inner_ty.clone(), size) + .annotate(ArrayType::new(inner_ty.clone(), size)) .into() } _ => unreachable!(), @@ -228,6 +299,7 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { FieldElementExpression::member(base.clone(), member.id) @@ -242,9 +314,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - UExpression::member(base.clone(), member.id).into() + UExpression::member(base.clone(), member.id) + .span(span) + .into() } } Type::Boolean => { @@ -258,6 +333,7 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { BooleanExpression::member(base.clone(), member.id) @@ -272,9 +348,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - ArrayExpression::member(base.clone(), member.id).into() + ArrayExpression::member(base.clone(), member.id) + .span(span) + .into() } } Type::Struct(..) => { @@ -288,9 +367,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - StructExpression::member(base.clone(), member.id).into() + StructExpression::member(base.clone(), member.id) + .span(span) + .into() } } Type::Tuple(..) => { @@ -301,9 +383,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - TupleExpression::member(base.clone(), member.id).into() + TupleExpression::member(base.clone(), member.id) + .span(span) + .into() } } }) @@ -341,9 +426,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - FieldElementExpression::element(base.clone(), i).into() + FieldElementExpression::element(base.clone(), i) + .span(span) + .into() } } Type::Uint(..) => { @@ -353,9 +441,10 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - UExpression::element(base.clone(), i).into() + UExpression::element(base.clone(), i).span(span).into() } } Type::Boolean => { @@ -366,9 +455,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - BooleanExpression::element(base.clone(), i).into() + BooleanExpression::element(base.clone(), i) + .span(span) + .into() } } Type::Array(..) => { @@ -378,9 +470,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - ArrayExpression::element(base.clone(), i).into() + ArrayExpression::element(base.clone(), i) + .span(span) + .into() } } Type::Struct(..) => { @@ -391,9 +486,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - StructExpression::element(base.clone(), i).into() + StructExpression::element(base.clone(), i) + .span(span) + .into() } } Type::Tuple(..) => { @@ -403,9 +501,12 @@ impl<'ast> VariableWriteRemover { tail.clone(), new_expression.clone(), statements, + span, ) } else { - TupleExpression::element(base.clone(), i).into() + TupleExpression::element(base.clone(), i) + .span(span) + .into() } } }) @@ -424,27 +525,28 @@ impl<'ast> VariableWriteRemover { #[derive(Clone, Debug)] enum Access<'ast, T: Field> { - Select(Box>), + Select(UExpression<'ast, T>), Member(MemberId), Element(u32), } + /// Turn an assignee into its representation as a base variable and a list accesses /// a[2][3][4] -> (a, [2, 3, 4]) fn linear(a: TypedAssignee) -> (Variable, Vec>) { match a { TypedAssignee::Identifier(v) => (v, vec![]), - TypedAssignee::Select(box array, box index) => { - let (v, mut indices) = linear(array); - indices.push(Access::Select(box index)); + TypedAssignee::Select(array, index) => { + let (v, mut indices) = linear(*array); + indices.push(Access::Select(*index)); (v, indices) } - TypedAssignee::Member(box s, m) => { - let (v, mut indices) = linear(s); + TypedAssignee::Member(s, m) => { + let (v, mut indices) = linear(*s); indices.push(Access::Member(m)); (v, indices) } - TypedAssignee::Element(box s, i) => { - let (v, mut indices) = linear(s); + TypedAssignee::Element(s, i) => { + let (v, mut indices) = linear(*s); indices.push(Access::Element(i)); (v, indices) } @@ -454,51 +556,58 @@ fn linear(a: TypedAssignee) -> (Variable, Vec>) { fn is_constant(assignee: &TypedAssignee) -> bool { match assignee { TypedAssignee::Identifier(_) => true, - TypedAssignee::Select(box assignee, box index) => match index.as_inner() { + TypedAssignee::Select(assignee, index) => match index.as_inner() { UExpressionInner::Value(_) => is_constant(assignee), _ => false, }, - TypedAssignee::Member(box assignee, _) => is_constant(assignee), - TypedAssignee::Element(box assignee, _) => is_constant(assignee), + TypedAssignee::Member(ref assignee, _) => is_constant(assignee), + TypedAssignee::Element(ref assignee, _) => is_constant(assignee), } } impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { type Error = Error; - fn fold_assembly_statement( + fn fold_assembly_assignment( &mut self, - s: TypedAssemblyStatement<'ast, T>, + s: AssemblyAssignment<'ast, T>, ) -> Result>, Self::Error> { - match s { - TypedAssemblyStatement::Assignment(a, e) if is_constant(&a) => { - Ok(vec![TypedAssemblyStatement::Assignment(a, e)]) - } - TypedAssemblyStatement::Assignment(a, _) => Err(Error(format!( + match is_constant(&s.assignee) { + true => Ok(vec![TypedAssemblyStatement::Assignment(s)]), + false => Err(Error(format!( "Cannot assign to an assignee with a variable index `{}`", - a + s.assignee ))), - s => Ok(vec![s]), } } - fn fold_statement( + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + Ok(vec![TypedAssemblyStatement::Constraint(s)]) + } + + fn fold_definition_statement( &mut self, - s: TypedStatement<'ast, T>, + s: DefinitionStatement<'ast, T>, ) -> Result>, Self::Error> { - match s { - TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => { + let span = s.get_span(); + + match s.rhs { + DefinitionRhs::Expression(expr) => { + let a = s.assignee; let expr = self.fold_expression(expr)?; - if is_constant(&assignee) { - Ok(vec![TypedStatement::definition(assignee, expr)]) + if is_constant(&a) { + Ok(vec![TypedStatement::definition(a, expr).span(span)]) } else { // Note: here we redefine the whole object, ideally we would only redefine some of it // Example: `a[0][i] = 42` we redefine `a` but we could redefine just `a[0]` - let (variable, indices) = linear(assignee); + let (variable, indices) = linear(a); - let base = match variable.get_type() { + let base: TypedExpression<'ast, T> = match variable.get_type() { Type::Int => unreachable!(), Type::FieldElement => { FieldElementExpression::identifier(variable.id.clone()).into() @@ -508,7 +617,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { .annotate(bitwidth) .into(), Type::Array(array_type) => ArrayExpression::identifier(variable.id.clone()) - .annotate(*array_type.ty, *array_type.size) + .annotate(array_type) .into(), Type::Struct(members) => StructExpression::identifier(variable.id.clone()) .annotate(members) @@ -518,31 +627,34 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for VariableWriteRemover { .into(), }; + let base = base.span(span); + let base = self.fold_expression(base)?; let indices = indices .into_iter() .map(|a| match a { - Access::Select(box i) => { - Ok(Access::Select(box self.fold_uint_expression(i)?)) - } + Access::Select(i) => Ok(Access::Select(self.fold_uint_expression(i)?)), a => Ok(a), }) .collect::>()?; let mut range_checks = HashSet::new(); - let e = Self::choose_many(base, indices, expr, &mut range_checks); + let e = Self::choose_many(base, indices, expr, &mut range_checks, span); Ok(range_checks .into_iter() - .chain(std::iter::once(TypedStatement::definition( - TypedAssignee::Identifier(variable), - e, - ))) + .chain(std::iter::once( + TypedStatement::definition( + TypedAssignee::Identifier(variable.span(span)), + e, + ) + .span(span), + )) .collect()) } } - s => fold_statement(self, s), + _ => fold_definition_statement(self, s), } } } diff --git a/zokrates_analysis/src/zir_propagation.rs b/zokrates_analysis/src/zir_propagation.rs index 8fdaaf73c..0d3fb1c73 100644 --- a/zokrates_analysis/src/zir_propagation.rs +++ b/zokrates_analysis/src/zir_propagation.rs @@ -2,8 +2,11 @@ use num::traits::Pow; use num_bigint::BigUint; use std::collections::HashMap; use std::fmt; -use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub}; +use std::ops::*; +use zokrates_ast::common::{ResultFold, WithSpan}; use zokrates_ast::zir::types::UBitwidth; +use zokrates_ast::zir::AssertionStatement; +use zokrates_ast::zir::IfElseStatement; use zokrates_ast::zir::{ result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Constant, Expr, Id, IdentifierExpression, IdentifierOrExpression, SelectExpression, SelectOrExpression, @@ -79,158 +82,158 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { }) } - fn fold_assembly_statement( + fn fold_assembly_assignment( &mut self, - s: ZirAssemblyStatement<'ast, T>, + s: zokrates_ast::zir::AssemblyAssignment<'ast, T>, ) -> Result>, Self::Error> { - match s { - ZirAssemblyStatement::Assignment(assignees, function) => { - let assignees: Vec<_> = assignees - .into_iter() - .map(|a| self.fold_assignee(a)) - .collect::>()?; - - let function = self.fold_function(function)?; - - match &function.statements.last().unwrap() { - ZirStatement::Return(values) => { - if values.iter().all(|v| v.is_constant()) { - self.constants.extend( - assignees - .into_iter() - .zip(values.iter()) - .map(|(a, v)| (a.id, v.clone())), - ); - Ok(vec![]) - } else { - assignees.iter().for_each(|a| { - self.constants.remove(&a.id); - }); - Ok(vec![ZirAssemblyStatement::Assignment(assignees, function)]) - } - } - _ => { - assignees.iter().for_each(|a| { - self.constants.remove(&a.id); - }); - Ok(vec![ZirAssemblyStatement::Assignment(assignees, function)]) - } + let assignees: Vec<_> = s + .assignee + .into_iter() + .map(|a| self.fold_assignee(a)) + .collect::>()?; + + let function = self.fold_function(s.expression)?; + + match &function.statements.last().unwrap() { + ZirStatement::Return(s) => { + if s.inner.iter().all(|v| v.is_constant()) { + self.constants.extend( + assignees + .into_iter() + .zip(s.inner.iter()) + .map(|(a, v)| (a.id, v.clone())), + ); + Ok(vec![]) + } else { + assignees.iter().for_each(|a| { + self.constants.remove(&a.id); + }); + Ok(vec![ZirAssemblyStatement::assignment(assignees, function)]) } } - ZirAssemblyStatement::Constraint(left, right, metadata) => { - let left = self.fold_field_expression(left)?; - let right = self.fold_field_expression(right)?; - - // a bit hacky, but we use a fake boolean expression to check this - let is_equal = BooleanExpression::FieldEq(box left.clone(), box right.clone()); - let is_equal = self.fold_boolean_expression(is_equal)?; - - match is_equal { - BooleanExpression::Value(true) => Ok(vec![]), - BooleanExpression::Value(false) => { - Err(Error::AssertionFailed(RuntimeError::SourceAssertion( - metadata - .message(Some(format!("In asm block: `{} !== {}`", left, right))), - ))) - } - _ => Ok(vec![ZirAssemblyStatement::Constraint( - left, right, metadata, - )]), - } + _ => { + assignees.iter().for_each(|a| { + self.constants.remove(&a.id); + }); + Ok(vec![ZirAssemblyStatement::assignment(assignees, function)]) } } } - fn fold_statement( + fn fold_assembly_constraint( &mut self, - s: ZirStatement<'ast, T>, + s: zokrates_ast::zir::AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + let left = self.fold_field_expression(s.left)?; + let right = self.fold_field_expression(s.right)?; + + // a bit hacky, but we use a fake boolean expression to check this + let is_equal = BooleanExpression::field_eq(left.clone(), right.clone()); + let is_equal = self.fold_boolean_expression(is_equal)?; + + match is_equal { + BooleanExpression::Value(v) if v.value => Ok(vec![]), + BooleanExpression::Value(v) if !v.value => { + Err(Error::AssertionFailed(RuntimeError::SourceAssertion( + s.metadata + .message(Some(format!("In asm block: `{} !== {}`", left, right))), + ))) + } + _ => Ok(vec![ZirAssemblyStatement::constraint( + left, right, s.metadata, + )]), + } + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, ) -> Result>, Self::Error> { - match s { - ZirStatement::Assertion(e, error) => match self.fold_boolean_expression(e)? { - BooleanExpression::Value(true) => Ok(vec![]), - BooleanExpression::Value(false) => Err(Error::AssertionFailed(error)), - e => Ok(vec![ZirStatement::Assertion(e, error)]), - }, - ZirStatement::Definition(a, e) => { - let e = self.fold_expression(e)?; - match e { - ZirExpression::FieldElement(FieldElementExpression::Number(..)) - | ZirExpression::Boolean(BooleanExpression::Value(..)) - | ZirExpression::Uint(UExpression { - inner: UExpressionInner::Value(..), - .. - }) => { - self.constants.insert(a.id, e); - Ok(vec![]) - } - _ => { - self.constants.remove(&a.id); - Ok(vec![ZirStatement::Definition(a, e)]) - } - } + match self.fold_boolean_expression(s.expression)? { + BooleanExpression::Value(v) if v.value => Ok(vec![]), + BooleanExpression::Value(v) if !v.value => Err(Error::AssertionFailed(s.error)), + e => Ok(vec![ZirStatement::assertion(e, s.error)]), + } + } + + fn fold_definition_statement( + &mut self, + s: zokrates_ast::zir::DefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + let e = self.fold_expression(s.rhs)?; + match e { + ZirExpression::FieldElement(FieldElementExpression::Value(..)) + | ZirExpression::Boolean(BooleanExpression::Value(..)) + | ZirExpression::Uint(UExpression { + inner: UExpressionInner::Value(..), + .. + }) => { + self.constants.insert(s.assignee.id, e); + Ok(vec![]) } - ZirStatement::IfElse(e, consequence, alternative) => { - match self.fold_boolean_expression(e)? { - BooleanExpression::Value(true) => Ok(consequence + _ => { + self.constants.remove(&s.assignee.id); + Ok(vec![ZirStatement::definition(s.assignee, e)]) + } + } + } + + fn fold_if_else_statement( + &mut self, + s: zokrates_ast::zir::IfElseStatement<'ast, T>, + ) -> Result>, Self::Error> { + { + match self.fold_boolean_expression(s.condition)? { + BooleanExpression::Value(v) if v.value => Ok(s + .consequence + .into_iter() + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect()), + BooleanExpression::Value(v) if !v.value => Ok(s + .alternative + .into_iter() + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect()), + e => Ok(vec![ZirStatement::IfElse(IfElseStatement::new( + e, + s.consequence .into_iter() .map(|s| self.fold_statement(s)) .collect::, _>>()? .into_iter() .flatten() - .collect()), - BooleanExpression::Value(false) => Ok(alternative + .collect(), + s.alternative .into_iter() .map(|s| self.fold_statement(s)) .collect::, _>>()? .into_iter() .flatten() - .collect()), - e => Ok(vec![ZirStatement::IfElse( - e, - consequence - .into_iter() - .map(|s| self.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - alternative - .into_iter() - .map(|s| self.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - )]), - } - } - ZirStatement::MultipleDefinition(assignees, list) => { - for a in &assignees { - self.constants.remove(&a.id); - } - Ok(vec![ZirStatement::MultipleDefinition( - assignees, - self.fold_expression_list(list)?, - )]) + .collect(), + ))]), } - ZirStatement::Assembly(statements) => { - let statements: Vec<_> = statements - .into_iter() - .map(|s| self.fold_assembly_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(); - match statements.len() { - 0 => Ok(vec![]), - _ => Ok(vec![ZirStatement::Assembly(statements)]), - } - } - _ => fold_statement(self, s), } } - fn fold_identifier_expression + Id<'ast, T> + ResultFold<'ast, T>>( + fn fold_multiple_definition_statement( + &mut self, + s: zokrates_ast::zir::MultipleDefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + for a in &s.assignees { + self.constants.remove(&a.id); + } + fold_multiple_definition_statement(self, s) + } + + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + ResultFold, + >( &mut self, _: &E::Ty, id: IdentifierExpression<'ast, E>, @@ -241,169 +244,175 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { } } - fn fold_field_expression( + fn fold_field_expression_cases( &mut self, e: FieldElementExpression<'ast, T>, ) -> Result, Self::Error> { match e { - FieldElementExpression::Number(n) => Ok(FieldElementExpression::Number(n)), - FieldElementExpression::Add(box e1, box e2) => { - match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n), e) - | (e, FieldElementExpression::Number(n)) - if n == T::from(0) => + FieldElementExpression::Value(n) => Ok(FieldElementExpression::Value(n)), + FieldElementExpression::Add(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n), e) + | (e, FieldElementExpression::Value(n)) + if n.value == T::from(0) => { - Ok(e) + e } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 + n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::number(n1.value + n2.value) } - (e1, e2) => Ok(FieldElementExpression::Add(box e1, box e2)), + (e1, e2) => FieldElementExpression::add(e1, e2), } + .span(e.span)) } - FieldElementExpression::Sub(box e1, box e2) => { - match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (e, FieldElementExpression::Number(n)) if n == T::from(0) => Ok(e), - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 - n2)) + FieldElementExpression::Sub(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (e, FieldElementExpression::Value(n)) if n.value == T::from(0) => e, + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::number(n1.value - n2.value) } - (e1, e2) => Ok(FieldElementExpression::Sub(box e1, box e2)), + (e1, e2) => FieldElementExpression::sub(e1, e2), } + .span(e.span)) } - FieldElementExpression::Mult(box e1, box e2) => { - match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (FieldElementExpression::Number(n), _) - | (_, FieldElementExpression::Number(n)) - if n == T::from(0) => + FieldElementExpression::Mult(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + Ok(match (left, right) { + (FieldElementExpression::Value(n), _) + | (_, FieldElementExpression::Value(n)) + if n.value == T::from(0) => { - Ok(FieldElementExpression::Number(T::from(0))) + FieldElementExpression::number(T::from(0)) } - (FieldElementExpression::Number(n), e) - | (e, FieldElementExpression::Number(n)) - if n == T::from(1) => + (FieldElementExpression::Value(n), e) + | (e, FieldElementExpression::Value(n)) + if n.value == T::from(1) => { - Ok(e) + e } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 * n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + FieldElementExpression::number(n1.value * n2.value) } - (e1, e2) => Ok(FieldElementExpression::Mult(box e1, box e2)), + (e1, e2) => FieldElementExpression::mul(e1, e2), } + .span(e.span)) } - FieldElementExpression::Div(box e1, box e2) => { - match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, - ) { - (_, FieldElementExpression::Number(n)) if n == T::from(0) => { + FieldElementExpression::Div(e) => { + let left = self.fold_field_expression(*e.left)?; + let right = self.fold_field_expression(*e.right)?; + + match (left, right) { + (_, FieldElementExpression::Value(n)) if n.value == T::from(0) => { Err(Error::DivisionByZero) } - (e, FieldElementExpression::Number(n)) if n == T::from(1) => Ok(e), - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number(n1 / n2)) + (e, FieldElementExpression::Value(n)) if n.value == T::from(1) => Ok(e), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::number(n1.value / n2.value).span(e.span)) } - (e1, e2) => Ok(FieldElementExpression::Div(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::div(e1, e2).span(e.span)), } } - FieldElementExpression::Pow(box e, box exponent) => { - let exponent = self.fold_uint_expression(exponent)?; - match (self.fold_field_expression(e)?, exponent.into_inner()) { - (_, UExpressionInner::Value(n2)) if n2 == 0 => { - Ok(FieldElementExpression::Number(T::from(1))) + FieldElementExpression::Pow(e) => { + let exponent = self.fold_uint_expression(*e.right)?; + match (self.fold_field_expression(*e.left)?, exponent.into_inner()) { + (_, UExpressionInner::Value(n2)) if n2.value == 0 => { + Ok(FieldElementExpression::number(T::from(1))) } - (e, UExpressionInner::Value(n2)) if n2 == 1 => Ok(e), - (FieldElementExpression::Number(n), UExpressionInner::Value(e)) => { - Ok(FieldElementExpression::Number(n.pow(e as usize))) + (e, UExpressionInner::Value(n2)) if n2.value == 1 => Ok(e), + (FieldElementExpression::Value(n), UExpressionInner::Value(e)) => Ok( + FieldElementExpression::number(n.value.pow(e.value as usize)), + ), + (e, exp) => { + Ok(FieldElementExpression::pow(e, exp.annotate(UBitwidth::B32)) + .into_inner()) } - (e, exp) => Ok(FieldElementExpression::Pow( - box e, - box exp.annotate(UBitwidth::B32), - )), } } - FieldElementExpression::Xor(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::Xor(e) => { + let e1 = self.fold_field_expression(*e.right)?; + let e2 = self.fold_field_expression(*e.left)?; match (e1, e2) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitxor(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))), - (e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)), + (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::value(T::from(0))), + (e1, e2) => Ok(FieldElementExpression::bitxor(e1, e2)), } } - FieldElementExpression::And(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::And(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (_, FieldElementExpression::Number(n)) - | (FieldElementExpression::Number(n), _) - if n == T::from(0) => + (_, FieldElementExpression::Value(n)) + | (FieldElementExpression::Value(n), _) + if n.value == T::from(0) => { - Ok(FieldElementExpression::Number(n)) + Ok(FieldElementExpression::Value(n)) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitand(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitand(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) => Ok(FieldElementExpression::And(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::bitand(e1, e2)), } } - FieldElementExpression::Or(box e1, box e2) => { - let e1 = self.fold_field_expression(e1)?; - let e2 = self.fold_field_expression(e2)?; + FieldElementExpression::Or(e) => { + let e1 = self.fold_field_expression(*e.left)?; + let e2 = self.fold_field_expression(*e.right)?; match (e1, e2) { - (e, FieldElementExpression::Number(n)) - | (FieldElementExpression::Number(n), e) - if n == T::from(0) => + (e, FieldElementExpression::Value(n)) + | (FieldElementExpression::Value(n), e) + if n.value == T::from(0) => { Ok(e) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(FieldElementExpression::Number( - T::try_from(n1.to_biguint().bitor(n2.to_biguint())).unwrap(), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(FieldElementExpression::value( + T::try_from(n1.value.to_biguint().bitor(n2.value.to_biguint())) + .unwrap(), )) } - (e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)), + (e1, e2) => Ok(FieldElementExpression::bitor(e1, e2)), } } - FieldElementExpression::LeftShift(box e, box by) => { - let e = self.fold_field_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e, by) { + FieldElementExpression::LeftShift(e) => { + let expr = self.fold_field_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; + match (expr, by) { ( e, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by == 0 => Ok(e), + ) if by.value == 0 => Ok(e), ( _, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by as usize >= T::get_required_bits() => { - Ok(FieldElementExpression::Number(T::from(0))) + ) if by.value as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::value(T::from(0))) } ( - FieldElementExpression::Number(n), + FieldElementExpression::Value(n), UExpression { inner: UExpressionInner::Value(by), .. @@ -412,197 +421,206 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { let two = BigUint::from(2usize); let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); - Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(), + Ok(FieldElementExpression::value( + T::try_from(n.value.to_biguint().shl(by.value as usize).bitand(mask)) + .unwrap(), )) } - (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), + (expr, by) => Ok(FieldElementExpression::left_shift(expr, by)), } } - FieldElementExpression::RightShift(box e, box by) => { - let e = self.fold_field_expression(e)?; - let by = self.fold_uint_expression(by)?; - match (e, by) { + FieldElementExpression::RightShift(e) => { + let expr = self.fold_field_expression(*e.left)?; + let by = self.fold_uint_expression(*e.right)?; + match (expr, by) { ( e, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by == 0 => Ok(e), + ) if by.value == 0 => Ok(e), ( _, UExpression { inner: UExpressionInner::Value(by), .. }, - ) if by as usize >= T::get_required_bits() => { - Ok(FieldElementExpression::Number(T::from(0))) + ) if by.value as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::value(T::from(0))) } ( - FieldElementExpression::Number(n), + FieldElementExpression::Value(n), UExpression { inner: UExpressionInner::Value(by), .. }, - ) => Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shr(by as usize)).unwrap(), + ) => Ok(FieldElementExpression::value( + T::try_from(n.value.to_biguint().shr(by.value as usize)).unwrap(), )), - (e, by) => Ok(FieldElementExpression::RightShift(box e, box by)), + (expr, by) => Ok(FieldElementExpression::right_shift(expr, by)), } } - e => fold_field_expression(self, e), + e => fold_field_expression_cases(self, e), } } - fn fold_boolean_expression( + fn fold_boolean_expression_cases( &mut self, e: BooleanExpression<'ast, T>, ) -> Result, Error> { match e { BooleanExpression::Value(v) => Ok(BooleanExpression::Value(v)), - BooleanExpression::FieldLt(box e1, box e2) => { + BooleanExpression::FieldLt(e) => { match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, + self.fold_field_expression(*e.left)?, + self.fold_field_expression(*e.right)?, ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 < n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(BooleanExpression::value(n1 < n2)) } - (_, FieldElementExpression::Number(c)) if c == T::zero() => { - Ok(BooleanExpression::Value(false)) + (_, FieldElementExpression::Value(c)) if c.value == T::zero() => { + Ok(BooleanExpression::value(false)) } - (FieldElementExpression::Number(c), _) if c == T::max_value() => { - Ok(BooleanExpression::Value(false)) + (FieldElementExpression::Value(c), _) if c.value == T::max_value() => { + Ok(BooleanExpression::value(false)) } - (e1, e2) => Ok(BooleanExpression::FieldLt(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::field_lt(e1, e2)), } } - BooleanExpression::FieldLe(box e1, box e2) => { + BooleanExpression::FieldLe(e) => { match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, + self.fold_field_expression(*e.left)?, + self.fold_field_expression(*e.right)?, ) { - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - Ok(BooleanExpression::Value(n1 <= n2)) + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { + Ok(BooleanExpression::value(n1 <= n2)) } - (e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::field_le(e1, e2)), } } - BooleanExpression::FieldEq(box e1, box e2) => { + BooleanExpression::FieldEq(e) => { match ( - self.fold_field_expression(e1)?, - self.fold_field_expression(e2)?, + self.fold_field_expression(*e.left)?, + self.fold_field_expression(*e.right)?, ) { - (FieldElementExpression::Number(v1), FieldElementExpression::Number(v2)) => { - Ok(BooleanExpression::Value(v1.eq(&v2))) + (FieldElementExpression::Value(v1), FieldElementExpression::Value(v2)) => { + Ok(BooleanExpression::value(v1.eq(&v2))) } (e1, e2) => { if e1.eq(&e2) { - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) } else { - Ok(BooleanExpression::FieldEq(box e1, box e2)) + Ok(BooleanExpression::field_eq(e1, e2)) } } } } - BooleanExpression::UintLt(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintLt(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(BooleanExpression::Value(v1 < v2)) + Ok(BooleanExpression::value(v1 < v2)) } - _ => Ok(BooleanExpression::UintLt(box e1, box e2)), + _ => Ok(BooleanExpression::uint_lt(e1, e2)), } } - BooleanExpression::UintLe(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintLe(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(BooleanExpression::Value(v1 <= v2)) + Ok(BooleanExpression::value(v1 <= v2)) } - _ => Ok(BooleanExpression::UintLe(box e1, box e2)), + _ => Ok(BooleanExpression::uint_le(e1, e2)), } } - BooleanExpression::UintEq(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + BooleanExpression::UintEq(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - Ok(BooleanExpression::Value(v1 == v2)) + Ok(BooleanExpression::value(v1 == v2)) } _ => { if e1.eq(&e2) { - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) } else { - Ok(BooleanExpression::UintEq(box e1, box e2)) + Ok(BooleanExpression::uint_eq(e1, e2)) } } } } - BooleanExpression::BoolEq(box e1, box e2) => { + BooleanExpression::BoolEq(e) => { match ( - self.fold_boolean_expression(e1)?, - self.fold_boolean_expression(e2)?, + self.fold_boolean_expression(*e.left)?, + self.fold_boolean_expression(*e.right)?, ) { (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - Ok(BooleanExpression::Value(v1 == v2)) + Ok(BooleanExpression::value(v1 == v2)) } (e1, e2) => { if e1.eq(&e2) { - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) } else { - Ok(BooleanExpression::BoolEq(box e1, box e2)) + Ok(BooleanExpression::bool_eq(e1, e2)) } } } } - BooleanExpression::Or(box e1, box e2) => { + BooleanExpression::Or(e) => { match ( - self.fold_boolean_expression(e1)?, - self.fold_boolean_expression(e2)?, + self.fold_boolean_expression(*e.left)?, + self.fold_boolean_expression(*e.right)?, ) { (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - Ok(BooleanExpression::Value(v1 || v2)) + Ok(BooleanExpression::value(v1.value || v2.value)) } - (_, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), _) => { - Ok(BooleanExpression::Value(true)) + (_, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), _) + if v.value => + { + Ok(BooleanExpression::value(true)) } - (e, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), e) => { + (e, BooleanExpression::Value(v)) | (BooleanExpression::Value(v), e) + if !v.value => + { Ok(e) } - (e1, e2) => Ok(BooleanExpression::Or(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::bitor(e1, e2)), } } - BooleanExpression::And(box e1, box e2) => { + BooleanExpression::And(e) => { match ( - self.fold_boolean_expression(e1)?, - self.fold_boolean_expression(e2)?, + self.fold_boolean_expression(*e.left)?, + self.fold_boolean_expression(*e.right)?, ) { - (BooleanExpression::Value(true), e) | (e, BooleanExpression::Value(true)) => { + (BooleanExpression::Value(v), e) | (e, BooleanExpression::Value(v)) + if v.value => + { Ok(e) } - (BooleanExpression::Value(false), _) | (_, BooleanExpression::Value(false)) => { - Ok(BooleanExpression::Value(false)) + (BooleanExpression::Value(v), _) | (_, BooleanExpression::Value(v)) + if !v.value => + { + Ok(BooleanExpression::value(false)) } - (e1, e2) => Ok(BooleanExpression::And(box e1, box e2)), + (e1, e2) => Ok(BooleanExpression::bitand(e1, e2)), } } - BooleanExpression::Not(box e) => match self.fold_boolean_expression(e)? { - BooleanExpression::Value(v) => Ok(BooleanExpression::Value(!v)), - e => Ok(BooleanExpression::Not(box e)), + BooleanExpression::Not(e) => match self.fold_boolean_expression(*e.inner)? { + BooleanExpression::Value(v) => Ok(BooleanExpression::value(!v.value)), + e => Ok(BooleanExpression::not(e)), }, - e => fold_boolean_expression(self, e), + e => fold_boolean_expression_cases(self, e), } } fn fold_select_expression< - E: Clone + Expr<'ast, T> + ResultFold<'ast, T> + zokrates_ast::zir::Select<'ast, T>, + E: Clone + Expr<'ast, T> + ResultFold + zokrates_ast::zir::Select<'ast, T>, >( &mut self, _: &E::Ty, @@ -617,9 +635,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { match index.as_inner() { UExpressionInner::Value(v) => array - .get(*v as usize) + .get(v.value as usize) .cloned() - .ok_or(Error::OutOfBounds(*v as usize, array.len())) + .ok_or(Error::OutOfBounds(v.value as usize, array.len())) .map(|e| SelectOrExpression::Expression(e.into_inner())), _ => Ok(SelectOrExpression::Expression( E::select(array, index).into_inner(), @@ -627,185 +645,221 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { } } - fn fold_uint_expression_inner( + fn fold_uint_expression_cases( &mut self, bitwidth: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, Self::Error> { match e { UExpressionInner::Value(v) => Ok(UExpressionInner::Value(v)), - UExpressionInner::Add(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Add(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (UExpressionInner::Value(0), e) | (e, UExpressionInner::Value(0)) => Ok(e), - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( - UExpressionInner::Value((n1 + n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + (UExpressionInner::Value(v), e) | (e, UExpressionInner::Value(v)) + if v.value == 0 => + { + Ok(e) + } + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + Ok(UExpression::value( + (n1.value + n2.value) % 2_u128.pow(bitwidth.to_usize() as u32), + )) + } + (e1, e2) => Ok( + UExpression::add(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), ), - (e1, e2) => Ok(UExpressionInner::Add( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), } } - UExpressionInner::Sub(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Sub(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(0)) => Ok(e), + (e, UExpressionInner::Value(v)) if v.value == 0 => Ok(e), (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(UExpressionInner::Value( - n1.wrapping_sub(n2) % 2_u128.pow(bitwidth.to_usize() as u32), + Ok(UExpression::value( + n1.value.wrapping_sub(n2.value) + % 2_u128.pow(bitwidth.to_usize() as u32), )) } - (e1, e2) => Ok(UExpressionInner::Sub( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => Ok( + UExpression::sub(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), + ), } } - UExpressionInner::Mult(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Mult(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (_, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), _) => { - Ok(UExpressionInner::Value(0)) + (_, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), _) + if v.value == 0 => + { + Ok(UExpression::value(0)) + } + (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) + if v.value == 1 => + { + Ok(e) } - (e, UExpressionInner::Value(1)) | (UExpressionInner::Value(1), e) => Ok(e), - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( - UExpressionInner::Value((n1 * n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + Ok(UExpression::value( + (n1.value * n2.value) % 2_u128.pow(bitwidth.to_usize() as u32), + )) + } + (e1, e2) => Ok( + UExpression::mult(e1.annotate(bitwidth), e2.annotate(bitwidth)) + .into_inner(), ), - (e1, e2) => Ok(UExpressionInner::Mult( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), } } - UExpressionInner::Div(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Div(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (_, UExpressionInner::Value(n)) if n == 0 => Err(Error::DivisionByZero), - (e, UExpressionInner::Value(n)) if n == 1 => Ok(e), - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( - UExpressionInner::Value((n1 / n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + (_, UExpressionInner::Value(n)) if n.value == 0 => Err(Error::DivisionByZero), + (e, UExpressionInner::Value(n)) if n.value == 1 => Ok(e), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + Ok(UExpression::value( + (n1.value / n2.value) % 2_u128.pow(bitwidth.to_usize() as u32), + )) + } + (e1, e2) => Ok( + UExpression::div(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), ), - (e1, e2) => Ok(UExpressionInner::Div( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), } } - UExpressionInner::Rem(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Rem(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( - UExpressionInner::Value((n1 % n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + Ok(UExpression::value( + (n1.value % n2.value) % 2_u128.pow(bitwidth.to_usize() as u32), + )) + } + (e1, e2) => Ok( + UExpression::rem(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), ), - (e1, e2) => Ok(UExpressionInner::Rem( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), } } - UExpressionInner::Xor(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Xor(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(UExpressionInner::Value(n1 ^ n2)) + Ok(UExpression::value(n1.value ^ n2.value)) } - (e1, e2) if e1.eq(&e2) => Ok(UExpressionInner::Value(0)), - (e1, e2) => Ok(UExpressionInner::Xor( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) if e1.eq(&e2) => Ok(UExpression::value(0)), + (e1, e2) => Ok( + UExpression::xor(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), + ), } } - UExpressionInner::And(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::And(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { (e, UExpressionInner::Value(n)) | (UExpressionInner::Value(n), e) - if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => + if n.value == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => { Ok(e) } - (_, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), _) => { - Ok(UExpressionInner::Value(0)) + (_, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), _) + if v.value == 0 => + { + Ok(UExpression::value(0)) } (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(UExpressionInner::Value(n1 & n2)) + Ok(UExpression::value(n1.value & n2.value)) } - (e1, e2) => Ok(UExpressionInner::And( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => Ok( + UExpression::and(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), + ), } } - UExpressionInner::Or(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1)?; - let e2 = self.fold_uint_expression(e2)?; + UExpressionInner::Or(e) => { + let e1 = self.fold_uint_expression(*e.left)?; + let e2 = self.fold_uint_expression(*e.right)?; match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), e) => Ok(e), + (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) + if v.value == 0 => + { + Ok(e) + } (_, UExpressionInner::Value(n)) | (UExpressionInner::Value(n), _) - if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => + if n.value == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => { - Ok(UExpressionInner::Value(n)) + Ok(UExpression::value(n.value)) } (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - Ok(UExpressionInner::Value(n1 | n2)) + Ok(UExpression::value(n1.value | n2.value)) } - (e1, e2) => Ok(UExpressionInner::Or( - box e1.annotate(bitwidth), - box e2.annotate(bitwidth), - )), + (e1, e2) => Ok( + UExpression::or(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner(), + ), } } - UExpressionInner::LeftShift(box e, by) => { - let e = self.fold_uint_expression(e)?; - match (e.into_inner(), by) { - (e, 0) => Ok(e), - (_, by) if by >= bitwidth as u32 => Ok(UExpressionInner::Value(0)), - (UExpressionInner::Value(n), by) => Ok(UExpressionInner::Value( - (n << by) & (2_u128.pow(bitwidth as u32) - 1), - )), - (e, by) => Ok(UExpressionInner::LeftShift(box e.annotate(bitwidth), by)), + UExpressionInner::LeftShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + match (left.into_inner(), right.into_inner()) { + (e, UExpressionInner::Value(by)) if by.value == 0 => Ok(e), + (_, UExpressionInner::Value(by)) if by.value as u32 >= bitwidth as u32 => { + Ok(UExpression::value(0)) + } + (UExpressionInner::Value(n), UExpressionInner::Value(by)) => { + Ok(UExpression::value( + (n.value << by.value) & (2_u128.pow(bitwidth as u32) - 1), + )) + } + (e, by) => Ok(UExpression::left_shift( + e.annotate(bitwidth), + by.annotate(UBitwidth::B32), + ) + .into_inner()), } } - UExpressionInner::RightShift(box e, by) => { - let e = self.fold_uint_expression(e)?; - match (e.into_inner(), by) { - (e, 0) => Ok(e), - (_, by) if by >= bitwidth as u32 => Ok(UExpressionInner::Value(0)), - (UExpressionInner::Value(n), by) => Ok(UExpressionInner::Value(n >> by)), - (e, by) => Ok(UExpressionInner::RightShift(box e.annotate(bitwidth), by)), + UExpressionInner::RightShift(e) => { + let left = self.fold_uint_expression(*e.left)?; + let right = self.fold_uint_expression(*e.right)?; + match (left.into_inner(), right.into_inner()) { + (e, UExpressionInner::Value(by)) if by.value == 0 => Ok(e), + (_, UExpressionInner::Value(by)) if by.value as u32 >= bitwidth as u32 => { + Ok(UExpression::value(0)) + } + (UExpressionInner::Value(n), UExpressionInner::Value(by)) => { + Ok(UExpression::value(n.value >> by.value)) + } + (e, by) => Ok(UExpression::right_shift( + e.annotate(bitwidth), + by.annotate(UBitwidth::B32), + ) + .into_inner()), } } - UExpressionInner::Not(box e) => { - let e = self.fold_uint_expression(e)?; + UExpressionInner::Not(e) => { + let e = self.fold_uint_expression(*e.inner)?; match e.into_inner() { - UExpressionInner::Value(n) => Ok(UExpressionInner::Value( - !n & (2_u128.pow(bitwidth as u32) - 1), + UExpressionInner::Value(n) => Ok(UExpression::value( + !n.value & (2_u128.pow(bitwidth as u32) - 1), )), - e => Ok(UExpressionInner::Not(box e.annotate(bitwidth))), + e => Ok(UExpression::not(e.annotate(bitwidth)).into_inner()), } } - e => fold_uint_expression_inner(self, bitwidth, e), + e => fold_uint_expression_cases(self, bitwidth, e), } } fn fold_conditional_expression< - E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + ResultFold + Conditional<'ast, T>, >( &mut self, _: &E::Ty, @@ -814,10 +868,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { let condition = self.fold_boolean_expression(*e.condition)?; match condition { - BooleanExpression::Value(true) => Ok(ConditionalOrExpression::Expression( + BooleanExpression::Value(v) if v.value => Ok(ConditionalOrExpression::Expression( e.consequence.fold(self)?.into_inner(), )), - BooleanExpression::Value(false) => Ok(ConditionalOrExpression::Expression( + BooleanExpression::Value(v) if !v.value => Ok(ConditionalOrExpression::Expression( e.alternative.fold(self)?.into_inner(), )), condition => { @@ -847,15 +901,15 @@ mod tests { #[test] fn propagation() { // assert([x, 1] == [y, 1]) - let statements = vec![ZirStatement::Assertion( - BooleanExpression::And( - box BooleanExpression::FieldEq( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + let statements = vec![ZirStatement::assertion( + BooleanExpression::bitand( + BooleanExpression::field_eq( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), - box BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Number(Bn128Field::from(1)), + BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(1)), ), ), RuntimeError::mock(), @@ -873,10 +927,10 @@ mod tests { assert_eq!( statements, - vec![ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + vec![ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), RuntimeError::mock() )] @@ -897,21 +951,21 @@ mod tests { assert_eq!( propagator.fold_field_expression(FieldElementExpression::select( vec![ - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ], - UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(2))) + Ok(FieldElementExpression::value(Bn128Field::from(2))) ); assert_eq!( propagator.fold_field_expression(FieldElementExpression::select( vec![ - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), ], - UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), Err(Error::OutOfBounds(3, 2)) ); @@ -922,18 +976,18 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_field_expression(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(5))) + Ok(FieldElementExpression::value(Bn128Field::from(5))) ); // a + 0 = a assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Add( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_field_expression(FieldElementExpression::add( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), Ok(FieldElementExpression::identifier("a".into())) ); @@ -944,18 +998,18 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Sub( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_field_expression(FieldElementExpression::sub( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); // a - 0 = a assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Sub( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_field_expression(FieldElementExpression::sub( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), Ok(FieldElementExpression::identifier("a".into())) ); @@ -966,27 +1020,27 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_field_expression(FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(6))) + Ok(FieldElementExpression::value(Bn128Field::from(6))) ); // a * 0 = 0 assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_field_expression(FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); // a * 1 = a assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + propagator.fold_field_expression(FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(1)), )), Ok(FieldElementExpression::identifier("a".into())) ); @@ -997,25 +1051,25 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Div( - box FieldElementExpression::Number(Bn128Field::from(6)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::value(Bn128Field::from(6)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(3))) + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Div( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(1)), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Div( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), Err(Error::DivisionByZero) ); @@ -1026,27 +1080,27 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Pow( - box FieldElementExpression::Number(Bn128Field::from(3)), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::pow( + FieldElementExpression::value(Bn128Field::from(3)), + UExpression::value(2).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(9))) + Ok(FieldElementExpression::value(Bn128Field::from(9))) ); // a ** 0 = 1 assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + UExpression::value(0).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); // a ** 1 = a assert_eq!( - propagator.fold_field_expression(FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(1).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + UExpression::value(1).annotate(UBitwidth::B32), )), Ok(FieldElementExpression::identifier("a".into())) ); @@ -1057,44 +1111,44 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::identifier("a".into()), + UExpression::value(0).annotate(UBitwidth::B32), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(2)), + UExpression::value(2_u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(8))) + Ok(FieldElementExpression::value(Bn128Field::from(8))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(1)), - box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(1)), + UExpression::value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) + Ok(FieldElementExpression::value(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(3)), - box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(3)), + UExpression::value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) + Ok(FieldElementExpression::value(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::LeftShift( - box FieldElementExpression::Number(Bn128Field::from(1)), - box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128) + propagator.fold_field_expression(FieldElementExpression::left_shift( + FieldElementExpression::value(Bn128Field::from(1)), + UExpression::value((Bn128Field::get_required_bits()) as u128) .annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); } @@ -1103,61 +1157,61 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::identifier("a".into()), + UExpression::value(0).annotate(UBitwidth::B32), )), Ok(FieldElementExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::identifier("a".into()), - box UExpressionInner::Value(Bn128Field::get_required_bits() as u128) + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::identifier("a".into()), + UExpression::value(Bn128Field::get_required_bits() as u128) .annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(3)), - box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(3)), + UExpression::value(1_u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(2)), + UExpression::value(2_u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::from(2)), - box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32), + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::from(2)), + UExpression::value(4_u128).annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::max_value()), - box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128) + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::max_value()), + UExpression::value((Bn128Field::get_required_bits() - 1) as u128) .annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( - propagator.fold_field_expression(FieldElementExpression::RightShift( - box FieldElementExpression::Number(Bn128Field::max_value()), - box UExpressionInner::Value(Bn128Field::get_required_bits() as u128) + propagator.fold_field_expression(FieldElementExpression::right_shift( + FieldElementExpression::value(Bn128Field::max_value()), + UExpression::value(Bn128Field::get_required_bits() as u128) .annotate(UBitwidth::B32), )), - Ok(FieldElementExpression::Number(Bn128Field::from(0))) + Ok(FieldElementExpression::value(Bn128Field::from(0))) ); } @@ -1167,29 +1221,29 @@ mod tests { assert_eq!( propagator.fold_field_expression(FieldElementExpression::conditional( - BooleanExpression::Value(true), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + BooleanExpression::value(true), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(1))) + Ok(FieldElementExpression::value(Bn128Field::from(1))) ); assert_eq!( propagator.fold_field_expression(FieldElementExpression::conditional( - BooleanExpression::Value(false), - FieldElementExpression::Number(Bn128Field::from(1)), - FieldElementExpression::Number(Bn128Field::from(2)), + BooleanExpression::value(false), + FieldElementExpression::value(Bn128Field::from(1)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(2))) + Ok(FieldElementExpression::value(Bn128Field::from(2))) ); assert_eq!( propagator.fold_field_expression(FieldElementExpression::conditional( BooleanExpression::identifier("a".into()), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(FieldElementExpression::Number(Bn128Field::from(2))) + Ok(FieldElementExpression::value(Bn128Field::from(2))) ); } } @@ -1208,21 +1262,21 @@ mod tests { assert_eq!( propagator.fold_boolean_expression(BooleanExpression::select( vec![ - BooleanExpression::Value(false), - BooleanExpression::Value(true), + BooleanExpression::value(false), + BooleanExpression::value(true) ], - UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::select( vec![ - BooleanExpression::Value(false), - BooleanExpression::Value(true), + BooleanExpression::value(false), + BooleanExpression::value(true) ], - UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), Err(Error::OutOfBounds(3, 2)) ); @@ -1233,35 +1287,35 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_boolean_expression(BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_boolean_expression(BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLt( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + propagator.fold_boolean_expression(BooleanExpression::field_lt( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(0)), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::max_value()), - box FieldElementExpression::identifier("a".into()), + propagator.fold_boolean_expression(BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::max_value()), + FieldElementExpression::identifier("a".into()), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1270,19 +1324,19 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_boolean_expression(BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(3)), + propagator.fold_boolean_expression(BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(3)), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } @@ -1291,19 +1345,19 @@ mod tests { let mut propagator = ZirPropagator::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_boolean_expression(BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(3)), - box FieldElementExpression::Number(Bn128Field::from(2)), + propagator.fold_boolean_expression(BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(3)), + FieldElementExpression::value(Bn128Field::from(2)), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1312,19 +1366,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintLt( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_lt( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintLt( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_lt( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1333,19 +1387,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintLe( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_le( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintLe( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_le( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } @@ -1354,19 +1408,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintEq( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_eq( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::UintEq( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + propagator.fold_boolean_expression(BooleanExpression::uint_eq( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1375,19 +1429,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::BoolEq( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::bool_eq( + BooleanExpression::value(true), + BooleanExpression::value(true), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::BoolEq( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::bool_eq( + BooleanExpression::value(true), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1396,35 +1450,35 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::value(true), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::And( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::bitand( + BooleanExpression::value(true), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::And( - box BooleanExpression::identifier("a".into()), - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::bitand( + BooleanExpression::identifier("a".into()), + BooleanExpression::value(true) )), Ok(BooleanExpression::identifier("a".into())) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::And( - box BooleanExpression::identifier("a".into()), - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::bitand( + BooleanExpression::identifier("a".into()), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); } @@ -1433,19 +1487,19 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::value(true), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::bitor( + BooleanExpression::value(true), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } @@ -1454,17 +1508,17 @@ mod tests { let mut propagator = ZirPropagator::::default(); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::Not( - box BooleanExpression::Value(true), + propagator.fold_boolean_expression(BooleanExpression::not( + BooleanExpression::value(true) )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( - propagator.fold_boolean_expression(BooleanExpression::Not( - box BooleanExpression::Value(false), + propagator.fold_boolean_expression(BooleanExpression::not( + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } @@ -1474,29 +1528,29 @@ mod tests { assert_eq!( propagator.fold_boolean_expression(BooleanExpression::conditional( - BooleanExpression::Value(true), - BooleanExpression::Value(true), - BooleanExpression::Value(false) + BooleanExpression::value(true), + BooleanExpression::value(true), + BooleanExpression::value(false) )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::conditional( - BooleanExpression::Value(false), - BooleanExpression::Value(true), - BooleanExpression::Value(false) + BooleanExpression::value(false), + BooleanExpression::value(true), + BooleanExpression::value(false), )), - Ok(BooleanExpression::Value(false)) + Ok(BooleanExpression::value(false)) ); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::conditional( BooleanExpression::identifier("a".into()), - BooleanExpression::Value(true), - BooleanExpression::Value(true) + BooleanExpression::value(true), + BooleanExpression::value(true), )), - Ok(BooleanExpression::Value(true)) + Ok(BooleanExpression::value(true)) ); } } @@ -1516,14 +1570,14 @@ mod tests { UBitwidth::B32, UExpression::select( vec![ - UExpressionInner::Value(1).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ], - UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), ) .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( @@ -1531,10 +1585,10 @@ mod tests { UBitwidth::B32, UExpression::select( vec![ - UExpressionInner::Value(1).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ], - UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) .into_inner() ), @@ -1549,22 +1603,24 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Add( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::add( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(5)) + Ok(UExpression::value(5)) ); // a + 0 = a assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Add( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::add( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1577,22 +1633,24 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Sub( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::sub( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); // a - 0 = a assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Sub( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::sub( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1605,22 +1663,24 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Mult( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::mult( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(6)) + Ok(UExpression::value(6)) ); // a * 1 = a assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Mult( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::mult( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1629,12 +1689,13 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Mult( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::mult( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); } @@ -1645,21 +1706,23 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Div( - box UExpressionInner::Value(6).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::div( + UExpression::value(6).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(3)) + Ok(UExpression::value(3)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Div( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpression::div( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(1).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1667,10 +1730,11 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Div( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::div( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), Err(Error::DivisionByZero) ); @@ -1683,23 +1747,25 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Rem( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::rem( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Rem( - box UExpressionInner::Value(3).annotate(UBitwidth::B32), - box UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::rem( + UExpression::value(3).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); } @@ -1710,23 +1776,25 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Xor( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::xor( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Xor( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::xor( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::identifier("a".into()).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); } @@ -1737,32 +1805,35 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::And( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::and( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::And( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::and( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::And( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(u32::MAX as u128).annotate(UBitwidth::B32), + UExpression::and( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(u32::MAX as u128).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1775,21 +1846,23 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Or( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - box UExpressionInner::Value(3).annotate(UBitwidth::B32), + UExpression::or( + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(3).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(3)) + Ok(UExpression::value(3)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Or( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(0).annotate(UBitwidth::B32), + UExpression::or( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(0).annotate(UBitwidth::B32), ) + .into_inner() ), Ok(UExpression::identifier("a".into())) ); @@ -1797,12 +1870,13 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Or( - box UExpression::identifier("a".into()).annotate(UBitwidth::B32), - box UExpressionInner::Value(u32::MAX as u128).annotate(UBitwidth::B32), + UExpression::or( + UExpression::identifier("a".into()).annotate(UBitwidth::B32), + UExpression::value(u32::MAX as u128).annotate(UBitwidth::B32), ) + .into_inner() ), - Ok(UExpressionInner::Value(u32::MAX as u128)) + Ok(UExpression::value(u32::MAX as u128)) ); } @@ -1813,34 +1887,37 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::LeftShift( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - 3, + UExpression::left_shift( + UExpression::value(2).annotate(UBitwidth::B32), + 3.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(16)) + Ok(UExpression::value(16)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::LeftShift( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - 0, + UExpression::left_shift( + UExpression::value(2).annotate(UBitwidth::B32), + 0.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::LeftShift( - box UExpressionInner::Value(2).annotate(UBitwidth::B32), - 32, + UExpression::left_shift( + UExpression::value(2).annotate(UBitwidth::B32), + 32.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); } @@ -1851,34 +1928,37 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::RightShift( - box UExpressionInner::Value(4).annotate(UBitwidth::B32), - 2, + UExpression::right_shift( + UExpression::value(4).annotate(UBitwidth::B32), + 2.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::RightShift( - box UExpressionInner::Value(4).annotate(UBitwidth::B32), - 0, + UExpression::right_shift( + UExpression::value(4).annotate(UBitwidth::B32), + 0.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(4)) + Ok(UExpression::value(4)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::RightShift( - box UExpressionInner::Value(4).annotate(UBitwidth::B32), - 32, + UExpression::right_shift( + UExpression::value(4).annotate(UBitwidth::B32), + 32.into(), ) + .into_inner() ), - Ok(UExpressionInner::Value(0)) + Ok(UExpression::value(0)) ); } @@ -1889,9 +1969,9 @@ mod tests { assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, - UExpressionInner::Not(box UExpressionInner::Value(2).annotate(UBitwidth::B32),) + UExpression::not(UExpression::value(2).annotate(UBitwidth::B32),).into_inner() ), - Ok(UExpressionInner::Value(4294967293)) + Ok(UExpression::value(4294967293)) ); } @@ -1903,26 +1983,26 @@ mod tests { propagator.fold_uint_expression_inner( UBitwidth::B32, UExpression::conditional( - BooleanExpression::Value(true), - UExpressionInner::Value(1).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + BooleanExpression::value(true), + UExpression::value(1).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) .into_inner() ), - Ok(UExpressionInner::Value(1)) + Ok(UExpression::value(1)) ); assert_eq!( propagator.fold_uint_expression_inner( UBitwidth::B32, UExpression::conditional( - BooleanExpression::Value(false), - UExpressionInner::Value(1).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + BooleanExpression::value(false), + UExpression::value(1).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); assert_eq!( @@ -1930,12 +2010,12 @@ mod tests { UBitwidth::B32, UExpression::conditional( BooleanExpression::identifier("a".into()), - UExpressionInner::Value(2).annotate(UBitwidth::B32), - UExpressionInner::Value(2).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), + UExpression::value(2).annotate(UBitwidth::B32), ) .into_inner() ), - Ok(UExpressionInner::Value(2)) + Ok(UExpression::value(2)) ); } } diff --git a/zokrates_ark/Cargo.toml b/zokrates_ark/Cargo.toml index c31dc8d3e..8aec37846 100644 --- a/zokrates_ark/Cargo.toml +++ b/zokrates_ark/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_ark" -version = "0.1.2" +version = "0.1.3" edition = "2021" [features] @@ -28,7 +28,7 @@ ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = f ark-bw6-761 = { version = "^0.3.0", default-features = false } ark-gm17 = { version = "^0.3.0", default-features = false } ark-groth16 = { version = "^0.3.0", default-features = false } -ark-serialize = { version = "^0.3.0", default-features = false } +ark-serialize = { version = "^0.3.0", default-features = false, features = ["std"] } ark-relations = { version = "^0.3.0", default-features = false } ark-marlin = { git = "https://github.com/arkworks-rs/marlin", rev = "63cfd82", default-features = false } ark-poly = { version = "^0.3.0", default-features = false } diff --git a/zokrates_ark/src/gm17.rs b/zokrates_ark/src/gm17.rs index a8941ed03..4148c1361 100644 --- a/zokrates_ark/src/gm17.rs +++ b/zokrates_ark/src/gm17.rs @@ -40,11 +40,16 @@ impl NonUniversalBackend for Ark { } impl Backend for Ark { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: std::io::Read, + G: RngCore + CryptoRng, + >( program: ProgIterator<'a, T, I>, witness: Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof { let computation = Computation::with_witness(program, witness); @@ -54,10 +59,9 @@ impl Backend for Ark { .map(parse_fr::) .collect::>(); - let pk = ProvingKey::<::ArkEngine>::deserialize_unchecked( - &mut proving_key.as_slice(), - ) - .unwrap(); + let pk = + ProvingKey::<::ArkEngine>::deserialize_unchecked(proving_key) + .unwrap(); let proof = ArkGM17::::prove(&pk, computation, rng).unwrap(); let proof_points = ProofPoints { @@ -120,9 +124,15 @@ mod tests { #[test] fn verify_bls12_377_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -135,7 +145,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); @@ -145,9 +158,15 @@ mod tests { #[test] fn verify_bw6_761_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -158,8 +177,12 @@ mod tests { .execute(program.clone(), &[Bw6_761Field::from(42)]) .unwrap(); - let proof = - >::generate_proof(program, witness, keypair.pk, rng); + let proof = >::generate_proof( + program, + witness, + keypair.pk.as_slice(), + rng, + ); let ans = >::verify(keypair.vk, proof); assert!(ans); diff --git a/zokrates_ark/src/groth16.rs b/zokrates_ark/src/groth16.rs index e9281210a..c49c8cab1 100644 --- a/zokrates_ark/src/groth16.rs +++ b/zokrates_ark/src/groth16.rs @@ -4,6 +4,7 @@ use ark_groth16::{ ProvingKey, VerifyingKey, }; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use std::io::Read; use zokrates_field::ArkFieldExtensions; use zokrates_field::Field; use zokrates_proof_systems::{Backend, NonUniversalBackend, Proof, SetupKeypair}; @@ -17,11 +18,16 @@ use zokrates_proof_systems::groth16::{ProofPoints, VerificationKey, G16}; use zokrates_proof_systems::Scheme; impl Backend for Ark { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: Read, + G: RngCore + CryptoRng, + >( program: ProgIterator<'a, T, I>, witness: Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof { let computation = Computation::with_witness(program, witness); @@ -31,12 +37,12 @@ impl Backend for Ark { .map(parse_fr::) .collect::>(); - let pk = ProvingKey::<::ArkEngine>::deserialize_unchecked( - &mut proving_key.as_slice(), - ) - .unwrap(); + let pk = + ProvingKey::<::ArkEngine>::deserialize_unchecked(proving_key) + .unwrap(); let proof = Groth16::::prove(&pk, computation, rng).unwrap(); + let proof_points = ProofPoints { a: parse_g1::(&proof.a), b: parse_g2::(&proof.b), @@ -117,9 +123,15 @@ mod tests { #[test] fn verify_bls12_377_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -132,7 +144,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); @@ -142,9 +157,15 @@ mod tests { #[test] fn verify_bw6_761_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -155,8 +176,12 @@ mod tests { .execute(program.clone(), &[Bw6_761Field::from(42)]) .unwrap(); - let proof = - >::generate_proof(program, witness, keypair.pk, rng); + let proof = >::generate_proof( + program, + witness, + keypair.pk.as_slice(), + rng, + ); let ans = >::verify(keypair.vk, proof); assert!(ans); diff --git a/zokrates_ark/src/lib.rs b/zokrates_ark/src/lib.rs index 425be3a8d..dba49e9a9 100644 --- a/zokrates_ark/src/lib.rs +++ b/zokrates_ark/src/lib.rs @@ -8,8 +8,8 @@ use ark_relations::r1cs::{ SynthesisError, Variable as ArkVariable, }; use std::collections::BTreeMap; -use zokrates_ast::common::Variable; -use zokrates_ast::ir::{CanonicalLinComb, ProgIterator, Statement, Witness}; +use zokrates_ast::common::flat::Variable; +use zokrates_ast::ir::{LinComb, ProgIterator, Statement, Witness}; use zokrates_field::{ArkFieldExtensions, Field}; pub use self::parse::*; @@ -39,12 +39,13 @@ impl<'a, T, I: IntoIterator>> Computation<'a, T, I> { } fn ark_combination( - l: CanonicalLinComb, + l: LinComb, cs: &mut ConstraintSystem<<::ArkEngine as PairingEngine>::Fr>, symbols: &mut BTreeMap, witness: &mut Witness, ) -> LinearCombination<<::ArkEngine as PairingEngine>::Fr> { - l.0.into_iter() + l.value + .into_iter() .map(|(k, v)| { ( v.into_ark(), @@ -112,25 +113,10 @@ impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator> })); for statement in self.program.statements { - if let Statement::Constraint(quad, lin, _) = statement { - let a = ark_combination( - quad.left.clone().into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - ); - let b = ark_combination( - quad.right.clone().into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - ); - let c = ark_combination( - lin.into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - ); + if let Statement::Constraint(s) = statement { + let a = ark_combination(s.quad.left, &mut cs, &mut symbols, &mut witness); + let b = ark_combination(s.quad.right, &mut cs, &mut symbols, &mut witness); + let c = ark_combination(s.lin, &mut cs, &mut symbols, &mut witness); cs.enforce_constraint(a, b, c)?; } @@ -150,7 +136,7 @@ impl<'a, T: Field + ArkFieldExtensions, I: IntoIterator> self.program .public_inputs_values(self.witness.as_ref().unwrap()) .iter() - .map(|v| v.clone().into_ark()) + .map(|v| v.into_ark()) .collect() } } diff --git a/zokrates_ark/src/marlin.rs b/zokrates_ark/src/marlin.rs index b0e0f75e7..5292fba94 100644 --- a/zokrates_ark/src/marlin.rs +++ b/zokrates_ark/src/marlin.rs @@ -19,6 +19,7 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use digest::Digest; use rand_0_8::{CryptoRng, Error, RngCore, SeedableRng}; use sha3::Keccak256; +use std::io::Read; use std::marker::PhantomData; use zokrates_field::{ArkFieldExtensions, Field}; @@ -206,11 +207,16 @@ impl UniversalBackend for Ark } impl Backend for Ark { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: Read, + G: RngCore + CryptoRng, + >( program: ProgIterator<'a, T, I>, witness: Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof { let computation = Computation::with_witness(program, witness); @@ -220,7 +226,7 @@ impl Backend for Ark { T::ArkEngine, DensePolynomial<<::ArkEngine as PairingEngine>::Fr>, >, - >::deserialize_unchecked(&mut proving_key.as_slice()) + >::deserialize_unchecked(proving_key) .unwrap(); let public_inputs = computation.public_inputs_values(); @@ -392,18 +398,18 @@ mod tests { #[test] fn verify_bls12_377_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::private(Variable::new(0))], return_count: 1, statements: vec![ Statement::constraint( - QuadComb::from_linear_combinations( - Variable::new(0).into(), - Variable::new(0).into(), - ), + QuadComb::new(Variable::new(0).into(), Variable::new(0).into()), Variable::new(1), + None, ), - Statement::constraint(Variable::new(1), Variable::public(0)), + Statement::constraint(Variable::new(1), Variable::public(0), None), ], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -417,7 +423,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); @@ -427,18 +436,18 @@ mod tests { #[test] fn verify_bw6_761_field() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::private(Variable::new(0))], return_count: 1, statements: vec![ Statement::constraint( - QuadComb::from_linear_combinations( - Variable::new(0).into(), - Variable::new(0).into(), - ), + QuadComb::new(Variable::new(0).into(), Variable::new(0).into()), Variable::new(1), + None, ), - Statement::constraint(Variable::new(1), Variable::public(0)), + Statement::constraint(Variable::new(1), Variable::public(0), None), ], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -452,7 +461,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 60eb498c6..1e8e47c9f 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_ast" -version = "0.1.5" +version = "0.1.6" edition = "2021" [features] @@ -9,15 +9,16 @@ bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman"] ark = ["ark-bls12-377", "zokrates_embed/ark"] [dependencies] +log = "0.4" +byteorder = "1.4.3" zokrates_pest_ast = { version = "0.3.0", path = "../zokrates_pest_ast" } cfg-if = "0.1" zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } serde = { version = "1.0", features = ["derive"] } -csv = "1" serde_cbor = "0.11.2" num-bigint = { version = "0.2", default-features = false } serde_json = { version = "1.0", features = ["preserve_order"] } zokrates_embed = { version = "0.1.0", path = "../zokrates_embed", default-features = false } pairing_ce = { version = "^0.21", optional = true } ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = false, optional = true } -derivative = "2.2.0" \ No newline at end of file +derivative = "2.2.0" diff --git a/zokrates_ast/src/common/embed.rs b/zokrates_ast/src/common/embed.rs index 58294142f..de5d3ad07 100644 --- a/zokrates_ast/src/common/embed.rs +++ b/zokrates_ast/src/common/embed.rs @@ -1,5 +1,8 @@ -use crate::common::{Parameter, RuntimeError, Solver, Variable}; -use crate::flat::{flat_expression_from_bits, flat_expression_from_variable_summands}; +use crate::common::{ + flat::{Parameter, Variable}, + RuntimeError, Solver, +}; +use crate::flat::flat_expression_from_bits; use crate::flat::{FlatDirective, FlatExpression, FlatFunctionIterator, FlatStatement}; use crate::typed::types::{ ConcreteGenericsAssignment, DeclarationConstant, DeclarationSignature, DeclarationType, @@ -11,12 +14,16 @@ use crate::untyped::{ }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::ops::*; use zokrates_field::Field; +use super::ModuleMap; + cfg_if::cfg_if! { if #[cfg(feature = "bellman")] { use pairing_ce::bn256::Bn256; use zokrates_embed::{bellman::{from_bellman, generate_sha256_round_constraints}}; + use crate::flat::flat_expression_from_variable_summands; } } @@ -294,7 +301,7 @@ impl FlatEmbed { ); assert_eq!(gen.len(), assignment.0.len()); - gen.map(|g| *assignment.0.get(&g).unwrap() as u32).collect() + gen.map(|g| *assignment.0.get(&g).unwrap()).collect() } pub fn id(&self) -> &'static str { @@ -348,15 +355,11 @@ pub fn sha256_round<'ast, T: Field>( let cs_indices = 0..variable_count; // indices of the arguments to the function // apply an offset of `variable_count` to get the indice of our dummy `input` argument - let input_argument_indices: Vec<_> = input_indices - .clone() - .into_iter() - .map(|i| i + variable_count) - .collect(); + let input_argument_indices: Vec<_> = + input_indices.clone().map(|i| i + variable_count).collect(); // apply an offset of `variable_count` to get the indice of our dummy `current_hash` argument let current_hash_argument_indices: Vec<_> = current_hash_indices .clone() - .into_iter() .map(|i| i + variable_count) .collect(); // define parameters to the function based on the variables @@ -364,21 +367,18 @@ pub fn sha256_round<'ast, T: Field>( .clone() .into_iter() .chain(current_hash_argument_indices.clone()) - .map(|i| Parameter { - id: Variable::new(i), - private: true, - }) + .map(|i| Parameter::private(Variable::new(i))) .collect(); // define a binding of the first variable in the constraint system to one - let one_binding_statement = FlatStatement::Condition( + let one_binding_statement = FlatStatement::condition( Variable::new(0).into(), - FlatExpression::Number(T::from(1)), + FlatExpression::value(T::from(1)), RuntimeError::BellmanOneBinding, ); let input_binding_statements = // bind input and current_hash to inputs input_indices.chain(current_hash_indices).zip(input_argument_indices.clone().into_iter().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| { - FlatStatement::Condition( + FlatStatement::condition( Variable::new(cs_index).into(), Variable::new(argument_index).into(), RuntimeError::BellmanInputBinding @@ -391,30 +391,30 @@ pub fn sha256_round<'ast, T: Field>( let rhs_b = flat_expression_from_variable_summands::(c.b.as_slice()); let lhs = flat_expression_from_variable_summands::(c.c.as_slice()); - FlatStatement::Condition( + FlatStatement::condition( lhs, - FlatExpression::Mult(box rhs_a, box rhs_b), + FlatExpression::mul(rhs_a, rhs_b), RuntimeError::BellmanConstraint, ) }); // define which subset of the witness is returned - let outputs = output_indices.map(|o| FlatExpression::Identifier(Variable::new(o))); + let outputs = output_indices.map(|o| FlatExpression::identifier(Variable::new(o))); // insert a directive to set the witness based on the bellman gadget and inputs - let directive_statement = FlatStatement::Directive(FlatDirective { - outputs: cs_indices.map(Variable::new).collect(), - inputs: input_argument_indices + let directive_statement = FlatStatement::Directive(FlatDirective::new( + cs_indices.map(Variable::new).collect(), + Solver::Sha256Round, + input_argument_indices .into_iter() .chain(current_hash_argument_indices) .map(|i| Variable::new(i).into()) .collect(), - solver: Solver::Sha256Round, - }); + )); // insert a statement to return the subset of the witness let return_statements = outputs .into_iter() .enumerate() - .map(|(index, e)| FlatStatement::Definition(Variable::public(index), e)); + .map(|(index, e)| FlatStatement::definition(Variable::public(index), e)); let statements = std::iter::once(directive_statement) .chain(std::iter::once(one_binding_statement)) .chain(input_binding_statements) @@ -425,6 +425,7 @@ pub fn sha256_round<'ast, T: Field>( arguments, statements, return_count, + module_map: ModuleMap::default(), } } @@ -467,9 +468,9 @@ pub fn snark_verify_bls12_377<'ast, T: Field>( .chain(vk_arguments) .collect(); - let one_binding_statement = FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(0)), - FlatExpression::Number(T::from(1)), + let one_binding_statement = FlatStatement::condition( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(T::from(1)), RuntimeError::ArkOneBinding, ); @@ -483,7 +484,7 @@ pub fn snark_verify_bls12_377<'ast, T: Field>( .chain(vk_argument_indices.clone()), ) .map(|(cs_index, argument_index)| { - FlatStatement::Condition( + FlatStatement::condition( Variable::new(cs_index).into(), Variable::new(argument_index).into(), RuntimeError::ArkInputBinding, @@ -499,29 +500,29 @@ pub fn snark_verify_bls12_377<'ast, T: Field>( let rhs_b = flat_expression_from_variable_summands::(c.b.as_slice()); let lhs = flat_expression_from_variable_summands::(c.c.as_slice()); - FlatStatement::Condition( + FlatStatement::condition( lhs, - FlatExpression::Mult(box rhs_a, box rhs_b), + FlatExpression::mul(rhs_a, rhs_b), RuntimeError::ArkConstraint, ) }) .collect(); - let return_statement = FlatStatement::Definition( + let return_statement = FlatStatement::definition( Variable::public(0), - FlatExpression::Identifier(Variable::new(out_index)), + FlatExpression::identifier(Variable::new(out_index)), ); // insert a directive to set the witness - let directive_statement = FlatStatement::Directive(FlatDirective { - outputs: cs_indices.map(Variable::new).collect(), - inputs: input_argument_indices + let directive_statement = FlatStatement::Directive(FlatDirective::new( + cs_indices.map(Variable::new).collect(), + Solver::SnarkVerifyBls12377(n), + input_argument_indices .chain(proof_argument_indices) .chain(vk_argument_indices) .map(|i| Variable::new(i).into()) .collect(), - solver: Solver::SnarkVerifyBls12377(n), - }); + )); let statements = std::iter::once(directive_statement) .chain(std::iter::once(one_binding_statement)) @@ -533,6 +534,7 @@ pub fn snark_verify_bls12_377<'ast, T: Field>( arguments, statements, return_count: 1, + module_map: ModuleMap::default(), } } @@ -562,14 +564,11 @@ pub fn unpack_to_bitwidth<'ast, T: Field>( let mut layout = HashMap::new(); - let arguments = vec![Parameter { - id: Variable::new(0), - private: true, - }]; + let arguments = vec![Parameter::private(Variable::new(0))]; // o0, ..., o253 = ToBits(i0) - let directive_inputs = vec![FlatExpression::Identifier(use_variable( + let directive_inputs = vec![FlatExpression::identifier(use_variable( &mut layout, "i0".into(), &mut counter, @@ -585,16 +584,16 @@ pub fn unpack_to_bitwidth<'ast, T: Field>( let outputs: Vec<_> = directive_outputs .iter() .enumerate() - .map(|(_, o)| FlatExpression::Identifier(*o)) + .map(|(_, o)| FlatExpression::identifier(*o)) .collect(); // o253, o252, ... o{253 - (bit_width - 1)} are bits let mut statements: Vec> = (0..bit_width) .map(|index| { - let bit = FlatExpression::Identifier(Variable::new(bit_width - index)); - FlatStatement::Condition( + let bit = FlatExpression::identifier(Variable::new(bit_width - index)); + FlatStatement::condition( bit.clone(), - FlatExpression::Mult(box bit.clone(), box bit.clone()), + FlatExpression::mul(bit.clone(), bit), RuntimeError::Bitness, ) }) @@ -603,39 +602,40 @@ pub fn unpack_to_bitwidth<'ast, T: Field>( // sum check: o253 + o252 * 2 + ... + o{253 - (bit_width - 1)} * 2**(bit_width - 1) let lhs_sum = flat_expression_from_bits( (0..bit_width) - .map(|i| FlatExpression::Identifier(Variable::new(i + 1))) + .map(|i| FlatExpression::identifier(Variable::new(i + 1))) .collect(), ); - statements.push(FlatStatement::Condition( + statements.push(FlatStatement::condition( lhs_sum, - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(T::from(1)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(T::from(1)), ), RuntimeError::Sum, )); statements.insert( 0, - FlatStatement::Directive(FlatDirective { - inputs: directive_inputs, - outputs: directive_outputs, + FlatStatement::Directive(FlatDirective::new( + directive_outputs, solver, - }), + directive_inputs, + )), ); statements.extend( outputs .into_iter() .enumerate() - .map(|(index, e)| FlatStatement::Definition(Variable::public(index), e)), + .map(|(index, e)| FlatStatement::definition(Variable::public(index), e)), ); FlatFunctionIterator { arguments, statements: statements.into_iter(), return_count: bit_width, + module_map: ModuleMap::default(), } } @@ -661,7 +661,7 @@ mod tests { .map(|i| Variable::new(i + 1)) .collect(), Solver::bits(Bn128Field::get_required_bits()), - vec![Variable::new(0)] + vec![Variable::new(0).into()] )) ); assert_eq!( @@ -709,9 +709,9 @@ mod tests { // bellman variable #0: index 0 should equal 1 assert_eq!( compiled.statements[1], - FlatStatement::Condition( + FlatStatement::condition( Variable::new(0).into(), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), RuntimeError::BellmanOneBinding ) ); @@ -719,7 +719,7 @@ mod tests { // bellman input #0: index 1 should equal zokrates input #0: index v_count assert_eq!( compiled.statements[2], - FlatStatement::Condition( + FlatStatement::condition( Variable::new(1).into(), Variable::new(26936).into(), RuntimeError::BellmanInputBinding diff --git a/zokrates_ast/src/common/expressions.rs b/zokrates_ast/src/common/expressions.rs new file mode 100644 index 000000000..01481ce6a --- /dev/null +++ b/zokrates_ast/src/common/expressions.rs @@ -0,0 +1,280 @@ +use derivative::Derivative; +use num_bigint::BigUint; +use serde::{Deserialize, Serialize}; + +use super::operators::OpEq; +use super::{operators::OperatorStr, Span, WithSpan}; +use std::fmt; +use std::marker::PhantomData; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct BinaryExpression { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub left: Box, + pub right: Box, + operator: PhantomData, + output: PhantomData, +} + +impl BinaryExpression { + pub fn new(left: L, right: R) -> Self { + Self { + span: None, + left: Box::new(left), + right: Box::new(right), + operator: PhantomData, + output: PhantomData, + } + } +} + +impl fmt::Display + for BinaryExpression +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "({} {} {})", self.left, Op::STR, self.right,) + } +} + +impl WithSpan for BinaryExpression { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +pub enum BinaryOrExpression { + Binary(BinaryExpression), + Expression(I), +} + +pub type EqExpression = BinaryExpression; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnaryExpression { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: Box, + operator: PhantomData, + output: PhantomData, +} + +impl UnaryExpression { + pub fn new(inner: In) -> Self { + Self { + span: None, + inner: Box::new(inner), + operator: PhantomData, + output: PhantomData, + } + } +} + +impl fmt::Display + for UnaryExpression +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "({}{})", Op::STR, self.inner,) + } +} + +impl WithSpan for UnaryExpression { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +pub enum UnaryOrExpression { + Unary(UnaryExpression), + Expression(I), +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ValueExpression { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub value: V, +} + +impl ValueExpression { + pub fn new(value: V) -> Self { + Self { span: None, value } + } +} + +impl fmt::Display for ValueExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.value,) + } +} + +pub type FieldValueExpression = ValueExpression; + +pub type BooleanValueExpression = ValueExpression; + +pub type UValueExpression = ValueExpression; + +pub type IntValueExpression = ValueExpression; + +pub enum ValueOrExpression { + Value(V), + Expression(I), +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IdentifierExpression { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub id: I, + pub ty: PhantomData, +} + +impl IdentifierExpression { + pub fn new(id: I) -> Self { + IdentifierExpression { + span: None, + id, + ty: PhantomData, + } + } +} + +impl fmt::Display for IdentifierExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.id,) + } +} + +impl WithSpan for IdentifierExpression { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +pub enum IdentifierOrExpression { + Identifier(IdentifierExpression), + Expression(I), +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DefinitionStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub assignee: A, + pub rhs: E, +} + +impl DefinitionStatement { + pub fn new(assignee: A, rhs: E) -> Self { + DefinitionStatement { + span: None, + assignee, + rhs, + } + } +} + +impl WithSpan for DefinitionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for DefinitionStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} = {}", self.assignee, self.rhs,) + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AssertionStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub expression: B, + pub error: E, +} + +impl AssertionStatement { + pub fn new(expression: B, error: E) -> Self { + AssertionStatement { + span: None, + expression, + error, + } + } +} + +impl WithSpan for AssertionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ReturnStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: E, +} + +impl ReturnStatement { + pub fn new(e: E) -> Self { + ReturnStatement { + span: None, + inner: e, + } + } +} + +impl WithSpan for ReturnStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for ReturnStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "return {};", self.inner) + } +} diff --git a/zokrates_ast/src/common/flat/mod.rs b/zokrates_ast/src/common/flat/mod.rs new file mode 100644 index 000000000..d1892a6bb --- /dev/null +++ b/zokrates_ast/src/common/flat/mod.rs @@ -0,0 +1,5 @@ +pub mod parameter; +pub mod variable; + +pub use parameter::Parameter; +pub use variable::Variable; diff --git a/zokrates_ast/src/common/flat/parameter.rs b/zokrates_ast/src/common/flat/parameter.rs new file mode 100644 index 000000000..c4a86c962 --- /dev/null +++ b/zokrates_ast/src/common/flat/parameter.rs @@ -0,0 +1,63 @@ +use crate::common::{Span, WithSpan}; + +use super::variable::Variable; +use derivative::Derivative; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct Parameter { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub id: Variable, + pub private: bool, +} + +impl WithSpan for Parameter { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl Parameter { + pub fn new(id: Variable, private: bool) -> Self { + Parameter { + id, + private, + span: None, + } + } + + pub fn public(v: Variable) -> Self { + Self::new(v, false) + } + + pub fn private(v: Variable) -> Self { + Self::new(v, true) + } +} + +impl fmt::Display for Parameter { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let visibility = if self.private { "private " } else { "" }; + write!(f, "{}{}", visibility, self.id) + } +} + +impl Parameter { + pub fn apply_substitution(self, substitution: &HashMap) -> Parameter { + Parameter { + id: *substitution.get(&self.id).unwrap(), + private: self.private, + ..self + } + } +} diff --git a/zokrates_ast/src/common/flat/variable.rs b/zokrates_ast/src/common/flat/variable.rs new file mode 100644 index 000000000..ab460623c --- /dev/null +++ b/zokrates_ast/src/common/flat/variable.rs @@ -0,0 +1,102 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; +use std::io::{Read, Write}; + +// A variable in a constraint system +// id > 0 for intermediate variables +// id == 0 for ~one +// id < 0 for public outputs +#[derive(Serialize, Deserialize, Clone, PartialEq, Hash, Eq, Ord, PartialOrd, Copy)] +pub struct Variable { + pub id: isize, +} + +impl Variable { + pub fn new(id: usize) -> Self { + Variable { + id: 1 + id as isize, + } + } + + pub fn one() -> Self { + Variable { id: 0 } + } + + pub fn public(id: usize) -> Self { + Variable { + id: -(id as isize) - 1, + } + } + + pub fn id(&self) -> usize { + assert!(self.id > 0); + (self.id as usize) - 1 + } + + pub fn write(&self, mut writer: W) -> std::io::Result<()> { + writer.write_all(&self.id.to_le_bytes())?; + Ok(()) + } + + pub fn read(mut reader: R) -> std::io::Result { + let mut buf = [0; std::mem::size_of::()]; + reader.read_exact(&mut buf)?; + + Ok(Variable { + id: isize::from_le_bytes(buf), + }) + } +} + +impl fmt::Display for Variable { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.id { + 0 => write!(f, "~one"), + i if i > 0 => write!(f, "_{}", i - 1), + i => write!(f, "~out_{}", -(i + 1)), + } + } +} + +impl fmt::Debug for Variable { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.id { + 0 => write!(f, "~one"), + i if i > 0 => write!(f, "_{}", i - 1), + i => write!(f, "~out_{}", -(i + 1)), + } + } +} + +impl Variable { + pub fn apply_substitution(self, substitution: &HashMap) -> &Self { + substitution.get(&self).unwrap() + } + + pub fn is_output(&self) -> bool { + self.id < 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn one() { + assert_eq!(format!("{}", Variable::one()), "~one"); + } + + #[test] + fn public() { + assert_eq!(format!("{}", Variable::public(0)), "~out_0"); + assert_eq!(format!("{}", Variable::public(42)), "~out_42"); + } + + #[test] + fn private() { + assert_eq!(format!("{}", Variable::new(0)), "_0"); + assert_eq!(format!("{}", Variable::new(42)), "_42"); + } +} diff --git a/zokrates_ast/src/common/fold.rs b/zokrates_ast/src/common/fold.rs new file mode 100644 index 000000000..98369bef1 --- /dev/null +++ b/zokrates_ast/src/common/fold.rs @@ -0,0 +1,7 @@ +pub trait Fold: Sized { + fn fold(self, f: &mut F) -> Self; +} + +pub trait ResultFold: Sized { + fn fold(self, f: &mut F) -> Result; +} diff --git a/zokrates_ast/src/common/mod.rs b/zokrates_ast/src/common/mod.rs index 13d23bfdb..186ea3caa 100644 --- a/zokrates_ast/src/common/mod.rs +++ b/zokrates_ast/src/common/mod.rs @@ -1,15 +1,28 @@ pub mod embed; mod error; +pub mod expressions; +pub mod flat; +mod fold; mod format_string; mod metadata; +pub mod operators; mod parameter; +mod position; mod solvers; +pub mod statements; +mod value; mod variable; pub use self::embed::FlatEmbed; pub use self::error::RuntimeError; +pub use self::fold::{Fold, ResultFold}; pub use self::metadata::SourceMetadata; pub use self::parameter::Parameter; -pub use self::solvers::Solver; +pub use self::position::{ + LocalSourceSpan, ModuleId, ModuleIdHash, ModuleMap, OwnedModuleId, Position, SourceSpan, Span, + WithSpan, +}; +pub use self::solvers::{RefCall, Solver}; +pub use self::value::Value; pub use self::variable::Variable; pub use format_string::FormatString; diff --git a/zokrates_ast/src/common/operators.rs b/zokrates_ast/src/common/operators.rs new file mode 100644 index 000000000..18a26cec2 --- /dev/null +++ b/zokrates_ast/src/common/operators.rs @@ -0,0 +1,143 @@ +pub trait OperatorStr { + const STR: &'static str; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpAdd; + +impl OperatorStr for OpAdd { + const STR: &'static str = "+"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpSub; + +impl OperatorStr for OpSub { + const STR: &'static str = "-"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpFloorSub; + +impl OperatorStr for OpFloorSub { + const STR: &'static str = "- /* FLOOR_SUB */"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpMul; + +impl OperatorStr for OpMul { + const STR: &'static str = "*"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpDiv; + +impl OperatorStr for OpDiv { + const STR: &'static str = "/"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpRem; + +impl OperatorStr for OpRem { + const STR: &'static str = "%"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpPow; + +impl OperatorStr for OpPow { + const STR: &'static str = "**"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpEq; + +impl OperatorStr for OpEq { + const STR: &'static str = "=="; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpLt; + +impl OperatorStr for OpLt { + const STR: &'static str = "<"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpLe; + +impl OperatorStr for OpLe { + const STR: &'static str = "<="; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpGt; + +impl OperatorStr for OpGt { + const STR: &'static str = ">"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpGe; + +impl OperatorStr for OpGe { + const STR: &'static str = ">="; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpXor; + +impl OperatorStr for OpXor { + const STR: &'static str = "^"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpOr; + +impl OperatorStr for OpOr { + const STR: &'static str = "|"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpAnd; + +impl OperatorStr for OpAnd { + const STR: &'static str = "&"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpLsh; + +impl OperatorStr for OpLsh { + const STR: &'static str = "<<"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpRsh; + +impl OperatorStr for OpRsh { + const STR: &'static str = ">>"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpNot; + +impl OperatorStr for OpNot { + const STR: &'static str = "!"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpNeg; + +impl OperatorStr for OpNeg { + const STR: &'static str = "-"; +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +pub struct OpPos; + +impl OperatorStr for OpPos { + const STR: &'static str = "+"; +} diff --git a/zokrates_ast/src/common/parameter.rs b/zokrates_ast/src/common/parameter.rs index 4b17395f4..e63b8be5c 100644 --- a/zokrates_ast/src/common/parameter.rs +++ b/zokrates_ast/src/common/parameter.rs @@ -1,46 +1,58 @@ -use super::variable::Variable; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::fmt; -#[derive(Serialize, Deserialize, Hash, Eq, PartialEq, Clone, Copy)] -pub struct Parameter { - pub id: Variable, +use derivative::Derivative; +use serde::{Deserialize, Serialize}; + +use super::{Span, WithSpan}; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Hash, Eq)] +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Parameter { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub id: V, pub private: bool, } -impl Parameter { - fn new(id: Variable, private: bool) -> Self { - Parameter { id, private } +impl From for Parameter { + fn from(v: V) -> Self { + Self::private(v) + } +} + +impl Parameter { + pub fn new(v: V, private: bool) -> Self { + Parameter { + span: None, + id: v, + private, + } } - pub fn public(v: Variable) -> Self { + pub fn public(v: V) -> Self { Self::new(v, false) } - pub fn private(v: Variable) -> Self { + pub fn private(v: V) -> Self { Self::new(v, true) } } -impl fmt::Display for Parameter { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let visibility = if self.private { "private " } else { "" }; - write!(f, "{}{}", visibility, self.id) +impl WithSpan for Parameter { + fn span(mut self, span: Option) -> Self { + self.span = span; + self } -} -impl fmt::Debug for Parameter { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Parameter(id: {:?})", self.id) + fn get_span(&self) -> Option { + self.span } } -impl Parameter { - pub fn apply_substitution(self, substitution: &HashMap) -> Parameter { - Parameter { - id: *substitution.get(&self.id).unwrap(), - private: self.private, - } +impl fmt::Display for Parameter { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let visibility = if self.private { "private " } else { "" }; + write!(f, "{}{}", visibility, self.id) } } diff --git a/zokrates_ast/src/common/position.rs b/zokrates_ast/src/common/position.rs new file mode 100644 index 000000000..a817f0749 --- /dev/null +++ b/zokrates_ast/src/common/position.rs @@ -0,0 +1,242 @@ +use std::{ + collections::BTreeMap, + fmt, + path::{Path, PathBuf}, +}; + +use serde::{Deserialize, Serialize}; + +use super::FlatEmbed; + +#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord, Deserialize, Serialize)] +pub struct LocalSourceSpan { + pub from: Position, + pub to: Position, +} + +pub type ModuleIdHash = u64; + +pub type ModuleId = Path; + +pub type OwnedModuleId = PathBuf; + +#[derive(Clone, PartialEq, Debug, Eq, Hash, Default, PartialOrd, Ord, Deserialize, Serialize)] +pub struct ModuleMap { + modules: BTreeMap, +} + +impl ModuleMap { + pub fn new>(i: I) -> Self { + Self { + modules: i.into_iter().map(|id| (hash(&id), id)).collect(), + } + } + + pub fn remap_prefix(self, prefix: &Path, to: &Path) -> Self { + Self { + modules: self + .modules + .into_iter() + .map(|(id, path)| { + ( + id, + path.strip_prefix(prefix) + .map(|path| to.join(path)) + .unwrap_or(path), + ) + }) + .collect(), + } + } +} + +#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord, Deserialize, Serialize)] +pub struct Position { + pub line: usize, + pub col: usize, +} + +#[derive(Clone, PartialEq, Eq, Copy, Hash, PartialOrd, Ord, Deserialize, Serialize, Debug)] +pub enum Span { + Source(SourceSpan), + Embed(FlatEmbed), +} + +impl fmt::Display for Span { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Span::Source(s) => write!(f, "{}", s), + Span::Embed(e) => write!(f, "{:?}", e), + } + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum ResolvedSpan { + Source(ResolvedSourceSpan), + Embed(FlatEmbed), +} + +impl fmt::Display for ResolvedSpan { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ResolvedSpan::Source(s) => write!(f, "{}", s), + ResolvedSpan::Embed(e) => write!(f, "{:?}", e), + } + } +} + +impl Span { + pub fn resolve(self, map: &ModuleMap) -> ResolvedSpan { + match self { + Span::Source(s) => ResolvedSpan::Source(ResolvedSourceSpan { + module: map.modules.get(&s.module).cloned().unwrap(), + from: s.from, + to: s.to, + }), + Span::Embed(s) => ResolvedSpan::Embed(s), + } + } +} + +#[derive(Clone, PartialEq, Eq, Copy, Hash, Default, PartialOrd, Ord, Deserialize, Serialize)] +pub struct SourceSpan { + pub module: ModuleIdHash, + pub from: Position, + pub to: Position, +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct ResolvedSourceSpan { + pub module: OwnedModuleId, + pub from: Position, + pub to: Position, +} + +impl From for Span { + fn from(span: SourceSpan) -> Self { + Self::Source(span) + } +} + +impl From for Span { + fn from(embed: FlatEmbed) -> Self { + Self::Embed(embed) + } +} + +impl SourceSpan { + pub fn mock() -> Self { + Self { + module: hash(&OwnedModuleId::default()), + from: Position::mock(), + to: Position::mock(), + } + } +} + +pub trait WithSpan: Sized { + fn span(self, _: Option) -> Self; + + fn with_span>(self, span: S) -> Self { + self.span(Some(span.into())) + } + + fn get_span(&self) -> Option; +} + +fn hash(id: &ModuleId) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + id.hash(&mut hasher); + hasher.finish() +} + +impl LocalSourceSpan { + pub fn in_module(self, module_id: &ModuleId) -> SourceSpan { + SourceSpan { + module: hash(module_id), + from: self.from, + to: self.to, + } + } + + pub fn mock() -> Self { + Self { + from: Position::mock(), + to: Position::mock(), + } + } +} + +impl Position { + pub fn col(&self, delta: isize) -> Position { + assert!(self.col <= isize::max_value() as usize); + assert!(self.col as isize >= delta); + Position { + line: self.line, + col: (self.col as isize + delta) as usize, + } + } + + pub fn mock() -> Self { + Position { line: 42, col: 42 } + } +} +impl fmt::Display for Position { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}:{}", self.line, self.col) + } +} +impl fmt::Debug for Position { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for SourceSpan { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.from) + } +} +impl fmt::Debug for SourceSpan { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl fmt::Display for ResolvedSourceSpan { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}:{} (until {})", + self.module.display(), + self.from, + self.to + ) + } +} + +#[test] +fn position_col() { + let pos = Position { + line: 100, + col: 258, + }; + assert_eq!( + pos.col(26), + Position { + line: 100, + col: 284, + } + ); + assert_eq!( + pos.col(-23), + Position { + line: 100, + col: 235, + } + ); +} diff --git a/zokrates_ast/src/common/solvers.rs b/zokrates_ast/src/common/solvers.rs index 9b4f5c900..9c6e9bbcc 100644 --- a/zokrates_ast/src/common/solvers.rs +++ b/zokrates_ast/src/common/solvers.rs @@ -2,6 +2,12 @@ use crate::zir::ZirFunction; use serde::{Deserialize, Serialize}; use std::fmt; +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] +pub struct RefCall { + pub index: usize, + pub argument_count: usize, +} + #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver<'ast, T> { ConditionEq, @@ -14,6 +20,7 @@ pub enum Solver<'ast, T> { EuclideanDiv, #[serde(borrow)] Zir(ZirFunction<'ast, T>), + Ref(RefCall), #[cfg(feature = "bellman")] Sha256Round, #[cfg(feature = "ark")] @@ -32,6 +39,7 @@ impl<'ast, T> fmt::Display for Solver<'ast, T> { Solver::ShaCh => write!(f, "ShaCh"), Solver::EuclideanDiv => write!(f, "EuclideanDiv"), Solver::Zir(_) => write!(f, "Zir(..)"), + Solver::Ref(call) => write!(f, "Ref@{}({})", call.index, call.argument_count), #[cfg(feature = "bellman")] Solver::Sha256Round => write!(f, "Sha256Round"), #[cfg(feature = "ark")] @@ -52,6 +60,7 @@ impl<'ast, T> Solver<'ast, T> { Solver::ShaCh => (3, 1), Solver::EuclideanDiv => (2, 2), Solver::Zir(f) => (f.arguments.len(), 1), + Solver::Ref(c) => (c.argument_count, 1), #[cfg(feature = "bellman")] Solver::Sha256Round => (768, 26935), #[cfg(feature = "ark")] diff --git a/zokrates_ast/src/common/statements.rs b/zokrates_ast/src/common/statements.rs new file mode 100644 index 000000000..ba67b8283 --- /dev/null +++ b/zokrates_ast/src/common/statements.rs @@ -0,0 +1,278 @@ +use derivative::Derivative; +use serde::{Deserialize, Serialize}; + +use crate::Solver; + +use super::{FormatString, SourceMetadata}; +use super::{Span, WithSpan}; +use std::fmt; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct DefinitionStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub assignee: A, + pub rhs: E, +} + +impl DefinitionStatement { + pub fn new(assignee: A, rhs: E) -> Self { + DefinitionStatement { + span: None, + assignee, + rhs, + } + } +} + +impl WithSpan for DefinitionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for DefinitionStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} = {};", self.assignee, self.rhs,) + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct AssertionStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub expression: B, + pub error: E, +} + +impl AssertionStatement { + pub fn new(expression: B, error: E) -> Self { + AssertionStatement { + span: None, + expression, + error, + } + } +} + +impl WithSpan for AssertionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct ReturnStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: E, +} + +impl ReturnStatement { + pub fn new(e: E) -> Self { + ReturnStatement { + span: None, + inner: e, + } + } +} + +impl WithSpan for ReturnStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for ReturnStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "return {};", self.inner) + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LogStatement { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub format_string: FormatString, + pub expressions: Vec, +} + +impl LogStatement { + pub fn new(format_string: FormatString, expressions: Vec) -> Self { + LogStatement { + span: None, + format_string, + expressions, + } + } +} + +impl WithSpan for LogStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for LogStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "log({}, {});", + self.format_string, + self.expressions + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", ") + ) + } +} + +#[derive(Derivative, Clone, Debug, Serialize, Deserialize)] +#[derivative(Hash, PartialEq, Eq)] +pub struct DirectiveStatement { + #[derivative(Hash = "ignore", PartialEq = "ignore")] + pub span: Option, + pub inputs: Vec, + pub outputs: Vec, + pub solver: S, +} + +impl<'ast, T, I, O> DirectiveStatement> { + pub fn new(outputs: Vec, solver: Solver<'ast, T>, inputs: Vec) -> Self { + let (in_len, out_len) = solver.get_signature(); + assert_eq!(in_len, inputs.len()); + assert_eq!(out_len, outputs.len()); + Self { + span: None, + inputs, + outputs, + solver, + } + } +} + +impl WithSpan for DirectiveStatement { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display + for DirectiveStatement +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "# {} = {}({})", + self.outputs + .iter() + .map(|o| format!("{}", o)) + .collect::>() + .join(", "), + self.solver, + self.inputs + .iter() + .map(|i| format!("{}", i)) + .collect::>() + .join(", ") + ) + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AssemblyAssignment { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub assignee: A, + pub expression: E, +} + +impl AssemblyAssignment { + pub fn new(assignee: A, expression: E) -> Self { + Self { + span: None, + assignee, + expression, + } + } +} + +impl WithSpan for AssemblyAssignment { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AssemblyConstraint { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub left: E, + pub right: E, + pub metadata: SourceMetadata, +} + +impl AssemblyConstraint { + pub fn new(left: E, right: E, metadata: SourceMetadata) -> Self { + Self { + span: None, + left, + right, + metadata, + } + } +} + +impl WithSpan for AssemblyConstraint { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for AssemblyConstraint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} === {};", self.left, self.right) + } +} diff --git a/zokrates_ast/src/common/value.rs b/zokrates_ast/src/common/value.rs new file mode 100644 index 000000000..af204a813 --- /dev/null +++ b/zokrates_ast/src/common/value.rs @@ -0,0 +1,3 @@ +pub trait Value { + type Value: Clone; +} diff --git a/zokrates_ast/src/common/variable.rs b/zokrates_ast/src/common/variable.rs index 983e9e178..c5556805a 100644 --- a/zokrates_ast/src/common/variable.rs +++ b/zokrates_ast/src/common/variable.rs @@ -1,111 +1,42 @@ +use derivative::Derivative; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::fmt; -// A variable in a constraint system -// id > 0 for intermediate variables -// id == 0 for ~one -// id < 0 for public outputs -#[derive(Serialize, Deserialize, Clone, PartialEq, Hash, Eq, Ord, PartialOrd, Copy)] -pub struct Variable { - pub id: isize, -} - -impl Variable { - pub fn new(id: usize) -> Self { - Variable { - id: 1 + id as isize, - } - } +use super::{Span, WithSpan}; - pub fn one() -> Self { - Variable { id: 0 } - } - - pub fn public(id: usize) -> Self { - Variable { - id: -(id as isize) - 1, - } - } +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Hash, Eq, Ord)] +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Variable { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub id: I, + pub ty: T, +} - pub fn id(&self) -> usize { - assert!(self.id > 0); - (self.id as usize) - 1 +impl WithSpan for Variable { + fn span(mut self, span: Option) -> Self { + self.span = span; + self } - pub fn try_from_human_readable(s: &str) -> Result { - if s == "~one" { - return Ok(Variable::one()); - } - - let mut public = s.split("~out_"); - match public.nth(1) { - Some(v) => { - let v = v.parse().map_err(|_| s)?; - Ok(Variable::public(v)) - } - None => { - let mut private = s.split('_'); - match private.nth(1) { - Some(v) => { - let v = v.parse().map_err(|_| s)?; - Ok(Variable::new(v)) - } - None => Err(s), - } - } - } + fn get_span(&self) -> Option { + self.span } } -impl fmt::Display for Variable { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.id { - 0 => write!(f, "~one"), - i if i > 0 => write!(f, "_{}", i - 1), - i => write!(f, "~out_{}", -(i + 1)), +impl Variable { + pub fn new>(id: J, ty: T) -> Self { + Self { + span: None, + id: id.into(), + ty, } } } -impl fmt::Debug for Variable { +impl fmt::Display for Variable { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.id { - 0 => write!(f, "~one"), - i if i > 0 => write!(f, "_{}", i - 1), - i => write!(f, "~out_{}", -(i + 1)), - } - } -} - -impl Variable { - pub fn apply_substitution(self, substitution: &HashMap) -> &Self { - substitution.get(&self).unwrap() - } - - pub fn is_output(&self) -> bool { - self.id < 0 - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn one() { - assert_eq!(format!("{}", Variable::one()), "~one"); - } - - #[test] - fn public() { - assert_eq!(format!("{}", Variable::public(0)), "~out_0"); - assert_eq!(format!("{}", Variable::public(42)), "~out_42"); - } - - #[test] - fn private() { - assert_eq!(format!("{}", Variable::new(0)), "_0"); - assert_eq!(format!("{}", Variable::new(42)), "_42"); + write!(f, "{} {}", self.ty, self.id,) } } diff --git a/zokrates_ast/src/flat/folder.rs b/zokrates_ast/src/flat/folder.rs index ce50d7dac..539e02ebe 100644 --- a/zokrates_ast/src/flat/folder.rs +++ b/zokrates_ast/src/flat/folder.rs @@ -1,9 +1,18 @@ // Generic walk through an IR AST. Not mutating in place use super::*; -use crate::common::Variable; +use crate::common::{ + expressions::{BinaryOrExpression, IdentifierOrExpression}, + flat::Variable, + Fold, WithSpan, +}; use zokrates_field::Field; +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for FlatExpression { + fn fold(self, f: &mut F) -> Self { + f.fold_expression(self) + } +} pub trait Folder<'ast, T: Field>: Sized { fn fold_program(&mut self, p: FlatProg<'ast, T>) -> FlatProg<'ast, T> { fold_program(self, p) @@ -25,6 +34,20 @@ pub trait Folder<'ast, T: Field>: Sized { fold_expression(self, e) } + fn fold_binary_expression, R: Fold, E>( + &mut self, + e: BinaryExpression, + ) -> BinaryOrExpression> { + fold_binary_expression(self, e) + } + + fn fold_identifier_expression( + &mut self, + e: IdentifierExpression>, + ) -> IdentifierOrExpression, FlatExpression> { + fold_identifier_expression(self, e) + } + fn fold_directive(&mut self, d: FlatDirective<'ast, T>) -> FlatDirective<'ast, T> { fold_directive(self, d) } @@ -45,7 +68,7 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - return_count: p.return_count, + ..p } } @@ -54,25 +77,27 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( s: FlatStatement<'ast, T>, ) -> Vec> { match s { - FlatStatement::Block(statements) => vec![FlatStatement::Block( + FlatStatement::Condition(s) => vec![FlatStatement::condition( + f.fold_expression(s.quad), + f.fold_expression(s.lin), + s.error, + )], + FlatStatement::Block(statements) => vec![FlatStatement::block( statements + .inner .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), )], - FlatStatement::Condition(left, right, error) => vec![FlatStatement::Condition( - f.fold_expression(left), - f.fold_expression(right), - error, - )], - FlatStatement::Definition(v, e) => vec![FlatStatement::Definition( - f.fold_variable(v), - f.fold_expression(e), + FlatStatement::Definition(s) => vec![FlatStatement::definition( + f.fold_variable(s.assignee), + f.fold_expression(s.rhs), )], FlatStatement::Directive(d) => vec![FlatStatement::Directive(f.fold_directive(d))], - FlatStatement::Log(s, e) => vec![FlatStatement::Log( - s, - e.into_iter() + FlatStatement::Log(s) => vec![FlatStatement::log( + s.format_string, + s.expressions + .into_iter() .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) .collect(), )], @@ -84,20 +109,45 @@ pub fn fold_expression<'ast, T: Field, F: Folder<'ast, T>>( e: FlatExpression, ) -> FlatExpression { match e { - FlatExpression::Number(n) => FlatExpression::Number(n), - FlatExpression::Identifier(id) => FlatExpression::Identifier(f.fold_variable(id)), - FlatExpression::Add(box left, box right) => { - FlatExpression::Add(box f.fold_expression(left), box f.fold_expression(right)) - } - FlatExpression::Sub(box left, box right) => { - FlatExpression::Sub(box f.fold_expression(left), box f.fold_expression(right)) - } - FlatExpression::Mult(box left, box right) => { - FlatExpression::Mult(box f.fold_expression(left), box f.fold_expression(right)) - } + FlatExpression::Value(n) => FlatExpression::Value(n), + FlatExpression::Identifier(id) => match f.fold_identifier_expression(id) { + IdentifierOrExpression::Identifier(e) => FlatExpression::Identifier(e), + IdentifierOrExpression::Expression(e) => e, + }, + FlatExpression::Add(e) => match f.fold_binary_expression(e) { + BinaryOrExpression::Binary(e) => FlatExpression::Add(e), + BinaryOrExpression::Expression(e) => e, + }, + FlatExpression::Sub(e) => match f.fold_binary_expression(e) { + BinaryOrExpression::Binary(e) => FlatExpression::Sub(e), + BinaryOrExpression::Expression(e) => e, + }, + FlatExpression::Mult(e) => match f.fold_binary_expression(e) { + BinaryOrExpression::Binary(e) => FlatExpression::Mult(e), + BinaryOrExpression::Expression(e) => e, + }, } } +fn fold_identifier_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: IdentifierExpression>, +) -> IdentifierOrExpression, FlatExpression> { + let id = f.fold_variable(e.id); + + IdentifierOrExpression::Identifier(IdentifierExpression { id, ..e }) +} + +fn fold_binary_expression<'ast, T: Field, F: Folder<'ast, T>, Op, L: Fold, R: Fold, E>( + f: &mut F, + e: BinaryExpression, +) -> BinaryOrExpression> { + let left = e.left.fold(f); + let right = e.right.fold(f); + + BinaryOrExpression::Binary(BinaryExpression::new(left, right).span(e.span)) +} + pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ds: FlatDirective<'ast, T>, @@ -116,7 +166,7 @@ pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter { Parameter { id: f.fold_variable(a.id), - private: a.private, + ..a } } diff --git a/zokrates_ast/src/flat/mod.rs b/zokrates_ast/src/flat/mod.rs index 015e670a5..44bd42573 100644 --- a/zokrates_ast/src/flat/mod.rs +++ b/zokrates_ast/src/flat/mod.rs @@ -8,11 +8,20 @@ pub mod folder; pub mod utils; +use crate::common; +pub use crate::common::flat::Parameter; +pub use crate::common::flat::Variable; +use crate::common::statements::DirectiveStatement; use crate::common::FormatString; -pub use crate::common::Parameter; +use crate::common::ModuleMap; pub use crate::common::RuntimeError; -pub use crate::common::Variable; +use crate::common::{ + expressions::{BinaryExpression, IdentifierExpression, ValueExpression}, + operators::*, +}; +use crate::common::{Span, WithSpan}; +use derivative::Derivative; pub use utils::{ flat_expression_from_bits, flat_expression_from_expression_summands, flat_expression_from_variable_summands, @@ -32,6 +41,8 @@ pub type FlatProgIterator<'ast, T, I> = FlatFunctionIterator<'ast, T, I>; #[derive(Clone, PartialEq, Eq, Debug)] pub struct FlatFunctionIterator<'ast, T, I: IntoIterator>> { + /// The map of the modules for sourcemaps + pub module_map: ModuleMap, /// Arguments of the function pub arguments: Vec, /// Vector of statements that are executed when running the function @@ -46,6 +57,7 @@ impl<'ast, T, I: IntoIterator>> FlatFunctionIterat statements: self.statements.into_iter().collect(), arguments: self.arguments, return_count: self.return_count, + module_map: self.module_map, } } } @@ -70,45 +82,155 @@ impl<'ast, T: Field> fmt::Display for FlatFunction<'ast, T> { } } -/// Calculates a flattened function based on a R1CS (A, B, C) and returns that flattened function: -/// * The Rank 1 Constraint System (R1CS) is defined as: -/// * `* = ` for a witness `x` -/// * Since the matrices in R1CS are usually sparse, the following encoding is used: -/// * For each constraint (i.e., row in the R1CS), only non-zero values are supplied and encoded as a tuple (index, value). -/// -/// # Arguments -/// -/// * r1cs - R1CS in standard JSON data format +pub type DefinitionStatement = + common::expressions::DefinitionStatement>; +pub type LogStatement = common::statements::LogStatement<(ConcreteType, Vec>)>; +pub type FlatDirective<'ast, T> = + common::statements::DirectiveStatement, Variable, Solver<'ast, T>>; -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug)] +pub struct BlockStatement<'ast, T> { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: Vec>, +} + +impl<'ast, T> BlockStatement<'ast, T> { + pub fn new(inner: Vec>) -> Self { + BlockStatement { span: None, inner } + } +} + +impl WithSpan for AssertionStatement { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug)] +pub struct AssertionStatement { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, + pub quad: FlatExpression, + pub lin: FlatExpression, + pub error: RuntimeError, +} + +impl AssertionStatement { + pub fn new(lin: FlatExpression, quad: FlatExpression, error: RuntimeError) -> Self { + AssertionStatement { + span: None, + quad, + lin, + error, + } + } +} + +impl<'ast, T> WithSpan for BlockStatement<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug)] pub enum FlatStatement<'ast, T> { - Block(Vec>), - Condition(FlatExpression, FlatExpression, RuntimeError), - Definition(Variable, FlatExpression), + Condition(AssertionStatement), + Definition(DefinitionStatement), Directive(FlatDirective<'ast, T>), - Log(FormatString, Vec<(ConcreteType, Vec>)>), + Log(LogStatement), + Block(BlockStatement<'ast, T>), +} + +impl<'ast, T> FlatStatement<'ast, T> { + pub fn definition(assignee: Variable, rhs: FlatExpression) -> Self { + Self::Definition(DefinitionStatement::new(assignee, rhs)) + } + + pub fn condition(lin: FlatExpression, quad: FlatExpression, error: RuntimeError) -> Self { + Self::Condition(AssertionStatement::new(lin, quad, error)) + } + + pub fn log( + format_string: FormatString, + expressions: Vec<(ConcreteType, Vec>)>, + ) -> Self { + Self::Log(LogStatement::new(format_string, expressions)) + } + + pub fn directive( + outputs: Vec, + solver: Solver<'ast, T>, + inputs: Vec>, + ) -> Self { + Self::Directive(DirectiveStatement::new(outputs, solver, inputs)) + } + + pub fn block(inner: Vec>) -> Self { + Self::Block(BlockStatement::new(inner)) + } +} + +impl<'ast, T> WithSpan for FlatStatement<'ast, T> { + fn span(self, span: Option) -> Self { + use FlatStatement::*; + + match self { + Condition(e) => Condition(e.span(span)), + Definition(e) => Definition(e.span(span)), + Directive(e) => Directive(e.span(span)), + Log(e) => Log(e.span(span)), + Block(e) => Block(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use FlatStatement::*; + + match self { + Condition(e) => e.get_span(), + Definition(e) => e.get_span(), + Directive(e) => e.get_span(), + Log(e) => e.get_span(), + Block(e) => e.get_span(), + } + } } impl<'ast, T: Field> fmt::Display for FlatStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FlatStatement::Block(ref statements) => { + FlatStatement::Definition(ref e) => write!(f, "{}", e), + FlatStatement::Condition(ref s) => { + write!(f, "{} == {} // {}", s.lin, s.quad, s.error) + } + FlatStatement::Block(ref s) => { writeln!(f, "{{")?; - for s in statements { + for s in &s.inner { writeln!(f, "{}", s)?; } writeln!(f, "}}") } - FlatStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), - FlatStatement::Condition(ref lhs, ref rhs, ref message) => { - write!(f, "{} == {} // {}", lhs, rhs, message) - } FlatStatement::Directive(ref d) => write!(f, "{}", d), - FlatStatement::Log(ref l, ref expressions) => write!( + FlatStatement::Log(ref s) => write!( f, "log(\"{}\"), {})", - l, - expressions + s.format_string, + s.expressions .iter() .map(|(_, e)| format!( "[{}]", @@ -130,20 +252,20 @@ impl<'ast, T: Field> FlatStatement<'ast, T> { substitution: &'ast HashMap, ) -> FlatStatement { match self { - FlatStatement::Block(statements) => FlatStatement::Block( - statements + FlatStatement::Definition(s) => FlatStatement::definition( + *s.assignee.apply_substitution(substitution), + s.rhs.apply_substitution(substitution), + ), + FlatStatement::Block(s) => FlatStatement::block( + s.inner .into_iter() .map(|s| s.apply_substitution(substitution)) .collect(), ), - FlatStatement::Definition(id, x) => FlatStatement::Definition( - *id.apply_substitution(substitution), - x.apply_substitution(substitution), - ), - FlatStatement::Condition(x, y, message) => FlatStatement::Condition( - x.apply_substitution(substitution), - y.apply_substitution(substitution), - message, + FlatStatement::Condition(s) => FlatStatement::condition( + s.quad.apply_substitution(substitution), + s.lin.apply_substitution(substitution), + s.error, ), FlatStatement::Directive(d) => { let outputs = d @@ -163,9 +285,10 @@ impl<'ast, T: Field> FlatStatement<'ast, T> { ..d }) } - FlatStatement::Log(l, e) => FlatStatement::Log( - l, - e.into_iter() + FlatStatement::Log(s) => FlatStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() .map(|(t, e)| { ( t, @@ -175,112 +298,98 @@ impl<'ast, T: Field> FlatStatement<'ast, T> { ) }) .collect(), - ), + )), } } } -#[derive(Clone, Hash, Debug, PartialEq, Eq)] -pub struct FlatDirective<'ast, T> { - pub inputs: Vec>, - pub outputs: Vec, - pub solver: Solver<'ast, T>, +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum FlatExpression { + Value(ValueExpression), + Identifier(IdentifierExpression), + Add(BinaryExpression), + Sub(BinaryExpression), + Mult(BinaryExpression), } -impl<'ast, T> FlatDirective<'ast, T> { - pub fn new>>( - outputs: Vec, - solver: Solver<'ast, T>, - inputs: Vec, - ) -> Self { - let (in_len, out_len) = solver.get_signature(); - assert_eq!(in_len, inputs.len()); - assert_eq!(out_len, outputs.len()); - FlatDirective { - solver, - inputs: inputs.into_iter().map(|i| i.into()).collect(), - outputs, - } +impl std::ops::Add for FlatExpression { + type Output = Self; + + fn add(self, other: Self) -> Self::Output { + FlatExpression::Add(BinaryExpression::new(self, other)) } } -impl<'ast, T: Field> fmt::Display for FlatDirective<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "# {} = {}({})", - self.outputs - .iter() - .map(|o| o.to_string()) - .collect::>() - .join(", "), - self.solver, - self.inputs - .iter() - .map(|i| i.to_string()) - .collect::>() - .join(", ") - ) +impl std::ops::Sub for FlatExpression { + type Output = Self; + + fn sub(self, other: Self) -> Self::Output { + FlatExpression::Sub(BinaryExpression::new(self, other)) } } -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub enum FlatExpression { - Number(T), - Identifier(Variable), - Add(Box>, Box>), - Sub(Box>, Box>), - Mult(Box>, Box>), +impl std::ops::Mul for FlatExpression { + type Output = Self; + + fn mul(self, other: Self) -> Self::Output { + FlatExpression::Mult(BinaryExpression::new(self, other)) + } } impl From for FlatExpression { fn from(other: T) -> Self { - Self::Number(other) + Self::value(other) } } -impl FlatExpression { - pub fn apply_substitution( - self, - substitution: &HashMap, - ) -> FlatExpression { +impl BinaryExpression, FlatExpression, FlatExpression> { + fn apply_substitution(self, substitution: &HashMap) -> Self { + let left = self.left.apply_substitution(substitution); + let right = self.right.apply_substitution(substitution); + + Self::new(left, right).span(self.span) + } +} + +impl IdentifierExpression> { + fn apply_substitution(self, substitution: &HashMap) -> Self { + let id = *self.id.apply_substitution(substitution); + + IdentifierExpression { id, ..self } + } +} + +impl FlatExpression { + pub fn identifier(v: Variable) -> Self { + Self::Identifier(IdentifierExpression::new(v)) + } + + pub fn value(t: T) -> Self { + Self::Value(ValueExpression::new(t)) + } + + pub fn apply_substitution(self, substitution: &HashMap) -> Self { match self { - e @ FlatExpression::Number(_) => e, + e @ FlatExpression::Value(_) => e, FlatExpression::Identifier(id) => { - FlatExpression::Identifier(*id.apply_substitution(substitution)) + FlatExpression::Identifier(id.apply_substitution(substitution)) } - FlatExpression::Add(e1, e2) => FlatExpression::Add( - box e1.apply_substitution(substitution), - box e2.apply_substitution(substitution), - ), - FlatExpression::Sub(e1, e2) => FlatExpression::Sub( - box e1.apply_substitution(substitution), - box e2.apply_substitution(substitution), - ), - FlatExpression::Mult(e1, e2) => FlatExpression::Mult( - box e1.apply_substitution(substitution), - box e2.apply_substitution(substitution), - ), + FlatExpression::Add(e) => FlatExpression::Add(e.apply_substitution(substitution)), + FlatExpression::Sub(e) => FlatExpression::Sub(e.apply_substitution(substitution)), + FlatExpression::Mult(e) => FlatExpression::Mult(e.apply_substitution(substitution)), } } pub fn is_linear(&self) -> bool { match *self { - FlatExpression::Number(_) | FlatExpression::Identifier(_) => true, - FlatExpression::Add(ref x, ref y) | FlatExpression::Sub(ref x, ref y) => { - x.is_linear() && y.is_linear() - } - FlatExpression::Mult(ref x, ref y) => matches!( - (x.clone(), y.clone()), - (box FlatExpression::Number(_), box FlatExpression::Number(_)) - | ( - box FlatExpression::Number(_), - box FlatExpression::Identifier(_) - ) - | ( - box FlatExpression::Identifier(_), - box FlatExpression::Number(_) - ) + FlatExpression::Value(_) | FlatExpression::Identifier(_) => true, + FlatExpression::Add(ref e) => e.left.is_linear() && e.right.is_linear(), + FlatExpression::Sub(ref e) => e.left.is_linear() && e.right.is_linear(), + FlatExpression::Mult(ref e) => matches!( + (&*e.left, &*e.right), + (FlatExpression::Value(_), FlatExpression::Value(_)) + | (FlatExpression::Value(_), FlatExpression::Identifier(_)) + | (FlatExpression::Identifier(_), FlatExpression::Value(_)) ), } } @@ -289,18 +398,18 @@ impl FlatExpression { impl fmt::Display for FlatExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FlatExpression::Number(ref i) => write!(f, "{}", i), + FlatExpression::Value(ref i) => write!(f, "{}", i), FlatExpression::Identifier(ref var) => write!(f, "{}", var), - FlatExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - FlatExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - FlatExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), + FlatExpression::Add(ref e) => write!(f, "{}", e), + FlatExpression::Sub(ref e) => write!(f, "{}", e), + FlatExpression::Mult(ref e) => write!(f, "{}", e), } } } impl From for FlatExpression { fn from(v: Variable) -> FlatExpression { - FlatExpression::Identifier(v) + FlatExpression::identifier(v) } } @@ -314,3 +423,27 @@ impl fmt::Display for Error { write!(f, "{}", self.message) } } + +impl WithSpan for FlatExpression { + fn span(self, span: Option) -> Self { + use FlatExpression::*; + match self { + Add(e) => Add(e.span(span)), + Sub(e) => Sub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Value(e) => Value(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use FlatExpression::*; + match self { + Add(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Value(e) => e.get_span(), + Identifier(e) => e.get_span(), + } + } +} diff --git a/zokrates_ast/src/flat/utils.rs b/zokrates_ast/src/flat/utils.rs index 03239687c..7d1f1ff57 100644 --- a/zokrates_ast/src/flat/utils.rs +++ b/zokrates_ast/src/flat/utils.rs @@ -1,4 +1,5 @@ use crate::flat::{FlatExpression, Variable}; +use std::ops::*; use zokrates_field::Field; // util to convert a vector of `(coefficient, expression)` to a flat_expression @@ -7,16 +8,16 @@ pub fn flat_expression_from_expression_summands FlatExpression { match v.len() { - 0 => FlatExpression::Number(T::zero()), + 0 => FlatExpression::value(T::zero()), 1 => { let (val, var) = v[0].clone(); - FlatExpression::Mult(box FlatExpression::Number(val), box var.into()) + FlatExpression::mul(FlatExpression::value(val), var.into()) } n => { let (u, v) = v.split_at(n / 2); - FlatExpression::Add( - box flat_expression_from_expression_summands(u), - box flat_expression_from_expression_summands(v), + FlatExpression::add( + flat_expression_from_expression_summands(u), + flat_expression_from_expression_summands(v), ) } } @@ -34,19 +35,19 @@ pub fn flat_expression_from_bits(v: Vec>) -> FlatExp pub fn flat_expression_from_variable_summands(v: &[(T, usize)]) -> FlatExpression { match v.len() { - 0 => FlatExpression::Number(T::zero()), + 0 => FlatExpression::value(T::zero()), 1 => { - let (val, var) = v[0].clone(); - FlatExpression::Mult( - box FlatExpression::Number(val), - box FlatExpression::Identifier(Variable::new(var)), + let (val, var) = v[0]; + FlatExpression::mul( + FlatExpression::value(val), + FlatExpression::identifier(Variable::new(var)), ) } n => { let (u, v) = v.split_at(n / 2); - FlatExpression::Add( - box flat_expression_from_variable_summands(u), - box flat_expression_from_variable_summands(v), + FlatExpression::add( + flat_expression_from_variable_summands(u), + flat_expression_from_variable_summands(v), ) } } diff --git a/zokrates_ast/src/ir/check.rs b/zokrates_ast/src/ir/check.rs index 41cac7b0d..e4c1652ae 100644 --- a/zokrates_ast/src/ir/check.rs +++ b/zokrates_ast/src/ir/check.rs @@ -1,5 +1,5 @@ use crate::ir::folder::Folder; -use crate::ir::Directive; +use crate::ir::DirectiveStatement; use crate::ir::Parameter; use crate::ir::ProgIterator; use crate::ir::Statement; @@ -42,8 +42,11 @@ impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector { self.variables.remove(&v); v } - fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { + fn fold_directive_statement( + &mut self, + d: DirectiveStatement<'ast, T>, + ) -> Vec> { self.variables.extend(d.outputs.iter()); - d + vec![Statement::Directive(d)] } } diff --git a/zokrates_ast/src/ir/clean.rs b/zokrates_ast/src/ir/clean.rs index b4fb8f445..1c58faff6 100644 --- a/zokrates_ast/src/ir/clean.rs +++ b/zokrates_ast/src/ir/clean.rs @@ -1,3 +1,5 @@ +use crate::common::WithSpan; + use super::folder::Folder; use super::{ProgIterator, Statement}; use zokrates_field::Field; @@ -8,20 +10,27 @@ pub struct Cleaner; impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { pub fn clean(self) -> ProgIterator<'ast, T, impl IntoIterator>> { ProgIterator { + module_map: self.module_map, arguments: self.arguments, return_count: self.return_count, statements: self .statements .into_iter() .flat_map(|s| Cleaner::default().fold_statement(s)), + solvers: self.solvers, } } } impl<'ast, T: Field> Folder<'ast, T> for Cleaner { fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { + if s.get_span().is_none() { + eprintln!("Internal compiler warning: found a statement without source information. Please open an issue https://github.com/Zokrates/ZoKrates/issues/new?template=bug_report.md"); + } + match s { - Statement::Block(statements) => statements + Statement::Block(s) => s + .inner .into_iter() .flat_map(|s| self.fold_statement(s)) .collect(), diff --git a/zokrates_ast/src/ir/expression.rs b/zokrates_ast/src/ir/expression.rs index a32a1293c..b3dd59478 100644 --- a/zokrates_ast/src/ir/expression.rs +++ b/zokrates_ast/src/ir/expression.rs @@ -1,23 +1,35 @@ use super::Witness; -use crate::common::Variable; +use crate::common::{flat::Variable, Span, WithSpan}; +use derivative::Derivative; use serde::{Deserialize, Serialize}; use std::collections::btree_map::{BTreeMap, Entry}; use std::fmt; -use std::hash::Hash; use std::ops::{Add, Div, Mul, Sub}; use zokrates_field::Field; -#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq)] pub struct QuadComb { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, pub left: LinComb, pub right: LinComb, } -impl QuadComb { - pub fn from_linear_combinations(left: LinComb, right: LinComb) -> Self { - QuadComb { left, right } +impl WithSpan for QuadComb { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span } +} +impl QuadComb { + #[allow(clippy::result_large_err)] pub fn try_linear(self) -> Result, Self> { // identify `(k * ~ONE) * (lincomb)` and `(lincomb) * (k * ~ONE)` and return (k * lincomb) // if not, error out with the input @@ -30,7 +42,7 @@ impl QuadComb { Ok(coefficient) => Ok(self.right * &coefficient), Err(left) => match self.right.try_constant() { Ok(coefficient) => Ok(left * &coefficient), - Err(right) => Err(QuadComb::from_linear_combinations(left, right)), + Err(right) => Err(QuadComb::new(left, right)), }, } } @@ -44,7 +56,7 @@ impl From for LinComb { impl>> From for QuadComb { fn from(x: U) -> QuadComb { - QuadComb::from_linear_combinations(LinComb::one(), x.into()) + QuadComb::new(LinComb::one(), x.into()) } } @@ -54,21 +66,80 @@ impl fmt::Display for QuadComb { } } -#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub struct LinComb(pub Vec<(Variable, T)>); +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq)] +pub struct LinComb { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, + pub value: Vec<(Variable, T)>, +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq)] +pub struct CanonicalLinComb { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + pub span: Option, + pub value: BTreeMap, +} + +impl WithSpan for CanonicalLinComb { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } -#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)] -pub struct CanonicalLinComb(pub BTreeMap); + fn get_span(&self) -> Option { + self.span + } +} -#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq)] pub struct CanonicalQuadComb { + #[derivative(PartialEq = "ignore", Hash = "ignore")] + span: Option, left: CanonicalLinComb, right: CanonicalLinComb, } +impl WithSpan for CanonicalQuadComb { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl CanonicalQuadComb { + pub fn new(left: CanonicalLinComb, right: CanonicalLinComb) -> Self { + Self { + span: None, + left, + right, + } + } +} + +impl QuadComb { + pub fn new(left: LinComb, right: LinComb) -> Self { + Self { + span: None, + left, + right, + } + } +} + impl From> for QuadComb { fn from(q: CanonicalQuadComb) -> Self { QuadComb { + span: q.span, left: q.left.into(), right: q.right.into(), } @@ -77,42 +148,66 @@ impl From> for QuadComb { impl From> for LinComb { fn from(l: CanonicalLinComb) -> Self { - LinComb(l.0.into_iter().collect()) + LinComb { + span: l.span, + value: l.value.into_iter().collect(), + } + } +} + +impl CanonicalLinComb { + pub fn new(value: BTreeMap) -> Self { + Self { span: None, value } } } impl LinComb { + pub fn new(value: Vec<(Variable, T)>) -> Self { + Self { span: None, value } + } + pub fn summand>(mult: U, var: Variable) -> LinComb { let res = vec![(var, mult.into())]; - LinComb(res) + LinComb::new(res) } pub fn zero() -> LinComb { - LinComb(Vec::new()) + LinComb::new(Vec::new()) } pub fn is_zero(&self) -> bool { - self.0.is_empty() + self.value.is_empty() + } +} + +impl WithSpan for LinComb { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span } } impl LinComb { pub fn try_constant(self) -> Result { - match self.0.len() { + match self.value.len() { // if the lincomb is empty, it is reduceable to 0 0 => Ok(T::zero()), _ => { // take the first variable in the lincomb - let first = &self.0[0].0; + let first = &self.value[0].0; if first != &Variable::one() { return Err(self); } // all terms must contain the same variable - if self.0.iter().all(|element| element.0 == *first) { - Ok(self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1)) + if self.value.iter().all(|element| element.0 == *first) { + Ok(self.value.into_iter().fold(T::zero(), |acc, e| acc + e.1)) } else { Err(self) } @@ -121,26 +216,26 @@ impl LinComb { } pub fn is_assignee(&self, witness: &Witness) -> bool { - self.0.len() == 1 - && self.0.get(0).unwrap().1 == T::from(1) - && !witness.0.contains_key(&self.0.get(0).unwrap().0) + self.value.len() == 1 + && self.value.get(0).unwrap().1 == T::from(1) + && !witness.0.contains_key(&self.value.get(0).unwrap().0) } pub fn try_summand(self) -> Result<(Variable, T), Self> { - match self.0.len() { + match self.value.len() { // if the lincomb is empty, it is not reduceable to a summand 0 => Err(self), _ => { // take the first variable in the lincomb - let first = &self.0[0].0; + let first = &self.value[0].0; - if self.0.iter().all(|element| + if self.value.iter().all(|element| // all terms must contain the same variable element.0 == *first) { Ok(( *first, - self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1), + self.value.into_iter().fold(T::zero(), |acc, e| acc + e.1), )) } else { Err(self) @@ -156,32 +251,34 @@ impl LinComb { impl LinComb { pub fn into_canonical(self) -> CanonicalLinComb { - CanonicalLinComb( - self.0 - .into_iter() - .fold(BTreeMap::new(), |mut acc, (val, coeff)| { - // if we're adding 0 times some variable, we can ignore this term - if coeff != T::zero() { - match acc.entry(val) { - Entry::Occupied(o) => { - // if the new value is non zero, update, else remove the term entirely - if o.get().clone() + coeff.clone() != T::zero() { - *o.into_mut() = o.get().clone() + coeff; - } else { - o.remove(); - } - } - Entry::Vacant(v) => { - // We checked earlier but let's make sure we're not creating zero-coeff terms - assert!(coeff != T::zero()); - v.insert(coeff); + let span = self.get_span(); + + CanonicalLinComb::new(self.value.into_iter().fold( + BTreeMap::::new(), + |mut acc, (val, coeff)| { + // if we're adding 0 times some variable, we can ignore this term + if coeff != T::zero() { + match acc.entry(val) { + Entry::Occupied(o) => { + // if the new value is non zero, update, else remove the term entirely + if *o.get() + coeff != T::zero() { + *o.into_mut() = *o.get() + coeff; + } else { + o.remove(); } } + Entry::Vacant(v) => { + // We checked earlier but let's make sure we're not creating zero-coeff terms + assert!(coeff != T::zero()); + v.insert(coeff); + } } + } - acc - }), - ) + acc + }, + )) + .span(span) } pub fn reduce(self) -> Self { @@ -191,10 +288,8 @@ impl LinComb { impl QuadComb { pub fn into_canonical(self) -> CanonicalQuadComb { - CanonicalQuadComb { - left: self.left.into_canonical(), - right: self.right.into_canonical(), - } + CanonicalQuadComb::new(self.left.into_canonical(), self.right.into_canonical()) + .span(self.span) } pub fn reduce(self) -> Self { @@ -209,7 +304,7 @@ impl fmt::Display for LinComb { false => write!( f, "{}", - self.0 + self.value .iter() .map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k)) .collect::>() @@ -222,7 +317,7 @@ impl fmt::Display for LinComb { impl From for LinComb { fn from(v: Variable) -> LinComb { let r = vec![(v, T::one())]; - LinComb(r) + LinComb::new(r) } } @@ -230,9 +325,9 @@ impl Add> for LinComb { type Output = LinComb; fn add(self, other: LinComb) -> LinComb { - let mut res = self.0; - res.extend(other.0); - LinComb(res) + let mut res = self.value; + res.extend(other.value); + LinComb::new(res) } } @@ -241,9 +336,14 @@ impl Sub> for LinComb { fn sub(self, other: LinComb) -> LinComb { // Concatenate with second vector that have negative coeffs - let mut res = self.0; - res.extend(other.0.into_iter().map(|(var, val)| (var, T::zero() - val))); - LinComb(res) + let mut res = self.value; + res.extend( + other + .value + .into_iter() + .map(|(var, val)| (var, T::zero() - val)), + ); + LinComb::new(res) } } @@ -255,10 +355,10 @@ impl Mul<&T> for LinComb { return self; } - LinComb( - self.0 + LinComb::new( + self.value .into_iter() - .map(|(var, coeff)| (var, coeff * scalar.clone())) + .map(|(var, coeff)| (var, coeff * scalar)) .collect(), ) } @@ -299,7 +399,7 @@ mod tests { (Variable::new(42), Bn128Field::from(1)), ]; - assert_eq!(c, LinComb(expected_vec)); + assert_eq!(c, LinComb::new(expected_vec)); } #[test] fn sub() { @@ -312,7 +412,7 @@ mod tests { (Variable::new(42), Bn128Field::from(-1)), ]; - assert_eq!(c, LinComb(expected_vec)); + assert_eq!(c, LinComb::new(expected_vec)); } #[test] @@ -331,35 +431,26 @@ mod tests { fn from_linear() { let a: LinComb = LinComb::summand(3, Variable::new(42)) + LinComb::summand(4, Variable::new(33)); - let expected = QuadComb { - left: LinComb::one(), - right: a.clone(), - }; + let expected = QuadComb::new(LinComb::one(), a.clone()); assert_eq!(QuadComb::from(a), expected); } #[test] fn zero() { let a: LinComb = LinComb::zero(); - let expected: QuadComb = QuadComb { - left: LinComb::one(), - right: LinComb::zero(), - }; + let expected: QuadComb = QuadComb::new(LinComb::one(), LinComb::zero()); assert_eq!(QuadComb::from(a), expected); } #[test] fn display() { - let a: QuadComb = QuadComb { - left: LinComb::summand(3, Variable::new(42)) - + LinComb::summand(4, Variable::new(33)), - right: LinComb::summand(1, Variable::new(21)), - }; + let a: QuadComb = QuadComb::new( + LinComb::summand(3, Variable::new(42)) + LinComb::summand(4, Variable::new(33)), + LinComb::summand(1, Variable::new(21)), + ); assert_eq!(&a.to_string(), "(3 * _42 + 4 * _33) * (1 * _21)"); - let a: QuadComb = QuadComb { - left: LinComb::zero(), - right: LinComb::summand(1, Variable::new(21)), - }; + let a: QuadComb = + QuadComb::new(LinComb::zero(), LinComb::summand(1, Variable::new(21))); assert_eq!(&a.to_string(), "(0) * (1 * _21)"); } } @@ -369,7 +460,7 @@ mod tests { #[test] fn try_summand() { - let summand = LinComb(vec![ + let summand = LinComb::new(vec![ (Variable::new(42), Bn128Field::from(1)), (Variable::new(42), Bn128Field::from(2)), (Variable::new(42), Bn128Field::from(3)), @@ -379,14 +470,14 @@ mod tests { Ok((Variable::new(42), Bn128Field::from(6))) ); - let not_summand = LinComb(vec![ + let not_summand = LinComb::new(vec![ (Variable::new(41), Bn128Field::from(1)), (Variable::new(42), Bn128Field::from(2)), (Variable::new(42), Bn128Field::from(3)), ]); assert!(not_summand.try_summand().is_err()); - let empty: LinComb = LinComb(vec![]); + let empty: LinComb = LinComb::new(vec![]); assert!(empty.try_summand().is_err()); } } diff --git a/zokrates_ast/src/ir/folder.rs b/zokrates_ast/src/ir/folder.rs index 6e67c15de..3b17c722f 100644 --- a/zokrates_ast/src/ir/folder.rs +++ b/zokrates_ast/src/ir/folder.rs @@ -1,7 +1,7 @@ // Generic walk through an IR AST. Not mutating in place use super::*; -use crate::common::Variable; +use crate::common::{flat::Variable, WithSpan}; use zokrates_field::Field; pub trait Folder<'ast, T: Field>: Sized { @@ -21,6 +21,29 @@ pub trait Folder<'ast, T: Field>: Sized { fold_statement(self, s) } + fn fold_statement_cases(&mut self, s: Statement<'ast, T>) -> Vec> { + fold_statement_cases(self, s) + } + + fn fold_constraint_statement(&mut self, s: ConstraintStatement) -> Vec> { + fold_constraint_statement(self, s) + } + + fn fold_directive_statement( + &mut self, + s: DirectiveStatement<'ast, T>, + ) -> Vec> { + fold_directive_statement(self, s) + } + + fn fold_log_statement(&mut self, s: LogStatement) -> Vec> { + fold_log_statement(self, s) + } + + fn fold_block_statement(&mut self, s: BlockStatement<'ast, T>) -> Vec> { + fold_block_statement(self, s) + } + fn fold_linear_combination(&mut self, e: LinComb) -> LinComb { fold_linear_combination(self, e) } @@ -28,10 +51,6 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_quadratic_combination(&mut self, es: QuadComb) -> QuadComb { fold_quadratic_combination(self, es) } - - fn fold_directive(&mut self, d: Directive<'ast, T>) -> Directive<'ast, T> { - fold_directive(self, d) - } } pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( @@ -49,40 +68,73 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - return_count: p.return_count, + ..p } } -pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_constraint_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ConstraintStatement, +) -> Vec> { + vec![Statement::constraint( + f.fold_quadratic_combination(s.quad), + f.fold_linear_combination(s.lin), + s.error, + )] +} + +pub fn fold_log_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: LogStatement, +) -> Vec> { + vec![Statement::log( + s.format_string, + s.expressions + .into_iter() + .map(|(t, e)| { + ( + t, + e.into_iter() + .map(|e| f.fold_linear_combination(e)) + .collect(), + ) + }) + .collect(), + )] +} + +pub fn fold_block_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: BlockStatement<'ast, T>, +) -> Vec> { + vec![Statement::block( + s.inner + .into_iter() + .flat_map(|s| f.fold_statement(s)) + .collect(), + )] +} + +fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: Statement<'ast, T>, +) -> Vec> { + let span = s.get_span(); + f.fold_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +pub fn fold_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: Statement<'ast, T>, ) -> Vec> { match s { - Statement::Block(statements) => vec![Statement::Block( - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - )], - Statement::Constraint(quad, lin, message) => vec![Statement::Constraint( - f.fold_quadratic_combination(quad), - f.fold_linear_combination(lin), - message, - )], - Statement::Directive(dir) => vec![Statement::Directive(f.fold_directive(dir))], - Statement::Log(l, e) => vec![Statement::Log( - l, - e.into_iter() - .map(|(t, e)| { - ( - t, - e.into_iter() - .map(|e| f.fold_linear_combination(e)) - .collect(), - ) - }) - .collect(), - )], + Statement::Constraint(s) => f.fold_constraint_statement(s), + Statement::Directive(s) => f.fold_directive_statement(s), + Statement::Log(s) => f.fold_log_statement(s), + Statement::Block(s) => f.fold_block_statement(s), } } @@ -90,28 +142,31 @@ pub fn fold_linear_combination<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: LinComb, ) -> LinComb { - LinComb( - e.0.into_iter() + LinComb::new( + e.value + .into_iter() .map(|(variable, coefficient)| (f.fold_variable(variable), coefficient)) .collect(), ) + .span(e.span) } pub fn fold_quadratic_combination<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: QuadComb, ) -> QuadComb { - QuadComb { - left: f.fold_linear_combination(e.left), - right: f.fold_linear_combination(e.right), - } + QuadComb::new( + f.fold_linear_combination(e.left), + f.fold_linear_combination(e.right), + ) + .span(e.span) } -pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_directive_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - ds: Directive<'ast, T>, -) -> Directive<'ast, T> { - Directive { + ds: DirectiveStatement<'ast, T>, +) -> Vec> { + vec![Statement::Directive(DirectiveStatement { inputs: ds .inputs .into_iter() @@ -119,13 +174,13 @@ pub fn fold_directive<'ast, T: Field, F: Folder<'ast, T>>( .collect(), outputs: ds.outputs.into_iter().map(|o| f.fold_variable(o)).collect(), ..ds - } + })] } pub fn fold_argument<'ast, T: Field, F: Folder<'ast, T>>(f: &mut F, a: Parameter) -> Parameter { Parameter { id: f.fold_variable(a.id), - private: a.private, + ..a } } diff --git a/zokrates_ast/src/ir/from_flat.rs b/zokrates_ast/src/ir/from_flat.rs index fc961cd85..f7f8c42f8 100644 --- a/zokrates_ast/src/ir/from_flat.rs +++ b/zokrates_ast/src/ir/from_flat.rs @@ -1,5 +1,8 @@ +use crate::common::statements::LogStatement; +use crate::common::WithSpan; use crate::flat::{FlatDirective, FlatExpression, FlatProgIterator, FlatStatement, Variable}; -use crate::ir::{Directive, LinComb, ProgIterator, QuadComb, Statement}; +use crate::ir::{DirectiveStatement, LinComb, ProgIterator, QuadComb, Statement}; +use std::ops::*; use zokrates_field::Field; impl QuadComb { @@ -8,9 +11,7 @@ impl QuadComb { match flat_expression.is_linear() { true => LinComb::from(flat_expression).into(), false => match flat_expression { - FlatExpression::Mult(box e1, box e2) => { - QuadComb::from_linear_combinations(e1.into(), e2.into()) - } + FlatExpression::Mult(e) => QuadComb::new((*e.left).into(), (*e.right).into()), e => unimplemented!("{}", e), }, } @@ -24,71 +25,76 @@ pub fn from_flat<'ast, T: Field, I: IntoIterator>> statements: flat_prog_iterator.statements.into_iter().map(Into::into), arguments: flat_prog_iterator.arguments, return_count: flat_prog_iterator.return_count, + module_map: flat_prog_iterator.module_map, + solvers: vec![], } } impl From> for LinComb { fn from(flat_expression: FlatExpression) -> LinComb { + let span = flat_expression.get_span(); + match flat_expression { - FlatExpression::Number(ref n) if *n == T::from(0) => LinComb::zero(), - FlatExpression::Number(n) => LinComb::summand(n, Variable::one()), - FlatExpression::Identifier(id) => LinComb::from(id), - FlatExpression::Add(box e1, box e2) => LinComb::from(e1) + LinComb::from(e2), - FlatExpression::Sub(box e1, box e2) => LinComb::from(e1) - LinComb::from(e2), - FlatExpression::Mult( - box FlatExpression::Number(n1), - box FlatExpression::Identifier(v1), - ) - | FlatExpression::Mult( - box FlatExpression::Identifier(v1), - box FlatExpression::Number(n1), - ) => LinComb::summand(n1, v1), - FlatExpression::Mult( - box FlatExpression::Number(n1), - box FlatExpression::Number(n2), - ) => LinComb::summand(n1 * n2, Variable::one()), - e => unreachable!("{}", e), + FlatExpression::Value(ref n) if n.value == T::from(0) => LinComb::zero(), + FlatExpression::Value(n) => LinComb::summand(n.value, Variable::one()), + FlatExpression::Identifier(id) => LinComb::from(id.id), + FlatExpression::Add(e) => LinComb::from(*e.left) + LinComb::from(*e.right), + FlatExpression::Sub(e) => LinComb::from(*e.left) - LinComb::from(*e.right), + FlatExpression::Mult(e) => match (*e.left, *e.right) { + (FlatExpression::Value(n1), FlatExpression::Identifier(v1)) + | (FlatExpression::Identifier(v1), FlatExpression::Value(n1)) => { + LinComb::summand(n1.value, v1.id) + } + (FlatExpression::Value(n1), FlatExpression::Value(n2)) => { + LinComb::summand(n1.value * n2.value, Variable::one()) + } + (left, right) => unreachable!("{}", FlatExpression::mul(left, right).span(e.span)), + }, } + .span(span) } } impl<'ast, T: Field> From> for Statement<'ast, T> { fn from(flat_statement: FlatStatement<'ast, T>) -> Statement<'ast, T> { + let span = flat_statement.get_span(); + match flat_statement { - FlatStatement::Block(statements) => { - Statement::Block(statements.into_iter().map(Statement::from).collect()) - } - FlatStatement::Condition(linear, quadratic, message) => match quadratic { - FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( - QuadComb::from_linear_combinations(lhs.into(), rhs.into()), - linear.into(), - Some(message), + FlatStatement::Condition(s) => match s.quad { + FlatExpression::Mult(e) => Statement::constraint( + QuadComb::new((*e.left).into(), (*e.right).into()).span(e.span), + LinComb::from(s.lin), + Some(s.error), ), - e => Statement::Constraint(LinComb::from(e).into(), linear.into(), Some(message)), + e => Statement::constraint(LinComb::from(e), s.lin, Some(s.error)), }, - FlatStatement::Definition(var, quadratic) => match quadratic { - FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint( - QuadComb::from_linear_combinations(lhs.into(), rhs.into()), - var.into(), + FlatStatement::Block(statements) => { + Statement::block(statements.inner.into_iter().map(Statement::from).collect()) + } + FlatStatement::Definition(s) => match s.rhs { + FlatExpression::Mult(e) => Statement::constraint( + QuadComb::new((*e.left).into(), (*e.right).into()).span(e.span), + s.assignee, None, ), - e => Statement::Constraint(LinComb::from(e).into(), var.into(), None), + e => Statement::constraint(LinComb::from(e), s.assignee, None), }, FlatStatement::Directive(ds) => Statement::Directive(ds.into()), - FlatStatement::Log(l, expressions) => Statement::Log( - l, - expressions + FlatStatement::Log(s) => Statement::Log(LogStatement::new( + s.format_string, + s.expressions .into_iter() .map(|(t, e)| (t, e.into_iter().map(LinComb::from).collect())) .collect(), - ), + )), } + .span(span) } } -impl<'ast, T: Field> From> for Directive<'ast, T> { - fn from(ds: FlatDirective<'ast, T>) -> Directive { - Directive { +impl<'ast, T: Field> From> for DirectiveStatement<'ast, T> { + fn from(ds: FlatDirective) -> DirectiveStatement { + DirectiveStatement { inputs: ds .inputs .into_iter() @@ -96,6 +102,7 @@ impl<'ast, T: Field> From> for Directive<'ast, T> { .collect(), solver: ds.solver, outputs: ds.outputs, + span: ds.span, } } } @@ -108,7 +115,7 @@ mod tests { #[test] fn zero() { // 0 - let zero = FlatExpression::Number(Bn128Field::from(0)); + let zero = FlatExpression::value(Bn128Field::from(0)); let expected: LinComb = LinComb::zero(); assert_eq!(LinComb::from(zero), expected); } @@ -116,7 +123,7 @@ mod tests { #[test] fn one() { // 1 - let one = FlatExpression::Number(Bn128Field::from(1)); + let one = FlatExpression::value(Bn128Field::from(1)); let expected: LinComb = Variable::one().into(); assert_eq!(LinComb::from(one), expected); } @@ -124,7 +131,7 @@ mod tests { #[test] fn forty_two() { // 42 - let one = FlatExpression::Number(Bn128Field::from(42)); + let one = FlatExpression::value(Bn128Field::from(42)); let expected: LinComb = LinComb::summand(42, Variable::one()); assert_eq!(LinComb::from(one), expected); } @@ -132,9 +139,9 @@ mod tests { #[test] fn add() { // x + y - let add = FlatExpression::Add( - box FlatExpression::Identifier(Variable::new(42)), - box FlatExpression::Identifier(Variable::new(21)), + let add = FlatExpression::add( + FlatExpression::identifier(Variable::new(42)), + FlatExpression::identifier(Variable::new(21)), ); let expected: LinComb = LinComb::summand(1, Variable::new(42)) + LinComb::summand(1, Variable::new(21)); @@ -144,14 +151,14 @@ mod tests { #[test] fn linear_combination() { // 42*x + 21*y - let add = FlatExpression::Add( - box FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(42)), - box FlatExpression::Identifier(Variable::new(42)), + let add = FlatExpression::add( + FlatExpression::mul( + FlatExpression::value(Bn128Field::from(42)), + FlatExpression::identifier(Variable::new(42)), ), - box FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(21)), - box FlatExpression::Identifier(Variable::new(21)), + FlatExpression::mul( + FlatExpression::value(Bn128Field::from(21)), + FlatExpression::identifier(Variable::new(21)), ), ); let expected: LinComb = @@ -162,14 +169,14 @@ mod tests { #[test] fn linear_combination_inverted() { // x*42 + y*21 - let add = FlatExpression::Add( - box FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(42)), - box FlatExpression::Number(Bn128Field::from(42)), + let add = FlatExpression::add( + FlatExpression::mul( + FlatExpression::identifier(Variable::new(42)), + FlatExpression::value(Bn128Field::from(42)), ), - box FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(21)), - box FlatExpression::Number(Bn128Field::from(21)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(21)), + FlatExpression::value(Bn128Field::from(21)), ), ); let expected: LinComb = diff --git a/zokrates_ast/src/ir/mod.rs b/zokrates_ast/src/ir/mod.rs index 78b48f808..e3aecebf1 100644 --- a/zokrates_ast/src/ir/mod.rs +++ b/zokrates_ast/src/ir/mod.rs @@ -1,10 +1,9 @@ -use crate::common::FormatString; +use crate::common::{FormatString, ModuleMap, Span, WithSpan}; use crate::typed::ConcreteType; use derivative::Derivative; use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt; -use std::hash::Hash; use zokrates_field::Field; mod check; @@ -14,100 +13,186 @@ pub mod folder; pub mod from_flat; mod serialize; pub mod smtlib2; +mod solver_indexer; pub mod visitor; mod witness; pub use self::expression::QuadComb; pub use self::expression::{CanonicalLinComb, LinComb}; -pub use self::serialize::ProgEnum; -pub use crate::common::Parameter; +pub use self::serialize::{ProgEnum, ProgHeader}; +pub use crate::common::flat::Parameter; +pub use crate::common::flat::Variable; pub use crate::common::RuntimeError; pub use crate::common::Solver; -pub use crate::common::Variable; pub use self::witness::Witness; +pub type LogStatement = crate::common::statements::LogStatement<(ConcreteType, Vec>)>; +pub type DirectiveStatement<'ast, T> = + crate::common::statements::DirectiveStatement, Variable, Solver<'ast, T>>; + +#[derive(Derivative, Clone, Debug, Serialize, Deserialize)] +#[derivative(Hash, PartialEq, Eq)] +pub struct ConstraintStatement { + #[derivative(Hash = "ignore", PartialEq = "ignore")] + pub span: Option, + pub quad: QuadComb, + pub lin: LinComb, + #[derivative(Hash = "ignore", PartialEq = "ignore")] + pub error: Option, +} + +impl ConstraintStatement { + pub fn new(quad: QuadComb, lin: LinComb, error: Option) -> Self { + Self { + span: None, + quad, + lin, + error, + } + } +} + +impl WithSpan for ConstraintStatement { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl fmt::Display for ConstraintStatement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} == {}{}", + self.quad, + self.lin, + self.error + .as_ref() + .map(|e| format!(" // {}", e)) + .unwrap_or_else(|| "".to_string()) + ) + } +} + +#[derive(Derivative, Clone, Debug, Serialize, Deserialize)] +#[derivative(Hash, PartialEq, Eq)] +pub struct BlockStatement<'ast, T> { + #[derivative(Hash = "ignore", PartialEq = "ignore")] + pub span: Option, + #[serde(borrow)] + pub inner: Vec>, +} + +impl<'ast, T> BlockStatement<'ast, T> { + pub fn new(inner: Vec>) -> Self { + Self { span: None, inner } + } +} + +impl<'ast, T> WithSpan for BlockStatement<'ast, T> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T: Field> fmt::Display for BlockStatement<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "{{")?; + for s in &self.inner { + writeln!(f, "{}", s)?; + } + write!(f, "}}") + } +} + +#[allow(clippy::large_enum_variant)] #[derive(Debug, Serialize, Deserialize, Clone, Derivative)] #[derivative(Hash, PartialEq, Eq)] pub enum Statement<'ast, T> { #[serde(skip)] - Block(Vec>), - Constraint( - QuadComb, - LinComb, - #[derivative(Hash = "ignore")] Option, - ), + Block(BlockStatement<'ast, T>), + Constraint(ConstraintStatement), #[serde(borrow)] - Directive(Directive<'ast, T>), - Log(FormatString, Vec<(ConcreteType, Vec>)>), + Directive(DirectiveStatement<'ast, T>), + Log(LogStatement), } pub type PublicInputs = BTreeSet; +impl<'ast, T> WithSpan for Statement<'ast, T> { + fn span(self, span: Option) -> Self { + match self { + Statement::Constraint(c) => Statement::Constraint(c.span(span)), + Statement::Directive(c) => Statement::Directive(c.span(span)), + Statement::Log(c) => Statement::Log(c.span(span)), + Statement::Block(c) => Statement::Block(c.span(span)), + } + } + + fn get_span(&self) -> Option { + match self { + Statement::Constraint(c) => c.get_span(), + Statement::Directive(c) => c.get_span(), + Statement::Log(c) => c.get_span(), + Statement::Block(c) => c.get_span(), + } + } +} + impl<'ast, T: Field> Statement<'ast, T> { pub fn definition>>(v: Variable, e: U) -> Self { - Statement::Constraint(e.into(), v.into(), None) + Statement::constraint(e, v, None) } - pub fn constraint>, V: Into>>(quad: U, lin: V) -> Self { - Statement::Constraint(quad.into(), lin.into(), None) + pub fn constraint>, V: Into>>( + quad: U, + lin: V, + error: Option, + ) -> Self { + Statement::Constraint(ConstraintStatement::new(quad.into(), lin.into(), error)) } -} -#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] -pub struct Directive<'ast, T> { - pub inputs: Vec>, - pub outputs: Vec, - #[serde(borrow)] - pub solver: Solver<'ast, T>, -} + pub fn log( + format_string: FormatString, + expressions: Vec<(ConcreteType, Vec>)>, + ) -> Self { + Statement::Log(LogStatement::new(format_string, expressions)) + } -impl<'ast, T: Field> fmt::Display for Directive<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "# {} = {}({})", - self.outputs - .iter() - .map(|o| format!("{}", o)) - .collect::>() - .join(", "), - self.solver, - self.inputs - .iter() - .map(|i| format!("{}", i)) - .collect::>() - .join(", ") - ) + pub fn block(inner: Vec>) -> Self { + Statement::Block(BlockStatement::new(inner)) + } + + pub fn directive( + outputs: Vec, + solver: Solver<'ast, T>, + inputs: Vec>, + ) -> Self { + Statement::Directive(DirectiveStatement::new(outputs, solver, inputs)) } } impl<'ast, T: Field> fmt::Display for Statement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Statement::Block(ref statements) => { - writeln!(f, "{{")?; - for s in statements { - writeln!(f, "{}", s)?; - } - write!(f, "}}") - } - Statement::Constraint(ref quad, ref lin, ref error) => write!( - f, - "{} == {}{}", - quad, - lin, - error - .as_ref() - .map(|e| format!(" // {}", e)) - .unwrap_or_else(|| "".to_string()) - ), + Statement::Constraint(ref s) => write!(f, "{}", s), + Statement::Block(ref s) => write!(f, "{}", s), Statement::Directive(ref s) => write!(f, "{}", s), - Statement::Log(ref s, ref expressions) => write!( + Statement::Log(ref s) => write!( f, "log(\"{}\", {})", - s, - expressions + s.format_string, + s.expressions .iter() .map(|(_, l)| format!( "[{}]", @@ -127,17 +212,28 @@ pub type Prog<'ast, T> = ProgIterator<'ast, T, Vec>>; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub struct ProgIterator<'ast, T, I: IntoIterator>> { + pub module_map: ModuleMap, pub arguments: Vec, pub return_count: usize, pub statements: I, + #[serde(borrow)] + pub solvers: Vec>, } impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, I> { - pub fn new(arguments: Vec, statements: I, return_count: usize) -> Self { + pub fn new( + arguments: Vec, + statements: I, + return_count: usize, + module_map: ModuleMap, + solvers: Vec>, + ) -> Self { Self { arguments, return_count, statements, + module_map, + solvers, } } @@ -146,6 +242,8 @@ impl<'ast, T, I: IntoIterator>> ProgIterator<'ast, T, statements: self.statements.into_iter().collect::>(), arguments: self.arguments, return_count: self.return_count, + module_map: self.module_map, + solvers: self.solvers, } } @@ -171,7 +269,7 @@ impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'a self.arguments .iter() .filter(|p| !p.private) - .map(|p| witness.0.get(&p.id).unwrap().clone()) + .map(|p| *witness.0.get(&p.id).unwrap()) .chain(witness.return_values()) .collect() } @@ -192,6 +290,8 @@ impl<'ast, T> Prog<'ast, T> { statements: self.statements.into_iter(), arguments: self.arguments, return_count: self.return_count, + module_map: self.module_map, + solvers: self.solvers, } } } @@ -233,12 +333,9 @@ mod tests { #[test] fn print_constraint() { - let c: Statement = Statement::Constraint( - QuadComb::from_linear_combinations( - Variable::new(42).into(), - Variable::new(42).into(), - ), - Variable::new(42).into(), + let c: Statement = Statement::constraint( + QuadComb::new(Variable::new(42).into(), Variable::new(42).into()), + Variable::new(42), None, ); assert_eq!(format!("{}", c), "(1 * _42) * (1 * _42) == 1 * _42") diff --git a/zokrates_ast/src/ir/serialize.rs b/zokrates_ast/src/ir/serialize.rs index 09d003900..02b2d420a 100644 --- a/zokrates_ast/src/ir/serialize.rs +++ b/zokrates_ast/src/ir/serialize.rs @@ -1,14 +1,17 @@ -use crate::ir::check::UnconstrainedVariableDetector; +use crate::ir::{check::UnconstrainedVariableDetector, solver_indexer::SolverIndexer}; use super::{ProgIterator, Statement}; +use crate::ir::ModuleMap; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use serde::Deserialize; use serde_cbor::{self, StreamDeserializer}; -use std::io::{Read, Write}; +use std::io::{Read, Seek, Write}; use zokrates_field::*; type DynamicError = Box; const ZOKRATES_MAGIC: &[u8; 4] = &[0x5a, 0x4f, 0x4b, 0]; -const ZOKRATES_VERSION_2: &[u8; 4] = &[0, 0, 0, 2]; +const FILE_VERSION: &[u8; 4] = &[3, 0, 0, 0]; #[derive(PartialEq, Eq, Debug)] pub enum ProgEnum< @@ -58,33 +61,203 @@ impl< } } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[repr(u32)] +pub enum SectionType { + Parameters = 1, + Constraints = 2, + Solvers = 3, + Modules = 4, +} + +impl TryFrom for SectionType { + type Error = String; + + fn try_from(value: u32) -> Result { + match value { + 1 => Ok(SectionType::Parameters), + 2 => Ok(SectionType::Constraints), + 3 => Ok(SectionType::Solvers), + 4 => Ok(SectionType::Modules), + _ => Err("invalid section type".to_string()), + } + } +} + +#[derive(Debug, Clone)] +pub struct Section { + pub ty: SectionType, + pub offset: u64, + pub length: u64, +} + +impl Section { + pub fn new(ty: SectionType) -> Self { + Self { + ty, + offset: 0, + length: 0, + } + } + + pub fn set_offset(&mut self, offset: u64) { + self.offset = offset; + } + + pub fn set_length(&mut self, length: u64) { + self.length = length; + } +} + +#[derive(Debug, Clone)] +pub struct ProgHeader { + pub magic: [u8; 4], + pub version: [u8; 4], + pub curve_id: [u8; 4], + pub constraint_count: u32, + pub return_count: u32, + pub sections: [Section; 4], +} + +impl ProgHeader { + pub fn write(&self, mut w: W) -> std::io::Result<()> { + w.write_all(&self.magic)?; + w.write_all(&self.version)?; + w.write_all(&self.curve_id)?; + w.write_u32::(self.constraint_count)?; + w.write_u32::(self.return_count)?; + + for s in &self.sections { + w.write_u32::(s.ty as u32)?; + w.write_u64::(s.offset)?; + w.write_u64::(s.length)?; + } + + Ok(()) + } + + pub fn read(mut r: R) -> std::io::Result { + let mut magic = [0; 4]; + r.read_exact(&mut magic)?; + + let mut version = [0; 4]; + r.read_exact(&mut version)?; + + let mut curve_id = [0; 4]; + r.read_exact(&mut curve_id)?; + + let constraint_count = r.read_u32::()?; + let return_count = r.read_u32::()?; + + let parameters = Self::read_section(r.by_ref())?; + let constraints = Self::read_section(r.by_ref())?; + let solvers = Self::read_section(r.by_ref())?; + let module_map = Self::read_section(r.by_ref())?; + + Ok(ProgHeader { + magic, + version, + curve_id, + constraint_count, + return_count, + sections: [parameters, constraints, solvers, module_map], + }) + } + + fn read_section(mut r: R) -> std::io::Result
{ + let id = r.read_u32::()?; + let mut section = Section::new( + SectionType::try_from(id) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, + ); + section.set_offset(r.read_u64::()?); + section.set_length(r.read_u64::()?); + Ok(section) + } +} + impl<'ast, T: Field, I: IntoIterator>> ProgIterator<'ast, T, I> { /// serialize a program iterator, returning the number of constraints serialized /// Note that we only return constraints, not other statements such as directives - pub fn serialize(self, mut w: W) -> Result { + pub fn serialize(self, mut w: W) -> Result { use super::folder::Folder; - w.write_all(ZOKRATES_MAGIC)?; - w.write_all(ZOKRATES_VERSION_2)?; - w.write_all(&T::id())?; + // reserve bytes for the header + w.write_all(&[0u8; std::mem::size_of::()])?; + + // write parameters section + let parameters = { + let mut section = Section::new(SectionType::Parameters); + section.set_offset(w.stream_position()?); - serde_cbor::to_writer(&mut w, &self.arguments)?; - serde_cbor::to_writer(&mut w, &self.return_count)?; + serde_cbor::to_writer(&mut w, &self.arguments)?; + section.set_length(w.stream_position()? - section.offset); + section + }; + + let mut solver_indexer: SolverIndexer<'ast, T> = SolverIndexer::default(); let mut unconstrained_variable_detector = UnconstrainedVariableDetector::new(&self); + let mut count: usize = 0; - let statements = self.statements.into_iter(); + // write constraints section + let constraints = { + let mut section = Section::new(SectionType::Constraints); + section.set_offset(w.stream_position()?); - let mut count = 0; - for s in statements { - if matches!(s, Statement::Constraint(..)) { - count += 1; - } - let s = unconstrained_variable_detector.fold_statement(s); - for s in s { - serde_cbor::to_writer(&mut w, &s)?; + let statements = self.statements.into_iter(); + for s in statements { + if matches!(s, Statement::Constraint(..)) { + count += 1; + } + let s: Vec> = solver_indexer + .fold_statement(s) + .into_iter() + .flat_map(|s| unconstrained_variable_detector.fold_statement(s)) + .collect(); + for s in s { + serde_cbor::to_writer(&mut w, &s)?; + } } - } + + section.set_length(w.stream_position()? - section.offset); + section + }; + + // write solvers section + let solvers = { + let mut section = Section::new(SectionType::Solvers); + section.set_offset(w.stream_position()?); + + serde_cbor::to_writer(&mut w, &solver_indexer.solvers)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; + + // write module map section + let module_map = { + let mut section = Section::new(SectionType::Solvers); + section.set_offset(w.stream_position()?); + + serde_cbor::to_writer(&mut w, &self.module_map)?; + + section.set_length(w.stream_position()? - section.offset); + section + }; + + let header = ProgHeader { + magic: *ZOKRATES_MAGIC, + version: *FILE_VERSION, + curve_id: T::id(), + constraint_count: count as u32, + return_count: self.return_count as u32, + sections: [parameters, constraints, solvers, module_map], + }; + + // rewind to write the header + w.rewind()?; + header.write(&mut w)?; unconstrained_variable_detector .finalize() @@ -103,11 +276,11 @@ impl<'de, R: serde_cbor::de::Read<'de>, T: serde::Deserialize<'de>> Iterator type Item = T; fn next(&mut self) -> Option { - self.s.next().transpose().unwrap() + self.s.next().and_then(|v| v.ok()) } } -impl<'de, R: Read> +impl<'de, R: Read + Seek> ProgEnum< 'de, UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bls12_381Field>>, @@ -116,125 +289,86 @@ impl<'de, R: Read> UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, Bw6_761Field>>, > { + fn read( + mut r: R, + header: &ProgHeader, + ) -> ProgIterator< + 'de, + T, + UnwrappedStreamDeserializer<'de, serde_cbor::de::IoRead, Statement<'de, T>>, + > { + let parameters = { + let section = &header.sections[0]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read parameters")) + .unwrap() + }; + + let solvers = { + let section = &header.sections[2]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + Vec::deserialize(&mut p) + .map_err(|_| String::from("Cannot read solvers")) + .unwrap() + }; + + let module_map = { + let section = &header.sections[3]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let mut p = serde_cbor::Deserializer::from_reader(r.by_ref()); + ModuleMap::deserialize(&mut p) + .map_err(|_| String::from("Cannot read module map")) + .unwrap() + }; + + let statements_deserializer = { + let section = &header.sections[1]; + r.seek(std::io::SeekFrom::Start(section.offset)).unwrap(); + + let p = serde_cbor::Deserializer::from_reader(r); + let s = p.into_iter::>(); + + UnwrappedStreamDeserializer { s } + }; + + ProgIterator::new( + parameters, + statements_deserializer, + header.return_count as usize, + module_map, + solvers, + ) + } + pub fn deserialize(mut r: R) -> Result { + let header = ProgHeader::read(&mut r).map_err(|_| String::from("Invalid header"))?; + // Check the magic number, `ZOK` - let mut magic = [0; 4]; - r.read_exact(&mut magic) - .map_err(|_| String::from("Cannot read magic number"))?; - - if &magic == ZOKRATES_MAGIC { - // Check the version, 2 - let mut version = [0; 4]; - r.read_exact(&mut version) - .map_err(|_| String::from("Cannot read version"))?; - - if &version == ZOKRATES_VERSION_2 { - // Check the curve identifier, deserializing accordingly - let mut curve = [0; 4]; - r.read_exact(&mut curve) - .map_err(|_| String::from("Cannot read curve identifier"))?; - - use serde::de::Deserializer; - let mut p = serde_cbor::Deserializer::from_reader(r); - - struct ArgumentsVisitor; - - impl<'de> serde::de::Visitor<'de> for ArgumentsVisitor { - type Value = Vec; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("seq of flat param") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut res = vec![]; - while let Some(e) = seq.next_element().unwrap() { - res.push(e); - } - Ok(res) - } - } + if &header.magic != ZOKRATES_MAGIC { + return Err("Invalid magic number".to_string()); + } - let arguments = p.deserialize_seq(ArgumentsVisitor).unwrap(); - - struct ReturnCountVisitor; - - impl<'de> serde::de::Visitor<'de> for ReturnCountVisitor { - type Value = usize; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("usize") - } - - fn visit_u32(self, v: u32) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - - fn visit_u8(self, v: u8) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - - fn visit_u16(self, v: u16) -> Result - where - E: serde::de::Error, - { - Ok(v as usize) - } - } + // Check the file version + if &header.version != FILE_VERSION { + return Err("Invalid file version".to_string()); + } - let return_count = p.deserialize_u32(ReturnCountVisitor).unwrap(); - - match curve { - m if m == Bls12_381Field::id() => { - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_381Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - ))) - } - m if m == Bn128Field::id() => { - let s = p.into_iter::>(); - - Ok(ProgEnum::Bn128Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - ))) - } - m if m == Bls12_377Field::id() => { - let s = p.into_iter::>(); - - Ok(ProgEnum::Bls12_377Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - ))) - } - m if m == Bw6_761Field::id() => { - let s = p.into_iter::>(); - - Ok(ProgEnum::Bw6_761Program(ProgIterator::new( - arguments, - UnwrappedStreamDeserializer { s }, - return_count, - ))) - } - _ => Err(String::from("Unknown curve identifier")), - } - } else { - Err(String::from("Unknown version")) + match header.curve_id { + m if m == Bls12_381Field::id() => { + Ok(ProgEnum::Bls12_381Program(Self::read(r, &header))) + } + m if m == Bn128Field::id() => Ok(ProgEnum::Bn128Program(Self::read(r, &header))), + m if m == Bls12_377Field::id() => { + Ok(ProgEnum::Bls12_377Program(Self::read(r, &header))) } - } else { - Err(String::from("Wrong magic number")) + m if m == Bw6_761Field::id() => Ok(ProgEnum::Bw6_761Program(Self::read(r, &header))), + _ => Err(String::from("Unknown curve identifier")), } } } diff --git a/zokrates_ast/src/ir/smtlib2.rs b/zokrates_ast/src/ir/smtlib2.rs index bc1188518..4d43d0d28 100644 --- a/zokrates_ast/src/ir/smtlib2.rs +++ b/zokrates_ast/src/ir/smtlib2.rs @@ -79,11 +79,11 @@ impl<'ast, T: Field> SMTLib2 for Statement<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Statement::Block(..) => unreachable!(), - Statement::Constraint(ref quad, ref lin, _) => { + Statement::Constraint(ref s) => { write!(f, "(= (mod ")?; - quad.to_smtlib2(f)?; + s.quad.to_smtlib2(f)?; write!(f, " |~prime|) (mod ")?; - lin.to_smtlib2(f)?; + s.lin.to_smtlib2(f)?; write!(f, " |~prime|))") } Statement::Directive(ref s) => s.to_smtlib2(f), @@ -92,7 +92,7 @@ impl<'ast, T: Field> SMTLib2 for Statement<'ast, T> { } } -impl<'ast, T: Field> SMTLib2 for Directive<'ast, T> { +impl<'ast, T: Field> SMTLib2 for DirectiveStatement<'ast, T> { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "") } @@ -109,15 +109,20 @@ impl SMTLib2 for LinComb { match self.is_zero() { true => write!(f, "0"), false => { - if self.0.len() > 1 { + if self.value.len() > 1 { write!(f, "(+")?; - for expr in self.0.iter() { + for expr in self.value.iter() { write!(f, " ")?; format_prefix_op_smtlib2(f, "*", &expr.0, &expr.1.to_biguint())?; } write!(f, ")") } else { - format_prefix_op_smtlib2(f, "*", &self.0[0].0, &self.0[0].1.to_biguint()) + format_prefix_op_smtlib2( + f, + "*", + &self.value[0].0, + &self.value[0].1.to_biguint(), + ) } } } diff --git a/zokrates_ast/src/ir/solver_indexer.rs b/zokrates_ast/src/ir/solver_indexer.rs new file mode 100644 index 000000000..334da750b --- /dev/null +++ b/zokrates_ast/src/ir/solver_indexer.rs @@ -0,0 +1,60 @@ +use crate::common::RefCall; +use crate::ir::folder::Folder; +use crate::ir::Solver; +use crate::zir::ZirFunction; +use std::collections::hash_map::DefaultHasher; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use zokrates_field::Field; + +use super::DirectiveStatement; +use super::Statement; + +type Hash = u64; + +fn hash(f: &ZirFunction) -> Hash { + use std::hash::Hash; + use std::hash::Hasher; + let mut hasher = DefaultHasher::new(); + f.hash(&mut hasher); + hasher.finish() +} + +#[derive(Debug, Default)] +pub struct SolverIndexer<'ast, T> { + pub solvers: Vec>, + pub index_map: HashMap, +} + +impl<'ast, T: Field> Folder<'ast, T> for SolverIndexer<'ast, T> { + fn fold_directive_statement( + &mut self, + d: DirectiveStatement<'ast, T>, + ) -> Vec> { + let res = match d.solver { + Solver::Zir(f) => { + let argument_count = f.arguments.len(); + let h = hash(&f); + let index = match self.index_map.entry(h) { + Entry::Occupied(v) => *v.get(), + Entry::Vacant(entry) => { + let index = self.solvers.len(); + entry.insert(index); + self.solvers.push(Solver::Zir(f)); + index + } + }; + DirectiveStatement::new( + d.outputs, + Solver::Ref(RefCall { + index, + argument_count, + }), + d.inputs, + ) + } + _ => d, + }; + vec![Statement::Directive(res)] + } +} diff --git a/zokrates_ast/src/ir/visitor.rs b/zokrates_ast/src/ir/visitor.rs index d3894ca6b..175c48bfe 100644 --- a/zokrates_ast/src/ir/visitor.rs +++ b/zokrates_ast/src/ir/visitor.rs @@ -1,7 +1,7 @@ // Generic walk through an IR AST. Not mutating in place use super::*; -use crate::common::Variable; +use crate::common::flat::Variable; use zokrates_field::Field; pub trait Visitor: Sized { @@ -33,8 +33,8 @@ pub trait Visitor: Sized { visit_quadratic_combination(self, es) } - fn visit_directive(&mut self, d: &Directive) { - visit_directive(self, d) + fn visit_directive_statement(&mut self, d: &DirectiveStatement) { + visit_directive_statement(self, d) } fn visit_runtime_error(&mut self, e: &RuntimeError) { @@ -53,21 +53,21 @@ pub fn visit_module>(f: &mut F, p: &Prog) { pub fn visit_statement>(f: &mut F, s: &Statement) { match s { - Statement::Block(statements) => { - for s in statements { + Statement::Block(s) => { + for s in &s.inner { f.visit_statement(s); } } - Statement::Constraint(quad, lin, error) => { - f.visit_quadratic_combination(quad); - f.visit_linear_combination(lin); - if let Some(error) = error.as_ref() { + Statement::Constraint(s) => { + f.visit_quadratic_combination(&s.quad); + f.visit_linear_combination(&s.lin); + if let Some(error) = s.error.as_ref() { f.visit_runtime_error(error); } } - Statement::Directive(dir) => f.visit_directive(dir), - Statement::Log(_, expressions) => { - for (_, e) in expressions { + Statement::Directive(dir) => f.visit_directive_statement(dir), + Statement::Log(s) => { + for (_, e) in &s.expressions { for e in e { f.visit_linear_combination(e); } @@ -77,7 +77,7 @@ pub fn visit_statement>(f: &mut F, s: &Statement) { } pub fn visit_linear_combination>(f: &mut F, e: &LinComb) { - for expr in e.0.iter() { + for expr in e.value.iter() { f.visit_variable(&expr.0); f.visit_value(&expr.1); } @@ -88,7 +88,7 @@ pub fn visit_quadratic_combination>(f: &mut F, e: &QuadC f.visit_linear_combination(&e.right); } -pub fn visit_directive>(f: &mut F, ds: &Directive) { +pub fn visit_directive_statement>(f: &mut F, ds: &DirectiveStatement) { for expr in ds.inputs.iter() { f.visit_quadratic_combination(expr); } diff --git a/zokrates_ast/src/ir/witness.rs b/zokrates_ast/src/ir/witness.rs index 366c62d8d..b8117279f 100644 --- a/zokrates_ast/src/ir/witness.rs +++ b/zokrates_ast/src/ir/witness.rs @@ -1,4 +1,4 @@ -use crate::common::Variable; +use crate::common::flat::Variable; use std::collections::{BTreeMap, HashMap}; use std::fmt; use std::io; @@ -41,54 +41,44 @@ impl Witness { Witness(BTreeMap::new()) } - pub fn write(&self, writer: W) -> io::Result<()> { - let mut wtr = csv::WriterBuilder::new() - .delimiter(b' ') - .flexible(true) - .has_headers(false) - .from_writer(writer); + pub fn write(&self, mut writer: W) -> io::Result<()> { + let length = self.0.len(); + writer.write_all(&length.to_le_bytes())?; - // Write each line of the witness to the file for (variable, value) in &self.0 { - wtr.serialize((variable.to_string(), value.to_dec_string()))?; + variable.write(&mut writer)?; + value.write(&mut writer)?; } - Ok(()) } pub fn read(mut reader: R) -> io::Result { - let mut rdr = csv::ReaderBuilder::new() - .delimiter(b' ') - .flexible(true) - .has_headers(false) - .from_reader(&mut reader); - - let map = rdr - .deserialize::<(String, String)>() - .map(|r| { - r.map(|(variable, value)| { - let variable = Variable::try_from_human_readable(&variable).map_err(|why| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid variable in witness: {}", why), - ) - })?; - let value = T::try_from_dec_str(&value).map_err(|_| { - io::Error::new( - io::ErrorKind::Other, - format!("Invalid value in witness: {}", value), - ) - })?; - Ok((variable, value)) - }) - .map_err(|e| match e.into_kind() { - csv::ErrorKind::Io(e) => e, - e => io::Error::new(io::ErrorKind::Other, format!("{:?}", e)), - })? - }) - .collect::>>()?; + let mut witness = Self::empty(); + + let mut buf = [0; std::mem::size_of::()]; + reader.read_exact(&mut buf)?; + + let length: usize = usize::from_le_bytes(buf); + + for _ in 0..length { + let var = Variable::read(&mut reader)?; + let val = T::read(&mut reader)?; - Ok(Witness(map)) + witness.insert(var, val); + } + + Ok(witness) + } + + pub fn write_json(&self, writer: W) -> io::Result<()> { + let map = self + .0 + .iter() + .map(|(k, v)| (k.to_string(), serde_json::json!(v.to_dec_string()))) + .collect::>(); + + serde_json::to_writer_pretty(writer, &map)?; + Ok(()) } } @@ -138,32 +128,29 @@ mod tests { } #[test] - fn wrong_value() { - let mut buff = Cursor::new(vec![]); - - buff.write_all("_1 123bug".as_ref()).unwrap(); - buff.set_position(0); - - assert!(Witness::::read(buff).is_err()); - } - - #[test] - fn wrong_variable() { - let mut buff = Cursor::new(vec![]); - - buff.write_all("_1bug 123".as_ref()).unwrap(); - buff.set_position(0); - - assert!(Witness::::read(buff).is_err()); - } - - #[test] - fn not_csv() { - let mut buff = Cursor::new(vec![]); - buff.write_all("whatwhat".as_ref()).unwrap(); - buff.set_position(0); + fn serialize_json() { + let w = Witness( + vec![ + (Variable::new(42), Bn128Field::from(42)), + (Variable::public(8), Bn128Field::from(8)), + (Variable::one(), Bn128Field::from(1)), + ] + .into_iter() + .collect(), + ); - assert!(Witness::::read(buff).is_err()); + let mut buf = Cursor::new(vec![]); + w.write_json(&mut buf).unwrap(); + + let output = String::from_utf8(buf.into_inner()).unwrap(); + assert_eq!( + output.as_str(), + r#"{ + "~out_8": "8", + "~one": "1", + "_42": "42" +}"# + ) } } } diff --git a/zokrates_ast/src/lib.rs b/zokrates_ast/src/lib.rs index 797b85619..084173f84 100644 --- a/zokrates_ast/src/lib.rs +++ b/zokrates_ast/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - pub mod common; pub mod flat; pub mod ir; diff --git a/zokrates_ast/src/typed/abi.rs b/zokrates_ast/src/typed/abi.rs index 253d27e28..88a0da9e4 100644 --- a/zokrates_ast/src/typed/abi.rs +++ b/zokrates_ast/src/typed/abi.rs @@ -22,7 +22,7 @@ impl Abi { ConcreteSignature { generics: vec![], inputs: self.inputs.iter().map(|i| i.ty.clone()).collect(), - output: box self.output.clone(), + output: Box::new(self.output.clone()), } } } @@ -49,14 +49,14 @@ mod tests { ConcreteFunctionKey::with_location("main", "main").into(), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![ - DeclarationParameter { - id: DeclarationVariable::new("a", DeclarationType::FieldElement, true), - private: true, - }, - DeclarationParameter { - id: DeclarationVariable::new("b", DeclarationType::Boolean, false), - private: false, - }, + DeclarationParameter::private(DeclarationVariable::new( + "a", + DeclarationType::FieldElement, + )), + DeclarationParameter::public(DeclarationVariable::new( + "b", + DeclarationType::Boolean, + )), ], statements: vec![], signature: ConcreteSignature::new() @@ -73,6 +73,7 @@ mod tests { let typed_ast: TypedProgram = TypedProgram { main: "main".into(), modules, + module_map: Default::default(), }; let abi: Abi = typed_ast.abi(); diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index 722dcbf87..424109bf3 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -1,47 +1,47 @@ +use crate::common::expressions::{ + BooleanValueExpression, EqExpression, FieldValueExpression, IntValueExpression, + UnaryOrExpression, ValueOrExpression, +}; // Generic walk through a typed AST. Not mutating in place - -use crate::typed::types::*; +use crate::common::{expressions::BinaryOrExpression, Fold}; use crate::typed::*; use zokrates_field::Field; use super::identifier::FrameIdentifier; +use super::types::{DeclarationStructMember, DeclarationTupleType, StructMember}; -pub trait Fold<'ast, T: Field>: Sized { - fn fold>(self, f: &mut F) -> Self; -} - -impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for FieldElementExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_field_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for BooleanExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for BooleanExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_boolean_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for UExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for UExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_uint_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for StructExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for StructExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_struct_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for ArrayExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for ArrayExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_array_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for TupleExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for TupleExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_tuple_expression(self) } } @@ -142,22 +142,16 @@ pub trait Folder<'ast, T: Field>: Sized { } fn fold_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { - Variable { - id: self.fold_name(v.id), - _type: self.fold_type(v._type), - is_mutable: v.is_mutable, - } + let span = v.get_span(); + Variable::new(self.fold_name(v.id), self.fold_type(v.ty)).span(span) } fn fold_declaration_variable( &mut self, v: DeclarationVariable<'ast, T>, ) -> DeclarationVariable<'ast, T> { - DeclarationVariable { - id: self.fold_name(v.id), - _type: self.fold_declaration_type(v._type), - is_mutable: v.is_mutable, - } + let span = v.get_span(); + DeclarationVariable::new(self.fold_name(v.id), self.fold_declaration_type(v.ty)).span(span) } fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> { @@ -173,8 +167,8 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_array_type(&mut self, t: ArrayType<'ast, T>) -> ArrayType<'ast, T> { ArrayType { - ty: box self.fold_type(*t.ty), - size: box self.fold_uint_expression(*t.size), + ty: Box::new(self.fold_type(*t.ty)), + size: Box::new(self.fold_uint_expression(*t.size)), } } @@ -195,7 +189,7 @@ pub trait Folder<'ast, T: Field>: Sized { .members .into_iter() .map(|m| StructMember { - ty: box self.fold_type(*m.ty), + ty: Box::new(self.fold_type(*m.ty)), ..m }) .collect(), @@ -219,8 +213,8 @@ pub trait Folder<'ast, T: Field>: Sized { t: DeclarationArrayType<'ast, T>, ) -> DeclarationArrayType<'ast, T> { DeclarationArrayType { - ty: box self.fold_declaration_type(*t.ty), - size: box self.fold_declaration_constant(*t.size), + ty: Box::new(self.fold_declaration_type(*t.ty)), + size: Box::new(self.fold_declaration_constant(*t.size)), } } @@ -251,7 +245,7 @@ pub trait Folder<'ast, T: Field>: Sized { .members .into_iter() .map(|m| DeclarationStructMember { - ty: box self.fold_declaration_type(*m.ty), + ty: Box::new(self.fold_declaration_type(*m.ty)), ..m }) .collect(), @@ -263,6 +257,27 @@ pub trait Folder<'ast, T: Field>: Sized { fold_assignee(self, a) } + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Vec> { + fold_assembly_block(self, s) + } + + fn fold_assembly_assignment( + &mut self, + s: AssemblyAssignment<'ast, T>, + ) -> Vec> { + fold_assembly_assignment(self, s) + } + + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Vec> { + fold_assembly_constraint(self, s) + } + fn fold_assembly_statement( &mut self, s: TypedAssemblyStatement<'ast, T>, @@ -270,10 +285,50 @@ pub trait Folder<'ast, T: Field>: Sized { fold_assembly_statement(self, s) } + fn fold_assembly_statement_cases( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Vec> { + fold_assembly_statement_cases(self, s) + } + fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { fold_statement(self, s) } + fn fold_statement_cases(&mut self, s: TypedStatement<'ast, T>) -> Vec> { + fold_statement_cases(self, s) + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Vec> { + fold_definition_statement(self, s) + } + + fn fold_return_statement( + &mut self, + s: ReturnStatement<'ast, T>, + ) -> Vec> { + fold_return_statement(self, s) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Vec> { + fold_assertion_statement(self, s) + } + + fn fold_log_statement(&mut self, s: LogStatement<'ast, T>) -> Vec> { + fold_log_statement(self, s) + } + + fn fold_for_statement(&mut self, s: ForStatement<'ast, T>) -> Vec> { + fold_for_statement(self, s) + } + fn fold_definition_rhs(&mut self, rhs: DefinitionRhs<'ast, T>) -> DefinitionRhs<'ast, T> { fold_definition_rhs(self, rhs) } @@ -312,7 +367,7 @@ pub trait Folder<'ast, T: Field>: Sized { } } - fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn fold_module_id(&mut self, i: OwnedModuleId) -> OwnedModuleId { i } @@ -320,7 +375,7 @@ pub trait Folder<'ast, T: Field>: Sized { fold_expression(self, e) } - fn fold_block_expression>( + fn fold_block_expression>( &mut self, block: BlockExpression<'ast, T, E>, ) -> BlockExpression<'ast, T, E> { @@ -329,7 +384,7 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_conditional_expression< E: Expr<'ast, T> - + Fold<'ast, T> + + Fold + Block<'ast, T> + Conditional<'ast, T> + From>, @@ -341,6 +396,35 @@ pub trait Folder<'ast, T: Field>: Sized { fold_conditional_expression(self, ty, e) } + fn fold_binary_expression< + L: Expr<'ast, T> + Fold, + R: Expr<'ast, T> + Fold, + E: Expr<'ast, T> + Fold, + Op, + >( + &mut self, + ty: &E::Ty, + e: BinaryExpression, + ) -> BinaryOrExpression { + fold_binary_expression(self, ty, e) + } + + fn fold_eq_expression + Fold>( + &mut self, + e: EqExpression>, + ) -> BinaryOrExpression, BooleanExpression<'ast, T>> + { + fold_binary_expression(self, &Type::Boolean, e) + } + + fn fold_unary_expression + Fold, E: Expr<'ast, T> + Fold, Op>( + &mut self, + ty: &E::Ty, + e: UnaryExpression, + ) -> UnaryOrExpression { + fold_unary_expression(self, ty, e) + } + fn fold_member_expression< E: Expr<'ast, T> + Member<'ast, T> + From>, >( @@ -351,6 +435,17 @@ pub trait Folder<'ast, T: Field>: Sized { fold_member_expression(self, ty, e) } + fn fold_slice_expression(&mut self, e: SliceExpression<'ast, T>) -> SliceOrExpression<'ast, T> { + fold_slice_expression(self, e) + } + + fn fold_repeat_expression( + &mut self, + e: RepeatExpression<'ast, T>, + ) -> RepeatOrExpression<'ast, T> { + fold_repeat_expression(self, e) + } + fn fold_identifier_expression< E: Expr<'ast, T> + Id<'ast, T> + From>, >( @@ -371,13 +466,6 @@ pub trait Folder<'ast, T: Field>: Sized { fold_element_expression(self, ty, e) } - fn fold_eq_expression + PartialEq + Constant + Fold<'ast, T>>( - &mut self, - e: EqExpression, - ) -> EqOrBoolean<'ast, T, E> { - fold_eq_expression(self, e) - } - fn fold_function_call_expression< E: Id<'ast, T> + From> + Expr<'ast, T> + FunctionCall<'ast, T>, >( @@ -417,6 +505,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_int_expression(self, e) } + fn fold_field_expression_cases( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + fold_field_expression_cases(self, e) + } + fn fold_field_expression( &mut self, e: FieldElementExpression<'ast, T>, @@ -424,6 +519,13 @@ pub trait Folder<'ast, T: Field>: Sized { fold_field_expression(self, e) } + fn fold_boolean_expression_cases( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + fold_boolean_expression_cases(self, e) + } + fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, @@ -435,6 +537,14 @@ pub trait Folder<'ast, T: Field>: Sized { fold_uint_expression(self, e) } + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + fold_uint_expression_cases(self, bitwidth, e) + } + fn fold_uint_expression_inner( &mut self, bitwidth: UBitwidth, @@ -443,6 +553,14 @@ pub trait Folder<'ast, T: Field>: Sized { fold_uint_expression_inner(self, bitwidth, e) } + fn fold_array_expression_cases( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> ArrayExpressionInner<'ast, T> { + fold_array_expression_cases(self, ty, e) + } + fn fold_array_expression_inner( &mut self, ty: &ArrayType<'ast, T>, @@ -451,6 +569,14 @@ pub trait Folder<'ast, T: Field>: Sized { fold_array_expression_inner(self, ty, e) } + fn fold_tuple_expression_cases( + &mut self, + ty: &TupleType<'ast, T>, + e: TupleExpressionInner<'ast, T>, + ) -> TupleExpressionInner<'ast, T> { + fold_tuple_expression_cases(self, ty, e) + } + fn fold_tuple_expression_inner( &mut self, ty: &TupleType<'ast, T>, @@ -459,6 +585,14 @@ pub trait Folder<'ast, T: Field>: Sized { fold_tuple_expression_inner(self, ty, e) } + fn fold_struct_expression_cases( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> StructExpressionInner<'ast, T> { + fold_struct_expression_cases(self, ty, e) + } + fn fold_struct_expression_inner( &mut self, ty: &StructType<'ast, T>, @@ -466,6 +600,51 @@ pub trait Folder<'ast, T: Field>: Sized { ) -> StructExpressionInner<'ast, T> { fold_struct_expression_inner(self, ty, e) } + + fn fold_field_value_expression( + &mut self, + v: FieldValueExpression, + ) -> ValueOrExpression, FieldElementExpression<'ast, T>> { + fold_field_value_expression(self, v) + } + + fn fold_boolean_value_expression( + &mut self, + v: BooleanValueExpression, + ) -> ValueOrExpression> { + fold_boolean_value_expression(self, v) + } + + fn fold_integer_value_expression( + &mut self, + v: IntValueExpression, + ) -> ValueOrExpression> { + fold_integer_value_expression(self, v) + } + + fn fold_struct_value_expression( + &mut self, + ty: &StructType<'ast, T>, + v: StructValueExpression<'ast, T>, + ) -> ValueOrExpression, StructExpressionInner<'ast, T>> { + fold_struct_value_expression(self, ty, v) + } + + fn fold_array_value_expression( + &mut self, + ty: &ArrayType<'ast, T>, + v: ArrayValueExpression<'ast, T>, + ) -> ValueOrExpression, ArrayExpressionInner<'ast, T>> { + fold_array_value_expression(self, ty, v) + } + + fn fold_tuple_value_expression( + &mut self, + ty: &TupleType<'ast, T>, + v: TupleValueExpression<'ast, T>, + ) -> ValueOrExpression, TupleExpressionInner<'ast, T>> { + fold_tuple_value_expression(self, ty, v) + } } pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( @@ -525,58 +704,136 @@ pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_return_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ReturnStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::Return(ReturnStatement::new( + f.fold_expression(s.inner), + ))] +} + +pub fn fold_definition_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: DefinitionStatement<'ast, T>, +) -> Vec> { + let rhs = f.fold_definition_rhs(s.rhs); + vec![TypedStatement::Definition( + DefinitionStatement::new(f.fold_assignee(s.assignee), rhs).span(s.span), + )] +} + +pub fn fold_assertion_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssertionStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::Assertion( + AssertionStatement::new(f.fold_boolean_expression(s.expression), s.error).span(s.span), + )] +} + +pub fn fold_for_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ForStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::For(ForStatement::new( + f.fold_variable(s.var), + f.fold_uint_expression(s.from), + f.fold_uint_expression(s.to), + s.statements + .into_iter() + .flat_map(|s| f.fold_statement(s)) + .collect(), + ))] +} + +pub fn fold_log_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: LogStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() + .map(|e| f.fold_expression(e)) + .collect(), + ))] +} + +pub fn fold_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: TypedStatement<'ast, T>, +) -> Vec> { + match s { + TypedStatement::Return(s) => f.fold_return_statement(s), + TypedStatement::Definition(s) => f.fold_definition_statement(s), + TypedStatement::Assertion(s) => f.fold_assertion_statement(s), + TypedStatement::For(s) => f.fold_for_statement(s), + TypedStatement::Log(s) => f.fold_log_statement(s), + TypedStatement::Assembly(s) => f.fold_assembly_block(s), + } +} + +pub fn fold_assembly_block<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyBlockStatement<'ast, T>, +) -> Vec> { + vec![TypedStatement::Assembly(AssemblyBlockStatement::new( + s.inner + .into_iter() + .flat_map(|s| f.fold_assembly_statement(s)) + .collect(), + ))] +} + +pub fn fold_assembly_assignment<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyAssignment<'ast, T>, +) -> Vec> { + let assignee = f.fold_assignee(s.assignee); + let expression = f.fold_expression(s.expression); + vec![TypedAssemblyStatement::assignment(assignee, expression)] +} + +pub fn fold_assembly_constraint<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyConstraint<'ast, T>, +) -> Vec> { + let left = f.fold_field_expression(s.left); + let right = f.fold_field_expression(s.right); + vec![TypedAssemblyStatement::constraint(left, right, s.metadata)] +} + +fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: TypedAssemblyStatement<'ast, T>, +) -> Vec> { + let span = s.get_span(); + f.fold_assembly_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +pub fn fold_assembly_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedAssemblyStatement<'ast, T>, ) -> Vec> { match s { - TypedAssemblyStatement::Assignment(a, e) => { - let e = f.fold_expression(e); - vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)] - } - TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { - vec![TypedAssemblyStatement::Constraint( - f.fold_field_expression(lhs), - f.fold_field_expression(rhs), - metadata, - )] - } + TypedAssemblyStatement::Assignment(s) => f.fold_assembly_assignment(s), + TypedAssemblyStatement::Constraint(s) => f.fold_assembly_constraint(s), } } -pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, ) -> Vec> { - let res = match s { - TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)), - TypedStatement::Definition(a, rhs) => { - let rhs = f.fold_definition_rhs(rhs); - TypedStatement::Definition(f.fold_assignee(a), rhs) - } - TypedStatement::Assertion(e, error) => { - TypedStatement::Assertion(f.fold_boolean_expression(e), error) - } - TypedStatement::For(v, from, to, statements) => TypedStatement::For( - f.fold_variable(v), - f.fold_uint_expression(from), - f.fold_uint_expression(to), - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - ), - TypedStatement::Log(s, e) => { - TypedStatement::Log(s, e.into_iter().map(|e| f.fold_expression(e)).collect()) - } - TypedStatement::Assembly(statements) => TypedStatement::Assembly( - statements - .into_iter() - .flat_map(|s| f.fold_assembly_statement(s)) - .collect(), - ), - }; - vec![res] + let span = s.get_span(); + f.fold_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() } pub fn fold_identifier_expression< @@ -621,7 +878,16 @@ pub fn fold_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, +) -> ArrayExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_array_expression_cases(ty, e).span(span) +} + +pub fn fold_array_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, @@ -634,16 +900,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)), - Value(exprs) => Value( - exprs - .into_iter() - .map(|e| f.fold_expression_or_spread(e)) - .collect(), - ), FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call), FunctionCallOrExpression::Expression(u) => u, }, + Value(function_call) => match f.fold_array_value_expression(ty, function_call) { + ValueOrExpression::Value(value) => Value(value), + ValueOrExpression::Expression(e) => e, + }, Conditional(c) => match f.fold_conditional_expression(ty, c) { ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, @@ -656,17 +920,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( MemberOrExpression::Member(m) => Member(m), MemberOrExpression::Expression(u) => u, }, - Slice(box array, box from, box to) => { - let array = f.fold_array_expression(array); - let from = f.fold_uint_expression(from); - let to = f.fold_uint_expression(to); - Slice(box array, box from, box to) - } - Repeat(box e, box count) => { - let e = f.fold_expression(e); - let count = f.fold_uint_expression(count); - Repeat(box e, box count) - } + Slice(s) => match f.fold_slice_expression(s) { + SliceOrExpression::Slice(m) => Slice(m), + SliceOrExpression::Expression(u) => u, + }, + Repeat(s) => match f.fold_repeat_expression(s) { + RepeatOrExpression::Repeat(m) => Repeat(m), + RepeatOrExpression::Expression(u) => u, + }, Element(m) => match f.fold_element_expression(ty, m) { ElementOrExpression::Element(m) => Element(m), ElementOrExpression::Expression(u) => u, @@ -674,7 +935,16 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, +) -> StructExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_struct_expression_cases(ty, e).span(span) +} + +pub fn fold_struct_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, @@ -687,7 +957,10 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)), - Value(exprs) => Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()), + Value(function_call) => match f.fold_struct_value_expression(ty, function_call) { + ValueOrExpression::Value(value) => Value(value), + ValueOrExpression::Expression(e) => e, + }, FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call), FunctionCallOrExpression::Expression(u) => u, @@ -711,7 +984,16 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: &TupleType<'ast, T>, + e: TupleExpressionInner<'ast, T>, +) -> TupleExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_tuple_expression_cases(ty, e).span(span) +} + +pub fn fold_tuple_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: &TupleType<'ast, T>, e: TupleExpressionInner<'ast, T>, @@ -724,7 +1006,10 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Identifier(i) => Identifier(i), IdentifierOrExpression::Expression(u) => u, }, - Value(exprs) => Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()), + Value(function_call) => match f.fold_tuple_value_expression(ty, function_call) { + ValueOrExpression::Value(value) => Value(value), + ValueOrExpression::Expression(e) => e, + }, FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call), FunctionCallOrExpression::Expression(u) => u, @@ -748,7 +1033,15 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> FieldElementExpression<'ast, T> { + let span = e.get_span(); + f.fold_field_expression_cases(e).span(span) +} + +pub fn fold_field_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { @@ -760,72 +1053,58 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)), - Number(n) => Number(n), - Add(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - Add(box e1, box e2) - } - Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - Sub(box e1, box e2) - } - Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - Mult(box e1, box e2) - } - Div(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - Div(box e1, box e2) - } - Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_uint_expression(e2); - Pow(box e1, box e2) - } - Neg(box e) => { - let e = f.fold_field_expression(e); - - Neg(box e) - } - Pos(box e) => { - let e = f.fold_field_expression(e); - - Pos(box e) - } - And(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - And(box left, box right) - } - Or(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - Or(box left, box right) - } - Xor(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - Xor(box left, box right) - } - LeftShift(box e, box by) => { - let e = f.fold_field_expression(e); - let by = f.fold_uint_expression(by); - - LeftShift(box e, box by) - } - RightShift(box e, box by) => { - let e = f.fold_field_expression(e); - let by = f.fold_uint_expression(by); - - RightShift(box e, box by) - } + Value(value) => match f.fold_field_value_expression(value) { + ValueOrExpression::Value(value) => Value(value), + ValueOrExpression::Expression(e) => e, + }, + Add(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Pow(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Pow(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Neg(e) => match f.fold_unary_expression(&Type::FieldElement, e) { + UnaryOrExpression::Unary(e) => Neg(e), + UnaryOrExpression::Expression(u) => u, + }, + Pos(e) => match f.fold_unary_expression(&Type::FieldElement, e) { + UnaryOrExpression::Unary(e) => Pos(e), + UnaryOrExpression::Expression(u) => u, + }, Conditional(c) => match f.fold_conditional_expression(&Type::FieldElement, c) { ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, @@ -856,19 +1135,53 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_conditional_expression< 'ast, T: Field, - E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T> + From>, + E: Expr<'ast, T> + Fold + Conditional<'ast, T> + From>, F: Folder<'ast, T>, >( f: &mut F, _: &E::Ty, e: ConditionalExpression<'ast, T, E>, ) -> ConditionalOrExpression<'ast, T, E> { - ConditionalOrExpression::Conditional(ConditionalExpression::new( - f.fold_boolean_expression(*e.condition), - e.consequence.fold(f), - e.alternative.fold(f), - e.kind, - )) + ConditionalOrExpression::Conditional( + ConditionalExpression::new( + f.fold_boolean_expression(*e.condition), + e.consequence.fold(f), + e.alternative.fold(f), + e.kind, + ) + .span(e.span), + ) +} + +pub fn fold_binary_expression< + 'ast, + T: Field, + L: Expr<'ast, T> + Fold + From>, + R: Expr<'ast, T> + Fold + From>, + E: Expr<'ast, T> + Fold + From>, + F: Folder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: BinaryExpression, +) -> BinaryOrExpression { + BinaryOrExpression::Binary(BinaryExpression::new(e.left.fold(f), e.right.fold(f))) +} + +pub fn fold_unary_expression< + 'ast, + T: Field, + In: Expr<'ast, T> + Fold + From>, + E: Expr<'ast, T> + Fold + From>, + F: Folder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: UnaryExpression, +) -> UnaryOrExpression { + UnaryOrExpression::Unary(UnaryExpression::new(e.inner.fold(f))) } pub fn fold_member_expression< @@ -887,6 +1200,27 @@ pub fn fold_member_expression< )) } +pub fn fold_slice_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: SliceExpression<'ast, T>, +) -> SliceOrExpression<'ast, T> { + SliceOrExpression::Slice(SliceExpression::new( + f.fold_array_expression(*e.array), + f.fold_uint_expression(*e.from), + f.fold_uint_expression(*e.to), + )) +} + +pub fn fold_repeat_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: RepeatExpression<'ast, T>, +) -> RepeatOrExpression<'ast, T> { + RepeatOrExpression::Repeat(RepeatExpression::new( + f.fold_expression(*e.e), + f.fold_uint_expression(*e.count), + )) +} + pub fn fold_element_expression< 'ast, T: Field, @@ -897,17 +1231,9 @@ pub fn fold_element_expression< _: &E::Ty, e: ElementExpression<'ast, T, E>, ) -> ElementOrExpression<'ast, T, E> { - ElementOrExpression::Element(ElementExpression::new( - f.fold_tuple_expression(*e.tuple), - e.index, - )) -} - -pub fn fold_eq_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T>>( - f: &mut F, - e: EqExpression, -) -> EqOrBoolean<'ast, T, E> { - EqOrBoolean::Eq(EqExpression::new(e.left.fold(f), e.right.fold(f))) + ElementOrExpression::Element( + ElementExpression::new(f.fold_tuple_expression(*e.tuple), e.index).span(e.span), + ) } pub fn fold_select_expression< @@ -920,10 +1246,13 @@ pub fn fold_select_expression< _: &E::Ty, e: SelectExpression<'ast, T, E>, ) -> SelectOrExpression<'ast, T, E> { - SelectOrExpression::Select(SelectExpression::new( - f.fold_array_expression(*e.array), - f.fold_uint_expression(*e.index), - )) + SelectOrExpression::Select( + SelectExpression::new( + f.fold_array_expression(*e.array), + f.fold_uint_expression(*e.index), + ) + .span(e.span), + ) } pub fn fold_int_expression<'ast, T: Field, F: Folder<'ast, T>>( @@ -933,7 +1262,15 @@ pub fn fold_int_expression<'ast, T: Field, F: Folder<'ast, T>>( unreachable!() } -pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> BooleanExpression<'ast, T> { + let span = e.get_span(); + f.fold_boolean_expression_cases(e).span(span) +} + +pub fn fold_boolean_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { @@ -945,85 +1282,62 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => BooleanExpression::Block(f.fold_block_expression(block)), - Value(v) => BooleanExpression::Value(v), + Value(e) => match f.fold_boolean_value_expression(e) { + ValueOrExpression::Value(e) => Value(e), + ValueOrExpression::Expression(u) => u, + }, FieldEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::FieldEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => FieldEq(e), + BinaryOrExpression::Expression(u) => u, }, BoolEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::BoolEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => BoolEq(e), + BinaryOrExpression::Expression(u) => u, }, ArrayEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::ArrayEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => ArrayEq(e), + BinaryOrExpression::Expression(u) => u, }, StructEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::StructEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => StructEq(e), + BinaryOrExpression::Expression(u) => u, }, TupleEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::TupleEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => TupleEq(e), + BinaryOrExpression::Expression(u) => u, }, UintEq(e) => match f.fold_eq_expression(e) { - EqOrBoolean::Eq(e) => BooleanExpression::UintEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => UintEq(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLt(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldLt(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLe(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldLe(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLt(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintLt(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLe(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintLe(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&Type::Boolean, e) { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, }, - FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldLt(box e1, box e2) - } - FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldLe(box e1, box e2) - } - FieldGt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldGt(box e1, box e2) - } - FieldGe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldGe(box e1, box e2) - } - UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - UintLt(box e1, box e2) - } - UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - UintLe(box e1, box e2) - } - UintGt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - UintGt(box e1, box e2) - } - UintGe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - UintGe(box e1, box e2) - } - Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - Or(box e1, box e2) - } - And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - And(box e1, box e2) - } - Not(box e) => { - let e = f.fold_boolean_expression(e); - Not(box e) - } FunctionCall(function_call) => { match f.fold_function_call_expression(&Type::Boolean, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => { @@ -1061,7 +1375,16 @@ pub fn fold_uint_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> UExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_uint_expression_cases(ty, e).span(span) +} + +pub fn fold_uint_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: UBitwidth, e: UExpressionInner<'ast, T>, @@ -1075,87 +1398,62 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( }, Block(block) => Block(f.fold_block_expression(block)), Value(v) => Value(v), - Add(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Add(box left, box right) - } - Sub(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Sub(box left, box right) - } - FloorSub(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - FloorSub(box left, box right) - } - Mult(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Mult(box left, box right) - } - Div(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Div(box left, box right) - } - Rem(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Rem(box left, box right) - } - Xor(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Xor(box left, box right) - } - And(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - And(box left, box right) - } - Or(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - Or(box left, box right) - } - LeftShift(box e, box by) => { - let e = f.fold_uint_expression(e); - let by = f.fold_uint_expression(by); - - LeftShift(box e, box by) - } - RightShift(box e, box by) => { - let e = f.fold_uint_expression(e); - let by = f.fold_uint_expression(by); - - RightShift(box e, box by) - } - Not(box e) => { - let e = f.fold_uint_expression(e); - - Not(box e) - } - Neg(box e) => { - let e = f.fold_uint_expression(e); - - Neg(box e) - } - Pos(box e) => { - let e = f.fold_uint_expression(e); - - Pos(box e) - } + Add(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + FloorSub(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => FloorSub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Rem(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Rem(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&ty, e) { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, + Neg(e) => match f.fold_unary_expression(&ty, e) { + UnaryOrExpression::Unary(e) => Neg(e), + UnaryOrExpression::Expression(u) => u, + }, + Pos(e) => match f.fold_unary_expression(&ty, e) { + UnaryOrExpression::Unary(e) => Pos(e), + UnaryOrExpression::Expression(u) => u, + }, FunctionCall(function_call) => match f.fold_function_call_expression(&ty, function_call) { FunctionCallOrExpression::FunctionCall(function_call) => FunctionCall(function_call), FunctionCallOrExpression::Expression(u) => u, @@ -1179,18 +1477,19 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_block_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T>>( +pub fn fold_block_expression<'ast, T: Field, E: Fold, F: Folder<'ast, T>>( f: &mut F, block: BlockExpression<'ast, T, E>, ) -> BlockExpression<'ast, T, E> { - BlockExpression { - statements: block + BlockExpression::new( + block .statements .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - value: box block.value.fold(f), - } + block.value.fold(f), + ) + .span(block.span) } pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>( @@ -1214,17 +1513,20 @@ pub fn fold_function_call_expression< _: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> FunctionCallOrExpression<'ast, T, E> { - FunctionCallOrExpression::FunctionCall(FunctionCallExpression::new( - f.fold_declaration_function_key(e.function_key), - e.generics - .into_iter() - .map(|g| g.map(|g| f.fold_uint_expression(g))) - .collect(), - e.arguments - .into_iter() - .map(|e| f.fold_expression(e)) - .collect(), - )) + FunctionCallOrExpression::FunctionCall( + FunctionCallExpression::new( + f.fold_declaration_function_key(e.function_key), + e.generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g))) + .collect(), + e.arguments + .into_iter() + .map(|e| f.fold_expression(e)) + .collect(), + ) + .span(e.span), + ) } pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>( @@ -1257,7 +1559,7 @@ fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>( .into_iter() .map(|o| f.fold_declaration_type(o)) .collect(), - output: box f.fold_declaration_type(*s.output), + output: Box::new(f.fold_declaration_type(*s.output)), } } @@ -1282,7 +1584,7 @@ pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>( ArrayExpression { inner: f.fold_array_expression_inner(&ty, e.inner), - ty: box ty, + ty: Box::new(ty), } } @@ -1308,6 +1610,64 @@ pub fn fold_tuple_expression<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_integer_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + _f: &mut F, + a: IntValueExpression, +) -> ValueOrExpression> { + ValueOrExpression::Value(a) +} + +pub fn fold_boolean_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + _f: &mut F, + a: BooleanValueExpression, +) -> ValueOrExpression> { + ValueOrExpression::Value(a) +} + +pub fn fold_field_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + _f: &mut F, + a: FieldValueExpression, +) -> ValueOrExpression, FieldElementExpression<'ast, T>> { + ValueOrExpression::Value(a) +} + +pub fn fold_struct_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + _ty: &StructType<'ast, T>, + a: StructValueExpression<'ast, T>, +) -> ValueOrExpression, StructExpressionInner<'ast, T>> { + ValueOrExpression::Value(StructValueExpression { + value: a.value.into_iter().map(|v| f.fold_expression(v)).collect(), + ..a + }) +} + +pub fn fold_array_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + _ty: &ArrayType<'ast, T>, + a: ArrayValueExpression<'ast, T>, +) -> ValueOrExpression, ArrayExpressionInner<'ast, T>> { + ValueOrExpression::Value(ArrayValueExpression { + value: a + .value + .into_iter() + .map(|v| f.fold_expression_or_spread(v)) + .collect(), + ..a + }) +} + +pub fn fold_tuple_value_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + _ty: &TupleType<'ast, T>, + a: TupleValueExpression<'ast, T>, +) -> ValueOrExpression, TupleExpressionInner<'ast, T>> { + ValueOrExpression::Value(TupleValueExpression { + value: a.value.into_iter().map(|v| f.fold_expression(v)).collect(), + ..a + }) +} + pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, c: TypedConstant<'ast, T>, @@ -1349,12 +1709,13 @@ pub fn fold_assignee<'ast, T: Field, F: Folder<'ast, T>>( ) -> TypedAssignee<'ast, T> { match a { TypedAssignee::Identifier(v) => TypedAssignee::Identifier(f.fold_variable(v)), - TypedAssignee::Select(box a, box index) => { - TypedAssignee::Select(box f.fold_assignee(a), box f.fold_uint_expression(index)) - } - TypedAssignee::Member(box s, m) => TypedAssignee::Member(box f.fold_assignee(s), m), - TypedAssignee::Element(box s, index) => { - TypedAssignee::Element(box f.fold_assignee(s), index) + TypedAssignee::Select(a, index) => TypedAssignee::Select( + Box::new(f.fold_assignee(*a)), + Box::new(f.fold_uint_expression(*index)), + ), + TypedAssignee::Member(s, m) => TypedAssignee::Member(Box::new(f.fold_assignee(*s)), m), + TypedAssignee::Element(s, index) => { + TypedAssignee::Element(Box::new(f.fold_assignee(*s)), index) } } } @@ -1370,5 +1731,6 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( .map(|(module_id, module)| (f.fold_module_id(module_id), f.fold_module(module))) .collect(), main: f.fold_module_id(p.main), + ..p } } diff --git a/zokrates_ast/src/typed/integer.rs b/zokrates_ast/src/typed/integer.rs index 507ce9174..d5ba9ad58 100644 --- a/zokrates_ast/src/typed/integer.rs +++ b/zokrates_ast/src/typed/integer.rs @@ -8,14 +8,21 @@ use crate::typed::{ ArrayExpression, ArrayExpressionInner, BooleanExpression, Conditional, ConditionalExpression, Expr, FieldElementExpression, Select, SelectExpression, StructExpression, StructExpressionInner, TupleExpression, TupleExpressionInner, Typed, TypedExpression, - TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner, + TypedExpressionOrSpread, TypedSpread, UExpression, }; + +use crate::common::{operators::*, WithSpan}; + use num_bigint::BigUint; use std::convert::TryFrom; use std::fmt; -use std::ops::{Add, Div, Mul, Neg, Not, Rem, Sub}; +use std::ops::*; use zokrates_field::Field; +use crate::common::expressions::*; + +use super::{ArrayValueExpression, RepeatExpression}; + type TypedExpressionPair<'ast, T> = (TypedExpression<'ast, T>, TypedExpression<'ast, T>); impl<'ast, T: Field> TypedExpressionOrSpread<'ast, T> { @@ -103,7 +110,7 @@ impl<'ast, T: Clone> IntegerInference for StructType<'ast, T> { .zip(other.members.into_iter()) .map(|(m_t, m_u)| match m_t.ty.get_common_pattern(*m_u.ty) { Ok(ty) => DeclarationStructMember { - ty: box ty, + ty: Box::new(ty), id: m_t.id, }, Err(..) => unreachable!( @@ -274,30 +281,107 @@ impl<'ast, T: Field> TypedExpression<'ast, T> { #[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] pub enum IntExpression<'ast, T> { - Value(BigUint), - Pos(Box>), - Neg(Box>), - Add(Box>, Box>), - Sub(Box>, Box>), - Mult(Box>, Box>), - Div(Box>, Box>), - Rem(Box>, Box>), - Pow(Box>, Box>), + Value(IntValueExpression), + Pos(UnaryExpression, IntExpression<'ast, T>>), + Neg(UnaryExpression, IntExpression<'ast, T>>), + Not(UnaryExpression, IntExpression<'ast, T>>), + Add( + BinaryExpression< + OpAdd, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Sub( + BinaryExpression< + OpSub, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Mult( + BinaryExpression< + OpMul, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Div( + BinaryExpression< + OpDiv, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Rem( + BinaryExpression< + OpRem, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Pow( + BinaryExpression< + OpPow, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Xor( + BinaryExpression< + OpXor, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + And( + BinaryExpression< + OpAnd, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + Or( + BinaryExpression< + OpOr, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + LeftShift( + BinaryExpression< + OpLsh, + IntExpression<'ast, T>, + UExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), + RightShift( + BinaryExpression< + OpRsh, + IntExpression<'ast, T>, + UExpression<'ast, T>, + IntExpression<'ast, T>, + >, + ), Conditional(ConditionalExpression<'ast, T, IntExpression<'ast, T>>), Select(SelectExpression<'ast, T, IntExpression<'ast, T>>), - Xor(Box>, Box>), - And(Box>, Box>), - Or(Box>, Box>), - Not(Box>), - LeftShift(Box>, Box>), - RightShift(Box>, Box>), } impl<'ast, T> Add for IntExpression<'ast, T> { type Output = Self; fn add(self, other: Self) -> Self { - IntExpression::Add(box self, box other) + IntExpression::Add(BinaryExpression::new(self, other)) } } @@ -305,7 +389,7 @@ impl<'ast, T> Sub for IntExpression<'ast, T> { type Output = Self; fn sub(self, other: Self) -> Self { - IntExpression::Sub(box self, box other) + IntExpression::Sub(BinaryExpression::new(self, other)) } } @@ -313,7 +397,7 @@ impl<'ast, T> Mul for IntExpression<'ast, T> { type Output = Self; fn mul(self, other: Self) -> Self { - IntExpression::Mult(box self, box other) + IntExpression::Mult(BinaryExpression::new(self, other)) } } @@ -321,7 +405,7 @@ impl<'ast, T> Div for IntExpression<'ast, T> { type Output = Self; fn div(self, other: Self) -> Self { - IntExpression::Div(box self, box other) + IntExpression::Div(BinaryExpression::new(self, other)) } } @@ -329,7 +413,7 @@ impl<'ast, T> Rem for IntExpression<'ast, T> { type Output = Self; fn rem(self, other: Self) -> Self { - IntExpression::Rem(box self, box other) + IntExpression::Rem(BinaryExpression::new(self, other)) } } @@ -337,7 +421,7 @@ impl<'ast, T> Not for IntExpression<'ast, T> { type Output = Self; fn not(self) -> Self { - IntExpression::Not(box self) + IntExpression::Not(UnaryExpression::new(self)) } } @@ -345,37 +429,37 @@ impl<'ast, T> Neg for IntExpression<'ast, T> { type Output = Self; fn neg(self) -> Self { - IntExpression::Neg(box self) + IntExpression::Neg(UnaryExpression::new(self)) } } impl<'ast, T> IntExpression<'ast, T> { pub fn pow(self, other: Self) -> Self { - IntExpression::Pow(box self, box other) + IntExpression::Pow(BinaryExpression::new(self, other)) } pub fn and(self, other: Self) -> Self { - IntExpression::And(box self, box other) + IntExpression::And(BinaryExpression::new(self, other)) } pub fn xor(self, other: Self) -> Self { - IntExpression::Xor(box self, box other) + IntExpression::Xor(BinaryExpression::new(self, other)) } pub fn or(self, other: Self) -> Self { - IntExpression::Or(box self, box other) + IntExpression::Or(BinaryExpression::new(self, other)) } pub fn left_shift(self, by: UExpression<'ast, T>) -> Self { - IntExpression::LeftShift(box self, box by) + IntExpression::LeftShift(BinaryExpression::new(self, by)) } pub fn right_shift(self, by: UExpression<'ast, T>) -> Self { - IntExpression::RightShift(box self, box by) + IntExpression::RightShift(BinaryExpression::new(self, by)) } pub fn pos(self) -> Self { - IntExpression::Pos(box self) + IntExpression::Pos(UnaryExpression::new(self)) } } @@ -383,21 +467,21 @@ impl<'ast, T: fmt::Display> fmt::Display for IntExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { IntExpression::Value(ref v) => write!(f, "{}", v), - IntExpression::Pos(ref e) => write!(f, "(+{})", e), - IntExpression::Neg(ref e) => write!(f, "(-{})", e), - IntExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), - IntExpression::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), - IntExpression::Pow(ref lhs, ref rhs) => write!(f, "({} ** {})", lhs, rhs), + IntExpression::Pos(ref e) => write!(f, "{}", e), + IntExpression::Neg(ref e) => write!(f, "{}", e), + IntExpression::Div(ref e) => write!(f, "{}", e), + IntExpression::Rem(ref e) => write!(f, "{}", e), + IntExpression::Pow(ref e) => write!(f, "{}", e), IntExpression::Select(ref select) => write!(f, "{}", select), - IntExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - IntExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - IntExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - IntExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - IntExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - IntExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - IntExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), - IntExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), - IntExpression::Not(ref e) => write!(f, "!{}", e), + IntExpression::Add(ref e) => write!(f, "{}", e), + IntExpression::And(ref e) => write!(f, "{}", e), + IntExpression::Or(ref e) => write!(f, "{}", e), + IntExpression::Xor(ref e) => write!(f, "{}", e), + IntExpression::Sub(ref e) => write!(f, "{}", e), + IntExpression::Mult(ref e) => write!(f, "{}", e), + IntExpression::RightShift(ref e) => write!(f, "{}", e), + IntExpression::LeftShift(ref e) => write!(f, "{}", e), + IntExpression::Not(ref e) => write!(f, "{}", e), IntExpression::Conditional(ref c) => write!(f, "{}", c), } } @@ -424,48 +508,52 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { } pub fn try_from_int(i: IntExpression<'ast, T>) -> Result> { + let span = i.get_span(); + match i { - IntExpression::Value(i) => Ok(Self::Number(T::try_from(i.clone()).map_err(|_| i)?)), - IntExpression::Add(box e1, box e2) => Ok(Self::Add( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Value(i) => Ok(Self::Value(ValueExpression::new( + T::try_from(i.value.clone()).map_err(|_| IntExpression::Value(i))?, + ))), + IntExpression::Add(e) => Ok(Self::add( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Sub(box e1, box e2) => Ok(Self::Sub( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Sub(e) => Ok(Self::sub( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Mult(box e1, box e2) => Ok(Self::Mult( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Mult(e) => Ok(Self::mul( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Pow(box e1, box e2) => Ok(Self::Pow( - box Self::try_from_int(e1)?, - box UExpression::try_from_int(e2, &UBitwidth::B32)?, + IntExpression::Pow(e) => Ok(Self::pow( + Self::try_from_int(*e.left)?, + UExpression::try_from_int(*e.right, &UBitwidth::B32)?, )), - IntExpression::Div(box e1, box e2) => Ok(Self::Div( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Div(e) => Ok(Self::div( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::And(box e1, box e2) => Ok(Self::And( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::And(e) => Ok(Self::bitand( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Or(box e1, box e2) => Ok(Self::Or( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Or(e) => Ok(Self::bitor( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::Xor(box e1, box e2) => Ok(Self::Xor( - box Self::try_from_int(e1)?, - box Self::try_from_int(e2)?, + IntExpression::Xor(e) => Ok(Self::bitxor( + Self::try_from_int(*e.left)?, + Self::try_from_int(*e.right)?, )), - IntExpression::LeftShift(box e1, box e2) => { - Ok(Self::LeftShift(box Self::try_from_int(e1)?, box e2)) + IntExpression::LeftShift(e) => { + Ok(Self::left_shift(Self::try_from_int(*e.left)?, *e.right)) } - IntExpression::RightShift(box e1, box e2) => { - Ok(Self::RightShift(box Self::try_from_int(e1)?, box e2)) + IntExpression::RightShift(e) => { + Ok(Self::right_shift(Self::try_from_int(*e.left)?, *e.right)) } - IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)), - IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)), + IntExpression::Pos(e) => Ok(Self::pos(Self::try_from_int(*e.inner)?)), + IntExpression::Neg(e) => Ok(Self::neg(Self::try_from_int(*e.inner)?)), IntExpression::Conditional(c) => Ok(Self::Conditional(ConditionalExpression::new( *c.condition, Self::try_from_int(*c.consequence)?, @@ -481,6 +569,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { match array.into_inner() { ArrayExpressionInner::Value(values) => { let values = values + .value .into_iter() .map(|v| { TypedExpressionOrSpread::align_to_type( @@ -501,8 +590,8 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { }) .collect::, _>>()?; Ok(FieldElementExpression::select( - ArrayExpressionInner::Value(values.into()) - .annotate(Type::FieldElement, size), + ArrayExpression::value(values) + .annotate(ArrayType::new(Type::FieldElement, size)), index, )) } @@ -511,6 +600,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { } i => Err(i), } + .map(|e| e.span(span)) } } @@ -537,54 +627,56 @@ impl<'ast, T: Field> UExpression<'ast, T> { ) -> Result> { use self::IntExpression::*; + let span = i.get_span(); + match i { Value(i) => { - if i <= BigUint::from(2u128.pow(bitwidth.to_usize() as u32) - 1) { - Ok(UExpressionInner::Value( - u128::from_str_radix(&i.to_str_radix(16), 16).unwrap(), + if i.value <= BigUint::from(2u128.pow(bitwidth.to_usize() as u32) - 1) { + Ok(UExpression::value( + u128::from_str_radix(&i.value.to_str_radix(16), 16).unwrap(), ) .annotate(*bitwidth)) } else { Err(Value(i)) } } - Add(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? + Self::try_from_int(e2, bitwidth)?) - } - Pos(box e) => Ok(Self::pos(Self::try_from_int(e, bitwidth)?)), - Neg(box e) => Ok(Self::neg(Self::try_from_int(e, bitwidth)?)), - Sub(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? - Self::try_from_int(e2, bitwidth)?) - } - Mult(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? * Self::try_from_int(e2, bitwidth)?) - } - Div(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? / Self::try_from_int(e2, bitwidth)?) - } - Rem(box e1, box e2) => { - Ok(Self::try_from_int(e1, bitwidth)? % Self::try_from_int(e2, bitwidth)?) - } - And(box e1, box e2) => Ok(UExpression::and( - Self::try_from_int(e1, bitwidth)?, - Self::try_from_int(e2, bitwidth)?, + Add(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? + Self::try_from_int(*e.right, bitwidth)? + ), + Pos(e) => Ok(Self::pos(Self::try_from_int(*e.inner, bitwidth)?)), + Neg(e) => Ok(Self::neg(Self::try_from_int(*e.inner, bitwidth)?)), + Sub(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? - Self::try_from_int(*e.right, bitwidth)? + ), + Mult(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? * Self::try_from_int(*e.right, bitwidth)? + ), + Div(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? / Self::try_from_int(*e.right, bitwidth)? + ), + Rem(e) => Ok( + Self::try_from_int(*e.left, bitwidth)? % Self::try_from_int(*e.right, bitwidth)? + ), + And(e) => Ok(UExpression::and( + Self::try_from_int(*e.left, bitwidth)?, + Self::try_from_int(*e.right, bitwidth)?, )), - Or(box e1, box e2) => Ok(UExpression::or( - Self::try_from_int(e1, bitwidth)?, - Self::try_from_int(e2, bitwidth)?, + Or(e) => Ok(UExpression::or( + Self::try_from_int(*e.left, bitwidth)?, + Self::try_from_int(*e.right, bitwidth)?, )), - Not(box e) => Ok(!Self::try_from_int(e, bitwidth)?), - Xor(box e1, box e2) => Ok(UExpression::xor( - Self::try_from_int(e1, bitwidth)?, - Self::try_from_int(e2, bitwidth)?, + Not(e) => Ok(!Self::try_from_int(*e.inner, bitwidth)?), + Xor(e) => Ok(UExpression::xor( + Self::try_from_int(*e.left, bitwidth)?, + Self::try_from_int(*e.right, bitwidth)?, )), - RightShift(box e1, box e2) => Ok(UExpression::right_shift( - Self::try_from_int(e1, bitwidth)?, - e2, + RightShift(e) => Ok(UExpression::right_shift( + Self::try_from_int(*e.left, bitwidth)?, + *e.right, )), - LeftShift(box e1, box e2) => Ok(UExpression::left_shift( - Self::try_from_int(e1, bitwidth)?, - e2, + LeftShift(e) => Ok(UExpression::left_shift( + Self::try_from_int(*e.left, bitwidth)?, + *e.right, )), Conditional(c) => Ok(UExpression::conditional( *c.condition, @@ -600,6 +692,7 @@ impl<'ast, T: Field> UExpression<'ast, T> { match array.into_inner() { ArrayExpressionInner::Value(values) => { let values = values + .value .into_iter() .map(|v| { TypedExpressionOrSpread::align_to_type( @@ -620,8 +713,8 @@ impl<'ast, T: Field> UExpression<'ast, T> { }) .collect::, _>>()?; Ok(UExpression::select( - ArrayExpressionInner::Value(values.into()) - .annotate(Type::Uint(*bitwidth), size), + ArrayExpression::value(values) + .annotate(ArrayType::new(Type::Uint(*bitwidth), size)), index, )) } @@ -630,6 +723,7 @@ impl<'ast, T: Field> UExpression<'ast, T> { } i => Err(i), } + .map(|e| e.span(span)) } } @@ -649,6 +743,8 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { array: Self, target_array_ty: &GArrayType, ) -> Result> { + let span = array.get_span(); + let array_ty = array.ty().clone(); // elements must fit in the target type @@ -659,6 +755,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { _ => { // try to convert all elements to the target type inline_array + .value .into_iter() .map(|v| { TypedExpressionOrSpread::align_to_type(v, target_array_ty).map_err( @@ -671,37 +768,43 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { ) }) .collect::, _>>() - .map(|v| v.into()) + .map(ArrayValueExpression::new) } }?; - let inner_ty = res.0[0].get_type().0; + let inner_ty = res.value[0].get_type().0; - Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, *array_ty.size)) + let array_ty = ArrayType::new(inner_ty, *array_ty.size); + + Ok(ArrayExpressionInner::Value(res).annotate(array_ty)) } - ArrayExpressionInner::Repeat(box e, box count) => { + ArrayExpressionInner::Repeat(r) => { match &*target_array_ty.ty { - GType::Int => Ok(ArrayExpressionInner::Repeat(box e, box count) - .annotate(Type::Int, *array_ty.size)), + GType::Int => { + let array_ty = ArrayType::new(Type::Int, *array_ty.size); + + Ok(ArrayExpressionInner::Repeat(r).annotate(array_ty)) + } // try to align the repeated element to the target type - t => TypedExpression::align_to_type(e, t) + t => TypedExpression::align_to_type(*r.e, t) .map(|e| { let ty = e.get_type().clone(); - ArrayExpressionInner::Repeat(box e, box count) - .annotate(ty, *array_ty.size) + ArrayExpressionInner::Repeat(RepeatExpression::new(e, *r.count)) + .annotate(ArrayType::new(ty, *array_ty.size)) }) .map_err(|(e, _)| e), } } a => { if *target_array_ty.ty == *array_ty.ty { - Ok(a.annotate(*array_ty.ty, *array_ty.size)) + Ok(a.annotate(array_ty)) } else { - Err(a.annotate(*array_ty.ty, *array_ty.size).into()) + Err(a.annotate(array_ty).into()) } } } + .map(|e| e.span(span)) } } @@ -710,6 +813,8 @@ impl<'ast, T: Field> StructExpression<'ast, T> { struc: Self, target_struct_ty: &GStructType, ) -> Result> { + let span = struc.get_span(); + let struct_ty = struc.ty().clone(); if struct_ty.members.len() != target_struct_ty.members.len() { @@ -724,7 +829,7 @@ impl<'ast, T: Field> StructExpression<'ast, T> { TypedExpression::align_to_type(value, &*target_member.ty) }) .collect::, _>>() - .map(|v| StructExpressionInner::Value(v).annotate(struct_ty.clone())) + .map(|v| StructExpression::value(v).annotate(struct_ty.clone())) .map_err(|(v, _)| v), s => { if struct_ty @@ -739,6 +844,7 @@ impl<'ast, T: Field> StructExpression<'ast, T> { } } } + .map(|e| e.span(span)) } pub fn try_from_typed>>( @@ -757,6 +863,8 @@ impl<'ast, T: Field> TupleExpression<'ast, T> { tuple: Self, target_tuple_ty: >upleType, ) -> Result> { + let span = tuple.get_span(); + let tuple_ty = tuple.ty().clone(); if tuple_ty.elements.len() != target_tuple_ty.elements.len() { @@ -771,7 +879,7 @@ impl<'ast, T: Field> TupleExpression<'ast, T> { .collect::, _>>() .map(|v| { let ty = TupleType::new(v.iter().map(|e| e.get_type()).collect()); - TupleExpressionInner::Value(v).annotate(ty) + TupleExpression::value(v).annotate(ty) }) .map_err(|(v, _)| v), s => { @@ -787,6 +895,7 @@ impl<'ast, T: Field> TupleExpression<'ast, T> { } } } + .map(|e| e.span(span)) } pub fn try_from_typed>>( @@ -800,9 +909,9 @@ impl<'ast, T: Field> TupleExpression<'ast, T> { } } -impl<'ast, T> From for IntExpression<'ast, T> { +impl<'ast, T: Field> From for IntExpression<'ast, T> { fn from(v: BigUint) -> Self { - IntExpression::Value(v) + IntExpression::value(v) } } @@ -816,11 +925,13 @@ mod tests { fn field_from_int() { let n: IntExpression = BigUint::from(42usize).into(); let n_a: ArrayExpression = - ArrayExpressionInner::Value(vec![n.clone().into()].into()).annotate(Type::Int, 1u32); + ArrayExpressionInner::Value(ValueExpression::new(vec![n.clone().into()])) + .annotate(ArrayType::new(Type::Int, 1u32)); let t: FieldElementExpression = Bn128Field::from(42).into(); let t_a: ArrayExpression = - ArrayExpressionInner::Value(vec![t.clone().into()].into()) - .annotate(Type::FieldElement, 1u32); + ArrayExpressionInner::Value(ValueExpression::new(vec![t.clone().into()])) + .annotate(ArrayType::new(Type::FieldElement, 1u32)); + let i: UExpression = 42u32.into(); let c: BooleanExpression = true.into(); @@ -876,11 +987,12 @@ mod tests { fn uint_from_int() { let n: IntExpression = BigUint::from(42usize).into(); let n_a: ArrayExpression = - ArrayExpressionInner::Value(vec![n.clone().into()].into()).annotate(Type::Int, 1u32); + ArrayExpressionInner::Value(ValueExpression::new(vec![n.clone().into()])) + .annotate(ArrayType::new(Type::Int, 1u32)); let t: UExpression = 42u32.into(); let t_a: ArrayExpression = - ArrayExpressionInner::Value(vec![t.clone().into()].into()) - .annotate(Type::Uint(UBitwidth::B32), 1u32); + ArrayExpressionInner::Value(ValueExpression::new(vec![t.clone().into()])) + .annotate(ArrayType::new(Type::Uint(UBitwidth::B32), 1u32)); let i: UExpression = 0u32.into(); let c: BooleanExpression = true.into(); diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index bd000d12a..77789fe06 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -27,16 +27,23 @@ pub use self::types::{ UBitwidth, }; use self::types::{ConcreteArrayType, ConcreteStructType}; + +use crate::common::expressions::{ + self, BinaryExpression, BooleanValueExpression, FieldValueExpression, UnaryExpression, + ValueExpression, +}; +use crate::common::{self, ModuleMap, Span, Value, WithSpan}; +pub use crate::common::{ModuleId, OwnedModuleId}; use crate::typed::types::IntoType; pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable}; use std::marker::PhantomData; -use std::path::{Path, PathBuf}; +use std::ops::Deref; pub use crate::typed::integer::IntExpression; pub use crate::typed::uint::{bitwidth, UExpression, UExpressionInner, UMetadata}; -use crate::common::{FlatEmbed, FormatString, SourceMetadata}; +use crate::common::{operators::*, FlatEmbed, FormatString, SourceMetadata}; use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; @@ -44,20 +51,17 @@ use std::fmt; pub use crate::typed::types::{ArrayType, FunctionKey, MemberId}; +use derivative::Derivative; +use num_bigint::BigUint; use zokrates_field::Field; pub use self::folder::Folder; use crate::typed::abi::{Abi, AbiInput}; -use std::ops::{Add, Div, Mul, Sub}; pub use self::identifier::Identifier; -/// An identifier for a `TypedModule`. Typically a path or uri. -pub type OwnedTypedModuleId = PathBuf; -pub type TypedModuleId = Path; - /// A collection of `TypedModule`s -pub type TypedModules<'ast, T> = BTreeMap>; +pub type TypedModules<'ast, T> = BTreeMap>; /// A collection of `TypedFunctionSymbol`s /// # Remarks @@ -80,12 +84,16 @@ pub type TypedConstantSymbols<'ast, T> = Vec<( )>; /// A typed program as a collection of modules, one of them being the main -#[derive(PartialEq, Eq, Debug, Clone)] +#[derive(PartialEq, Eq, Debug, Clone, Default)] pub struct TypedProgram<'ast, T> { + pub module_map: ModuleMap, pub modules: TypedModules<'ast, T>, - pub main: OwnedTypedModuleId, + pub main: OwnedModuleId, } +pub type IdentifierOrExpression<'ast, T, E> = + expressions::IdentifierOrExpression, E, >::Inner>; + impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn abi(&self) -> Abi { let main = &self.modules[&self.main] @@ -107,7 +115,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { crate::typed::types::try_from_g_type::< DeclarationConstant<'ast, T>, UExpression<'ast, T>, - >(p.id._type.clone()) + >(p.id.ty.clone()) .unwrap(), ) .map(|ty| AbiInput { @@ -397,6 +405,20 @@ pub enum TypedAssignee<'ast, T> { Element(Box>, u32), } +impl<'ast, T> TypedAssignee<'ast, T> { + pub fn select(self, index: UExpression<'ast, T>) -> Self { + Self::Select(Box::new(self), Box::new(index)) + } + + pub fn member(self, member: MemberId) -> Self { + Self::Member(Box::new(self), member) + } + + pub fn element(self, index: u32) -> Self { + Self::Element(Box::new(self), index) + } +} + #[derive(Clone, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)] pub struct TypedSpread<'ast, T> { pub array: ArrayExpression<'ast, T>, @@ -499,7 +521,7 @@ impl<'ast, T: Clone> TypedExpressionOrSpread<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionOrSpread<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - TypedExpressionOrSpread::Expression(e) => write!(f, "{}", e), + TypedExpressionOrSpread::Expression(ref e) => write!(f, "{}", e), TypedExpressionOrSpread::Spread(s) => write!(f, "{}", s), } } @@ -642,98 +664,242 @@ impl<'ast, T: fmt::Display> fmt::Display for DefinitionRhs<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { DefinitionRhs::EmbedCall(c) => write!(f, "{}", c), - DefinitionRhs::Expression(e) => write!(f, "{}", e), + DefinitionRhs::Expression(ref e) => write!(f, "{}", e), + } + } +} + +pub type DefinitionStatement<'ast, T> = + common::statements::DefinitionStatement, DefinitionRhs<'ast, T>>; +pub type AssertionStatement<'ast, T> = + common::statements::AssertionStatement, RuntimeError>; +pub type ReturnStatement<'ast, T> = common::statements::ReturnStatement>; +pub type LogStatement<'ast, T> = common::statements::LogStatement>; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct ForStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub var: Variable<'ast, T>, + pub from: UExpression<'ast, T>, + pub to: UExpression<'ast, T>, + pub statements: Vec>, +} + +impl<'ast, T> ForStatement<'ast, T> { + fn new( + var: Variable<'ast, T>, + from: UExpression<'ast, T>, + to: UExpression<'ast, T>, + statements: Vec>, + ) -> Self { + Self { + span: None, + var, + from, + to, + statements, } } } +impl<'ast, T> WithSpan for ForStatement<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct AssemblyBlockStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub inner: Vec>, +} + +impl<'ast, T> AssemblyBlockStatement<'ast, T> { + pub fn new(inner: Vec>) -> Self { + Self { span: None, inner } + } +} + +pub type AssemblyConstraint<'ast, T> = + crate::common::statements::AssemblyConstraint>; +pub type AssemblyAssignment<'ast, T> = + crate::common::statements::AssemblyAssignment, TypedExpression<'ast, T>>; + #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedAssemblyStatement<'ast, T> { - Assignment(TypedAssignee<'ast, T>, TypedExpression<'ast, T>), - Constraint( - FieldElementExpression<'ast, T>, - FieldElementExpression<'ast, T>, - SourceMetadata, - ), + Assignment(AssemblyAssignment<'ast, T>), + Constraint(AssemblyConstraint<'ast, T>), +} + +impl<'ast, T> WithSpan for TypedAssemblyStatement<'ast, T> { + fn span(self, span: Option) -> Self { + match self { + TypedAssemblyStatement::Assignment(s) => { + TypedAssemblyStatement::Assignment(s.span(span)) + } + TypedAssemblyStatement::Constraint(s) => { + TypedAssemblyStatement::Constraint(s.span(span)) + } + } + } + + fn get_span(&self) -> Option { + match self { + TypedAssemblyStatement::Assignment(s) => s.get_span(), + TypedAssemblyStatement::Constraint(s) => s.get_span(), + } + } +} + +impl<'ast, T> TypedAssemblyStatement<'ast, T> { + pub fn assignment( + assignee: TypedAssignee<'ast, T>, + expression: TypedExpression<'ast, T>, + ) -> Self { + TypedAssemblyStatement::Assignment(AssemblyAssignment::new(assignee, expression)) + } + + pub fn constraint( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + metadata: SourceMetadata, + ) -> Self { + TypedAssemblyStatement::Constraint(AssemblyConstraint::new(left, right, metadata)) + } } impl<'ast, T: fmt::Display> fmt::Display for TypedAssemblyStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - TypedAssemblyStatement::Assignment(ref lhs, ref rhs) => { - write!(f, "{} <-- {};", lhs, rhs) + TypedAssemblyStatement::Assignment(ref s) => { + write!(f, "{} <-- {};", s.assignee, s.expression) } - TypedAssemblyStatement::Constraint(ref lhs, ref rhs, _) => { - write!(f, "{} === {};", lhs, rhs) + TypedAssemblyStatement::Constraint(ref s) => { + write!(f, "{}", s) } } } } +impl<'ast, T> WithSpan for AssemblyBlockStatement<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + /// A statement in a `TypedFunction` #[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedStatement<'ast, T> { - Return(TypedExpression<'ast, T>), - Definition(TypedAssignee<'ast, T>, DefinitionRhs<'ast, T>), - Assertion(BooleanExpression<'ast, T>, RuntimeError), - For( - Variable<'ast, T>, - UExpression<'ast, T>, - UExpression<'ast, T>, - Vec>, - ), - Log(FormatString, Vec>), - Assembly(Vec>), + Return(ReturnStatement<'ast, T>), + Definition(DefinitionStatement<'ast, T>), + Assertion(AssertionStatement<'ast, T>), + For(ForStatement<'ast, T>), + Log(LogStatement<'ast, T>), + Assembly(AssemblyBlockStatement<'ast, T>), +} + +impl<'ast, T> WithSpan for TypedStatement<'ast, T> { + fn span(self, span: Option) -> Self { + use TypedStatement::*; + + match self { + Return(e) => Return(e.span(span)), + Definition(e) => Definition(e.span(span)), + Assertion(e) => Assertion(e.span(span)), + For(e) => For(e.span(span)), + Log(e) => Log(e.span(span)), + Assembly(e) => Assembly(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use TypedStatement::*; + + match self { + Return(e) => e.get_span(), + Definition(e) => e.get_span(), + Assertion(e) => e.get_span(), + For(e) => e.get_span(), + Log(e) => e.get_span(), + Assembly(e) => e.get_span(), + } + } } impl<'ast, T> TypedStatement<'ast, T> { pub fn definition(a: TypedAssignee<'ast, T>, e: TypedExpression<'ast, T>) -> Self { - Self::Definition(a, e.into()) + Self::Definition(DefinitionStatement::new(a, e.into())) + } + + pub fn for_( + var: Variable<'ast, T>, + from: UExpression<'ast, T>, + to: UExpression<'ast, T>, + statements: Vec>, + ) -> Self { + Self::For(ForStatement::new(var, from, to, statements)) + } + + pub fn assertion(e: BooleanExpression<'ast, T>, error: RuntimeError) -> Self { + Self::Assertion(AssertionStatement::new(e, error)) + } + + pub fn ret(e: TypedExpression<'ast, T>) -> Self { + Self::Return(ReturnStatement::new(e)) } pub fn embed_call_definition(a: TypedAssignee<'ast, T>, c: EmbedCall<'ast, T>) -> Self { - Self::Definition(a, c.into()) + Self::Definition(DefinitionStatement::new(a, c.into())) + } + + pub fn log(format_string: FormatString, expressions: Vec>) -> Self { + Self::Log(LogStatement::new(format_string, expressions)) } } impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - TypedStatement::Return(ref e) => { - write!(f, "return {};", e) + TypedStatement::Return(ref s) => { + write!(f, "{}", s) } - TypedStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {};", lhs, rhs), - TypedStatement::Assertion(ref e, ref error) => { - write!(f, "assert({}", e)?; - match error { - RuntimeError::SourceAssertion(metadata) => match &metadata.message { + TypedStatement::Definition(ref s) => write!(f, "{}", s), + TypedStatement::Assertion(ref s) => { + write!(f, "assert({}", s.expression)?; + match s.error { + RuntimeError::SourceAssertion(ref metadata) => match &metadata.message { Some(m) => write!(f, ", \"{}\");", m), None => write!(f, ");"), }, - error => write!(f, "); // {}", error), + ref error => write!(f, "); // {}", error), } } - TypedStatement::For(ref var, ref start, ref stop, ref list) => { - writeln!(f, "for {} in {}..{} {{", var, start, stop)?; - for l in list { + TypedStatement::For(ref s) => { + writeln!(f, "for {} in {}..{} {{", s.var, s.from, s.to)?; + for l in &s.statements { writeln!(f, "\t\t{}", l)?; } write!(f, "\t}}") } - TypedStatement::Log(ref l, ref expressions) => write!( - f, - "log({}, {})", - l, - expressions - .iter() - .map(|e| e.to_string()) - .collect::>() - .join(", ") - ), - TypedStatement::Assembly(ref statements) => { + TypedStatement::Log(ref s) => write!(f, "{}", s), + TypedStatement::Assembly(ref s) => { writeln!(f, "asm {{")?; - for s in statements { + for s in &s.inner { writeln!(f, "\t\t{}", s)?; } write!(f, "\t}}") @@ -759,9 +925,9 @@ pub enum TypedExpression<'ast, T> { Int(IntExpression<'ast, T>), } -impl<'ast, T> TypedExpression<'ast, T> { +impl<'ast, T: Field> TypedExpression<'ast, T> { pub fn empty_tuple() -> TypedExpression<'ast, T> { - TypedExpression::Tuple(TupleExpressionInner::Value(vec![]).annotate(TupleType::new(vec![]))) + TypedExpression::Tuple(TupleExpression::value(vec![]).annotate(TupleType::new(vec![]))) } } @@ -905,65 +1071,112 @@ impl<'ast, T: Clone> Typed<'ast, T> for BooleanExpression<'ast, T> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] -pub struct EqExpression { - pub left: Box, - pub right: Box, +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct BlockExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub statements: Vec>, + pub value: Box, } -impl EqExpression { - pub fn new(left: E, right: E) -> Self { - EqExpression { - left: box left, - right: box right, +impl<'ast, T, E> BlockExpression<'ast, T, E> { + pub fn new(statements: Vec>, value: E) -> Self { + BlockExpression { + span: None, + statements, + value: Box::new(value), } } } -impl fmt::Display for EqExpression { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "({} == {})", self.left, self.right) +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct SliceExpression<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub array: Box>, + pub from: Box>, + pub to: Box>, +} + +impl<'ast, T> SliceExpression<'ast, T> { + pub fn new( + array: ArrayExpression<'ast, T>, + from: UExpression<'ast, T>, + to: UExpression<'ast, T>, + ) -> Self { + SliceExpression { + span: None, + array: Box::new(array), + from: Box::new(from), + to: Box::new(to), + } } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] -pub struct BlockExpression<'ast, T, E> { - pub statements: Vec>, - pub value: Box, +impl<'ast, T> WithSpan for SliceExpression<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } } -impl<'ast, T, E> BlockExpression<'ast, T, E> { - pub fn new(statements: Vec>, value: E) -> Self { - BlockExpression { - statements, - value: box value, - } +impl<'ast, T: fmt::Display> fmt::Display for SliceExpression<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}[{}..{}]", self.array, self.from, self.to) } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] -pub struct IdentifierExpression<'ast, E> { - pub id: Identifier<'ast>, - ty: PhantomData, +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] +pub struct RepeatExpression<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + pub count: Box>, + pub e: Box>, } -impl<'ast, E> IdentifierExpression<'ast, E> { - pub fn new(id: Identifier<'ast>) -> Self { - IdentifierExpression { - id, - ty: PhantomData, +impl<'ast, T> RepeatExpression<'ast, T> { + pub fn new(e: TypedExpression<'ast, T>, count: UExpression<'ast, T>) -> Self { + RepeatExpression { + span: None, + e: Box::new(e), + count: Box::new(count), } } } -impl<'ast, E> fmt::Display for IdentifierExpression<'ast, E> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.id) +impl<'ast, T> WithSpan for RepeatExpression<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +impl<'ast, T: fmt::Display> fmt::Display for RepeatExpression<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[{}; {}]", self.e, self.count) + } +} + +pub type IdentifierExpression<'ast, E> = expressions::IdentifierExpression, E>; + +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct MemberExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub struc: Box>, pub id: MemberId, ty: PhantomData, @@ -972,7 +1185,8 @@ pub struct MemberExpression<'ast, T, E> { impl<'ast, T, E> MemberExpression<'ast, T, E> { pub fn new(struc: StructExpression<'ast, T>, id: MemberId) -> Self { MemberExpression { - struc: box struc, + span: None, + struc: Box::new(struc), id, ty: PhantomData, } @@ -985,8 +1199,12 @@ impl<'ast, T: fmt::Display, E> fmt::Display for MemberExpression<'ast, T, E> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct SelectExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub array: Box>, pub index: Box>, ty: PhantomData, @@ -995,8 +1213,9 @@ pub struct SelectExpression<'ast, T, E> { impl<'ast, T, E> SelectExpression<'ast, T, E> { pub fn new(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { SelectExpression { - array: box array, - index: box index, + span: None, + array: Box::new(array), + index: Box::new(index), ty: PhantomData, } } @@ -1008,8 +1227,12 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct ElementExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub tuple: Box>, pub index: u32, ty: PhantomData, @@ -1018,7 +1241,8 @@ pub struct ElementExpression<'ast, T, E> { impl<'ast, T, E> ElementExpression<'ast, T, E> { pub fn new(tuple: TupleExpression<'ast, T>, index: u32) -> Self { ElementExpression { - tuple: box tuple, + span: None, + tuple: Box::new(tuple), index, ty: PhantomData, } @@ -1037,8 +1261,12 @@ pub enum ConditionalKind { Ternary, } -#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct ConditionalExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub condition: Box>, pub consequence: Box, pub alternative: Box, @@ -1053,9 +1281,10 @@ impl<'ast, T, E> ConditionalExpression<'ast, T, E> { kind: ConditionalKind, ) -> Self { ConditionalExpression { - condition: box condition, - consequence: box consequence, - alternative: box alternative, + span: None, + condition: Box::new(condition), + consequence: Box::new(consequence), + alternative: Box::new(alternative), kind, } } @@ -1080,8 +1309,12 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpress } } -#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] +#[derive(Derivative)] +#[derivative(PartialOrd, PartialEq, Eq, Hash, Ord)] +#[derive(Clone, Debug)] pub struct FunctionCallExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub function_key: DeclarationFunctionKey<'ast, T>, pub generics: Vec>>, pub arguments: Vec>, @@ -1095,6 +1328,7 @@ impl<'ast, T, E> FunctionCallExpression<'ast, T, E> { arguments: Vec>, ) -> Self { FunctionCallExpression { + span: None, function_key, generics, arguments, @@ -1136,51 +1370,21 @@ impl<'ast, T: fmt::Display, E> fmt::Display for FunctionCallExpression<'ast, T, #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum FieldElementExpression<'ast, T> { Block(BlockExpression<'ast, T, Self>), - Number(T), + Value(FieldValueExpression), Identifier(IdentifierExpression<'ast, Self>), - Add( - Box>, - Box>, - ), - Sub( - Box>, - Box>, - ), - Mult( - Box>, - Box>, - ), - Div( - Box>, - Box>, - ), - Pow( - Box>, - Box>, - ), - And( - Box>, - Box>, - ), - Or( - Box>, - Box>, - ), - Xor( - Box>, - Box>, - ), - LeftShift( - Box>, - Box>, - ), - RightShift( - Box>, - Box>, - ), + Add(BinaryExpression), + Sub(BinaryExpression), + Mult(BinaryExpression), + Div(BinaryExpression), + Pow(BinaryExpression, Self>), + And(BinaryExpression), + Or(BinaryExpression), + Xor(BinaryExpression), + LeftShift(BinaryExpression, Self>), + RightShift(BinaryExpression, Self>), Conditional(ConditionalExpression<'ast, T, Self>), - Neg(Box>), - Pos(Box>), + Neg(UnaryExpression), + Pos(UnaryExpression), FunctionCall(FunctionCallExpression<'ast, T, Self>), Member(MemberExpression<'ast, T, Self>), Select(SelectExpression<'ast, T, Self>), @@ -1192,14 +1396,14 @@ impl<'ast, T: Field> From> for TupleExpression<'ast, T> { match assignee { TypedAssignee::Identifier(v) => { let inner = TupleExpression::identifier(v.id); - match v._type { + match v.ty { GType::Tuple(tuple_ty) => inner.annotate(tuple_ty), _ => unreachable!(), } } - TypedAssignee::Select(box a, box index) => TupleExpression::select(a.into(), index), - TypedAssignee::Member(box a, id) => TupleExpression::member(a.into(), id), - TypedAssignee::Element(box a, index) => TupleExpression::element(a.into(), index), + TypedAssignee::Select(a, index) => TupleExpression::select((*a).into(), *index), + TypedAssignee::Member(a, id) => TupleExpression::member((*a).into(), id), + TypedAssignee::Element(a, index) => TupleExpression::element((*a).into(), index), } } } @@ -1209,14 +1413,14 @@ impl<'ast, T: Field> From> for StructExpression<'ast, T> match assignee { TypedAssignee::Identifier(v) => { let inner = StructExpression::identifier(v.id); - match v._type { + match v.ty { GType::Struct(struct_ty) => inner.annotate(struct_ty), _ => unreachable!(), } } - TypedAssignee::Select(box a, box index) => StructExpression::select(a.into(), index), - TypedAssignee::Member(box a, id) => StructExpression::member(a.into(), id), - TypedAssignee::Element(box a, index) => StructExpression::element(a.into(), index), + TypedAssignee::Select(a, index) => StructExpression::select((*a).into(), *index), + TypedAssignee::Member(a, id) => StructExpression::member((*a).into(), id), + TypedAssignee::Element(a, index) => StructExpression::element((*a).into(), index), } } } @@ -1226,14 +1430,14 @@ impl<'ast, T: Field> From> for ArrayExpression<'ast, T> { match assignee { TypedAssignee::Identifier(v) => { let inner = ArrayExpression::identifier(v.id); - match v._type { - GType::Array(array_ty) => inner.annotate(*array_ty.ty, *array_ty.size), + match v.ty { + GType::Array(array_ty) => inner.annotate(array_ty), _ => unreachable!(), } } - TypedAssignee::Select(box a, box index) => ArrayExpression::select(a.into(), index), - TypedAssignee::Member(box a, id) => ArrayExpression::member(a.into(), id), - TypedAssignee::Element(box a, index) => ArrayExpression::element(a.into(), index), + TypedAssignee::Select(a, index) => ArrayExpression::select((*a).into(), *index), + TypedAssignee::Member(a, id) => ArrayExpression::member((*a).into(), id), + TypedAssignee::Element(a, index) => ArrayExpression::element((*a).into(), index), } } } @@ -1242,58 +1446,98 @@ impl<'ast, T: Field> From> for FieldElementExpression<'as fn from(assignee: TypedAssignee<'ast, T>) -> Self { match assignee { TypedAssignee::Identifier(v) => FieldElementExpression::identifier(v.id), - TypedAssignee::Element(box a, index) => { - FieldElementExpression::element(a.into(), index) - } - TypedAssignee::Member(box a, id) => FieldElementExpression::member(a.into(), id), - TypedAssignee::Select(box a, box index) => { - FieldElementExpression::select(a.into(), index) - } + TypedAssignee::Element(a, index) => FieldElementExpression::element((*a).into(), index), + TypedAssignee::Member(a, id) => FieldElementExpression::member((*a).into(), id), + TypedAssignee::Select(a, index) => FieldElementExpression::select((*a).into(), *index), } } } -impl<'ast, T> Add for FieldElementExpression<'ast, T> { +impl<'ast, T> std::ops::Add for FieldElementExpression<'ast, T> { type Output = Self; fn add(self, other: Self) -> Self { - FieldElementExpression::Add(box self, box other) + FieldElementExpression::Add(BinaryExpression::new(self, other)) } } -impl<'ast, T> Sub for FieldElementExpression<'ast, T> { +impl<'ast, T: Field> std::ops::Sub for FieldElementExpression<'ast, T> { type Output = Self; fn sub(self, other: Self) -> Self { - FieldElementExpression::Sub(box self, box other) + FieldElementExpression::Sub(BinaryExpression::new(self, other)) } } -impl<'ast, T> Mul for FieldElementExpression<'ast, T> { +impl<'ast, T: Field> std::ops::Mul for FieldElementExpression<'ast, T> { type Output = Self; fn mul(self, other: Self) -> Self { - FieldElementExpression::Mult(box self, box other) + FieldElementExpression::Mult(BinaryExpression::new(self, other)) } } -impl<'ast, T> Div for FieldElementExpression<'ast, T> { +impl<'ast, T: Field> std::ops::Div for FieldElementExpression<'ast, T> { type Output = Self; fn div(self, other: Self) -> Self { - FieldElementExpression::Div(box self, box other) + FieldElementExpression::Div(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitAnd for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + FieldElementExpression::And(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitOr for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + FieldElementExpression::Or(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitXor for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitxor(self, other: Self) -> Self { + FieldElementExpression::Xor(BinaryExpression::new(self, other)) } } -impl<'ast, T> FieldElementExpression<'ast, T> { +impl<'ast, T> std::ops::Neg for FieldElementExpression<'ast, T> { + type Output = Self; + + fn neg(self) -> Self { + FieldElementExpression::Neg(UnaryExpression::new(self)) + } +} + +impl<'ast, T: Field> FieldElementExpression<'ast, T> { pub fn pow(self, other: UExpression<'ast, T>) -> Self { - FieldElementExpression::Pow(box self, box other) + FieldElementExpression::Pow(BinaryExpression::new(self, other)) + } + + pub fn pos(self) -> Self { + FieldElementExpression::Pos(UnaryExpression::new(self)) + } + + pub fn left_shift(self, by: UExpression<'ast, T>) -> Self { + FieldElementExpression::LeftShift(BinaryExpression::new(self, by)) + } + + pub fn right_shift(self, by: UExpression<'ast, T>) -> Self { + FieldElementExpression::RightShift(BinaryExpression::new(self, by)) } } -impl<'ast, T> From for FieldElementExpression<'ast, T> { +impl<'ast, T: Clone> From for FieldElementExpression<'ast, T> { fn from(n: T) -> Self { - FieldElementExpression::Number(n) + FieldElementExpression::Value(ValueExpression::new(n)) } } @@ -1302,42 +1546,41 @@ impl<'ast, T> From for FieldElementExpression<'ast, T> { pub enum BooleanExpression<'ast, T> { Block(BlockExpression<'ast, T, Self>), Identifier(IdentifierExpression<'ast, Self>), - Value(bool), + Value(BooleanValueExpression), FieldLt( - Box>, - Box>, + BinaryExpression< + OpLt, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), FieldLe( - Box>, - Box>, - ), - FieldGe( - Box>, - Box>, + BinaryExpression< + OpLe, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), - FieldGt( - Box>, - Box>, + UintLt(BinaryExpression, UExpression<'ast, T>, Self>), + UintLe(BinaryExpression, UExpression<'ast, T>, Self>), + FieldEq( + BinaryExpression< + OpEq, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), - UintLt(Box>, Box>), - UintLe(Box>, Box>), - UintGe(Box>, Box>), - UintGt(Box>, Box>), - FieldEq(EqExpression>), - BoolEq(EqExpression>), - ArrayEq(EqExpression>), - StructEq(EqExpression>), - TupleEq(EqExpression>), - UintEq(EqExpression>), - Or( - Box>, - Box>, - ), - And( - Box>, - Box>, - ), - Not(Box>), + BoolEq(BinaryExpression, BooleanExpression<'ast, T>, Self>), + ArrayEq(BinaryExpression, ArrayExpression<'ast, T>, Self>), + StructEq(BinaryExpression, StructExpression<'ast, T>, Self>), + TupleEq(BinaryExpression, TupleExpression<'ast, T>, Self>), + UintEq(BinaryExpression, UExpression<'ast, T>, Self>), + Or(BinaryExpression, BooleanExpression<'ast, T>, Self>), + And(BinaryExpression, BooleanExpression<'ast, T>, Self>), + Not(UnaryExpression), Conditional(ConditionalExpression<'ast, T, Self>), Member(MemberExpression<'ast, T, Self>), FunctionCall(FunctionCallExpression<'ast, T, Self>), @@ -1347,7 +1590,104 @@ pub enum BooleanExpression<'ast, T> { impl<'ast, T> From for BooleanExpression<'ast, T> { fn from(b: bool) -> Self { - BooleanExpression::Value(b) + BooleanExpression::Value(ValueExpression::new(b)) + } +} + +impl<'ast, T> std::ops::Not for BooleanExpression<'ast, T> { + type Output = Self; + + fn not(self) -> Self { + Self::Not(UnaryExpression::new(self)) + } +} + +impl<'ast, T> std::ops::BitAnd for BooleanExpression<'ast, T> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + Self::And(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> std::ops::BitOr for BooleanExpression<'ast, T> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + Self::Or(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> BooleanExpression<'ast, T> { + pub fn uint_eq(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintEq(BinaryExpression::new(left, right)) + } + + pub fn bool_eq(left: BooleanExpression<'ast, T>, right: BooleanExpression<'ast, T>) -> Self { + Self::BoolEq(BinaryExpression::new(left, right)) + } + + pub fn field_eq( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldEq(BinaryExpression::new(left, right)) + } + + pub fn struct_eq(left: StructExpression<'ast, T>, right: StructExpression<'ast, T>) -> Self { + Self::StructEq(BinaryExpression::new(left, right)) + } + + pub fn array_eq(left: ArrayExpression<'ast, T>, right: ArrayExpression<'ast, T>) -> Self { + Self::ArrayEq(BinaryExpression::new(left, right)) + } + + pub fn tuple_eq(left: TupleExpression<'ast, T>, right: TupleExpression<'ast, T>) -> Self { + Self::TupleEq(BinaryExpression::new(left, right)) + } + + pub fn uint_lt(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLt(BinaryExpression::new(left, right)) + } + + pub fn uint_le(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLe(BinaryExpression::new(left, right)) + } + + pub fn uint_gt(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLt(BinaryExpression::new(right, left)) + } + + pub fn uint_ge(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLe(BinaryExpression::new(right, left)) + } + + pub fn field_lt( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLt(BinaryExpression::new(left, right)) + } + + pub fn field_le( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLe(BinaryExpression::new(left, right)) + } + + pub fn field_gt( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLt(BinaryExpression::new(right, left)) + } + + pub fn field_ge( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLe(BinaryExpression::new(right, left)) } } @@ -1362,25 +1702,34 @@ pub struct ArrayExpression<'ast, T> { pub inner: ArrayExpressionInner<'ast, T>, } -#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)] -pub struct ArrayValue<'ast, T>(pub Vec>); +type ArrayValueExpression<'ast, T> = ValueExpression>>; -impl<'ast, T> From>> for ArrayValue<'ast, T> { - fn from(array: Vec>) -> Self { - Self(array) - } -} - -impl<'ast, T> IntoIterator for ArrayValue<'ast, T> { +impl<'ast, T> IntoIterator for ArrayValueExpression<'ast, T> { type Item = TypedExpressionOrSpread<'ast, T>; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() + self.value.into_iter() } } -impl<'ast, T: Field> ArrayValue<'ast, T> { +impl<'ast, T> Deref for ArrayValueExpression<'ast, T> { + type Target = [TypedExpressionOrSpread<'ast, T>]; + + fn deref(&self) -> &Self::Target { + &self.value[..] + } +} + +impl<'ast, T> std::iter::FromIterator> + for ArrayValueExpression<'ast, T> +{ + fn from_iter>>(iter: I) -> Self { + Self::new(iter.into_iter().collect()) + } +} + +impl<'ast, T: Field> ArrayValueExpression<'ast, T> { fn expression_at_aux< U: Select<'ast, T> + From> + Into>, >( @@ -1396,13 +1745,9 @@ impl<'ast, T: Field> ArrayValue<'ast, T> { ArrayExpressionInner::Value(v) => { v.into_iter().flat_map(Self::expression_at_aux).collect() } - a => (0..size) + a => (0..size.value) .map(|i| { - Some(U::select( - a.clone() - .annotate(*array_ty.ty.clone(), *array_ty.size.clone()), - i as u32, - )) + Some(U::select(a.clone().annotate(array_ty.clone()), i as u32)) }) .collect(), } @@ -1418,8 +1763,7 @@ impl<'ast, T: Field> ArrayValue<'ast, T> { &self, index: usize, ) -> Option { - self.0 - .iter() + self.iter() .flat_map(|v| Self::expression_at_aux(v.clone())) .take_while(|e| e.is_some()) .map(|e| e.unwrap()) @@ -1427,44 +1771,24 @@ impl<'ast, T: Field> ArrayValue<'ast, T> { } } -impl<'ast, T> ArrayValue<'ast, T> { - fn iter(&self) -> std::slice::Iter> { - self.0.iter() - } -} - -impl<'ast, T> std::iter::FromIterator> for ArrayValue<'ast, T> { - fn from_iter>>(iter: I) -> Self { - Self(iter.into_iter().collect()) - } -} - #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum ArrayExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>), Identifier(IdentifierExpression<'ast, ArrayExpression<'ast, T>>), - Value(ArrayValue<'ast, T>), + Value(ArrayValueExpression<'ast, T>), FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, ArrayExpression<'ast, T>>), Member(MemberExpression<'ast, T, ArrayExpression<'ast, T>>), Select(SelectExpression<'ast, T, ArrayExpression<'ast, T>>), Element(ElementExpression<'ast, T, ArrayExpression<'ast, T>>), - Slice( - Box>, - Box>, - Box>, - ), - Repeat(Box>, Box>), + Slice(SliceExpression<'ast, T>), + Repeat(RepeatExpression<'ast, T>), } impl<'ast, T> ArrayExpressionInner<'ast, T> { - pub fn annotate>>( - self, - ty: Type<'ast, T>, - size: S, - ) -> ArrayExpression<'ast, T> { + pub fn annotate(self, ty: ArrayType<'ast, T>) -> ArrayExpression<'ast, T> { ArrayExpression { - ty: box (ty, size.into()).into(), + ty: Box::new(ty), inner: self, } } @@ -1478,6 +1802,24 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> { pub fn size(&self) -> UExpression<'ast, T> { *self.ty.size.clone() } + + pub fn slice( + array: ArrayExpression<'ast, T>, + from: UExpression<'ast, T>, + to: UExpression<'ast, T>, + ) -> Self { + let inner = array.inner_type().clone(); + let size = to.clone() - from.clone(); + let array_ty = ArrayType::new(inner, size); + ArrayExpressionInner::Slice(SliceExpression::new(array, from, to)).annotate(array_ty) + } + + pub fn repeat(e: TypedExpression<'ast, T>, count: UExpression<'ast, T>) -> Self { + let inner = e.get_type().clone(); + let size = count.clone(); + let array_ty = ArrayType::new(inner, size); + ArrayExpressionInner::Repeat(RepeatExpression::new(e, count)).annotate(array_ty) + } } #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] @@ -1486,6 +1828,31 @@ pub struct StructExpression<'ast, T> { inner: StructExpressionInner<'ast, T>, } +type StructValueExpression<'ast, T> = ValueExpression>>; + +impl<'ast, T> IntoIterator for StructValueExpression<'ast, T> { + type Item = TypedExpression<'ast, T>; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.value.into_iter() + } +} + +impl<'ast, T> Deref for StructValueExpression<'ast, T> { + type Target = [TypedExpression<'ast, T>]; + + fn deref(&self) -> &Self::Target { + &self.value[..] + } +} + +impl<'ast, T> std::iter::FromIterator> for StructValueExpression<'ast, T> { + fn from_iter>>(iter: I) -> Self { + Self::new(iter.into_iter().collect()) + } +} + impl<'ast, T> StructExpression<'ast, T> { pub fn ty(&self) -> &StructType<'ast, T> { &self.ty @@ -1508,7 +1875,7 @@ impl<'ast, T> StructExpression<'ast, T> { pub enum StructExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, StructExpression<'ast, T>>), Identifier(IdentifierExpression<'ast, StructExpression<'ast, T>>), - Value(Vec>), + Value(TupleValueExpression<'ast, T>), FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, StructExpression<'ast, T>>), Member(MemberExpression<'ast, T, StructExpression<'ast, T>>), @@ -1546,11 +1913,13 @@ impl<'ast, T> TupleExpression<'ast, T> { } } +type TupleValueExpression<'ast, T> = ValueExpression>>; + #[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TupleExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, TupleExpression<'ast, T>>), Identifier(IdentifierExpression<'ast, TupleExpression<'ast, T>>), - Value(Vec>), + Value(TupleValueExpression<'ast, T>), FunctionCall(FunctionCallExpression<'ast, T, TupleExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, TupleExpression<'ast, T>>), Member(MemberExpression<'ast, T, TupleExpression<'ast, T>>), @@ -1715,7 +2084,7 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for BlockExpression<'a .map(|s| s.to_string()) .chain(std::iter::once(self.value.to_string())) .collect::>() - .join("\n") + .join("\n"), ) } } @@ -1724,21 +2093,21 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { FieldElementExpression::Block(ref block) => write!(f, "{}", block), - FieldElementExpression::Number(ref i) => write!(f, "{}f", i), + FieldElementExpression::Value(ref i) => write!(f, "{}f", i), FieldElementExpression::Identifier(ref var) => write!(f, "{}", var), - FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - FieldElementExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), - FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), - FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e), - FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e), - FieldElementExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - FieldElementExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - FieldElementExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - FieldElementExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), - FieldElementExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), + FieldElementExpression::Add(ref e) => write!(f, "{}", e), + FieldElementExpression::Sub(ref e) => write!(f, "{}", e), + FieldElementExpression::Mult(ref e) => write!(f, "{}", e), + FieldElementExpression::Div(ref e) => write!(f, "{}", e), + FieldElementExpression::Pow(ref e) => write!(f, "{}", e), + FieldElementExpression::Neg(ref e) => write!(f, "{}", e), + FieldElementExpression::Pos(ref e) => write!(f, "{}", e), FieldElementExpression::Conditional(ref c) => write!(f, "{}", c), + FieldElementExpression::And(ref e) => write!(f, "{}", e), + FieldElementExpression::Or(ref e) => write!(f, "{}", e), + FieldElementExpression::Xor(ref e) => write!(f, "{}", e), + FieldElementExpression::LeftShift(ref e) => write!(f, "{}", e), + FieldElementExpression::RightShift(ref e) => write!(f, "{}", e), FieldElementExpression::FunctionCall(ref function_call) => { write!(f, "{}", function_call) } @@ -1755,22 +2124,22 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::Block(ref block) => write!(f, "{}", block,), UExpressionInner::Value(ref v) => write!(f, "{}", v), UExpressionInner::Identifier(ref var) => write!(f, "{}", var), - UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - UExpressionInner::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - UExpressionInner::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - UExpressionInner::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - UExpressionInner::FloorSub(ref lhs, ref rhs) => { - write!(f, "(FLOOR_SUB({}, {}))", lhs, rhs) + UExpressionInner::Add(ref e) => write!(f, "{}", e), + UExpressionInner::And(ref e) => write!(f, "{}", e), + UExpressionInner::Or(ref e) => write!(f, "{}", e), + UExpressionInner::Xor(ref e) => write!(f, "{}", e), + UExpressionInner::Sub(ref e) => write!(f, "{}", e), + UExpressionInner::Mult(ref e) => write!(f, "{}", e), + UExpressionInner::FloorSub(ref e) => { + write!(f, "{}", e) } - UExpressionInner::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), - UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), - UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), - UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), - UExpressionInner::Not(ref e) => write!(f, "!{}", e), - UExpressionInner::Neg(ref e) => write!(f, "(-{})", e), - UExpressionInner::Pos(ref e) => write!(f, "(+{})", e), + UExpressionInner::Div(ref e) => write!(f, "{}", e), + UExpressionInner::Rem(ref e) => write!(f, "{}", e), + UExpressionInner::RightShift(ref e) => write!(f, "{}", e), + UExpressionInner::LeftShift(ref e) => write!(f, "{}", e), + UExpressionInner::Not(ref e) => write!(f, "{}", e), + UExpressionInner::Neg(ref e) => write!(f, "{}", e), + UExpressionInner::Pos(ref e) => write!(f, "{}", e), UExpressionInner::Select(ref select) => write!(f, "{}", select), UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call), UExpressionInner::Conditional(ref c) => write!(f, "{}", c), @@ -1782,27 +2151,23 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { + match &self { BooleanExpression::Block(ref block) => write!(f, "{}", block,), BooleanExpression::Identifier(ref var) => write!(f, "{}", var), - BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs), - BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs), - BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "({} >= {})", lhs, rhs), - BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "({} > {})", lhs, rhs), - BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs), - BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs), - BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "({} >= {})", lhs, rhs), - BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "({} > {})", lhs, rhs), + BooleanExpression::FieldLt(ref e) => write!(f, "{}", e), + BooleanExpression::FieldLe(ref e) => write!(f, "{}", e), + BooleanExpression::UintLt(ref e) => write!(f, "{}", e), + BooleanExpression::UintLe(ref e) => write!(f, "{}", e), BooleanExpression::FieldEq(ref e) => write!(f, "{}", e), BooleanExpression::BoolEq(ref e) => write!(f, "{}", e), BooleanExpression::ArrayEq(ref e) => write!(f, "{}", e), BooleanExpression::StructEq(ref e) => write!(f, "{}", e), BooleanExpression::TupleEq(ref e) => write!(f, "{}", e), BooleanExpression::UintEq(ref e) => write!(f, "{}", e), - BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "({} || {})", lhs, rhs), - BooleanExpression::And(ref lhs, ref rhs) => write!(f, "({} && {})", lhs, rhs), - BooleanExpression::Not(ref exp) => write!(f, "!{}", exp), - BooleanExpression::Value(b) => write!(f, "{}", b), + BooleanExpression::Or(ref e) => write!(f, "{}", e), + BooleanExpression::And(ref e) => write!(f, "{}", e), + BooleanExpression::Not(ref exp) => write!(f, "{}", exp), + BooleanExpression::Value(ref b) => write!(f, "{}", b), BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call), BooleanExpression::Conditional(ref c) => write!(f, "{}", c), BooleanExpression::Member(ref m) => write!(f, "{}", m), @@ -1830,12 +2195,8 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> { ArrayExpressionInner::Conditional(ref c) => write!(f, "{}", c), ArrayExpressionInner::Member(ref m) => write!(f, "{}", m), ArrayExpressionInner::Select(ref select) => write!(f, "{}", select), - ArrayExpressionInner::Slice(ref a, ref from, ref to) => { - write!(f, "{}[{}..{}]", a, from, to) - } - ArrayExpressionInner::Repeat(ref e, ref count) => { - write!(f, "[{}; {}]", e, count) - } + ArrayExpressionInner::Slice(ref e) => write!(f, "{}", e), + ArrayExpressionInner::Repeat(ref e) => write!(f, "{}", e), ArrayExpressionInner::Element(ref element) => write!(f, "{}", element), } } @@ -1848,9 +2209,7 @@ impl<'ast, T: Field> From> for TypedExpression<'ast, T> { match v.get_type() { Type::FieldElement => FieldElementExpression::identifier(v.id).into(), Type::Boolean => BooleanExpression::identifier(v.id).into(), - Type::Array(ty) => ArrayExpression::identifier(v.id) - .annotate(*ty.ty, *ty.size) - .into(), + Type::Array(ty) => ArrayExpression::identifier(v.id).annotate(ty).into(), Type::Struct(ty) => StructExpression::identifier(v.id).annotate(ty).into(), Type::Tuple(ty) => TupleExpression::identifier(v.id).annotate(ty).into(), Type::Uint(w) => UExpression::identifier(v.id).annotate(w).into(), @@ -1859,10 +2218,507 @@ impl<'ast, T: Field> From> for TypedExpression<'ast, T> { } } +// TODO: MACROS + +impl<'ast, T> WithSpan for TypedExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use TypedExpression::*; + match self { + Boolean(e) => Boolean(e.span(span)), + FieldElement(e) => FieldElement(e.span(span)), + Uint(e) => Uint(e.span(span)), + Array(e) => Array(e.span(span)), + Struct(e) => Struct(e.span(span)), + Tuple(e) => Tuple(e.span(span)), + Int(e) => Int(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use TypedExpression::*; + match self { + Boolean(e) => e.get_span(), + FieldElement(e) => e.get_span(), + Uint(e) => e.get_span(), + Array(e) => e.get_span(), + Struct(e) => e.get_span(), + Tuple(e) => e.get_span(), + Int(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for FieldElementExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use FieldElementExpression::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Add(e) => Add(e.span(span)), + Value(e) => Value(e.span(span)), + Mult(e) => Mult(e.span(span)), + Sub(e) => Sub(e.span(span)), + Pow(e) => Pow(e.span(span)), + Div(e) => Div(e.span(span)), + Pos(e) => Pos(e.span(span)), + Neg(e) => Neg(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + Xor(e) => Xor(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use FieldElementExpression::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Add(e) => e.get_span(), + Value(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Pow(e) => e.get_span(), + Neg(e) => e.get_span(), + Pos(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + Xor(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for BooleanExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use BooleanExpression::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + FieldLt(e) => FieldLt(e.span(span)), + FieldLe(e) => FieldLe(e.span(span)), + UintLt(e) => UintLt(e.span(span)), + UintLe(e) => UintLe(e.span(span)), + FieldEq(e) => FieldEq(e.span(span)), + BoolEq(e) => BoolEq(e.span(span)), + ArrayEq(e) => ArrayEq(e.span(span)), + StructEq(e) => StructEq(e.span(span)), + TupleEq(e) => TupleEq(e.span(span)), + UintEq(e) => UintEq(e.span(span)), + Or(e) => Or(e.span(span)), + And(e) => And(e.span(span)), + Not(e) => Not(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use BooleanExpression::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + FieldLt(e) => e.get_span(), + FieldLe(e) => e.get_span(), + UintLt(e) => e.get_span(), + UintLe(e) => e.get_span(), + FieldEq(e) => e.get_span(), + BoolEq(e) => e.get_span(), + ArrayEq(e) => e.get_span(), + StructEq(e) => e.get_span(), + TupleEq(e) => e.get_span(), + UintEq(e) => e.get_span(), + Or(e) => e.get_span(), + And(e) => e.get_span(), + Not(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for UExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use UExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + Add(e) => Add(e.span(span)), + Sub(e) => Sub(e.span(span)), + FloorSub(e) => FloorSub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Div(e) => Div(e.span(span)), + Rem(e) => Rem(e.span(span)), + Xor(e) => Xor(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + Not(e) => Not(e.span(span)), + Neg(e) => Neg(e.span(span)), + Pos(e) => Pos(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use UExpressionInner::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + Add(e) => e.get_span(), + Sub(e) => e.get_span(), + FloorSub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Rem(e) => e.get_span(), + Xor(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + Not(e) => e.get_span(), + Neg(e) => e.get_span(), + Pos(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for ArrayExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use ArrayExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + Slice(e) => Slice(e.span(span)), + Repeat(e) => Repeat(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use ArrayExpressionInner::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + Slice(e) => e.get_span(), + Repeat(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for StructExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use StructExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use StructExpressionInner::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for TupleExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use TupleExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Block(e) => Block(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + FunctionCall(e) => FunctionCall(e.span(span)), + Member(e) => Member(e.span(span)), + Element(e) => Element(e.span(span)), + Value(e) => Value(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use TupleExpressionInner::*; + match self { + Select(e) => e.get_span(), + Block(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + FunctionCall(e) => e.get_span(), + Member(e) => e.get_span(), + Element(e) => e.get_span(), + Value(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for IntExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use IntExpression::*; + match self { + Conditional(e) => Conditional(e.span(span)), + Select(e) => Select(e.span(span)), + Value(e) => Value(e.span(span)), + Pos(e) => Pos(e.span(span)), + Neg(e) => Neg(e.span(span)), + Not(e) => Not(e.span(span)), + Add(e) => Add(e.span(span)), + Sub(e) => Sub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Div(e) => Div(e.span(span)), + Rem(e) => Rem(e.span(span)), + Pow(e) => Pow(e.span(span)), + Xor(e) => Xor(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use IntExpression::*; + match self { + Conditional(e) => e.get_span(), + Select(e) => e.get_span(), + Value(e) => e.get_span(), + Pos(e) => e.get_span(), + Neg(e) => e.get_span(), + Not(e) => e.get_span(), + Add(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Rem(e) => e.get_span(), + Pow(e) => e.get_span(), + Xor(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for TupleExpression<'ast, T> { + fn span(self, span: Option) -> Self { + TupleExpression { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T> WithSpan for StructExpression<'ast, T> { + fn span(self, span: Option) -> Self { + StructExpression { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T> WithSpan for ArrayExpression<'ast, T> { + fn span(self, span: Option) -> Self { + ArrayExpression { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T> WithSpan for UExpression<'ast, T> { + fn span(self, span: Option) -> Self { + UExpression { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T, E> WithSpan for ConditionalExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl WithSpan for ValueExpression { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for SelectExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for ElementExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for MemberExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for FunctionCallExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T, E> WithSpan for BlockExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T: Clone> Value for FieldElementExpression<'ast, T> { + type Value = T; +} + +impl<'ast, T> Value for BooleanExpression<'ast, T> { + type Value = bool; +} + +impl<'ast, T> Value for UExpression<'ast, T> { + type Value = u128; +} + +impl<'ast, T: Clone> Value for ArrayExpression<'ast, T> { + type Value = Vec>; +} + +impl<'ast, T: Clone> Value for StructExpression<'ast, T> { + type Value = Vec>; +} + +impl<'ast, T: Clone> Value for TupleExpression<'ast, T> { + type Value = Vec>; +} + +impl<'ast, T> Value for IntExpression<'ast, T> { + type Value = BigUint; +} + // Common behaviour across expressions -pub trait Expr<'ast, T>: fmt::Display + From> { - type Inner; +pub trait Expr<'ast, T>: Value + WithSpan + fmt::Display + From> { + type Inner: WithSpan; type Ty: Clone + IntoType>; type ConcreteTy: Clone + IntoType; @@ -1873,9 +2729,11 @@ pub trait Expr<'ast, T>: fmt::Display + From> { fn as_inner(&self) -> &Self::Inner; fn as_inner_mut(&mut self) -> &mut Self::Inner; + + fn value(_: Self::Value) -> Self::Inner; } -impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; type ConcreteTy = ConcreteType; @@ -1895,9 +2753,13 @@ impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: ::Value) -> Self::Inner { + Self::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for BooleanExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; type ConcreteTy = ConcreteType; @@ -1917,9 +2779,13 @@ impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: ::Value) -> Self::Inner { + Self::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for UExpression<'ast, T> { type Inner = UExpressionInner<'ast, T>; type Ty = UBitwidth; type ConcreteTy = UBitwidth; @@ -1939,9 +2805,13 @@ impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + UExpressionInner::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for StructExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for StructExpression<'ast, T> { type Inner = StructExpressionInner<'ast, T>; type Ty = StructType<'ast, T>; type ConcreteTy = ConcreteStructType; @@ -1961,9 +2831,13 @@ impl<'ast, T: Field> Expr<'ast, T> for StructExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + StructExpressionInner::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for ArrayExpression<'ast, T> { +impl<'ast, T: Clone + fmt::Display> Expr<'ast, T> for ArrayExpression<'ast, T> { type Inner = ArrayExpressionInner<'ast, T>; type Ty = ArrayType<'ast, T>; type ConcreteTy = ConcreteArrayType; @@ -1983,9 +2857,13 @@ impl<'ast, T: Field> Expr<'ast, T> for ArrayExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + ArrayExpressionInner::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for TupleExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for TupleExpression<'ast, T> { type Inner = TupleExpressionInner<'ast, T>; type Ty = TupleType<'ast, T>; type ConcreteTy = ConcreteTupleType; @@ -2005,9 +2883,13 @@ impl<'ast, T: Field> Expr<'ast, T> for TupleExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + TupleExpressionInner::Value(ValueExpression::new(v)) + } } -impl<'ast, T: Field> Expr<'ast, T> for IntExpression<'ast, T> { +impl<'ast, T: fmt::Display + Clone> Expr<'ast, T> for IntExpression<'ast, T> { type Inner = Self; type Ty = Type<'ast, T>; type ConcreteTy = ConcreteType; @@ -2027,6 +2909,10 @@ impl<'ast, T: Field> Expr<'ast, T> for IntExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: as Value>::Value) -> Self { + IntExpression::Value(ValueExpression::new(v)) + } } // Enums types to enable returning e.g a member expression OR another type of expression of this type @@ -2039,20 +2925,19 @@ pub enum SelectOrExpression<'ast, T, E: Expr<'ast, T>> { Select(SelectExpression<'ast, T, E>), Expression(E::Inner), } - -pub enum EqOrBoolean<'ast, T, E> { - Eq(EqExpression), - Boolean(BooleanExpression<'ast, T>), -} - pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> { Member(MemberExpression<'ast, T, E>), Expression(E::Inner), } -pub enum IdentifierOrExpression<'ast, T, E: Expr<'ast, T>> { - Identifier(IdentifierExpression<'ast, E>), - Expression(E::Inner), +pub enum RepeatOrExpression<'ast, T> { + Repeat(RepeatExpression<'ast, T>), + Expression(ArrayExpressionInner<'ast, T>), +} + +pub enum SliceOrExpression<'ast, T> { + Slice(SliceExpression<'ast, T>), + Expression(ArrayExpressionInner<'ast, T>), } pub enum ElementOrExpression<'ast, T, E: Expr<'ast, T>> { @@ -2065,6 +2950,22 @@ pub enum ConditionalOrExpression<'ast, T, E: Expr<'ast, T>> { Expression(E::Inner), } +pub trait Basic<'ast, T> { + type ZirExpressionType: WithSpan + Value + From>; +} + +impl<'ast, T: Clone> Basic<'ast, T> for FieldElementExpression<'ast, T> { + type ZirExpressionType = crate::zir::FieldElementExpression<'ast, T>; +} + +impl<'ast, T> Basic<'ast, T> for BooleanExpression<'ast, T> { + type ZirExpressionType = crate::zir::BooleanExpression<'ast, T>; +} + +impl<'ast, T> Basic<'ast, T> for UExpression<'ast, T> { + type ZirExpressionType = crate::zir::UExpression<'ast, T>; +} + pub trait Conditional<'ast, T> { fn conditional( condition: BooleanExpression<'ast, T>, @@ -2141,22 +3042,21 @@ impl<'ast, T> Conditional<'ast, T> for UExpression<'ast, T> { } } -impl<'ast, T: Clone> Conditional<'ast, T> for ArrayExpression<'ast, T> { +impl<'ast, T: Clone + fmt::Display> Conditional<'ast, T> for ArrayExpression<'ast, T> { fn conditional( condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self, kind: ConditionalKind, ) -> Self { - let ty = consequence.inner_type().clone(); - let size = consequence.size(); + let ty = consequence.ty().clone(); ArrayExpressionInner::Conditional(ConditionalExpression::new( condition, consequence, alternative, kind, )) - .annotate(ty, size) + .annotate(ty) } } @@ -2245,13 +3145,9 @@ impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> { fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { - let (ty, size) = match array.inner_type() { - Type::Array(array_type) => (array_type.ty.clone(), array_type.size.clone()), - _ => unreachable!(), - }; + let array_ty = array.inner_type().clone().try_into().unwrap(); - ArrayExpressionInner::Select(SelectExpression::new(array, index.into())) - .annotate(*ty, *size) + ArrayExpressionInner::Select(SelectExpression::new(array, index.into())).annotate(array_ty) } } @@ -2295,56 +3191,60 @@ impl<'ast, T> Member<'ast, T> for BooleanExpression<'ast, T> { impl<'ast, T: Clone> Member<'ast, T> for UExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - let ty = s.ty().members.iter().find(|member| id == member.id); - let bitwidth = match ty { - Some(crate::typed::types::StructMember { - ty: box Type::Uint(bitwidth), - .. - }) => *bitwidth, - _ => unreachable!(), - }; + let ty = *s + .ty() + .members + .iter() + .find(|member| id == member.id) + .unwrap() + .ty + .clone(); + let bitwidth: UBitwidth = ty.try_into().unwrap(); UExpressionInner::Member(MemberExpression::new(s, id)).annotate(bitwidth) } } impl<'ast, T: Clone> Member<'ast, T> for ArrayExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - let ty = s.ty().members.iter().find(|member| id == member.id); - let (ty, size) = match ty { - Some(crate::typed::types::StructMember { - ty: box Type::Array(array_ty), - .. - }) => (*array_ty.ty.clone(), array_ty.size.clone()), - _ => unreachable!(), - }; - ArrayExpressionInner::Member(MemberExpression::new(s, id)).annotate(ty, *size) + let ty = *s + .ty() + .members + .iter() + .find(|member| id == member.id) + .unwrap() + .ty + .clone(); + let array_ty: ArrayType<'ast, T> = ty.try_into().unwrap(); + ArrayExpressionInner::Member(MemberExpression::new(s, id)).annotate(array_ty) } } impl<'ast, T: Clone> Member<'ast, T> for StructExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - let ty = s.ty().members.iter().find(|member| id == member.id); - let struct_ty = match ty { - Some(crate::typed::types::StructMember { - ty: box Type::Struct(struct_ty), - .. - }) => struct_ty.clone(), - _ => unreachable!(), - }; + let ty = *s + .ty() + .members + .iter() + .find(|member| id == member.id) + .unwrap() + .ty + .clone(); + let struct_ty = ty.try_into().unwrap(); StructExpressionInner::Member(MemberExpression::new(s, id)).annotate(struct_ty) } } impl<'ast, T: Clone> Member<'ast, T> for TupleExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, id: MemberId) -> Self { - let ty = s.ty().members.iter().find(|member| id == member.id); - let tuple_ty = match ty { - Some(crate::typed::types::StructMember { - ty: box Type::Tuple(tuple_ty), - .. - }) => tuple_ty.clone(), - _ => unreachable!(), - }; + let ty = *s + .ty() + .members + .iter() + .find(|member| id == member.id) + .unwrap() + .ty + .clone(); + let tuple_ty = ty.try_into().unwrap(); TupleExpressionInner::Member(MemberExpression::new(s, id)).annotate(tuple_ty) } } @@ -2378,12 +3278,9 @@ impl<'ast, T: Clone> Element<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Clone> Element<'ast, T> for ArrayExpression<'ast, T> { fn element(s: TupleExpression<'ast, T>, id: u32) -> Self { - let ty = &s.ty().elements[id as usize]; - let (ty, size) = match ty { - Type::Array(array_ty) => (*array_ty.ty.clone(), array_ty.size.clone()), - _ => unreachable!(), - }; - ArrayExpressionInner::Element(ElementExpression::new(s, id)).annotate(ty, *size) + let ty = s.ty().elements[id as usize].clone(); + let array_ty = ty.try_into().unwrap(); + ArrayExpressionInner::Element(ElementExpression::new(s, id)).annotate(array_ty) } } @@ -2543,8 +3440,7 @@ impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> { fn block(statements: Vec>, value: Self) -> Self { let array_ty = value.ty().clone(); - ArrayExpressionInner::Block(BlockExpression::new(statements, value)) - .annotate(*array_ty.ty, *array_ty.size) + ArrayExpressionInner::Block(BlockExpression::new(statements, value)).annotate(array_ty) } } @@ -2576,7 +3472,7 @@ pub trait Constant: Sized { impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> { fn is_constant(&self) -> bool { - matches!(self, FieldElementExpression::Number(..)) + matches!(self, FieldElementExpression::Value(..)) } } @@ -2595,16 +3491,14 @@ impl<'ast, T: Field> Constant for UExpression<'ast, T> { impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { fn is_constant(&self) -> bool { match self.as_inner() { - ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e { + ArrayExpressionInner::Value(v) => v.iter().all(|e| match e { TypedExpressionOrSpread::Expression(e) => e.is_constant(), TypedExpressionOrSpread::Spread(s) => s.array.is_constant(), }), - ArrayExpressionInner::Slice(box a, box from, box to) => { - from.is_constant() && to.is_constant() && a.is_constant() - } - ArrayExpressionInner::Repeat(box e, box count) => { - count.is_constant() && e.is_constant() + ArrayExpressionInner::Slice(e) => { + e.from.is_constant() && e.to.is_constant() && e.array.is_constant() } + ArrayExpressionInner::Repeat(e) => e.count.is_constant() && e.e.is_constant(), _ => false, } } @@ -2620,35 +3514,35 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { .into_iter() .flat_map(into_canonical_constant_aux) .collect(), - ArrayExpressionInner::Slice(box v, box from, box to) => { - let from = match from.into_inner() { + ArrayExpressionInner::Slice(e) => { + let from = match e.from.into_inner() { UExpressionInner::Value(v) => v, _ => unreachable!(), }; - let to = match to.into_inner() { + let to = match e.to.into_inner() { UExpressionInner::Value(v) => v, _ => unreachable!(), }; - let v = match v.into_inner() { + let v = match e.array.into_inner() { ArrayExpressionInner::Value(v) => v, _ => unreachable!(), }; v.into_iter() .flat_map(into_canonical_constant_aux) - .skip(from as usize) - .take(to as usize - from as usize) + .skip(from.value as usize) + .take(to.value as usize - from.value as usize) .collect() } - ArrayExpressionInner::Repeat(box e, box count) => { - let count = match count.into_inner() { + ArrayExpressionInner::Repeat(e) => { + let count = match e.count.into_inner() { UExpressionInner::Value(count) => count, _ => unreachable!(), }; - vec![e.into_canonical_constant(); count as usize] + vec![e.e.into_canonical_constant(); count.value as usize] } a => unreachable!("{}", a), }, @@ -2658,53 +3552,49 @@ impl<'ast, T: Field> Constant for ArrayExpression<'ast, T> { let array_ty = self.ty().clone(); match self.into_inner() { - ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( + ArrayExpressionInner::Value(v) => ArrayExpression::value( v.into_iter() .flat_map(into_canonical_constant_aux) .map(|e| e.into()) - .collect::>() - .into(), + .collect::>(), ) - .annotate(*array_ty.ty, *array_ty.size), - ArrayExpressionInner::Slice(box a, box from, box to) => { - let from = match from.into_inner() { - UExpressionInner::Value(from) => from as usize, + .annotate(array_ty), + ArrayExpressionInner::Slice(e) => { + let from = match e.from.into_inner() { + UExpressionInner::Value(from) => from.value as usize, _ => unreachable!("should be a uint value"), }; - let to = match to.into_inner() { - UExpressionInner::Value(to) => to as usize, + let to = match e.to.into_inner() { + UExpressionInner::Value(to) => to.value as usize, _ => unreachable!("should be a uint value"), }; - let v = match a.into_inner() { + let v = match e.array.into_inner() { ArrayExpressionInner::Value(v) => v, _ => unreachable!("should be an array value"), }; - ArrayExpressionInner::Value( + ArrayExpression::value( v.into_iter() .flat_map(into_canonical_constant_aux) .map(|e| e.into()) .skip(from) .take(to - from) - .collect::>() - .into(), + .collect::>(), ) - .annotate(*array_ty.ty, *array_ty.size) + .annotate(array_ty) } - ArrayExpressionInner::Repeat(box e, box count) => { - let count = match count.into_inner() { - UExpressionInner::Value(from) => from as usize, + ArrayExpressionInner::Repeat(e) => { + let count = match e.count.into_inner() { + UExpressionInner::Value(from) => from.value as usize, _ => unreachable!("should be a uint value"), }; - let e = e.into_canonical_constant(); + let e = e.e.into_canonical_constant(); - ArrayExpressionInner::Value( - vec![TypedExpressionOrSpread::Expression(e); count].into(), - ) - .annotate(*array_ty.ty, *array_ty.size) + ArrayExpression::value(vec![TypedExpressionOrSpread::Expression(e); count]) + .annotate(array_ty) } _ => unreachable!(), } @@ -2723,7 +3613,7 @@ impl<'ast, T: Field> Constant for StructExpression<'ast, T> { let struct_ty = self.ty().clone(); match self.into_inner() { - StructExpressionInner::Value(expressions) => StructExpressionInner::Value( + StructExpressionInner::Value(expressions) => StructExpression::value( expressions .into_iter() .map(|e| e.into_canonical_constant()) @@ -2747,7 +3637,7 @@ impl<'ast, T: Field> Constant for TupleExpression<'ast, T> { let tuple_ty = self.ty().clone(); match self.into_inner() { - TupleExpressionInner::Value(expressions) => TupleExpressionInner::Value( + TupleExpressionInner::Value(expressions) => TupleExpression::value( expressions .into_iter() .map(|e| e.into_canonical_constant()) diff --git a/zokrates_ast/src/typed/parameter.rs b/zokrates_ast/src/typed/parameter.rs index 45b0dcae4..4b8332eb1 100644 --- a/zokrates_ast/src/typed/parameter.rs +++ b/zokrates_ast/src/typed/parameter.rs @@ -1,33 +1,5 @@ use crate::typed::types::DeclarationConstant; use crate::typed::GVariable; -use std::fmt; - -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct GParameter<'ast, S> { - pub id: GVariable<'ast, S>, - pub private: bool, -} - -impl<'ast, S> From> for GParameter<'ast, S> { - fn from(v: GVariable<'ast, S>) -> Self { - GParameter { - id: v, - private: true, - } - } -} +pub type GParameter<'ast, S> = crate::common::Parameter>; pub type DeclarationParameter<'ast, T> = GParameter<'ast, DeclarationConstant<'ast, T>>; - -impl<'ast, S: fmt::Display> fmt::Display for GParameter<'ast, S> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let visibility = if self.private { "private " } else { "" }; - write!(f, "{}{} {}", visibility, self.id._type, self.id.id) - } -} - -impl<'ast, S: fmt::Debug> fmt::Debug for GParameter<'ast, S> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Parameter(variable: {:?})", self.id) - } -} diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index 8ed911314..d8541add0 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -1,47 +1,57 @@ // Generic walk through a typed AST. Not mutating in place +use crate::common::expressions::{ + BinaryOrExpression, EqExpression, UnaryOrExpression, ValueOrExpression, +}; +use crate::common::ResultFold; use crate::typed::types::*; use crate::typed::*; use zokrates_field::Field; use super::identifier::FrameIdentifier; -pub trait ResultFold<'ast, T: Field>: Sized { - fn fold>(self, f: &mut F) -> Result; -} - -impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for FieldElementExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_field_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for BooleanExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for BooleanExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_boolean_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for UExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold for UExpression<'ast, T> { + fn fold(self, f: &mut F) -> Result { f.fold_uint_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for ArrayExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { - f.fold_array_expression(self) +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for StructExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { + f.fold_struct_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for StructExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { - f.fold_struct_expression(self) +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for ArrayExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { + f.fold_array_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for TupleExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for TupleExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_tuple_expression(self) } } @@ -153,7 +163,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized { }) } - fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> Result { + fn fold_module_id(&mut self, i: OwnedModuleId) -> Result { Ok(i) } @@ -170,22 +180,19 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } fn fold_variable(&mut self, v: Variable<'ast, T>) -> Result, Self::Error> { - Ok(Variable { - id: self.fold_name(v.id)?, - _type: self.fold_type(v._type)?, - is_mutable: v.is_mutable, - }) + let span = v.get_span(); + Ok(Variable::new(self.fold_name(v.id)?, self.fold_type(v.ty)?).span(span)) } fn fold_declaration_variable( &mut self, v: DeclarationVariable<'ast, T>, ) -> Result, Self::Error> { - Ok(DeclarationVariable { - id: self.fold_name(v.id)?, - _type: self.fold_declaration_type(v._type)?, - is_mutable: v.is_mutable, - }) + let span = v.get_span(); + Ok( + DeclarationVariable::new(self.fold_name(v.id)?, self.fold_declaration_type(v.ty)?) + .span(span), + ) } fn fold_type(&mut self, t: Type<'ast, T>) -> Result, Self::Error> { @@ -200,7 +207,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } fn fold_conditional_expression< - E: Expr<'ast, T> + PartialEq + Conditional<'ast, T> + ResultFold<'ast, T>, + E: Expr<'ast, T> + PartialEq + Conditional<'ast, T> + ResultFold, >( &mut self, ty: &E::Ty, @@ -209,14 +216,55 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_conditional_expression(self, ty, e) } - fn fold_block_expression>( + #[allow(clippy::type_complexity)] + fn fold_binary_expression< + L: Expr<'ast, T> + PartialEq + ResultFold, + R: Expr<'ast, T> + PartialEq + ResultFold, + E: Expr<'ast, T> + PartialEq + ResultFold, + Op: OperatorStr, + >( + &mut self, + ty: &E::Ty, + e: BinaryExpression, + ) -> Result, Self::Error> { + fold_binary_expression(self, ty, e) + } + + #[allow(clippy::type_complexity)] + fn fold_eq_expression< + E: Expr<'ast, T> + Constant + Typed<'ast, T> + PartialEq + ResultFold, + >( + &mut self, + e: EqExpression>, + ) -> Result< + BinaryOrExpression, BooleanExpression<'ast, T>>, + Self::Error, + > { + fold_binary_expression(self, &Type::Boolean, e) + } + + fn fold_unary_expression< + In: Expr<'ast, T> + PartialEq + ResultFold, + E: Expr<'ast, T> + PartialEq + ResultFold, + Op, + >( + &mut self, + ty: &E::Ty, + e: UnaryExpression, + ) -> Result, Self::Error> { + fold_unary_expression(self, ty, e) + } + + fn fold_block_expression>( &mut self, block: BlockExpression<'ast, T, E>, ) -> Result, Self::Error> { fold_block_expression(self, block) } - fn fold_identifier_expression + Id<'ast, T> + ResultFold<'ast, T>>( + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + ResultFold, + >( &mut self, ty: &E::Ty, id: IdentifierExpression<'ast, E>, @@ -234,6 +282,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_member_expression(self, ty, e) } + fn fold_slice_expression( + &mut self, + e: SliceExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_slice_expression(self, e) + } + + fn fold_repeat_expression( + &mut self, + e: RepeatExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_repeat_expression(self, e) + } + fn fold_element_expression< E: Expr<'ast, T> + Element<'ast, T> + From>, >( @@ -244,15 +306,6 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_element_expression(self, ty, e) } - fn fold_eq_expression< - E: Expr<'ast, T> + Typed<'ast, T> + PartialEq + Constant + ResultFold<'ast, T>, - >( - &mut self, - e: EqExpression, - ) -> Result, Self::Error> { - fold_eq_expression(self, e) - } - fn fold_select_expression< E: Expr<'ast, T> + Select<'ast, T> @@ -281,8 +334,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { t: ArrayType<'ast, T>, ) -> Result, Self::Error> { Ok(ArrayType { - ty: box self.fold_type(*t.ty)?, - size: box self.fold_uint_expression(*t.size)?, + ty: Box::new(self.fold_type(*t.ty)?), + size: Box::new(self.fold_uint_expression(*t.size)?), }) } @@ -314,8 +367,10 @@ pub trait ResultFolder<'ast, T: Field>: Sized { .into_iter() .map(|m| { let id = m.id; - self.fold_type(*m.ty) - .map(|ty| StructMember { ty: box ty, id }) + self.fold_type(*m.ty).map(|ty| StructMember { + ty: Box::new(ty), + id, + }) }) .collect::>()?, ..t @@ -341,8 +396,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { t: DeclarationArrayType<'ast, T>, ) -> Result, Self::Error> { Ok(DeclarationArrayType { - ty: box self.fold_declaration_type(*t.ty)?, - size: box self.fold_declaration_constant(*t.size)?, + ty: Box::new(self.fold_declaration_type(*t.ty)?), + size: Box::new(self.fold_declaration_constant(*t.size)?), }) } @@ -375,7 +430,10 @@ pub trait ResultFolder<'ast, T: Field>: Sized { .map(|m| { let id = m.id; self.fold_declaration_type(*m.ty) - .map(|ty| DeclarationStructMember { ty: box ty, id }) + .map(|ty| DeclarationStructMember { + ty: Box::new(ty), + id, + }) }) .collect::>()?, ..t @@ -389,6 +447,27 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_assignee(self, a) } + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_block(self, s) + } + + fn fold_assembly_assignment( + &mut self, + s: AssemblyAssignment<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_assignment(self, s) + } + + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_constraint(self, s) + } + fn fold_assembly_statement( &mut self, s: TypedAssemblyStatement<'ast, T>, @@ -396,6 +475,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_assembly_statement(self, s) } + fn fold_assembly_statement_cases( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_statement_cases(self, s) + } + fn fold_statement( &mut self, s: TypedStatement<'ast, T>, @@ -403,6 +489,48 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_statement(self, s) } + fn fold_statement_cases( + &mut self, + s: TypedStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_statement_cases(self, s) + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_definition_statement(self, s) + } + + fn fold_return_statement( + &mut self, + s: ReturnStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_return_statement(self, s) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assertion_statement(self, s) + } + + fn fold_log_statement( + &mut self, + s: LogStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_log_statement(self, s) + } + + fn fold_for_statement( + &mut self, + s: ForStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_for_statement(self, s) + } + fn fold_definition_rhs( &mut self, rhs: DefinitionRhs<'ast, T>, @@ -481,12 +609,28 @@ pub trait ResultFolder<'ast, T: Field>: Sized { ) -> Result, Self::Error> { fold_field_expression(self, e) } + + fn fold_field_expression_cases( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_field_expression_cases(self, e) + } + fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, ) -> Result, Self::Error> { fold_boolean_expression(self, e) } + + fn fold_boolean_expression_cases( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_boolean_expression_cases(self, e) + } + fn fold_uint_expression( &mut self, e: UExpression<'ast, T>, @@ -502,6 +646,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_uint_expression_inner(self, bitwidth, e) } + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression_cases(self, bitwidth, e) + } + fn fold_array_expression_inner( &mut self, ty: &ArrayType<'ast, T>, @@ -509,6 +661,15 @@ pub trait ResultFolder<'ast, T: Field>: Sized { ) -> Result, Self::Error> { fold_array_expression_inner(self, ty, e) } + + fn fold_array_expression_cases( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_array_expression_cases(self, ty, e) + } + fn fold_struct_expression_inner( &mut self, ty: &StructType<'ast, T>, @@ -517,6 +678,14 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_struct_expression_inner(self, ty, e) } + fn fold_struct_expression_cases( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_struct_expression_cases(self, ty, e) + } + fn fold_tuple_expression_inner( &mut self, ty: &TupleType<'ast, T>, @@ -524,69 +693,186 @@ pub trait ResultFolder<'ast, T: Field>: Sized { ) -> Result, Self::Error> { fold_tuple_expression_inner(self, ty, e) } + + fn fold_tuple_expression_cases( + &mut self, + ty: &TupleType<'ast, T>, + e: TupleExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_tuple_expression_cases(self, ty, e) + } + + fn fold_struct_value_expression( + &mut self, + ty: &StructType<'ast, T>, + v: StructValueExpression<'ast, T>, + ) -> Result< + ValueOrExpression, StructExpressionInner<'ast, T>>, + Self::Error, + > { + fold_struct_value_expression(self, ty, v) + } + + fn fold_array_value_expression( + &mut self, + ty: &ArrayType<'ast, T>, + v: ArrayValueExpression<'ast, T>, + ) -> Result< + ValueOrExpression, ArrayExpressionInner<'ast, T>>, + Self::Error, + > { + fold_array_value_expression(self, ty, v) + } + + fn fold_tuple_value_expression( + &mut self, + ty: &TupleType<'ast, T>, + v: TupleValueExpression<'ast, T>, + ) -> Result< + ValueOrExpression, TupleExpressionInner<'ast, T>>, + Self::Error, + > { + fold_tuple_value_expression(self, ty, v) + } } -pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( +pub fn fold_statement_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedStatement<'ast, T>, +) -> Result>, F::Error> { + match s { + TypedStatement::Return(s) => f.fold_return_statement(s), + TypedStatement::Definition(s) => f.fold_definition_statement(s), + TypedStatement::Assertion(s) => f.fold_assertion_statement(s), + TypedStatement::For(s) => f.fold_for_statement(s), + TypedStatement::Log(s) => f.fold_log_statement(s), + TypedStatement::Assembly(s) => f.fold_assembly_block(s), + } +} + +pub fn fold_assembly_block<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyBlockStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::Assembly(AssemblyBlockStatement::new( + s.inner + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ))]) +} + +pub fn fold_assembly_assignment<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyAssignment<'ast, T>, +) -> Result>, F::Error> { + let assignee = f.fold_assignee(s.assignee)?; + let expression = f.fold_expression(s.expression)?; + Ok(vec![TypedAssemblyStatement::assignment( + assignee, expression, + )]) +} + +pub fn fold_assembly_constraint<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyConstraint<'ast, T>, +) -> Result>, F::Error> { + let left = f.fold_field_expression(s.left)?; + let right = f.fold_field_expression(s.right)?; + Ok(vec![TypedAssemblyStatement::constraint( + left, right, s.metadata, + )]) +} + +fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedAssemblyStatement<'ast, T>, ) -> Result>, F::Error> { - Ok(match s { - TypedAssemblyStatement::Assignment(a, e) => { - let e = f.fold_expression(e)?; - vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)] - } - TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { - vec![TypedAssemblyStatement::Constraint( - f.fold_field_expression(lhs)?, - f.fold_field_expression(rhs)?, - metadata, - )] - } - }) + let span = s.get_span(); + f.fold_assembly_statement_cases(s) + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_assembly_statement_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedAssemblyStatement<'ast, T>, +) -> Result>, F::Error> { + match s { + TypedAssemblyStatement::Assignment(s) => f.fold_assembly_assignment(s), + TypedAssemblyStatement::Constraint(s) => f.fold_assembly_constraint(s), + } } pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, ) -> Result>, F::Error> { - let res = match s { - TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?), - TypedStatement::Definition(a, e) => { - let rhs = f.fold_definition_rhs(e)?; - TypedStatement::Definition(f.fold_assignee(a)?, rhs) - } - TypedStatement::Assertion(e, error) => { - TypedStatement::Assertion(f.fold_boolean_expression(e)?, error) - } - TypedStatement::For(v, from, to, statements) => TypedStatement::For( - f.fold_variable(v)?, - f.fold_uint_expression(from)?, - f.fold_uint_expression(to)?, - statements - .into_iter() - .map(|s| f.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - ), - TypedStatement::Log(s, e) => TypedStatement::Log( - s, - e.into_iter() - .map(|e| f.fold_expression(e)) - .collect::, _>>()?, - ), - TypedStatement::Assembly(statements) => TypedStatement::Assembly( - statements - .into_iter() - .map(|s| f.fold_assembly_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - ), - }; - Ok(vec![res]) + let span = s.get_span(); + f.fold_statement_cases(s) + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_return_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ReturnStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::Return(ReturnStatement::new( + f.fold_expression(s.inner)?, + ))]) +} + +pub fn fold_definition_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: DefinitionStatement<'ast, T>, +) -> Result>, F::Error> { + let rhs = f.fold_definition_rhs(s.rhs)?; + Ok(vec![TypedStatement::Definition(DefinitionStatement::new( + f.fold_assignee(s.assignee)?, + rhs, + ))]) +} + +pub fn fold_assertion_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssertionStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::Assertion( + AssertionStatement::new(f.fold_boolean_expression(s.expression)?, s.error).span(s.span), + )]) +} + +pub fn fold_for_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ForStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::For(ForStatement::new( + f.fold_variable(s.var)?, + f.fold_uint_expression(s.from)?, + f.fold_uint_expression(s.to)?, + s.statements + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ))]) +} + +pub fn fold_log_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: LogStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![TypedStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ))]) } pub fn fold_definition_rhs<'ast, T: Field, F: ResultFolder<'ast, T>>( @@ -613,7 +899,16 @@ pub fn fold_embed_call<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_array_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_array_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, @@ -626,16 +921,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( IdentifierOrExpression::Identifier(i) => ArrayExpressionInner::Identifier(i), IdentifierOrExpression::Expression(u) => u, }, - Value(exprs) => Value( - exprs - .into_iter() - .map(|e| f.fold_expression_or_spread(e)) - .collect::>()?, - ), FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), FunctionCallOrExpression::Expression(u) => u, }, + Value(value) => match f.fold_array_value_expression(ty, value)? { + ValueOrExpression::Value(c) => Value(c), + ValueOrExpression::Expression(u) => u, + }, Conditional(c) => match f.fold_conditional_expression(ty, c)? { ConditionalOrExpression::Conditional(c) => Conditional(c), ConditionalOrExpression::Expression(u) => u, @@ -648,17 +941,14 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( SelectOrExpression::Select(m) => Select(m), SelectOrExpression::Expression(u) => u, }, - Slice(box array, box from, box to) => { - let array = f.fold_array_expression(array)?; - let from = f.fold_uint_expression(from)?; - let to = f.fold_uint_expression(to)?; - Slice(box array, box from, box to) - } - Repeat(box e, box count) => { - let e = f.fold_expression(e)?; - let count = f.fold_uint_expression(count)?; - Repeat(box e, box count) - } + Slice(s) => match f.fold_slice_expression(s)? { + SliceOrExpression::Slice(m) => Slice(m), + SliceOrExpression::Expression(u) => u, + }, + Repeat(s) => match f.fold_repeat_expression(s)? { + RepeatOrExpression::Repeat(m) => Repeat(m), + RepeatOrExpression::Expression(u) => u, + }, Element(element) => match f.fold_element_expression(ty, element)? { ElementOrExpression::Element(m) => Element(m), ElementOrExpression::Expression(u) => u, @@ -673,18 +963,78 @@ pub fn fold_assignee<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result, F::Error> { match a { TypedAssignee::Identifier(v) => Ok(TypedAssignee::Identifier(f.fold_variable(v)?)), - TypedAssignee::Select(box a, box index) => Ok(TypedAssignee::Select( - box f.fold_assignee(a)?, - box f.fold_uint_expression(index)?, + TypedAssignee::Select(a, index) => Ok(TypedAssignee::Select( + Box::new(f.fold_assignee(*a)?), + Box::new(f.fold_uint_expression(*index)?), + )), + TypedAssignee::Member(s, m) => Ok(TypedAssignee::Member(Box::new(f.fold_assignee(*s)?), m)), + TypedAssignee::Element(s, index) => Ok(TypedAssignee::Element( + Box::new(f.fold_assignee(*s)?), + index, )), - TypedAssignee::Member(box s, m) => Ok(TypedAssignee::Member(box f.fold_assignee(s)?, m)), - TypedAssignee::Element(box s, index) => { - Ok(TypedAssignee::Element(box f.fold_assignee(s)?, index)) - } } } -pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +pub fn fold_struct_value_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &StructType<'ast, T>, + a: StructValueExpression<'ast, T>, +) -> Result< + ValueOrExpression, StructExpressionInner<'ast, T>>, + F::Error, +> { + Ok(ValueOrExpression::Value(StructValueExpression { + value: a + .value + .into_iter() + .map(|v| f.fold_expression(v)) + .collect::, _>>()?, + ..a + })) +} + +pub fn fold_array_value_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &ArrayType<'ast, T>, + a: ArrayValueExpression<'ast, T>, +) -> Result, ArrayExpressionInner<'ast, T>>, F::Error> +{ + Ok(ValueOrExpression::Value(ArrayValueExpression { + value: a + .value + .into_iter() + .map(|v| f.fold_expression_or_spread(v)) + .collect::, _>>()?, + ..a + })) +} + +pub fn fold_tuple_value_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &TupleType<'ast, T>, + a: TupleValueExpression<'ast, T>, +) -> Result, TupleExpressionInner<'ast, T>>, F::Error> +{ + Ok(ValueOrExpression::Value(TupleValueExpression { + value: a + .value + .into_iter() + .map(|v| f.fold_expression(v)) + .collect::, _>>()?, + ..a + })) +} + +fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_struct_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_struct_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, @@ -697,12 +1047,10 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)?), - Value(exprs) => Value( - exprs - .into_iter() - .map(|e| f.fold_expression(e)) - .collect::>()?, - ), + Value(value) => match f.fold_struct_value_expression(ty, value)? { + ValueOrExpression::Value(c) => Value(c), + ValueOrExpression::Expression(u) => u, + }, FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), FunctionCallOrExpression::Expression(u) => u, @@ -727,7 +1075,16 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( Ok(e) } -pub fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: &TupleType<'ast, T>, + e: TupleExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_tuple_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_tuple_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: &TupleType<'ast, T>, e: TupleExpressionInner<'ast, T>, @@ -740,12 +1097,10 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( IdentifierOrExpression::Identifier(i) => Identifier(i), IdentifierOrExpression::Expression(u) => u, }, - Value(exprs) => Value( - exprs - .into_iter() - .map(|e| f.fold_expression(e)) - .collect::>()?, - ), + Value(value) => match f.fold_tuple_value_expression(ty, value)? { + ValueOrExpression::Value(c) => Value(c), + ValueOrExpression::Expression(u) => u, + }, FunctionCall(function_call) => match f.fold_function_call_expression(ty, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), FunctionCallOrExpression::Expression(u) => u, @@ -770,7 +1125,15 @@ pub fn fold_tuple_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( Ok(e) } -pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_field_expression_cases(e).map(|e| e.span(span)) +} + +pub fn fold_field_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, ) -> Result, F::Error> { @@ -782,72 +1145,55 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( IdentifierOrExpression::Expression(u) => u, }, Block(block) => Block(f.fold_block_expression(block)?), - Number(n) => Number(n), - Add(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - Add(box e1, box e2) - } - Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - Sub(box e1, box e2) - } - Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - Mult(box e1, box e2) - } - Div(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - Div(box e1, box e2) - } - Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - Pow(box e1, box e2) - } - Neg(box e) => { - let e = f.fold_field_expression(e)?; - - Neg(box e) - } - Pos(box e) => { - let e = f.fold_field_expression(e)?; - - Pos(box e) - } - And(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - And(box left, box right) - } - Or(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - Or(box left, box right) - } - Xor(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - Xor(box left, box right) - } - LeftShift(box e, box by) => { - let e = f.fold_field_expression(e)?; - let by = f.fold_uint_expression(by)?; - - LeftShift(box e, box by) - } - RightShift(box e, box by) => { - let e = f.fold_field_expression(e)?; - let by = f.fold_uint_expression(by)?; - - RightShift(box e, box by) - } + Value(n) => Value(n), + Add(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Pow(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Pow(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Neg(e) => match f.fold_unary_expression(&Type::FieldElement, e)? { + UnaryOrExpression::Unary(e) => Neg(e), + UnaryOrExpression::Expression(u) => u, + }, + Pos(e) => match f.fold_unary_expression(&Type::FieldElement, e)? { + UnaryOrExpression::Unary(e) => Pos(e), + UnaryOrExpression::Expression(u) => u, + }, Conditional(c) => match f.fold_conditional_expression(&Type::FieldElement, c)? { ConditionalOrExpression::Conditional(c) => Conditional(c), ConditionalOrExpression::Expression(u) => u, @@ -881,12 +1227,17 @@ pub fn fold_int_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( unreachable!() } -pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>( +pub fn fold_block_expression< + 'ast, + T: Field, + E: ResultFold, + F: ResultFolder<'ast, T>, +>( f: &mut F, block: BlockExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(BlockExpression { - statements: block + Ok(BlockExpression::new( + block .statements .into_iter() .map(|s| f.fold_statement(s)) @@ -894,8 +1245,9 @@ pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFo .into_iter() .flatten() .collect(), - value: box block.value.fold(f)?, - }) + block.value.fold(f)?, + ) + .span(block.span)) } pub fn fold_conditional_expression< @@ -904,7 +1256,7 @@ pub fn fold_conditional_expression< E: Expr<'ast, T> + Conditional<'ast, T> + PartialEq - + ResultFold<'ast, T> + + ResultFold + From>, F: ResultFolder<'ast, T>, >( @@ -918,7 +1270,44 @@ pub fn fold_conditional_expression< e.consequence.fold(f)?, e.alternative.fold(f)?, e.kind, - ), + ) + .span(e.span), + )) +} + +#[allow(clippy::type_complexity)] +pub fn fold_binary_expression< + 'ast, + T: Field, + L: Expr<'ast, T> + PartialEq + ResultFold + From>, + R: Expr<'ast, T> + PartialEq + ResultFold + From>, + E: Expr<'ast, T> + PartialEq + ResultFold + From>, + F: ResultFolder<'ast, T>, + Op: OperatorStr, +>( + f: &mut F, + _: &E::Ty, + e: BinaryExpression, +) -> Result, F::Error> { + Ok(BinaryOrExpression::Binary( + BinaryExpression::new(e.left.fold(f)?, e.right.fold(f)?).span(e.span), + )) +} + +pub fn fold_unary_expression< + 'ast, + T: Field, + In: Expr<'ast, T> + PartialEq + ResultFold + From>, + E: Expr<'ast, T> + PartialEq + ResultFold + From>, + F: ResultFolder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: UnaryExpression, +) -> Result, F::Error> { + Ok(UnaryOrExpression::Unary( + UnaryExpression::new(e.inner.fold(f)?).span(e.span), )) } @@ -932,9 +1321,29 @@ pub fn fold_member_expression< _: &E::Ty, e: MemberExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(MemberOrExpression::Member(MemberExpression::new( - f.fold_struct_expression(*e.struc)?, - e.id, + Ok(MemberOrExpression::Member( + MemberExpression::new(f.fold_struct_expression(*e.struc)?, e.id).span(e.span), + )) +} + +pub fn fold_slice_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: SliceExpression<'ast, T>, +) -> Result, F::Error> { + Ok(SliceOrExpression::Slice(SliceExpression::new( + f.fold_array_expression(*e.array)?, + f.fold_uint_expression(*e.from)?, + f.fold_uint_expression(*e.to)?, + ))) +} + +pub fn fold_repeat_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: RepeatExpression<'ast, T>, +) -> Result, F::Error> { + Ok(RepeatOrExpression::Repeat(RepeatExpression::new( + f.fold_expression(*e.e)?, + f.fold_uint_expression(*e.count)?, ))) } @@ -949,20 +1358,10 @@ pub fn fold_identifier_expression< e: IdentifierExpression<'ast, E>, ) -> Result, F::Error> { Ok(IdentifierOrExpression::Identifier( - IdentifierExpression::new(f.fold_name(e.id)?), + IdentifierExpression::new(f.fold_name(e.id)?).span(e.span), )) } -pub fn fold_eq_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>( - f: &mut F, - e: EqExpression, -) -> Result, F::Error> { - Ok(EqOrBoolean::Eq(EqExpression::new( - e.left.fold(f)?, - e.right.fold(f)?, - ))) -} - pub fn fold_select_expression< 'ast, T: Field, @@ -976,10 +1375,13 @@ pub fn fold_select_expression< _: &E::Ty, e: SelectExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(SelectOrExpression::Select(SelectExpression::new( - f.fold_array_expression(*e.array)?, - f.fold_uint_expression(*e.index)?, - ))) + Ok(SelectOrExpression::Select( + SelectExpression::new( + f.fold_array_expression(*e.array)?, + f.fold_uint_expression(*e.index)?, + ) + .span(e.span), + )) } pub fn fold_element_expression< @@ -992,10 +1394,9 @@ pub fn fold_element_expression< _: &E::Ty, e: ElementExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(ElementOrExpression::Element(ElementExpression::new( - f.fold_tuple_expression(*e.tuple)?, - e.index, - ))) + Ok(ElementOrExpression::Element( + ElementExpression::new(f.fold_tuple_expression(*e.tuple)?, e.index).span(e.span), + )) } pub fn fold_function_call_expression< @@ -1008,20 +1409,31 @@ pub fn fold_function_call_expression< _: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> Result, F::Error> { - Ok(FunctionCallOrExpression::Expression(E::function_call( - f.fold_declaration_function_key(e.function_key)?, - e.generics - .into_iter() - .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) - .collect::>()?, - e.arguments - .into_iter() - .map(|e| f.fold_expression(e)) - .collect::>()?, - ))) + Ok(FunctionCallOrExpression::Expression( + E::function_call( + f.fold_declaration_function_key(e.function_key)?, + e.generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) + .collect::>()?, + e.arguments + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ) + .span(e.span), + )) } -pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_boolean_expression_cases(e).map(|e| e.span(span)) +} + +pub fn fold_boolean_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, ) -> Result, F::Error> { @@ -1035,83 +1447,57 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( Block(block) => Block(f.fold_block_expression(block)?), Value(v) => Value(v), FieldEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => FieldEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => FieldEq(e), + BinaryOrExpression::Expression(u) => u, }, BoolEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => BoolEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => BoolEq(e), + BinaryOrExpression::Expression(u) => u, }, ArrayEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => ArrayEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => ArrayEq(e), + BinaryOrExpression::Expression(u) => u, }, StructEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => StructEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => StructEq(e), + BinaryOrExpression::Expression(u) => u, }, TupleEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => TupleEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => TupleEq(e), + BinaryOrExpression::Expression(u) => u, }, UintEq(e) => match f.fold_eq_expression(e)? { - EqOrBoolean::Eq(e) => UintEq(e), - EqOrBoolean::Boolean(u) => u, + BinaryOrExpression::Binary(e) => UintEq(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLt(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldLt(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLe(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldLe(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLt(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintLt(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLe(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintLe(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&Type::Boolean, e)? { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, }, - FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldLt(box e1, box e2) - } - FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldLe(box e1, box e2) - } - FieldGt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldGt(box e1, box e2) - } - FieldGe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldGe(box e1, box e2) - } - UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - UintLt(box e1, box e2) - } - UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - UintLe(box e1, box e2) - } - UintGt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - UintGt(box e1, box e2) - } - UintGe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - UintGe(box e1, box e2) - } - Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - Or(box e1, box e2) - } - And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - And(box e1, box e2) - } - Not(box e) => { - let e = f.fold_boolean_expression(e)?; - Not(box e) - } FunctionCall(function_call) => { match f.fold_function_call_expression(&Type::Boolean, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), @@ -1135,6 +1521,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( ElementOrExpression::Expression(u) => u, }, }; + Ok(e) } @@ -1148,7 +1535,16 @@ pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_uint_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_uint_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: UBitwidth, e: UExpressionInner<'ast, T>, @@ -1162,87 +1558,62 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( }, Block(block) => Block(f.fold_block_expression(block)?), Value(v) => Value(v), - Add(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Add(box left, box right) - } - Sub(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Sub(box left, box right) - } - FloorSub(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - FloorSub(box left, box right) - } - Mult(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Mult(box left, box right) - } - Div(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Div(box left, box right) - } - Rem(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Rem(box left, box right) - } - Xor(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Xor(box left, box right) - } - And(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - And(box left, box right) - } - Or(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - Or(box left, box right) - } - LeftShift(box e, box by) => { - let e = f.fold_uint_expression(e)?; - let by = f.fold_uint_expression(by)?; - - LeftShift(box e, box by) - } - RightShift(box e, box by) => { - let e = f.fold_uint_expression(e)?; - let by = f.fold_uint_expression(by)?; - - RightShift(box e, box by) - } - Not(box e) => { - let e = f.fold_uint_expression(e)?; - - Not(box e) - } - Neg(box e) => { - let e = f.fold_uint_expression(e)?; - - Neg(box e) - } - Pos(box e) => { - let e = f.fold_uint_expression(e)?; - - Pos(box e) - } + Add(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + FloorSub(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => FloorSub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Rem(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Rem(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Pos(e) => match f.fold_unary_expression(&ty, e)? { + UnaryOrExpression::Unary(e) => Pos(e), + UnaryOrExpression::Expression(u) => u, + }, + Neg(e) => match f.fold_unary_expression(&ty, e)? { + UnaryOrExpression::Unary(e) => Neg(e), + UnaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&ty, e)? { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, FunctionCall(function_call) => { match f.fold_function_call_expression(&ty, function_call)? { FunctionCallOrExpression::FunctionCall(c) => FunctionCall(c), @@ -1266,6 +1637,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( ElementOrExpression::Expression(u) => u, }, }; + Ok(e) } @@ -1317,7 +1689,7 @@ pub fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>( .into_iter() .map(|o| f.fold_declaration_type(o)) .collect::>()?, - output: box f.fold_declaration_type(*s.output)?, + output: Box::new(f.fold_declaration_type(*s.output)?), }) } @@ -1359,7 +1731,7 @@ pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( Ok(ArrayExpression { inner: f.fold_array_expression_inner(&ty, e.inner)?, - ty: box ty, + ty: Box::new(ty), }) } @@ -1481,5 +1853,6 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>( f.fold_module(module).map(|m| (module_id, m)) }) .collect::>()?, + ..p }) } diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index f2bf23d5b..ece3547e7 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -1,5 +1,6 @@ +use crate::common::expressions::ValueExpression; use crate::typed::{ - CoreIdentifier, Identifier, OwnedTypedModuleId, TypedExpression, UExpression, UExpressionInner, + CoreIdentifier, Identifier, OwnedModuleId, TypedExpression, UExpression, UExpressionInner, }; use crate::typed::{TryFrom, TryInto}; use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; @@ -125,7 +126,7 @@ pub type ConstantIdentifier<'ast> = &'ast str; #[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub struct CanonicalConstantIdentifier<'ast> { - pub module: OwnedTypedModuleId, + pub module: OwnedModuleId, #[serde(borrow)] pub id: ConstantIdentifier<'ast>, } @@ -137,7 +138,7 @@ impl<'ast> fmt::Display for CanonicalConstantIdentifier<'ast> { } impl<'ast> CanonicalConstantIdentifier<'ast> { - pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self { + pub fn new(id: ConstantIdentifier<'ast>, module: OwnedModuleId) -> Self { CanonicalConstantIdentifier { module, id } } } @@ -188,7 +189,7 @@ impl<'ast, T: PartialEq> PartialEq> for DeclarationConstant inner: UExpressionInner::Value(v), .. }, - ) => *c == *v as u32, + ) => *c == v.value as u32, (DeclarationConstant::Expression(TypedExpression::Uint(e0)), e1) => e0 == e1, (DeclarationConstant::Expression(..), _) => false, // type error _ => true, @@ -233,7 +234,7 @@ impl<'ast, T: fmt::Display> fmt::Display for DeclarationConstant<'ast, T> { impl<'ast, T> From for UExpression<'ast, T> { fn from(i: u32) -> Self { - UExpressionInner::Value(i as u128).annotate(UBitwidth::B32) + UExpressionInner::Value(ValueExpression::new(i as u128)).annotate(UBitwidth::B32) } } @@ -245,7 +246,7 @@ impl<'ast, T: Field> From> for UExpression<'ast, T> .annotate(UBitwidth::B32) } DeclarationConstant::Concrete(v) => { - UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) + UExpression::value(v as u128).annotate(UBitwidth::B32) } DeclarationConstant::Constant(v) => { UExpression::identifier(FrameIdentifier::from(v).into()).annotate(UBitwidth::B32) @@ -262,7 +263,7 @@ impl<'ast, T> TryInto for UExpression<'ast, T> { assert_eq!(self.bitwidth, UBitwidth::B32); match self.into_inner() { - UExpressionInner::Value(v) => Ok(v as u32), + UExpressionInner::Value(v) => Ok(v.value as u32), _ => Err(SpecializationError), } } @@ -281,7 +282,7 @@ impl<'ast, T> TryInto for DeclarationConstant<'ast, T> { pub type MemberId = String; -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub struct GStructMember { #[serde(rename = "name")] @@ -305,7 +306,7 @@ fn try_from_g_struct_member, U>( ) -> Result, SpecializationError> { Ok(GStructMember { id: t.id, - ty: box try_from_g_type(*t.ty)?, + ty: Box::new(try_from_g_type(*t.ty)?), }) } @@ -323,7 +324,7 @@ impl<'ast, T> From for StructMember<'ast, T> { } } -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)] pub struct GArrayType { pub size: Box, @@ -371,8 +372,8 @@ fn try_from_g_array_type, U>( t: GArrayType, ) -> Result, SpecializationError> { Ok(GArrayType { - size: box (*t.size).try_into().map_err(|_| SpecializationError)?, - ty: box try_from_g_type(*t.ty)?, + size: Box::new((*t.size).try_into().map_err(|_| SpecializationError)?), + ty: Box::new(try_from_g_type(*t.ty)?), }) } @@ -390,7 +391,7 @@ impl<'ast, T> From for ArrayType<'ast, T> { } } -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Debug)] pub struct GTupleType { pub elements: Vec>, @@ -469,6 +470,42 @@ impl TryFrom> for GTupleType { } } +impl TryFrom> for GStructType { + type Error = (); + + fn try_from(t: GType) -> Result { + if let GType::Struct(t) = t { + Ok(t) + } else { + Err(()) + } + } +} + +impl TryFrom> for GArrayType { + type Error = (); + + fn try_from(t: GType) -> Result { + if let GType::Array(t) = t { + Ok(t) + } else { + Err(()) + } + } +} + +impl TryFrom> for UBitwidth { + type Error = (); + + fn try_from(t: GType) -> Result { + if let GType::Uint(t) = t { + Ok(t) + } else { + Err(()) + } + } +} + #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialOrd, Ord, Eq, PartialEq)] pub struct StructLocation { #[serde(skip)] @@ -642,7 +679,7 @@ impl fmt::Display for UBitwidth { } } -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Clone, Eq, Hash, PartialOrd, Ord, Debug)] pub enum GType { FieldElement, @@ -820,8 +857,8 @@ impl<'ast, T> From for DeclarationType<'ast, T> { impl> From<(GType, U)> for GArrayType { fn from(tup: (GType, U)) -> Self { GArrayType { - ty: box tup.0, - size: box tup.1.into(), + ty: Box::new(tup.0), + size: Box::new(tup.1.into()), } } } @@ -829,8 +866,8 @@ impl> From<(GType, U)> for GArrayType { impl GArrayType { pub fn new>(ty: GType, size: U) -> Self { GArrayType { - ty: box ty, - size: box size.into(), + ty: Box::new(ty), + size: Box::new(size.into()), } } } @@ -920,7 +957,7 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> { match (&l.size.as_inner(), &*r.size) { // compare the sizes for concrete ones (UExpressionInner::Value(v), DeclarationConstant::Concrete(c)) => { - (*v as u32) == *c + (v.value as u32) == *c } _ => true, } @@ -968,7 +1005,7 @@ pub type FunctionIdentifier<'ast> = &'ast str; #[derive(PartialEq, Eq, Hash, Debug, Clone, PartialOrd, Ord)] pub struct GFunctionKey<'ast, S> { - pub module: OwnedTypedModuleId, + pub module: OwnedModuleId, pub id: FunctionIdentifier<'ast>, pub signature: GSignature, } @@ -1046,7 +1083,7 @@ impl<'ast, T> From> for DeclarationFunctionKey<'ast, T } impl<'ast, S> GFunctionKey<'ast, S> { - pub fn with_location, U: Into>>( + pub fn with_location, U: Into>>( module: T, id: U, ) -> Self { @@ -1067,7 +1104,7 @@ impl<'ast, S> GFunctionKey<'ast, S> { self } - pub fn module>(mut self, module: T) -> Self { + pub fn module>(mut self, module: T) -> Self { self.module = module.into(); self } @@ -1100,7 +1137,7 @@ pub fn check_generic<'ast, T, S: Clone + PartialEq + PartialEq>( DeclarationConstant::Constant(..) => true, DeclarationConstant::Expression(e) => match e { TypedExpression::Uint(e) => match e.as_inner() { - UExpressionInner::Value(v) => *value == *v as u32, + UExpressionInner::Value(v) => *value == v.value as u32, _ => true, }, _ => unreachable!(), @@ -1207,8 +1244,12 @@ pub fn specialize_declaration_type< .into_iter() .map(|m| { let id = m.id; - specialize_declaration_type(*m.ty, &inside_generics) - .map(|ty| GStructMember { ty: box ty, id }) + specialize_declaration_type(*m.ty, &inside_generics).map(|ty| { + GStructMember { + ty: Box::new(ty), + id, + } + }) }) .collect::>()?, generics: s0 @@ -1231,7 +1272,7 @@ pub use self::signature::{ }; use super::identifier::FrameIdentifier; -use super::{Id, ShadowedIdentifier}; +use super::{Expr, Id, ShadowedIdentifier}; pub mod signature { use super::*; @@ -1249,7 +1290,7 @@ pub mod signature { Self { generics: vec![], inputs: vec![], - output: box GType::Tuple(GTupleType::new(vec![])), + output: Box::new(GType::Tuple(GTupleType::new(vec![]))), } } } @@ -1398,7 +1439,7 @@ pub mod signature { .into_iter() .map(try_from_g_type) .collect::>()?, - output: box try_from_g_type(*t.output)?, + output: Box::new(try_from_g_type(*t.output)?), }) } diff --git a/zokrates_ast/src/typed/uint.rs b/zokrates_ast/src/typed/uint.rs index c2120fc6f..a04550a4c 100644 --- a/zokrates_ast/src/typed/uint.rs +++ b/zokrates_ast/src/typed/uint.rs @@ -1,5 +1,5 @@ -use crate::typed::types::UBitwidth; use crate::typed::*; +use crate::{common::expressions::UValueExpression, typed::types::UBitwidth}; use std::ops::{Add, Div, Mul, Neg, Not, Rem, Sub}; use zokrates_field::Field; @@ -12,10 +12,12 @@ impl<'ast, T> Add for UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); + // we apply a basic simplification here which enables more precise comparison of array sizes during semantic checking + // this could be done by the caller by calling propagation, but it is deemed simple enough to be done here match (self.as_inner(), other.as_inner()) { - (UExpressionInner::Value(0), _) => other, - (_, UExpressionInner::Value(0)) => self, - _ => UExpressionInner::Add(box self, box other).annotate(bitwidth), + (UExpressionInner::Value(v), _) if v.value == 0 => other, + (_, UExpressionInner::Value(v)) if v.value == 0 => self, + _ => UExpressionInner::Add(BinaryExpression::new(self, other)).annotate(bitwidth), } } } @@ -26,7 +28,7 @@ impl<'ast, T> Sub for UExpression<'ast, T> { fn sub(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Sub(box self, box other).annotate(bitwidth) + UExpressionInner::Sub(BinaryExpression::new(self, other)).annotate(bitwidth) } } @@ -36,7 +38,7 @@ impl<'ast, T> Mul for UExpression<'ast, T> { fn mul(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Mult(box self, box other).annotate(bitwidth) + UExpressionInner::Mult(BinaryExpression::new(self, other)).annotate(bitwidth) } } @@ -46,7 +48,7 @@ impl<'ast, T> Div for UExpression<'ast, T> { fn div(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Div(box self, box other).annotate(bitwidth) + UExpressionInner::Div(BinaryExpression::new(self, other)).annotate(bitwidth) } } @@ -56,7 +58,7 @@ impl<'ast, T> Rem for UExpression<'ast, T> { fn rem(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Rem(box self, box other).annotate(bitwidth) + UExpressionInner::Rem(BinaryExpression::new(self, other)).annotate(bitwidth) } } @@ -65,7 +67,7 @@ impl<'ast, T> Not for UExpression<'ast, T> { fn not(self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::Not(box self).annotate(bitwidth) + UExpressionInner::Not(UnaryExpression::new(self)).annotate(bitwidth) } } @@ -74,7 +76,7 @@ impl<'ast, T> Neg for UExpression<'ast, T> { fn neg(self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::Neg(box self).annotate(bitwidth) + UExpressionInner::Neg(UnaryExpression::new(self)).annotate(bitwidth) } } @@ -82,48 +84,48 @@ impl<'ast, T: Field> UExpression<'ast, T> { pub fn xor(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Xor(box self, box other).annotate(bitwidth) + UExpressionInner::Xor(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn or(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Or(box self, box other).annotate(bitwidth) + UExpressionInner::Or(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn and(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::And(box self, box other).annotate(bitwidth) + UExpressionInner::And(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn pos(self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::Pos(box self).annotate(bitwidth) + UExpressionInner::Pos(UnaryExpression::new(self)).annotate(bitwidth) } pub fn left_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(by.bitwidth, UBitwidth::B32); - UExpressionInner::LeftShift(box self, box by).annotate(bitwidth) + UExpressionInner::LeftShift(BinaryExpression::new(self, by)).annotate(bitwidth) } pub fn right_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(by.bitwidth, UBitwidth::B32); - UExpressionInner::RightShift(box self, box by).annotate(bitwidth) + UExpressionInner::RightShift(BinaryExpression::new(self, by)).annotate(bitwidth) } pub fn floor_sub(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::FloorSub(box self, box other).annotate(bitwidth) + UExpressionInner::FloorSub(BinaryExpression::new(self, other)).annotate(bitwidth) } } -impl<'ast, T: Field> From for UExpressionInner<'ast, T> { +impl<'ast, T> From for UExpressionInner<'ast, T> { fn from(e: u128) -> Self { - UExpressionInner::Value(e) + UExpressionInner::Value(ValueExpression::new(e)) } } @@ -142,20 +144,20 @@ pub struct UExpression<'ast, T> { impl<'ast, T> From for UExpression<'ast, T> { fn from(u: u16) -> Self { - UExpressionInner::Value(u as u128).annotate(UBitwidth::B16) + UExpressionInner::Value(ValueExpression::new(u as u128)).annotate(UBitwidth::B16) } } impl<'ast, T> From for UExpression<'ast, T> { fn from(u: u8) -> Self { - UExpressionInner::Value(u as u128).annotate(UBitwidth::B8) + UExpressionInner::Value(ValueExpression::new(u as u128)).annotate(UBitwidth::B8) } } impl<'ast, T> PartialEq for UExpression<'ast, T> { fn eq(&self, other: &u32) -> bool { match self.as_inner() { - UExpressionInner::Value(v) => *v == *other as u128, + UExpressionInner::Value(v) => v.value == *other as u128, _ => true, } } @@ -165,22 +167,33 @@ impl<'ast, T> PartialEq for UExpression<'ast, T> { pub enum UExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, UExpression<'ast, T>>), Identifier(IdentifierExpression<'ast, UExpression<'ast, T>>), - Value(u128), - Add(Box>, Box>), - Sub(Box>, Box>), - FloorSub(Box>, Box>), - Mult(Box>, Box>), - Div(Box>, Box>), - Rem(Box>, Box>), - Xor(Box>, Box>), - And(Box>, Box>), - Or(Box>, Box>), - Not(Box>), - Neg(Box>), - Pos(Box>), + Value(UValueExpression), + Add(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Sub(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + FloorSub( + BinaryExpression< + OpFloorSub, + UExpression<'ast, T>, + UExpression<'ast, T>, + UExpression<'ast, T>, + >, + ), + Mult(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Div(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Rem(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Xor(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + And(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Or(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Not(UnaryExpression, UExpression<'ast, T>>), + Neg(UnaryExpression, UExpression<'ast, T>>), + Pos(UnaryExpression, UExpression<'ast, T>>), FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>), - LeftShift(Box>, Box>), - RightShift(Box>, Box>), + LeftShift( + BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>, + ), + RightShift( + BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>, + ), Conditional(ConditionalExpression<'ast, T, UExpression<'ast, T>>), Member(MemberExpression<'ast, T, UExpression<'ast, T>>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>), diff --git a/zokrates_ast/src/typed/utils/mod.rs b/zokrates_ast/src/typed/utils/mod.rs index 6ec6885c2..ad2c406f0 100644 --- a/zokrates_ast/src/typed/utils/mod.rs +++ b/zokrates_ast/src/typed/utils/mod.rs @@ -1,13 +1,13 @@ use super::{ - ArrayExpression, ArrayExpressionInner, ArrayValue, BooleanExpression, Conditional, - ConditionalKind, Expr, FieldElementExpression, Id, Identifier, Select, Typed, TypedExpression, - TypedExpressionOrSpread, UBitwidth, UExpression, UExpressionInner, + ArrayExpression, ArrayExpressionInner, BooleanExpression, Conditional, ConditionalKind, Expr, + FieldElementExpression, GArrayType, Id, Identifier, Select, Typed, TypedExpression, + TypedExpressionOrSpread, UBitwidth, UExpression, ValueExpression, }; use zokrates_field::Field; pub fn f<'ast, T, U: TryInto>(v: U) -> FieldElementExpression<'ast, T> { - FieldElementExpression::Number(v.try_into().map_err(|_| ()).unwrap()) + FieldElementExpression::Value(ValueExpression::new(v.try_into().map_err(|_| ()).unwrap())) } pub fn a_id<'ast, T: Field, I: TryInto>>(v: I) -> ArrayExpressionInner<'ast, T> { @@ -16,24 +16,26 @@ pub fn a_id<'ast, T: Field, I: TryInto>>(v: I) -> ArrayExpressi pub fn a< 'ast, - T, + T: Field, E: Typed<'ast, T> + Expr<'ast, T> + Into>, const N: usize, >( values: [E; N], ) -> ArrayExpression<'ast, T> { let ty = values[0].get_type(); - ArrayExpressionInner::Value(ArrayValue( + + let array_ty = GArrayType::new(ty, N as u32); + ArrayExpression::value( values .into_iter() .map(|e| TypedExpressionOrSpread::Expression(e.into())) .collect(), - )) - .annotate(ty, N as u32) + ) + .annotate(array_ty) } -pub fn u_32<'ast, T, U: TryInto>(v: U) -> UExpression<'ast, T> { - UExpressionInner::Value(v.try_into().map_err(|_| ()).unwrap() as u128).annotate(UBitwidth::B32) +pub fn u_32<'ast, T: Field, U: TryInto>(v: U) -> UExpression<'ast, T> { + UExpression::value(v.try_into().map_err(|_| ()).unwrap() as u128).annotate(UBitwidth::B32) } pub fn conditional<'ast, T, E: Conditional<'ast, T>>( diff --git a/zokrates_ast/src/typed/variable.rs b/zokrates_ast/src/typed/variable.rs index d47d48728..0ad50061c 100644 --- a/zokrates_ast/src/typed/variable.rs +++ b/zokrates_ast/src/typed/variable.rs @@ -1,17 +1,11 @@ +use crate::common::WithSpan; use crate::typed::types::{DeclarationConstant, GStructType, UBitwidth}; use crate::typed::types::{GType, SpecializationError}; use crate::typed::Identifier; use crate::typed::UExpression; use crate::typed::{TryFrom, TryInto}; -use std::fmt; - -#[derive(Clone, PartialEq, Hash, Eq, PartialOrd, Ord, Debug)] -pub struct GVariable<'ast, S> { - pub id: Identifier<'ast>, - pub _type: GType, - pub is_mutable: bool, -} +pub type GVariable<'ast, S> = crate::common::Variable, GType>; pub type DeclarationVariable<'ast, T> = GVariable<'ast, DeclarationConstant<'ast, T>>; pub type ConcreteVariable<'ast> = GVariable<'ast, u32>; pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>; @@ -20,84 +14,56 @@ impl<'ast, T> TryFrom> for ConcreteVariable<'ast> { type Error = SpecializationError; fn try_from(v: Variable<'ast, T>) -> Result { - let _type = v._type.try_into()?; + let span = v.get_span(); - Ok(Self { - _type, - id: v.id, - is_mutable: v.is_mutable, - }) + let ty = v.ty.try_into()?; + + Ok(Self::new(v.id, ty).span(span)) } } impl<'ast, T> From> for Variable<'ast, T> { fn from(v: ConcreteVariable<'ast>) -> Self { - let _type = v._type.into(); + let span = v.get_span(); + + let ty = v.ty.into(); - Self { - _type, - id: v.id, - is_mutable: v.is_mutable, - } + Self::new(v.id, ty).span(span) } } pub fn try_from_g_variable, U>( v: GVariable, ) -> Result, SpecializationError> { - let _type = crate::typed::types::try_from_g_type(v._type)?; + let span = v.get_span(); - Ok(GVariable { - _type, - id: v.id, - is_mutable: v.is_mutable, - }) + let ty = crate::typed::types::try_from_g_type(v.ty)?; + + Ok(GVariable::new(v.id, ty).span(span)) } impl<'ast, S: Clone> GVariable<'ast, S> { pub fn field_element>>(id: I) -> Self { - Self::immutable(id, GType::FieldElement) + Self::new(id, GType::FieldElement) } pub fn boolean>>(id: I) -> Self { - Self::immutable(id, GType::Boolean) + Self::new(id, GType::Boolean) } pub fn uint>, W: Into>(id: I, bitwidth: W) -> Self { - Self::immutable(id, GType::uint(bitwidth)) + Self::new(id, GType::uint(bitwidth)) } pub fn array>, U: Into>(id: I, ty: GType, size: U) -> Self { - Self::immutable(id, GType::array((ty, size.into()))) + Self::new(id, GType::array((ty, size.into()))) } pub fn struc>>(id: I, ty: GStructType) -> Self { - Self::immutable(id, GType::Struct(ty)) - } - - pub fn immutable>>(id: I, _type: GType) -> Self { - Self::new(id, _type, false) - } - - pub fn mutable>>(id: I, _type: GType) -> Self { - Self::new(id, _type, true) - } - - pub fn new>>(id: I, _type: GType, is_mutable: bool) -> Self { - GVariable { - id: id.into(), - _type, - is_mutable, - } + Self::new(id, GType::Struct(ty)) } pub fn get_type(&self) -> GType { - self._type.clone() - } -} - -impl<'ast, S: fmt::Display> fmt::Display for GVariable<'ast, S> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self._type, self.id,) + self.ty.clone() } } diff --git a/zokrates_ast/src/untyped/from_ast.rs b/zokrates_ast/src/untyped/from_ast.rs index 88c12d6ce..78e02b7d0 100644 --- a/zokrates_ast/src/untyped/from_ast.rs +++ b/zokrates_ast/src/untyped/from_ast.rs @@ -390,85 +390,85 @@ impl<'ast> From> for untyped::ExpressionNode<'ast> use crate::untyped::NodeValue; match expression.op { pest::BinaryOperator::Add => untyped::Expression::Add( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Sub => untyped::Expression::Sub( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Mul => untyped::Expression::Mult( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Div => untyped::Expression::Div( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Rem => untyped::Expression::Rem( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Eq => untyped::Expression::Eq( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Lt => untyped::Expression::Lt( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Lte => untyped::Expression::Le( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Gt => untyped::Expression::Gt( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Gte => untyped::Expression::Ge( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::And => untyped::Expression::And( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Or => untyped::Expression::Or( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::Pow => untyped::Expression::Pow( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::BitXor => untyped::Expression::BitXor( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::LeftShift => untyped::Expression::LeftShift( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::RightShift => untyped::Expression::RightShift( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::BitAnd => untyped::Expression::BitAnd( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), pest::BinaryOperator::BitOr => untyped::Expression::BitOr( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ), // rewrite (a != b)` as `!(a == b)` - pest::BinaryOperator::NotEq => untyped::Expression::Not( - box untyped::Expression::Eq( - box untyped::ExpressionNode::from(*expression.left), - box untyped::ExpressionNode::from(*expression.right), + pest::BinaryOperator::NotEq => untyped::Expression::Not(Box::new( + untyped::Expression::Eq( + Box::new(untyped::ExpressionNode::from(*expression.left)), + Box::new(untyped::ExpressionNode::from(*expression.right)), ) .span(expression.span.clone()), - ), + )), } .span(expression.span) } @@ -477,22 +477,22 @@ impl<'ast> From> for untyped::ExpressionNode<'ast> impl<'ast> From> for untyped::ExpressionNode<'ast> { fn from(expression: pest::IfElseExpression<'ast>) -> untyped::ExpressionNode<'ast> { use crate::untyped::NodeValue; - untyped::Expression::Conditional(box ConditionalExpression { - condition: box untyped::ExpressionNode::from(*expression.condition), + untyped::Expression::Conditional(Box::new(ConditionalExpression { + condition: Box::new(untyped::ExpressionNode::from(*expression.condition)), consequence_statements: expression .consequence_statements .into_iter() .map(untyped::StatementNode::from) .collect(), - consequence: box untyped::ExpressionNode::from(*expression.consequence), + consequence: Box::new(untyped::ExpressionNode::from(*expression.consequence)), alternative_statements: expression .alternative_statements .into_iter() .map(untyped::StatementNode::from) .collect(), - alternative: box untyped::ExpressionNode::from(*expression.alternative), + alternative: Box::new(untyped::ExpressionNode::from(*expression.alternative)), kind: untyped::ConditionalKind::IfElse, - }) + })) .span(expression.span) } } @@ -500,14 +500,14 @@ impl<'ast> From> for untyped::ExpressionNode<'ast> impl<'ast> From> for untyped::ExpressionNode<'ast> { fn from(expression: pest::TernaryExpression<'ast>) -> untyped::ExpressionNode<'ast> { use crate::untyped::NodeValue; - untyped::Expression::Conditional(box ConditionalExpression { - condition: box untyped::ExpressionNode::from(*expression.condition), + untyped::Expression::Conditional(Box::new(ConditionalExpression { + condition: Box::new(untyped::ExpressionNode::from(*expression.condition)), consequence_statements: vec![], - consequence: box untyped::ExpressionNode::from(*expression.consequence), + consequence: Box::new(untyped::ExpressionNode::from(*expression.consequence)), alternative_statements: vec![], - alternative: box untyped::ExpressionNode::from(*expression.alternative), + alternative: Box::new(untyped::ExpressionNode::from(*expression.alternative)), kind: untyped::ConditionalKind::Ternary, - }) + })) .span(expression.span) } } @@ -617,7 +617,8 @@ impl<'ast> From> for untyped::ExpressionN let value = untyped::ExpressionNode::from(*initializer.value); let count = untyped::ExpressionNode::from(*initializer.count); - untyped::Expression::ArrayInitializer(box value, box count).span(initializer.span) + untyped::Expression::ArrayInitializer(Box::new(value), Box::new(count)) + .span(initializer.span) } } @@ -674,18 +675,20 @@ impl<'ast> From> for untyped::ExpressionNode<'ast> ) .span(a.span), pest::Access::Select(a) => untyped::Expression::Select( - box acc, - box untyped::RangeOrExpression::from(a.expression), + Box::new(acc), + Box::new(untyped::RangeOrExpression::from(a.expression)), ) .span(a.span), pest::Access::Dot(m) => match m.inner { pest::IdentifierOrDecimal::Identifier(id) => { - untyped::Expression::Member(box acc, box id.span.as_str()).span(m.span) - } - pest::IdentifierOrDecimal::Decimal(id) => { - untyped::Expression::Element(box acc, id.span.as_str().parse().unwrap()) + untyped::Expression::Member(Box::new(acc), Box::new(id.span.as_str())) .span(m.span) } + pest::IdentifierOrDecimal::Decimal(id) => untyped::Expression::Element( + Box::new(acc), + id.span.as_str().parse().unwrap(), + ) + .span(m.span), }, }) } @@ -783,15 +786,15 @@ impl<'ast> From> for untyped::AssigneeNode<'ast> { assignee.accesses.into_iter().fold(a, |acc, s| { match s { pest::AssigneeAccess::Select(s) => untyped::Assignee::Select( - box acc, - box untyped::RangeOrExpression::from(s.expression), + Box::new(acc), + Box::new(untyped::RangeOrExpression::from(s.expression)), ), pest::AssigneeAccess::Dot(a) => match a.inner { pest::IdentifierOrDecimal::Identifier(id) => { - untyped::Assignee::Member(box acc, box id.span.as_str()) + untyped::Assignee::Member(Box::new(acc), Box::new(id.span.as_str())) } pest::IdentifierOrDecimal::Decimal(id) => { - untyped::Assignee::Element(box acc, id.span.as_str().parse().unwrap()) + untyped::Assignee::Element(Box::new(acc), id.span.as_str().parse().unwrap()) } }, } @@ -1032,29 +1035,33 @@ mod tests { ( "field[2]", untyped::UnresolvedType::Array( - box untyped::UnresolvedType::FieldElement.mock(), + Box::new(untyped::UnresolvedType::FieldElement.mock()), untyped::Expression::IntConstant(2usize.into()).mock(), ), ), ( "field[2][3]", untyped::UnresolvedType::Array( - box untyped::UnresolvedType::Array( - box untyped::UnresolvedType::FieldElement.mock(), - untyped::Expression::IntConstant(3usize.into()).mock(), - ) - .mock(), + Box::new( + untyped::UnresolvedType::Array( + Box::new(untyped::UnresolvedType::FieldElement.mock()), + untyped::Expression::IntConstant(3usize.into()).mock(), + ) + .mock(), + ), untyped::Expression::IntConstant(2usize.into()).mock(), ), ), ( "bool[2][3u32]", untyped::UnresolvedType::Array( - box untyped::UnresolvedType::Array( - box untyped::UnresolvedType::Boolean.mock(), - untyped::Expression::U32Constant(3u32).mock(), - ) - .mock(), + Box::new( + untyped::UnresolvedType::Array( + Box::new(untyped::UnresolvedType::Boolean.mock()), + untyped::Expression::U32Constant(3u32).mock(), + ) + .mock(), + ), untyped::Expression::IntConstant(2usize.into()).mock(), ), ), @@ -1099,59 +1106,59 @@ mod tests { ( "a[3]", untyped::Expression::Select( - box untyped::Expression::Identifier("a").into(), - box untyped::RangeOrExpression::Expression( + Box::new(untyped::Expression::Identifier("a").into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(3usize.into()).into(), ), - ), + )), ), ( "a[3][4]", untyped::Expression::Select( - box untyped::Expression::Select( - box untyped::Expression::Identifier("a").into(), - box untyped::RangeOrExpression::Expression( + Box::new(untyped::Expression::Select( + Box::new(untyped::Expression::Identifier("a").into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(3usize.into()).into(), - ), + )), ) - .into(), - box untyped::RangeOrExpression::Expression( + .into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(4usize.into()).into(), - ), + )), ), ), ( "a(3)[4]", untyped::Expression::Select( - box untyped::Expression::FunctionCall( - box untyped::Expression::Identifier("a").mock(), + Box::new(untyped::Expression::FunctionCall( + Box::new(untyped::Expression::Identifier("a").mock()), None, vec![untyped::Expression::IntConstant(3usize.into()).into()], ) - .into(), - box untyped::RangeOrExpression::Expression( + .into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(4usize.into()).into(), - ), + )), ), ), ( "a(3)[4][5]", untyped::Expression::Select( - box untyped::Expression::Select( - box untyped::Expression::FunctionCall( - box untyped::Expression::Identifier("a").mock(), + Box::new(untyped::Expression::Select( + Box::new(untyped::Expression::FunctionCall( + Box::new(untyped::Expression::Identifier("a").mock()), None, vec![untyped::Expression::IntConstant(3usize.into()).into()], ) - .into(), - box untyped::RangeOrExpression::Expression( + .into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(4usize.into()).into(), - ), + )), ) - .into(), - box untyped::RangeOrExpression::Expression( + .into()), + Box::new(untyped::RangeOrExpression::Expression( untyped::Expression::IntConstant(5usize.into()).into(), - ), + )), ), ), ]; @@ -1172,13 +1179,15 @@ mod tests { assert_eq!( untyped::Module::from(ast), wrap(untyped::Expression::FunctionCall( - box untyped::Expression::Select( - box untyped::Expression::Identifier("a").mock(), - box untyped::RangeOrExpression::Expression( - untyped::Expression::IntConstant(2u32.into()).mock() + Box::new( + untyped::Expression::Select( + Box::new(untyped::Expression::Identifier("a").mock()), + Box::new(untyped::RangeOrExpression::Expression( + untyped::Expression::IntConstant(2u32.into()).mock() + )) ) - ) - .mock(), + .mock() + ), None, vec![untyped::Expression::IntConstant(3u32.into()).mock()], )) @@ -1194,12 +1203,14 @@ mod tests { assert_eq!( untyped::Module::from(ast), wrap(untyped::Expression::FunctionCall( - box untyped::Expression::FunctionCall( - box untyped::Expression::Identifier("a").mock(), - None, - vec![untyped::Expression::IntConstant(2u32.into()).mock()] - ) - .mock(), + Box::new( + untyped::Expression::FunctionCall( + Box::new(untyped::Expression::Identifier("a").mock()), + None, + vec![untyped::Expression::IntConstant(2u32.into()).mock()] + ) + .mock() + ), None, vec![untyped::Expression::IntConstant(3u32.into()).mock()], )) @@ -1261,10 +1272,10 @@ mod tests { span: span.clone(), }), expression: pest::Expression::Postfix(pest::PostfixExpression { - base: box pest::Expression::Identifier(pest::IdentifierExpression { + base: Box::new(pest::Expression::Identifier(pest::IdentifierExpression { value: String::from("foo"), span: span.clone(), - }), + })), accesses: vec![pest::Access::Call(pest::CallAccess { explicit_generics: None, arguments: pest::Arguments { diff --git a/zokrates_ast/src/untyped/mod.rs b/zokrates_ast/src/untyped/mod.rs index 07541edec..c0f77ea64 100644 --- a/zokrates_ast/src/untyped/mod.rs +++ b/zokrates_ast/src/untyped/mod.rs @@ -8,17 +8,17 @@ mod from_ast; mod node; pub mod parameter; -mod position; pub mod types; pub mod variable; pub use self::node::{Node, NodeValue}; pub use self::parameter::{Parameter, ParameterNode}; -pub use self::position::Position; use self::types::{UnresolvedSignature, UnresolvedType, UserTypeId}; pub use self::variable::{Variable, VariableNode}; use crate::common::FlatEmbed; -use std::path::{Path, PathBuf}; +pub use crate::common::Position; +pub use crate::common::{ModuleId, OwnedModuleId}; +use std::path::Path; use std::fmt; @@ -28,10 +28,6 @@ use std::collections::HashMap; /// An identifier of a function or a variable pub type Identifier<'ast> = &'ast str; -/// The identifier of a `Module`, typically a path or uri -pub type OwnedModuleId = PathBuf; -pub type ModuleId = Path; - /// A collection of `Module`s pub type Modules<'ast> = HashMap>; diff --git a/zokrates_ast/src/untyped/node.rs b/zokrates_ast/src/untyped/node.rs index 62ef299df..fbff15dd8 100644 --- a/zokrates_ast/src/untyped/node.rs +++ b/zokrates_ast/src/untyped/node.rs @@ -1,26 +1,25 @@ +use crate::common::LocalSourceSpan as Span; use std::fmt; -use zokrates_pest_ast::Span; +use zokrates_pest_ast::Span as PestSpan; #[derive(Clone)] pub struct Node { - pub start: Position, - pub end: Position, + pub span: Span, pub value: T, } impl Node { pub fn mock(e: T) -> Self { Self { - start: Position::mock(), - end: Position::mock(), + span: Span::mock(), value: e, } } } impl Node { - pub fn pos(&self) -> (Position, Position) { - (self.start, self.end) + pub fn span(&self) -> Span { + self.span } } @@ -37,8 +36,11 @@ impl fmt::Debug for Node { } impl Node { - pub fn new(start: Position, end: Position, value: T) -> Node { - Node { start, end, value } + pub fn new(from: Position, to: Position, value: T) -> Node { + Node { + span: Span { from, to }, + value, + } } } @@ -56,7 +58,7 @@ pub trait NodeValue: fmt::Display + fmt::Debug + Sized + PartialEq { Node::new(Position::mock(), Position::mock(), self) } - fn span(self, span: Span) -> Node { + fn span(self, span: PestSpan) -> Node { let from = span.start_pos().line_col(); let to = span.end_pos().line_col(); diff --git a/zokrates_ast/src/untyped/types.rs b/zokrates_ast/src/untyped/types.rs index 0f6c4ba45..f2a87ad86 100644 --- a/zokrates_ast/src/untyped/types.rs +++ b/zokrates_ast/src/untyped/types.rs @@ -69,7 +69,7 @@ impl<'ast> fmt::Display for UnresolvedType<'ast> { impl<'ast> UnresolvedType<'ast> { pub fn array(ty: UnresolvedTypeNode<'ast>, size: ExpressionNode<'ast>) -> Self { - UnresolvedType::Array(box ty, size) + UnresolvedType::Array(Box::new(ty), size) } } diff --git a/zokrates_ast/src/untyped/variable.rs b/zokrates_ast/src/untyped/variable.rs index edbc7f4cf..61f86f5dc 100644 --- a/zokrates_ast/src/untyped/variable.rs +++ b/zokrates_ast/src/untyped/variable.rs @@ -8,7 +8,7 @@ use super::Identifier; pub struct Variable<'ast> { pub is_mutable: bool, pub id: Identifier<'ast>, - pub _type: UnresolvedTypeNode<'ast>, + pub ty: UnresolvedTypeNode<'ast>, } pub type VariableNode<'ast> = Node>; @@ -22,7 +22,7 @@ impl<'ast> Variable<'ast> { Variable { is_mutable, id: id.into(), - _type: t, + ty: t, } } @@ -35,13 +35,13 @@ impl<'ast> Variable<'ast> { } pub fn get_type(&self) -> &UnresolvedType<'ast> { - &self._type.value + &self.ty.value } } impl<'ast> fmt::Display for Variable<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self._type, self.id,) + write!(f, "{} {}", self.ty, self.id,) } } @@ -50,7 +50,7 @@ impl<'ast> fmt::Debug for Variable<'ast> { write!( f, "Variable(type: {:?}, id: {:?}, is_mutable: {:?})", - self._type, self.id, self.is_mutable + self.ty, self.id, self.is_mutable ) } } diff --git a/zokrates_ast/src/zir/canonicalizer.rs b/zokrates_ast/src/zir/canonicalizer.rs new file mode 100644 index 000000000..61bf1970a --- /dev/null +++ b/zokrates_ast/src/zir/canonicalizer.rs @@ -0,0 +1,91 @@ +use super::{Folder, Identifier, Parameter, Variable, ZirAssignee}; +use std::collections::HashMap; +use zokrates_field::Field; + +#[derive(Default)] +pub struct ZirCanonicalizer<'ast> { + identifier_map: HashMap, usize>, +} + +impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> { + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(p.id.id.clone(), new_id); + + Parameter { + id: Variable::with_id_and_type(Identifier::internal(new_id), p.id.ty), + ..p + } + } + fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> ZirAssignee<'ast> { + let new_id = self.identifier_map.len(); + self.identifier_map.insert(a.id.clone(), new_id); + ZirAssignee::with_id_and_type(Identifier::internal(new_id), a.ty) + } + fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { + match self.identifier_map.get(&n) { + Some(v) => Identifier::internal(*v), + None => unreachable!(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::zir::{ + FieldElementExpression, IdentifierExpression, Signature, Type, ZirAssignee, ZirFunction, + ZirStatement, + }; + + use super::*; + use zokrates_field::Bn128Field; + + #[test] + fn canonicalize() { + let func = ZirFunction:: { + arguments: vec![Parameter::new(Variable::field_element("a"), true)], + statements: vec![ + ZirStatement::definition( + ZirAssignee::field_element("b"), + FieldElementExpression::Identifier(IdentifierExpression::new("a".into())) + .into(), + ), + ZirStatement::ret(vec![FieldElementExpression::Identifier( + IdentifierExpression::new("b".into()), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + let mut canonicalizer = ZirCanonicalizer::default(); + let result = canonicalizer.fold_function(func); + + let expected = ZirFunction:: { + arguments: vec![Parameter::new( + Variable::field_element(Identifier::internal(0usize)), + true, + )], + statements: vec![ + ZirStatement::definition( + ZirAssignee::field_element(Identifier::internal(1usize)), + FieldElementExpression::Identifier(IdentifierExpression::new( + Identifier::internal(0usize), + )) + .into(), + ), + ZirStatement::ret(vec![FieldElementExpression::Identifier( + IdentifierExpression::new(Identifier::internal(1usize)), + ) + .into()]), + ], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + }; + + assert_eq!(result, expected); + } +} diff --git a/zokrates_ast/src/zir/folder.rs b/zokrates_ast/src/zir/folder.rs index 770eb0f03..b536ccc78 100644 --- a/zokrates_ast/src/zir/folder.rs +++ b/zokrates_ast/src/zir/folder.rs @@ -1,27 +1,25 @@ // Generic walk through ZIR. Not mutating in place +use crate::common::expressions::{BinaryOrExpression, IdentifierOrExpression, UnaryOrExpression}; +use crate::common::{Fold, WithSpan}; use crate::zir::types::UBitwidth; use crate::zir::*; use zokrates_field::Field; -pub trait Fold<'ast, T: Field>: Sized { - fn fold>(self, f: &mut F) -> Self; -} - -impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for FieldElementExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_field_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for BooleanExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for BooleanExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_boolean_expression(self) } } -impl<'ast, T: Field> Fold<'ast, T> for UExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Self { +impl<'ast, T: Field, F: Folder<'ast, T>> Fold for UExpression<'ast, T> { + fn fold(self, f: &mut F) -> Self { f.fold_uint_expression(self) } } @@ -56,6 +54,20 @@ pub trait Folder<'ast, T: Field>: Sized { self.fold_variable(a) } + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Vec> { + fold_assembly_constraint(self, s) + } + + fn fold_assembly_assignment( + &mut self, + s: AssemblyAssignment<'ast, T>, + ) -> Vec> { + fold_assembly_assignment(self, s) + } + fn fold_assembly_statement( &mut self, s: ZirAssemblyStatement<'ast, T>, @@ -63,19 +75,73 @@ pub trait Folder<'ast, T: Field>: Sized { fold_assembly_statement(self, s) } + fn fold_assembly_statement_cases( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Vec> { + fold_assembly_statement_cases(self, s) + } + fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { fold_statement(self, s) } + fn fold_statement_cases(&mut self, s: ZirStatement<'ast, T>) -> Vec> { + fold_statement_cases(self, s) + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Vec> { + fold_definition_statement(self, s) + } + + fn fold_if_else_statement( + &mut self, + s: IfElseStatement<'ast, T>, + ) -> Vec> { + fold_if_else_statement(self, s) + } + + fn fold_multiple_definition_statement( + &mut self, + s: MultipleDefinitionStatement<'ast, T>, + ) -> Vec> { + fold_multiple_definition_statement(self, s) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Vec> { + fold_assertion_statement(self, s) + } + + fn fold_return_statement(&mut self, s: ReturnStatement<'ast, T>) -> Vec> { + fold_return_statement(self, s) + } + + fn fold_log_statement(&mut self, s: LogStatement<'ast, T>) -> Vec> { + fold_log_statement(self, s) + } + + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Vec> { + fold_assembly_block(self, s) + } + fn fold_identifier_expression + Id<'ast, T>>( &mut self, ty: &E::Ty, e: IdentifierExpression<'ast, E>, - ) -> IdentifierOrExpression<'ast, T, E> { + ) -> IdentifierOrExpression, E, E::Inner> { fold_identifier_expression(self, ty, e) } - fn fold_conditional_expression + Fold<'ast, T> + Conditional<'ast, T>>( + fn fold_conditional_expression + Fold + Conditional<'ast, T>>( &mut self, ty: &E::Ty, e: ConditionalExpression<'ast, T, E>, @@ -83,7 +149,7 @@ pub trait Folder<'ast, T: Field>: Sized { fold_conditional_expression(self, ty, e) } - fn fold_select_expression + Fold<'ast, T> + Select<'ast, T>>( + fn fold_select_expression + Fold + Select<'ast, T>>( &mut self, ty: &E::Ty, e: SelectExpression<'ast, T, E>, @@ -91,6 +157,27 @@ pub trait Folder<'ast, T: Field>: Sized { fold_select_expression(self, ty, e) } + fn fold_binary_expression< + L: Expr<'ast, T> + Fold, + R: Expr<'ast, T> + Fold, + E: Expr<'ast, T> + Fold, + Op, + >( + &mut self, + ty: &E::Ty, + e: BinaryExpression, + ) -> BinaryOrExpression { + fold_binary_expression(self, ty, e) + } + + fn fold_unary_expression + Fold, E: Expr<'ast, T> + Fold, Op>( + &mut self, + ty: &E::Ty, + e: UnaryExpression, + ) -> UnaryOrExpression { + fold_unary_expression(self, ty, e) + } + fn fold_expression(&mut self, e: ZirExpression<'ast, T>) -> ZirExpression<'ast, T> { match e { ZirExpression::FieldElement(e) => self.fold_field_expression(e).into(), @@ -140,23 +227,66 @@ pub trait Folder<'ast, T: Field>: Sized { ) -> UExpressionInner<'ast, T> { fold_uint_expression_inner(self, bitwidth, e) } + + fn fold_field_expression_cases( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + fold_field_expression_cases(self, e) + } + + fn fold_boolean_expression_cases( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + fold_boolean_expression_cases(self, e) + } + + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + fold_uint_expression_cases(self, bitwidth, e) + } } -pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_assembly_assignment<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyAssignment<'ast, T>, +) -> Vec> { + let assignees = s.assignee.into_iter().map(|a| f.fold_assignee(a)).collect(); + let expression = f.fold_function(s.expression); + vec![ZirAssemblyStatement::assignment(assignees, expression)] +} + +pub fn fold_assembly_constraint<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyConstraint<'ast, T>, +) -> Vec> { + let left = f.fold_field_expression(s.left); + let right = f.fold_field_expression(s.right); + vec![ZirAssemblyStatement::constraint(left, right, s.metadata)] +} + +fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ZirAssemblyStatement<'ast, T>, +) -> Vec> { + let span = s.get_span(); + f.fold_assembly_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +fn fold_assembly_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: ZirAssemblyStatement<'ast, T>, ) -> Vec> { match s { - ZirAssemblyStatement::Assignment(assignees, function) => { - let assignees = assignees.into_iter().map(|a| f.fold_assignee(a)).collect(); - let function = f.fold_function(function); - vec![ZirAssemblyStatement::Assignment(assignees, function)] - } - ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = f.fold_field_expression(lhs); - let rhs = f.fold_field_expression(rhs); - vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)] - } + ZirAssemblyStatement::Assignment(s) => f.fold_assembly_assignment(s), + ZirAssemblyStatement::Constraint(s) => f.fold_assembly_constraint(s), } } @@ -164,48 +294,119 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: ZirStatement<'ast, T>, ) -> Vec> { - let res = match s { - ZirStatement::Return(expressions) => ZirStatement::Return( - expressions - .into_iter() - .map(|e| f.fold_expression(e)) - .collect(), - ), - ZirStatement::Definition(a, e) => { - ZirStatement::Definition(f.fold_assignee(a), f.fold_expression(e)) - } - ZirStatement::IfElse(condition, consequence, alternative) => ZirStatement::IfElse( - f.fold_boolean_expression(condition), - consequence - .into_iter() - .flat_map(|e| f.fold_statement(e)) - .collect(), - alternative - .into_iter() - .flat_map(|e| f.fold_statement(e)) - .collect(), - ), - ZirStatement::Assertion(e, error) => { - ZirStatement::Assertion(f.fold_boolean_expression(e), error) - } - ZirStatement::MultipleDefinition(variables, elist) => ZirStatement::MultipleDefinition( - variables.into_iter().map(|v| f.fold_variable(v)).collect(), - f.fold_expression_list(elist), - ), - ZirStatement::Log(l, e) => ZirStatement::Log( - l, - e.into_iter() - .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) - .collect(), - ), - ZirStatement::Assembly(statements) => ZirStatement::Assembly( - statements + let span = s.get_span(); + f.fold_statement_cases(s) + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +pub fn fold_statement_cases<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ZirStatement<'ast, T>, +) -> Vec> { + let span = s.get_span(); + + match s { + ZirStatement::Return(s) => f.fold_return_statement(s), + ZirStatement::Definition(s) => f.fold_definition_statement(s), + ZirStatement::IfElse(s) => f.fold_if_else_statement(s), + ZirStatement::Assertion(s) => f.fold_assertion_statement(s), + ZirStatement::MultipleDefinition(s) => f.fold_multiple_definition_statement(s), + ZirStatement::Log(s) => f.fold_log_statement(s), + ZirStatement::Assembly(s) => f.fold_assembly_block(s), + } + .into_iter() + .map(|s| s.span(span)) + .collect() +} + +pub fn fold_multiple_definition_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: MultipleDefinitionStatement<'ast, T>, +) -> Vec> { + let expression_list = f.fold_expression_list(s.rhs); + vec![ZirStatement::MultipleDefinition( + MultipleDefinitionStatement::new( + s.assignees .into_iter() - .flat_map(|s| f.fold_assembly_statement(s)) + .map(|v| f.fold_variable(v)) .collect(), + expression_list, ), - }; - vec![res] + )] +} + +pub fn fold_if_else_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: IfElseStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::IfElse(IfElseStatement::new( + f.fold_boolean_expression(s.condition), + s.consequence + .into_iter() + .flat_map(|e| f.fold_statement(e)) + .collect(), + s.alternative + .into_iter() + .flat_map(|e| f.fold_statement(e)) + .collect(), + ))] +} + +pub fn fold_definition_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: DefinitionStatement<'ast, T>, +) -> Vec> { + let rhs = f.fold_expression(s.rhs); + vec![ZirStatement::Definition(DefinitionStatement::new( + f.fold_assignee(s.assignee), + rhs, + ))] +} + +pub fn fold_assertion_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssertionStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::Assertion(AssertionStatement::new( + f.fold_boolean_expression(s.expression), + s.error, + ))] +} + +pub fn fold_return_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: ReturnStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::Return(ReturnStatement::new( + s.inner.into_iter().map(|e| f.fold_expression(e)).collect(), + ))] +} + +pub fn fold_log_statement<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: LogStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::Log(LogStatement::new( + s.format_string, + s.expressions + .into_iter() + .map(|(t, e)| (t, e.into_iter().map(|e| f.fold_expression(e)).collect())) + .collect(), + ))] +} + +pub fn fold_assembly_block<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + s: AssemblyBlockStatement<'ast, T>, +) -> Vec> { + vec![ZirStatement::Assembly(AssemblyBlockStatement::new( + s.inner + .into_iter() + .flat_map(|s| f.fold_assembly_statement(s)) + .collect(), + ))] } pub fn fold_identifier_expression< @@ -217,16 +418,24 @@ pub fn fold_identifier_expression< f: &mut F, _: &E::Ty, e: IdentifierExpression<'ast, E>, -) -> IdentifierOrExpression<'ast, T, E> { +) -> IdentifierOrExpression, E, E::Inner> { IdentifierOrExpression::Identifier(IdentifierExpression::new(f.fold_name(e.id))) } -pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> FieldElementExpression<'ast, T> { + let span = e.get_span(); + f.fold_field_expression_cases(e).span(span) +} + +pub fn fold_field_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { match e { - FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Value(n) => FieldElementExpression::Value(n), FieldElementExpression::Identifier(id) => { match f.fold_identifier_expression(&Type::FieldElement, id) { IdentifierOrExpression::Identifier(i) => FieldElementExpression::Identifier(i), @@ -239,60 +448,49 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( SelectOrExpression::Expression(u) => u, } } - FieldElementExpression::Add(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldElementExpression::Add(box e1, box e2) - } - FieldElementExpression::Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldElementExpression::Sub(box e1, box e2) - } - FieldElementExpression::Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldElementExpression::Mult(box e1, box e2) - } - FieldElementExpression::Div(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - FieldElementExpression::Div(box e1, box e2) - } - FieldElementExpression::Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_uint_expression(e2); - FieldElementExpression::Pow(box e1, box e2) - } - FieldElementExpression::And(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - FieldElementExpression::And(box left, box right) - } - FieldElementExpression::Or(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - FieldElementExpression::Or(box left, box right) - } - FieldElementExpression::Xor(box left, box right) => { - let left = f.fold_field_expression(left); - let right = f.fold_field_expression(right); - - FieldElementExpression::Xor(box left, box right) - } - FieldElementExpression::LeftShift(box e, box by) => { - let e = f.fold_field_expression(e); - let by = f.fold_uint_expression(by); - - FieldElementExpression::LeftShift(box e, box by) + FieldElementExpression::Add(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Add(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Sub(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Sub(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Mult(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Mult(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Div(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Div(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Pow(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Pow(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::And(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::And(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Or(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Or(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Xor(e) => match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::Xor(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::LeftShift(e) => { + match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::LeftShift(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::RightShift(box e, box by) => { - let e = f.fold_field_expression(e); - let by = f.fold_uint_expression(by); - - FieldElementExpression::RightShift(box e, box by) + FieldElementExpression::RightShift(e) => { + match f.fold_binary_expression(&Type::FieldElement, e) { + BinaryOrExpression::Binary(e) => FieldElementExpression::RightShift(e), + BinaryOrExpression::Expression(e) => e, + } } FieldElementExpression::Conditional(c) => { match f.fold_conditional_expression(&Type::FieldElement, c) { @@ -303,10 +501,20 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { + let span = e.get_span(); + f.fold_boolean_expression_cases(e).span(span) +} + +pub fn fold_boolean_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> BooleanExpression<'ast, T> { + use BooleanExpression::*; + match e { BooleanExpression::Value(v) => BooleanExpression::Value(v), BooleanExpression::Identifier(id) => match f.fold_identifier_expression(&Type::Boolean, id) @@ -318,55 +526,46 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( SelectOrExpression::Select(s) => BooleanExpression::Select(s), SelectOrExpression::Expression(u) => u, }, - BooleanExpression::FieldEq(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - BooleanExpression::FieldEq(box e1, box e2) - } - BooleanExpression::BoolEq(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - BooleanExpression::BoolEq(box e1, box e2) - } - BooleanExpression::UintEq(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - BooleanExpression::UintEq(box e1, box e2) - } - BooleanExpression::FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - BooleanExpression::FieldLt(box e1, box e2) - } - BooleanExpression::UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - BooleanExpression::UintLt(box e1, box e2) - } - BooleanExpression::FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); - BooleanExpression::FieldLe(box e1, box e2) - } - BooleanExpression::UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1); - let e2 = f.fold_uint_expression(e2); - BooleanExpression::UintLe(box e1, box e2) - } - BooleanExpression::Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - BooleanExpression::Or(box e1, box e2) - } - BooleanExpression::And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1); - let e2 = f.fold_boolean_expression(e2); - BooleanExpression::And(box e1, box e2) - } - BooleanExpression::Not(box e) => { - let e = f.fold_boolean_expression(e); - BooleanExpression::Not(box e) - } + FieldEq(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldEq(e), + BinaryOrExpression::Expression(u) => u, + }, + BoolEq(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => BoolEq(e), + BinaryOrExpression::Expression(u) => u, + }, + UintEq(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintEq(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLt(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldLt(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLe(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => FieldLe(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLt(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintLt(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLe(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => UintLe(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::Boolean, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&Type::Boolean, e) { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, BooleanExpression::Conditional(c) => match f.fold_conditional_expression(&Type::Boolean, c) { ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s), @@ -385,86 +584,78 @@ pub fn fold_uint_expression<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( +fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, ty: UBitwidth, e: UExpressionInner<'ast, T>, ) -> UExpressionInner<'ast, T> { + let span = e.get_span(); + f.fold_uint_expression_cases(ty, e).span(span) +} + +pub fn fold_uint_expression_cases<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> UExpressionInner<'ast, T> { + use UExpressionInner::*; + match e { - UExpressionInner::Value(v) => UExpressionInner::Value(v), - UExpressionInner::Identifier(id) => match f.fold_identifier_expression(&ty, id) { + Value(v) => UExpressionInner::Value(v), + Identifier(id) => match f.fold_identifier_expression(&ty, id) { IdentifierOrExpression::Identifier(i) => UExpressionInner::Identifier(i), IdentifierOrExpression::Expression(e) => e, }, - UExpressionInner::Select(e) => match f.fold_select_expression(&ty, e) { + Select(e) => match f.fold_select_expression(&ty, e) { SelectOrExpression::Select(s) => UExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, }, - UExpressionInner::Add(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Add(box left, box right) - } - UExpressionInner::Sub(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Sub(box left, box right) - } - UExpressionInner::Mult(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Mult(box left, box right) - } - UExpressionInner::Div(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Div(box left, box right) - } - UExpressionInner::Rem(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Rem(box left, box right) - } - UExpressionInner::Xor(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Xor(box left, box right) - } - UExpressionInner::And(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::And(box left, box right) - } - UExpressionInner::Or(box left, box right) => { - let left = f.fold_uint_expression(left); - let right = f.fold_uint_expression(right); - - UExpressionInner::Or(box left, box right) - } - UExpressionInner::LeftShift(box e, by) => { - let e = f.fold_uint_expression(e); - - UExpressionInner::LeftShift(box e, by) - } - UExpressionInner::RightShift(box e, by) => { - let e = f.fold_uint_expression(e); - - UExpressionInner::RightShift(box e, by) - } - UExpressionInner::Not(box e) => { - let e = f.fold_uint_expression(e); - - UExpressionInner::Not(box e) - } + Add(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Rem(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Rem(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&ty, e) { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&ty, e) { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c) { - ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s), + ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, }, } @@ -495,13 +686,14 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( ) -> ZirProgram<'ast, T> { ZirProgram { main: f.fold_function(p.main), + ..p } } pub fn fold_conditional_expression< 'ast, T: Field, - E: Expr<'ast, T> + Fold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + Fold + Conditional<'ast, T>, F: Folder<'ast, T>, >( f: &mut F, @@ -518,7 +710,7 @@ pub fn fold_conditional_expression< pub fn fold_select_expression< 'ast, T: Field, - E: Expr<'ast, T> + Fold<'ast, T> + Select<'ast, T>, + E: Expr<'ast, T> + Fold + Select<'ast, T>, F: Folder<'ast, T>, >( f: &mut F, @@ -530,3 +722,34 @@ pub fn fold_select_expression< e.index.fold(f), )) } + +pub fn fold_binary_expression< + 'ast, + T: Field, + L: Expr<'ast, T> + Fold + From>, + R: Expr<'ast, T> + Fold + From>, + E: Expr<'ast, T> + Fold + From>, + F: Folder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: BinaryExpression, +) -> BinaryOrExpression { + BinaryOrExpression::Binary(BinaryExpression::new(e.left.fold(f), e.right.fold(f)).span(e.span)) +} + +pub fn fold_unary_expression< + 'ast, + T: Field, + In: Expr<'ast, T> + Fold + From>, + E: Expr<'ast, T> + Fold + From>, + F: Folder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: UnaryExpression, +) -> UnaryOrExpression { + UnaryOrExpression::Unary(UnaryExpression::new(e.inner.fold(f)).span(e.span)) +} diff --git a/zokrates_ast/src/zir/identifier.rs b/zokrates_ast/src/zir/identifier.rs index 249b2630e..3d56b07ef 100644 --- a/zokrates_ast/src/zir/identifier.rs +++ b/zokrates_ast/src/zir/identifier.rs @@ -4,13 +4,14 @@ use std::fmt; use crate::typed::Identifier as CoreIdentifier; -#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Identifier<'ast> { #[serde(borrow)] Source(SourceIdentifier<'ast>), + Internal(usize), } -#[derive(Debug, PartialEq, Clone, Hash, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum SourceIdentifier<'ast> { #[serde(borrow)] Basic(CoreIdentifier<'ast>), @@ -19,21 +20,42 @@ pub enum SourceIdentifier<'ast> { Element(Box>, u32), } -impl<'ast> fmt::Display for SourceIdentifier<'ast> { +impl<'ast> Identifier<'ast> { + pub fn internal>(id: T) -> Self { + Identifier::Internal(id.into()) + } +} + +impl<'ast> SourceIdentifier<'ast> { + pub fn select(self, index: u32) -> Self { + Self::Select(Box::new(self), index) + } + + pub fn member(self, member: MemberId) -> Self { + Self::Member(Box::new(self), member) + } + + pub fn element(self, index: u32) -> Self { + Self::Element(Box::new(self), index) + } +} + +impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - SourceIdentifier::Basic(i) => write!(f, "{}", i), - SourceIdentifier::Select(box i, index) => write!(f, "{}~{}", i, index), - SourceIdentifier::Member(box i, m) => write!(f, "{}.{}", i, m), - SourceIdentifier::Element(box i, index) => write!(f, "{}.{}", i, index), + Identifier::Source(s) => write!(f, "{}", s), + Identifier::Internal(i) => write!(f, "i{}", i), } } } -impl<'ast> fmt::Display for Identifier<'ast> { +impl<'ast> fmt::Display for SourceIdentifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Identifier::Source(s) => write!(f, "{}", s), + SourceIdentifier::Basic(i) => write!(f, "{}", i), + SourceIdentifier::Select(i, index) => write!(f, "{}~{}", i, index), + SourceIdentifier::Member(i, m) => write!(f, "{}.{}", i, m), + SourceIdentifier::Element(i, index) => write!(f, "{}.{}", i, index), } } } diff --git a/zokrates_ast/src/zir/lqc.rs b/zokrates_ast/src/zir/lqc.rs index 121b2a396..e34877c5a 100644 --- a/zokrates_ast/src/zir/lqc.rs +++ b/zokrates_ast/src/zir/lqc.rs @@ -43,7 +43,7 @@ impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { linear: { let mut l = self.linear; other.linear.iter_mut().for_each(|(c, _)| { - *c = T::zero() - &*c; + *c = T::zero() - *c; }); l.append(&mut other.linear); l @@ -51,7 +51,7 @@ impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { quadratic: { let mut q = self.quadratic; other.quadratic.iter_mut().for_each(|(c, _, _)| { - *c = T::zero() - &*c; + *c = T::zero() - *c; }); q.append(&mut other.quadratic); q @@ -68,18 +68,18 @@ impl<'ast, T: Field> LinQuadComb<'ast, T> { } Ok(Self { - constant: self.constant.clone() * rhs.constant.clone(), + constant: self.constant * rhs.constant, linear: { // lin0 * const1 + lin1 * const0 self.linear .clone() .into_iter() - .map(|(c, i)| (c * rhs.constant.clone(), i)) + .map(|(c, i)| (c * rhs.constant, i)) .chain( rhs.linear .clone() .into_iter() - .map(|(c, i)| (c * self.constant.clone(), i)), + .map(|(c, i)| (c * self.constant, i)), ) .collect() }, @@ -87,16 +87,16 @@ impl<'ast, T: Field> LinQuadComb<'ast, T> { // quad0 * const1 + quad1 * const0 + lin0 * lin1 self.quadratic .into_iter() - .map(|(c, i0, i1)| (c * rhs.constant.clone(), i0, i1)) + .map(|(c, i0, i1)| (c * rhs.constant, i0, i1)) .chain( rhs.quadratic .into_iter() - .map(|(c, i0, i1)| (c * self.constant.clone(), i0, i1)), + .map(|(c, i0, i1)| (c * self.constant, i0, i1)), ) .chain(self.linear.iter().flat_map(|(cl, l)| { rhs.linear .iter() - .map(|(cr, r)| (cl.clone() * cr.clone(), l.clone(), r.clone())) + .map(|(cr, r)| (*cl * *cr, l.clone(), r.clone())) })) .collect() }, @@ -109,23 +109,23 @@ impl<'ast, T: Field> TryFrom> for LinQuadComb<'a fn try_from(e: FieldElementExpression<'ast, T>) -> Result { match e { - FieldElementExpression::Number(v) => Ok(Self { - constant: v, + FieldElementExpression::Value(v) => Ok(Self { + constant: v.value, ..Self::default() }), FieldElementExpression::Identifier(id) => Ok(Self { linear: vec![(T::one(), id.id)], ..Self::default() }), - FieldElementExpression::Add(box left, box right) => { - Ok(Self::try_from(left)? + Self::try_from(right)?) + FieldElementExpression::Add(e) => { + Ok(Self::try_from(*e.left)? + Self::try_from(*e.right)?) } - FieldElementExpression::Sub(box left, box right) => { - Ok(Self::try_from(left)? - Self::try_from(right)?) + FieldElementExpression::Sub(e) => { + Ok(Self::try_from(*e.left)? - Self::try_from(*e.right)?) } - FieldElementExpression::Mult(box left, box right) => { - let left = Self::try_from(left)?; - let right = Self::try_from(right)?; + FieldElementExpression::Mult(e) => { + let left = Self::try_from(*e.left)?; + let right = Self::try_from(*e.right)?; left.try_mul(right) } @@ -137,30 +137,31 @@ impl<'ast, T: Field> TryFrom> for LinQuadComb<'a #[cfg(test)] mod tests { use super::*; - use crate::zir::Id; + use crate::zir::{Expr, Id}; + use std::ops::*; use zokrates_field::Bn128Field; #[test] fn add() { // (2 + 2*a) - let a = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let a = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), )) .unwrap(); // (2 + 2*a*b) - let b = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let b = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), )) .unwrap(); @@ -186,24 +187,24 @@ mod tests { #[test] fn sub() { // (2 + 2*a) - let a = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let a = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), )) .unwrap(); // (2 + 2*a*b) - let b = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let b = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ), )) .unwrap(); @@ -227,23 +228,23 @@ mod tests { } #[test] - fn mult() { + fn mul() { // (2 + 2*a) - let a = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("a".into()), + let a = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("a".into()), ), )) .unwrap(); // (2 + 2*b) - let b = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::identifier("b".into()), + let b = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::identifier("b".into()), ), )) .unwrap(); @@ -266,13 +267,13 @@ mod tests { } #[test] - fn mult_degree_error() { + fn mul_degree_error() { // 2*a*b - let a = LinQuadComb::try_from(FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("a".into()), - box FieldElementExpression::identifier("b".into()), + let a = LinQuadComb::try_from(FieldElementExpression::add( + FieldElementExpression::value(Bn128Field::from(2)), + FieldElementExpression::mul( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::identifier("b".into()), ), )) .unwrap(); diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index 60dc1467f..a7cc08019 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -1,3 +1,4 @@ +pub mod canonicalizer; pub mod folder; mod from_typed; mod identifier; @@ -11,13 +12,20 @@ mod variable; pub use self::parameter::Parameter; pub use self::types::{Type, UBitwidth}; pub use self::variable::Variable; -use crate::common::{FlatEmbed, FormatString, SourceMetadata}; +use crate::common::expressions::{BooleanValueExpression, UnaryExpression}; +use crate::common::SourceMetadata; +use crate::common::{self, FlatEmbed, ModuleMap, Span, Value, WithSpan}; +use crate::common::{ + expressions::{self, BinaryExpression, ValueExpression}, + operators::*, +}; use crate::typed::ConcreteType; pub use crate::zir::uint::{ShouldReduce, UExpression, UExpressionInner, UMetadata}; use crate::zir::types::Signature; use std::fmt; -use std::marker::PhantomData; + +use derivative::Derivative; use zokrates_field::Field; pub use self::folder::Folder; @@ -25,9 +33,12 @@ pub use self::identifier::{Identifier, SourceIdentifier}; use serde::{Deserialize, Serialize}; /// A typed program as a collection of modules, one of them being the main -#[derive(PartialEq, Eq, Debug, Clone)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug)] pub struct ZirProgram<'ast, T> { pub main: ZirFunction<'ast, T>, + pub module_map: ModuleMap, } impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> { @@ -36,10 +47,11 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> { } } /// A typed function -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash, Eq)] +#[derive(Clone, Serialize, Deserialize)] pub struct ZirFunction<'ast, T> { /// Arguments of the function - #[serde(borrow)] pub arguments: Vec>, /// Vector of statements that are executed when running the function #[serde(borrow)] @@ -48,6 +60,9 @@ pub struct ZirFunction<'ast, T> { pub signature: Signature, } +pub type IdentifierOrExpression<'ast, T, E> = + expressions::IdentifierOrExpression, E, >::Inner>; + impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!( @@ -117,58 +132,263 @@ impl RuntimeError { } } +pub type AssemblyConstraint<'ast, T> = + crate::common::statements::AssemblyConstraint>; +pub type AssemblyAssignment<'ast, T> = + crate::common::statements::AssemblyAssignment>, ZirFunction<'ast, T>>; + #[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] pub enum ZirAssemblyStatement<'ast, T> { - Assignment( - #[serde(borrow)] Vec>, - ZirFunction<'ast, T>, - ), - Constraint( - FieldElementExpression<'ast, T>, - FieldElementExpression<'ast, T>, - SourceMetadata, - ), + #[serde(borrow)] + Assignment(AssemblyAssignment<'ast, T>), + Constraint(AssemblyConstraint<'ast, T>), +} + +impl<'ast, T> ZirAssemblyStatement<'ast, T> { + pub fn assignment(assignee: Vec>, expression: ZirFunction<'ast, T>) -> Self { + ZirAssemblyStatement::Assignment(AssemblyAssignment::new(assignee, expression)) + } + + pub fn constraint( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + metadata: SourceMetadata, + ) -> Self { + ZirAssemblyStatement::Constraint(AssemblyConstraint::new(left, right, metadata)) + } +} + +impl<'ast, T> WithSpan for ZirAssemblyStatement<'ast, T> { + fn span(self, span: Option) -> Self { + match self { + ZirAssemblyStatement::Assignment(s) => ZirAssemblyStatement::Assignment(s.span(span)), + ZirAssemblyStatement::Constraint(s) => ZirAssemblyStatement::Constraint(s.span(span)), + } + } + + fn get_span(&self) -> Option { + match self { + ZirAssemblyStatement::Assignment(s) => s.get_span(), + ZirAssemblyStatement::Constraint(s) => s.get_span(), + } + } } impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - ZirAssemblyStatement::Assignment(ref lhs, ref rhs) => { + ZirAssemblyStatement::Assignment(ref s) => { write!( f, "{} <-- {};", - lhs.iter() + s.assignee + .iter() .map(|a| a.to_string()) .collect::>() .join(", "), - rhs + s.expression ) } - ZirAssemblyStatement::Constraint(ref lhs, ref rhs, _) => { - write!(f, "{} === {};", lhs, rhs) + ZirAssemblyStatement::Constraint(ref s) => { + write!(f, "{}", s) } } } } +pub type DefinitionStatement<'ast, T> = + common::expressions::DefinitionStatement, ZirExpression<'ast, T>>; +pub type AssertionStatement<'ast, T> = + common::expressions::AssertionStatement, RuntimeError>; +pub type ReturnStatement<'ast, T> = + common::expressions::ReturnStatement>>; +pub type LogStatement<'ast, T> = + common::statements::LogStatement<(ConcreteType, Vec>)>; + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IfElseStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + #[serde(borrow)] + pub condition: BooleanExpression<'ast, T>, + pub consequence: Vec>, + pub alternative: Vec>, +} + +impl<'ast, T> IfElseStatement<'ast, T> { + pub fn new( + condition: BooleanExpression<'ast, T>, + consequence: Vec>, + alternative: Vec>, + ) -> Self { + Self { + span: None, + condition, + consequence, + alternative, + } + } +} + +impl<'ast, T> WithSpan for IfElseStatement<'ast, T> { + fn span(self, span: Option) -> Self { + Self { span, ..self } + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MultipleDefinitionStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + #[serde(borrow)] + pub assignees: Vec>, + pub rhs: ZirExpressionList<'ast, T>, +} + +impl<'ast, T> MultipleDefinitionStatement<'ast, T> { + pub fn new(assignees: Vec>, rhs: ZirExpressionList<'ast, T>) -> Self { + Self { + span: None, + assignees, + rhs, + } + } +} + +impl<'ast, T> WithSpan for MultipleDefinitionStatement<'ast, T> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +impl<'ast, T: fmt::Display> fmt::Display for MultipleDefinitionStatement<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (i, id) in self.assignees.iter().enumerate() { + write!(f, "{}", id)?; + if i < self.assignees.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, " = {};", self.rhs) + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AssemblyBlockStatement<'ast, T> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, + #[serde(borrow)] + pub inner: Vec>, +} + +impl<'ast, T> AssemblyBlockStatement<'ast, T> { + pub fn new(inner: Vec>) -> Self { + Self { span: None, inner } + } +} + +impl<'ast, T> WithSpan for AssemblyBlockStatement<'ast, T> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + /// A statement in a `ZirFunction` -#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum ZirStatement<'ast, T> { - Return(Vec>), - Definition(ZirAssignee<'ast>, ZirExpression<'ast, T>), - IfElse( - BooleanExpression<'ast, T>, - Vec>, - Vec>, - ), - Assertion(BooleanExpression<'ast, T>, RuntimeError), - MultipleDefinition(Vec>, ZirExpressionList<'ast, T>), - Log( - FormatString, - Vec<(ConcreteType, Vec>)>, - ), + Return(ReturnStatement<'ast, T>), + Definition(DefinitionStatement<'ast, T>), + IfElse(IfElseStatement<'ast, T>), + Assertion(AssertionStatement<'ast, T>), + MultipleDefinition(MultipleDefinitionStatement<'ast, T>), + Log(LogStatement<'ast, T>), #[serde(borrow)] - Assembly(Vec>), + Assembly(AssemblyBlockStatement<'ast, T>), +} + +impl<'ast, T> ZirStatement<'ast, T> { + pub fn definition(a: ZirAssignee<'ast>, e: ZirExpression<'ast, T>) -> Self { + Self::Definition(DefinitionStatement::new(a, e)) + } + + pub fn multiple_definition( + assignees: Vec>, + e: ZirExpressionList<'ast, T>, + ) -> Self { + Self::MultipleDefinition(MultipleDefinitionStatement::new(assignees, e)) + } + + pub fn assertion(e: BooleanExpression<'ast, T>, error: RuntimeError) -> Self { + Self::Assertion(AssertionStatement::new(e, error)) + } + + pub fn ret(e: Vec>) -> Self { + Self::Return(ReturnStatement::new(e)) + } + + pub fn assembly(s: Vec>) -> Self { + Self::Assembly(AssemblyBlockStatement::new(s)) + } + + pub fn if_else( + condition: BooleanExpression<'ast, T>, + consequence: Vec>, + alternative: Vec>, + ) -> Self { + Self::IfElse(IfElseStatement::new(condition, consequence, alternative)) + } +} + +impl<'ast, T> WithSpan for ZirStatement<'ast, T> { + fn span(self, span: Option) -> Self { + use ZirStatement::*; + + match self { + Return(e) => Return(e.span(span)), + Definition(e) => Definition(e.span(span)), + Assertion(e) => Assertion(e.span(span)), + IfElse(e) => IfElse(e.span(span)), + MultipleDefinition(e) => MultipleDefinition(e.span(span)), + Log(e) => Log(e.span(span)), + Assembly(e) => Assembly(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use ZirStatement::*; + + match self { + Return(e) => e.get_span(), + Definition(e) => e.get_span(), + Assertion(e) => e.get_span(), + IfElse(e) => e.get_span(), + MultipleDefinition(e) => e.get_span(), + Log(e) => e.get_span(), + Assembly(e) => e.get_span(), + } + } } impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> { @@ -180,14 +400,15 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> { impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result { write!(f, "{}", "\t".repeat(depth))?; + match self { - ZirStatement::Return(ref exprs) => { + ZirStatement::Return(ref s) => { write!(f, "return")?; - if !exprs.is_empty() { + if !s.inner.is_empty() { write!( f, " {}", - exprs + s.inner .iter() .map(|e| e.to_string()) .collect::>() @@ -196,43 +417,37 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { } write!(f, ";") } - ZirStatement::Definition(ref lhs, ref rhs) => { - write!(f, "{} = {};", lhs, rhs) + ZirStatement::Definition(ref s) => { + write!(f, "{}", s) } - ZirStatement::IfElse(ref condition, ref consequence, ref alternative) => { - writeln!(f, "if {} {{", condition)?; - for s in consequence { + ZirStatement::IfElse(ref s) => { + writeln!(f, "if {} {{", s.condition)?; + for s in &s.consequence { s.fmt_indented(f, depth + 1)?; writeln!(f)?; } writeln!(f, "{}}} else {{", "\t".repeat(depth))?; - for s in alternative { + for s in &s.alternative { s.fmt_indented(f, depth + 1)?; writeln!(f)?; } write!(f, "{}}}", "\t".repeat(depth)) } - ZirStatement::Assertion(ref e, ref error) => { - write!(f, "assert({}", e)?; - match error { + ZirStatement::Assertion(ref s) => { + write!(f, "assert({}", s.expression)?; + match &s.error { RuntimeError::SourceAssertion(message) => write!(f, ", \"{}\");", message), error => write!(f, "); // {}", error), } } - ZirStatement::MultipleDefinition(ref ids, ref rhs) => { - for (i, id) in ids.iter().enumerate() { - write!(f, "{}", id)?; - if i < ids.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, " = {};", rhs) + ZirStatement::MultipleDefinition(ref s) => { + write!(f, "{}", s) } - ZirStatement::Log(ref l, ref expressions) => write!( + ZirStatement::Log(ref e) => write!( f, "log(\"{}\"), {});", - l, - expressions + e.format_string, + e.expressions .iter() .map(|(_, e)| format!( "[{}]", @@ -244,9 +459,9 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> { .collect::>() .join(", ") ), - ZirStatement::Assembly(statements) => { + ZirStatement::Assembly(s) => { writeln!(f, "asm {{")?; - for s in statements { + for s in &s.inner { writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; } write!(f, "{}}}", "\t".repeat(depth)) @@ -259,30 +474,12 @@ pub trait Typed { fn get_type(&self) -> Type; } -#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] -pub struct IdentifierExpression<'ast, E> { - #[serde(borrow)] - pub id: Identifier<'ast>, - ty: PhantomData, -} - -impl<'ast, E> fmt::Display for IdentifierExpression<'ast, E> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.id) - } -} - -impl<'ast, E> IdentifierExpression<'ast, E> { - pub fn new(id: Identifier<'ast>) -> Self { - IdentifierExpression { - id, - ty: PhantomData, - } - } -} - -#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ConditionalExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, #[serde(borrow)] pub condition: Box>, pub consequence: Box, @@ -292,9 +489,10 @@ pub struct ConditionalExpression<'ast, T, E> { impl<'ast, T, E> ConditionalExpression<'ast, T, E> { pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self { ConditionalExpression { - condition: box condition, - consequence: box consequence, - alternative: box alternative, + span: None, + condition: Box::new(condition), + consequence: Box::new(consequence), + alternative: Box::new(alternative), } } } @@ -304,13 +502,28 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for ConditionalExpress write!( f, "{} ? {} : {}", - self.condition, self.consequence, self.alternative + self.condition, self.consequence, self.alternative, ) } } -#[derive(Debug, Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +impl<'ast, T, E> WithSpan for ConditionalExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct SelectExpression<'ast, T, E> { + #[derivative(PartialEq = "ignore", PartialOrd = "ignore", Hash = "ignore")] + pub span: Option, pub array: Vec, #[serde(borrow)] pub index: Box>, @@ -319,8 +532,9 @@ pub struct SelectExpression<'ast, T, E> { impl<'ast, T, E> SelectExpression<'ast, T, E> { pub fn new(array: Vec, index: UExpression<'ast, T>) -> Self { SelectExpression { + span: None, array, - index: box index, + index: Box::new(index), } } } @@ -340,12 +554,46 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for SelectExpression<' } } +impl<'ast, T, E> WithSpan for SelectExpression<'ast, T, E> { + fn span(mut self, span: Option) -> Self { + self.span = span; + self + } + + fn get_span(&self) -> Option { + self.span + } +} + /// A typed expression -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Serialize, Deserialize)] pub enum ZirExpression<'ast, T> { + #[serde(borrow)] Boolean(BooleanExpression<'ast, T>), FieldElement(FieldElementExpression<'ast, T>), - Uint(#[serde(borrow)] UExpression<'ast, T>), + Uint(UExpression<'ast, T>), +} + +impl<'ast, T> WithSpan for ZirExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use ZirExpression::*; + match self { + Boolean(e) => Boolean(e.span(span)), + FieldElement(e) => FieldElement(e.span(span)), + Uint(e) => Uint(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use ZirExpression::*; + match self { + Boolean(e) => e.get_span(), + FieldElement(e) => e.get_span(), + Uint(e) => e.get_span(), + } + } } impl<'ast, T: Field> From> for ZirExpression<'ast, T> { @@ -418,7 +666,9 @@ pub trait MultiTyped { fn get_types(&self) -> &Vec; } -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Serialize, Deserialize)] pub enum ZirExpressionList<'ast, T> { EmbedCall( FlatEmbed, @@ -427,111 +677,126 @@ pub enum ZirExpressionList<'ast, T> { ), } +pub type IdentifierExpression<'ast, E> = expressions::IdentifierExpression, E>; + /// An expression of type `field` -#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Eq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum FieldElementExpression<'ast, T> { - Number(T), + Value(ValueExpression), #[serde(borrow)] Identifier(IdentifierExpression<'ast, Self>), Select(SelectExpression<'ast, T, Self>), - Add( - Box>, - Box>, - ), - Sub( - Box>, - Box>, - ), - Mult( - Box>, - Box>, - ), - Div( - Box>, - Box>, - ), - Pow( - Box>, - #[serde(borrow)] Box>, - ), - And( - Box>, - Box>, - ), - Or( - Box>, - Box>, - ), - Xor( - Box>, - Box>, - ), - LeftShift( - Box>, - Box>, - ), - RightShift( - Box>, - Box>, - ), + Add(BinaryExpression), + Sub(BinaryExpression), + Mult(BinaryExpression), + Div(BinaryExpression), + Pow(BinaryExpression, Self>), + And(BinaryExpression), + Or(BinaryExpression), + Xor(BinaryExpression), + LeftShift(BinaryExpression, Self>), + RightShift(BinaryExpression, Self>), Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>), } impl<'ast, T> FieldElementExpression<'ast, T> { + pub fn number(n: T) -> Self { + Self::Value(ValueExpression::new(n)) + } + + pub fn pow(self, right: UExpression<'ast, T>) -> Self { + Self::Pow(BinaryExpression::new(self, right)) + } + pub fn is_linear(&self) -> bool { match self { - FieldElementExpression::Number(_) => true, + FieldElementExpression::Value(_) => true, FieldElementExpression::Identifier(_) => true, - FieldElementExpression::Add(box left, box right) => { - left.is_linear() && right.is_linear() - } - FieldElementExpression::Sub(box left, box right) => { - left.is_linear() && right.is_linear() - } - FieldElementExpression::Mult(box left, box right) => matches!( - (left, right), - (FieldElementExpression::Number(_), _) | (_, FieldElementExpression::Number(_)) + FieldElementExpression::Add(e) => e.left.is_linear() && e.right.is_linear(), + FieldElementExpression::Sub(e) => e.left.is_linear() && e.right.is_linear(), + FieldElementExpression::Mult(e) => matches!( + (&*e.left, &*e.right), + (FieldElementExpression::Value(_), _) | (_, FieldElementExpression::Value(_)) ), _ => false, } } + + pub fn left_shift(self, by: UExpression<'ast, T>) -> Self { + FieldElementExpression::LeftShift(BinaryExpression::new(self, by)) + } + + pub fn right_shift(self, by: UExpression<'ast, T>) -> Self { + FieldElementExpression::RightShift(BinaryExpression::new(self, by)) + } +} + +impl<'ast, T: Field> std::ops::BitAnd for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + FieldElementExpression::And(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitOr for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + FieldElementExpression::Or(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::BitXor for FieldElementExpression<'ast, T> { + type Output = Self; + + fn bitxor(self, other: Self) -> Self { + FieldElementExpression::Xor(BinaryExpression::new(self, other)) + } } /// An expression of type `bool` -#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum BooleanExpression<'ast, T> { - Value(bool), + Value(BooleanValueExpression), #[serde(borrow)] Identifier(IdentifierExpression<'ast, Self>), Select(SelectExpression<'ast, T, Self>), FieldLt( - Box>, - Box>, + BinaryExpression< + OpLt, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), FieldLe( - Box>, - Box>, + BinaryExpression< + OpLe, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), FieldEq( - Box>, - Box>, - ), - UintLt(Box>, Box>), - UintLe(Box>, Box>), - UintEq(Box>, Box>), - BoolEq( - Box>, - Box>, - ), - Or( - Box>, - Box>, - ), - And( - Box>, - Box>, + BinaryExpression< + OpEq, + FieldElementExpression<'ast, T>, + FieldElementExpression<'ast, T>, + Self, + >, ), - Not(Box>), + UintLt(BinaryExpression, UExpression<'ast, T>, Self>), + UintLe(BinaryExpression, UExpression<'ast, T>, Self>), + UintEq(BinaryExpression, UExpression<'ast, T>, Self>), + BoolEq(BinaryExpression), + Or(BinaryExpression), + And(BinaryExpression), + Not(UnaryExpression), Conditional(ConditionalExpression<'ast, T, BooleanExpression<'ast, T>>), } @@ -544,9 +809,9 @@ impl<'ast, T> Iterator for ConjunctionIterator> { fn next(&mut self) -> Option { self.current.pop().and_then(|n| match n { - BooleanExpression::And(box left, box right) => { - self.current.push(left); - self.current.push(right); + BooleanExpression::And(e) => { + self.current.push(*e.left); + self.current.push(*e.right); self.next() } n => Some(n), @@ -593,23 +858,19 @@ impl<'ast, T> From> for UExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FieldElementExpression::Number(ref i) => write!(f, "{}", i), + FieldElementExpression::Value(ref i) => write!(f, "{}", i), FieldElementExpression::Identifier(ref var) => write!(f, "{}", var), FieldElementExpression::Select(ref e) => write!(f, "{}", e), - FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - FieldElementExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - FieldElementExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - FieldElementExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), - FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), - FieldElementExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - FieldElementExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - FieldElementExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - FieldElementExpression::LeftShift(ref lhs, ref rhs) => { - write!(f, "({} << {})", lhs, rhs) - } - FieldElementExpression::RightShift(ref lhs, ref rhs) => { - write!(f, "({} >> {})", lhs, rhs) - } + FieldElementExpression::Add(ref e) => write!(f, "{}", e), + FieldElementExpression::Sub(ref e) => write!(f, "{}", e), + FieldElementExpression::Mult(ref e) => write!(f, "{}", e), + FieldElementExpression::Div(ref e) => write!(f, "{}", e), + FieldElementExpression::Pow(ref e) => write!(f, "{}", e), + FieldElementExpression::And(ref e) => write!(f, "{}", e), + FieldElementExpression::Or(ref e) => write!(f, "{}", e), + FieldElementExpression::Xor(ref e) => write!(f, "{}", e), + FieldElementExpression::LeftShift(ref e) => write!(f, "{}", e), + FieldElementExpression::RightShift(ref e) => write!(f, "{}", e), FieldElementExpression::Conditional(ref c) => { write!(f, "{}", c) } @@ -623,17 +884,17 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::Value(ref v) => write!(f, "{}", v), UExpressionInner::Identifier(ref var) => write!(f, "{}", var), UExpressionInner::Select(ref e) => write!(f, "{}", e), - UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), - UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), - UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - UExpressionInner::Div(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), - UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), - UExpressionInner::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), - UExpressionInner::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), - UExpressionInner::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), - UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), - UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), - UExpressionInner::Not(ref e) => write!(f, "!{}", e), + UExpressionInner::Add(ref e) => write!(f, "{}", e), + UExpressionInner::Sub(ref e) => write!(f, "{}", e), + UExpressionInner::Mult(ref e) => write!(f, "{}", e), + UExpressionInner::Div(ref e) => write!(f, "{}", e), + UExpressionInner::Rem(ref e) => write!(f, "{}", e), + UExpressionInner::Xor(ref e) => write!(f, "{}", e), + UExpressionInner::And(ref e) => write!(f, "{}", e), + UExpressionInner::Or(ref e) => write!(f, "{}", e), + UExpressionInner::LeftShift(ref e) => write!(f, "{}", e), + UExpressionInner::RightShift(ref e) => write!(f, "{}", e), + UExpressionInner::Not(ref e) => write!(f, "{}", e), UExpressionInner::Conditional(ref c) => { write!(f, "{}", c) } @@ -645,18 +906,18 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { BooleanExpression::Identifier(ref var) => write!(f, "{}", var), - BooleanExpression::Value(b) => write!(f, "{}", b), + BooleanExpression::Value(ref b) => write!(f, "{}", b), BooleanExpression::Select(ref e) => write!(f, "{}", e), - BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs), - BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs), - BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs), - BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs), - BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs), - BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs), - BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs), - BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "({} || {})", lhs, rhs), - BooleanExpression::And(ref lhs, ref rhs) => write!(f, "({} && {})", lhs, rhs), - BooleanExpression::Not(ref exp) => write!(f, "!{}", exp), + BooleanExpression::FieldLt(ref e) => write!(f, "{}", e), + BooleanExpression::UintLt(ref e) => write!(f, "{}", e), + BooleanExpression::FieldLe(ref e) => write!(f, "{}", e), + BooleanExpression::UintLe(ref e) => write!(f, "{}", e), + BooleanExpression::FieldEq(ref e) => write!(f, "{}", e), + BooleanExpression::BoolEq(ref e) => write!(f, "{}", e), + BooleanExpression::UintEq(ref e) => write!(f, "{}", e), + BooleanExpression::Or(ref e) => write!(f, "{}", e), + BooleanExpression::And(ref e) => write!(f, "{}", e), + BooleanExpression::Not(ref exp) => write!(f, "{}", exp), BooleanExpression::Conditional(ref c) => { write!(f, "{}", c) } @@ -664,6 +925,69 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { } } +impl<'ast, T> std::ops::Not for BooleanExpression<'ast, T> { + type Output = Self; + + fn not(self) -> Self { + Self::Not(UnaryExpression::new(self)) + } +} + +impl<'ast, T> std::ops::BitAnd for BooleanExpression<'ast, T> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + Self::And(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> std::ops::BitOr for BooleanExpression<'ast, T> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + Self::Or(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T> BooleanExpression<'ast, T> { + pub fn uint_eq(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintEq(BinaryExpression::new(left, right)) + } + + pub fn bool_eq(left: BooleanExpression<'ast, T>, right: BooleanExpression<'ast, T>) -> Self { + Self::BoolEq(BinaryExpression::new(left, right)) + } + + pub fn field_eq( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldEq(BinaryExpression::new(left, right)) + } + + pub fn uint_lt(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLt(BinaryExpression::new(left, right)) + } + + pub fn uint_le(left: UExpression<'ast, T>, right: UExpression<'ast, T>) -> Self { + Self::UintLe(BinaryExpression::new(left, right)) + } + + pub fn field_lt( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLt(BinaryExpression::new(left, right)) + } + + pub fn field_le( + left: FieldElementExpression<'ast, T>, + right: FieldElementExpression<'ast, T>, + ) -> Self { + Self::FieldLe(BinaryExpression::new(left, right)) + } +} + impl<'ast, T: fmt::Display> fmt::Display for ZirExpressionList<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -709,8 +1033,52 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ZirExpressionList<'ast, T> { } } +impl<'ast, T: Field> std::ops::Add for FieldElementExpression<'ast, T> { + type Output = Self; + + fn add(self, other: Self) -> Self { + FieldElementExpression::Add(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::Sub for FieldElementExpression<'ast, T> { + type Output = Self; + + fn sub(self, other: Self) -> Self { + FieldElementExpression::Sub(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::Mul for FieldElementExpression<'ast, T> { + type Output = Self; + + fn mul(self, other: Self) -> Self { + FieldElementExpression::Mult(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Field> std::ops::Div for FieldElementExpression<'ast, T> { + type Output = Self; + + fn div(self, other: Self) -> Self { + FieldElementExpression::Div(BinaryExpression::new(self, other)) + } +} + +impl<'ast, T: Clone> Value for FieldElementExpression<'ast, T> { + type Value = T; +} + +impl<'ast, T> Value for BooleanExpression<'ast, T> { + type Value = bool; +} + +impl<'ast, T> Value for UExpression<'ast, T> { + type Value = u128; +} + // Common behaviour across expressions -pub trait Expr<'ast, T>: fmt::Display + PartialEq + From> { +pub trait Expr<'ast, T>: Value + fmt::Display + PartialEq + From> { type Inner; type Ty: Clone + IntoType; @@ -721,6 +1089,8 @@ pub trait Expr<'ast, T>: fmt::Display + PartialEq + From> fn as_inner(&self) -> &Self::Inner; fn as_inner_mut(&mut self) -> &mut Self::Inner; + + fn value(_: Self::Value) -> Self::Inner; } impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> { @@ -742,6 +1112,10 @@ impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: ::Value) -> Self::Inner { + Self::Value(ValueExpression::new(v)) + } } impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> { @@ -763,6 +1137,10 @@ impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { self } + + fn value(v: as Value>::Value) -> Self::Inner { + Self::Value(ValueExpression::new(v)) + } } impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> { @@ -784,6 +1162,10 @@ impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> { fn as_inner_mut(&mut self) -> &mut Self::Inner { &mut self.inner } + + fn value(v: Self::Value) -> Self::Inner { + UExpressionInner::Value(ValueExpression::new(v)) + } } pub trait Id<'ast, T>: Expr<'ast, T> { @@ -808,11 +1190,6 @@ impl<'ast, T: Field> Id<'ast, T> for UExpression<'ast, T> { } } -pub enum IdentifierOrExpression<'ast, T, E: Expr<'ast, T>> { - Identifier(IdentifierExpression<'ast, E>), - Expression(E::Inner), -} - pub trait Conditional<'ast, T> { fn conditional( condition: BooleanExpression<'ast, T>, @@ -929,9 +1306,150 @@ impl<'ast, T: Field> Constant for ZirExpression<'ast, T> { } } +impl<'ast, T> WithSpan for FieldElementExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use FieldElementExpression::*; + match self { + Select(e) => Select(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + Add(e) => Add(e.span(span)), + Value(e) => Value(e.span(span)), + Sub(e) => Sub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Div(e) => Div(e.span(span)), + Pow(e) => Pow(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + Xor(e) => Xor(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use FieldElementExpression::*; + match self { + Select(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + Add(e) => e.get_span(), + Value(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Pow(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + Xor(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for BooleanExpression<'ast, T> { + fn span(self, span: Option) -> Self { + use BooleanExpression::*; + match self { + Select(e) => Select(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + Value(e) => Value(e.span(span)), + FieldLt(e) => FieldLt(e.span(span)), + FieldLe(e) => FieldLe(e.span(span)), + FieldEq(e) => FieldEq(e.span(span)), + UintLt(e) => UintLt(e.span(span)), + UintLe(e) => UintLe(e.span(span)), + UintEq(e) => UintEq(e.span(span)), + BoolEq(e) => BoolEq(e.span(span)), + Or(e) => Or(e.span(span)), + And(e) => And(e.span(span)), + Not(e) => Not(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use BooleanExpression::*; + match self { + Select(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + Value(e) => e.get_span(), + FieldLt(e) => e.get_span(), + FieldLe(e) => e.get_span(), + FieldEq(e) => e.get_span(), + UintLt(e) => e.get_span(), + UintLe(e) => e.get_span(), + UintEq(e) => e.get_span(), + BoolEq(e) => e.get_span(), + Or(e) => e.get_span(), + And(e) => e.get_span(), + Not(e) => e.get_span(), + } + } +} + +impl<'ast, T> WithSpan for UExpression<'ast, T> { + fn span(self, span: Option) -> Self { + Self { + inner: self.inner.span(span), + ..self + } + } + + fn get_span(&self) -> Option { + self.inner.get_span() + } +} + +impl<'ast, T> WithSpan for UExpressionInner<'ast, T> { + fn span(self, span: Option) -> Self { + use UExpressionInner::*; + match self { + Select(e) => Select(e.span(span)), + Identifier(e) => Identifier(e.span(span)), + Conditional(e) => Conditional(e.span(span)), + Value(e) => Value(e.span(span)), + Add(e) => Add(e.span(span)), + Sub(e) => Sub(e.span(span)), + Mult(e) => Mult(e.span(span)), + Div(e) => Div(e.span(span)), + Rem(e) => Rem(e.span(span)), + Xor(e) => Xor(e.span(span)), + And(e) => And(e.span(span)), + Or(e) => Or(e.span(span)), + LeftShift(e) => LeftShift(e.span(span)), + RightShift(e) => RightShift(e.span(span)), + Not(e) => Not(e.span(span)), + } + } + + fn get_span(&self) -> Option { + use UExpressionInner::*; + match self { + Select(e) => e.get_span(), + Identifier(e) => e.get_span(), + Conditional(e) => e.get_span(), + Value(e) => e.get_span(), + Add(e) => e.get_span(), + Sub(e) => e.get_span(), + Mult(e) => e.get_span(), + Div(e) => e.get_span(), + Rem(e) => e.get_span(), + Xor(e) => e.get_span(), + And(e) => e.get_span(), + Or(e) => e.get_span(), + LeftShift(e) => e.get_span(), + RightShift(e) => e.get_span(), + Not(e) => e.get_span(), + } + } +} + impl<'ast, T: Field> Constant for FieldElementExpression<'ast, T> { fn is_constant(&self) -> bool { - matches!(self, FieldElementExpression::Number(..)) + matches!(self, FieldElementExpression::Value(..)) } } diff --git a/zokrates_ast/src/zir/parameter.rs b/zokrates_ast/src/zir/parameter.rs index 203a291a7..ac413e54a 100644 --- a/zokrates_ast/src/zir/parameter.rs +++ b/zokrates_ast/src/zir/parameter.rs @@ -1,33 +1,3 @@ -use crate::zir::Variable; -use serde::{Deserialize, Serialize}; -use std::fmt; +use super::Variable; -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] -pub struct Parameter<'ast> { - #[serde(borrow)] - pub id: Variable<'ast>, - pub private: bool, -} - -impl<'ast> Parameter<'ast> { - #[cfg(test)] - pub fn private(v: Variable<'ast>) -> Self { - Parameter { - id: v, - private: true, - } - } -} - -impl<'ast> fmt::Display for Parameter<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let visibility = if self.private { "private " } else { "" }; - write!(f, "{}{} {}", visibility, self.id.get_type(), self.id.id) - } -} - -impl<'ast> fmt::Debug for Parameter<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Parameter(variable: {:?})", self.id) - } -} +pub type Parameter<'ast> = crate::common::Parameter>; diff --git a/zokrates_ast/src/zir/result_folder.rs b/zokrates_ast/src/zir/result_folder.rs index 6ea741cf6..d368d972d 100644 --- a/zokrates_ast/src/zir/result_folder.rs +++ b/zokrates_ast/src/zir/result_folder.rs @@ -1,27 +1,30 @@ // Generic walk through ZIR. Not mutating in place +use crate::common::expressions::{BinaryOrExpression, IdentifierOrExpression, UnaryOrExpression}; +use crate::common::ResultFold; +use crate::common::WithSpan; use crate::zir::types::UBitwidth; use crate::zir::*; use zokrates_field::Field; -pub trait ResultFold<'ast, T: Field>: Sized { - fn fold>(self, f: &mut F) -> Result; -} - -impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for FieldElementExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_field_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for BooleanExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold + for BooleanExpression<'ast, T> +{ + fn fold(self, f: &mut F) -> Result { f.fold_boolean_expression(self) } } -impl<'ast, T: Field> ResultFold<'ast, T> for UExpression<'ast, T> { - fn fold>(self, f: &mut F) -> Result { +impl<'ast, T: Field, F: ResultFolder<'ast, T>> ResultFold for UExpression<'ast, T> { + fn fold(self, f: &mut F) -> Result { f.fold_uint_expression(self) } } @@ -61,6 +64,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized { self.fold_variable(a) } + fn fold_assembly_constraint( + &mut self, + s: AssemblyConstraint<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_constraint(self, s) + } + + fn fold_assembly_assignment( + &mut self, + s: AssemblyAssignment<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_assignment(self, s) + } + fn fold_assembly_statement( &mut self, s: ZirAssemblyStatement<'ast, T>, @@ -68,6 +85,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_assembly_statement(self, s) } + fn fold_assembly_statement_cases( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_statement_cases(self, s) + } + fn fold_statement( &mut self, s: ZirStatement<'ast, T>, @@ -75,6 +99,62 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_statement(self, s) } + fn fold_statement_cases( + &mut self, + s: ZirStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_statement_cases(self, s) + } + + fn fold_definition_statement( + &mut self, + s: DefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_definition_statement(self, s) + } + + fn fold_multiple_definition_statement( + &mut self, + s: MultipleDefinitionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_multiple_definition_statement(self, s) + } + + fn fold_return_statement( + &mut self, + s: ReturnStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_return_statement(self, s) + } + + fn fold_log_statement( + &mut self, + s: LogStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_log_statement(self, s) + } + + fn fold_assembly_block( + &mut self, + s: AssemblyBlockStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assembly_block(self, s) + } + + fn fold_assertion_statement( + &mut self, + s: AssertionStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_assertion_statement(self, s) + } + + fn fold_if_else_statement( + &mut self, + s: IfElseStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_if_else_statement(self, s) + } + fn fold_expression( &mut self, e: ZirExpression<'ast, T>, @@ -104,16 +184,18 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } } - fn fold_identifier_expression + Id<'ast, T> + ResultFold<'ast, T>>( + fn fold_identifier_expression< + E: Expr<'ast, T> + Id<'ast, T> + ResultFold, + >( &mut self, ty: &E::Ty, id: IdentifierExpression<'ast, E>, - ) -> Result, Self::Error> { + ) -> Result, E, E::Inner>, Self::Error> { fold_identifier_expression(self, ty, id) } fn fold_conditional_expression< - E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + ResultFold + Conditional<'ast, T>, >( &mut self, ty: &E::Ty, @@ -122,7 +204,35 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_conditional_expression(self, ty, e) } - fn fold_select_expression + ResultFold<'ast, T> + Select<'ast, T>>( + #[allow(clippy::type_complexity)] + fn fold_binary_expression< + L: Expr<'ast, T> + PartialEq + ResultFold, + R: Expr<'ast, T> + PartialEq + ResultFold, + E: Expr<'ast, T> + PartialEq + ResultFold, + Op, + >( + &mut self, + ty: &E::Ty, + e: BinaryExpression, + ) -> Result, Self::Error> { + fold_binary_expression(self, ty, e) + } + + fn fold_unary_expression< + In: Expr<'ast, T> + PartialEq + ResultFold, + E: Expr<'ast, T> + PartialEq + ResultFold, + Op, + >( + &mut self, + ty: &E::Ty, + e: UnaryExpression, + ) -> Result, Self::Error> { + fold_unary_expression(self, ty, e) + } + + fn fold_select_expression< + E: Clone + Expr<'ast, T> + ResultFold + Select<'ast, T>, + >( &mut self, ty: &E::Ty, e: SelectExpression<'ast, T, E>, @@ -137,6 +247,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_field_expression(self, e) } + fn fold_field_expression_cases( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_field_expression_cases(self, e) + } + fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, @@ -144,6 +261,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_boolean_expression(self, e) } + fn fold_boolean_expression_cases( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_boolean_expression_cases(self, e) + } + fn fold_uint_expression( &mut self, e: UExpression<'ast, T>, @@ -158,102 +282,213 @@ pub trait ResultFolder<'ast, T: Field>: Sized { ) -> Result, Self::Error> { fold_uint_expression_inner(self, bitwidth, e) } + + fn fold_uint_expression_cases( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression_cases(self, bitwidth, e) + } +} + +pub fn fold_assembly_assignment<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyAssignment<'ast, T>, +) -> Result>, F::Error> { + let assignees = s + .assignee + .into_iter() + .map(|a| f.fold_assignee(a)) + .collect::>()?; + let expression = f.fold_function(s.expression)?; + Ok(vec![ZirAssemblyStatement::assignment( + assignees, expression, + )]) +} + +pub fn fold_assembly_constraint<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyConstraint<'ast, T>, +) -> Result>, F::Error> { + let left = f.fold_field_expression(s.left)?; + let right = f.fold_field_expression(s.right)?; + Ok(vec![ZirAssemblyStatement::constraint( + left, right, s.metadata, + )]) } -pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + +fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: ZirAssemblyStatement<'ast, T>, ) -> Result>, F::Error> { - Ok(match s { - ZirAssemblyStatement::Assignment(assignees, function) => { - let assignees = assignees - .into_iter() - .map(|a| f.fold_assignee(a)) - .collect::>()?; - let function = f.fold_function(function)?; - vec![ZirAssemblyStatement::Assignment(assignees, function)] - } - ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = f.fold_field_expression(lhs)?; - let rhs = f.fold_field_expression(rhs)?; - vec![ZirAssemblyStatement::Constraint(lhs, rhs, metadata)] - } - }) + let span = s.get_span(); + f.fold_assembly_statement_cases(s) + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_assembly_statement_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ZirAssemblyStatement<'ast, T>, +) -> Result>, F::Error> { + match s { + ZirAssemblyStatement::Assignment(s) => f.fold_assembly_assignment(s), + ZirAssemblyStatement::Constraint(s) => f.fold_assembly_constraint(s), + } } pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, s: ZirStatement<'ast, T>, ) -> Result>, F::Error> { - let res = match s { - ZirStatement::Return(expressions) => ZirStatement::Return( - expressions + let span = s.get_span(); + f.fold_statement_cases(s) + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_statement_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ZirStatement<'ast, T>, +) -> Result>, F::Error> { + let span = s.get_span(); + + match s { + ZirStatement::Return(s) => f.fold_return_statement(s), + ZirStatement::Definition(s) => f.fold_definition_statement(s), + ZirStatement::IfElse(s) => f.fold_if_else_statement(s), + ZirStatement::Assertion(s) => f.fold_assertion_statement(s), + ZirStatement::MultipleDefinition(s) => f.fold_multiple_definition_statement(s), + ZirStatement::Log(s) => f.fold_log_statement(s), + ZirStatement::Assembly(s) => f.fold_assembly_block(s), + } + .map(|s| s.into_iter().map(|s| s.span(span)).collect()) +} + +pub fn fold_return_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ReturnStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![ZirStatement::Return( + ReturnStatement::new( + s.inner .into_iter() .map(|e| f.fold_expression(e)) .collect::>()?, - ), - ZirStatement::Definition(a, e) => { - ZirStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?) - } - ZirStatement::IfElse(condition, consequence, alternative) => ZirStatement::IfElse( - f.fold_boolean_expression(condition)?, - consequence + ) + .span(s.span), + )]) +} + +pub fn fold_definition_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: DefinitionStatement<'ast, T>, +) -> Result>, F::Error> { + let rhs = f.fold_expression(s.rhs)?; + Ok(vec![ZirStatement::Definition( + DefinitionStatement::new(f.fold_assignee(s.assignee)?, rhs).span(s.span), + )]) +} + +pub fn fold_if_else_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: IfElseStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![ZirStatement::IfElse( + IfElseStatement::new( + f.fold_boolean_expression(s.condition)?, + s.consequence .into_iter() .map(|s| f.fold_statement(s)) .collect::, _>>()? .into_iter() .flatten() .collect(), - alternative + s.alternative .into_iter() .map(|s| f.fold_statement(s)) .collect::, _>>()? .into_iter() .flatten() .collect(), - ), - ZirStatement::Assertion(e, error) => { - ZirStatement::Assertion(f.fold_boolean_expression(e)?, error) - } - ZirStatement::MultipleDefinition(variables, elist) => ZirStatement::MultipleDefinition( - variables + ) + .span(s.span), + )]) +} + +pub fn fold_assertion_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssertionStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![ZirStatement::Assertion( + AssertionStatement::new(f.fold_boolean_expression(s.expression)?, s.error).span(s.span), + )]) +} + +pub fn fold_log_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: LogStatement<'ast, T>, +) -> Result>, F::Error> { + let expressions = s + .expressions + .into_iter() + .map(|(t, e)| { + e.into_iter() + .map(|e| f.fold_expression(e)) + .collect::, _>>() + .map(|e| (t, e)) + }) + .collect::, _>>()?; + Ok(vec![ZirStatement::Log(LogStatement::new( + s.format_string, + expressions, + ))]) +} + +pub fn fold_assembly_block<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: AssemblyBlockStatement<'ast, T>, +) -> Result>, F::Error> { + Ok(vec![ZirStatement::Assembly(AssemblyBlockStatement::new( + s.inner + .into_iter() + .map(|s| f.fold_assembly_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ))]) +} + +pub fn fold_multiple_definition_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: MultipleDefinitionStatement<'ast, T>, +) -> Result>, F::Error> { + let expression_list = f.fold_expression_list(s.rhs)?; + Ok(vec![ZirStatement::MultipleDefinition( + MultipleDefinitionStatement::new( + s.assignees .into_iter() .map(|v| f.fold_assignee(v)) .collect::>()?, - f.fold_expression_list(elist)?, + expression_list, ), - ZirStatement::Log(l, e) => { - let e = e - .into_iter() - .map(|(t, e)| { - e.into_iter() - .map(|e| f.fold_expression(e)) - .collect::, _>>() - .map(|e| (t, e)) - }) - .collect::, _>>()?; - - ZirStatement::Log(l, e) - } - ZirStatement::Assembly(statements) => { - let statements = statements - .into_iter() - .map(|s| f.fold_assembly_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(); - ZirStatement::Assembly(statements) - } - }; - Ok(vec![res]) + )]) +} + +fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_field_expression_cases(e).map(|e| e.span(span)) } -pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( +pub fn fold_field_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: FieldElementExpression<'ast, T>, ) -> Result, F::Error> { Ok(match e { - FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Value(n) => FieldElementExpression::Value(n), FieldElementExpression::Identifier(id) => { match f.fold_identifier_expression(&Type::FieldElement, id)? { IdentifierOrExpression::Identifier(i) => FieldElementExpression::Identifier(i), @@ -266,60 +501,63 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( SelectOrExpression::Expression(u) => u, } } - FieldElementExpression::Add(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldElementExpression::Add(box e1, box e2) - } - FieldElementExpression::Sub(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldElementExpression::Sub(box e1, box e2) + FieldElementExpression::Add(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Add(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Mult(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldElementExpression::Mult(box e1, box e2) + FieldElementExpression::Sub(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Sub(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Div(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - FieldElementExpression::Div(box e1, box e2) + FieldElementExpression::Mult(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Mult(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Pow(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - FieldElementExpression::Pow(box e1, box e2) + FieldElementExpression::Div(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Div(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Xor(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - FieldElementExpression::Xor(box left, box right) + FieldElementExpression::Pow(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Pow(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::And(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - FieldElementExpression::And(box left, box right) + FieldElementExpression::And(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::And(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::Or(box left, box right) => { - let left = f.fold_field_expression(left)?; - let right = f.fold_field_expression(right)?; - - FieldElementExpression::Or(box left, box right) + FieldElementExpression::Or(e) => match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Or(e), + BinaryOrExpression::Expression(e) => e, + }, + FieldElementExpression::Xor(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::Xor(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::LeftShift(box e, box by) => { - let e = f.fold_field_expression(e)?; - let by = f.fold_uint_expression(by)?; - - FieldElementExpression::LeftShift(box e, box by) + FieldElementExpression::LeftShift(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::LeftShift(e), + BinaryOrExpression::Expression(e) => e, + } } - FieldElementExpression::RightShift(box e, box by) => { - let e = f.fold_field_expression(e)?; - let by = f.fold_uint_expression(by)?; - - FieldElementExpression::RightShift(box e, box by) + FieldElementExpression::RightShift(e) => { + match f.fold_binary_expression(&Type::FieldElement, e)? { + BinaryOrExpression::Binary(e) => FieldElementExpression::RightShift(e), + BinaryOrExpression::Expression(e) => e, + } } FieldElementExpression::Conditional(c) => { match f.fold_conditional_expression(&Type::FieldElement, c)? { @@ -330,10 +568,20 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_boolean_expression_cases(e).map(|e| e.span(span)) +} + +pub fn fold_boolean_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, ) -> Result, F::Error> { + use BooleanExpression::*; + Ok(match e { BooleanExpression::Value(v) => BooleanExpression::Value(v), BooleanExpression::Identifier(id) => { @@ -346,55 +594,46 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( SelectOrExpression::Select(s) => BooleanExpression::Select(s), SelectOrExpression::Expression(u) => u, }, - BooleanExpression::FieldEq(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - BooleanExpression::FieldEq(box e1, box e2) - } - BooleanExpression::BoolEq(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - BooleanExpression::BoolEq(box e1, box e2) - } - BooleanExpression::UintEq(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - BooleanExpression::UintEq(box e1, box e2) - } - BooleanExpression::FieldLt(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - BooleanExpression::FieldLt(box e1, box e2) - } - BooleanExpression::UintLt(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - BooleanExpression::UintLt(box e1, box e2) - } - BooleanExpression::FieldLe(box e1, box e2) => { - let e1 = f.fold_field_expression(e1)?; - let e2 = f.fold_field_expression(e2)?; - BooleanExpression::FieldLe(box e1, box e2) - } - BooleanExpression::UintLe(box e1, box e2) => { - let e1 = f.fold_uint_expression(e1)?; - let e2 = f.fold_uint_expression(e2)?; - BooleanExpression::UintLe(box e1, box e2) - } - BooleanExpression::Or(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - BooleanExpression::Or(box e1, box e2) - } - BooleanExpression::And(box e1, box e2) => { - let e1 = f.fold_boolean_expression(e1)?; - let e2 = f.fold_boolean_expression(e2)?; - BooleanExpression::And(box e1, box e2) - } - BooleanExpression::Not(box e) => { - let e = f.fold_boolean_expression(e)?; - BooleanExpression::Not(box e) - } + FieldEq(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldEq(e), + BinaryOrExpression::Expression(u) => u, + }, + BoolEq(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => BoolEq(e), + BinaryOrExpression::Expression(u) => u, + }, + UintEq(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintEq(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLt(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldLt(e), + BinaryOrExpression::Expression(u) => u, + }, + FieldLe(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => FieldLe(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLt(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintLt(e), + BinaryOrExpression::Expression(u) => u, + }, + UintLe(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => UintLe(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&Type::Boolean, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&Type::Boolean, e)? { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, BooleanExpression::Conditional(c) => { match f.fold_conditional_expression(&Type::Boolean, c)? { ConditionalOrExpression::Conditional(s) => BooleanExpression::Conditional(s), @@ -414,86 +653,78 @@ pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( +fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> Result, F::Error> { + let span = e.get_span(); + f.fold_uint_expression_cases(ty, e).map(|e| e.span(span)) +} + +pub fn fold_uint_expression_cases<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, F::Error> { + use UExpressionInner::*; + Ok(match e { - UExpressionInner::Value(v) => UExpressionInner::Value(v), - UExpressionInner::Identifier(id) => match f.fold_identifier_expression(&ty, id)? { + Value(v) => UExpressionInner::Value(v), + Identifier(id) => match f.fold_identifier_expression(&ty, id)? { IdentifierOrExpression::Identifier(i) => UExpressionInner::Identifier(i), IdentifierOrExpression::Expression(e) => e, }, - UExpressionInner::Select(e) => match f.fold_select_expression(&ty, e)? { + Select(e) => match f.fold_select_expression(&ty, e)? { SelectOrExpression::Select(s) => UExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, }, - UExpressionInner::Add(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Add(box left, box right) - } - UExpressionInner::Sub(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Sub(box left, box right) - } - UExpressionInner::Mult(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Mult(box left, box right) - } - UExpressionInner::Div(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Div(box left, box right) - } - UExpressionInner::Rem(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Rem(box left, box right) - } - UExpressionInner::Xor(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Xor(box left, box right) - } - UExpressionInner::And(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::And(box left, box right) - } - UExpressionInner::Or(box left, box right) => { - let left = f.fold_uint_expression(left)?; - let right = f.fold_uint_expression(right)?; - - UExpressionInner::Or(box left, box right) - } - UExpressionInner::LeftShift(box e, by) => { - let e = f.fold_uint_expression(e)?; - - UExpressionInner::LeftShift(box e, by) - } - UExpressionInner::RightShift(box e, by) => { - let e = f.fold_uint_expression(e)?; - - UExpressionInner::RightShift(box e, by) - } - UExpressionInner::Not(box e) => { - let e = f.fold_uint_expression(e)?; - - UExpressionInner::Not(box e) - } - UExpressionInner::Conditional(c) => match f.fold_conditional_expression(&ty, c)? { - ConditionalOrExpression::Conditional(s) => UExpressionInner::Conditional(s), + Add(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Add(e), + BinaryOrExpression::Expression(u) => u, + }, + Sub(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Sub(e), + BinaryOrExpression::Expression(u) => u, + }, + Mult(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Mult(e), + BinaryOrExpression::Expression(u) => u, + }, + Div(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Div(e), + BinaryOrExpression::Expression(u) => u, + }, + Rem(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Rem(e), + BinaryOrExpression::Expression(u) => u, + }, + Xor(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Xor(e), + BinaryOrExpression::Expression(u) => u, + }, + And(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => And(e), + BinaryOrExpression::Expression(u) => u, + }, + Or(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => Or(e), + BinaryOrExpression::Expression(u) => u, + }, + LeftShift(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => LeftShift(e), + BinaryOrExpression::Expression(u) => u, + }, + RightShift(e) => match f.fold_binary_expression(&ty, e)? { + BinaryOrExpression::Binary(e) => RightShift(e), + BinaryOrExpression::Expression(u) => u, + }, + Not(e) => match f.fold_unary_expression(&ty, e)? { + UnaryOrExpression::Unary(e) => Not(e), + UnaryOrExpression::Expression(u) => u, + }, + Conditional(c) => match f.fold_conditional_expression(&ty, c)? { + ConditionalOrExpression::Conditional(s) => Conditional(s), ConditionalOrExpression::Expression(u) => u, }, }) @@ -527,6 +758,7 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result, F::Error> { Ok(ZirProgram { main: f.fold_function(p.main)?, + ..p }) } @@ -539,7 +771,7 @@ pub fn fold_identifier_expression< f: &mut F, _: &E::Ty, e: IdentifierExpression<'ast, E>, -) -> Result, F::Error> { +) -> Result, E, E::Inner>, F::Error> { Ok(IdentifierOrExpression::Identifier( IdentifierExpression::new(f.fold_name(e.id)?), )) @@ -548,7 +780,7 @@ pub fn fold_identifier_expression< pub fn fold_conditional_expression< 'ast, T: Field, - E: Expr<'ast, T> + ResultFold<'ast, T> + Conditional<'ast, T>, + E: Expr<'ast, T> + ResultFold + Conditional<'ast, T>, F: ResultFolder<'ast, T>, >( f: &mut F, @@ -567,7 +799,7 @@ pub fn fold_conditional_expression< pub fn fold_select_expression< 'ast, T: Field, - E: Expr<'ast, T> + ResultFold<'ast, T> + Select<'ast, T>, + E: Expr<'ast, T> + ResultFold + Select<'ast, T>, F: ResultFolder<'ast, T>, >( f: &mut F, @@ -582,3 +814,39 @@ pub fn fold_select_expression< e.index.fold(f)?, ))) } + +#[allow(clippy::type_complexity)] +pub fn fold_binary_expression< + 'ast, + T: Field, + L: Expr<'ast, T> + PartialEq + ResultFold + From>, + R: Expr<'ast, T> + PartialEq + ResultFold + From>, + E: Expr<'ast, T> + PartialEq + ResultFold + From>, + F: ResultFolder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: BinaryExpression, +) -> Result, F::Error> { + Ok(BinaryOrExpression::Binary( + BinaryExpression::new(e.left.fold(f)?, e.right.fold(f)?).span(e.span), + )) +} + +pub fn fold_unary_expression< + 'ast, + T: Field, + In: Expr<'ast, T> + PartialEq + ResultFold + From>, + E: Expr<'ast, T> + PartialEq + ResultFold + From>, + F: ResultFolder<'ast, T>, + Op, +>( + f: &mut F, + _: &E::Ty, + e: UnaryExpression, +) -> Result, F::Error> { + Ok(UnaryOrExpression::Unary( + UnaryExpression::new(e.inner.fold(f)?).span(e.span), + )) +} diff --git a/zokrates_ast/src/zir/uint.rs b/zokrates_ast/src/zir/uint.rs index 9ae30d40e..3ace2a152 100644 --- a/zokrates_ast/src/zir/uint.rs +++ b/zokrates_ast/src/zir/uint.rs @@ -1,5 +1,7 @@ -use crate::zir::types::UBitwidth; +use crate::common::expressions::{UValueExpression, UnaryExpression, ValueExpression}; use crate::zir::IdentifierExpression; +use crate::{common::expressions::BinaryExpression, common::operators::*, zir::types::UBitwidth}; +use derivative::Derivative; use serde::{Deserialize, Serialize}; use zokrates_field::Field; @@ -10,14 +12,14 @@ impl<'ast, T: Field> UExpression<'ast, T> { pub fn add(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Add(box self, box other).annotate(bitwidth) + UExpressionInner::Add(BinaryExpression::new(self, other)).annotate(bitwidth) } #[allow(clippy::should_implement_trait)] pub fn sub(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Sub(box self, box other).annotate(bitwidth) + UExpressionInner::Sub(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn select(values: Vec, index: Self) -> UExpression<'ast, T> { @@ -28,67 +30,67 @@ impl<'ast, T: Field> UExpression<'ast, T> { pub fn mult(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Mult(box self, box other).annotate(bitwidth) + UExpressionInner::Mult(BinaryExpression::new(self, other)).annotate(bitwidth) } #[allow(clippy::should_implement_trait)] pub fn div(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Div(box self, box other).annotate(bitwidth) + UExpressionInner::Div(BinaryExpression::new(self, other)).annotate(bitwidth) } #[allow(clippy::should_implement_trait)] pub fn rem(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Rem(box self, box other).annotate(bitwidth) + UExpressionInner::Rem(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn xor(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Xor(box self, box other).annotate(bitwidth) + UExpressionInner::Xor(BinaryExpression::new(self, other)).annotate(bitwidth) } #[allow(clippy::should_implement_trait)] pub fn not(self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::Not(box self).annotate(bitwidth) + UExpressionInner::Not(UnaryExpression::new(self)).annotate(bitwidth) } pub fn or(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Or(box self, box other).annotate(bitwidth) + UExpressionInner::Or(BinaryExpression::new(self, other)).annotate(bitwidth) } pub fn and(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::And(box self, box other).annotate(bitwidth) + UExpressionInner::And(BinaryExpression::new(self, other)).annotate(bitwidth) } - pub fn left_shift(self, by: u32) -> UExpression<'ast, T> { + pub fn left_shift(self, by: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::LeftShift(box self, by).annotate(bitwidth) + UExpressionInner::LeftShift(BinaryExpression::new(self, by)).annotate(bitwidth) } - pub fn right_shift(self, by: u32) -> UExpression<'ast, T> { + pub fn right_shift(self, by: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - UExpressionInner::RightShift(box self, by).annotate(bitwidth) + UExpressionInner::RightShift(BinaryExpression::new(self, by)).annotate(bitwidth) } } -impl<'ast, T: Field> From for UExpressionInner<'ast, T> { +impl<'ast, T> From for UExpressionInner<'ast, T> { fn from(e: u128) -> Self { - UExpressionInner::Value(e) + UExpressionInner::Value(ValueExpression::new(e)) } } impl<'ast, T> From for UExpression<'ast, T> { fn from(u: u32) -> Self { - UExpressionInner::Value(u as u128).annotate(UBitwidth::B32) + UExpressionInner::from(u as u128).annotate(UBitwidth::B32) } } @@ -151,7 +153,7 @@ impl UMetadata { } pub fn bitwidth(&self) -> u32 { - self.max.bits() as u32 + self.max.bits() } // issue the metadata for a parameter of a given bitwidth @@ -163,7 +165,9 @@ impl UMetadata { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct UExpression<'ast, T> { pub bitwidth: UBitwidth, pub metadata: Option>, @@ -171,23 +175,29 @@ pub struct UExpression<'ast, T> { pub inner: UExpressionInner<'ast, T>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Derivative)] +#[derivative(PartialEq, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum UExpressionInner<'ast, T> { - Value(u128), + Value(UValueExpression), #[serde(borrow)] Identifier(IdentifierExpression<'ast, UExpression<'ast, T>>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>), - Add(Box>, Box>), - Sub(Box>, Box>), - Mult(Box>, Box>), - Div(Box>, Box>), - Rem(Box>, Box>), - Xor(Box>, Box>), - And(Box>, Box>), - Or(Box>, Box>), - LeftShift(Box>, u32), - RightShift(Box>, u32), - Not(Box>), + Add(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Sub(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Mult(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Div(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Rem(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Xor(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + And(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + Or(BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>), + LeftShift( + BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>, + ), + RightShift( + BinaryExpression, UExpression<'ast, T>, UExpression<'ast, T>>, + ), + Not(UnaryExpression, UExpression<'ast, T>>), Conditional(ConditionalExpression<'ast, T, UExpression<'ast, T>>), } diff --git a/zokrates_ast/src/zir/variable.rs b/zokrates_ast/src/zir/variable.rs index 14d329727..62ee8bb12 100644 --- a/zokrates_ast/src/zir/variable.rs +++ b/zokrates_ast/src/zir/variable.rs @@ -1,14 +1,7 @@ use crate::zir::types::{Type, UBitwidth}; use crate::zir::Identifier; -use serde::{Deserialize, Serialize}; -use std::fmt; -#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize)] -pub struct Variable<'ast> { - #[serde(borrow)] - pub id: Identifier<'ast>, - pub _type: Type, -} +pub type Variable<'ast> = crate::common::Variable, Type>; impl<'ast> Variable<'ast> { pub fn field_element>>(id: I) -> Variable<'ast> { @@ -23,26 +16,11 @@ impl<'ast> Variable<'ast> { Self::with_id_and_type(id, Type::uint(bitwidth)) } - pub fn with_id_and_type>>(id: I, _type: Type) -> Variable<'ast> { - Variable { - id: id.into(), - _type, - } + pub fn with_id_and_type>>(id: I, ty: Type) -> Variable<'ast> { + Variable::new(id.into(), ty) } pub fn get_type(&self) -> Type { - self._type.clone() - } -} - -impl<'ast> fmt::Display for Variable<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {}", self._type, self.id,) - } -} - -impl<'ast> fmt::Debug for Variable<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Variable(type: {:?}, id: {:?})", self._type, self.id,) + self.ty.clone() } } diff --git a/zokrates_bellman/Cargo.toml b/zokrates_bellman/Cargo.toml index 9d8fff446..0c832a678 100644 --- a/zokrates_bellman/Cargo.toml +++ b/zokrates_bellman/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_bellman" -version = "0.1.1" +version = "0.1.2" edition = "2021" [features] @@ -9,7 +9,7 @@ wasm = ["bellman/nolog", "bellman/wasm"] multicore = ["bellman/multicore", "phase2/multicore"] [dependencies] -zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } +zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false, features = ["bellman"] } zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } zokrates_proof_systems = { version = "0.1", path = "../zokrates_proof_systems", default-features = false } diff --git a/zokrates_bellman/src/groth16.rs b/zokrates_bellman/src/groth16.rs index d0e6c4178..2efabd991 100644 --- a/zokrates_bellman/src/groth16.rs +++ b/zokrates_bellman/src/groth16.rs @@ -20,14 +20,19 @@ use zokrates_proof_systems::groth16::{ProofPoints, VerificationKey, G16}; use zokrates_proof_systems::Scheme; impl Backend for Bellman { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: Read, + G: RngCore + CryptoRng, + >( program: ProgIterator<'a, T, I>, witness: Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof { let computation = Computation::with_witness(program, witness); - let params = Parameters::read(proving_key.as_slice(), true).unwrap(); + let params = Parameters::read(proving_key, true).unwrap(); let public_inputs: Vec = computation .public_inputs_values() @@ -108,7 +113,7 @@ impl MpcBackend for Bellman { } fn contribute( - params: &mut R, + params: R, rng: &mut G, output: &mut W, ) -> Result<[u8; 64], String> { @@ -124,8 +129,8 @@ impl MpcBackend for Bellman { Ok(hash) } - fn verify<'a, P: Read, R: Read, I: IntoIterator>>( - params: &mut P, + fn verify<'a, R: Read, I: IntoIterator>>( + params: R, program: ProgIterator<'a, T, I>, phase1_radix: &mut R, ) -> Result, String> { @@ -140,7 +145,7 @@ impl MpcBackend for Bellman { Ok(hashes) } - fn export_keypair(params: &mut R) -> Result, String> { + fn export_keypair(params: R) -> Result, String> { let params = MPCParameters::::read(params, true).map_err(|e| e.to_string())?; @@ -205,15 +210,21 @@ mod tests { use zokrates_interpreter::Interpreter; use super::*; - use zokrates_ast::common::{Parameter, Variable}; - use zokrates_ast::ir::{Prog, Statement}; + use zokrates_ast::common::flat::Parameter; + use zokrates_ast::ir::{Prog, Statement, Variable}; #[test] fn verify() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], + solvers: vec![], }; let rng = &mut StdRng::from_entropy(); @@ -226,7 +237,10 @@ mod tests { .unwrap(); let proof = >::generate_proof( - program, witness, keypair.pk, rng, + program, + witness, + keypair.pk.as_slice(), + rng, ); let ans = >::verify(keypair.vk, proof); diff --git a/zokrates_bellman/src/lib.rs b/zokrates_bellman/src/lib.rs index 9828b9745..0f8a7d511 100644 --- a/zokrates_bellman/src/lib.rs +++ b/zokrates_bellman/src/lib.rs @@ -10,8 +10,8 @@ use bellman::{ Circuit, ConstraintSystem, LinearCombination, SynthesisError, Variable as BellmanVariable, }; use std::collections::BTreeMap; -use zokrates_ast::common::Variable; -use zokrates_ast::ir::{CanonicalLinComb, ProgIterator, Statement, Witness}; +use zokrates_ast::common::flat::Variable; +use zokrates_ast::ir::{LinComb, ProgIterator, Statement, Witness}; use zokrates_field::BellmanFieldExtensions; use zokrates_field::Field; @@ -45,12 +45,13 @@ impl<'a, T: Field, I: IntoIterator>> Computation<'a, T, } fn bellman_combination>( - l: CanonicalLinComb, + l: LinComb, cs: &mut CS, symbols: &mut BTreeMap, witness: &mut Witness, ) -> LinearCombination { - l.0.into_iter() + l.value + .into_iter() .map(|(k, v)| { ( v.into_bellman(), @@ -126,20 +127,10 @@ impl<'a, T: BellmanFieldExtensions + Field, I: IntoIterator = Prog { + module_map: Default::default(), arguments: vec![Parameter::private(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -294,9 +291,15 @@ mod tests { #[test] fn public_identity() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::public(0), + None, + )], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -315,9 +318,15 @@ mod tests { #[test] fn no_arguments() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![], return_count: 1, - statements: vec![Statement::constraint(Variable::one(), Variable::public(0))], + statements: vec![Statement::constraint( + Variable::one(), + Variable::public(0), + None, + )], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -335,6 +344,7 @@ mod tests { // public variables must be ordered from 0 // private variables can be unordered let program: Prog = Prog { + module_map: Default::default(), arguments: vec![ Parameter::private(Variable::new(42)), Parameter::public(Variable::new(51)), @@ -344,12 +354,15 @@ mod tests { Statement::constraint( LinComb::from(Variable::new(42)) + LinComb::from(Variable::new(51)), Variable::public(0), + None, ), Statement::constraint( LinComb::from(Variable::one()) + LinComb::from(Variable::new(42)), Variable::public(1), + None, ), ], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -367,12 +380,15 @@ mod tests { #[test] fn one() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(42))], return_count: 1, statements: vec![Statement::constraint( LinComb::from(Variable::new(42)) + LinComb::one(), Variable::public(0), + None, )], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -391,6 +407,7 @@ mod tests { #[test] fn with_directives() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![ Parameter::private(Variable::new(42)), Parameter::public(Variable::new(51)), @@ -399,7 +416,9 @@ mod tests { statements: vec![Statement::constraint( LinComb::from(Variable::new(42)) + LinComb::from(Variable::new(51)), Variable::public(0), + None, )], + solvers: vec![], }; let interpreter = Interpreter::default(); diff --git a/zokrates_book/src/examples/sha256example.md b/zokrates_book/src/examples/sha256example.md index e970d7dc2..a933b8cb8 100644 --- a/zokrates_book/src/examples/sha256example.md +++ b/zokrates_book/src/examples/sha256example.md @@ -43,20 +43,14 @@ As a next step we can create a witness file using the following command: Using the flag `-a` we pass arguments to the program. Recall that our goal is to compute the hash for the number `5`. Consequently we set `a`, `b` and `c` to `0` and `d` to `5`. -Still here? Great! At this point, we can check the `witness` file for the return values: +Still here? Great! At this point we can check the return values. We should see the following output: ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:13}} +Witness: +["263561599766550617289250058199814760685","65303172752238645975888084098459749904"] ``` -which should lead to the following output: - -```sh -~out_0 263561599766550617289250058199814760685 -~out_1 65303172752238645975888084098459749904 -``` - -Hence, by concatenating the outputs as 128 bit numbers, we arrive at the following value as the hash for our selected pre-image : +By concatenating the outputs as 128 bit numbers, we arrive at the following value as the hash for our selected pre-image : `0xc6481e22c5ff4164af680b8cfaa5e8ed3120eeff89c4f307c4a6faaae059ce10` ## Prove knowledge of pre-image @@ -78,13 +72,13 @@ Note that we now compare the result of `sha256packed` with the hard-coded correc So, having defined the program, Victor is now ready to compile the code: ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:17}} +{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:15}} ``` Based on that Victor can run the setup phase and export a verifier smart contract as a Solidity file: ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:18:19}} +{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:16:17}} ``` `setup` creates a `verification.key` file and a `proving.key` file. Victor gives the proving key to Peggy. @@ -94,13 +88,13 @@ Based on that Victor can run the setup phase and export a verifier smart contrac Peggy provides the correct pre-image as an argument to the program. ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:20}} +{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:18}} ``` Finally, Peggy can run the command to construct the proof: ``` -{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:21}} +{{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:19}} ``` As the inputs were declared as private in the program, they do not appear in the proof thanks to the zero-knowledge property of the protocol. diff --git a/zokrates_book/src/gettingstarted.md b/zokrates_book/src/gettingstarted.md index 93caf96c2..d0aa512f3 100644 --- a/zokrates_book/src/gettingstarted.md +++ b/zokrates_book/src/gettingstarted.md @@ -25,7 +25,7 @@ You can build ZoKrates from [source](https://github.com/ZoKrates/ZoKrates/) with git clone https://github.com/ZoKrates/ZoKrates cd ZoKrates export ZOKRATES_STDLIB=$PWD/zokrates_stdlib/stdlib -cargo +nightly build -p zokrates_cli --release +cargo build -p zokrates_cli --release cd target/release ``` diff --git a/zokrates_circom/Cargo.toml b/zokrates_circom/Cargo.toml index 6a0118e31..04de4d523 100644 --- a/zokrates_circom/Cargo.toml +++ b/zokrates_circom/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_circom" -version = "0.1.2" +version = "0.1.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/zokrates_circom/src/lib.rs b/zokrates_circom/src/lib.rs index 9b16742fb..9960fc19c 100644 --- a/zokrates_circom/src/lib.rs +++ b/zokrates_circom/src/lib.rs @@ -24,26 +24,28 @@ mod tests { #[test] fn setup_and_prove() { let prog: Prog = Prog { + module_map: Default::default(), arguments: vec![ Parameter::private(Variable::new(0)), Parameter::public(Variable::new(1)), ], return_count: 1, statements: vec![ - Statement::Constraint( - QuadComb::from_linear_combinations( + Statement::constraint( + QuadComb::new( LinComb::from(Variable::new(0)), LinComb::from(Variable::new(0)), ), LinComb::from(Variable::new(0)), None, ), - Statement::Constraint( - (LinComb::from(Variable::new(0)) + LinComb::from(Variable::new(1))).into(), - Variable::public(0).into(), + Statement::constraint( + LinComb::from(Variable::new(0)) + LinComb::from(Variable::new(1)), + Variable::public(0), None, ), ], + solvers: vec![], }; let mut r1cs = vec![]; diff --git a/zokrates_circom/src/r1cs.rs b/zokrates_circom/src/r1cs.rs index 854bc0eab..00c0d8dea 100644 --- a/zokrates_circom/src/r1cs.rs +++ b/zokrates_circom/src/r1cs.rs @@ -69,19 +69,19 @@ pub fn r1cs_program(prog: Prog) -> (Vec, usize, Vec Some((quad, lin)), + for s in prog.statements.iter().filter_map(|s| match s { + Statement::Constraint(s) => Some(s), Statement::Directive(..) => None, Statement::Block(..) => unreachable!(), Statement::Log(..) => None, }) { - for (k, _) in &quad.left.0 { + for (k, _) in &s.quad.left.value { ordered_variables_set.insert(k); } - for (k, _) in &quad.right.0 { + for (k, _) in &s.quad.right.value { ordered_variables_set.insert(k); } - for (k, _) in &lin.0 { + for (k, _) in &s.lin.value { ordered_variables_set.insert(k); } } @@ -95,23 +95,23 @@ pub fn r1cs_program(prog: Prog) -> (Vec, usize, Vec Some((quad, lin)), Statement::Block(..) => unreachable!(), Statement::Directive(..) => None, Statement::Log(..) => None, + Statement::Constraint(s) => Some((s.quad, s.lin)), }) { constraints.push(( quad.left - .0 + .value .into_iter() .map(|(k, v)| (*variables.get(&k).unwrap(), v)) .collect(), quad.right - .0 + .value .into_iter() .map(|(k, v)| (*variables.get(&k).unwrap(), v)) .collect(), - lin.0 + lin.value .into_iter() .map(|(k, v)| (*variables.get(&k).unwrap(), v)) .collect(), @@ -289,13 +289,15 @@ mod tests { #[test] fn return_one() { let prog: Prog = Prog { + module_map: Default::default(), arguments: vec![], return_count: 1, - statements: vec![Statement::Constraint( - LinComb::one().into(), - Variable::public(0).into(), + statements: vec![Statement::constraint( + LinComb::one(), + Variable::public(0), None, )], + solvers: vec![], }; let mut buf = Vec::new(); @@ -345,26 +347,28 @@ mod tests { #[test] fn with_inputs() { let prog: Prog = Prog { + module_map: Default::default(), arguments: vec![ Parameter::private(Variable::new(0)), Parameter::public(Variable::new(1)), ], return_count: 1, statements: vec![ - Statement::Constraint( - QuadComb::from_linear_combinations( + Statement::constraint( + QuadComb::new( LinComb::from(Variable::new(0)), LinComb::from(Variable::new(0)), ), LinComb::from(Variable::new(0)), None, ), - Statement::Constraint( - (LinComb::from(Variable::new(0)) + LinComb::from(Variable::new(1))).into(), - Variable::public(0).into(), + Statement::constraint( + LinComb::from(Variable::new(0)) + LinComb::from(Variable::new(1)), + Variable::public(0), None, ), ], + solvers: vec![], }; let mut buf = Vec::new(); diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index ccf98794d..8eda11008 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_cli" -version = "0.8.5" +version = "0.8.6" authors = ["Jacob Eberhardt ", "Dennis Kuhnert ", "Thibaut Schaeffer "] repository = "https://github.com/Zokrates/ZoKrates.git" edition = "2018" @@ -21,6 +21,7 @@ zokrates_field = { version = "0.5", path = "../zokrates_field", features = ["mul zokrates_abi = { version = "0.1", path = "../zokrates_abi" } zokrates_core = { version = "0.7", path = "../zokrates_core", default-features = false } zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } +zokrates_profiler = { version = "0.1", path = "../zokrates_profiler", default-features = false } zokrates_interpreter = { version = "0.1", path = "../zokrates_interpreter", default-features = false } zokrates_circom = { version = "0.1", path = "../zokrates_circom", default-features = false } zokrates_embed = { version = "0.1", path = "../zokrates_embed", features = ["multicore"] } diff --git a/zokrates_cli/examples/book/sha256_tutorial/test.sh b/zokrates_cli/examples/book/sha256_tutorial/test.sh index 664749ccf..07756593e 100755 --- a/zokrates_cli/examples/book/sha256_tutorial/test.sh +++ b/zokrates_cli/examples/book/sha256_tutorial/test.sh @@ -8,9 +8,7 @@ function zokrates() { } zokrates compile -i hashexample.zok -zokrates compute-witness -a 0 0 0 5 - -grep '~out' witness +zokrates compute-witness -a 0 0 0 5 --verbose cp -f hashexample_updated.zok hashexample.zok diff --git a/zokrates_cli/examples/sudoku/sudoku_checker.zok b/zokrates_cli/examples/sudoku/sudoku_checker.zok index d6b596f35..4c01df6ec 100644 --- a/zokrates_cli/examples/sudoku/sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/sudoku_checker.zok @@ -12,11 +12,11 @@ def countDuplicates(field e11, field e12, field e21, field e22) -> field { field mut duplicates = e11 == e12 ? 1 : 0; - duplicates = duplicates + e11 == e21 ? 1 : 0; - duplicates = duplicates + e11 == e22 ? 1 : 0; - duplicates = duplicates + e12 == e21 ? 1 : 0; - duplicates = duplicates + e12 == e22 ? 1 : 0; - duplicates = duplicates + e21 == e22 ? 1 : 0; + duplicates = duplicates + (e11 == e21 ? 1 : 0); + duplicates = duplicates + (e11 == e22 ? 1 : 0); + duplicates = duplicates + (e12 == e21 ? 1 : 0); + duplicates = duplicates + (e12 == e22 ? 1 : 0); + duplicates = duplicates + (e21 == e22 ? 1 : 0); return duplicates; } diff --git a/zokrates_cli/src/bin.rs b/zokrates_cli/src/bin.rs index a5f8a31f8..68e7d5ec4 100644 --- a/zokrates_cli/src/bin.rs +++ b/zokrates_cli/src/bin.rs @@ -1,5 +1,3 @@ -#![feature(panic_info_message)] -#![feature(backtrace)] // // @file bin.rs // @author Jacob Eberhardt @@ -57,7 +55,9 @@ fn cli() -> Result<(), String> { generate_smtlib2::subcommand(), print_proof::subcommand(), #[cfg(any(feature = "bellman", feature = "ark"))] - verify::subcommand()]) + verify::subcommand(), + profile::subcommand() + ]) .get_matches(); match matches.subcommand() { @@ -78,26 +78,14 @@ fn cli() -> Result<(), String> { ("print-proof", Some(sub_matches)) => print_proof::exec(sub_matches), #[cfg(any(feature = "bellman", feature = "ark"))] ("verify", Some(sub_matches)) => verify::exec(sub_matches), + ("profile", Some(sub_matches)) => profile::exec(sub_matches), _ => unreachable!(), } } fn panic_hook(pi: &std::panic::PanicInfo) { - let location = pi - .location() - .map(|l| format!("({})", l)) - .unwrap_or_default(); - - let message = pi - .message() - .map(|m| format!("{}", m)) - .or_else(|| pi.payload().downcast_ref::<&str>().map(|p| p.to_string())); - - if let Some(s) = message { - println!("{} {}", s, location); - } else { - println!("The compiler unexpectedly panicked {}", location); - } + println!("The compiler unexpectedly panicked"); + println!("{}", pi); #[cfg(debug_assertions)] { diff --git a/zokrates_cli/src/ops/compile.rs b/zokrates_cli/src/ops/compile.rs index 74d26b6f2..c7bc25a15 100644 --- a/zokrates_cli/src/ops/compile.rs +++ b/zokrates_cli/src/ops/compile.rs @@ -130,8 +130,8 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { let arena = Arena::new(); - let artifacts = - compile::(source, path, Some(&resolver), config, &arena).map_err(|e| { + let artifacts = compile::(source, path.clone(), Some(&resolver), config, &arena) + .map_err(|e| { format!( "Compilation failed:\n\n{}", e.0.iter() @@ -145,16 +145,24 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { // serialize flattened program and write to binary file log::debug!("Serialize program"); - let bin_output_file = File::create(&bin_output_path) + let bin_output_file = File::create(bin_output_path) .map_err(|why| format!("Could not create {}: {}", bin_output_path.display(), why))?; - let r1cs_output_file = File::create(&r1cs_output_path) + let r1cs_output_file = File::create(r1cs_output_path) .map_err(|why| format!("Could not create {}: {}", r1cs_output_path.display(), why))?; let mut bin_writer = BufWriter::new(bin_output_file); let mut r1cs_writer = BufWriter::new(r1cs_output_file); - let program_flattened = program_flattened.collect(); + let mut program_flattened = program_flattened.collect(); + + // hide user path + program_flattened.module_map = program_flattened + .module_map + .remap_prefix(path.parent().unwrap(), Path::new("")); + program_flattened.module_map = program_flattened + .module_map + .remap_prefix(Path::new(stdlib_path), Path::new("STDLIB")); write_r1cs(&mut r1cs_writer, program_flattened.clone()).unwrap(); @@ -162,7 +170,7 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { Ok(constraint_count) => { // serialize ABI spec and write to JSON file log::debug!("Serialize ABI"); - let abi_spec_file = File::create(&abi_spec_path) + let abi_spec_file = File::create(abi_spec_path) .map_err(|why| format!("Could not create {}: {}", abi_spec_path.display(), why))?; let mut writer = BufWriter::new(abi_spec_file); diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index ab2a959bd..466c88dfc 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -66,6 +66,10 @@ pub fn subcommand() -> App<'static, 'static> { .help("Read arguments from stdin") .conflicts_with("arguments") .required(false) + ).arg(Arg::with_name("json") + .long("json") + .help("Write witness in a json format for debugging purposes") + .required(false) ) } @@ -73,7 +77,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open {}: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open {}: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -102,7 +106,7 @@ fn cli_compute<'a, T: Field, I: Iterator>>( let signature = match is_abi { true => { let path = Path::new(sub_matches.value_of("abi-spec").unwrap()); - let file = File::open(&path) + let file = File::open(path) .map_err(|why| format!("Could not open {}: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -186,7 +190,7 @@ fn cli_compute<'a, T: Field, I: Iterator>>( // write witness to file let output_path = Path::new(sub_matches.value_of("output").unwrap()); - let output_file = File::create(&output_path) + let output_file = File::create(output_path) .map_err(|why| format!("Could not create {}: {}", output_path.display(), why))?; let writer = BufWriter::new(output_file); @@ -195,9 +199,22 @@ fn cli_compute<'a, T: Field, I: Iterator>>( .write(writer) .map_err(|why| format!("Could not save witness: {:?}", why))?; + // write witness in the json format + if sub_matches.is_present("json") { + let json_path = Path::new(sub_matches.value_of("output").unwrap()).with_extension("json"); + let json_file = File::create(&json_path) + .map_err(|why| format!("Could not create {}: {}", json_path.display(), why))?; + + let writer = BufWriter::new(json_file); + + witness + .write_json(writer) + .map_err(|why| format!("Could not save {}: {:?}", json_path.display(), why))?; + } + // write circom witness to file let wtns_path = Path::new(sub_matches.value_of("circom-witness").unwrap()); - let wtns_file = File::create(&wtns_path) + let wtns_file = File::create(wtns_path) .map_err(|why| format!("Could not create {}: {}", output_path.display(), why))?; let mut writer = BufWriter::new(wtns_file); diff --git a/zokrates_cli/src/ops/export_verifier.rs b/zokrates_cli/src/ops/export_verifier.rs index 08b35a44a..f343a7374 100644 --- a/zokrates_cli/src/ops/export_verifier.rs +++ b/zokrates_cli/src/ops/export_verifier.rs @@ -35,7 +35,7 @@ pub fn subcommand() -> App<'static, 'static> { pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { let vk_path = Path::new(sub_matches.value_of("input").unwrap()); - let vk_file = File::open(&vk_path) + let vk_file = File::open(vk_path) .map_err(|why| format!("Could not open {}: {}", vk_path.display(), why))?; // deserialize vk to JSON @@ -84,7 +84,7 @@ fn cli_export_verifier App<'static, 'static> { pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { let program_path = Path::new(sub_matches.value_of("input").unwrap()); - let program_file = File::open(&program_path) + let program_file = File::open(program_path) .map_err(|why| format!("Could not open {}: {}", program_path.display(), why))?; let mut reader = BufReader::new(program_file); @@ -160,30 +160,28 @@ fn cli_generate_proof< // deserialize witness let witness_path = Path::new(sub_matches.value_of("witness").unwrap()); - let witness_file = File::open(&witness_path) + let witness_file = File::open(witness_path) .map_err(|why| format!("Could not open {}: {}", witness_path.display(), why))?; - let witness = ir::Witness::read(witness_file) + let witness_reader = BufReader::new(witness_file); + + let witness = ir::Witness::read(witness_reader) .map_err(|why| format!("Could not load witness: {:?}", why))?; let pk_path = Path::new(sub_matches.value_of("proving-key-path").unwrap()); let proof_path = Path::new(sub_matches.value_of("proof-path").unwrap()); - let pk_file = File::open(&pk_path) + let pk_file = File::open(pk_path) .map_err(|why| format!("Could not open {}: {}", pk_path.display(), why))?; - let mut pk: Vec = Vec::new(); - let mut pk_reader = BufReader::new(pk_file); - pk_reader - .read_to_end(&mut pk) - .map_err(|why| format!("Could not read {}: {}", pk_path.display(), why))?; + let pk_reader = BufReader::new(pk_file); let mut rng = sub_matches .value_of("entropy") .map(get_rng_from_entropy) .unwrap_or_else(StdRng::from_entropy); - let proof = B::generate_proof(program, witness, pk, &mut rng); + let proof = B::generate_proof(program, witness, pk_reader, &mut rng); let mut proof_file = File::create(proof_path).unwrap(); let proof = diff --git a/zokrates_cli/src/ops/generate_smtlib2.rs b/zokrates_cli/src/ops/generate_smtlib2.rs index ac58f8e56..389372df9 100644 --- a/zokrates_cli/src/ops/generate_smtlib2.rs +++ b/zokrates_cli/src/ops/generate_smtlib2.rs @@ -35,7 +35,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open {}: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open {}: {}", path.display(), why))?; let mut reader = BufReader::new(file); diff --git a/zokrates_cli/src/ops/inspect.rs b/zokrates_cli/src/ops/inspect.rs index b8ca37545..1db259b6a 100644 --- a/zokrates_cli/src/ops/inspect.rs +++ b/zokrates_cli/src/ops/inspect.rs @@ -31,7 +31,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); diff --git a/zokrates_cli/src/ops/mod.rs b/zokrates_cli/src/ops/mod.rs index e82dd506a..cb21da334 100644 --- a/zokrates_cli/src/ops/mod.rs +++ b/zokrates_cli/src/ops/mod.rs @@ -9,6 +9,7 @@ pub mod inspect; #[cfg(feature = "bellman")] pub mod mpc; pub mod print_proof; +pub mod profile; #[cfg(any(feature = "bellman", feature = "ark"))] pub mod setup; #[cfg(feature = "ark")] diff --git a/zokrates_cli/src/ops/mpc/beacon.rs b/zokrates_cli/src/ops/mpc/beacon.rs index 619a95ed9..71412868b 100644 --- a/zokrates_cli/src/ops/mpc/beacon.rs +++ b/zokrates_cli/src/ops/mpc/beacon.rs @@ -73,7 +73,7 @@ fn cli_mpc_beacon, B: MpcBack ) -> Result<(), String> { let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -134,7 +134,7 @@ fn cli_mpc_beacon, B: MpcBack }; let output_path = Path::new(sub_matches.value_of("output").unwrap()); - let output_file = File::create(&output_path) + let output_file = File::create(output_path) .map_err(|why| format!("Could not create `{}`: {}", output_path.display(), why))?; let mut writer = BufWriter::new(output_file); diff --git a/zokrates_cli/src/ops/mpc/contribute.rs b/zokrates_cli/src/ops/mpc/contribute.rs index fcb7e5461..96128768c 100644 --- a/zokrates_cli/src/ops/mpc/contribute.rs +++ b/zokrates_cli/src/ops/mpc/contribute.rs @@ -70,12 +70,12 @@ pub fn cli_mpc_contribute< ) -> Result<(), String> { let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); let output_path = Path::new(sub_matches.value_of("output").unwrap()); - let output_file = File::create(&output_path) + let output_file = File::create(output_path) .map_err(|why| format!("Could not create `{}`: {}", output_path.display(), why))?; let mut writer = BufWriter::new(output_file); diff --git a/zokrates_cli/src/ops/mpc/export.rs b/zokrates_cli/src/ops/mpc/export.rs index cf640fa1c..87c780052 100644 --- a/zokrates_cli/src/ops/mpc/export.rs +++ b/zokrates_cli/src/ops/mpc/export.rs @@ -66,7 +66,7 @@ pub fn cli_mpc_export, B: Mpc ) -> Result<(), String> { let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); diff --git a/zokrates_cli/src/ops/mpc/init.rs b/zokrates_cli/src/ops/mpc/init.rs index 92a972482..72e2283c6 100644 --- a/zokrates_cli/src/ops/mpc/init.rs +++ b/zokrates_cli/src/ops/mpc/init.rs @@ -46,7 +46,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -76,7 +76,7 @@ fn cli_mpc_init< let mut radix_reader = BufReader::new(radix_file); let output_path = Path::new(sub_matches.value_of("output").unwrap()); - let output_file = File::create(&output_path) + let output_file = File::create(output_path) .map_err(|why| format!("Could not create `{}`: {}", output_path.display(), why))?; let mut writer = BufWriter::new(output_file); diff --git a/zokrates_cli/src/ops/mpc/verify.rs b/zokrates_cli/src/ops/mpc/verify.rs index fa014bd06..ef0371127 100644 --- a/zokrates_cli/src/ops/mpc/verify.rs +++ b/zokrates_cli/src/ops/mpc/verify.rs @@ -46,7 +46,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("circuit").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; let mut reader = BufReader::new(file); @@ -71,9 +71,9 @@ fn cli_mpc_verify< let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; - let mut reader = BufReader::new(file); + let reader = BufReader::new(file); let radix_path = Path::new(sub_matches.value_of("radix-path").unwrap()); let radix_file = File::open(radix_path) @@ -81,7 +81,7 @@ fn cli_mpc_verify< let mut radix_reader = BufReader::new(radix_file); - let result = B::verify(&mut reader, program, &mut radix_reader) + let result = B::verify(reader, program, &mut radix_reader) .map_err(|e| format!("Verification failed: {}", e))?; let contribution_count = result.len(); diff --git a/zokrates_cli/src/ops/print_proof.rs b/zokrates_cli/src/ops/print_proof.rs index a6df505af..cf0d25a7c 100644 --- a/zokrates_cli/src/ops/print_proof.rs +++ b/zokrates_cli/src/ops/print_proof.rs @@ -38,7 +38,7 @@ pub fn subcommand() -> App<'static, 'static> { pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { let proof_path = Path::new(sub_matches.value_of("proof-path").unwrap()); - let proof_file = File::open(&proof_path) + let proof_file = File::open(proof_path) .map_err(|why| format!("Could not open {}: {}", proof_path.display(), why))?; // deserialize proof to JSON diff --git a/zokrates_cli/src/ops/profile.rs b/zokrates_cli/src/ops/profile.rs new file mode 100644 index 000000000..a866beee8 --- /dev/null +++ b/zokrates_cli/src/ops/profile.rs @@ -0,0 +1,52 @@ +use crate::cli_constants::FLATTENED_CODE_DEFAULT_PATH; +use clap::{App, Arg, ArgMatches, SubCommand}; +use std::fs::File; +use std::io::BufReader; +use std::path::Path; +use zokrates_ast::ir::{self, ProgEnum}; +use zokrates_field::Field; +use zokrates_profiler::profile; + +pub fn subcommand() -> App<'static, 'static> { + SubCommand::with_name("profile") + .about("Profiles a compiled program, indicating which parts of the source yield the most constraints") + .arg( + Arg::with_name("input") + .short("i") + .long("input") + .help("Path of the binary") + .value_name("FILE") + .takes_value(true) + .required(false) + .default_value(FLATTENED_CODE_DEFAULT_PATH), + ) +} + +pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { + // read compiled program + let path = Path::new(sub_matches.value_of("input").unwrap()); + let file = + File::open(path).map_err(|why| format!("Could not open `{}`: {}", path.display(), why))?; + + let mut reader = BufReader::new(file); + + match ProgEnum::deserialize(&mut reader)? { + ProgEnum::Bn128Program(p) => cli_profile(p, sub_matches), + ProgEnum::Bls12_377Program(p) => cli_profile(p, sub_matches), + ProgEnum::Bls12_381Program(p) => cli_profile(p, sub_matches), + ProgEnum::Bw6_761Program(p) => cli_profile(p, sub_matches), + } +} + +fn cli_profile<'ast, T: Field, I: Iterator>>( + ir_prog: ir::ProgIterator<'ast, T, I>, + _: &ArgMatches, +) -> Result<(), String> { + let module_map = ir_prog.module_map.clone(); + + let heat_map = profile(ir_prog); + + println!("{}", heat_map.display(&module_map)); + + Ok(()) +} diff --git a/zokrates_cli/src/ops/setup.rs b/zokrates_cli/src/ops/setup.rs index 27ddc94c1..25bca25c8 100644 --- a/zokrates_cli/src/ops/setup.rs +++ b/zokrates_cli/src/ops/setup.rs @@ -95,7 +95,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { // read compiled program let path = Path::new(sub_matches.value_of("input").unwrap()); let file = - File::open(&path).map_err(|why| format!("Couldn't open {}: {}", path.display(), why))?; + File::open(path).map_err(|why| format!("Couldn't open {}: {}", path.display(), why))?; let mut reader = BufReader::new(file); let prog = ProgEnum::deserialize(&mut reader)?; @@ -146,7 +146,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { #[cfg(feature = "ark")] Parameters(BackendParameter::Ark, _, SchemeParameter::MARLIN) => { let setup_path = Path::new(sub_matches.value_of("universal-setup-path").unwrap()); - let setup_file = File::open(&setup_path) + let setup_file = File::open(setup_path) .map_err(|why| format!("Couldn't open {}: {}\nExpected an universal setup, make sure `zokrates universal-setup` was run`", setup_path.display(), why))?; let mut reader = BufReader::new(setup_file); diff --git a/zokrates_cli/src/ops/verify.rs b/zokrates_cli/src/ops/verify.rs index 24a96c387..568b13383 100644 --- a/zokrates_cli/src/ops/verify.rs +++ b/zokrates_cli/src/ops/verify.rs @@ -51,7 +51,7 @@ pub fn subcommand() -> App<'static, 'static> { pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { let vk_path = Path::new(sub_matches.value_of("verification-key-path").unwrap()); - let vk_file = File::open(&vk_path) + let vk_file = File::open(vk_path) .map_err(|why| format!("Could not open {}: {}", vk_path.display(), why))?; // deserialize vk to JSON @@ -60,7 +60,7 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { .map_err(|why| format!("Could not deserialize verification key: {}", why))?; let proof_path = Path::new(sub_matches.value_of("proof-path").unwrap()); - let proof_file = File::open(&proof_path) + let proof_file = File::open(proof_path) .map_err(|why| format!("Could not open {}: {}", proof_path.display(), why))?; // deserialize proof to JSON diff --git a/zokrates_cli/tests/code/arithmetics.expected.witness b/zokrates_cli/tests/code/arithmetics.expected.witness deleted file mode 100644 index caaf539db..000000000 --- a/zokrates_cli/tests/code/arithmetics.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 12 diff --git a/zokrates_cli/tests/code/arithmetics.expected.witness.json b/zokrates_cli/tests/code/arithmetics.expected.witness.json new file mode 100644 index 000000000..4e1ae3537 --- /dev/null +++ b/zokrates_cli/tests/code/arithmetics.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "12" +} diff --git a/zokrates_cli/tests/code/conditional_false.expected.witness b/zokrates_cli/tests/code/conditional_false.expected.witness deleted file mode 100644 index 1b8f13fa2..000000000 --- a/zokrates_cli/tests/code/conditional_false.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 0 \ No newline at end of file diff --git a/zokrates_cli/tests/code/conditional_false.expected.witness.json b/zokrates_cli/tests/code/conditional_false.expected.witness.json new file mode 100644 index 000000000..50655725f --- /dev/null +++ b/zokrates_cli/tests/code/conditional_false.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "0" +} diff --git a/zokrates_cli/tests/code/conditional_true.expected.witness b/zokrates_cli/tests/code/conditional_true.expected.witness deleted file mode 100644 index 1e61044c7..000000000 --- a/zokrates_cli/tests/code/conditional_true.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 1 \ No newline at end of file diff --git a/zokrates_cli/tests/code/conditional_true.expected.witness.json b/zokrates_cli/tests/code/conditional_true.expected.witness.json new file mode 100644 index 000000000..cd003d107 --- /dev/null +++ b/zokrates_cli/tests/code/conditional_true.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "1" +} diff --git a/zokrates_cli/tests/code/multidim_update.expected.witness b/zokrates_cli/tests/code/multidim_update.expected.witness deleted file mode 100644 index 0b054f2fe..000000000 --- a/zokrates_cli/tests/code/multidim_update.expected.witness +++ /dev/null @@ -1,4 +0,0 @@ -~out_0 0 -~out_1 0 -~out_2 0 -~out_3 42 \ No newline at end of file diff --git a/zokrates_cli/tests/code/multidim_update.expected.witness.json b/zokrates_cli/tests/code/multidim_update.expected.witness.json new file mode 100644 index 000000000..98c0fbd76 --- /dev/null +++ b/zokrates_cli/tests/code/multidim_update.expected.witness.json @@ -0,0 +1,6 @@ +{ + "~out_0": "0", + "~out_1": "0", + "~out_2": "0", + "~out_3": "42" +} diff --git a/zokrates_cli/tests/code/n_choose_k.expected.witness b/zokrates_cli/tests/code/n_choose_k.expected.witness deleted file mode 100644 index b51f241b4..000000000 --- a/zokrates_cli/tests/code/n_choose_k.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 5 \ No newline at end of file diff --git a/zokrates_cli/tests/code/n_choose_k.expected.witness.json b/zokrates_cli/tests/code/n_choose_k.expected.witness.json new file mode 100644 index 000000000..839b53c97 --- /dev/null +++ b/zokrates_cli/tests/code/n_choose_k.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "5" +} diff --git a/zokrates_cli/tests/code/no_return.expected.witness b/zokrates_cli/tests/code/no_return.expected.witness deleted file mode 100644 index e69de29bb..000000000 diff --git a/zokrates_cli/tests/code/no_return.expected.witness.json b/zokrates_cli/tests/code/no_return.expected.witness.json new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/zokrates_cli/tests/code/no_return.expected.witness.json @@ -0,0 +1 @@ +{} diff --git a/zokrates_cli/tests/code/return_array.expected.witness b/zokrates_cli/tests/code/return_array.expected.witness deleted file mode 100644 index f9f9cf798..000000000 --- a/zokrates_cli/tests/code/return_array.expected.witness +++ /dev/null @@ -1,8 +0,0 @@ -~out_0 2 -~out_1 1 -~out_2 1 -~out_3 1 -~out_4 3 -~out_5 3 -~out_6 3 -~out_7 3 \ No newline at end of file diff --git a/zokrates_cli/tests/code/return_array.expected.witness.json b/zokrates_cli/tests/code/return_array.expected.witness.json new file mode 100644 index 000000000..d1b0a1d25 --- /dev/null +++ b/zokrates_cli/tests/code/return_array.expected.witness.json @@ -0,0 +1,10 @@ +{ + "~out_0": "2", + "~out_1": "1", + "~out_2": "1", + "~out_3": "1", + "~out_4": "3", + "~out_5": "3", + "~out_6": "3", + "~out_7": "3" +} diff --git a/zokrates_cli/tests/code/simple_add.expected.witness b/zokrates_cli/tests/code/simple_add.expected.witness deleted file mode 100644 index 23b7f950e..000000000 --- a/zokrates_cli/tests/code/simple_add.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 3 diff --git a/zokrates_cli/tests/code/simple_add.expected.witness.json b/zokrates_cli/tests/code/simple_add.expected.witness.json new file mode 100644 index 000000000..ca9fd972e --- /dev/null +++ b/zokrates_cli/tests/code/simple_add.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "3" +} diff --git a/zokrates_cli/tests/code/simple_mul.expected.witness b/zokrates_cli/tests/code/simple_mul.expected.witness deleted file mode 100644 index 8eb3a8d7a..000000000 --- a/zokrates_cli/tests/code/simple_mul.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 24 \ No newline at end of file diff --git a/zokrates_cli/tests/code/simple_mul.expected.witness.json b/zokrates_cli/tests/code/simple_mul.expected.witness.json new file mode 100644 index 000000000..7143416e9 --- /dev/null +++ b/zokrates_cli/tests/code/simple_mul.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "24" +} diff --git a/zokrates_cli/tests/code/taxation.expected.witness b/zokrates_cli/tests/code/taxation.expected.witness deleted file mode 100644 index 1b8f13fa2..000000000 --- a/zokrates_cli/tests/code/taxation.expected.witness +++ /dev/null @@ -1 +0,0 @@ -~out_0 0 \ No newline at end of file diff --git a/zokrates_cli/tests/code/taxation.expected.witness.json b/zokrates_cli/tests/code/taxation.expected.witness.json new file mode 100644 index 000000000..50655725f --- /dev/null +++ b/zokrates_cli/tests/code/taxation.expected.witness.json @@ -0,0 +1,3 @@ +{ + "~out_0": "0" +} diff --git a/zokrates_cli/tests/integration.rs b/zokrates_cli/tests/integration.rs index 35aff1deb..68bc0abf6 100644 --- a/zokrates_cli/tests/integration.rs +++ b/zokrates_cli/tests/integration.rs @@ -16,17 +16,78 @@ mod integration { use std::fs; use std::fs::File; use std::io::{BufReader, Read, Write}; - use std::panic; use std::path::Path; use std::process::Command; use tempdir::TempDir; use zokrates_abi::{parse_strict, Encode}; + use zokrates_ast::ir::Witness; use zokrates_ast::typed::abi::Abi; use zokrates_field::Bn128Field; use zokrates_proof_systems::{ to_token::ToToken, Marlin, Proof, SolidityCompatibleScheme, G16, GM17, }; + mod helpers { + use super::*; + use zokrates_ast::common::flat::Variable; + use zokrates_field::Field; + + pub fn parse_variable(s: &str) -> Result { + if s == "~one" { + return Ok(Variable::one()); + } + + let mut public = s.split("~out_"); + match public.nth(1) { + Some(v) => { + let v = v.parse().map_err(|_| s)?; + Ok(Variable::public(v)) + } + None => { + let mut private = s.split('_'); + match private.nth(1) { + Some(v) => { + let v = v.parse().map_err(|_| s)?; + Ok(Variable::new(v)) + } + None => Err(s), + } + } + } + } + + pub fn parse_witness_json(reader: R) -> std::io::Result> { + use std::io::{Error, ErrorKind}; + + let json: serde_json::Value = serde_json::from_reader(reader)?; + let object = json + .as_object() + .ok_or_else(|| Error::new(ErrorKind::Other, "Witness must be an object"))?; + + let mut witness = Witness::empty(); + for (k, v) in object { + let variable = parse_variable(k).map_err(|why| { + Error::new( + ErrorKind::Other, + format!("Invalid variable in witness: {}", why), + ) + })?; + + let value = v + .as_str() + .ok_or_else(|| Error::new(ErrorKind::Other, "Witness value must be a string")) + .and_then(|v| { + T::try_from_dec_str(v).map_err(|_| { + Error::new(ErrorKind::Other, format!("Invalid value in witness: {}", v)) + }) + })?; + + witness.insert(variable, value); + } + Ok(witness) + } + } + macro_rules! map( { $($key:expr => $value:expr),+ } => { @@ -101,7 +162,9 @@ mod integration { let program_name = Path::new(Path::new(path.file_stem().unwrap()).file_stem().unwrap()); let prog = dir.join(program_name).with_extension("zok"); - let witness = dir.join(program_name).with_extension("expected.witness"); + let witness = dir + .join(program_name) + .with_extension("expected.witness.json"); let json_input = dir.join(program_name).with_extension("arguments.json"); test_compile_and_witness( @@ -250,33 +313,24 @@ mod integration { .unwrap(); // load the expected witness - let mut expected_witness_file = File::open(&expected_witness_path).unwrap(); - let mut expected_witness = String::new(); - expected_witness_file - .read_to_string(&mut expected_witness) - .unwrap(); + let expected_witness_file = File::open(expected_witness_path).unwrap(); + let expected_witness: Witness = + helpers::parse_witness_json(expected_witness_file).unwrap(); // load the actual witness - let mut witness_file = File::open(&witness_path).unwrap(); - let mut witness = String::new(); - witness_file.read_to_string(&mut witness).unwrap(); + let witness_file = File::open(&witness_path).unwrap(); + let witness = Witness::::read(witness_file).unwrap(); // load the actual inline witness - let mut inline_witness_file = File::open(&inline_witness_path).unwrap(); - let mut inline_witness = String::new(); - inline_witness_file - .read_to_string(&mut inline_witness) - .unwrap(); + let inline_witness_file = File::open(&inline_witness_path).unwrap(); + let inline_witness = + Witness::::read(inline_witness_file).unwrap(); assert_eq!(inline_witness, witness); - for line in expected_witness.as_str().split('\n') { - assert!( - witness.contains(line), - "Witness generation failed for {}\n\nLine \"{}\" not found in witness", - program_path.to_str().unwrap(), - line - ); + for (k, v) in expected_witness.0 { + let value = witness.0.get(&k).expect("should contain key"); + assert!(v.eq(value)); } let backends = map! { @@ -584,7 +638,7 @@ mod integration { .unwrap(); // load the expected smtlib2 - let mut expected_smtlib2_file = File::open(&expected_smtlib2_path).unwrap(); + let mut expected_smtlib2_file = File::open(expected_smtlib2_path).unwrap(); let mut expected_smtlib2 = String::new(); expected_smtlib2_file .read_to_string(&mut expected_smtlib2) diff --git a/zokrates_codegen/Cargo.toml b/zokrates_codegen/Cargo.toml index 1d9ec3bc2..30fe981c0 100644 --- a/zokrates_codegen/Cargo.toml +++ b/zokrates_codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_codegen" -version = "0.1.1" +version = "0.1.2" edition = "2021" [features] diff --git a/zokrates_codegen/src/lib.rs b/zokrates_codegen/src/lib.rs index e149a9688..46faf587e 100644 --- a/zokrates_codegen/src/lib.rs +++ b/zokrates_codegen/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - //! Module containing the `Flattener` to process a program that is R1CS-able. //! //! @file flatten.rs @@ -10,9 +8,12 @@ mod utils; use self::utils::flat_expression_from_bits; -use zokrates_ast::zir::{ - ConditionalExpression, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, - ZirExpressionList, +use zokrates_ast::{ + common::{expressions::ValueExpression, Span}, + zir::{ + canonicalizer::ZirCanonicalizer, ConditionalExpression, Expr, Folder, SelectExpression, + ShouldReduce, UMetadata, ZirAssemblyStatement, ZirExpressionList, ZirProgram, + }, }; use zokrates_interpreter::Interpreter; @@ -20,32 +21,73 @@ use std::collections::{ hash_map::{Entry, HashMap}, VecDeque, }; +use std::ops::*; use zokrates_ast::common::embed::*; use zokrates_ast::common::FlatEmbed; -use zokrates_ast::common::{RuntimeError, Variable}; +use zokrates_ast::common::WithSpan; +use zokrates_ast::common::{flat::Variable, RuntimeError}; use zokrates_ast::flat::*; use zokrates_ast::ir::Solver; use zokrates_ast::zir::types::{Type, UBitwidth}; use zokrates_ast::zir::{ BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter, - UExpression, UExpressionInner, Variable as ZirVariable, ZirExpression, ZirFunction, - ZirStatement, + UExpression, UExpressionInner, Variable as ZirVariable, ZirExpression, ZirStatement, }; use zokrates_common::CompileConfig; use zokrates_field::Field; -type FlatStatements<'ast, T> = VecDeque>; +/// A container for statements produced during code generation +/// New statements are registered with the span set in the container +#[derive(Default)] +pub struct FlatStatements<'ast, T> { + span: Option, + buffer: VecDeque>, +} + +impl<'ast, T> FlatStatements<'ast, T> { + fn push_back(&mut self, s: FlatStatement<'ast, T>) { + self.buffer.push_back(s.span(self.span)) + } + + fn pop_front(&mut self) -> Option> { + self.buffer.pop_front() + } + + fn extend(&mut self, i: impl IntoIterator>) { + self.buffer.extend(i.into_iter().map(|s| s.span(self.span))) + } + + fn set_span(&mut self, span: Option) { + self.span = span; + } + + fn is_empty(&self) -> bool { + self.buffer.is_empty() + } +} + +impl<'ast, T> IntoIterator for FlatStatements<'ast, T> { + type Item = FlatStatement<'ast, T>; + + type IntoIter = std::collections::vec_deque::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.buffer.into_iter() + } +} /// Flattens a function /// /// # Arguments /// * `funct` - `ZirFunction` that will be flattened -pub fn from_function_and_config( - funct: ZirFunction, +pub fn from_program_and_config( + prog: ZirProgram, config: CompileConfig, ) -> FlattenerIterator { + let funct = prog.main; + let mut flattener = Flattener::new(config); - let mut statements_flattened = FlatStatements::new(); + let mut statements_flattened = FlatStatements::default(); // push parameters let arguments_flattened = funct .arguments @@ -61,6 +103,7 @@ pub fn from_function_and_config( flattener, }, return_count: funct.signature.outputs.len(), + module_map: prog.module_map, } } @@ -168,12 +211,28 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> { } #[derive(Clone, Debug)] -struct FlatUExpression { +struct FlatUExpression { field: Option>, bits: Option>>, } -impl FlatUExpression { +impl WithSpan for FlatUExpression { + fn span(self, span: Option) -> Self { + Self { + field: self.field.map(|e| e.span(span)), + bits: self + .bits + .map(|bits| bits.into_iter().map(|b| b.span(span)).collect()), + } + } + + fn get_span(&self) -> Option { + let field_span = self.field.as_ref().map(|f| f.get_span()); + field_span.unwrap_or_else(|| unimplemented!()) + } +} + +impl FlatUExpression { fn default() -> Self { FlatUExpression { field: None, @@ -182,7 +241,7 @@ impl FlatUExpression { } } -impl FlatUExpression { +impl FlatUExpression { fn field>>>(mut self, e: U) -> Self { self.field = e.into(); self @@ -200,7 +259,9 @@ impl FlatUExpression { fn with_bits>>>>(e: U) -> Self { Self::default().bits(e) } +} +impl FlatUExpression { fn get_field_unchecked(self) -> FlatExpression { match self.field { Some(f) => f, @@ -230,10 +291,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, ) -> Variable { match e { - FlatExpression::Identifier(id) => id, + FlatExpression::Identifier(id) => id.id, e => { let res = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(res, e)); + statements_flattened.push_back(FlatStatement::definition(res, e)); res } } @@ -290,9 +351,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { .iter() .map(|e| { let e_id = self.define(e.clone(), statements_flattened); - FlatStatement::Condition( + FlatStatement::condition( e_id.into(), - FlatExpression::Mult(box e_id.into(), box e_id.into()), + FlatExpression::mul(e_id.into(), e_id.into()), RuntimeError::Bitness, ) }) @@ -312,28 +373,25 @@ impl<'ast, T: Field> Flattener<'ast, T> { } // init size_unknown = true - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( size_unknown[0], - FlatExpression::Number(T::from(1)), + FlatExpression::value(T::from(1)), )); let mut res = vec![]; for (i, b) in b.iter().enumerate() { if *b { - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( is_not_smaller_run[i], a[i].clone(), )); // don't need to update size_unknown in the last round if i < len - 1 { - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( size_unknown[i + 1], - FlatExpression::Mult( - box size_unknown[i].into(), - box is_not_smaller_run[i].into(), - ), + FlatExpression::mul(size_unknown[i].into(), is_not_smaller_run[i].into()), )); } } else { @@ -343,24 +401,20 @@ impl<'ast, T: Field> Flattener<'ast, T> { // sizeUnknown is not changing in this case // We sill have to assign the old value to the variable of the current run // This trivial definition will later be removed by the optimiser - FlatStatement::Definition(size_unknown[i + 1], size_unknown[i].into()), + FlatStatement::definition(size_unknown[i + 1], size_unknown[i].into()), ); } - let or_left = FlatExpression::Sub( - box FlatExpression::Number(T::from(1)), - box size_unknown[i].into(), - ); + let or_left = + FlatExpression::sub(FlatExpression::value(T::from(1)), size_unknown[i].into()); let or_right: FlatExpression<_> = - FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box a[i].clone()); + FlatExpression::sub(FlatExpression::value(T::from(1)), a[i].clone()); let and_name = self.use_sym(); - let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone()); - statements_flattened.push_back(FlatStatement::Definition(and_name, and)); - let or = FlatExpression::Sub( - box FlatExpression::Add(box or_left, box or_right), - box and_name.into(), - ); + let and = FlatExpression::mul(or_left.clone(), or_right.clone()); + statements_flattened.push_back(FlatStatement::definition(and_name, and)); + let or = + FlatExpression::sub(FlatExpression::add(or_left, or_right), and_name.into()); res.push(or); } @@ -395,30 +449,30 @@ impl<'ast, T: Field> Flattener<'ast, T> { // Y == X * M // 0 == (1-Y) * X - let x = FlatExpression::Sub(box left.into(), box right.into()); + let x = FlatExpression::sub(left.into(), right.into()); let name_y = self.use_sym(); let name_m = self.use_sym(); - statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::directive( vec![name_y, name_m], Solver::ConditionEq, vec![x.clone()], - ))); - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Identifier(name_y), - FlatExpression::Mult(box x.clone(), box FlatExpression::Identifier(name_m)), + )); + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::identifier(name_y), + FlatExpression::mul(x.clone(), FlatExpression::identifier(name_m)), RuntimeError::Equal, )); - let res = FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box FlatExpression::Identifier(name_y), + let res = FlatExpression::sub( + FlatExpression::value(T::one()), + FlatExpression::identifier(name_y), ); - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::zero()), - FlatExpression::Mult(box res.clone(), box x), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::zero()), + FlatExpression::mul(res.clone(), x), RuntimeError::Equal, )); @@ -446,11 +500,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { let conditions_sum = conditions .into_iter() .fold(FlatExpression::from(T::zero()), |acc, e| { - FlatExpression::Add(box acc, box e) + FlatExpression::add(acc, e) }); - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::from(0)), - FlatExpression::Sub(box conditions_sum, box T::from(conditions_count).into()), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::from(0)), + FlatExpression::sub(conditions_sum, T::from(conditions_count).into()), error, )); } @@ -507,7 +561,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) { // `e < 0` will always result in false value, so we constrain `0 == 1` if c == T::zero() { - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( T::zero().into(), T::one().into(), error, @@ -525,49 +579,58 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements .into_iter() .flat_map(|s| match s { - FlatStatement::Condition(left, right, message) => { - let mut output = VecDeque::new(); + FlatStatement::Condition(s) => { + let span = s.get_span(); + + let mut output = FlatStatements::default(); + + output.set_span(span); // we transform (a == b) into (c => (a == b)) which is (!c || (a == b)) // let's introduce new variables to make sure everything is linear - let name_left = self.define(left, &mut output); - let name_right = self.define(right, &mut output); + let name_lin = self.define(s.lin, &mut output); + let name_quad = self.define(s.quad, &mut output); // let's introduce an expression which is 1 iff `a == b` - let y = FlatExpression::Add( - box FlatExpression::Sub(box name_left.into(), box name_right.into()), - box T::one().into(), - ); - // let's introduce !c - let x = FlatExpression::Sub(box T::one().into(), box condition.clone()); - + let y = FlatExpression::add( + FlatExpression::sub(name_lin.into(), name_quad.into()), + T::one().into(), + ); // let's introduce !c + let x = FlatExpression::sub(T::one().into(), condition.clone()); assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); - output.push_back(FlatStatement::Directive(FlatDirective { - solver: Solver::Or, - outputs: vec![name_x_or_y], - inputs: vec![x.clone(), y.clone()], - })); - output.push_back(FlatStatement::Condition( - FlatExpression::Add( - box x.clone(), - box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), + output.push_back(FlatStatement::directive( + vec![name_x_or_y], + Solver::Or, + vec![x.clone(), y.clone()], + )); + output.push_back(FlatStatement::condition( + FlatExpression::add( + x.clone(), + FlatExpression::sub(y.clone(), name_x_or_y.into()), ), - FlatExpression::Mult(box x.clone(), box y.clone()), + FlatExpression::mul(x, y), RuntimeError::BranchIsolation, )); - output.push_back(FlatStatement::Condition( + output.push_back(FlatStatement::condition( name_x_or_y.into(), T::one().into(), - message, + s.error, )); output } - s => VecDeque::from([s]), + s => { + let mut v = FlatStatements::default(); + v.push_back(s); + v + } + }) + .fold(FlatStatements::default(), |mut acc, s| { + acc.push_back(s); + acc }) - .collect() } /// Flatten an if/else expression @@ -593,14 +656,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.flatten_boolean_expression(statements_flattened, condition.clone()); let condition_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(condition_id, condition_flat)); + statements_flattened.push_back(FlatStatement::definition(condition_id, condition_flat)); let (consequence, alternative) = if self.config.isolate_branches { - let mut consequence_statements = VecDeque::new(); + let mut consequence_statements = FlatStatements::default(); let consequence = consequence.flatten(self, &mut consequence_statements); - let mut alternative_statements = VecDeque::new(); + let mut alternative_statements = FlatStatements::default(); let alternative = alternative.flatten(self, &mut alternative_statements); @@ -608,10 +671,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.make_conditional(consequence_statements, condition_id.into()); let alternative_statements = self.make_conditional( alternative_statements, - FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box condition_id.into(), - ), + FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), ); statements_flattened.extend(consequence_statements); @@ -629,43 +689,37 @@ impl<'ast, T: Field> Flattener<'ast, T> { let alternative = alternative.flat(); let consequence_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(consequence_id, consequence)); + statements_flattened.push_back(FlatStatement::definition(consequence_id, consequence)); let alternative_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(alternative_id, alternative)); + statements_flattened.push_back(FlatStatement::definition(alternative_id, alternative)); let term0_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( term0_id, - FlatExpression::Mult( - box condition_id.into(), - box FlatExpression::from(consequence_id), - ), + FlatExpression::mul(condition_id.into(), FlatExpression::from(consequence_id)), )); let term1_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( term1_id, - FlatExpression::Mult( - box FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box condition_id.into(), - ), - box FlatExpression::from(alternative_id), + FlatExpression::mul( + FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), + FlatExpression::from(alternative_id), ), )); let res = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( res, - FlatExpression::Add( - box FlatExpression::from(term0_id), - box FlatExpression::from(term1_id), + FlatExpression::add( + FlatExpression::from(term0_id), + FlatExpression::from(term1_id), ), )); FlatUExpression { - field: Some(FlatExpression::Identifier(res)), + field: Some(FlatExpression::identifier(res)), bits: None, } } @@ -736,8 +790,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { T::from(conditions.len()).into(), conditions .into_iter() - .fold(FlatExpression::Number(T::zero()), |acc, e| { - FlatExpression::Add(box acc, box e) + .fold(FlatExpression::value(T::zero()), |acc, e| { + FlatExpression::add(acc, e) }), ) } @@ -750,13 +804,13 @@ impl<'ast, T: Field> Flattener<'ast, T> { rhs_flattened: FlatExpression, bit_width: usize, ) -> FlatExpression { - FlatExpression::Add( - box self.eq_check( + FlatExpression::add( + self.eq_check( statements_flattened, lhs_flattened.clone(), rhs_flattened.clone(), ), - box self.lt_check( + self.lt_check( statements_flattened, lhs_flattened, rhs_flattened, @@ -774,25 +828,25 @@ impl<'ast, T: Field> Flattener<'ast, T> { bit_width: usize, ) -> FlatExpression { match (lhs_flattened, rhs_flattened) { - (x, FlatExpression::Number(constant)) => { - self.constant_lt_check(statements_flattened, x, constant) + (x, FlatExpression::Value(constant)) => { + self.constant_lt_check(statements_flattened, x, constant.value) } // (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c) - (FlatExpression::Number(constant), x) => self.constant_lt_check( + (FlatExpression::Value(constant), x) => self.constant_lt_check( statements_flattened, - FlatExpression::Sub(box T::max_value().into(), box x), - T::max_value() - constant, + FlatExpression::sub(T::max_value().into(), x), + T::max_value() - constant.value, ), (lhs_flattened, rhs_flattened) => { let lhs_id = self.define(lhs_flattened, statements_flattened); let rhs_id = self.define(rhs_flattened, statements_flattened); // shifted_sub := 2**safe_width + lhs - rhs - let shifted_sub = FlatExpression::Add( - box FlatExpression::Number(T::from(2).pow(bit_width)), - box FlatExpression::Sub( - box FlatExpression::Identifier(lhs_id), - box FlatExpression::Identifier(rhs_id), + let shifted_sub = FlatExpression::add( + FlatExpression::value(T::from(2).pow(bit_width)), + FlatExpression::sub( + FlatExpression::identifier(lhs_id), + FlatExpression::identifier(rhs_id), ), ); @@ -806,9 +860,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { RuntimeError::IncompleteDynamicRange, ); - FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box shifted_sub_bits_be[0].clone(), + FlatExpression::sub( + FlatExpression::value(T::one()), + shifted_sub_bits_be[0].clone(), ) } } @@ -830,21 +884,27 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, expression: BooleanExpression<'ast, T>, ) -> FlatExpression { - match expression { + let span = expression.get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + + let res = match expression { BooleanExpression::Identifier(x) => { - FlatExpression::Identifier(*self.layout.get(&x.id).unwrap()) + FlatExpression::identifier(*self.layout.get(&x.id).unwrap()) } BooleanExpression::Select(e) => self .flatten_select_expression(statements_flattened, e) .get_field_unchecked(), - BooleanExpression::FieldLt(box lhs, box rhs) => { + BooleanExpression::FieldLt(e) => { // Get the bit width to know the size of the binary decompositions for this Field let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to field elements whose difference is strictly smaller than 2**(bitwidth - 2) - let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs); - let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs); + let lhs_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let rhs_flattened = self.flatten_field_expression(statements_flattened, *e.right); self.lt_check( statements_flattened, lhs_flattened, @@ -852,10 +912,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { safe_width, ) } - BooleanExpression::BoolEq(box lhs, box rhs) => { + BooleanExpression::BoolEq(e) => { // lhs and rhs are booleans, they flatten to 0 or 1 - let x = self.flatten_boolean_expression(statements_flattened, lhs); - let y = self.flatten_boolean_expression(statements_flattened, rhs); + let x = self.flatten_boolean_expression(statements_flattened, *e.left); + let y = self.flatten_boolean_expression(statements_flattened, *e.right); // Wanted: Not(X - Y)**2 which is an XNOR // We know that X and Y are [0, 1] // (X - Y) can become a negative values, which is why squaring the result is needed @@ -871,27 +931,27 @@ impl<'ast, T: Field> Flattener<'ast, T> { // | 0 | 0 | 0 | 1 | // +---+---+-------+---------------+ - let x_sub_y = FlatExpression::Sub(box x, box y); + let x_sub_y = FlatExpression::sub(x, y); let name_x_mult_x = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( name_x_mult_x, - FlatExpression::Mult(box x_sub_y.clone(), box x_sub_y), + FlatExpression::mul(x_sub_y.clone(), x_sub_y), )); - FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box FlatExpression::Identifier(name_x_mult_x), + FlatExpression::sub( + FlatExpression::value(T::one()), + FlatExpression::identifier(name_x_mult_x), ) } - BooleanExpression::FieldEq(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); + BooleanExpression::FieldEq(e) => { + let lhs = self.flatten_field_expression(statements_flattened, *e.left); - let rhs = self.flatten_field_expression(statements_flattened, rhs); + let rhs = self.flatten_field_expression(statements_flattened, *e.right); self.eq_check(statements_flattened, lhs, rhs) } - BooleanExpression::UintEq(box lhs, box rhs) => { + BooleanExpression::UintEq(e) => { // We reduce each side into range and apply the same approach as for field elements // Wanted: (Y = (X != 0) ? 1 : 0) @@ -901,39 +961,41 @@ impl<'ast, T: Field> Flattener<'ast, T> { // Y == X * M // 0 == (1-Y) * X - assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool()); - assert!(rhs.metadata.as_ref().unwrap().should_reduce.to_bool()); + assert!(e.left.metadata.as_ref().unwrap().should_reduce.to_bool()); + assert!(e.right.metadata.as_ref().unwrap().should_reduce.to_bool()); let lhs = self - .flatten_uint_expression(statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self - .flatten_uint_expression(statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.eq_check(statements_flattened, lhs, rhs) } - BooleanExpression::FieldLe(box lhs, box rhs) => { + BooleanExpression::FieldLe(e) => { let lt = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::FieldLt(box lhs.clone(), box rhs.clone()), + BooleanExpression::field_lt(*e.left.clone(), *e.right.clone()).span(span), ); + let eq = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::FieldEq(box lhs, box rhs), + BooleanExpression::field_eq(*e.left, *e.right).span(span), ); - FlatExpression::Add(box eq, box lt) + + FlatExpression::add(eq, lt) } - BooleanExpression::UintLt(box lhs, box rhs) => { - let bit_width = lhs.bitwidth.to_usize(); - assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool()); - assert!(rhs.metadata.as_ref().unwrap().should_reduce.to_bool()); + BooleanExpression::UintLt(e) => { + let bit_width = e.left.bitwidth.to_usize(); + assert!(e.left.metadata.as_ref().unwrap().should_reduce.to_bool()); + assert!(e.right.metadata.as_ref().unwrap().should_reduce.to_bool()); let lhs_flattened = self - .flatten_uint_expression(statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs_flattened = self - .flatten_uint_expression(statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.lt_check( @@ -943,55 +1005,55 @@ impl<'ast, T: Field> Flattener<'ast, T> { bit_width, ) } - BooleanExpression::UintLe(box lhs, box rhs) => { + BooleanExpression::UintLe(e) => { let lt = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::UintLt(box lhs.clone(), box rhs.clone()), + BooleanExpression::uint_lt(*e.left.clone(), *e.right.clone()).span(span), ); let eq = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::UintEq(box lhs, box rhs), + BooleanExpression::uint_eq(*e.left, *e.right).span(span), ); - FlatExpression::Add(box eq, box lt) + FlatExpression::add(eq, lt) } - BooleanExpression::Or(box lhs, box rhs) => { - let x = self.flatten_boolean_expression(statements_flattened, lhs); - let y = self.flatten_boolean_expression(statements_flattened, rhs); + BooleanExpression::Or(e) => { + let x = self.flatten_boolean_expression(statements_flattened, *e.left); + let y = self.flatten_boolean_expression(statements_flattened, *e.right); assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); - statements_flattened.push_back(FlatStatement::Directive(FlatDirective { - solver: Solver::Or, - outputs: vec![name_x_or_y], - inputs: vec![x.clone(), y.clone()], - })); - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Add( - box x.clone(), - box FlatExpression::Sub(box y.clone(), box name_x_or_y.into()), + statements_flattened.push_back(FlatStatement::directive( + vec![name_x_or_y], + Solver::Or, + vec![x.clone(), y.clone()], + )); + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::add( + x.clone(), + FlatExpression::sub(y.clone(), name_x_or_y.into()), ), - FlatExpression::Mult(box x, box y), + FlatExpression::mul(x, y), RuntimeError::Or, )); name_x_or_y.into() } - BooleanExpression::And(box lhs, box rhs) => { - let x = self.flatten_boolean_expression(statements_flattened, lhs); - let y = self.flatten_boolean_expression(statements_flattened, rhs); + BooleanExpression::And(e) => { + let x = self.flatten_boolean_expression(statements_flattened, *e.left); + let y = self.flatten_boolean_expression(statements_flattened, *e.right); let name_x_and_y = self.use_sym(); assert!(x.is_linear() && y.is_linear()); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( name_x_and_y, - FlatExpression::Mult(box x, box y), + FlatExpression::mul(x, y), )); - FlatExpression::Identifier(name_x_and_y) + FlatExpression::identifier(name_x_and_y) } - BooleanExpression::Not(box exp) => { - let x = self.flatten_boolean_expression(statements_flattened, exp); - FlatExpression::Sub(box FlatExpression::Number(T::one()), box x) + BooleanExpression::Not(e) => { + let x = self.flatten_boolean_expression(statements_flattened, *e.inner); + FlatExpression::sub(FlatExpression::value(T::one()), x) } - BooleanExpression::Value(b) => FlatExpression::Number(match b { + BooleanExpression::Value(b) => FlatExpression::value(match b.value { true => T::from(1), false => T::from(0), }), @@ -999,6 +1061,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { .flatten_conditional_expression(statements_flattened, e) .get_field_unchecked(), } + .span(span); + + statements_flattened.set_span(span_backup); + + res } fn u_to_bits( @@ -1089,8 +1156,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { let constants: Vec<_> = constants .into_iter() .map(|e| match e.get_field_unchecked() { - FlatExpression::Number(n) if n == T::one() => true, - FlatExpression::Number(n) if n == T::zero() => false, + FlatExpression::Value(n) if n.value == T::one() => true, + FlatExpression::Value(n) if n.value == T::zero() => false, _ => unreachable!(), }) .collect(); @@ -1106,8 +1173,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { T::from(conditions.len()).into(), conditions .into_iter() - .fold(FlatExpression::Number(T::zero()), |acc, e| { - FlatExpression::Add(box acc, box e) + .fold(FlatExpression::value(T::zero()), |acc, e| { + FlatExpression::add(acc, e) }), ), )] @@ -1157,16 +1224,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { let statements = funct.statements.into_iter().map(|stat| match stat { FlatStatement::Block(..) => unreachable!(), - FlatStatement::Definition(var, rhs) => { + FlatStatement::Definition(s) => { let new_var = self.use_sym(); - replacement_map.insert(var, new_var); - let new_rhs = rhs.apply_substitution(&replacement_map); - FlatStatement::Definition(new_var, new_rhs) + replacement_map.insert(s.assignee, new_var); + let new_rhs = s.rhs.apply_substitution(&replacement_map); + FlatStatement::definition(new_var, new_rhs) } - FlatStatement::Condition(lhs, rhs, message) => { - let new_lhs = lhs.apply_substitution(&replacement_map); - let new_rhs = rhs.apply_substitution(&replacement_map); - FlatStatement::Condition(new_lhs, new_rhs, message) + FlatStatement::Condition(s) => { + let new_quad = s.quad.apply_substitution(&replacement_map); + let new_lin = s.lin.apply_substitution(&replacement_map); + FlatStatement::condition(new_lin, new_quad, s.error) } FlatStatement::Directive(d) => { let new_outputs = d @@ -1183,15 +1250,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .map(|i| i.apply_substitution(&replacement_map)) .collect(); - FlatStatement::Directive(FlatDirective { - outputs: new_outputs, - solver: d.solver, - inputs: new_inputs, - }) + FlatStatement::directive(new_outputs, d.solver, new_inputs) } - FlatStatement::Log(l, expressions) => FlatStatement::Log( - l, - expressions + FlatStatement::Log(s) => FlatStatement::Log(LogStatement::new( + s.format_string, + s.expressions .into_iter() .map(|(t, e)| { ( @@ -1202,7 +1265,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ) }) .collect(), - ), + )), }); statements_flattened.extend(statements); @@ -1253,12 +1316,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { - (FlatExpression::Number(n), e) | (e, FlatExpression::Number(n)) => { - if n == T::from(0) { + (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { + if n.value == T::from(0) { self.define(e, statements_flattened).into() - } else if n == T::from(1) { + } else if n.value == T::from(1) { self.define( - FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box e), + FlatExpression::sub(FlatExpression::value(T::from(1)), e), statements_flattened, ) .into() @@ -1270,20 +1333,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { let name = self.use_sym(); statements_flattened.extend(vec![ - FlatStatement::Directive(FlatDirective::new( + FlatStatement::directive( vec![name], Solver::Xor, vec![x.clone(), y.clone()], - )), - FlatStatement::Condition( - FlatExpression::Add( - box x.clone(), - box FlatExpression::Sub(box y.clone(), box name.into()), - ), - FlatExpression::Mult( - box FlatExpression::Add(box x.clone(), box x), - box y, + ), + FlatStatement::condition( + FlatExpression::add( + x.clone(), + FlatExpression::sub(y.clone(), name.into()), ), + FlatExpression::mul(FlatExpression::add(x.clone(), x), y), RuntimeError::Xor, ), ]); @@ -1313,26 +1373,26 @@ impl<'ast, T: Field> Flattener<'ast, T> { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let d = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; // introduce the quotient and remainder let q = self.use_sym(); let r = self.use_sym(); - statements_flattened.push_back(FlatStatement::Directive(FlatDirective { - inputs: vec![n.clone(), d.clone()], - outputs: vec![q, r], - solver: Solver::EuclideanDiv, - })); + statements_flattened.push_back(FlatStatement::directive( + vec![q, r], + Solver::EuclideanDiv, + vec![n.clone(), d.clone()], + )); let target_bitwidth = target_bitwidth.to_usize(); @@ -1356,9 +1416,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { // r < d <=> r - d + 2**w < 2**w let _ = self.get_bits_unchecked( - &FlatUExpression::with_field(FlatExpression::Add( - box FlatExpression::Sub(box r.into(), box d.clone()), - box FlatExpression::Number(T::from(2_u128.pow(target_bitwidth as u32))), + &FlatUExpression::with_field(FlatExpression::add( + FlatExpression::sub(r.into(), d.clone()), + FlatExpression::value(T::from(2_u128.pow(target_bitwidth as u32))), )), target_bitwidth, target_bitwidth, @@ -1367,9 +1427,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); // q*d == n - r - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Sub(box n, box r.into()), - FlatExpression::Mult(box q.into(), box d), + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::sub(n, r.into()), + FlatExpression::mul(q.into(), d), RuntimeError::Euclidean, )); @@ -1387,6 +1447,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, expr: UExpression<'ast, T>, ) -> FlatUExpression { + let span = expr.as_inner().get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + // the bitwidth for this type of uint (8, 16 or 32) let target_bitwidth = expr.bitwidth; @@ -1402,10 +1468,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { let res = match expr.into_inner() { UExpressionInner::Value(x) => { - FlatUExpression::with_field(FlatExpression::Number(T::from(x))) + FlatUExpression::with_field(FlatExpression::value(T::from(x.value))) } // force to be a field element UExpressionInner::Identifier(x) => { - let field = FlatExpression::Identifier(*self.layout.get(&x.id).unwrap()); + let field = FlatExpression::identifier(*self.layout.get(&x.id).unwrap()); let bits = self.bits_cache.get(&field).map(|bits| { assert_eq!(bits.len(), target_bitwidth.to_usize()); bits.clone() @@ -1413,8 +1479,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_field(field).bits(bits) } UExpressionInner::Select(e) => self.flatten_select_expression(statements_flattened, e), - UExpressionInner::Not(box e) => { - let e = self.flatten_uint_expression(statements_flattened, e); + UExpressionInner::Not(e) => { + let e = self.flatten_uint_expression(statements_flattened, *e.inner); let e_bits = e.bits.unwrap(); @@ -1424,7 +1490,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .map(|bit| { self.define( - FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box bit), + FlatExpression::sub(FlatExpression::value(T::from(1)), bit), statements_flattened, ) .into() @@ -1433,66 +1499,73 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits(name_not) } - UExpressionInner::Add(box left, box right) => { + UExpressionInner::Add(e) => { let left_flattened = self - .flatten_uint_expression(statements_flattened, left) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(statements_flattened, right) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatUExpression::with_field(FlatExpression::Add(box new_left, box new_right)) + FlatUExpression::with_field(FlatExpression::add(new_left, new_right)) } - UExpressionInner::Sub(box left, box right) => { + UExpressionInner::Sub(e) => { // see uint optimizer for the reasoning here - let offset = FlatExpression::Number(T::from(2).pow(std::cmp::max( - right.metadata.as_ref().unwrap().bitwidth() as usize, + let offset = FlatExpression::value(T::from(2).pow(std::cmp::max( + e.right.metadata.as_ref().unwrap().bitwidth() as usize, target_bitwidth as usize, ))); let left_flattened = self - .flatten_uint_expression(statements_flattened, left) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(statements_flattened, right) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatUExpression::with_field(FlatExpression::Add( - box offset, - box FlatExpression::Sub(box new_left, box new_right), + FlatUExpression::with_field(FlatExpression::add( + offset, + FlatExpression::sub(new_left, new_right), )) } - UExpressionInner::LeftShift(box e, by) => { + UExpressionInner::LeftShift(e) => { + let by = match e.right.into_inner() { + UExpressionInner::Value(v) => v.value as u32, + _ => unreachable!(), + }; + + let e = *e.left; + let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); @@ -1505,12 +1578,19 @@ impl<'ast, T: Field> Flattener<'ast, T> { .skip(by as usize) .chain( (0..std::cmp::min(by as usize, target_bitwidth.to_usize())) - .map(|_| FlatExpression::Number(T::from(0))), + .map(|_| FlatExpression::value(T::from(0))), ) .collect::>(), ) } - UExpressionInner::RightShift(box e, by) => { + UExpressionInner::RightShift(e) => { + let by = match e.right.into_inner() { + UExpressionInner::Value(v) => v.value as u32, + _ => unreachable!(), + }; + + let e = *e.left; + let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); @@ -1519,7 +1599,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits( (0..std::cmp::min(by as usize, target_bitwidth.to_usize())) - .map(|_| FlatExpression::Number(T::from(0))) + .map(|_| FlatExpression::value(T::from(0))) .chain(e_bits.into_iter().take( target_bitwidth.to_usize() - std::cmp::min(by as usize, target_bitwidth.to_usize()), @@ -1527,59 +1607,75 @@ impl<'ast, T: Field> Flattener<'ast, T> { .collect::>(), ) } - UExpressionInner::Mult(box left, box right) => { + UExpressionInner::Mult(e) => { let left_flattened = self - .flatten_uint_expression(statements_flattened, left) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(statements_flattened, right) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; let res = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( res, - FlatExpression::Mult(box new_left, box new_right), + FlatExpression::mul(new_left, new_right), )); - FlatUExpression::with_field(FlatExpression::Identifier(res)) + FlatUExpression::with_field(FlatExpression::identifier(res)) } - UExpressionInner::Div(box left, box right) => { - let (q, _) = - self.euclidean_division(statements_flattened, target_bitwidth, left, right); + UExpressionInner::Div(e) => { + let (q, _) = self.euclidean_division( + statements_flattened, + target_bitwidth, + *e.left, + *e.right, + ); FlatUExpression::with_field(q) } - UExpressionInner::Rem(box left, box right) => { - let (_, r) = - self.euclidean_division(statements_flattened, target_bitwidth, left, right); + UExpressionInner::Rem(e) => { + let (_, r) = self.euclidean_division( + statements_flattened, + target_bitwidth, + *e.left, + *e.right, + ); FlatUExpression::with_field(r) } UExpressionInner::Conditional(e) => { self.flatten_conditional_expression(statements_flattened, e) } - UExpressionInner::Xor(box left, box right) => { - let left_metadata = left.metadata.clone().unwrap(); - let right_metadata = right.metadata.clone().unwrap(); + UExpressionInner::Xor(e) => { + let left_metadata = e.left.metadata.clone().unwrap(); + let right_metadata = e.right.metadata.clone().unwrap(); + + let left_span = e.left.get_span(); + let right_span = e.right.get_span(); + + match (e.left.into_inner(), e.right.into_inner()) { + (UExpressionInner::And(e), UExpressionInner::And(ee)) => { + let a = *e.left; + let b = *e.right; + let aa = *ee.left; + let c = *ee.right; - match (left.into_inner(), right.into_inner()) { - (UExpressionInner::And(box a, box b), UExpressionInner::And(box aa, box c)) => { - if aa.clone().into_inner() == UExpressionInner::Not(box a.clone()) { + if aa.as_inner() == UExpression::not(a.clone()).as_inner() { let a_flattened = self.flatten_uint_expression(statements_flattened, a); let b_flattened = self.flatten_uint_expression(statements_flattened, b); let c_flattened = self.flatten_uint_expression(statements_flattened, c); @@ -1598,17 +1694,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { let ch = self.use_sym(); statements_flattened.extend(vec![ - FlatStatement::Directive(FlatDirective::new( + FlatStatement::directive( vec![ch], Solver::ShaCh, vec![a.clone(), b.clone(), c.clone()], - )), - FlatStatement::Condition( - FlatExpression::Sub(box ch.into(), box c.clone()), - FlatExpression::Mult( - box a, - box FlatExpression::Sub(box b, box c), - ), + ), + FlatStatement::condition( + FlatExpression::sub(ch.into(), c.clone()), + a * (b - c), RuntimeError::ShaXor, ), ]); @@ -1620,25 +1713,32 @@ impl<'ast, T: Field> Flattener<'ast, T> { } else { self.default_xor( statements_flattened, - UExpressionInner::And(box a, box b) - .annotate(target_bitwidth) - .metadata(left_metadata), - UExpressionInner::And(box aa, box c) - .annotate(target_bitwidth) - .metadata(right_metadata), + UExpression::and(a, b).metadata(left_metadata), + UExpression::and(aa, c).metadata(right_metadata), ) } } - (UExpressionInner::Xor(box a, box b), c) => { - let a_metadata = a.metadata.clone().unwrap(); - let b_metadata = b.metadata.clone().unwrap(); + (UExpressionInner::Xor(e), c) => { + let a_metadata = e.left.metadata.clone().unwrap(); + let b_metadata = e.right.metadata.clone().unwrap(); + + let a_span = e.left.get_span(); + let b_span = e.right.get_span(); + let c_span = right_span; - match (a.into_inner(), b.into_inner(), c) { + match (e.left.into_inner(), e.right.into_inner(), c) { ( - UExpressionInner::And(box a, box b), - UExpressionInner::And(box aa, box c), - UExpressionInner::And(box bb, box cc), + UExpressionInner::And(e0), + UExpressionInner::And(e1), + UExpressionInner::And(e2), ) => { + let a = *e0.left; + let b = *e0.right; + let aa = *e1.left; + let c = *e1.right; + let bb = *e2.left; + let cc = *e2.right; + if (aa == a) && (bb == b) && (cc == c) { let a_flattened = self.flatten_uint_expression(statements_flattened, a); @@ -1663,33 +1763,27 @@ impl<'ast, T: Field> Flattener<'ast, T> { let bc = self.use_sym(); statements_flattened.extend(vec![ - FlatStatement::Directive(FlatDirective::new( + FlatStatement::directive( vec![maj], Solver::ShaAndXorAndXorAnd, vec![a.clone(), b.clone(), c.clone()], - )), - FlatStatement::Condition( + ), + FlatStatement::condition( bc.into(), - FlatExpression::Mult( - box b.clone(), - box c.clone(), - ), + FlatExpression::mul(b.clone(), c.clone()), RuntimeError::ShaXor, ), - FlatStatement::Condition( - FlatExpression::Sub( - box bc.into(), - box maj.into(), - ), - FlatExpression::Mult( - box FlatExpression::Sub( - box FlatExpression::Add( - box bc.into(), - box bc.into(), + FlatStatement::condition( + FlatExpression::sub(bc.into(), maj.into()), + FlatExpression::mul( + FlatExpression::sub( + FlatExpression::add( + bc.into(), + bc.into(), ), - box FlatExpression::Add(box b, box c), + FlatExpression::add(b, c), ), - box a, + a, ), RuntimeError::ShaXor, ), @@ -1702,45 +1796,57 @@ impl<'ast, T: Field> Flattener<'ast, T> { } else { self.default_xor( statements_flattened, - UExpressionInner::Xor( - box UExpressionInner::And(box a, box b) - .annotate(target_bitwidth) - .metadata(a_metadata), - box UExpressionInner::And(box aa, box c) - .annotate(target_bitwidth) - .metadata(b_metadata), + UExpression::xor( + UExpression::and(a, b) + .metadata(a_metadata) + .span(a_span), + UExpression::and(aa, c) + .metadata(b_metadata) + .span(b_span), ) - .annotate(target_bitwidth) - .metadata(left_metadata), - UExpressionInner::And(box bb, box cc) - .annotate(target_bitwidth) - .metadata(right_metadata), + .metadata(left_metadata) + .span(left_span), + UExpression::and(bb, cc) + .metadata(right_metadata) + .span(c_span), ) } } (a, b, c) => self.default_xor( statements_flattened, - UExpressionInner::Xor( - box a.annotate(target_bitwidth).metadata(a_metadata), - box b.annotate(target_bitwidth).metadata(b_metadata), + UExpression::xor( + a.annotate(target_bitwidth) + .metadata(a_metadata) + .span(a_span), + b.annotate(target_bitwidth) + .metadata(b_metadata) + .span(b_span), ) - .annotate(target_bitwidth) - .metadata(left_metadata), - c.annotate(target_bitwidth).metadata(right_metadata), + .metadata(left_metadata) + .span(left_span), + c.annotate(target_bitwidth) + .metadata(right_metadata) + .span(c_span), ), } } (left_i, right_i) => self.default_xor( statements_flattened, - left_i.annotate(target_bitwidth).metadata(left_metadata), - right_i.annotate(target_bitwidth).metadata(right_metadata), + left_i + .annotate(target_bitwidth) + .metadata(left_metadata) + .span(left_span), + right_i + .annotate(target_bitwidth) + .metadata(right_metadata) + .span(right_span), ), } } - UExpressionInner::And(box left, box right) => { - let left_flattened = self.flatten_uint_expression(statements_flattened, left); + UExpressionInner::And(e) => { + let left_flattened = self.flatten_uint_expression(statements_flattened, *e.left); - let right_flattened = self.flatten_uint_expression(statements_flattened, right); + let right_flattened = self.flatten_uint_expression(statements_flattened, *e.right); let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); @@ -1752,26 +1858,26 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { - (FlatExpression::Number(n), e) | (e, FlatExpression::Number(n)) => { - if n == T::from(0) { - FlatExpression::Number(T::from(0)) - } else if n == T::from(1) { + (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { + if n.value == T::from(0) { + FlatExpression::value(T::from(0)) + } else if n.value == T::from(1) { e } else { unreachable!(); } } (x, y) => self - .define(FlatExpression::Mult(box x, box y), statements_flattened) + .define(FlatExpression::mul(x, y), statements_flattened) .into(), }) .collect(); FlatUExpression::with_bits(and) } - UExpressionInner::Or(box left, box right) => { - let left_flattened = self.flatten_uint_expression(statements_flattened, left); - let right_flattened = self.flatten_uint_expression(statements_flattened, right); + UExpressionInner::Or(e) => { + let left_flattened = self.flatten_uint_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_uint_expression(statements_flattened, *e.right); let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); @@ -1786,11 +1892,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { - (FlatExpression::Number(n), e) | (e, FlatExpression::Number(n)) => { - if n == T::from(0) { + (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { + if n.value == T::from(0) { self.define(e, statements_flattened).into() - } else if n == T::from(1) { - FlatExpression::Number(T::from(1)) + } else if n.value == T::from(1) { + FlatExpression::value(T::from(1)) } else { unreachable!() } @@ -1799,17 +1905,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { let name = self.use_sym(); statements_flattened.extend(vec![ - FlatStatement::Directive(FlatDirective::new( + FlatStatement::directive( vec![name], Solver::Or, vec![x.clone(), y.clone()], - )), - FlatStatement::Condition( - FlatExpression::Add( - box x.clone(), - box FlatExpression::Sub(box y.clone(), box name.into()), + ), + FlatStatement::condition( + FlatExpression::add( + x.clone(), + FlatExpression::sub(y.clone(), name.into()), ), - FlatExpression::Mult(box x, box y), + FlatExpression::mul(x, y), RuntimeError::Or, ), ]); @@ -1820,7 +1926,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits(or) } - }; + } + .span(span); let res = match should_reduce { true => { @@ -1834,15 +1941,15 @@ impl<'ast, T: Field> Flattener<'ast, T> { let field = if actual_bitwidth > target_bitwidth.to_usize() { bits.iter().enumerate().fold( - FlatExpression::Number(T::from(0)), + FlatExpression::value(T::from(0)), |acc, (index, bit)| { - FlatExpression::Add( - box acc, - box FlatExpression::Mult( - box FlatExpression::Number( + FlatExpression::add( + acc, + FlatExpression::mul( + FlatExpression::value( T::from(2).pow(target_bitwidth.to_usize() - index - 1), ), - box bit.clone(), + bit.clone(), ), ) }, @@ -1856,6 +1963,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { false => res, }; + statements_flattened.set_span(span_backup); + res } @@ -1884,11 +1993,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { assert!(to <= T::get_required_bits()); // constants do not require directives - if let Some(FlatExpression::Number(ref x)) = e.field { - let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()]) + if let Some(FlatExpression::Value(ref x)) = e.field { + let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.value], &[]) .unwrap() .into_iter() - .map(FlatExpression::Number) + .map(FlatExpression::value) .collect(); assert_eq!(bits.len(), to); @@ -1915,28 +2024,28 @@ impl<'ast, T: Field> Flattener<'ast, T> { res } else { (0..to - res.len()) - .map(|_| FlatExpression::Number(T::zero())) + .map(|_| FlatExpression::value(T::zero())) .chain(res) .collect() } } Entry::Vacant(_) => { let bits = (0..from).map(|_| self.use_sym()).collect::>(); - statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::directive( bits.clone(), Solver::Bits(from), vec![e.field.clone().unwrap()], - ))); + )); - let bits: Vec<_> = bits.into_iter().map(FlatExpression::Identifier).collect(); + let bits: Vec<_> = bits.into_iter().map(FlatExpression::identifier).collect(); // decompose to the actual bitwidth // bit checks statements_flattened.extend(bits.iter().take(from).map(|bit| { - FlatStatement::Condition( + FlatStatement::condition( bit.clone(), - FlatExpression::Mult(box bit.clone(), box bit.clone()), + FlatExpression::mul(bit.clone(), bit.clone()), RuntimeError::Bitness, ) })); @@ -1944,7 +2053,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let sum = flat_expression_from_bits(bits.clone()); // sum check - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e.field.clone().unwrap(), sum.clone(), error, @@ -1974,6 +2083,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, e: SelectExpression<'ast, T, U>, ) -> FlatUExpression { + let span = e.get_span(); + let array = e.array; let index = *e.index; @@ -1983,15 +2094,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { .map(|(i, e)| { let condition = self.flatten_boolean_expression( statements_flattened, - BooleanExpression::UintEq( - box UExpressionInner::Value(i as u128) + BooleanExpression::uint_eq( + UExpression::value(i as u128) .annotate(UBitwidth::B32) .metadata(UMetadata { should_reduce: ShouldReduce::True, max: T::from(i), - }), - box index.clone(), - ), + }) + .span(span), + index.clone(), + ) + .span(span), ); let element = e.flatten(self, statements_flattened); @@ -2002,26 +2115,29 @@ impl<'ast, T: Field> Flattener<'ast, T> { .into_iter() .fold( ( - FlatExpression::Number(T::zero()), - FlatExpression::Number(T::zero()), + FlatExpression::value(T::zero()), + FlatExpression::value(T::zero()), ), |(mut range_check, mut result), (condition, element)| { - range_check = FlatExpression::Add(box range_check, box condition.clone()); + range_check = FlatExpression::add(range_check, condition.clone()).span(span); let conditional_element_id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( - conditional_element_id, - FlatExpression::Mult(box condition, box element.flat()), - )); + statements_flattened.push_back( + FlatStatement::definition( + conditional_element_id, + FlatExpression::mul(condition, element.flat()).span(span), + ) + .span(span), + ); - result = FlatExpression::Add(box result, box conditional_element_id.into()); + result = FlatExpression::add(result, conditional_element_id.into()).span(span); (range_check, result) }, ); - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( range_check, - FlatExpression::Number(T::one()), + FlatExpression::value(T::one()), RuntimeError::SelectRangeCheck, )); FlatUExpression::with_field(result) @@ -2038,84 +2154,90 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, expr: FieldElementExpression<'ast, T>, ) -> FlatExpression { - match expr { - FieldElementExpression::Number(x) => FlatExpression::Number(x), // force to be a field element - FieldElementExpression::Identifier(x) => FlatExpression::Identifier( + let span = expr.get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + + let res = match expr { + FieldElementExpression::Value(x) => FlatExpression::Value(x), // force to be a field element + FieldElementExpression::Identifier(x) => FlatExpression::identifier( *self.layout.get(&x.id).unwrap_or_else(|| panic!("{}", x)), ), FieldElementExpression::Select(e) => self .flatten_select_expression(statements_flattened, e) .get_field_unchecked(), - FieldElementExpression::Add(box left, box right) => { - let left_flattened = self.flatten_field_expression(statements_flattened, left); - let right_flattened = self.flatten_field_expression(statements_flattened, right); + FieldElementExpression::Add(e) => { + let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatExpression::Add(box new_left, box new_right) + FlatExpression::add(new_left, new_right) } - FieldElementExpression::Sub(box left, box right) => { - let left_flattened = self.flatten_field_expression(statements_flattened, left); - let right_flattened = self.flatten_field_expression(statements_flattened, right); + FieldElementExpression::Sub(e) => { + let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatExpression::Sub(box new_left, box new_right) + FlatExpression::sub(new_left, new_right) } - FieldElementExpression::Mult(box left, box right) => { - let left_flattened = self.flatten_field_expression(statements_flattened, left); - let right_flattened = self.flatten_field_expression(statements_flattened, right); + FieldElementExpression::Mult(e) => { + let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); + FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); - FlatExpression::Identifier(id) + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); + FlatExpression::identifier(id) }; - FlatExpression::Mult(box new_left, box new_right) + FlatExpression::mul(new_left, new_right) } - FieldElementExpression::Div(box left, box right) => { - let left_flattened = self.flatten_field_expression(statements_flattened, left); - let right_flattened = self.flatten_field_expression(statements_flattened, right); + FieldElementExpression::Div(e) => { + let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); + let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left: FlatExpression = { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, left_flattened)); + statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); id.into() }; let new_right: FlatExpression = { let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(id, right_flattened)); + statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); id.into() }; @@ -2125,34 +2247,34 @@ impl<'ast, T: Field> Flattener<'ast, T> { let inverse = self.use_sym(); // # c = a/b - statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( + statements_flattened.push_back(FlatStatement::directive( vec![inverse], Solver::Div, vec![new_left.clone(), new_right.clone()], - ))); + )); // assert(c * b == a) - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( new_left, - FlatExpression::Mult(box new_right, box inverse.into()), + FlatExpression::mul(new_right, inverse.into()), RuntimeError::Division, )); inverse.into() } - FieldElementExpression::Pow(box base, box exponent) => { - match exponent.into_inner() { - UExpressionInner::Value(ref e) => { + FieldElementExpression::Pow(e) => { + match e.right.into_inner() { + UExpressionInner::Value(ref exp) => { // flatten the base expression let base_flattened = - self.flatten_field_expression(statements_flattened, base.clone()); + self.flatten_field_expression(statements_flattened, *e.left.clone()); // we require from the base to be linear // TODO change that assert!(base_flattened.is_linear()); // convert the exponent to bytes, big endian - let ebytes_be = e.to_be_bytes(); + let ebytes_be = exp.value.to_be_bytes(); // convert the bytes to bits, remove leading zeroes (we only need powers up to the highest non-zero bit) #[allow(clippy::needless_collect)] @@ -2181,17 +2303,14 @@ impl<'ast, T: Field> Flattener<'ast, T> { // introduce a new variable let id = self.use_sym(); // set it to the square of the previous one, stored in state - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( id, - FlatExpression::Mult( - box previous.clone(), - box previous.clone(), - ), + FlatExpression::mul(previous.clone(), previous.clone()), )); // store it in the state for later squaring - *state = Some(FlatExpression::Identifier(id)); + *state = Some(FlatExpression::identifier(id)); // return it for later use constructing the result - Some(FlatExpression::Identifier(id)) + Some(FlatExpression::identifier(id)) } } }) @@ -2199,16 +2318,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { // construct the result iterating through the bits, multiplying by the associated power iff the bit is true ebits_le.into_iter().zip(powers).fold( - FlatExpression::Number(T::from(1)), // initialise the result at 1. If we have no bits to itegrate through, we're computing x**0 == 1 + FlatExpression::value(T::from(1)), // initialise the result at 1. If we have no bits to iterate through, we're computing x**0 == 1 |acc, (bit, power)| match bit { true => { // update the result by introducing a new variable let id = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition( + statements_flattened.push_back(FlatStatement::definition( id, - FlatExpression::Mult(box acc.clone(), box power), // set the new result to the current result times the current power + FlatExpression::mul(acc, power), // set the new result to the current result times the current power )); - FlatExpression::Identifier(id) + FlatExpression::identifier(id) } false => acc, // this bit is false, keep the previous result }, @@ -2221,7 +2340,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { .flatten_conditional_expression(statements_flattened, e) .get_field_unchecked(), _ => unreachable!(), - } + }; + + statements_flattened.set_span(span_backup); + + res } fn flatten_assembly_statement( @@ -2229,31 +2352,46 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements<'ast, T>, stat: ZirAssemblyStatement<'ast, T>, ) { + let span = stat.get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + match stat { - ZirAssemblyStatement::Assignment(assignees, function) => { - let inputs: Vec> = function + ZirAssemblyStatement::Assignment(s) => { + let inputs: Vec> = s + .expression .arguments .iter() .cloned() .map(|p| self.layout.get(&p.id.id).cloned().unwrap().into()) .collect(); - let outputs: Vec = assignees + + let outputs: Vec = s + .assignee .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); + + let mut canonicalizer = ZirCanonicalizer::default(); + let function = canonicalizer.fold_function(s.expression); + let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); } - ZirAssemblyStatement::Constraint(lhs, rhs, metadata) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); - let rhs = self.flatten_field_expression(statements_flattened, rhs); - statements_flattened.push_back(FlatStatement::Condition( + ZirAssemblyStatement::Constraint(s) => { + let lhs = self.flatten_field_expression(statements_flattened, s.left); + let rhs = self.flatten_field_expression(statements_flattened, s.right); + statements_flattened.push_back(FlatStatement::condition( lhs, rhs, - RuntimeError::SourceAssemblyConstraint(metadata), + RuntimeError::SourceAssemblyConstraint(s.metadata), )); } - } + }; + + statements_flattened.set_span(span_backup); } /// Flattens a statement @@ -2265,21 +2403,21 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn flatten_statement( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, - stat: ZirStatement<'ast, T>, + s: ZirStatement<'ast, T>, ) { - match stat { - ZirStatement::Assembly(statements) => { - let mut block_statements = VecDeque::new(); - for s in statements { - self.flatten_assembly_statement(&mut block_statements, s); - } - statements_flattened.push_back(FlatStatement::Block(block_statements.into())); - } - ZirStatement::Return(exprs) => { + let span = s.get_span(); + + let span_backup = statements_flattened.span; + + statements_flattened.set_span(span); + + match s { + ZirStatement::Return(s) => { #[allow(clippy::needless_collect)] // clippy suggests to not collect here, but `statements_flattened` is borrowed in the iterator, // so we cannot borrow again when extending - let flat_expressions: Vec<_> = exprs + let flat_expressions: Vec<_> = s + .inner .into_iter() .map(|expr| self.flatten_expression(statements_flattened, expr)) .map(|x| x.get_field_unchecked()) @@ -2289,25 +2427,35 @@ impl<'ast, T: Field> Flattener<'ast, T> { flat_expressions .into_iter() .enumerate() - .map(|(index, e)| FlatStatement::Definition(Variable::public(index), e)), + .map(|(index, e)| FlatStatement::definition(Variable::public(index), e)), ); } - ZirStatement::IfElse(condition, consequence, alternative) => { + ZirStatement::Assembly(s) => { + let mut block_statements = FlatStatements::default(); + block_statements.set_span(s.get_span()); + + for s in s.inner { + self.flatten_assembly_statement(&mut block_statements, s); + } + statements_flattened + .push_back(FlatStatement::block(block_statements.buffer.into())); + } + ZirStatement::IfElse(s) => { let condition_flat = - self.flatten_boolean_expression(statements_flattened, condition.clone()); + self.flatten_boolean_expression(statements_flattened, s.condition.clone()); let condition_id = self.use_sym(); statements_flattened - .push_back(FlatStatement::Definition(condition_id, condition_flat)); + .push_back(FlatStatement::definition(condition_id, condition_flat)); if self.config.isolate_branches { - let mut consequence_statements = VecDeque::new(); - let mut alternative_statements = VecDeque::new(); + let mut consequence_statements = FlatStatements::default(); + let mut alternative_statements = FlatStatements::default(); - consequence + s.consequence .into_iter() .for_each(|s| self.flatten_statement(&mut consequence_statements, s)); - alternative + s.alternative .into_iter() .for_each(|s| self.flatten_statement(&mut alternative_statements, s)); @@ -2315,41 +2463,41 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.make_conditional(consequence_statements, condition_id.into()); let alternative_statements = self.make_conditional( alternative_statements, - FlatExpression::Sub( - box FlatExpression::Number(T::one()), - box condition_id.into(), - ), + FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), ); statements_flattened.extend(consequence_statements); statements_flattened.extend(alternative_statements); } else { - consequence + s.consequence .into_iter() .for_each(|s| self.flatten_statement(statements_flattened, s)); - alternative + s.alternative .into_iter() .for_each(|s| self.flatten_statement(statements_flattened, s)); } } - ZirStatement::Definition(assignee, expr) => { + ZirStatement::Definition(s) => { // define n variables with n the number of primitive types for v_type // assign them to the n primitive types for expr + let assignee = s.assignee; + let expr = s.rhs; + let rhs = self.flatten_expression(statements_flattened, expr); let bits = rhs.bits.clone(); let var = match rhs.get_field_unchecked() { FlatExpression::Identifier(id) => { - self.use_variable_with_existing(&assignee, id); - id + self.use_variable_with_existing(&assignee, id.id); + id.id } e => { let var = self.use_variable(&assignee); // handle return of function call - statements_flattened.push_back(FlatStatement::Definition(var, e)); + statements_flattened.push_back(FlatStatement::definition(var, e)); var } @@ -2358,22 +2506,25 @@ impl<'ast, T: Field> Flattener<'ast, T> { // register bits if let Some(bits) = bits { self.bits_cache - .insert(FlatExpression::Identifier(var), bits); + .insert(FlatExpression::identifier(var), bits); } } - ZirStatement::Assertion(e, error) => { + ZirStatement::Assertion(s) => { + let e = s.expression; + let error = s.error; + match e { BooleanExpression::And(..) => { for boolean in e.into_conjunction_iterator() { self.flatten_statement( statements_flattened, - ZirStatement::Assertion(boolean, error.clone()), + ZirStatement::assertion(boolean, error.clone()).span(span), ) } } - BooleanExpression::FieldEq(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); - let rhs = self.flatten_field_expression(statements_flattened, rhs); + BooleanExpression::FieldEq(e) => { + let lhs = self.flatten_field_expression(statements_flattened, *e.left); + let rhs = self.flatten_field_expression(statements_flattened, *e.right); self.flatten_equality_assertion( statements_flattened, @@ -2382,106 +2533,106 @@ impl<'ast, T: Field> Flattener<'ast, T> { error.into(), ) } - BooleanExpression::FieldLt(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); - let rhs = self.flatten_field_expression(statements_flattened, rhs); + BooleanExpression::FieldLt(e) => { + let lhs = self.flatten_field_expression(statements_flattened, *e.left); + let rhs = self.flatten_field_expression(statements_flattened, *e.right); match (lhs, rhs) { - (e, FlatExpression::Number(c)) => self.enforce_constant_lt_check( + (e, FlatExpression::Value(c)) => self.enforce_constant_lt_check( statements_flattened, e, - c, + c.value, error.into(), ), // c < e <=> p - 1 - e < p - 1 - c - (FlatExpression::Number(c), e) => self.enforce_constant_lt_check( + (FlatExpression::Value(c), e) => self.enforce_constant_lt_check( statements_flattened, - FlatExpression::Sub(box T::max_value().into(), box e), - T::max_value() - c, + FlatExpression::sub(T::max_value().into(), e), + T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.lt_check(statements_flattened, lhs, rhs, safe_width); - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e, - FlatExpression::Number(T::one()), + FlatExpression::value(T::one()), error.into(), )); } } } - BooleanExpression::FieldLe(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(statements_flattened, lhs); - let rhs = self.flatten_field_expression(statements_flattened, rhs); + BooleanExpression::FieldLe(e) => { + let lhs = self.flatten_field_expression(statements_flattened, *e.left); + let rhs = self.flatten_field_expression(statements_flattened, *e.right); match (lhs, rhs) { - (e, FlatExpression::Number(c)) => self.enforce_constant_le_check( + (e, FlatExpression::Value(c)) => self.enforce_constant_le_check( statements_flattened, e, - c, + c.value, error.into(), ), // c <= e <=> p - 1 - e <= p - 1 - c - (FlatExpression::Number(c), e) => self.enforce_constant_le_check( + (FlatExpression::Value(c), e) => self.enforce_constant_le_check( statements_flattened, - FlatExpression::Sub(box T::max_value().into(), box e), - T::max_value() - c, + FlatExpression::sub(T::max_value().into(), e), + T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.le_check(statements_flattened, lhs, rhs, safe_width); - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e, - FlatExpression::Number(T::one()), + FlatExpression::value(T::one()), error.into(), )); } } } - BooleanExpression::UintLe(box lhs, box rhs) => { + BooleanExpression::UintLe(e) => { let lhs = self - .flatten_uint_expression(statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self - .flatten_uint_expression(statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); match (lhs, rhs) { - (e, FlatExpression::Number(c)) => self.enforce_constant_le_check( + (e, FlatExpression::Value(c)) => self.enforce_constant_le_check( statements_flattened, e, - c, + c.value, error.into(), ), // c <= e <=> p - 1 - e <= p - 1 - c - (FlatExpression::Number(c), e) => self.enforce_constant_le_check( + (FlatExpression::Value(c), e) => self.enforce_constant_le_check( statements_flattened, - FlatExpression::Sub(box T::max_value().into(), box e), - T::max_value() - c, + FlatExpression::sub(T::max_value().into(), e), + T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.le_check(statements_flattened, lhs, rhs, safe_width); - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( e, - FlatExpression::Number(T::one()), + FlatExpression::value(T::one()), error.into(), )); } } } - BooleanExpression::UintEq(box lhs, box rhs) => { + BooleanExpression::UintEq(e) => { let lhs = self - .flatten_uint_expression(statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self - .flatten_uint_expression(statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.flatten_equality_assertion( @@ -2491,9 +2642,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { error.into(), ) } - BooleanExpression::BoolEq(box lhs, box rhs) => { - let lhs = self.flatten_boolean_expression(statements_flattened, lhs); - let rhs = self.flatten_boolean_expression(statements_flattened, rhs); + BooleanExpression::BoolEq(e) => { + let lhs = self.flatten_boolean_expression(statements_flattened, *e.left); + let rhs = self.flatten_boolean_expression(statements_flattened, *e.right); self.flatten_equality_assertion( statements_flattened, @@ -2502,106 +2653,78 @@ impl<'ast, T: Field> Flattener<'ast, T> { error.into(), ) } - // `!(x == 0)` can be asserted by giving the inverse of `x` - BooleanExpression::Not(box BooleanExpression::UintEq( - box UExpression { - inner: UExpressionInner::Value(0), - .. - }, - box x, - )) - | BooleanExpression::Not(box BooleanExpression::UintEq( - box x, - box UExpression { - inner: UExpressionInner::Value(0), - .. - }, - )) => { - let x = self - .flatten_uint_expression(statements_flattened, x) - .get_field_unchecked(); - - // introduce intermediate variable - let x_id = self.define(x, statements_flattened); - - // check that `x` is not 0 by giving its inverse - let invx = self.use_sym(); - - // # invx = 1/x - statements_flattened.push_back(FlatStatement::Directive( - FlatDirective::new( - vec![invx], - Solver::Div, - vec![FlatExpression::Number(T::one()), x_id.into()], - ), - )); - - // assert(invx * x == 1) - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::one()), - FlatExpression::Mult(box invx.into(), box x_id.into()), - RuntimeError::Inverse, - )); - } - // `!(x == 0)` can be asserted by giving the inverse of `x` - BooleanExpression::Not(box BooleanExpression::FieldEq( - box FieldElementExpression::Number(zero), - box x, - )) - | BooleanExpression::Not(box BooleanExpression::FieldEq( - box x, - box FieldElementExpression::Number(zero), - )) if zero == T::from(0) => { - let x = self.flatten_field_expression(statements_flattened, x); - - // introduce intermediate variable - let x_id = self.define(x, statements_flattened); - - // check that `x` is not 0 by giving its inverse - let invx = self.use_sym(); - - // # invx = 1/x - statements_flattened.push_back(FlatStatement::Directive( - FlatDirective::new( - vec![invx], - Solver::Div, - vec![FlatExpression::Number(T::one()), x_id.into()], + BooleanExpression::Not(u) => { + let inner_span = u.get_span(); + + match *u.inner { + BooleanExpression::UintEq(b) => { + if let UExpressionInner::Value(ValueExpression { + value: 0, .. + }) = b.left.inner + { + let x = self + .flatten_uint_expression(statements_flattened, *b.right) + .get_field_unchecked(); + self.enforce_not_zero_assertion(statements_flattened, x) + } else if let UExpressionInner::Value(ValueExpression { + value: 0, + .. + }) = b.right.inner + { + let x = self + .flatten_uint_expression(statements_flattened, *b.left) + .get_field_unchecked(); + self.enforce_not_zero_assertion(statements_flattened, x) + } else { + self.enforce_naive_assertion( + statements_flattened, + BooleanExpression::not(BooleanExpression::UintEq(b)), + error, + ); + } + } + BooleanExpression::FieldEq(b) => match (*b.left, *b.right) { + ( + FieldElementExpression::Value(ValueExpression { + value: zero, + .. + }), + x, + ) + | ( + x, + FieldElementExpression::Value(ValueExpression { + value: zero, + .. + }), + ) if zero == T::from(0) => { + let x = self.flatten_field_expression(statements_flattened, x); + self.enforce_not_zero_assertion(statements_flattened, x) + } + (left, right) => self.enforce_naive_assertion( + statements_flattened, + BooleanExpression::not( + BooleanExpression::field_eq(left, right).span(inner_span), + ) + .span(span), + error, + ), + }, + e => self.enforce_naive_assertion( + statements_flattened, + BooleanExpression::not(e.span(inner_span)).span(span), + error, ), - )); - - // assert(invx * x == 1) - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::one()), - FlatExpression::Mult(box invx.into(), box x_id.into()), - RuntimeError::Inverse, - )); - } - e => { - // naive approach: flatten the boolean to a single field element and constrain it to 1 - let e = self.flatten_boolean_expression(statements_flattened, e); - - if e.is_linear() { - statements_flattened.push_back(FlatStatement::Condition( - e, - FlatExpression::Number(T::from(1)), - error.into(), - )); - } else { - // swap so that left side is linear - statements_flattened.push_back(FlatStatement::Condition( - FlatExpression::Number(T::from(1)), - e, - error.into(), - )); } } + e => self.enforce_naive_assertion(statements_flattened, e, error), } } - ZirStatement::MultipleDefinition(vars, rhs) => { + ZirStatement::MultipleDefinition(s) => { // flatten the right side to p = sum(var_i.type.primitive_count) expressions // define p new variables to the right side expressions - match rhs { + match s.rhs { ZirExpressionList::EmbedCall(embed, generics, exprs) => { let rhs_flattened = self.flatten_embed_call( statements_flattened, @@ -2612,20 +2735,21 @@ impl<'ast, T: Field> Flattener<'ast, T> { let rhs = rhs_flattened.into_iter(); - assert_eq!(vars.len(), rhs.len()); + assert_eq!(s.assignees.len(), rhs.len()); - let vars: Vec<_> = vars + let assignees: Vec<_> = s + .assignees .into_iter() .zip(rhs) .map(|(v, r)| match r.get_field_unchecked() { FlatExpression::Identifier(id) => { - self.use_variable_with_existing(&v, id); - id + self.use_variable_with_existing(&v, id.id); + id.id } e => { let id = self.use_variable(&v); statements_flattened - .push_back(FlatStatement::Definition(id, e)); + .push_back(FlatStatement::definition(id, e)); id } }) @@ -2643,15 +2767,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { .get_field_unchecked() }) .collect(); - self.bits_cache.insert(vars[0].into(), bits); + self.bits_cache.insert(assignees[0].into(), bits); } _ => {} } } } } - ZirStatement::Log(l, expressions) => { - let expressions = expressions + ZirStatement::Log(s) => { + let expressions = s + .expressions .into_iter() .map(|(t, e)| { ( @@ -2666,11 +2791,73 @@ impl<'ast, T: Field> Flattener<'ast, T> { }) .collect(); - statements_flattened.push_back(FlatStatement::Log(l, expressions)); + statements_flattened.push_back(FlatStatement::Log(LogStatement::new( + s.format_string, + expressions, + ))); } + }; + + statements_flattened.set_span(span_backup); + } + + fn enforce_naive_assertion( + &mut self, + statements_flattened: &mut FlatStatements<'ast, T>, + e: BooleanExpression<'ast, T>, + error: zokrates_ast::zir::RuntimeError, + ) { + // naive approach: flatten the boolean to a single field element and constrain it to 1 + let e = self.flatten_boolean_expression(statements_flattened, e); + + if e.is_linear() { + statements_flattened.push_back(FlatStatement::condition( + e, + FlatExpression::value(T::from(1)), + error.into(), + )); + } else { + // swap so that left side is linear + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::from(1)), + e, + error.into(), + )); } } + /// Enforce that x is not zero + /// + /// # Arguments + /// + /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. + /// * `x` - `FlatExpression` The expression to be constrained to not be zero. + fn enforce_not_zero_assertion( + &mut self, + statements_flattened: &mut FlatStatements<'ast, T>, + x: FlatExpression, + ) { + // introduce intermediate variable + let x_id = self.define(x, statements_flattened); + + // check that `x` is not 0 by giving its inverse + let invx = self.use_sym(); + + // # invx = 1/x + statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( + vec![invx], + Solver::Div, + vec![FlatExpression::value(T::one()), x_id.into()], + ))); + + // assert(invx * x == 1) + statements_flattened.push_back(FlatStatement::condition( + FlatExpression::value(T::one()), + FlatExpression::mul(invx.into(), x_id.into()), + RuntimeError::Inverse, + )); + } + /// Flattens an equality assertion, enforcing it in the circuit. /// /// # Arguments @@ -2686,22 +2873,22 @@ impl<'ast, T: Field> Flattener<'ast, T> { error: RuntimeError, ) { let (lhs, rhs) = match (lhs, rhs) { - (FlatExpression::Mult(box x, box y), z) | (z, FlatExpression::Mult(box x, box y)) => ( + (FlatExpression::Mult(e), z) | (z, FlatExpression::Mult(e)) => ( self.identify_expression(z, statements_flattened), - FlatExpression::Mult( - box self.identify_expression(x, statements_flattened), - box self.identify_expression(y, statements_flattened), + FlatExpression::mul( + self.identify_expression(*e.left, statements_flattened), + self.identify_expression(*e.right, statements_flattened), ), ), (x, z) => ( self.identify_expression(z, statements_flattened), - FlatExpression::Mult( - box self.identify_expression(x, statements_flattened), - box FlatExpression::Number(T::from(1)), + FlatExpression::mul( + self.identify_expression(x, statements_flattened), + FlatExpression::value(T::from(1)), ), ), }; - statements_flattened.push_back(FlatStatement::Condition(lhs, rhs, error)); + statements_flattened.push_back(FlatStatement::condition(lhs, rhs, error)); } /// Identifies a non-linear expression by assigning it to a new identifier. @@ -2719,8 +2906,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { true => e, false => { let sym = self.use_sym(); - statements_flattened.push_back(FlatStatement::Definition(sym, e)); - FlatExpression::Identifier(sym) + statements_flattened.push_back(FlatStatement::definition(sym, e)); + FlatExpression::identifier(sym) } } } @@ -2754,6 +2941,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { parameter: &ZirParameter<'ast>, statements_flattened: &mut FlatStatements<'ast, T>, ) -> Parameter { + let span = parameter.get_span(); + + let backup_span = statements_flattened.span; + + statements_flattened.set_span(span); + let variable = self.use_variable(¶meter.id); match parameter.id.get_type() { @@ -2761,7 +2954,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // to constrain unsigned integer inputs to be in range, we get their bit decomposition. // it will be cached self.get_bits_unchecked( - &FlatUExpression::with_field(FlatExpression::Identifier(variable)), + &FlatUExpression::with_field(FlatExpression::identifier(variable)), bitwidth.to_usize(), bitwidth.to_usize(), statements_flattened, @@ -2769,19 +2962,18 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); } Type::Boolean => { - statements_flattened.push_back(FlatStatement::Condition( + statements_flattened.push_back(FlatStatement::condition( variable.into(), - FlatExpression::Mult(box variable.into(), box variable.into()), + FlatExpression::mul(variable.into(), variable.into()), RuntimeError::ArgumentBitness, )); } Type::FieldElement => {} } - Parameter { - id: variable, - private: parameter.private, - } + statements_flattened.set_span(backup_span); + + Parameter::new(variable, parameter.private).span(span) } fn issue_new_variable(&mut self) -> Variable { @@ -2803,10 +2995,18 @@ mod tests { use zokrates_ast::zir::types::Signature; use zokrates_ast::zir::types::Type; use zokrates_ast::zir::Id; + use zokrates_ast::zir::ZirFunction; use zokrates_field::Bn128Field; fn flatten_function(f: ZirFunction) -> FlatProg { - from_function_and_config(f, CompileConfig::default()).collect() + from_program_and_config( + ZirProgram { + main: f, + module_map: Default::default(), + }, + CompileConfig::default(), + ) + .collect() } #[test] @@ -2826,22 +3026,22 @@ mod tests { let function = ZirFunction:: { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::boolean("x".into()), - BooleanExpression::Value(true).into(), + BooleanExpression::value(true).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::boolean("y".into()), - BooleanExpression::Value(true).into(), + BooleanExpression::value(true).into(), ), - ZirStatement::Assertion( - BooleanExpression::BoolEq( - box BooleanExpression::identifier("x".into()), - box BooleanExpression::identifier("y".into()), + ZirStatement::assertion( + BooleanExpression::bool_eq( + BooleanExpression::identifier("x".into()), + BooleanExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -2851,22 +3051,23 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(1)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(Bn128Field::from(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(1)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), @@ -2893,25 +3094,25 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::value(Bn128Field::from(1)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::Add( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::add( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::value(Bn128Field::from(1)), ), - box FieldElementExpression::identifier("y".into()), + FieldElementExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -2922,25 +3123,26 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(2)), - ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(1)), - FlatExpression::Mult( - box FlatExpression::Add( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(2)), + ), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(1)), + FlatExpression::mul( + FlatExpression::add( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(Bn128Field::from(1)), ), - box FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), @@ -2969,24 +3171,24 @@ mod tests { let function = ZirFunction:: { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::uint("x".into(), 32), ZirExpression::Uint( - UExpressionInner::Value(42) + UExpression::value(42) .annotate(32) .metadata(metadata.clone()), ), ), - ZirStatement::Assertion( - BooleanExpression::UintEq( - box UExpression::identifier("x".into()) + ZirStatement::assertion( + BooleanExpression::uint_eq( + UExpression::identifier("x".into()) .annotate(32) .metadata(metadata.clone()), - box UExpressionInner::Value(42).annotate(32).metadata(metadata), + UExpression::value(42).annotate(32).metadata(metadata), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -2997,18 +3199,19 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(42)), + FlatExpression::value(Bn128Field::from(42)), ), - FlatStatement::Condition( - FlatExpression::Number(Bn128Field::from(42)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(Bn128Field::from(1)), + FlatStatement::condition( + FlatExpression::value(Bn128Field::from(42)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), @@ -3035,22 +3238,22 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -3061,22 +3264,23 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(1)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Number(Bn128Field::from(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(1)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), @@ -3105,29 +3309,29 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("z"), - FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(4)).into(), ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::mul( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), - box FieldElementExpression::identifier("z".into()), + FieldElementExpression::identifier("z".into()), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -3138,26 +3342,27 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(2), - FlatExpression::Number(Bn128Field::from(4)), + FlatExpression::value(Bn128Field::from(4)), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(2)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Identifier(Variable::new(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(2)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), @@ -3186,29 +3391,29 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("z"), - FieldElementExpression::Number(Bn128Field::from(4)).into(), - ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::identifier("z".into()), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + FieldElementExpression::value(Bn128Field::from(4)).into(), + ), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::identifier("z".into()), + FieldElementExpression::mul( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), ), zir::RuntimeError::mock(), ), - ZirStatement::Return(vec![]), + ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], @@ -3219,26 +3424,27 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(2), - FlatExpression::Number(Bn128Field::from(4)), + FlatExpression::value(Bn128Field::from(4)), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(2)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Identifier(Variable::new(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(2)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), @@ -3270,31 +3476,31 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("x"), - FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(4)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("y"), - FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::value(Bn128Field::from(4)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("z"), - FieldElementExpression::Number(Bn128Field::from(8)).into(), + FieldElementExpression::value(Bn128Field::from(8)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("t"), - FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::value(Bn128Field::from(2)).into(), ), - ZirStatement::Assertion( - BooleanExpression::FieldEq( - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("x".into()), - box FieldElementExpression::identifier("y".into()), + ZirStatement::assertion( + BooleanExpression::field_eq( + FieldElementExpression::mul( + FieldElementExpression::identifier("x".into()), + FieldElementExpression::identifier("y".into()), ), - box FieldElementExpression::Mult( - box FieldElementExpression::identifier("z".into()), - box FieldElementExpression::identifier("t".into()), + FieldElementExpression::mul( + FieldElementExpression::identifier("z".into()), + FieldElementExpression::identifier("t".into()), ), ), zir::RuntimeError::mock(), @@ -3309,37 +3515,38 @@ mod tests { let flat = flatten_function(function); let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(4)), + FlatExpression::value(Bn128Field::from(4)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(4)), + FlatExpression::value(Bn128Field::from(4)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(2), - FlatExpression::Number(Bn128Field::from(8)), + FlatExpression::value(Bn128Field::from(8)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(3), - FlatExpression::Number(Bn128Field::from(2)), + FlatExpression::value(Bn128Field::from(2)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(4), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(2)), - box FlatExpression::Identifier(Variable::new(3)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(2)), + FlatExpression::identifier(Variable::new(3)), ), ), - FlatStatement::Condition( - FlatExpression::Identifier(Variable::new(4)), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Identifier(Variable::new(1)), + FlatStatement::condition( + FlatExpression::identifier(Variable::new(4)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), @@ -3365,19 +3572,19 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("a"), - FieldElementExpression::Number(Bn128Field::from(7)).into(), + FieldElementExpression::value(Bn128Field::from(7)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("b"), - FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box 0u32.into(), + FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + 0u32.into(), ) .into(), ), - ZirStatement::Return(vec![FieldElementExpression::identifier("b".into()).into()]), + ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], @@ -3386,20 +3593,21 @@ mod tests { }; let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(7)), + FlatExpression::value(Bn128Field::from(7)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Number(Bn128Field::from(1)), + FlatExpression::value(Bn128Field::from(1)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::public(0), - FlatExpression::Identifier(Variable::new(1)), + FlatExpression::identifier(Variable::new(1)), ), ], }; @@ -3426,19 +3634,19 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("a"), - FieldElementExpression::Number(Bn128Field::from(7)).into(), + FieldElementExpression::value(Bn128Field::from(7)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("b"), - FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box 1u32.into(), + FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + 1u32.into(), ) .into(), ), - ZirStatement::Return(vec![FieldElementExpression::identifier("b".into()).into()]), + ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], @@ -3447,23 +3655,24 @@ mod tests { }; let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(7)), + FlatExpression::value(Bn128Field::from(7)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(1)), - box FlatExpression::Identifier(Variable::new(0)), + FlatExpression::mul( + FlatExpression::value(Bn128Field::from(1)), + FlatExpression::identifier(Variable::new(0)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::public(0), - FlatExpression::Identifier(Variable::new(1)), + FlatExpression::identifier(Variable::new(1)), ), ], }; @@ -3507,19 +3716,19 @@ mod tests { let function = ZirFunction { arguments: vec![], statements: vec![ - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("a"), - FieldElementExpression::Number(Bn128Field::from(7)).into(), + FieldElementExpression::value(Bn128Field::from(7)).into(), ), - ZirStatement::Definition( + ZirStatement::definition( zir::Variable::field_element("b"), - FieldElementExpression::Pow( - box FieldElementExpression::identifier("a".into()), - box 13u32.into(), + FieldElementExpression::pow( + FieldElementExpression::identifier("a".into()), + 13u32.into(), ) .into(), ), - ZirStatement::Return(vec![FieldElementExpression::identifier("b".into()).into()]), + ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], @@ -3528,58 +3737,59 @@ mod tests { }; let expected = FlatFunction { + module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ - FlatStatement::Definition( + FlatStatement::definition( Variable::new(0), - FlatExpression::Number(Bn128Field::from(7)), + FlatExpression::value(Bn128Field::from(7)), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(1), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(0)), - box FlatExpression::Identifier(Variable::new(0)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(0)), + FlatExpression::identifier(Variable::new(0)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(2), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(1)), - box FlatExpression::Identifier(Variable::new(1)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(1)), + FlatExpression::identifier(Variable::new(1)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(3), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(2)), - box FlatExpression::Identifier(Variable::new(2)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(2)), + FlatExpression::identifier(Variable::new(2)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(4), - FlatExpression::Mult( - box FlatExpression::Number(Bn128Field::from(1)), - box FlatExpression::Identifier(Variable::new(0)), + FlatExpression::mul( + FlatExpression::value(Bn128Field::from(1)), + FlatExpression::identifier(Variable::new(0)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(5), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(4)), - box FlatExpression::Identifier(Variable::new(2)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(4)), + FlatExpression::identifier(Variable::new(2)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::new(6), - FlatExpression::Mult( - box FlatExpression::Identifier(Variable::new(5)), - box FlatExpression::Identifier(Variable::new(3)), + FlatExpression::mul( + FlatExpression::identifier(Variable::new(5)), + FlatExpression::identifier(Variable::new(3)), ), ), - FlatStatement::Definition( + FlatStatement::definition( Variable::public(0), - FlatExpression::Identifier(Variable::new(6)), + FlatExpression::identifier(Variable::new(6)), ), ], }; @@ -3593,28 +3803,28 @@ mod tests { fn if_else() { let config = CompileConfig::default(); let expression = FieldElementExpression::conditional( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(32)), - box FieldElementExpression::Number(Bn128Field::from(4)), + BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(32)), + FieldElementExpression::value(Bn128Field::from(4)), ), - FieldElementExpression::Number(Bn128Field::from(12)), - FieldElementExpression::Number(Bn128Field::from(51)), + FieldElementExpression::value(Bn128Field::from(12)), + FieldElementExpression::value(Bn128Field::from(51)), ); let mut flattener = Flattener::new(config); - flattener.flatten_field_expression(&mut FlatStatements::new(), expression); + flattener.flatten_field_expression(&mut FlatStatements::default(), expression); } #[test] fn geq_leq() { let config = CompileConfig::default(); let mut flattener = Flattener::new(config); - let expression_le = BooleanExpression::FieldLe( - box FieldElementExpression::Number(Bn128Field::from(32)), - box FieldElementExpression::Number(Bn128Field::from(4)), + let expression_le = BooleanExpression::field_le( + FieldElementExpression::value(Bn128Field::from(32)), + FieldElementExpression::value(Bn128Field::from(4)), ); - flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_le); + flattener.flatten_boolean_expression(&mut FlatStatements::default(), expression_le); } #[test] @@ -3623,21 +3833,21 @@ mod tests { let mut flattener = Flattener::new(config); let expression = FieldElementExpression::conditional( - BooleanExpression::And( - box BooleanExpression::FieldEq( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(4)), + BooleanExpression::bitand( + BooleanExpression::field_eq( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(4)), ), - box BooleanExpression::FieldLt( - box FieldElementExpression::Number(Bn128Field::from(4)), - box FieldElementExpression::Number(Bn128Field::from(20)), + BooleanExpression::field_lt( + FieldElementExpression::value(Bn128Field::from(4)), + FieldElementExpression::value(Bn128Field::from(20)), ), ), - FieldElementExpression::Number(Bn128Field::from(12)), - FieldElementExpression::Number(Bn128Field::from(51)), + FieldElementExpression::value(Bn128Field::from(12)), + FieldElementExpression::value(Bn128Field::from(51)), ); - flattener.flatten_field_expression(&mut FlatStatements::new(), expression); + flattener.flatten_field_expression(&mut FlatStatements::default(), expression); } #[test] @@ -3645,21 +3855,21 @@ mod tests { // a = 5 / b / b let config = CompileConfig::default(); let mut flattener = Flattener::new(config); - let mut statements_flattened = FlatStatements::new(); + let mut statements_flattened = FlatStatements::default(); - let definition = ZirStatement::Definition( + let definition = ZirStatement::definition( zir::Variable::field_element("b"), - FieldElementExpression::Number(Bn128Field::from(42)).into(), + FieldElementExpression::value(Bn128Field::from(42)).into(), ); - let statement = ZirStatement::Definition( + let statement = ZirStatement::definition( zir::Variable::field_element("a"), - FieldElementExpression::Div( - box FieldElementExpression::Div( - box FieldElementExpression::Number(Bn128Field::from(5)), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::div( + FieldElementExpression::div( + FieldElementExpression::value(Bn128Field::from(5)), + FieldElementExpression::identifier("b".into()), ), - box FieldElementExpression::identifier("b".into()), + FieldElementExpression::identifier("b".into()), ) .into(), ); @@ -3683,35 +3893,35 @@ mod tests { let sym_2 = Variable::new(6); assert_eq!( - statements_flattened, + statements_flattened.buffer, vec![ - FlatStatement::Definition(b, FlatExpression::Number(Bn128Field::from(42))), + FlatStatement::definition(b, FlatExpression::value(Bn128Field::from(42))), // inputs to first div (5/b) - FlatStatement::Definition(five, FlatExpression::Number(Bn128Field::from(5))), - FlatStatement::Definition(b0, b.into()), + FlatStatement::definition(five, FlatExpression::value(Bn128Field::from(5))), + FlatStatement::definition(b0, b.into()), // execute div FlatStatement::Directive(FlatDirective::new( vec![sym_0], Solver::Div, - vec![five, b0] + vec![five.into(), b0.into()] )), - FlatStatement::Condition( + FlatStatement::condition( five.into(), - FlatExpression::Mult(box b0.into(), box sym_0.into()), + FlatExpression::mul(b0.into(), sym_0.into()), RuntimeError::Division ), // inputs to second div (res/b) - FlatStatement::Definition(sym_1, sym_0.into()), - FlatStatement::Definition(b1, b.into()), + FlatStatement::definition(sym_1, sym_0.into()), + FlatStatement::definition(b1, b.into()), // execute div FlatStatement::Directive(FlatDirective::new( vec![sym_2], Solver::Div, - vec![sym_1, b1] + vec![sym_1.into(), b1.into()] )), - FlatStatement::Condition( + FlatStatement::condition( sym_1.into(), - FlatExpression::Mult(box b1.into(), box sym_2.into()), + FlatExpression::mul(b1.into(), sym_2.into()), RuntimeError::Division ), ] diff --git a/zokrates_codegen/src/utils.rs b/zokrates_codegen/src/utils.rs index 6fc257921..ab53c4464 100644 --- a/zokrates_codegen/src/utils.rs +++ b/zokrates_codegen/src/utils.rs @@ -1,3 +1,4 @@ +use std::ops::*; use zokrates_ast::flat::*; use zokrates_field::Field; @@ -6,16 +7,16 @@ pub fn flat_expression_from_bits(v: Vec>) -> FlatExp v: Vec<(T, FlatExpression)>, ) -> FlatExpression { match v.len() { - 0 => FlatExpression::Number(T::zero()), + 0 => FlatExpression::value(T::zero()), 1 => { let (coeff, var) = v[0].clone(); - FlatExpression::Mult(box FlatExpression::Number(coeff), box var) + FlatExpression::mul(FlatExpression::value(coeff), var) } n => { let (u, v) = v.split_at(n / 2); - FlatExpression::Add( - box flat_expression_from_bits_aux(u.to_vec()), - box flat_expression_from_bits_aux(v.to_vec()), + FlatExpression::add( + flat_expression_from_bits_aux(u.to_vec()), + flat_expression_from_bits_aux(v.to_vec()), ) } } diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index 2b2530047..e443f1768 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core" -version = "0.7.4" +version = "0.7.5" edition = "2021" authors = ["Jacob Eberhardt ", "Dennis Kuhnert "] repository = "https://github.com/Zokrates/ZoKrates" diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 0c7583fa2..4a0505b7d 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -18,7 +18,7 @@ use zokrates_ast::ir::{self, from_flat::from_flat}; use zokrates_ast::typed::abi::Abi; use zokrates_ast::untyped::{Module, OwnedModuleId, Program}; use zokrates_ast::zir::ZirProgram; -use zokrates_codegen::from_function_and_config; +use zokrates_codegen::from_program_and_config; use zokrates_common::{CompileConfig, Resolver}; use zokrates_field::Field; use zokrates_pest_ast as pest; @@ -153,18 +153,12 @@ impl fmt::Display for CompileErrorInner { CompileErrorInner::ParserError(ref e) => write!(f, "\n\t{}", e), CompileErrorInner::MacroError(ref e) => write!(f, "\n\t{}", e), CompileErrorInner::SemanticError(ref e) => { - let location = e - .pos() - .map(|p| format!("{}", p.0)) - .unwrap_or_else(|| "".to_string()); + let location = e.pos().map(|p| format!("{}", p.from)).unwrap_or_default(); write!(f, "{}\n\t{}", location, e.message()) } CompileErrorInner::ReadError(ref e) => write!(f, "\n\t{}", e), CompileErrorInner::ImportError(ref e) => { - let location = e - .pos() - .map(|p| format!("{}", p.0)) - .unwrap_or_else(|| "".to_string()); + let location = e.span().map(|p| format!("{}", p.from)).unwrap_or_default(); write!(f, "{}\n\t{}", location, e.message()) } CompileErrorInner::AnalysisError(ref e) => write!(f, "\n\t{}", e), @@ -189,7 +183,7 @@ pub fn compile<'ast, T: Field, E: Into>( // flatten input program log::debug!("Flatten"); - let program_flattened = from_function_and_config(typed_ast.main, config); + let program_flattened = from_program_and_config(typed_ast, config); // convert to ir log::debug!("Convert to IR"); @@ -327,7 +321,7 @@ mod test { assert!(e.0[0] .value() .to_string() - .contains(&"Cannot resolve import without a resolver")); + .contains("Cannot resolve import without a resolver")); } #[test] @@ -448,15 +442,15 @@ struct Bar { field a; } vec![], vec![ConcreteStructMember { id: "b".into(), - ty: box ConcreteType::Struct(ConcreteStructType::new( + ty: Box::new(ConcreteType::Struct(ConcreteStructType::new( "bar".into(), "Bar".into(), vec![], - vec![ConcreteStructMember { - id: "a".into(), - ty: box ConcreteType::FieldElement - }] - )) + vec![ConcreteStructMember::new( + "a".into(), + ConcreteType::FieldElement + )] + ))) }] )) }], diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index f7fa95f13..3bc08c5b8 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -13,43 +13,43 @@ use std::path::{Path, PathBuf}; use zokrates_ast::untyped::*; use typed_arena::Arena; -use zokrates_ast::common::FlatEmbed; +use zokrates_ast::common::{FlatEmbed, SourceSpan}; use zokrates_ast::untyped::types::UnresolvedType; use zokrates_common::Resolver; use zokrates_field::Field; #[derive(PartialEq, Eq, Debug)] pub struct Error { - pos: Option<(Position, Position)>, + span: Option, message: String, } impl Error { pub fn new>(message: T) -> Error { Error { - pos: None, + span: None, message: message.into(), } } - pub fn pos(&self) -> &Option<(Position, Position)> { - &self.pos + pub fn span(&self) -> &Option { + &self.span } pub fn message(&self) -> &str { &self.message } - fn with_pos(self, pos: Option<(Position, Position)>) -> Error { - Error { pos, ..self } + fn with_span(self, span: Option) -> Error { + Error { span, ..self } } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let location = self - .pos - .map(|p| format!("{}", p.0)) + .span + .map(|p| format!("{}", p.from)) .unwrap_or_else(|| "?".to_string()); write!(f, "{}\n\t{}", location, self.message) } @@ -58,7 +58,7 @@ impl fmt::Display for Error { impl From for Error { fn from(error: io::Error) -> Self { Error { - pos: None, + span: None, message: format!("I/O Error: {}", error), } } @@ -95,7 +95,7 @@ impl Importer { modules: &mut HashMap>, arena: &'ast Arena, ) -> Result, CompileErrors> { - let pos = import.pos(); + let span = import.span().in_module(location); let module_id = import.value.source; let symbol = import.value.id; @@ -111,7 +111,7 @@ impl Importer { Bn128Field::name(), T::name() )) - .with_pos(Some(pos)), + .with_span(Some(span)), ) .in_file(location) .into()); @@ -132,7 +132,7 @@ impl Importer { Bw6_761Field::name(), T::name() )) - .with_pos(Some(pos)), + .with_span(Some(span)), ) .in_file(location) .into()); @@ -195,12 +195,12 @@ impl Importer { expression: Expression::U32Constant(T::get_required_bits() as u32) .into(), } - .start_end(pos.0, pos.1), + .start_end(span.from, span.to), )), }, s => { return Err(CompileErrorInner::ImportError( - Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)), + Error::new(format!("Embed {} not found", s)).with_span(Some(span)), ) .in_file(location) .into()); @@ -245,16 +245,16 @@ impl Importer { id: alias, symbol: Symbol::There( SymbolImport::with_id_in_module(symbol.id, new_location) - .start_end(pos.0, pos.1), + .start_end(span.from, span.to), ), } } Err(err) => { - return Err( - CompileErrorInner::ImportError(err.into().with_pos(Some(pos))) - .in_file(location) - .into(), - ); + return Err(CompileErrorInner::ImportError( + err.into().with_span(Some(span)), + ) + .in_file(location) + .into()); } }, None => { @@ -267,6 +267,6 @@ impl Importer { }, }; - Ok(symbol_declaration.start_end(pos.0, pos.1)) + Ok(symbol_declaration.start_end(span.from, span.to)) } } diff --git a/zokrates_core/src/lib.rs b/zokrates_core/src/lib.rs index 51de21db8..874b996ac 100644 --- a/zokrates_core/src/lib.rs +++ b/zokrates_core/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(box_patterns, box_syntax)] - pub mod compile; pub mod imports; mod macros; diff --git a/zokrates_core/src/optimizer/directive.rs b/zokrates_core/src/optimizer/directive.rs index 4d140637a..dd8159013 100644 --- a/zokrates_core/src/optimizer/directive.rs +++ b/zokrates_core/src/optimizer/directive.rs @@ -28,24 +28,34 @@ impl<'ast, T: Field> Folder<'ast, T> for DirectiveOptimizer<'ast, T> { *self.substitution.get(&v).unwrap_or(&v) } - fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { - match s { - Statement::Directive(d) => { - let d = self.fold_directive(d); + fn fold_directive_statement( + &mut self, + d: DirectiveStatement<'ast, T>, + ) -> Vec> { + let d = DirectiveStatement { + inputs: d + .inputs + .into_iter() + .map(|e| self.fold_quadratic_combination(e)) + .collect(), + outputs: d + .outputs + .into_iter() + .map(|o| self.fold_variable(o)) + .collect(), + ..d + }; - match self.calls.entry((d.solver.clone(), d.inputs.clone())) { - Entry::Vacant(e) => { - e.insert(d.outputs.clone()); - vec![Statement::Directive(d)] - } - Entry::Occupied(e) => { - self.substitution - .extend(d.outputs.into_iter().zip(e.get().iter().cloned())); - vec![] - } - } + match self.calls.entry((d.solver.clone(), d.inputs.clone())) { + Entry::Vacant(e) => { + e.insert(d.outputs.clone()); + vec![Statement::Directive(d)] + } + Entry::Occupied(e) => { + self.substitution + .extend(d.outputs.into_iter().zip(e.get().iter().cloned())); + vec![] } - s => fold_statement(self, s), } } } diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 664cfc2db..9f12e1e04 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -39,14 +39,22 @@ impl<'ast, T: Field> Folder<'ast, T> for DuplicateOptimizer { } fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { - let hashed = hash(&s); - let result = match self.seen.get(&hashed) { - Some(_) => vec![], - None => vec![s], - }; - - self.seen.insert(hashed); - result + match s { + Statement::Block(s) => s + .inner + .into_iter() + .flat_map(|s| self.fold_statement(s)) + .collect(), + s => { + let hashed = hash(&s); + let result = match self.seen.get(&hashed) { + Some(_) => vec![], + None => vec![s], + }; + self.seen.insert(hashed); + result + } + } } } @@ -59,24 +67,28 @@ mod tests { #[test] fn identity() { let p: Prog = Prog { + module_map: Default::default(), statements: vec![ Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(3)), LinComb::summand(3, Variable::new(3)), ), LinComb::one(), + None, ), Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(42)), LinComb::summand(3, Variable::new(3)), ), LinComb::zero(), + None, ), ], return_count: 0, arguments: vec![], + solvers: vec![], }; let expected = p.clone(); @@ -90,44 +102,51 @@ mod tests { #[test] fn remove_duplicates() { let constraint = Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(3)), LinComb::summand(3, Variable::new(3)), ), LinComb::one(), + None, ); let p: Prog = Prog { + module_map: Default::default(), statements: vec![ constraint.clone(), constraint.clone(), Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(42)), LinComb::summand(3, Variable::new(3)), ), LinComb::zero(), + None, ), constraint.clone(), constraint.clone(), ], return_count: 0, arguments: vec![], + solvers: vec![], }; let expected = Prog { + module_map: Default::default(), statements: vec![ constraint, Statement::constraint( - QuadComb::from_linear_combinations( + QuadComb::new( LinComb::summand(3, Variable::new(42)), LinComb::summand(3, Variable::new(3)), ), LinComb::zero(), + None, ), ], return_count: 0, arguments: vec![], + solvers: vec![], }; assert_eq!( diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index 1f94740a9..08a6c1b7d 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -54,6 +54,8 @@ pub fn optimize<'ast, T: Field, I: IntoIterator>>( .flat_map(move |s| directive_optimizer.fold_statement(s)) .flat_map(move |s| duplicate_optimizer.fold_statement(s)), return_count: p.return_count, + module_map: p.module_map, + solvers: p.solvers, }; log::debug!("Done"); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index b0877fd02..55880465a 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -37,8 +37,9 @@ // - otherwise return `c_0` use std::collections::{HashMap, HashSet}; +use zokrates_ast::common::WithSpan; use zokrates_ast::flat::Variable; -use zokrates_ast::ir::folder::{fold_statement, Folder}; +use zokrates_ast::ir::folder::{fold_statement_cases, Folder}; use zokrates_ast::ir::LinComb; use zokrates_ast::ir::*; use zokrates_field::Field; @@ -62,29 +63,28 @@ impl RedefinitionOptimizer { .into_iter() .chain(p.arguments.iter().map(|p| p.id)) .chain(p.returns()) - .into_iter() .collect(), } } - fn fold_statement<'ast>( + fn fold_statement_cases<'ast>( &mut self, s: Statement<'ast, T>, aggressive: bool, ) -> Vec> { match s { - Statement::Constraint(quad, lin, message) => { - let quad = self.fold_quadratic_combination(quad); - let lin = self.fold_linear_combination(lin); + Statement::Constraint(s) => { + let quad = self.fold_quadratic_combination(s.quad); + let lin = self.fold_linear_combination(s.lin); if lin.is_zero() { - return vec![Statement::Constraint(quad, lin, message)]; + return vec![Statement::constraint(quad, lin, s.error)]; } - let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.0[0].0) - || self.substitution.contains_key(&lin.0[0].0) + let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.value[0].0) + || self.substitution.contains_key(&lin.value[0].0) { - true => (Some(Statement::Constraint(quad, lin, message)), None, None), + true => (Some(Statement::constraint(quad, lin, s.error)), None, None), false => match lin.try_summand() { // if the right side is a single variable Ok((variable, coefficient)) => match quad.try_linear() { @@ -92,16 +92,16 @@ impl RedefinitionOptimizer { Ok(l) => (None, Some((variable, l / &coefficient)), None), // if the left side isn't linear Err(quad) => ( - Some(Statement::Constraint( + Some(Statement::constraint( quad, LinComb::summand(coefficient, variable), - message, + s.error, )), None, Some(variable), ), }, - Err(l) => (Some(Statement::Constraint(quad, l, message)), None, None), + Err(l) => (Some(Statement::constraint(quad, l, s.error)), None, None), }, }; @@ -122,7 +122,19 @@ impl RedefinitionOptimizer { } } Statement::Directive(d) => { - let d = self.fold_directive(d); + let d = DirectiveStatement { + inputs: d + .inputs + .into_iter() + .map(|e| self.fold_quadratic_combination(e)) + .collect(), + outputs: d + .outputs + .into_iter() + .map(|o| self.fold_variable(o)) + .collect(), + ..d + }; // check if the inputs are constants, ie reduce to the form `coeff * ~one` let inputs: Vec<_> = d @@ -146,7 +158,7 @@ impl RedefinitionOptimizer { // unwrap inputs to their constant value let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect(); // run the solver - let outputs = Interpreter::execute_solver(&d.solver, &inputs).unwrap(); + let outputs = Interpreter::execute_solver(&d.solver, &inputs, &[]).unwrap(); assert_eq!(outputs.len(), d.outputs.len()); // insert the results in the substitution @@ -171,24 +183,30 @@ impl RedefinitionOptimizer { self.ignore.insert(o); } } - vec![Statement::Directive(Directive { inputs, ..d })] + vec![Statement::directive(d.outputs, d.solver, inputs)] } } } - s => fold_statement(self, s), + s => fold_statement_cases(self, s), } } } impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer { - fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { + fn fold_statement_cases(&mut self, s: Statement<'ast, T>) -> Vec> { match s { Statement::Block(statements) => { #[allow(clippy::needless_collect)] // optimize aggressively and clean up in a second pass (we need to collect here) let statements: Vec<_> = statements + .inner .into_iter() - .flat_map(|s| self.fold_statement(s, true)) + .flat_map(|s| { + let span = s.get_span(); + self.fold_statement_cases(s, true) + .into_iter() + .map(move |s| s.span(span)) + }) .collect(); // clean up @@ -203,22 +221,23 @@ impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer { }) .collect(); - vec![Statement::Block(statements)] + vec![Statement::block(statements)] } - s => self.fold_statement(s, false), + s => self.fold_statement_cases(s, false), } } fn fold_linear_combination(&mut self, lc: LinComb) -> LinComb { match lc - .0 + .value .iter() .any(|(variable, _)| self.substitution.get(variable).is_some()) { true => // for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is { - lc.0.into_iter() + lc.value + .into_iter() .map(|(variable, coefficient)| { self.substitution .get(&variable) @@ -250,18 +269,22 @@ mod tests { let out = Variable::public(0); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ Statement::definition(y, x.id), Statement::definition(out, y), ], return_count: 1, + solvers: vec![], }; let optimized: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![Statement::definition(out, x.id)], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -277,9 +300,11 @@ mod tests { let x = Parameter::public(Variable::new(0)); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![Statement::definition(one, x.id)], return_count: 1, + solvers: vec![], }; let optimized = p.clone(); @@ -308,23 +333,27 @@ mod tests { let out = Variable::public(0); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ Statement::definition(y, x.id), Statement::definition(z, y), - Statement::constraint(z, y), + Statement::constraint(z, y, None), Statement::definition(out, z), ], return_count: 1, + solvers: vec![], }; let optimized: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ - Statement::constraint(x.id, x.id), + Statement::constraint(x.id, x.id, None), Statement::definition(out, x.id), ], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -355,6 +384,7 @@ mod tests { let out_0 = Variable::public(1); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ Statement::definition(y, x.id), @@ -365,15 +395,18 @@ mod tests { Statement::definition(out_1, w), ], return_count: 2, + solvers: vec![], }; let optimized: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ Statement::definition(out_0, x.id), Statement::definition(out_1, Bn128Field::from(1)), ], return_count: 2, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -404,6 +437,7 @@ mod tests { let r = Variable::public(0); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x, y], statements: vec![ Statement::definition(a, LinComb::from(x.id) + LinComb::from(y.id)), @@ -418,18 +452,22 @@ mod tests { Statement::constraint( LinComb::summand(2, c), LinComb::summand(6, x.id) + LinComb::summand(6, y.id), + None, ), Statement::definition(r, LinComb::from(a) + LinComb::from(b) + LinComb::from(c)), ], return_count: 1, + solvers: vec![], }; let expected: Prog = Prog { + module_map: Default::default(), arguments: vec![x, y], statements: vec![ Statement::constraint( LinComb::summand(6, x.id) + LinComb::summand(6, y.id), LinComb::summand(6, x.id) + LinComb::summand(6, y.id), + None, ), Statement::definition( r, @@ -442,6 +480,7 @@ mod tests { ), ], return_count: 1, + solvers: vec![], }; let mut optimizer = RedefinitionOptimizer::init(&p); @@ -470,15 +509,14 @@ mod tests { let z = Variable::new(2); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x, y], statements: vec![ - Statement::definition( - z, - QuadComb::from_linear_combinations(LinComb::from(x.id), LinComb::from(y.id)), - ), + Statement::definition(z, QuadComb::new(LinComb::from(x.id), LinComb::from(y.id))), Statement::definition(z, LinComb::from(x.id)), ], return_count: 0, + solvers: vec![], }; let optimized = p.clone(); @@ -501,12 +539,14 @@ mod tests { let x = Parameter::public(Variable::new(0)); let p: Prog = Prog { + module_map: Default::default(), arguments: vec![x], statements: vec![ - Statement::constraint(x.id, Bn128Field::from(1)), - Statement::constraint(x.id, Bn128Field::from(2)), + Statement::constraint(x.id, Bn128Field::from(1), None), + Statement::constraint(x.id, Bn128Field::from(2), None), ], return_count: 1, + solvers: vec![], }; let optimized = p.clone(); diff --git a/zokrates_core/src/optimizer/tautology.rs b/zokrates_core/src/optimizer/tautology.rs index 855efa11d..d33ddeeda 100644 --- a/zokrates_core/src/optimizer/tautology.rs +++ b/zokrates_core/src/optimizer/tautology.rs @@ -5,7 +5,6 @@ // // This makes the assumption that ~one has value 1, as should be guaranteed by the verifier -use zokrates_ast::ir::folder::fold_statement; use zokrates_ast::ir::folder::Folder; use zokrates_ast::ir::*; use zokrates_field::Field; @@ -14,19 +13,16 @@ use zokrates_field::Field; pub struct TautologyOptimizer; impl<'ast, T: Field> Folder<'ast, T> for TautologyOptimizer { - fn fold_statement(&mut self, s: Statement<'ast, T>) -> Vec> { - match s { - Statement::Constraint(quad, lin, message) => match quad.try_linear() { - Ok(l) => { - if l == lin { - vec![] - } else { - vec![Statement::Constraint(l.into(), lin, message)] - } + fn fold_constraint_statement(&mut self, s: ConstraintStatement) -> Vec> { + match s.quad.try_linear() { + Ok(l) => { + if l == s.lin { + vec![] + } else { + vec![Statement::constraint(l, s.lin, s.error)] } - Err(quad) => vec![Statement::Constraint(quad, lin, message)], - }, - _ => fold_statement(self, s), + } + Err(quad) => vec![Statement::constraint(quad, s.lin, s.error)], } } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 4e006eed3..5be636cbb 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -8,7 +8,8 @@ use num_bigint::BigUint; use std::collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt; use std::path::PathBuf; -use zokrates_ast::common::{FormatString, SourceMetadata}; +use zokrates_ast::common::expressions::ValueExpression; +use zokrates_ast::common::{FormatString, ModuleMap, SourceMetadata, SourceSpan, WithSpan}; use zokrates_ast::typed::types::{GGenericsAssignment, GTupleType, GenericsAssignment}; use zokrates_ast::typed::SourceIdentifier; use zokrates_ast::typed::*; @@ -17,6 +18,8 @@ use zokrates_ast::untyped::Identifier; use zokrates_ast::untyped::*; use zokrates_field::Field; +use std::ops::*; + use zokrates_ast::untyped::types::{UnresolvedSignature, UnresolvedType, UserTypeId}; use std::hash::Hash; @@ -29,7 +32,7 @@ use zokrates_ast::typed::types::{ #[derive(PartialEq, Eq, Debug)] pub struct ErrorInner { - pos: Option<(Position, Position)>, + span: Option, message: String, } @@ -40,8 +43,8 @@ pub struct Error { } impl ErrorInner { - pub fn pos(&self) -> &Option<(Position, Position)> { - &self.pos + pub fn pos(&self) -> &Option { + &self.span } pub fn message(&self) -> &str { @@ -369,6 +372,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(TypedProgram { main: program.main, + module_map: ModuleMap::new(state.typed_modules.keys().cloned()), modules: state.typed_modules, }) } @@ -379,7 +383,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &State<'ast, T>, ) -> Result, Vec> { - let pos = ty.pos(); + let span = ty.span().in_module(module_id); let ty = ty.value; let mut errors = vec![]; @@ -395,7 +399,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .is_some() { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!( "Generic parameter {p} conflicts with constant symbol {p}", p = g.value @@ -410,7 +414,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } false => { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!("Generic parameter {} is already declared", g.value), }); } @@ -432,7 +436,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for declared_generic in generics_map.keys() { if !used_generics.contains(declared_generic) { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Generic parameter {} must be used", declared_generic), }); } @@ -458,7 +462,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &State<'ast, T>, ) -> Result, ErrorInner> { - let pos = c.pos(); + let span = c.span().in_module(module_id); let ty = self.check_declaration_type( c.value.ty.clone(), @@ -492,7 +496,7 @@ impl<'ast, T: Field> Checker<'ast, T> { DeclarationType::Int => Err(checked_expr), // Integers cannot be assigned } .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expression `{}` of type `{}` cannot be assigned to constant `{}` of type `{}`", e, @@ -511,7 +515,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &State<'ast, T>, ) -> Result, Vec> { - let pos = s.pos(); + let span = s.span().in_module(module_id); let s = s.value; let mut errors = vec![]; @@ -529,7 +533,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .is_some() { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!( "Generic parameter {p} conflicts with constant symbol {p}", p = g.value @@ -544,7 +548,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } false => { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!("Generic parameter {} is already declared", g.value), }); } @@ -569,7 +573,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(f) => match fields_set.insert(f.0.clone()) { true => fields.push(f), false => errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Duplicate key {} in struct definition", f.0,), }), }, @@ -583,7 +587,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for declared_generic in generics_map.keys() { if !used_generics.contains(declared_generic) { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Generic parameter {} must be used", declared_generic), }); } @@ -614,7 +618,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ) -> Result<(), Vec> { let mut errors: Vec = vec![]; - let pos = declaration.pos(); + let span = declaration.span().in_module(module_id); let declaration = declaration.value; match declaration.symbol.clone() { @@ -629,7 +633,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match symbol_unifier.insert_type(declaration.id) { false => errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id @@ -671,7 +675,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match symbol_unifier.insert_constant(declaration.id) { false => errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id @@ -722,7 +726,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match symbol_unifier.insert_type(declaration.id) { false => errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id, @@ -753,7 +757,7 @@ impl<'ast, T: Field> Checker<'ast, T> { { false => errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id @@ -789,7 +793,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } Symbol::There(import) => { - let pos = import.pos(); + let span = import.span().in_module(module_id); let import = import.value; match Checker::default().check_module(&import.module_id, state) { @@ -800,7 +804,6 @@ impl<'ast, T: Field> Checker<'ast, T> { .get(&import.module_id) .unwrap() .functions_iter() - .into_iter() .filter(|d| d.key.id == import.symbol_id) .map(|d| DeclarationFunctionKey { module: import.module_id.to_path_buf(), @@ -850,7 +853,7 @@ impl<'ast, T: Field> Checker<'ast, T> { errors.push(Error { module_id: module_id.to_path_buf(), inner: ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id, @@ -871,7 +874,7 @@ impl<'ast, T: Field> Checker<'ast, T> { errors.push(Error { module_id: module_id.to_path_buf(), inner: ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id, @@ -909,7 +912,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } (0, None, None) => { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Could not find symbol {} in module {}", import.symbol_id, import.module_id.display(), @@ -923,7 +926,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match symbol_unifier.insert_function(declaration.id, candidate.signature.clone()) { false => { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id, @@ -957,7 +960,7 @@ impl<'ast, T: Field> Checker<'ast, T> { false => { errors.push( ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} conflicts with another symbol", declaration.id @@ -1048,17 +1051,16 @@ impl<'ast, T: Field> Checker<'ast, T> { fn check_single_main(module: &TypedModule) -> Result<(), ErrorInner> { match module .functions_iter() - .into_iter() .filter(|d| d.key.id == "main") .count() { 1 => Ok(()), 0 => Err(ErrorInner { - pos: None, + span: None, message: "No main function found".into(), }), n => Err(ErrorInner { - pos: None, + span: None, message: format!("Only one main function allowed, found {}", n), }), } @@ -1070,14 +1072,14 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, Vec> { - let pos = var.pos(); + let span = var.span().in_module(module_id); let var = self.check_variable(var, module_id, types)?; match var.get_type() { Type::Uint(UBitwidth::B32) => Ok(()), t => Err(vec![ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Variable in for loop cannot have type {}", t), }]), }?; @@ -1106,7 +1108,7 @@ impl<'ast, T: Field> Checker<'ast, T> { self.enter_scope(); - let pos = funct_node.pos(); + let span = funct_node.span().in_module(module_id); let mut errors = vec![]; let funct = funct_node.value; @@ -1142,14 +1144,14 @@ impl<'ast, T: Field> Checker<'ast, T> { } for (arg, decl_ty) in funct.arguments.into_iter().zip(s.inputs.iter()) { - let pos = arg.pos(); + let span = arg.span().in_module(module_id); let arg = arg.value; // parameters defined on a non-entrypoint function should not have visibility modifiers if (state.main_id != module_id || id != "main") && arg.is_private.is_some() { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Visibility modifiers on arguments are only allowed on the entrypoint function" .into(), @@ -1159,12 +1161,11 @@ impl<'ast, T: Field> Checker<'ast, T> { let decl_v = DeclarationVariable::new( self.id_in_this_scope(arg.id.value.id), decl_ty.clone(), - arg.id.value.is_mutable, ); let is_mutable = arg.id.value.is_mutable; - let ty = specialize_declaration_type(decl_v.clone()._type, &generics).unwrap(); + let ty = specialize_declaration_type(decl_v.clone().ty, &generics).unwrap(); assert_eq!(self.scope.level, 1); @@ -1178,27 +1179,27 @@ impl<'ast, T: Field> Checker<'ast, T> { false => {} true => { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Duplicate name in function definition: `{}` was previously declared as an argument, a generic parameter or a constant", arg.id.value.id) }); } }; - arguments_checked.push(DeclarationParameter { - id: decl_v, - private: arg.is_private.unwrap_or(false), - }); + arguments_checked.push( + DeclarationParameter::new(decl_v, arg.is_private.unwrap_or(false)) + .with_span(span), + ); } let mut found_return = false; for stat in funct.statements.into_iter() { - let pos = Some(stat.pos()); + let span = Some(stat.span().in_module(module_id)); if let Statement::Return(..) = stat.value { if found_return { errors.push(ErrorInner { - pos, + span, message: "Expected a single return statement".to_string(), }); } @@ -1217,12 +1218,13 @@ impl<'ast, T: Field> Checker<'ast, T> { } if !found_return { - match (&*s.output).is_empty_tuple() { - true => statements_checked - .push(TypedStatement::Return(TypedExpression::empty_tuple())), + match (*s.output).is_empty_tuple() { + true => statements_checked.push( + TypedStatement::ret(TypedExpression::empty_tuple()).with_span(span), + ), false => { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Expected a return statement".to_string(), }); } @@ -1271,7 +1273,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .is_some() { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!( "Generic parameter {p} conflicts with constant symbol {p}", p = g.value @@ -1286,7 +1288,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } false => { errors.push(ErrorInner { - pos: Some(g.pos()), + span: Some(g.span().in_module(module_id)), message: format!("Generic parameter {} is already declared", g.value), }); } @@ -1348,7 +1350,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { - let pos = ty.pos(); + let span = ty.span().in_module(module_id); let ty = ty.value; match ty { @@ -1364,7 +1366,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Uint(e) => match e.bitwidth() { UBitwidth::B32 => Ok(e), _ => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant, found {} of type {}", e, ty @@ -1374,7 +1376,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Int(v) => { UExpression::try_from_int(v.clone(), &UBitwidth::B32).map_err(|_| { ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant, found {} of type {}", v, ty @@ -1383,7 +1385,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } _ => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant, found {} of type {}", size, ty @@ -1411,7 +1413,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .get(&id) .cloned() .ok_or_else(|| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Undefined type {}", id), })?; @@ -1438,13 +1440,13 @@ impl<'ast, T: Field> Checker<'ast, T> { UExpression::try_from_typed(e, &UBitwidth::B32) .map(|e| (GenericIdentifier::with_name(g).with_index(i), e)) .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Expected u32 expression, but got expression of type {}", e.get_type()), }) }) }, None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Expected u32 constant or identifier, but found `_`. Generic inference is not supported yet." .into(), @@ -1456,7 +1458,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(specialize_declaration_type(declaration_type, &assignment).unwrap()) } false => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected {} generic argument{} on type {}, but got {}", generic_identifiers.len(), @@ -1481,7 +1483,7 @@ impl<'ast, T: Field> Checker<'ast, T> { generics_map: &BTreeMap, usize>, used_generics: &mut HashSet>, ) -> Result, ErrorInner> { - let pos = expr.pos(); + let span = expr.span().in_module(module_id); match expr.value { Expression::U32Constant(c) => Ok(DeclarationConstant::from(c)), @@ -1492,7 +1494,7 @@ impl<'ast, T: Field> Checker<'ast, T> { )) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant or an identifier, found {}", Expression::IntConstant(c) @@ -1508,7 +1510,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match ty { DeclarationType::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(CanonicalConstantIdentifier::new(name, module_id.into()))), _ => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant or an identifier, found {} of type {}", name, ty @@ -1518,13 +1520,13 @@ impl<'ast, T: Field> Checker<'ast, T> { } (None, Some(index)) => Ok(DeclarationConstant::Generic(GenericIdentifier::with_name(name).with_index(*index))), _ => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Undeclared symbol `{}`", name) }) } } e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array dimension to be a u32 constant or an identifier, found {}", e @@ -1541,7 +1543,7 @@ impl<'ast, T: Field> Checker<'ast, T> { generics_map: &BTreeMap, usize>, used_generics: &mut HashSet>, ) -> Result, ErrorInner> { - let pos = ty.pos(); + let span = ty.span().in_module(module_id); let ty = ty.value; match ty { @@ -1587,7 +1589,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .get(&id) .cloned() .ok_or_else(|| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Undefined type {}", id), })?; @@ -1605,7 +1607,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ) .map(Some), None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Expected u32 constant or identifier, but found `_`".into(), }), }) @@ -1639,7 +1641,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(res) } false => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected {} generic argument{} on type {}, but got {}", ty.generics.len(), @@ -1660,17 +1662,13 @@ impl<'ast, T: Field> Checker<'ast, T> { types: &TypeMap<'ast, T>, ) -> Result, Vec> { let ty = self - .check_type(v.value._type, module_id, types) + .check_type(v.value.ty, module_id, types) .map_err(|e| vec![e])?; // insert into the scope and ignore whether shadowing happened self.insert_into_scope(v.value.id, ty.clone(), v.value.is_mutable); - Ok(Variable::new( - self.id_in_this_scope(v.value.id), - ty, - v.value.is_mutable, - )) + Ok(Variable::new(self.id_in_this_scope(v.value.id), ty)) } fn check_for_loop( @@ -1678,7 +1676,7 @@ impl<'ast, T: Field> Checker<'ast, T> { var: zokrates_ast::untyped::VariableNode<'ast>, range: (ExpressionNode<'ast>, ExpressionNode<'ast>), statements: Vec>, - pos: (Position, Position), + span: SourceSpan, module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, Vec> { @@ -1693,7 +1691,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Uint(from) => match from.bitwidth() { UBitwidth::B32 => Ok(from), bitwidth => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected lower loop bound to be of type u32, found {}", Type::::Uint(bitwidth) @@ -1702,7 +1700,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }, TypedExpression::Int(v) => { UExpression::try_from_int(v, &UBitwidth::B32).map_err(|_| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected lower loop bound to be of type u32, found {}", Type::::Int @@ -1710,7 +1708,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } from => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected lower loop bound to be of type u32, found {}", from.get_type() @@ -1723,7 +1721,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Uint(to) => match to.bitwidth() { UBitwidth::B32 => Ok(to), bitwidth => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected upper loop bound to be of type u32, found {}", Type::::Uint(bitwidth) @@ -1732,7 +1730,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }, TypedExpression::Int(v) => { UExpression::try_from_int(v, &UBitwidth::B32).map_err(|_| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected upper loop bound to be of type u32, found {}", Type::::Int @@ -1740,7 +1738,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }) } to => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected upper loop bound to be of type u32, found {}", to.get_type() @@ -1756,7 +1754,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|s| self.check_statement(s, module_id, types)) .collect::, _>>()?; - Ok(TypedStatement::For(var, from, to, checked_statements)) + Ok(TypedStatement::for_(var, from, to, checked_statements).with_span(span)) } // the assignee is already checked to be defined and mutable @@ -1769,9 +1767,9 @@ impl<'ast, T: Field> Checker<'ast, T> { ) -> Result, ErrorInner> { match expr.value { // for function calls, check the rhs with the expected type - Expression::FunctionCall(box fun_id_expression, generics, arguments) => self + Expression::FunctionCall(fun_id_expression, generics, arguments) => self .check_function_call_expression( - fun_id_expression, + *fun_id_expression, generics, arguments, Some(return_type), @@ -1789,7 +1787,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result>, ErrorInner> { - let pos = stat.pos(); + let span = stat.span().in_module(module_id); match stat.value { AssemblyStatement::Assignment(assignee, expression, constrained) => { @@ -1797,7 +1795,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let e = self.check_expression(expression, module_id, types)?; let e = FieldElementExpression::try_from_typed(e).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected right hand side of an assembly assignment to be of type field, found {}", e.get_type(), @@ -1809,25 +1807,29 @@ impl<'ast, T: Field> Checker<'ast, T> { let e = FieldElementExpression::block(vec![], e); match assignee.get_type() { Type::FieldElement => Ok(vec![ - TypedAssemblyStatement::Assignment( + TypedAssemblyStatement::assignment( assignee.clone(), e.clone().into(), - ), - TypedAssemblyStatement::Constraint( + ) + .with_span(span), + TypedAssemblyStatement::constraint( assignee.into(), e, - SourceMetadata::new(module_id.display().to_string(), pos.0), - ), + SourceMetadata::new(module_id.display().to_string(), span.from), + ) + .with_span(span), ]), ty => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Assignee must be of type field, found {}", ty), }), } } false => { let e = FieldElementExpression::block(vec![], e); - Ok(vec![TypedAssemblyStatement::Assignment(assignee, e.into())]) + Ok(vec![ + TypedAssemblyStatement::assignment(assignee, e.into()).with_span(span) + ]) } } } @@ -1836,7 +1838,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let rhs = self.check_expression(rhs, module_id, types)?; let lhs = FieldElementExpression::try_from_typed(lhs).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected left hand side of a constraint to be of type field, found {}", e.get_type(), @@ -1844,18 +1846,19 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; let rhs = FieldElementExpression::try_from_typed(rhs).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected right hand side of a constraint to be of type field, found {}", e.get_type(), ), })?; - Ok(vec![TypedAssemblyStatement::Constraint( + Ok(vec![TypedAssemblyStatement::constraint( lhs, rhs, - SourceMetadata::new(module_id.display().to_string(), pos.0), - )]) + SourceMetadata::new(module_id.display().to_string(), span.from), + ) + .with_span(span)]) } } } @@ -1866,7 +1869,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, Vec> { - let pos = stat.pos(); + let span = stat.span().in_module(module_id); match stat.value { Statement::Assembly(statements) => { @@ -1877,7 +1880,9 @@ impl<'ast, T: Field> Checker<'ast, T> { .map_err(|e| vec![e])?, ); } - Ok(TypedStatement::Assembly(checked_statements)) + Ok(TypedStatement::Assembly( + AssemblyBlockStatement::new(checked_statements).with_span(span), + )) } Statement::Log(l, expressions) => { let l = FormatString::from(l); @@ -1893,7 +1898,7 @@ impl<'ast, T: Field> Checker<'ast, T> { for e in &expressions { if let TypedExpression::Int(e) = e { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot determine type for expression `{}`", e), }); } @@ -1901,7 +1906,7 @@ impl<'ast, T: Field> Checker<'ast, T> { if expressions.len() != l.len() { errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Wrong argument count in log call: expected {}, got {}", l.len(), @@ -1914,7 +1919,7 @@ impl<'ast, T: Field> Checker<'ast, T> { return Err(errors); } - Ok(TypedStatement::Log(l, expressions)) + Ok(TypedStatement::log(l, expressions).with_span(span)) } Statement::Return(e) => { let mut errors = vec![]; @@ -1925,11 +1930,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let e_checked = e .map(|e| { match e.value { - Expression::FunctionCall( - box fun_id_expression, - generics, - arguments, - ) => { + Expression::FunctionCall(fun_id_expression, generics, arguments) => { let ty = zokrates_ast::typed::types::try_from_g_type( return_type.clone(), ) @@ -1937,7 +1938,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .unwrap(); self.check_function_call_expression( - fun_id_expression, + *fun_id_expression, generics, arguments, ty, @@ -1954,7 +1955,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let res = match TypedExpression::align_to_type(e_checked.clone(), &return_type) .map_err(|e| { vec![ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected return value to be of type `{}`, found `{}` of type `{}`", e.1, @@ -1967,7 +1968,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match e.get_type() == return_type { true => {} false => errors.push(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected `{}` in return statement, found `{}`", return_type, @@ -1975,11 +1976,11 @@ impl<'ast, T: Field> Checker<'ast, T> { ), }), } - TypedStatement::Return(e) + TypedStatement::ret(e).with_span(span) } Err(err) => { errors.extend(err); - TypedStatement::Return(e_checked) + TypedStatement::ret(e_checked).with_span(span) } }; @@ -1992,7 +1993,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Statement::Definition(var, expr) => { // get the lhs type let var_ty = self - .check_type(var.value._type, module_id, types) + .check_type(var.value.ty, module_id, types) .map_err(|e| vec![e])?; // check the rhs based on the lhs type @@ -2003,11 +2004,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // insert the lhs into the scope and ignore whether shadowing happened self.insert_into_scope(var.value.id, var_ty.clone(), var.value.is_mutable); - let var = Variable::new( - self.id_in_this_scope(var.value.id), - var_ty.clone(), - var.value.is_mutable, - ); + let var = Variable::new(self.id_in_this_scope(var.value.id), var_ty.clone()); match var_ty { Type::FieldElement => FieldElementExpression::try_from_typed(checked_expr) @@ -2032,7 +2029,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Int => Err(checked_expr), // Integers cannot be assigned } .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`", e, @@ -2041,7 +2038,7 @@ impl<'ast, T: Field> Checker<'ast, T> { var_ty ), }) - .map(|e| TypedStatement::Definition(var.into(), e.into())) + .map(|e| TypedStatement::definition(var.into(), e).with_span(span)) .map_err(|e| vec![e]) } Statement::Assignment(assignee, expr) => { @@ -2079,7 +2076,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Int => Err(checked_expr), // Integers cannot be assigned } .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expression `{}` of type `{}` cannot be assigned to `{}` of type `{}`", e, @@ -2088,7 +2085,7 @@ impl<'ast, T: Field> Checker<'ast, T> { assignee_ty ), }) - .map(|e| TypedStatement::Definition(assignee, e.into())) + .map(|e| TypedStatement::definition(assignee, e).with_span(span)) .map_err(|e| vec![e]) } Statement::Assertion(e, message) => { @@ -2097,15 +2094,16 @@ impl<'ast, T: Field> Checker<'ast, T> { .map_err(|e| vec![e])?; match e { - TypedExpression::Boolean(e) => Ok(TypedStatement::Assertion( + TypedExpression::Boolean(e) => Ok(TypedStatement::assertion( e, RuntimeError::SourceAssertion( - SourceMetadata::new(module_id.display().to_string(), pos.0) + SourceMetadata::new(module_id.display().to_string(), span.from) .message(message), ), - )), + ) + .with_span(span)), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected {} to be of type bool, found {}", e, @@ -2118,7 +2116,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Statement::For(var, from, to, statements) => { self.enter_scope(); - let res = self.check_for_loop(var, (from, to), statements, pos, module_id, types); + let res = self.check_for_loop(var, (from, to), statements, span, module_id, types); self.exit_scope(); @@ -2133,33 +2131,32 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { - let pos = assignee.pos(); + let span = assignee.span().in_module(module_id); // check that the assignee is declared match assignee.value { Assignee::Identifier(variable_name) => match self.scope.get(variable_name) { Some(info) => match info.is_mutable { false => Err(ErrorInner { - pos: Some(assignee.pos()), + span: Some(assignee.span().in_module(module_id)), message: format!("Assignment to an immutable variable `{}`", variable_name), }), _ => Ok(TypedAssignee::Identifier(Variable::new( info.id, info.ty.clone(), - info.is_mutable, ))), }, None => Err(ErrorInner { - pos: Some(assignee.pos()), + span: Some(assignee.span().in_module(module_id)), message: format!("Variable `{}` is undeclared", variable_name), }), }, - Assignee::Select(box assignee, box index) => { - let checked_assignee = self.check_assignee(assignee, module_id, types)?; + Assignee::Select(assignee, index) => { + let checked_assignee = self.check_assignee(*assignee, module_id, types)?; let ty = checked_assignee.get_type(); match ty { Type::Array(..) => { - let checked_index = match index { + let checked_index = match *index { RangeOrExpression::Expression(e) => { self.check_expression(e, module_id, types)? } @@ -2172,7 +2169,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let checked_typed_index = UExpression::try_from_typed(checked_index, &UBitwidth::B32).map_err( |e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array {} index to have type u32, found {}", checked_assignee, @@ -2181,13 +2178,10 @@ impl<'ast, T: Field> Checker<'ast, T> { }, )?; - Ok(TypedAssignee::Select( - box checked_assignee, - box checked_typed_index, - )) + Ok(TypedAssignee::select(checked_assignee, checked_typed_index)) } ty => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access element at index {} on {} of type {}", index, checked_assignee, ty, @@ -2195,15 +2189,15 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Assignee::Member(box assignee, box member) => { - let checked_assignee = self.check_assignee(assignee, module_id, types)?; + Assignee::Member(assignee, member) => { + let checked_assignee = self.check_assignee(*assignee, module_id, types)?; let ty = checked_assignee.get_type(); match &ty { - Type::Struct(members) => match members.iter().find(|m| m.id == member) { - Some(_) => Ok(TypedAssignee::Member(box checked_assignee, member.into())), + Type::Struct(members) => match members.iter().find(|m| m.id == *member) { + Some(_) => Ok(TypedAssignee::member(checked_assignee, (*member).into())), None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} {{{}}} doesn't have member {}", ty, @@ -2217,7 +2211,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), }, ty => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access field {} on {} of type {}", @@ -2226,15 +2220,15 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Assignee::Element(box assignee, index) => { - let checked_assignee = self.check_assignee(assignee, module_id, types)?; + Assignee::Element(assignee, index) => { + let checked_assignee = self.check_assignee(*assignee, module_id, types)?; let ty = checked_assignee.get_type(); match &ty { Type::Tuple(tuple_ty) => match tuple_ty.elements.get(index as usize) { - Some(_) => Ok(TypedAssignee::Element(box checked_assignee, index)), + Some(_) => Ok(TypedAssignee::element(checked_assignee, index)), None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Tuple of size {} cannot be accessed at index {}", tuple_ty.elements.len(), @@ -2243,7 +2237,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), }, ty => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access element {} on {} of type {}", @@ -2263,7 +2257,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ) -> Result, ErrorInner> { match spread_or_expression { SpreadOrExpression::Spread(s) => { - let pos = s.pos(); + let span = s.span().in_module(module_id); let checked_expression = self.check_expression(s.value.expression, module_id, types)?; @@ -2271,7 +2265,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match checked_expression { TypedExpression::Array(a) => Ok(TypedExpressionOrSpread::Spread(a.into())), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected spread operator to apply on array, found {}", e.get_type() @@ -2294,11 +2288,11 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { - let pos = function_id.pos(); + let span = function_id.span().in_module(module_id); let fun_id = match function_id.value { Expression::Identifier(id) => Ok(id), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected function in function call to be an identifier, found `{}`", e @@ -2313,11 +2307,11 @@ impl<'ast, T: Field> Checker<'ast, T> { .into_iter() .map(|g| { g.map(|g| { - let pos = g.pos(); + let span = g.span().in_module(module_id); self.check_expression(g, module_id, types).and_then(|g| { UExpression::try_from_typed(g, &UBitwidth::B32).map_err(|e| { ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected {} to be of type u32, found {}", e, @@ -2361,7 +2355,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let signature = f.signature; let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.iter()).map(|(a, t)| TypedExpression::align_to_type(a, t)).collect::, _>>().map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Expected function call argument to be of type `{}`, found `{}` of type `{}`", e.1, e.0, e.0.get_type()) })?; @@ -2371,7 +2365,7 @@ impl<'ast, T: Field> Checker<'ast, T> { generics_checked.clone(), arguments_checked.iter().map(|a| a.get_type()).collect() ).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Failed to infer value for generic parameter `{}`, try providing an explicit value", e, @@ -2384,7 +2378,7 @@ impl<'ast, T: Field> Checker<'ast, T> { signature: signature.clone(), }; - match output_type { + let res: Result, _> = match output_type { Type::Int => unreachable!(), Type::FieldElement => Ok(FieldElementExpression::function_call( function_key, @@ -2410,23 +2404,25 @@ impl<'ast, T: Field> Checker<'ast, T> { function_key, generics_checked, arguments_checked, - ).annotate(*array_ty.ty, *array_ty.size).into()), + ).annotate(array_ty).into()), Type::Tuple(tuple_ty) => Ok(TupleExpression::function_call( function_key, generics_checked, arguments_checked, ).annotate(tuple_ty).into()), - } + }; + + res.map(|e| e.with_span(span)) } 0 => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Function definition for function {} with signature {} not found.", fun_id, query ), }), n => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n) }), } @@ -2438,11 +2434,11 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { - let pos = expr.pos(); + let span: SourceSpan = expr.span().in_module(module_id); match expr.value { - Expression::IntConstant(v) => Ok(IntExpression::Value(v).into()), - Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()), + Expression::IntConstant(v) => Ok(IntExpression::value(v).with_span(span).into()), + Expression::BooleanConstant(b) => Ok(BooleanExpression::value(b).with_span(span).into()), Expression::Identifier(name) => { // check that `id` is defined in the scope match self.scope.get(name) { @@ -2457,7 +2453,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(FieldElementExpression::identifier(id.into()).into()) } Type::Array(array_type) => Ok(ArrayExpression::identifier(id.into()) - .annotate(*array_type.ty, *array_type.size) + .annotate(array_type) .into()), Type::Struct(members) => Ok(StructExpression::identifier(id.into()) .annotate(members) @@ -2469,14 +2465,14 @@ impl<'ast, T: Field> Checker<'ast, T> { } } None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Identifier \"{}\" is undefined", name), }), } } - Expression::Add(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Add(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; use self::TypedExpression::*; @@ -2484,14 +2480,15 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `+` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { - (Int(e1), Int(e2)) => Ok(IntExpression::Add(box e1, box e2).into()), + (Int(e1), Int(e2)) => Ok(IntExpression::add(e1, e2).into()), (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::Add(box e1, box e2).into()) + Ok(FieldElementExpression::add(e1, e2) + .into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => @@ -2499,7 +2496,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok((e1 + e2).into()) } (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `+` to {}, {}", @@ -2509,9 +2506,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Sub(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Sub(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; use self::TypedExpression::*; @@ -2519,18 +2516,16 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `-` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { - (Int(e1), Int(e2)) => Ok(IntExpression::Sub(box e1, box e2).into()), - (FieldElement(e1), FieldElement(e2)) => { - Ok(FieldElementExpression::Sub(box e1, box e2).into()) - } + (Int(e1), Int(e2)) => Ok(IntExpression::sub(e1, e2).into()), + (FieldElement(e1), FieldElement(e2)) => Ok((e1 - e2).into()), (Uint(e1), Uint(e2)) if e1.get_type() == e2.get_type() => Ok((e1 - e2).into()), (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected only field elements, found {}, {}", @@ -2540,9 +2535,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Mult(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Mult(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; use self::TypedExpression::*; @@ -2550,14 +2545,14 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `*` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { - (Int(e1), Int(e2)) => Ok(IntExpression::Mult(box e1, box e2).into()), + (Int(e1), Int(e2)) => Ok(IntExpression::mul(e1, e2).into()), (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::Mult(box e1, box e2).into()) + Ok((e1 * e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => @@ -2565,7 +2560,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok((e1 * e2).into()) } (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `*` to {}, {}", @@ -2575,9 +2570,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Div(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Div(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; use self::TypedExpression::*; @@ -2585,22 +2580,20 @@ impl<'ast, T: Field> Checker<'ast, T> { e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `/` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { - (Int(e1), Int(e2)) => Ok(IntExpression::Div(box e1, box e2).into()), - (FieldElement(e1), FieldElement(e2)) => { - Ok(FieldElementExpression::Div(box e1, box e2).into()) - } + (Int(e1), Int(e2)) => Ok(IntExpression::div(e1, e2).into()), + (FieldElement(e1), FieldElement(e2)) => Ok((e1 / e2).into()), (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => { Ok((e1 / e2).into()) } (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `/` to {}, {}", @@ -2610,15 +2603,15 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Rem(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Rem(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `%` to {}, {}", e1.get_type(), e2.get_type()), })?; @@ -2629,7 +2622,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok((e1 % e2).into()) } (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `%` to {}, {}", @@ -2639,9 +2632,9 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Pow(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Pow(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let e1_checked = match FieldElementExpression::try_from_typed(e1_checked) { Ok(e) => e.into(), @@ -2654,10 +2647,10 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::Uint(e2)) => Ok( - TypedExpression::FieldElement(FieldElementExpression::Pow(box e1, box e2)), + TypedExpression::FieldElement(FieldElementExpression::pow(e1, e2)), ), (t1, t2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected `field` and `u32`, found {}, {}", @@ -2667,17 +2660,17 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Neg(box e) => { - let e = self.check_expression(e, module_id, types)?; + Expression::Neg(e) => { + let e = self.check_expression(*e, module_id, types)?; match e { - TypedExpression::Int(e) => Ok(IntExpression::Neg(box e).into()), + TypedExpression::Int(e) => Ok(IntExpression::neg(e).into()), TypedExpression::FieldElement(e) => { - Ok(FieldElementExpression::Neg(box e).into()) + Ok(FieldElementExpression::neg(e).into()) } TypedExpression::Uint(e) => Ok((-e).into()), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Unary operator `-` cannot be applied to {} of type {}", e, @@ -2686,17 +2679,17 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Pos(box e) => { - let e = self.check_expression(e, module_id, types)?; + Expression::Pos(e) => { + let e = self.check_expression(*e, module_id, types)?; match e { - TypedExpression::Int(e) => Ok(IntExpression::Pos(box e).into()), + TypedExpression::Int(e) => Ok(IntExpression::pos(e).into()), TypedExpression::FieldElement(e) => { - Ok(FieldElementExpression::Pos(box e).into()) + Ok(FieldElementExpression::pos(e).into()) } TypedExpression::Uint(e) => Ok(UExpression::pos(e).into()), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Unary operator `+` cannot be applied to {} of type {}", e, @@ -2705,7 +2698,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Conditional(box conditional) => { + Expression::Conditional(conditional) => { let condition_checked = self.check_expression(*conditional.condition, module_id, types)?; @@ -2713,7 +2706,7 @@ impl<'ast, T: Field> Checker<'ast, T> { || !conditional.alternative_statements.is_empty() { return Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Statements are not supported in conditional branches".to_string(), }); } @@ -2729,7 +2722,7 @@ impl<'ast, T: Field> Checker<'ast, T> { alternative_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("{{consequence}} and {{alternative}} in conditional expression should have the same type, found {}, {}", e1.get_type(), e2.get_type()), })?; @@ -2767,13 +2760,13 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(IntExpression::conditional(condition, consequence, alternative, kind).into()) }, (c, a) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("{{consequence}} and {{alternative}} in conditional expression should have the same type, found {}, {}", c.get_type(), a.get_type()) }) } } c => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{{condition}} should be a boolean, found {}", c.get_type() @@ -2781,39 +2774,41 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::FieldConstant(n) => Ok(FieldElementExpression::Number( - T::try_from(n).map_err(|_| ErrorInner { - pos: Some(pos), - message: format!( - "Field constant not in the representable range [{}, {}]", - T::min_value(), - T::max_value() - ), - })?, - ) + Expression::FieldConstant(n) => Ok(FieldElementExpression::Value( + T::try_from(n) + .map(ValueExpression::new) + .map_err(|_| ErrorInner { + span: Some(span), + message: format!( + "Field constant not in the representable range [{}, {}]", + T::min_value(), + T::max_value() + ), + })?, + ).with_span(span) .into()), - Expression::U8Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(8).into()), - Expression::U16Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(16).into()), - Expression::U32Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(32).into()), - Expression::U64Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(64).into()), - Expression::FunctionCall(box fun_id_expression, generics, arguments) => self + Expression::U8Constant(n) => Ok(UExpression::value(n.into()).annotate(8).with_span(span).into()), + Expression::U16Constant(n) => Ok(UExpression::value(n.into()).annotate(16).with_span(span).into()), + Expression::U32Constant(n) => Ok(UExpression::value(n.into()).annotate(32).with_span(span).into()), + Expression::U64Constant(n) => Ok(UExpression::value(n.into()).annotate(64).with_span(span).into()), + Expression::FunctionCall(fun_id_expression, generics, arguments) => self .check_function_call_expression( - fun_id_expression, + *fun_id_expression, generics, arguments, None, module_id, types, ), - Expression::Lt(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Lt(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2825,14 +2820,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldLt(box e1, box e2).into()) + Ok(BooleanExpression::field_lt(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { if e1.get_type() == e2.get_type() { - Ok(BooleanExpression::UintLt(box e1, box e2).into()) + Ok(BooleanExpression::uint_lt(e1, e2).into()) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2844,7 +2839,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2855,15 +2850,15 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Le(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Le(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2875,14 +2870,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldLe(box e1, box e2).into()) + Ok(BooleanExpression::field_le(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { if e1.get_type() == e2.get_type() { - Ok(BooleanExpression::UintLe(box e1, box e2).into()) + Ok(BooleanExpression::uint_le(e1, e2).into()) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2894,7 +2889,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2905,15 +2900,15 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Eq(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Eq(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2925,27 +2920,27 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::field_eq(e1, e2).into()) } (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { - Ok(BooleanExpression::BoolEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::bool_eq(e1, e2).into()) } (TypedExpression::Array(e1), TypedExpression::Array(e2)) => { - Ok(BooleanExpression::ArrayEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::array_eq(e1, e2).into()) } (TypedExpression::Struct(e1), TypedExpression::Struct(e2)) => { - Ok(BooleanExpression::StructEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::struct_eq(e1, e2).into()) } (TypedExpression::Tuple(e1), TypedExpression::Tuple(e2)) => { - Ok(BooleanExpression::TupleEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::tuple_eq(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => { - Ok(BooleanExpression::UintEq(EqExpression::new(e1, e2)).into()) + Ok(BooleanExpression::uint_eq(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2956,15 +2951,15 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Ge(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Ge(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2976,14 +2971,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldGe(box e1, box e2).into()) + Ok(BooleanExpression::field_ge(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { if e1.get_type() == e2.get_type() { - Ok(BooleanExpression::UintGe(box e1, box e2).into()) + Ok(BooleanExpression::uint_ge(e1, e2).into()) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -2995,7 +2990,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -3006,15 +3001,15 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Gt(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Gt(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -3026,14 +3021,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::FieldGt(box e1, box e2).into()) + Ok(BooleanExpression::field_gt(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { if e1.get_type() == e2.get_type() { - Ok(BooleanExpression::UintGt(box e1, box e2).into()) + Ok(BooleanExpression::uint_gt(e1, e2).into()) } else { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -3045,7 +3040,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot compare {} of type {} to {} of type {}", e1, @@ -3056,10 +3051,10 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Select(box array, box index) => { - let array = self.check_expression(array, module_id, types)?; + Expression::Select(array, index) => { + let array = self.check_expression(*array, module_id, types)?; - match index { + match *index { RangeOrExpression::Range(r) => { match array { TypedExpression::Array(array) => { @@ -3081,7 +3076,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .unwrap_or_else(|| Ok(array_size.clone().into()))?; let from = UExpression::try_from_typed(from, &UBitwidth::B32).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected the lower bound of the range to be a u32, found {} of type {}", e, @@ -3090,7 +3085,7 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; let to = UExpression::try_from_typed(to, &UBitwidth::B32).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected the upper bound of the range to be a u32, found {} of type {}", e, @@ -3099,15 +3094,16 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; Ok(ArrayExpressionInner::Slice( - box array, - box from.clone(), - box to.clone(), - ) - .annotate(inner_type, UExpression::floor_sub(to, from)) + SliceExpression::new( + array, + from.clone(), + to.clone(), + )) + .annotate(ArrayType::new(inner_type, UExpression::floor_sub(to, from))) .into()) } e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access slice of expression {} of type {}", e, @@ -3122,7 +3118,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let index = UExpression::try_from_typed(index, &UBitwidth::B32).map_err(|e| { ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected index to be of type u32, found {}", e @@ -3145,7 +3141,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } a => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access element at index {} of type {} on expression {} of type {}", index, @@ -3158,8 +3154,8 @@ impl<'ast, T: Field> Checker<'ast, T> { } } } - Expression::Element(box e, index) => { - let e = self.check_expression(e, module_id, types)?; + Expression::Element(e, index) => { + let e = self.check_expression(*e, module_id, types)?; match e { TypedExpression::Tuple(t) => { let ty = t.ty().elements.get(index as usize); @@ -3177,7 +3173,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Tuple(..) => Ok(TupleExpression::element(t, index).into()), }, None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Tuple of size {} cannot be accessed at index {}", t.ty().elements.len(), @@ -3187,7 +3183,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access tuple element {} on expression of type {}", index, @@ -3196,13 +3192,13 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Member(box e, box id) => { - let e = self.check_expression(e, module_id, types)?; + Expression::Member(e, id) => { + let e = self.check_expression(*e, module_id, types)?; match e { TypedExpression::Struct(s) => { // check that the struct has that field and return the type if it does - let ty = s.ty().iter().find(|m| m.id == id).map(|m| *m.ty.clone()); + let ty = s.ty().iter().find(|m| m.id == *id).map(|m| *m.ty.clone()); match ty { Some(ty) => match ty { @@ -3225,7 +3221,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } }, None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "{} {{{}}} doesn't have member {}", s.get_type(), @@ -3241,7 +3237,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } } e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot access member {} on expression of type {}", id, @@ -3260,7 +3256,7 @@ impl<'ast, T: Field> Checker<'ast, T> { if expressions_or_spreads_checked.is_empty() { return Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: "Empty arrays are not allowed".to_string(), }); } @@ -3280,14 +3276,14 @@ impl<'ast, T: Field> Checker<'ast, T> { Type::Int => expressions_or_spreads_checked, t => { let target_array_ty = - ArrayType::new(t, UExpressionInner::Value(0).annotate(UBitwidth::B32)); + ArrayType::new(t, UExpression::value(0).annotate(UBitwidth::B32)); expressions_or_spreads_checked .into_iter() .map(|e| { TypedExpressionOrSpread::align_to_type(e, &target_array_ty).map_err( |(e, ty)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Expected {} to have type {}", e, ty,), }, ) @@ -3308,11 +3304,11 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|e| e.size()) .fold(None, |acc, e| match acc { Some((c_acc, e_acc)) => match e.as_inner() { - UExpressionInner::Value(e) => Some(((c_acc + *e as u32), e_acc)), + UExpressionInner::Value(e) => Some(((c_acc + e.value as u32), e_acc)), _ => Some((c_acc, e_acc + e)), }, None => match e.as_inner() { - UExpressionInner::Value(e) => Some((*e as u32, 0u32.into())), + UExpressionInner::Value(e) => Some((e.value as u32, 0u32.into())), _ => Some((0u32, e)), }, }) @@ -3320,8 +3316,8 @@ impl<'ast, T: Field> Checker<'ast, T> { .unwrap_or_else(|| 0u32.into()); Ok( - ArrayExpressionInner::Value(unwrapped_expressions_or_spreads.into()) - .annotate(inferred_type, size) + ArrayExpression::value(unwrapped_expressions_or_spreads) + .annotate(ArrayType::new(inferred_type, size)) .into(), ) } @@ -3331,17 +3327,17 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|e| self.check_expression(e, module_id, types)) .collect::>()?; let ty = TupleType::new(elements.iter().map(|e| e.get_type()).collect()); - Ok(TupleExpressionInner::Value(elements).annotate(ty).into()) + Ok(TupleExpression::value(elements).annotate(ty).into()) } - Expression::ArrayInitializer(box e, box count) => { - let e = self.check_expression(e, module_id, types)?; + Expression::ArrayInitializer(e, count) => { + let e = self.check_expression(*e, module_id, types)?; let ty = e.get_type(); - let count = self.check_expression(count, module_id, types)?; + let count = self.check_expression(*count, module_id, types)?; let count = UExpression::try_from_typed(count, &UBitwidth::B32).map_err(|e| { ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected array initializer count to be a u32, found {} of type {}", e, @@ -3350,14 +3346,14 @@ impl<'ast, T: Field> Checker<'ast, T> { } })?; - Ok(ArrayExpressionInner::Repeat(box e, box count.clone()) - .annotate(ty, count) + Ok(ArrayExpressionInner::Repeat(RepeatExpression::new(e, count.clone())) + .annotate(ArrayType::new(ty, count)) .into()) } Expression::InlineStruct(id, inline_members) => { let ty = match types.get(module_id).unwrap().get(&id).cloned() { None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Undefined type `{}`", id), }), Some(ty) => Ok(ty), @@ -3379,7 +3375,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // check that we provided the required number of values if declared_struct_type.members_count() != inline_members.len() { return Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Inline struct {} does not match {} {{{}}}", Expression::InlineStruct(id, inline_members), @@ -3414,7 +3410,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let expression_checked = TypedExpression::align_to_type(expression_checked, &*member.ty) .map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Member {} of struct {} has type {}, found {} of type {}", member.id, @@ -3428,7 +3424,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(expression_checked) } None => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Member {} of struct {} {{{}}} not found in value {}", member.id, @@ -3454,7 +3450,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .map(|(m, v)| { if !check_type(&m.ty, &v.get_type(), &mut generics_map) { Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Value `{}` doesn't match the expected type `{}` because of conflict in generic values", Expression::InlineStruct(id.clone(), inline_members.clone()), @@ -3462,10 +3458,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ), }) } else { - Ok(StructMember { - id: m.id.clone(), - ty: box v.get_type().clone(), - }) + Ok(StructMember::new(m.id.clone(), v.get_type().clone())) } }) .collect::, _>>()?; @@ -3479,19 +3472,19 @@ impl<'ast, T: Field> Checker<'ast, T> { members, }; - Ok(StructExpressionInner::Value(inferred_values) + Ok(StructExpression::value(inferred_values) .annotate(inferred_struct_type) .into()) } - Expression::And(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::And(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply boolean operators to {} and {}", e1.get_type(), @@ -3501,13 +3494,13 @@ impl<'ast, T: Field> Checker<'ast, T> { match (e1_checked, e2_checked) { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { - Ok(IntExpression::And(box e1, box e2).into()) + Ok(IntExpression::and(e1, e2).into()) } (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { - Ok(BooleanExpression::And(box e1, box e2).into()) + Ok(BooleanExpression::bitand(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply boolean operators to {} and {}", @@ -3517,15 +3510,15 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Or(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::Or(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; match (e1_checked, e2_checked) { (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { - Ok(BooleanExpression::Or(box e1, box e2).into()) + Ok(BooleanExpression::bitor(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `||` to {}, {}", e1.get_type(), @@ -3534,13 +3527,13 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::LeftShift(box e1, box e2) => { - let e1 = self.check_expression(e1, module_id, types)?; - let e2 = self.check_expression(e2, module_id, types)?; + Expression::LeftShift(e1, e2) => { + let e1 = self.check_expression(*e1, module_id, types)?; + let e2 = self.check_expression(*e2, module_id, types)?; let e2 = UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected the left shift right operand to have type `u32`, found {}", e @@ -3548,13 +3541,13 @@ impl<'ast, T: Field> Checker<'ast, T> { })?; match e1 { - TypedExpression::Int(e1) => Ok(IntExpression::LeftShift(box e1, box e2).into()), + TypedExpression::Int(e1) => Ok(IntExpression::left_shift(e1, e2).into()), TypedExpression::Uint(e1) => Ok(UExpression::left_shift(e1, e2).into()), TypedExpression::FieldElement(e1) => { - Ok(FieldElementExpression::LeftShift(box e1, box e2).into()) + Ok(FieldElementExpression::left_shift(e1, e2).into()) } e1 => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot left-shift {} by {}", e1.get_type(), @@ -3563,13 +3556,13 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::RightShift(box e1, box e2) => { - let e1 = self.check_expression(e1, module_id, types)?; - let e2 = self.check_expression(e2, module_id, types)?; + Expression::RightShift(e1, e2) => { + let e1 = self.check_expression(*e1, module_id, types)?; + let e2 = self.check_expression(*e2, module_id, types)?; let e2 = UExpression::try_from_typed(e2, &UBitwidth::B32).map_err(|e| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Expected the right shift right operand to be of type `u32`, found {}", e @@ -3578,14 +3571,14 @@ impl<'ast, T: Field> Checker<'ast, T> { match e1 { TypedExpression::Int(e1) => { - Ok(IntExpression::RightShift(box e1, box e2).into()) + Ok(IntExpression::right_shift(e1, e2).into()) } TypedExpression::Uint(e1) => Ok(UExpression::right_shift(e1, e2).into()), TypedExpression::FieldElement(e1) => { - Ok(FieldElementExpression::RightShift(box e1, box e2).into()) + Ok(FieldElementExpression::right_shift(e1, e2).into()) } e1 => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot right-shift {} by {}", e1.get_type(), @@ -3594,24 +3587,24 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::BitOr(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::BitOr(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `|` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { - Ok(IntExpression::Or(box e1, box e2).into()) + Ok(IntExpression::or(e1, e2).into()) } (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::Or(box e1, box e2).into()) + Ok(FieldElementExpression::bitor(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => @@ -3619,7 +3612,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(UExpression::or(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `|` to {}, {}", e1.get_type(), @@ -3628,24 +3621,24 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::BitAnd(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::BitAnd(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `&` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { - Ok(IntExpression::And(box e1, box e2).into()) + Ok(IntExpression::and(e1, e2).into()) } (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::And(box e1, box e2).into()) + Ok(FieldElementExpression::bitand(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => @@ -3653,7 +3646,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(UExpression::and(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `&` to {}, {}", e1.get_type(), @@ -3662,24 +3655,24 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::BitXor(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, types)?; - let e2_checked = self.check_expression(e2, module_id, types)?; + Expression::BitXor(e1, e2) => { + let e1_checked = self.check_expression(*e1, module_id, types)?; + let e2_checked = self.check_expression(*e2, module_id, types)?; let (e1_checked, e2_checked) = TypedExpression::align_without_integers( e1_checked, e2_checked, ) .map_err(|(e1, e2)| ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot apply `^` to {}, {}", e1.get_type(), e2.get_type()), })?; match (e1_checked, e2_checked) { (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { - Ok(IntExpression::Xor(box e1, box e2).into()) + Ok(IntExpression::xor(e1, e2).into()) } (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(FieldElementExpression::Xor(box e1, box e2).into()) + Ok(FieldElementExpression::bitxor(e1, e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => @@ -3687,7 +3680,7 @@ impl<'ast, T: Field> Checker<'ast, T> { Ok(UExpression::xor(e1, e2).into()) } (e1, e2) => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!( "Cannot apply `^` to {}, {}", e1.get_type(), @@ -3696,19 +3689,19 @@ impl<'ast, T: Field> Checker<'ast, T> { }), } } - Expression::Not(box e) => { - let e_checked = self.check_expression(e, module_id, types)?; + Expression::Not(e) => { + let e_checked = self.check_expression(*e, module_id, types)?; match e_checked { - TypedExpression::Int(e) => Ok(IntExpression::Not(box e).into()), - TypedExpression::Boolean(e) => Ok(BooleanExpression::Not(box e).into()), + TypedExpression::Int(e) => Ok(IntExpression::not(e).into()), + TypedExpression::Boolean(e) => Ok(BooleanExpression::not(e).into()), TypedExpression::Uint(e) => Ok((!e).into()), e => Err(ErrorInner { - pos: Some(pos), + span: Some(span), message: format!("Cannot negate {}", e.get_type()), }), } } - } + }.map(|e| e.with_span(span)) } fn insert_into_scope>>( @@ -3751,19 +3744,17 @@ mod tests { use zokrates_field::Bn128Field; lazy_static! { - static ref MODULE_ID: OwnedModuleId = OwnedModuleId::from(""); + static ref MODULE_ID: OwnedModuleId = OwnedModuleId::default(); } mod constants { use super::*; - use std::ops::Add; - #[test] fn field_in_range() { // The value of `P - 1` is a valid field literal let expr = Expression::FieldConstant(Bn128Field::max_value().to_biguint()).mock(); assert!(Checker::::default() - .check_expression(expr, &*MODULE_ID, &TypeMap::new()) + .check_expression(expr, &MODULE_ID, &TypeMap::new()) .is_ok()); } @@ -3774,7 +3765,7 @@ mod tests { let expr = Expression::FieldConstant(value).mock(); assert!(Checker::::default() - .check_expression(expr, &*MODULE_ID, &TypeMap::new()) + .check_expression(expr, &MODULE_ID, &TypeMap::new()) .is_err()); } } @@ -3797,7 +3788,7 @@ mod tests { ]) .mock(); assert!(Checker::::default() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &MODULE_ID, &types) .is_err()); // [[0f], [0f, 0f]] @@ -3817,7 +3808,7 @@ mod tests { ]) .mock(); assert!(Checker::::default() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &MODULE_ID, &types) .is_ok()); // [[0f], true] @@ -3833,7 +3824,7 @@ mod tests { ]) .mock(); assert!(Checker::::default() - .check_expression(a, &*MODULE_ID, &types) + .check_expression(a, &MODULE_ID, &types) .is_err()); } } @@ -3994,7 +3985,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&OwnedTypedModuleId::from("bar"), &mut state), + checker.check_module(&OwnedModuleId::from("bar"), &mut state), Ok(()) ); assert_eq!( @@ -4048,7 +4039,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4131,7 +4122,7 @@ mod tests { ); let mut checker: Checker = Checker::default(); - assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(checker.check_module(&MODULE_ID, &mut state).is_ok()); } mod generics { @@ -4173,7 +4164,7 @@ mod tests { ); let mut checker: Checker = Checker::default(); - assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(checker.check_module(&MODULE_ID, &mut state).is_ok()); } #[test] @@ -4230,7 +4221,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "Undeclared symbol `P`" @@ -4270,7 +4261,7 @@ mod tests { ); let mut checker: Checker = Checker::default(); - assert_eq!(checker.check_module(&*MODULE_ID, &mut state), Ok(())); + assert_eq!(checker.check_module(&MODULE_ID, &mut state), Ok(())); assert!(state .typed_modules .get(&*MODULE_ID) @@ -4321,7 +4312,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4365,7 +4356,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4417,7 +4408,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4466,7 +4457,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] + checker.check_module(&MODULE_ID, &mut state).unwrap_err()[0] .inner .message, "foo conflicts with another symbol" @@ -4495,14 +4486,14 @@ mod tests { let state = State::new(modules, (*MODULE_ID).clone()); let signature = UnresolvedSignature::new().inputs(vec![UnresolvedType::Array( - box UnresolvedType::FieldElement.mock(), + Box::new(UnresolvedType::FieldElement.mock()), Expression::Identifier("K").mock(), ) .mock()]); assert_eq!( - Checker::::default().check_signature(signature, &*MODULE_ID, &state), + Checker::::default().check_signature(signature, &MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Undeclared symbol `K`".to_string() }]) ); @@ -4517,27 +4508,31 @@ mod tests { let signature = UnresolvedSignature::new() .generics(vec!["K".mock(), "L".mock(), "M".mock()]) .inputs(vec![UnresolvedType::Array( - box UnresolvedType::Array( - box UnresolvedType::FieldElement.mock(), - Expression::Identifier("K").mock(), - ) - .mock(), + Box::new( + UnresolvedType::Array( + Box::new(UnresolvedType::FieldElement.mock()), + Expression::Identifier("K").mock(), + ) + .mock(), + ), Expression::Identifier("L").mock(), ) .mock()]) .output( UnresolvedType::Array( - box UnresolvedType::Array( - box UnresolvedType::FieldElement.mock(), - Expression::Identifier("L").mock(), - ) - .mock(), + Box::new( + UnresolvedType::Array( + Box::new(UnresolvedType::FieldElement.mock()), + Expression::Identifier("L").mock(), + ) + .mock(), + ), Expression::Identifier("K").mock(), ) .mock(), ); assert_eq!( - Checker::::default().check_signature(signature, &*MODULE_ID, &state), + Checker::::default().check_signature(signature, &MODULE_ID, &state), Ok(DeclarationSignature::new() .inputs(vec![DeclarationType::array(( DeclarationType::array(( @@ -4572,9 +4567,9 @@ mod tests { checker.enter_scope(); assert_eq!( - checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()), + checker.check_statement(statement, &MODULE_ID, &TypeMap::new()), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Identifier \"b\" is undefined".into() }]) ); @@ -4603,7 +4598,7 @@ mod tests { let mut checker: Checker = new_with_args(scope, HashSet::new()); checker.enter_scope(); assert_eq!( - checker.check_statement(statement, &*MODULE_ID, &TypeMap::new()), + checker.check_statement(statement, &MODULE_ID, &TypeMap::new()), Ok(TypedStatement::definition( typed::Variable::field_element("a").into(), FieldElementExpression::identifier("b".into()).into() @@ -4672,10 +4667,10 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state), + checker.check_module(&MODULE_ID, &mut state), Err(vec![Error { inner: ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Identifier \"a\" is undefined".into() }, module_id: (*MODULE_ID).clone() @@ -4768,7 +4763,7 @@ mod tests { ); let mut checker: Checker = Checker::default(); - assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(checker.check_module(&MODULE_ID, &mut state).is_ok()); } #[test] @@ -4802,9 +4797,9 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_function("foo", foo, &*MODULE_ID, &state), + checker.check_function("foo", foo, &MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Identifier \"i\" is undefined".into() }]) ); @@ -4851,7 +4846,7 @@ mod tests { )]; let foo_statements_checked = vec![ - TypedStatement::For( + TypedStatement::for_( typed::Variable::uint( CoreIdentifier::Source(ShadowedIdentifier::shadow("i".into(), 1)), UBitwidth::B32, @@ -4860,7 +4855,7 @@ mod tests { 10u32.into(), for_statements_checked, ), - TypedStatement::Return(TypedExpression::empty_tuple()), + TypedStatement::ret(TypedExpression::empty_tuple()), ]; let foo = Function { @@ -4881,7 +4876,7 @@ mod tests { let mut checker: Checker = Checker::default(); assert_eq!( - checker.check_function("foo", foo, &*MODULE_ID, &state), + checker.check_function("foo", foo, &MODULE_ID, &state), Ok(foo_checked) ); } @@ -4899,8 +4894,12 @@ mod tests { let bar_statements: Vec = vec![ Statement::Definition( untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(), - Expression::FunctionCall(box Expression::Identifier("foo").mock(), None, vec![]) - .mock(), + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], + ) + .mock(), ) .mock(), Statement::Return(None).mock(), @@ -4928,9 +4927,9 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), functions); assert_eq!( - checker.check_function("bar", bar, &*MODULE_ID, &state), + checker.check_function("bar", bar, &MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Function definition for function foo with signature () -> field not found." .into() @@ -4948,8 +4947,12 @@ mod tests { let bar_statements: Vec = vec![ Statement::Definition( untyped::Variable::immutable("a", UnresolvedType::FieldElement.mock()).mock(), - Expression::FunctionCall(box Expression::Identifier("foo").mock(), None, vec![]) - .mock(), + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], + ) + .mock(), ) .mock(), Statement::Return(None).mock(), @@ -4967,9 +4970,9 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( - checker.check_function("bar", bar, &*MODULE_ID, &state), + checker.check_function("bar", bar, &MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Function definition for function foo with signature () -> field not found." @@ -5004,8 +5007,12 @@ mod tests { let main_statements: Vec = vec![ Statement::Assignment( Assignee::Identifier("a").mock(), - Expression::FunctionCall(box Expression::Identifier("foo").mock(), None, vec![]) - .mock(), + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], + ) + .mock(), ) .mock(), Statement::Return(None).mock(), @@ -5042,10 +5049,10 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( - checker.check_module(&*MODULE_ID, &mut state), + checker.check_module(&MODULE_ID, &mut state), Err(vec![Error { inner: ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Variable `a` is undeclared".into() }, module_id: (*MODULE_ID).clone() @@ -5096,14 +5103,18 @@ mod tests { .mock(), Statement::Assignment( Assignee::Select( - box Assignee::Identifier("a").mock(), - box RangeOrExpression::Expression( + Box::new(Assignee::Identifier("a").mock()), + Box::new(RangeOrExpression::Expression( untyped::Expression::IntConstant(0usize.into()).mock(), - ), + )), + ) + .mock(), + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], ) .mock(), - Expression::FunctionCall(box Expression::Identifier("foo").mock(), None, vec![]) - .mock(), ) .mock(), Statement::Return(None).mock(), @@ -5139,7 +5150,7 @@ mod tests { ); let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); - assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(checker.check_module(&MODULE_ID, &mut state).is_ok()); } #[test] @@ -5153,13 +5164,15 @@ mod tests { let bar_statements: Vec = vec![ Statement::Assertion( Expression::Eq( - box Expression::IntConstant(1usize.into()).mock(), - box Expression::FunctionCall( - box Expression::Identifier("foo").mock(), - None, - vec![], - ) - .mock(), + Box::new(Expression::IntConstant(1usize.into()).mock()), + Box::new( + Expression::FunctionCall( + Box::new(Expression::Identifier("foo").mock()), + None, + vec![], + ) + .mock(), + ), ) .mock(), None, @@ -5180,9 +5193,9 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( - checker.check_function("bar", bar, &*MODULE_ID, &state), + checker.check_function("bar", bar, &MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Function definition for function foo with signature () -> _ not found." .into() @@ -5213,9 +5226,9 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( - checker.check_function("bar", bar, &*MODULE_ID, &state), + checker.check_function("bar", bar, &MODULE_ID, &state), Err(vec![ErrorInner { - pos: Some((Position::mock(), Position::mock())), + span: Some(SourceSpan::mock()), message: "Identifier \"a\" is undefined".into() }]) ); @@ -5251,7 +5264,7 @@ mod tests { let mut checker: Checker = new_with_args(Scope::default(), HashSet::new()); assert_eq!( checker - .check_function("main", f, &*MODULE_ID, &state) + .check_function("main", f, &MODULE_ID, &state) .unwrap_err()[0] .message, "Duplicate name in function definition: `a` was previously declared as an argument, a generic parameter or a constant" @@ -5326,7 +5339,7 @@ mod tests { checker.check_program(program), Err(vec![Error { inner: ErrorInner { - pos: None, + span: None, message: "Only one main function allowed, found 2".into() }, module_id: (*MODULE_ID).clone() @@ -5353,7 +5366,7 @@ mod tests { untyped::Expression::IntConstant(2usize.into()).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ); let s2_checked: Result, Vec> = checker @@ -5364,7 +5377,7 @@ mod tests { untyped::Expression::IntConstant(2usize.into()).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ); assert!(s2_checked.is_ok()); @@ -5385,7 +5398,7 @@ mod tests { untyped::Expression::IntConstant(2usize.into()).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ); let s2_checked: Result, Vec> = checker @@ -5395,7 +5408,7 @@ mod tests { untyped::Expression::BooleanConstant(true).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ); assert!(s2_checked.is_ok()); @@ -5459,16 +5472,14 @@ mod tests { typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, - true, ) .into(), - FieldElementExpression::Number(2u32.into()).into(), + FieldElementExpression::value(2u32.into()).into(), ), - TypedStatement::For( + TypedStatement::for_( typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("i".into(), 1)), Type::Uint(UBitwidth::B32), - false, ), 0u32.into(), 0u32.into(), @@ -5477,19 +5488,17 @@ mod tests { typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, - true, ) .into(), - FieldElementExpression::Number(3u32.into()).into(), + FieldElementExpression::value(3u32.into()).into(), ), TypedStatement::definition( typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 1)), Type::FieldElement, - false, ) .into(), - FieldElementExpression::Number(4u32.into()).into(), + FieldElementExpression::value(4u32.into()).into(), ), ], ), @@ -5497,10 +5506,9 @@ mod tests { typed::Variable::new( CoreIdentifier::from(ShadowedIdentifier::shadow("a".into(), 0)), Type::FieldElement, - true, ) .into(), - FieldElementExpression::Number(5u32.into()).into(), + FieldElementExpression::value(5u32.into()).into(), ), ]; @@ -5509,7 +5517,7 @@ mod tests { .into_iter() .map(|s| { checker - .check_statement(s, &*MODULE_ID, &TypeMap::default()) + .check_statement(s, &MODULE_ID, &TypeMap::default()) .unwrap() }) .collect(); @@ -5541,7 +5549,7 @@ mod tests { let mut checker: Checker = Checker::default(); - checker.check_module(&*MODULE_ID, &mut state).unwrap(); + checker.check_module(&MODULE_ID, &mut state).unwrap(); (checker, state) } @@ -5569,7 +5577,7 @@ mod tests { Checker::::default().check_struct_type_declaration( "Foo".into(), declaration, - &*MODULE_ID, + &MODULE_ID, &state ), Ok(expected_type) @@ -5613,7 +5621,7 @@ mod tests { Checker::::default().check_struct_type_declaration( "Foo".into(), declaration, - &*MODULE_ID, + &MODULE_ID, &state ), Ok(expected_type) @@ -5648,7 +5656,7 @@ mod tests { .check_struct_type_declaration( "Foo".into(), declaration, - &*MODULE_ID, + &MODULE_ID, &state ) .unwrap_err()[0] @@ -5705,7 +5713,7 @@ mod tests { ); assert!(Checker::default() - .check_module(&*MODULE_ID, &mut state) + .check_module(&MODULE_ID, &mut state) .is_ok()); assert_eq!( state @@ -5765,7 +5773,7 @@ mod tests { ); assert!(Checker::default() - .check_module(&*MODULE_ID, &mut state) + .check_module(&MODULE_ID, &mut state) .is_err()); } @@ -5799,7 +5807,7 @@ mod tests { ); assert!(Checker::default() - .check_module(&*MODULE_ID, &mut state) + .check_module(&MODULE_ID, &mut state) .is_err()); } @@ -5851,7 +5859,7 @@ mod tests { ); assert!(Checker::default() - .check_module(&*MODULE_ID, &mut state) + .check_module(&MODULE_ID, &mut state) .is_err()); } } @@ -5881,7 +5889,7 @@ mod tests { assert_eq!( checker.check_type( UnresolvedType::User("Foo".into(), None).mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ), Ok(Type::Struct(StructType::new( @@ -5896,7 +5904,7 @@ mod tests { checker .check_type( UnresolvedType::User("Bar".into(), None).mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -5929,19 +5937,21 @@ mod tests { assert_eq!( checker.check_expression( Expression::Member( - box Expression::InlineStruct( - "Foo".into(), - vec![("foo", Expression::IntConstant(42usize.into()).mock())] - ) - .mock(), + Box::new( + Expression::InlineStruct( + "Foo".into(), + vec![("foo", Expression::IntConstant(42usize.into()).mock())] + ) + .mock() + ), "foo".into() ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ), Ok(FieldElementExpression::member( - StructExpressionInner::Value(vec![FieldElementExpression::Number( + StructExpression::value(vec![FieldElementExpression::value( Bn128Field::from(42u32) ) .into()]) @@ -5977,15 +5987,20 @@ mod tests { checker .check_expression( Expression::Member( - box Expression::InlineStruct( - "Foo".into(), - vec![("foo", Expression::IntConstant(42usize.into()).mock())] - ) - .mock(), + Box::new( + Expression::InlineStruct( + "Foo".into(), + vec![( + "foo", + Expression::IntConstant(42usize.into()).mock() + )] + ) + .mock() + ), "bar".into() ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -6020,7 +6035,7 @@ mod tests { vec![("foo", Expression::IntConstant(42usize.into()).mock())] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -6062,12 +6077,12 @@ mod tests { ] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ), - Ok(StructExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(42u32)).into(), - BooleanExpression::Value(true).into() + Ok(StructExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(42u32)).into(), + BooleanExpression::value(true).into() ]) .annotate(StructType::new( "".into(), @@ -6115,12 +6130,12 @@ mod tests { ] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ), - Ok(StructExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(42u32)).into(), - BooleanExpression::Value(true).into() + Ok(StructExpression::value(vec![ + FieldElementExpression::value(Bn128Field::from(42u32)).into(), + BooleanExpression::value(true).into() ]) .annotate(StructType::new( "".into(), @@ -6166,7 +6181,7 @@ mod tests { vec![("foo", Expression::IntConstant(42usize.into()).mock())] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -6214,7 +6229,7 @@ mod tests { )] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ).unwrap_err() .message, @@ -6232,7 +6247,7 @@ mod tests { ] ) .mock(), - &*MODULE_ID, + &MODULE_ID, &state.types ) .unwrap_err() @@ -6292,7 +6307,7 @@ mod tests { main.value.statements = vec![Statement::Return(Some( Expression::FunctionCall( - box Expression::Identifier("foo").mock(), + Box::new(Expression::Identifier("foo").mock()), None, vec![Expression::IntConstant(0usize.into()).mock()], ) @@ -6356,17 +6371,16 @@ mod tests { Expression::FieldConstant(42u32.into()).mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), + checker.check_assignee(a, &MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Identifier(typed::Variable::new( "a", Type::FieldElement, - true ))) ); } @@ -6376,8 +6390,10 @@ mod tests { // field[3] a = [1, 2, 3] // a[2] = 42 let a = Assignee::Select( - box Assignee::Identifier("a").mock(), - box RangeOrExpression::Expression(Expression::IntConstant(2usize.into()).mock()), + Box::new(Assignee::Identifier("a").mock()), + Box::new(RangeOrExpression::Expression( + Expression::IntConstant(2usize.into()).mock(), + )), ) .mock(); @@ -6408,20 +6424,19 @@ mod tests { .mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), + checker.check_assignee(a, &MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Select( - box TypedAssignee::Identifier(typed::Variable::new( + Box::new(TypedAssignee::Identifier(typed::Variable::new( "a", Type::array((Type::FieldElement, 3u32)), - true, - )), - box 2u32.into() + ))), + Box::new(2u32.into()) )) ); } @@ -6431,14 +6446,18 @@ mod tests { // field[1][1] a = [[1]] // a[0][0] let a: AssigneeNode = Assignee::Select( - box Assignee::Select( - box Assignee::Identifier("a").mock(), - box RangeOrExpression::Expression( - Expression::IntConstant(0usize.into()).mock(), - ), - ) - .mock(), - box RangeOrExpression::Expression(Expression::IntConstant(0usize.into()).mock()), + Box::new( + Assignee::Select( + Box::new(Assignee::Identifier("a").mock()), + Box::new(RangeOrExpression::Expression( + Expression::IntConstant(0usize.into()).mock(), + )), + ) + .mock(), + ), + Box::new(RangeOrExpression::Expression( + Expression::IntConstant(0usize.into()).mock(), + )), ) .mock(); @@ -6469,23 +6488,22 @@ mod tests { .mock(), ) .mock(), - &*MODULE_ID, + &MODULE_ID, &TypeMap::new(), ) .unwrap(); assert_eq!( - checker.check_assignee(a, &*MODULE_ID, &TypeMap::new()), + checker.check_assignee(a, &MODULE_ID, &TypeMap::new()), Ok(TypedAssignee::Select( - box TypedAssignee::Select( - box TypedAssignee::Identifier(typed::Variable::new( + Box::new(TypedAssignee::Select( + Box::new(TypedAssignee::Identifier(typed::Variable::new( "a", Type::array((Type::array((Type::FieldElement, 1u32)), 1u32)), - true, - )), - box 0u32.into() - ), - box 0u32.into() + ))), + Box::new(0u32.into()) + )), + Box::new(0u32.into()) )) ); } diff --git a/zokrates_core_test/Cargo.toml b/zokrates_core_test/Cargo.toml index 19c802668..3730edd67 100644 --- a/zokrates_core_test/Cargo.toml +++ b/zokrates_core_test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core_test" -version = "0.2.10" +version = "0.2.11" authors = ["schaeff "] edition = "2018" diff --git a/zokrates_core_test/tests/tests/uint/ch.json b/zokrates_core_test/tests/tests/uint/ch.json index 0039cf2a6..435893ece 100644 --- a/zokrates_core_test/tests/tests/uint/ch.json +++ b/zokrates_core_test/tests/tests/uint/ch.json @@ -1,14 +1,14 @@ { "entry_point": "./tests/tests/uint/ch.zok", - "max_constraint_count": 200, + "max_constraint_count": 132, "tests": [ { "input": { - "values": ["0x00000000", "0x00000000", "0x00000000"] + "values": ["0x0000000f", "0x0000000f", "0x0000000f"] }, "output": { "Ok": { - "value": "0x00000000" + "value": "0x0000000f" } } } diff --git a/zokrates_embed/Cargo.toml b/zokrates_embed/Cargo.toml index 147a2f0cc..fcff9510e 100644 --- a/zokrates_embed/Cargo.toml +++ b/zokrates_embed/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_embed" -version = "0.1.9" +version = "0.1.10" authors = ["schaeff "] edition = "2018" diff --git a/zokrates_embed/src/ark.rs b/zokrates_embed/src/ark.rs index 40ce41ddd..fa8544962 100644 --- a/zokrates_embed/src/ark.rs +++ b/zokrates_embed/src/ark.rs @@ -279,8 +279,8 @@ fn var_to_index(var: &FpVar, offset: usize) -> usize { fn new_g1(flat: &[T]) -> G1 { assert_eq!(flat.len(), 2); G1::new( - BLS12Fq::from_str(&*flat[0].to_dec_string()).unwrap(), - BLS12Fq::from_str(&*flat[1].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[0].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[1].to_dec_string()).unwrap(), false, ) } @@ -290,12 +290,12 @@ fn new_g2(flat: &[T]) -> G2 { assert_eq!(flat.len(), 4); G2::new( BLS12Fq2::new( - BLS12Fq::from_str(&*flat[0].to_dec_string()).unwrap(), - BLS12Fq::from_str(&*flat[1].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[0].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[1].to_dec_string()).unwrap(), ), BLS12Fq2::new( - BLS12Fq::from_str(&*flat[2].to_dec_string()).unwrap(), - BLS12Fq::from_str(&*flat[3].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[2].to_dec_string()).unwrap(), + BLS12Fq::from_str(&flat[3].to_dec_string()).unwrap(), ), false, ) diff --git a/zokrates_field/Cargo.toml b/zokrates_field/Cargo.toml index bb420c7f9..c73d38591 100644 --- a/zokrates_field/Cargo.toml +++ b/zokrates_field/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_field" -version = "0.5.3" +version = "0.5.4" authors = ["Thibaut Schaeffer ", "Guillaume Ballet "] edition = "2018" @@ -29,7 +29,7 @@ ark-bn254 = { version = "^0.3.0", features = ["curve"], default-features = false ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = false } ark-bls12-381 = { version = "^0.3.0", features = ["curve"] } ark-bw6-761 = { version = "^0.3.0", default-features = false } -ark-serialize = { version = "^0.3.0", default-features = false } +ark-serialize = { version = "^0.3.0", default-features = false, features = ["std"] } [dev-dependencies] rand = "0.4" diff --git a/zokrates_field/src/bn128.rs b/zokrates_field/src/bn128.rs index 4b6c9cbb3..a30b3a45f 100644 --- a/zokrates_field/src/bn128.rs +++ b/zokrates_field/src/bn128.rs @@ -48,7 +48,7 @@ mod tests { ); assert_eq!( FieldPrime::from("65484493"), - FieldPrime::from("65416358") + &FieldPrime::from("68135") + FieldPrime::from("65416358") + FieldPrime::from("68135") ); } @@ -60,7 +60,7 @@ mod tests { ); assert_eq!( FieldPrime::from("3"), - FieldPrime::from("5") + &FieldPrime::from(-2) + FieldPrime::from("5") + FieldPrime::from(-2) ); } @@ -72,7 +72,7 @@ mod tests { ); assert_eq!( FieldPrime::from("65348223"), - FieldPrime::from("65416358") + &FieldPrime::from(-68135) + FieldPrime::from("65416358") + FieldPrime::from(-68135) ); } @@ -84,7 +84,7 @@ mod tests { ); assert_eq!( FieldPrime::from("65348223"), - FieldPrime::from("65416358") - &FieldPrime::from("68135") + FieldPrime::from("65416358") - FieldPrime::from("68135") ); } @@ -96,7 +96,7 @@ mod tests { ); assert_eq!( FieldPrime::from("65484493"), - FieldPrime::from("65416358") - &FieldPrime::from(-68135) + FieldPrime::from("65416358") - FieldPrime::from(-68135) ); } @@ -112,7 +112,7 @@ mod tests { FieldPrime::from( "21888242871839275222246405745257275088548364400416034343698204186575743147394" ), - FieldPrime::from("68135") - &FieldPrime::from("65416358") + FieldPrime::from("68135") - FieldPrime::from("65416358") ); } @@ -124,7 +124,7 @@ mod tests { ); assert_eq!( FieldPrime::from("13472"), - FieldPrime::from("32") * &FieldPrime::from("421") + FieldPrime::from("32") * FieldPrime::from("421") ); } @@ -140,7 +140,7 @@ mod tests { FieldPrime::from( "21888242871839275222246405745257275088548364400416034343698204186575808014369" ), - FieldPrime::from("54") * &FieldPrime::from(-8912) + FieldPrime::from("54") * FieldPrime::from(-8912) ); } @@ -152,7 +152,7 @@ mod tests { ); assert_eq!( FieldPrime::from("648"), - FieldPrime::from(-54) * &FieldPrime::from(-12) + FieldPrime::from(-54) * FieldPrime::from(-12) ); } @@ -172,7 +172,7 @@ mod tests { ), FieldPrime::from( "21888242871839225222246405785257275088694311157297823662689037894645225727" - ) * &FieldPrime::from("218882428715392752222464057432572755886923") + ) * FieldPrime::from("218882428715392752222464057432572755886923") ); } @@ -184,7 +184,7 @@ mod tests { ); assert_eq!( FieldPrime::from(4), - FieldPrime::from(48) / &FieldPrime::from(12) + FieldPrime::from(48) / FieldPrime::from(12) ); } @@ -318,7 +318,7 @@ mod tests { let a: Fr = rng.gen(); // now test idempotence let a = FieldPrime::from_bellman(a); - assert_eq!(FieldPrime::from_bellman(a.clone().into_bellman()), a); + assert_eq!(FieldPrime::from_bellman(a.into_bellman()), a); } } diff --git a/zokrates_field/src/dummy_curve.rs b/zokrates_field/src/dummy_curve.rs index 5d3aed4a3..55460f036 100644 --- a/zokrates_field/src/dummy_curve.rs +++ b/zokrates_field/src/dummy_curve.rs @@ -10,7 +10,9 @@ use std::ops::{Add, Div, Mul, Sub}; const _PRIME: u8 = 7; -#[derive(Default, Debug, Hash, Clone, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq)] +#[derive( + Default, Debug, Hash, Clone, Copy, PartialOrd, Ord, Serialize, Deserialize, PartialEq, Eq, +)] pub struct FieldPrime { v: u8, } @@ -250,4 +252,12 @@ impl Field for FieldPrime { fn to_biguint(&self) -> num_bigint::BigUint { unimplemented!() } + + fn read(_: R) -> std::io::Result { + unimplemented!() + } + + fn write(&self, _: W) -> std::io::Result<()> { + unimplemented!() + } } diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index 38f76905b..a767f7d08 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -8,7 +8,6 @@ extern crate num_bigint; #[cfg(feature = "bellman")] use bellman_ce::pairing::{ff::ScalarEngine, Engine}; - use num_bigint::BigUint; use num_traits::{CheckedDiv, One, Zero}; use serde::{Deserialize, Serialize}; @@ -16,6 +15,7 @@ use std::convert::{From, TryFrom}; use std::fmt; use std::fmt::{Debug, Display}; use std::hash::Hash; +use std::io::{Read, Write}; use std::ops::{Add, Div, Mul, Sub}; pub trait Pow { @@ -70,6 +70,7 @@ pub trait Field: + Zero + One + Clone + + Copy + PartialEq + Eq + Hash @@ -95,6 +96,10 @@ pub trait Field: + num_traits::CheckedMul { const G2_TYPE: G2Type = G2Type::Fq2; + // Read field from the reader + fn read(reader: R) -> std::io::Result; + // Write field to the writer + fn write(&self, writer: W) -> std::io::Result<()>; /// Returns this `Field`'s contents as little-endian byte vector fn to_byte_vector(&self) -> Vec; /// Returns an element of this `Field` from a little-endian byte vector @@ -144,11 +149,12 @@ mod prime_field { use std::convert::TryFrom; use std::fmt; use std::fmt::{Debug, Display}; + use std::io::{Read, Write}; use std::ops::{Add, Div, Mul, Sub}; type Fr = <$v as ark_ec::PairingEngine>::Fr; - #[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash)] + #[derive(PartialEq, PartialOrd, Clone, Copy, Eq, Ord, Hash)] pub struct FieldPrime { v: Fr, } @@ -186,9 +192,21 @@ mod prime_field { self.v.into_repr().to_bytes_le() } - fn from_byte_vector(bytes: Vec) -> Self { + fn read(reader: R) -> std::io::Result { use ark_ff::FromBytes; + Ok(FieldPrime { + v: Fr::read(reader)?, + }) + } + fn write(&self, mut writer: W) -> std::io::Result<()> { + use ark_ff::ToBytes; + self.v.write(&mut writer)?; + Ok(()) + } + + fn from_byte_vector(bytes: Vec) -> Self { + use ark_ff::FromBytes; FieldPrime { v: Fr::from(::BigInt::read(&bytes[..]).unwrap()), } @@ -591,9 +609,12 @@ mod prime_field { } fn into_bellman(self) -> ::Fr { - use bellman_ce::pairing::ff::PrimeField; - let s = self.to_dec_string(); - ::Fr::from_str(&s).unwrap() + use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr}; + let bytes = self.to_byte_vector(); + let mut repr = + <::Fr as PrimeField>::Repr::default(); + repr.read_le(bytes.as_slice()).unwrap(); + ::Fr::from_repr(repr).unwrap() } fn new_fq2( diff --git a/zokrates_interpreter/Cargo.toml b/zokrates_interpreter/Cargo.toml index 270436685..30c9f272d 100644 --- a/zokrates_interpreter/Cargo.toml +++ b/zokrates_interpreter/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_interpreter" -version = "0.1.3" +version = "0.1.4" edition = "2021" [features] diff --git a/zokrates_interpreter/src/lib.rs b/zokrates_interpreter/src/lib.rs index 9f90e00ca..44fabfd73 100644 --- a/zokrates_interpreter/src/lib.rs +++ b/zokrates_interpreter/src/lib.rs @@ -5,7 +5,7 @@ use zokrates_abi::{Decode, Value}; use zokrates_ast::ir::{ LinComb, ProgIterator, QuadComb, RuntimeError, Solver, Statement, Variable, Witness, }; -use zokrates_ast::zir; +use zokrates_ast::zir::{self, Expr}; use zokrates_field::Field; pub type ExecutionResult = Result, Error>; @@ -50,22 +50,22 @@ impl Interpreter { witness.insert(Variable::one(), T::one()); for (arg, value) in program.arguments.iter().zip(inputs.iter()) { - witness.insert(arg.id, value.clone()); + witness.insert(arg.id, *value); } for statement in program.statements.into_iter() { match statement { Statement::Block(..) => unreachable!(), - Statement::Constraint(quad, lin, error) => match lin.is_assignee(&witness) { + Statement::Constraint(s) => match s.lin.is_assignee(&witness) { true => { - let val = evaluate_quad(&witness, &quad).unwrap(); - witness.insert(lin.0.get(0).unwrap().0, val); + let val = evaluate_quad(&witness, &s.quad).unwrap(); + witness.insert(s.lin.value.get(0).unwrap().0, val); } false => { - let lhs_value = evaluate_quad(&witness, &quad).unwrap(); - let rhs_value = evaluate_lin(&witness, &lin).unwrap(); + let lhs_value = evaluate_quad(&witness, &s.quad).unwrap(); + let rhs_value = evaluate_lin(&witness, &s.lin).unwrap(); if lhs_value != rhs_value { - return Err(Error::UnsatisfiedConstraint { error }); + return Err(Error::UnsatisfiedConstraint { error: s.error }); } } }, @@ -83,21 +83,21 @@ impl Interpreter { inputs.pop().unwrap(), )) } - _ => Self::execute_solver(&d.solver, &inputs), + _ => Self::execute_solver(&d.solver, &inputs, &program.solvers), } .map_err(Error::Solver)?; for (i, o) in d.outputs.iter().enumerate() { - witness.insert(*o, res[i].clone()); + witness.insert(*o, res[i]); } } - Statement::Log(l, expressions) => { - let mut parts = l.parts.into_iter(); + Statement::Log(s) => { + let mut parts = s.format_string.parts.into_iter(); write!(log_stream, "{}", parts.next().unwrap()) .map_err(|_| Error::LogStream)?; - for ((t, e), part) in expressions.into_iter().zip(parts) { + for ((t, e), part) in s.expressions.into_iter().zip(parts) { let values: Vec<_> = e .iter() .map(|e| evaluate_lin(&witness, e).unwrap()) @@ -164,7 +164,15 @@ impl Interpreter { pub fn execute_solver<'ast, T: Field>( solver: &Solver<'ast, T>, inputs: &[T], + solvers: &[Solver<'ast, T>], ) -> Result, String> { + let solver = match solver { + Solver::Ref(call) => solvers + .get(call.index) + .ok_or_else(|| format!("Could not get solver at index {}", call.index))?, + s => s, + }; + let (expected_input_count, expected_output_count) = solver.get_signature(); assert_eq!(inputs.len(), expected_input_count); @@ -177,26 +185,26 @@ impl Interpreter { .arguments .iter() .zip(inputs) - .map(|(a, v)| match &a.id._type { + .map(|(a, v)| match &a.id.ty { zir::Type::FieldElement => Ok(( a.id.id.clone(), - zokrates_ast::zir::FieldElementExpression::Number(v.clone()).into(), + zokrates_ast::zir::FieldElementExpression::value(*v).into(), )), zir::Type::Boolean => match v { v if *v == T::from(0) => Ok(( a.id.id.clone(), - zokrates_ast::zir::BooleanExpression::Value(false).into(), + zokrates_ast::zir::BooleanExpression::value(false).into(), )), v if *v == T::from(1) => Ok(( a.id.id.clone(), - zokrates_ast::zir::BooleanExpression::Value(true).into(), + zokrates_ast::zir::BooleanExpression::value(true).into(), )), v => Err(format!("`{}` has unexpected value `{}`", a.id, v)), }, zir::Type::Uint(bitwidth) => match v.bits() <= bitwidth.to_usize() as u32 { true => Ok(( a.id.id.clone(), - zokrates_ast::zir::UExpressionInner::Value( + zokrates_ast::zir::UExpression::value( v.to_dec_string().parse::().unwrap(), ) .annotate(*bitwidth) @@ -222,11 +230,12 @@ impl Interpreter { if let zokrates_ast::zir::ZirStatement::Return(v) = folded_function.statements[0].clone() { - v.into_iter() + v.inner + .into_iter() .map(|v| match v { zokrates_ast::zir::ZirExpression::FieldElement( - zokrates_ast::zir::FieldElementExpression::Number(n), - ) => n, + zokrates_ast::zir::FieldElementExpression::Value(n), + ) => n.value, _ => unreachable!(), }) .collect() @@ -256,41 +265,38 @@ impl Interpreter { .collect() } Solver::Xor => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); + let x = inputs[0]; + let y = inputs[1]; - vec![x.clone() + y.clone() - T::from(2) * x * y] + vec![x + y - T::from(2) * x * y] } Solver::Or => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); + let x = inputs[0]; + let y = inputs[1]; - vec![x.clone() + y.clone() - x * y] + vec![x + y - x * y] } // res = b * c - (2b * c - b - c) * (a) Solver::ShaAndXorAndXorAnd => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a] + let a = inputs[0]; + let b = inputs[1]; + let c = inputs[2]; + vec![b * c - (T::from(2) * b * c - b - c) * a] } // res = a(b - c) + c Solver::ShaCh => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![a * (b - c.clone()) + c] + let a = inputs[0]; + let b = inputs[1]; + let c = inputs[2]; + vec![a * (b - c) + c] } - Solver::Div => vec![inputs[0] - .clone() - .checked_div(&inputs[1]) - .unwrap_or_else(T::one)], + Solver::Div => vec![inputs[0].checked_div(&inputs[1]).unwrap_or_else(T::one)], Solver::EuclideanDiv => { use num::CheckedDiv; - let n = inputs[0].clone().to_biguint(); - let d = inputs[1].clone().to_biguint(); + let n = inputs[0].to_biguint(); + let d = inputs[1].to_biguint(); let q = n.checked_div(&d).unwrap_or_else(|| 0u32.into()); let r = n - d * &q; @@ -334,6 +340,7 @@ impl Interpreter { &inputs[*n + 8usize..], ) } + _ => unreachable!("unexpected solver"), }; assert_eq!(res.len(), expected_output_count); @@ -354,14 +361,11 @@ pub enum Error { } fn evaluate_lin(w: &Witness, l: &LinComb) -> Result { - l.0.iter() - .map(|(var, mult)| { - w.0.get(var) - .map(|v| v.clone() * mult) - .ok_or(EvaluationError) - }) // get each term - .collect::, _>>() // fail if any term isn't found - .map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum + l.value.iter().try_fold(T::from(0), |acc, (var, mult)| { + w.0.get(var) + .map(|v| acc + (*v * mult)) + .ok_or(EvaluationError) // fail if any term isn't found + }) } pub fn evaluate_quad(w: &Witness, q: &QuadComb) -> Result { @@ -434,6 +438,7 @@ mod tests { .iter() .map(|&i| Bn128Field::from(i)) .collect::>(), + &[], ) .unwrap(); let res: Vec = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect(); @@ -450,6 +455,7 @@ mod tests { .iter() .map(|&i| Bn128Field::from(i)) .collect::>(), + &[], ) .unwrap(); let res: Vec = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect(); @@ -460,9 +466,12 @@ mod tests { #[test] fn bits_of_one() { let inputs = vec![Bn128Field::from(1)]; - let res = - Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = Interpreter::execute_solver( + &Solver::Bits(Bn128Field::get_required_bits()), + &inputs, + &[], + ) + .unwrap(); assert_eq!(res[253], Bn128Field::from(1)); for r in &res[0..253] { assert_eq!(*r, Bn128Field::from(0)); @@ -472,9 +481,12 @@ mod tests { #[test] fn bits_of_42() { let inputs = vec![Bn128Field::from(42)]; - let res = - Interpreter::execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) - .unwrap(); + let res = Interpreter::execute_solver( + &Solver::Bits(Bn128Field::get_required_bits()), + &inputs, + &[], + ) + .unwrap(); assert_eq!(res[253], Bn128Field::from(0)); assert_eq!(res[252], Bn128Field::from(1)); assert_eq!(res[251], Bn128Field::from(0)); @@ -487,11 +499,51 @@ mod tests { #[test] fn five_hundred_bits_of_1() { let inputs = vec![Bn128Field::from(1)]; - let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs).unwrap(); + let res = Interpreter::execute_solver(&Solver::Bits(500), &inputs, &[]).unwrap(); let mut expected = vec![Bn128Field::from(0); 500]; expected[499] = Bn128Field::from(1); assert_eq!(res, expected); } + + #[test] + fn solver_ref() { + use std::ops::Mul; + use zir::{ + types::{Signature, Type}, + FieldElementExpression, Identifier, IdentifierExpression, Parameter, Variable, + ZirFunction, ZirStatement, + }; + use zokrates_ast::common::RefCall; + + let id = IdentifierExpression::new(Identifier::internal(0usize)); + + // (field i0) -> i0 * i0 + let solvers = vec![Solver::Zir(ZirFunction { + arguments: vec![Parameter::new(Variable::field_element(id.id.clone()), true)], + statements: vec![ZirStatement::ret(vec![FieldElementExpression::mul( + FieldElementExpression::Identifier(id.clone()), + FieldElementExpression::Identifier(id.clone()), + ) + .into()])], + signature: Signature::new() + .inputs(vec![Type::FieldElement]) + .outputs(vec![Type::FieldElement]), + })]; + + let inputs = vec![Bn128Field::from(2)]; + let res = Interpreter::execute_solver( + &Solver::Ref(RefCall { + index: 0, + argument_count: 1, + }), + &inputs, + &solvers, + ) + .unwrap(); + + let expected = vec![Bn128Field::from(4)]; + assert_eq!(res, expected); + } } diff --git a/zokrates_js/Cargo.toml b/zokrates_js/Cargo.toml index 02374289b..4840d1c48 100644 --- a/zokrates_js/Cargo.toml +++ b/zokrates_js/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_js" -version = "1.1.6" +version = "1.1.7" authors = ["Darko Macesic"] edition = "2018" @@ -11,7 +11,7 @@ crate-type = ["cdylib"] js-sys = "0.3.33" serde = { version = "^1.0.59", features = ["derive"] } serde_json = { version = "1.0", features = ["preserve_order"] } -wasm-bindgen = { version = "0.2.46", features = ["serde-serialize"] } +wasm-bindgen = { version = "=0.2.81", features = ["serde-serialize"] } typed-arena = "1.4.1" lazy_static = "1.4.0" zokrates_field = { path = "../zokrates_field" } diff --git a/zokrates_js/index.d.ts b/zokrates_js/index.d.ts index d2755a2c2..ff7cba246 100644 --- a/zokrates_js/index.d.ts +++ b/zokrates_js/index.d.ts @@ -42,7 +42,7 @@ declare module "zokrates-js" { } export interface ComputationResult { - witness: string; + witness: Uint8Array; output: string; snarkjs?: { witness: Uint8Array; @@ -90,7 +90,7 @@ declare module "zokrates-js" { setupWithSrs(srs: Uint8Array, program: Uint8Array): SetupKeypair; generateProof( program: Uint8Array, - witness: string, + witness: Uint8Array, provingKey: Uint8Array, entropy?: string ): Proof; diff --git a/zokrates_js/package.json b/zokrates_js/package.json index 02b6e919e..14ca7ef7e 100644 --- a/zokrates_js/package.json +++ b/zokrates_js/package.json @@ -1,6 +1,6 @@ { "name": "zokrates-js", - "version": "1.1.6", + "version": "1.1.7", "module": "index.js", "main": "index-node.js", "description": "JavaScript bindings for ZoKrates", diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 2b4dccebc..bc7a666c6 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -79,15 +79,17 @@ pub struct ResolverResult { #[wasm_bindgen] pub struct ComputationResult { - witness: String, + witness: Vec, output: String, snarkjs_witness: Option>, } #[wasm_bindgen] impl ComputationResult { - pub fn witness(&self) -> JsValue { - JsValue::from_str(&self.witness) + pub fn witness(&self) -> js_sys::Uint8Array { + let arr = js_sys::Uint8Array::new_with_length(self.witness.len() as u32); + arr.copy_from(&self.witness); + arr } pub fn output(&self) -> JsValue { @@ -151,7 +153,7 @@ impl<'a> Resolver for JsResolver<'a> { Some(Component::Normal(_)) => { let path_normalized = normalize_path(path); let source = STDLIB - .get(&path_normalized.to_str().unwrap()) + .get(path_normalized.to_str().unwrap()) .ok_or_else(|| { Error::new(format!( "module `{}` not found in stdlib", @@ -360,8 +362,14 @@ mod internal { buffer.into_inner() }); + let witness = { + let mut buffer = Cursor::new(vec![]); + witness.write(&mut buffer).unwrap(); + buffer.into_inner() + }; + Ok(ComputationResult { - witness: format!("{}", witness), + witness, output: to_string_pretty(&return_values).unwrap(), snarkjs_witness, }) @@ -416,15 +424,14 @@ mod internal { pub fn generate_proof, B: Backend, R: RngCore + CryptoRng>( prog: ir::Prog, - witness: JsValue, + witness: &[u8], pk: &[u8], rng: &mut R, ) -> Result { - let str_witness = witness.as_string().unwrap(); - let ir_witness: ir::Witness = ir::Witness::read(str_witness.as_bytes()) + let ir_witness: ir::Witness = ir::Witness::read(witness) .map_err(|err| JsValue::from_str(&format!("Could not read witness: {}", err)))?; - let proof = B::generate_proof(prog, ir_witness, pk.to_vec(), rng); + let proof = B::generate_proof(prog, ir_witness, pk, rng); Ok(JsValue::from_serde(&TaggedProof::::new(proof.proof, proof.inputs)).unwrap()) } @@ -513,7 +520,8 @@ pub fn compute_witness( config: JsValue, log_callback: &js_sys::Function, ) -> Result { - let prog = ir::ProgEnum::deserialize(program) + let cursor = Cursor::new(program); + let prog = ir::ProgEnum::deserialize(cursor) .map_err(|err| JsValue::from_str(&err))? .collect(); match prog { @@ -575,7 +583,8 @@ pub fn setup(program: &[u8], entropy: JsValue, options: JsValue) -> Result Result Result"] edition = "2018" [dependencies] -pest = "2.0" +pest = "=2.1" pest_derive = "2.0" [dev-dependencies] diff --git a/zokrates_parser/src/lib.rs b/zokrates_parser/src/lib.rs index b4f9cb85c..eb61ce2d8 100644 --- a/zokrates_parser/src/lib.rs +++ b/zokrates_parser/src/lib.rs @@ -12,6 +12,7 @@ use pest::Parser; #[grammar = "zokrates.pest"] struct ZoKratesParser; +#[allow(clippy::result_large_err)] pub fn parse(input: &str) -> Result, Error> { ZoKratesParser::parse(Rule::file, input) } diff --git a/zokrates_pest_ast/Cargo.toml b/zokrates_pest_ast/Cargo.toml index 372eb132d..a71ef66cd 100644 --- a/zokrates_pest_ast/Cargo.toml +++ b/zokrates_pest_ast/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "zokrates_pest_ast" -version = "0.3.1" +version = "0.3.2" authors = ["schaeff "] edition = "2018" [dependencies] zokrates_parser = { version = "0.3.0", path = "../zokrates_parser" } -pest = "2.0" +pest = "=2.1" pest-ast = "0.3.3" from-pest = "0.3.1" lazy_static = "1.3.0" diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 65268b304..8772e82f4 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -1173,6 +1173,7 @@ impl fmt::Display for Error { } } +#[allow(clippy::result_large_err)] pub fn generate_ast(input: &str) -> Result { let parse_tree = parse(input).map_err(Error)?; Ok(Prog::from(parse_tree).0) diff --git a/zokrates_profiler/Cargo.toml b/zokrates_profiler/Cargo.toml new file mode 100644 index 000000000..00c0295f3 --- /dev/null +++ b/zokrates_profiler/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "zokrates_profiler" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +zokrates_ast = { version = "0.1", path = "../zokrates_ast", default-features = false } \ No newline at end of file diff --git a/zokrates_profiler/src/lib.rs b/zokrates_profiler/src/lib.rs new file mode 100644 index 000000000..126178572 --- /dev/null +++ b/zokrates_profiler/src/lib.rs @@ -0,0 +1,52 @@ +use std::collections::HashMap; + +use zokrates_ast::{ + common::{ModuleMap, Span}, + ir::{ProgIterator, Statement}, +}; + +#[derive(Default, Debug)] +pub struct HeatMap { + /// the total number of constraints + count: usize, + /// for each span, how many constraints are linked to it + map: HashMap, usize>, +} + +impl HeatMap { + pub fn display(&self, module_map: &ModuleMap) -> String { + let count = self.count; + + let mut stats: Vec<_> = self.map.iter().collect(); + + stats.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + + stats + .iter() + .map(|(span, c)| { + format!( + "{:>4.2}% : {}", + (**c as f64) / (count as f64) * 100.0, + span.map(|s| s.resolve(module_map).to_string()) + .unwrap_or_else(|| String::from("???")), + ) + }) + .collect::>() + .join("\n") + } +} + +pub fn profile<'ast, T, I: IntoIterator>>( + prog: ProgIterator<'ast, T, I>, +) -> HeatMap { + prog.statements + .into_iter() + .fold(HeatMap::default(), |mut heat_map, s| match s { + Statement::Constraint(s) => { + heat_map.count += 1; + *heat_map.map.entry(s.span).or_default() += 1; + heat_map + } + _ => heat_map, + }) +} diff --git a/zokrates_proof_systems/Cargo.toml b/zokrates_proof_systems/Cargo.toml index 55fc91c82..b0b245101 100644 --- a/zokrates_proof_systems/Cargo.toml +++ b/zokrates_proof_systems/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_proof_systems" -version = "0.1.1" +version = "0.1.2" edition = "2021" [dependencies] diff --git a/zokrates_proof_systems/src/lib.rs b/zokrates_proof_systems/src/lib.rs index 69a9740df..80432e3f6 100644 --- a/zokrates_proof_systems/src/lib.rs +++ b/zokrates_proof_systems/src/lib.rs @@ -96,11 +96,16 @@ impl ToString for G2AffineFq2 { } pub trait Backend> { - fn generate_proof<'a, I: IntoIterator>, R: RngCore + CryptoRng>( + fn generate_proof< + 'a, + I: IntoIterator>, + R: Read, + G: RngCore + CryptoRng, + >( program: ir::ProgIterator<'a, T, I>, witness: ir::Witness, - proving_key: Vec, - rng: &mut R, + proving_key: R, + rng: &mut G, ) -> Proof; fn verify(vk: S::VerificationKey, proof: Proof) -> bool; @@ -129,16 +134,16 @@ pub trait MpcBackend> { ) -> Result<(), String>; fn contribute( - params: &mut R, + params: R, rng: &mut G, output: &mut W, ) -> Result<[u8; 64], String>; - fn verify<'a, P: Read, R: Read, I: IntoIterator>>( - params: &mut P, + fn verify<'a, R: Read, I: IntoIterator>>( + params: R, program: ir::ProgIterator<'a, T, I>, phase1_radix: &mut R, ) -> Result, String>; - fn export_keypair(params: &mut R) -> Result, String>; + fn export_keypair(params: R) -> Result, String>; } diff --git a/zokrates_proof_systems/src/rng.rs b/zokrates_proof_systems/src/rng.rs index 9e2d51df5..814217f27 100644 --- a/zokrates_proof_systems/src/rng.rs +++ b/zokrates_proof_systems/src/rng.rs @@ -5,7 +5,7 @@ use rand_0_8::{rngs::StdRng, SeedableRng}; pub fn get_rng_from_entropy(entropy: &str) -> StdRng { let h = { let mut h = Blake2b::default(); - h.input(&entropy.as_bytes()); + h.input(entropy.as_bytes()); h.result() }; diff --git a/zokrates_proof_systems/src/to_token.rs b/zokrates_proof_systems/src/to_token.rs index fbefdc400..a2be8a963 100644 --- a/zokrates_proof_systems/src/to_token.rs +++ b/zokrates_proof_systems/src/to_token.rs @@ -8,8 +8,8 @@ use super::{ /// Helper methods for parsing group structure pub fn encode_g1_element(g: &G1Affine) -> (U256, U256) { ( - U256::from(&hex::decode(&g.0.trim_start_matches("0x")).unwrap()[..]), - U256::from(&hex::decode(&g.1.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.0.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.1.trim_start_matches("0x")).unwrap()[..]), ) } @@ -17,12 +17,12 @@ pub fn encode_g2_element(g: &G2Affine) -> ((U256, U256), (U256, U256)) { match g { G2Affine::Fq2(g) => ( ( - U256::from(&hex::decode(&g.0 .0.trim_start_matches("0x")).unwrap()[..]), - U256::from(&hex::decode(&g.0 .1.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.0 .0.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.0 .1.trim_start_matches("0x")).unwrap()[..]), ), ( - U256::from(&hex::decode(&g.1 .0.trim_start_matches("0x")).unwrap()[..]), - U256::from(&hex::decode(&g.1 .1.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.1 .0.trim_start_matches("0x")).unwrap()[..]), + U256::from(&hex::decode(g.1 .1.trim_start_matches("0x")).unwrap()[..]), ), ), _ => unreachable!(), @@ -30,7 +30,7 @@ pub fn encode_g2_element(g: &G2Affine) -> ((U256, U256), (U256, U256)) { } pub fn encode_fr_element(f: &Fr) -> U256 { - U256::from(&hex::decode(&f.trim_start_matches("0x")).unwrap()[..]) + U256::from(&hex::decode(f.trim_start_matches("0x")).unwrap()[..]) } pub trait ToToken: SolidityCompatibleScheme { diff --git a/zokrates_test/Cargo.toml b/zokrates_test/Cargo.toml index 0a601c703..45e01ce8a 100644 --- a/zokrates_test/Cargo.toml +++ b/zokrates_test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_test" -version = "0.2.1" +version = "0.2.2" authors = ["schaeff "] edition = "2018" diff --git a/zokrates_test/tests/wasm.rs b/zokrates_test/tests/wasm.rs index 8b81f3bf0..6253b39f5 100644 --- a/zokrates_test/tests/wasm.rs +++ b/zokrates_test/tests/wasm.rs @@ -17,9 +17,15 @@ use zokrates_proof_systems::groth16::G16; #[wasm_bindgen_test] fn generate_proof() { let program: Prog = Prog { + module_map: Default::default(), arguments: vec![Parameter::public(Variable::new(0))], return_count: 1, - statements: vec![Statement::constraint(Variable::new(0), Variable::new(0))], + statements: vec![Statement::constraint( + Variable::new(0), + Variable::new(0), + None, + )], + solvers: vec![], }; let interpreter = Interpreter::default(); @@ -29,6 +35,10 @@ fn generate_proof() { let rng = &mut StdRng::from_entropy(); let keypair = >::setup(program.clone(), rng); - let _proof = - >::generate_proof(program, witness, keypair.pk, rng); + let _proof = >::generate_proof( + program, + witness, + keypair.pk.as_slice(), + rng, + ); } diff --git a/zokrates_test_derive/Cargo.toml b/zokrates_test_derive/Cargo.toml index 753297675..5d67294b4 100644 --- a/zokrates_test_derive/Cargo.toml +++ b/zokrates_test_derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_test_derive" -version = "0.0.1" +version = "0.0.2" authors = ["schaeff "] edition = "2018" diff --git a/zokrates_test_derive/src/lib.rs b/zokrates_test_derive/src/lib.rs index 5f7f10f9a..110a02753 100644 --- a/zokrates_test_derive/src/lib.rs +++ b/zokrates_test_derive/src/lib.rs @@ -9,7 +9,7 @@ pub fn write_tests(base: &str) { let base = Path::new(&base); let out_dir = env::var("OUT_DIR").unwrap(); let destination = Path::new(&out_dir).join("tests.rs"); - let test_file = File::create(&destination).unwrap(); + let test_file = File::create(destination).unwrap(); let mut writer = BufWriter::new(test_file); for p in glob(base.join("**/*.json").to_str().unwrap()).unwrap() {