diff --git a/ahnlich/Cargo.lock b/ahnlich/Cargo.lock index 49c4ec6d..79ec5947 100644 --- a/ahnlich/Cargo.lock +++ b/ahnlich/Cargo.lock @@ -107,6 +107,7 @@ dependencies = [ "futures", "hf-hub", "image", + "itertools 0.10.5", "log", "moka", "ndarray", diff --git a/ahnlich/ai/Cargo.toml b/ahnlich/ai/Cargo.toml index e146873a..e07b09d9 100644 --- a/ahnlich/ai/Cargo.toml +++ b/ahnlich/ai/Cargo.toml @@ -46,6 +46,7 @@ moka = { version = "0.12.8", features = ["future"] } tracing-opentelemetry.workspace = true futures.workspace = true tiktoken-rs = "0.5.9" +itertools.workspace = true [dev-dependencies] db = { path = "../db", version = "*" } pretty_assertions.workspace = true diff --git a/ahnlich/ai/src/engine/ai/models.rs b/ahnlich/ai/src/engine/ai/models.rs index 59069812..6a0afc66 100644 --- a/ahnlich/ai/src/engine/ai/models.rs +++ b/ahnlich/ai/src/engine/ai/models.rs @@ -9,6 +9,7 @@ use ahnlich_types::{ keyval::{StoreInput, StoreKey}, }; use image::{GenericImageView, ImageReader}; +use ndarray::ArrayView; use ndarray::{Array, Ix3}; use nonzero_ext::nonzero; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -115,17 +116,17 @@ impl Model { } } - // TODO: model ndarray values is based on length of string or vec, so for now make sure strings - // or vecs have different lengths #[tracing::instrument(skip(self))] pub fn model_ndarray( &self, - storeinput: &Vec, + storeinput: Vec, action_type: &InputAction, ) -> Result, AIProxyError> { let store_keys = match &self.provider { - ModelProviders::FastEmbed(provider) => provider.run_inference(storeinput, action_type)?, - ModelProviders::ORT(provider) => provider.run_inference(storeinput, action_type)? + ModelProviders::FastEmbed(provider) => { + provider.run_inference(storeinput, action_type)? + } + ModelProviders::ORT(provider) => provider.run_inference(storeinput, action_type)?, }; Ok(store_keys) } @@ -246,9 +247,9 @@ pub enum ModelInput { Image(ImageArray), } -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone)] pub struct ImageArray { - array: Array, + array: Array, bytes: Vec, } @@ -273,12 +274,19 @@ impl ImageArray { let channels = img.color().channel_count(); let shape = (height as usize, width as usize, channels as usize); let array = Array::from_shape_vec(shape, img.into_bytes()) - .map_err(|_| AIProxyError::ImageBytesDecodeError)?; + .map_err(|_| AIProxyError::ImageBytesDecodeError)? + .mapv(f32::from); Ok(ImageArray { array, bytes }) } - pub fn get_array(&self) -> &Array { - &self.array + // Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX + pub fn onnx_transform(&mut self) { + self.array.swap_axes(1, 2); + self.array.swap_axes(0, 1); + } + + pub fn view(&self) -> ArrayView { + self.array.view() } pub fn get_bytes(&self) -> &Vec { @@ -313,7 +321,8 @@ impl ImageArray { let flattened_pixels = resized_img.into_bytes(); let array = Array::from_shape_vec(shape, flattened_pixels) - .map_err(|_| AIProxyError::ImageResizeError)?; + .map_err(|_| AIProxyError::ImageResizeError)? + .mapv(f32::from); let bytes = buffer.into_inner(); Ok(ImageArray { array, bytes }) } diff --git a/ahnlich/ai/src/engine/ai/providers/fastembed.rs b/ahnlich/ai/src/engine/ai/providers/fastembed.rs index cd19e18c..f97f43b6 100644 --- a/ahnlich/ai/src/engine/ai/providers/fastembed.rs +++ b/ahnlich/ai/src/engine/ai/providers/fastembed.rs @@ -3,16 +3,16 @@ use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput, Mode use crate::engine::ai::providers::{ProviderTrait, TextPreprocessorTrait}; use crate::error::AIProxyError; use ahnlich_types::ai::AIStoreInputType; +use ahnlich_types::keyval::StoreKey; use fastembed::{EmbeddingModel, ImageEmbedding, InitOptions, TextEmbedding}; use hf_hub::{api::sync::ApiBuilder, Cache}; +use ndarray::Array1; +use rayon::iter::Either; +use rayon::prelude::*; use std::convert::TryFrom; use std::fmt; use std::path::{Path, PathBuf}; -use rayon::prelude::*; -use rayon::iter::Either; -use ndarray::Array1; use tiktoken_rs::{cl100k_base, CoreBPE}; -use ahnlich_types::keyval::StoreKey; #[derive(Default)] pub struct FastEmbedProvider { @@ -186,15 +186,17 @@ impl ProviderTrait for FastEmbedProvider { // TODO (HAKSOAT): When we add model specific tokenizers, add the get tokenizer call here too. } - fn run_inference(&self, inputs: &[ModelInput], action_type: &InputAction) -> Result, AIProxyError> { + fn run_inference( + &self, + inputs: Vec, + action_type: &InputAction, + ) -> Result, AIProxyError> { return if let Some(fastembed_model) = &self.model { - let (string_inputs, image_inputs): (Vec<&String>, Vec<&ImageArray>) = inputs - .par_iter().partition_map(|input| { - match input { + let (string_inputs, image_inputs): (Vec, Vec) = + inputs.into_par_iter().partition_map(|input| match input { ModelInput::Text(value) => Either::Left(value), ModelInput::Image(value) => Either::Right(value), - } - }); + }); if !image_inputs.is_empty() { let store_input_type: AIStoreInputType = AIStoreInputType::Image; @@ -209,20 +211,20 @@ impl ProviderTrait for FastEmbedProvider { }); } let FastEmbedModel::Text(model) = fastembed_model else { - return Err(AIProxyError::AIModelNotSupported) + return Err(AIProxyError::AIModelNotSupported); }; let batch_size = 16; let store_keys = model .embed(string_inputs, Some(batch_size)) - .map_err(|_| AIProxyError::ModelProviderRunInferenceError)? + .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))? .iter() - .try_fold(Vec::new(), |mut accumulator, embedding|{ + .try_fold(Vec::new(), |mut accumulator, embedding| { accumulator.push(StoreKey(>::from(embedding.to_owned()))); Ok(accumulator) }); store_keys } else { Err(AIProxyError::AIModelNotSupported) - } + }; } } diff --git a/ahnlich/ai/src/engine/ai/providers/mod.rs b/ahnlich/ai/src/engine/ai/providers/mod.rs index fa5b4b0e..f2086f67 100644 --- a/ahnlich/ai/src/engine/ai/providers/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/mod.rs @@ -6,9 +6,9 @@ use crate::engine::ai::models::{InputAction, ModelInput}; use crate::engine::ai::providers::fastembed::FastEmbedProvider; use crate::engine::ai::providers::ort::ORTProvider; use crate::error::AIProxyError; +use ahnlich_types::keyval::StoreKey; use std::path::Path; use strum::EnumIter; -use ahnlich_types::keyval::StoreKey; #[derive(Debug, EnumIter)] pub enum ModelProviders { @@ -23,7 +23,7 @@ pub trait ProviderTrait: std::fmt::Debug + Send + Sync { fn get_model(&self) -> Result<(), AIProxyError>; fn run_inference( &self, - input: &[ModelInput], + input: Vec, action_type: &InputAction, ) -> Result, AIProxyError>; } diff --git a/ahnlich/ai/src/engine/ai/providers/ort.rs b/ahnlich/ai/src/engine/ai/providers/ort.rs index 83eb8429..0b805272 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort.rs @@ -1,20 +1,22 @@ use crate::cli::server::SupportedModels; -use crate::engine::ai::models::{InputAction, ImageArray, Model, ModelInput}; +use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput}; use crate::engine::ai::providers::ProviderTrait; use crate::error::AIProxyError; use ahnlich_types::ai::AIStoreInputType; +use fallible_collections::FallibleVec; use hf_hub::{api::sync::ApiBuilder, Cache}; +use itertools::Itertools; use ort::Session; -use rayon::prelude::*; use rayon::iter::Either; +use rayon::prelude::*; +use ahnlich_types::keyval::StoreKey; +use ndarray::{Array1, ArrayView, Axis, Ix3}; use std::convert::TryFrom; use std::default::Default; use std::fmt; use std::path::{Path, PathBuf}; use std::thread::available_parallelism; -use ndarray::{Array, Array1, ArrayView, Axis, Ix3}; -use ahnlich_types::keyval::StoreKey; #[derive(Default)] pub struct ORTProvider { @@ -90,44 +92,46 @@ impl ORTProvider { v.par_iter().map(|&val| val / (norm + epsilon)).collect() } - pub fn batch_inference(&self, inputs: &[&ImageArray]) -> Result, AIProxyError> { + pub fn batch_inference( + &self, + mut inputs: Vec, + ) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Image(model)) => model, _ => return Err(AIProxyError::AIModelNotSupported), }; - let array: Vec> = inputs.par_iter() + let array_views: Vec> = inputs + .par_iter_mut() .map(|image_arr| { - let arr = image_arr.get_array(); - let mut arr = arr.mapv(f32::from); - // Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX - arr.swap_axes(1, 2); - arr.swap_axes(0, 1); - arr + image_arr.onnx_transform(); + image_arr.view() }) .collect(); - // TODO: Figure how to avoid this second par_iter. - let array_views: Vec> = array.par_iter() - .map(|arr| arr.view()).collect(); - - let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views).unwrap(); + let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views) + .map_err(|e| AIProxyError::EmbeddingShapeError(e.to_string()))?; match &model.session { Some(session) => { let session_inputs = ort::inputs![ - model.input_param.as_str() => pixel_values_array.view(), - ].map_err(|_| AIProxyError::ModelProviderPreprocessingError)?; + model.input_param.as_str() => pixel_values_array.view(), + ] + .map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; - let outputs = session.run(session_inputs) - .map_err(|_| AIProxyError::ModelProviderRunInferenceError)?; + let outputs = session + .run(session_inputs) + .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?; let last_hidden_state_key = match outputs.len() { - 1 => outputs.keys().next().unwrap(), + 1 => outputs + .keys() + .next() + .expect("Should not happen as length was checked"), _ => model.output_param.as_str(), }; let output_data = outputs[last_hidden_state_key] .try_extract_tensor::() - .map_err(|_| AIProxyError::ModelProviderPostprocessingError)?; + .map_err(|e| AIProxyError::ModelProviderPostprocessingError(e.to_string()))?; let store_keys = output_data .axis_iter(Axis(0)) .into_par_iter() @@ -138,7 +142,7 @@ impl ORTProvider { .collect(); Ok(store_keys) } - None => Err(AIProxyError::AIModelNotInitialized) + None => Err(AIProxyError::AIModelNotInitialized), } } } @@ -161,7 +165,6 @@ impl ProviderTrait for ORTProvider { }; let ort_model = ORTModel::try_from(&supported_model)?; - let cache = Cache::new(cache_location); let api = ApiBuilder::from_cache(cache) .with_progress(true) @@ -203,7 +206,8 @@ impl ProviderTrait for ORTProvider { let Some(cache_location) = self.cache_location.clone() else { return Err(AIProxyError::CacheLocationNotInitiailized); }; - let supported_model = self.supported_models + let supported_model = self + .supported_models .ok_or(AIProxyError::AIModelNotInitialized)?; let ort_model = ORTModel::try_from(&supported_model)?; @@ -231,14 +235,16 @@ impl ProviderTrait for ORTProvider { } } - fn run_inference(&self, inputs: &[ModelInput], action_type: &InputAction) -> Result, AIProxyError> { - let (string_inputs, image_inputs): (Vec<&String>, Vec<&ImageArray>) = inputs - .par_iter().partition_map(|input| { - match input { + fn run_inference( + &self, + inputs: Vec, + action_type: &InputAction, + ) -> Result, AIProxyError> { + let (string_inputs, image_inputs): (Vec, Vec) = + inputs.into_par_iter().partition_map(|input| match input { ModelInput::Text(value) => Either::Left(value), ModelInput::Image(value) => Either::Right(value), - } - }); + }); if !string_inputs.is_empty() { let store_input_type: AIStoreInputType = AIStoreInputType::RawString; @@ -252,15 +258,11 @@ impl ProviderTrait for ORTProvider { storeinput_type: store_input_type, }); } - let batch_size = 16; - let store_keys = image_inputs - .chunks(batch_size) - .try_fold(Vec::new(), |mut accumulator, batch_inputs|{ - accumulator.extend(self.batch_inference(batch_inputs)?); - Ok(accumulator) - }); - - store_keys + let mut store_keys: Vec<_> = FallibleVec::try_with_capacity(image_inputs.len())?; + for batch_inputs in image_inputs.into_iter().chunks(batch_size).into_iter() { + store_keys.extend(self.batch_inference(batch_inputs.collect())?); + } + Ok(store_keys) } } diff --git a/ahnlich/ai/src/error.rs b/ahnlich/ai/src/error.rs index 17719cb6..35f4d8b2 100644 --- a/ahnlich/ai/src/error.rs +++ b/ahnlich/ai/src/error.rs @@ -101,14 +101,14 @@ pub enum AIProxyError { #[error("Image could not be resized.")] ImageResizeError, - #[error("Model provider failed on preprocessing the input.")] - ModelProviderPreprocessingError, + #[error("Model provider failed on preprocessing the input {0}")] + ModelProviderPreprocessingError(String), - #[error("Model provider failed on running inference.")] - ModelProviderRunInferenceError, + #[error("Model provider failed on running inference {0}")] + ModelProviderRunInferenceError(String), - #[error("Model provider failed on postprocessing the output.")] - ModelProviderPostprocessingError, + #[error("Model provider failed on postprocessing the output {0}")] + ModelProviderPostprocessingError(String), #[error("Model provider failed on tokenization of text inputs.")] ModelTokenizationError, diff --git a/ahnlich/ai/src/manager/mod.rs b/ahnlich/ai/src/manager/mod.rs index c7de6732..d7a3b3a0 100644 --- a/ahnlich/ai/src/manager/mod.rs +++ b/ahnlich/ai/src/manager/mod.rs @@ -67,7 +67,7 @@ impl ModelThread { ) -> ModelThreadResponse { let mut response: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?; let processed_inputs = self.preprocess_store_input(process_action, inputs)?; - let mut store_key = self.model.model_ndarray(&processed_inputs, &action_type)?; + let mut store_key = self.model.model_ndarray(processed_inputs, &action_type)?; response.append(&mut store_key); Ok(response) } @@ -80,9 +80,9 @@ impl ModelThread { ) -> Result, AIProxyError> { let preprocessed_inputs = inputs .into_par_iter() - .try_fold(Vec::new, | mut accumulator, input| { + .try_fold(Vec::new, |mut accumulator, input| { let model_input = ModelInput::try_from(input)?; - let processed_input = match (process_action, &model_input) { + let processed_input = match (process_action, model_input) { (PreprocessAction::Image(image_action), ModelInput::Image(image_array)) => { let output = self.process_image(image_array, image_action)?; Ok(ModelInput::Image(output)) @@ -91,27 +91,24 @@ impl ModelThread { let output = self.preprocess_raw_string(string, string_action)?; Ok(ModelInput::Text(output)) } - _ => { - let input_type: AIStoreInputType = (&model_input).into(); - Err(AIProxyError::PreprocessingMismatchError { - input_type, - preprocess_action: process_action, - }) - } + (_, model_input) => Err(AIProxyError::PreprocessingMismatchError { + input_type: (&model_input).into(), + preprocess_action: process_action, + }), }?; - accumulator.push(processed_input); - Ok::, AIProxyError>(accumulator) + accumulator.push(processed_input); + Ok::, AIProxyError>(accumulator) }) .try_reduce(Vec::new, |mut accumulator, mut item| { - accumulator.append(&mut item); - Ok(accumulator) - })?; + accumulator.append(&mut item); + Ok(accumulator) + })?; Ok(preprocessed_inputs) } #[tracing::instrument(skip(self, input))] fn preprocess_raw_string( &self, - input: &str, + input: String, string_action: StringAction, ) -> Result { let max_token_size = self.model.max_input_token().unwrap_or_else(|| { @@ -129,7 +126,7 @@ impl ModelThread { return Err(AIProxyError::TokenTruncationNotSupported); }; - let tokens = provider.encode_str(input)?; + let tokens = provider.encode_str(&input)?; if tokens.len() > max_token_size.into() { if let StringAction::ErrorIfTokensExceed = string_action { @@ -142,13 +139,13 @@ impl ModelThread { return Ok(processed_input); } }; - Ok(input.to_owned()) + Ok(input) } #[tracing::instrument(skip(self, input))] fn process_image( &self, - input: &ImageArray, + input: ImageArray, image_action: ImageAction, ) -> Result { // process image, return error if max dimensions exceeded @@ -177,7 +174,7 @@ impl ModelThread { } } - Ok(input.clone()) + Ok(input) } } @@ -201,9 +198,7 @@ impl Task for ModelThread { child_span.set_parent(trace_span.context()); let responses = self.input_to_response(inputs, preprocess_action, action_type); - if let Err(e) = - response.send(responses) - { + if let Err(e) = response.send(responses) { log::error!("{} could not send response to channel {e:?}", self.name()); } return TaskState::Continue;