diff --git a/src/sm4_gcm.c b/src/sm4_gcm.c index 780794a70..a86d3e29d 100644 --- a/src/sm4_gcm.c +++ b/src/sm4_gcm.c @@ -27,9 +27,6 @@ int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, uint8_t *out, size_t taglen, uint8_t *tag) { - const uint8_t *pin = in; - uint8_t *pout = out; - size_t left = inlen; uint8_t H[16] = {0}; uint8_t Y[16]; uint8_t T[16]; @@ -51,19 +48,12 @@ int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, sm4_encrypt(key, Y, T); - while (left) { - uint8_t block[16]; - size_t len = left < 16 ? left : 16; - ctr_incr(Y); - sm4_encrypt(key, Y, block); - gmssl_memxor(pout, pin, block, len); - pin += len; - pout += len; - left -= len; - } + ctr_incr(Y); + sm4_ctr_encrypt(key, Y, in, inlen, out); ghash(H, aad, aadlen, out, inlen, H); gmssl_memxor(tag, T, H, taglen); + return 1; } @@ -71,9 +61,6 @@ int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, const uint8_t *tag, size_t taglen, uint8_t *out) { - const uint8_t *pin = in; - uint8_t *pout = out; - size_t left = inlen; uint8_t H[16] = {0}; uint8_t Y[16]; uint8_t T[16]; @@ -89,6 +76,7 @@ int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, } ghash(H, aad, aadlen, in, inlen, H); + sm4_encrypt(key, Y, T); gmssl_memxor(T, T, H, taglen); if (memcmp(T, tag, taglen) != 0) { @@ -96,16 +84,9 @@ int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, return -1; } - while (left) { - uint8_t block[16]; - size_t len = left < 16 ? left : 16; - ctr_incr(Y); - sm4_encrypt(key, Y, block); - gmssl_memxor(pout, pin, block, len); - pin += len; - pout += len; - left -= len; - } + ctr_incr(Y); + sm4_ctr_encrypt(key, Y, in, inlen, out); + return 1; }