Skip to content

Commit

Permalink
feat: lancedb metadata type to proper type
Browse files Browse the repository at this point in the history
  • Loading branch information
edwinkys committed Jun 4, 2024
1 parent 7ec001d commit 9d928b1
Showing 1 changed file with 86 additions and 11 deletions.
97 changes: 86 additions & 11 deletions src/vectordbs/lancedb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use std::{
use anyhow::{anyhow, Result};
use arrow_array::{
cast::as_string_array,
types::Float32Type,
types::{self, Float32Type},
Array,
BooleanArray,
FixedSizeListArray,
PrimitiveArray,
RecordBatch,
Expand Down Expand Up @@ -91,7 +92,8 @@ async fn vector_chunk_from_batch(
let mut i = 0;
for row in as_string_array(batch.column_by_name(field_name).unwrap()) {
let row = row.ok_or(anyhow!("metadata is null"))?;
let value: serde_json::Value = serde_json::from_str(&row)?;
let value: serde_json::Value =
serde_json::from_str(&row).unwrap_or(serde_json::json!(&row));
metadatas[i].insert(field_name.to_string(), value);
i += 1;
}
Expand Down Expand Up @@ -149,11 +151,32 @@ async fn update_schema_with_missing_fields(
let mut new_fields = Vec::new();
let mut schema = tbl.schema().await?;
// Find the fields that has to be added
for (key, _) in metadata.iter() {
for (key, value) in metadata.iter() {
if schema.field_with_name(key).is_err() {
new_fields.push(Arc::new(Field::new(key, DataType::Utf8, true)));
match value {
serde_json::Value::Number(n) => {
if n.is_f64() {
new_fields.push(Field::new(key, DataType::Float64, true));
} else {
new_fields.push(Field::new(key, DataType::Int64, true));
}
}
serde_json::Value::String(_) => {
new_fields.push(Field::new(key, DataType::Utf8, true));
}
serde_json::Value::Bool(_) => {
new_fields.push(Field::new(key, DataType::Boolean, true));
}
_ => {}
}
}
}

let new_fields = new_fields
.into_iter()
.map(|f| Arc::new(f))
.collect::<Vec<_>>();

if !new_fields.is_empty() {
let mut all_fields = schema.fields().to_vec();
all_fields.extend(new_fields.clone());
Expand Down Expand Up @@ -279,13 +302,18 @@ impl VectorDb for LanceDb {
{
continue;
}
let values = chunks.iter().map(|c| {
c.metadata
.get(field.name())
.map(|v| serde_json::to_string(&v).unwrap())
});
let array = values.collect::<StringArray>();
arrays.push(Arc::new(array));

let mut some_values = vec![]; // Used to check the type.
let mut all_values = vec![];
for chunk in &chunks {
let value = chunk.metadata.get(field.name());
all_values.push(value);
if let Some(value) = value {
some_values.push(value);
}
}

arrays.push(from_serde_json_to_arrow_array(some_values[0], all_values)?);
}

let batches = RecordBatchIterator::new(
Expand Down Expand Up @@ -459,6 +487,53 @@ impl VectorDb for LanceDb {
}
}

fn from_serde_json_to_arrow_array(
match_value: &serde_json::Value,
values: Vec<Option<&serde_json::Value>>,
) -> Result<Arc<dyn Array>> {
let iterator = values.iter();
match match_value {
serde_json::Value::Number(n) => {
if n.is_f64() {
let arr = iterator
.map(|v| match v {
Some(serde_json::Value::Number(n)) => Some(n.as_f64().unwrap()),
_ => None,
})
.collect::<PrimitiveArray<types::Float64Type>>();
Ok(Arc::new(arr))
} else {
let arr = iterator
.map(|v| match v {
Some(serde_json::Value::Number(n)) => Some(n.as_i64().unwrap()),
_ => None,
})
.collect::<PrimitiveArray<types::Int64Type>>();
Ok(Arc::new(arr))
}
}
serde_json::Value::String(_) => {
let arr = iterator
.map(|v| match v {
Some(serde_json::Value::String(s)) => Some(s.to_string()),
_ => None,
})
.collect::<StringArray>();
Ok(Arc::new(arr))
}
serde_json::Value::Bool(_) => {
let arr = iterator
.map(|v| match v {
Some(serde_json::Value::Bool(b)) => Some(*b),
_ => None,
})
.collect::<BooleanArray>();
Ok(Arc::new(arr))
}
_ => Err(anyhow!("unsupported metadata type for field")),
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down

0 comments on commit 9d928b1

Please sign in to comment.