Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(write path): newtype to enforce use of fully initialized slices #8717

Merged
merged 6 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 51 additions & 42 deletions pageserver/src/tenant/blob_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use tracing::warn;
use crate::context::RequestContext;
use crate::page_cache::PAGE_SZ;
use crate::tenant::block_io::BlockCursor;
use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt};
use crate::virtual_file::VirtualFile;
use std::cmp::min;
use std::io::{Error, ErrorKind};
Expand Down Expand Up @@ -186,11 +187,11 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
/// You need to make sure that the internal buffer is empty, otherwise
/// data will be written in wrong order.
#[inline(always)]
async fn write_all_unbuffered<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
async fn write_all_unbuffered<Buf: IoBuf + Send>(
&mut self,
src_buf: B,
src_buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (B::Buf, Result<(), Error>) {
) -> (FullSlice<Buf>, Result<(), Error>) {
let (src_buf, res) = self.inner.write_all(src_buf, ctx).await;
let nbytes = match res {
Ok(nbytes) => nbytes,
Expand All @@ -204,8 +205,9 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
/// Flushes the internal buffer to the underlying `VirtualFile`.
pub async fn flush_buffer(&mut self, ctx: &RequestContext) -> Result<(), Error> {
let buf = std::mem::take(&mut self.buf);
let (mut buf, res) = self.inner.write_all(buf, ctx).await;
let (slice, res) = self.inner.write_all(buf.slice_len(), ctx).await;
res?;
let mut buf = slice.into_raw_slice().into_inner();
buf.clear();
self.buf = buf;
Ok(())
Expand All @@ -222,19 +224,30 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
}

/// Internal, possibly buffered, write function
async fn write_all<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
async fn write_all<Buf: IoBuf + Send>(
&mut self,
src_buf: B,
src_buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (B::Buf, Result<(), Error>) {
) -> (FullSlice<Buf>, Result<(), Error>) {
let src_buf = src_buf.into_raw_slice();
let src_buf_bounds = src_buf.bounds();
let restore = move |src_buf_slice: Slice<_>| {
FullSlice::must_new(Slice::from_buf_bounds(
src_buf_slice.into_inner(),
src_buf_bounds,
))
};

if !BUFFERED {
assert!(self.buf.is_empty());
return self.write_all_unbuffered(src_buf, ctx).await;
return self
.write_all_unbuffered(FullSlice::must_new(src_buf), ctx)
.await;
}
let remaining = Self::CAPACITY - self.buf.len();
let src_buf_len = src_buf.bytes_init();
if src_buf_len == 0 {
return (Slice::into_inner(src_buf.slice_full()), Ok(()));
return (restore(src_buf), Ok(()));
}
let mut src_buf = src_buf.slice(0..src_buf_len);
// First try to copy as much as we can into the buffer
Expand All @@ -245,7 +258,7 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
// Then, if the buffer is full, flush it out
if self.buf.len() == Self::CAPACITY {
if let Err(e) = self.flush_buffer(ctx).await {
return (Slice::into_inner(src_buf), Err(e));
return (restore(src_buf), Err(e));
}
}
// Finally, write the tail of src_buf:
Expand All @@ -258,27 +271,29 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
let copied = self.write_into_buffer(&src_buf);
// We just verified above that src_buf fits into our internal buffer.
assert_eq!(copied, src_buf.len());
Slice::into_inner(src_buf)
restore(src_buf)
} else {
let (src_buf, res) = self.write_all_unbuffered(src_buf, ctx).await;
let (src_buf, res) = self
.write_all_unbuffered(FullSlice::must_new(src_buf), ctx)
.await;
if let Err(e) = res {
return (src_buf, Err(e));
}
src_buf
}
} else {
Slice::into_inner(src_buf)
restore(src_buf)
};
(src_buf, Ok(()))
}

/// Write a blob of data. Returns the offset that it was written to,
/// which can be used to retrieve the data later.
pub async fn write_blob<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
pub async fn write_blob<Buf: IoBuf + Send>(
&mut self,
srcbuf: B,
srcbuf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (B::Buf, Result<u64, Error>) {
) -> (FullSlice<Buf>, Result<u64, Error>) {
let (buf, res) = self
.write_blob_maybe_compressed(srcbuf, ctx, ImageCompressionAlgorithm::Disabled)
.await;
Expand All @@ -287,43 +302,40 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {

/// Write a blob of data. Returns the offset that it was written to,
/// which can be used to retrieve the data later.
pub async fn write_blob_maybe_compressed<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
pub(crate) async fn write_blob_maybe_compressed<Buf: IoBuf + Send>(
&mut self,
srcbuf: B,
srcbuf: FullSlice<Buf>,
ctx: &RequestContext,
algorithm: ImageCompressionAlgorithm,
) -> (B::Buf, Result<(u64, CompressionInfo), Error>) {
) -> (FullSlice<Buf>, Result<(u64, CompressionInfo), Error>) {
let offset = self.offset;
let mut compression_info = CompressionInfo {
written_compressed: false,
compressed_size: None,
};

let len = srcbuf.bytes_init();
let len = srcbuf.len();

let mut io_buf = self.io_buf.take().expect("we always put it back below");
io_buf.clear();
let mut compressed_buf = None;
let ((io_buf, hdr_res), srcbuf) = async {
let ((io_buf_slice, hdr_res), srcbuf) = async {
if len < 128 {
// Short blob. Write a 1-byte length header
io_buf.put_u8(len as u8);
(
self.write_all(io_buf, ctx).await,
srcbuf.slice_full().into_inner(),
)
(self.write_all(io_buf.slice_len(), ctx).await, srcbuf)
} else {
// Write a 4-byte length header
if len > MAX_SUPPORTED_LEN {
return (
(
io_buf,
io_buf.slice_len(),
Err(Error::new(
ErrorKind::Other,
format!("blob too large ({len} bytes)"),
)),
),
srcbuf.slice_full().into_inner(),
srcbuf,
);
}
let (high_bit_mask, len_written, srcbuf) = match algorithm {
Expand All @@ -336,40 +348,37 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
} else {
async_compression::tokio::write::ZstdEncoder::new(Vec::new())
};
let slice = srcbuf.slice_full();
encoder.write_all(&slice[..]).await.unwrap();
encoder.write_all(&srcbuf[..]).await.unwrap();
encoder.shutdown().await.unwrap();
let compressed = encoder.into_inner();
compression_info.compressed_size = Some(compressed.len());
if compressed.len() < len {
compression_info.written_compressed = true;
let compressed_len = compressed.len();
compressed_buf = Some(compressed);
(BYTE_ZSTD, compressed_len, slice.into_inner())
(BYTE_ZSTD, compressed_len, srcbuf)
} else {
(BYTE_UNCOMPRESSED, len, slice.into_inner())
(BYTE_UNCOMPRESSED, len, srcbuf)
}
}
ImageCompressionAlgorithm::Disabled => {
(BYTE_UNCOMPRESSED, len, srcbuf.slice_full().into_inner())
}
ImageCompressionAlgorithm::Disabled => (BYTE_UNCOMPRESSED, len, srcbuf),
};
let mut len_buf = (len_written as u32).to_be_bytes();
assert_eq!(len_buf[0] & 0xf0, 0);
len_buf[0] |= high_bit_mask;
io_buf.extend_from_slice(&len_buf[..]);
(self.write_all(io_buf, ctx).await, srcbuf)
(self.write_all(io_buf.slice_len(), ctx).await, srcbuf)
}
}
.await;
self.io_buf = Some(io_buf);
self.io_buf = Some(io_buf_slice.into_raw_slice().into_inner());
match hdr_res {
Ok(_) => (),
Err(e) => return (Slice::into_inner(srcbuf.slice(..)), Err(e)),
Err(e) => return (srcbuf, Err(e)),
}
let (srcbuf, res) = if let Some(compressed_buf) = compressed_buf {
let (_buf, res) = self.write_all(compressed_buf, ctx).await;
(Slice::into_inner(srcbuf.slice(..)), res)
let (_buf, res) = self.write_all(compressed_buf.slice_len(), ctx).await;
(srcbuf, res)
} else {
self.write_all(srcbuf, ctx).await
};
Expand Down Expand Up @@ -432,21 +441,21 @@ pub(crate) mod tests {
let (_, res) = if compression {
let res = wtr
.write_blob_maybe_compressed(
blob.clone(),
blob.clone().slice_len(),
ctx,
ImageCompressionAlgorithm::Zstd { level: Some(1) },
)
.await;
(res.0, res.1.map(|(off, _)| off))
} else {
wtr.write_blob(blob.clone(), ctx).await
wtr.write_blob(blob.clone().slice_len(), ctx).await
};
let offs = res?;
offsets.push(offs);
}
// Write out one page worth of zeros so that we can
// read again with read_blk
let (_, res) = wtr.write_blob(vec![0; PAGE_SZ], ctx).await;
let (_, res) = wtr.write_blob(vec![0; PAGE_SZ].slice_len(), ctx).await;
let offs = res?;
println!("Writing final blob at offs={offs}");
wtr.flush_buffer(ctx).await?;
Expand Down
31 changes: 8 additions & 23 deletions pageserver/src/tenant/ephemeral_file/page_caching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use crate::context::RequestContext;
use crate::page_cache::{self, PAGE_SZ};
use crate::tenant::block_io::BlockLease;
use crate::virtual_file::owned_buffers_io::io_buf_ext::FullSlice;
use crate::virtual_file::VirtualFile;

use once_cell::sync::Lazy;
Expand Down Expand Up @@ -208,21 +209,11 @@ impl PreWarmingWriter {
}

impl crate::virtual_file::owned_buffers_io::write::OwnedAsyncWriter for PreWarmingWriter {
async fn write_all<
B: tokio_epoll_uring::BoundedBuf<Buf = Buf>,
Buf: tokio_epoll_uring::IoBuf + Send,
>(
async fn write_all<Buf: tokio_epoll_uring::IoBuf + Send>(
&mut self,
buf: B,
buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> std::io::Result<(usize, B::Buf)> {
let buf = buf.slice(..);
let saved_bounds = buf.bounds(); // save for reconstructing the Slice from iobuf after the IO is done
let check_bounds_stuff_works = if cfg!(test) && cfg!(debug_assertions) {
Some(buf.to_vec())
} else {
None
};
) -> std::io::Result<(usize, FullSlice<Buf>)> {
let buflen = buf.len();
assert_eq!(
buflen % PAGE_SZ,
Expand All @@ -231,10 +222,10 @@ impl crate::virtual_file::owned_buffers_io::write::OwnedAsyncWriter for PreWarmi
);

// Do the IO.
let iobuf = match self.file.write_all(buf, ctx).await {
(iobuf, Ok(nwritten)) => {
let buf = match self.file.write_all(buf, ctx).await {
(buf, Ok(nwritten)) => {
assert_eq!(nwritten, buflen);
iobuf
buf
}
(_, Err(e)) => {
return Err(std::io::Error::new(
Expand All @@ -248,12 +239,6 @@ impl crate::virtual_file::owned_buffers_io::write::OwnedAsyncWriter for PreWarmi
}
};

// Reconstruct the Slice (the write path consumed the Slice and returned us the underlying IoBuf)
let buf = tokio_epoll_uring::Slice::from_buf_bounds(iobuf, saved_bounds);
if let Some(check_bounds_stuff_works) = check_bounds_stuff_works {
assert_eq!(&check_bounds_stuff_works, &*buf);
}

let nblocks = buflen / PAGE_SZ;
let nblocks32 = u32::try_from(nblocks).unwrap();

Expand Down Expand Up @@ -300,6 +285,6 @@ impl crate::virtual_file::owned_buffers_io::write::OwnedAsyncWriter for PreWarmi
}

self.nwritten_blocks = self.nwritten_blocks.checked_add(nblocks32).unwrap();
Ok((buflen, buf.into_inner()))
Ok((buflen, buf))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

use std::mem::MaybeUninit;

use crate::virtual_file::owned_buffers_io::io_buf_ext::FullSlice;

/// See module-level comment.
pub struct Buffer<const N: usize> {
allocation: Box<[u8; N]>,
Expand Down Expand Up @@ -60,10 +62,10 @@ impl<const N: usize> crate::virtual_file::owned_buffers_io::write::Buffer for Bu
self.written
}

fn flush(self) -> tokio_epoll_uring::Slice<Self> {
fn flush(self) -> FullSlice<Self> {
self.invariants();
let written = self.written;
tokio_epoll_uring::BoundedBuf::slice(self, 0..written)
FullSlice::must_new(tokio_epoll_uring::BoundedBuf::slice(self, 0..written))
}

fn reuse_after_flush(iobuf: Self::IoBuf) -> Self {
Expand Down
5 changes: 2 additions & 3 deletions pageserver/src/tenant/remote_timeline_client/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::span::debug_assert_current_span_has_tenant_and_timeline_id;
use crate::tenant::remote_timeline_client::{remote_layer_path, remote_timelines_path};
use crate::tenant::storage_layer::LayerName;
use crate::tenant::Generation;
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
use crate::virtual_file::{on_fatal_io_error, MaybeFatalIo, VirtualFile};
use crate::TEMP_FILE_SUFFIX;
use remote_storage::{DownloadError, GenericRemoteStorage, ListingMode, RemotePath};
Expand Down Expand Up @@ -219,9 +220,7 @@ async fn download_object<'a>(
Ok(chunk) => chunk,
Err(e) => return Err(e),
};
buffered
.write_buffered(tokio_epoll_uring::BoundedBuf::slice_full(chunk), ctx)
.await?;
buffered.write_buffered(chunk.slice_len(), ctx).await?;
}
let size_tracking = buffered.flush_and_into_inner(ctx).await?;
Ok(size_tracking.into_inner())
Expand Down
Loading
Loading