Skip to content

Commit

Permalink
chore: refactor and clean segment anything function (letsql#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored Aug 30, 2024
1 parent 58566e9 commit 9add741
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions src/tensor_functions/segment_anything.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ fn segment_anything_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

let str_array = as_string_array(&args[0])?;

// let api = hf_hub::api::sync::Api::new().map_err(|e| Execution(e.to_string()))?;
// let api = api.model("lmz/candle-sam".to_string());
// let model = api
// .get(str_array.value(0))
// .map_err(|e| Execution(e.to_string()))?;

let model = std::path::PathBuf::from(str_array.value(0));

let device = Device::Cpu;
Expand Down Expand Up @@ -101,9 +95,9 @@ fn segment_anything_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

let images = as_binary_array(&args[1])?;
let row_count = images.len();
let mut rotated: Vec<Vec<u8>> = (0..row_count).map(|_| Vec::new()).collect();
let mut segmented: Vec<Vec<u8>> = (0..row_count).map(|_| Vec::new()).collect();

for (i, mut bytes) in (0..row_count).zip(rotated.clone()) {
for (i, mut bytes) in (0..row_count).zip(segmented.clone()) {
let image = images.value(i);
let (format, mut image) = binary_to_img(image)?;
let tensor = get_tensor_from_image(Some(sam::IMAGE_SIZE), image.clone())
Expand Down Expand Up @@ -164,10 +158,10 @@ fn segment_anything_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

let _ = image.write_to(&mut writer, format);

rotated[i] = bytes;
segmented[i] = bytes;
}

let result = rotated.iter().map(|v| v.as_slice()).collect();
let result = segmented.iter().map(|v| v.as_slice()).collect();

Ok(Arc::new(LargeBinaryArray::from_vec(result)))
}
Expand Down

0 comments on commit 9add741

Please sign in to comment.