From ec4c9e1c95a05206681d122c70d9176ec68f0eae Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Sun, 7 Jan 2024 12:14:05 +0000 Subject: [PATCH] Improve test, minor cleanups --- package.json | 2 +- src/quantized_mistral.rs | 14 +++++--------- tests/web.rs | 12 +++++++++++- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/package.json b/package.json index 49aca4c..88d11c7 100644 --- a/package.json +++ b/package.json @@ -10,7 +10,7 @@ "build": "wasm-pack build -s text-yoga --dev", "build:release": "wasm-pack build -s text-yoga --release", "test:server": "npx http-server --cors -p 31300 ./tests/data", - "test:chrome": "wasm-pack test --chrome --chromedriver \"$(which chromedriver)\" --headless", + "test:chrome": "RUST_LOG=wasm_bindgen_test_runner wasm-pack -vvv test --chrome --chromedriver \"$(which chromedriver)\" --headless", "test:gecko": "wasm-pack test --firefox --geckodriver \"$(which geckodriver)\" --headless" }, "keywords": [], diff --git a/src/quantized_mistral.rs b/src/quantized_mistral.rs index e811212..a1f724f 100644 --- a/src/quantized_mistral.rs +++ b/src/quantized_mistral.rs @@ -11,7 +11,7 @@ use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; use web_time as time; -use gloo::console::log; +use gloo::console::{debug, log}; use web_sys::console; #[wasm_bindgen] @@ -84,9 +84,6 @@ impl Model { let mut total_size_in_bytes = 0; for (_, tensor) in model.tensor_infos.iter() { let elem_count = tensor.shape.elem_count(); - log!("elem_count", elem_count); - log!("type_size", tensor.ggml_dtype.type_size()); - log!("blck_size", tensor.ggml_dtype.blck_size()); total_size_in_bytes += // Very important to keep the parenthesis here, otherwise might overflow (in test). elem_count * (tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size()); @@ -128,11 +125,10 @@ impl Model { repeat_penalty: f32, seed: u64, ) -> Result { - console::log_1(&"Initialising prompt with temperature=".into()); - console::log_2(&"temperature=".into(), &temp.into()); - console::log_2(&"top_p=".into(), &top_p.into()); - console::log_2(&"repeat_penalty=".into(), &repeat_penalty.into()); - console::log_2(&"seed=".into(), &seed.into()); + debug!( + "Initialising prompt with temperature=", + temp, "top_p=", top_p, "repeat_penalty=", repeat_penalty, "seed=", seed + ); // First reset the cache. // { // let mut cache = self.inner.cache.kvs.lock().unwrap(); diff --git a/tests/web.rs b/tests/web.rs index b97966b..624cab0 100644 --- a/tests/web.rs +++ b/tests/web.rs @@ -39,9 +39,19 @@ async fn pass() -> Result<(), JsValue> { let top_p: f64 = 1.; let repeat_penalty: f32 = 1.1; let seed: u64 = 203948203948; + let max_token = 20; let first_result: String = model.init_with_prompt(prompt, temp, top_p, repeat_penalty, seed)?; log!("first prompt result", &first_result); - assert_eq!(1 + 1, 2); + + let mut result = first_result; + for index in 0..max_token { + let tok = model.next_token()?; + result += &tok; + log!(tok); + } + log!("\nComplete result:"); + log!(result); + let result = assert_eq!(1 + 1, 2); Ok(()) }