diff --git a/lib/Crypto/SelfTest/PublicKey/test_import_ECC.py b/lib/Crypto/SelfTest/PublicKey/test_import_ECC.py index f9222c86d..1c86bc597 100644 --- a/lib/Crypto/SelfTest/PublicKey/test_import_ECC.py +++ b/lib/Crypto/SelfTest/PublicKey/test_import_ECC.py @@ -163,6 +163,19 @@ def create_ref_keys_ed448(): return (key, key.public_key()) +def create_ref_keys_sm2(): + key_len = 32 + key_lines = load_file("ecc_sm2.txt").splitlines() + private_key_d = bytes_to_long(compact(key_lines[2:5])) + public_key_xy = compact(key_lines[6:11]) + assert bord(public_key_xy[0]) == 4 # Uncompressed + public_key_x = bytes_to_long(public_key_xy[1:key_len+1]) + public_key_y = bytes_to_long(public_key_xy[key_len+1:]) + + return (ECC.construct(curve="sm2", d=private_key_d), + ECC.construct(curve="sm2", point_x=public_key_x, point_y=public_key_y)) + + # Create reference key pair # ref_private, ref_public = create_ref_keys_p521() @@ -814,6 +827,117 @@ def test_import_openssh_private_password(self): self.assertEqual(key, key_old) +class TestImport_SM2(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super(TestImport_SM2, self).__init__(*args, **kwargs) + self.ref_private, self.ref_public = create_ref_keys_sm2() + + def test_import_public_der(self): + key_file = load_file("ecc_sm2_public.der") + + key = ECC._import_subjectPublicKeyInfo(key_file) + self.assertEqual(self.ref_public, key) + + key = ECC._import_der(key_file, None) + self.assertEqual(self.ref_public, key) + + key = ECC.import_key(key_file) + self.assertEqual(self.ref_public, key) + + def test_import_sec1_uncompressed(self): + key_file = load_file("ecc_sm2_public.der") + value = extract_bitstring_from_spki(key_file) + key = ECC.import_key(key_file, curve_name='sm2') + self.assertEqual(self.ref_public, key) + + def test_import_sec1_compressed(self): + key_file = load_file("ecc_sm2_public_compressed.der") + value = extract_bitstring_from_spki(key_file) + key = ECC.import_key(key_file, curve_name='sm2') + self.assertEqual(self.ref_public, key) + + def test_import_rfc5915_der(self): + key_file = load_file("ecc_sm2_private.der") + + key = ECC._import_rfc5915_der(key_file, None) + self.assertEqual(self.ref_private, key) + + key = ECC._import_der(key_file, None) + self.assertEqual(self.ref_private, key) + + key = ECC.import_key(key_file) + self.assertEqual(self.ref_private, key) + + def test_import_private_pkcs8_clear(self): + key_file = load_file("ecc_sm2_private_p8_clear.der") + + key = ECC._import_der(key_file, None) + self.assertEqual(self.ref_private, key) + + key = ECC.import_key(key_file) + self.assertEqual(self.ref_private, key) + + def test_import_private_pkcs8_in_pem_clear(self): + key_file = load_file("ecc_sm2_private_p8_clear.pem") + + key = ECC.import_key(key_file) + self.assertEqual(self.ref_private, key) + + def test_import_private_pkcs8_encrypted_1(self): + key_file = load_file("ecc_sm2_private_p8.der") + + key = ECC._import_der(key_file, "secret") + self.assertEqual(self.ref_private, key) + + key = ECC.import_key(key_file, "secret") + self.assertEqual(self.ref_private, key) + + def test_import_private_pkcs8_encrypted_2(self): + key_file = load_file("ecc_sm2_private_p8.pem") + + key = ECC.import_key(key_file, "secret") + self.assertEqual(self.ref_private, key) + + def test_import_x509_der(self): + key_file = load_file("ecc_sm2_x509.der") + + key = ECC._import_der(key_file, None) + self.assertEqual(self.ref_public, key) + + key = ECC.import_key(key_file) + self.assertEqual(self.ref_public, key) + + def test_import_public_pem(self): + key_file = load_file("ecc_sm2_public.pem") + + key = ECC.import_key(key_file) + self.assertEqual(self.ref_public, key) + + def test_import_private_pem(self): + key_file = load_file("ecc_sm2_private.pem") + + key = ECC.import_key(key_file) + self.assertEqual(self.ref_private, key) + + def test_import_private_pem_encrypted(self): + for algo in "des3", "aes128", "aes192", "aes256", "aes256_gcm": + key_file = load_file("ecc_sm2_private_enc_%s.pem" % algo) + + key = ECC.import_key(key_file, "secret") + self.assertEqual(self.ref_private, key) + + key = ECC.import_key(tostr(key_file), b"secret") + self.assertEqual(self.ref_private, key) + + def test_import_x509_pem(self): + key_file = load_file("ecc_sm2_x509.pem") + + key = ECC.import_key(key_file) + self.assertEqual(self.ref_public, key) + + + class TestExport_P192(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -2613,6 +2737,258 @@ def test_error_params1(self): passphrase="secret") +class TestExport_SM2(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super(TestExport_SM2, self).__init__(*args, **kwargs) + self.ref_private, self.ref_public = create_ref_keys_sm2() + + def test_export_public_der_uncompressed(self): + key_file = load_file("ecc_sm2_public.der") + + encoded = self.ref_public._export_subjectPublicKeyInfo(False) + self.assertEqual(key_file, encoded) + + encoded = self.ref_public.export_key(format="DER") + self.assertEqual(key_file, encoded) + + encoded = self.ref_public.export_key(format="DER", compress=False) + self.assertEqual(key_file, encoded) + + def test_export_public_der_compressed(self): + key_file = load_file("ecc_sm2_public.der") + pub_key = ECC.import_key(key_file) + key_file_compressed = pub_key.export_key(format="DER", compress=True) + + key_file_compressed_ref = load_file("ecc_sm2_public_compressed.der") + self.assertEqual(key_file_compressed, key_file_compressed_ref) + + def test_export_public_sec1_uncompressed(self): + key_file = load_file("ecc_sm2_public.der") + value = extract_bitstring_from_spki(key_file) + + encoded = self.ref_public.export_key(format="SEC1") + self.assertEqual(value, encoded) + + def test_export_public_sec1_compressed(self): + key_file = load_file("ecc_sm2_public.der") + encoded = self.ref_public.export_key(format="SEC1", compress=True) + + key_file_compressed_ref = load_file("ecc_sm2_public_compressed.der") + value = extract_bitstring_from_spki(key_file_compressed_ref) + self.assertEqual(value, encoded) + + def test_export_rfc5915_private_der(self): + key_file = load_file("ecc_sm2_private.der") + + encoded = self.ref_private._export_rfc5915_private_der() + self.assertEqual(key_file, encoded) + + # --- + + encoded = self.ref_private.export_key(format="DER", use_pkcs8=False) + self.assertEqual(key_file, encoded) + + def test_export_private_pkcs8_clear(self): + key_file = load_file("ecc_sm2_private_p8_clear.der") + + encoded = self.ref_private._export_pkcs8() + self.assertEqual(key_file, encoded) + + # --- + + encoded = self.ref_private.export_key(format="DER") + self.assertEqual(key_file, encoded) + + def test_export_private_pkcs8_encrypted(self): + encoded = self.ref_private._export_pkcs8(passphrase="secret", + protection="PBKDF2WithHMAC-SHA1AndAES128-CBC") + + # This should prove that the output is password-protected + self.assertRaises(ValueError, ECC._import_pkcs8, encoded, None) + + decoded = ECC._import_pkcs8(encoded, "secret") + self.assertEqual(self.ref_private, decoded) + + # --- + + encoded = self.ref_private.export_key(format="DER", + passphrase="secret", + protection="PBKDF2WithHMAC-SHA1AndAES128-CBC") + decoded = ECC.import_key(encoded, "secret") + self.assertEqual(self.ref_private, decoded) + + def test_export_public_pem_uncompressed(self): + key_file = load_file("ecc_sm2_public.pem", "rt").strip() + + encoded = self.ref_private._export_public_pem(False) + self.assertEqual(key_file, encoded) + + # --- + + encoded = self.ref_public.export_key(format="PEM") + self.assertEqual(key_file, encoded) + + encoded = self.ref_public.export_key(format="PEM", compress=False) + self.assertEqual(key_file, encoded) + + def test_export_public_pem_compressed(self): + key_file = load_file("ecc_sm2_public.pem", "rt").strip() + pub_key = ECC.import_key(key_file) + + key_file_compressed = pub_key.export_key(format="PEM", compress=True) + key_file_compressed_ref = load_file("ecc_sm2_public_compressed.pem", "rt").strip() + + self.assertEqual(key_file_compressed, key_file_compressed_ref) + + def test_export_private_pem_clear(self): + key_file = load_file("ecc_sm2_private.pem", "rt").strip() + + encoded = self.ref_private._export_private_pem(None) + self.assertEqual(key_file, encoded) + + # --- + + encoded = self.ref_private.export_key(format="PEM", use_pkcs8=False) + self.assertEqual(key_file, encoded) + + def test_export_private_pem_encrypted(self): + encoded = self.ref_private._export_private_pem(passphrase=b"secret") + + # This should prove that the output is password-protected + self.assertRaises(ValueError, ECC.import_key, encoded) + + assert "EC PRIVATE KEY" in encoded + + decoded = ECC.import_key(encoded, "secret") + self.assertEqual(self.ref_private, decoded) + + # --- + + encoded = self.ref_private.export_key(format="PEM", + passphrase="secret", + use_pkcs8=False) + decoded = ECC.import_key(encoded, "secret") + self.assertEqual(self.ref_private, decoded) + + def test_export_private_pkcs8_and_pem_1(self): + # PKCS8 inside PEM with both unencrypted + key_file = load_file("ecc_sm2_private_p8_clear.pem", "rt").strip() + + encoded = self.ref_private._export_private_clear_pkcs8_in_clear_pem() + self.assertEqual(key_file, encoded) + + # --- + + encoded = self.ref_private.export_key(format="PEM") + self.assertEqual(key_file, encoded) + + def test_export_private_pkcs8_and_pem_2(self): + # PKCS8 inside PEM with PKCS8 encryption + encoded = self.ref_private._export_private_encrypted_pkcs8_in_clear_pem("secret", + protection="PBKDF2WithHMAC-SHA1AndAES128-CBC") + + # This should prove that the output is password-protected + self.assertRaises(ValueError, ECC.import_key, encoded) + + assert "ENCRYPTED PRIVATE KEY" in encoded + + decoded = ECC.import_key(encoded, "secret") + self.assertEqual(self.ref_private, decoded) + + # --- + + encoded = self.ref_private.export_key(format="PEM", + passphrase="secret", + protection="PBKDF2WithHMAC-SHA1AndAES128-CBC") + decoded = ECC.import_key(encoded, "secret") + self.assertEqual(self.ref_private, decoded) + + def test_prng(self): + # Test that password-protected containers use the provided PRNG + encoded1 = self.ref_private.export_key(format="PEM", + passphrase="secret", + protection="PBKDF2WithHMAC-SHA1AndAES128-CBC", + randfunc=get_fixed_prng()) + encoded2 = self.ref_private.export_key(format="PEM", + passphrase="secret", + protection="PBKDF2WithHMAC-SHA1AndAES128-CBC", + randfunc=get_fixed_prng()) + self.assertEqual(encoded1, encoded2) + + # --- + + encoded1 = self.ref_private.export_key(format="PEM", + use_pkcs8=False, + passphrase="secret", + randfunc=get_fixed_prng()) + encoded2 = self.ref_private.export_key(format="PEM", + use_pkcs8=False, + passphrase="secret", + randfunc=get_fixed_prng()) + self.assertEqual(encoded1, encoded2) + + def test_byte_or_string_passphrase(self): + encoded1 = self.ref_private.export_key(format="PEM", + use_pkcs8=False, + passphrase="secret", + randfunc=get_fixed_prng()) + encoded2 = self.ref_private.export_key(format="PEM", + use_pkcs8=False, + passphrase=b"secret", + randfunc=get_fixed_prng()) + self.assertEqual(encoded1, encoded2) + + def test_error_params1(self): + # Unknown format + self.assertRaises(ValueError, self.ref_private.export_key, format="XXX") + + # Missing 'protection' parameter when PKCS#8 is used + self.ref_private.export_key(format="PEM", passphrase="secret", + use_pkcs8=False) + self.assertRaises(ValueError, self.ref_private.export_key, format="PEM", + passphrase="secret") + + # DER format but no PKCS#8 + self.assertRaises(ValueError, self.ref_private.export_key, format="DER", + passphrase="secret", + use_pkcs8=False, + protection="PBKDF2WithHMAC-SHA1AndAES128-CBC") + + # Incorrect parameters for public keys + self.assertRaises(ValueError, self.ref_public.export_key, format="DER", + use_pkcs8=False) + + # Empty password + self.assertRaises(ValueError, self.ref_private.export_key, format="PEM", + passphrase="", use_pkcs8=False) + self.assertRaises(ValueError, self.ref_private.export_key, format="PEM", + passphrase="", + protection="PBKDF2WithHMAC-SHA1AndAES128-CBC") + + def test_compressed_curve(self): + + # Compressed sm2 curve (Y-point is even) + pem1 = """-----BEGIN EC PRIVATE KEY----- + MFcCAQEEIN509xPy6uRGyYkiy0t1VBG7kGewTnEZGp6QEUR3guX3oAoGCCqBHM9V + AYItoSQDIgACVf1bhKVDUXZ3yWD6LZ41i/4yaJzWtEuujrCw+4rTIfU= + -----END EC PRIVATE KEY-----""" + + # Compressed sm2 curve (Y-point is odd) + pem2 = """-----BEGIN EC PRIVATE KEY----- + MFcCAQEEIKaNq594sjs8C65Pn+B4pjFPc7nN5ZIxCjTLSZXCRG4ioAoGCCqBHM9V + AYItoSQDIgADL4goDzlQBtdaSRyKWiWmYyyBGR7z/btnQVQlRKG66nU= + -----END EC PRIVATE KEY-----""" + + key1 = ECC.import_key(pem1) + low16 = int(key1.pointQ.y % 65536) + self.assertEqual(low16, 0xF810) + + key2 = ECC.import_key(pem2) + low16 = int(key2.pointQ.y % 65536) + self.assertEqual(low16, 0x6C81) + + def get_tests(config={}): tests = [] tests += list_test_cases(TestImport) @@ -2624,6 +3000,7 @@ def get_tests(config={}): tests += list_test_cases(TestImport_P521) tests += list_test_cases(TestImport_Ed25519) tests += list_test_cases(TestImport_Ed448) + tests += list_test_cases(TestImport_SM2) tests += list_test_cases(TestExport_P192) tests += list_test_cases(TestExport_P224) @@ -2632,6 +3009,7 @@ def get_tests(config={}): tests += list_test_cases(TestExport_P521) tests += list_test_cases(TestExport_Ed25519) tests += list_test_cases(TestExport_Ed448) + tests += list_test_cases(TestExport_SM2) except MissingTestVectorException: pass