Skip to content

Commit

Permalink
update method signature, with default and name of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Marco Malagoli committed Sep 3, 2024
1 parent 7f000a0 commit 88869bf
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 20 deletions.
69 changes: 58 additions & 11 deletions csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Internal;
using Apache.Arrow.Flight.Protocol;
Expand All @@ -32,7 +33,12 @@ public FlightClient(ChannelBase grpcChannel)
_client = new FlightService.FlightServiceClient(grpcChannel);
}

public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria criteria = null, Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria criteria = null, Metadata headers = null)
{
return ListFlights(criteria, headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria criteria, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
if (criteria == null)
{
Expand All @@ -45,24 +51,39 @@ public FlightClient(ChannelBase grpcChannel)
return new AsyncServerStreamingCall<FlightInfo>(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}

public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata headers = null)
{
return ListActions(headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var response = _client.ListActions(EmptyInstance, headers, deadline, cancellationToken);
var convertStream = new StreamReader<Protocol.ActionType, FlightActionType>(response.ResponseStream, actionType => new FlightActionType(actionType));

return new AsyncServerStreamingCall<FlightActionType>(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}

public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers = null)
{
return GetStream(ticket, headers, null, CancellationToken.None);
}

public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var stream = _client.DoGet(ticket.ToProtocol(), headers, deadline, cancellationToken);
var responseStream = new FlightClientRecordBatchStreamReader(stream.ResponseStream);
return new FlightRecordBatchStreamingCall(responseStream, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose);
}

public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var flightInfoResult = _client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers);
return GetInfo(flightDescriptor, headers, null, CancellationToken.None);
}

public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var flightInfoResult = _client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers, deadline, cancellationToken);

var flightInfo = flightInfoResult
.ResponseAsync
Expand All @@ -77,7 +98,12 @@ public FlightClient(ChannelBase grpcChannel)
flightInfoResult.Dispose);
}

public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null)
{
return StartPut(flightDescriptor, headers, null, CancellationToken.None);
}

public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channels = _client.DoPut(headers, deadline, cancellationToken);
var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
Expand All @@ -91,7 +117,13 @@ public FlightClient(ChannelBase grpcChannel)
channels.Dispose);
}

public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers = null)
{
return Handshake(headers, null, CancellationToken.None);

}

public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channel = _client.Handshake(headers, deadline, cancellationToken);
var readStream = new StreamReader<HandshakeResponse, FlightHandshakeResponse>(channel.ResponseStream, response => new FlightHandshakeResponse(response));
Expand All @@ -107,7 +139,12 @@ public FlightClient(ChannelBase grpcChannel)
return call;
}

public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers = null)
{
return DoExchange(flightDescriptor, headers, null, CancellationToken.None);
}

public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channel = _client.DoExchange(headers, deadline, cancellationToken);
var requestStream = new FlightClientRecordBatchStreamWriter(channel.RequestStream, flightDescriptor);
Expand All @@ -123,14 +160,24 @@ public FlightClient(ChannelBase grpcChannel)
return call;
}

public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction action, Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction action, Metadata headers = null)
{
return DoAction(action, headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction action, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var stream = _client.DoAction(action.ToProtocol(), headers);
var stream = _client.DoAction(action.ToProtocol(), headers, deadline, cancellationToken);
var streamReader = new StreamReader<Protocol.Result, FlightResult>(stream.ResponseStream, result => new FlightResult(result));
return new AsyncServerStreamingCall<FlightResult>(streamReader, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose);
}

public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor flightDescriptor, Metadata headers = null, System.DateTime? deadline = null, System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken))
public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor flightDescriptor, Metadata headers = null)
{
return GetSchema(flightDescriptor, headers, null, CancellationToken.None);
}

public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var schemaResult = _client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers, deadline, cancellationToken);

Expand Down
42 changes: 33 additions & 9 deletions csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -420,14 +420,29 @@ public async Task EnsureTheSerializedBatchContainsTheProperTotalRecordsAndTotalB
}

[Fact]
public async Task EnsureCallRaiseDeadlineExceeded()
public async Task EnsureCallRaisesDeadlineExceeded()
{
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_deadline");
var deadline = DateTime.UtcNow;
var batch = CreateTestBatch(0, 100);

RpcException exception = null;

var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode);

var asyncServerStreamingCallActions = _flightClient.ListActions(null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode);

GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetInfo(flightDescriptor, null, deadline));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var flightInfo = await _flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();
var getStream = _flightClient.GetStream(endpoint.Ticket, null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, getStream.GetStatus().StatusCode);

var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
Expand All @@ -442,24 +457,36 @@ public async Task EnsureCallRaiseDeadlineExceeded()
var handshakeStreamingCall = _flightClient.Handshake(null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await handshakeStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty)));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode);
}

[Fact]
public async Task EnsureCallRaiseRequestCancelled()
public async Task EnsureCallRaisesRequestCancelled()
{
var cts = new CancellationTokenSource();
cts.CancelAfter(1);

var batch = CreateTestBatch(0, 100);
var metadata = new Metadata();
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_cancelled");
await Task.Delay(5);
RpcException exception = null;

var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode);

var asyncServerStreamingCallActions = _flightClient.ListActions(null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode);

GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetInfo(flightDescriptor, null, null, cts.Token));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var flightInfo = await _flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();
var getStream = _flightClient.GetStream(endpoint.Ticket, null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, getStream.GetStatus().StatusCode);

var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, null, cts.Token);
duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

Expand All @@ -474,9 +501,6 @@ public async Task EnsureCallRaiseRequestCancelled()
exception = await Assert.ThrowsAsync<RpcException>(async () => await handshakeStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty)));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode);

}
}
}

0 comments on commit 88869bf

Please sign in to comment.