diff --git a/Cargo.lock b/Cargo.lock index 7886609e..cf85b495 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -87,7 +87,7 @@ dependencies = [ [[package]] name = "ai00_server" -version = "0.4.1" +version = "0.4.2" dependencies = [ "anyhow", "bit-set", @@ -668,10 +668,35 @@ dependencies = [ ] [[package]] -name = "com-rs" -version = "0.2.1" +name = "com" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e17887fd17353b65b1b2ef1c526c83e26cd72e74f598a8dc1bee13a48f3d9f6" +dependencies = [ + "com_macros", +] + +[[package]] +name = "com_macros" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d375883580a668c7481ea6631fc1a8863e33cc335bf56bfad8d7e6d4b04b13a5" +dependencies = [ + "com_macros_support", + "proc-macro2", + "syn 1.0.109", +] + +[[package]] +name = "com_macros_support" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf43edc576402991846b093a7ca18a3477e0ef9c588cde84964b5d3e43016642" +checksum = "ad899a1087a9296d5644792d7cb72b8e34c1bec8e7d4fbc002230169a6e8710c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] [[package]] name = "constant_time_eq" @@ -794,9 +819,9 @@ dependencies = [ [[package]] name = "d3d12" -version = "0.7.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e16e44ab292b1dddfdaf7be62cfd8877df52f2f3fde5858d95bab606be259f20" +checksum = "3e3d747f100290a1ca24b752186f61f6637e1deffe3bf6320de6fcb29510a307" dependencies = [ "bitflags 2.5.0", "libloading 0.8.3", @@ -1201,11 +1226,10 @@ dependencies = [ [[package]] name = "gpu-allocator" -version = "0.23.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40fe17c8a05d60c38c0a4e5a3c802f2f1ceb66b76c67d96ffb34bef0475a7fad" +checksum = "6f56f6318968d03c18e1bcf4857ff88c61157e9da8e47c5f29055d60e1228884" dependencies = [ - "backtrace", "log", "presser", "thiserror", @@ -1285,14 +1309,14 @@ dependencies = [ [[package]] name = "hassle-rs" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1397650ee315e8891a0df210707f0fc61771b0cc518c3023896064c5407cb3b0" +checksum = "af2a7e73e1f34c48da31fb668a907f250794837e08faa144fd24f0b8b741e890" dependencies = [ - "bitflags 1.3.2", - "com-rs", + "bitflags 2.5.0", + "com", "libc", - "libloading 0.7.4", + "libloading 0.8.3", "thiserror", "widestring", "winapi", @@ -1583,6 +1607,12 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "jobserver" version = "0.1.28" @@ -1826,9 +1856,9 @@ dependencies = [ [[package]] name = "naga" -version = "0.14.2" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae585df4b6514cf8842ac0f1ab4992edc975892704835b549cf818dc0191249e" +checksum = "50e3524642f53d9af419ab5e8dd29d3ba155708267667c2f3f06c88c9e130843" dependencies = [ "bit-set", "bitflags 2.5.0", @@ -1853,6 +1883,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "ndk-sys" +version = "0.5.0+25.2.9519653" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691" +dependencies = [ + "jni-sys", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -2317,9 +2356,9 @@ checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab" [[package]] name = "raw-window-handle" -version = "0.5.2" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +checksum = "8cc3bcbdb1ddfc11e700e62968e6b4cc9c75bb466464ad28fb61c5b2c964418b" [[package]] name = "rayon" @@ -3160,12 +3199,11 @@ dependencies = [ [[package]] name = "spirv" -version = "0.2.0+1.5.4" +version = "0.3.0+sdk-1.3.268.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "246bfa38fe3db3f1dfc8ca5a2cdeb7348c78be2112740cc0ec8ef18b6d94f830" +checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" dependencies = [ - "bitflags 1.3.2", - "num-traits", + "bitflags 2.5.0", ] [[package]] @@ -3826,9 +3864,9 @@ checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" [[package]] name = "web-rwkv" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f660f511cb5744347a37a97fa2762a6a32dfd3d85a9900db419fd2151b761456" +checksum = "22fc4bc8cbb7445bb71cb399875079b90dfa744b046d90f405785dd12677cd81" dependencies = [ "ahash 0.8.11", "anyhow", @@ -3881,9 +3919,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.64" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" dependencies = [ "js-sys", "wasm-bindgen", @@ -3901,13 +3939,13 @@ dependencies = [ [[package]] name = "wgpu" -version = "0.18.0" +version = "0.19.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30e7d227c9f961f2061c26f4cb0fbd4df0ef37e056edd0931783599d6c94ef24" +checksum = "cbd7311dbd2abcfebaabf1841a2824ed7c8be443a0f29166e5d3c6a53a762c01" dependencies = [ "arrayvec", "cfg-if", - "flume", + "cfg_aliases", "js-sys", "log", "naga", @@ -3926,16 +3964,19 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "0.18.1" +version = "0.19.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef91c1d62d1e9e81c79e600131a258edf75c9531cbdbde09c44a011a47312726" +checksum = "28b94525fc99ba9e5c9a9e24764f2bc29bad0911a7446c12f446a8277369bf3a" dependencies = [ "arrayvec", "bit-vec", "bitflags 2.5.0", + "cfg_aliases", "codespan-reporting", + "indexmap 2.2.5", "log", "naga", + "once_cell", "parking_lot", "profiling", "raw-window-handle", @@ -3949,9 +3990,9 @@ dependencies = [ [[package]] name = "wgpu-hal" -version = "0.18.1" +version = "0.19.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84ecc802da3eb67b4cf3dd9ea6fe45bbb47ef13e6c49c5c3240868a9cc6cdd9" +checksum = "fc1a4924366df7ab41a5d8546d6534f1f33231aa5b3f72b9930e300f254e39c3" dependencies = [ "android_system_properties", "arrayvec", @@ -3959,6 +4000,7 @@ dependencies = [ "bit-set", "bitflags 2.5.0", "block", + "cfg_aliases", "core-graphics-types", "d3d12", "glow", @@ -3974,6 +4016,7 @@ dependencies = [ "log", "metal", "naga", + "ndk-sys", "objc", "once_cell", "parking_lot", @@ -3992,9 +4035,9 @@ dependencies = [ [[package]] name = "wgpu-types" -version = "0.18.0" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d5ed5f0edf0de351fe311c53304986315ce866f394a2e6df0c4b3c70774bcdd" +checksum = "b671ff9fb03f78b46ff176494ee1ebe7d603393f42664be55b64dc8d53969805" dependencies = [ "bitflags 2.5.0", "js-sys", diff --git a/Cargo.toml b/Cargo.toml index 67361357..24469c85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" name = "ai00_server" repository = "https://github.com/cgisky1980/ai00_rwkv_server" rust-version = "1.75" -version = "0.4.1" +version = "0.4.2" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -53,7 +53,7 @@ zip-extract = "0.1" [dependencies.web-rwkv] # path = "../web-rwkv" -version = "0.7.5" +version = "0.7.6" [dependencies.salvo] default-features = true diff --git a/src/middleware.rs b/src/middleware.rs index 87e1aaff..cc7f05a1 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -10,7 +10,6 @@ use anyhow::{bail, Result}; use bnf_sampler::{utils::U8ArrayWrapper, vocabulary::Vocabulary}; use derivative::Derivative; use flume::{Receiver, Sender}; -use half::f16; use itertools::Itertools; use memmap2::Mmap; use qp_trie::Trie; @@ -250,6 +249,7 @@ fn list_adapters() -> AdapterList { let instance = Instance::new(); let list = instance .enumerate_adapters(backends) + .into_iter() .map(|adapter| adapter.get_info()) .map(|info| format!("{} ({:?})", info.name, info.backend)) .collect(); @@ -265,7 +265,7 @@ async fn create_context(adapter: AdapterOption, info: &ModelInfo) -> Result instance.select_adapter(backends, selection), }?; let context = ContextBuilder::new(adapter) - .with_auto_limits(info) + .auto_limits(info) .build() .await?; Ok(context) @@ -305,11 +305,7 @@ async fn load_runtime( ) -> Result { let ReloadRequest { model_path, - lora, - quant, - quant_type, max_batch, - embed_device, tokenizer_path, .. } = reload.clone(); @@ -320,81 +316,83 @@ async fn load_runtime( let file = File::open(model_path).await?; let data = unsafe { Mmap::map(&file) }?; - let runtime = match load { - LoadType::SafeTensors => { - let model = SafeTensors::deserialize(&data)?; - let quant = (0..quant).map(|layer| (layer, quant_type)).collect(); - let lora = { - let mut x = Vec::with_capacity(lora.len()); - for lora in lora.into_iter() { - let file = File::open(lora.path).await?; - let data = unsafe { Mmap::map(&file) }?; - let blend = LoraBlend::full(lora.alpha); - x.push((data, blend)) - } - x - }; - let lora: Vec<_> = lora - .iter() - .map(|(data, blend)| -> Result<_> { - let data = SafeTensors::deserialize(data)?; - let blend = blend.clone(); - Ok(Lora { data, blend }) - }) - .try_collect()?; - - let builder = ModelBuilder::new(context, model) - .with_quant(quant) - .with_num_batch(max_batch) - .with_embed_device(embed_device); - let builder = lora.into_iter().fold(builder, |b, x| b.add_lora(x)); - - let context = context.clone(); - let reload = reload.clone(); - match info.version { - ModelVersion::V4 => { - let builder = Build::>::build(builder).await?; - Runtime::new(context, builder, reload, tokenizer, vocab).await - } - ModelVersion::V5 => { - let builder = Build::>::build(builder).await?; - Runtime::new(context, builder, reload, tokenizer, vocab).await - } - ModelVersion::V6 => { - let builder = Build::>::build(builder).await?; - Runtime::new(context, builder, reload, tokenizer, vocab).await - } + async fn load_model( + context: &Context, + load: LoadType, + data: &[u8], + reload: &ReloadRequest, + ) -> Result + where + for<'a> ModelBuilder>: Build, + for<'de> Seed<'de, Context, M>: DeserializeSeed<'de, Value = M>, + { + match load { + LoadType::SafeTensors => { + let ReloadRequest { + lora, + quant, + quant_type, + embed_device, + .. + } = reload.clone(); + + let model = SafeTensors::deserialize(data)?; + let quant = (0..quant).map(|layer| (layer, quant_type)).collect(); + let lora = { + let mut x = Vec::with_capacity(lora.len()); + for lora in lora.into_iter() { + let file = File::open(lora.path).await?; + let data = unsafe { Mmap::map(&file) }?; + let blend = LoraBlend::full(lora.alpha); + x.push((data, blend)) + } + x + }; + let lora: Vec<_> = lora + .iter() + .map(|(data, blend)| -> Result<_> { + let data = SafeTensors::deserialize(data)?; + let blend = blend.clone(); + Ok(Lora { data, blend }) + }) + .try_collect()?; + + let builder = ModelBuilder::new(context, model) + .quant(quant) + .embed_device(embed_device); + let builder = lora.into_iter().fold(builder, |b, x| b.lora(x)); + Build::::build(builder).await } - } - LoadType::Prefab => { - use cbor4ii::{core::utils::SliceReader, serde::Deserializer}; - - let reader = SliceReader::new(&data); - let mut deserializer = Deserializer::new(reader); - - let context = context.clone(); - let reload = reload.clone(); - match info.version { - ModelVersion::V4 => { - let seed: Seed<_, v4::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v4::ModelJobBuilder::new(model, reload.max_batch); - Runtime::new(context, builder, reload, tokenizer, vocab).await - } - ModelVersion::V5 => { - let seed: Seed<_, v5::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v5::ModelJobBuilder::new(model, reload.max_batch); - Runtime::new(context, builder, reload, tokenizer, vocab).await - } - ModelVersion::V6 => { - let seed: Seed<_, v6::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v6::ModelJobBuilder::new(model, reload.max_batch); - Runtime::new(context, builder, reload, tokenizer, vocab).await - } + LoadType::Prefab => { + use cbor4ii::{core::utils::SliceReader, serde::Deserializer}; + + let reader = SliceReader::new(data); + let mut deserializer = Deserializer::new(reader); + + let seed: Seed = Seed::new(context); + Ok(seed.deserialize(&mut deserializer)?) } } + } + + let context = context.clone(); + let reload = reload.clone(); + let runtime = match info.version { + ModelVersion::V4 => { + let model = load_model::(&context, load, &data, &reload).await?; + let builder = v4::ModelJobBuilder::new(model, max_batch); + Runtime::new(context, builder, reload, tokenizer, vocab).await + } + ModelVersion::V5 => { + let model = load_model::(&context, load, &data, &reload).await?; + let builder = v5::ModelJobBuilder::new(model, max_batch); + Runtime::new(context, builder, reload, tokenizer, vocab).await + } + ModelVersion::V6 => { + let model = load_model::(&context, load, &data, &reload).await?; + let builder = v6::ModelJobBuilder::new(model, max_batch); + Runtime::new(context, builder, reload, tokenizer, vocab).await + } }; Ok(runtime)