Skip to content

Commit

Permalink
Allow set_tmp_ecdh to take cryptography elliptic curves (#1327)
Browse files Browse the repository at this point in the history
Deprecate `get_elliptic_curves` and `get_elliptic_curve`
  • Loading branch information
alex authored Jul 31, 2024
1 parent fb6d150 commit 8c42c52
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 5 deletions.
36 changes: 33 additions & 3 deletions src/OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import os
import socket
import typing
import warnings
from errno import errorcode
from functools import partial, wraps
from itertools import chain, count
from sys import platform
from typing import Any, Callable, List, Optional, Sequence, TypeVar
from weakref import WeakValueDictionary

from cryptography.hazmat.primitives.asymmetric import ec

from OpenSSL._util import (
StrOrBytesPath as _StrOrBytesPath,
)
Expand Down Expand Up @@ -1358,17 +1361,44 @@ def load_tmp_dh(self, dhfile: _StrOrBytesPath) -> None:
res = _lib.SSL_CTX_set_tmp_dh(self._context, dh)
_openssl_assert(res == 1)

def set_tmp_ecdh(self, curve: _EllipticCurve) -> None:
def set_tmp_ecdh(self, curve: _EllipticCurve | ec.EllipticCurve) -> None:
"""
Select a curve to use for ECDHE key exchange.
:param curve: A curve object to use as returned by either
:param curve: A curve instance from cryptography
(:class:`~cryptogragraphy.hazmat.primitives.asymmetric.ec.EllipticCurve`).
Alternatively (deprecated) a curve object from either
:meth:`OpenSSL.crypto.get_elliptic_curve` or
:meth:`OpenSSL.crypto.get_elliptic_curves`.
:return: None
"""
_lib.SSL_CTX_set_tmp_ecdh(self._context, curve._to_EC_KEY())

if isinstance(curve, _EllipticCurve):
warnings.warn(
(
"Passing pyOpenSSL elliptic curves to set_tmp_ecdh is "
"deprecated. You should use cryptography's elliptic curve "
"types instead."
),
DeprecationWarning,
stacklevel=2,
)
_lib.SSL_CTX_set_tmp_ecdh(self._context, curve._to_EC_KEY())
else:
name = curve.name
if name == "secp192r1":
name = "prime192v1"
elif name == "secp256r1":
name = "prime256v1"
nid = _lib.OBJ_txt2nid(name.encode())
if nid == _lib.NID_undef:
_raise_current_error()

ec = _lib.EC_KEY_new_by_curve_name(nid)
_openssl_assert(ec != _ffi.NULL)
ec = _ffi.gc(ec, _lib.EC_KEY_free)
_lib.SSL_CTX_set_tmp_ecdh(self._context, ec)

def set_cipher_list(self, cipher_list: bytes) -> None:
"""
Expand Down
28 changes: 27 additions & 1 deletion src/OpenSSL/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,20 @@ def get_elliptic_curves() -> set[_EllipticCurve]:
return _EllipticCurve._get_elliptic_curves(_lib)


_get_elliptic_curves_internal = get_elliptic_curves

utils.deprecated(
get_elliptic_curves,
__name__,
(
"get_elliptic_curves is deprecated. You should use the APIs in "
"cryptography instead."
),
DeprecationWarning,
name="get_elliptic_curves",
)


def get_elliptic_curve(name: str) -> _EllipticCurve:
"""
Return a single curve object selected by name.
Expand All @@ -588,12 +602,24 @@ def get_elliptic_curve(name: str) -> _EllipticCurve:
If the named curve is not supported then :py:class:`ValueError` is raised.
"""
for curve in get_elliptic_curves():
for curve in _get_elliptic_curves_internal():
if curve.name == name:
return curve
raise ValueError("unknown curve name", name)


utils.deprecated(
get_elliptic_curve,
__name__,
(
"get_elliptic_curve is deprecated. You should use the APIs in "
"cryptography instead."
),
DeprecationWarning,
name="get_elliptic_curve",
)


@functools.total_ordering
class X509Name:
"""
Expand Down
10 changes: 9 additions & 1 deletion tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import pytest
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.x509.oid import NameOID
from pretend import raiser

Expand Down Expand Up @@ -1685,6 +1685,14 @@ def test_set_tmp_ecdh(self):
continue
# The only easily "assertable" thing is that it does not raise an
# exception.
with pytest.deprecated_call():
context.set_tmp_ecdh(curve)

for name in dir(ec.EllipticCurveOID):
if name.startswith("_"):
continue
oid = getattr(ec.EllipticCurveOID, name)
curve = ec.get_curve_for_oid(oid)
context.set_tmp_ecdh(curve)

def test_set_session_cache_mode_wrong_args(self):
Expand Down

0 comments on commit 8c42c52

Please sign in to comment.