-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: new text representation for sparse vector (#466)
* feat: new text embedding for sparse vector Signed-off-by: cutecutecat <[email protected]> * fix: use 0-based index Signed-off-by: usamoi <[email protected]> * refactor: use sparse struct to parse Signed-off-by: cutecutecat <[email protected]> * fix: zero-check, sort and tests Signed-off-by: cutecutecat <[email protected]> * fix: new reject case Signed-off-by: cutecutecat <[email protected]> * refactor: use state machine Signed-off-by: cutecutecat <[email protected]> * fix: fsm with more checks Signed-off-by: cutecutecat <[email protected]> * fix: by comments Signed-off-by: cutecutecat <[email protected]> * fix: by comments Signed-off-by: cutecutecat <[email protected]> * fix: remove funcs Signed-off-by: cutecutecat <[email protected]> --------- Signed-off-by: cutecutecat <[email protected]> Signed-off-by: usamoi <[email protected]> Co-authored-by: usamoi <[email protected]>
- Loading branch information
1 parent
d192c40
commit 1a9e0b6
Showing
5 changed files
with
325 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,97 @@ | ||
use num_traits::Zero; | ||
use pgrx::error; | ||
|
||
use super::memory_svecf32::SVecf32Output; | ||
use crate::datatype::memory_svecf32::SVecf32Input; | ||
use crate::datatype::typmod::Typmod; | ||
use crate::error::*; | ||
use base::scalar::*; | ||
use base::vector::*; | ||
use num_traits::Zero; | ||
use pgrx::pg_sys::Oid; | ||
use std::ffi::{CStr, CString}; | ||
use std::fmt::Write; | ||
|
||
#[pgrx::pg_extern(immutable, strict, parallel_safe)] | ||
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output { | ||
use crate::utils::parse::parse_vector; | ||
let reserve = Typmod::parse_from_i32(typmod) | ||
.unwrap() | ||
.dims() | ||
.map(|x| x.get()) | ||
.unwrap_or(0); | ||
let v = parse_vector(input.to_bytes(), reserve as usize, |s| { | ||
s.parse::<F32>().ok() | ||
}); | ||
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { | ||
use crate::utils::parse::parse_pgvector_svector; | ||
let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::<F32>().ok()); | ||
match v { | ||
Err(e) => { | ||
bad_literal(&e.to_string()); | ||
} | ||
Ok(vector) => { | ||
check_value_dims_1048575(vector.len()); | ||
let mut indexes = Vec::<u32>::new(); | ||
let mut values = Vec::<F32>::new(); | ||
for (i, &x) in vector.iter().enumerate() { | ||
if !x.is_zero() { | ||
indexes.push(i as u32); | ||
values.push(x); | ||
Ok((mut indexes, mut values, dims)) => { | ||
check_value_dims_1048575(dims); | ||
// is_sorted | ||
if !indexes.windows(2).all(|i| i[0] <= i[1]) { | ||
assert_eq!(indexes.len(), values.len()); | ||
let n = indexes.len(); | ||
let mut permutation = (0..n).collect::<Vec<_>>(); | ||
permutation.sort_unstable_by_key(|&i| &indexes[i]); | ||
for i in 0..n { | ||
if i == permutation[i] || usize::MAX == permutation[i] { | ||
continue; | ||
} | ||
let index = indexes[i]; | ||
let value = values[i]; | ||
let mut j = i; | ||
while i != permutation[j] { | ||
let next = permutation[j]; | ||
indexes[j] = indexes[permutation[j]]; | ||
values[j] = values[permutation[j]]; | ||
permutation[j] = usize::MAX; | ||
j = next; | ||
} | ||
indexes[j] = index; | ||
values[j] = value; | ||
permutation[j] = usize::MAX; | ||
} | ||
} | ||
let mut last: Option<u32> = None; | ||
for index in indexes.clone() { | ||
if last == Some(index) { | ||
error!( | ||
"Indexes need to be unique, but there are more than one same index {index}" | ||
) | ||
} | ||
if last >= Some(dims as u32) { | ||
error!("Index out of bounds: the dim is {dims} but the index is {index}"); | ||
} | ||
last = Some(index); | ||
{ | ||
let mut i = 0; | ||
let mut j = 0; | ||
while j < values.len() { | ||
if !values[j].is_zero() { | ||
indexes[i] = indexes[j]; | ||
values[i] = values[j]; | ||
i += 1; | ||
} | ||
j += 1; | ||
} | ||
indexes.truncate(i); | ||
values.truncate(i); | ||
} | ||
} | ||
SVecf32Output::new(SVecf32Borrowed::new(vector.len() as u32, &indexes, &values)) | ||
SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values)) | ||
} | ||
} | ||
} | ||
|
||
#[pgrx::pg_extern(immutable, strict, parallel_safe)] | ||
fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString { | ||
let dims = vector.for_borrow().dims(); | ||
let mut buffer = String::new(); | ||
buffer.push('['); | ||
let vec = vector.for_borrow().to_vec(); | ||
let mut iter = vec.iter(); | ||
if let Some(x) = iter.next() { | ||
buffer.push_str(format!("{}", x).as_str()); | ||
} | ||
for x in iter { | ||
buffer.push_str(format!(", {}", x).as_str()); | ||
buffer.push('{'); | ||
let svec = vector.for_borrow(); | ||
let mut need_splitter = false; | ||
for (&index, &value) in svec.indexes().iter().zip(svec.values().iter()) { | ||
match need_splitter { | ||
false => { | ||
write!(buffer, "{}:{}", index, value).unwrap(); | ||
need_splitter = true; | ||
} | ||
true => write!(buffer, ", {}:{}", index, value).unwrap(), | ||
} | ||
} | ||
buffer.push(']'); | ||
write!(buffer, "}}/{}", dims).unwrap(); | ||
CString::new(buffer).unwrap() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.