From 2f85fa3619fbaa23f60ea5947c59cc6bf384e101 Mon Sep 17 00:00:00 2001 From: hoslo Date: Fri, 23 Feb 2024 16:39:39 +0800 Subject: [PATCH] feat(services/memcached): change to binary protocal (#4252) * feat(services/memcached): change to binary protocal * feat(services/memcached): change to binary protocal --- core/src/services/memcached/ascii.rs | 172 --------------- core/src/services/memcached/backend.rs | 43 +++- core/src/services/memcached/binary.rs | 286 +++++++++++++++++++++++++ core/src/services/memcached/docs.md | 7 +- core/src/services/memcached/mod.rs | 2 +- 5 files changed, 331 insertions(+), 179 deletions(-) delete mode 100644 core/src/services/memcached/ascii.rs create mode 100644 core/src/services/memcached/binary.rs diff --git a/core/src/services/memcached/ascii.rs b/core/src/services/memcached/ascii.rs deleted file mode 100644 index 6a790889396c..000000000000 --- a/core/src/services/memcached/ascii.rs +++ /dev/null @@ -1,172 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use tokio::io::AsyncBufReadExt; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::BufReader; -use tokio::net::TcpStream; - -use crate::raw::*; -use crate::*; - -pub struct Connection { - io: BufReader, - buf: Vec, -} - -impl Connection { - pub fn new(io: TcpStream) -> Self { - Self { - io: BufReader::new(io), - buf: Vec::new(), - } - } - - pub async fn get(&mut self, key: &str) -> Result>> { - // Send command - let writer = self.io.get_mut(); - writer - .write_all(&[b"get ", key.as_bytes(), b"\r\n"].concat()) - .await - .map_err(new_std_io_error)?; - writer.flush().await.map_err(new_std_io_error)?; - - // Read response header - let header = self.read_header().await?; - - // Check response header and parse value length - if header.contains("ERROR") { - return Err( - Error::new(ErrorKind::Unexpected, "unexpected data received") - .with_context("message", header), - ); - } else if header.starts_with("END") { - return Ok(None); - } - - // VALUE []\r\n - let length: usize = header - .split(' ') - .nth(3) - .and_then(|len| len.trim_end().parse().ok()) - .ok_or_else(|| Error::new(ErrorKind::Unexpected, "invalid data received"))?; - - // Read value - let mut buffer: Vec = vec![0; length]; - self.io - .read_exact(&mut buffer) - .await - .map_err(new_std_io_error)?; - - // Read the trailing header - self.read_line().await?; // \r\n - self.read_line().await?; // END\r\n - - Ok(Some(buffer)) - } - - pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { - let header = format!("set {} 0 {} {}\r\n", key, expiration, val.len()); - self.io - .write_all(header.as_bytes()) - .await - .map_err(new_std_io_error)?; - self.io.write_all(val).await.map_err(new_std_io_error)?; - self.io.write_all(b"\r\n").await.map_err(new_std_io_error)?; - self.io.flush().await.map_err(new_std_io_error)?; - - // Read response header - let header = self.read_header().await?; - - // Check response header and make sure we got a `STORED` - if header.contains("STORED") { - return Ok(()); - } else if header.contains("ERROR") { - return Err( - Error::new(ErrorKind::Unexpected, "unexpected data received") - .with_context("message", header), - ); - } - Ok(()) - } - - pub async fn delete(&mut self, key: &str) -> Result<()> { - let header = format!("delete {}\r\n", key); - self.io - .write_all(header.as_bytes()) - .await - .map_err(new_std_io_error)?; - self.io.flush().await.map_err(new_std_io_error)?; - - // Read response header - let header = self.read_header().await?; - - // Check response header and parse value length - if header.contains("NOT_FOUND") || header.starts_with("END") { - return Ok(()); - } else if header.contains("ERROR") || !header.contains("DELETED") { - return Err( - Error::new(ErrorKind::Unexpected, "unexpected data received") - .with_context("message", header), - ); - } - Ok(()) - } - - pub async fn version(&mut self) -> Result { - self.io - .write_all(b"version\r\n") - .await - .map_err(new_std_io_error)?; - self.io.flush().await.map_err(new_std_io_error)?; - - // Read response header - let header = self.read_header().await?; - - if !header.starts_with("VERSION") { - return Err( - Error::new(ErrorKind::Unexpected, "unexpected data received") - .with_context("message", header), - ); - } - let version = header.trim_start_matches("VERSION ").trim_end(); - Ok(version.to_string()) - } - - async fn read_line(&mut self) -> Result<&[u8]> { - let Self { io, buf } = self; - buf.clear(); - io.read_until(b'\n', buf).await.map_err(new_std_io_error)?; - if buf.last().copied() != Some(b'\n') { - return Err(Error::new( - ErrorKind::ContentIncomplete, - "unexpected eof, the response must be incomplete", - )); - } - Ok(&buf[..]) - } - - async fn read_header(&mut self) -> Result<&str> { - let header = self.read_line().await?; - let header = std::str::from_utf8(header).map_err(|err| { - Error::new(ErrorKind::Unexpected, "invalid data received").set_source(err) - })?; - - Ok(header) - } -} diff --git a/core/src/services/memcached/backend.rs b/core/src/services/memcached/backend.rs index f0d48e4a9173..d2a5be030d47 100644 --- a/core/src/services/memcached/backend.rs +++ b/core/src/services/memcached/backend.rs @@ -24,7 +24,7 @@ use serde::Deserialize; use tokio::net::TcpStream; use tokio::sync::OnceCell; -use super::ascii; +use super::binary; use crate::raw::adapters::kv; use crate::raw::*; use crate::*; @@ -42,6 +42,10 @@ pub struct MemcachedConfig { /// /// default is "/" root: Option, + /// Memcached username, optional. + username: Option, + /// Memcached password, optional. + password: Option, /// The default ttl for put operations. default_ttl: Option, } @@ -74,6 +78,18 @@ impl MemcachedBuilder { self } + /// set the username. + pub fn username(&mut self, username: &str) -> &mut Self { + self.config.username = Some(username.to_string()); + self + } + + /// set the password. + pub fn password(&mut self, password: &str) -> &mut Self { + self.config.password = Some(password.to_string()); + self + } + /// Set the default ttl for memcached services. pub fn default_ttl(&mut self, ttl: Duration) -> &mut Self { self.config.default_ttl = Some(ttl); @@ -151,6 +167,8 @@ impl Builder for MemcachedBuilder { let conn = OnceCell::new(); Ok(MemcachedBackend::new(Adapter { endpoint, + username: self.config.username.clone(), + password: self.config.password.clone(), conn, default_ttl: self.config.default_ttl, }) @@ -164,6 +182,8 @@ pub type MemcachedBackend = kv::Backend; #[derive(Clone, Debug)] pub struct Adapter { endpoint: String, + username: Option, + password: Option, default_ttl: Option, conn: OnceCell>, } @@ -173,7 +193,11 @@ impl Adapter { let pool = self .conn .get_or_try_init(|| async { - let mgr = MemcacheConnectionManager::new(&self.endpoint); + let mgr = MemcacheConnectionManager::new( + &self.endpoint, + self.username.clone(), + self.password.clone(), + ); bb8::Pool::builder().build(mgr).await.map_err(|err| { Error::new(ErrorKind::ConfigInvalid, "connect to memecached failed") @@ -237,19 +261,23 @@ impl kv::Adapter for Adapter { #[derive(Clone, Debug)] struct MemcacheConnectionManager { address: String, + username: Option, + password: Option, } impl MemcacheConnectionManager { - fn new(address: &str) -> Self { + fn new(address: &str, username: Option, password: Option) -> Self { Self { address: address.to_string(), + username, + password, } } } #[async_trait] impl bb8::ManageConnection for MemcacheConnectionManager { - type Connection = ascii::Connection; + type Connection = binary::Connection; type Error = Error; /// TODO: Implement unix stream support. @@ -257,7 +285,12 @@ impl bb8::ManageConnection for MemcacheConnectionManager { let conn = TcpStream::connect(&self.address) .await .map_err(new_std_io_error)?; - Ok(ascii::Connection::new(conn)) + let mut conn = binary::Connection::new(conn); + + if let (Some(username), Some(password)) = (self.username.as_ref(), self.password.as_ref()) { + conn.auth(username, password).await?; + } + Ok(conn) } async fn is_valid(&self, conn: &mut Self::Connection) -> std::result::Result<(), Self::Error> { diff --git a/core/src/services/memcached/binary.rs b/core/src/services/memcached/binary.rs new file mode 100644 index 000000000000..5fbce6a5ca5a --- /dev/null +++ b/core/src/services/memcached/binary.rs @@ -0,0 +1,286 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use tokio::io::{self, AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpStream; + +use crate::raw::*; +use crate::*; + +pub(super) mod constants { + pub const OK_STATUS: u16 = 0x0; + pub const KEY_NOT_FOUND: u16 = 0x1; +} + +pub enum Opcode { + Get = 0x00, + Set = 0x01, + Delete = 0x04, + Version = 0x0b, + StartAuth = 0x21, +} + +pub enum Magic { + Request = 0x80, +} + +#[derive(Debug)] +pub struct StoreExtras { + pub flags: u32, + pub expiration: u32, +} + +#[derive(Debug, Default)] +pub struct PacketHeader { + pub magic: u8, + pub opcode: u8, + pub key_length: u16, + pub extras_length: u8, + pub data_type: u8, + pub vbucket_id_or_status: u16, + pub total_body_length: u32, + pub opaque: u32, + pub cas: u64, +} + +impl PacketHeader { + pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> { + writer.write_u8(self.magic).await?; + writer.write_u8(self.opcode).await?; + writer.write_u16(self.key_length).await?; + writer.write_u8(self.extras_length).await?; + writer.write_u8(self.data_type).await?; + writer.write_u16(self.vbucket_id_or_status).await?; + writer.write_u32(self.total_body_length).await?; + writer.write_u32(self.opaque).await?; + writer.write_u64(self.cas).await?; + Ok(()) + } + + pub async fn read(reader: &mut TcpStream) -> std::result::Result { + let header = PacketHeader { + magic: reader.read_u8().await?, + opcode: reader.read_u8().await?, + key_length: reader.read_u16().await?, + extras_length: reader.read_u8().await?, + data_type: reader.read_u8().await?, + vbucket_id_or_status: reader.read_u16().await?, + total_body_length: reader.read_u32().await?, + opaque: reader.read_u32().await?, + cas: reader.read_u64().await?, + }; + Ok(header) + } +} + +pub struct Response { + header: PacketHeader, + _key: Vec, + _extras: Vec, + value: Vec, +} + +pub struct Connection { + io: BufReader, +} + +impl Connection { + pub fn new(io: TcpStream) -> Self { + Self { + io: BufReader::new(io), + } + } + + pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> { + let writer = self.io.get_mut(); + let key = "PLAIN"; + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::StartAuth as u8, + key_length: key.len() as u16, + total_body_length: (key.len() + username.len() + password.len() + 2) as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer + .write_all(format!("\x00{}\x00{}", username, password).as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + parse_response(writer).await?; + Ok(()) + } + + pub async fn version(&mut self) -> Result { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Version as u8, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + let response = parse_response(writer).await?; + let version = String::from_utf8(response.value); + match version { + Ok(version) => Ok(version), + Err(e) => { + Err(Error::new(ErrorKind::Unexpected, "unexpected data received").set_source(e)) + } + } + } + + pub async fn get(&mut self, key: &str) -> Result>> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Get as u8, + key_length: key.len() as u16, + total_body_length: key.len() as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + match parse_response(writer).await { + Ok(response) => { + if response.header.vbucket_id_or_status == 0x1 { + return Ok(None); + } + Ok(Some(response.value)) + } + Err(e) => Err(e), + } + } + + pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Set as u8, + key_length: key.len() as u16, + extras_length: 8, + total_body_length: (8 + key.len() + val.len()) as u32, + ..Default::default() + }; + let extras = StoreExtras { + flags: 0, + expiration, + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_u32(extras.flags) + .await + .map_err(new_std_io_error)?; + writer + .write_u32(extras.expiration) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.write_all(val).await.map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + + parse_response(writer).await?; + Ok(()) + } + + pub async fn delete(&mut self, key: &str) -> Result<()> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Delete as u8, + key_length: key.len() as u16, + total_body_length: key.len() as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + parse_response(writer).await?; + Ok(()) + } +} + +pub async fn parse_response(reader: &mut TcpStream) -> Result { + let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?; + + if header.vbucket_id_or_status != constants::OK_STATUS + && header.vbucket_id_or_status != constants::KEY_NOT_FOUND + { + return Err( + Error::new(ErrorKind::Unexpected, "unexpected status received") + .with_context("message", format!("{}", header.vbucket_id_or_status)), + ); + } + + let mut extras = vec![0x0; header.extras_length as usize]; + reader + .read_exact(extras.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + let mut key = vec![0x0; header.key_length as usize]; + reader + .read_exact(key.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + let mut value = vec![ + 0x0; + (header.total_body_length - u32::from(header.key_length) - u32::from(header.extras_length)) + as usize + ]; + reader + .read_exact(value.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + Ok(Response { + header, + _key: key, + _extras: extras, + value, + }) +} diff --git a/core/src/services/memcached/docs.md b/core/src/services/memcached/docs.md index 0de179ab9894..844c7ee49b9e 100644 --- a/core/src/services/memcached/docs.md +++ b/core/src/services/memcached/docs.md @@ -5,7 +5,7 @@ This service can be used to: - [x] stat - [x] read - [x] write -- [x] create_dir +- [ ] create_dir - [x] delete - [ ] copy - [ ] rename @@ -17,6 +17,8 @@ This service can be used to: ## Configuration - `root`: Set the working directory of `OpenDAL` +- `username`: Set the username for authentication. +- `password`: Set the password for authentication. - `endpoint`: Set the network address of memcached server - `default_ttl`: Set the ttl for memcached service. @@ -37,6 +39,9 @@ async fn main() -> Result<()> { let mut builder = Memcached::default(); builder.endpoint("tcp://127.0.0.1:11211"); + // if you enable authentication, set username and password for authentication + // builder.username("admin"); + // builder.password("password"); let op: Operator = Operator::new(builder)?.finish(); Ok(()) diff --git a/core/src/services/memcached/mod.rs b/core/src/services/memcached/mod.rs index bbe45219a111..d293520797cb 100644 --- a/core/src/services/memcached/mod.rs +++ b/core/src/services/memcached/mod.rs @@ -18,4 +18,4 @@ mod backend; pub use backend::MemcachedBuilder as Memcached; pub use backend::MemcachedConfig; -mod ascii; +mod binary;