Skip to content

Commit

Permalink
Adjust SM3 SM4 API
Browse files Browse the repository at this point in the history
Remove sm3_digest. Use more _gmssl_export
  • Loading branch information
guanzhi committed Apr 19, 2024
1 parent ab7c9a7 commit 8cb306a
Show file tree
Hide file tree
Showing 17 changed files with 132 additions and 130 deletions.
16 changes: 15 additions & 1 deletion include/gmssl/sm3.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <string.h>
#include <stdint.h>
#include <gmssl/api.h>

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -70,11 +71,24 @@ void sm3_kdf(const uint8_t *in, size_t inlen, size_t outlen, uint8_t *out);
#define SM3_PBKDF2_MAX_SALT_SIZE 64
#define SM3_PBKDF2_DEFAULT_SALT_SIZE 8

int sm3_pbkdf2(const char *pass, size_t passlen,
_gmssl_export int sm3_pbkdf2(const char *pass, size_t passlen,
const uint8_t *salt, size_t saltlen, size_t count,
size_t outlen, uint8_t *out);


typedef struct {
union {
SM3_CTX sm3_ctx;
SM3_HMAC_CTX hmac_ctx;
void *handle;
};
int state;
} SM3_DIGEST_CTX;

_gmssl_export int sm3_digest_init(SM3_DIGEST_CTX *ctx, const uint8_t *key, size_t keylen);
_gmssl_export int sm3_digest_update(SM3_DIGEST_CTX *ctx, const uint8_t *data, size_t datalen);
_gmssl_export int sm3_digest_finish(SM3_DIGEST_CTX *ctx, uint8_t dgst[SM3_DIGEST_SIZE]);


#ifdef __cplusplus
}
Expand Down
140 changes: 60 additions & 80 deletions include/gmssl/sm4.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <stdint.h>
#include <string.h>
#include <gmssl/ghash.h>
#include <gmssl/api.h>

#ifdef __cplusplus
extern "C" {
Expand All @@ -31,48 +33,72 @@ typedef struct {
void sm4_set_encrypt_key(SM4_KEY *key, const uint8_t raw_key[SM4_KEY_SIZE]);
void sm4_set_decrypt_key(SM4_KEY *key, const uint8_t raw_key[SM4_KEY_SIZE]);
void sm4_encrypt(const SM4_KEY *key, const uint8_t in[SM4_BLOCK_SIZE], uint8_t out[SM4_BLOCK_SIZE]);
#define sm4_decrypt(key,in,out) sm4_encrypt(key,in,out)


void sm4_cbc_encrypt(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE],
const uint8_t *in, size_t nblocks, uint8_t *out);
void sm4_cbc_decrypt(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE],
const uint8_t *in, size_t nblocks, uint8_t *out);

int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);

typedef struct {
union {
SM4_KEY sm4_key;
void *handle;
};
uint8_t iv[SM4_BLOCK_SIZE];
uint8_t block[SM4_BLOCK_SIZE];
size_t block_nbytes;
} SM4_CBC_CTX;

_gmssl_export int sm4_cbc_encrypt_init(SM4_CBC_CTX *ctx, const uint8_t key[SM4_KEY_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]);
_gmssl_export int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm4_cbc_encrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen);
_gmssl_export int sm4_cbc_decrypt_init(SM4_CBC_CTX *ctx, const uint8_t key[SM4_KEY_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]);
_gmssl_export int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm4_cbc_decrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen);


void sm4_ctr_encrypt(const SM4_KEY *key, uint8_t ctr[SM4_BLOCK_SIZE],
const uint8_t *in, size_t inlen, uint8_t *out);
#define sm4_ctr_decrypt(key,ctr,in,inlen,out) sm4_ctr_encrypt(key,ctr,in,inlen,out)

typedef struct {
union {
SM4_KEY sm4_key;
void *handle;
};
uint8_t ctr[SM4_BLOCK_SIZE];
uint8_t block[SM4_BLOCK_SIZE];
size_t block_nbytes;
} SM4_CTR_CTX;

#define SM4_GCM_IV_MIN_SIZE 1
#define SM4_GCM_IV_MAX_SIZE (((uint64_t)1 << (64-3)) - 1) // 2305843009213693951

#define SM4_GCM_IV_DEFAULT_BITS 96
#define SM4_GCM_IV_DEFAULT_SIZE 12
_gmssl_export int sm4_ctr_encrypt_init(SM4_CTR_CTX *ctx, const uint8_t key[SM4_KEY_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]);
_gmssl_export int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm4_ctr_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen);

//#define NIST_SP800_GCM_MAX_IV_SIZE (((uint64_t)1 << (64-3)) - 1) // 2305843009213693951

#define NIST_SP800_GCM_MAX_IV_SIZE (((uint64_t)1 << (64-3)) - 1) // 2305843009213693951
#define SM4_GCM_MAX_IV_SIZE 64
#define SM4_GCM_MIN_IV_SIZE 1
#define SM4_GCM_DEFAULT_IV_SIZE 12

#define NIST_SP800_GCM_MAX_AAD_SIZE (((uint64_t)1 << (64-3)) - 1) // 2305843009213693951
#define SM4_GCM_MIN_AAD_SIZE 0
#define SM4_GCM_MAX_AAD_SIZE (((uint64_t)1 << (64-3)) - 1) // 2305843009213693951
#define SM4_GCM_MAX_AAD_SIZE (1<<24) // 16MiB

#define SM4_GCM_MIN_PLAINTEXT_SIZE 0
#define SM4_GCM_MAX_PLAINTEXT_SIZE ((((uint64_t)1 << 39) - 256) >> 3) // 68719476704

#define SM4_GCM_MAX_TAG_SIZE 16
#define SM4_GCM_MIN_TAG_SIZE 12
#define SM4_GCM_DEFAULT_TAG_SIZE 16
// For certain applications (voice or video), tag may be 64 or 32 bits
// see NIST Special Publication 800-38D, Appendix C for more details


int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
uint8_t *out, size_t taglen, uint8_t *tag);
Expand All @@ -81,47 +107,6 @@ int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen,
const uint8_t *tag, size_t taglen, uint8_t *out);


typedef struct {
union {
SM4_KEY sm4_key;
void *handle;
};
uint8_t iv[SM4_BLOCK_SIZE];
uint8_t block[SM4_BLOCK_SIZE];
size_t block_nbytes;
} SM4_CBC_CTX;

int sm4_cbc_encrypt_init(SM4_CBC_CTX *ctx, const uint8_t key[SM4_KEY_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]);
int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_cbc_encrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen);

int sm4_cbc_decrypt_init(SM4_CBC_CTX *ctx, const uint8_t key[SM4_KEY_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]);
int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_cbc_decrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen);


typedef struct {
union {
SM4_KEY sm4_key;
void *handle;
};
uint8_t ctr[SM4_BLOCK_SIZE];
uint8_t block[SM4_BLOCK_SIZE];
size_t block_nbytes;
} SM4_CTR_CTX;

int sm4_ctr_encrypt_init(SM4_CTR_CTX *ctx, const uint8_t key[SM4_KEY_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]);
int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_ctr_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen);

#define sm4_ctr_decrypt_init(ctx,key,ctr) sm4_ctr_encrypt_init(ctx,key,ctr)
#define sm4_ctr_decrypt_update(ctx,in,inlen,out,outlen) sm4_ctr_encrypt_update(ctx,in,inlen,out,outlen)
#define sm4_ctr_decrypt_finish(ctx,out,outlen) sm4_ctr_encrypt_finish(ctx,out,outlen)


#include <gmssl/ghash.h>
#include <gmssl/api.h>

typedef struct {
SM4_CTR_CTX enc_ctx;
GHASH_CTX mac_ctx;
Expand All @@ -131,9 +116,6 @@ typedef struct {
size_t maclen;
} SM4_GCM_CTX;

#define SM4_GCM_KEY_SIZE 16
#define SM4_GCM_DEFAULT_TAG_SIZE 16

_gmssl_export int sm4_gcm_encrypt_init(SM4_GCM_CTX *ctx,
const uint8_t *key, size_t keylen, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, size_t taglen);
Expand All @@ -150,8 +132,6 @@ _gmssl_export int sm4_gcm_decrypt_finish(SM4_GCM_CTX *ctx,
uint8_t *out, size_t *outlen);




#ifdef ENABLE_SM4_ECB
// call `sm4_set_decrypt_key` before decrypt
void sm4_ecb_encrypt(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out);
Expand All @@ -171,7 +151,7 @@ int sm4_ecb_decrypt_init(SM4_ECB_CTX *ctx, const uint8_t key[SM4_BLOCK_SIZE]);
int sm4_ecb_decrypt_update(SM4_ECB_CTX *ctx,
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_ecb_decrypt_finish(SM4_ECB_CTX *ctx, uint8_t *out, size_t *outlen);
#endif
#endif // ENABLE_SM4_ECB


#ifdef ENABLE_SM4_OFB
Expand All @@ -186,12 +166,12 @@ typedef struct {
size_t block_nbytes;
} SM4_OFB_CTX;

int sm4_ofb_encrypt_init(SM4_OFB_CTX *ctx,
_gmssl_export int sm4_ofb_encrypt_init(SM4_OFB_CTX *ctx,
const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]);
int sm4_ofb_encrypt_update(SM4_OFB_CTX *ctx,
_gmssl_export int sm4_ofb_encrypt_update(SM4_OFB_CTX *ctx,
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_ofb_encrypt_finish(SM4_OFB_CTX *ctx, uint8_t *out, size_t *outlen);
#endif
_gmssl_export int sm4_ofb_encrypt_finish(SM4_OFB_CTX *ctx, uint8_t *out, size_t *outlen);
#endif // ENABLE_SM4_OFB


#ifdef ENABLE_SM4_CFB
Expand All @@ -213,18 +193,18 @@ typedef struct {
size_t sbytes;
} SM4_CFB_CTX;

int sm4_cfb_encrypt_init(SM4_CFB_CTX *ctx, size_t sbytes,
_gmssl_export int sm4_cfb_encrypt_init(SM4_CFB_CTX *ctx, size_t sbytes,
const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]);
int sm4_cfb_encrypt_update(SM4_CFB_CTX *ctx,
_gmssl_export int sm4_cfb_encrypt_update(SM4_CFB_CTX *ctx,
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_cfb_encrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen);
_gmssl_export int sm4_cfb_encrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen);

int sm4_cfb_decrypt_init(SM4_CFB_CTX *ctx, size_t sbytes,
_gmssl_export int sm4_cfb_decrypt_init(SM4_CFB_CTX *ctx, size_t sbytes,
const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]);
int sm4_cfb_decrypt_update(SM4_CFB_CTX *ctx,
_gmssl_export int sm4_cfb_decrypt_update(SM4_CFB_CTX *ctx,
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_cfb_decrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen);
#endif
_gmssl_export int sm4_cfb_decrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen);
#endif // ENABLE_SM4_CFB


#ifdef ENABLE_SM4_CCM
Expand All @@ -234,21 +214,21 @@ int sm4_cfb_decrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen);
#define SM4_CCM_MAX_MAC_SIZE 16

// make sure inlen < 2^((15 - ivlen) * 8)
int sm4_ccm_encrypt(const SM4_KEY *sm4_key, const uint8_t *iv, size_t ivlen,
_gmssl_export int sm4_ccm_encrypt(const SM4_KEY *sm4_key, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
uint8_t *out, size_t taglen, uint8_t *tag);
int sm4_ccm_decrypt(const SM4_KEY *sm4_key, const uint8_t *iv, size_t ivlen,
_gmssl_export int sm4_ccm_decrypt(const SM4_KEY *sm4_key, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
const uint8_t *tag, size_t taglen, uint8_t *out);
#endif
#endif // ENABLE_SM4_CCM


#ifdef ENABLE_SM4_XTS
// call `sm4_set_encrypt_key` to set both `key1` and `key2`
int sm4_xts_encrypt(const SM4_KEY *key1, const SM4_KEY *key2, const uint8_t tweak[16],
_gmssl_export int sm4_xts_encrypt(const SM4_KEY *key1, const SM4_KEY *key2, const uint8_t tweak[16],
const uint8_t *in, size_t inlen, uint8_t *out);
// call `sm4_set_decrypt_key(key1)` and `sm4_set_encrypt_key(key2)`
int sm4_xts_decrypt(const SM4_KEY *key1, const SM4_KEY *key2, const uint8_t tweak[16],
_gmssl_export int sm4_xts_decrypt(const SM4_KEY *key1, const SM4_KEY *key2, const uint8_t tweak[16],
const uint8_t *in, size_t inlen, uint8_t *out);

typedef struct {
Expand All @@ -260,13 +240,13 @@ typedef struct {
size_t block_nbytes;
} SM4_XTS_CTX;

int sm4_xts_encrypt_init(SM4_XTS_CTX *ctx, const uint8_t key[32], const uint8_t iv[16], size_t data_unit_size);
int sm4_xts_encrypt_update(SM4_XTS_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_xts_encrypt_finish(SM4_XTS_CTX *ctx, uint8_t *out, size_t *outlen);
int sm4_xts_decrypt_init(SM4_XTS_CTX *ctx, const uint8_t key[32], const uint8_t iv[16], size_t data_unit_size);
int sm4_xts_decrypt_update(SM4_XTS_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int sm4_xts_decrypt_finish(SM4_XTS_CTX *ctx, uint8_t *out, size_t *outlen);
#endif
_gmssl_export int sm4_xts_encrypt_init(SM4_XTS_CTX *ctx, const uint8_t key[32], const uint8_t iv[16], size_t data_unit_size);
_gmssl_export int sm4_xts_encrypt_update(SM4_XTS_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm4_xts_encrypt_finish(SM4_XTS_CTX *ctx, uint8_t *out, size_t *outlen);
_gmssl_export int sm4_xts_decrypt_init(SM4_XTS_CTX *ctx, const uint8_t key[32], const uint8_t iv[16], size_t data_unit_size);
_gmssl_export int sm4_xts_decrypt_update(SM4_XTS_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm4_xts_decrypt_finish(SM4_XTS_CTX *ctx, uint8_t *out, size_t *outlen);
#endif // ENABLE_SM4_XTS


#ifdef __cplusplus
Expand Down
12 changes: 6 additions & 6 deletions src/digest.c
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ const DIGEST *digest_from_name(const char *name)
return NULL;
}

static int sm3_digest_init(DIGEST_CTX *ctx)
static int _sm3_digest_init(DIGEST_CTX *ctx)
{
if (!ctx) {
error_print();
Expand All @@ -128,7 +128,7 @@ static int sm3_digest_init(DIGEST_CTX *ctx)
return 1;
}

static int sm3_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen)
static int _sm3_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen)
{
if (!ctx || (!in && inlen != 0)) {
error_print();
Expand All @@ -138,7 +138,7 @@ static int sm3_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen)
return 1;
}

static int sm3_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst)
static int _sm3_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst)
{
if (!ctx || !dgst) {
error_print();
Expand All @@ -153,9 +153,9 @@ static const DIGEST sm3_digest_object = {
SM3_DIGEST_SIZE,
SM3_BLOCK_SIZE,
sizeof(SM3_CTX),
sm3_digest_init,
sm3_digest_update,
sm3_digest_finish,
_sm3_digest_init,
_sm3_digest_update,
_sm3_digest_finish,
};

const DIGEST *DIGEST_sm3(void)
Expand Down
11 changes: 9 additions & 2 deletions src/sm2_key.c
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,15 @@ int sm2_public_key_equ(const SM2_KEY *sm2_key, const SM2_KEY *pub_key)
int sm2_public_key_digest(const SM2_KEY *sm2_key, uint8_t dgst[32])
{
uint8_t bits[65];
sm2_z256_point_to_uncompressed_octets(&sm2_key->public_key, bits);
sm3_digest(bits, sizeof(bits), dgst);
SM3_CTX sm3_ctx;

if (sm2_z256_point_to_uncompressed_octets(&sm2_key->public_key, bits) != 1) {
error_print();
return -1;
}
sm3_init(&sm3_ctx);
sm3_update(&sm3_ctx, bits, sizeof(bits));
sm3_finish(&sm3_ctx, dgst);
return 1;
}

Expand Down
5 changes: 4 additions & 1 deletion src/sm2_z256.c
Original file line number Diff line number Diff line change
Expand Up @@ -1693,7 +1693,10 @@ int sm2_z256_point_from_hash(SM2_Z256_POINT *R, const uint8_t *data, size_t data

do {
// x = sm3(data) mod p
sm3_digest(data, datalen, dgst);
SM3_CTX sm3_ctx;
sm3_init(&sm3_ctx);
sm3_update(&sm3_ctx, data, datalen);
sm3_finish(&sm3_ctx, dgst);

sm2_z256_from_bytes(x, dgst);
if (sm2_z256_cmp(x, SM2_Z256_P) >= 0) {
Expand Down
2 changes: 2 additions & 0 deletions src/sm3.c
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ void sm3_finish(SM3_CTX *ctx, uint8_t *digest)
}
}

/*
void sm3_digest(const uint8_t *msg, size_t msglen,
uint8_t dgst[SM3_DIGEST_SIZE])
{
Expand All @@ -212,3 +213,4 @@ void sm3_digest(const uint8_t *msg, size_t msglen,
sm3_finish(&ctx, dgst);
memset(&ctx, 0, sizeof(ctx));
}
*/
2 changes: 1 addition & 1 deletion src/sm3_digest.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


#include <string.h>
#include <gmssl/sm3_digest.h>
#include <gmssl/sm3.h>
#include <gmssl/error.h>


Expand Down
Loading

0 comments on commit 8cb306a

Please sign in to comment.