diff --git a/signxml/verifier.py b/signxml/verifier.py index f17be69..a1edb7c 100644 --- a/signxml/verifier.py +++ b/signxml/verifier.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, replace from typing import Callable, FrozenSet, List, Optional, Union +import cryptography.exceptions from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa, utils from cryptography.hazmat.primitives.asymmetric.padding import MGF1, PSS, AsymmetricPadding, PKCS1v15 @@ -114,14 +115,20 @@ def _get_signature(self, root): def _verify_signature_with_pubkey( self, + *, signed_info_c14n: bytes, raw_signature: bytes, - key_value: etree._Element, - der_encoded_key_value: Optional[etree._Element], signature_alg: SignatureMethod, + key_value: Optional[etree._Element] = None, + der_encoded_key_value: Optional[etree._Element] = None, + signing_certificate: Optional[x509.Certificate] = None, ) -> None: if der_encoded_key_value is not None: key = load_der_public_key(b64decode(der_encoded_key_value.text)) # type: ignore + elif signing_certificate is not None: + key = signing_certificate.public_key() + elif key_value is None: + raise InvalidInput("Expected one of key_value, der_encoded_key_value, or signing_certificate to be set") digest_alg_impl = digest_algorithm_implementations[signature_alg]() if signature_alg.name.startswith("ECDSA_"): @@ -137,8 +144,8 @@ def _verify_signature_with_pubkey( key = ecpn.public_key() elif not isinstance(key, ec.EllipticCurvePublicKey): raise InvalidInput("DER encoded key value does not match specified signature algorithm") - dss_signature = self._encode_dss_signature(raw_signature, key.key_size) - key.verify(dss_signature, data=signed_info_c14n, signature_algorithm=ec.ECDSA(digest_alg_impl)) + signature_for_ecdsa = self._encode_dss_signature(raw_signature, key.key_size) + key.verify(signature_for_ecdsa, data=signed_info_c14n, signature_algorithm=ec.ECDSA(digest_alg_impl)) elif signature_alg.name.startswith("DSA_"): if key_value is not None: dsa_key_value = self._find(key_value, "DSAKeyValue") @@ -167,7 +174,7 @@ def _verify_signature_with_pubkey( padding = PSS(mgf=MGF1(algorithm=digest_alg_impl), salt_length=digest_alg_impl.digest_size) key.verify(raw_signature, data=signed_info_c14n, padding=padding, algorithm=digest_alg_impl) else: - raise NotImplementedError() + raise InvalidInput(f"Unsupported signature algorithm {signature_alg}") def _encode_dss_signature(self, raw_signature: bytes, key_size_bits: int) -> bytes: want_raw_signature_len = bits_to_bytes_unit(key_size_bits) * 2 @@ -226,16 +233,6 @@ def _apply_transforms(self, payload, *, transforms_node: etree._Element, signatu def get_cert_chain_verifier(self, ca_pem_file, ca_path): return X509CertChainVerifier(ca_pem_file=ca_pem_file, ca_path=ca_path) - def _verify_signature_with_public_key_impl(self, *, signature, data, public_key, signature_alg_impl): - verify_args = dict(signature=signature, data=data) - if isinstance(public_key, rsa.RSAPublicKey): - verify_args.update(padding=PKCS1v15(), algorithm=signature_alg_impl) - elif isinstance(public_key, dsa.DSAPublicKey): - verify_args.update(algorithm=signature_alg_impl) - elif isinstance(public_key, ec.EllipticCurvePublicKey): - verify_args.update(signature_algorithm=ec.ECDSA(signature_alg_impl)) - public_key.verify(**verify_args) - def verify( self, data, @@ -406,17 +403,15 @@ def verify( if subject_cn_from_signing_cert != cert_subject_name: raise InvalidSignature("Certificate subject common name mismatch") - if signature_alg.name.startswith("ECDSA"): - raw_signature = self._encode_dss_signature(raw_signature, signing_cert.public_key().key_size) - - cert_public_key = signing_cert.public_key() - signature_alg_impl = digest_algorithm_implementations[signature_alg]() - self._verify_signature_with_public_key_impl( - signature=raw_signature, - data=signed_info_c14n, - public_key=cert_public_key, - signature_alg_impl=signature_alg_impl, - ) + try: + self._verify_signature_with_pubkey( + signed_info_c14n=signed_info_c14n, + raw_signature=raw_signature, + signing_certificate=signing_cert, + signature_alg=signature_alg, + ) + except cryptography.exceptions.InvalidSignature as e: + raise InvalidSignature(f"Signature verification failed: {e}") # If both X509Data and KeyValue are present, match one against the other and raise an error on mismatch if key_value is not None: