From 2157236c7aa736415ce68ae60042d1024515df44 Mon Sep 17 00:00:00 2001 From: cryscan Date: Tue, 1 Aug 2023 14:01:51 +0800 Subject: [PATCH] Select adapter. --- Cargo.lock | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++--- Cargo.toml | 5 +++-- src/main.rs | 21 +++++++++++++++++-- 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a17fd432..551b87fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,11 +62,12 @@ dependencies = [ [[package]] name = "ai00_server" -version = "0.1.9" +version = "0.1.10" dependencies = [ "anyhow", "axum", "clap", + "dialoguer", "fastrand", "flume", "futures-util", @@ -511,6 +512,19 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf43edc576402991846b093a7ca18a3477e0ef9c588cde84964b5d3e43016642" +[[package]] +name = "console" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.45.0", +] + [[package]] name = "constant_time_eq" version = "0.1.5" @@ -609,6 +623,18 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "dialoguer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59c6f2989294b9a498d3ad5491a79c6deb604617378e1cdc4bfc1c1361fe2f87" +dependencies = [ + "console", + "shell-words", + "tempfile", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -626,6 +652,12 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "errno" version = "0.3.1" @@ -1636,6 +1668,12 @@ dependencies = [ "digest", ] +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -2123,9 +2161,9 @@ checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "web-rwkv" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e78c58d2a71eaa5f37264059731deacefad97d16cadcd4039138def6c4b9e1b" +checksum = "abceb817a7f97c9ad12ada520f53d7790704409a3e1c1d498e2e26788afc6af9" dependencies = [ "ahash 0.8.3", "anyhow", @@ -2312,6 +2350,15 @@ dependencies = [ "windows_x86_64_msvc 0.42.2", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -2435,6 +2482,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "zeroize" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index cb376758..9cf02f38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ai00_server" -version = "0.1.9" +version = "0.1.10" edition = "2021" authors = ["Gu ZhenNiu <448885@qq.com>", "Zhang Zhenyuan "] license = "MIT OR Apache-2.0" @@ -18,9 +18,10 @@ axum = { git = "https://github.com/cryscan/axum", branch = "sse-leading-space" } tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.4", features = ["full"] } tokio = { version = "1", features = ["full"] } -web-rwkv = "0.1.16" +web-rwkv = "0.1.17" memmap = "0.7" regex = "1.8" +dialoguer = "0.10" clap = { version = "4.3", features = ["derive"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1" diff --git a/src/main.rs b/src/main.rs index 529df39d..6b94ddc0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ use axum::{ Router, }; use clap::Parser; +use dialoguer::{theme::ColorfulTheme, Select}; use flume::Receiver; use memmap::Mmap; use qp_trie::Trie; @@ -19,7 +20,8 @@ use std::{ }; use tower_http::{cors::CorsLayer, services::ServeDir}; use web_rwkv::{ - BackedModelState, Environment, LayerFlags, Model, ModelBuilder, Quantization, Tokenizer, + BackedModelState, Environment, Instance, LayerFlags, Model, ModelBuilder, Quantization, + Tokenizer, }; mod chat; @@ -117,6 +119,21 @@ pub struct ReloadRequest { pub quantized_layers: Vec, } +async fn create_environment() -> Result { + let instance = Instance::new(); + let adapters = instance.adapters(); + let selection = Select::with_theme(&ColorfulTheme::default()) + .with_prompt("Please select an adapter") + .default(0) + .items(&adapters) + .interact()?; + + let adapter = instance.select_adapter(selection)?; + let env = Environment::new(adapter).await?; + println!("{:#?}", env.adapter.get_info()); + Ok(env) +} + fn load_tokenizer(path: &PathBuf) -> Result { let file = File::open(path)?; let mut reader = BufReader::new(file); @@ -368,7 +385,7 @@ async fn main() -> Result<()> { ); let (sender, receiver) = flume::unbounded::(); - let env = Environment::create().await?; + let env = create_environment().await?; let tokenizer = load_tokenizer(&tokenizer_path)?; log::info!("{:#?}", env.adapter.get_info());