Skip to content

Commit

Permalink
feat: new text representation for sparse vector (#466)
Browse files Browse the repository at this point in the history
* feat: new text embedding for sparse vector

Signed-off-by: cutecutecat <[email protected]>

* fix: use 0-based index

Signed-off-by: usamoi <[email protected]>

* refactor: use sparse struct to parse

Signed-off-by: cutecutecat <[email protected]>

* fix: zero-check, sort and tests

Signed-off-by: cutecutecat <[email protected]>

* fix: new reject case

Signed-off-by: cutecutecat <[email protected]>

* refactor: use state machine

Signed-off-by: cutecutecat <[email protected]>

* fix: fsm with more checks

Signed-off-by: cutecutecat <[email protected]>

* fix: by comments

Signed-off-by: cutecutecat <[email protected]>

* fix: by comments

Signed-off-by: cutecutecat <[email protected]>

* fix: remove funcs

Signed-off-by: cutecutecat <[email protected]>

---------

Signed-off-by: cutecutecat <[email protected]>
Signed-off-by: usamoi <[email protected]>
Co-authored-by: usamoi <[email protected]>
  • Loading branch information
cutecutecat and usamoi authored Jun 18, 2024
1 parent d192c40 commit 1a9e0b6
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 93 deletions.
102 changes: 72 additions & 30 deletions src/datatype/text_svecf32.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,97 @@
use num_traits::Zero;
use pgrx::error;

use super::memory_svecf32::SVecf32Output;
use crate::datatype::memory_svecf32::SVecf32Input;
use crate::datatype::typmod::Typmod;
use crate::error::*;
use base::scalar::*;
use base::vector::*;
use num_traits::Zero;
use pgrx::pg_sys::Oid;
use std::ffi::{CStr, CString};
use std::fmt::Write;

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {
use crate::utils::parse::parse_vector;
let reserve = Typmod::parse_from_i32(typmod)
.unwrap()
.dims()
.map(|x| x.get())
.unwrap_or(0);
let v = parse_vector(input.to_bytes(), reserve as usize, |s| {
s.parse::<F32>().ok()
});
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output {
use crate::utils::parse::parse_pgvector_svector;
let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::<F32>().ok());
match v {
Err(e) => {
bad_literal(&e.to_string());
}
Ok(vector) => {
check_value_dims_1048575(vector.len());
let mut indexes = Vec::<u32>::new();
let mut values = Vec::<F32>::new();
for (i, &x) in vector.iter().enumerate() {
if !x.is_zero() {
indexes.push(i as u32);
values.push(x);
Ok((mut indexes, mut values, dims)) => {
check_value_dims_1048575(dims);
// is_sorted
if !indexes.windows(2).all(|i| i[0] <= i[1]) {
assert_eq!(indexes.len(), values.len());
let n = indexes.len();
let mut permutation = (0..n).collect::<Vec<_>>();
permutation.sort_unstable_by_key(|&i| &indexes[i]);
for i in 0..n {
if i == permutation[i] || usize::MAX == permutation[i] {
continue;
}
let index = indexes[i];
let value = values[i];
let mut j = i;
while i != permutation[j] {
let next = permutation[j];
indexes[j] = indexes[permutation[j]];
values[j] = values[permutation[j]];
permutation[j] = usize::MAX;
j = next;
}
indexes[j] = index;
values[j] = value;
permutation[j] = usize::MAX;
}
}
let mut last: Option<u32> = None;
for index in indexes.clone() {
if last == Some(index) {
error!(
"Indexes need to be unique, but there are more than one same index {index}"
)
}
if last >= Some(dims as u32) {
error!("Index out of bounds: the dim is {dims} but the index is {index}");
}
last = Some(index);
{
let mut i = 0;
let mut j = 0;
while j < values.len() {
if !values[j].is_zero() {
indexes[i] = indexes[j];
values[i] = values[j];
i += 1;
}
j += 1;
}
indexes.truncate(i);
values.truncate(i);
}
}
SVecf32Output::new(SVecf32Borrowed::new(vector.len() as u32, &indexes, &values))
SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values))
}
}
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString {
let dims = vector.for_borrow().dims();
let mut buffer = String::new();
buffer.push('[');
let vec = vector.for_borrow().to_vec();
let mut iter = vec.iter();
if let Some(x) = iter.next() {
buffer.push_str(format!("{}", x).as_str());
}
for x in iter {
buffer.push_str(format!(", {}", x).as_str());
buffer.push('{');
let svec = vector.for_borrow();
let mut need_splitter = false;
for (&index, &value) in svec.indexes().iter().zip(svec.values().iter()) {
match need_splitter {
false => {
write!(buffer, "{}:{}", index, value).unwrap();
need_splitter = true;
}
true => write!(buffer, ", {}:{}", index, value).unwrap(),
}
}
buffer.push(']');
write!(buffer, "}}/{}", dims).unwrap();
CString::new(buffer).unwrap()
}
195 changes: 194 additions & 1 deletion src/utils/parse.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use num_traits::Zero;
use thiserror::Error;

#[derive(Debug, Error)]
#[derive(Debug, Error, PartialEq)]
pub enum ParseVectorError {
#[error("The input string is empty.")]
EmptyString {},
Expand Down Expand Up @@ -83,3 +84,195 @@ where
}
Ok(vector)
}

#[derive(PartialEq, Debug, Clone)]
enum ParseState {
Start,
LeftBracket,
Index,
Colon,
Value,
Comma,
RightBracket,
Splitter,
Dims,
}

#[inline(always)]
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
input: &[u8],
f: F,
) -> Result<(Vec<u32>, Vec<T>, usize), ParseVectorError>
where
F: Fn(&str) -> Option<T>,
{
use arrayvec::ArrayVec;
if input.is_empty() {
return Err(ParseVectorError::EmptyString {});
}
let mut token: ArrayVec<u8, 48> = ArrayVec::new();
let mut indexes = Vec::<u32>::new();
let mut values = Vec::<T>::new();

let mut state = ParseState::Start;
for (position, c) in input.iter().copied().enumerate() {
state = match (&state, c) {
(_, b' ') => state,
(ParseState::Start, b'{') => ParseState::LeftBracket,
(
ParseState::LeftBracket | ParseState::Index | ParseState::Comma,
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-',
) => {
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
ParseState::Index
}
(ParseState::Colon, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => {
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
ParseState::Value
}
(ParseState::LeftBracket | ParseState::Comma, b'}') => ParseState::RightBracket,
(ParseState::Index, b':') => {
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
let index = s
.parse::<u32>()
.map_err(|_| ParseVectorError::BadParsing { position })?;
indexes.push(index);
token.clear();
ParseState::Colon
}
(ParseState::Value, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => {
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
ParseState::Value
}
(ParseState::Value, b',') => {
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
values.push(num);
token.clear();
ParseState::Comma
}
(ParseState::Value, b'}') => {
if token.is_empty() {
return Err(ParseVectorError::TooShortNumber { position });
}
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
values.push(num);
token.clear();
ParseState::RightBracket
}
(ParseState::RightBracket, b'/') => ParseState::Splitter,
(ParseState::Dims | ParseState::Splitter, b'0'..=b'9') => {
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
ParseState::Dims
}
(_, _) => {
return Err(ParseVectorError::BadCharacter { position });
}
}
}
if state != ParseState::Dims {
return Err(ParseVectorError::BadParsing {
position: input.len(),
});
}
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
let dims = s
.parse::<usize>()
.map_err(|_| ParseVectorError::BadParsing {
position: input.len(),
})?;
Ok((indexes, values, dims))
}

#[cfg(test)]
mod tests {
use base::scalar::F32;

use super::*;

#[test]
fn test_svector_parse_accept() {
let exprs: Vec<(&str, (Vec<u32>, Vec<F32>, usize))> = vec![
("{}/1", (vec![], vec![], 1)),
("{0:1}/1", (vec![0], vec![F32(1.0)], 1)),
(
"{0:1, 1:-2, }/2",
(vec![0, 1], vec![F32(1.0), F32(-2.0)], 2),
),
("{0:1, 1:1.5}/2", (vec![0, 1], vec![F32(1.0), F32(1.5)], 2)),
(
"{0:+3, 2:-4.1}/3",
(vec![0, 2], vec![F32(3.0), F32(-4.1)], 3),
),
(
"{0:0, 1:0, 2:0}/3",
(vec![0, 1, 2], vec![F32(0.0), F32(0.0), F32(0.0)], 3),
),
(
"{3:3, 2:2, 1:1, 0:0}/4",
(
vec![3, 2, 1, 0],
vec![F32(3.0), F32(2.0), F32(1.0), F32(0.0)],
4,
),
),
];
for (e, parsed) in exprs {
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret);
assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e);
}
}

#[test]
fn test_svector_parse_reject() {
let exprs: Vec<(&str, ParseVectorError)> = vec![
("{", ParseVectorError::BadParsing { position: 1 }),
("}", ParseVectorError::BadCharacter { position: 0 }),
("{:", ParseVectorError::BadCharacter { position: 1 }),
(":}", ParseVectorError::BadCharacter { position: 0 }),
(
"{0:1, 1:2, 2:3}",
ParseVectorError::BadParsing { position: 15 },
),
(
"{0:1, 1:2, 2:3",
ParseVectorError::BadParsing { position: 14 },
),
("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 11 }),
("{0}/5", ParseVectorError::BadCharacter { position: 2 }),
("{0:}/5", ParseVectorError::BadCharacter { position: 3 }),
("{:0}/5", ParseVectorError::BadCharacter { position: 1 }),
(
"{0:, 1:2}/5",
ParseVectorError::BadCharacter { position: 3 },
),
("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }),
("/2", ParseVectorError::BadCharacter { position: 0 }),
("{}/1/2", ParseVectorError::BadCharacter { position: 4 }),
(
"{0:1, 1:2}/4/2",
ParseVectorError::BadCharacter { position: 12 },
),
("{}/-4", ParseVectorError::BadCharacter { position: 3 }),
(
"{1,2,3,4}/5",
ParseVectorError::BadCharacter { position: 2 },
),
];
for (e, err) in exprs {
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
assert!(ret.is_err(), "at expr {:?}: {:?}", e, ret);
assert_eq!(ret.unwrap_err(), err, "parsed at expr {:?}", e);
}
}
}
13 changes: 5 additions & 8 deletions tests/sqllogictest/sparse.slt
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ CREATE INDEX ON t USING vectors (val svector_cos_ops)
WITH (options = "[indexing.hnsw]");

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '{1:3,2:1}/6'::svector limit 10) t2;
----
10

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '{1:3,2:1}/6'::svector limit 10) t2;
----
10

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '{1:3,2:1}/6'::svector limit 10) t2;
----
10

Expand All @@ -40,21 +40,18 @@ DROP TABLE t;
query I
SELECT to_svector(5, '{1,2}', '{1,2}');
----
[0, 1, 2, 0, 0]
{1:1, 2:2}/5

query I
SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}');
----
[0, 2, 0, 0, 0]
{1:2}/5

statement error Lengths of index and value are not matched.
SELECT to_svector(5, '{1,2,3}', '{1,2}');

statement error Duplicated index.
SELECT to_svector(5, '{1,1}', '{1,2}');

statement ok
SELECT replace(replace(array_agg(RANDOM())::real[]::text, '{', '['), '}', ']')::svector FROM generate_series(1, 100000);

statement ok
SELECT to_svector(200000, array_agg(val)::integer[], array_agg(val)::real[]) FROM generate_series(1, 100000) AS VAL;
Loading

0 comments on commit 1a9e0b6

Please sign in to comment.