Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/master' into parquet-sbbf-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed May 4, 2024
2 parents bf18d68 + b3f06f6 commit 7d18a2b
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 54 deletions.
11 changes: 10 additions & 1 deletion arrow-cast/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,20 @@ impl Parser for Float64Type {
}
}

/// This API is only stable since 1.70 so can't use it when current MSRV is lower
#[inline(always)]
fn is_some_and<T>(opt: Option<T>, f: impl FnOnce(T) -> bool) -> bool {
match opt {
None => false,
Some(x) => f(x),
}
}

macro_rules! parser_primitive {
($t:ty) => {
impl Parser for $t {
fn parse(string: &str) -> Option<Self::Native> {
if !string.as_bytes().last().is_some_and(|x| x.is_ascii_digit()) {
if !is_some_and(string.as_bytes().last(), |x| x.is_ascii_digit()) {
return None;
}
match atoi::FromRadix10SignedChecked::from_radix_10_signed_checked(
Expand Down
42 changes: 30 additions & 12 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,7 @@ impl FlightClient {
///
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
/// .build(futures::stream::iter(vec![Ok(batch)]))
/// // data encoder return Results, but do_exchange requires FlightData
/// .map(|batch|batch.unwrap());
/// .build(futures::stream::iter(vec![Ok(batch)]));
///
/// // send the stream and get the results as `RecordBatches`
/// let response: Vec<RecordBatch> = client
Expand All @@ -431,20 +429,40 @@ impl FlightClient {
/// .expect("error calling do_exchange");
/// # }
/// ```
pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let request = self.make_request(request);
let (sender, mut receiver) = futures::channel::oneshot::channel();

let response = self
.inner
.do_exchange(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_exchange(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});

Ok(FlightRecordBatchStream::new_from_flight_data(response))
// combine the response from the server and any error from the client
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
}

/// Make a `ListFlights` call to the server with the provided
Expand Down
97 changes: 94 additions & 3 deletions arrow-flight/tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ async fn test_do_exchange() {
.set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect());

let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");

Expand Down Expand Up @@ -528,7 +528,7 @@ async fn test_do_exchange_error() {
let input_flight_data = test_flight_data().await;

let response = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Expand Down Expand Up @@ -572,7 +572,7 @@ async fn test_do_exchange_error_stream() {
test_server.set_do_exchange_response(response);

let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");

Expand All @@ -593,6 +593,97 @@ async fn test_do_exchange_error_stream() {
.await;
}

#[tokio::test]
async fn test_do_exchange_error_stream_client() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e = Status::invalid_argument("bad arg: client");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e.clone(),
))]));

let output_flight_data = FlightData::new()
.with_descriptor(FlightDescriptor::new_cmd("Sample command"))
.with_data_body("body".as_bytes())
.with_data_header("header".as_bytes())
.with_app_metadata("metadata".as_bytes());

// server responds with one good message
let response = vec![Ok(output_flight_data)];
test_server.set_do_exchange_response(response);

let response_stream = client
.do_exchange(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client
expect_status(response, e);
// server still got the request messages until the client sent the error
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_do_exchange_error_client_and_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e_client = Status::invalid_argument("bad arg: client");
let e_server = Status::invalid_argument("bad arg: server");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e_client.clone(),
))]));

// server responds with an error (e.g. because it got truncated data)
let response = vec![Err(e_server)];
test_server.set_do_exchange_response(response);

let response_stream = client
.do_exchange(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client (not the server)
expect_status(response, e_client);
// server still got the request messages until the client sent the error
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_get_schema() {
do_test(|test_server, mut client| async move {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ struct Args {
arrow: String,
#[clap(short, long, help("Path to JSON file"))]
json: String,
#[clap(value_enum, short, long, default_value_t = Mode::Validate, help="Mode of integration testing tool")]
#[clap(
value_enum,
short,
long,
default_value = "VALIDATE",
help = "Mode of integration testing tool"
)]
mode: Mode,
#[clap(short, long)]
verbose: bool,
Expand Down
28 changes: 12 additions & 16 deletions arrow-json/src/reader/serializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,16 @@ impl<'a, 'b> SerializeMap for ObjectSerializer<'a, 'b> {
type Ok = ();
type Error = SerializerError;

fn serialize_key<T: ?Sized>(&mut self, key: &T) -> Result<(), Self::Error>
fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
key.serialize(&mut *self.serializer)
}

fn serialize_value<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
value.serialize(&mut *self.serializer)
}
Expand All @@ -333,13 +333,9 @@ impl<'a, 'b> SerializeStruct for ObjectSerializer<'a, 'b> {
type Ok = ();
type Error = SerializerError;

fn serialize_field<T: ?Sized>(
&mut self,
key: &'static str,
value: &T,
) -> Result<(), Self::Error>
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
key.serialize(&mut *self.serializer)?;
value.serialize(&mut *self.serializer)
Expand Down Expand Up @@ -376,9 +372,9 @@ impl<'a, 'b> SerializeSeq for ListSerializer<'a, 'b> {
type Ok = ();
type Error = SerializerError;

fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
value.serialize(&mut *self.serializer)
}
Expand All @@ -393,9 +389,9 @@ impl<'a, 'b> SerializeTuple for ListSerializer<'a, 'b> {
type Ok = ();
type Error = SerializerError;

fn serialize_element<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
value.serialize(&mut *self.serializer)
}
Expand All @@ -410,9 +406,9 @@ impl<'a, 'b> SerializeTupleStruct for ListSerializer<'a, 'b> {
type Ok = ();
type Error = SerializerError;

fn serialize_field<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: Serialize,
T: Serialize + ?Sized,
{
value.serialize(&mut *self.serializer)
}
Expand Down
17 changes: 13 additions & 4 deletions arrow-json/src/writer/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,21 @@ struct StructArrayEncoder<'a> {
explicit_nulls: bool,
}

/// This API is only stable since 1.70 so can't use it when current MSRV is lower
#[inline(always)]
fn is_some_and<T>(opt: Option<T>, f: impl FnOnce(T) -> bool) -> bool {
match opt {
None => false,
Some(x) => f(x),
}
}

impl<'a> Encoder for StructArrayEncoder<'a> {
fn encode(&mut self, idx: usize, out: &mut Vec<u8>) {
out.push(b'{');
let mut is_first = true;
for field_encoder in &mut self.encoders {
let is_null = field_encoder.nulls.as_ref().is_some_and(|n| n.is_null(idx));
let is_null = is_some_and(field_encoder.nulls.as_ref(), |n| n.is_null(idx));
if is_null && !self.explicit_nulls {
continue;
}
Expand Down Expand Up @@ -447,13 +456,13 @@ impl<'a> MapEncoder<'a> {
let (values, value_nulls) = make_encoder_impl(values, options)?;

// We sanity check nulls as these are currently not enforced by MapArray (#1697)
if key_nulls.is_some_and(|x| x.null_count() != 0) {
if is_some_and(key_nulls, |x| x.null_count() != 0) {
return Err(ArrowError::InvalidArgumentError(
"Encountered nulls in MapArray keys".to_string(),
));
}

if array.entries().nulls().is_some_and(|x| x.null_count() != 0) {
if is_some_and(array.entries().nulls(), |x| x.null_count() != 0) {
return Err(ArrowError::InvalidArgumentError(
"Encountered nulls in MapArray entries".to_string(),
));
Expand All @@ -478,7 +487,7 @@ impl<'a> Encoder for MapEncoder<'a> {

out.push(b'{');
for idx in start..end {
let is_null = self.value_nulls.as_ref().is_some_and(|n| n.is_null(idx));
let is_null = is_some_and(self.value_nulls.as_ref(), |n| n.is_null(idx));
if is_null && !self.explicit_nulls {
continue;
}
Expand Down
4 changes: 2 additions & 2 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1401,9 +1401,9 @@ mod tests {
);
}

fn _test_take_string<'a, K: 'static>()
fn _test_take_string<'a, K>()
where
K: Array + PartialEq + From<Vec<Option<&'a str>>>,
K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
{
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);

Expand Down
8 changes: 4 additions & 4 deletions arrow/src/util/string_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
//! }
//! ```
use std::fmt::Formatter;
use std::io::{Error, ErrorKind, Result, Write};

#[derive(Debug)]
Expand All @@ -83,10 +84,9 @@ impl Default for StringWriter {
Self::new()
}
}

impl ToString for StringWriter {
fn to_string(&self) -> String {
self.data.clone()
impl std::fmt::Display for StringWriter {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.data)
}
}

Expand Down
7 changes: 0 additions & 7 deletions parquet/src/column/writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1172,13 +1172,6 @@ fn compare_greater<T: ParquetValueType>(descr: &ColumnDescriptor, a: &T, b: &T)
// https://github.com/apache/parquet-mr/blob/master/parquet-column/src/main/java/org/apache/parquet/column/values/factory/DefaultV1ValuesWriterFactory.java
// https://github.com/apache/parquet-mr/blob/master/parquet-column/src/main/java/org/apache/parquet/column/values/factory/DefaultV2ValuesWriterFactory.java

/// Trait to define default encoding for types, including whether or not the type
/// supports dictionary encoding.
trait EncodingWriteSupport {
/// Returns true if dictionary is supported for column writer, false otherwise.
fn has_dictionary_support(props: &WriterProperties) -> bool;
}

/// Returns encoding for a column when no other encoding is provided in writer properties.
fn fallback_encoding(kind: Type, props: &WriterProperties) -> Encoding {
match (kind, props.writer_version()) {
Expand Down
Loading

0 comments on commit 7d18a2b

Please sign in to comment.