From 718066b9d1d17ecdb417b94ee661afd38098c06f Mon Sep 17 00:00:00 2001 From: bskrlj Date: Tue, 21 Nov 2023 21:52:13 +0100 Subject: [PATCH 01/21] does this actually work? --- src/block_ffm.rs | 48 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/src/block_ffm.rs b/src/block_ffm.rs index 952be685..3e2c779e 100644 --- a/src/block_ffm.rs +++ b/src/block_ffm.rs @@ -7,6 +7,7 @@ use std::error::Error; use std::mem::{self, MaybeUninit}; use std::sync::Mutex; use std::{io, ptr}; +use std::slice; use merand48::*; @@ -829,7 +830,19 @@ impl BlockTrait for BlockFFM { &self, output_bufwriter: &mut dyn io::Write, ) -> Result<(), Box> { - block_helpers::write_weights_to_buf(&self.weights, output_bufwriter)?; + + let mut v = Vec::<[u8; 3]>::with_capacity(self.weights.len()); + for weight in self.weights.iter() { + let tmp_bytes = weight.to_be_bytes(); + let mut tmp_vec = [0, 0, 0]; + tmp_vec[0] = tmp_bytes[0]; + tmp_vec[1] = tmp_bytes[1]; + tmp_vec[2] = tmp_bytes[2]; + v.push(tmp_vec); + } + +// block_helpers::write_weights_to_buf(&self.weights, output_bufwriter)?; + block_helpers::write_weights_to_buf(&v, output_bufwriter)?; block_helpers::write_weights_to_buf(&self.optimizer, output_bufwriter)?; Ok(()) } @@ -838,7 +851,38 @@ impl BlockTrait for BlockFFM { &mut self, input_bufreader: &mut dyn io::Read, ) -> Result<(), Box> { - block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader)?; + + unsafe { + let buf_view: &mut [u8] = slice::from_raw_parts_mut( + self.weights.as_mut_ptr() as *mut u8, + self.weights.len() * 3, + ); + + let mut tmp_weights: Vec = Vec::new(); + input_bufreader.read_exact(buf_view)?; + for (wb_index, weight_byte) in buf_view.iter().enumerate(){ + if wb_index > 0 && wb_index % 4 == 0 { + tmp_weights.push(0 as u8); + } else { + tmp_weights.push(*weight_byte); + } + } + + for (index, byte_array) in tmp_weights.chunks(4).enumerate() { + + let mut out_ary: [u8; 4] = [0; 4]; + + out_ary[0] = byte_array[0]; + out_ary[1] = byte_array[1]; + out_ary[2] = byte_array[2]; +// out_ary[3] = byte_array[3]; + + let float: f32 = f32::from_be_bytes(out_ary); + self.weights[index] = float; + } + } + + //block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader)?; block_helpers::read_weights_from_buf(&mut self.optimizer, input_bufreader)?; Ok(()) } From 2df835aa50dee4cbdfab8788e968fb9b1ecca4b1 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Thu, 23 Nov 2023 14:17:15 +0100 Subject: [PATCH 02/21] operational 3by --- src/block_ffm.rs | 65 +++++++++++++------------------------------- src/block_helpers.rs | 2 ++ src/block_lr.rs | 6 ++-- src/block_neural.rs | 12 ++++---- src/cmdline.rs | 5 ++++ src/lib.rs | 1 + src/main.rs | 6 ++-- src/persistence.rs | 16 ++++++++--- src/quantization.rs | 45 ++++++++++++++++++++++++++++++ src/regressor.rs | 13 ++++++--- 10 files changed, 107 insertions(+), 64 deletions(-) create mode 100644 src/quantization.rs diff --git a/src/block_ffm.rs b/src/block_ffm.rs index 3e2c779e..c1a042d9 100644 --- a/src/block_ffm.rs +++ b/src/block_ffm.rs @@ -7,7 +7,6 @@ use std::error::Error; use std::mem::{self, MaybeUninit}; use std::sync::Mutex; use std::{io, ptr}; -use std::slice; use merand48::*; @@ -24,6 +23,7 @@ use crate::optimizer; use crate::port_buffer; use crate::port_buffer::PortBuffer; use crate::regressor; +use crate::quantization; use crate::regressor::{BlockCache, FFM_CONTRA_BUF_LEN}; const FFM_STACK_BUF_LEN: usize = 131072; @@ -829,61 +829,34 @@ impl BlockTrait for BlockFFM { fn write_weights_to_buf( &self, output_bufwriter: &mut dyn io::Write, + use_quantization: bool ) -> Result<(), Box> { - let mut v = Vec::<[u8; 3]>::with_capacity(self.weights.len()); - for weight in self.weights.iter() { - let tmp_bytes = weight.to_be_bytes(); - let mut tmp_vec = [0, 0, 0]; - tmp_vec[0] = tmp_bytes[0]; - tmp_vec[1] = tmp_bytes[1]; - tmp_vec[2] = tmp_bytes[2]; - v.push(tmp_vec); + if use_quantization { + + let quantized_weights = quantization::quantize_ffm_weights_3by(&self.weights); + block_helpers::write_weights_to_buf(&quantized_weights, output_bufwriter, false)?; + } else { + block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?; } - -// block_helpers::write_weights_to_buf(&self.weights, output_bufwriter)?; - block_helpers::write_weights_to_buf(&v, output_bufwriter)?; - block_helpers::write_weights_to_buf(&self.optimizer, output_bufwriter)?; + block_helpers::write_weights_to_buf(&self.optimizer, output_bufwriter, false)?; Ok(()) } fn read_weights_from_buf( &mut self, input_bufreader: &mut dyn io::Read, + use_quantization: bool ) -> Result<(), Box> { - - unsafe { - let buf_view: &mut [u8] = slice::from_raw_parts_mut( - self.weights.as_mut_ptr() as *mut u8, - self.weights.len() * 3, - ); - - let mut tmp_weights: Vec = Vec::new(); - input_bufreader.read_exact(buf_view)?; - for (wb_index, weight_byte) in buf_view.iter().enumerate(){ - if wb_index > 0 && wb_index % 4 == 0 { - tmp_weights.push(0 as u8); - } else { - tmp_weights.push(*weight_byte); - } - } - - for (index, byte_array) in tmp_weights.chunks(4).enumerate() { - - let mut out_ary: [u8; 4] = [0; 4]; - - out_ary[0] = byte_array[0]; - out_ary[1] = byte_array[1]; - out_ary[2] = byte_array[2]; -// out_ary[3] = byte_array[3]; - - let float: f32 = f32::from_be_bytes(out_ary); - self.weights[index] = float; - } - } - //block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader)?; - block_helpers::read_weights_from_buf(&mut self.optimizer, input_bufreader)?; + if use_quantization { + // in-place expand weights via dequantization (for inference) + quantization::dequantize_ffm_weights_3by(input_bufreader, &mut self.weights); + } else { + block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?; + } + + block_helpers::read_weights_from_buf(&mut self.optimizer, input_bufreader, false)?; Ok(()) } @@ -910,7 +883,7 @@ impl BlockTrait for BlockFFM { .as_any() .downcast_mut::>() .unwrap(); - block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader)?; + block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?; block_helpers::skip_weights_from_buf::>( self.ffm_weights_len as usize, input_bufreader, diff --git a/src/block_helpers.rs b/src/block_helpers.rs index 89a8d571..27022df0 100644 --- a/src/block_helpers.rs +++ b/src/block_helpers.rs @@ -43,6 +43,7 @@ macro_rules! assert_epsilon { pub fn read_weights_from_buf( weights: &mut Vec, input_bufreader: &mut dyn io::Read, + _use_quantization: bool ) -> Result<(), Box> { if weights.is_empty() { return Err("Loading weights to unallocated weighs buffer".to_string())?; @@ -74,6 +75,7 @@ pub fn skip_weights_from_buf( pub fn write_weights_to_buf( weights: &Vec, output_bufwriter: &mut dyn io::Write, + _use_quantization: bool ) -> Result<(), Box> { if weights.is_empty() { assert!(false); diff --git a/src/block_lr.rs b/src/block_lr.rs index ba7e67cc..5bbd17ba 100644 --- a/src/block_lr.rs +++ b/src/block_lr.rs @@ -263,15 +263,17 @@ impl BlockTrait for BlockLR { fn read_weights_from_buf( &mut self, input_bufreader: &mut dyn io::Read, + _use_quantization: bool ) -> Result<(), Box> { - block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader) + block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false) } fn write_weights_to_buf( &self, output_bufwriter: &mut dyn io::Write, + _use_quantization: bool ) -> Result<(), Box> { - block_helpers::write_weights_to_buf(&self.weights, output_bufwriter) + block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false) } fn read_weights_from_buf_into_forward_only( diff --git a/src/block_neural.rs b/src/block_neural.rs index fee21f4d..7c159c06 100644 --- a/src/block_neural.rs +++ b/src/block_neural.rs @@ -430,18 +430,20 @@ impl BlockTrait for BlockNeuronLayer { fn write_weights_to_buf( &self, output_bufwriter: &mut dyn io::Write, + _use_quantization: bool ) -> Result<(), Box> { - block_helpers::write_weights_to_buf(&self.weights, output_bufwriter)?; - block_helpers::write_weights_to_buf(&self.weights_optimizer, output_bufwriter)?; + block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?; + block_helpers::write_weights_to_buf(&self.weights_optimizer, output_bufwriter, false)?; Ok(()) } fn read_weights_from_buf( &mut self, input_bufreader: &mut dyn io::Read, + _use_quantization: bool ) -> Result<(), Box> { - block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader)?; - block_helpers::read_weights_from_buf(&mut self.weights_optimizer, input_bufreader)?; + block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?; + block_helpers::read_weights_from_buf(&mut self.weights_optimizer, input_bufreader, false)?; Ok(()) } @@ -469,7 +471,7 @@ impl BlockTrait for BlockNeuronLayer { .as_any() .downcast_mut::>() .unwrap(); - block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader)?; + block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?; block_helpers::skip_weights_from_buf::>( self.weights_len as usize, input_bufreader, diff --git a/src/cmdline.rs b/src/cmdline.rs index 2279c91b..5b749c03 100644 --- a/src/cmdline.rs +++ b/src/cmdline.rs @@ -309,6 +309,11 @@ pub fn create_expected_args<'a>() -> App<'a, 'a> { .value_name("num_threads") .help("Number of threads to use with hogwild training") .takes_value(true)) + .arg(Arg::with_name("weight_quantization") + .long("weight_quantization") + .value_name("Whether to consider weight quantization when reading/writing weights.") + .help("Half-float quantization trigger (inference only is the suggested use).") + .takes_value(false)) .arg(Arg::with_name("predictions_stdout") .long("predictions_stdout") .value_name("Output predictions to stdout") diff --git a/src/lib.rs b/src/lib.rs index 0c4b8231..f14408cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod quantization; pub mod block_ffm; pub mod block_helpers; pub mod block_loss_functions; diff --git a/src/main.rs b/src/main.rs index 4a23b243..2457ad16 100644 --- a/src/main.rs +++ b/src/main.rs @@ -114,7 +114,7 @@ fn main2() -> Result<(), Box> { }; let testonly = cl.is_present("testonly"); - + let quantize_weights = cl.is_present("weight_quantization"); let final_regressor_filename = cl.value_of("final_regressor"); let output_pred_sto: bool = cl.is_present("predictions_stdout"); if let Some(filename) = final_regressor_filename { @@ -149,7 +149,7 @@ fn main2() -> Result<(), Box> { new_regressor_from_filename(filename, true, Option::Some(&cl))?; mi2.optimizer = Optimizer::SGD; if let Some(filename1) = inference_regressor_filename { - save_regressor_to_filename(filename1, &mi2, &vw2, re_fixed).unwrap() + save_regressor_to_filename(filename1, &mi2, &vw2, re_fixed, quantize_weights).unwrap() } } else { let vw: VwNamespaceMap; @@ -296,7 +296,7 @@ fn main2() -> Result<(), Box> { log::info!("Elapsed: {:.2?} rows: {}", elapsed, example_num); if let Some(filename) = final_regressor_filename { - save_sharable_regressor_to_filename(filename, &mi, &vw, sharable_regressor) + save_sharable_regressor_to_filename(filename, &mi, &vw, sharable_regressor, quantize_weights) .unwrap() } } diff --git a/src/persistence.rs b/src/persistence.rs index c32ac0db..6bbf4fc8 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -57,6 +57,7 @@ pub fn save_sharable_regressor_to_filename( mi: &model_instance::ModelInstance, vwmap: &vwmap::VwNamespaceMap, re: BoxedRegressorTrait, + quantize_weights: bool ) -> Result<(), Box> { let output_bufwriter = &mut io::BufWriter::new( fs::File::create(filename) @@ -65,7 +66,7 @@ pub fn save_sharable_regressor_to_filename( write_regressor_header(output_bufwriter)?; vwmap.save_to_buf(output_bufwriter)?; mi.save_to_buf(output_bufwriter)?; - re.write_weights_to_buf(output_bufwriter)?; + re.write_weights_to_buf(output_bufwriter, quantize_weights)?; Ok(()) } @@ -74,6 +75,7 @@ pub fn save_regressor_to_filename( mi: &model_instance::ModelInstance, vwmap: &vwmap::VwNamespaceMap, re: Regressor, + quantize_weights: bool, ) -> Result<(), Box> { let output_bufwriter = &mut io::BufWriter::new( fs::File::create(filename) @@ -82,7 +84,7 @@ pub fn save_regressor_to_filename( write_regressor_header(output_bufwriter)?; vwmap.save_to_buf(output_bufwriter)?; mi.save_to_buf(output_bufwriter)?; - re.write_weights_to_buf(output_bufwriter)?; + re.write_weights_to_buf(output_bufwriter, quantize_weights)?; Ok(()) } @@ -136,9 +138,15 @@ pub fn new_regressor_from_filename( > { let mut input_bufreader = io::BufReader::new(fs::File::open(filename).unwrap()); let (mut mi, vw, mut re) = load_regressor_without_weights(&mut input_bufreader, cmd_arguments)?; + + // reading logic is for some reason different, so doing this again here .. + let quantization_flag = cmd_arguments.unwrap().is_present("weight_quantization"); + let conversion_flag = cmd_arguments.unwrap().is_present("convert_inference_regressor"); + let weight_quantization = quantization_flag && !conversion_flag; + log::info!("Reading weights, dequantization enabled: {}", weight_quantization); if !immutable { re.allocate_and_init_weights(&mi); - re.overwrite_weights_from_buf(&mut input_bufreader)?; + re.overwrite_weights_from_buf(&mut input_bufreader, weight_quantization)?; Ok((mi, vw, re)) } else { mi.optimizer = model_instance::Optimizer::SGD; @@ -154,7 +162,7 @@ pub fn hogwild_load(re: &mut regressor::Regressor, filename: &str) -> Result<(), let (_, _, mut re_hw) = load_regressor_without_weights(&mut input_bufreader, None)?; // TODO: Here we should do safety comparison that the regressor is really the same; if !re.immutable { - re.overwrite_weights_from_buf(&mut input_bufreader)?; + re.overwrite_weights_from_buf(&mut input_bufreader, false)?; } else { re_hw.into_immutable_regressor_from_buf(re, &mut input_bufreader)?; } diff --git a/src/quantization.rs b/src/quantization.rs new file mode 100644 index 00000000..b1b51f91 --- /dev/null +++ b/src/quantization.rs @@ -0,0 +1,45 @@ + +use std::slice; +use std::io; + + +pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; 3]> { + // Quantize float-based weights to three most significant bytes + + let mut v = Vec::<[u8; 3]>::with_capacity(weights.len()); + for &weight in weights { + let tmp_bytes = weight.to_be_bytes(); + let tmp_vec = [tmp_bytes[0], tmp_bytes[1], tmp_bytes[2]]; + v.push(tmp_vec); + } + debug_assert_eq!(v.len(), weights.len()); + v +} + + +pub fn dequantize_ffm_weights_3by(input_bufreader: &mut dyn io::Read, reference_weights: &mut Vec) { + // This function overwrites existing weights with dequantized ones from the input buffer. + + unsafe { + let buf_view: &mut [u8] = slice::from_raw_parts_mut( + reference_weights.as_mut_ptr() as *mut u8, + reference_weights.len() * 3, + ); + let _ = input_bufreader.read_exact(buf_view); + + let tmp_weights: Vec = buf_view + .chunks(3) + .flat_map(|chunk| chunk.iter().chain(std::iter::once(&0u8)).cloned()) + .collect(); + + for (chunk, float_ref) in tmp_weights.chunks(4).zip(reference_weights.iter_mut()) { + + let mut out_ary: [u8; 4] = [0; 4]; + out_ary[0] = chunk[0]; + out_ary[1] = chunk[1]; + out_ary[2] = chunk[2]; + + *float_ref = f32::from_be_bytes(out_ary); + } + } +} diff --git a/src/regressor.rs b/src/regressor.rs index 9626b989..f45c63f6 100644 --- a/src/regressor.rs +++ b/src/regressor.rs @@ -100,6 +100,7 @@ pub trait BlockTrait { fn write_weights_to_buf( &self, _output_bufwriter: &mut dyn io::Write, + _use_quantization: bool ) -> Result<(), Box> { Ok(()) } @@ -107,6 +108,7 @@ pub trait BlockTrait { fn read_weights_from_buf( &mut self, _input_bufreader: &mut dyn io::Read, + _use_quantization: bool ) -> Result<(), Box> { Ok(()) } @@ -423,6 +425,7 @@ impl Regressor { pub fn write_weights_to_buf( &self, output_bufwriter: &mut dyn io::Write, + quantize_weights: bool ) -> Result<(), Box> { let length = self .blocks_boxes @@ -430,9 +433,9 @@ impl Regressor { .map(|block| block.get_serialized_len()) .sum::() as u64; output_bufwriter.write_u64::(length)?; - + log::info!("Write Quantization enabled: {}", quantize_weights); for v in &self.blocks_boxes { - v.write_weights_to_buf(output_bufwriter)?; + v.write_weights_to_buf(output_bufwriter, quantize_weights)?; } Ok(()) } @@ -440,6 +443,7 @@ impl Regressor { pub fn overwrite_weights_from_buf( &mut self, input_bufreader: &mut dyn io::Read, + use_quantization: bool ) -> Result<(), Box> { // This is a bit weird format // You would expect each block to have its own sig @@ -457,7 +461,7 @@ impl Regressor { ))?; } for v in &mut self.blocks_boxes { - v.read_weights_from_buf(input_bufreader)?; + v.read_weights_from_buf(input_bufreader, use_quantization)?; } Ok(()) @@ -505,6 +509,7 @@ impl Regressor { pub fn immutable_regressor( &mut self, mi: &model_instance::ModelInstance, + use_quantization: bool ) -> Result> { // Only to be used by unit tests // make sure we are creating immutable regressor from SGD mi @@ -515,7 +520,7 @@ impl Regressor { let mut tmp_vec: Vec = Vec::new(); for (i, v) in &mut self.blocks_boxes.iter().enumerate() { let mut cursor = Cursor::new(&mut tmp_vec); - v.write_weights_to_buf(&mut cursor)?; + v.write_weights_to_buf(&mut cursor, use_quantization)?; cursor.set_position(0); v.read_weights_from_buf_into_forward_only(&mut cursor, &mut rg.blocks_boxes[i])?; } From 08d32bcae07b5148ef30a6ef5b263af82bda70b7 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Thu, 23 Nov 2023 14:40:31 +0100 Subject: [PATCH 03/21] tests --- src/persistence.rs | 25 ++++++++++++++++--------- src/serving.rs | 8 ++++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/persistence.rs b/src/persistence.rs index 6bbf4fc8..4f5a9cba 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -140,8 +140,15 @@ pub fn new_regressor_from_filename( let (mut mi, vw, mut re) = load_regressor_without_weights(&mut input_bufreader, cmd_arguments)?; // reading logic is for some reason different, so doing this again here .. - let quantization_flag = cmd_arguments.unwrap().is_present("weight_quantization"); - let conversion_flag = cmd_arguments.unwrap().is_present("convert_inference_regressor"); + + let mut quantization_flag = false; + let mut conversion_flag = false; + + if cmd_arguments.is_some(){ + quantization_flag = cmd_arguments.unwrap().is_present("weight_quantization"); + conversion_flag = cmd_arguments.unwrap().is_present("convert_inference_regressor"); + } + let weight_quantization = quantization_flag && !conversion_flag; log::info!("Reading weights, dequantization enabled: {}", weight_quantization); if !immutable { @@ -217,7 +224,7 @@ B,featureB let rr = regressor::get_regressor_with_weights(&mi); let dir = tempfile::tempdir().unwrap(); let regressor_filepath = dir.path().join("test_regressor.fw"); - save_regressor_to_filename(regressor_filepath.to_str().unwrap(), &mi, &vw, rr).unwrap(); + save_regressor_to_filename(regressor_filepath.to_str().unwrap(), &mi, &vw, rr, false).unwrap(); } fn lr_vec(v: Vec) -> feature_buffer::FeatureBuffer { @@ -268,7 +275,7 @@ B,featureB // Now we test conversion to fixed regressor { mi.optimizer = model_instance::Optimizer::SGD; - let re_fixed = re.immutable_regressor(&mi).unwrap(); + let re_fixed = re.immutable_regressor(&mi, false).unwrap(); // predict with the same feature vector assert_eq!(re_fixed.predict(fbuf, &mut pb), expected_result); mi.optimizer = model_instance::Optimizer::AdagradFlex; @@ -277,7 +284,7 @@ B,featureB { let dir = tempdir().unwrap(); let regressor_filepath = dir.path().join("test_regressor2.fw"); - save_regressor_to_filename(regressor_filepath.to_str().unwrap(), &mi, &vw, re).unwrap(); + save_regressor_to_filename(regressor_filepath.to_str().unwrap(), &mi, &vw, re, false).unwrap(); // a) load as regular regressor let (_mi2, _vw2, mut re2) = @@ -372,7 +379,7 @@ B,featureB // Now we test conversion to fixed regressor { mi.optimizer = Optimizer::SGD; - let re_fixed = re.immutable_regressor(&mi).unwrap(); + let re_fixed = re.immutable_regressor(&mi, false).unwrap(); // predict with the same feature vector mi.optimizer = Optimizer::AdagradFlex; assert_epsilon!(re_fixed.predict(fbuf, &mut pb), expected_result); @@ -381,7 +388,7 @@ B,featureB { let dir = tempdir().unwrap(); let regressor_filepath = dir.path().join("test_regressor2.fw"); - save_regressor_to_filename(regressor_filepath.to_str().unwrap(), &mi, &vw, re).unwrap(); + save_regressor_to_filename(regressor_filepath.to_str().unwrap(), &mi, &vw, re, false).unwrap(); // a) load as regular regressor let (_mi2, _vw2, mut re2) = @@ -545,14 +552,14 @@ B,featureB .to_str() .unwrap() .to_owned(); - save_regressor_to_filename(®ressor_filepath_1, &mi, &vw, re_1).unwrap(); + save_regressor_to_filename(®ressor_filepath_1, &mi, &vw, re_1, false).unwrap(); let regressor_filepath_2 = dir .path() .join("test_regressor2.fw") .to_str() .unwrap() .to_owned(); - save_regressor_to_filename(®ressor_filepath_2, &mi, &vw, re_2).unwrap(); + save_regressor_to_filename(®ressor_filepath_2, &mi, &vw, re_2, false).unwrap(); // The mutable path let (_mi1, _vw1, mut new_re_1) = diff --git a/src/serving.rs b/src/serving.rs index 87ee5315..40442188 100644 --- a/src/serving.rs +++ b/src/serving.rs @@ -287,7 +287,7 @@ C,featureC mi.optimizer = model_instance::Optimizer::AdagradLUT; let mut re = regressor::Regressor::new(&mi); mi.optimizer = model_instance::Optimizer::SGD; - let re_fixed = BoxedRegressorTrait::new(Box::new(re.immutable_regressor(&mi).unwrap())); + let re_fixed = BoxedRegressorTrait::new(Box::new(re.immutable_regressor(&mi, false).unwrap())); let fbt = feature_buffer::FeatureBufferTranslator::new(&mi); let pa = parser::VowpalParser::new(&vw); let pb = re_fixed.new_portbuffer(); @@ -386,7 +386,7 @@ C,featureC .to_str() .unwrap() .to_owned(); - persistence::save_regressor_to_filename(®ressor_filepath_1, &mi, &vw, re_1).unwrap(); + persistence::save_regressor_to_filename(®ressor_filepath_1, &mi, &vw, re_1, false).unwrap(); let regressor_filepath_2 = dir .path() @@ -394,13 +394,13 @@ C,featureC .to_str() .unwrap() .to_owned(); - persistence::save_regressor_to_filename(®ressor_filepath_2, &mi, &vw, re_2).unwrap(); + persistence::save_regressor_to_filename(®ressor_filepath_2, &mi, &vw, re_2, false).unwrap(); // OK NOW EVERYTHING IS READY... Let's start mi.optimizer = model_instance::Optimizer::AdagradLUT; let mut re = regressor::Regressor::new(&mi); mi.optimizer = model_instance::Optimizer::SGD; - let re_fixed = BoxedRegressorTrait::new(Box::new(re.immutable_regressor(&mi).unwrap())); + let re_fixed = BoxedRegressorTrait::new(Box::new(re.immutable_regressor(&mi, false).unwrap())); let fbt = feature_buffer::FeatureBufferTranslator::new(&mi); let pa = parser::VowpalParser::new(&vw); let pb = re_fixed.new_portbuffer(); From c604e9f694b511eea76072f7cb939645b1fd6b13 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Mon, 27 Nov 2023 07:06:46 +0100 Subject: [PATCH 04/21] quantization internal for convert --- src/main.rs | 3 +++ src/model_instance.rs | 4 ++++ src/persistence.rs | 2 +- src/quantization.rs | 36 +++++++++++++++++------------------- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2457ad16..a69e45dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -148,6 +148,9 @@ fn main2() -> Result<(), Box> { let (mut mi2, vw2, re_fixed) = new_regressor_from_filename(filename, true, Option::Some(&cl))?; mi2.optimizer = Optimizer::SGD; + if cl.is_present("weight_quantization") { + mi2.dequantize_weights = true; + } if let Some(filename1) = inference_regressor_filename { save_regressor_to_filename(filename1, &mi2, &vw2, re_fixed, quantize_weights).unwrap() } diff --git a/src/model_instance.rs b/src/model_instance.rs index ef0f88ec..f842fed4 100644 --- a/src/model_instance.rs +++ b/src/model_instance.rs @@ -92,6 +92,9 @@ pub struct ModelInstance { pub optimizer: Optimizer, pub transform_namespaces: feature_transform_parser::NamespaceTransforms, + + pub dequantize_weights: bool, + } fn default_u32_zero() -> u32 { @@ -142,6 +145,7 @@ impl ModelInstance { optimizer: Optimizer::SGD, transform_namespaces: feature_transform_parser::NamespaceTransforms::new(), nn_config: NNConfig::new(), + dequantize_weights: false, }; Ok(mi) } diff --git a/src/persistence.rs b/src/persistence.rs index 4f5a9cba..703f0020 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -145,7 +145,7 @@ pub fn new_regressor_from_filename( let mut conversion_flag = false; if cmd_arguments.is_some(){ - quantization_flag = cmd_arguments.unwrap().is_present("weight_quantization"); + quantization_flag = mi.dequantize_weights; conversion_flag = cmd_arguments.unwrap().is_present("convert_inference_regressor"); } diff --git a/src/quantization.rs b/src/quantization.rs index b1b51f91..324a3972 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -2,15 +2,19 @@ use std::slice; use std::io; +const BY_3: usize = 3; -pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; 3]> { +pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; BY_3]> { // Quantize float-based weights to three most significant bytes - let mut v = Vec::<[u8; 3]>::with_capacity(weights.len()); + let mut v = Vec::<[u8; BY_3]>::with_capacity(weights.len()); for &weight in weights { let tmp_bytes = weight.to_be_bytes(); - let tmp_vec = [tmp_bytes[0], tmp_bytes[1], tmp_bytes[2]]; - v.push(tmp_vec); + let mut out_ary: [u8; BY_3] = [0; BY_3]; + for k in 0..BY_3 { + out_ary[k] = tmp_bytes[k]; + } + v.push(out_ary); } debug_assert_eq!(v.len(), weights.len()); v @@ -23,23 +27,17 @@ pub fn dequantize_ffm_weights_3by(input_bufreader: &mut dyn io::Read, reference_ unsafe { let buf_view: &mut [u8] = slice::from_raw_parts_mut( reference_weights.as_mut_ptr() as *mut u8, - reference_weights.len() * 3, + reference_weights.len() * BY_3, ); let _ = input_bufreader.read_exact(buf_view); - let tmp_weights: Vec = buf_view - .chunks(3) - .flat_map(|chunk| chunk.iter().chain(std::iter::once(&0u8)).cloned()) - .collect(); + let mut out_ary: [u8; 4] = [0; 4]; + for (chunk, float_ref) in buf_view.chunks(3).zip(reference_weights.iter_mut()) { - for (chunk, float_ref) in tmp_weights.chunks(4).zip(reference_weights.iter_mut()) { - - let mut out_ary: [u8; 4] = [0; 4]; - out_ary[0] = chunk[0]; - out_ary[1] = chunk[1]; - out_ary[2] = chunk[2]; - - *float_ref = f32::from_be_bytes(out_ary); - } - } + for k in 0..BY_3 { + out_ary[k] = chunk[k]; + } + *float_ref = f32::from_be_bytes(out_ary); + } + } } From b44dd97d04fa3eb3fd181a081ac74d3e63198aa1 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Mon, 27 Nov 2023 07:16:59 +0100 Subject: [PATCH 05/21] quantization test --- src/quantization.rs | 48 ++++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/src/quantization.rs b/src/quantization.rs index 324a3972..490e8576 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -1,6 +1,5 @@ - -use std::slice; use std::io; +use std::slice; const BY_3: usize = 3; @@ -10,34 +9,47 @@ pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; BY_3]> { let mut v = Vec::<[u8; BY_3]>::with_capacity(weights.len()); for &weight in weights { let tmp_bytes = weight.to_be_bytes(); - let mut out_ary: [u8; BY_3] = [0; BY_3]; - for k in 0..BY_3 { - out_ary[k] = tmp_bytes[k]; - } + let mut out_ary: [u8; BY_3] = [0; BY_3]; + for k in 0..BY_3 { + out_ary[k] = tmp_bytes[k]; + } v.push(out_ary); } debug_assert_eq!(v.len(), weights.len()); v } - -pub fn dequantize_ffm_weights_3by(input_bufreader: &mut dyn io::Read, reference_weights: &mut Vec) { +pub fn dequantize_ffm_weights_3by( + input_bufreader: &mut dyn io::Read, + reference_weights: &mut Vec, +) { // This function overwrites existing weights with dequantized ones from the input buffer. unsafe { let buf_view: &mut [u8] = slice::from_raw_parts_mut( - reference_weights.as_mut_ptr() as *mut u8, - reference_weights.len() * BY_3, + reference_weights.as_mut_ptr() as *mut u8, + reference_weights.len() * BY_3, ); let _ = input_bufreader.read_exact(buf_view); - let mut out_ary: [u8; 4] = [0; 4]; - for (chunk, float_ref) in buf_view.chunks(3).zip(reference_weights.iter_mut()) { + let mut out_ary: [u8; 4] = [0; 4]; + for (chunk, float_ref) in buf_view.chunks(3).zip(reference_weights.iter_mut()) { + for k in 0..BY_3 { + out_ary[k] = chunk[k]; + } + *float_ref = f32::from_be_bytes(out_ary); + } + } +} - for k in 0..BY_3 { - out_ary[k] = chunk[k]; - } - *float_ref = f32::from_be_bytes(out_ary); - } - } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_quantize_3by() { + let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; + let output_weights = quantize_ffm_weights_3by(&some_random_float_weights); + assert_eq!(output_weights[3], [61, 252, 80]); + } } From 982dfc8e20c7fe87573be673d715e41b40852064 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Tue, 28 Nov 2023 09:04:11 +0100 Subject: [PATCH 06/21] bfloat ftw --- Cargo.toml | 1 + src/main.rs | 1 + src/quantization.rs | 29 ++++++++++++++++------------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0c8273b0..61ff8bf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ intel-mkl-src = {version= "0.8.1", default-features = false, features=["mkl-stat log = "0.4.18" env_logger = "0.10.0" rustc-hash = "1.1.0" +half = "2.3.1" [build-dependencies] cbindgen = "0.23.0" diff --git a/src/main.rs b/src/main.rs index a69e45dc..612c6777 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ use std::time::Instant; extern crate blas; extern crate intel_mkl_src; +extern crate half; #[macro_use] extern crate nom; diff --git a/src/quantization.rs b/src/quantization.rs index 490e8576..47c9942f 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -1,16 +1,18 @@ use std::io; use std::slice; +use half::bf16; -const BY_3: usize = 3; +const BY_X: usize = 2; -pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; BY_3]> { + +pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; BY_X]> { // Quantize float-based weights to three most significant bytes - let mut v = Vec::<[u8; BY_3]>::with_capacity(weights.len()); + let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len()); for &weight in weights { - let tmp_bytes = weight.to_be_bytes(); - let mut out_ary: [u8; BY_3] = [0; BY_3]; - for k in 0..BY_3 { + let tmp_bytes = (bf16::from_f32(weight)).to_be_bytes(); + let mut out_ary: [u8; BY_X] = [0; BY_X]; + for k in 0..BY_X { out_ary[k] = tmp_bytes[k]; } v.push(out_ary); @@ -28,16 +30,17 @@ pub fn dequantize_ffm_weights_3by( unsafe { let buf_view: &mut [u8] = slice::from_raw_parts_mut( reference_weights.as_mut_ptr() as *mut u8, - reference_weights.len() * BY_3, + reference_weights.len() * BY_X, ); let _ = input_bufreader.read_exact(buf_view); - let mut out_ary: [u8; 4] = [0; 4]; - for (chunk, float_ref) in buf_view.chunks(3).zip(reference_weights.iter_mut()) { - for k in 0..BY_3 { + let mut out_ary: [u8; 2] = [0; 2]; + for (chunk, float_ref) in buf_view.chunks(BY_X).zip(reference_weights.iter_mut()) { + for k in 0..BY_X { out_ary[k] = chunk[k]; } - *float_ref = f32::from_be_bytes(out_ary); + let weight = bf16::to_f32(bf16::from_be_bytes(out_ary)); + *float_ref = weight; } } } @@ -47,9 +50,9 @@ mod tests { use super::*; #[test] - fn test_quantize_3by() { + fn test_quantize_2by() { let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; let output_weights = quantize_ffm_weights_3by(&some_random_float_weights); - assert_eq!(output_weights[3], [61, 252, 80]); + assert_eq!(output_weights[3], [61, 252]); } } From f83c336d40a8642c3c9ece02668de54ce72f49c9 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Fri, 1 Dec 2023 12:50:45 +0100 Subject: [PATCH 07/21] 1by --- src/quantization.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/quantization.rs b/src/quantization.rs index 47c9942f..d51ddee8 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -1,8 +1,8 @@ use std::io; use std::slice; -use half::bf16; +//use half::bf16; -const BY_X: usize = 2; +const BY_X: usize = 1; pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; BY_X]> { @@ -10,7 +10,7 @@ pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; BY_X]> { let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len()); for &weight in weights { - let tmp_bytes = (bf16::from_f32(weight)).to_be_bytes(); + let tmp_bytes = (weight).to_le_bytes(); let mut out_ary: [u8; BY_X] = [0; BY_X]; for k in 0..BY_X { out_ary[k] = tmp_bytes[k]; @@ -34,12 +34,13 @@ pub fn dequantize_ffm_weights_3by( ); let _ = input_bufreader.read_exact(buf_view); - let mut out_ary: [u8; 2] = [0; 2]; + let mut out_ary: [u8; 4] = [0; 4]; for (chunk, float_ref) in buf_view.chunks(BY_X).zip(reference_weights.iter_mut()) { for k in 0..BY_X { out_ary[k] = chunk[k]; } - let weight = bf16::to_f32(bf16::from_be_bytes(out_ary)); + let weight = f32::from_le_bytes(out_ary); +// let weight = bf16::to_f32(bf16::from_be_bytes(out_ary)); *float_ref = weight; } } From f1bd94776a4e2642a48c3b3a5576ab228a00980b Mon Sep 17 00:00:00 2001 From: bskrlj Date: Fri, 1 Dec 2023 12:54:20 +0100 Subject: [PATCH 08/21] tests --- src/quantization.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quantization.rs b/src/quantization.rs index d51ddee8..ea4274a2 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -54,6 +54,6 @@ mod tests { fn test_quantize_2by() { let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; let output_weights = quantize_ffm_weights_3by(&some_random_float_weights); - assert_eq!(output_weights[3], [61, 252]); + assert_eq!(output_weights[3], [72]); } } From 8870793692ef8c6ac34745fe56dcab8a8c0e5cbb Mon Sep 17 00:00:00 2001 From: bskrlj Date: Fri, 1 Dec 2023 12:56:45 +0100 Subject: [PATCH 09/21] dummy cmt --- src/quantization.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/quantization.rs b/src/quantization.rs index ea4274a2..92d51e31 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -40,6 +40,7 @@ pub fn dequantize_ffm_weights_3by( out_ary[k] = chunk[k]; } let weight = f32::from_le_bytes(out_ary); + // uncomment for 16b // let weight = bf16::to_f32(bf16::from_be_bytes(out_ary)); *float_ref = weight; } From 37df5f7bbb79206aa4665cc6e8305ab9b0df5bca Mon Sep 17 00:00:00 2001 From: bskrlj Date: Fri, 1 Dec 2023 13:20:26 +0100 Subject: [PATCH 10/21] 16b --- src/quantization.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/quantization.rs b/src/quantization.rs index 92d51e31..ea13f72e 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -2,7 +2,7 @@ use std::io; use std::slice; //use half::bf16; -const BY_X: usize = 1; +const BY_X: usize = 2; pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; BY_X]> { @@ -55,6 +55,6 @@ mod tests { fn test_quantize_2by() { let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; let output_weights = quantize_ffm_weights_3by(&some_random_float_weights); - assert_eq!(output_weights[3], [72]); + assert_eq!(output_weights[3], [72, 80]); } } From 8b339e0cf3e92660b095852691dec9c7cbfe3d3c Mon Sep 17 00:00:00 2001 From: bskrlj Date: Sun, 3 Dec 2023 09:42:09 +0100 Subject: [PATCH 11/21] range-based quantization --- src/block_ffm.rs | 4 +- src/quantization.rs | 177 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 146 insertions(+), 35 deletions(-) diff --git a/src/block_ffm.rs b/src/block_ffm.rs index c1a042d9..4c879034 100644 --- a/src/block_ffm.rs +++ b/src/block_ffm.rs @@ -834,7 +834,7 @@ impl BlockTrait for BlockFFM { if use_quantization { - let quantized_weights = quantization::quantize_ffm_weights_3by(&self.weights); + let quantized_weights = quantization::quantize_ffm_weights(&self.weights); block_helpers::write_weights_to_buf(&quantized_weights, output_bufwriter, false)?; } else { block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?; @@ -851,7 +851,7 @@ impl BlockTrait for BlockFFM { if use_quantization { // in-place expand weights via dequantization (for inference) - quantization::dequantize_ffm_weights_3by(input_bufreader, &mut self.weights); + quantization::dequantize_ffm_weights(input_bufreader, &mut self.weights); } else { block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?; } diff --git a/src/quantization.rs b/src/quantization.rs index ea13f72e..463791ea 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -1,50 +1,126 @@ use std::io; -use std::slice; -//use half::bf16; +use half::f16; const BY_X: usize = 2; +const NUM_BUCKETS: f32 = 65025.0; +const CRITICAL_WEIGHT_BOUND: f32 = 10.0; // naive detection of really bad weights, this should never get to prod. +const MEAN_SAMPLING_RATIO: usize = 10; -pub fn quantize_ffm_weights_3by(weights: &[f32]) -> Vec<[u8; BY_X]> { +#[derive(Debug)] +struct WeightStat { + min: f32, + max: f32, + mean: f32 +} + + +fn emit_weight_statistics(weights: &[f32]) -> WeightStat { + // Bound estimator for quantization range + + let init_weight = weights[0]; + let mut min_weight = init_weight; + let mut max_weight = init_weight; + let mut mean_weight = 0.0; + let mut weight_counter = 0; + + for (enx, weight) in weights.iter().enumerate() { + + if *weight > max_weight { + max_weight = *weight; + } + + if *weight < min_weight { + min_weight = *weight; + } + + if enx % MEAN_SAMPLING_RATIO == 0 { + weight_counter += 1; + mean_weight += *weight; + } + + } + + log::info!("Weight values; min: {}, max: {}, mean: {}", min_weight, max_weight, mean_weight / weight_counter as f32); + + WeightStat{min: min_weight, max: max_weight, mean: mean_weight} +} + + +pub fn quantize_ffm_weights(weights: &[f32]) -> Vec<[u8; BY_X]> { // Quantize float-based weights to three most significant bytes + // To be more precise in terms of representation of ranges, we extend the weight object with a "header" that contains two floats required for proper dequantization -- this is computed on-the-fly, works better + + + let weight_statistics = emit_weight_statistics(weights); + + // Cheap, yet important check + if weight_statistics.mean > CRITICAL_WEIGHT_BOUND || weight_statistics.mean < -CRITICAL_WEIGHT_BOUND { + panic!("Identified a very skewed weight distribution indicating exploded weights, not serving that! Mean weight value: {}", weight_statistics.mean); + } + // Uniform distribution within the relevant interval + let weight_increment = (weight_statistics.max - weight_statistics.min) / NUM_BUCKETS; let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len()); - for &weight in weights { - let tmp_bytes = (weight).to_le_bytes(); - let mut out_ary: [u8; BY_X] = [0; BY_X]; - for k in 0..BY_X { - out_ary[k] = tmp_bytes[k]; - } - v.push(out_ary); - } - debug_assert_eq!(v.len(), weights.len()); + + // Increment needs to be stored + let weight_increment_bytes = weight_increment.to_le_bytes(); + let deq_header1 = [weight_increment_bytes[0], weight_increment_bytes[1]]; + let deq_header2 = [weight_increment_bytes[2], weight_increment_bytes[3]]; + v.push(deq_header1); + v.push(deq_header2); + + // Minimal value needs to be stored + let min_val_bytes = weight_statistics.min.to_le_bytes(); + let deq_header3 = [min_val_bytes[0], min_val_bytes[1]]; + let deq_header4 = [min_val_bytes[2], min_val_bytes[3]]; + v.push(deq_header3); + v.push(deq_header4); + + for weight in weights { + + let weight_interval = ((*weight - weight_statistics.min) / weight_increment).round(); + let weight_interval_bytes = f16::to_le_bytes(f16::from_f32(weight_interval)); + v.push(weight_interval_bytes); + + } + + // This is done during reading, so fine as a sanity here. + assert_eq!(v.len() - 4, weights.len()); + v } -pub fn dequantize_ffm_weights_3by( +pub fn dequantize_ffm_weights( input_bufreader: &mut dyn io::Read, reference_weights: &mut Vec, ) { // This function overwrites existing weights with dequantized ones from the input buffer. - unsafe { - let buf_view: &mut [u8] = slice::from_raw_parts_mut( - reference_weights.as_mut_ptr() as *mut u8, - reference_weights.len() * BY_X, - ); - let _ = input_bufreader.read_exact(buf_view); - - let mut out_ary: [u8; 4] = [0; 4]; - for (chunk, float_ref) in buf_view.chunks(BY_X).zip(reference_weights.iter_mut()) { - for k in 0..BY_X { - out_ary[k] = chunk[k]; - } - let weight = f32::from_le_bytes(out_ary); - // uncomment for 16b -// let weight = bf16::to_f32(bf16::from_be_bytes(out_ary)); - *float_ref = weight; - } + let mut header: [u8; 8] = [0; 8]; + let _ = input_bufreader.read_exact(&mut header); + + let mut incr_vec: [u8; 4] = [0; 4]; + let mut min_vec: [u8; 4] = [0; 4]; + + for j in 0..4 { + incr_vec[j] = header[j]; + min_vec[j] = header[j + 4]; } + + let weight_increment = f32::from_le_bytes(incr_vec); + let weight_min = f32::from_le_bytes(min_vec); + let mut weight_bytes: [u8; 2] = [0; 2]; + + // All set, dequantize in a stream + for weight_index in 0..reference_weights.len(){ + let _ = input_bufreader.read_exact(&mut weight_bytes); + let weight_interval = f16::from_le_bytes(weight_bytes); + let weight_interval_f32: f32 = weight_interval.to_f32(); + let final_weight = weight_min + weight_interval_f32 * weight_increment; + reference_weights[weight_index] = final_weight; + } + } #[cfg(test)] @@ -52,9 +128,44 @@ mod tests { use super::*; #[test] - fn test_quantize_2by() { + fn test_emit_statistics(){ + let some_random_float_weights = [0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; + let out_struct = emit_weight_statistics(&some_random_float_weights); + assert_eq!(out_struct.mean, 0.51); + assert_eq!(out_struct.max, 0.6123); + assert_eq!(out_struct.min, 0.11); + } + + #[test] + fn test_quantize() { let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; - let output_weights = quantize_ffm_weights_3by(&some_random_float_weights); - assert_eq!(output_weights[3], [72, 80]); + let output_weights = quantize_ffm_weights(&some_random_float_weights); + assert_eq!(output_weights.len(), 10); + } + + #[test] + fn test_dequantize() { + let mut reference_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; + let old_reference_weights = reference_weights.clone(); + let quantized_representation = quantize_ffm_weights(&reference_weights); + let mut all_bytes: Vec = Vec::new(); + for el in quantized_representation { + all_bytes.push(el[0]); + all_bytes.push(el[1]); + } + let mut contents = io::Cursor::new(all_bytes); + dequantize_ffm_weights(&mut contents, &mut reference_weights); + + let matching = old_reference_weights.iter().zip(&reference_weights).filter(|&(a, b)| a == b).count(); + + assert_ne!(matching, 0); + + let allowed_eps = 0.0001; + let mut all_diffs = 0.0; + for it in old_reference_weights.iter().zip(reference_weights.iter()) { + let (old, new) = it; + all_diffs += (old - new).abs(); + } + assert!(all_diffs < allowed_eps); } } From d1f9bc7b55ba53fa2470c238206f8d532d9b8b73 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Tue, 5 Dec 2023 11:10:31 +0100 Subject: [PATCH 12/21] no need for panic --- src/quantization.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quantization.rs b/src/quantization.rs index 463791ea..d1be161a 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -56,7 +56,7 @@ pub fn quantize_ffm_weights(weights: &[f32]) -> Vec<[u8; BY_X]> { // Cheap, yet important check if weight_statistics.mean > CRITICAL_WEIGHT_BOUND || weight_statistics.mean < -CRITICAL_WEIGHT_BOUND { - panic!("Identified a very skewed weight distribution indicating exploded weights, not serving that! Mean weight value: {}", weight_statistics.mean); + log::warn!("Identified a very skewed weight distribution indicating exploded weights, not serving that! Mean weight value: {}", weight_statistics.mean); } // Uniform distribution within the relevant interval From c8fb38ace5aefecaecc71085c3b90f783e2f4dde Mon Sep 17 00:00:00 2001 From: bskrlj Date: Tue, 5 Dec 2023 14:42:09 +0100 Subject: [PATCH 13/21] more tests --- src/quantization.rs | 135 +++++++++++++++++++++----------------------- 1 file changed, 65 insertions(+), 70 deletions(-) diff --git a/src/quantization.rs b/src/quantization.rs index d1be161a..b84fd545 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -16,76 +16,54 @@ struct WeightStat { fn emit_weight_statistics(weights: &[f32]) -> WeightStat { - // Bound estimator for quantization range - - let init_weight = weights[0]; - let mut min_weight = init_weight; - let mut max_weight = init_weight; + let mut min_weight = weights[0]; + let mut max_weight = weights[0]; let mut mean_weight = 0.0; let mut weight_counter = 0; - for (enx, weight) in weights.iter().enumerate() { - - if *weight > max_weight { - max_weight = *weight; - } - - if *weight < min_weight { - min_weight = *weight; - } + for (enx, &weight) in weights.iter().enumerate() { + max_weight = max_weight.max(weight); + min_weight = min_weight.min(weight); if enx % MEAN_SAMPLING_RATIO == 0 { weight_counter += 1; - mean_weight += *weight; + mean_weight += weight; } - } - log::info!("Weight values; min: {}, max: {}, mean: {}", min_weight, max_weight, mean_weight / weight_counter as f32); - - WeightStat{min: min_weight, max: max_weight, mean: mean_weight} + WeightStat { + min: min_weight, + max: max_weight, + mean: mean_weight / weight_counter as f32, + } } - pub fn quantize_ffm_weights(weights: &[f32]) -> Vec<[u8; BY_X]> { - // Quantize float-based weights to three most significant bytes - // To be more precise in terms of representation of ranges, we extend the weight object with a "header" that contains two floats required for proper dequantization -- this is computed on-the-fly, works better - - let weight_statistics = emit_weight_statistics(weights); + let weight_increment = (weight_statistics.max - weight_statistics.min) / NUM_BUCKETS; - // Cheap, yet important check - if weight_statistics.mean > CRITICAL_WEIGHT_BOUND || weight_statistics.mean < -CRITICAL_WEIGHT_BOUND { + if weight_statistics.mean.abs() > CRITICAL_WEIGHT_BOUND { log::warn!("Identified a very skewed weight distribution indicating exploded weights, not serving that! Mean weight value: {}", weight_statistics.mean); } - // Uniform distribution within the relevant interval - let weight_increment = (weight_statistics.max - weight_statistics.min) / NUM_BUCKETS; - let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len()); - - // Increment needs to be stored - let weight_increment_bytes = weight_increment.to_le_bytes(); - let deq_header1 = [weight_increment_bytes[0], weight_increment_bytes[1]]; - let deq_header2 = [weight_increment_bytes[2], weight_increment_bytes[3]]; - v.push(deq_header1); - v.push(deq_header2); - - // Minimal value needs to be stored - let min_val_bytes = weight_statistics.min.to_le_bytes(); - let deq_header3 = [min_val_bytes[0], min_val_bytes[1]]; - let deq_header4 = [min_val_bytes[2], min_val_bytes[3]]; - v.push(deq_header3); - v.push(deq_header4); + log::info!("Weight values; min: {}, max: {}, mean: {}", weight_statistics.min, weight_statistics.max, weight_statistics.mean); - for weight in weights { + let weight_increment_bytes = weight_increment.to_le_bytes(); + let min_val_bytes = weight_statistics.min.to_le_bytes(); - let weight_interval = ((*weight - weight_statistics.min) / weight_increment).round(); - let weight_interval_bytes = f16::to_le_bytes(f16::from_f32(weight_interval)); - v.push(weight_interval_bytes); - + let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len() + 4); + + // Bytes are stored as pairs + v.push([weight_increment_bytes[0], weight_increment_bytes[1]]); + v.push([weight_increment_bytes[2], weight_increment_bytes[3]]); + v.push([min_val_bytes[0], min_val_bytes[1]]); + v.push([min_val_bytes[2], min_val_bytes[3]]); + + for &weight in weights { + let weight_interval = ((weight - weight_statistics.min) / weight_increment).round(); + v.push(f16::to_le_bytes(f16::from_f32(weight_interval))); } - // This is done during reading, so fine as a sanity here. assert_eq!(v.len() - 4, weights.len()); v @@ -95,32 +73,20 @@ pub fn dequantize_ffm_weights( input_bufreader: &mut dyn io::Read, reference_weights: &mut Vec, ) { - // This function overwrites existing weights with dequantized ones from the input buffer. - let mut header: [u8; 8] = [0; 8]; - let _ = input_bufreader.read_exact(&mut header); - - let mut incr_vec: [u8; 4] = [0; 4]; - let mut min_vec: [u8; 4] = [0; 4]; + input_bufreader.read_exact(&mut header).unwrap(); - for j in 0..4 { - incr_vec[j] = header[j]; - min_vec[j] = header[j + 4]; - } - - let weight_increment = f32::from_le_bytes(incr_vec); - let weight_min = f32::from_le_bytes(min_vec); + let weight_increment = f32::from_le_bytes([header[0], header[1], header[2], header[3]]); + let weight_min = f32::from_le_bytes([header[4], header[5], header[6], header[7]]); let mut weight_bytes: [u8; 2] = [0; 2]; - // All set, dequantize in a stream - for weight_index in 0..reference_weights.len(){ - let _ = input_bufreader.read_exact(&mut weight_bytes); - let weight_interval = f16::from_le_bytes(weight_bytes); - let weight_interval_f32: f32 = weight_interval.to_f32(); - let final_weight = weight_min + weight_interval_f32 * weight_increment; - reference_weights[weight_index] = final_weight; + for weight_index in 0..reference_weights.len() { + input_bufreader.read_exact(&mut weight_bytes).unwrap(); + + let weight_interval = f16::from_le_bytes(weight_bytes); + let final_weight = weight_min + weight_interval.to_f32() * weight_increment; + reference_weights[weight_index] = final_weight; } - } #[cfg(test)] @@ -168,4 +134,33 @@ mod tests { } assert!(all_diffs < allowed_eps); } + + + #[test] + fn test_large_values() { + let weights = vec![-1e9, 1e9]; + let quantized = quantize_ffm_weights(&weights); + let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); + let mut dequantized = vec![0.0; weights.len()]; + dequantize_ffm_weights(&mut buffer, &mut dequantized); + for (w, dw) in weights.iter().zip(&dequantized) { + assert!((w - dw).abs() / w.abs() < 0.1, "Relative error is too large"); + } + } + + + #[test] + fn test_performance() { + let weights: Vec = (0..10_000_000).map(|x| x as f32).collect(); + let now = std::time::Instant::now(); + let quantized = quantize_ffm_weights(&weights); + assert!(now.elapsed().as_millis() < 150); + + let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); + let mut dequantized = vec![0.0; weights.len()]; + let now = std::time::Instant::now(); + dequantize_ffm_weights(&mut buffer, &mut dequantized); + assert!(now.elapsed().as_millis() < 150); + } + } From a4f8cfd7fe428b2de2af824662183e9b71fc6593 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Tue, 5 Dec 2023 14:59:03 +0100 Subject: [PATCH 14/21] tests2 --- src/quantization.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/quantization.rs b/src/quantization.rs index b84fd545..00aaf185 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -154,13 +154,13 @@ mod tests { let weights: Vec = (0..10_000_000).map(|x| x as f32).collect(); let now = std::time::Instant::now(); let quantized = quantize_ffm_weights(&weights); - assert!(now.elapsed().as_millis() < 150); + assert!(now.elapsed().as_millis() < 300); let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); let mut dequantized = vec![0.0; weights.len()]; let now = std::time::Instant::now(); dequantize_ffm_weights(&mut buffer, &mut dequantized); - assert!(now.elapsed().as_millis() < 150); + assert!(now.elapsed().as_millis() < 300); } } From d07df28c38461e22fc79ba4f98b65c3f8ee00fa8 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Thu, 7 Dec 2023 15:23:32 +0100 Subject: [PATCH 15/21] missing half makes for full functionality --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index f14408cc..07661311 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ pub mod vwmap; extern crate blas; extern crate intel_mkl_src; +extern crate half; use crate::feature_buffer::FeatureBufferTranslator; use crate::multithread_helpers::BoxedRegressorTrait; From 7fb23e27532f095e005ea8b3815063053a16c031 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Fri, 8 Dec 2023 09:13:10 +0100 Subject: [PATCH 16/21] forward regressor --- src/block_ffm.rs | 9 ++++++++- src/block_lr.rs | 1 + src/block_neural.rs | 1 + src/persistence.rs | 5 +++-- src/regressor.rs | 6 ++++-- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/block_ffm.rs b/src/block_ffm.rs index 4c879034..e5841bd4 100644 --- a/src/block_ffm.rs +++ b/src/block_ffm.rs @@ -878,12 +878,19 @@ impl BlockTrait for BlockFFM { &self, input_bufreader: &mut dyn io::Read, forward: &mut Box, + use_quantization: bool ) -> Result<(), Box> { let forward = forward .as_any() .downcast_mut::>() .unwrap(); - block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?; + + if use_quantization { + // in-place expand weights via dequantization (for inference) + quantization::dequantize_ffm_weights(input_bufreader, &mut forward.weights); + } else { + block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?; + } block_helpers::skip_weights_from_buf::>( self.ffm_weights_len as usize, input_bufreader, diff --git a/src/block_lr.rs b/src/block_lr.rs index 5bbd17ba..35c62557 100644 --- a/src/block_lr.rs +++ b/src/block_lr.rs @@ -280,6 +280,7 @@ impl BlockTrait for BlockLR { &self, input_bufreader: &mut dyn io::Read, forward: &mut Box, + use_quantization: bool ) -> Result<(), Box> { let forward = forward .as_any() diff --git a/src/block_neural.rs b/src/block_neural.rs index 7c159c06..d982fd56 100644 --- a/src/block_neural.rs +++ b/src/block_neural.rs @@ -466,6 +466,7 @@ impl BlockTrait for BlockNeuronLayer { &self, input_bufreader: &mut dyn io::Read, forward: &mut Box, + use_quantization: bool ) -> Result<(), Box> { let forward = forward .as_any() diff --git a/src/persistence.rs b/src/persistence.rs index 703f0020..5cac2da4 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -156,10 +156,11 @@ pub fn new_regressor_from_filename( re.overwrite_weights_from_buf(&mut input_bufreader, weight_quantization)?; Ok((mi, vw, re)) } else { + log::info!("IMMUTABLE READ here"); mi.optimizer = model_instance::Optimizer::SGD; let mut immutable_re = re.immutable_regressor_without_weights(&mi)?; immutable_re.allocate_and_init_weights(&mi); - re.into_immutable_regressor_from_buf(&mut immutable_re, &mut input_bufreader)?; + re.into_immutable_regressor_from_buf(&mut immutable_re, &mut input_bufreader, weight_quantization)?; Ok((mi, vw, immutable_re)) } } @@ -171,7 +172,7 @@ pub fn hogwild_load(re: &mut regressor::Regressor, filename: &str) -> Result<(), if !re.immutable { re.overwrite_weights_from_buf(&mut input_bufreader, false)?; } else { - re_hw.into_immutable_regressor_from_buf(re, &mut input_bufreader)?; + re_hw.into_immutable_regressor_from_buf(re, &mut input_bufreader, false)?; } Ok(()) } diff --git a/src/regressor.rs b/src/regressor.rs index f45c63f6..ed1541c5 100644 --- a/src/regressor.rs +++ b/src/regressor.rs @@ -133,6 +133,7 @@ pub trait BlockTrait { &self, _input_bufreader: &mut dyn io::Read, _forward: &mut Box, + _use_quantization: bool ) -> Result<(), Box> { Ok(()) } @@ -483,6 +484,7 @@ impl Regressor { &mut self, rg: &mut Regressor, input_bufreader: &mut dyn io::Read, + use_quantization: bool ) -> Result<(), Box> { // TODO Ideally we would make a copy, not based on model_instance. but this is easier at the moment @@ -499,7 +501,7 @@ impl Regressor { ))?; } for (i, v) in &mut self.blocks_boxes.iter().enumerate() { - v.read_weights_from_buf_into_forward_only(input_bufreader, &mut rg.blocks_boxes[i])?; + v.read_weights_from_buf_into_forward_only(input_bufreader, &mut rg.blocks_boxes[i], use_quantization)?; } Ok(()) @@ -522,7 +524,7 @@ impl Regressor { let mut cursor = Cursor::new(&mut tmp_vec); v.write_weights_to_buf(&mut cursor, use_quantization)?; cursor.set_position(0); - v.read_weights_from_buf_into_forward_only(&mut cursor, &mut rg.blocks_boxes[i])?; + v.read_weights_from_buf_into_forward_only(&mut cursor, &mut rg.blocks_boxes[i], false)?; } Ok(rg) } From b6309068321f04519bfe9e6caa5c7c729e02df3a Mon Sep 17 00:00:00 2001 From: bskrlj Date: Fri, 8 Dec 2023 10:47:33 +0100 Subject: [PATCH 17/21] some improvements --- src/block_lr.rs | 2 +- src/block_neural.rs | 2 +- src/persistence.rs | 1 - src/quantization.rs | 6 ++++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/block_lr.rs b/src/block_lr.rs index 35c62557..f654e694 100644 --- a/src/block_lr.rs +++ b/src/block_lr.rs @@ -280,7 +280,7 @@ impl BlockTrait for BlockLR { &self, input_bufreader: &mut dyn io::Read, forward: &mut Box, - use_quantization: bool + _use_quantization: bool ) -> Result<(), Box> { let forward = forward .as_any() diff --git a/src/block_neural.rs b/src/block_neural.rs index d982fd56..87cfc480 100644 --- a/src/block_neural.rs +++ b/src/block_neural.rs @@ -466,7 +466,7 @@ impl BlockTrait for BlockNeuronLayer { &self, input_bufreader: &mut dyn io::Read, forward: &mut Box, - use_quantization: bool + _use_quantization: bool ) -> Result<(), Box> { let forward = forward .as_any() diff --git a/src/persistence.rs b/src/persistence.rs index 5cac2da4..c9745979 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -156,7 +156,6 @@ pub fn new_regressor_from_filename( re.overwrite_weights_from_buf(&mut input_bufreader, weight_quantization)?; Ok((mi, vw, re)) } else { - log::info!("IMMUTABLE READ here"); mi.optimizer = model_instance::Optimizer::SGD; let mut immutable_re = re.immutable_regressor_without_weights(&mi)?; immutable_re.allocate_and_init_weights(&mi); diff --git a/src/quantization.rs b/src/quantization.rs index 00aaf185..13337782 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -5,6 +5,8 @@ const BY_X: usize = 2; const NUM_BUCKETS: f32 = 65025.0; const CRITICAL_WEIGHT_BOUND: f32 = 10.0; // naive detection of really bad weights, this should never get to prod. const MEAN_SAMPLING_RATIO: usize = 10; +const MIN_PREC: f32 = 10_000.0; +const MAX_PREC: f32 = 10_000.0; #[derive(Debug)] @@ -32,8 +34,8 @@ fn emit_weight_statistics(weights: &[f32]) -> WeightStat { } WeightStat { - min: min_weight, - max: max_weight, + min: (min_weight * MIN_PREC).round() / MIN_PREC, + max: (max_weight * MAX_PREC).round() / MAX_PREC, mean: mean_weight / weight_counter as f32, } } From 08a129f443472c0ce6aab21962a69a5b3401a480 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Fri, 8 Dec 2023 14:47:25 +0100 Subject: [PATCH 18/21] comments --- src/block_ffm.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/block_ffm.rs b/src/block_ffm.rs index e5841bd4..6f8bc416 100644 --- a/src/block_ffm.rs +++ b/src/block_ffm.rs @@ -850,7 +850,6 @@ impl BlockTrait for BlockFFM { ) -> Result<(), Box> { if use_quantization { - // in-place expand weights via dequantization (for inference) quantization::dequantize_ffm_weights(input_bufreader, &mut self.weights); } else { block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?; @@ -886,7 +885,6 @@ impl BlockTrait for BlockFFM { .unwrap(); if use_quantization { - // in-place expand weights via dequantization (for inference) quantization::dequantize_ffm_weights(input_bufreader, &mut forward.weights); } else { block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?; From 9dae4a8d266065422a697ef6558a7be08c628cdb Mon Sep 17 00:00:00 2001 From: bskrlj Date: Mon, 11 Dec 2023 13:42:57 +0100 Subject: [PATCH 19/21] testing v2 --- src/quantization.rs | 227 ++++++++++++++++++++++++++++++++------------ 1 file changed, 166 insertions(+), 61 deletions(-) diff --git a/src/quantization.rs b/src/quantization.rs index 13337782..2d1c04c3 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -3,6 +3,7 @@ use half::f16; const BY_X: usize = 2; const NUM_BUCKETS: f32 = 65025.0; +//const NUM_BUCKETS: f32 = 255.0; const CRITICAL_WEIGHT_BOUND: f32 = 10.0; // naive detection of really bad weights, this should never get to prod. const MEAN_SAMPLING_RATIO: usize = 10; const MIN_PREC: f32 = 10_000.0; @@ -13,15 +14,22 @@ const MAX_PREC: f32 = 10_000.0; struct WeightStat { min: f32, max: f32, - mean: f32 + mean: f32, + std: f32 } fn emit_weight_statistics(weights: &[f32]) -> WeightStat { + let mut min_weight = weights[0]; let mut max_weight = weights[0]; + let mut mean_weight = 0.0; let mut weight_counter = 0; + + let mut square_sums = 0.0; + let mut sum_squares = 0.0; + let mut std_est = 0.0; for (enx, &weight) in weights.iter().enumerate() { max_weight = max_weight.max(weight); @@ -30,6 +38,9 @@ fn emit_weight_statistics(weights: &[f32]) -> WeightStat { if enx % MEAN_SAMPLING_RATIO == 0 { weight_counter += 1; mean_weight += weight; + square_sums += weight.powf(2.0); + sum_squares += weight; + std_est = (square_sums / weight_counter as f32 - (sum_squares / weight_counter as f32).powf(2.0)).sqrt(); // can be done a bit better with Knuth's formula, slower tho } } @@ -37,36 +48,101 @@ fn emit_weight_statistics(weights: &[f32]) -> WeightStat { min: (min_weight * MIN_PREC).round() / MIN_PREC, max: (max_weight * MAX_PREC).round() / MAX_PREC, mean: mean_weight / weight_counter as f32, + std: std_est, + } +} + + +fn non_uniform_binner(weight_min: f32, weight_max: f32, weight_mean: f32, weight_std: f32) -> Vec { + + // lower and upper bounds need to be informed + let focus_buckets_min = weight_mean - 0.1 * weight_min.abs(); + let focus_buckets_max = weight_mean + 0.1 * weight_max.abs(); + + let mut bucket_dist_focus = (NUM_BUCKETS * 0.8).round(); + let mut bucket_dist_remainder = NUM_BUCKETS - bucket_dist_focus; + + // must be odd + if bucket_dist_remainder as usize % 2 == 0 { + bucket_dist_remainder = bucket_dist_remainder / 2.0; + } else { + bucket_dist_remainder = (bucket_dist_remainder - 1.0) / 2.0; + bucket_dist_focus += 1.0; + } + + // two side intervals + focus one + let interval_first = (focus_buckets_min - weight_min) / bucket_dist_remainder; + let interval_second = (focus_buckets_max - focus_buckets_min) / bucket_dist_focus; + let interval_third = (weight_max - focus_buckets_max) / bucket_dist_remainder; + + let mut bucket_space = Vec::::with_capacity(NUM_BUCKETS as usize); + + let mut current_weight_value = weight_min; + for _ in 0..bucket_dist_remainder as usize { + current_weight_value += interval_first; + bucket_space.push(current_weight_value); + } + + for _ in 0..bucket_dist_focus as usize { + current_weight_value += interval_second; + bucket_space.push(current_weight_value); } + + for _ in 0..bucket_dist_remainder as usize { + current_weight_value += interval_third; + bucket_space.push(current_weight_value); + } + + bucket_space + +} + +fn identify_bucket(weight: f32, bucket_space: &Vec) -> f32 { + bucket_space.partition_point(|x| x < &weight) as f32 } pub fn quantize_ffm_weights(weights: &[f32]) -> Vec<[u8; BY_X]> { let weight_statistics = emit_weight_statistics(weights); let weight_increment = (weight_statistics.max - weight_statistics.min) / NUM_BUCKETS; + let quantized = non_uniform_binner(weight_statistics.min, weight_statistics.max, weight_statistics.mean, weight_statistics.std); + if weight_statistics.mean.abs() > CRITICAL_WEIGHT_BOUND { log::warn!("Identified a very skewed weight distribution indicating exploded weights, not serving that! Mean weight value: {}", weight_statistics.mean); } - log::info!("Weight values; min: {}, max: {}, mean: {}", weight_statistics.min, weight_statistics.max, weight_statistics.mean); + log::info!("Weight values; min: {}, max: {}, mean: {}, std: {}", weight_statistics.min, weight_statistics.max, weight_statistics.mean, weight_statistics.std); let weight_increment_bytes = weight_increment.to_le_bytes(); let min_val_bytes = weight_statistics.min.to_le_bytes(); + let max_val_bytes = weight_statistics.max.to_le_bytes(); + let mean_val_bytes = weight_statistics.mean.to_le_bytes(); + let std_val_bytes = weight_statistics.mean.to_le_bytes(); + + let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len() + 10); - let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len() + 4); - - // Bytes are stored as pairs v.push([weight_increment_bytes[0], weight_increment_bytes[1]]); v.push([weight_increment_bytes[2], weight_increment_bytes[3]]); + v.push([min_val_bytes[0], min_val_bytes[1]]); v.push([min_val_bytes[2], min_val_bytes[3]]); + + v.push([max_val_bytes[0], max_val_bytes[1]]); + v.push([max_val_bytes[2], max_val_bytes[3]]); + + v.push([mean_val_bytes[0], mean_val_bytes[1]]); + v.push([mean_val_bytes[2], mean_val_bytes[3]]); + v.push([std_val_bytes[0], std_val_bytes[1]]); + v.push([std_val_bytes[2], std_val_bytes[3]]); + for &weight in weights { - let weight_interval = ((weight - weight_statistics.min) / weight_increment).round(); + let weight_interval = identify_bucket(weight, &quantized); +// let weight_interval = ((weight - weight_statistics.min) / weight_increment).round(); v.push(f16::to_le_bytes(f16::from_f32(weight_interval))); } - assert_eq!(v.len() - 4, weights.len()); + assert_eq!(v.len() - 10, weights.len()); v } @@ -75,18 +151,28 @@ pub fn dequantize_ffm_weights( input_bufreader: &mut dyn io::Read, reference_weights: &mut Vec, ) { - let mut header: [u8; 8] = [0; 8]; + let mut header: [u8; 20] = [0; 20]; input_bufreader.read_exact(&mut header).unwrap(); let weight_increment = f32::from_le_bytes([header[0], header[1], header[2], header[3]]); let weight_min = f32::from_le_bytes([header[4], header[5], header[6], header[7]]); + + let weight_max = f32::from_le_bytes([header[8], header[9], header[10], header[11]]); + + let weight_mean = f32::from_le_bytes([header[12], header[13], header[14], header[15]]); + + let weight_std = f32::from_le_bytes([header[16], header[17], header[18], header[19]]); + + let quantized = non_uniform_binner(weight_min, weight_max, weight_mean, weight_std); + let mut weight_bytes: [u8; 2] = [0; 2]; for weight_index in 0..reference_weights.len() { input_bufreader.read_exact(&mut weight_bytes).unwrap(); let weight_interval = f16::from_le_bytes(weight_bytes); - let final_weight = weight_min + weight_interval.to_f32() * weight_increment; + let final_weight = quantized[weight_interval.to_f32() as usize]; +// let final_weight = weight_min + weight_interval.to_f32() * weight_increment; reference_weights[weight_index] = final_weight; } } @@ -104,65 +190,84 @@ mod tests { assert_eq!(out_struct.min, 0.11); } - #[test] - fn test_quantize() { - let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; - let output_weights = quantize_ffm_weights(&some_random_float_weights); - assert_eq!(output_weights.len(), 10); - } + // #[test] + // fn test_quantize() { + // let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; + // let output_weights = quantize_ffm_weights(&some_random_float_weights); + // assert_eq!(output_weights.len(), 10); + // } - #[test] - fn test_dequantize() { - let mut reference_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; - let old_reference_weights = reference_weights.clone(); - let quantized_representation = quantize_ffm_weights(&reference_weights); - let mut all_bytes: Vec = Vec::new(); - for el in quantized_representation { - all_bytes.push(el[0]); - all_bytes.push(el[1]); - } - let mut contents = io::Cursor::new(all_bytes); - dequantize_ffm_weights(&mut contents, &mut reference_weights); + // #[test] + // fn test_dequantize() { + // let mut reference_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; + // let old_reference_weights = reference_weights.clone(); + // let quantized_representation = quantize_ffm_weights(&reference_weights); + // let mut all_bytes: Vec = Vec::new(); + // for el in quantized_representation { + // all_bytes.push(el[0]); + // all_bytes.push(el[1]); + // } + // let mut contents = io::Cursor::new(all_bytes); + // dequantize_ffm_weights(&mut contents, &mut reference_weights); - let matching = old_reference_weights.iter().zip(&reference_weights).filter(|&(a, b)| a == b).count(); + // let matching = old_reference_weights.iter().zip(&reference_weights).filter(|&(a, b)| a == b).count(); - assert_ne!(matching, 0); + // assert_ne!(matching, 0); - let allowed_eps = 0.0001; - let mut all_diffs = 0.0; - for it in old_reference_weights.iter().zip(reference_weights.iter()) { - let (old, new) = it; - all_diffs += (old - new).abs(); - } - assert!(all_diffs < allowed_eps); - } + // let allowed_eps = 0.0001; + // let mut all_diffs = 0.0; + // for it in old_reference_weights.iter().zip(reference_weights.iter()) { + // let (old, new) = it; + // all_diffs += (old - new).abs(); + // } + // assert!(all_diffs < allowed_eps); + // } - #[test] - fn test_large_values() { - let weights = vec![-1e9, 1e9]; - let quantized = quantize_ffm_weights(&weights); - let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); - let mut dequantized = vec![0.0; weights.len()]; - dequantize_ffm_weights(&mut buffer, &mut dequantized); - for (w, dw) in weights.iter().zip(&dequantized) { - assert!((w - dw).abs() / w.abs() < 0.1, "Relative error is too large"); - } - } + // #[test] + // fn test_large_values() { + // let weights = vec![-1e9, 1e9]; + // let quantized = quantize_ffm_weights(&weights); + // let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); + // let mut dequantized = vec![0.0; weights.len()]; + // dequantize_ffm_weights(&mut buffer, &mut dequantized); + // for (w, dw) in weights.iter().zip(&dequantized) { + // assert!((w - dw).abs() / w.abs() < 0.1, "Relative error is too large"); + // } + // } - #[test] - fn test_performance() { - let weights: Vec = (0..10_000_000).map(|x| x as f32).collect(); - let now = std::time::Instant::now(); - let quantized = quantize_ffm_weights(&weights); - assert!(now.elapsed().as_millis() < 300); + // #[test] + // fn test_performance() { + // let weights: Vec = (0..10_000_000).map(|x| x as f32).collect(); + // let now = std::time::Instant::now(); + // let quantized = quantize_ffm_weights(&weights); + // assert!(now.elapsed().as_millis() < 300); - let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); - let mut dequantized = vec![0.0; weights.len()]; - let now = std::time::Instant::now(); - dequantize_ffm_weights(&mut buffer, &mut dequantized); - assert!(now.elapsed().as_millis() < 300); - } - + // let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); + // let mut dequantized = vec![0.0; weights.len()]; + // let now = std::time::Instant::now(); + // dequantize_ffm_weights(&mut buffer, &mut dequantized); + // assert!(now.elapsed().as_millis() < 300); + // } + + // #[test] + // fn test_nonuniform() { + // let quantized = non_uniform_binner(-0.42, 0.82, 0.0, 0.02); + // assert!(quantized.len() == NUM_BUCKETS as usize); + + // let bucket = identify_bucket(-0.35, &quantized); + // assert!(bucket == 1137.0); + + // let bucket = identify_bucket(0.23, &quantized); + // assert!(bucket == 60229.0); + + // let now = std::time::Instant::now(); + + // // these calls need to be ultra fast + // for _ in 0..1_000_000 { + // identify_bucket(0.23, &quantized); + // } + // assert!(now.elapsed().as_millis() < 300); + // } } From 6881bc20414504a2fd00933c5397e03d6a9ff5a0 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Thu, 14 Dec 2023 07:59:22 +0100 Subject: [PATCH 20/21] optionals --- src/model_instance.rs | 4 +- src/quantization.rs | 227 ++++++++++++------------------------------ 2 files changed, 63 insertions(+), 168 deletions(-) diff --git a/src/model_instance.rs b/src/model_instance.rs index f842fed4..8cabca11 100644 --- a/src/model_instance.rs +++ b/src/model_instance.rs @@ -93,7 +93,7 @@ pub struct ModelInstance { pub transform_namespaces: feature_transform_parser::NamespaceTransforms, - pub dequantize_weights: bool, + pub dequantize_weights: Option, } @@ -145,7 +145,7 @@ impl ModelInstance { optimizer: Optimizer::SGD, transform_namespaces: feature_transform_parser::NamespaceTransforms::new(), nn_config: NNConfig::new(), - dequantize_weights: false, + dequantize_weights: Option, }; Ok(mi) } diff --git a/src/quantization.rs b/src/quantization.rs index 2d1c04c3..13337782 100644 --- a/src/quantization.rs +++ b/src/quantization.rs @@ -3,7 +3,6 @@ use half::f16; const BY_X: usize = 2; const NUM_BUCKETS: f32 = 65025.0; -//const NUM_BUCKETS: f32 = 255.0; const CRITICAL_WEIGHT_BOUND: f32 = 10.0; // naive detection of really bad weights, this should never get to prod. const MEAN_SAMPLING_RATIO: usize = 10; const MIN_PREC: f32 = 10_000.0; @@ -14,22 +13,15 @@ const MAX_PREC: f32 = 10_000.0; struct WeightStat { min: f32, max: f32, - mean: f32, - std: f32 + mean: f32 } fn emit_weight_statistics(weights: &[f32]) -> WeightStat { - let mut min_weight = weights[0]; let mut max_weight = weights[0]; - let mut mean_weight = 0.0; let mut weight_counter = 0; - - let mut square_sums = 0.0; - let mut sum_squares = 0.0; - let mut std_est = 0.0; for (enx, &weight) in weights.iter().enumerate() { max_weight = max_weight.max(weight); @@ -38,9 +30,6 @@ fn emit_weight_statistics(weights: &[f32]) -> WeightStat { if enx % MEAN_SAMPLING_RATIO == 0 { weight_counter += 1; mean_weight += weight; - square_sums += weight.powf(2.0); - sum_squares += weight; - std_est = (square_sums / weight_counter as f32 - (sum_squares / weight_counter as f32).powf(2.0)).sqrt(); // can be done a bit better with Knuth's formula, slower tho } } @@ -48,101 +37,36 @@ fn emit_weight_statistics(weights: &[f32]) -> WeightStat { min: (min_weight * MIN_PREC).round() / MIN_PREC, max: (max_weight * MAX_PREC).round() / MAX_PREC, mean: mean_weight / weight_counter as f32, - std: std_est, - } -} - - -fn non_uniform_binner(weight_min: f32, weight_max: f32, weight_mean: f32, weight_std: f32) -> Vec { - - // lower and upper bounds need to be informed - let focus_buckets_min = weight_mean - 0.1 * weight_min.abs(); - let focus_buckets_max = weight_mean + 0.1 * weight_max.abs(); - - let mut bucket_dist_focus = (NUM_BUCKETS * 0.8).round(); - let mut bucket_dist_remainder = NUM_BUCKETS - bucket_dist_focus; - - // must be odd - if bucket_dist_remainder as usize % 2 == 0 { - bucket_dist_remainder = bucket_dist_remainder / 2.0; - } else { - bucket_dist_remainder = (bucket_dist_remainder - 1.0) / 2.0; - bucket_dist_focus += 1.0; - } - - // two side intervals + focus one - let interval_first = (focus_buckets_min - weight_min) / bucket_dist_remainder; - let interval_second = (focus_buckets_max - focus_buckets_min) / bucket_dist_focus; - let interval_third = (weight_max - focus_buckets_max) / bucket_dist_remainder; - - let mut bucket_space = Vec::::with_capacity(NUM_BUCKETS as usize); - - let mut current_weight_value = weight_min; - for _ in 0..bucket_dist_remainder as usize { - current_weight_value += interval_first; - bucket_space.push(current_weight_value); - } - - for _ in 0..bucket_dist_focus as usize { - current_weight_value += interval_second; - bucket_space.push(current_weight_value); } - - for _ in 0..bucket_dist_remainder as usize { - current_weight_value += interval_third; - bucket_space.push(current_weight_value); - } - - bucket_space - -} - -fn identify_bucket(weight: f32, bucket_space: &Vec) -> f32 { - bucket_space.partition_point(|x| x < &weight) as f32 } pub fn quantize_ffm_weights(weights: &[f32]) -> Vec<[u8; BY_X]> { let weight_statistics = emit_weight_statistics(weights); let weight_increment = (weight_statistics.max - weight_statistics.min) / NUM_BUCKETS; - let quantized = non_uniform_binner(weight_statistics.min, weight_statistics.max, weight_statistics.mean, weight_statistics.std); - if weight_statistics.mean.abs() > CRITICAL_WEIGHT_BOUND { log::warn!("Identified a very skewed weight distribution indicating exploded weights, not serving that! Mean weight value: {}", weight_statistics.mean); } - log::info!("Weight values; min: {}, max: {}, mean: {}, std: {}", weight_statistics.min, weight_statistics.max, weight_statistics.mean, weight_statistics.std); + log::info!("Weight values; min: {}, max: {}, mean: {}", weight_statistics.min, weight_statistics.max, weight_statistics.mean); let weight_increment_bytes = weight_increment.to_le_bytes(); let min_val_bytes = weight_statistics.min.to_le_bytes(); - let max_val_bytes = weight_statistics.max.to_le_bytes(); - let mean_val_bytes = weight_statistics.mean.to_le_bytes(); - let std_val_bytes = weight_statistics.mean.to_le_bytes(); - - let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len() + 10); + let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len() + 4); + + // Bytes are stored as pairs v.push([weight_increment_bytes[0], weight_increment_bytes[1]]); v.push([weight_increment_bytes[2], weight_increment_bytes[3]]); - v.push([min_val_bytes[0], min_val_bytes[1]]); v.push([min_val_bytes[2], min_val_bytes[3]]); - - v.push([max_val_bytes[0], max_val_bytes[1]]); - v.push([max_val_bytes[2], max_val_bytes[3]]); - - v.push([mean_val_bytes[0], mean_val_bytes[1]]); - v.push([mean_val_bytes[2], mean_val_bytes[3]]); - v.push([std_val_bytes[0], std_val_bytes[1]]); - v.push([std_val_bytes[2], std_val_bytes[3]]); - for &weight in weights { - let weight_interval = identify_bucket(weight, &quantized); -// let weight_interval = ((weight - weight_statistics.min) / weight_increment).round(); + let weight_interval = ((weight - weight_statistics.min) / weight_increment).round(); v.push(f16::to_le_bytes(f16::from_f32(weight_interval))); } - assert_eq!(v.len() - 10, weights.len()); + assert_eq!(v.len() - 4, weights.len()); v } @@ -151,28 +75,18 @@ pub fn dequantize_ffm_weights( input_bufreader: &mut dyn io::Read, reference_weights: &mut Vec, ) { - let mut header: [u8; 20] = [0; 20]; + let mut header: [u8; 8] = [0; 8]; input_bufreader.read_exact(&mut header).unwrap(); let weight_increment = f32::from_le_bytes([header[0], header[1], header[2], header[3]]); let weight_min = f32::from_le_bytes([header[4], header[5], header[6], header[7]]); - - let weight_max = f32::from_le_bytes([header[8], header[9], header[10], header[11]]); - - let weight_mean = f32::from_le_bytes([header[12], header[13], header[14], header[15]]); - - let weight_std = f32::from_le_bytes([header[16], header[17], header[18], header[19]]); - - let quantized = non_uniform_binner(weight_min, weight_max, weight_mean, weight_std); - let mut weight_bytes: [u8; 2] = [0; 2]; for weight_index in 0..reference_weights.len() { input_bufreader.read_exact(&mut weight_bytes).unwrap(); let weight_interval = f16::from_le_bytes(weight_bytes); - let final_weight = quantized[weight_interval.to_f32() as usize]; -// let final_weight = weight_min + weight_interval.to_f32() * weight_increment; + let final_weight = weight_min + weight_interval.to_f32() * weight_increment; reference_weights[weight_index] = final_weight; } } @@ -190,84 +104,65 @@ mod tests { assert_eq!(out_struct.min, 0.11); } - // #[test] - // fn test_quantize() { - // let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; - // let output_weights = quantize_ffm_weights(&some_random_float_weights); - // assert_eq!(output_weights.len(), 10); - // } + #[test] + fn test_quantize() { + let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; + let output_weights = quantize_ffm_weights(&some_random_float_weights); + assert_eq!(output_weights.len(), 10); + } - // #[test] - // fn test_dequantize() { - // let mut reference_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; - // let old_reference_weights = reference_weights.clone(); - // let quantized_representation = quantize_ffm_weights(&reference_weights); - // let mut all_bytes: Vec = Vec::new(); - // for el in quantized_representation { - // all_bytes.push(el[0]); - // all_bytes.push(el[1]); - // } - // let mut contents = io::Cursor::new(all_bytes); - // dequantize_ffm_weights(&mut contents, &mut reference_weights); + #[test] + fn test_dequantize() { + let mut reference_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; + let old_reference_weights = reference_weights.clone(); + let quantized_representation = quantize_ffm_weights(&reference_weights); + let mut all_bytes: Vec = Vec::new(); + for el in quantized_representation { + all_bytes.push(el[0]); + all_bytes.push(el[1]); + } + let mut contents = io::Cursor::new(all_bytes); + dequantize_ffm_weights(&mut contents, &mut reference_weights); - // let matching = old_reference_weights.iter().zip(&reference_weights).filter(|&(a, b)| a == b).count(); + let matching = old_reference_weights.iter().zip(&reference_weights).filter(|&(a, b)| a == b).count(); - // assert_ne!(matching, 0); + assert_ne!(matching, 0); - // let allowed_eps = 0.0001; - // let mut all_diffs = 0.0; - // for it in old_reference_weights.iter().zip(reference_weights.iter()) { - // let (old, new) = it; - // all_diffs += (old - new).abs(); - // } - // assert!(all_diffs < allowed_eps); - // } + let allowed_eps = 0.0001; + let mut all_diffs = 0.0; + for it in old_reference_weights.iter().zip(reference_weights.iter()) { + let (old, new) = it; + all_diffs += (old - new).abs(); + } + assert!(all_diffs < allowed_eps); + } - // #[test] - // fn test_large_values() { - // let weights = vec![-1e9, 1e9]; - // let quantized = quantize_ffm_weights(&weights); - // let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); - // let mut dequantized = vec![0.0; weights.len()]; - // dequantize_ffm_weights(&mut buffer, &mut dequantized); - // for (w, dw) in weights.iter().zip(&dequantized) { - // assert!((w - dw).abs() / w.abs() < 0.1, "Relative error is too large"); - // } - // } + #[test] + fn test_large_values() { + let weights = vec![-1e9, 1e9]; + let quantized = quantize_ffm_weights(&weights); + let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); + let mut dequantized = vec![0.0; weights.len()]; + dequantize_ffm_weights(&mut buffer, &mut dequantized); + for (w, dw) in weights.iter().zip(&dequantized) { + assert!((w - dw).abs() / w.abs() < 0.1, "Relative error is too large"); + } + } - // #[test] - // fn test_performance() { - // let weights: Vec = (0..10_000_000).map(|x| x as f32).collect(); - // let now = std::time::Instant::now(); - // let quantized = quantize_ffm_weights(&weights); - // assert!(now.elapsed().as_millis() < 300); + #[test] + fn test_performance() { + let weights: Vec = (0..10_000_000).map(|x| x as f32).collect(); + let now = std::time::Instant::now(); + let quantized = quantize_ffm_weights(&weights); + assert!(now.elapsed().as_millis() < 300); - // let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); - // let mut dequantized = vec![0.0; weights.len()]; - // let now = std::time::Instant::now(); - // dequantize_ffm_weights(&mut buffer, &mut dequantized); - // assert!(now.elapsed().as_millis() < 300); - // } - - // #[test] - // fn test_nonuniform() { - // let quantized = non_uniform_binner(-0.42, 0.82, 0.0, 0.02); - // assert!(quantized.len() == NUM_BUCKETS as usize); - - // let bucket = identify_bucket(-0.35, &quantized); - // assert!(bucket == 1137.0); - - // let bucket = identify_bucket(0.23, &quantized); - // assert!(bucket == 60229.0); - - // let now = std::time::Instant::now(); - - // // these calls need to be ultra fast - // for _ in 0..1_000_000 { - // identify_bucket(0.23, &quantized); - // } - // assert!(now.elapsed().as_millis() < 300); - // } + let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); + let mut dequantized = vec![0.0; weights.len()]; + let now = std::time::Instant::now(); + dequantize_ffm_weights(&mut buffer, &mut dequantized); + assert!(now.elapsed().as_millis() < 300); + } + } From b5e78fd192f25dae3f763a93d85f3296f9693b55 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Thu, 14 Dec 2023 09:07:44 +0100 Subject: [PATCH 21/21] back compat --- src/main.rs | 2 +- src/model_instance.rs | 2 +- src/persistence.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main.rs b/src/main.rs index 612c6777..e29ddf52 100644 --- a/src/main.rs +++ b/src/main.rs @@ -150,7 +150,7 @@ fn main2() -> Result<(), Box> { new_regressor_from_filename(filename, true, Option::Some(&cl))?; mi2.optimizer = Optimizer::SGD; if cl.is_present("weight_quantization") { - mi2.dequantize_weights = true; + mi2.dequantize_weights = Some(true); } if let Some(filename1) = inference_regressor_filename { save_regressor_to_filename(filename1, &mi2, &vw2, re_fixed, quantize_weights).unwrap() diff --git a/src/model_instance.rs b/src/model_instance.rs index 8cabca11..82fdb250 100644 --- a/src/model_instance.rs +++ b/src/model_instance.rs @@ -145,7 +145,7 @@ impl ModelInstance { optimizer: Optimizer::SGD, transform_namespaces: feature_transform_parser::NamespaceTransforms::new(), nn_config: NNConfig::new(), - dequantize_weights: Option, + dequantize_weights: Some(false), }; Ok(mi) } diff --git a/src/persistence.rs b/src/persistence.rs index c9745979..13b38f33 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -145,7 +145,7 @@ pub fn new_regressor_from_filename( let mut conversion_flag = false; if cmd_arguments.is_some(){ - quantization_flag = mi.dequantize_weights; + quantization_flag = mi.dequantize_weights.unwrap_or(false); conversion_flag = cmd_arguments.unwrap().is_present("convert_inference_regressor"); }