Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-43907: [C#][FlightRPC] Add Grpc Call Options support on Flight Client #43910

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Internal;
using Apache.Arrow.Flight.Protocol;
Expand All @@ -34,35 +35,55 @@ public FlightClient(ChannelBase grpcChannel)

public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria criteria = null, Metadata headers = null)
{
if(criteria == null)
return ListFlights(criteria, headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria criteria, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
if (criteria == null)
{
criteria = FlightCriteria.Empty;
}
var response = _client.ListFlights(criteria.ToProtocol(), headers);

var response = _client.ListFlights(criteria.ToProtocol(), headers, deadline, cancellationToken);
var convertStream = new StreamReader<Protocol.FlightInfo, FlightInfo>(response.ResponseStream, inFlight => new FlightInfo(inFlight));

return new AsyncServerStreamingCall<FlightInfo>(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}

public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata headers = null)
{
var response = _client.ListActions(EmptyInstance, headers);
return ListActions(headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var response = _client.ListActions(EmptyInstance, headers, deadline, cancellationToken);
var convertStream = new StreamReader<Protocol.ActionType, FlightActionType>(response.ResponseStream, actionType => new FlightActionType(actionType));

return new AsyncServerStreamingCall<FlightActionType>(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}

public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers = null)
{
var stream = _client.DoGet(ticket.ToProtocol(), headers);
return GetStream(ticket, headers, null, CancellationToken.None);
}

public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var stream = _client.DoGet(ticket.ToProtocol(), headers, deadline, cancellationToken);
var responseStream = new FlightClientRecordBatchStreamReader(stream.ResponseStream);
return new FlightRecordBatchStreamingCall(responseStream, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose);
}

public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var flightInfoResult = _client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers);
return GetInfo(flightDescriptor, headers, null, CancellationToken.None);
}

public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var flightInfoResult = _client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers, deadline, cancellationToken);

var flightInfo = flightInfoResult
.ResponseAsync
Expand All @@ -79,7 +100,12 @@ public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Met

public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var channels = _client.DoPut(headers);
return StartPut(flightDescriptor, headers, null, CancellationToken.None);
}

public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channels = _client.DoPut(headers, deadline, cancellationToken);
var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
var readStream = new StreamReader<Protocol.PutResult, FlightPutResult>(channels.ResponseStream, putResult => new FlightPutResult(putResult));
return new FlightRecordBatchDuplexStreamingCall(
Expand All @@ -93,7 +119,13 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc

public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers = null)
{
var channel = _client.Handshake(headers);
return Handshake(headers, null, CancellationToken.None);

}

public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channel = _client.Handshake(headers, deadline, cancellationToken);
var readStream = new StreamReader<HandshakeResponse, FlightHandshakeResponse>(channel.ResponseStream, response => new FlightHandshakeResponse(response));
var writeStream = new FlightHandshakeStreamWriterAdapter(channel.RequestStream);
var call = new AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse>(
Expand All @@ -109,7 +141,12 @@ public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse>

public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var channel = _client.DoExchange(headers);
return DoExchange(flightDescriptor, headers, null, CancellationToken.None);
}

public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channel = _client.DoExchange(headers, deadline, cancellationToken);
var requestStream = new FlightClientRecordBatchStreamWriter(channel.RequestStream, flightDescriptor);
var responseStream = new FlightClientRecordBatchStreamReader(channel.ResponseStream);
var call = new FlightRecordBatchExchangeCall(
Expand All @@ -125,14 +162,24 @@ public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescripto

public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction action, Metadata headers = null)
{
var stream = _client.DoAction(action.ToProtocol(), headers);
return DoAction(action, headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction action, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var stream = _client.DoAction(action.ToProtocol(), headers, deadline, cancellationToken);
var streamReader = new StreamReader<Protocol.Result, FlightResult>(stream.ResponseStream, result => new FlightResult(result));
return new AsyncServerStreamingCall<FlightResult>(streamReader, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose);
}

public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var schemaResult = _client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers);
return GetSchema(flightDescriptor, headers, null, CancellationToken.None);
}

public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var schemaResult = _client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers, deadline, cancellationToken);

var schema = schemaResult
.ResponseAsync
Expand Down
97 changes: 92 additions & 5 deletions csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Client;
using Apache.Arrow.Flight.TestWeb;
using Apache.Arrow.Tests;
using Google.Protobuf;
using Grpc.Core;
using Grpc.Core.Utils;
using Python.Runtime;
using Xunit;

namespace Apache.Arrow.Flight.Tests
Expand Down Expand Up @@ -70,7 +73,7 @@ private FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, params R

var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress());

foreach(var batch in batches)
foreach (var batch in batches)
{
flightHolder.AddBatch(batch);
}
Expand Down Expand Up @@ -187,8 +190,8 @@ public async Task TestGetFlightMetadata()

var getStream = _flightClient.GetStream(endpoint.Ticket);

List<ByteString> actualMetadata = new List<ByteString>();
while(await getStream.ResponseStream.MoveNext(default))
List<ByteString> actualMetadata = new List<ByteString>();
while (await getStream.ResponseStream.MoveNext(default))
{
actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata);
}
Expand Down Expand Up @@ -277,7 +280,7 @@ public async Task TestListFlights()

var actualFlights = await listFlightStream.ResponseStream.ToListAsync();

for(int i = 0; i < expectedFlightInfo.Count; i++)
for (int i = 0; i < expectedFlightInfo.Count; i++)
{
FlightInfoComparer.Compare(expectedFlightInfo[i], actualFlights[i]);
}
Expand Down Expand Up @@ -386,7 +389,7 @@ public async Task TestGetBatchesWithAsyncEnumerable()


List<RecordBatch> resultList = new List<RecordBatch>();
await foreach(var recordBatch in getStream.ResponseStream)
await foreach (var recordBatch in getStream.ResponseStream)
{
resultList.Add(recordBatch);
}
Expand Down Expand Up @@ -415,5 +418,89 @@ public async Task EnsureTheSerializedBatchContainsTheProperTotalRecordsAndTotalB
Assert.Equal(expectedBatch.Length, result.TotalRecords);
Assert.Equal(expectedTotalBytes, result.TotalBytes);
}

[Fact]
public async Task EnsureCallRaisesDeadlineExceeded()
{
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_deadline");
var deadline = DateTime.UtcNow;
var batch = CreateTestBatch(0, 100);

RpcException exception = null;

var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode);

var asyncServerStreamingCallActions = _flightClient.ListActions(null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode);

GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetInfo(flightDescriptor, null, deadline));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var flightInfo = await _flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();
var getStream = _flightClient.GetStream(endpoint.Ticket, null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, getStream.GetStatus().StatusCode);

var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var putStream = _flightClient.StartPut(flightDescriptor, null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, deadline));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var handshakeStreamingCall = _flightClient.Handshake(null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await handshakeStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty)));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
}

[Fact]
public async Task EnsureCallRaisesRequestCancelled()
{
var cts = new CancellationTokenSource();
cts.CancelAfter(1);

var batch = CreateTestBatch(0, 100);
var metadata = new Metadata();
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_cancelled");
await Task.Delay(5);
RpcException exception = null;

var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode);

var asyncServerStreamingCallActions = _flightClient.ListActions(null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode);

GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetInfo(flightDescriptor, null, null, cts.Token));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var flightInfo = await _flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();
var getStream = _flightClient.GetStream(endpoint.Ticket, null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, getStream.GetStatus().StatusCode);

var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var putStream = _flightClient.StartPut(flightDescriptor, null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var handshakeStreamingCall = _flightClient.Handshake(null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await handshakeStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty)));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

}
}
}
Loading