Skip to content

Commit

Permalink
update kdf module
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozg committed Dec 30, 2023
1 parent 2b17687 commit 9e73d9c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 9 deletions.
51 changes: 49 additions & 2 deletions src/kdf.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,31 @@ static int openssl_kdf_fetch(lua_State *L)

return 1;
}
#endif

/***
compute KDF delive
compute KDF delive, openssl version >= v3
@function deilver
@tparam evp_kdf|string kdf
@tparam table array of paramaters
@treturn string result binary string
*/

/***
compute KDF delive, openssl version < v3
@function deilver
@tparam string pass
@tparam string salt
@tparam string|object|nid digest
@tparam[opt=1000] number iterator
@tparam[opt=32] number keylen
@treturn string deilved result binary string
*/
static int openssl_kdf_derive(lua_State *L)
{
#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
EVP_KDF *kdf = get_kdf(L, 1);
OSSL_PARAM *params = openssl_toparams(L, 2);
unsigned char key[64] = {0};
Expand All @@ -124,8 +138,35 @@ static int openssl_kdf_derive(lua_State *L)
EVP_KDF_CTX_free(ctx);
OPENSSL_free(params);
return ret;
#else
size_t passlen, saltlen;
const char* pass = luaL_checklstring (L, 1, &passlen);
const char* salt = luaL_checklstring (L, 2, &saltlen);
const EVP_MD* md = get_digest(L, 3, NULL);
int iter = luaL_optinteger(L, 4, 1000);
int keylen = luaL_optinteger(L, 5, 32);
unsigned char key[256] = {0};

luaL_argcheck(L, keylen <= sizeof(key), 5,
"out of support range, limited to 256");

int ret = PKCS5_PBKDF2_HMAC(pass, (int)passlen,
salt, (int)saltlen,
iter,
md,
keylen,
key);
if (ret==1)
{
lua_pushlstring(L, key, keylen);
} else
ret = openssl_pushresult(L, ret);

return ret;
#endif
}

#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
/***
openssl.kdf_ctx object
@type kdf_ctx
Expand Down Expand Up @@ -512,24 +553,30 @@ static luaL_Reg kdf_funs[] =

{NULL, NULL}
};
#endif

static const luaL_Reg kdf_R[] =
{
#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
{"fetch", openssl_kdf_fetch},
{"iterator", openssl_kdf_iterator_kdf},
#endif
{"derive", openssl_kdf_derive},

{NULL, NULL}
};

int luaopen_kdf(lua_State *L)
{
#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
auxiliar_newclass(L, "openssl.kdf", kdf_funs);
auxiliar_newclass(L, "openssl.kdf_ctx", kdf_ctx_funs);
#endif

lua_newtable(L);
luaL_setfuncs(L, kdf_R, 0);

#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
lua_pushliteral(L, "names");
lua_newtable(L);

Expand All @@ -550,7 +597,7 @@ int luaopen_kdf(lua_State *L)
#endif

lua_rawset(L, -3);
#endif

return 1;
}
#endif
2 changes: 1 addition & 1 deletion src/openssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,9 @@ LUALIB_API int luaopen_openssl(lua_State*L)
lua_setfield(L, -2, "mac");
luaopen_param(L);
lua_setfield(L, -2, "param");
#endif
luaopen_kdf(L);
lua_setfield(L, -2, "kdf");
#endif

luaopen_pkey(L);
lua_setfield(L, -2, "pkey");
Expand Down
2 changes: 1 addition & 1 deletion src/openssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ LUA_FUNCTION(luaopen_dh);
#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
LUA_FUNCTION(luaopen_mac);
LUA_FUNCTION(luaopen_param);
LUA_FUNCTION(luaopen_kdf);
#endif
LUA_FUNCTION(luaopen_kdf);

void openssl_add_method_or_alias(const OBJ_NAME *name, void *arg) ;
void openssl_add_method(const OBJ_NAME *name, void *arg);
Expand Down
25 changes: 20 additions & 5 deletions test/2.kdf.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@ local lu = require 'luaunit'

local openssl = require 'openssl'
local kdf = require'openssl'.kdf
if not kdf then
return
end

TestKDF = {}

function TestKDF:testDerive()
local pwd = "1234567890";
local salt = "0987654321"
local md = 'sha256'
local iter = 4096
local keylen = 32

local key = assert(kdf.derive(pwd, salt, md, iter, keylen))
print('key', key)
assert(key)
assert(#key == 32)
end

function TestKDF:testBasic()
if not kdf.iterator then return end
kdf.iterator(function(k)
assert(k:name())
assert(k)
Expand All @@ -25,8 +36,10 @@ function TestKDF:testBasic()
end

function TestKDF:testPBKDF2()
if not kdf.fetch then return end

local pwd = "1234567890";
local salt = "0987654321" -- <D-s>getSalt(pwd)
local salt = "0987654321" -- getSalt(pwd)
local pbkdf2 = kdf.fetch('PBKDF2')
local t = assert(pbkdf2:settable_ctx_params())
local key = assert(pbkdf2:derive({
Expand Down Expand Up @@ -59,8 +72,10 @@ function TestKDF:testPBKDF2()
end

function TestKDF:testPBKDF2CTX()
if not kdf.fetch then return end

local pwd = "1234567890";
local salt = "0987654321" -- <D-s>getSalt(pwd)
local salt = "0987654321" -- getSalt(pwd)
local pbkdf2 = kdf.fetch('PBKDF2')
local ctx = assert(pbkdf2:new())

Expand Down

0 comments on commit 9e73d9c

Please sign in to comment.