diff --git a/src/mqtt_sessions.erl b/src/mqtt_sessions.erl index 2016b52..dc7cbbe 100644 --- a/src/mqtt_sessions.erl +++ b/src/mqtt_sessions.erl @@ -1,7 +1,9 @@ %% @author Marc Worrell -%% @copyright 2018 Marc Worrell +%% @copyright 2018-2024 Marc Worrell +%% @doc Session management for a MQTT server. +%% @end -%% Copyright 2018 Marc Worrell +%% Copyright 2018-2024 Marc Worrell %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -31,8 +33,6 @@ find_session/1, find_session/2, - fetch_queue/1, - fetch_queue/2, session_count/1, router_info/1, @@ -147,17 +147,6 @@ session_count(Pool) -> router_info(Pool) -> mqtt_sessions_router:info(Pool). --spec fetch_queue( session_ref() ) -> {ok, list( mqtt_packet_map:mqtt_packet() | binary() )} | {error, notfound}. -fetch_queue(ClientId) -> - fetch_queue(?MQTT_SESSIONS_DEFAULT, ClientId). - --spec fetch_queue( atom(), session_ref() ) -> {ok, list( mqtt_packet_map:mqtt_packet() | binary() )} | {error, notfound}. -fetch_queue(Pool, ClientId) -> - case find_session(Pool, ClientId) of - {ok, Pid} -> mqtt_sessions_process:fetch_queue(Pid); - {error, _} = Error -> Error - end. - -spec get_user_context( session_ref() ) -> {ok, term()} | {error, notfound | noproc}. get_user_context(ClientId) -> get_user_context(?MQTT_SESSIONS_DEFAULT, ClientId). diff --git a/src/mqtt_sessions_process.erl b/src/mqtt_sessions_process.erl index 2a3effa..3cd0f77 100644 --- a/src/mqtt_sessions_process.erl +++ b/src/mqtt_sessions_process.erl @@ -1,7 +1,9 @@ -%% @doc Process handling one single MQTT session. -%% Transports attaches and detaches from this session. %% @author Marc Worrell %% @copyright 2018-2024 Marc Worrell +%% @doc Process handling one single MQTT session. +%% MQTT connections attach and detach from this session. Buffers outgoing +%% messages if there is not connection attached. +%% @end %% Copyright 2018-2024 Marc Worrell %% @@ -22,6 +24,9 @@ %% TODO: Refuse incoming publish messages if too many publish_jobs %% TODO: Limit incoming_data buffer size +%% Cleanup awaiting_rel for too old waiting +%% Refactor use of awaiting_ack to not use buffer + % MQTTv5 spec http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html % MQTTv3.1.1 spec http://docs.oasis-open.org/mqtt/mqtt/v5.0/cos01/mqtt-v5.0-cos01.html @@ -40,7 +45,6 @@ kill/1, incoming_connect/3, incoming_data/2, - fetch_queue/1, start_link/3 ]). @@ -56,12 +60,13 @@ -define(MAX_PACKET_ID, 65535). -define(RECEIVE_MAXIMUM, 65535). -define(KEEP_ALIVE_DEFAULT, 30). % Default keep alive in seconds --define(SESSION_EXPIRY, 600). % Default session expiration --define(SESSION_EXPIRY_DEFAULT, 3600). % Maximum allowed session expiration +-define(SESSION_EXPIRY, 900). % Default session expiration (15 minutes) +-define(SESSION_EXPIRY_MAX, 3600). % Maximum allowed session expiration (1 hour) -define(MESSAGE_EXPIRY_DEFAULT, 3600). +-define(ACK_EXPIRY, 600). --define(MAX_QUEUED, 500). % Max pending messages for any QoS --define(MAX_INFLIGHT_ACK, 250). % Max in-flight or pending messages waiting with QoS 1 or 2 +-define(MAX_BUFFERED, 500). % Max buffered QoS 0 messages +-define(MAX_INFLIGHT_ACK, 500). % Max in-flight QoS 1/2 messages -define(KILL_TIMEOUT, 5000). @@ -69,6 +74,24 @@ -type packet_id() :: 0..65535. % ?MAX_PACKET_ID +-record(queued, { + msg_nr :: pos_integer(), + type :: atom(), + packet_id = undefined :: undefined | packet_id(), + queued :: non_neg_integer(), + expiry :: non_neg_integer(), + qos = 0 :: 0..2, + message :: mqtt_packet_map:mqtt_packet() +}). + +-record(wait_for, { + msg_nr :: pos_integer(), + type :: atom(), + message = undefined :: undefined | mqtt_packet_map:mqtt_packet(), + is_sent = true :: boolean(), + queued :: non_neg_integer() +}). + -record(state, { protocol_version :: mqtt_packet_map:mqtt_version(), pool :: atom(), @@ -79,15 +102,15 @@ transport = undefined :: mqtt_sessions:transport() | undefined, connection_pid = undefined :: pid() | undefined, is_session_present = false :: boolean(), - pending_connack = undefined :: term(), - pending :: queue:queue(), + is_connected = false :: boolean(), + buffer = #{} :: #{ non_neg_integer() => #queued{} }, packet_id = 1 :: packet_id(), send_quota = ?RECEIVE_MAXIMUM :: non_neg_integer(), - awaiting_ack = #{} :: map(), % Initiated by server + awaiting_ack = #{} :: #{ non_neg_integer() => #wait_for{} }, % Initiated by server awaiting_rel = #{} :: map(), % Initiated by client will = undefined :: undefined | map(), will_pid = undefined :: undefined | pid(), - msg_nr = 0 :: non_neg_integer(), + msg_nr = 0 :: non_neg_integer(), % Incremental counter to keep the buffer in sequence keep_alive = ?KEEP_ALIVE_DEFAULT :: non_neg_integer(), keep_alive_counter = 3 :: integer(), keep_alive_ref :: undefined | reference(), @@ -103,16 +126,6 @@ publish_jobs = #{} :: map() }). --record(queued, { - type :: atom(), - msg_nr :: pos_integer(), - packet_id = undefined :: undefined | packet_id(), - queued :: non_neg_integer(), - expiry :: non_neg_integer(), - qos = 0 :: 0..2, - message :: mqtt_packet_map:mqtt_packet() -}). - -include_lib("kernel/include/logger.hrl"). -include_lib("mqtt_packet_map/include/mqtt_packet_map.hrl"). @@ -179,10 +192,6 @@ incoming_connect(Pid, Msg, Options) when is_map(Options) -> incoming_data(Pid, Data) -> gen_server:call(Pid, {incoming_data, Data, self()}). --spec fetch_queue(pid()) -> {ok, list( map() | binary() )}. -fetch_queue( Pid ) -> - gen_server:call(Pid, fetch_queue, infinity). - -spec start_link( Pool::atom(), ClientId::binary(), mqtt_sessions:session_options() ) -> {ok, pid()}. start_link( Pool, ClientId, SessionOptions ) -> gen_server:start_link(?MODULE, [ Pool, ClientId, SessionOptions ], []). @@ -209,19 +218,13 @@ init([ Pool, ClientId, SessionOptions ]) -> user_context = Runtime:new_user_context(Pool, ClientId, SessionOptions1), client_id = ClientId, routing_id = RoutingId, - pending = queue:new(), + buffer = #{}, will_pid = WillPid, keep_alive = ?KEEP_ALIVE_DEFAULT, keep_alive_counter = 3, keep_alive_ref = KeepAliveRef }}. -handle_call(fetch_queue, _From, #state{ pending_connack = undefined } = State) -> - Qs = [ Msg || #queued{ message = Msg } <- queue:to_list(State#state.pending) ], - {reply, {ok, encode(State#state.protocol_version, Qs)}, State#state{ pending = queue:new() }}; -handle_call(fetch_queue, _From, #state{ pending_connack = ConnAck } = State) -> - {reply, {ok, encode(State#state.protocol_version, ConnAck)}, State#state{ pending_connack = undefined }}; - handle_call(get_user_context, _From, #state{ user_context = UserContext } = State) -> {reply, {ok, UserContext}, State}; handle_call({set_user_context, UserContext}, _From, State) -> @@ -414,7 +417,7 @@ handle_incoming(#{ type := unsubscribe } = Msg, _Options, State) -> packet_unsubscribe(Msg, State); handle_incoming(#{ type := pingreq }, _Options, State) -> - State1 = reply(#{ type => pingresp }, State), + State1 = reply_or_drop(#{ type => pingresp }, State), {ok, State1}; handle_incoming(#{ type := pingresp }, _Options, State) -> {ok, State}; @@ -443,13 +446,13 @@ packet_connect(#{ protocol_version := V, protocol_name := <<"MQTT">> }, Options, type => connack, reason_code => ?MQTT_RC_NOT_AUTHORIZED }, - _ = reply(ConnAck, set_connection(Options, State)), + _ = reply_to_transport(ConnAck, set_connection(Options, State)), {error, protocol_version_changed}; packet_connect(#{ protocol_version := 5, protocol_name := <<"MQTT">>, properties := Props } = Msg, Options, State) -> % MQTT v5 ExpiryInterval = case maps:get(session_expiry_interval, Props, none) of - none -> ?SESSION_EXPIRY_DEFAULT; - EI -> EI + none -> ?SESSION_EXPIRY; + EI -> max(EI, ?SESSION_EXPIRY_MAX) end, KeepAlive = maps:get(keep_alive, Msg, ?KEEP_ALIVE_DEFAULT), StateIfAccept = State#state{ @@ -478,7 +481,7 @@ packet_connect(_ConnectMsg, Options, State) -> type => connack, reason_code => ?MQTT_RC_PROTOCOL_VERSION }, - _ = reply(ConnAck, set_connection(Options, State)), + _ = reply_to_transport(ConnAck, set_connection(Options, State)), {error, protocol_version}. packet_connect_auth(Msg, #state{ runtime = Runtime, user_context = UserContext } = State) -> @@ -494,7 +497,7 @@ handle_connect_auth_1({ok, #{ type := connack, reason_code := ?MQTT_RC_SUCCESS } #{ clean_start := CleanStart }, StateIfAccept, #state{ is_session_present = IsSessionPresent }) -> StateCleaned = maybe_clean_start(CleanStart, StateIfAccept), - %% Set the session_present flag to true, when the runtime omitted it, and when there is a + %% Set the session_present flag to true, if the runtime omitted it, and if there is a %% session present. ConnAck1 = case maps:find(session_present, ConnAck) of {ok, _} -> ConnAck; @@ -505,13 +508,15 @@ handle_connect_auth_1({ok, #{ type := connack, reason_code := ?MQTT_RC_SUCCESS } State1 = StateCleaned#state{ user_context = UserContext1, is_session_present = true, - will = undefined + is_connected = true, + will = undefined, + connect_count = StateCleaned#state.connect_count + 1 }, State2 = reply_connack(ConnAck1, State1), mqtt_sessions_will:connected(State2#state.will_pid, StateIfAccept#state.will, State2#state.session_expiry_interval, State2#state.user_context), - State3 = resend_unacknowledged( cleanup_pending_qos0(State2) ), + State3 = resend_buffered_and_unacknowledged(State2), {ok, State3}; handle_connect_auth_1({ok, #{ type := connack, reason_code := ReasonCode } = ConnAck, _UserContext1}, _Msg, StateIfAccept, _State) -> _ = reply_connack(ConnAck, StateIfAccept), @@ -526,9 +531,10 @@ handle_connect_auth_1({ok, #{ type := connack, reason_code := ReasonCode } = Con {error, connection_refused}; handle_connect_auth_1({ok, #{ type := auth } = Auth, UserContext1}, _Msg, StateIfAccept, _State) -> State1 = StateIfAccept#state{ - user_context = UserContext1 + user_context = UserContext1, + is_connected = true }, - State2 = reply(Auth, State1), + State2 = reply_or_drop(Auth, State1), mqtt_sessions_will:connected(State2#state.will_pid, undefined, State2#state.session_expiry_interval, State2#state.user_context), {ok, State2}; @@ -544,15 +550,19 @@ handle_connect_auth_1({error, Reason}, Msg, _StateIfAccept, _State) -> {error, connection_refused}. -%% @doc Drop all current subscriptions and pending messages on a clean start +%% @doc Drop all current subscriptions and buffered messages on a clean start maybe_clean_start(false, State) -> State; maybe_clean_start(true, #state{ pool = Pool } = State) -> mqtt_sessions_router:unsubscribe_pid(Pool, self()), - State#state{ pending = queue:new() }. + State#state{ + buffer = #{}, + awaiting_ack = #{}, + awaiting_rel = #{} + }. -%% @doc Handle a publish request +%% @doc Handle a publish request from remote to here packet_publish(#{ topic := Topic, qos := 0 } = Msg, #state{ runtime = Runtime, user_context = UCtx, client_id = ClientId } = State) -> case Topic of @@ -581,7 +591,7 @@ packet_publish(#{ topic := Topic, qos := 1, dup := Dup, packet_id := PacketId } packet_id => PacketId, reason_code => ?MQTT_RC_PACKET_ID_IN_USE }, - reply(PubAck, State); + reply_or_drop(PubAck, State); {ok, {pubrel, RC, _}} when Dup -> % There is a qos 2 level message with the same packet id % But the received mesage is a duplicate, just ack. @@ -590,7 +600,7 @@ packet_publish(#{ topic := Topic, qos := 1, dup := Dup, packet_id := PacketId } packet_id => PacketId, reason_code => RC }, - reply(PubAck, State); + reply_or_drop(PubAck, State); error -> RC = case Runtime:is_allowed(publish, Topic, Msg, UCtx) of true -> @@ -606,7 +616,7 @@ packet_publish(#{ topic := Topic, qos := 1, dup := Dup, packet_id := PacketId } packet_id => PacketId, reason_code => RC }, - State1 = reply(PubAck, State), + State1 = reply_or_drop(PubAck, State), {ok, State1} end; packet_publish(#{ topic := Topic, qos := 2, dup := Dup, packet_id := PacketId } = Msg, @@ -618,14 +628,14 @@ packet_publish(#{ topic := Topic, qos := 2, dup := Dup, packet_id := PacketId } packet_id => PacketId, reason_code => ?MQTT_RC_PACKET_ID_IN_USE }, - reply(PubRec, State); + reply_or_drop(PubRec, State); {ok, {pubrel, RC, _}} when Dup -> PubRec = #{ type => pubrec, packet_id => PacketId, reason_code => RC }, - State1 = reply(PubRec, State), + State1 = reply_or_drop(PubRec, State), {ok, State1}; error -> RC = case Runtime:is_allowed(publish, Topic, Msg, UCtx) of @@ -639,7 +649,9 @@ packet_publish(#{ topic := Topic, qos := 2, dup := Dup, packet_id := PacketId } end, State1 = if RC < 16#80 -> - State#state{ awaiting_rel = WaitRel#{ PacketId => {pubrel, RC, mqtt_sessions_timestamp:timestamp()} } }; + State#state{ + awaiting_rel = WaitRel#{ PacketId => {pubrel, RC, mqtt_sessions_timestamp:timestamp()} } + }; true -> State end, @@ -648,7 +660,7 @@ packet_publish(#{ topic := Topic, qos := 2, dup := Dup, packet_id := PacketId } packet_id => PacketId, reason_code => RC }, - State2 = reply(PubRec, State1), + State2 = reply_or_drop(PubRec, State1), {ok, State2} end. @@ -662,7 +674,7 @@ packet_pubrel(#{ packet_id := PacketId, reason_code := ?MQTT_RC_SUCCESS }, #stat reason_code => ?MQTT_RC_SUCCESS }, WaitRel1 = maps:remove(PacketId, WaitRel), - State1 = reply(PubComp, State), + State1 = reply_or_drop(PubComp, State), {ok, State1#state{ awaiting_rel = WaitRel1 }}; error -> PubComp = #{ @@ -670,7 +682,7 @@ packet_pubrel(#{ packet_id := PacketId, reason_code := ?MQTT_RC_SUCCESS }, #stat packet_id => PacketId, reason_code => ?MQTT_RC_PACKET_ID_NOT_FOUND }, - State1 = reply(PubComp, State), + State1 = reply_or_drop(PubComp, State), {ok, State1} end; packet_pubrel(#{ packet_id := PacketId, reason_code := RC }, #state{ awaiting_rel = WaitRel } = State) -> @@ -688,9 +700,11 @@ packet_pubrel(#{ packet_id := PacketId, reason_code := RC }, #state{ awaiting_re %% @doc Handle puback for QoS 1 publish messages sent to the client packet_puback(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) -> WaitAck1 = case maps:find(PacketId, WaitAck) of - {ok, {_MsgNr, puback, _Msg}} -> + {ok, #wait_for{ is_sent = false }} -> + WaitAck; + {ok, #wait_for{ type = puback }} -> maps:remove(PacketId, WaitAck); - {ok, {_MsgNr, Wait, Msg}} -> + {ok, #wait_for{ type = Wait, message = Msg }} -> ?LOG_WARNING(#{ in => mqtt_sessions, text => <<"PUBACK for message wating for something else - dropping pending ack">>, @@ -708,11 +722,13 @@ packet_puback(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = Sta %% @doc Handle pubrec for QoS 2 publish messages sent to the client packet_pubrec(#{ packet_id := PacketId, reason_code := RC }, #state{ awaiting_ack = WaitAck } = State) when RC >= 16#80 -> WaitAck1 = case maps:find(PacketId, WaitAck) of - {ok, {_MsgNr, pubrec, _Msg}} -> + {ok, #wait_for{ is_sent = false }} -> + WaitAck; + {ok, #wait_for{ type = pubrec }} -> maps:remove(PacketId, WaitAck); - {ok, {_MsgNr, pubcomp, _Msg}} -> + {ok, #wait_for{ type = pubcomp }} -> maps:remove(PacketId, WaitAck); - {ok, {_MsgNr, Wait, Msg}} -> + {ok, #wait_for{ type = Wait, message = Msg }} -> ?LOG_WARNING(#{ in => mqtt_sessions, text => <<"PUBREC for message wating for something else - dropping pending ack">>, @@ -728,11 +744,18 @@ packet_pubrec(#{ packet_id := PacketId, reason_code := RC }, #state{ awaiting_ac {ok, State#state{ awaiting_ack = WaitAck1 }}; packet_pubrec(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) -> {WaitAck1, RC} = case maps:find(PacketId, WaitAck) of - {ok, {MsgNr, pubrec, _Msg}} -> - {WaitAck#{ PacketId => {MsgNr, pubcomp, undefined} }, ?MQTT_RC_SUCCESS}; - {ok, {_MsgNr, pubcomp, _Msg}} -> + {ok, #wait_for{ msg_nr = MsgNr, type = pubrec }} -> + WaitFor = #wait_for{ + msg_nr = MsgNr, + type = pubcomp, + queued = mqtt_sessions_timestamp:timestamp() + }, + {WaitAck#{ PacketId => WaitFor }, ?MQTT_RC_SUCCESS}; + {ok, #wait_for{ type = pubcomp }} -> + {WaitAck, ?MQTT_RC_SUCCESS}; + {ok, #wait_for{ is_sent = false }} -> {WaitAck, ?MQTT_RC_SUCCESS}; - {ok, {_MsgNr, Wait, Msg}} -> + {ok, #wait_for{ type = Wait, message = Msg }} -> ?LOG_WARNING(#{ in => mqtt_sessions, text => <<"PUBREC for message wating for something else - dropping pending ack">>, @@ -751,14 +774,16 @@ packet_pubrec(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = Sta packet_id => PacketId, reason_code => RC }, - {ok, reply(PubRel, State1)}. + {ok, reply_or_drop(PubRel, State1)}. %% @doc Handle pubcomp for QoS 2 publish messages sent to the client packet_pubcomp(#{ packet_id := PacketId }, #state{ awaiting_ack = WaitAck } = State) -> WaitAck1 = case maps:find(PacketId, WaitAck) of - {ok, {_MsgNr, pubcomp, _Msg}} -> + {ok, #wait_for{ type = pubcomp }} -> maps:remove(PacketId, WaitAck); - {ok, {_MsgNr, Wait, Msg}} -> + {ok, #wait_for{ is_sent = false }} -> + WaitAck; + {ok, #wait_for{ type = Wait, message = Msg }} -> ?LOG_WARNING(#{ in => mqtt_sessions, text => <<"PUBCOMP for message wating for something else - dropping pending ack">>, @@ -805,7 +830,7 @@ packet_subscribe(#{ topics := Topics } = Msg, #state{ runtime = Runtime, user_co packet_id => maps:get(packet_id, Msg, 0), acks => Resp }, - State1 = reply(SubAck, State), + State1 = reply_or_drop(SubAck, State), {ok, State1}. %% @doc Handle the unsubscribe request @@ -823,7 +848,7 @@ packet_unsubscribe(#{ topics := Topics } = Msg, State) -> packet_id => maps:get(packet_id, Msg, 0), acks => Resp }, - State1 = reply(UnsubAck, State), + State1 = reply_or_drop(UnsubAck, State), {ok, State1}. @@ -861,10 +886,11 @@ relay_publish(#{ type := publish, message := Msg } = MqttMsg, State) -> StatePurged = maybe_purge(State), case QoS of 0 -> - reply(Msg2#{ packet_id => 0 }, StatePurged); + State2 = #state{ msg_nr = MsgNr } = inc_msg_nr(StatePurged), + reply_or_queue(Msg2#{ packet_id => 0 }, MsgNr, State2); _ -> case maps:size(StatePurged#state.awaiting_ack) >= ?MAX_INFLIGHT_ACK of - true -> + true when State#state.transport =/= undefined -> ?LOG_INFO(#{ in => mqtt_session, text => <<"Not accepting QoS 1/2 message, too many inflight or queued acks">>, @@ -872,6 +898,9 @@ relay_publish(#{ type := publish, message := Msg } = MqttMsg, State) -> reason => buffer_full }), StatePurged; + true -> + % Dormant session, just drop excess messages. + StatePurged; false -> State1 = #state{ packet_id = PacketId } = inc_packet_id(StatePurged), State2 = #state{ msg_nr = MsgNr } = inc_msg_nr(State1), @@ -882,10 +911,24 @@ relay_publish(#{ type := publish, message := Msg } = MqttMsg, State) -> Msg3 = Msg2#{ packet_id => PacketId }, - State3 = State2#state{ - awaiting_ack = (State2#state.awaiting_ack)#{ PacketId => {MsgNr, AckRec, Msg3} } + {IsSent, State3} = if + State2#state.transport =:= undefined -> + {false, State2}; + true -> + {true, reply_or_drop(Msg3, State2)} + end, + WaitFor = #wait_for{ + msg_nr = MsgNr, + message = Msg3, + type = AckRec, + is_sent = IsSent, + queued = mqtt_sessions_timestamp:timestamp() }, - reply(Msg3, State3) + State3#state{ + awaiting_ack = (State3#state.awaiting_ack)#{ + PacketId => WaitFor + } + } end end. @@ -894,37 +937,62 @@ relay_publish(#{ type := publish, message := Msg } = MqttMsg, State) -> % ------------------------------- queue functions --------------------------------------- % --------------------------------------------------------------------------------------- -cleanup_pending_qos0(#state{ pending = Pending } = State) -> - Pending1 = queue:filter(fun(#queued{ qos = QoS }) -> QoS > 0 end, Pending), - State#state{ pending = Pending1 }. +delete_buffered_qos0(#state{ buffer = Buffer } = State) -> + Buffer1 = maps:filter(fun(_MsgNr, #queued{ qos = QoS }) -> QoS > 0 end, Buffer), + State#state{ buffer = Buffer1 }. -resend_unacknowledged(#state{ awaiting_ack = AwaitAck } = State) -> - Msgs = maps:fold( +resend_buffered_and_unacknowledged(#state{ awaiting_ack = AwaitAck, buffer = Buffer } = State) -> + ResendMap = maps:fold( fun - (_PacketId, {MsgNr, pubrec, Msg}, Acc) -> - [ {MsgNr, Msg#{ dup => true }} | Acc ]; - (PacketId, {MsgNr, pubcomp, _Msg}, Acc) -> + (_PacketId, #wait_for{ is_sent = false, msg_nr = MsgNr, message = Msg }, Acc) -> + % Unsent QoS 1 or 2 message + Acc#{ MsgNr => Msg }; + (_PacketId, #wait_for{ msg_nr = MsgNr, type = puback, message = Msg }, Acc) -> + Acc#{ MsgNr => Msg#{ dup => true } }; + (_PacketId, #wait_for{ msg_nr = MsgNr, type = pubrec, message = Msg }, Acc) -> + Acc#{ MsgNr => Msg#{ dup => true } }; + (PacketId, #wait_for{ msg_nr = MsgNr, type = pubcomp }, Acc) -> PubComp = #{ type => pubrec, packet_id => PacketId }, - [ {MsgNr, PubComp} | Acc ]; - (_PacketId, {MsgNr, suback, Msg}, Acc) -> - [ {MsgNr, Msg} | Acc ]; - (_PacketId, {MsgNr, unsuback, Msg}, Acc) -> - [ {MsgNr, Msg} | Acc ]; + Acc#{ MsgNr => PubComp }; + (_PacketId, #wait_for{ msg_nr = MsgNr, type = suback, message = Msg }, Acc) -> + Acc#{ MsgNr => Msg }; + (_PacketId, #wait_for{ msg_nr = MsgNr, type = unsuback, message = Msg }, Acc) -> + Acc#{ MsgNr => Msg }; (_PacketId, _, Acc) -> Acc end, - [], + Buffer, AwaitAck), + ResendList = lists:sort(maps:to_list(ResendMap)), lists:foldl( - fun({_Nr, Msg}, StateAcc) -> - reply(Msg, StateAcc) + fun + (_Msg, #state{ transport = undefined } = AccState) -> + AccState; + ({_MsgNr, #{ type := publish, packet_id := PacketId } = Msg}, AccState) -> + AccState1 = mark_packet_sent(PacketId, AccState), + reply_or_drop(Msg, AccState1); + ({_MsgNr, #{ type := _} = Msg}, AccState) -> + reply_or_drop(Msg, AccState); + ({MsgNr, #queued{ message = Msg } = Q}, AccState) -> + case reply_or_drop(Msg, AccState) of + #state{ transport = undefined, buffer = AccBuffer } = AccState1 -> + AccBuffer1 = AccBuffer#{ MsgNr => Q }, + AccState1#state{ buffer = AccBuffer1 }; + #state{} = AccState1 -> + AccState1 + end end, - State, - lists:sort(Msgs)). + State#state{ buffer = #{} }, + ResendList). +mark_packet_sent(PacketId, #state{ awaiting_ack = AwaitAck } = State) -> + WaitFor = maps:get(PacketId, AwaitAck), + State#state{ + awaiting_ack = AwaitAck#{ PacketId => WaitFor#wait_for{ is_sent = true } } + }. % --------------------------------------------------------------------------------------- % -------------------------------- misc functions --------------------------------------- @@ -938,11 +1006,10 @@ do_disconnected(#state{ will_pid = WillPid } = State) -> %% @todo Cleanup pending messages and awaiting states. cleanup_state_disconnected(State) -> - cleanup_pending_qos0(State#state{ - pending_connack = undefined, + delete_buffered_qos0(State#state{ connection_pid = undefined, transport = undefined, - awaiting_rel = #{} + is_connected = false }). @@ -959,9 +1026,9 @@ reply_connack(#{ type := connack, reason_code := ?MQTT_RC_SUCCESS } = ConnAck, S <<"cotonic-routing-id">> => State#state.routing_id } }, - reply(ConnAck1, State); + reply_to_transport(ConnAck1, State); reply_connack(#{ type := connack } = ConnAck, State) -> - reply(ConnAck, State). + reply_to_transport(ConnAck, State). %% @doc Check the connect packet, extract the will as a map for the will-watchdog. @@ -986,8 +1053,16 @@ extract_will(#{ type := connect, will_flag := true, properties := Props } = Msg) retain => maps:get(will_retain, Msg, false) }. +force_disconnect(#state{ connection_pid = undefined, transport = undefined } = State) -> + State; force_disconnect(State) -> State1 = disconnect_transport(State), + if + is_pid(State#state.connection_pid) -> + State#state.connection_pid ! {mqtt_transport, self(), disconnect}; + true -> + ok + end, State2 = cleanup_state_disconnected(State1), case State2#state.is_session_present of false -> @@ -1001,21 +1076,43 @@ disconnect_transport(#state{ transport = undefined } = State) -> State; disconnect_transport(#state{ transport = Transport } = State) when is_pid(Transport) -> Transport ! {mqtt_transport, self(), disconnect}, - State#state{ transport = undefined }; + State#state{ transport = undefined, is_connected = false }; disconnect_transport(#state{ transport = Transport } = State) when is_function(Transport) -> Transport(disconnect), - State#state{ transport = undefined }. + State#state{ transport = undefined, is_connected = false }; +disconnect_transport(#state{ transport = {M, F, A} } = State) -> + erlang:apply(M, F, [disconnect | A]), + State#state{ transport = undefined, is_connected = false }. + +reply_to_transport(_Msg, #state{ transport = undefined } = State) -> + State; +reply_to_transport(Msg, State) -> + case send_transport(Msg, State) of + ok -> + State; + {error, _} -> + force_disconnect(State) + end. -reply(undefined, State) -> +reply_or_drop(_Msg, #state{ is_connected = false } = State) -> State; -reply(Msg, #state{ transport = undefined } = State) -> - maybe_purge( queue(Msg, State) ); -reply(Msg, State) -> +reply_or_drop(Msg, State) -> + case send_transport(Msg, State) of + ok -> + State; + {error, _} -> + force_disconnect(State) + end. + +reply_or_queue(Msg, MsgNr, #state{ is_connected = false } = State) -> + maybe_purge( queue(Msg, MsgNr, State) ); +reply_or_queue(Msg, MsgNr, State) -> case send_transport(Msg, State) of ok -> State; {error, _} -> - maybe_purge( queue(Msg, State#state{ transport = undefined }) ) + State1 = force_disconnect(State), + maybe_purge( queue(Msg, MsgNr, State1) ) end. send_transport(_Msg, #state{ transport = undefined }) -> @@ -1028,7 +1125,7 @@ send_transport(Msg, #state{ transport = Pid }) when is_pid(Pid) -> Pid ! {mqtt_transport, self(), Msg}, ok; false -> - ok + {error, transport_down} end; send_transport(Msg, #state{ transport = Fun }) when is_function(Fun) -> Fun(Msg); @@ -1037,14 +1134,7 @@ send_transport(Msg, #state{ transport = {M, F, A} }) -> %% @doc Queue a message, extract, type, message expiry, and QoS -queue(#{ type := connack } = Msg, State) -> - State#state{ pending_connack = Msg }; -queue(#{ type := auth } = Msg, State) -> - State#state{ pending_connack = Msg }; -queue(Msg, State) -> - queue_1(Msg, inc_msg_nr(State)). - -queue_1(#{ type := Type } = Msg, #state{ msg_nr = MsgNr, pending = Pending } = State) -> +queue(#{ type := Type } = Msg, MsgNr, #state{ buffer = Buffer } = State) -> Props = maps:get(properties, Msg, #{}), Now = mqtt_sessions_timestamp:timestamp(), Item = #queued{ @@ -1056,79 +1146,75 @@ queue_1(#{ type := Type } = Msg, #state{ msg_nr = MsgNr, pending = Pending } = S qos = maps:get(qos, Msg, 1), message = Msg }, - State#state{ pending = queue:in(Item, Pending) }. + State#state{ buffer = Buffer#{ MsgNr => Item } }. -maybe_purge(#state{ pending = Queue, awaiting_ack = WaitAcks } = State) -> - case queue:len(Queue) > ?MAX_QUEUED orelse maps:size(WaitAcks) > ?MAX_INFLIGHT_ACK of - true -> - PacketIdsBefore = queue:fold( - fun - (#queued{ qos = 0 }, Acc) -> Acc; - (#queued{ packet_id = PacketId }, Acc) -> [ PacketId | Acc ] - end, - [], - Queue), - PurgedQueue = purge(Queue), - PacketIdsAfter = queue:fold( - fun - (#queued{ qos = 0 }, Acc) -> Acc; - (#queued{ packet_id = PacketId }, Acc) -> [ PacketId | Acc ] - end, - [], - PurgedQueue), - PurgedPacketIds = PacketIdsBefore -- PacketIdsAfter, - PurgedWaitAcks = maps:without(PurgedPacketIds, WaitAcks), - ?LOG_DEBUG(#{ - in => mqtt_sessions, - text => <<"Purged pending messages">>, - result => ok, - pending_before => queue:len(Queue), - pending_after => queue:len(PurgedQueue), - dropped_acks => length(PurgedPacketIds), - pending_acks => maps:size(PurgedWaitAcks) - }), - State#state{ - pending = PurgedQueue, - awaiting_ack = PurgedWaitAcks - }; - false -> - State - end. +maybe_purge(#state{ buffer = Buffer, awaiting_ack = WaitAcks } = State) -> + State#state{ + buffer = maybe_purge_buffer(maps:size(Buffer), Buffer), + awaiting_ack = maybe_purge_ack(WaitAcks) + }. -purge(Queue) -> - {value, #queued{ queued = Oldest }} = queue:peek(Queue), - {value, #queued{ queued = Newest }} = queue:peek_r(Queue), - PurgeTime = mqtt_sessions_timestamp:timestamp(), - QoS0PurgeAge = (Newest - Oldest) div 2, - Queue1 = queue:filter( +maybe_purge_ack(WaitAcks) -> + Now = mqtt_sessions_timestamp:timestamp(), + maps:filter( fun - (#queued{ qos = 0, queued = Queued, expiry = Expiry }) -> - PurgeTime < Expiry andalso PurgeTime < (Queued + QoS0PurgeAge); - (#queued{ expiry = Expiry }) -> - PurgeTime < Expiry + (_, #wait_for{ is_sent = true, message = Msg, queued = Queued }) -> + Props = maps:get(properties, Msg, #{}), + Expiry = Queued + maps:get(message_expiry_interval, Props, ?MESSAGE_EXPIRY_DEFAULT), + Expiry > Now; + (_, #wait_for{ queued = Queued }) -> + Queued + ?ACK_EXPIRY > Now + end, + WaitAcks). + +maybe_purge_buffer(Size, Buffer) when Size > ?MAX_BUFFERED-> + purge_buffer(Buffer); +maybe_purge_buffer(_Size, Buffer) -> + Buffer. + +purge_buffer(Buffer) -> + Now = mqtt_sessions_timestamp:timestamp(), + {Oldest, Newest} = maps:fold( + fun(_MsgNr, #queued{ queued = Time }, {OAcc, NAcc}) -> + OAcc1 = min(OAcc, Time), + NAcc1 = max(NAcc, Time), + {OAcc1, NAcc1} end, - Queue), - case queue:len(Queue1) > ?MAX_QUEUED of + {Now, 0}, + Buffer), + Buffer1 = if + Oldest =< Newest -> + QoS0PurgeAge = (Newest - Oldest) div 2, + maps:filter( + fun + (_MsgNr, #queued{ qos = 0, queued = Queued, expiry = Expiry }) -> + Now < Expiry andalso Now < (Queued + QoS0PurgeAge); + (_MsgNr, #queued{ expiry = Expiry }) -> + Now < Expiry + end, + Buffer); + true -> + Buffer + end, + case maps:size(Buffer1) > ?MAX_BUFFERED of true -> % Drop all QoS 0 messages - queue:filter(fun(#queued{ qos = QoS }) -> QoS > 0 end, Queue1); + maps:filter(fun(_, #queued{ qos = QoS }) -> QoS > 0 end, Buffer1); false -> - Queue1 + Buffer1 end. -spec encode( mqtt_packet_map:mqtt_version(), mqtt_packet_map:mqtt_packet() | list( mqtt_packet_map:mqtt_packet() )) -> binary(). encode(ProtocolVersion, Msg) when is_map(Msg) -> {ok, Bin} = mqtt_packet_map:encode(ProtocolVersion, Msg), - Bin; -encode(ProtocolVersion, Ms) when is_list(Ms) -> - iolist_to_binary([ encode(ProtocolVersion, M) || M <- Ms ]). + Bin. %% @doc Set the new connection, disconnect existing transport. set_connection(#{ connection_pid := ConnectionPid, transport := Transport }, State) -> case State#state.connection_pid of ConnectionPid -> - State; + State#state{ transport = Transport }; undefined -> set_connection_1(ConnectionPid, Transport, State); OldConnectionPid -> @@ -1148,7 +1234,7 @@ start_keep_alive(#state{ keep_alive = N } = State) -> erlang:send_after(N * 500, self(), {keep_alive, Ref}), State#state{ keep_alive_counter = 3, keep_alive_ref = Ref }. -%% @doc Increment the message number, this number is used for order of resent messages +%% @doc Increment the message number, this number is used for order of resending buffered messages inc_msg_nr(#state{ msg_nr = Nr } = State) -> State#state{ msg_nr = Nr + 1 }. diff --git a/test/mqtt_sessions_protocol_SUITE.erl b/test/mqtt_sessions_protocol_SUITE.erl index ce351ae..bff1c1d 100644 --- a/test/mqtt_sessions_protocol_SUITE.erl +++ b/test/mqtt_sessions_protocol_SUITE.erl @@ -9,7 +9,7 @@ -dialyzer({nowarn_function, connect_disconnect_v5_test/1}). -dialyzer({nowarn_function, connect_reconnect_v5_test/1}). -dialyzer({nowarn_function, connect_reconnect_clean_v5_test/1}). - +-dialyzer({nowarn_function, connect_reconnect_buffered_v5_test/1}). %%-------------------------------------------------------------------- %% COMMON TEST CALLBACK FUNCTIONS @@ -32,7 +32,8 @@ all() -> [ connect_disconnect_v5_test, connect_reconnect_v5_test, - connect_reconnect_clean_v5_test + connect_reconnect_clean_v5_test, + connect_reconnect_buffered_v5_test ]. %%-------------------------------------------------------------------- @@ -330,3 +331,202 @@ connect_reconnect_clean_v5_test(_Config) -> ok end, ok. + +connect_reconnect_buffered_v5_test(_Config) -> + Connect = #{ + type => connect, + protocol_name => <<"MQTT">>, + protocol_version => 5, + clean_start => true, + client_id => <<"test4">>, + will_flag => false, + username => <<>>, + password => <<>>, + properties => #{ + } + }, + {ok, ConnectMsg} = mqtt_packet_map:encode(5, Connect), + Options = #{ + transport => self() + }, + {ok, {SessionPid, <<>>}} = mqtt_sessions:incoming_connect(ConnectMsg, Options), + true = is_pid(SessionPid), + receive + {mqtt_transport, SessionPid, MsgBin} when is_binary(MsgBin) -> + {ok, {ConnAck, <<>>}} = mqtt_packet_map:decode(MsgBin), + #{ + type := connack, + session_present := false, + reason_code := 0 + } = ConnAck, + ok + end, + % Subscribe to a topic + Subscribe = #{ + type => subscribe, + topics => [ #{ topic => <<"reconnect_v5_test">>, qos => 2 } ] + }, + {ok, SubMsg} = mqtt_packet_map:encode(5, Subscribe), + mqtt_sessions:incoming_data(SessionPid, SubMsg), + receive + {mqtt_transport, SessionPid, SubAckMsgBin} when is_binary(SubAckMsgBin) -> + {ok, {SubAck, <<>>}} = mqtt_packet_map:decode(SubAckMsgBin), + #{ + type := suback, + acks := [ {ok, 2} ] + } = SubAck, + ok + end, + Disconnect = #{ + type => disconnect, + reason_code => 0, + properties => #{ + session_expiry_interval => 3600 + } + }, + {ok, DisconnectMsg} = mqtt_packet_map:encode(5, Disconnect), + ok = mqtt_sessions:incoming_data(SessionPid, DisconnectMsg), + % + % The connection should be signaled to disconnect + % + receive + {mqtt_transport, SessionPid, disconnect} -> + ok + end, + % + % And the session process should still be running + % + timer:sleep(100), + true = erlang:is_process_alive(SessionPid), + % + % Message should be queued + % + PubMsg1 = #{ + type => publish, + topic => <<"reconnect_v5_test">>, + payload => <<"hello-offline-qos-0">>, + qos => 0 + }, + ok = mqtt_sessions:publish(PubMsg1, undefined), + + PubMsg2 = #{ + type => publish, + topic => <<"reconnect_v5_test">>, + payload => <<"hello-offline-qos-1">>, + qos => 1 + }, + ok = mqtt_sessions:publish(PubMsg2, undefined), + + PubMsg3 = #{ + type => publish, + topic => <<"reconnect_v5_test">>, + payload => <<"hello-offline-qos-2">>, + qos => 2 + }, + ok = mqtt_sessions:publish(PubMsg3, undefined), + % + % Reconnect without clean_start + % + Reconnect = #{ + type => connect, + protocol_name => <<"MQTT">>, + protocol_version => 5, + clean_start => false, + client_id => <<"test4">>, + will_flag => false, + username => <<>>, + password => <<>>, + properties => #{ + } + }, + {ok, ReconnectMsg} = mqtt_packet_map:encode(5, Reconnect), + {ok, {SessionPid, <<>>}} = mqtt_sessions:incoming_connect(ReconnectMsg, Options), + % + % We have reconnected to the existing session, check the connack + % for the 'session_present' flag. + % + receive + {mqtt_transport, SessionPid, MsgBin2} when is_binary(MsgBin2) -> + {ok, {ConnAck2, <<>>}} = mqtt_packet_map:decode(MsgBin2), + #{ + type := connack, + session_present := true, + reason_code := 0 + } = ConnAck2, + ok + end, + % + % Should receive queued QoS 0 message + % + receive + {mqtt_transport, SessionPid, PubMsg1Bin} when is_binary(PubMsg1Bin) -> + {ok, {PubMsg1Received, <<>>}} = mqtt_packet_map:decode(PubMsg1Bin), + #{ + type := publish, + payload := <<"hello-offline-qos-0">>, + qos := 0, + packet_id := undefined + } = PubMsg1Received, + ok + after 10 -> + ct:fail(unsubscribed) + end, + % + % Should receive queued QoS 1 message + % + PacketId2 = receive + {mqtt_transport, SessionPid, PubMsg2Bin} when is_binary(PubMsg2Bin) -> + {ok, {PubMsg2Received, <<>>}} = mqtt_packet_map:decode(PubMsg2Bin), + #{ + type := publish, + qos := 1, + payload := <<"hello-offline-qos-1">>, + packet_id := PId2 + } = PubMsg2Received, + PId2 + after 10 -> + ct:fail(unsubscribed) + end, + Ack2 = #{ + type => puback, + packet_id => PacketId2 + }, + {ok, Ack2Data} = mqtt_packet_map:encode(5, Ack2), + mqtt_sessions_process:incoming_data(SessionPid, Ack2Data), + % + % Should receive queued QoS 2 message + % + PacketId3 = receive + {mqtt_transport, SessionPid, PubMsg3Bin} when is_binary(PubMsg3Bin) -> + {ok, {PubMsg3Received, <<>>}} = mqtt_packet_map:decode(PubMsg3Bin), + #{ + type := publish, + qos := 2, + payload := <<"hello-offline-qos-2">>, + packet_id := PId3 + } = PubMsg3Received, + PId3 + after 10 -> + ct:fail(unsubscribed) + end, + % Acknowledge with pubrel + Rel3 = #{ + type => pubrel, + packet_id => PacketId3 + }, + {ok, Rel3Data} = mqtt_packet_map:encode(5, Rel3), + mqtt_sessions_process:incoming_data(SessionPid, Rel3Data), + % Should receive a pubcomp back + receive + {mqtt_transport, SessionPid, PubMsg4Bin} when is_binary(PubMsg4Bin) -> + {ok, {PubMsg4Received, <<>>}} = mqtt_packet_map:decode(PubMsg4Bin), + #{ + type := pubcomp, + packet_id := PacketId3 + } = PubMsg4Received, + ok + after 10 -> + ct:fail(pubcomp) + end, + + ok.