Skip to content

Commit

Permalink
Fix validation (#533)
Browse files Browse the repository at this point in the history
- Changes cert entity validation to only prevent collision with non-managed entities
- Changes key operations validation to only allow setting operations which make sense for the key type (no sign/verify for OCT, no encrypt/decrypt/wrap/unwrap for EC)
- Updates tests

Resolves #529
{minor}

Signed-off-by: Esta Nagy <[email protected]>
  • Loading branch information
nagyesta authored Mar 31, 2023
1 parent a82d17b commit 96c9ce5
Show file tree
Hide file tree
Showing 31 changed files with 282 additions and 216 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ public VersionedKeyEntityId generateKeyPair(final ReadOnlyCertificatePolicy inpu
final OffsetDateTime expiry = now.plusMonths(input.getValidityMonths());
return vaultFake.keyVaultFake().createKeyVersion(input.getName(), KeyCreateDetailedInput.builder()
.key(input.toKeyCreationInput())
.keyOperations(List.of(
KeyOperation.SIGN, KeyOperation.VERIFY,
KeyOperation.ENCRYPT, KeyOperation.DECRYPT,
KeyOperation.WRAP_KEY, KeyOperation.UNWRAP_KEY))
.keyOperations(List.of(KeyOperation.SIGN, KeyOperation.VERIFY))
.notBefore(now)
.expiresOn(expiry)
.enabled(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import com.github.nagyesta.lowkeyvault.model.v7_2.key.request.JsonWebKeyImportRequest;
import com.github.nagyesta.lowkeyvault.model.v7_3.certificate.CertificateRestoreInput;
import com.github.nagyesta.lowkeyvault.service.EntityId;
import com.github.nagyesta.lowkeyvault.service.certificate.ReadOnlyKeyVaultCertificateEntity;
import com.github.nagyesta.lowkeyvault.service.certificate.id.VersionedCertificateEntityId;
import com.github.nagyesta.lowkeyvault.service.common.BaseVaultEntity;
import com.github.nagyesta.lowkeyvault.service.common.impl.KeyVaultBaseEntity;
import com.github.nagyesta.lowkeyvault.service.exception.CryptoException;
import com.github.nagyesta.lowkeyvault.service.key.id.KeyEntityId;
Expand Down Expand Up @@ -41,7 +43,6 @@ public class KeyVaultCertificateEntity
private final String originalCertificateContents;
private CertificatePolicy issuancePolicy;
private PKCS10CertificationRequest csr;

/**
* Constructor for certificate creation.
*
Expand All @@ -57,10 +58,7 @@ public KeyVaultCertificateEntity(@NonNull final String name,
"Certificate name (" + name + ") did not match name from certificate creation input: " + input.getName());
final KeyEntityId kid = new KeyEntityId(vault.baseUri(), name);
final SecretEntityId sid = new SecretEntityId(vault.baseUri(), name);
Assert.state(!vault.keyVaultFake().getEntities().containsName(kid.id()),
"Key must not exist to be able to store certificate data in it. " + kid.asUriNoVersion(vault.baseUri()));
Assert.state(!vault.secretVaultFake().getEntities().containsName(sid.id()),
"Secret must not exist to be able to store certificate data in it. " + sid.asUriNoVersion(vault.baseUri()));
assertNoNameCollisionWithNotManagedEntity(vault, kid, sid);
this.issuancePolicy = new CertificatePolicy(input);
this.originalCertificatePolicy = new CertificatePolicy(input);
this.generator = new CertificateBackingEntityGenerator(vault);
Expand Down Expand Up @@ -100,10 +98,7 @@ public KeyVaultCertificateEntity(@NonNull final String name,
"Certificate name (" + name + ") did not match name from certificate creation input: " + policy.getName());
final KeyEntityId kid = new KeyEntityId(vault.baseUri(), name);
final SecretEntityId sid = new SecretEntityId(vault.baseUri(), name);
Assert.state(!vault.keyVaultFake().getEntities().containsName(kid.id()),
"Key must not exist to be able to store certificate data in it. " + kid.asUriNoVersion(vault.baseUri()));
Assert.state(!vault.secretVaultFake().getEntities().containsName(sid.id()),
"Secret must not exist to be able to store certificate data in it. " + sid.asUriNoVersion(vault.baseUri()));
assertNoNameCollisionWithNotManagedEntity(vault, kid, sid);
this.issuancePolicy = new CertificatePolicy(policy);
this.originalCertificatePolicy = new CertificatePolicy(originalCertificateData);
this.generator = new CertificateBackingEntityGenerator(vault);
Expand Down Expand Up @@ -134,7 +129,7 @@ public KeyVaultCertificateEntity(@NonNull final ReadOnlyCertificatePolicy input,
super(vault);
Assert.state(vault.keyVaultFake().getEntities().containsEntity(kid),
"Key must exist to be able to renew certificate using it. " + kid.asUriNoVersion(vault.baseUri()));
Assert.state(vault.secretVaultFake().getEntities().containsName(input.getName()),
Assert.state(vault.secretVaultFake().getEntities().containsEntityMatching(input.getName(), BaseVaultEntity::isManaged),
"A version of the Secret must exist to be able to generate a new version using name: " + input.getName());
this.issuancePolicy = new CertificatePolicy(input);
this.originalCertificatePolicy = new CertificatePolicy(input);
Expand Down Expand Up @@ -169,6 +164,7 @@ public KeyVaultCertificateEntity(@NonNull final VersionedCertificateEntityId id,
final JsonWebKeyImportRequest keyImportRequest = input.getKeyData();
final VersionedKeyEntityId kid = new VersionedKeyEntityId(vault.baseUri(), id.id(), input.getKeyVersion());
final VersionedSecretEntityId sid = new VersionedSecretEntityId(vault.baseUri(), id.id(), id.version());
assertNoNameCollisionWithNotManagedEntity(vault, kid, sid);
this.issuancePolicy = new CertificatePolicy(policy);
this.originalCertificatePolicy = new CertificatePolicy(originalCertificateData);
this.generator = new CertificateBackingEntityGenerator(vault);
Expand Down Expand Up @@ -297,6 +293,18 @@ public void regenerateCertificate(final VaultFake vault) {
}
}

private static void assertNoNameCollisionWithNotManagedEntity(
final VaultFake vault, final KeyEntityId kid, final SecretEntityId sid) {
Assert.state(!vault.keyVaultFake().getEntities().containsEntityMatching(kid.id(), KeyVaultCertificateEntity::isNotManaged),
"Key must not exist to be able to store certificate data in it. " + kid.asUriNoVersion(vault.baseUri()));
Assert.state(!vault.secretVaultFake().getEntities().containsEntityMatching(sid.id(), KeyVaultCertificateEntity::isNotManaged),
"Secret must not exist to be able to store certificate data in it. " + sid.asUriNoVersion(vault.baseUri()));
}

private static boolean isNotManaged(final BaseVaultEntity<? extends EntityId> e) {
return !e.isManaged();
}

private void normalizeCoreTimeStamps(final ReadOnlyCertificatePolicy certPolicy, final OffsetDateTime createOrUpdate) {
this.setNotBefore(certPolicy.getValidityStart());
this.setExpiry(certPolicy.getValidityStart().plusMonths(certPolicy.getValidityMonths()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.Deque;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;

public interface ReadOnlyVersionedEntityMultiMap<K extends EntityId, V extends K, RE extends BaseVaultEntity<V>> {

Expand All @@ -17,6 +18,8 @@ public interface ReadOnlyVersionedEntityMultiMap<K extends EntityId, V extends K

boolean containsName(String name);

boolean containsEntityMatching(String name, Predicate<RE> predicate);

boolean containsEntity(K entityId);

void assertContainsEntity(V entityId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -67,6 +68,11 @@ public boolean containsName(@NonNull final String name) {
return entities.containsKey(name);
}

@Override
public boolean containsEntityMatching(final String name, final Predicate<RE> predicate) {
return containsName(name) && entities.get(name).values().stream().anyMatch(predicate);
}

@Override
public boolean containsEntity(@NonNull final K entityId) {
return containsName(entityId.id()) && entities.get(entityId.id()).containsKey(entityId.version());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import java.util.List;

import static com.github.nagyesta.lowkeyvault.service.key.util.KeyGenUtil.generateAes;

Expand Down Expand Up @@ -56,6 +57,11 @@ public int getKeySize() {
return getKeyParam();
}

@Override
protected List<KeyOperation> disallowedOperations() {
return List.of(KeyOperation.SIGN, KeyOperation.VERIFY);
}

@Override
public byte[] encryptBytes(
@NonNull final byte[] clear, @NonNull final EncryptionAlgorithm encryptionAlgorithm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.security.Signature;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.util.List;
import java.util.Optional;

import static com.github.nagyesta.lowkeyvault.service.key.util.KeyGenUtil.generateEc;
Expand Down Expand Up @@ -69,6 +70,11 @@ public KeyCurveName getKeyCurveName() {
return getKeyParam();
}

@Override
protected List<KeyOperation> disallowedOperations() {
return List.of(KeyOperation.WRAP_KEY, KeyOperation.UNWRAP_KEY, KeyOperation.ENCRYPT, KeyOperation.DECRYPT);
}

@Override
public byte[] encryptBytes(final byte[] clear, final EncryptionAlgorithm encryptionAlgorithm, final byte[] iv) {
throw new UnsupportedOperationException("Encrypt is not supported for EC keys.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import com.github.nagyesta.lowkeyvault.service.vault.VaultFake;
import lombok.NonNull;
import org.slf4j.Logger;
import org.springframework.util.Assert;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;

/**
* Common Key entity base class.
Expand Down Expand Up @@ -63,10 +65,16 @@ public List<KeyOperation> getOperations() {
}

public void setOperations(final List<KeyOperation> operations) {
final List<KeyOperation> invalid = operations.stream().filter(this.disallowedOperations()::contains).collect(Collectors.toList());
Assert.isTrue(invalid.isEmpty(), "Operation not allowed for this key type: " + invalid + ".");
this.updatedNow();
this.operations = List.copyOf(operations);
}

protected List<KeyOperation> disallowedOperations() {
return Collections.emptyList();
}

protected <R> R doCrypto(final Callable<R> task, final String message, final Logger log) {
try {
return task.call();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void testConstructorShouldThrowExceptionWhenCalledWithAlreadyUsedKeyName() {

final ReadOnlyVersionedEntityMultiMap<KeyEntityId, VersionedKeyEntityId, ReadOnlyKeyVaultKeyEntity> keyMap
= mock(ReadOnlyVersionedEntityMultiMap.class);
when(keyMap.containsName(eq(id.id()))).thenReturn(true);
when(keyMap.containsEntityMatching(eq(id.id()), any())).thenReturn(true);

final KeyVaultFake keyFake = mock(KeyVaultFake.class);
when(keyFake.getEntities()).thenReturn(keyMap);
Expand All @@ -137,7 +137,7 @@ void testConstructorShouldThrowExceptionWhenCalledWithAlreadyUsedKeyName() {
//then + exception
verify(vault).keyVaultFake();
verify(keyFake).getEntities();
verify(keyMap).containsName(eq(id.id()));
verify(keyMap).containsEntityMatching(eq(id.id()), any());
}

@SuppressWarnings("unchecked")
Expand All @@ -149,7 +149,7 @@ void testConstructorShouldThrowExceptionWhenCalledWithAlreadyUsedSecretName() {

final ReadOnlyVersionedEntityMultiMap<SecretEntityId, VersionedSecretEntityId, ReadOnlyKeyVaultSecretEntity> secretMap
= mock(ReadOnlyVersionedEntityMultiMap.class);
when(secretMap.containsName(eq(id.id()))).thenReturn(true);
when(secretMap.containsEntityMatching(eq(id.id()), any())).thenReturn(true);

final ReadOnlyVersionedEntityMultiMap<KeyEntityId, VersionedKeyEntityId, ReadOnlyKeyVaultKeyEntity> keyMap
= mock(ReadOnlyVersionedEntityMultiMap.class);
Expand All @@ -172,7 +172,7 @@ void testConstructorShouldThrowExceptionWhenCalledWithAlreadyUsedSecretName() {
//then + exception
verify(vault).secretVaultFake();
verify(secretFake).getEntities();
verify(secretMap).containsName(eq(id.id()));
verify(secretMap).containsEntityMatching(eq(id.id()), any());
}

@Test
Expand Down Expand Up @@ -383,7 +383,7 @@ void testRenewalConstructorShouldThrowExceptionWhenNoMatchingSecretNameFound() {

final ReadOnlyVersionedEntityMultiMap<SecretEntityId, VersionedSecretEntityId, ReadOnlyKeyVaultSecretEntity> secretMap
= mock(ReadOnlyVersionedEntityMultiMap.class);
when(keyMap.containsName(eq(id.id()))).thenReturn(false);
when(keyMap.containsEntityMatching(eq(id.id()), any())).thenReturn(false);
final SecretVaultFake secretFake = mock(SecretVaultFake.class);
when(secretFake.getEntities()).thenReturn(secretMap);

Expand All @@ -400,7 +400,7 @@ void testRenewalConstructorShouldThrowExceptionWhenNoMatchingSecretNameFound() {
verify(keyFake).getEntities();
verify(keyMap).containsEntity(eq(kid));
verify(secretFake).getEntities();
verify(secretMap).containsName(eq(id.id()));
verify(secretMap).containsEntityMatching(eq(id.id()), any());
}

@ParameterizedTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ void testSignShouldThrowExceptionWhenCalled() {
final VaultFake vaultFake = new VaultFakeImpl(HTTPS_LOWKEY_VAULT);
final AesKeyVaultKeyEntity underTest = new AesKeyVaultKeyEntity(
VERSIONED_KEY_ENTITY_ID_1_VERSION_1, vaultFake, KeyType.OCT.getValidKeyParameters(Integer.class).first(), false);
underTest.setOperations(List.of(KeyOperation.SIGN, KeyOperation.VERIFY));
underTest.setEnabled(true);

//when
Expand All @@ -234,7 +233,6 @@ void testVerifyShouldThrowExceptionWhenCalled() {
final VaultFake vaultFake = new VaultFakeImpl(HTTPS_LOWKEY_VAULT);
final AesKeyVaultKeyEntity underTest = new AesKeyVaultKeyEntity(
VERSIONED_KEY_ENTITY_ID_1_VERSION_1, vaultFake, KeyType.OCT.getValidKeyParameters(Integer.class).first(), false);
underTest.setOperations(List.of(KeyOperation.SIGN, KeyOperation.VERIFY));
underTest.setEnabled(true);

//when
Expand All @@ -245,6 +243,21 @@ void testVerifyShouldThrowExceptionWhenCalled() {
//then + exception
}

@Test
void testSetOperationsShouldThrowExceptionWhenCalledWithSignOrVerify() {
//given
final VaultFake vaultFake = new VaultFakeImpl(HTTPS_LOWKEY_VAULT);
final AesKeyVaultKeyEntity underTest = new AesKeyVaultKeyEntity(
VERSIONED_KEY_ENTITY_ID_1_VERSION_1, vaultFake, KeyType.OCT.getValidKeyParameters(Integer.class).first(), false);
underTest.setEnabled(true);

//when
Assertions.assertThrows(IllegalArgumentException.class,
() -> underTest.setOperations(List.of(KeyOperation.SIGN, KeyOperation.VERIFY)));

//then + exception
}

@Test
void testKeyCreationInputShouldReturnOriginalParameters() {
//given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ public static Stream<Arguments> keyOperationsProvider() {
.add(Arguments.of(List.of()))
.add(Arguments.of(List.of(KeyOperation.ENCRYPT)))
.add(Arguments.of(List.of(KeyOperation.ENCRYPT, KeyOperation.DECRYPT)))
.add(Arguments.of(Arrays.asList(KeyOperation.values())))
.build();
}

Expand Down Expand Up @@ -888,7 +887,7 @@ void testRotateKeyShouldCreateNewKeyVersionKeepingTagsAndOperationsWhenCalledWit
//given
final KeyCurveName keyParameter = KeyCurveName.P_384;
final Map<String, String> tags = Map.of(KEY_1, VALUE_1);
final List<KeyOperation> operations = List.of(KeyOperation.ENCRYPT);
final List<KeyOperation> operations = List.of(KeyOperation.SIGN);

final KeyVaultFake underTest = createUnderTest();
final VersionedKeyEntityId keyEntityId = underTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ public void countCertificatesAreImportedFromTheGivenResourceUsingPassword(
final int count, final String resource, final String password) throws IOException {
final byte[] content = Objects.requireNonNull(getClass().getResourceAsStream("/certs/" + resource)).readAllBytes();
final CertificateClient client = context.getClient(context.getCertificateServiceVersion());
IntStream.range(0, count).forEach(i -> {
IntStream.range(1, count + 1).forEach(i -> {
final String name = "multi-import-" + i;
final ImportCertificateOptions options = new ImportCertificateOptions(name, content);
options.setEnabled(true);
Optional.ofNullable(password).ifPresent(options::setPassword);
final KeyVaultCertificateWithPolicy certificate = client
.importCertificate(options);
Expand Down Expand Up @@ -204,7 +205,7 @@ public void theCertificateVersionsAreListed() {
@And("{int} certificates with {name} prefix are deleted")
public void certificatesWithMultiImportPrefixAreDeleted(final int count, final String prefix) {
final CertificateClient client = context.getClient(context.getCertificateServiceVersion());
IntStream.range(0, count).forEach(i -> {
IntStream.range(1, count + 1).forEach(i -> {
final DeletedCertificate deletedCertificate = client.beginDeleteCertificate(prefix + i)
.waitForCompletion().getValue();
context.setLastDeleted(deletedCertificate);
Expand All @@ -214,15 +215,15 @@ public void certificatesWithMultiImportPrefixAreDeleted(final int count, final S
@And("{int} certificates with {name} prefix are purged")
public void certificatesWithMultiImportPrefixArePurged(final int count, final String prefix) {
final CertificateClient client = context.getClient(context.getCertificateServiceVersion());
IntStream.range(0, count).forEach(i -> {
IntStream.range(1, count + 1).forEach(i -> {
client.purgeDeletedCertificate(prefix + i);
});
}

@And("{int} certificates with {name} prefix are recovered")
public void certificatesWithMultiImportPrefixAreRecovered(final int count, final String prefix) {
final CertificateClient client = context.getClient(context.getCertificateServiceVersion());
IntStream.range(0, count).forEach(i -> {
IntStream.range(1, count + 1).forEach(i -> {
client.beginRecoverDeletedCertificate(prefix + i).waitForCompletion();
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public void ecKeyImportedWithNameAndParameters(final String name, final KeyCurve
final KeyPair keyPair = KeyGenUtil.generateEc(curveName);
context.setKeyPair(keyPair);
final JsonWebKey key = JsonWebKey.fromEc(keyPair, BOUNCY_CASTLE_PROVIDER)
.setKeyOps(List.of(KeyOperation.SIGN, KeyOperation.ENCRYPT, KeyOperation.WRAP_KEY));
.setKeyOps(List.of(KeyOperation.SIGN));
if (hsm) {
key.setKeyType(KeyType.EC_HSM);
}
Expand Down Expand Up @@ -155,7 +155,7 @@ public void octKeyImportedWithNameAndParameters(final String name, final int siz
final SecretKey secretKey = KeyGenUtil.generateAes(size);
context.setSecretKey(secretKey);
final JsonWebKey key = JsonWebKey.fromAes(secretKey)
.setKeyOps(List.of(KeyOperation.SIGN, KeyOperation.ENCRYPT, KeyOperation.WRAP_KEY))
.setKeyOps(List.of(KeyOperation.ENCRYPT, KeyOperation.WRAP_KEY))
.setKeyType(KeyType.OCT_HSM);
final ImportKeyOptions options = new ImportKeyOptions(name, key)
.setHardwareProtected(true);
Expand Down
Loading

0 comments on commit 96c9ce5

Please sign in to comment.