diff --git a/lib/gen_stage.ex b/lib/gen_stage.ex index eccd0b1..f629375 100644 --- a/lib/gen_stage.ex +++ b/lib/gen_stage.ex @@ -858,8 +858,9 @@ defmodule GenStage do end The returned tuple may also contain 3 or 4 elements. The third - element may be the `:hibernate` atom or a set of options defined - below. + element may be a set of options defined below. The fourth element + is a timeout, the `:hibernate` atom or a `:continue` tuple. See + the return values for `c:GenServer.init/1` for more information. Returning `:ignore` will cause `start_link/3` to return `:ignore` and the process will exit normally without entering the loop or @@ -910,10 +911,14 @@ defmodule GenStage do @callback init(args :: term) :: {:producer, state} | {:producer, state, [producer_option]} + | {:producer, state, [producer_option], timeout() | {:continue, term} | :hibernate} | {:producer_consumer, state} | {:producer_consumer, state, [producer_consumer_option]} + | {:producer_consumer, state, [producer_consumer_option], + timeout() | {:continue, term} | :hibernate} | {:consumer, state} | {:consumer, state, [consumer_option]} + | {:consumer, state, [consumer_option], timeout() | {:continue, term} | :hibernate} | :ignore | {:stop, reason :: any} when state: any @@ -995,7 +1000,9 @@ defmodule GenStage do """ @callback handle_demand(demand :: pos_integer, state :: term) :: {:noreply, [event], new_state} + | {:noreply, [event], new_state, timeout()} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason, new_state} when new_state: term, reason: term, event: term @@ -1074,7 +1081,9 @@ defmodule GenStage do state :: term ) :: {:noreply, [event], new_state} + | {:noreply, [event], new_state, timeout()} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason, new_state} when event: term, new_state: term, reason: term @@ -1087,7 +1096,9 @@ defmodule GenStage do """ @callback handle_events(events :: [event], from, state :: term) :: {:noreply, [event], new_state} + | {:noreply, [event], new_state, timeout()} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason, new_state} when new_state: term, reason: term, event: term @@ -1128,9 +1139,13 @@ defmodule GenStage do """ @callback handle_call(request :: term, from :: GenServer.from(), state :: term) :: {:reply, reply, [event], new_state} + | {:reply, reply, [event], new_state, timeout()} | {:reply, reply, [event], new_state, :hibernate} + | {:reply, reply, [event], new_state, {:continue, term}} | {:noreply, [event], new_state} + | {:noreply, [event], new_state, timeout()} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason, reply, new_state} | {:stop, reason, new_state} when reply: term, new_state: term, reason: term, event: term @@ -1145,12 +1160,21 @@ defmodule GenStage do the loop with new state `new_state`. Only `:producer` and `:producer_consumer` stages can return a non-empty list of events. + Returning `{:noreply, [event], state, timeout}` is similar to `{:noreply, state}` + , except that it also sets a timeout. See the "Timeouts" section in the + `GenServer` documentation for more information. + Returning `{:noreply, [event], new_state, :hibernate}` is similar to `{:noreply, new_state}` except the process is hibernated before continuing the loop. See the return values for `c:GenServer.handle_call/3` for more information on hibernation. Only `:producer` and `:producer_consumer` stages can return a non-empty list of events. + Returning `{:noreply, [event], new_state, {:continue, continue_arg}}` is similar + to `{:noreply, new_state}` except that immediately after entering the loop, the + `c:handle_continue/2` callback will be invoked with `continue_arg` as the first + argument and `state` as the second one. + Returning `{:stop, reason, new_state}` stops the loop and `terminate/2` is called with the reason `reason` and state `new_state`. The process exits with reason `reason`. @@ -1160,7 +1184,9 @@ defmodule GenStage do """ @callback handle_cast(request :: term, state :: term) :: {:noreply, [event], new_state} + | {:noreply, [event], new_state, timeout()} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason :: term, new_state} when new_state: term, event: term @@ -1180,7 +1206,32 @@ defmodule GenStage do """ @callback handle_info(message :: term, state :: term) :: {:noreply, [event], new_state} + | {:noreply, [event], new_state, timeout()} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} + | {:stop, reason :: term, new_state} + when new_state: term, event: term + + @doc """ + Invoked to handle `:continue` instructions. + + This callback can be used to perform work right after emitting events from + other callbacks. The "continue mechanism" makes sure that no messages, + calls, casts, or anything else will be handled between a callback emitting + a `:continue` tuple and the `c:handle_continue/2` callback being invoked. + + Return values are the same as `c:handle_cast/2`. + + This callback is optional. If one is not implemented, the server will fail + if a continue instruction is used. + + This callback is only supported on Erlang/OTP 21+. + """ + @callback handle_continue(continue :: term, state :: term) :: + {:noreply, [event], new_state} + | {:noreply, [event], new_state, timeout()} + | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason :: term, new_state} when new_state: term, event: term @@ -1217,6 +1268,7 @@ defmodule GenStage do format_status: 2, handle_call: 3, handle_cast: 2, + handle_continue: 2, handle_info: 2, terminate: 2 ] @@ -1820,17 +1872,26 @@ defmodule GenStage do {:producer, state, opts} when is_list(opts) -> init_producer(mod, opts, state) + {:producer, state, opts, additional_info} when is_list(opts) -> + init_producer(mod, opts, state, additional_info) + {:producer_consumer, state} -> - init_producer_consumer(mod, [], state) + init_producer_consumer(mod, [], state, nil) {:producer_consumer, state, opts} when is_list(opts) -> - init_producer_consumer(mod, opts, state) + init_producer_consumer(mod, opts, state, nil) + + {:producer_consumer, state, opts, additional_info} when is_list(opts) -> + init_producer_consumer(mod, opts, state, additional_info) {:consumer, state} -> - init_consumer(mod, [], state) + init_consumer(mod, [], state, nil) {:consumer, state, opts} when is_list(opts) -> - init_consumer(mod, opts, state) + init_consumer(mod, opts, state, nil) + + {:consumer, state, opts, additional_info} when is_list(opts) -> + init_consumer(mod, opts, state, additional_info) {:stop, _} = stop -> stop @@ -1862,13 +1923,18 @@ defmodule GenStage do dispatcher_mod: dispatcher_mod, dispatcher_state: dispatcher_state } - {:ok, stage} else {:error, message} -> {:stop, {:bad_opts, message}} end end + defp init_producer(mod, opts, state, additional_info) do + with {:ok, stage} <- init_producer(mod, opts, state) do + {:ok, stage, additional_info} + end + end + defp init_dispatcher(opts) do case Keyword.pop(opts, :dispatcher, GenStage.DemandDispatcher) do {dispatcher, opts} when is_atom(dispatcher) -> @@ -1885,7 +1951,7 @@ defmodule GenStage do end end - defp init_producer_consumer(mod, opts, state) do + defp init_producer_consumer(mod, opts, state, additional_info) do with {:ok, dispatcher_mod, dispatcher_state, opts} <- init_dispatcher(opts), {:ok, subscribe_to, opts} <- Utils.validate_list(opts, :subscribe_to, []), {:ok, buffer_size, opts} <- @@ -1904,22 +1970,76 @@ defmodule GenStage do dispatcher_state: dispatcher_state } - consumer_init_subscribe(subscribe_to, stage) + case handle_gen_server_init_args(additional_info, stage) do + {:ok, stage} -> + consumer_init_subscribe(subscribe_to, stage) + + {:ok, stage, args} -> + {:ok, stage} = consumer_init_subscribe(subscribe_to, stage) + {:ok, stage, args} + + {:stop, _, _} = error -> + error + end else {:error, message} -> {:stop, {:bad_opts, message}} end end - defp init_consumer(mod, opts, state) do + defp init_consumer(mod, opts, state, additional_info) do with {:ok, subscribe_to, opts} <- Utils.validate_list(opts, :subscribe_to, []), :ok <- Utils.validate_no_opts(opts) do stage = %GenStage{mod: mod, state: state, type: :consumer} - consumer_init_subscribe(subscribe_to, stage) + + case handle_gen_server_init_args(additional_info, stage) do + {:ok, stage} -> + consumer_init_subscribe(subscribe_to, stage) + + {:ok, stage, args} -> + {:ok, stage} = consumer_init_subscribe(subscribe_to, stage) + {:ok, stage, args} + + {:stop, _, _} = error -> + error + end else {:error, message} -> {:stop, {:bad_opts, message}} end end + defp handle_gen_server_init_args({:continue, _term} = continue, stage) do + case handle_continue(continue, stage) do + {:noreply, stage} -> + {:ok, stage} + + {:noreply, stage, :hibernate} -> + {:ok, stage, :hibernate} + + {:noreply, stage, {:continue, _term} = continue} -> + {:ok, stage, continue} + + {:noreply, stage, timeout} -> + {:ok, stage, timeout} + + {:stop, reason, stage} -> + {:stop, reason, stage} + end + end + + defp handle_gen_server_init_args(:hibernate, stage), do: {:ok, stage, :hibernate} + + defp handle_gen_server_init_args(timeout, stage) + when (is_integer(timeout) and timeout >= 0) or timeout == :infinity, + do: {:ok, stage, timeout} + + defp handle_gen_server_init_args(nil, stage), do: {:ok, stage} + + @doc false + + def handle_continue(continue, %{state: state} = stage) do + noreply_callback(:handle_continue, [continue, state], stage) + end + @doc false def handle_call({:"$info", msg}, _from, stage) do @@ -1948,6 +2068,14 @@ defmodule GenStage do stage = dispatch_events(events, length(events), %{stage | state: state}) {:reply, reply, stage, :hibernate} + {:reply, reply, events, state, {:continue, _term} = continue} -> + stage = dispatch_events(events, length(events), %{stage | state: state}) + {:reply, reply, stage, continue} + + {:reply, reply, events, state, timeout} -> + stage = dispatch_events(events, length(events), %{stage | state: state}) + {:reply, reply, stage, timeout} + {:stop, reason, reply, state} -> {:stop, reason, reply, %{stage | state: state}} @@ -2092,7 +2220,7 @@ defmodule GenStage do case producers do %{^ref => entry} -> {batches, stage} = consumer_receive(from, entry, events, stage) - consumer_dispatch(batches, from, mod, state, stage, false) + consumer_dispatch(batches, from, mod, state, stage, nil) _ -> msg = {:"$gen_producer", {self(), ref}, {:cancel, :unknown_subscription}} @@ -2219,6 +2347,14 @@ defmodule GenStage do end end + defp noreply_callback(:handle_continue, [continue, state], %{mod: mod} = stage) do + if function_exported?(mod, :handle_continue, 2) do + handle_noreply_callback(mod.handle_continue(continue, state), stage) + else + :error_handler.raise_undef_exception(mod, :handle_continue, [continue, state]) + end + end + defp noreply_callback(callback, args, %{mod: mod} = stage) do handle_noreply_callback(apply(mod, callback, args), stage) end @@ -2233,6 +2369,14 @@ defmodule GenStage do stage = dispatch_events(events, length(events), %{stage | state: state}) {:noreply, stage, :hibernate} + {:noreply, events, state, {:continue, _term} = continue} when is_list(events) -> + stage = dispatch_events(events, length(events), %{stage | state: state}) + {:noreply, stage, continue} + + {:noreply, events, state, timeout} when is_list(events) -> + stage = dispatch_events(events, length(events), %{stage | state: state}) + {:noreply, stage, timeout} + {:stop, reason, state} -> {:stop, reason, %{stage | state: state}} @@ -2364,6 +2508,9 @@ defmodule GenStage do # main module must know the consumer is no longer subscribed. dispatcher_callback(:cancel, [{pid, ref}, dispatcher_state], stage) + {:noreply, %{dispatcher_state: dispatcher_state} = stage, _hibernate_or_continue} -> + dispatcher_callback(:cancel, [{pid, ref}, dispatcher_state], stage) + {:stop, _, _} = stop -> stop end @@ -2574,17 +2721,27 @@ defmodule GenStage do {[{events, 0}], stage} end - defp consumer_dispatch([{batch, ask} | batches], from, mod, state, stage, _hibernate?) do + defp consumer_dispatch([{batch, ask} | batches], from, mod, state, stage, _gen_opts) do case mod.handle_events(batch, from, state) do {:noreply, events, state} when is_list(events) -> stage = dispatch_events(events, length(events), stage) ask(from, ask, [:noconnect]) - consumer_dispatch(batches, from, mod, state, stage, false) + consumer_dispatch(batches, from, mod, state, stage, nil) - {:noreply, events, state, :hibernate} when is_list(events) -> + {:noreply, events, state, :hibernate} -> stage = dispatch_events(events, length(events), stage) ask(from, ask, [:noconnect]) - consumer_dispatch(batches, from, mod, state, stage, true) + consumer_dispatch(batches, from, mod, state, stage, :hibernate) + + {:noreply, events, state, {:continue, _} = continue} -> + stage = dispatch_events(events, length(events), stage) + ask(from, ask, [:noconnect]) + consumer_dispatch(batches, from, mod, state, stage, continue) + + {:noreply, events, state, timeout} -> + stage = dispatch_events(events, length(events), stage) + ask(from, ask, [:noconnect]) + consumer_dispatch(batches, from, mod, state, stage, timeout) {:stop, reason, state} -> {:stop, reason, %{stage | state: state}} @@ -2594,12 +2751,12 @@ defmodule GenStage do end end - defp consumer_dispatch([], _from, _mod, state, stage, false) do + defp consumer_dispatch([], _from, _mod, state, stage, nil) do {:noreply, %{stage | state: state}} end - defp consumer_dispatch([], _from, _mod, state, stage, true) do - {:noreply, %{stage | state: state}, :hibernate} + defp consumer_dispatch([], _from, _mod, state, stage, gen_opts) do + {:noreply, %{stage | state: state}, gen_opts} end defp consumer_subscribe({to, opts}, stage) when is_list(opts), @@ -2738,11 +2895,11 @@ defmodule GenStage do {producer_id, _, _} = entry from = {producer_id, ref} {batches, stage} = consumer_receive(from, entry, events, stage) - consumer_dispatch(batches, from, mod, state, stage, false) + consumer_dispatch(batches, from, mod, state, stage, nil) %{} -> # We queued but producer was removed - consumer_dispatch([{events, 0}], {:pid, ref}, mod, state, stage, false) + consumer_dispatch([{events, 0}], {:pid, ref}, mod, state, stage, nil) end end @@ -2759,6 +2916,9 @@ defmodule GenStage do {:noreply, stage, :hibernate} -> take_pc_events(queue, counter, stage) + {:noreply, stage, {:continue, _term}} -> + take_pc_events(queue, counter, stage) + {:stop, _, _} = stop -> stop end @@ -2771,6 +2931,9 @@ defmodule GenStage do {:noreply, %{events: {queue, counter}} = stage, :hibernate} -> take_pc_events(queue, counter, stage) + {:noreply, %{events: {queue, counter}} = stage, {:continue, _term}} -> + take_pc_events(queue, counter, stage) + {:stop, _, _} = stop -> stop end diff --git a/test/gen_stage_test.exs b/test/gen_stage_test.exs index 0a59b74..17bf015 100644 --- a/test/gen_stage_test.exs +++ b/test/gen_stage_test.exs @@ -81,6 +81,109 @@ defmodule GenStageTest do events = Enum.to_list(counter..(counter + demand - 1)) {:noreply, events, counter + demand} end + + # Use continue instructions to modify the counter, this + # can be reached from any gen_server callback by supplying + # a continue instruction with an integer term. + def handle_continue(new_counter, _counter) when is_integer(new_counter) do + {:noreply, [], new_counter} + end + end + + defmodule CounterNestedContinue do + @moduledoc """ + A producer that works as a counter in batches. + It also supports events to be queued via sync + and async calls. A negative counter disables + the counting behaviour. + + This counter uses a nested handle_continue on init. + """ + + use GenStage + + def start_link(init, opts \\ []) do + GenStage.start_link(__MODULE__, init, opts) + end + + def sync_queue(stage, events) do + GenStage.call(stage, {:queue, events}) + end + + def async_queue(stage, events) do + GenStage.cast(stage, {:queue, events}) + end + + def stop(stage) do + GenStage.call(stage, :stop) + end + + ## Callbacks + + def init(init) do + init + end + + def handle_call(:stop, _from, state) do + {:stop, :shutdown, :ok, state} + end + + def handle_call({:early_reply_queue, events}, from, state) do + GenStage.reply(from, state) + {:noreply, events, state} + end + + def handle_call({:queue, events}, _from, state) do + {:reply, state, events, state} + end + + def handle_cast({:queue, events}, state) do + {:noreply, events, state} + end + + def handle_info({:queue, events}, state) do + {:noreply, events, state} + end + + def handle_info(other, state) do + is_pid(state) && send(state, other) + {:noreply, [], state} + end + + def handle_subscribe(:consumer, opts, from, state) do + is_pid(state) && send(state, {:producer_subscribed, from}) + {Keyword.get(opts, :producer_demand, :automatic), state} + end + + def handle_cancel(reason, from, state) do + is_pid(state) && send(state, {:producer_cancelled, from, reason}) + {:noreply, [], state} + end + + def handle_demand(demand, pid) when is_pid(pid) and demand > 0 do + {:noreply, [], pid} + end + + def handle_demand(demand, counter) when demand > 0 do + # If the counter is 3 and we ask for 2 items, we will + # emit the items 3 and 4, and set the state to 5. + events = Enum.to_list(counter..(counter + demand - 1)) + {:noreply, events, counter + demand} + end + + # Use continue instructions to modify the counter, this + # can be reached from any gen_server callback by supplying + # a continue instruction with an integer term. + # + # This particular handle_continue returns another continue instruction + # testing that we handle nested continues properly. + def handle_continue(500, _counter) do + {:noreply, [], 500, {:continue, 2000}} + end + + def handle_continue(2000, _counter) do + {:noreply, [], 2000} + end end defmodule DemandProducer do @@ -139,6 +242,11 @@ defmodule GenStageTest do is_pid(state) && send(state, other) {:noreply, [], state} end + + def handle_continue({:continue, term}, recipient) do + send(recipient, term) + {:noreply, [], recipient} + end end defmodule Postponer do @@ -258,6 +366,11 @@ defmodule GenStageTest do {:noreply, [], recipient} end + def handle_continue({:continue, term}, recipient) do + send(recipient, term) + {:noreply, [], recipient} + end + def terminate(reason, state) do send(state, {:terminated, reason}) end @@ -324,6 +437,97 @@ defmodule GenStageTest do } end + {otp_version, ""} = :otp_release |> :erlang.system_info() |> to_string() |> Integer.parse() + + if otp_version >= 21 do + describe "handle_continue tests" do + test "producing_init with continue instruction setting counter start position" do + {:ok, producer} = Counter.start_link({:producer, 0, [], {:continue, 500}}) + {:ok, _} = Forwarder.start_link({:consumer, self(), subscribe_to: [producer]}) + + batch = Enum.to_list(0..499) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(500..999) + assert_receive {:consumed, ^batch} + batch = Enum.to_list(1000..1499) + assert_receive {:consumed, ^batch} + end + + test "producer_init with nested continue instruction setting counter start position" do + {:ok, producer} = CounterNestedContinue.start_link({:producer, 0, [], {:continue, 500}}) + {:ok, _} = Forwarder.start_link({:consumer, self(), subscribe_to: [producer]}) + + # The nested continue sets the counter to 2000 + batch = Enum.to_list(0..499) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(500..999) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(1000..1499) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(1500..1999) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(2000..2499) + assert_receive {:consumed, ^batch} + end + + test "consumer_init with continue instruction" do + {:ok, producer} = Counter.start_link({:producer, 0, [], {:continue, 500}}) + + {:ok, _} = + Forwarder.start_link( + {:consumer, self(), [subscribe_to: [producer]], {:continue, :continue_reached}} + ) + + assert_receive :continue_reached + end + + test "producer_consumer with continue instruction" do + {:ok, producer} = Counter.start_link({:producer, 0}) + + {:ok, _doubler} = + Doubler.start_link( + {:producer_consumer, self(), + [subscribe_to: [{producer, max_demand: 100, min_demand: 80}]], + {:continue, :continue_reached}} + ) + + assert_receive :continue_reached + end + end + end + + describe "hibernate tests" do + test "producer_init with hibernate instruction" do + {:ok, producer} = Counter.start_link({:producer, 0, [], :hibernate}) + + assert :erlang.process_info(producer, :current_function) == + {:current_function, {:erlang, :hibernate, 3}} + end + + test "consumer_init with hibernate instruction" do + {:ok, producer} = Counter.start_link({:producer, 0}) + + {:ok, consumer} = + Forwarder.start_link({:consumer, self(), [subscribe_to: [producer]], :hibernate}) + + assert :erlang.process_info(consumer, :current_function) == + {:current_function, {:erlang, :hibernate, 3}} + end + + test "producer_consumer with hibernate instruction" do + {:ok, producer} = Counter.start_link({:producer, 0}) + + {:ok, doubler} = + Doubler.start_link( + {:producer_consumer, self(), + [subscribe_to: [{producer, max_demand: 100, min_demand: 80}]], :hibernate} + ) + + assert :erlang.process_info(doubler, :current_function) == + {:current_function, {:erlang, :hibernate, 3}} + end + end + describe "producer-to-consumer demand" do test "with default max and min demand" do {:ok, producer} = Counter.start_link({:producer, 0})