diff --git a/libc-bottom-half/cloudlibc/src/libc/poll/poll.c b/libc-bottom-half/cloudlibc/src/libc/poll/poll.c index c80532d06..dd030cbe8 100644 --- a/libc-bottom-half/cloudlibc/src/libc/poll/poll.c +++ b/libc-bottom-half/cloudlibc/src/libc/poll/poll.c @@ -140,6 +140,11 @@ typedef struct { static int poll_preview2(struct pollfd* fds, size_t nfds, int timeout) { + int event_count = 0; + for (size_t i = 0; i < nfds; ++i) { + fds[i].revents = 0; + } + size_t max_pollables = (2 * nfds) + 1; state_t states[max_pollables]; size_t state_index = 0; @@ -153,30 +158,41 @@ static int poll_preview2(struct pollfd* fds, size_t nfds, int timeout) switch (socket->state.tag) { case TCP_SOCKET_STATE_CONNECTING: { if ((pollfd->events & (POLLRDNORM | POLLWRNORM)) != 0) { - state_t state = { .pollable = socket->socket_pollable, + states[state_index++] = (state_t) { + .pollable = socket->socket_pollable, .pollfd = pollfd, .socket = socket, - .events = pollfd->events }; - states[state_index++] = state; + .events = pollfd->events + }; } break; } case TCP_SOCKET_STATE_CONNECTED: { if ((pollfd->events & POLLRDNORM) != 0) { - state_t state = { .pollable = socket->state.connected.input_pollable, + states[state_index++] = (state_t) { + .pollable = socket->state.connected.input_pollable, .pollfd = pollfd, .socket = socket, - .events = POLLRDNORM }; - states[state_index++] = state; + .events = POLLRDNORM + }; } if ((pollfd->events & POLLWRNORM) != 0) { - state_t state = { .pollable = socket->state.connected.output_pollable, + states[state_index++] = (state_t) { + .pollable = socket->state.connected.output_pollable, .pollfd = pollfd, .socket = socket, - .events = POLLWRNORM }; - states[state_index++] = state; + .events = POLLWRNORM + }; + } + break; + } + + case TCP_SOCKET_STATE_CONNECT_FAILED: { + if (pollfd->revents == 0) { + ++event_count; } + pollfd->revents |= pollfd->events; break; } @@ -197,6 +213,10 @@ static int poll_preview2(struct pollfd* fds, size_t nfds, int timeout) } } + if (event_count > 0 && timeout != 0) { + return event_count; + } + poll_borrow_pollable_t pollables[state_index + 1]; for (size_t i = 0; i < state_index; ++i) { pollables[i] = poll_borrow_pollable(states[i].pollable); @@ -213,11 +233,6 @@ static int poll_preview2(struct pollfd* fds, size_t nfds, int timeout) poll_list_borrow_pollable_t list = { .ptr = (poll_borrow_pollable_t*)&pollables, .len = pollable_count }; poll_poll(&list, &ready); - for (size_t i = 0; i < nfds; ++i) { - fds[i].revents = 0; - } - - int event_count = 0; for (size_t i = 0; i < ready.len; ++i) { size_t index = ready.ptr[i]; if (index < state_index) { @@ -231,12 +246,15 @@ static int poll_preview2(struct pollfd* fds, size_t nfds, int timeout) streams_own_pollable_t input_pollable = streams_method_input_stream_subscribe(input_stream_borrow); streams_borrow_output_stream_t output_stream_borrow = streams_borrow_output_stream(tuple.f1); streams_own_pollable_t output_pollable = streams_method_output_stream_subscribe(output_stream_borrow); - state->socket->state = (tcp_socket_state_t){ .tag = TCP_SOCKET_STATE_CONNECTED, .connected = { - .input_pollable = input_pollable, - .input = tuple.f0, - .output_pollable = output_pollable, - .output = tuple.f1, - } }; + state->socket->state = (tcp_socket_state_t) { + .tag = TCP_SOCKET_STATE_CONNECTED, + .connected = { + .input_pollable = input_pollable, + .input = tuple.f0, + .output_pollable = output_pollable, + .output = tuple.f1, + } + }; if (state->pollfd->revents == 0) { ++event_count; } @@ -244,9 +262,12 @@ static int poll_preview2(struct pollfd* fds, size_t nfds, int timeout) } else if (error == NETWORK_ERROR_CODE_WOULD_BLOCK) { // No events yet -- application will need to poll again } else { - state->socket->state = (tcp_socket_state_t){ .tag = TCP_SOCKET_STATE_CONNECT_FAILED, .connect_failed = { - .error_code = error, - } }; + state->socket->state = (tcp_socket_state_t) { + .tag = TCP_SOCKET_STATE_CONNECT_FAILED, + .connect_failed = { + .error_code = error, + } + }; if (state->pollfd->revents == 0) { ++event_count; } @@ -297,7 +318,7 @@ int poll(struct pollfd* fds, nfds_t nfds, int timeout) } else if (found_non_socket) { return poll_preview1(fds, nfds, timeout); } else if (timeout >= 0) { - return poll_preview2(fds, nfds, timeout); + return poll_preview2(fds, nfds, timeout); } else { errno = ENOTSUP; return -1;