Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storing f32 to later avoid allocation a Vec of inputs #137

Merged
merged 2 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions ahnlich/ai/src/engine/ai/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use ahnlich_types::{
keyval::{StoreInput, StoreKey},
};
use image::{GenericImageView, ImageReader};
use ndarray::ArrayView;
use ndarray::{Array, Ix3};
use nonzero_ext::nonzero;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
Expand Down Expand Up @@ -115,17 +116,17 @@ impl Model {
}
}

// TODO: model ndarray values is based on length of string or vec, so for now make sure strings
// or vecs have different lengths
#[tracing::instrument(skip(self))]
pub fn model_ndarray(
&self,
storeinput: &Vec<ModelInput>,
storeinput: Vec<ModelInput>,
action_type: &InputAction,
) -> Result<Vec<StoreKey>, AIProxyError> {
let store_keys = match &self.provider {
ModelProviders::FastEmbed(provider) => provider.run_inference(storeinput, action_type)?,
ModelProviders::ORT(provider) => provider.run_inference(storeinput, action_type)?
ModelProviders::FastEmbed(provider) => {
provider.run_inference(storeinput, action_type)?
}
ModelProviders::ORT(provider) => provider.run_inference(storeinput, action_type)?,
};
Ok(store_keys)
}
Expand Down Expand Up @@ -246,9 +247,9 @@ pub enum ModelInput {
Image(ImageArray),
}

#[derive(Debug, Clone, PartialEq, Hash, Eq)]
#[derive(Debug, Clone)]
pub struct ImageArray {
array: Array<u8, Ix3>,
array: Array<f32, Ix3>,
bytes: Vec<u8>,
}

Expand All @@ -273,12 +274,19 @@ impl ImageArray {
let channels = img.color().channel_count();
let shape = (height as usize, width as usize, channels as usize);
let array = Array::from_shape_vec(shape, img.into_bytes())
.map_err(|_| AIProxyError::ImageBytesDecodeError)?;
.map_err(|_| AIProxyError::ImageBytesDecodeError)?
.mapv(f32::from);
Ok(ImageArray { array, bytes })
}

pub fn get_array(&self) -> &Array<u8, Ix3> {
&self.array
// Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX
pub fn onnx_transform(&mut self) {
self.array.swap_axes(1, 2);
self.array.swap_axes(0, 1);
}

pub fn view(&self) -> ArrayView<f32, Ix3> {
self.array.view()
}

pub fn get_bytes(&self) -> &Vec<u8> {
Expand Down Expand Up @@ -313,7 +321,8 @@ impl ImageArray {

let flattened_pixels = resized_img.into_bytes();
let array = Array::from_shape_vec(shape, flattened_pixels)
.map_err(|_| AIProxyError::ImageResizeError)?;
.map_err(|_| AIProxyError::ImageResizeError)?
.mapv(f32::from);
let bytes = buffer.into_inner();
Ok(ImageArray { array, bytes })
}
Expand Down
30 changes: 16 additions & 14 deletions ahnlich/ai/src/engine/ai/providers/fastembed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput, Mode
use crate::engine::ai::providers::{ProviderTrait, TextPreprocessorTrait};
use crate::error::AIProxyError;
use ahnlich_types::ai::AIStoreInputType;
use ahnlich_types::keyval::StoreKey;
use fastembed::{EmbeddingModel, ImageEmbedding, InitOptions, TextEmbedding};
use hf_hub::{api::sync::ApiBuilder, Cache};
use ndarray::Array1;
use rayon::iter::Either;
use rayon::prelude::*;
use std::convert::TryFrom;
use std::fmt;
use std::path::{Path, PathBuf};
use rayon::prelude::*;
use rayon::iter::Either;
use ndarray::Array1;
use tiktoken_rs::{cl100k_base, CoreBPE};
use ahnlich_types::keyval::StoreKey;

#[derive(Default)]
pub struct FastEmbedProvider {
Expand Down Expand Up @@ -186,15 +186,17 @@ impl ProviderTrait for FastEmbedProvider {
// TODO (HAKSOAT): When we add model specific tokenizers, add the get tokenizer call here too.
}

fn run_inference(&self, inputs: &[ModelInput], action_type: &InputAction) -> Result<Vec<StoreKey>, AIProxyError> {
fn run_inference(
&self,
inputs: Vec<ModelInput>,
action_type: &InputAction,
) -> Result<Vec<StoreKey>, AIProxyError> {
return if let Some(fastembed_model) = &self.model {
let (string_inputs, image_inputs): (Vec<&String>, Vec<&ImageArray>) = inputs
.par_iter().partition_map(|input| {
match input {
let (string_inputs, image_inputs): (Vec<String>, Vec<ImageArray>) =
inputs.into_par_iter().partition_map(|input| match input {
ModelInput::Text(value) => Either::Left(value),
ModelInput::Image(value) => Either::Right(value),
}
});
});

if !image_inputs.is_empty() {
let store_input_type: AIStoreInputType = AIStoreInputType::Image;
Expand All @@ -209,20 +211,20 @@ impl ProviderTrait for FastEmbedProvider {
});
}
let FastEmbedModel::Text(model) = fastembed_model else {
return Err(AIProxyError::AIModelNotSupported)
return Err(AIProxyError::AIModelNotSupported);
};
let batch_size = 16;
let store_keys = model
.embed(string_inputs, Some(batch_size))
.map_err(|_| AIProxyError::ModelProviderRunInferenceError)?
.map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?
.iter()
.try_fold(Vec::new(), |mut accumulator, embedding|{
.try_fold(Vec::new(), |mut accumulator, embedding| {
accumulator.push(StoreKey(<Array1<f32>>::from(embedding.to_owned())));
Ok(accumulator)
});
store_keys
} else {
Err(AIProxyError::AIModelNotSupported)
}
};
}
}
4 changes: 2 additions & 2 deletions ahnlich/ai/src/engine/ai/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use crate::engine::ai::models::{InputAction, ModelInput};
use crate::engine::ai::providers::fastembed::FastEmbedProvider;
use crate::engine::ai::providers::ort::ORTProvider;
use crate::error::AIProxyError;
use ahnlich_types::keyval::StoreKey;
use std::path::Path;
use strum::EnumIter;
use ahnlich_types::keyval::StoreKey;

#[derive(Debug, EnumIter)]
pub enum ModelProviders {
Expand All @@ -23,7 +23,7 @@ pub trait ProviderTrait: std::fmt::Debug + Send + Sync {
fn get_model(&self) -> Result<(), AIProxyError>;
fn run_inference(
&self,
input: &[ModelInput],
input: Vec<ModelInput>,
action_type: &InputAction,
) -> Result<Vec<StoreKey>, AIProxyError>;
}
Expand Down
79 changes: 37 additions & 42 deletions ahnlich/ai/src/engine/ai/providers/ort.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use crate::cli::server::SupportedModels;
use crate::engine::ai::models::{InputAction, ImageArray, Model, ModelInput};
use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput};
use crate::engine::ai::providers::ProviderTrait;
use crate::error::AIProxyError;
use ahnlich_types::ai::AIStoreInputType;
use hf_hub::{api::sync::ApiBuilder, Cache};
use ort::Session;
use rayon::prelude::*;
use rayon::iter::Either;
use rayon::prelude::*;

use ahnlich_types::keyval::StoreKey;
use ndarray::{Array1, ArrayView, Axis, Ix3};
use std::convert::TryFrom;
use std::default::Default;
use std::fmt;
use std::path::{Path, PathBuf};
use std::thread::available_parallelism;
use ndarray::{Array, Array1, ArrayView, Axis, Ix3};
use ahnlich_types::keyval::StoreKey;

#[derive(Default)]
pub struct ORTProvider {
Expand Down Expand Up @@ -90,44 +90,46 @@ impl ORTProvider {
v.par_iter().map(|&val| val / (norm + epsilon)).collect()
}

pub fn batch_inference(&self, inputs: &[&ImageArray]) -> Result<Vec<StoreKey>, AIProxyError> {
pub fn batch_inference(
&self,
mut inputs: Vec<ImageArray>,
) -> Result<Vec<StoreKey>, AIProxyError> {
let model = match &self.model {
Some(ORTModel::Image(model)) => model,
_ => return Err(AIProxyError::AIModelNotSupported),
};

let array: Vec<Array<f32, Ix3>> = inputs.par_iter()
let array_views: Vec<ArrayView<f32, Ix3>> = inputs
.par_iter_mut()
.map(|image_arr| {
let arr = image_arr.get_array();
let mut arr = arr.mapv(f32::from);
// Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX
arr.swap_axes(1, 2);
arr.swap_axes(0, 1);
arr
image_arr.onnx_transform();
image_arr.view()
})
.collect();

// TODO: Figure how to avoid this second par_iter.
let array_views: Vec<ArrayView<f32, Ix3>> = array.par_iter()
.map(|arr| arr.view()).collect();

let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views).unwrap();
let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views)
.map_err(|e| AIProxyError::EmbeddingShapeError(e.to_string()))?;
match &model.session {
Some(session) => {
let session_inputs = ort::inputs![
model.input_param.as_str() => pixel_values_array.view(),
].map_err(|_| AIProxyError::ModelProviderPreprocessingError)?;
model.input_param.as_str() => pixel_values_array.view(),
]
.map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?;

let outputs = session.run(session_inputs)
.map_err(|_| AIProxyError::ModelProviderRunInferenceError)?;
let outputs = session
.run(session_inputs)
.map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?;
let last_hidden_state_key = match outputs.len() {
1 => outputs.keys().next().unwrap(),
1 => outputs
.keys()
.next()
.expect("Should not happen as length was checked"),
_ => model.output_param.as_str(),
};

let output_data = outputs[last_hidden_state_key]
.try_extract_tensor::<f32>()
.map_err(|_| AIProxyError::ModelProviderPostprocessingError)?;
.map_err(|e| AIProxyError::ModelProviderPostprocessingError(e.to_string()))?;
let store_keys = output_data
.axis_iter(Axis(0))
.into_par_iter()
Expand All @@ -138,7 +140,7 @@ impl ORTProvider {
.collect();
Ok(store_keys)
}
None => Err(AIProxyError::AIModelNotInitialized)
None => Err(AIProxyError::AIModelNotInitialized),
}
}
}
Expand All @@ -161,7 +163,6 @@ impl ProviderTrait for ORTProvider {
};
let ort_model = ORTModel::try_from(&supported_model)?;


let cache = Cache::new(cache_location);
let api = ApiBuilder::from_cache(cache)
.with_progress(true)
Expand Down Expand Up @@ -203,7 +204,8 @@ impl ProviderTrait for ORTProvider {
let Some(cache_location) = self.cache_location.clone() else {
return Err(AIProxyError::CacheLocationNotInitiailized);
};
let supported_model = self.supported_models
let supported_model = self
.supported_models
.ok_or(AIProxyError::AIModelNotInitialized)?;
let ort_model = ORTModel::try_from(&supported_model)?;

Expand Down Expand Up @@ -231,14 +233,16 @@ impl ProviderTrait for ORTProvider {
}
}

fn run_inference(&self, inputs: &[ModelInput], action_type: &InputAction) -> Result<Vec<StoreKey>, AIProxyError> {
let (string_inputs, image_inputs): (Vec<&String>, Vec<&ImageArray>) = inputs
.par_iter().partition_map(|input| {
match input {
fn run_inference(
&self,
inputs: Vec<ModelInput>,
action_type: &InputAction,
) -> Result<Vec<StoreKey>, AIProxyError> {
let (string_inputs, image_inputs): (Vec<String>, Vec<ImageArray>) =
inputs.into_par_iter().partition_map(|input| match input {
ModelInput::Text(value) => Either::Left(value),
ModelInput::Image(value) => Either::Right(value),
}
});
});

if !string_inputs.is_empty() {
let store_input_type: AIStoreInputType = AIStoreInputType::RawString;
Expand All @@ -252,15 +256,6 @@ impl ProviderTrait for ORTProvider {
storeinput_type: store_input_type,
});
}

let batch_size = 16;
let store_keys = image_inputs
.chunks(batch_size)
.try_fold(Vec::new(), |mut accumulator, batch_inputs|{
accumulator.extend(self.batch_inference(batch_inputs)?);
Ok(accumulator)
});
HAKSOAT marked this conversation as resolved.
Show resolved Hide resolved

store_keys
self.batch_inference(image_inputs)
}
}
12 changes: 6 additions & 6 deletions ahnlich/ai/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ pub enum AIProxyError {
#[error("Image could not be resized.")]
ImageResizeError,

#[error("Model provider failed on preprocessing the input.")]
ModelProviderPreprocessingError,
#[error("Model provider failed on preprocessing the input {0}")]
ModelProviderPreprocessingError(String),

#[error("Model provider failed on running inference.")]
ModelProviderRunInferenceError,
#[error("Model provider failed on running inference {0}")]
ModelProviderRunInferenceError(String),

#[error("Model provider failed on postprocessing the output.")]
ModelProviderPostprocessingError,
#[error("Model provider failed on postprocessing the output {0}")]
ModelProviderPostprocessingError(String),

#[error("Model provider failed on tokenization of text inputs.")]
ModelTokenizationError,
Expand Down
Loading
Loading