Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds hmget command #487

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat: adds hmget command
Adds `hmget` command to resp protocol implementation.
hderms authored and twitter-dermot committed Nov 16, 2022
commit 397ee49bbff395720c57b8600c9051f69c8ef264
137 changes: 137 additions & 0 deletions src/protocol/resp/src/request/hmget.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use super::*;
use std::io::{Error, ErrorKind};
use std::sync::Arc;

type ArcByteSlice = Arc<Box<[u8]>>;
#[derive(Debug, PartialEq, Eq)]
pub struct HmGetRequest {
key: ArcByteSlice,
fields: Arc<Box<[ArcByteSlice]>>,
}

impl HmGetRequest {
pub fn key(&self) -> &[u8] {
&self.key
}

pub fn fields(&self) -> Box<[&[u8]]> {
self.fields
.iter()
.map(|f| &***f)
.collect::<Vec<&[u8]>>()
.into_boxed_slice()
}
}

impl TryFrom<Message> for HmGetRequest {
type Error = Error;

fn try_from(other: Message) -> Result<Self, Error> {
if let Message::Array(array) = other {
if array.inner.is_none() {
return Err(Error::new(ErrorKind::Other, "malformed command"));
}

let mut array = array.inner.unwrap();

if array.len() <= 2 {
return Err(Error::new(ErrorKind::Other, "malformed command"));
}

let key = take_bulk_string(&mut array)?;
if key.is_empty() {
return Err(Error::new(ErrorKind::Other, "malformed command"));
}

let mut fields = Vec::with_capacity(array.len());
while array.len() >= 2 {
let field = take_bulk_string(&mut array)?;
if field.is_empty() {
return Err(Error::new(ErrorKind::Other, "malformed command"));
}

fields.push(field);
}

let f = Arc::new(Box::<[ArcByteSlice]>::from(fields));
Ok(Self { key, fields: f })
} else {
Err(Error::new(ErrorKind::Other, "malformed command"))
}
}
}

impl From<&HmGetRequest> for Message {
fn from(other: &HmGetRequest) -> Message {
let mut v = vec![
Message::bulk_string(b"HMGET"),
Message::BulkString(BulkString::from(other.key.clone())),
];
for kv in (*other.fields).iter() {
v.push(Message::BulkString(BulkString::from(kv.clone())));
}

Message::Array(Array { inner: Some(v) })
}
}

impl Compose for HmGetRequest {
fn compose(&self, buf: &mut dyn BufMut) -> usize {
let message = Message::from(self);
message.compose(buf)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn parser() {
let parser = RequestParser::new();

//1 field
if let Request::HmGet(request) = parser.parse(b"hmget key field1\r\n").unwrap().into_inner()
{
assert_eq!(request.key(), b"key");
assert_eq!(request.fields().len(), 1);
assert_eq!(request.fields()[0], b"field1");
} else {
panic!("invalid parse result");
}

//2 fields
if let Request::HmGet(request) = parser
.parse(b"hmget key field1 field2\r\n")
.unwrap()
.into_inner()
{
assert_eq!(request.key(), b"key");
assert_eq!(request.fields().len(), 2);
assert_eq!(request.fields()[0], b"field1");
assert_eq!(request.fields()[1], b"field2");
} else {
panic!("invalid parse result");
}

//3 fields
if let Request::HmGet(request) = parser
.parse(b"hmget key field1 field2 42\r\n")
.unwrap()
.into_inner()
{
assert_eq!(request.key(), b"key");
assert_eq!(request.fields().len(), 3);
assert_eq!(request.fields()[0], b"field1");
assert_eq!(request.fields()[1], b"field2");
assert_eq!(request.fields()[2], b"42");
} else {
panic!("invalid parse result");
}

//insufficient whitespace delimited strings
parser
.parse(b"hmget key\r\n")
.expect_err("malformed command");
}
}
15 changes: 15 additions & 0 deletions src/protocol/resp/src/request/mod.rs
Original file line number Diff line number Diff line change
@@ -12,10 +12,12 @@ use std::sync::Arc;

mod badd;
mod get;
mod hmget;
mod set;

pub use badd::BAddRequest;
pub use get::GetRequest;
pub use hmget::HmGetRequest;
pub use set::SetRequest;

#[derive(Default)]
@@ -95,6 +97,9 @@ impl Parse<Request> for RequestParser {
Some(b"get") | Some(b"GET") => {
GetRequest::try_from(message).map(Request::from)
}
Some(b"hmget") | Some(b"HMGET") => {
HmGetRequest::try_from(message).map(Request::from)
}
Some(b"set") | Some(b"SET") => {
SetRequest::try_from(message).map(Request::from)
}
@@ -120,6 +125,7 @@ impl Compose for Request {
match self {
Self::BAdd(r) => r.compose(buf),
Self::Get(r) => r.compose(buf),
Self::HmGet(r) => r.compose(buf),
Self::Set(r) => r.compose(buf),
}
}
@@ -129,6 +135,7 @@ impl Compose for Request {
pub enum Request {
BAdd(BAddRequest),
Get(GetRequest),
HmGet(HmGetRequest),
Set(SetRequest),
}

@@ -144,6 +151,12 @@ impl From<GetRequest> for Request {
}
}

impl From<HmGetRequest> for Request {
fn from(other: HmGetRequest) -> Self {
Self::HmGet(other)
}
}

impl From<SetRequest> for Request {
fn from(other: SetRequest) -> Self {
Self::Set(other)
@@ -154,6 +167,7 @@ impl From<SetRequest> for Request {
pub enum Command {
BAdd,
Get,
HmGet,
Set,
}

@@ -164,6 +178,7 @@ impl TryFrom<&[u8]> for Command {
match other {
b"badd" | b"BADD" => Ok(Command::BAdd),
b"get" | b"GET" => Ok(Command::Get),
b"hmget" | b"HMGET" => Ok(Command::HmGet),
b"set" | b"SET" => Ok(Command::Set),
_ => Err(()),
}