diff --git a/Cargo.lock b/Cargo.lock index 23307eca..7dceb1ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -200,6 +200,29 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "bindgen" +version = "0.69.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +dependencies = [ + "bitflags 2.4.1", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.58", + "which", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -325,6 +348,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -358,6 +390,17 @@ dependencies = [ "half", ] +[[package]] +name = "clang-sys" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "2.34.0" @@ -394,6 +437,15 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" +[[package]] +name = "cmake" +version = "0.1.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +dependencies = [ + "cc", +] + [[package]] name = "constantine-core" version = "0.1.0" @@ -623,7 +675,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -753,6 +805,43 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "icicle-bls12-381" +version = "1.9.1" +source = "git+https://github.com/ArtiomTr/icicle.git?rev=2942ad9f9894119f0204325e08ddb55b8a8de227#2942ad9f9894119f0204325e08ddb55b8a8de227" +dependencies = [ + "cmake", + "icicle-core", + "icicle-cuda-runtime", +] + +[[package]] +name = "icicle-core" +version = "1.9.1" +source = "git+https://github.com/ArtiomTr/icicle.git?rev=2942ad9f9894119f0204325e08ddb55b8a8de227#2942ad9f9894119f0204325e08ddb55b8a8de227" +dependencies = [ + "icicle-cuda-runtime", + "rayon", +] + +[[package]] +name = "icicle-cuda-runtime" +version = "1.9.1" +source = "git+https://github.com/ArtiomTr/icicle.git?rev=2942ad9f9894119f0204325e08ddb55b8a8de227#2942ad9f9894119f0204325e08ddb55b8a8de227" +dependencies = [ + "bindgen", + "bitflags 1.3.2", +] + [[package]] name = "impl-codec" version = "0.6.0" @@ -791,7 +880,7 @@ checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi 0.3.3", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -838,6 +927,9 @@ name = "kzg" version = "0.1.0" dependencies = [ "blst", + "icicle-bls12-381", + "icicle-core", + "icicle-cuda-runtime", "num_cpus", "rayon", "sha2 0.10.8", @@ -866,12 +958,28 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.149" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" +[[package]] +name = "libloading" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" +dependencies = [ + "cfg-if", + "windows-targets 0.48.5", +] + [[package]] name = "linux-raw-sys" version = "0.4.10" @@ -899,6 +1007,22 @@ dependencies = [ "autocfg", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "num-bigint" version = "0.3.3" @@ -1053,6 +1177,16 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "prettyplease" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7" +dependencies = [ + "proc-macro2", + "syn 2.0.58", +] + [[package]] name = "primitive-types" version = "0.12.2" @@ -1076,18 +1210,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -1139,9 +1273,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ "either", "rayon-core", @@ -1149,9 +1283,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -1292,6 +1426,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hex" version = "2.1.0" @@ -1317,7 +1457,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1374,7 +1514,7 @@ checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.58", ] [[package]] @@ -1437,6 +1577,12 @@ dependencies = [ "opaque-debug", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "siphasher" version = "1.0.0" @@ -1474,9 +1620,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.38" +version = "2.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" dependencies = [ "proc-macro2", "quote", @@ -1613,7 +1759,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.58", "wasm-bindgen-shared", ] @@ -1635,7 +1781,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.58", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1656,6 +1802,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1693,7 +1851,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.4", ] [[package]] @@ -1702,13 +1869,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" +dependencies = [ + "windows_aarch64_gnullvm 0.52.4", + "windows_aarch64_msvc 0.52.4", + "windows_i686_gnu 0.52.4", + "windows_i686_msvc 0.52.4", + "windows_x86_64_gnu 0.52.4", + "windows_x86_64_gnullvm 0.52.4", + "windows_x86_64_msvc 0.52.4", ] [[package]] @@ -1717,42 +1899,84 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" + [[package]] name = "winnow" version = "0.5.18" @@ -1788,7 +2012,7 @@ checksum = "020f3dfe25dfc38dfea49ce62d5d45ecdd7f0d8a724fa63eb36b6eba4ec76806" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.58", ] [[package]] @@ -1808,5 +2032,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.58", ] diff --git a/arkworks/Cargo.toml b/arkworks/Cargo.toml index c73c9671..ee6e6453 100644 --- a/arkworks/Cargo.toml +++ b/arkworks/Cargo.toml @@ -15,7 +15,7 @@ ark-serialize = { version = "^0.4.2", default-features = false } hex = "0.4.3" rand = { version = "0.8.5", optional = true } libc = { version = "0.2.148", default-features = false } -rayon = { version = "1.8.0", optional = true } +rayon = { version = "1.9.0", optional = true } [dev-dependencies] criterion = "0.5.1" @@ -47,6 +47,9 @@ bgmw = [ arkmsm = [ "kzg/arkmsm" ] +cuda = [ + "kzg/cuda" +] [[bench]] name = "fft" diff --git a/arkworks/src/kzg_types.rs b/arkworks/src/kzg_types.rs index 1d77111b..67ce99f6 100644 --- a/arkworks/src/kzg_types.rs +++ b/arkworks/src/kzg_types.rs @@ -17,6 +17,7 @@ use crate::utils::{ use ark_bls12_381::{g1, g2, Fr, G1Affine, G2Affine}; use ark_ec::{models::short_weierstrass::Projective, AffineRepr, Group}; use ark_ec::{CurveConfig, CurveGroup}; +use ark_ff::BigInt; use ark_ff::{biginteger::BigInteger256, BigInteger, Field}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{One, Zero}; @@ -840,6 +841,16 @@ impl G1Fp for ArkFp { Self(default) } + fn to_limbs(&self) -> [u64; 6] { + self.0.0.0 + } + + fn from_bytes_le(bytes: &[u8; 48]) -> Self { + let storage: [u64; 6] = bytes.chunks(8).map(|it| u64::from_le_bytes(it.try_into().unwrap())).collect::>().try_into().unwrap(); + let big_int = BigInt::new(storage); + Self(ArkFpInt::from(big_int)) + } + fn neg_assign(&mut self) { self.0 = -self.0; } diff --git a/kzg/Cargo.toml b/kzg/Cargo.toml index ad5fb252..75274d8f 100644 --- a/kzg/Cargo.toml +++ b/kzg/Cargo.toml @@ -10,6 +10,9 @@ num_cpus = { version = "1.16.0", optional = true } rayon = { version = "1.8.0", optional = true } threadpool = { version = "^1.8.1", optional = true } siphasher = { version = "1.0.0", default-features = false } +icicle-bls12-381 = { git = "https://github.com/ArtiomTr/icicle.git", rev = "2942ad9f9894119f0204325e08ddb55b8a8de227", version = "1.9.1", optional = true } +icicle-core = { git = "https://github.com/ArtiomTr/icicle.git", rev = "2942ad9f9894119f0204325e08ddb55b8a8de227", version = "1.9.1", optional = true } +icicle-cuda-runtime = { git = "https://github.com/ArtiomTr/icicle.git", rev = "2942ad9f9894119f0204325e08ddb55b8a8de227", version = "1.9.1", optional = true } [features] default = [ @@ -29,3 +32,9 @@ std = [ rand = [] arkmsm = [] bgmw = [] +cuda = [ + "parallel", + "dep:icicle-bls12-381", + "dep:icicle-core", + "dep:icicle-cuda-runtime" +] diff --git a/kzg/src/lib.rs b/kzg/src/lib.rs index 3f546fd8..4d4b2116 100644 --- a/kzg/src/lib.rs +++ b/kzg/src/lib.rs @@ -201,6 +201,10 @@ pub trait G1Fp: Clone + Default + Sync + Copy + PartialEq + Debug + Send { fn set_one(&mut self) { *self = Self::ONE; } + + fn to_limbs(&self) -> [u64; 6]; + + fn from_bytes_le(bytes: &[u8; 48]) -> Self; } pub trait G1Affine: diff --git a/kzg/src/msm/cuda.rs b/kzg/src/msm/cuda.rs new file mode 100644 index 00000000..403d4adf --- /dev/null +++ b/kzg/src/msm/cuda.rs @@ -0,0 +1,122 @@ +use core::marker::PhantomData; + +use icicle_bls12_381::curve::CurveCfg; +use icicle_core::{curve::Affine, msm::{precompute_bases, MSMConfig}, traits::FieldImpl}; +use icicle_cuda_runtime::{memory::HostOrDeviceSlice, device_context::{DeviceContext, DEFAULT_DEVICE_ID}}; +use core::fmt::Debug; +use crate::{Fr, G1Affine, G1Fp, G1GetFp, G1Mul, Scalar256, G1}; + +use super::msm_impls::batch_convert; + +pub struct IcicleConfig +where + TFr: Fr, + TG1: G1 + G1Mul + G1GetFp, + TG1Fp: G1Fp, + TG1Affine: G1Affine, +{ + affines: HostOrDeviceSlice<'static, Affine>, + + g1_marker: PhantomData, + g1_fp_marker: PhantomData, + fr_marker: PhantomData, + g1_affine_marker: PhantomData +} + +impl< +TFr: Fr, +TG1Fp: G1Fp, +TG1: G1 + G1Mul + G1GetFp, +TG1Affine: G1Affine, +> Debug for IcicleConfig { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + // TODO: add formatting for affines + f.debug_struct("IcicleConfig").finish() + } +} + +impl< +TFr: Fr, +TG1Fp: G1Fp, +TG1: G1 + G1Mul + G1GetFp, +TG1Affine: G1Affine, +> Clone for IcicleConfig { + fn clone(&self) -> Self { + // FIXME: affines should be cloned actually + Self { affines: HostOrDeviceSlice::Host(vec![]), g1_marker: PhantomData, g1_fp_marker: PhantomData, fr_marker: PhantomData, g1_affine_marker: PhantomData } + } +} + +const PRECOMPUTE_FACTOR: usize = 8; + +impl< + TFr: Fr, + TG1Fp: G1Fp, + TG1: G1 + G1Mul + G1GetFp, + TG1Affine: G1Affine, + > IcicleConfig +{ + pub fn new(points: &[TG1]) -> Result, String> { + let affines_raw = batch_convert::(points).iter().map(|it| icicle_bls12_381::curve::G1Affine::from_limbs(it.x().to_limbs(), it.y().to_limbs())).collect::>(); + // let Ok(mut affines) = HostOrDeviceSlice::<'static, Affine>::cuda_malloc(affines_raw.len()) else { + // return Ok(None); + // }; + // if affines.copy_from_host(&affines_raw).is_err() { + // return Ok(None); + // } + let device_affines = HostOrDeviceSlice::on_host(affines_raw); + + let Ok(mut affines) = HostOrDeviceSlice::<'static, Affine>::cuda_malloc(points.len() * PRECOMPUTE_FACTOR) else { + return Ok(None); + }; + + if precompute_bases(&device_affines, PRECOMPUTE_FACTOR as i32, 0, &DeviceContext::default_for_device(DEFAULT_DEVICE_ID), &mut affines).is_err() { + return Ok(None); + } + + Ok(Some(Self { + affines, + + fr_marker: PhantomData, + g1_fp_marker: PhantomData, + g1_marker: PhantomData, + g1_affine_marker: PhantomData + })) + } + + pub fn multiply_sequential(&self, _scalars: &[Scalar256]) -> TG1 { + panic!("No sequential implementation for CUDA MSM"); + } + + #[cfg(feature = "parallel")] + pub fn multiply_parallel(&self, scalars: &[Scalar256]) -> TG1 { + use icicle_bls12_381::curve::ScalarField; + use icicle_core::curve::Projective; + use icicle_cuda_runtime::stream::CudaStream; + + let mut results = HostOrDeviceSlice::cuda_malloc(1).unwrap(); + let mut scalars_d = HostOrDeviceSlice::cuda_malloc(scalars.len()).unwrap(); + let stream = CudaStream::create().unwrap(); + scalars_d.copy_from_host_async(&scalars.iter().map(|it| ScalarField::from_bytes_le(it.as_u8())).collect::>(), &stream).unwrap(); + let mut config = MSMConfig::default_for_device(DEFAULT_DEVICE_ID); + config.precompute_factor = PRECOMPUTE_FACTOR as i32; + config.ctx.stream = &stream; + config.is_async = true; + + icicle_core::msm::msm(&scalars_d, &self.affines, &config, &mut results).unwrap(); + + let mut results_h = vec![Projective::::zero(); 1]; + results.copy_to_host_async(&mut results_h, &stream); + + stream.synchronize().unwrap(); + stream.destroy().unwrap(); + + let mut output = TG1::default(); + + *output.x_mut() = TG1Fp::from_bytes_le(&results_h.as_slice()[0].x.to_bytes_le().try_into().unwrap()); + *output.y_mut() = TG1Fp::from_bytes_le(&results_h.as_slice()[0].y.to_bytes_le().try_into().unwrap()); + *output.z_mut() = TG1Fp::from_bytes_le(&results_h.as_slice()[0].z.to_bytes_le().try_into().unwrap()); + + output + } +} diff --git a/kzg/src/msm/mod.rs b/kzg/src/msm/mod.rs index 9167c71e..c2fb3026 100644 --- a/kzg/src/msm/mod.rs +++ b/kzg/src/msm/mod.rs @@ -15,3 +15,11 @@ mod pippenger_utils; #[cfg(all(feature = "bgmw", any(not(feature = "arkmsm"), feature = "parallel")))] mod bgmw; + +#[cfg(feature = "cuda")] +mod cuda; + +#[cfg(all(feature = "cuda", feature = "bgmw"))] +compile_error!{"features `cuda` and `bgmw` are mutally exclusive"} +#[cfg(all(feature = "cuda", not(feature = "parallel")))] +compile_error!{"feature `cuda` requires feature `parallel`"} diff --git a/kzg/src/msm/msm_impls.rs b/kzg/src/msm/msm_impls.rs index d44a0348..d4ca4349 100644 --- a/kzg/src/msm/msm_impls.rs +++ b/kzg/src/msm/msm_impls.rs @@ -59,7 +59,7 @@ fn msm_sequential< } } -fn batch_convert + Sized>( +pub fn batch_convert + Sized>( points: &[TG1], ) -> Vec { #[cfg(feature = "parallel")] diff --git a/kzg/src/msm/precompute.rs b/kzg/src/msm/precompute.rs index 7eacbf98..9a435acc 100644 --- a/kzg/src/msm/precompute.rs +++ b/kzg/src/msm/precompute.rs @@ -9,7 +9,7 @@ pub type PrecomputationTable = super::bgmw::BgmwTable; #[cfg(any( - not(feature = "bgmw"), + all(not(feature = "bgmw"), not(feature = "cuda")), all(feature = "arkmsm", not(feature = "parallel")) ))] #[derive(Debug, Clone)] @@ -27,7 +27,7 @@ where } #[cfg(any( - not(feature = "bgmw"), + all(not(feature = "bgmw"), not(feature = "cuda")), all(feature = "arkmsm", not(feature = "parallel")) ))] impl EmptyTable @@ -52,11 +52,14 @@ where } #[cfg(any( - not(feature = "bgmw"), + all(not(feature = "bgmw"), not(feature = "cuda")), all(feature = "arkmsm", not(feature = "parallel")) ))] pub type PrecomputationTable = EmptyTable; +#[cfg(feature = "cuda")] +pub type PrecomputationTable = super::cuda::IcicleConfig; + pub fn precompute( points: &[TG1], ) -> Result>, String>