diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml index 172ddc865b22..422f676cb7d9 100644 --- a/.github/workflows/object_store.yml +++ b/.github/workflows/object_store.yml @@ -138,7 +138,7 @@ jobs: - name: Setup LocalStack (AWS emulation) run: | - echo "LOCALSTACK_CONTAINER=$(docker run -d -p 4566:4566 localstack/localstack:3.2.0)" >> $GITHUB_ENV + echo "LOCALSTACK_CONTAINER=$(docker run -d -p 4566:4566 localstack/localstack:3.3.0)" >> $GITHUB_ENV echo "EC2_METADATA_CONTAINER=$(docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2)" >> $GITHUB_ENV aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket aws --endpoint-url=http://localhost:4566 dynamodb create-table --table-name test-table --key-schema AttributeName=path,KeyType=HASH AttributeName=etag,KeyType=RANGE --attribute-definitions AttributeName=path,AttributeType=S AttributeName=etag,AttributeType=S --provisioned-throughput ReadCapacityUnits=5,WriteCapacityUnits=5 diff --git a/arrow-buffer/Cargo.toml b/arrow-buffer/Cargo.toml index 746045cc8dde..8bc33b1874e4 100644 --- a/arrow-buffer/Cargo.toml +++ b/arrow-buffer/Cargo.toml @@ -46,4 +46,8 @@ rand = { version = "0.8", default-features = false, features = ["std", "std_rng" [[bench]] name = "i256" +harness = false + +[[bench]] +name = "offset" harness = false \ No newline at end of file diff --git a/arrow-buffer/benches/offset.rs b/arrow-buffer/benches/offset.rs new file mode 100644 index 000000000000..1aea5024fbd1 --- /dev/null +++ b/arrow-buffer/benches/offset.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::{OffsetBuffer, OffsetBufferBuilder}; +use criterion::*; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +const SIZE: usize = 1024; + +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(42); + let lengths: Vec = black_box((0..SIZE).map(|_| rng.gen_range(0..40)).collect()); + + c.bench_function("OffsetBuffer::from_lengths", |b| { + b.iter(|| OffsetBuffer::::from_lengths(lengths.iter().copied())); + }); + + c.bench_function("OffsetBufferBuilder::push_length", |b| { + b.iter(|| { + let mut builder = OffsetBufferBuilder::::new(lengths.len()); + lengths.iter().for_each(|x| builder.push_length(*x)); + builder.finish() + }); + }); + + let offsets = OffsetBuffer::::from_lengths(lengths.iter().copied()).into_inner(); + + c.bench_function("OffsetBuffer::new", |b| { + b.iter(|| OffsetBuffer::new(black_box(offsets.clone()))); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 552e3f1615c7..f26cde05b7ab 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -171,23 +171,33 @@ impl Buffer { /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`. /// Doing so allows the same memory region to be shared between buffers. + /// /// # Panics + /// /// Panics iff `offset` is larger than `len`. pub fn slice(&self, offset: usize) -> Self { + let mut s = self.clone(); + s.advance(offset); + s + } + + /// Increases the offset of this buffer by `offset` + /// + /// # Panics + /// + /// Panics iff `offset` is larger than `len`. + #[inline] + pub fn advance(&mut self, offset: usize) { assert!( offset <= self.length, "the offset of the new Buffer cannot exceed the existing length" ); + self.length -= offset; // Safety: // This cannot overflow as // `self.offset + self.length < self.data.len()` // `offset < self.length` - let ptr = unsafe { self.ptr.add(offset) }; - Self { - data: self.data.clone(), - length: self.length - offset, - ptr, - } + self.ptr = unsafe { self.ptr.add(offset) }; } /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`, diff --git a/arrow-buffer/src/buffer/offset.rs b/arrow-buffer/src/buffer/offset.rs index 652d30c3b0ab..e9087d30098c 100644 --- a/arrow-buffer/src/buffer/offset.rs +++ b/arrow-buffer/src/buffer/offset.rs @@ -16,7 +16,7 @@ // under the License. use crate::buffer::ScalarBuffer; -use crate::{ArrowNativeType, MutableBuffer}; +use crate::{ArrowNativeType, MutableBuffer, OffsetBufferBuilder}; use std::ops::Deref; /// A non-empty buffer of monotonically increasing, positive integers. @@ -55,7 +55,6 @@ use std::ops::Deref; /// (offsets[i], /// offsets[i+1]) /// ``` - #[derive(Debug, Clone)] pub struct OffsetBuffer(ScalarBuffer); @@ -174,6 +173,12 @@ impl AsRef<[T]> for OffsetBuffer { } } +impl From> for OffsetBuffer { + fn from(value: OffsetBufferBuilder) -> Self { + value.finish() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 3826d74e43bd..2019cc79830d 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -160,6 +160,15 @@ impl From> for ScalarBuffer { } } +impl From> for Vec { + fn from(value: ScalarBuffer) -> Self { + value + .buffer + .into_vec() + .unwrap_or_else(|buffer| buffer.typed_data::().into()) + } +} + impl From> for ScalarBuffer { fn from(mut value: BufferBuilder) -> Self { let len = value.len(); @@ -208,6 +217,8 @@ impl PartialEq> for Vec { #[cfg(test)] mod tests { + use std::{ptr::NonNull, sync::Arc}; + use super::*; #[test] @@ -284,4 +295,45 @@ mod tests { let scalar_buffer = ScalarBuffer::from(buffer_builder); assert_eq!(scalar_buffer.as_ref(), input); } + + #[test] + fn into_vec() { + let input = vec![1u8, 2, 3, 4]; + + // No copy + let input_buffer = Buffer::from_vec(input.clone()); + let input_ptr = input_buffer.as_ptr(); + let input_len = input_buffer.len(); + let scalar_buffer = ScalarBuffer::::new(input_buffer, 0, input_len); + let vec = Vec::from(scalar_buffer); + assert_eq!(vec.as_slice(), input.as_slice()); + assert_eq!(vec.as_ptr(), input_ptr); + + // Custom allocation - makes a copy + let mut input_clone = input.clone(); + let input_ptr = NonNull::new(input_clone.as_mut_ptr()).unwrap(); + let dealloc = Arc::new(()); + let buffer = + unsafe { Buffer::from_custom_allocation(input_ptr, input_clone.len(), dealloc as _) }; + let scalar_buffer = ScalarBuffer::::new(buffer, 0, input.len()); + let vec = Vec::from(scalar_buffer); + assert_eq!(vec, input.as_slice()); + assert_ne!(vec.as_ptr(), input_ptr.as_ptr()); + + // Offset - makes a copy + let input_buffer = Buffer::from_vec(input.clone()); + let input_ptr = input_buffer.as_ptr(); + let input_len = input_buffer.len(); + let scalar_buffer = ScalarBuffer::::new(input_buffer, 1, input_len - 1); + let vec = Vec::from(scalar_buffer); + assert_eq!(vec.as_slice(), &input[1..]); + assert_ne!(vec.as_ptr(), input_ptr); + + // Inner buffer Arc ref count != 0 - makes a copy + let buffer = Buffer::from_slice_ref(input.as_slice()); + let scalar_buffer = ScalarBuffer::::new(buffer, 0, input.len()); + let vec = Vec::from(scalar_buffer); + assert_eq!(vec, input.as_slice()); + assert_ne!(vec.as_ptr(), input.as_ptr()); + } } diff --git a/arrow-buffer/src/builder/mod.rs b/arrow-buffer/src/builder/mod.rs index d5d5a7d3f18d..f7e0e29dace4 100644 --- a/arrow-buffer/src/builder/mod.rs +++ b/arrow-buffer/src/builder/mod.rs @@ -18,9 +18,12 @@ //! Buffer builders mod boolean; -pub use boolean::*; mod null; +mod offset; + +pub use boolean::*; pub use null::*; +pub use offset::*; use crate::{ArrowNativeType, Buffer, MutableBuffer}; use std::{iter, marker::PhantomData}; diff --git a/arrow-buffer/src/builder/offset.rs b/arrow-buffer/src/builder/offset.rs new file mode 100644 index 000000000000..6a236d2a3e12 --- /dev/null +++ b/arrow-buffer/src/builder/offset.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ops::Deref; + +use crate::{ArrowNativeType, OffsetBuffer}; + +#[derive(Debug)] +pub struct OffsetBufferBuilder { + offsets: Vec, + last_offset: usize, +} + +/// Builder of [`OffsetBuffer`] +impl OffsetBufferBuilder { + /// Create a new builder with space for `capacity + 1` offsets + pub fn new(capacity: usize) -> Self { + let mut offsets = Vec::with_capacity(capacity + 1); + offsets.push(O::usize_as(0)); + Self { + offsets, + last_offset: 0, + } + } + + /// Push a slice of `length` bytes + /// + /// # Panics + /// + /// Panics if adding `length` would overflow `usize` + #[inline] + pub fn push_length(&mut self, length: usize) { + self.last_offset = self.last_offset.checked_add(length).expect("overflow"); + self.offsets.push(O::usize_as(self.last_offset)) + } + + /// Reserve space for at least `additional` further offsets + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.offsets.reserve(additional); + } + + /// Takes the builder itself and returns an [`OffsetBuffer`] + /// + /// # Panics + /// + /// Panics if offsets overflow `O` + pub fn finish(self) -> OffsetBuffer { + O::from_usize(self.last_offset).expect("overflow"); + unsafe { OffsetBuffer::new_unchecked(self.offsets.into()) } + } + + /// Builds the [OffsetBuffer] without resetting the builder. + /// + /// # Panics + /// + /// Panics if offsets overflow `O` + pub fn finish_cloned(&self) -> OffsetBuffer { + O::from_usize(self.last_offset).expect("overflow"); + unsafe { OffsetBuffer::new_unchecked(self.offsets.clone().into()) } + } +} + +impl Deref for OffsetBufferBuilder { + type Target = [O]; + + fn deref(&self) -> &Self::Target { + self.offsets.as_ref() + } +} + +#[cfg(test)] +mod tests { + use crate::OffsetBufferBuilder; + + #[test] + fn test_basic() { + let mut builder = OffsetBufferBuilder::::new(5); + assert_eq!(builder.len(), 1); + assert_eq!(&*builder, &[0]); + let finished = builder.finish_cloned(); + assert_eq!(finished.len(), 1); + assert_eq!(&*finished, &[0]); + + builder.push_length(2); + builder.push_length(6); + builder.push_length(0); + builder.push_length(13); + + let finished = builder.finish(); + assert_eq!(&*finished, &[0, 2, 8, 8, 21]); + } + + #[test] + #[should_panic(expected = "overflow")] + fn test_usize_overflow() { + let mut builder = OffsetBufferBuilder::::new(5); + builder.push_length(1); + builder.push_length(usize::MAX); + builder.finish(); + } + + #[test] + #[should_panic(expected = "overflow")] + fn test_i32_overflow() { + let mut builder = OffsetBufferBuilder::::new(5); + builder.push_length(1); + builder.push_length(i32::MAX as usize); + builder.finish(); + } +} diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs new file mode 100644 index 000000000000..600f868a3e01 --- /dev/null +++ b/arrow-cast/src/cast/decimal.rs @@ -0,0 +1,573 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; + +/// A utility trait that provides checked conversions between +/// decimal types inspired by [`NumCast`] +pub(crate) trait DecimalCast: Sized { + fn to_i128(self) -> Option; + + fn to_i256(self) -> Option; + + fn from_decimal(n: T) -> Option; +} + +impl DecimalCast for i128 { + fn to_i128(self) -> Option { + Some(self) + } + + fn to_i256(self) -> Option { + Some(i256::from_i128(self)) + } + + fn from_decimal(n: T) -> Option { + n.to_i128() + } +} + +impl DecimalCast for i256 { + fn to_i128(self) -> Option { + self.to_i128() + } + + fn to_i256(self) -> Option { + Some(self) + } + + fn from_decimal(n: T) -> Option { + n.to_i256() + } +} + +pub(crate) fn cast_decimal_to_decimal_error( + output_precision: u8, + output_scale: i8, +) -> impl Fn(::Native) -> ArrowError +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + move |x: I::Native| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + O::PREFIX, + output_precision, + output_scale, + x + )) + } +} + +pub(crate) fn convert_to_smaller_scale_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); + let div = I::Native::from_decimal(10_i128) + .unwrap() + .pow_checked((input_scale - output_scale) as u32)?; + + let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); + let half_neg = half.neg_wrapping(); + + let f = |x: I::Native| { + // div is >= 10 and so this cannot overflow + let d = x.div_wrapping(div); + let r = x.mod_wrapping(div); + + // Round result + let adjusted = match x >= I::Native::ZERO { + true if r >= half => d.add_wrapping(I::Native::ONE), + false if r <= half_neg => d.sub_wrapping(I::Native::ONE), + _ => d, + }; + O::Native::from_decimal(adjusted) + }; + + Ok(match cast_options.safe { + true => array.unary_opt(f), + false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + }) +} + +pub(crate) fn convert_to_bigger_or_equal_scale_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); + let mul = O::Native::from_decimal(10_i128) + .unwrap() + .pow_checked((output_scale - input_scale) as u32)?; + + let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); + + Ok(match cast_options.safe { + true => array.unary_opt(f), + false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + }) +} + +// Only support one type of decimal cast operations +pub(crate) fn cast_decimal_to_decimal_same_type( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array: PrimitiveArray = match input_scale.cmp(&output_scale) { + Ordering::Equal => { + // the scale doesn't change, the native value don't need to be changed + array.clone() + } + Ordering::Greater => convert_to_smaller_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )?, + Ordering::Less => { + // input_scale < output_scale + convert_to_bigger_or_equal_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + } + }; + + Ok(Arc::new(array.with_precision_and_scale( + output_precision, + output_scale, + )?)) +} + +// Support two different types of decimal cast operations +pub(crate) fn cast_decimal_to_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array: PrimitiveArray = if input_scale > output_scale { + convert_to_smaller_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + } else { + convert_to_bigger_or_equal_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + }; + + Ok(Arc::new(array.with_precision_and_scale( + output_precision, + output_scale, + )?)) +} + +/// Parses given string to specified decimal native (i128/i256) based on given +/// scale. Returns an `Err` if it cannot parse given string. +pub(crate) fn parse_string_to_decimal_native( + value_str: &str, + scale: usize, +) -> Result +where + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + let value_str = value_str.trim(); + let parts: Vec<&str> = value_str.split('.').collect(); + if parts.len() > 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + let (negative, first_part) = if parts[0].is_empty() { + (false, parts[0]) + } else { + match parts[0].as_bytes()[0] { + b'-' => (true, &parts[0][1..]), + b'+' => (false, &parts[0][1..]), + _ => (false, parts[0]), + } + }; + + let integers = first_part.trim_start_matches('0'); + let decimals = if parts.len() == 2 { parts[1] } else { "" }; + + if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + // Adjust decimal based on scale + let mut number_decimals = if decimals.len() > scale { + let decimal_number = i256::from_string(decimals).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}")) + })?; + + let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; + + let half = div.div_wrapping(i256::from_i128(2)); + let half_neg = half.neg_wrapping(); + + let d = decimal_number.div_wrapping(div); + let r = decimal_number.mod_wrapping(div); + + // Round result + let adjusted = match decimal_number >= i256::ZERO { + true if r >= half => d.add_wrapping(i256::ONE), + false if r <= half_neg => d.sub_wrapping(i256::ONE), + _ => d, + }; + + let integers = if !integers.is_empty() { + i256::from_string(integers) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Cannot parse decimal format: {value_str}" + )) + }) + .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))? + } else { + i256::ZERO + }; + + format!("{}", integers.add_wrapping(adjusted)) + } else { + let padding = if scale > decimals.len() { scale } else { 0 }; + + let decimals = format!("{decimals:0( + from: &GenericStringArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + if cast_options.safe { + let iter = from.iter().map(|v| { + v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) + .and_then(|v| { + T::validate_decimal_precision(v, precision) + .is_ok() + .then_some(v) + }) + }); + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + .with_precision_and_scale(precision, scale)? + }) + } else { + let vec = from + .iter() + .map(|v| { + v.map(|v| { + parse_string_to_decimal_native::(v, scale as usize) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + T::DATA_TYPE, + )) + }) + .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) + }) + .transpose() + }) + .collect::, _>>()?; + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + .with_precision_and_scale(precision, scale)? + }) + } +} + +/// Cast Utf8 to decimal +pub(crate) fn cast_string_to_decimal( + from: &dyn Array, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + if scale < 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot cast string to decimal with negative scale {scale}" + ))); + } + + if scale > T::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot cast string to decimal greater than maximum scale {}", + T::MAX_SCALE + ))); + } + + Ok(Arc::new(string_to_decimal_cast::( + from.as_any() + .downcast_ref::>() + .unwrap(), + precision, + scale, + cast_options, + )?)) +} + +pub(crate) fn cast_floating_point_to_decimal128( + array: &PrimitiveArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + if cast_options.safe { + array + .unary_opt::<_, Decimal128Type>(|v| { + (mul * v.as_()) + .round() + .to_i128() + .filter(|v| Decimal128Type::validate_decimal_precision(*v, precision).is_ok()) + }) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal128Type, _>(|v| { + (mul * v.as_()) + .round() + .to_i128() + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + precision, + scale, + v + )) + }) + .and_then(|v| { + Decimal128Type::validate_decimal_precision(v, precision).map(|_| v) + }) + })? + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } +} + +pub(crate) fn cast_floating_point_to_decimal256( + array: &PrimitiveArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + if cast_options.safe { + array + .unary_opt::<_, Decimal256Type>(|v| { + i256::from_f64((v.as_() * mul).round()) + .filter(|v| Decimal256Type::validate_decimal_precision(*v, precision).is_ok()) + }) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal256Type, _>(|v| { + i256::from_f64((v.as_() * mul).round()) + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + precision, + scale, + v + )) + }) + .and_then(|v| { + Decimal256Type::validate_decimal_precision(v, precision).map(|_| v) + }) + })? + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } +} + +pub(crate) fn cast_decimal_to_integer( + array: &dyn Array, + base: D::Native, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: ArrowPrimitiveType, + ::Native: NumCast, + D: DecimalType + ArrowPrimitiveType, + ::Native: ArrowNativeTypeOp + ToPrimitive, +{ + let array = array.as_primitive::(); + + let div: D::Native = base.pow_checked(scale as u32).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast to {:?}. The scale {} causes overflow.", + D::PREFIX, + scale, + )) + })?; + + let mut value_builder = PrimitiveBuilder::::with_capacity(array.len()); + + if cast_options.safe { + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null(); + } else { + let v = array + .value(i) + .div_checked(div) + .ok() + .and_then(::from::); + + value_builder.append_option(v); + } + } + } else { + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null(); + } else { + let v = array.value(i).div_checked(div)?; + + let value = ::from::(v).ok_or_else(|| { + ArrowError::CastError(format!( + "value of {:?} is out of range {}", + v, + T::DATA_TYPE + )) + })?; + + value_builder.append_value(value); + } + } + } + Ok(Arc::new(value_builder.finish())) +} + +// Cast the decimal array to floating-point array +pub(crate) fn cast_decimal_to_float( + array: &dyn Array, + op: F, +) -> Result +where + F: Fn(D::Native) -> T::Native, +{ + let array = array.as_primitive::(); + let array = array.unary::<_, T>(op); + Ok(Arc::new(array)) +} diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs new file mode 100644 index 000000000000..244e101f1d8d --- /dev/null +++ b/arrow-cast/src/cast/dictionary.rs @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; + +/// Attempts to cast an `ArrayDictionary` with index type K into +/// `to_type` for supported types. +/// +/// K is the key type +pub(crate) fn dictionary_cast( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; + + match to_type { + Dictionary(to_index_type, to_value_type) => { + let dict_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), + ) + })?; + + let keys_array: ArrayRef = + Arc::new(PrimitiveArray::::from(dict_array.keys().to_data())); + let values_array = dict_array.values(); + let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; + let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > keys_array.null_count() { + return Err(ArrowError::ComputeError(format!( + "Could not convert {} dictionary indexes from {:?} to {:?}", + cast_keys.null_count() - keys_array.null_count(), + keys_array.data_type(), + to_index_type + ))); + } + + let data = cast_keys.into_data(); + let builder = data + .into_builder() + .data_type(to_type.clone()) + .child_data(vec![cast_values.into_data()]); + + // Safety + // Cast keys are still valid + let data = unsafe { builder.build_unchecked() }; + + // create the appropriate array type + let new_array: ArrayRef = match **to_index_type { + Int8 => Arc::new(DictionaryArray::::from(data)), + Int16 => Arc::new(DictionaryArray::::from(data)), + Int32 => Arc::new(DictionaryArray::::from(data)), + Int64 => Arc::new(DictionaryArray::::from(data)), + UInt8 => Arc::new(DictionaryArray::::from(data)), + UInt16 => Arc::new(DictionaryArray::::from(data)), + UInt32 => Arc::new(DictionaryArray::::from(data)), + UInt64 => Arc::new(DictionaryArray::::from(data)), + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported type {to_index_type:?} for dictionary index" + ))); + } + }; + + Ok(new_array) + } + _ => unpack_dictionary::(array, to_type, cast_options), + } +} + +// Unpack a dictionary where the keys are of type into a flattened array of type to_type +pub(crate) fn unpack_dictionary( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let dict_array = array.as_dictionary::(); + let cast_dict_values = cast_with_options(dict_array.values(), to_type, cast_options)?; + take(cast_dict_values.as_ref(), dict_array.keys(), None) +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +pub(crate) fn cast_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; + + match *dict_value_type { + Int8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Decimal128(_, _) => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + } + Decimal256(_, _) => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + } + Utf8 => pack_byte_to_dictionary::>(array, cast_options), + LargeUtf8 => pack_byte_to_dictionary::>(array, cast_options), + Binary => pack_byte_to_dictionary::>(array, cast_options), + LargeBinary => pack_byte_to_dictionary::>(array, cast_options), + _ => Err(ArrowError::CastError(format!( + "Unsupported output type for dictionary packing: {dict_value_type:?}" + ))), + } +} + +// Packs the data from the primitive array of type to a +// DictionaryArray with keys of type K and values of value_type V +pub(crate) fn pack_numeric_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, + V: ArrowPrimitiveType, +{ + // attempt to cast the source array values to the target value type (the dictionary values type) + let cast_values = cast_with_options(array, dict_value_type, cast_options)?; + let values = cast_values.as_primitive::(); + + let mut b = PrimitiveDictionaryBuilder::::with_capacity(values.len(), values.len()); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null(); + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + +// Packs the data as a GenericByteDictionaryBuilder, if possible, with the +// key types of K +pub(crate) fn pack_byte_to_dictionary( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + let cast_values = cast_with_options(array, &T::DATA_TYPE, cast_options)?; + let values = cast_values + .as_any() + .downcast_ref::>() + .unwrap(); + let mut b = GenericByteDictionaryBuilder::::with_capacity(values.len(), 1024, 1024); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null(); + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} diff --git a/arrow-cast/src/cast/list.rs b/arrow-cast/src/cast/list.rs new file mode 100644 index 000000000000..33faacdccb92 --- /dev/null +++ b/arrow-cast/src/cast/list.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; + +/// Helper function that takes a primitive array and casts to a (generic) list array. +pub(crate) fn cast_values_to_list( + array: &dyn Array, + to: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let values = cast_with_options(array, to.data_type(), cast_options)?; + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(1).take(values.len())); + let list = GenericListArray::::new(to.clone(), offsets, values, None); + Ok(Arc::new(list)) +} + +/// Helper function that takes a primitive array and casts to a fixed size list array. +pub(crate) fn cast_values_to_fixed_size_list( + array: &dyn Array, + to: &FieldRef, + size: i32, + cast_options: &CastOptions, +) -> Result { + let values = cast_with_options(array, to.data_type(), cast_options)?; + let list = FixedSizeListArray::new(to.clone(), size, values, None); + Ok(Arc::new(list)) +} + +pub(crate) fn cast_fixed_size_list_to_list( + array: &dyn Array, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let fixed_size_list: &FixedSizeListArray = array.as_fixed_size_list(); + let list: GenericListArray = fixed_size_list.clone().into(); + Ok(Arc::new(list)) +} + +pub(crate) fn cast_list_to_fixed_size_list( + array: &GenericListArray, + field: &FieldRef, + size: i32, + cast_options: &CastOptions, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let cap = array.len() * size as usize; + + let mut nulls = (cast_options.safe || array.null_count() != 0).then(|| { + let mut buffer = BooleanBufferBuilder::new(array.len()); + match array.nulls() { + Some(n) => buffer.append_buffer(n.inner()), + None => buffer.append_n(array.len(), true), + } + buffer + }); + + // Nulls in FixedSizeListArray take up space and so we must pad the values + let values = array.values().to_data(); + let mut mutable = MutableArrayData::new(vec![&values], cast_options.safe, cap); + // The end position in values of the last incorrectly-sized list slice + let mut last_pos = 0; + for (idx, w) in array.offsets().windows(2).enumerate() { + let start_pos = w[0].as_usize(); + let end_pos = w[1].as_usize(); + let len = end_pos - start_pos; + + if len != size as usize { + if cast_options.safe || array.is_null(idx) { + if last_pos != start_pos { + // Extend with valid slices + mutable.extend(0, last_pos, start_pos); + } + // Pad this slice with nulls + mutable.extend_nulls(size as _); + nulls.as_mut().unwrap().set_bit(idx, false); + // Set last_pos to the end of this slice's values + last_pos = end_pos + } else { + return Err(ArrowError::CastError(format!( + "Cannot cast to FixedSizeList({size}): value at index {idx} has length {len}", + ))); + } + } + } + + let values = match last_pos { + 0 => array.values().slice(0, cap), // All slices were the correct length + _ => { + if mutable.len() != cap { + // Remaining slices were all correct length + let remaining = cap - mutable.len(); + mutable.extend(0, last_pos, last_pos + remaining) + } + make_array(mutable.freeze()) + } + }; + + // Cast the inner values if necessary + let values = cast_with_options(values.as_ref(), field.data_type(), cast_options)?; + + // Construct the FixedSizeListArray + let nulls = nulls.map(|mut x| x.finish().into()); + let array = FixedSizeListArray::new(field.clone(), size, values, nulls); + Ok(Arc::new(array)) +} + +/// Helper function that takes an Generic list container and casts the inner datatype. +pub(crate) fn cast_list_values( + array: &dyn Array, + to: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let list = array.as_list::(); + let values = cast_with_options(list.values(), to.data_type(), cast_options)?; + Ok(Arc::new(GenericListArray::::new( + to.clone(), + list.offsets().clone(), + values, + list.nulls().cloned(), + ))) +} + +/// Cast the container type of List/Largelist array along with the inner datatype +pub(crate) fn cast_list( + array: &dyn Array, + field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let list = array.as_list::(); + let values = list.values(); + let offsets = list.offsets(); + let nulls = list.nulls().cloned(); + + if !O::IS_LARGE && values.len() > i32::MAX as usize { + return Err(ArrowError::ComputeError( + "LargeList too large to cast to List".into(), + )); + } + + // Recursively cast values + let values = cast_with_options(values, field.data_type(), cast_options)?; + let offsets: Vec<_> = offsets.iter().map(|x| O::usize_as(x.as_usize())).collect(); + + // Safety: valid offsets and checked for overflow + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + Ok(Arc::new(GenericListArray::::new( + field.clone(), + offsets, + values, + nulls, + ))) +} diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast/mod.rs similarity index 91% rename from arrow-cast/src/cast.rs rename to arrow-cast/src/cast/mod.rs index 7868946532c4..52eb0d367271 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast/mod.rs @@ -37,6 +37,13 @@ //! assert_eq!(7.0, c.value(2)); //! ``` +mod decimal; +mod dictionary; +mod list; +use crate::cast::decimal::*; +use crate::cast::dictionary::*; +use crate::cast::list::*; + use chrono::{NaiveTime, Offset, TimeZone, Utc}; use std::cmp::Ordering; use std::sync::Arc; @@ -334,92 +341,6 @@ where Ok(Arc::new(array.with_precision_and_scale(precision, scale)?)) } -fn cast_floating_point_to_decimal128( - array: &PrimitiveArray, - precision: u8, - scale: i8, - cast_options: &CastOptions, -) -> Result -where - ::Native: AsPrimitive, -{ - let mul = 10_f64.powi(scale as i32); - - if cast_options.safe { - array - .unary_opt::<_, Decimal128Type>(|v| { - (mul * v.as_()) - .round() - .to_i128() - .filter(|v| Decimal128Type::validate_decimal_precision(*v, precision).is_ok()) - }) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) - } else { - array - .try_unary::<_, Decimal128Type, _>(|v| { - (mul * v.as_()) - .round() - .to_i128() - .ok_or_else(|| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - Decimal128Type::PREFIX, - precision, - scale, - v - )) - }) - .and_then(|v| { - Decimal128Type::validate_decimal_precision(v, precision).map(|_| v) - }) - })? - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) - } -} - -fn cast_floating_point_to_decimal256( - array: &PrimitiveArray, - precision: u8, - scale: i8, - cast_options: &CastOptions, -) -> Result -where - ::Native: AsPrimitive, -{ - let mul = 10_f64.powi(scale as i32); - - if cast_options.safe { - array - .unary_opt::<_, Decimal256Type>(|v| { - i256::from_f64((v.as_() * mul).round()) - .filter(|v| Decimal256Type::validate_decimal_precision(*v, precision).is_ok()) - }) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) - } else { - array - .try_unary::<_, Decimal256Type, _>(|v| { - i256::from_f64((v.as_() * mul).round()) - .ok_or_else(|| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - Decimal256Type::PREFIX, - precision, - scale, - v - )) - }) - .and_then(|v| { - Decimal256Type::validate_decimal_precision(v, precision).map(|_| v) - }) - })? - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) - } -} - /// Cast the array from interval year month to month day nano fn cast_interval_year_month_to_interval_month_day_nano( array: &dyn Array, @@ -549,79 +470,6 @@ fn cast_reinterpret_arrays().reinterpret_cast::())) } -fn cast_decimal_to_integer( - array: &dyn Array, - base: D::Native, - scale: i8, - cast_options: &CastOptions, -) -> Result -where - T: ArrowPrimitiveType, - ::Native: NumCast, - D: DecimalType + ArrowPrimitiveType, - ::Native: ArrowNativeTypeOp + ToPrimitive, -{ - let array = array.as_primitive::(); - - let div: D::Native = base.pow_checked(scale as u32).map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast to {:?}. The scale {} causes overflow.", - D::PREFIX, - scale, - )) - })?; - - let mut value_builder = PrimitiveBuilder::::with_capacity(array.len()); - - if cast_options.safe { - for i in 0..array.len() { - if array.is_null(i) { - value_builder.append_null(); - } else { - let v = array - .value(i) - .div_checked(div) - .ok() - .and_then(::from::); - - value_builder.append_option(v); - } - } - } else { - for i in 0..array.len() { - if array.is_null(i) { - value_builder.append_null(); - } else { - let v = array.value(i).div_checked(div)?; - - let value = ::from::(v).ok_or_else(|| { - ArrowError::CastError(format!( - "value of {:?} is out of range {}", - v, - T::DATA_TYPE - )) - })?; - - value_builder.append_value(value); - } - } - } - Ok(Arc::new(value_builder.finish())) -} - -// cast the decimal array to floating-point array -fn cast_decimal_to_float( - array: &dyn Array, - op: F, -) -> Result -where - F: Fn(D::Native) -> T::Native, -{ - let array = array.as_primitive::(); - let array = array.unary::<_, T>(op); - Ok(Arc::new(array)) -} - fn make_timestamp_array( array: &PrimitiveArray, unit: TimeUnit, @@ -2097,212 +1945,6 @@ const fn time_unit_multiple(unit: &TimeUnit) -> i64 { } } -/// A utility trait that provides checked conversions between -/// decimal types inspired by [`NumCast`] -trait DecimalCast: Sized { - fn to_i128(self) -> Option; - - fn to_i256(self) -> Option; - - fn from_decimal(n: T) -> Option; -} - -impl DecimalCast for i128 { - fn to_i128(self) -> Option { - Some(self) - } - - fn to_i256(self) -> Option { - Some(i256::from_i128(self)) - } - - fn from_decimal(n: T) -> Option { - n.to_i128() - } -} - -impl DecimalCast for i256 { - fn to_i128(self) -> Option { - self.to_i128() - } - - fn to_i256(self) -> Option { - Some(self) - } - - fn from_decimal(n: T) -> Option { - n.to_i256() - } -} - -fn cast_decimal_to_decimal_error( - output_precision: u8, - output_scale: i8, -) -> impl Fn(::Native) -> ArrowError -where - I: DecimalType, - O: DecimalType, - I::Native: DecimalCast + ArrowNativeTypeOp, - O::Native: DecimalCast + ArrowNativeTypeOp, -{ - move |x: I::Native| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - O::PREFIX, - output_precision, - output_scale, - x - )) - } -} - -fn convert_to_smaller_scale_decimal( - array: &PrimitiveArray, - input_scale: i8, - output_precision: u8, - output_scale: i8, - cast_options: &CastOptions, -) -> Result, ArrowError> -where - I: DecimalType, - O: DecimalType, - I::Native: DecimalCast + ArrowNativeTypeOp, - O::Native: DecimalCast + ArrowNativeTypeOp, -{ - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); - let div = I::Native::from_decimal(10_i128) - .unwrap() - .pow_checked((input_scale - output_scale) as u32)?; - - let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); - let half_neg = half.neg_wrapping(); - - let f = |x: I::Native| { - // div is >= 10 and so this cannot overflow - let d = x.div_wrapping(div); - let r = x.mod_wrapping(div); - - // Round result - let adjusted = match x >= I::Native::ZERO { - true if r >= half => d.add_wrapping(I::Native::ONE), - false if r <= half_neg => d.sub_wrapping(I::Native::ONE), - _ => d, - }; - O::Native::from_decimal(adjusted) - }; - - Ok(match cast_options.safe { - true => array.unary_opt(f), - false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, - }) -} - -fn convert_to_bigger_or_equal_scale_decimal( - array: &PrimitiveArray, - input_scale: i8, - output_precision: u8, - output_scale: i8, - cast_options: &CastOptions, -) -> Result, ArrowError> -where - I: DecimalType, - O: DecimalType, - I::Native: DecimalCast + ArrowNativeTypeOp, - O::Native: DecimalCast + ArrowNativeTypeOp, -{ - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); - let mul = O::Native::from_decimal(10_i128) - .unwrap() - .pow_checked((output_scale - input_scale) as u32)?; - - let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); - - Ok(match cast_options.safe { - true => array.unary_opt(f), - false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, - }) -} - -// Only support one type of decimal cast operations -fn cast_decimal_to_decimal_same_type( - array: &PrimitiveArray, - input_scale: i8, - output_precision: u8, - output_scale: i8, - cast_options: &CastOptions, -) -> Result -where - T: DecimalType, - T::Native: DecimalCast + ArrowNativeTypeOp, -{ - let array: PrimitiveArray = match input_scale.cmp(&output_scale) { - Ordering::Equal => { - // the scale doesn't change, the native value don't need to be changed - array.clone() - } - Ordering::Greater => convert_to_smaller_scale_decimal::( - array, - input_scale, - output_precision, - output_scale, - cast_options, - )?, - Ordering::Less => { - // input_scale < output_scale - convert_to_bigger_or_equal_scale_decimal::( - array, - input_scale, - output_precision, - output_scale, - cast_options, - )? - } - }; - - Ok(Arc::new(array.with_precision_and_scale( - output_precision, - output_scale, - )?)) -} - -// Support two different types of decimal cast operations -fn cast_decimal_to_decimal( - array: &PrimitiveArray, - input_scale: i8, - output_precision: u8, - output_scale: i8, - cast_options: &CastOptions, -) -> Result -where - I: DecimalType, - O: DecimalType, - I::Native: DecimalCast + ArrowNativeTypeOp, - O::Native: DecimalCast + ArrowNativeTypeOp, -{ - let array: PrimitiveArray = if input_scale > output_scale { - convert_to_smaller_scale_decimal::( - array, - input_scale, - output_precision, - output_scale, - cast_options, - )? - } else { - convert_to_bigger_or_equal_scale_decimal::( - array, - input_scale, - output_precision, - output_scale, - cast_options, - )? - }; - - Ok(Arc::new(array.with_precision_and_scale( - output_precision, - output_scale, - )?)) -} - /// Convert Array into a PrimitiveArray of type, and apply numeric cast fn cast_numeric_arrays( from: &dyn Array, @@ -2615,196 +2257,6 @@ where Ok(Arc::new(output_array)) } -/// Parses given string to specified decimal native (i128/i256) based on given -/// scale. Returns an `Err` if it cannot parse given string. -fn parse_string_to_decimal_native( - value_str: &str, - scale: usize, -) -> Result -where - T::Native: DecimalCast + ArrowNativeTypeOp, -{ - let value_str = value_str.trim(); - let parts: Vec<&str> = value_str.split('.').collect(); - if parts.len() > 2 { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid decimal format: {value_str:?}" - ))); - } - - let (negative, first_part) = if parts[0].is_empty() { - (false, parts[0]) - } else { - match parts[0].as_bytes()[0] { - b'-' => (true, &parts[0][1..]), - b'+' => (false, &parts[0][1..]), - _ => (false, parts[0]), - } - }; - - let integers = first_part.trim_start_matches('0'); - let decimals = if parts.len() == 2 { parts[1] } else { "" }; - - if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid decimal format: {value_str:?}" - ))); - } - - if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid decimal format: {value_str:?}" - ))); - } - - // Adjust decimal based on scale - let mut number_decimals = if decimals.len() > scale { - let decimal_number = i256::from_string(decimals).ok_or_else(|| { - ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}")) - })?; - - let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; - - let half = div.div_wrapping(i256::from_i128(2)); - let half_neg = half.neg_wrapping(); - - let d = decimal_number.div_wrapping(div); - let r = decimal_number.mod_wrapping(div); - - // Round result - let adjusted = match decimal_number >= i256::ZERO { - true if r >= half => d.add_wrapping(i256::ONE), - false if r <= half_neg => d.sub_wrapping(i256::ONE), - _ => d, - }; - - let integers = if !integers.is_empty() { - i256::from_string(integers) - .ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "Cannot parse decimal format: {value_str}" - )) - }) - .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))? - } else { - i256::ZERO - }; - - format!("{}", integers.add_wrapping(adjusted)) - } else { - let padding = if scale > decimals.len() { scale } else { 0 }; - - let decimals = format!("{decimals:0( - from: &GenericStringArray, - precision: u8, - scale: i8, - cast_options: &CastOptions, -) -> Result, ArrowError> -where - T: DecimalType, - T::Native: DecimalCast + ArrowNativeTypeOp, -{ - if cast_options.safe { - let iter = from.iter().map(|v| { - v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) - .and_then(|v| { - T::validate_decimal_precision(v, precision) - .is_ok() - .then_some(v) - }) - }); - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - Ok(unsafe { - PrimitiveArray::::from_trusted_len_iter(iter) - .with_precision_and_scale(precision, scale)? - }) - } else { - let vec = from - .iter() - .map(|v| { - v.map(|v| { - parse_string_to_decimal_native::(v, scale as usize) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - T::DATA_TYPE, - )) - }) - .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) - }) - .transpose() - }) - .collect::, _>>()?; - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - Ok(unsafe { - PrimitiveArray::::from_trusted_len_iter(vec.iter()) - .with_precision_and_scale(precision, scale)? - }) - } -} - -/// Cast Utf8 to decimal -fn cast_string_to_decimal( - from: &dyn Array, - precision: u8, - scale: i8, - cast_options: &CastOptions, -) -> Result -where - T: DecimalType, - T::Native: DecimalCast + ArrowNativeTypeOp, -{ - if scale < 0 { - return Err(ArrowError::InvalidArgumentError(format!( - "Cannot cast string to decimal with negative scale {scale}" - ))); - } - - if scale > T::MAX_SCALE { - return Err(ArrowError::InvalidArgumentError(format!( - "Cannot cast string to decimal greater than maximum scale {}", - T::MAX_SCALE - ))); - } - - Ok(Arc::new(string_to_decimal_cast::( - from.as_any() - .downcast_ref::>() - .unwrap(), - precision, - scale, - cast_options, - )?)) -} - /// Cast numeric types to Boolean /// /// Any zero value returns `false` while non-zero returns `true` @@ -2873,208 +2325,6 @@ where unsafe { PrimitiveArray::::from_trusted_len_iter(iter) } } -/// Attempts to cast an `ArrayDictionary` with index type K into -/// `to_type` for supported types. -/// -/// K is the key type -fn dictionary_cast( - array: &dyn Array, - to_type: &DataType, - cast_options: &CastOptions, -) -> Result { - use DataType::*; - - match to_type { - Dictionary(to_index_type, to_value_type) => { - let dict_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), - ) - })?; - - let keys_array: ArrayRef = - Arc::new(PrimitiveArray::::from(dict_array.keys().to_data())); - let values_array = dict_array.values(); - let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; - let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; - - // Failure to cast keys (because they don't fit in the - // target type) results in NULL values; - if cast_keys.null_count() > keys_array.null_count() { - return Err(ArrowError::ComputeError(format!( - "Could not convert {} dictionary indexes from {:?} to {:?}", - cast_keys.null_count() - keys_array.null_count(), - keys_array.data_type(), - to_index_type - ))); - } - - let data = cast_keys.into_data(); - let builder = data - .into_builder() - .data_type(to_type.clone()) - .child_data(vec![cast_values.into_data()]); - - // Safety - // Cast keys are still valid - let data = unsafe { builder.build_unchecked() }; - - // create the appropriate array type - let new_array: ArrayRef = match **to_index_type { - Int8 => Arc::new(DictionaryArray::::from(data)), - Int16 => Arc::new(DictionaryArray::::from(data)), - Int32 => Arc::new(DictionaryArray::::from(data)), - Int64 => Arc::new(DictionaryArray::::from(data)), - UInt8 => Arc::new(DictionaryArray::::from(data)), - UInt16 => Arc::new(DictionaryArray::::from(data)), - UInt32 => Arc::new(DictionaryArray::::from(data)), - UInt64 => Arc::new(DictionaryArray::::from(data)), - _ => { - return Err(ArrowError::CastError(format!( - "Unsupported type {to_index_type:?} for dictionary index" - ))); - } - }; - - Ok(new_array) - } - _ => unpack_dictionary::(array, to_type, cast_options), - } -} - -// Unpack a dictionary where the keys are of type into a flattened array of type to_type -fn unpack_dictionary( - array: &dyn Array, - to_type: &DataType, - cast_options: &CastOptions, -) -> Result -where - K: ArrowDictionaryKeyType, -{ - let dict_array = array.as_dictionary::(); - let cast_dict_values = cast_with_options(dict_array.values(), to_type, cast_options)?; - take(cast_dict_values.as_ref(), dict_array.keys(), None) -} - -/// Attempts to encode an array into an `ArrayDictionary` with index -/// type K and value (dictionary) type value_type -/// -/// K is the key type -fn cast_to_dictionary( - array: &dyn Array, - dict_value_type: &DataType, - cast_options: &CastOptions, -) -> Result { - use DataType::*; - - match *dict_value_type { - Int8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - Int16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - Int32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - Int64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - UInt8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - Decimal128(_, _) => { - pack_numeric_to_dictionary::(array, dict_value_type, cast_options) - } - Decimal256(_, _) => { - pack_numeric_to_dictionary::(array, dict_value_type, cast_options) - } - Utf8 => pack_byte_to_dictionary::>(array, cast_options), - LargeUtf8 => pack_byte_to_dictionary::>(array, cast_options), - Binary => pack_byte_to_dictionary::>(array, cast_options), - LargeBinary => pack_byte_to_dictionary::>(array, cast_options), - _ => Err(ArrowError::CastError(format!( - "Unsupported output type for dictionary packing: {dict_value_type:?}" - ))), - } -} - -// Packs the data from the primitive array of type to a -// DictionaryArray with keys of type K and values of value_type V -fn pack_numeric_to_dictionary( - array: &dyn Array, - dict_value_type: &DataType, - cast_options: &CastOptions, -) -> Result -where - K: ArrowDictionaryKeyType, - V: ArrowPrimitiveType, -{ - // attempt to cast the source array values to the target value type (the dictionary values type) - let cast_values = cast_with_options(array, dict_value_type, cast_options)?; - let values = cast_values.as_primitive::(); - - let mut b = PrimitiveDictionaryBuilder::::with_capacity(values.len(), values.len()); - - // copy each element one at a time - for i in 0..values.len() { - if values.is_null(i) { - b.append_null(); - } else { - b.append(values.value(i))?; - } - } - Ok(Arc::new(b.finish())) -} - -// Packs the data as a GenericByteDictionaryBuilder, if possible, with the -// key types of K -fn pack_byte_to_dictionary( - array: &dyn Array, - cast_options: &CastOptions, -) -> Result -where - K: ArrowDictionaryKeyType, - T: ByteArrayType, -{ - let cast_values = cast_with_options(array, &T::DATA_TYPE, cast_options)?; - let values = cast_values - .as_any() - .downcast_ref::>() - .unwrap(); - let mut b = GenericByteDictionaryBuilder::::with_capacity(values.len(), 1024, 1024); - - // copy each element one at a time - for i in 0..values.len() { - if values.is_null(i) { - b.append_null(); - } else { - b.append(values.value(i))?; - } - } - Ok(Arc::new(b.finish())) -} - -/// Helper function that takes a primitive array and casts to a (generic) list array. -fn cast_values_to_list( - array: &dyn Array, - to: &FieldRef, - cast_options: &CastOptions, -) -> Result { - let values = cast_with_options(array, to.data_type(), cast_options)?; - let offsets = OffsetBuffer::from_lengths(std::iter::repeat(1).take(values.len())); - let list = GenericListArray::::new(to.clone(), offsets, values, None); - Ok(Arc::new(list)) -} - -/// Helper function that takes a primitive array and casts to a fixed size list array. -fn cast_values_to_fixed_size_list( - array: &dyn Array, - to: &FieldRef, - size: i32, - cast_options: &CastOptions, -) -> Result { - let values = cast_with_options(array, to.data_type(), cast_options)?; - let list = FixedSizeListArray::new(to.clone(), size, values, None); - Ok(Arc::new(list)) -} - /// A specified helper to cast from `GenericBinaryArray` to `GenericStringArray` when they have same /// offset size so re-encoding offset is unnecessary. fn cast_binary_to_string( @@ -3217,133 +2467,6 @@ where Ok(Arc::new(GenericByteArray::::from(array_data))) } -fn cast_fixed_size_list_to_list(array: &dyn Array) -> Result -where - OffsetSize: OffsetSizeTrait, -{ - let fixed_size_list: &FixedSizeListArray = array.as_fixed_size_list(); - let list: GenericListArray = fixed_size_list.clone().into(); - Ok(Arc::new(list)) -} - -fn cast_list_to_fixed_size_list( - array: &GenericListArray, - field: &FieldRef, - size: i32, - cast_options: &CastOptions, -) -> Result -where - OffsetSize: OffsetSizeTrait, -{ - let cap = array.len() * size as usize; - - let mut nulls = (cast_options.safe || array.null_count() != 0).then(|| { - let mut buffer = BooleanBufferBuilder::new(array.len()); - match array.nulls() { - Some(n) => buffer.append_buffer(n.inner()), - None => buffer.append_n(array.len(), true), - } - buffer - }); - - // Nulls in FixedSizeListArray take up space and so we must pad the values - let values = array.values().to_data(); - let mut mutable = MutableArrayData::new(vec![&values], cast_options.safe, cap); - // The end position in values of the last incorrectly-sized list slice - let mut last_pos = 0; - for (idx, w) in array.offsets().windows(2).enumerate() { - let start_pos = w[0].as_usize(); - let end_pos = w[1].as_usize(); - let len = end_pos - start_pos; - - if len != size as usize { - if cast_options.safe || array.is_null(idx) { - if last_pos != start_pos { - // Extend with valid slices - mutable.extend(0, last_pos, start_pos); - } - // Pad this slice with nulls - mutable.extend_nulls(size as _); - nulls.as_mut().unwrap().set_bit(idx, false); - // Set last_pos to the end of this slice's values - last_pos = end_pos - } else { - return Err(ArrowError::CastError(format!( - "Cannot cast to FixedSizeList({size}): value at index {idx} has length {len}", - ))); - } - } - } - - let values = match last_pos { - 0 => array.values().slice(0, cap), // All slices were the correct length - _ => { - if mutable.len() != cap { - // Remaining slices were all correct length - let remaining = cap - mutable.len(); - mutable.extend(0, last_pos, last_pos + remaining) - } - make_array(mutable.freeze()) - } - }; - - // Cast the inner values if necessary - let values = cast_with_options(values.as_ref(), field.data_type(), cast_options)?; - - // Construct the FixedSizeListArray - let nulls = nulls.map(|mut x| x.finish().into()); - let array = FixedSizeListArray::new(field.clone(), size, values, nulls); - Ok(Arc::new(array)) -} - -/// Helper function that takes an Generic list container and casts the inner datatype. -fn cast_list_values( - array: &dyn Array, - to: &FieldRef, - cast_options: &CastOptions, -) -> Result { - let list = array.as_list::(); - let values = cast_with_options(list.values(), to.data_type(), cast_options)?; - Ok(Arc::new(GenericListArray::::new( - to.clone(), - list.offsets().clone(), - values, - list.nulls().cloned(), - ))) -} - -/// Cast the container type of List/Largelist array along with the inner datatype -fn cast_list( - array: &dyn Array, - field: &FieldRef, - cast_options: &CastOptions, -) -> Result { - let list = array.as_list::(); - let values = list.values(); - let offsets = list.offsets(); - let nulls = list.nulls().cloned(); - - if !O::IS_LARGE && values.len() > i32::MAX as usize { - return Err(ArrowError::ComputeError( - "LargeList too large to cast to List".into(), - )); - } - - // Recursively cast values - let values = cast_with_options(values, field.data_type(), cast_options)?; - let offsets: Vec<_> = offsets.iter().map(|x| O::usize_as(x.as_usize())).collect(); - - // Safety: valid offsets and checked for overflow - let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; - - Ok(Arc::new(GenericListArray::::new( - field.clone(), - offsets, - values, - nulls, - ))) -} - #[cfg(test)] mod tests { use arrow_buffer::{Buffer, NullBuffer}; diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 83c8965fdf8a..5e0530289623 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -1877,7 +1877,7 @@ mod tests { #[test] fn test_bounded() { let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]); - let data = vec![ + let data = [ vec!["0"], vec!["1"], vec!["2"], @@ -1919,7 +1919,7 @@ mod tests { #[test] fn test_empty_projection() { let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]); - let data = vec![vec!["0"], vec!["1"]]; + let data = [vec!["0"], vec!["1"]]; let data = data .iter() diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index efd8b6dec90f..a8f8d1606506 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -16,6 +16,7 @@ // under the License. use arrow_flight::sql::server::PeekableFlightDataStream; +use arrow_flight::sql::DoPutPreparedStatementResult; use base64::prelude::BASE64_STANDARD; use base64::Engine; use futures::{stream, Stream, TryStreamExt}; @@ -272,7 +273,7 @@ impl FlightSqlService for FlightSqlServiceImpl { ) -> Result, Status> { let flight_descriptor = request.into_inner(); let ticket = Ticket { - ticket: query.encode_to_vec().into(), + ticket: query.as_any().encode_to_vec().into(), }; let endpoint = FlightEndpoint::new().with_ticket(ticket); @@ -292,7 +293,7 @@ impl FlightSqlService for FlightSqlServiceImpl { ) -> Result, Status> { let flight_descriptor = request.into_inner(); let ticket = Ticket { - ticket: query.encode_to_vec().into(), + ticket: query.as_any().encode_to_vec().into(), }; let endpoint = FlightEndpoint::new().with_ticket(ticket); @@ -312,7 +313,7 @@ impl FlightSqlService for FlightSqlServiceImpl { ) -> Result, Status> { let flight_descriptor = request.into_inner(); let ticket = Ticket { - ticket: query.encode_to_vec().into(), + ticket: query.as_any().encode_to_vec().into(), }; let endpoint = FlightEndpoint::new().with_ticket(ticket); @@ -341,7 +342,7 @@ impl FlightSqlService for FlightSqlServiceImpl { request: Request, ) -> Result, Status> { let flight_descriptor = request.into_inner(); - let ticket = Ticket::new(query.encode_to_vec()); + let ticket = Ticket::new(query.as_any().encode_to_vec()); let endpoint = FlightEndpoint::new().with_ticket(ticket); let flight_info = FlightInfo::new() @@ -399,7 +400,7 @@ impl FlightSqlService for FlightSqlServiceImpl { request: Request, ) -> Result, Status> { let flight_descriptor = request.into_inner(); - let ticket = Ticket::new(query.encode_to_vec()); + let ticket = Ticket::new(query.as_any().encode_to_vec()); let endpoint = FlightEndpoint::new().with_ticket(ticket); let flight_info = FlightInfo::new() @@ -619,7 +620,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandPreparedStatementQuery, _request: Request, - ) -> Result::DoPutStream>, Status> { + ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_query not implemented", )) diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs index 2b2f4af7ac90..01ea9b61a8f7 100644 --- a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs +++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs @@ -808,6 +808,25 @@ pub struct DoPutUpdateResult { #[prost(int64, tag = "1")] pub record_count: i64, } +/// An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`. +/// +/// *Note on legacy behavior*: previous versions of the protocol did not return any result for +/// this command, and that behavior should still be supported by clients. In that case, the client +/// can continue as though the fields in this message were not provided or set to sensible default values. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DoPutPreparedStatementResult { + /// Represents a (potentially updated) opaque handle for the prepared statement on the server. + /// Because the handle could potentially be updated, any previous handles for this prepared + /// statement should be considered invalid, and all subsequent requests for this prepared + /// statement must use this new handle. + /// The updated handle allows implementing query parameters with stateless services. + /// + /// When an updated handle is not provided by the server, clients should contiue + /// using the previous handle provided by `ActionCreatePreparedStatementResonse`. + #[prost(bytes = "bytes", optional, tag = "1")] + pub prepared_statement_handle: ::core::option::Option<::prost::bytes::Bytes>, +} /// /// Request message for the "CancelQuery" action. /// diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index a014137f6fa9..44250fbe63e2 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -35,7 +35,8 @@ use crate::sql::{ CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo, + CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, + SqlInfo, }; use crate::trailers::extract_lazy_trailers; use crate::{ @@ -501,6 +502,7 @@ impl PreparedStatement { } /// Submit parameters to the server, if any have been set on this prepared statement instance + /// Updates our stored prepared statement handle with the handle given by the server response. async fn write_bind_params(&mut self) -> Result<(), ArrowError> { if let Some(ref params_batch) = self.parameter_binding { let cmd = CommandPreparedStatementQuery { @@ -519,17 +521,38 @@ impl PreparedStatement { .await .map_err(flight_error_to_arrow_error)?; - self.flight_sql_client + // Attempt to update the stored handle with any updated handle in the DoPut result. + // Older servers do not respond with a result for DoPut, so skip this step when + // the stream closes with no response. + if let Some(result) = self + .flight_sql_client .do_put(stream::iter(flight_data)) .await? - .try_collect::>() + .message() .await - .map_err(status_to_arrow_error)?; + .map_err(status_to_arrow_error)? + { + if let Some(handle) = self.unpack_prepared_statement_handle(&result)? { + self.handle = handle; + } + } } - Ok(()) } + /// Decodes the app_metadata stored in a [`PutResult`] as a + /// [`DoPutPreparedStatementResult`] and then returns + /// the inner prepared statement handle as [`Bytes`] + fn unpack_prepared_statement_handle( + &self, + put_result: &PutResult, + ) -> Result, ArrowError> { + let any = Any::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?; + Ok(any + .unpack::()? + .and_then(|result| result.prepared_statement_handle)) + } + /// Close the prepared statement, so that this PreparedStatement can not used /// anymore and server can free up any resources. pub async fn close(mut self) -> Result<(), ArrowError> { diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 97645ae7840d..089ee4dd8c3e 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -75,6 +75,7 @@ pub use gen::CommandPreparedStatementUpdate; pub use gen::CommandStatementQuery; pub use gen::CommandStatementSubstraitPlan; pub use gen::CommandStatementUpdate; +pub use gen::DoPutPreparedStatementResult; pub use gen::DoPutUpdateResult; pub use gen::Nullable; pub use gen::Searchable; @@ -251,6 +252,7 @@ prost_message_ext!( CommandStatementSubstraitPlan, CommandStatementUpdate, DoPutUpdateResult, + DoPutPreparedStatementResult, TicketStatementQuery, ); diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 0431e58111a4..c18024cf068a 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -33,7 +33,8 @@ use super::{ CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, - DoPutUpdateResult, ProstMessageExt, SqlInfo, TicketStatementQuery, + DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo, + TicketStatementQuery, }; use crate::{ flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty, @@ -397,11 +398,15 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { } /// Bind parameters to given prepared statement. + /// + /// Returns an opaque handle that the client should pass + /// back to the server during subsequent requests with this + /// prepared statement. async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, _request: Request, - ) -> Result::DoPutStream>, Status> { + ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_query has no default implementation", )) @@ -709,7 +714,13 @@ where Ok(Response::new(Box::pin(output))) } Command::CommandPreparedStatementQuery(command) => { - self.do_put_prepared_statement_query(command, request).await + let result = self + .do_put_prepared_statement_query(command, request) + .await?; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) } Command::CommandStatementSubstraitPlan(command) => { let record_count = self.do_put_substrait_plan(command, request).await?; diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index cc270eeb6186..50a4ec0d8c66 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -32,17 +32,18 @@ use arrow_flight::{ CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan, - CommandStatementUpdate, ProstMessageExt, SqlInfo, TicketStatementQuery, + CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt, SqlInfo, + TicketStatementQuery, }, utils::batches_to_flight_data, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, Ticket, + HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, }; use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema}; use assert_cmd::Command; use bytes::Bytes; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{Stream, TryStreamExt}; use prost::Message; use tokio::{net::TcpListener, task::JoinHandle}; use tonic::{Request, Response, Status, Streaming}; @@ -51,7 +52,7 @@ const QUERY: &str = "SELECT * FROM table;"; #[tokio::test] async fn test_simple() { - let test_server = FlightSqlServiceImpl {}; + let test_server = FlightSqlServiceImpl::default(); let fixture = TestFixture::new(&test_server).await; let addr = fixture.addr; @@ -92,10 +93,9 @@ async fn test_simple() { const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; +const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle"; -#[tokio::test] -async fn test_do_put_prepared_statement() { - let test_server = FlightSqlServiceImpl {}; +async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) { let fixture = TestFixture::new(&test_server).await; let addr = fixture.addr; @@ -136,11 +136,40 @@ async fn test_do_put_prepared_statement() { ); } +#[tokio::test] +pub async fn test_do_put_prepared_statement_stateless() { + test_do_put_prepared_statement(FlightSqlServiceImpl { + stateless_prepared_statements: true, + }) + .await +} + +#[tokio::test] +pub async fn test_do_put_prepared_statement_stateful() { + test_do_put_prepared_statement(FlightSqlServiceImpl { + stateless_prepared_statements: false, + }) + .await +} + /// All tests must complete within this many seconds or else the test server is shutdown const DEFAULT_TIMEOUT_SECONDS: u64 = 30; -#[derive(Clone, Default)] -pub struct FlightSqlServiceImpl {} +#[derive(Clone)] +pub struct FlightSqlServiceImpl { + /// Whether to emulate stateless (true) or stateful (false) behavior for + /// prepared statements. stateful servers will not return an updated + /// handle after executing `DoPut(CommandPreparedStatementQuery)` + stateless_prepared_statements: bool, +} + +impl Default for FlightSqlServiceImpl { + fn default() -> Self { + Self { + stateless_prepared_statements: true, + } + } +} impl FlightSqlServiceImpl { /// Return an [`FlightServiceServer`] that can be used with a @@ -274,10 +303,17 @@ impl FlightSqlService for FlightSqlServiceImpl { cmd: CommandPreparedStatementQuery, _request: Request, ) -> Result, Status> { - assert_eq!( - cmd.prepared_statement_handle, - PREPARED_STATEMENT_HANDLE.as_bytes() - ); + if self.stateless_prepared_statements { + assert_eq!( + cmd.prepared_statement_handle, + UPDATED_PREPARED_STATEMENT_HANDLE.as_bytes() + ); + } else { + assert_eq!( + cmd.prepared_statement_handle, + PREPARED_STATEMENT_HANDLE.as_bytes() + ); + } let resp = Response::new(self.fake_flight_info().unwrap()); Ok(resp) } @@ -524,7 +560,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandPreparedStatementQuery, request: Request, - ) -> Result::DoPutStream>, Status> { + ) -> Result { // just make sure decoding the parameters works let parameters = FlightRecordBatchStream::new_from_flight_data( request.into_inner().map_err(|e| e.into()), @@ -543,10 +579,15 @@ impl FlightSqlService for FlightSqlServiceImpl { ))); } } - - Ok(Response::new( - futures::stream::once(async { Ok(PutResult::default()) }).boxed(), - )) + let handle = if self.stateless_prepared_statements { + UPDATED_PREPARED_STATEMENT_HANDLE.to_string().into() + } else { + PREPARED_STATEMENT_HANDLE.to_string().into() + }; + let result = DoPutPreparedStatementResult { + prepared_statement_handle: Some(handle), + }; + Ok(result) } async fn do_put_prepared_statement_update( diff --git a/arrow-integration-testing/tests/ipc_reader.rs b/arrow-integration-testing/tests/ipc_reader.rs index 88cdad64f92f..a683075990c7 100644 --- a/arrow-integration-testing/tests/ipc_reader.rs +++ b/arrow-integration-testing/tests/ipc_reader.rs @@ -19,10 +19,12 @@ //! in `testing/arrow-ipc-stream/integration/...` use arrow::error::ArrowError; -use arrow::ipc::reader::{FileReader, StreamReader}; +use arrow::ipc::reader::{FileReader, StreamDecoder, StreamReader}; use arrow::util::test_util::arrow_test_data; +use arrow_buffer::Buffer; use arrow_integration_testing::read_gzip_json; use std::fs::File; +use std::io::Read; #[test] fn read_0_1_4() { @@ -182,18 +184,45 @@ fn verify_arrow_stream(testdata: &str, version: &str, path: &str) { let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); println!("Verifying {filename}"); + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + // Compare contents to the expected output format in JSON { println!(" verifying content"); let file = File::open(&filename).unwrap(); let mut reader = StreamReader::try_new(file, None).unwrap(); - // read expected JSON output - let arrow_json = read_gzip_json(version, path); assert!(arrow_json.equals_reader(&mut reader).unwrap()); // the next batch must be empty assert!(reader.next().is_none()); // the stream must indicate that it's finished assert!(reader.is_finished()); } + + // Test stream decoder + let expected = arrow_json.get_record_batches().unwrap(); + for chunk_sizes in [1, 2, 8, 123] { + let mut decoder = StreamDecoder::new(); + let stream = chunked_file(&filename, chunk_sizes); + let mut actual = Vec::with_capacity(expected.len()); + for mut x in stream { + while !x.is_empty() { + if let Some(x) = decoder.decode(&mut x).unwrap() { + actual.push(x); + } + } + } + decoder.finish().unwrap(); + assert_eq!(expected, actual); + } +} + +fn chunked_file(filename: &str, chunk_size: u64) -> impl Iterator { + let mut file = File::open(filename).unwrap(); + std::iter::from_fn(move || { + let mut buf = vec![]; + let read = (&mut file).take(chunk_size).read_to_end(&mut buf).unwrap(); + (read != 0).then(|| Buffer::from_vec(buf)) + }) } diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index d0b9145c7f27..49da0efae3a6 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -17,12 +17,17 @@ //! Utilities for converting between IPC types and native Arrow types +use arrow_buffer::Buffer; use arrow_schema::*; -use flatbuffers::{FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, WIPOffset}; +use flatbuffers::{ + FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, Verifiable, Verifier, + VerifierOptions, WIPOffset, +}; use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use crate::{size_prefixed_root_as_message, KeyValue, CONTINUATION_MARKER}; +use crate::{size_prefixed_root_as_message, KeyValue, Message, CONTINUATION_MARKER}; use DataType::*; /// Serialize a schema in IPC format @@ -817,6 +822,45 @@ pub(crate) fn get_fb_dictionary<'a>( builder.finish() } +/// An owned container for a validated [`Message`] +/// +/// Safely decoding a flatbuffer requires validating the various embedded offsets, +/// see [`Verifier`]. This is a potentially expensive operation, and it is therefore desirable +/// to only do this once. [`crate::root_as_message`] performs this validation on construction, +/// however, it returns a [`Message`] borrowing the provided byte slice. This prevents +/// storing this [`Message`] in the same data structure that owns the buffer, as this +/// would require self-referential borrows. +/// +/// [`MessageBuffer`] solves this problem by providing a safe API for a [`Message`] +/// without a lifetime bound. +#[derive(Clone)] +pub struct MessageBuffer(Buffer); + +impl Debug for MessageBuffer { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.as_ref().fmt(f) + } +} + +impl MessageBuffer { + /// Try to create a [`MessageBuffer`] from the provided [`Buffer`] + pub fn try_new(buf: Buffer) -> Result { + let opts = VerifierOptions::default(); + let mut v = Verifier::new(&opts, &buf); + >::run_verifier(&mut v, 0).map_err(|err| { + ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) + })?; + Ok(Self(buf)) + } + + /// Return the [`Message`] + #[inline] + pub fn as_ref(&self) -> Message<'_> { + // SAFETY: Run verifier on construction + unsafe { crate::root_as_message_unchecked(&self.0) } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index da5adf8d8f2c..4591777c1e37 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -20,6 +20,10 @@ //! The `FileReader` and `StreamReader` have similar interfaces, //! however the `FileReader` expects a reader that supports `Seek`ing +mod stream; + +pub use stream::*; + use flatbuffers::{VectorIter, VerifierOptions}; use std::collections::{HashMap, VecDeque}; use std::fmt; diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs new file mode 100644 index 000000000000..7807228175ac --- /dev/null +++ b/arrow-ipc/src/reader/stream.rs @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_schema::{ArrowError, SchemaRef}; + +use crate::convert::MessageBuffer; +use crate::reader::{read_dictionary, read_record_batch}; +use crate::{MessageHeader, CONTINUATION_MARKER}; + +/// A low-level interface for reading [`RecordBatch`] data from a stream of bytes +/// +/// See [StreamReader](crate::reader::StreamReader) for a higher-level interface +#[derive(Debug, Default)] +pub struct StreamDecoder { + /// The schema of this decoder, if read + schema: Option, + /// Lookup table for dictionaries by ID + dictionaries: HashMap, + /// The decoder state + state: DecoderState, + /// A scratch buffer when a read is split across multiple `Buffer` + buf: MutableBuffer, +} + +#[derive(Debug)] +enum DecoderState { + /// Decoding the message header + Header { + /// Temporary buffer + buf: [u8; 4], + /// Number of bytes read into buf + read: u8, + /// If we have read a continuation token + continuation: bool, + }, + /// Decoding the message flatbuffer + Message { + /// The size of the message flatbuffer + size: u32, + }, + /// Decoding the message body + Body { + /// The message flatbuffer + message: MessageBuffer, + }, + /// Reached the end of the stream + Finished, +} + +impl Default for DecoderState { + fn default() -> Self { + Self::Header { + buf: [0; 4], + read: 0, + continuation: false, + } + } +} + +impl StreamDecoder { + /// Create a new [`StreamDecoder`] + pub fn new() -> Self { + Self::default() + } + + /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] + /// + /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes. + /// + /// The push-based interface facilitates integration with sources that yield arbitrarily + /// delimited bytes ranges, such as a chunked byte stream received from object storage + /// + /// ``` + /// # use arrow_array::RecordBatch; + /// # use arrow_buffer::Buffer; + /// # use arrow_ipc::reader::StreamDecoder; + /// # use arrow_schema::ArrowError; + /// # + /// fn print_stream(src: impl Iterator) -> Result<(), ArrowError> { + /// let mut decoder = StreamDecoder::new(); + /// for mut x in src { + /// while !x.is_empty() { + /// if let Some(x) = decoder.decode(&mut x)? { + /// println!("{x:?}"); + /// } + /// } + /// } + /// decoder.finish().unwrap(); + /// Ok(()) + /// } + /// ``` + pub fn decode(&mut self, buffer: &mut Buffer) -> Result, ArrowError> { + while !buffer.is_empty() { + match &mut self.state { + DecoderState::Header { + buf, + read, + continuation, + } => { + let offset_buf = &mut buf[*read as usize..]; + let to_read = buffer.len().min(offset_buf.len()); + offset_buf[..to_read].copy_from_slice(&buffer[..to_read]); + *read += to_read as u8; + buffer.advance(to_read); + if *read == 4 { + if !*continuation && buf == &CONTINUATION_MARKER { + *continuation = true; + *read = 0; + continue; + } + let size = u32::from_le_bytes(*buf); + + if size == 0 { + self.state = DecoderState::Finished; + continue; + } + self.state = DecoderState::Message { size }; + } + } + DecoderState::Message { size } => { + let len = *size as usize; + if self.buf.is_empty() && buffer.len() > len { + let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?; + self.state = DecoderState::Body { message }; + buffer.advance(len); + continue; + } + + let to_read = buffer.len().min(len - self.buf.len()); + self.buf.extend_from_slice(&buffer[..to_read]); + buffer.advance(to_read); + if self.buf.len() == len { + let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?; + self.state = DecoderState::Body { message }; + } + } + DecoderState::Body { message } => { + let message = message.as_ref(); + let body_length = message.bodyLength() as usize; + + let body = if self.buf.is_empty() && buffer.len() >= body_length { + let body = buffer.slice_with_length(0, body_length); + buffer.advance(body_length); + body + } else { + let to_read = buffer.len().min(body_length - self.buf.len()); + self.buf.extend_from_slice(&buffer[..to_read]); + buffer.advance(to_read); + + if self.buf.len() != body_length { + continue; + } + std::mem::take(&mut self.buf).into() + }; + + let version = message.version(); + match message.header_type() { + MessageHeader::Schema => { + if self.schema.is_some() { + return Err(ArrowError::IpcError( + "Not expecting a schema when messages are read".to_string(), + )); + } + + let ipc_schema = message.header_as_schema().unwrap(); + let schema = crate::convert::fb_to_schema(ipc_schema); + self.state = DecoderState::default(); + self.schema = Some(Arc::new(schema)); + } + MessageHeader::RecordBatch => { + let batch = message.header_as_record_batch().unwrap(); + let schema = self.schema.clone().ok_or_else(|| { + ArrowError::IpcError("Missing schema".to_string()) + })?; + let batch = read_record_batch( + &body, + batch, + schema, + &self.dictionaries, + None, + &version, + )?; + self.state = DecoderState::default(); + return Ok(Some(batch)); + } + MessageHeader::DictionaryBatch => { + let dictionary = message.header_as_dictionary_batch().unwrap(); + let schema = self.schema.as_deref().ok_or_else(|| { + ArrowError::IpcError("Missing schema".to_string()) + })?; + read_dictionary( + &body, + dictionary, + schema, + &mut self.dictionaries, + &version, + )?; + self.state = DecoderState::default(); + } + MessageHeader::NONE => { + self.state = DecoderState::default(); + } + t => { + return Err(ArrowError::IpcError(format!( + "Message type unsupported by StreamDecoder: {t:?}" + ))) + } + } + } + DecoderState::Finished => { + return Err(ArrowError::IpcError("Unexpected EOS".to_string())) + } + } + } + Ok(None) + } + + /// Signal the end of stream + /// + /// Returns an error if any partial data remains in the stream + pub fn finish(&mut self) -> Result<(), ArrowError> { + match self.state { + DecoderState::Finished + | DecoderState::Header { + read: 0, + continuation: false, + .. + } => Ok(()), + _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::writer::StreamWriter; + use arrow_array::{Int32Array, Int64Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + + // Further tests in arrow-integration-testing/tests/ipc_reader.rs + + #[test] + fn test_eos() { + let schema = Arc::new(Schema::new(vec![ + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + ])); + + let input = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as _, + Arc::new(Int64Array::from(vec![1, 2, 3])) as _, + ], + ) + .unwrap(); + + let mut buf = Vec::with_capacity(1024); + let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap(); + s.write(&input).unwrap(); + s.finish().unwrap(); + drop(s); + + let buffer = Buffer::from_vec(buf); + + let mut b = buffer.slice_with_length(0, buffer.len() - 1); + let mut decoder = StreamDecoder::new(); + let output = decoder.decode(&mut b).unwrap().unwrap(); + assert_eq!(output, input); + assert_eq!(b.len(), 7); // 8 byte EOS truncated by 1 byte + assert!(decoder.decode(&mut b).unwrap().is_none()); + + let err = decoder.finish().unwrap_err().to_string(); + assert_eq!(err, "Ipc error: Unexpected End of Stream"); + } +} diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 265d6be1a503..2a3474fe0fc6 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -721,6 +721,7 @@ impl DictionaryTracker { } } +/// Writer for an IPC file pub struct FileWriter { /// The object to write to writer: BufWriter, @@ -745,13 +746,13 @@ pub struct FileWriter { } impl FileWriter { - /// Try create a new writer, with the schema written as part of the header + /// Try to create a new writer, with the schema written as part of the header pub fn try_new(writer: W, schema: &Schema) -> Result { let write_options = IpcWriteOptions::default(); Self::try_new_with_options(writer, schema, write_options) } - /// Try create a new writer with IpcWriteOptions + /// Try to create a new writer with IpcWriteOptions pub fn try_new_with_options( writer: W, schema: &Schema, @@ -901,6 +902,7 @@ impl RecordBatchWriter for FileWriter { } } +/// Writer for an IPC stream pub struct StreamWriter { /// The object to write to writer: BufWriter, @@ -915,7 +917,7 @@ pub struct StreamWriter { } impl StreamWriter { - /// Try create a new writer, with the schema written as part of the header + /// Try to create a new writer, with the schema written as part of the header pub fn try_new(writer: W, schema: &Schema) -> Result { let write_options = IpcWriteOptions::default(); Self::try_new_with_options(writer, schema, write_options) diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 99055573345a..628e5c96693d 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -416,7 +416,7 @@ impl Decoder { /// should be included in the next call to [`Self::decode`] /// /// There is no requirement that `buf` contains a whole number of records, facilitating - /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] + /// integration with arbitrary byte streams, such as those yielded by [`BufRead`] pub fn decode(&mut self, buf: &[u8]) -> Result { self.tape_decoder.decode(buf) } diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index c2f5293f94c8..037ed404adca 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -2223,7 +2223,7 @@ mod tests { let first = Int32Array::from(vec![None, Some(2), Some(4)]); let second = Int32Array::from(vec![Some(2), None, Some(4)]); - let arrays = vec![Arc::new(first) as ArrayRef, Arc::new(second) as ArrayRef]; + let arrays = [Arc::new(first) as ArrayRef, Arc::new(second) as ArrayRef]; for array in arrays.iter() { rows.clear(); diff --git a/format/FlightSql.proto b/format/FlightSql.proto index f78e77e23278..4fc68f2a5db0 100644 --- a/format/FlightSql.proto +++ b/format/FlightSql.proto @@ -1796,7 +1796,27 @@ // an unknown updated record count. int64 record_count = 1; } - + + /* An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`. + * + * *Note on legacy behavior*: previous versions of the protocol did not return any result for + * this command, and that behavior should still be supported by clients. In that case, the client + * can continue as though the fields in this message were not provided or set to sensible default values. + */ + message DoPutPreparedStatementResult { + option (experimental) = true; + + // Represents a (potentially updated) opaque handle for the prepared statement on the server. + // Because the handle could potentially be updated, any previous handles for this prepared + // statement should be considered invalid, and all subsequent requests for this prepared + // statement must use this new handle. + // The updated handle allows implementing query parameters with stateless services. + // + // When an updated handle is not provided by the server, clients should contiue + // using the previous handle provided by `ActionCreatePreparedStatementResonse`. + optional bytes prepared_statement_handle = 1; + } + /* * Request message for the "CancelQuery" action. * diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml index a1e80ce51ded..79813a0ea1fb 100644 --- a/object_store/Cargo.toml +++ b/object_store/Cargo.toml @@ -45,12 +45,12 @@ walkdir = "2" # Cloud storage support base64 = { version = "0.22", default-features = false, features = ["std"], optional = true } -hyper = { version = "0.14", default-features = false, optional = true } +hyper = { version = "1.2", default-features = false, optional = true } quick-xml = { version = "0.31.0", features = ["serialize", "overlapped-lists"], optional = true } serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } serde_json = { version = "1.0", default-features = false, optional = true } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } -reqwest = { version = "0.11", default-features = false, features = ["rustls-tls-native-roots"], optional = true } +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "http2"], optional = true } ring = { version = "0.17", default-features = false, features = ["std"], optional = true } rustls-pemfile = { version = "2.0", default-features = false, features = ["std"], optional = true } tokio = { version = "1.25.0", features = ["sync", "macros", "rt", "time", "io-util"] } @@ -69,7 +69,9 @@ tls-webpki-roots = ["reqwest?/rustls-tls-webpki-roots"] [dev-dependencies] # In alphabetical order futures-test = "0.3" -hyper = { version = "0.14.24", features = ["server"] } +hyper = { version = "1.2", features = ["server"] } +hyper-util = "0.1" +http-body-util = "0.1" rand = "0.8" tempfile = "3.1.0" diff --git a/object_store/src/aws/builder.rs b/object_store/src/aws/builder.rs index a578d1abbfae..664e18364600 100644 --- a/object_store/src/aws/builder.rs +++ b/object_store/src/aws/builder.rs @@ -1333,10 +1333,7 @@ mod tests { .unwrap_err() .to_string(); - assert_eq!( - "Generic HTTP client error: builder error: unknown proxy scheme", - err - ); + assert_eq!("Generic HTTP client error: builder error", err); } #[test] diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index aa9f6bf3320c..4d101456fd16 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -138,6 +138,7 @@ struct BatchDeleteResponse { #[derive(Deserialize)] enum DeleteObjectResult { + #[allow(unused)] Deleted(DeletedObject), Error(DeleteError), } diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index f8614f4f563c..dd7fa5b41da3 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -738,7 +738,7 @@ struct CreateSessionOutput { mod tests { use super::*; use crate::client::mock_server::MockServer; - use hyper::{Body, Response}; + use hyper::Response; use reqwest::{Client, Method}; use std::env; @@ -939,7 +939,7 @@ mod tests { #[tokio::test] async fn test_mock() { - let server = MockServer::new(); + let server = MockServer::new().await; const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token"; @@ -955,7 +955,7 @@ mod tests { server.push_fn(|req| { assert_eq!(req.uri().path(), "/latest/api/token"); assert_eq!(req.method(), &Method::PUT); - Response::new(Body::from("cupcakes")) + Response::new("cupcakes".to_string()) }); server.push_fn(|req| { assert_eq!( @@ -965,14 +965,14 @@ mod tests { assert_eq!(req.method(), &Method::GET); let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap(); assert_eq!(t, "cupcakes"); - Response::new(Body::from("myrole")) + Response::new("myrole".to_string()) }); server.push_fn(|req| { assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole"); assert_eq!(req.method(), &Method::GET); let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap(); assert_eq!(t, "cupcakes"); - Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#)) + Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string()) }); let creds = instance_creds(&client, &retry_config, endpoint, true) @@ -989,7 +989,7 @@ mod tests { assert_eq!(req.method(), &Method::PUT); Response::builder() .status(StatusCode::FORBIDDEN) - .body(Body::empty()) + .body(String::new()) .unwrap() }); server.push_fn(|req| { @@ -999,13 +999,13 @@ mod tests { ); assert_eq!(req.method(), &Method::GET); assert!(req.headers().get(IMDSV2_HEADER).is_none()); - Response::new(Body::from("myrole")) + Response::new("myrole".to_string()) }); server.push_fn(|req| { assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole"); assert_eq!(req.method(), &Method::GET); assert!(req.headers().get(IMDSV2_HEADER).is_none()); - Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#)) + Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string()) }); let creds = instance_creds(&client, &retry_config, endpoint, true) @@ -1020,7 +1020,7 @@ mod tests { server.push( Response::builder() .status(StatusCode::FORBIDDEN) - .body(Body::empty()) + .body(String::new()) .unwrap(), ); diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index b11f4513b6df..b33771de9a86 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -17,17 +17,14 @@ //! An object store implementation for S3 //! -//! ## Multi-part uploads +//! ## Multipart uploads //! -//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart] method. -//! Data passed to the writer is automatically buffered to meet the minimum size -//! requirements for a part. Multiple parts are uploaded concurrently. +//! Multipart uploads can be initiated with the [ObjectStore::put_multipart] method. //! //! If the writer fails for any reason, you may have parts uploaded to AWS but not -//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method -//! to abort the upload and drop those unneeded parts. In addition, you may wish to -//! consider implementing [automatic cleanup] of unused parts that are older than one -//! week. +//! used that you will be charged for. [`MultipartUpload::abort`] may be invoked to drop +//! these unneeded parts, however, it is recommended that you consider implementing +//! [automatic cleanup] of unused parts that are older than some threshold. //! //! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/ @@ -38,18 +35,17 @@ use futures::{StreamExt, TryStreamExt}; use reqwest::header::{HeaderName, IF_MATCH, IF_NONE_MATCH}; use reqwest::{Method, StatusCode}; use std::{sync::Arc, time::Duration}; -use tokio::io::AsyncWrite; use url::Url; use crate::aws::client::{RequestError, S3Client}; use crate::client::get::GetClientExt; use crate::client::list::ListClientExt; use crate::client::CredentialProvider; -use crate::multipart::{MultiPartStore, PartId, PutPart, WriteMultiPart}; +use crate::multipart::{MultipartStore, PartId}; use crate::signer::Signer; use crate::{ - Error, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, PutMode, - PutOptions, PutResult, Result, + Error, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, ObjectMeta, + ObjectStore, Path, PutMode, PutOptions, PutResult, Result, UploadPart, }; static TAGS_HEADER: HeaderName = HeaderName::from_static("x-amz-tagging"); @@ -85,6 +81,7 @@ const STORE: &str = "S3"; /// [`CredentialProvider`] for [`AmazonS3`] pub type AwsCredentialProvider = Arc>; +use crate::client::parts::Parts; pub use credential::{AwsAuthorizer, AwsCredential}; /// Interface for [Amazon S3](https://aws.amazon.com/s3/). @@ -211,25 +208,18 @@ impl ObjectStore for AmazonS3 { } } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - let id = self.client.create_multipart(location).await?; - - let upload = S3MultiPartUpload { - location: location.clone(), - upload_id: id.clone(), - client: Arc::clone(&self.client), - }; - - Ok((id, Box::new(WriteMultiPart::new(upload, 8)))) - } - - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> { - self.client - .delete_request(location, &[("uploadId", multipart_id)]) - .await + async fn put_multipart(&self, location: &Path) -> Result> { + let upload_id = self.client.create_multipart(location).await?; + + Ok(Box::new(S3MultiPartUpload { + part_idx: 0, + state: Arc::new(UploadState { + client: Arc::clone(&self.client), + location: location.clone(), + upload_id: upload_id.clone(), + parts: Default::default(), + }), + })) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { @@ -319,30 +309,55 @@ impl ObjectStore for AmazonS3 { } } +#[derive(Debug)] struct S3MultiPartUpload { + part_idx: usize, + state: Arc, +} + +#[derive(Debug)] +struct UploadState { + parts: Parts, location: Path, upload_id: String, client: Arc, } #[async_trait] -impl PutPart for S3MultiPartUpload { - async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { - self.client - .put_part(&self.location, &self.upload_id, part_idx, buf.into()) +impl MultipartUpload for S3MultiPartUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let idx = self.part_idx; + self.part_idx += 1; + let state = Arc::clone(&self.state); + Box::pin(async move { + let part = state + .client + .put_part(&state.location, &state.upload_id, idx, data) + .await?; + state.parts.put(idx, part); + Ok(()) + }) + } + + async fn complete(&mut self) -> Result { + let parts = self.state.parts.finish(self.part_idx)?; + + self.state + .client + .complete_multipart(&self.state.location, &self.state.upload_id, parts) .await } - async fn complete(&self, completed_parts: Vec) -> Result<()> { - self.client - .complete_multipart(&self.location, &self.upload_id, completed_parts) - .await?; - Ok(()) + async fn abort(&mut self) -> Result<()> { + self.state + .client + .delete_request(&self.state.location, &[("uploadId", &self.state.upload_id)]) + .await } } #[async_trait] -impl MultiPartStore for AmazonS3 { +impl MultipartStore for AmazonS3 { async fn create_multipart(&self, path: &Path) -> Result { self.client.create_multipart(path).await } @@ -377,7 +392,6 @@ mod tests { use crate::{client::get::GetClient, tests::*}; use bytes::Bytes; use hyper::HeaderMap; - use tokio::io::AsyncWriteExt; const NON_EXISTENT_NAME: &str = "nonexistentname"; @@ -542,9 +556,9 @@ mod tests { store.put(&locations[0], data.clone()).await.unwrap(); store.copy(&locations[0], &locations[1]).await.unwrap(); - let (_, mut writer) = store.put_multipart(&locations[2]).await.unwrap(); - writer.write_all(&data).await.unwrap(); - writer.shutdown().await.unwrap(); + let mut upload = store.put_multipart(&locations[2]).await.unwrap(); + upload.put_part(data.clone()).await.unwrap(); + upload.complete().await.unwrap(); for location in &locations { let res = store diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 9360831974ca..6dc3141b08c8 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -930,8 +930,8 @@ impl CredentialProvider for AzureCliCredential { #[cfg(test)] mod tests { use futures::executor::block_on; - use hyper::body::to_bytes; - use hyper::{Body, Response, StatusCode}; + use http_body_util::BodyExt; + use hyper::{Response, StatusCode}; use reqwest::{Client, Method}; use tempfile::NamedTempFile; @@ -942,7 +942,7 @@ mod tests { #[tokio::test] async fn test_managed_identity() { - let server = MockServer::new(); + let server = MockServer::new().await; std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret"); @@ -964,7 +964,7 @@ mod tests { assert_eq!(t, "env-secret"); let t = req.headers().get("metadata").unwrap().to_str().unwrap(); assert_eq!(t, "true"); - Response::new(Body::from( + Response::new( r#" { "access_token": "TOKEN", @@ -975,8 +975,9 @@ mod tests { "resource": "https://management.azure.com/", "token_type": "Bearer" } - "#, - )) + "# + .to_string(), + ) }); let credential = ImdsManagedIdentityProvider::new( @@ -999,7 +1000,7 @@ mod tests { #[tokio::test] async fn test_workload_identity() { - let server = MockServer::new(); + let server = MockServer::new().await; let tokenfile = NamedTempFile::new().unwrap(); let tenant = "tenant"; std::fs::write(tokenfile.path(), "federated-token").unwrap(); @@ -1012,10 +1013,10 @@ mod tests { server.push_fn(move |req| { assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token")); assert_eq!(req.method(), &Method::POST); - let body = block_on(to_bytes(req.into_body())).unwrap(); + let body = block_on(async move { req.into_body().collect().await.unwrap().to_bytes() }); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.contains("federated-token")); - Response::new(Body::from( + Response::new( r#" { "access_token": "TOKEN", @@ -1026,8 +1027,9 @@ mod tests { "resource": "https://management.azure.com/", "token_type": "Bearer" } - "#, - )) + "# + .to_string(), + ) }); let credential = WorkloadIdentityOAuthProvider::new( @@ -1050,7 +1052,7 @@ mod tests { #[tokio::test] async fn test_no_credentials() { - let server = MockServer::new(); + let server = MockServer::new().await; let endpoint = server.url(); let store = MicrosoftAzureBuilder::new() @@ -1068,7 +1070,7 @@ mod tests { assert!(req.headers().get("Authorization").is_none()); Response::builder() .status(StatusCode::NOT_FOUND) - .body(Body::from("not found")) + .body("not found".to_string()) .unwrap() }); diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 712b7a36c56a..5d3a405ccc93 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -19,19 +19,15 @@ //! //! ## Streaming uploads //! -//! [ObjectStore::put_multipart] will upload data in blocks and write a blob from those -//! blocks. Data is buffered internally to make blocks of at least 5MB and blocks -//! are uploaded concurrently. +//! [ObjectStore::put_multipart] will upload data in blocks and write a blob from those blocks. //! -//! [ObjectStore::abort_multipart] is a no-op, since Azure Blob Store doesn't provide -//! a way to drop old blocks. Instead unused blocks are automatically cleaned up -//! after 7 days. +//! Unused blocks will automatically be dropped after 7 days. use crate::{ - multipart::{MultiPartStore, PartId, PutPart, WriteMultiPart}, + multipart::{MultipartStore, PartId}, path::Path, signer::Signer, - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, - Result, + GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, ObjectMeta, ObjectStore, + PutOptions, PutResult, Result, UploadPart, }; use async_trait::async_trait; use bytes::Bytes; @@ -40,7 +36,6 @@ use reqwest::Method; use std::fmt::Debug; use std::sync::Arc; use std::time::Duration; -use tokio::io::AsyncWrite; use url::Url; use crate::client::get::GetClientExt; @@ -54,6 +49,8 @@ mod credential; /// [`CredentialProvider`] for [`MicrosoftAzure`] pub type AzureCredentialProvider = Arc>; +use crate::azure::client::AzureClient; +use crate::client::parts::Parts; pub use builder::{AzureConfigKey, MicrosoftAzureBuilder}; pub use credential::AzureCredential; @@ -94,21 +91,15 @@ impl ObjectStore for MicrosoftAzure { self.client.put_blob(location, bytes, opts).await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - let inner = AzureMultiPartUpload { - client: Arc::clone(&self.client), - location: location.to_owned(), - }; - Ok((String::new(), Box::new(WriteMultiPart::new(inner, 8)))) - } - - async fn abort_multipart(&self, _location: &Path, _multipart_id: &MultipartId) -> Result<()> { - // There is no way to drop blocks that have been uploaded. Instead, they simply - // expire in 7 days. - Ok(()) + async fn put_multipart(&self, location: &Path) -> Result> { + Ok(Box::new(AzureMultiPartUpload { + part_idx: 0, + state: Arc::new(UploadState { + client: Arc::clone(&self.client), + location: location.clone(), + parts: Default::default(), + }), + })) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { @@ -197,26 +188,49 @@ impl Signer for MicrosoftAzure { /// put_multipart_part -> PUT block /// complete -> PUT block list /// abort -> No equivalent; blocks are simply dropped after 7 days -#[derive(Debug, Clone)] +#[derive(Debug)] struct AzureMultiPartUpload { - client: Arc, + part_idx: usize, + state: Arc, +} + +#[derive(Debug)] +struct UploadState { location: Path, + parts: Parts, + client: Arc, } #[async_trait] -impl PutPart for AzureMultiPartUpload { - async fn put_part(&self, buf: Vec, idx: usize) -> Result { - self.client.put_block(&self.location, idx, buf.into()).await +impl MultipartUpload for AzureMultiPartUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let idx = self.part_idx; + self.part_idx += 1; + let state = Arc::clone(&self.state); + Box::pin(async move { + let part = state.client.put_block(&state.location, idx, data).await?; + state.parts.put(idx, part); + Ok(()) + }) + } + + async fn complete(&mut self) -> Result { + let parts = self.state.parts.finish(self.part_idx)?; + + self.state + .client + .put_block_list(&self.state.location, parts) + .await } - async fn complete(&self, parts: Vec) -> Result<()> { - self.client.put_block_list(&self.location, parts).await?; + async fn abort(&mut self) -> Result<()> { + // Nothing to do Ok(()) } } #[async_trait] -impl MultiPartStore for MicrosoftAzure { +impl MultipartStore for MicrosoftAzure { async fn create_multipart(&self, _: &Path) -> Result { Ok(String::new()) } diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs index fdefe599f79e..de6d4eb1bb9c 100644 --- a/object_store/src/buffered.rs +++ b/object_store/src/buffered.rs @@ -18,7 +18,7 @@ //! Utilities for performing tokio-style buffered IO use crate::path::Path; -use crate::{MultipartId, ObjectMeta, ObjectStore}; +use crate::{ObjectMeta, ObjectStore, WriteMultipart}; use bytes::Bytes; use futures::future::{BoxFuture, FutureExt}; use futures::ready; @@ -27,7 +27,7 @@ use std::io::{Error, ErrorKind, SeekFrom}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; /// The default buffer size used by [`BufReader`] pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024; @@ -207,13 +207,17 @@ impl AsyncBufRead for BufReader { /// An async buffered writer compatible with the tokio IO traits /// +/// This writer adaptively uses [`ObjectStore::put`] or +/// [`ObjectStore::put_multipart`] depending on the amount of data that has +/// been written. +/// /// Up to `capacity` bytes will be buffered in memory, and flushed on shutdown /// using [`ObjectStore::put`]. If `capacity` is exceeded, data will instead be /// streamed using [`ObjectStore::put_multipart`] pub struct BufWriter { capacity: usize, + max_concurrency: usize, state: BufWriterState, - multipart_id: Option, store: Arc, } @@ -221,22 +225,19 @@ impl std::fmt::Debug for BufWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("BufWriter") .field("capacity", &self.capacity) - .field("multipart_id", &self.multipart_id) .finish() } } -type MultipartResult = (MultipartId, Box); - enum BufWriterState { /// Buffer up to capacity bytes Buffer(Path, Vec), /// [`ObjectStore::put_multipart`] - Prepare(BoxFuture<'static, std::io::Result>), + Prepare(BoxFuture<'static, std::io::Result>), /// Write to a multipart upload - Write(Box), + Write(Option), /// [`ObjectStore::put`] - Put(BoxFuture<'static, std::io::Result<()>>), + Flush(BoxFuture<'static, std::io::Result<()>>), } impl BufWriter { @@ -250,14 +251,32 @@ impl BufWriter { Self { capacity, store, + max_concurrency: 8, state: BufWriterState::Buffer(path, Vec::new()), - multipart_id: None, } } - /// Returns the [`MultipartId`] if multipart upload - pub fn multipart_id(&self) -> Option<&MultipartId> { - self.multipart_id.as_ref() + /// Override the maximum number of in-flight requests for this writer + /// + /// Defaults to 8 + pub fn with_max_concurrency(self, max_concurrency: usize) -> Self { + Self { + max_concurrency, + ..self + } + } + + /// Abort this writer, cleaning up any partially uploaded state + /// + /// # Panic + /// + /// Panics if this writer has already been shutdown or aborted + pub async fn abort(&mut self) -> crate::Result<()> { + match &mut self.state { + BufWriterState::Buffer(_, _) | BufWriterState::Prepare(_) => Ok(()), + BufWriterState::Flush(_) => panic!("Already shut down"), + BufWriterState::Write(x) => x.take().unwrap().abort().await, + } } } @@ -268,14 +287,19 @@ impl AsyncWrite for BufWriter { buf: &[u8], ) -> Poll> { let cap = self.capacity; + let max_concurrency = self.max_concurrency; loop { return match &mut self.state { - BufWriterState::Write(write) => Pin::new(write).poll_write(cx, buf), - BufWriterState::Put(_) => panic!("Already shut down"), + BufWriterState::Write(Some(write)) => { + ready!(write.poll_for_capacity(cx, max_concurrency))?; + write.write(buf); + Poll::Ready(Ok(buf.len())) + } + BufWriterState::Write(None) | BufWriterState::Flush(_) => { + panic!("Already shut down") + } BufWriterState::Prepare(f) => { - let (id, w) = ready!(f.poll_unpin(cx)?); - self.state = BufWriterState::Write(w); - self.multipart_id = Some(id); + self.state = BufWriterState::Write(ready!(f.poll_unpin(cx)?).into()); continue; } BufWriterState::Buffer(path, b) => { @@ -284,9 +308,10 @@ impl AsyncWrite for BufWriter { let path = std::mem::take(path); let store = Arc::clone(&self.store); self.state = BufWriterState::Prepare(Box::pin(async move { - let (id, mut writer) = store.put_multipart(&path).await?; - writer.write_all(&buffer).await?; - Ok((id, writer)) + let upload = store.put_multipart(&path).await?; + let mut chunked = WriteMultipart::new(upload); + chunked.write(&buffer); + Ok(chunked) })); continue; } @@ -300,13 +325,10 @@ impl AsyncWrite for BufWriter { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { return match &mut self.state { - BufWriterState::Buffer(_, _) => Poll::Ready(Ok(())), - BufWriterState::Write(write) => Pin::new(write).poll_flush(cx), - BufWriterState::Put(_) => panic!("Already shut down"), + BufWriterState::Write(_) | BufWriterState::Buffer(_, _) => Poll::Ready(Ok(())), + BufWriterState::Flush(_) => panic!("Already shut down"), BufWriterState::Prepare(f) => { - let (id, w) = ready!(f.poll_unpin(cx)?); - self.state = BufWriterState::Write(w); - self.multipart_id = Some(id); + self.state = BufWriterState::Write(ready!(f.poll_unpin(cx)?).into()); continue; } }; @@ -317,21 +339,28 @@ impl AsyncWrite for BufWriter { loop { match &mut self.state { BufWriterState::Prepare(f) => { - let (id, w) = ready!(f.poll_unpin(cx)?); - self.state = BufWriterState::Write(w); - self.multipart_id = Some(id); + self.state = BufWriterState::Write(ready!(f.poll_unpin(cx)?).into()); } BufWriterState::Buffer(p, b) => { let buf = std::mem::take(b); let path = std::mem::take(p); let store = Arc::clone(&self.store); - self.state = BufWriterState::Put(Box::pin(async move { + self.state = BufWriterState::Flush(Box::pin(async move { store.put(&path, buf.into()).await?; Ok(()) })); } - BufWriterState::Put(f) => return f.poll_unpin(cx), - BufWriterState::Write(w) => return Pin::new(w).poll_shutdown(cx), + BufWriterState::Flush(f) => return f.poll_unpin(cx), + BufWriterState::Write(x) => { + let upload = x.take().unwrap(); + self.state = BufWriterState::Flush( + async move { + upload.finish().await?; + Ok(()) + } + .boxed(), + ) + } } } } @@ -352,7 +381,7 @@ mod tests { use super::*; use crate::memory::InMemory; use crate::path::Path; - use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncSeekExt}; + use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; #[tokio::test] async fn test_buf_reader() { @@ -443,9 +472,7 @@ mod tests { writer.write_all(&[0; 20]).await.unwrap(); writer.flush().await.unwrap(); writer.write_all(&[0; 5]).await.unwrap(); - assert!(writer.multipart_id().is_none()); writer.shutdown().await.unwrap(); - assert!(writer.multipart_id().is_none()); assert_eq!(store.head(&path).await.unwrap().size, 25); // Test multipart @@ -453,9 +480,7 @@ mod tests { writer.write_all(&[0; 20]).await.unwrap(); writer.flush().await.unwrap(); writer.write_all(&[0; 20]).await.unwrap(); - assert!(writer.multipart_id().is_some()); writer.shutdown().await.unwrap(); - assert!(writer.multipart_id().is_some()); assert_eq!(store.head(&path).await.unwrap().size, 40); } diff --git a/object_store/src/chunked.rs b/object_store/src/chunked.rs index d33556f4b12e..6db7f4b35e24 100644 --- a/object_store/src/chunked.rs +++ b/object_store/src/chunked.rs @@ -25,14 +25,13 @@ use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use futures::stream::BoxStream; use futures::StreamExt; -use tokio::io::AsyncWrite; use crate::path::Path; +use crate::Result; use crate::{ - GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutOptions, - PutResult, + GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, ObjectStore, + PutOptions, PutResult, }; -use crate::{MultipartId, Result}; /// Wraps a [`ObjectStore`] and makes its get response return chunks /// in a controllable manner. @@ -67,17 +66,10 @@ impl ObjectStore for ChunkedStore { self.inner.put_opts(location, bytes, opts).await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { + async fn put_multipart(&self, location: &Path) -> Result> { self.inner.put_multipart(location).await } - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> { - self.inner.abort_multipart(location, multipart_id).await - } - async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { let r = self.inner.get_opts(location, options).await?; let stream = match r.payload { diff --git a/object_store/src/client/mock_server.rs b/object_store/src/client/mock_server.rs index 70b856186d72..aa5a9e0ab4dd 100644 --- a/object_store/src/client/mock_server.rs +++ b/object_store/src/client/mock_server.rs @@ -17,18 +17,23 @@ use futures::future::BoxFuture; use futures::FutureExt; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Response, Server}; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use hyper_util::rt::TokioIo; use parking_lot::Mutex; use std::collections::VecDeque; use std::convert::Infallible; use std::future::Future; use std::net::SocketAddr; use std::sync::Arc; +use tokio::net::TcpListener; use tokio::sync::oneshot; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; -pub type ResponseFn = Box) -> BoxFuture<'static, Response> + Send>; +pub type ResponseFn = + Box) -> BoxFuture<'static, Response> + Send>; /// A mock server pub struct MockServer { @@ -39,39 +44,48 @@ pub struct MockServer { } impl MockServer { - pub fn new() -> Self { + pub async fn new() -> Self { let responses: Arc>> = Arc::new(Mutex::new(VecDeque::with_capacity(10))); - let r = Arc::clone(&responses); - let make_service = make_service_fn(move |_conn| { - let r = Arc::clone(&r); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - let r = Arc::clone(&r); - let next = r.lock().pop_front(); - async move { - Ok::<_, Infallible>(match next { - Some(r) => r(req).await, - None => Response::new(Body::from("Hello World")), - }) - } - })) - } - }); + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); - let (shutdown, rx) = oneshot::channel::<()>(); - let server = Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service); + let (shutdown, mut rx) = oneshot::channel::<()>(); - let url = format!("http://{}", server.local_addr()); + let url = format!("http://{}", listener.local_addr().unwrap()); + let r = Arc::clone(&responses); let handle = tokio::spawn(async move { - server - .with_graceful_shutdown(async { - rx.await.ok(); - }) - .await - .unwrap() + let mut set = JoinSet::new(); + + loop { + let (stream, _) = tokio::select! { + conn = listener.accept() => conn.unwrap(), + _ = &mut rx => break, + }; + + let r = Arc::clone(&r); + set.spawn(async move { + let _ = http1::Builder::new() + .serve_connection( + TokioIo::new(stream), + service_fn(move |req| { + let r = Arc::clone(&r); + let next = r.lock().pop_front(); + async move { + Ok::<_, Infallible>(match next { + Some(r) => r(req).await, + None => Response::new("Hello World".to_string()), + }) + } + }), + ) + .await; + }); + } + + set.abort_all(); }); Self { @@ -88,14 +102,14 @@ impl MockServer { } /// Add a response - pub fn push(&self, response: Response) { + pub fn push(&self, response: Response) { self.push_fn(|_| response) } /// Add a response function pub fn push_fn(&self, f: F) where - F: FnOnce(Request) -> Response + Send + 'static, + F: FnOnce(Request) -> Response + Send + 'static, { let f = Box::new(|req| async move { f(req) }.boxed()); self.responses.lock().push_back(f) @@ -103,8 +117,8 @@ impl MockServer { pub fn push_async_fn(&self, f: F) where - F: FnOnce(Request) -> Fut + Send + 'static, - Fut: Future> + Send + 'static, + F: FnOnce(Request) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, { self.responses.lock().push_back(Box::new(|r| f(r).boxed())) } diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index 252e9fdcadf5..7728f38954f9 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -40,6 +40,9 @@ pub mod header; #[cfg(any(feature = "aws", feature = "gcp"))] pub mod s3; +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub mod parts; + use async_trait::async_trait; use std::collections::HashMap; use std::str::FromStr; diff --git a/object_store/src/client/parts.rs b/object_store/src/client/parts.rs new file mode 100644 index 000000000000..9fc301edcf81 --- /dev/null +++ b/object_store/src/client/parts.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::multipart::PartId; +use parking_lot::Mutex; + +/// An interior mutable collection of upload parts and their corresponding part index +#[derive(Debug, Default)] +pub(crate) struct Parts(Mutex>); + +impl Parts { + /// Record the [`PartId`] for a given index + /// + /// Note: calling this method multiple times with the same `part_idx` + /// will result in multiple [`PartId`] in the final output + pub(crate) fn put(&self, part_idx: usize, id: PartId) { + self.0.lock().push((part_idx, id)) + } + + /// Produce the final list of [`PartId`] ordered by `part_idx` + /// + /// `expected` is the number of parts expected in the final result + pub(crate) fn finish(&self, expected: usize) -> crate::Result> { + let mut parts = self.0.lock(); + if parts.len() != expected { + return Err(crate::Error::Generic { + store: "Parts", + source: "Missing part".to_string().into(), + }); + } + parts.sort_unstable_by_key(|(idx, _)| *idx); + Ok(parts.drain(..).map(|(_, v)| v).collect()) + } +} diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs index fbd3645d2780..e4bb5c9e731a 100644 --- a/object_store/src/client/retry.rs +++ b/object_store/src/client/retry.rs @@ -259,13 +259,16 @@ impl RetryExt for reqwest::RequestBuilder { Err(e) => { let mut do_retry = false; - if req.method().is_safe() && e.is_timeout() { + if e.is_connect() || ( req.method().is_safe() && e.is_timeout()) { do_retry = true - } else if let Some(source) = e.source() { - if let Some(e) = source.downcast_ref::() { - if e.is_connect() || e.is_closed() || e.is_incomplete_message() { - do_retry = true; + } else { + let mut source = e.source(); + while let Some(e) = source { + if let Some(e) = e.downcast_ref::() { + do_retry = e.is_closed() || e.is_incomplete_message(); + break } + source = e.source(); } } @@ -305,13 +308,13 @@ mod tests { use crate::client::retry::{Error, RetryExt}; use crate::RetryConfig; use hyper::header::LOCATION; - use hyper::{Body, Response}; + use hyper::Response; use reqwest::{Client, Method, StatusCode}; use std::time::Duration; #[tokio::test] async fn test_retry() { - let mock = MockServer::new(); + let mock = MockServer::new().await; let retry = RetryConfig { backoff: Default::default(), @@ -334,7 +337,7 @@ mod tests { mock.push( Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::from("cupcakes")) + .body("cupcakes".to_string()) .unwrap(), ); @@ -350,7 +353,7 @@ mod tests { mock.push( Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) + .body(String::new()) .unwrap(), ); @@ -366,7 +369,7 @@ mod tests { mock.push( Response::builder() .status(StatusCode::BAD_GATEWAY) - .body(Body::empty()) + .body(String::new()) .unwrap(), ); @@ -377,7 +380,7 @@ mod tests { mock.push( Response::builder() .status(StatusCode::NO_CONTENT) - .body(Body::empty()) + .body(String::new()) .unwrap(), ); @@ -389,7 +392,7 @@ mod tests { Response::builder() .status(StatusCode::FOUND) .header(LOCATION, "/foo") - .body(Body::empty()) + .body(String::new()) .unwrap(), ); @@ -402,7 +405,7 @@ mod tests { Response::builder() .status(StatusCode::FOUND) .header(LOCATION, "/bar") - .body(Body::empty()) + .body(String::new()) .unwrap(), ); @@ -416,19 +419,19 @@ mod tests { Response::builder() .status(StatusCode::FOUND) .header(LOCATION, "/bar") - .body(Body::empty()) + .body(String::new()) .unwrap(), ); } let e = do_request().await.unwrap_err().to_string(); - assert!(e.ends_with("too many redirects"), "{}", e); + assert!(e.contains("error following redirect for url"), "{}", e); // Handles redirect missing location mock.push( Response::builder() .status(StatusCode::FOUND) - .body(Body::empty()) + .body(String::new()) .unwrap(), ); @@ -441,7 +444,7 @@ mod tests { mock.push( Response::builder() .status(StatusCode::BAD_GATEWAY) - .body(Body::from("ignored")) + .body("ignored".to_string()) .unwrap(), ); } @@ -486,7 +489,7 @@ mod tests { let res = client.request(Method::PUT, mock.url()).send_retry(&retry); let e = res.await.unwrap_err().to_string(); assert!( - e.contains("Error after 0 retries in") && e.contains("operation timed out"), + e.contains("Error after 0 retries in") && e.contains("error sending request for url"), "{e}" ); diff --git a/object_store/src/gcp/builder.rs b/object_store/src/gcp/builder.rs index 14c4257dc6a3..2cf75040b858 100644 --- a/object_store/src/gcp/builder.rs +++ b/object_store/src/gcp/builder.rs @@ -594,10 +594,7 @@ mod tests { .unwrap_err() .to_string(); - assert_eq!( - "Generic HTTP client error: builder error: unknown proxy scheme", - err - ); + assert_eq!("Generic HTTP client error: builder error", err); } #[test] diff --git a/object_store/src/gcp/client.rs b/object_store/src/gcp/client.rs index e4b0f9af7d15..def53beefe78 100644 --- a/object_store/src/gcp/client.rs +++ b/object_store/src/gcp/client.rs @@ -272,7 +272,7 @@ impl GoogleCloudStorageClient { }) } - /// Initiate a multi-part upload + /// Initiate a multipart upload pub async fn multipart_initiate(&self, path: &Path) -> Result { let credential = self.get_credential().await?; let url = self.object_url(path); diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index 8633abbfb4dc..2058d1f8055b 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -17,18 +17,14 @@ //! An object store implementation for Google Cloud Storage //! -//! ## Multi-part uploads +//! ## Multipart uploads //! -//! [Multi-part uploads](https://cloud.google.com/storage/docs/multipart-uploads) -//! can be initiated with the [ObjectStore::put_multipart] method. -//! Data passed to the writer is automatically buffered to meet the minimum size -//! requirements for a part. Multiple parts are uploaded concurrently. -//! -//! If the writer fails for any reason, you may have parts uploaded to GCS but not -//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method -//! to abort the upload and drop those unneeded parts. In addition, you may wish to -//! consider implementing automatic clean up of unused parts that are older than one -//! week. +//! [Multipart uploads](https://cloud.google.com/storage/docs/multipart-uploads) +//! can be initiated with the [ObjectStore::put_multipart] method. If neither +//! [`MultipartUpload::complete`] nor [`MultipartUpload::abort`] is invoked, you may +//! have parts uploaded to GCS but not used, that you will be charged for. It is recommended +//! you configure a [lifecycle rule] to abort incomplete multipart uploads after a certain +//! period of time to avoid being charged for storing partial uploads. //! //! ## Using HTTP/2 //! @@ -36,24 +32,24 @@ //! because it allows much higher throughput in our benchmarks (see //! [#5194](https://github.com/apache/arrow-rs/issues/5194)). HTTP/2 can be //! enabled by setting [crate::ClientConfigKey::Http1Only] to false. +//! +//! [lifecycle rule]: https://cloud.google.com/storage/docs/lifecycle#abort-mpu use std::sync::Arc; use crate::client::CredentialProvider; use crate::{ - multipart::{PartId, PutPart, WriteMultiPart}, - path::Path, - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, - Result, + multipart::PartId, path::Path, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, + ObjectMeta, ObjectStore, PutOptions, PutResult, Result, UploadPart, }; use async_trait::async_trait; use bytes::Bytes; use client::GoogleCloudStorageClient; use futures::stream::BoxStream; -use tokio::io::AsyncWrite; use crate::client::get::GetClientExt; use crate::client::list::ListClientExt; -use crate::multipart::MultiPartStore; +use crate::client::parts::Parts; +use crate::multipart::MultipartStore; pub use builder::{GoogleCloudStorageBuilder, GoogleConfigKey}; pub use credential::GcpCredential; @@ -89,27 +85,50 @@ impl GoogleCloudStorage { } } +#[derive(Debug)] struct GCSMultipartUpload { + state: Arc, + part_idx: usize, +} + +#[derive(Debug)] +struct UploadState { client: Arc, path: Path, multipart_id: MultipartId, + parts: Parts, } #[async_trait] -impl PutPart for GCSMultipartUpload { - /// Upload an object part - async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { - self.client - .put_part(&self.path, &self.multipart_id, part_idx, buf.into()) +impl MultipartUpload for GCSMultipartUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let idx = self.part_idx; + self.part_idx += 1; + let state = Arc::clone(&self.state); + Box::pin(async move { + let part = state + .client + .put_part(&state.path, &state.multipart_id, idx, data) + .await?; + state.parts.put(idx, part); + Ok(()) + }) + } + + async fn complete(&mut self) -> Result { + let parts = self.state.parts.finish(self.part_idx)?; + + self.state + .client + .multipart_complete(&self.state.path, &self.state.multipart_id, parts) .await } - /// Complete a multipart upload - async fn complete(&self, completed_parts: Vec) -> Result<()> { - self.client - .multipart_complete(&self.path, &self.multipart_id, completed_parts) - .await?; - Ok(()) + async fn abort(&mut self) -> Result<()> { + self.state + .client + .multipart_cleanup(&self.state.path, &self.state.multipart_id) + .await } } @@ -119,27 +138,18 @@ impl ObjectStore for GoogleCloudStorage { self.client.put(location, bytes, opts).await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { + async fn put_multipart(&self, location: &Path) -> Result> { let upload_id = self.client.multipart_initiate(location).await?; - let inner = GCSMultipartUpload { - client: Arc::clone(&self.client), - path: location.clone(), - multipart_id: upload_id.clone(), - }; - - Ok((upload_id, Box::new(WriteMultiPart::new(inner, 8)))) - } - - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> { - self.client - .multipart_cleanup(location, multipart_id) - .await?; - - Ok(()) + Ok(Box::new(GCSMultipartUpload { + part_idx: 0, + state: Arc::new(UploadState { + client: Arc::clone(&self.client), + path: location.clone(), + multipart_id: upload_id.clone(), + parts: Default::default(), + }), + })) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { @@ -176,7 +186,7 @@ impl ObjectStore for GoogleCloudStorage { } #[async_trait] -impl MultiPartStore for GoogleCloudStorage { +impl MultipartStore for GoogleCloudStorage { async fn create_multipart(&self, path: &Path) -> Result { self.client.multipart_initiate(path).await } diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index f1d11db4762c..626337df27f9 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -37,7 +37,6 @@ use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use snafu::{OptionExt, ResultExt, Snafu}; -use tokio::io::AsyncWrite; use url::Url; use crate::client::get::GetClientExt; @@ -45,7 +44,7 @@ use crate::client::header::get_etag; use crate::http::client::Client; use crate::path::Path; use crate::{ - ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, + ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, PutMode, PutOptions, PutResult, Result, RetryConfig, }; @@ -115,15 +114,8 @@ impl ObjectStore for HttpStore { }) } - async fn put_multipart( - &self, - _location: &Path, - ) -> Result<(MultipartId, Box)> { - Err(super::Error::NotImplemented) - } - - async fn abort_multipart(&self, _location: &Path, _multipart_id: &MultipartId) -> Result<()> { - Err(super::Error::NotImplemented) + async fn put_multipart(&self, _location: &Path) -> Result> { + Err(crate::Error::NotImplemented) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 8132002b6e01..97604a7dce68 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -88,11 +88,11 @@ //! //! # Why not a Filesystem Interface? //! -//! Whilst this crate does provide a [`BufReader`], the [`ObjectStore`] interface mirrors the APIs -//! of object stores and not filesystems, opting to provide stateless APIs instead of the cursor -//! based interfaces such as [`Read`] or [`Seek`] favoured by filesystems. +//! The [`ObjectStore`] interface is designed to mirror the APIs +//! of object stores and *not* filesystems, and thus has stateless APIs instead +//! of cursor based interfaces such as [`Read`] or [`Seek`] available in filesystems. //! -//! This provides some compelling advantages: +//! This design provides the following advantages: //! //! * All operations are atomic, and readers cannot observe partial and/or failed writes //! * Methods map directly to object store APIs, providing both efficiency and predictability @@ -100,7 +100,12 @@ //! * Allows for functionality not native to filesystems, such as operation preconditions //! and atomic multipart uploads //! +//! This crate does provide [`BufReader`] and [`BufWriter`] adapters +//! which provide a more filesystem-like API for working with the +//! [`ObjectStore`] trait, however, they should be used with care +//! //! [`BufReader`]: buffered::BufReader +//! [`BufWriter`]: buffered::BufWriter //! //! # Adapters //! @@ -264,12 +269,11 @@ //! //! # Multipart Upload //! -//! Use the [`ObjectStore::put_multipart`] method to atomically write a large amount of data, -//! with implementations automatically handling parallel, chunked upload where appropriate. +//! Use the [`ObjectStore::put_multipart`] method to atomically write a large amount of data //! //! ``` //! # use object_store::local::LocalFileSystem; -//! # use object_store::ObjectStore; +//! # use object_store::{ObjectStore, WriteMultipart}; //! # use std::sync::Arc; //! # use bytes::Bytes; //! # use tokio::io::AsyncWriteExt; @@ -281,12 +285,10 @@ //! # //! let object_store: Arc = get_object_store(); //! let path = Path::from("data/large_file"); -//! let (_id, mut writer) = object_store.put_multipart(&path).await.unwrap(); -//! -//! let bytes = Bytes::from_static(b"hello"); -//! writer.write_all(&bytes).await.unwrap(); -//! writer.flush().await.unwrap(); -//! writer.shutdown().await.unwrap(); +//! let upload = object_store.put_multipart(&path).await.unwrap(); +//! let mut write = WriteMultipart::new(upload); +//! write.write(b"hello"); +//! write.finish().await.unwrap(); //! # } //! ``` //! @@ -496,9 +498,11 @@ pub use tags::TagSet; pub mod multipart; mod parse; +mod upload; mod util; pub use parse::{parse_url, parse_url_opts}; +pub use upload::*; pub use util::GetRange; use crate::path::Path; @@ -515,12 +519,11 @@ use std::fmt::{Debug, Formatter}; use std::io::{Read, Seek, SeekFrom}; use std::ops::Range; use std::sync::Arc; -use tokio::io::AsyncWrite; /// An alias for a dynamically dispatched object store implementation. pub type DynObjectStore = dyn ObjectStore; -/// Id type for multi-part uploads. +/// Id type for multipart uploads. pub type MultipartId = String; /// Universal API to multiple object store services. @@ -538,48 +541,11 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// Save the provided bytes to the specified location with the given options async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result; - /// Get a multi-part upload that allows writing data in chunks. - /// - /// Most cloud-based uploads will buffer and upload parts in parallel. - /// - /// To complete the upload, [AsyncWrite::poll_shutdown] must be called - /// to completion. This operation is guaranteed to be atomic, it will either - /// make all the written data available at `location`, or fail. No clients - /// should be able to observe a partially written object. - /// - /// For some object stores (S3, GCS, and local in particular), if the - /// writer fails or panics, you must call [ObjectStore::abort_multipart] - /// to clean up partially written data. - /// - ///
- /// It is recommended applications wait for any in-flight requests to complete by calling `flush`, if - /// there may be a significant gap in time (> ~30s) before the next write. - /// These gaps can include times where the function returns control to the - /// caller while keeping the writer open. If `flush` is not called, futures - /// for in-flight requests may be left unpolled long enough for the requests - /// to time out, causing the write to fail. - ///
- /// - /// For applications requiring fine-grained control of multipart uploads - /// see [`MultiPartStore`], although note that this interface cannot be - /// supported by all [`ObjectStore`] backends. - /// - /// For applications looking to implement this interface for a custom - /// multipart API, see [`WriteMultiPart`] which handles the complexities - /// of performing parallel uploads of fixed size parts. - /// - /// [`WriteMultiPart`]: multipart::WriteMultiPart - /// [`MultiPartStore`]: multipart::MultiPartStore - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)>; - - /// Cleanup an aborted upload. + /// Perform a multipart upload /// - /// See documentation for individual stores for exact behavior, as capabilities - /// vary by object store. - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()>; + /// Client should prefer [`ObjectStore::put`] for small payloads, as streaming uploads + /// typically require multiple separate requests. See [`MultipartUpload`] for more information + async fn put_multipart(&self, location: &Path) -> Result>; /// Return the bytes that are stored at the specified location. async fn get(&self, location: &Path) -> Result { @@ -764,21 +730,10 @@ macro_rules! as_ref_impl { self.as_ref().put_opts(location, bytes, opts).await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { + async fn put_multipart(&self, location: &Path) -> Result> { self.as_ref().put_multipart(location).await } - async fn abort_multipart( - &self, - location: &Path, - multipart_id: &MultipartId, - ) -> Result<()> { - self.as_ref().abort_multipart(location, multipart_id).await - } - async fn get(&self, location: &Path) -> Result { self.as_ref().get(location).await } @@ -1241,14 +1196,12 @@ mod test_util { #[cfg(test)] mod tests { use super::*; - use crate::multipart::MultiPartStore; + use crate::multipart::MultipartStore; use crate::test_util::flatten_list_stream; use chrono::TimeZone; use futures::stream::FuturesUnordered; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; - use std::future::Future; - use tokio::io::AsyncWriteExt; pub(crate) async fn put_get_delete_list(storage: &DynObjectStore) { put_get_delete_list_opts(storage).await @@ -1923,12 +1876,11 @@ mod tests { let location = Path::from("test_dir/test_upload_file.txt"); // Can write to storage - let data = get_chunks(5_000, 10); + let data = get_chunks(5 * 1024 * 1024, 3); let bytes_expected = data.concat(); - let (_, mut writer) = storage.put_multipart(&location).await.unwrap(); - for chunk in &data { - writer.write_all(chunk).await.unwrap(); - } + let mut upload = storage.put_multipart(&location).await.unwrap(); + let uploads = data.into_iter().map(|x| upload.put_part(x)); + futures::future::try_join_all(uploads).await.unwrap(); // Object should not yet exist in store let meta_res = storage.head(&location).await; @@ -1944,7 +1896,8 @@ mod tests { let result = storage.list_with_delimiter(None).await.unwrap(); assert_eq!(&result.objects, &[]); - writer.shutdown().await.unwrap(); + upload.complete().await.unwrap(); + let bytes_written = storage.get(&location).await.unwrap().bytes().await.unwrap(); assert_eq!(bytes_expected, bytes_written); @@ -1952,22 +1905,19 @@ mod tests { // Sizes chosen to ensure we write three parts let data = get_chunks(3_200_000, 7); let bytes_expected = data.concat(); - let (_, mut writer) = storage.put_multipart(&location).await.unwrap(); + let upload = storage.put_multipart(&location).await.unwrap(); + let mut writer = WriteMultipart::new(upload); for chunk in &data { - writer.write_all(chunk).await.unwrap(); + writer.write(chunk) } - writer.shutdown().await.unwrap(); + writer.finish().await.unwrap(); let bytes_written = storage.get(&location).await.unwrap().bytes().await.unwrap(); assert_eq!(bytes_expected, bytes_written); // We can abort an empty write let location = Path::from("test_dir/test_abort_upload.txt"); - let (upload_id, writer) = storage.put_multipart(&location).await.unwrap(); - drop(writer); - storage - .abort_multipart(&location, &upload_id) - .await - .unwrap(); + let mut upload = storage.put_multipart(&location).await.unwrap(); + upload.abort().await.unwrap(); let get_res = storage.get(&location).await; assert!(get_res.is_err()); assert!(matches!( @@ -1976,17 +1926,13 @@ mod tests { )); // We can abort an in-progress write - let (upload_id, mut writer) = storage.put_multipart(&location).await.unwrap(); - if let Some(chunk) = data.first() { - writer.write_all(chunk).await.unwrap(); - let _ = writer.write(chunk).await.unwrap(); - } - drop(writer); - - storage - .abort_multipart(&location, &upload_id) + let mut upload = storage.put_multipart(&location).await.unwrap(); + upload + .put_part(data.first().unwrap().clone()) .await .unwrap(); + + upload.abort().await.unwrap(); let get_res = storage.get(&location).await; assert!(get_res.is_err()); assert!(matches!( @@ -2181,7 +2127,34 @@ mod tests { storage.delete(&path2).await.unwrap(); } - pub(crate) async fn multipart(storage: &dyn ObjectStore, multipart: &dyn MultiPartStore) { + pub(crate) async fn copy_rename_nonexistent_object(storage: &DynObjectStore) { + // Create empty source object + let path1 = Path::from("test1"); + + // Create destination object + let path2 = Path::from("test2"); + storage.put(&path2, Bytes::from("hello")).await.unwrap(); + + // copy() errors if source does not exist + let result = storage.copy(&path1, &path2).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // rename() errors if source does not exist + let result = storage.rename(&path1, &path2).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // copy_if_not_exists() errors if source does not exist + let result = storage.copy_if_not_exists(&path1, &path2).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // Clean up + storage.delete(&path2).await.unwrap(); + } + + pub(crate) async fn multipart(storage: &dyn ObjectStore, multipart: &dyn MultipartStore) { let path = Path::from("test_multipart"); let chunk_size = 5 * 1024 * 1024; @@ -2248,7 +2221,7 @@ mod tests { pub(crate) async fn tagging(storage: &dyn ObjectStore, validate: bool, get_tags: F) where F: Fn(Path) -> Fut + Send + Sync, - Fut: Future> + Send, + Fut: std::future::Future> + Send, { use bytes::Buf; use serde::Deserialize; diff --git a/object_store/src/limit.rs b/object_store/src/limit.rs index d1363d9a4d46..e5f6841638e1 100644 --- a/object_store/src/limit.rs +++ b/object_store/src/limit.rs @@ -18,18 +18,16 @@ //! An object store that limits the maximum concurrency of the wrapped implementation use crate::{ - BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, ObjectMeta, - ObjectStore, Path, PutOptions, PutResult, Result, StreamExt, + BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, Path, PutOptions, PutResult, Result, StreamExt, UploadPart, }; use async_trait::async_trait; use bytes::Bytes; use futures::{FutureExt, Stream}; -use std::io::{Error, IoSlice}; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::AsyncWrite; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; /// Store wrapper that wraps an inner store and limits the maximum number of concurrent @@ -81,18 +79,12 @@ impl ObjectStore for LimitStore { let _permit = self.semaphore.acquire().await.unwrap(); self.inner.put_opts(location, bytes, opts).await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); - let (id, write) = self.inner.put_multipart(location).await?; - Ok((id, Box::new(PermitWrapper::new(write, permit)))) - } - - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> { - let _permit = self.semaphore.acquire().await.unwrap(); - self.inner.abort_multipart(location, multipart_id).await + async fn put_multipart(&self, location: &Path) -> Result> { + let upload = self.inner.put_multipart(location).await?; + Ok(Box::new(LimitUpload { + semaphore: Arc::clone(&self.semaphore), + upload, + })) } async fn get(&self, location: &Path) -> Result { let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); @@ -221,39 +213,42 @@ impl Stream for PermitWrapper { } } -impl AsyncWrite for PermitWrapper { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } +/// An [`MultipartUpload`] wrapper that limits the maximum number of concurrent requests +#[derive(Debug)] +pub struct LimitUpload { + upload: Box, + semaphore: Arc, +} - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) +impl LimitUpload { + /// Create a new [`LimitUpload`] limiting `upload` to `max_concurrency` concurrent requests + pub fn new(upload: Box, max_concurrency: usize) -> Self { + Self { + upload, + semaphore: Arc::new(Semaphore::new(max_concurrency)), + } } +} - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) +#[async_trait] +impl MultipartUpload for LimitUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let upload = self.upload.put_part(data); + let s = Arc::clone(&self.semaphore); + Box::pin(async move { + let _permit = s.acquire().await.unwrap(); + upload.await + }) } - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + async fn complete(&mut self) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.upload.complete().await } - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() + async fn abort(&mut self) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.upload.abort().await } } diff --git a/object_store/src/local.rs b/object_store/src/local.rs index d631771778db..6cc0c672af45 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -16,34 +16,32 @@ // under the License. //! An object store implementation for a local filesystem -use crate::{ - maybe_spawn_blocking, - path::{absolute_path_to_url, Path}, - util::InvalidGetRange, - GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, ObjectMeta, ObjectStore, - PutMode, PutOptions, PutResult, Result, -}; -use async_trait::async_trait; -use bytes::Bytes; -use chrono::{DateTime, Utc}; -use futures::future::BoxFuture; -use futures::ready; -use futures::{stream::BoxStream, StreamExt}; -use futures::{FutureExt, TryStreamExt}; -use snafu::{ensure, ResultExt, Snafu}; use std::fs::{metadata, symlink_metadata, File, Metadata, OpenOptions}; use std::io::{ErrorKind, Read, Seek, SeekFrom, Write}; use std::ops::Range; -use std::pin::Pin; use std::sync::Arc; -use std::task::Poll; use std::time::SystemTime; use std::{collections::BTreeSet, convert::TryFrom, io}; use std::{collections::VecDeque, path::PathBuf}; -use tokio::io::AsyncWrite; + +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{DateTime, Utc}; +use futures::{stream::BoxStream, StreamExt}; +use futures::{FutureExt, TryStreamExt}; +use parking_lot::Mutex; +use snafu::{ensure, OptionExt, ResultExt, Snafu}; use url::Url; use walkdir::{DirEntry, WalkDir}; +use crate::{ + maybe_spawn_blocking, + path::{absolute_path_to_url, Path}, + util::InvalidGetRange, + GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, ObjectStore, + PutMode, PutOptions, PutResult, Result, UploadPart, +}; + /// A specialized `Error` for filesystem object store-related errors #[derive(Debug, Snafu)] #[allow(missing_docs)] @@ -155,6 +153,9 @@ pub(crate) enum Error { InvalidPath { path: String, }, + + #[snafu(display("Upload aborted"))] + Aborted, } impl From for super::Error { @@ -342,8 +343,7 @@ impl ObjectStore for LocalFileSystem { let path = self.path_to_filesystem(location)?; maybe_spawn_blocking(move || { - let (mut file, suffix) = new_staged_upload(&path)?; - let staging_path = staged_upload_path(&path, &suffix); + let (mut file, staging_path) = new_staged_upload(&path)?; let mut e_tag = None; let err = match file.write_all(&bytes) { @@ -395,31 +395,10 @@ impl ObjectStore for LocalFileSystem { .await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - let dest = self.path_to_filesystem(location)?; - - let (file, suffix) = new_staged_upload(&dest)?; - Ok(( - suffix.clone(), - Box::new(LocalUpload::new(dest, suffix, Arc::new(file))), - )) - } - - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> { + async fn put_multipart(&self, location: &Path) -> Result> { let dest = self.path_to_filesystem(location)?; - let path: PathBuf = staged_upload_path(&dest, multipart_id); - - maybe_spawn_blocking(move || match std::fs::remove_file(&path) { - Ok(_) => Ok(()), - Err(source) => match source.kind() { - ErrorKind::NotFound => Ok(()), // Already deleted - _ => Err(Error::UnableToDeleteFile { path, source }.into()), - }, - }) - .await + let (file, src) = new_staged_upload(&dest)?; + Ok(Box::new(LocalUpload::new(src, dest, file))) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { @@ -619,7 +598,10 @@ impl ObjectStore for LocalFileSystem { } Err(source) => match source.kind() { ErrorKind::AlreadyExists => id += 1, - ErrorKind::NotFound => create_parent_dirs(&to, source)?, + ErrorKind::NotFound => match from.exists() { + true => create_parent_dirs(&to, source)?, + false => return Err(Error::NotFound { path: from, source }.into()), + }, _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), }, } @@ -634,7 +616,10 @@ impl ObjectStore for LocalFileSystem { match std::fs::rename(&from, &to) { Ok(_) => return Ok(()), Err(source) => match source.kind() { - ErrorKind::NotFound => create_parent_dirs(&to, source)?, + ErrorKind::NotFound => match from.exists() { + true => create_parent_dirs(&to, source)?, + false => return Err(Error::NotFound { path: from, source }.into()), + }, _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), }, } @@ -657,7 +642,10 @@ impl ObjectStore for LocalFileSystem { } .into()) } - ErrorKind::NotFound => create_parent_dirs(&to, source)?, + ErrorKind::NotFound => match from.exists() { + true => create_parent_dirs(&to, source)?, + false => return Err(Error::NotFound { path: from, source }.into()), + }, _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), }, } @@ -677,17 +665,17 @@ fn create_parent_dirs(path: &std::path::Path, source: io::Error) -> Result<()> { Ok(()) } -/// Generates a unique file path `{base}#{suffix}`, returning the opened `File` and `suffix` +/// Generates a unique file path `{base}#{suffix}`, returning the opened `File` and `path` /// /// Creates any directories if necessary -fn new_staged_upload(base: &std::path::Path) -> Result<(File, String)> { +fn new_staged_upload(base: &std::path::Path) -> Result<(File, PathBuf)> { let mut multipart_id = 1; loop { let suffix = multipart_id.to_string(); let path = staged_upload_path(base, &suffix); let mut options = OpenOptions::new(); match options.read(true).write(true).create_new(true).open(&path) { - Ok(f) => return Ok((f, suffix)), + Ok(f) => return Ok((f, path)), Err(source) => match source.kind() { ErrorKind::AlreadyExists => multipart_id += 1, ErrorKind::NotFound => create_parent_dirs(&path, source)?, @@ -705,194 +693,91 @@ fn staged_upload_path(dest: &std::path::Path, suffix: &str) -> PathBuf { staging_path.into() } -enum LocalUploadState { - /// Upload is ready to send new data - Idle(Arc), - /// In the middle of a write - Writing(Arc, BoxFuture<'static, Result>), - /// In the middle of syncing data and closing file. - /// - /// Future will contain last reference to file, so it will call drop on completion. - ShuttingDown(BoxFuture<'static, Result<(), io::Error>>), - /// File is being moved from it's temporary location to the final location - Committing(BoxFuture<'static, Result<(), io::Error>>), - /// Upload is complete - Complete, +#[derive(Debug)] +struct LocalUpload { + /// The upload state + state: Arc, + /// The location of the temporary file + src: Option, + /// The next offset to write into the file + offset: u64, } -struct LocalUpload { - inner_state: LocalUploadState, +#[derive(Debug)] +struct UploadState { dest: PathBuf, - multipart_id: MultipartId, + file: Mutex>, } impl LocalUpload { - pub fn new(dest: PathBuf, multipart_id: MultipartId, file: Arc) -> Self { + pub fn new(src: PathBuf, dest: PathBuf, file: File) -> Self { Self { - inner_state: LocalUploadState::Idle(file), - dest, - multipart_id, + state: Arc::new(UploadState { + dest, + file: Mutex::new(Some(file)), + }), + src: Some(src), + offset: 0, } } } -impl AsyncWrite for LocalUpload { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - let invalid_state = |condition: &str| -> Poll> { - Poll::Ready(Err(io::Error::new( - ErrorKind::InvalidInput, - format!("Tried to write to file {condition}."), - ))) - }; +#[async_trait] +impl MultipartUpload for LocalUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let offset = self.offset; + self.offset += data.len() as u64; - if let Ok(runtime) = tokio::runtime::Handle::try_current() { - let mut data: Vec = buf.to_vec(); - let data_len = data.len(); - - loop { - match &mut self.inner_state { - LocalUploadState::Idle(file) => { - let file = Arc::clone(file); - let file2 = Arc::clone(&file); - let data: Vec = std::mem::take(&mut data); - self.inner_state = LocalUploadState::Writing( - file, - Box::pin( - runtime - .spawn_blocking(move || (&*file2).write_all(&data)) - .map(move |res| match res { - Err(err) => Err(io::Error::new(ErrorKind::Other, err)), - Ok(res) => res.map(move |_| data_len), - }), - ), - ); - } - LocalUploadState::Writing(file, inner_write) => { - let res = ready!(inner_write.poll_unpin(cx)); - self.inner_state = LocalUploadState::Idle(Arc::clone(file)); - return Poll::Ready(res); - } - LocalUploadState::ShuttingDown(_) => { - return invalid_state("when writer is shutting down"); - } - LocalUploadState::Committing(_) => { - return invalid_state("when writer is committing data"); - } - LocalUploadState::Complete => { - return invalid_state("when writer is complete"); - } - } - } - } else if let LocalUploadState::Idle(file) = &self.inner_state { - let file = Arc::clone(file); - (&*file).write_all(buf)?; - Poll::Ready(Ok(buf.len())) - } else { - // If we are running on this thread, then only possible states are Idle and Complete. - invalid_state("when writer is already complete.") - } + let s = Arc::clone(&self.state); + maybe_spawn_blocking(move || { + let mut f = s.file.lock(); + let file = f.as_mut().context(AbortedSnafu)?; + file.seek(SeekFrom::Start(offset)) + .context(SeekSnafu { path: &s.dest })?; + file.write_all(&data).context(UnableToCopyDataToFileSnafu)?; + Ok(()) + }) + .boxed() } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) + async fn complete(&mut self) -> Result { + let src = self.src.take().context(AbortedSnafu)?; + let s = Arc::clone(&self.state); + maybe_spawn_blocking(move || { + // Ensure no inflight writes + let f = s.file.lock().take().context(AbortedSnafu)?; + std::fs::rename(&src, &s.dest).context(UnableToRenameFileSnafu)?; + let metadata = f.metadata().map_err(|e| Error::Metadata { + source: e.into(), + path: src.to_string_lossy().to_string(), + })?; + + Ok(PutResult { + e_tag: Some(get_etag(&metadata)), + version: None, + }) + }) + .await } - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if let Ok(runtime) = tokio::runtime::Handle::try_current() { - loop { - match &mut self.inner_state { - LocalUploadState::Idle(file) => { - // We are moving file into the future, and it will be dropped on it's completion, closing the file. - let file = Arc::clone(file); - self.inner_state = LocalUploadState::ShuttingDown(Box::pin( - runtime - .spawn_blocking(move || (*file).sync_all()) - .map(move |res| match res { - Err(err) => Err(io::Error::new(io::ErrorKind::Other, err)), - Ok(res) => res, - }), - )); - } - LocalUploadState::ShuttingDown(fut) => match fut.poll_unpin(cx) { - Poll::Ready(res) => { - res?; - let staging_path = staged_upload_path(&self.dest, &self.multipart_id); - let dest = self.dest.clone(); - self.inner_state = LocalUploadState::Committing(Box::pin( - runtime - .spawn_blocking(move || std::fs::rename(&staging_path, &dest)) - .map(move |res| match res { - Err(err) => Err(io::Error::new(io::ErrorKind::Other, err)), - Ok(res) => res, - }), - )); - } - Poll::Pending => { - return Poll::Pending; - } - }, - LocalUploadState::Writing(_, _) => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Tried to commit a file where a write is in progress.", - ))); - } - LocalUploadState::Committing(fut) => { - let res = ready!(fut.poll_unpin(cx)); - self.inner_state = LocalUploadState::Complete; - return Poll::Ready(res); - } - LocalUploadState::Complete => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Already complete", - ))) - } - } - } - } else { - let staging_path = staged_upload_path(&self.dest, &self.multipart_id); - match &mut self.inner_state { - LocalUploadState::Idle(file) => { - let file = Arc::clone(file); - self.inner_state = LocalUploadState::Complete; - file.sync_all()?; - drop(file); - std::fs::rename(staging_path, &self.dest)?; - Poll::Ready(Ok(())) - } - _ => { - // If we are running on this thread, then only possible states are Idle and Complete. - Poll::Ready(Err(io::Error::new(ErrorKind::Other, "Already complete"))) - } - } - } + async fn abort(&mut self) -> Result<()> { + let src = self.src.take().context(AbortedSnafu)?; + maybe_spawn_blocking(move || { + std::fs::remove_file(&src).context(UnableToDeleteFileSnafu { path: &src })?; + Ok(()) + }) + .await } } impl Drop for LocalUpload { fn drop(&mut self) { - match self.inner_state { - LocalUploadState::Complete => (), - _ => { - self.inner_state = LocalUploadState::Complete; - let path = staged_upload_path(&self.dest, &self.multipart_id); - // Try to cleanup intermediate file ignoring any error - match tokio::runtime::Handle::try_current() { - Ok(r) => drop(r.spawn_blocking(move || std::fs::remove_file(path))), - Err(_) => drop(std::fs::remove_file(path)), - }; - } + if let Some(src) = self.src.take() { + // Try to clean up intermediate file ignoring any error + match tokio::runtime::Handle::try_current() { + Ok(r) => drop(r.spawn_blocking(move || std::fs::remove_file(src))), + Err(_) => drop(std::fs::remove_file(src)), + }; } } } @@ -1095,12 +980,13 @@ fn convert_walkdir_result( #[cfg(test)] mod tests { - use super::*; - use crate::test_util::flatten_list_stream; - use crate::tests::*; use futures::TryStreamExt; use tempfile::{NamedTempFile, TempDir}; - use tokio::io::AsyncWriteExt; + + use crate::test_util::flatten_list_stream; + use crate::tests::*; + + use super::*; #[tokio::test] async fn file_test() { @@ -1113,6 +999,7 @@ mod tests { list_with_delimiter(&integration).await; rename_and_copy(&integration).await; copy_if_not_exists(&integration).await; + copy_rename_nonexistent_object(&integration).await; stream_get(&integration).await; put_opts(&integration, false).await; } @@ -1125,7 +1012,18 @@ mod tests { put_get_delete_list(&integration).await; list_uses_directories_correctly(&integration).await; list_with_delimiter(&integration).await; - stream_get(&integration).await; + + // Can't use stream_get test as WriteMultipart uses a tokio JoinSet + let p = Path::from("manual_upload"); + let mut upload = integration.put_multipart(&p).await.unwrap(); + upload.put_part(Bytes::from_static(b"123")).await.unwrap(); + upload.put_part(Bytes::from_static(b"45678")).await.unwrap(); + let r = upload.complete().await.unwrap(); + + let get = integration.get(&p).await.unwrap(); + assert_eq!(get.meta.e_tag.as_ref().unwrap(), r.e_tag.as_ref().unwrap()); + let actual = get.bytes().await.unwrap(); + assert_eq!(actual.as_ref(), b"12345678"); }); } @@ -1422,12 +1320,11 @@ mod tests { let location = Path::from("some_file"); let data = Bytes::from("arbitrary data"); - let (multipart_id, mut writer) = integration.put_multipart(&location).await.unwrap(); - writer.write_all(&data).await.unwrap(); + let mut u1 = integration.put_multipart(&location).await.unwrap(); + u1.put_part(data.clone()).await.unwrap(); - let (multipart_id_2, mut writer_2) = integration.put_multipart(&location).await.unwrap(); - assert_ne!(multipart_id, multipart_id_2); - writer_2.write_all(&data).await.unwrap(); + let mut u2 = integration.put_multipart(&location).await.unwrap(); + u2.put_part(data).await.unwrap(); let list = flatten_list_stream(&integration, None).await.unwrap(); assert_eq!(list.len(), 0); @@ -1520,11 +1417,13 @@ mod tests { #[cfg(not(target_arch = "wasm32"))] #[cfg(test)] mod not_wasm_tests { - use crate::local::LocalFileSystem; - use crate::{ObjectStore, Path}; use std::time::Duration; + + use bytes::Bytes; use tempfile::TempDir; - use tokio::io::AsyncWriteExt; + + use crate::local::LocalFileSystem; + use crate::{ObjectStore, Path}; #[tokio::test] async fn test_cleanup_intermediate_files() { @@ -1532,12 +1431,13 @@ mod not_wasm_tests { let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); let location = Path::from("some_file"); - let (_, mut writer) = integration.put_multipart(&location).await.unwrap(); - writer.write_all(b"hello").await.unwrap(); + let data = Bytes::from_static(b"hello"); + let mut upload = integration.put_multipart(&location).await.unwrap(); + upload.put_part(data).await.unwrap(); let file_count = std::fs::read_dir(root.path()).unwrap().count(); assert_eq!(file_count, 1); - drop(writer); + drop(upload); tokio::time::sleep(Duration::from_millis(1)).await; @@ -1549,13 +1449,15 @@ mod not_wasm_tests { #[cfg(target_family = "unix")] #[cfg(test)] mod unix_test { - use crate::local::LocalFileSystem; - use crate::{ObjectStore, Path}; + use std::fs::OpenOptions; + use nix::sys::stat; use nix::unistd; - use std::fs::OpenOptions; use tempfile::TempDir; + use crate::local::LocalFileSystem; + use crate::{ObjectStore, Path}; + #[tokio::test] async fn test_fifo() { let filename = "some_file"; diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index 41ee1091a3b2..6c960d4f24fb 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -16,27 +16,24 @@ // under the License. //! An in-memory object store implementation -use crate::multipart::{MultiPartStore, PartId}; -use crate::util::InvalidGetRange; -use crate::{ - path::Path, GetRange, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, - PutMode, PutOptions, PutResult, Result, UpdateVersion, -}; -use crate::{GetOptions, MultipartId}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::ops::Range; +use std::sync::Arc; + use async_trait::async_trait; use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::{stream::BoxStream, StreamExt}; use parking_lot::RwLock; use snafu::{OptionExt, ResultExt, Snafu}; -use std::collections::BTreeSet; -use std::collections::{BTreeMap, HashMap}; -use std::io; -use std::ops::Range; -use std::pin::Pin; -use std::sync::Arc; -use std::task::Poll; -use tokio::io::AsyncWrite; + +use crate::multipart::{MultipartStore, PartId}; +use crate::util::InvalidGetRange; +use crate::GetOptions; +use crate::{ + path::Path, GetRange, GetResult, GetResultPayload, ListResult, MultipartId, MultipartUpload, + ObjectMeta, ObjectStore, PutMode, PutOptions, PutResult, Result, UpdateVersion, UploadPart, +}; /// A specialized `Error` for in-memory object store-related errors #[derive(Debug, Snafu)] @@ -213,23 +210,12 @@ impl ObjectStore for InMemory { }) } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - Ok(( - String::new(), - Box::new(InMemoryUpload { - location: location.clone(), - data: Vec::new(), - storage: Arc::clone(&self.storage), - }), - )) - } - - async fn abort_multipart(&self, _location: &Path, _multipart_id: &MultipartId) -> Result<()> { - // Nothing to clean up - Ok(()) + async fn put_multipart(&self, location: &Path) -> Result> { + Ok(Box::new(InMemoryUpload { + location: location.clone(), + parts: vec![], + storage: Arc::clone(&self.storage), + })) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { @@ -391,7 +377,7 @@ impl ObjectStore for InMemory { } #[async_trait] -impl MultiPartStore for InMemory { +impl MultipartStore for InMemory { async fn create_multipart(&self, _path: &Path) -> Result { let mut storage = self.storage.write(); let etag = storage.next_etag; @@ -482,45 +468,42 @@ impl InMemory { } } +#[derive(Debug)] struct InMemoryUpload { location: Path, - data: Vec, + parts: Vec, storage: Arc>, } -impl AsyncWrite for InMemoryUpload { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - self.data.extend_from_slice(buf); - Poll::Ready(Ok(buf.len())) +#[async_trait] +impl MultipartUpload for InMemoryUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + self.parts.push(data); + Box::pin(futures::future::ready(Ok(()))) } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) + async fn complete(&mut self) -> Result { + let cap = self.parts.iter().map(|x| x.len()).sum(); + let mut buf = Vec::with_capacity(cap); + self.parts.iter().for_each(|x| buf.extend_from_slice(x)); + let etag = self.storage.write().insert(&self.location, buf.into()); + Ok(PutResult { + e_tag: Some(etag.to_string()), + version: None, + }) } - fn poll_shutdown( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - let data = Bytes::from(std::mem::take(&mut self.data)); - self.storage.write().insert(&self.location, data); - Poll::Ready(Ok(())) + async fn abort(&mut self) -> Result<()> { + Ok(()) } } #[cfg(test)] mod tests { - use super::*; - use crate::tests::*; + use super::*; + #[tokio::test] async fn in_memory_test() { let integration = InMemory::new(); diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs index 1dcd5a6f4960..26cce3936244 100644 --- a/object_store/src/multipart.rs +++ b/object_store/src/multipart.rs @@ -17,34 +17,16 @@ //! Cloud Multipart Upload //! -//! This crate provides an asynchronous interface for multipart file uploads to cloud storage services. -//! It's designed to offer efficient, non-blocking operations, +//! This crate provides an asynchronous interface for multipart file uploads to +//! cloud storage services. It's designed to offer efficient, non-blocking operations, //! especially useful when dealing with large files or high-throughput systems. use async_trait::async_trait; use bytes::Bytes; -use futures::{stream::FuturesUnordered, Future, StreamExt}; -use std::{io, pin::Pin, sync::Arc, task::Poll}; -use tokio::io::AsyncWrite; use crate::path::Path; use crate::{MultipartId, PutResult, Result}; -type BoxedTryFuture = Pin> + Send>>; - -/// A trait used in combination with [`WriteMultiPart`] to implement -/// [`AsyncWrite`] on top of an API for multipart upload -#[async_trait] -pub trait PutPart: Send + Sync + 'static { - /// Upload a single part - async fn put_part(&self, buf: Vec, part_idx: usize) -> Result; - - /// Complete the upload with the provided parts - /// - /// `completed_parts` is in order of part number - async fn complete(&self, completed_parts: Vec) -> Result<()>; -} - /// Represents a part of a file that has been successfully uploaded in a multipart upload process. #[derive(Debug, Clone)] pub struct PartId { @@ -52,222 +34,6 @@ pub struct PartId { pub content_id: String, } -/// Wrapper around a [`PutPart`] that implements [`AsyncWrite`] -/// -/// Data will be uploaded in fixed size chunks of 10 MiB in parallel, -/// up to the configured maximum concurrency -pub struct WriteMultiPart { - inner: Arc, - /// A list of completed parts, in sequential order. - completed_parts: Vec>, - /// Part upload tasks currently running - tasks: FuturesUnordered>, - /// Maximum number of upload tasks to run concurrently - max_concurrency: usize, - /// Buffer that will be sent in next upload. - current_buffer: Vec, - /// Size of each part. - /// - /// While S3 and Minio support variable part sizes, R2 requires they all be - /// exactly the same size. - part_size: usize, - /// Index of current part - current_part_idx: usize, - /// The completion task - completion_task: Option>, -} - -impl WriteMultiPart { - /// Create a new multipart upload with the implementation and the given maximum concurrency - pub fn new(inner: T, max_concurrency: usize) -> Self { - Self { - inner: Arc::new(inner), - completed_parts: Vec::new(), - tasks: FuturesUnordered::new(), - max_concurrency, - current_buffer: Vec::new(), - // TODO: Should self vary by provider? - // TODO: Should we automatically increase then when part index gets large? - - // Minimum size of 5 MiB - // https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html - // https://cloud.google.com/storage/quotas#requests - part_size: 10 * 1024 * 1024, - current_part_idx: 0, - completion_task: None, - } - } - - // Add data to the current buffer, returning the number of bytes added - fn add_to_buffer(mut self: Pin<&mut Self>, buf: &[u8], offset: usize) -> usize { - let remaining_capacity = self.part_size - self.current_buffer.len(); - let to_copy = std::cmp::min(remaining_capacity, buf.len() - offset); - self.current_buffer - .extend_from_slice(&buf[offset..offset + to_copy]); - to_copy - } - - /// Poll current tasks - fn poll_tasks( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Result<(), io::Error> { - if self.tasks.is_empty() { - return Ok(()); - } - while let Poll::Ready(Some(res)) = self.tasks.poll_next_unpin(cx) { - let (part_idx, part) = res?; - let total_parts = self.completed_parts.len(); - self.completed_parts - .resize(std::cmp::max(part_idx + 1, total_parts), None); - self.completed_parts[part_idx] = Some(part); - } - Ok(()) - } - - // The `poll_flush` function will only flush the in-progress tasks. - // The `final_flush` method called during `poll_shutdown` will flush - // the `current_buffer` along with in-progress tasks. - // Please see https://github.com/apache/arrow-rs/issues/3390 for more details. - fn final_flush( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - // Poll current tasks - self.as_mut().poll_tasks(cx)?; - - // If current_buffer is not empty, see if it can be submitted - if !self.current_buffer.is_empty() && self.tasks.len() < self.max_concurrency { - let out_buffer: Vec = std::mem::take(&mut self.current_buffer); - let inner = Arc::clone(&self.inner); - let part_idx = self.current_part_idx; - self.tasks.push(Box::pin(async move { - let upload_part = inner.put_part(out_buffer, part_idx).await?; - Ok((part_idx, upload_part)) - })); - } - - self.as_mut().poll_tasks(cx)?; - - // If tasks and current_buffer are empty, return Ready - if self.tasks.is_empty() && self.current_buffer.is_empty() { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } -} - -impl AsyncWrite for WriteMultiPart { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - // Poll current tasks - self.as_mut().poll_tasks(cx)?; - - let mut offset = 0; - - loop { - // Fill up current buffer - offset += self.as_mut().add_to_buffer(buf, offset); - - // If we don't have a full buffer or we have too many tasks, break - if self.current_buffer.len() < self.part_size - || self.tasks.len() >= self.max_concurrency - { - break; - } - - let new_buffer = Vec::with_capacity(self.part_size); - let out_buffer = std::mem::replace(&mut self.current_buffer, new_buffer); - let inner = Arc::clone(&self.inner); - let part_idx = self.current_part_idx; - self.tasks.push(Box::pin(async move { - let upload_part = inner.put_part(out_buffer, part_idx).await?; - Ok((part_idx, upload_part)) - })); - self.current_part_idx += 1; - - // We need to poll immediately after adding to setup waker - self.as_mut().poll_tasks(cx)?; - } - - // If offset is zero, then we didn't write anything because we didn't - // have capacity for more tasks and our buffer is full. - if offset == 0 && !buf.is_empty() { - Poll::Pending - } else { - Poll::Ready(Ok(offset)) - } - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - // Poll current tasks - self.as_mut().poll_tasks(cx)?; - - // If tasks is empty, return Ready - if self.tasks.is_empty() { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - // First, poll flush - match self.as_mut().final_flush(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(res) => res?, - }; - - // If shutdown task is not set, set it - let parts = std::mem::take(&mut self.completed_parts); - let parts = parts - .into_iter() - .enumerate() - .map(|(idx, part)| { - part.ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - format!("Missing information for upload part {idx}"), - ) - }) - }) - .collect::>()?; - - let inner = Arc::clone(&self.inner); - let completion_task = self.completion_task.get_or_insert_with(|| { - Box::pin(async move { - inner.complete(parts).await?; - Ok(()) - }) - }); - - Pin::new(completion_task).poll(cx) - } -} - -impl std::fmt::Debug for WriteMultiPart { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("WriteMultiPart") - .field("completed_parts", &self.completed_parts) - .field("tasks", &self.tasks) - .field("max_concurrency", &self.max_concurrency) - .field("current_buffer", &self.current_buffer) - .field("part_size", &self.part_size) - .field("current_part_idx", &self.current_part_idx) - .finish() - } -} - /// A low-level interface for interacting with multipart upload APIs /// /// Most use-cases should prefer [`ObjectStore::put_multipart`] as this is supported by more @@ -277,7 +43,7 @@ impl std::fmt::Debug for WriteMultiPart { /// [`ObjectStore::put_multipart`]: crate::ObjectStore::put_multipart /// [`LocalFileSystem`]: crate::local::LocalFileSystem #[async_trait] -pub trait MultiPartStore: Send + Sync + 'static { +pub trait MultipartStore: Send + Sync + 'static { /// Creates a new multipart upload, returning the [`MultipartId`] async fn create_multipart(&self, path: &Path) -> Result; @@ -288,10 +54,11 @@ pub trait MultiPartStore: Send + Sync + 'static { /// /// Most stores require that all parts excluding the last are at least 5 MiB, and some /// further require that all parts excluding the last be the same size, e.g. [R2]. - /// [`WriteMultiPart`] performs writes in fixed size blocks of 10 MiB, and clients wanting + /// [`WriteMultipart`] performs writes in fixed size blocks of 5 MiB, and clients wanting /// to maximise compatibility should look to do likewise. /// /// [R2]: https://developers.cloudflare.com/r2/objects/multipart-objects/#limitations + /// [`WriteMultipart`]: crate::upload::WriteMultipart async fn put_part( &self, path: &Path, diff --git a/object_store/src/parse.rs b/object_store/src/parse.rs index 116c2ad2ac0e..5549fd3a3e5f 100644 --- a/object_store/src/parse.rs +++ b/object_store/src/parse.rs @@ -311,14 +311,14 @@ mod tests { #[cfg(feature = "http")] async fn test_url_http() { use crate::client::mock_server::MockServer; - use hyper::{header::USER_AGENT, Body, Response}; + use hyper::{header::USER_AGENT, Response}; - let server = MockServer::new(); + let server = MockServer::new().await; server.push_fn(|r| { assert_eq!(r.uri().path(), "/foo/bar"); assert_eq!(r.headers().get(USER_AGENT).unwrap(), "test_url"); - Response::new(Body::empty()) + Response::new(String::new()) }); let test = format!("{}/foo/bar", server.url()); diff --git a/object_store/src/prefix.rs b/object_store/src/prefix.rs index 38f9b07bbd05..053f71a2d063 100644 --- a/object_store/src/prefix.rs +++ b/object_store/src/prefix.rs @@ -19,12 +19,11 @@ use bytes::Bytes; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use std::ops::Range; -use tokio::io::AsyncWrite; use crate::path::Path; use crate::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, - Result, + GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, PutOptions, + PutResult, Result, }; #[doc(hidden)] @@ -91,18 +90,11 @@ impl ObjectStore for PrefixStore { self.inner.put_opts(&full_path, bytes, opts).await } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { + async fn put_multipart(&self, location: &Path) -> Result> { let full_path = self.full_path(location); self.inner.put_multipart(&full_path).await } - async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> { - let full_path = self.full_path(location); - self.inner.abort_multipart(&full_path, multipart_id).await - } async fn get(&self, location: &Path) -> Result { let full_path = self.full_path(location); self.inner.get(&full_path).await diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index 252256a4599e..65fac5922f69 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -20,16 +20,16 @@ use parking_lot::Mutex; use std::ops::Range; use std::{convert::TryInto, sync::Arc}; +use crate::multipart::{MultipartStore, PartId}; use crate::{ - path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutOptions, - PutResult, Result, + path::Path, GetResult, GetResultPayload, ListResult, MultipartId, MultipartUpload, ObjectMeta, + ObjectStore, PutOptions, PutResult, Result, }; -use crate::{GetOptions, MultipartId}; +use crate::{GetOptions, UploadPart}; use async_trait::async_trait; use bytes::Bytes; use futures::{stream::BoxStream, FutureExt, StreamExt}; use std::time::Duration; -use tokio::io::AsyncWrite; /// Configuration settings for throttled store #[derive(Debug, Default, Clone, Copy)] @@ -111,12 +111,12 @@ async fn sleep(duration: Duration) { /// **Note that the behavior of the wrapper is deterministic and might not reflect real-world /// conditions!** #[derive(Debug)] -pub struct ThrottledStore { +pub struct ThrottledStore { inner: T, config: Arc>, } -impl ThrottledStore { +impl ThrottledStore { /// Create new wrapper with zero waiting times. pub fn new(inner: T, config: ThrottleConfig) -> Self { Self { @@ -158,15 +158,12 @@ impl ObjectStore for ThrottledStore { self.inner.put_opts(location, bytes, opts).await } - async fn put_multipart( - &self, - _location: &Path, - ) -> Result<(MultipartId, Box)> { - Err(super::Error::NotImplemented) - } - - async fn abort_multipart(&self, _location: &Path, _multipart_id: &MultipartId) -> Result<()> { - Err(super::Error::NotImplemented) + async fn put_multipart(&self, location: &Path) -> Result> { + let upload = self.inner.put_multipart(location).await?; + Ok(Box::new(ThrottledUpload { + upload, + sleep: self.config().wait_put_per_call, + })) } async fn get(&self, location: &Path) -> Result { @@ -324,6 +321,63 @@ where .boxed() } +#[async_trait] +impl MultipartStore for ThrottledStore { + async fn create_multipart(&self, path: &Path) -> Result { + self.inner.create_multipart(path).await + } + + async fn put_part( + &self, + path: &Path, + id: &MultipartId, + part_idx: usize, + data: Bytes, + ) -> Result { + sleep(self.config().wait_put_per_call).await; + self.inner.put_part(path, id, part_idx, data).await + } + + async fn complete_multipart( + &self, + path: &Path, + id: &MultipartId, + parts: Vec, + ) -> Result { + self.inner.complete_multipart(path, id, parts).await + } + + async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()> { + self.inner.abort_multipart(path, id).await + } +} + +#[derive(Debug)] +struct ThrottledUpload { + upload: Box, + sleep: Duration, +} + +#[async_trait] +impl MultipartUpload for ThrottledUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let duration = self.sleep; + let put = self.upload.put_part(data); + Box::pin(async move { + sleep(duration).await; + put.await + }) + } + + async fn complete(&mut self) -> Result { + self.upload.complete().await + } + + async fn abort(&mut self) -> Result<()> { + self.upload.abort().await + } +} + #[cfg(test)] mod tests { use super::*; @@ -359,6 +413,8 @@ mod tests { list_with_delimiter(&store).await; rename_and_copy(&store).await; copy_if_not_exists(&store).await; + stream_get(&store).await; + multipart(&store, &store).await; } #[tokio::test] diff --git a/object_store/src/upload.rs b/object_store/src/upload.rs new file mode 100644 index 000000000000..fe864e2821c9 --- /dev/null +++ b/object_store/src/upload.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::future::BoxFuture; +use futures::ready; +use tokio::task::JoinSet; + +use crate::{PutResult, Result}; + +/// An upload part request +pub type UploadPart = BoxFuture<'static, Result<()>>; + +/// A trait allowing writing an object in fixed size chunks +/// +/// Consecutive chunks of data can be written by calling [`MultipartUpload::put_part`] and polling +/// the returned futures to completion. Multiple futures returned by [`MultipartUpload::put_part`] +/// may be polled in parallel, allowing for concurrent uploads. +/// +/// Once all part uploads have been polled to completion, the upload can be completed by +/// calling [`MultipartUpload::complete`]. This will make the entire uploaded object visible +/// as an atomic operation.It is implementation behind behaviour if [`MultipartUpload::complete`] +/// is called before all [`UploadPart`] have been polled to completion. +#[async_trait] +pub trait MultipartUpload: Send + std::fmt::Debug { + /// Upload the next part + /// + /// Most stores require that all parts excluding the last are at least 5 MiB, and some + /// further require that all parts excluding the last be the same size, e.g. [R2]. + /// Clients wanting to maximise compatibility should therefore perform writes in + /// fixed size blocks larger than 5 MiB. + /// + /// Implementations may invoke this method multiple times and then await on the + /// returned futures in parallel + /// + /// ```no_run + /// # use futures::StreamExt; + /// # use object_store::MultipartUpload; + /// # + /// # async fn test() { + /// # + /// let mut upload: Box<&dyn MultipartUpload> = todo!(); + /// let p1 = upload.put_part(vec![0; 10 * 1024 * 1024].into()); + /// let p2 = upload.put_part(vec![1; 10 * 1024 * 1024].into()); + /// futures::future::try_join(p1, p2).await.unwrap(); + /// upload.complete().await.unwrap(); + /// # } + /// ``` + /// + /// [R2]: https://developers.cloudflare.com/r2/objects/multipart-objects/#limitations + fn put_part(&mut self, data: Bytes) -> UploadPart; + + /// Complete the multipart upload + /// + /// It is implementation defined behaviour if this method is called before polling + /// all [`UploadPart`] returned by [`MultipartUpload::put_part`] to completion. Additionally, + /// it is implementation defined behaviour to call [`MultipartUpload::complete`] + /// on an already completed or aborted [`MultipartUpload`]. + async fn complete(&mut self) -> Result; + + /// Abort the multipart upload + /// + /// If a [`MultipartUpload`] is dropped without calling [`MultipartUpload::complete`], + /// some object stores will automatically clean up any previously uploaded parts. + /// However, some stores, such as S3 and GCS, cannot perform cleanup on drop. + /// As such [`MultipartUpload::abort`] can be invoked to perform this cleanup. + /// + /// It will not be possible to call `abort` in all failure scenarios, for example + /// non-graceful shutdown of the calling application. It is therefore recommended + /// object stores are configured with lifecycle rules to automatically cleanup + /// unused parts older than some threshold. See [crate::aws] and [crate::gcp] + /// for more information. + /// + /// It is implementation defined behaviour to call [`MultipartUpload::abort`] + /// on an already completed or aborted [`MultipartUpload`] + async fn abort(&mut self) -> Result<()>; +} + +/// A synchronous write API for uploading data in parallel in fixed size chunks +/// +/// Uses multiple tokio tasks in a [`JoinSet`] to multiplex upload tasks in parallel +/// +/// The design also takes inspiration from [`Sink`] with [`WriteMultipart::wait_for_capacity`] +/// allowing back pressure on producers, prior to buffering the next part. However, unlike +/// [`Sink`] this back pressure is optional, allowing integration with synchronous producers +/// +/// [`Sink`]: futures::sink::Sink +#[derive(Debug)] +pub struct WriteMultipart { + upload: Box, + + buffer: Vec, + + tasks: JoinSet>, +} + +impl WriteMultipart { + /// Create a new [`WriteMultipart`] that will upload using 5MB chunks + pub fn new(upload: Box) -> Self { + Self::new_with_chunk_size(upload, 5 * 1024 * 1024) + } + + /// Create a new [`WriteMultipart`] that will upload in fixed `chunk_size` sized chunks + pub fn new_with_chunk_size(upload: Box, chunk_size: usize) -> Self { + Self { + upload, + buffer: Vec::with_capacity(chunk_size), + tasks: Default::default(), + } + } + + /// Polls for there to be less than `max_concurrency` [`UploadPart`] in progress + /// + /// See [`Self::wait_for_capacity`] for an async version of this function + pub fn poll_for_capacity( + &mut self, + cx: &mut Context<'_>, + max_concurrency: usize, + ) -> Poll> { + while !self.tasks.is_empty() && self.tasks.len() >= max_concurrency { + ready!(self.tasks.poll_join_next(cx)).unwrap()?? + } + Poll::Ready(Ok(())) + } + + /// Wait until there are less than `max_concurrency` [`UploadPart`] in progress + /// + /// See [`Self::poll_for_capacity`] for a [`Poll`] version of this function + pub async fn wait_for_capacity(&mut self, max_concurrency: usize) -> Result<()> { + futures::future::poll_fn(|cx| self.poll_for_capacity(cx, max_concurrency)).await + } + + /// Write data to this [`WriteMultipart`] + /// + /// Note this method is synchronous (not `async`) and will immediately + /// start new uploads as soon as the internal `chunk_size` is hit, + /// regardless of how many outstanding uploads are already in progress. + /// + /// Back pressure can optionally be applied to producers by calling + /// [`Self::wait_for_capacity`] prior to calling this method + pub fn write(&mut self, mut buf: &[u8]) { + while !buf.is_empty() { + let capacity = self.buffer.capacity(); + let remaining = capacity - self.buffer.len(); + let to_read = buf.len().min(remaining); + self.buffer.extend_from_slice(&buf[..to_read]); + if to_read == remaining { + let part = std::mem::replace(&mut self.buffer, Vec::with_capacity(capacity)); + self.put_part(part.into()) + } + buf = &buf[to_read..] + } + } + + fn put_part(&mut self, part: Bytes) { + self.tasks.spawn(self.upload.put_part(part)); + } + + /// Abort this upload, attempting to clean up any successfully uploaded parts + pub async fn abort(mut self) -> Result<()> { + self.tasks.shutdown().await; + self.upload.abort().await + } + + /// Flush final chunk, and await completion of all in-flight requests + pub async fn finish(mut self) -> Result { + if !self.buffer.is_empty() { + let part = std::mem::take(&mut self.buffer); + self.put_part(part.into()) + } + + self.wait_for_capacity(0).await?; + self.upload.complete().await + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures::FutureExt; + + use crate::memory::InMemory; + use crate::path::Path; + use crate::throttle::{ThrottleConfig, ThrottledStore}; + use crate::ObjectStore; + + use super::*; + + #[tokio::test] + async fn test_concurrency() { + let config = ThrottleConfig { + wait_put_per_call: Duration::from_millis(1), + ..Default::default() + }; + + let path = Path::from("foo"); + let store = ThrottledStore::new(InMemory::new(), config); + let upload = store.put_multipart(&path).await.unwrap(); + let mut write = WriteMultipart::new_with_chunk_size(upload, 10); + + for _ in 0..20 { + write.write(&[0; 5]); + } + assert!(write.wait_for_capacity(10).now_or_never().is_none()); + write.wait_for_capacity(10).await.unwrap() + } +} diff --git a/object_store/tests/get_range_file.rs b/object_store/tests/get_range_file.rs index f73d78578f08..309a86d8fe9d 100644 --- a/object_store/tests/get_range_file.rs +++ b/object_store/tests/get_range_file.rs @@ -25,7 +25,6 @@ use object_store::path::Path; use object_store::*; use std::fmt::Formatter; use tempfile::tempdir; -use tokio::io::AsyncWrite; #[derive(Debug)] struct MyStore(LocalFileSystem); @@ -42,14 +41,7 @@ impl ObjectStore for MyStore { self.0.put_opts(path, data, opts).await } - async fn put_multipart( - &self, - _: &Path, - ) -> Result<(MultipartId, Box)> { - todo!() - } - - async fn abort_multipart(&self, _: &Path, _: &MultipartId) -> Result<()> { + async fn put_multipart(&self, _location: &Path) -> Result> { todo!() } diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index e6d612e0cc62..04dfed408c02 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -53,6 +53,9 @@ brotli = { version = "3.3", default-features = false, features = ["std"], option flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"], optional = true } zstd = { version = "0.13.0", optional = true, default-features = false } +# TODO: temporary to fix parquet wasm build +# upstream issue: https://github.com/gyscos/zstd-rs/issues/269 +zstd-sys = { version = "=2.0.9", optional = true, default-features = false } chrono = { workspace = true } num = { version = "0.4", default-features = false } num-bigint = { version = "0.4", default-features = false } @@ -77,6 +80,9 @@ brotli = { version = "3.3", default-features = false, features = ["std"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"] } lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"] } zstd = { version = "0.13", default-features = false } +# TODO: temporary to fix parquet wasm build +# upstream issue: https://github.com/gyscos/zstd-rs/issues/269 +zstd-sys = { version = "=2.0.9", default-features = false } serde_json = { version = "1.0", features = ["std"], default-features = false } arrow = { workspace = true, features = ["ipc", "test_utils", "prettyprint", "json"] } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "io-util", "fs"] } @@ -104,6 +110,8 @@ experimental = [] async = ["futures", "tokio"] # Enable object_store integration object_store = ["dep:object_store", "async"] +# Group Zstd dependencies +zstd = ["dep:zstd", "zstd-sys"] [[example]] name = "read_parquet" diff --git a/parquet/pytest/requirements.txt b/parquet/pytest/requirements.txt index 7462e8ff3b0d..aa91a8bb6415 100644 --- a/parquet/pytest/requirements.txt +++ b/parquet/pytest/requirements.txt @@ -24,28 +24,29 @@ attrs==22.1.0 \ --hash=sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6 \ --hash=sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c # via pytest -black==22.10.0 \ - --hash=sha256:14ff67aec0a47c424bc99b71005202045dc09270da44a27848d534600ac64fc7 \ - --hash=sha256:197df8509263b0b8614e1df1756b1dd41be6738eed2ba9e9769f3880c2b9d7b6 \ - --hash=sha256:1e464456d24e23d11fced2bc8c47ef66d471f845c7b7a42f3bd77bf3d1789650 \ - --hash=sha256:2039230db3c6c639bd84efe3292ec7b06e9214a2992cd9beb293d639c6402edb \ - --hash=sha256:21199526696b8f09c3997e2b4db8d0b108d801a348414264d2eb8eb2532e540d \ - --hash=sha256:2644b5d63633702bc2c5f3754b1b475378fbbfb481f62319388235d0cd104c2d \ - --hash=sha256:432247333090c8c5366e69627ccb363bc58514ae3e63f7fc75c54b1ea80fa7de \ - --hash=sha256:444ebfb4e441254e87bad00c661fe32df9969b2bf224373a448d8aca2132b395 \ - --hash=sha256:5b9b29da4f564ba8787c119f37d174f2b69cdfdf9015b7d8c5c16121ddc054ae \ - --hash=sha256:5cc42ca67989e9c3cf859e84c2bf014f6633db63d1cbdf8fdb666dcd9e77e3fa \ - --hash=sha256:5d8f74030e67087b219b032aa33a919fae8806d49c867846bfacde57f43972ef \ - --hash=sha256:72ef3925f30e12a184889aac03d77d031056860ccae8a1e519f6cbb742736383 \ - --hash=sha256:819dc789f4498ecc91438a7de64427c73b45035e2e3680c92e18795a839ebb66 \ - --hash=sha256:915ace4ff03fdfff953962fa672d44be269deb2eaf88499a0f8805221bc68c87 \ - --hash=sha256:9311e99228ae10023300ecac05be5a296f60d2fd10fff31cf5c1fa4ca4b1988d \ - --hash=sha256:974308c58d057a651d182208a484ce80a26dac0caef2895836a92dd6ebd725e0 \ - --hash=sha256:b8b49776299fece66bffaafe357d929ca9451450f5466e997a7285ab0fe28e3b \ - --hash=sha256:c957b2b4ea88587b46cf49d1dc17681c1e672864fd7af32fc1e9664d572b3458 \ - --hash=sha256:e41a86c6c650bcecc6633ee3180d80a025db041a8e2398dcc059b3afa8382cd4 \ - --hash=sha256:f513588da599943e0cde4e32cc9879e825d58720d6557062d1098c5ad80080e1 \ - --hash=sha256:fba8a281e570adafb79f7755ac8721b6cf1bbf691186a287e990c7929c7692ff +black==24.3.0 \ + --hash=sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f \ + --hash=sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93 \ + --hash=sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11 \ + --hash=sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0 \ + --hash=sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9 \ + --hash=sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5 \ + --hash=sha256:65b76c275e4c1c5ce6e9870911384bff5ca31ab63d19c76811cb1fb162678213 \ + --hash=sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d \ + --hash=sha256:6905238a754ceb7788a73f02b45637d820b2f5478b20fec82ea865e4f5d4d9f7 \ + --hash=sha256:79dcf34b33e38ed1b17434693763301d7ccbd1c5860674a8f871bd15139e7837 \ + --hash=sha256:7bb041dca0d784697af4646d3b62ba4a6b028276ae878e53f6b4f74ddd6db99f \ + --hash=sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395 \ + --hash=sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995 \ + --hash=sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f \ + --hash=sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597 \ + --hash=sha256:b5991d523eee14756f3c8d5df5231550ae8993e2286b8014e2fdea7156ed0959 \ + --hash=sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5 \ + --hash=sha256:c45f8dff244b3c431b36e3224b6be4a127c6aca780853574c00faf99258041eb \ + --hash=sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4 \ + --hash=sha256:d7de8d330763c66663661a1ffd432274a2f92f07feeddd89ffd085b5744f85e7 \ + --hash=sha256:e19cb1c6365fd6dc38a6eae2dcb691d7d83935c10215aef8e6c38edee3f77abd \ + --hash=sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7 # via -r requirements.in click==8.1.3 \ --hash=sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e \ @@ -93,10 +94,12 @@ numpy==1.23.5 \ --hash=sha256:f063b69b090c9d918f9df0a12116029e274daf0181df392839661c4c7ec9018a \ --hash=sha256:f9a909a8bae284d46bbfdefbdd4a262ba19d3bc9921b1e76126b1d21c3c34135 # via pandas -packaging==21.3 \ - --hash=sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb \ - --hash=sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522 - # via pytest +packaging==24.0 \ + --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ + --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 + # via + # black + # pytest pandas==1.5.2 \ --hash=sha256:0183cb04a057cc38fde5244909fca9826d5d57c4a5b7390c0cc3fa7acd9fa883 \ --hash=sha256:1fc87eac0541a7d24648a001d553406f4256e744d92df1df8ebe41829a915028 \ @@ -142,10 +145,6 @@ py4j==0.10.9.5 \ --hash=sha256:276a4a3c5a2154df1860ef3303a927460e02e97b047dc0a47c1c3fb8cce34db6 \ --hash=sha256:52d171a6a2b031d8a5d1de6efe451cf4f5baff1a2819aabc3741c8406539ba04 # via pyspark -pyparsing==3.0.9 \ - --hash=sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb \ - --hash=sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc - # via packaging pyspark==3.3.1 \ --hash=sha256:e99fa7de92be406884bfd831c32b9306a3a99de44cfc39a2eefb6ed07445d5fa # via -r requirements.in @@ -171,3 +170,7 @@ tomli==2.0.1 \ # via # black # pytest +typing-extensions==4.10.0 \ + --hash=sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475 \ + --hash=sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb + # via black diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 18c8617e07e6..7206e6a6907d 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -1200,7 +1200,7 @@ mod tests { let schema = Schema::new(vec![string_field, binary_field]); let raw_string_values = vec!["foo", "bar", "baz", "quux"]; - let raw_binary_values = vec![ + let raw_binary_values = [ b"foo".to_vec(), b"bar".to_vec(), b"baz".to_vec(), diff --git a/parquet/src/column/mod.rs b/parquet/src/column/mod.rs index c81d6290abc2..fc3c5cf34221 100644 --- a/parquet/src/column/mod.rs +++ b/parquet/src/column/mod.rs @@ -84,9 +84,9 @@ //! let reader = SerializedFileReader::new(file).unwrap(); //! let metadata = reader.metadata(); //! -//! let mut values = vec![0; 8]; -//! let mut def_levels = vec![0; 8]; -//! let mut rep_levels = vec![0; 8]; +//! let mut values = vec![]; +//! let mut def_levels = vec![]; +//! let mut rep_levels = vec![]; //! //! for i in 0..metadata.num_row_groups() { //! let row_group_reader = reader.get_row_group(i).unwrap(); @@ -112,9 +112,9 @@ //! } //! } //! -//! assert_eq!(values, vec![1, 2, 3, 0, 0, 0, 0, 0]); -//! assert_eq!(def_levels, vec![3, 3, 3, 2, 2, 0, 0, 0]); -//! assert_eq!(rep_levels, vec![0, 1, 0, 1, 1, 0, 0, 0]); +//! assert_eq!(values, vec![1, 2, 3]); +//! assert_eq!(def_levels, vec![3, 3, 3, 2, 2]); +//! assert_eq!(rep_levels, vec![0, 1, 0, 1, 1]); //! ``` pub mod page; diff --git a/parquet/src/compression.rs b/parquet/src/compression.rs index 89f4b64d48b5..10560210e4e8 100644 --- a/parquet/src/compression.rs +++ b/parquet/src/compression.rs @@ -145,21 +145,40 @@ pub(crate) trait CompressionLevel { /// bytes for the compression type. /// This returns `None` if the codec type is `UNCOMPRESSED`. pub fn create_codec(codec: CodecType, _options: &CodecOptions) -> Result>> { + #[allow(unreachable_code, unused_variables)] match codec { - #[cfg(any(feature = "brotli", test))] - CodecType::BROTLI(level) => Ok(Some(Box::new(BrotliCodec::new(level)))), - #[cfg(any(feature = "flate2", test))] - CodecType::GZIP(level) => Ok(Some(Box::new(GZipCodec::new(level)))), - #[cfg(any(feature = "snap", test))] - CodecType::SNAPPY => Ok(Some(Box::new(SnappyCodec::new()))), - #[cfg(any(feature = "lz4", test))] - CodecType::LZ4 => Ok(Some(Box::new(LZ4HadoopCodec::new( - _options.backward_compatible_lz4, - )))), - #[cfg(any(feature = "zstd", test))] - CodecType::ZSTD(level) => Ok(Some(Box::new(ZSTDCodec::new(level)))), - #[cfg(any(feature = "lz4", test))] - CodecType::LZ4_RAW => Ok(Some(Box::new(LZ4RawCodec::new()))), + CodecType::BROTLI(level) => { + #[cfg(any(feature = "brotli", test))] + return Ok(Some(Box::new(BrotliCodec::new(level)))); + Err(ParquetError::General("Disabled feature at compile time: brotli".into())) + }, + CodecType::GZIP(level) => { + #[cfg(any(feature = "flate2", test))] + return Ok(Some(Box::new(GZipCodec::new(level)))); + Err(ParquetError::General("Disabled feature at compile time: flate2".into())) + }, + CodecType::SNAPPY => { + #[cfg(any(feature = "snap", test))] + return Ok(Some(Box::new(SnappyCodec::new()))); + Err(ParquetError::General("Disabled feature at compile time: snap".into())) + }, + CodecType::LZ4 => { + #[cfg(any(feature = "lz4", test))] + return Ok(Some(Box::new(LZ4HadoopCodec::new( + _options.backward_compatible_lz4, + )))); + Err(ParquetError::General("Disabled feature at compile time: lz4".into())) + }, + CodecType::ZSTD(level) => { + #[cfg(any(feature = "zstd", test))] + return Ok(Some(Box::new(ZSTDCodec::new(level)))); + Err(ParquetError::General("Disabled feature at compile time: zstd".into())) + }, + CodecType::LZ4_RAW => { + #[cfg(any(feature = "lz4", test))] + return Ok(Some(Box::new(LZ4RawCodec::new()))); + Err(ParquetError::General("Disabled feature at compile time: lz4".into())) + }, CodecType::UNCOMPRESSED => Ok(None), _ => Err(nyi_err!("The codec type {} is not supported yet", codec)), }