Skip to content

Commit

Permalink
Migrate all JUnit assertions to AssertJ in polaris-service module (a…
Browse files Browse the repository at this point in the history
…pache#68)

* Use AssertJ assertions in polaris-service tests

* Use `@ParameterizedTest` in polaris-service tests
  • Loading branch information
ebyhr authored Aug 2, 2024
1 parent 3d70e75 commit 97e26f2
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
*/
package io.polaris.service.auth;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Fail.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTVerifier;
Expand Down Expand Up @@ -138,12 +137,12 @@ public void testSuccessfulTokenGeneration() throws Exception {
} catch (Exception e) {
fail("Unexpected exception: " + e);
}
assertNotNull(token);
assertEquals(420, token.getExpiresIn());
assertThat(token).isNotNull();
assertThat(token.getExpiresIn()).isEqualTo(420);

LocalRSAKeyProvider provider = new LocalRSAKeyProvider();
assertNotNull(provider.getPrivateKey());
assertNotNull(provider.getPublicKey());
assertThat(provider.getPrivateKey()).isNotNull();
assertThat(provider.getPublicKey()).isNotNull();
JWTVerifier verifier =
JWT.require(
Algorithm.RSA256(
Expand All @@ -152,8 +151,8 @@ public void testSuccessfulTokenGeneration() throws Exception {
.withIssuer("polaris")
.build();
DecodedJWT decodedJWT = verifier.verify(token.getAccessToken());
assertNotNull(decodedJWT);
assertEquals(decodedJWT.getClaim("scope").asString(), "PRINCIPAL_ROLE:TEST");
assertEquals(decodedJWT.getClaim("client_id").asString(), "test-client-id");
assertThat(decodedJWT).isNotNull();
assertThat(decodedJWT.getClaim("scope").asString()).isEqualTo("PRINCIPAL_ROLE:TEST");
assertThat(decodedJWT.getClaim("client_id").asString()).isEqualTo("test-client-id");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
*/
package io.polaris.service.auth;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.assertj.core.api.Assertions.assertThat;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTVerifier;
Expand Down Expand Up @@ -82,13 +81,13 @@ public Map<String, Object> contextVariables() {
TokenResponse token =
generator.generateFromClientSecrets(
clientId, mainSecret, TokenRequestValidator.CLIENT_CREDENTIALS, "PRINCIPAL_ROLE:TEST");
assertNotNull(token);
assertThat(token).isNotNull();

JWTVerifier verifier = JWT.require(Algorithm.HMAC256("polaris")).withIssuer("polaris").build();
DecodedJWT decodedJWT = verifier.verify(token.getAccessToken());
assertNotNull(decodedJWT);
assertEquals(666, token.getExpiresIn());
assertEquals(decodedJWT.getClaim("scope").asString(), "PRINCIPAL_ROLE:TEST");
assertEquals(decodedJWT.getClaim("client_id").asString(), clientId);
assertThat(decodedJWT).isNotNull();
assertThat(token.getExpiresIn()).isEqualTo(666);
assertThat(decodedJWT.getClaim("scope").asString()).isEqualTo("PRINCIPAL_ROLE:TEST");
assertThat(decodedJWT.getClaim("client_id").asString()).isEqualTo(clientId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,74 +15,72 @@
*/
package io.polaris.service.auth;

import java.util.Arrays;
import java.util.Optional;
import org.junit.jupiter.api.Assertions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

public class TokenRequestValidatorTest {
@Test
public void testValidateForClientCredentialsFlowNullClientId() {
Assertions.assertEquals(
OAuthTokenErrorResponse.Error.invalid_client,
new TokenRequestValidator()
.validateForClientCredentialsFlow(null, "notnull", "notnull", "nontnull")
.get());
Assertions.assertEquals(
OAuthTokenErrorResponse.Error.invalid_client,
new TokenRequestValidator()
.validateForClientCredentialsFlow("", "notnull", "notnull", "nonnull")
.get());
Assertions.assertThat(
new TokenRequestValidator()
.validateForClientCredentialsFlow(null, "notnull", "notnull", "nontnull")
.get())
.isEqualTo(OAuthTokenErrorResponse.Error.invalid_client);
Assertions.assertThat(
new TokenRequestValidator()
.validateForClientCredentialsFlow("", "notnull", "notnull", "nonnull")
.get())
.isEqualTo(OAuthTokenErrorResponse.Error.invalid_client);
}

@Test
public void testValidateForClientCredentialsFlowNullClientSecret() {
Assertions.assertEquals(
OAuthTokenErrorResponse.Error.invalid_client,
new TokenRequestValidator()
.validateForClientCredentialsFlow("client-id", null, "notnull", "nontnull")
.get());
Assertions.assertEquals(
OAuthTokenErrorResponse.Error.invalid_client,
new TokenRequestValidator()
.validateForClientCredentialsFlow("client-id", "", "notnull", "notnull")
.get());
Assertions.assertThat(
new TokenRequestValidator()
.validateForClientCredentialsFlow("client-id", null, "notnull", "nontnull")
.get())
.isEqualTo(OAuthTokenErrorResponse.Error.invalid_client);
Assertions.assertThat(
new TokenRequestValidator()
.validateForClientCredentialsFlow("client-id", "", "notnull", "notnull")
.get())
.isEqualTo(OAuthTokenErrorResponse.Error.invalid_client);
}

@Test
public void testValidateForClientCredentialsFlowInvalidGrantType() {
Assertions.assertEquals(
OAuthTokenErrorResponse.Error.invalid_grant,
new TokenRequestValidator()
.validateForClientCredentialsFlow(
"client-id", "client-secret", "not-client-credentials", "notnull")
.get());
Assertions.assertEquals(
OAuthTokenErrorResponse.Error.invalid_grant,
new TokenRequestValidator()
.validateForClientCredentialsFlow("client-id", "client-secret", "grant", "notnull")
.get());
Assertions.assertThat(
new TokenRequestValidator()
.validateForClientCredentialsFlow(
"client-id", "client-secret", "not-client-credentials", "notnull")
.get())
.isEqualTo(OAuthTokenErrorResponse.Error.invalid_grant);
Assertions.assertThat(
new TokenRequestValidator()
.validateForClientCredentialsFlow("client-id", "client-secret", "grant", "notnull")
.get())
.isEqualTo(OAuthTokenErrorResponse.Error.invalid_grant);
}

@Test
public void testValidateForClientCredentialsFlowInvalidScope() {
for (String scope :
Arrays.asList("null", "", ",", "ALL", "PRINCIPAL_ROLE:", "PRINCIPAL_ROLE")) {
Assertions.assertEquals(
OAuthTokenErrorResponse.Error.invalid_scope,
new TokenRequestValidator()
.validateForClientCredentialsFlow(
"client-id", "client-secret", "client_credentials", scope)
.get());
}
@ParameterizedTest
@ValueSource(strings = {"null", "", ",", "ALL", "PRINCIPAL_ROLE:", "PRINCIPAL_ROLE"})
public void testValidateForClientCredentialsFlowInvalidScope(String scope) {
Assertions.assertThat(
new TokenRequestValidator()
.validateForClientCredentialsFlow(
"client-id", "client-secret", "client_credentials", scope)
.get())
.isEqualTo(OAuthTokenErrorResponse.Error.invalid_scope);
}

@Test
public void testValidateForClientCredentialsFlowAllValid() {
Assertions.assertEquals(
Optional.empty(),
new TokenRequestValidator()
.validateForClientCredentialsFlow(
"client-id", "client-secret", "client_credentials", "PRINCIPAL_ROLE:ALL"));
Assertions.assertThat(
new TokenRequestValidator()
.validateForClientCredentialsFlow(
"client-id", "client-secret", "client_credentials", "PRINCIPAL_ROLE:ALL"))
.isEmpty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static io.polaris.service.context.DefaultContextResolver.REALM_PROPERTY_KEY;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.adobe.testing.s3mock.testcontainers.S3MockContainer;
import io.dropwizard.testing.ConfigOverride;
Expand Down Expand Up @@ -46,7 +47,6 @@
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -299,13 +299,8 @@ public void testCreateAndUpdateExternalTable() {
assertThat(tables).hasSize(1).extracting(row -> row.getString(1)).containsExactly("mytb1");
long rowCount = spark.sql("SELECT * FROM mytb1").count();
assertThat(rowCount).isEqualTo(3);
try {
spark.sql("INSERT INTO mytb1 VALUES (20, 'new_text')");
Assertions.fail("Expected exception when inserting into external table");
} catch (Exception e) {
LoggerFactory.getLogger(getClass()).info("Expected exception", e);
// expected exception
}
assertThatThrownBy(() -> spark.sql("INSERT INTO mytb1 VALUES (20, 'new_text')"))
.isInstanceOf(Exception.class);

spark.sql("INSERT INTO " + CATALOG_NAME + ".ns1.tb1 VALUES (20, 'new_text')");
tableResponse = loadTable(CATALOG_NAME, "ns1", "tb1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import io.polaris.core.admin.model.StorageConfigInfo;
import io.polaris.core.entity.CatalogEntity;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
Expand All @@ -50,13 +50,10 @@ public void testInvalidAllowedLocationPrefix() {
.setProperties(prop)
.setStorageConfigInfo(awsStorageConfigModel)
.build();
Exception ex =
Assertions.assertThrows(
IllegalArgumentException.class, () -> CatalogEntity.fromCatalog(awsCatalog));
Assertions.assertTrue(
ex.getMessage()
.contains(
"Location prefix not allowed: 'unsupportPrefix://mybucket/path', expected prefix: 's3://'"));
Assertions.assertThatThrownBy(() -> CatalogEntity.fromCatalog(awsCatalog))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining(
"Location prefix not allowed: 'unsupportPrefix://mybucket/path', expected prefix: 's3://'");

// Invaliad azure prefix
AzureStorageConfigInfo azureStorageConfigModel =
Expand All @@ -74,12 +71,9 @@ public void testInvalidAllowedLocationPrefix() {
new CatalogProperties("abfs://[email protected]/path"))
.setStorageConfigInfo(azureStorageConfigModel)
.build();
Exception ex2 =
Assertions.assertThrows(
IllegalArgumentException.class, () -> CatalogEntity.fromCatalog(azureCatalog));
Assertions.assertTrue(
ex2.getMessage()
.contains("Invalid azure adls location uri unsupportPrefix://mybucket/path"));
Assertions.assertThatThrownBy(() -> CatalogEntity.fromCatalog(azureCatalog))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Invalid azure adls location uri unsupportPrefix://mybucket/path");

// invalid gcp prefix
GcpStorageConfigInfo gcpStorageConfigModel =
Expand All @@ -94,13 +88,10 @@ public void testInvalidAllowedLocationPrefix() {
.setProperties(new CatalogProperties("gs://externally-owned-bucket"))
.setStorageConfigInfo(gcpStorageConfigModel)
.build();
Exception ex3 =
Assertions.assertThrows(
IllegalArgumentException.class, () -> CatalogEntity.fromCatalog(gcpCatalog));
Assertions.assertTrue(
ex3.getMessage()
.contains(
"Location prefix not allowed: 'unsupportPrefix://mybucket/path', expected prefix: 'gs://'"));
Assertions.assertThatThrownBy(() -> CatalogEntity.fromCatalog(gcpCatalog))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining(
"Location prefix not allowed: 'unsupportPrefix://mybucket/path', expected prefix: 'gs://'");
}

@Test
Expand Down Expand Up @@ -129,10 +120,9 @@ public void testExceedMaxAllowedLocations() {
.setProperties(prop)
.setStorageConfigInfo(awsStorageConfigModel)
.build();
Exception ex =
Assertions.assertThrows(
IllegalArgumentException.class, () -> CatalogEntity.fromCatalog(awsCatalog));
Assertions.assertTrue(ex.getMessage().contains("Number of allowed locations exceeds 5"));
Assertions.assertThatThrownBy(() -> CatalogEntity.fromCatalog(awsCatalog))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Number of allowed locations exceeds 5");
}

@Test
Expand All @@ -155,7 +145,7 @@ public void testValidAllowedLocationPrefix() {
.setProperties(prop)
.setStorageConfigInfo(awsStorageConfigModel)
.build();
Assertions.assertDoesNotThrow(() -> CatalogEntity.fromCatalog(awsCatalog));
Assertions.assertThatNoException().isThrownBy(() -> CatalogEntity.fromCatalog(awsCatalog));

basedLocation = "abfs://[email protected]/path";
prop.put(CatalogEntity.DEFAULT_BASE_LOCATION_KEY, basedLocation);
Expand All @@ -172,7 +162,7 @@ public void testValidAllowedLocationPrefix() {
.setProperties(new CatalogProperties(basedLocation))
.setStorageConfigInfo(azureStorageConfigModel)
.build();
Assertions.assertDoesNotThrow(() -> CatalogEntity.fromCatalog(azureCatalog));
Assertions.assertThatNoException().isThrownBy(() -> CatalogEntity.fromCatalog(azureCatalog));

basedLocation = "gs://externally-owned-bucket";
prop.put(CatalogEntity.DEFAULT_BASE_LOCATION_KEY, basedLocation);
Expand All @@ -188,7 +178,7 @@ public void testValidAllowedLocationPrefix() {
.setProperties(new CatalogProperties(basedLocation))
.setStorageConfigInfo(gcpStorageConfigModel)
.build();
Assertions.assertDoesNotThrow(() -> CatalogEntity.fromCatalog(gcpCatalog));
Assertions.assertThatNoException().isThrownBy(() -> CatalogEntity.fromCatalog(gcpCatalog));
}

@ParameterizedTest
Expand All @@ -211,9 +201,6 @@ public void testInvalidArn(String roleArn) {
.setProperties(prop)
.setStorageConfigInfo(awsStorageConfigModel)
.build();
Exception ex =
Assertions.assertThrows(
IllegalArgumentException.class, () -> CatalogEntity.fromCatalog(awsCatalog));
String expectedMessage = "";
switch (roleArn) {
case "":
Expand All @@ -227,6 +214,8 @@ public void testInvalidArn(String roleArn) {
expectedMessage = "Invalid role ARN format";
}
;
Assertions.assertEquals(ex.getMessage(), expectedMessage);
Assertions.assertThatThrownBy(() -> CatalogEntity.fromCatalog(awsCatalog))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage(expectedMessage);
}
}

0 comments on commit 97e26f2

Please sign in to comment.