diff --git a/shard.lock b/shard.lock index 942f7695d6..339b49edc2 100644 --- a/shard.lock +++ b/shard.lock @@ -16,6 +16,10 @@ shards: git: https://github.com/84codes/lz4.cr.git version: 1.0.0+git.commit.96d714f7593c66ca7425872fd26c7b1286806d3d + mqtt-protocol: + git: https://github.com/84codes/mqtt-protocol.cr.git + version: 0.2.0+git.commit.3f82ee85d029e6d0505cbe261b108e156df4e598 + systemd: git: https://github.com/84codes/systemd.cr.git version: 2.0.0 diff --git a/shard.yml b/shard.yml index 1798aacc3d..2a5ae429bf 100644 --- a/shard.yml +++ b/shard.yml @@ -32,6 +32,8 @@ dependencies: github: 84codes/systemd.cr lz4: github: 84codes/lz4.cr + mqtt-protocol: + github: 84codes/mqtt-protocol.cr development_dependencies: ameba: diff --git a/spec/mqtt_spec.cr b/spec/mqtt_spec.cr new file mode 100644 index 0000000000..c6e73d2835 --- /dev/null +++ b/spec/mqtt_spec.cr @@ -0,0 +1,51 @@ +require "spec" +require "socket" +require "./spec_helper" +require "mqtt-protocol" +require "../src/lavinmq/mqtt/connection_factory" + + +def setup_connection(s, pass) + left, right = UNIXSocket.pair + io = MQTT::Protocol::IO.new(left) + s.users.create("usr", "pass", [LavinMQ::Tag::Administrator]) + MQTT::Protocol::Connect.new("abc", false, 60u16, "usr", pass.to_slice, nil).to_io(io) + connection_factory = LavinMQ::MQTT::ConnectionFactory.new(right, + LavinMQ::ConnectionInfo.local, + s.users, + s.vhosts["/"]) + { connection_factory.start, io } +end + +describe LavinMQ do + src = "127.0.0.1" + dst = "127.0.0.1" + + it "MQTT connection should pass authentication" do + with_amqp_server do |s| + client, io = setup_connection(s, "pass") + client.should be_a(LavinMQ::MQTT::Client) + # client.close + MQTT::Protocol::Disconnect.new.to_io(io) + end + end + + it "unauthorized MQTT connection should not pass authentication" do + with_amqp_server do |s| + client, io = setup_connection(s, "pa&ss") + client.should_not be_a(LavinMQ::MQTT::Client) + # client.close + MQTT::Protocol::Disconnect.new.to_io(io) + end + end + + it "should handle a Ping" do + with_amqp_server do |s| + client, io = setup_connection(s, "pass") + client.should be_a(LavinMQ::MQTT::Client) + MQTT::Protocol::PingReq.new.to_io(io) + MQTT::Protocol::Packet.from_io(io).should be_a(MQTT::Protocol::Connack) + MQTT::Protocol::Packet.from_io(io).should be_a(MQTT::Protocol::PingResp) + end + end +end diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr index 31cdd8e76e..af2d574ab5 100644 --- a/spec/spec_helper.cr +++ b/spec/spec_helper.cr @@ -77,9 +77,9 @@ def with_amqp_server(tls = false, replicator = LavinMQ::Clustering::NoopServer.n ctx = OpenSSL::SSL::Context::Server.new ctx.certificate_chain = "spec/resources/server_certificate.pem" ctx.private_key = "spec/resources/server_key.pem" - spawn(name: "amqp tls listen") { s.listen_tls(tcp_server, ctx) } + spawn(name: "amqp tls listen") { s.listen_tls(tcp_server, ctx, "amqp") } else - spawn(name: "amqp tcp listen") { s.listen(tcp_server) } + spawn(name: "amqp tcp listen") { s.listen(tcp_server, "amqp") } end Fiber.yield yield s @@ -89,6 +89,16 @@ def with_amqp_server(tls = false, replicator = LavinMQ::Clustering::NoopServer.n end end +#do i need to do this? +# def with_mqtt_server(tls = false, & : LavinMQ::Server -> Nil) +# tcp_server = TCPServer.new("localhost", 0) +# s = LavinMQ::Server.new(LavinMQ::Config.instance.data_dir, replicator) +# begin +# if tls +# end + +# end + def with_http_server(&) with_amqp_server do |s| h = LavinMQ::HTTP::Server.new(s) diff --git a/src/lavinmq/config.cr b/src/lavinmq/config.cr index 3eafc64c87..bf1656e9f6 100644 --- a/src/lavinmq/config.cr +++ b/src/lavinmq/config.cr @@ -17,6 +17,8 @@ module LavinMQ property amqp_bind = "127.0.0.1" property amqp_port = 5672 property amqps_port = -1 + property mqtt_port = 1883 + property mqtt_bind = "127.0.0.1" property unix_path = "" property unix_proxy_protocol = 1_u8 # PROXY protocol version on unix domain socket connections property tcp_proxy_protocol = 0_u8 # PROXY protocol version on amqp tcp connections diff --git a/src/lavinmq/http/handler/websocket.cr b/src/lavinmq/http/handler/websocket.cr index 4a8fb131bd..b749807cb1 100644 --- a/src/lavinmq/http/handler/websocket.cr +++ b/src/lavinmq/http/handler/websocket.cr @@ -11,7 +11,7 @@ module LavinMQ Socket::IPAddress.new("127.0.0.1", 0) # Fake when UNIXAddress connection_info = ConnectionInfo.new(remote_address, local_address) io = WebSocketIO.new(ws) - spawn amqp_server.handle_connection(io, connection_info), name: "HandleWSconnection #{remote_address}" + spawn amqp_server.handle_connection(io, connection_info, "amqp"), name: "HandleWSconnection #{remote_address}" end end end diff --git a/src/lavinmq/http/http_server.cr b/src/lavinmq/http/http_server.cr index 0d0aeb8fcf..afe668aa27 100644 --- a/src/lavinmq/http/http_server.cr +++ b/src/lavinmq/http/http_server.cr @@ -22,7 +22,7 @@ module LavinMQ StaticController.new, ApiErrorHandler.new, AuthHandler.new(@amqp_server), - PrometheusController.new(@amqp_server), + # PrometheusController.new(@amqp_server), ApiDefaultsHandler.new, MainController.new(@amqp_server), DefinitionsController.new(@amqp_server), diff --git a/src/lavinmq/launcher.cr b/src/lavinmq/launcher.cr index 4e170e6859..501ca18b34 100644 --- a/src/lavinmq/launcher.cr +++ b/src/lavinmq/launcher.cr @@ -115,13 +115,13 @@ module LavinMQ private def listen if @config.amqp_port > 0 - spawn @amqp_server.listen(@config.amqp_bind, @config.amqp_port), + spawn @amqp_server.listen(@config.amqp_bind, @config.amqp_port, :amqp), name: "AMQP listening on #{@config.amqp_port}" end if @config.amqps_port > 0 if ctx = @tls_context - spawn @amqp_server.listen_tls(@config.amqp_bind, @config.amqps_port, ctx), + spawn @amqp_server.listen_tls(@config.amqp_bind, @config.amqps_port, ctx, :amqp), name: "AMQPS listening on #{@config.amqps_port}" end end @@ -131,7 +131,7 @@ module LavinMQ end unless @config.unix_path.empty? - spawn @amqp_server.listen_unix(@config.unix_path), name: "AMQP listening at #{@config.unix_path}" + spawn @amqp_server.listen_unix(@config.unix_path, :amqp), name: "AMQP listening at #{@config.unix_path}" end if @config.http_port > 0 @@ -150,6 +150,11 @@ module LavinMQ spawn(name: "HTTP listener") do @http_server.not_nil!.listen end + + if @config.mqtt_port > 0 + spawn @amqp_server.listen(@config.mqtt_bind, @config.mqtt_port, :mqtt), + name: "MQTT listening on #{@config.mqtt_port}" + end end private def dump_debug_info diff --git a/src/lavinmq/mqtt/client.cr b/src/lavinmq/mqtt/client.cr new file mode 100644 index 0000000000..333f6a6e60 --- /dev/null +++ b/src/lavinmq/mqtt/client.cr @@ -0,0 +1,96 @@ +require "openssl" +require "socket" +require "../client" +require "../error" + +module LavinMQ + module MQTT + class Client < LavinMQ::Client + include Stats + include SortableJSON + + getter vhost, channels, log, name, user + Log = ::Log.for "MQTT.client" + rate_stats({"send_oct", "recv_oct"}) + + def initialize(@socket : ::IO, + @connection_info : ConnectionInfo, + @vhost : VHost, + @user : User) + @io = MQTT::IO.new(@socket) + @lock = Mutex.new + @remote_address = @connection_info.src + @local_address = @connection_info.dst + @metadata = ::Log::Metadata.new(nil, {vhost: @vhost.name, address: @remote_address.to_s}) + @log = Logger.new(Log, @metadata) + @channels = Hash(UInt16, Client::Channel).new + @vhost.add_connection(self) + spawn read_loop + connection_name = "#{@remote_address} -> #{@local_address}" + @name = "#{@remote_address} -> #{@local_address}" + end + + private def read_loop + loop do + Log.trace { "waiting for packet" } + packet = read_and_handle_packet + # The disconnect packet has been handled and the socket has been closed. + # If we dont breakt the loop here we'll get a IO/Error on next read. + break if packet.is_a?(MQTT::Disconnect) + end + rescue ex : MQTT::Error::Connect + Log.warn { "Connect error #{ex.inspect}" } + ensure + @socket.close + @vhost.rm_connection(self) + end + + def read_and_handle_packet + packet : MQTT::Packet = MQTT::Packet.from_io(@io) + Log.info { "recv #{packet.inspect}" } + + case packet + when MQTT::Publish then pp "publish" + when MQTT::PubAck then pp "puback" + when MQTT::Subscribe then pp "subscribe" + when MQTT::Unsubscribe then pp "unsubscribe" + when MQTT::PingReq then receive_pingreq(packet) + when MQTT::Disconnect then return packet + else raise "invalid packet type for client to send" + end + packet + end + + private def send(packet) + @lock.synchronize do + packet.to_io(@io) + @socket.flush + end + # @broker.increment_bytes_sent(packet.bytesize) + # @broker.increment_messages_sent + # @broker.increment_publish_sent if packet.is_a?(MQTT::Protocol::Publish) + end + + def receive_pingreq(packet : MQTT::PingReq) + send(MQTT::PingResp.new) + end + + def details_tuple + { + vhost: @vhost.name, + user: @user.name, + protocol: "MQTT", + }.merge(stats_details) + end + + def update_rates + end + + def close(reason) + end + + def force_close + end + end + end +end diff --git a/src/lavinmq/mqtt/connection_factory.cr b/src/lavinmq/mqtt/connection_factory.cr new file mode 100644 index 0000000000..af458850f1 --- /dev/null +++ b/src/lavinmq/mqtt/connection_factory.cr @@ -0,0 +1,47 @@ +require "socket" +require "./protocol" +require "log" +require "./client" +require "../vhost" +require "../user" + +module LavinMQ + module MQTT + class ConnectionFactory + def initialize(@socket : ::IO, + @connection_info : ConnectionInfo, + @users : UserStore, + @vhost : VHost) + end + + def start + io = ::MQTT::Protocol::IO.new(@socket) + if packet = MQTT::Packet.from_io(@socket).as?(MQTT::Connect) + Log.trace { "recv #{packet.inspect}" } + if user = authenticate(io, packet, @users) + ::MQTT::Protocol::Connack.new(false, ::MQTT::Protocol::Connack::ReturnCode::Accepted).to_io(io) + io.flush + return LavinMQ::MQTT::Client.new(@socket, @connection_info, @vhost, user) + end + end + rescue ex + Log.warn { "Recieved the wrong packet" } + @socket.close + end + + def authenticate(io, packet, users) + return nil unless (username = packet.username) && (password = packet.password) + user = users[username]? + return user if user && user.password && user.password.not_nil!.verify(String.new(password)) + #probably not good to differentiate between user not found and wrong password + if user.nil? + Log.warn { "User \"#{username}\" not found" } + else + Log.warn { "Authentication failure for user \"#{username}\"" } + end + ::MQTT::Protocol::Connack.new(false, ::MQTT::Protocol::Connack::ReturnCode::NotAuthorized).to_io(io) + nil + end + end + end +end diff --git a/src/lavinmq/mqtt/protocol.cr b/src/lavinmq/mqtt/protocol.cr new file mode 100644 index 0000000000..9349f9560f --- /dev/null +++ b/src/lavinmq/mqtt/protocol.cr @@ -0,0 +1,7 @@ +require "mqtt-protocol" + +module LavinMQ + module MQTT + include ::MQTT::Protocol + end +end diff --git a/src/lavinmq/server.cr b/src/lavinmq/server.cr index ebffc78d32..659c18cf95 100644 --- a/src/lavinmq/server.cr +++ b/src/lavinmq/server.cr @@ -2,6 +2,7 @@ require "socket" require "openssl" require "systemd" require "./amqp" +require "./mqtt/protocol" require "./rough_time" require "../stdlib/*" require "./vhost_store" @@ -15,6 +16,7 @@ require "./proxy_protocol" require "./client/client" require "./client/connection_factory" require "./amqp/connection_factory" +require "./mqtt/connection_factory" require "./stats" module LavinMQ @@ -74,8 +76,8 @@ module LavinMQ Iterator(Client).chain(@vhosts.each_value.map(&.connections.each)) end - def listen(s : TCPServer) - @listeners[s] = :amqp + def listen(s : TCPServer, protocol) + @listeners[s] = :protocol Log.info { "Listening on #{s.local_address}" } loop do client = s.accept? || break @@ -85,7 +87,7 @@ module LavinMQ set_socket_options(client) set_buffer_size(client) conn_info = extract_conn_info(client) - handle_connection(client, conn_info) + handle_connection(client, conn_info, protocol) rescue ex Log.warn(exception: ex) { "Error accepting connection from #{remote_address}" } client.close rescue nil @@ -117,8 +119,8 @@ module LavinMQ end end - def listen(s : UNIXServer) - @listeners[s] = :amqp + def listen(s : UNIXServer, protocol) + @listeners[s] = :protocol Log.info { "Listening on #{s.local_address}" } loop do # do not try to use while client = s.accept? || break @@ -132,7 +134,7 @@ module LavinMQ when 2 then ProxyProtocol::V2.parse(client) else ConnectionInfo.local # TODO: use unix socket address, don't fake local end - handle_connection(client, conn_info) + handle_connection(client, conn_info, protocol) rescue ex Log.warn(exception: ex) { "Error accepting connection from #{remote_address}" } client.close rescue nil @@ -144,13 +146,13 @@ module LavinMQ @listeners.delete(s) end - def listen(bind = "::", port = 5672) + def listen(bind = "::", port = 5672, protocol = :amqp) s = TCPServer.new(bind, port) - listen(s) + listen(s, protocol) end - def listen_tls(s : TCPServer, context) - @listeners[s] = :amqps + def listen_tls(s : TCPServer, context, protocol) + @listeners[s] = :protocol Log.info { "Listening on #{s.local_address} (TLS)" } loop do # do not try to use while client = s.accept? || break @@ -165,7 +167,7 @@ module LavinMQ conn_info.ssl = true conn_info.ssl_version = ssl_client.tls_version conn_info.ssl_cipher = ssl_client.cipher - handle_connection(ssl_client, conn_info) + handle_connection(ssl_client, conn_info, protocol) rescue ex Log.warn(exception: ex) { "Error accepting TLS connection from #{remote_addr}" } client.close rescue nil @@ -177,15 +179,15 @@ module LavinMQ @listeners.delete(s) end - def listen_tls(bind, port, context) - listen_tls(TCPServer.new(bind, port), context) + def listen_tls(bind, port, context, protocol) + listen_tls(TCPServer.new(bind, port), context, protocol) end - def listen_unix(path : String) + def listen_unix(path : String, protocol) File.delete?(path) s = UNIXServer.new(path) File.chmod(path, 0o666) - listen(s) + listen(s, protocol) end def listen_clustering(bind, port) @@ -241,8 +243,17 @@ module LavinMQ end end - def handle_connection(socket, connection_info) - client = @amqp_connection_factory.start(socket, connection_info, @vhosts, @users) + def handle_connection(socket, connection_info, protocol) + case protocol + when :amqp + client = @amqp_connection_factory.start(socket, connection_info, @vhosts, @users) + when :mqtt + client = MQTT::ConnectionFactory.new(socket, connection_info, @users, @vhosts["/"]).start + else + Log.warn { "Unknown protocol '#{protocol}'" } + socket.close + end + ensure socket.close if client.nil? end diff --git a/static/js/connections.js b/static/js/connections.js index 611d4c5065..06b58c87b1 100644 --- a/static/js/connections.js +++ b/static/js/connections.js @@ -19,13 +19,14 @@ Table.renderTable('table', tableOptions, function (tr, item, all) { if (all) { const connectionLink = document.createElement('a') connectionLink.href = `connection#name=${encodeURIComponent(item.name)}` - if (item.client_properties.connection_name) { - connectionLink.appendChild(document.createElement('span')).textContent = item.name - connectionLink.appendChild(document.createElement('br')) - connectionLink.appendChild(document.createElement('small')).textContent = item.client_properties.connection_name - } else { + console.log(item) + // if (item.client_properties.connection_name) { + // connectionLink.appendChild(document.createElement('span')).textContent = item.name + // connectionLink.appendChild(document.createElement('br')) + // connectionLink.appendChild(document.createElement('small')).textContent = item.client_properties.connection_name + // } else { connectionLink.textContent = item.name - } + // } Table.renderCell(tr, 0, item.vhost) Table.renderCell(tr, 1, connectionLink) Table.renderCell(tr, 2, item.user) @@ -37,9 +38,9 @@ Table.renderTable('table', tableOptions, function (tr, item, all) { Table.renderCell(tr, 10, item.timeout, 'right') // Table.renderCell(tr, 8, item.auth_mechanism) const clientDiv = document.createElement('span') - clientDiv.textContent = `${item.client_properties.product} / ${item.client_properties.platform || ''}` + // clientDiv.textContent = `${item.client_properties.product} / ${item.client_properties.platform || ''}` clientDiv.appendChild(document.createElement('br')) - clientDiv.appendChild(document.createElement('small')).textContent = item.client_properties.version + // clientDiv.appendChild(document.createElement('small')).textContent = item.client_properties.version Table.renderCell(tr, 11, clientDiv) Table.renderCell(tr, 12, new Date(item.connected_at).toLocaleString(), 'center') }