diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index 3a74a53b..a090fce4 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -1926,7 +1926,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) { SSL *ssl; int ilen; - VALUE len, str; + VALUE len, str, cb_state; VALUE opts = Qnil; if (nonblock) { @@ -1959,6 +1959,14 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) rb_str_locktmp(str); for (;;) { int nread = SSL_read(ssl, RSTRING_PTR(str), ilen); + + cb_state = rb_attr_get(self, ID_callback_state); + if (!NIL_P(cb_state)) { + rb_ivar_set(self, ID_callback_state, Qnil); + ossl_clear_error(); + rb_jump_tag(NUM2INT(cb_state)); + } + switch (ssl_get_error(ssl, nread)) { case SSL_ERROR_NONE: rb_str_unlocktmp(str); @@ -2048,7 +2056,7 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) SSL *ssl; rb_io_t *fptr; int num, nonblock = opts != Qfalse; - VALUE tmp; + VALUE tmp, cb_state; GetSSL(self, ssl); if (!ssl_started(ssl)) @@ -2065,6 +2073,14 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) for (;;) { int nwritten = SSL_write(ssl, RSTRING_PTR(tmp), num); + + cb_state = rb_attr_get(self, ID_callback_state); + if (!NIL_P(cb_state)) { + rb_ivar_set(self, ID_callback_state, Qnil); + ossl_clear_error(); + rb_jump_tag(NUM2INT(cb_state)); + } + switch (ssl_get_error(ssl, nwritten)) { case SSL_ERROR_NONE: return INT2NUM(nwritten); diff --git a/test/openssl/test_ssl_session.rb b/test/openssl/test_ssl_session.rb index 89cf672a..d9b49a20 100644 --- a/test/openssl/test_ssl_session.rb +++ b/test/openssl/test_ssl_session.rb @@ -219,11 +219,11 @@ def test_server_session_cache # deadlock. TEST_SESSION_REMOVE_CB = ENV["OSSL_TEST_ALL"] == "1" - def test_ctx_client_session_cb - ctx_proc = proc { |ctx| ctx.ssl_version = :TLSv1_2 } - start_server(ctx_proc: ctx_proc) do |port| + def test_ctx_client_session_cb_tls12 + start_server do |port| called = {} ctx = OpenSSL::SSL::SSLContext.new + ctx.min_version = ctx.max_version = :TLS1_2 ctx.session_cache_mode = OpenSSL::SSL::SSLContext::SESSION_CACHE_CLIENT ctx.session_new_cb = lambda { |ary| sock, sess = ary @@ -233,7 +233,6 @@ def test_ctx_client_session_cb ctx.session_remove_cb = lambda { |ary| ctx, sess = ary called[:remove] = [ctx, sess] - # any resulting value is OK (ignored) } end @@ -241,8 +240,8 @@ def test_ctx_client_session_cb assert_equal(1, ctx.session_cache_stats[:cache_num]) assert_equal(1, ctx.session_cache_stats[:connect_good]) assert_equal([ssl, ssl.session], called[:new]) - assert(ctx.session_remove(ssl.session)) - assert(!ctx.session_remove(ssl.session)) + assert_equal(true, ctx.session_remove(ssl.session)) + assert_equal(false, ctx.session_remove(ssl.session)) if TEST_SESSION_REMOVE_CB assert_equal([ctx, ssl.session], called[:remove]) end @@ -250,6 +249,50 @@ def test_ctx_client_session_cb end end + def test_ctx_client_session_cb_tls13 + omit "TLS 1.3 not supported" unless tls13_supported? + omit "LibreSSL does not call session_new_cb in TLS 1.3" if libressl? + + start_server do |port| + called = {} + ctx = OpenSSL::SSL::SSLContext.new + ctx.min_version = :TLS1_3 + ctx.session_cache_mode = OpenSSL::SSL::SSLContext::SESSION_CACHE_CLIENT + ctx.session_new_cb = lambda { |ary| + sock, sess = ary + called[:new] = [sock, sess] + } + + server_connect_with_session(port, ctx, nil) { |ssl| + ssl.puts("abc"); assert_equal("abc\n", ssl.gets) + + assert_operator(1, :<=, ctx.session_cache_stats[:cache_num]) + assert_operator(1, :<=, ctx.session_cache_stats[:connect_good]) + assert_equal([ssl, ssl.session], called[:new]) + } + end + end + + def test_ctx_client_session_cb_tls13_exception + omit "TLS 1.3 not supported" unless tls13_supported? + omit "LibreSSL does not call session_new_cb in TLS 1.3" if libressl? + + start_server do |port| + ctx = OpenSSL::SSL::SSLContext.new + ctx.min_version = :TLS1_3 + ctx.session_cache_mode = OpenSSL::SSL::SSLContext::SESSION_CACHE_CLIENT + ctx.session_new_cb = lambda { |ary| + raise "in session_new_cb" + } + + server_connect_with_session(port, ctx, nil) { |ssl| + assert_raise_with_message(RuntimeError, /in session_new_cb/) { + ssl.puts("abc"); assert_equal("abc\n", ssl.gets) + } + } + end + end + def test_ctx_server_session_cb connections = nil called = {}