From 9add74136880fcb9d10235e9cf92d68f66f96422 Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Fri, 30 Aug 2024 13:54:46 +0200 Subject: [PATCH] chore: refactor and clean segment anything function (#243) --- src/tensor_functions/segment_anything.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/tensor_functions/segment_anything.rs b/src/tensor_functions/segment_anything.rs index aec9d67f..d96b9308 100644 --- a/src/tensor_functions/segment_anything.rs +++ b/src/tensor_functions/segment_anything.rs @@ -64,12 +64,6 @@ fn segment_anything_inner(args: &[ArrayRef]) -> Result { let str_array = as_string_array(&args[0])?; - // let api = hf_hub::api::sync::Api::new().map_err(|e| Execution(e.to_string()))?; - // let api = api.model("lmz/candle-sam".to_string()); - // let model = api - // .get(str_array.value(0)) - // .map_err(|e| Execution(e.to_string()))?; - let model = std::path::PathBuf::from(str_array.value(0)); let device = Device::Cpu; @@ -101,9 +95,9 @@ fn segment_anything_inner(args: &[ArrayRef]) -> Result { let images = as_binary_array(&args[1])?; let row_count = images.len(); - let mut rotated: Vec> = (0..row_count).map(|_| Vec::new()).collect(); + let mut segmented: Vec> = (0..row_count).map(|_| Vec::new()).collect(); - for (i, mut bytes) in (0..row_count).zip(rotated.clone()) { + for (i, mut bytes) in (0..row_count).zip(segmented.clone()) { let image = images.value(i); let (format, mut image) = binary_to_img(image)?; let tensor = get_tensor_from_image(Some(sam::IMAGE_SIZE), image.clone()) @@ -164,10 +158,10 @@ fn segment_anything_inner(args: &[ArrayRef]) -> Result { let _ = image.write_to(&mut writer, format); - rotated[i] = bytes; + segmented[i] = bytes; } - let result = rotated.iter().map(|v| v.as_slice()).collect(); + let result = segmented.iter().map(|v| v.as_slice()).collect(); Ok(Arc::new(LargeBinaryArray::from_vec(result))) }