Skip to content

Commit

Permalink
update org.junit.Assert to org.junit.jupiter.api.Assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
llama90 committed Jun 12, 2024
1 parent 383ffe0 commit beec961
Show file tree
Hide file tree
Showing 26 changed files with 1,686 additions and 1,783 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.flight.client;

import com.google.common.collect.ImmutableMap;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.HashMap;
import java.util.Map;

import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
Expand All @@ -43,22 +46,25 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/** Tests to ensure custom headers are passed along to the server for each command. */
import com.google.common.collect.ImmutableMap;

/**
* Tests to ensure custom headers are passed along to the server for each command.
*/
public class CustomHeaderTest {
FlightServer server;
FlightClient client;
BufferAllocator allocator;
TestCustomHeaderMiddleware.Factory headersMiddleware;
HeaderCallOption headers;
Map<String, String> testHeaders =
ImmutableMap.of(
Map<String, String> testHeaders = ImmutableMap.of(
"foo", "bar",
"bar", "foo",
"answer", "42");
"answer", "42"
);

@Before
public void setUp() throws Exception {
Expand All @@ -69,13 +75,11 @@ public void setUp() throws Exception {
callHeaders.insert(entry.getKey(), entry.getValue());
}
headers = new HeaderCallOption(callHeaders);
server =
FlightServer.builder(
allocator,
Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, /*port*/ 0),
new NoOpFlightProducer())
.middleware(FlightServerMiddleware.Key.of("customHeader"), headersMiddleware)
.build();
server = FlightServer.builder(allocator,
Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, /*port*/ 0),
new NoOpFlightProducer())
.middleware(FlightServerMiddleware.Key.of("customHeader"), headersMiddleware)
.build();
server.start();
client = FlightClient.builder(allocator, server.getLocation()).build();
}
Expand All @@ -90,8 +94,7 @@ public void tearDown() throws Exception {
public void testHandshake() {
try {
client.handshake(headers);
} catch (Exception ignored) {
}
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.HANDSHAKE);
}
Expand All @@ -100,8 +103,7 @@ public void testHandshake() {
public void testGetSchema() {
try {
client.getSchema(FlightDescriptor.command(new byte[0]), headers);
} catch (Exception ignored) {
}
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.GET_SCHEMA);
}
Expand All @@ -110,8 +112,7 @@ public void testGetSchema() {
public void testGetFlightInfo() {
try {
client.getInfo(FlightDescriptor.command(new byte[0]), headers);
} catch (Exception ignored) {
}
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.GET_FLIGHT_INFO);
}
Expand All @@ -120,18 +121,16 @@ public void testGetFlightInfo() {
public void testListActions() {
try {
client.listActions(headers).iterator().next();
} catch (Exception ignored) {
}
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.LIST_ACTIONS);
}

@Test
public void testListFlights() {
try {
client.listFlights(new Criteria(new byte[] {1}), headers).iterator().next();
} catch (Exception ignored) {
}
client.listFlights(new Criteria(new byte[]{1}), headers).iterator().next();
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.LIST_FLIGHTS);
}
Expand All @@ -140,20 +139,19 @@ public void testListFlights() {
public void testDoAction() {
try {
client.doAction(new Action("test"), headers).next();
} catch (Exception ignored) {
}
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.DO_ACTION);
}

@Test
public void testStartPut() {
try {
final ClientStreamListener listener =
client.startPut(FlightDescriptor.command(new byte[0]), new SyncPutListener(), headers);
final ClientStreamListener listener = client.startPut(FlightDescriptor.command(new byte[0]),
new SyncPutListener(),
headers);
listener.getResult();
} catch (Exception ignored) {
}
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.DO_PUT);
}
Expand All @@ -162,54 +160,62 @@ public void testStartPut() {
public void testGetStream() {
try (final FlightStream stream = client.getStream(new Ticket(new byte[0]), headers)) {
stream.next();
} catch (Exception ignored) {
}
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.DO_GET);
}

@Test
public void testDoExchange() {
try (final FlightClient.ExchangeReaderWriter stream =
client.doExchange(FlightDescriptor.command(new byte[0]), headers)) {
try (final FlightClient.ExchangeReaderWriter stream = client.doExchange(
FlightDescriptor.command(new byte[0]),
headers)
) {
stream.getReader().next();
} catch (Exception ignored) {
}
} catch (Exception ignored) { }

assertHeadersMatch(FlightMethod.DO_EXCHANGE);
}

private void assertHeadersMatch(FlightMethod method) {
for (Map.Entry<String, String> entry : testHeaders.entrySet()) {
Assert.assertEquals(
entry.getValue(), headersMiddleware.getCustomHeader(method, entry.getKey()));
assertEquals(entry.getValue(), headersMiddleware.getCustomHeader(method, entry.getKey()));
}
}

/** A middleware used to test if customHeaders are being sent to the server properly. */
/**
* A middleware used to test if customHeaders are being sent to the server properly.
*/
static class TestCustomHeaderMiddleware implements FlightServerMiddleware {

public TestCustomHeaderMiddleware() {}
public TestCustomHeaderMiddleware() {
}

@Override
public void onBeforeSendingHeaders(CallHeaders callHeaders) {}
public void onBeforeSendingHeaders(CallHeaders callHeaders) {

}

@Override
public void onCallCompleted(CallStatus callStatus) {}
public void onCallCompleted(CallStatus callStatus) {

}

@Override
public void onCallErrored(Throwable throwable) {}
public void onCallErrored(Throwable throwable) {

}

/**
* A factory for the middleware that keeps track of the received headers and provides a way to
* check those values for a given Flight Method.
* A factory for the middleware that keeps track of the received headers and provides a way
* to check those values for a given Flight Method.
*/
static class Factory implements FlightServerMiddleware.Factory<TestCustomHeaderMiddleware> {
private final Map<FlightMethod, CallHeaders> receivedCallHeaders = new HashMap<>();

@Override
public TestCustomHeaderMiddleware onCallStarted(
CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
public TestCustomHeaderMiddleware onCallStarted(CallInfo callInfo, CallHeaders callHeaders,
RequestContext requestContext) {

receivedCallHeaders.put(callInfo.method(), callHeaders);
return new TestCustomHeaderMiddleware();
Expand All @@ -223,5 +229,5 @@ public String getCustomHeader(FlightMethod method, String key) {
return headers.get(key);
}
}
}
}
}
Loading

0 comments on commit beec961

Please sign in to comment.