diff --git a/lib/mariaex/connection/ssl.ex b/lib/mariaex/connection/ssl.ex index 2055bec..cf2e1b7 100644 --- a/lib/mariaex/connection/ssl.ex +++ b/lib/mariaex/connection/ssl.ex @@ -1,31 +1,8 @@ defmodule Mariaex.Connection.Ssl do - def recv(sock, bytes, timeout), do: :ssl.recv(sock, bytes, timeout) - def recv_active(sock, timeout, buffer \\ :active_once) do - receive do - {:ssl, ^sock, buffer} -> - {:ok, buffer} - {:ssl_closed, ^sock} -> - {:disconnect, {tag(), "async_recv", :closed, buffer}} - {:ssl_error, ^sock, reason} -> - {:disconnect, {tag(), "async_recv", reason, buffer}} - after - timeout -> - {:ok, <<>>} - end - end - def tag(), do: :ssl - def fake_message(sock, buffer), do: {:ssl, sock, buffer} - - def receive(_sock, {:ssl, _, blob}), do: blob - - def setopts({:sslsocket, {:gen_tcp, sock, :tls_connection, _},_pid}, opts) do - :inet.setopts(sock, opts) - end - def send(sock, data), do: :ssl.send(sock, data) def close(sock), do: :ssl.close(sock) diff --git a/lib/mariaex/connection/tcp.ex b/lib/mariaex/connection/tcp.ex index 789e479..781d24c 100644 --- a/lib/mariaex/connection/tcp.ex +++ b/lib/mariaex/connection/tcp.ex @@ -19,28 +19,8 @@ defmodule Mariaex.Connection.Tcp do def recv(sock, bytes, timeout), do: :gen_tcp.recv(sock, bytes, timeout) - def recv_active(sock, timeout, buffer \\ :active_once) do - receive do - {:tcp, ^sock, buffer} -> - {:ok, buffer} - {:tcp_closed, ^sock} -> - {:disconnect, {tag(), "async_recv", :closed, buffer}} - {:tcp_error, ^sock, reason} -> - {:disconnect, {tag(), "async_recv", reason, buffer}} - after - timeout -> - {:ok, <<>>} - end - end - def tag(), do: :tcp - def fake_message(sock, buffer), do: {:tcp, sock, buffer} - - def receive(_sock, {:tcp, _, blob}), do: blob - - def setopts(sock, opts), do: :inet.setopts(sock, opts) - def send(sock, data), do: :gen_tcp.send(sock, data) def close(sock), do: :gen_tcp.close(sock) diff --git a/lib/mariaex/protocol.ex b/lib/mariaex/protocol.ex index 7b24c62..2d21eff 100644 --- a/lib/mariaex/protocol.ex +++ b/lib/mariaex/protocol.ex @@ -240,7 +240,7 @@ defmodule Mariaex.Protocol do {:error, error, _} -> {:error, error} {:ok, _, _, state} -> - activate(state, state.buffer) |> connected() + {:ok, %{state | buffer: state.buffer, state: :running}} end end defp handle_handshake(packet(seqnum: seqnum, msg: auth_switch(plugin: plugin, salt: salt) = _packet), nil, state = %{opts: opts}) do @@ -295,63 +295,29 @@ defmodule Mariaex.Protocol do """ def disconnect(_, state = %{sock: {sock_mod, sock}}) do msg_send(text_cmd(command: com_quit(), statement: ""), state, 0) - case msg_recv(state) do - {:ok, packet(msg: ok_resp()), _state} -> - sock_mod.close(sock) - {:ok, packet(msg: _), _state} -> - sock_mod.close(sock) - {:error, _} -> - sock_mod.close(sock) - end - _ = sock_mod.recv_active(sock, 0, "") + sock_mod.close(sock) :ok end @doc """ DBConnection callback """ - def checkout(%{buffer: :active_once, sock: {sock_mod, sock}} = s) do - case setopts(s, [active: :false], :active_once) do - :ok -> sock_mod.recv_active(sock, 0, "") |> handle_recv_buffer(s) - {:disconnect, _, _} = dis -> dis + def checkout(%{sock: {sock_mod, sock}} = s) do + case sock_mod.recv(sock, 0, 0) do + {:error, :timeout} -> {:ok, %{s | buffer: <<>>}} + {:error, description} -> do_disconnect(s, {sock_mod, "recv", description, <<>>}) + {:ok, _} -> {:ok, %{s | buffer: <<>>}} end end - defp handle_recv_buffer({:ok, buffer}, s) do - {:ok, %{s | buffer: buffer}} - end - defp handle_recv_buffer({:disconnect, description}, s) do - do_disconnect(s, description) - end @doc """ DBConnection callback """ def checkin(%{buffer: buffer} = s) when is_binary(buffer) do - activate(s, buffer) + {:ok, %{s | buffer: <<>>}} end - ## Fake [active: once] if buffer not empty - defp activate(s, <<>>) do - case setopts(s, [active: :once], <<>>) do - :ok -> {:ok, %{s | buffer: :active_once, state: :running}} - other -> other - end - end - defp activate(%{sock: {mod, sock}} = s, buffer) do - msg = mod.fake_message(sock, buffer) - send(self(), msg) - {:ok, %{s | buffer: :active_once, state: :running}} - end - - defp setopts(%{sock: {mod, sock}} = s, opts, buffer) do - case mod.setopts(sock, opts) do - :ok -> - :ok - {:error, reason} -> - do_disconnect(s, {mod, "setopts", reason, buffer}) - end - end @doc """ DBConnection callback @@ -1236,7 +1202,6 @@ defmodule Mariaex.Protocol do def msg_decode(<< len :: size(24)-little-integer, _seqnum :: size(8)-integer, message :: binary>>=header, state) when byte_size(message) >= len do - {packet, rest} = decode(header, state.state) {:ok, packet, %{state | buffer: rest}} end diff --git a/test/start_test.exs b/test/start_test.exs index 9d26445..60618cf 100644 --- a/test/start_test.exs +++ b/test/start_test.exs @@ -38,10 +38,14 @@ defmodule StartTest do backoff_type: :stop] Process.flag :trap_exit, true - assert {:error, {%Mariaex.Error{message: "failed to upgraded socket: {:tls_alert, 'unknown ca'}"}, _}} = - Mariaex.Connection.start_link(test_opts) - assert {:error, {%Mariaex.Error{message: "failed to upgraded socket: {:options, {:cacertfile, []}}"}, _}} = - Mariaex.Connection.start_link(Keyword.put(test_opts, :ssl_opts, Keyword.drop(test_opts[:ssl_opts], [:cacertfile]))) + + {:error, %Mariaex.Error{message: message}} = Mariaex.Protocol.connect(test_opts) + assert message = "failed to upgraded socket: {:tls_alert, {:unknown_ca, 'received CLIENT ALERT: Fatal - Unknown CA'}}" + + {:error, %Mariaex.Error{message: message}} = Mariaex.Protocol.connect( + Keyword.put(test_opts, :ssl_opts, Keyword.drop(test_opts[:ssl_opts], [:cacertfile])) + ) + assert message = "failed to upgraded socket: {:options, {:cacertfile, []}}" end @tag :socket diff --git a/test/test_helper.exs b/test/test_helper.exs index cd2f97b..83e3238 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -67,6 +67,7 @@ cmds = [ ~s(mysql #{mysql_connect} -e "CREATE DATABASE mariaex_test DEFAULT CHARACTER SET 'utf8' COLLATE 'utf8_general_ci';"), ~s(mysql #{mysql_connect} -e "#{create_user} 'mariaex_user'@'%' IDENTIFIED BY 'mariaex_pass';"), ~s(mysql #{mysql_connect} -e "GRANT ALL ON *.* TO 'mariaex_user'@'%' WITH GRANT OPTION"), + ~s(mysql #{mysql_connect} -e "FLUSH PRIVILEGES"), ~s(mysql --host=#{mysql_host} --port=#{mysql_port} --protocol=tcp -u mariaex_user -pmariaex_pass mariaex_test -e "#{ sql }")