Skip to content

Commit

Permalink
feat(sdk): Add and expose tamper error types (#187)
Browse files Browse the repository at this point in the history
Exception type for tamper detection -- make it the parent to the
exception types we want to be able to catch for tamper

---------

Co-authored-by: Dave Mihalcik <[email protected]>
  • Loading branch information
elizabethhealy and dmihalcik-virtru authored Oct 30, 2024
1 parent 94b161d commit b4f95e6
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 15 deletions.
20 changes: 17 additions & 3 deletions sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import io.grpc.Status;
import io.opentdf.platform.kas.AccessServiceGrpc;
import io.opentdf.platform.kas.PublicKeyRequest;
import io.opentdf.platform.kas.PublicKeyResponse;
import io.opentdf.platform.kas.RewrapRequest;
import io.opentdf.platform.kas.RewrapResponse;
import io.opentdf.platform.sdk.Config.KASInfo;
import io.opentdf.platform.sdk.nanotdf.ECKeyPair;
import io.opentdf.platform.sdk.nanotdf.NanoTDFType;
import io.opentdf.platform.sdk.TDF.KasBadRequestException;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
Expand Down Expand Up @@ -182,9 +186,19 @@ public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy) {
.newBuilder()
.setSignedRequestToken(jwt.serialize())
.build();
var response = getStub(keyAccess.url).rewrap(request);
var wrappedKey = response.getEntityWrappedKey().toByteArray();
return decryptor.decrypt(wrappedKey);
RewrapResponse response;
try {
response = getStub(keyAccess.url).rewrap(request);
var wrappedKey = response.getEntityWrappedKey().toByteArray();
return decryptor.decrypt(wrappedKey);
} catch (StatusRuntimeException e) {
if (e.getStatus().getCode() == Status.Code.INVALID_ARGUMENT) {
// 400 Bad Request
throw new KasBadRequestException("rewrap request 400: " + e.toString());
}
throw e;
}

}

public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kasURL) {
Expand Down
5 changes: 4 additions & 1 deletion sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

import io.opentdf.platform.sdk.TDF.AssertionException;

import org.apache.commons.codec.binary.Hex;
import org.erdtman.jcs.JsonCanonicalizer;

Expand Down Expand Up @@ -381,7 +384,7 @@ public void sign(final HashValues hashValues, final AssertionConfig.AssertionKey
public Assertion.HashValues verify(AssertionConfig.AssertionKey assertionKey)
throws ParseException, JOSEException {
if (binding == null) {
throw new SDKException("Binding is null in assertion");
throw new AssertionException("Binding is null in assertion", this.id);
}

String signatureString = binding.signature;
Expand Down
38 changes: 28 additions & 10 deletions sdk/src/main/java/io/opentdf/platform/sdk/TDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,30 +119,48 @@ public FailedToCreateGMAC(String errorMessage) {
}
}

public static class NotValidateRootSignature extends RuntimeException {
public NotValidateRootSignature(String errorMessage) {
public static class TDFReadFailed extends RuntimeException {
public TDFReadFailed(String errorMessage) {
super(errorMessage);
}
}

public static class TamperException extends SDKException {
public TamperException(String errorMessage) {
super("[tamper detected] "+errorMessage);
}
}

public static class RootSignatureValidationException extends TamperException {
public RootSignatureValidationException(String errorMessage) {
super(errorMessage);
}
}

public static class SegmentSizeMismatch extends RuntimeException {
public static class SegmentSizeMismatch extends TamperException {
public SegmentSizeMismatch(String errorMessage) {
super(errorMessage);
}
}

public static class SegmentSignatureMismatch extends RuntimeException {
public static class SegmentSignatureMismatch extends TamperException {
public SegmentSignatureMismatch(String errorMessage) {
super(errorMessage);
}
}

public static class TDFReadFailed extends RuntimeException {
public TDFReadFailed(String errorMessage) {
public static class KasBadRequestException extends TamperException {
public KasBadRequestException(String errorMessage) {
super(errorMessage);
}
}

public static class AssertionException extends TamperException {
public AssertionException(String errorMessage, String id) {
super("assertion id: "+ id + "; " + errorMessage);
}
}

public static class EncryptedMetadata {
private String ciphertext;
private String iv;
Expand Down Expand Up @@ -558,7 +576,7 @@ public Reader loadTDF(SeekableByteChannel tdf, SDK.KAS kas)

public Reader loadTDF(SeekableByteChannel tdf, SDK.KAS kas,
Config.TDFReaderConfig tdfReaderConfig)
throws NotValidateRootSignature, SegmentSizeMismatch,
throws RootSignatureValidationException, SegmentSizeMismatch,
IOException, FailedToCreateGMAC, JOSEException, ParseException, NoSuchAlgorithmException, DecoderException {

TDFReader tdfReader = new TDFReader(tdf);
Expand Down Expand Up @@ -666,7 +684,7 @@ public Reader loadTDF(SeekableByteChannel tdf, SDK.KAS kas,
}

if (rootSignature.compareTo(rootSigValue) != 0) {
throw new NotValidateRootSignature("root signature validation failed");
throw new RootSignatureValidationException("root signature validation failed");
}

int segmentSize = manifest.encryptionInformation.integrityInformation.segmentSizeDefault;
Expand Down Expand Up @@ -701,11 +719,11 @@ public Reader loadTDF(SeekableByteChannel tdf, SDK.KAS kas,
var encodeSignature = Base64.getEncoder().encodeToString(signature.getBytes());

if (!Objects.equals(hashOfAssertion, hashValues.getAssertionHash())) {
throw new SDKException("assertion hash mismatch");
throw new AssertionException("assertion hash mismatch", assertion.id);
}

if (!Objects.equals(encodeSignature, hashValues.getSignature())) {
throw new SDKException("failed integrity check on assertion signature");
throw new AssertionException("failed integrity check on assertion signature", assertion.id);
}
}

Expand Down
58 changes: 57 additions & 1 deletion sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse;
import io.opentdf.platform.policy.attributes.AttributesServiceGrpc;
import io.opentdf.platform.sdk.Config.KASInfo;
import io.opentdf.platform.sdk.TDF.Reader;
import io.opentdf.platform.sdk.nanotdf.NanoTDFType;
import org.apache.commons.compress.utils.SeekableInMemoryByteChannel;
import org.junit.jupiter.api.BeforeAll;
Expand All @@ -30,7 +31,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

Expand Down Expand Up @@ -333,6 +333,62 @@ void testSimpleTDFWithAssertionWithHS256() throws Exception {
}
}

@Test
void testSimpleTDFWithAssertionWithHS256Failure() throws Exception {

ListenableFuture<GetAttributeValuesByFqnsResponse> resp1 = mock(ListenableFuture.class);
lenient().when(resp1.get()).thenReturn(GetAttributeValuesByFqnsResponse.newBuilder().build());
lenient().when(attributeGrpcStub.getAttributeValuesByFqns(any(GetAttributeValuesByFqnsRequest.class)))
.thenReturn(resp1);

// var keypair = CryptoUtils.generateRSAKeypair();
SecureRandom secureRandom = new SecureRandom();
byte[] key = new byte[32];
secureRandom.nextBytes(key);

String assertion1Id = "assertion1";
var assertionConfig1 = new AssertionConfig();
assertionConfig1.id = assertion1Id;
assertionConfig1.type = AssertionConfig.Type.BaseAssertion;
assertionConfig1.scope = AssertionConfig.Scope.TrustedDataObj;
assertionConfig1.appliesToState = AssertionConfig.AppliesToState.Unencrypted;
assertionConfig1.statement = new AssertionConfig.Statement();
assertionConfig1.statement.format = "base64binary";
assertionConfig1.statement.schema = "text";
assertionConfig1.statement.value = "ICAgIDxlZGoOkVkaD4=";
assertionConfig1.assertionKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, key);

Config.TDFConfig config = Config.newTDFConfig(
Config.withAutoconfigure(false),
Config.withKasInformation(getKASInfos()),
Config.withAssertionConfig(assertionConfig1));

String plainText = "this is extremely sensitive stuff!!!";
InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes());
ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream();

TDF tdf = new TDF();
tdf.createTDF(plainTextInputStream, tdfOutputStream, config, kas, attributeGrpcStub);

byte[] notkey = new byte[32];
secureRandom.nextBytes(notkey);
var assertionVerificationKeys = new Config.AssertionVerificationKeys();
assertionVerificationKeys.defaultKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256,
notkey);
Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig(
Config.withAssertionVerificationKeys(assertionVerificationKeys));

var unwrappedData = new ByteArrayOutputStream();
Reader reader;
try {
reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), kas, readerConfig);
throw new RuntimeException("assertion verify key error thrown");

} catch (SDKException e) {
assertThat(e).hasMessageContaining("verify");
}
}

@Test
public void testCreatingTDFWithMultipleSegments() throws Exception {

Expand Down

0 comments on commit b4f95e6

Please sign in to comment.