Skip to content

Commit

Permalink
Require explicit socket port reuse
Browse files Browse the repository at this point in the history
Doing it implicitly can lead to mistaken socket leaks and reuse.
It now matches CPython.

Fixes adafruit#8443
  • Loading branch information
tannewt committed Feb 16, 2024
1 parent 02a4886 commit 9f3987a
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 27 deletions.
12 changes: 5 additions & 7 deletions ports/espressif/common-hal/socketpool/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_o
}
}

bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
size_t common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
const char *host, size_t hostlen, uint32_t port) {
struct sockaddr_in bind_addr;
const char *broadcast = "<broadcast>";
Expand All @@ -351,13 +351,11 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
bind_addr.sin_family = AF_INET;
bind_addr.sin_port = htons(port);

int opt = 1;
int err = lwip_setsockopt(self->num, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
if (err != 0) {
mp_raise_RuntimeError(MP_ERROR_TEXT("Cannot set socket options"));
}
int result = lwip_bind(self->num, (struct sockaddr *)&bind_addr, sizeof(bind_addr));
return result == 0;
if (result == 0) {
return 0;
}
return errno;
}

void socketpool_socket_close(socketpool_socket_obj_t *self) {
Expand Down
7 changes: 3 additions & 4 deletions ports/raspberrypi/common-hal/socketpool/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_o
return MP_OBJ_FROM_PTR(accepted);
}

bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket,
size_t common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket,
const char *host, size_t hostlen, uint32_t port) {

// get address
Expand All @@ -876,7 +876,6 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket,
} else {
bind_addr_ptr = IP_ANY_TYPE;
}
ip_set_option(socket->pcb.ip, SOF_REUSEADDR);

err_t err = ERR_ARG;
switch (socket->type) {
Expand All @@ -891,10 +890,10 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket,
}

if (err != ERR_OK) {
mp_raise_OSError(error_lookup_table[-err]);
return error_lookup_table[-err];
}

return mp_const_none;
return 0;
}

STATIC err_t _lwip_tcp_close_poll(void *arg, struct tcp_pcb *pcb) {
Expand Down
6 changes: 3 additions & 3 deletions shared-bindings/socketpool/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ STATIC mp_obj_t socketpool_socket_bind(mp_obj_t self_in, mp_obj_t addr_in) {
mp_raise_ValueError(MP_ERROR_TEXT("port must be >= 0"));
}

bool ok = common_hal_socketpool_socket_bind(self, host, hostlen, (uint32_t)port);
if (!ok) {
mp_raise_ValueError(MP_ERROR_TEXT("Error: Failure to bind"));
size_t error = common_hal_socketpool_socket_bind(self, host, hostlen, (uint32_t)port);
if (error != 0) {
mp_raise_OSError(error);
}

return mp_const_none;
Expand Down
2 changes: 1 addition & 1 deletion shared-bindings/socketpool/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
extern const mp_obj_type_t socketpool_socket_type;

socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_obj_t *self, uint8_t *ip, uint32_t *port);
bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
size_t common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
void common_hal_socketpool_socket_close(socketpool_socket_obj_t *self);
void common_hal_socketpool_socket_connect(socketpool_socket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
bool common_hal_socketpool_socket_get_closed(socketpool_socket_obj_t *self);
Expand Down
8 changes: 8 additions & 0 deletions shared-bindings/socketpool/SocketPool.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ MP_DEFINE_EXCEPTION(gaierror, OSError)
//| SOCK_RAW: int
//| EAI_NONAME: int
//|
//| SOL_SOCKET: int
//|
//| SO_REUSEADDR: int
//|
//| TCP_NODELAY: int
//|
//| IPPROTO_IP: int
Expand Down Expand Up @@ -196,6 +200,10 @@ STATIC const mp_rom_map_elem_t socketpool_socketpool_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_SOCK_DGRAM), MP_ROM_INT(SOCKETPOOL_SOCK_DGRAM) },
{ MP_ROM_QSTR(MP_QSTR_SOCK_RAW), MP_ROM_INT(SOCKETPOOL_SOCK_RAW) },

{ MP_ROM_QSTR(MP_QSTR_SOL_SOCKET), MP_ROM_INT(SOCKETPOOL_SOL_SOCKET) },

{ MP_ROM_QSTR(MP_QSTR_SO_REUSEADDR), MP_ROM_INT(SOCKETPOOL_SO_REUSEADDR) },

{ MP_ROM_QSTR(MP_QSTR_TCP_NODELAY), MP_ROM_INT(SOCKETPOOL_TCP_NODELAY) },

{ MP_ROM_QSTR(MP_QSTR_IPPROTO_IP), MP_ROM_INT(SOCKETPOOL_IPPROTO_IP) },
Expand Down
8 changes: 8 additions & 0 deletions shared-bindings/socketpool/SocketPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ typedef enum {
SOCKETPOOL_TCP_NODELAY = 1,
} socketpool_socketpool_tcpopt_t;

typedef enum {
SOCKETPOOL_SOL_SOCKET = 0xfff,
} socketpool_socketpool_optlevel_t;

typedef enum {
SOCKETPOOL_SO_REUSEADDR = 0x0004,
} socketpool_socketpool_socketopt_t;

typedef enum {
SOCKETPOOL_IP_MULTICAST_TTL = 5,
} socketpool_socketpool_ipopt_t;
Expand Down
24 changes: 12 additions & 12 deletions shared-module/ssl/SSLSocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons
(void)level;
mp_printf(&mp_plat_print, "DBG:%s:%04d: %s\n", file, line, str);
}
#define DEBUG(fmt, ...) mp_printf(&mp_plat_print, "DBG:%s:%04d: " fmt "\n", __FILE__, __LINE__,##__VA_ARGS__)
#define DEBUG_PRINT(fmt, ...) mp_printf(&mp_plat_print, "DBG:%s:%04d: " fmt "\n", __FILE__, __LINE__,##__VA_ARGS__)
#else
#define DEBUG(...) do {} while (0)
#define DEBUG_PRINT(...) do {} while (0)
#endif

STATIC NORETURN void mbedtls_raise_error(int err) {
Expand Down Expand Up @@ -107,10 +107,10 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {

// mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
mp_int_t out_sz = socketpool_socket_send(sock, buf, len);
DEBUG("socket_send() -> %d", out_sz);
DEBUG_PRINT("socket_send() -> %d", out_sz);
if (out_sz < 0) {
int err = -out_sz;
DEBUG("sock_stream->write() -> %d nonblocking? %d", out_sz, mp_is_nonblocking_error(err));
DEBUG_PRINT("sock_stream->write() -> %d nonblocking? %d", out_sz, mp_is_nonblocking_error(err));
if (mp_is_nonblocking_error(err)) {
return MBEDTLS_ERR_SSL_WANT_WRITE;
}
Expand All @@ -125,7 +125,7 @@ STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;

mp_int_t out_sz = socketpool_socket_recv_into(sock, buf, len);
DEBUG("socket_recv() -> %d", out_sz);
DEBUG_PRINT("socket_recv() -> %d", out_sz);
if (out_sz < 0) {
int err = -out_sz;
if (mp_is_nonblocking_error(err)) {
Expand Down Expand Up @@ -261,14 +261,14 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t

mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len) {
int ret = mbedtls_ssl_read(&self->ssl, buf, len);
DEBUG("recv_into mbedtls_ssl_read() -> %d\n", ret);
DEBUG_PRINT("recv_into mbedtls_ssl_read() -> %d\n", ret);
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
DEBUG("returning %d\n", 0);
DEBUG_PRINT("returning %d\n", 0);
// end of stream
return 0;
}
if (ret >= 0) {
DEBUG("returning %d\n", ret);
DEBUG_PRINT("returning %d\n", ret);
return ret;
}
if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
Expand All @@ -279,15 +279,15 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t
// renegotiation.
ret = MP_EWOULDBLOCK;
}
DEBUG("raising errno [error case] %d\n", ret);
DEBUG_PRINT("raising errno [error case] %d\n", ret);
mp_raise_OSError(ret);
}

mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len) {
int ret = mbedtls_ssl_write(&self->ssl, buf, len);
DEBUG("send mbedtls_ssl_write() -> %d\n", ret);
DEBUG_PRINT("send mbedtls_ssl_write() -> %d\n", ret);
if (ret >= 0) {
DEBUG("returning %d\n", ret);
DEBUG_PRINT("returning %d\n", ret);
return ret;
}
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
Expand All @@ -298,7 +298,7 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t
// renegotiation.
ret = MP_EWOULDBLOCK;
}
DEBUG("raising errno [error case] %d\n", ret);
DEBUG_PRINT("raising errno [error case] %d\n", ret);
mp_raise_OSError(ret);
}

Expand Down

0 comments on commit 9f3987a

Please sign in to comment.