Skip to content

Commit

Permalink
Select adapter.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Aug 1, 2023
1 parent c413f38 commit 2157236
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 7 deletions.
59 changes: 56 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ai00_server"
version = "0.1.9"
version = "0.1.10"
edition = "2021"
authors = ["Gu ZhenNiu <[email protected]>", "Zhang Zhenyuan <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand All @@ -18,9 +18,10 @@ axum = { git = "https://github.com/cryscan/axum", branch = "sse-leading-space" }
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.4", features = ["full"] }
tokio = { version = "1", features = ["full"] }
web-rwkv = "0.1.16"
web-rwkv = "0.1.17"
memmap = "0.7"
regex = "1.8"
dialoguer = "0.10"
clap = { version = "4.3", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
Expand Down
21 changes: 19 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use axum::{
Router,
};
use clap::Parser;
use dialoguer::{theme::ColorfulTheme, Select};
use flume::Receiver;
use memmap::Mmap;
use qp_trie::Trie;
Expand All @@ -19,7 +20,8 @@ use std::{
};
use tower_http::{cors::CorsLayer, services::ServeDir};
use web_rwkv::{
BackedModelState, Environment, LayerFlags, Model, ModelBuilder, Quantization, Tokenizer,
BackedModelState, Environment, Instance, LayerFlags, Model, ModelBuilder, Quantization,
Tokenizer,
};

mod chat;
Expand Down Expand Up @@ -117,6 +119,21 @@ pub struct ReloadRequest {
pub quantized_layers: Vec<usize>,
}

async fn create_environment() -> Result<Environment> {
let instance = Instance::new();
let adapters = instance.adapters();
let selection = Select::with_theme(&ColorfulTheme::default())
.with_prompt("Please select an adapter")
.default(0)
.items(&adapters)
.interact()?;

let adapter = instance.select_adapter(selection)?;
let env = Environment::new(adapter).await?;
println!("{:#?}", env.adapter.get_info());
Ok(env)
}

fn load_tokenizer(path: &PathBuf) -> Result<Tokenizer> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
Expand Down Expand Up @@ -368,7 +385,7 @@ async fn main() -> Result<()> {
);

let (sender, receiver) = flume::unbounded::<ThreadRequest>();
let env = Environment::create().await?;
let env = create_environment().await?;
let tokenizer = load_tokenizer(&tokenizer_path)?;

log::info!("{:#?}", env.adapter.get_info());
Expand Down

0 comments on commit 2157236

Please sign in to comment.