From ca243210db91df366195d992e1d3c72b0adcfcf9 Mon Sep 17 00:00:00 2001 From: Janusz Jakubiec Date: Mon, 3 Jun 2024 07:55:27 +0200 Subject: [PATCH 1/4] Adding support for stream_management to escalus_ws --- src/escalus_tcp.erl | 8 +- src/escalus_ws.erl | 198 ++++++++++++++++++++++++++++++++------------ 2 files changed, 148 insertions(+), 58 deletions(-) mode change 100644 => 100755 src/escalus_tcp.erl mode change 100644 => 100755 src/escalus_ws.erl diff --git a/src/escalus_tcp.erl b/src/escalus_tcp.erl old mode 100644 new mode 100755 index 3e35426..a655ad8 --- a/src/escalus_tcp.erl +++ b/src/escalus_tcp.erl @@ -427,7 +427,7 @@ separate_ack_requests({true, H0, inactive}, Stanzas) -> Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)], Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)], - case {length(Enabled),length(Resumed)} of + case {length(Enabled), length(Resumed)} of %% Enabled SM: set the H param to 0 and activate counter. {1,0} -> {{true, 0, active}, [], Stanzas}; @@ -450,12 +450,12 @@ separate_ack_requests({true, H0, active}, Stanzas) -> make_ack(H) -> {escalus_stanza:sm_ack(H), H}. -reply_to_ack_requests({false,H,A}, _, _) -> {false, H, A}; -reply_to_ack_requests({true,H,inactive}, _, _) -> {true, H, inactive}; +reply_to_ack_requests({false, H, A}, _, _) -> {false, H, A}; +reply_to_ack_requests({true, H, inactive}, _, _) -> {true, H, inactive}; reply_to_ack_requests({true, H0, active}, Acks, State) -> {true, % TODO: Maybe compress here? - lists:foldl(fun({Ack,H}, _) -> raw_send(exml:to_iolist(Ack), State), H end, + lists:foldl(fun({Ack, H}, _) -> raw_send(exml:to_iolist(Ack), State), H end, H0, Acks), active}. diff --git a/src/escalus_ws.erl b/src/escalus_ws.erl old mode 100644 new mode 100755 index 08694a3..8fc4fd0 --- a/src/escalus_ws.erl +++ b/src/escalus_ws.erl @@ -16,6 +16,8 @@ send/2, is_connected/1, reset_parser/1, + get_sm_h/1, + set_sm_h/2, use_zlib/1, upgrade_to_tls/2, set_filter_predicate/2, @@ -40,9 +42,24 @@ -define(SERVER, ?MODULE). -record(state, {owner, socket, parser, legacy_ws, compress = false, - event_client, filter_pred, stream_ref}). + event_client, sm_state, filter_pred, stream_ref, sent_stanzas = []}). -type state() :: #state{}. +-type sm_state() :: {boolean(), non_neg_integer(), 'active'|'inactive'}. + +-type opts() :: #{ + host => string(), + port => pos_integer(), + wspath => string(), + wslegacy => boolean(), + event_client => undefined | escalus_event:event_client(), + ssl => boolean(), + ssl_opts => [ssl:ssl_option()], + ws_upgrade_timeout => pos_integer(), + stream_management => boolean(), + manual_ack => boolean() +}. + %%%=================================================================== %%% API %%%=================================================================== @@ -60,6 +77,15 @@ send(Pid, Elem) -> is_connected(Pid) -> erlang:is_process_alive(Pid). +-spec get_sm_h(pid()) -> non_neg_integer(). +get_sm_h(Pid) -> + gen_server:call(Pid, get_sm_h). + +-spec set_sm_h(pid(), non_neg_integer()) -> {ok, non_neg_integer()}. +set_sm_h(Pid, H) -> + gen_server:call(Pid, {set_sm_h, H}). + + -spec reset_parser(pid()) -> ok. reset_parser(Pid) -> gen_server:cast(Pid, reset_parser). @@ -155,31 +181,58 @@ assert_stream_end(StreamEndRep, Props) -> %%% gen_server callbacks %%%=================================================================== -%% TODO: refactor all opt defaults taken from Args into a default_opts function, -%% so that we know what options the module actually expects +default_options() -> + #{host => "localhost", + port => 5280, + wspath => "/ws-xmpp", + wslegacy => false, + event_client => undefined, + ssl => false, + ssl_opts => [], + ws_upgrade_timeout => 5000, + stream_management => false, + manual_ack => false}. + +-spec get_stream_management_opt(opts()) -> sm_state(). +get_stream_management_opt(#{stream_management := false}) -> + {false, 0, inactive}; +get_stream_management_opt(#{manual_ack := true}) -> + {false, 0, inactive}; +get_stream_management_opt(#{stream_management := true, manual_ack := false}) -> + {true, 0, inactive}. + +overwrite_default_opts(GivenOpts, DefaultOpts) -> + maps:merge(DefaultOpts, GivenOpts). + +do_connect(#{ssl := true, ssl_opts := SSLOpts} = Opts) -> + TransportOpts = #{transport => tls, protocols => [http], + tls_opts => SSLOpts}, + do_connect(Opts, TransportOpts); +do_connect(Opts) -> + do_connect(Opts, #{transport => tcp, protocols => [http]}). + +do_connect(#{host := Host, port := Port}, TransportOpts) -> + Host1 = maybe_binary_to_list(Host), + {ok, ConnPid} = gun:open(Host1, Port, TransportOpts), + {ok, http} = gun:await_up(ConnPid), + ConnPid. + -spec init(list()) -> {ok, state()}. -init([Args, Owner]) -> - Host = get_host(Args, "localhost"), - Port = get_port(Args, 5280), - Resource = get_resource(Args, "/ws-xmpp"), - LegacyWS = get_legacy_ws(Args, false), - EventClient = proplists:get_value(event_client, Args), - SSL = proplists:get_value(ssl, Args, false), - SSLOpts = proplists:get_value(ssl_opts, Args, []), +init([Opts, Owner]) -> + Opts1 = overwrite_default_opts(maps:from_list(Opts), default_options()), + #{wspath := Resource, + wslegacy := LegacyWS, + event_client := EventClient, + ws_upgrade_timeout := Timeout} = Opts1, + SM = get_stream_management_opt(Opts1), + Resource1 = maybe_binary_to_list(Resource), + + ConnPid = do_connect(Opts1), + %% Disable http2 in protocols - TransportOpts = case SSL of - true -> - #{transport => tls, protocols => [http], - tls_opts => SSLOpts}; - _ -> - #{transport => tcp, protocols => [http]} - end, - {ok, ConnPid} = gun:open(Host, Port, TransportOpts), - {ok, http} = gun:await_up(ConnPid), WSUpgradeHeaders = [{<<"sec-websocket-protocol">>, <<"xmpp">>}], - StreamRef = gun:ws_upgrade(ConnPid, Resource, WSUpgradeHeaders, + StreamRef = gun:ws_upgrade(ConnPid, Resource1, WSUpgradeHeaders, #{protocols => [{<<"xmpp">>, gun_ws_h}]}), - Timeout = get_option(ws_upgrade_timeout, Args, 5000), wait_for_ws_upgrade(ConnPid, StreamRef, Timeout), ParserOpts = case LegacyWS of true -> []; @@ -190,6 +243,7 @@ init([Args, Owner]) -> socket = ConnPid, parser = Parser, legacy_ws = LegacyWS, + sm_state = SM, event_client = EventClient, stream_ref = StreamRef}}. @@ -208,6 +262,11 @@ wait_for_ws_upgrade(ConnPid, StreamRef, Timeout) -> -spec handle_call(term(), {pid(), term()}, state()) -> {reply, term(), state()} | {stop, normal, ok, state()}. +handle_call(get_sm_h, _From, #state{sm_state = {_, H, _}} = State) -> + {reply, H, State}; +handle_call({set_sm_h, H}, _From, #state{sm_state = {A, _OldH, S}} = State) -> + NewState = State#state{sm_state={A, H, S}}, + {reply, {ok, H}, NewState}; handle_call(use_zlib, _, #state{parser = Parser} = State) -> Zin = zlib:open(), Zout = zlib:open(), @@ -270,48 +329,81 @@ handle_data(Data, State = #state{parser = Parser, Decompressed = iolist_to_binary(zlib:inflate(Zin, Data)), exml_stream:parse(Parser, Decompressed) end, - NewState = State#state{parser = NewParser}, - escalus_connection:maybe_forward_to_owner(NewState#state.filter_pred, - NewState, + FwdState = State#state{parser = NewParser, sent_stanzas = []}, + escalus_connection:maybe_forward_to_owner(FwdState#state.filter_pred, + FwdState, Stanzas, - fun forward_to_owner/3, Timestamp), - case lists:filter(fun(Stanza) -> is_stream_end(Stanza, State) end, Stanzas) of - [] -> {noreply, NewState}; - _ -> {stop, normal, NewState} - end. + fun forward_to_owner/3, Timestamp). -spec is_stream_end(exml_stream:element(), state()) -> boolean(). is_stream_end(#xmlstreamend{}, #state{legacy_ws = true}) -> true; is_stream_end(#xmlel{name = <<"close">>}, #state{legacy_ws = false}) -> true; is_stream_end(_, _) -> false. -forward_to_owner(Stanzas, #state{owner = Owner, - event_client = EventClient}, Timestamp) -> +forward_to_owner(Stanzas0, #state{owner = Owner, + sm_state = SM0, + event_client = EventClient} = State, Timestamp) -> + {SM1, AckRequests, StanzasNoRs} = separate_ack_requests(SM0, Stanzas0), + reply_to_ack_requests(SM1, AckRequests, State), + lists:foreach(fun(Stanza) -> - escalus_event:incoming_stanza(EventClient, Stanza), - Owner ! escalus_connection:stanza_msg(Stanza, - #{recv_timestamp => Timestamp}) - end, Stanzas). + escalus_event:incoming_stanza(EventClient, Stanza), + Owner ! escalus_connection:stanza_msg(Stanza, #{recv_timestamp => Timestamp}) + end, StanzasNoRs), + + case lists:keyfind(xmlstreamend, 1, StanzasNoRs) of + false -> ok; + _ -> gen_server:cast(self(), stop) + end, + + {noreply, State#state{sm_state = SM1, sent_stanzas = StanzasNoRs}}. + +separate_ack_requests({false, H0, A}, Stanzas) -> + %% Don't keep track of H + {{false, H0, A}, [], Stanzas}; +separate_ack_requests({true, H0, inactive}, Stanzas) -> + Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)], + Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)], + + case {length(Enabled), length(Resumed)} of + %% Enabled SM: set the H param to 0 and activate counter. + {1,0} -> {{true, 0, active}, [], Stanzas}; + + %% Resumed SM: keep the H param and activate counter. + {0,1} -> {{true, H0, active}, [], Stanzas}; + + %% No new SM state: continue as usual + {0,0} -> {{true, H0, inactive}, [], Stanzas} + end; +separate_ack_requests({true, H0, active}, Stanzas) -> + %% Count H and construct appropriate acks + F = fun(Stanza, {H, Acks, NonAckRequests}) -> + case escalus_pred:is_sm_ack_request(Stanza) of + true -> {H, [make_ack(H)|Acks], NonAckRequests}; + false -> {H+1, Acks, [Stanza|NonAckRequests]} + end + end, + {H, Acks, Others} = lists:foldl(F, {H0, [], []}, Stanzas), + {{true, H, active}, lists:reverse(Acks), lists:reverse(Others)}. + +make_ack(H) -> {escalus_stanza:sm_ack(H), H}. + +reply_to_ack_requests({false, H, A}, _, _) -> + {false, H, A}; +reply_to_ack_requests({true, H, inactive}, _, _) -> + {true, H, inactive}; +reply_to_ack_requests({true, H0, active}, Acks, State) -> + {true, + lists:foldl(fun({Ack, H}, _) -> + Ack1 = exml:to_iolist(Ack), + gun:ws_send(State#state.socket, State#state.stream_ref, {text, Ack1}), + H + end, H0, Acks), + active}. common_terminate(_Reason, #state{parser = Parser}) -> exml_stream:free_parser(Parser). --spec get_port(list(), inet:port_number()) -> inet:port_number(). -get_port(Args, Default) -> - get_option(port, Args, Default). - --spec get_host(list(), string()) -> string(). -get_host(Args, Default) -> - maybe_binary_to_list(get_option(host, Args, Default)). - --spec get_resource(list(), string()) -> string(). -get_resource(Args, Default) -> - maybe_binary_to_list(get_option(wspath, Args, Default)). - --spec get_legacy_ws(list(), boolean()) -> boolean(). -get_legacy_ws(Args, Default) -> - get_option(wslegacy, Args, Default). - -spec maybe_binary_to_list(binary() | string()) -> string(). maybe_binary_to_list(B) when is_binary(B) -> binary_to_list(B); maybe_binary_to_list(S) when is_list(S) -> S. @@ -336,5 +428,3 @@ close_compression_streams({zlib, {Zin, Zout}}) -> ok = zlib:close(Zin), ok = zlib:close(Zout) end. - - From 511cb33a1d9ad86e373ca8350361cd7ca7cd53d7 Mon Sep 17 00:00:00 2001 From: Janusz Jakubiec Date: Wed, 5 Jun 2024 17:57:31 +0200 Subject: [PATCH 2/4] Fixing issue with stream ending --- src/escalus_ws.erl | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/escalus_ws.erl b/src/escalus_ws.erl index 8fc4fd0..f2a04c4 100755 --- a/src/escalus_ws.erl +++ b/src/escalus_ws.erl @@ -330,10 +330,14 @@ handle_data(Data, State = #state{parser = Parser, exml_stream:parse(Parser, Decompressed) end, FwdState = State#state{parser = NewParser, sent_stanzas = []}, - escalus_connection:maybe_forward_to_owner(FwdState#state.filter_pred, - FwdState, - Stanzas, - fun forward_to_owner/3, Timestamp). + NewState = escalus_connection:maybe_forward_to_owner(FwdState#state.filter_pred, + FwdState, + Stanzas, + fun forward_to_owner/3, Timestamp), + case lists:filter(fun(Stanza) -> is_stream_end(Stanza, State) end, Stanzas) of + [] -> {noreply, NewState}; + _ -> {stop, normal, NewState} + end. -spec is_stream_end(exml_stream:element(), state()) -> boolean(). is_stream_end(#xmlstreamend{}, #state{legacy_ws = true}) -> true; @@ -351,12 +355,7 @@ forward_to_owner(Stanzas0, #state{owner = Owner, Owner ! escalus_connection:stanza_msg(Stanza, #{recv_timestamp => Timestamp}) end, StanzasNoRs), - case lists:keyfind(xmlstreamend, 1, StanzasNoRs) of - false -> ok; - _ -> gen_server:cast(self(), stop) - end, - - {noreply, State#state{sm_state = SM1, sent_stanzas = StanzasNoRs}}. + State#state{sm_state = SM1, sent_stanzas = StanzasNoRs}. separate_ack_requests({false, H0, A}, Stanzas) -> %% Don't keep track of H From d345b524b69430c3bdf806cd62351940566408e1 Mon Sep 17 00:00:00 2001 From: Janusz Jakubiec Date: Thu, 13 Jun 2024 11:23:16 +0200 Subject: [PATCH 3/4] Adding changes in escalus_connection --- src/escalus_connection.erl | 4 +++ src/escalus_ws.erl | 70 ++++++++++++++++++++------------------ 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/src/escalus_connection.erl b/src/escalus_connection.erl index 891d2d4..6133cfd 100644 --- a/src/escalus_connection.erl +++ b/src/escalus_connection.erl @@ -373,12 +373,16 @@ get_stream_end(#client{rcv_pid = Pid, jid = Jid}, Timeout) -> -spec get_sm_h(client()) -> non_neg_integer(). get_sm_h(#client{module = escalus_tcp, rcv_pid = Pid}) -> escalus_tcp:get_sm_h(Pid); +get_sm_h(#client{module = escalus_ws, rcv_pid = Pid}) -> + escalus_ws:get_sm_h(Pid); get_sm_h(#client{module = Mod}) -> error({get_sm_h, {undefined_for_escalus_module, Mod}}). -spec set_sm_h(client(), non_neg_integer()) -> {ok, non_neg_integer()}. set_sm_h(#client{module = escalus_tcp, rcv_pid = Pid}, H) -> escalus_tcp:set_sm_h(Pid, H); +set_sm_h(#client{module = escalus_ws, rcv_pid = Pid}, H) -> + escalus_ws:set_sm_h(Pid, H); set_sm_h(#client{module = Mod}, _) -> error({set_sm_h, {undefined_for_escalus_module, Mod}}). diff --git a/src/escalus_ws.erl b/src/escalus_ws.erl index f2a04c4..34034c6 100755 --- a/src/escalus_ws.erl +++ b/src/escalus_ws.erl @@ -65,8 +65,9 @@ %%%=================================================================== -spec connect([proplists:property()]) -> pid(). -connect(Args) -> - {ok, Pid} = gen_server:start_link(?MODULE, [Args, self()], []), +connect(Opts0) -> + Opts1 = opts_to_map(Opts0), + {ok, Pid} = gen_server:start_link(?MODULE, [Opts1, self()], []), Pid. -spec send(pid(), exml:element()) -> ok. @@ -178,7 +179,7 @@ assert_stream_end(StreamEndRep, Props) -> end. %%%=================================================================== -%%% gen_server callbacks +%%% Default options %%%=================================================================== default_options() -> @@ -193,33 +194,13 @@ default_options() -> stream_management => false, manual_ack => false}. --spec get_stream_management_opt(opts()) -> sm_state(). -get_stream_management_opt(#{stream_management := false}) -> - {false, 0, inactive}; -get_stream_management_opt(#{manual_ack := true}) -> - {false, 0, inactive}; -get_stream_management_opt(#{stream_management := true, manual_ack := false}) -> - {true, 0, inactive}. - -overwrite_default_opts(GivenOpts, DefaultOpts) -> - maps:merge(DefaultOpts, GivenOpts). - -do_connect(#{ssl := true, ssl_opts := SSLOpts} = Opts) -> - TransportOpts = #{transport => tls, protocols => [http], - tls_opts => SSLOpts}, - do_connect(Opts, TransportOpts); -do_connect(Opts) -> - do_connect(Opts, #{transport => tcp, protocols => [http]}). - -do_connect(#{host := Host, port := Port}, TransportOpts) -> - Host1 = maybe_binary_to_list(Host), - {ok, ConnPid} = gun:open(Host1, Port, TransportOpts), - {ok, http} = gun:await_up(ConnPid), - ConnPid. +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== -spec init(list()) -> {ok, state()}. init([Opts, Owner]) -> - Opts1 = overwrite_default_opts(maps:from_list(Opts), default_options()), + Opts1 = overwrite_default_opts(Opts, default_options()), #{wspath := Resource, wslegacy := LegacyWS, event_client := EventClient, @@ -318,6 +299,30 @@ code_change(_OldVsn, State, _Extra) -> %%% Helpers %%%=================================================================== +-spec get_stream_management_opt(opts()) -> sm_state(). +get_stream_management_opt(#{stream_management := false}) -> + {false, 0, inactive}; +get_stream_management_opt(#{manual_ack := true}) -> + {false, 0, inactive}; +get_stream_management_opt(#{stream_management := true, manual_ack := false}) -> + {true, 0, inactive}. + +overwrite_default_opts(GivenOpts, DefaultOpts) -> + maps:merge(DefaultOpts, GivenOpts). + +do_connect(#{ssl := true, ssl_opts := SSLOpts} = Opts) -> + TransportOpts = #{transport => tls, protocols => [http], + tls_opts => SSLOpts}, + do_connect(Opts, TransportOpts); +do_connect(Opts) -> + do_connect(Opts, #{transport => tcp, protocols => [http]}). + +do_connect(#{host := Host, port := Port}, TransportOpts) -> + Host1 = maybe_binary_to_list(Host), + {ok, ConnPid} = gun:open(Host1, Port, TransportOpts), + {ok, http} = gun:await_up(ConnPid), + ConnPid. + handle_data(Data, State = #state{parser = Parser, compress = Compress}) -> Timestamp = os:system_time(micro_seconds), @@ -407,13 +412,6 @@ common_terminate(_Reason, #state{parser = Parser}) -> maybe_binary_to_list(B) when is_binary(B) -> binary_to_list(B); maybe_binary_to_list(S) when is_list(S) -> S. --spec get_option(any(), list(), any()) -> any(). -get_option(Key, Opts, Default) -> - case lists:keyfind(Key, 1, Opts) of - false -> Default; - {Key, Value} -> Value - end. - close_compression_streams(false) -> ok; close_compression_streams({zlib, {Zin, Zout}}) -> @@ -427,3 +425,7 @@ close_compression_streams({zlib, {Zin, Zout}}) -> ok = zlib:close(Zin), ok = zlib:close(Zout) end. + +-spec opts_to_map([proplists:property()] | opts()) -> opts(). +opts_to_map(Opts) when is_map(Opts) -> Opts; +opts_to_map(Opts) when is_list(Opts) -> maps:from_list(Opts). From 48421fa6d8a89574faccf0aef65f65a9e6bc67b9 Mon Sep 17 00:00:00 2001 From: Janusz Jakubiec Date: Wed, 19 Jun 2024 09:46:38 +0200 Subject: [PATCH 4/4] Fixing CR comments --- src/escalus_connection.erl | 32 +++++++++++++++++++++++++++- src/escalus_tcp.erl | 34 +----------------------------- src/escalus_ws.erl | 43 ++++++-------------------------------- 3 files changed, 38 insertions(+), 71 deletions(-) diff --git a/src/escalus_connection.erl b/src/escalus_connection.erl index 6133cfd..1b385ff 100644 --- a/src/escalus_connection.erl +++ b/src/escalus_connection.erl @@ -41,7 +41,7 @@ upgrade_to_tls/1, start_stream/1]). --export([stanza_msg/2]). +-export([stanza_msg/2, separate_ack_requests/2]). %% Behaviour helpers -export([maybe_forward_to_owner/5]). @@ -474,10 +474,40 @@ maybe_forward_to_owner(_, State, Stanzas, Fun, Timestamp) -> stanza_msg(Stanza, Metadata) -> {stanza, self(), Stanza, Metadata}. +separate_ack_requests({false, H0, A}, Stanzas) -> + %% Don't keep track of H + {{false, H0, A}, [], Stanzas}; +separate_ack_requests({true, H0, inactive}, Stanzas) -> + Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)], + Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)], + + case {length(Enabled), length(Resumed)} of + %% Enabled SM: set the H param to 0 and activate counter. + {1,0} -> {{true, 0, active}, [], Stanzas}; + + %% Resumed SM: keep the H param and activate counter. + {0,1} -> {{true, H0, active}, [], Stanzas}; + + %% No new SM state: continue as usual + {0,0} -> {{true, H0, inactive}, [], Stanzas} + end; +separate_ack_requests({true, H0, active}, Stanzas) -> + %% Count H and construct appropriate acks + F = fun(Stanza, {H, Acks, NonAckRequests}) -> + case escalus_pred:is_sm_ack_request(Stanza) of + true -> {H, [make_ack(H)|Acks], NonAckRequests}; + false -> {H+1, Acks, [Stanza|NonAckRequests]} + end + end, + {H, Acks, Others} = lists:foldl(F, {H0, [], []}, Stanzas), + {{true, H, active}, lists:reverse(Acks), lists:reverse(Others)}. + %%%=================================================================== %%% Helpers %%%=================================================================== +make_ack(H) -> {escalus_stanza:sm_ack(H), H}. + get_connection_steps(UserSpec) -> case lists:keyfind(connection_steps, 1, UserSpec) of false -> default_connection_steps(); diff --git a/src/escalus_tcp.erl b/src/escalus_tcp.erl index a655ad8..c1ffd4f 100755 --- a/src/escalus_tcp.erl +++ b/src/escalus_tcp.erl @@ -400,11 +400,10 @@ handle_data(Socket, Data, #state{parser = Parser, _ -> NewState end. - forward_to_owner(Stanzas0, #state{owner = Owner, sm_state = SM0, event_client = EventClient} = State, Timestamp) -> - {SM1, AckRequests, StanzasNoRs} = separate_ack_requests(SM0, Stanzas0), + {SM1, AckRequests, StanzasNoRs} = escalus_connection:separate_ack_requests(SM0, Stanzas0), reply_to_ack_requests(SM1, AckRequests, State), lists:foreach(fun(Stanza) -> @@ -419,37 +418,6 @@ forward_to_owner(Stanzas0, #state{owner = Owner, State#state{sm_state = SM1, sent_stanzas = StanzasNoRs}. - -separate_ack_requests({false, H0, A}, Stanzas) -> - %% Don't keep track of H - {{false, H0, A}, [], Stanzas}; -separate_ack_requests({true, H0, inactive}, Stanzas) -> - Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)], - Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)], - - case {length(Enabled), length(Resumed)} of - %% Enabled SM: set the H param to 0 and activate counter. - {1,0} -> {{true, 0, active}, [], Stanzas}; - - %% Resumed SM: keep the H param and activate counter. - {0,1} -> {{true, H0, active}, [], Stanzas}; - - %% No new SM state: continue as usual - {0,0} -> {{true, H0, inactive}, [], Stanzas} - end; -separate_ack_requests({true, H0, active}, Stanzas) -> - %% Count H and construct appropriate acks - F = fun(Stanza, {H, Acks, NonAckRequests}) -> - case escalus_pred:is_sm_ack_request(Stanza) of - true -> {H, [make_ack(H)|Acks], NonAckRequests}; - false -> {H+1, Acks, [Stanza|NonAckRequests]} - end - end, - {H, Acks, Others} = lists:foldl(F, {H0, [], []}, Stanzas), - {{true, H, active}, lists:reverse(Acks), lists:reverse(Others)}. - -make_ack(H) -> {escalus_stanza:sm_ack(H), H}. - reply_to_ack_requests({false, H, A}, _, _) -> {false, H, A}; reply_to_ack_requests({true, H, inactive}, _, _) -> {true, H, inactive}; reply_to_ack_requests({true, H0, active}, Acks, State) -> diff --git a/src/escalus_ws.erl b/src/escalus_ws.erl index 34034c6..4823d7e 100755 --- a/src/escalus_ws.erl +++ b/src/escalus_ws.erl @@ -86,7 +86,6 @@ get_sm_h(Pid) -> set_sm_h(Pid, H) -> gen_server:call(Pid, {set_sm_h, H}). - -spec reset_parser(pid()) -> ok. reset_parser(Pid) -> gen_server:cast(Pid, reset_parser). @@ -311,8 +310,8 @@ overwrite_default_opts(GivenOpts, DefaultOpts) -> maps:merge(DefaultOpts, GivenOpts). do_connect(#{ssl := true, ssl_opts := SSLOpts} = Opts) -> - TransportOpts = #{transport => tls, protocols => [http], - tls_opts => SSLOpts}, + TransportOpts = #{transport => tls, protocols => [http], + tls_opts => SSLOpts}, do_connect(Opts, TransportOpts); do_connect(Opts) -> do_connect(Opts, #{transport => tcp, protocols => [http]}). @@ -350,9 +349,9 @@ is_stream_end(#xmlel{name = <<"close">>}, #state{legacy_ws = false}) -> true; is_stream_end(_, _) -> false. forward_to_owner(Stanzas0, #state{owner = Owner, - sm_state = SM0, - event_client = EventClient} = State, Timestamp) -> - {SM1, AckRequests, StanzasNoRs} = separate_ack_requests(SM0, Stanzas0), + sm_state = SM0, + event_client = EventClient} = State, Timestamp) -> + {SM1, AckRequests, StanzasNoRs} = escalus_connection:separate_ack_requests(SM0, Stanzas0), reply_to_ack_requests(SM1, AckRequests, State), lists:foreach(fun(Stanza) -> @@ -362,36 +361,6 @@ forward_to_owner(Stanzas0, #state{owner = Owner, State#state{sm_state = SM1, sent_stanzas = StanzasNoRs}. -separate_ack_requests({false, H0, A}, Stanzas) -> - %% Don't keep track of H - {{false, H0, A}, [], Stanzas}; -separate_ack_requests({true, H0, inactive}, Stanzas) -> - Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)], - Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)], - - case {length(Enabled), length(Resumed)} of - %% Enabled SM: set the H param to 0 and activate counter. - {1,0} -> {{true, 0, active}, [], Stanzas}; - - %% Resumed SM: keep the H param and activate counter. - {0,1} -> {{true, H0, active}, [], Stanzas}; - - %% No new SM state: continue as usual - {0,0} -> {{true, H0, inactive}, [], Stanzas} - end; -separate_ack_requests({true, H0, active}, Stanzas) -> - %% Count H and construct appropriate acks - F = fun(Stanza, {H, Acks, NonAckRequests}) -> - case escalus_pred:is_sm_ack_request(Stanza) of - true -> {H, [make_ack(H)|Acks], NonAckRequests}; - false -> {H+1, Acks, [Stanza|NonAckRequests]} - end - end, - {H, Acks, Others} = lists:foldl(F, {H0, [], []}, Stanzas), - {{true, H, active}, lists:reverse(Acks), lists:reverse(Others)}. - -make_ack(H) -> {escalus_stanza:sm_ack(H), H}. - reply_to_ack_requests({false, H, A}, _, _) -> {false, H, A}; reply_to_ack_requests({true, H, inactive}, _, _) -> @@ -426,6 +395,6 @@ close_compression_streams({zlib, {Zin, Zout}}) -> ok = zlib:close(Zout) end. --spec opts_to_map([proplists:property()] | opts()) -> opts(). +-spec opts_to_map(proplists:proplist() | opts()) -> opts(). opts_to_map(Opts) when is_map(Opts) -> Opts; opts_to_map(Opts) when is_list(Opts) -> maps:from_list(Opts).