Skip to content

Commit

Permalink
GETKEY and DELKEY command inclusions
Browse files Browse the repository at this point in the history
  • Loading branch information
deven96 committed Sep 25, 2024
1 parent 2bd59cd commit 9335829
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 12 deletions.
1 change: 1 addition & 0 deletions ahnlich/dsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ ahnlich_types = { path = "../types", version = "*" }
pest = "2.7.13"
pest_derive = "2.7.13"
thiserror.workspace = true
ndarray.workspace = true

85 changes: 82 additions & 3 deletions ahnlich/dsl/src/db.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use ahnlich_types::{
db::DBQuery, keyval::StoreName, metadata::MetadataKey, similarity::NonLinearAlgorithm,
db::DBQuery,
keyval::{StoreKey, StoreName},
metadata::MetadataKey,
similarity::NonLinearAlgorithm,
};
use ndarray::Array1;
use pest::Parser;
use pest_derive::Parser;

Expand All @@ -17,6 +21,19 @@ fn to_non_linear(input: &str) -> Option<NonLinearAlgorithm> {
}
}

fn parse_multi_f32_array(f32_arrays_pair: pest::iterators::Pair<Rule>) -> Vec<StoreKey> {
f32_arrays_pair.into_inner().map(parse_f32_array).collect()
}

fn parse_f32_array(pair: pest::iterators::Pair<Rule>) -> StoreKey {
StoreKey(Array1::from_iter(pair.into_inner().map(|f32_pair| {
f32_pair
.as_str()
.parse::<f32>()
.expect("Cannot parse single f32 num")
})))
}

// Parse raw strings separated by ; into a Vec<DBQuery>. Examples include but are not restricted
// to
//
Expand All @@ -29,12 +46,12 @@ fn to_non_linear(input: &str) -> Option<NonLinearAlgorithm> {
// DROPPREDINDEX IF EXISTS (key1, key2) in store_name
// CREATENONLINEARALGORITHMINDEX (kdtree) in store_name
// DROPNONLINEARALGORITHMINDEX IF EXISTS (kdtree) in store_name
// GETKEY ((1.0, 2.0), (3.0, 4.0)) IN my_store
// DELKEY ((1.2, 3.0), (5.6, 7.8)) IN my_store
//
// #TODO
// SET
// DELKEY
// CREATESTORE
// GETKEY
// GETPRED
// GETSIMN
pub fn parse_db_query(input: &str) -> Result<Vec<DBQuery>, DslError> {
Expand All @@ -49,6 +66,28 @@ pub fn parse_db_query(input: &str) -> Result<Vec<DBQuery>, DslError> {
Rule::list_clients => DBQuery::ListClients,
Rule::list_stores => DBQuery::ListStores,
Rule::info_server => DBQuery::InfoServer,
Rule::get_key => {
let mut inner_pairs = statement.into_inner().peekable();
let f32_arrays_pair = inner_pairs.next().unwrap();
let keys = parse_multi_f32_array(f32_arrays_pair);

let store = inner_pairs.next().unwrap().as_str();
DBQuery::GetKey {
store: StoreName(store.to_string()),
keys,
}
}
Rule::del_key => {
let mut inner_pairs = statement.into_inner().peekable();
let f32_arrays_pair = inner_pairs.next().unwrap();
let keys = parse_multi_f32_array(f32_arrays_pair);

let store = inner_pairs.next().unwrap().as_str();
DBQuery::DelKey {
store: StoreName(store.to_string()),
keys,
}
}
Rule::create_non_linear_algorithm_index => {
let mut inner_pairs = statement.into_inner();
let index_name_pairs = inner_pairs
Expand Down Expand Up @@ -324,4 +363,44 @@ mod tests {
}]
);
}

#[test]
fn test_get_key_parse() {
let input = r#"getkey ((a, b, c), (3.0, 4.0)) in 1234"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 38));
let input = r#"getkey ((1, 2, 3), (3.0, 4.0)) in 1234"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::GetKey {
store: StoreName("1234".to_string()),
keys: vec![
StoreKey(Array1::from_iter([1.0, 2.0, 3.0])),
StoreKey(Array1::from_iter([3.0, 4.0])),
],
}]
);
}

#[test]
fn test_del_key_parse() {
let input = r#"DELKEY ((a, b, c), (3.0, 4.0)) in 1234"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 38));
let input = r#"DELKEY ((1, 2, 3), (3.0, 4.0)) in 1234"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::DelKey {
store: StoreName("1234".to_string()),
keys: vec![
StoreKey(Array1::from_iter([1.0, 2.0, 3.0])),
StoreKey(Array1::from_iter([3.0, 4.0])),
],
}]
);
}
}
31 changes: 22 additions & 9 deletions ahnlich/dsl/src/syntax/db.pest
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ whitespace = _{ " " | "\t" }
query = _{ statement ~ (";" ~ statement) * } // Matches multiple statements separated by ;

statement = _{
ping |
info_server |
list_stores |
list_clients |
drop_store |
create_pred_index |
drop_pred_index |
create_non_linear_algorithm_index |
drop_non_linear_algorithm_index |
ping |
info_server |
list_stores |
list_clients |
drop_store |
create_pred_index |
drop_pred_index |
create_non_linear_algorithm_index |
drop_non_linear_algorithm_index |
get_key |
del_key |
invalid_statement
}

Expand All @@ -24,6 +26,8 @@ create_pred_index = { whitespace* ~ ^"createpredindex" ~ whitespace* ~ "(" ~ ind
create_non_linear_algorithm_index = { whitespace* ~ ^"createnonlinearalgorithmindex" ~ whitespace* ~ "(" ~ non_linear_algorithms ~ ")" ~ whitespace* ~ ^"in" ~ whitespace* ~ store_name}
drop_pred_index = { whitespace* ~ ^"droppredindex" ~ whitespace* ~ (if_exists)? ~ "(" ~ index_names ~ ")" ~ whitespace* ~ ^"in" ~whitespace* ~ store_name }
drop_non_linear_algorithm_index = { whitespace* ~ ^"dropnonlinearalgorithmindex" ~ whitespace* ~ (if_exists)? ~ "(" ~ non_linear_algorithms ~ ")" ~ whitespace* ~ ^"in" ~whitespace* ~ store_name }
get_key = { whitespace* ~ ^"getkey" ~ whitespace* ~ "(" ~ f32_arrays ~ ")" ~ whitespace* ~ ^"in" ~ whitespace* ~ store_name }
del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ f32_arrays ~ ")" ~ whitespace* ~ ^"in" ~ whitespace* ~ store_name }

if_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"exists" ~ whitespace* }

Expand All @@ -34,6 +38,15 @@ non_linear_algorithm = { ^"kdtree" }
non_linear_algorithms = { non_linear_algorithm ~ (whitespace* ~ "," ~ whitespace* ~ non_linear_algorithm)* }
index_names = { index_name ~ (whitespace* ~ "," ~ whitespace* ~ index_name)* }

// Floating point number
f32 = { ASCII_DIGIT+ ~ ("." ~ ASCII_DIGIT+)? }

// Array of floating-point numbers
f32_array = { "(" ~ f32 ~ (whitespace* ~ "," ~ whitespace* ~ f32)* ~ ")"}

// List of f32 arrays (comma-separated)
f32_arrays = { f32_array ~ (whitespace* ~ "," ~ whitespace* ~ f32_array)* }

// Catch-all rule for invalid statements
invalid_statement = { whitespace* ~ (!";" ~ ANY)+ } // Match anything that isn't a valid statement

0 comments on commit 9335829

Please sign in to comment.