diff --git a/drivers/SmartThings/jbl/config.yml b/drivers/SmartThings/jbl/config.yml new file mode 100644 index 0000000000..0d3c374e74 --- /dev/null +++ b/drivers/SmartThings/jbl/config.yml @@ -0,0 +1,7 @@ +name: 'JBL' +packageKey: 'jbl' +permissions: + lan: {} + discovery: {} +description: "SmartThings driver for JBL devices" +vendorSupportInformation: "https://support.smartthings.com" \ No newline at end of file diff --git a/drivers/SmartThings/jbl/profiles/jbl.yml b/drivers/SmartThings/jbl/profiles/jbl.yml new file mode 100644 index 0000000000..143c90634f --- /dev/null +++ b/drivers/SmartThings/jbl/profiles/jbl.yml @@ -0,0 +1,30 @@ +name: jbl +components: +- id: main + capabilities: + - id: mediaPlayback + version: 1 + config: + values: + - key: "playbackStatus.value" + enabledValues: + - 'playing' + - 'paused' + - key: "{{enumCommands}}" + enabledValues: + - 'play' + - 'pause' + - id: mediaTrackControl + version: 1 + - id: audioMute + version: 1 + - id: audioVolume + version: 1 + - id: audioTrackData + version: 1 + - id: refresh + version: 1 + - id: audioNotification + version: 1 + categories: + - name: Speaker diff --git a/drivers/SmartThings/jbl/search-parameters.yml b/drivers/SmartThings/jbl/search-parameters.yml new file mode 100644 index 0000000000..b17e1d94b7 --- /dev/null +++ b/drivers/SmartThings/jbl/search-parameters.yml @@ -0,0 +1,3 @@ +--- +mdns: + - service: "_jbl._tcp" diff --git a/drivers/SmartThings/jbl/src/discovery.lua b/drivers/SmartThings/jbl/src/discovery.lua new file mode 100644 index 0000000000..7549770e7f --- /dev/null +++ b/drivers/SmartThings/jbl/src/discovery.lua @@ -0,0 +1,104 @@ +local log = require "log" +local discovery = {} + +local fields = require "fields" +local discovery_mdns = require "discovery_mdns" + +local socket = require "cosock.socket" + +-- mapping from device DNI to info needed at discovery/init time +local device_discovery_cache = {} + +local function set_device_field(driver, device) + + log.info(string.format("set_device_field : %s", device.device_network_id)) + local device_cache_value = device_discovery_cache[device.device_network_id] + + -- persistent fields + device:set_field(fields.DEVICE_IPV4, device_cache_value.ip, {persist = true}) + device:set_field(fields.DEVICE_INFO, device_cache_value.device_info , {persist = true}) + device:set_field(fields.CREDENTIAL, device_cache_value.credential , {persist = true}) +end + +local function update_device_discovery_cache(driver, dni, ip, credential) + log.info(string.format("update_device_discovery_cache for device dni: %s, %s", dni, ip)) + local device_info = driver.discovery_helper.get_device_info(driver, dni, ip) + device_discovery_cache[dni] = { + ip = ip, + device_info = device_info, + credential = credential, + } +end + +local function try_add_device(driver, device_dni, device_ip) + log.trace(string.format("try_add_device : dni=%s, ip=%s", device_dni, device_ip)) + + local credential = driver.discovery_helper.get_credential(driver, device_dni, device_ip) + + if not credential then + log.error(string.format("failed to get credential. dni=%s, ip=%s", device_dni, device_ip)) + return + end + + update_device_discovery_cache(driver, device_dni, device_ip, credential) + local create_device_msg = driver.discovery_helper.get_device_create_msg(driver, device_dni, device_ip) + driver:try_create_device(create_device_msg) +end + +function discovery.device_added(driver, device) + log.info("device_added : dni = " .. tostring(device.device_network_id)) + set_device_field(driver, device) + device_discovery_cache[device.device_network_id] = nil + driver.lifecycle_handlers.init(driver, device) +end + +function discovery.find_ip_table(driver) + local ip_table= discovery_mdns.find_ip_table_by_mdns(driver) + return ip_table +end + + +local function discovery_device(driver) + local unknown_discovered_devices = {} + local known_discovered_devices = {} + local known_devices = {} + + for _, device in pairs(driver:get_devices()) do + known_devices[device.device_network_id] = device + end + + local ip_table = discovery.find_ip_table(driver) + + for dni, ip in pairs(ip_table) do + log.info(string.format("discovery_device dni, ip = %s, %s", dni, ip)) + if not known_devices or not known_devices[dni] then + unknown_discovered_devices[dni] = ip + else + known_discovered_devices[dni] = ip + end + end + + for dni, ip in pairs(known_discovered_devices) do + log.trace(string.format("known dni=%s, ip=%s", dni, ip)) + end + + for dni, ip in pairs(unknown_discovered_devices) do + log.trace(string.format("unknown dni=%s, ip=%s", dni, ip)) + if not device_discovery_cache[dni] then + try_add_device(driver, dni, ip) + end + end + +end + +function discovery.do_network_discovery(driver, _, should_continue) + log.info("discovery.do_network_discovery :Starting mDNS discovery") + + while should_continue() do + discovery_device(driver) + socket.sleep(0.2) + end + log.info("discovery.do_network_discovery: Ending mDNS discovery") +end + +return discovery diff --git a/drivers/SmartThings/jbl/src/discovery_mdns.lua b/drivers/SmartThings/jbl/src/discovery_mdns.lua new file mode 100644 index 0000000000..febd14442f --- /dev/null +++ b/drivers/SmartThings/jbl/src/discovery_mdns.lua @@ -0,0 +1,150 @@ +local log = require "log" +local mdns = require "st.mdns" +local net_utils = require "st.net_utils" + +local discovery_mdns = {} + +local function byte_array_to_plain_text(byte_array) + local str = "" + for _, value in pairs(byte_array) do + str = str .. string.char(value) + end + return str +end + +local function get_text_by_srvname(srvname, discovery_responses) + for _,answer_item in pairs(discovery_responses.answers or {}) do + if answer_item.kind.TxtRecord ~= nil and answer_item.name == srvname then + return answer_item.kind.TxtRecord.text + end + end +end + +local function get_srvname_by_hostname(hostname, discovery_responses) + for _,answer_item in pairs(discovery_responses.answers or {}) do + if answer_item.kind.SrvRecord ~= nil and answer_item.kind.SrvRecord.target == hostname then + return answer_item.name + end + end +end + +local function get_hostname_by_ip(ip, discovery_responses) + for _,answer_item in pairs(discovery_responses.answers or {}) do + if answer_item.kind.ARecord ~= nil and answer_item.kind.ARecord.ipv4 == ip then + return answer_item.name + end + end +end + + +local function find_text_in_answers_by_ip(ip, discovery_responses) + local hostname = get_hostname_by_ip(ip, discovery_responses) + local srvname = get_srvname_by_hostname(hostname, discovery_responses) + local text = get_text_by_srvname(srvname,discovery_responses) + + return text +end + +function discovery_mdns.find_text_list_in_mdns_response(driver, ip, discovery_responses) + local text_list = {} + + for _, found_item in pairs(discovery_responses.found or {}) do + if found_item.host_info.address == ip then + for _, raw_text_array in pairs(found_item.txt.text or {}) do + local text_item = byte_array_to_plain_text(raw_text_array) + table.insert(text_list, text_item) + end + end + end + + local answer_text = find_text_in_answers_by_ip(ip, discovery_responses) + for _, text_item in pairs(answer_text or {}) do + table.insert(text_list, text_item) + end + return text_list +end + +local function filter_response_by_servie_name(service_type, domain, discovery_responses) + local filtered_responses = { + answers = {}, + found = {} + } + + for _, answer in pairs(discovery_responses.answers or {}) do + table.insert(filtered_responses.answers, answer) + end + + for _, additional in pairs(discovery_responses.additional or {}) do + table.insert(filtered_responses.answers, additional) + end + + for _, found in pairs(discovery_responses.found or {}) do + if found.service_info.service_type == service_type then + table.insert(filtered_responses.found, found) + end + end + + return filtered_responses +end + +local function insert_dni_ip_from_answers(driver, filtered_responses, target_table) + for _, answer in pairs(filtered_responses.answers) do + local dni, ip + log.info("answer_name, arecod = " .. tostring(answer.name) .. ", " .. tostring(answer.kind.ARecord)) + + if answer.kind.ARecord ~= nil then + ip = answer.kind.ARecord.ipv4 + end + + if ip ~= nil then + dni = driver.discovery_helper.get_dni(driver, ip, filtered_responses) + + if dni ~= nil then + target_table[dni] = ip + end + end + end +end + +local function insert_dni_ip_from_found(driver, filtered_responses, target_table) + for _, found in pairs(filtered_responses.found) do + local dni, ip + log.info("found_name = " .. tostring(found.service_info.service_type)) + if found.host_info.address ~= nil and net_utils.validate_ipv4_string(found.host_info.address) then + log.info("ip = " .. tostring(found.host_info.address)) + ip = found.host_info.address + end + + if ip ~= nil then + dni = driver.discovery_helper.get_dni(driver, ip, filtered_responses) + + if dni ~= nil then + target_table[dni] = ip + end + end + end +end + +local function get_dni_ip_table_from_mdns_responses(driver, service_type, domain, discovery_responses) + local dni_ip_table = {} + + local filtered_responses = filter_response_by_servie_name(service_type, domain, discovery_responses) + + insert_dni_ip_from_answers(driver, filtered_responses, dni_ip_table) + insert_dni_ip_from_found(driver, filtered_responses, dni_ip_table) + + return dni_ip_table +end + +function discovery_mdns.find_ip_table_by_mdns(driver) + log.info("discovery_mdns.find_device_ips") + + local service_type, domain = driver.discovery_helper.get_service_type_and_domain() + local discovery_responses = mdns.discover(service_type, domain) or {found = {}} + + local dni_ip_table = get_dni_ip_table_from_mdns_responses(driver, service_type, domain, discovery_responses) + + return dni_ip_table +end + +return discovery_mdns \ No newline at end of file diff --git a/drivers/SmartThings/jbl/src/fields.lua b/drivers/SmartThings/jbl/src/fields.lua new file mode 100644 index 0000000000..2df3fb0708 --- /dev/null +++ b/drivers/SmartThings/jbl/src/fields.lua @@ -0,0 +1,16 @@ +--- Table of constants used to index in to device store fields +--- @module "fields" +--- @class table +--- @field IPV4 string the ipV4 address of the device + +local fields = { + DEVICE_IPV4 = "device_ipv4", + DEVICE_INFO = "device_info", + CONN_INFO = "conn_info", + EVENT_SOURCE = "eventsource", + MONITORING_TIMER = "monitoring_timer", + CREDENTIAL = "credential", + _INIT = "init" +} + +return fields diff --git a/drivers/SmartThings/jbl/src/init.lua b/drivers/SmartThings/jbl/src/init.lua new file mode 100644 index 0000000000..08b59d992f --- /dev/null +++ b/drivers/SmartThings/jbl/src/init.lua @@ -0,0 +1,221 @@ +-- Copyright 2022 SmartThings +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +-- except in compliance with the License. You may obtain a copy of the License at: +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software distributed under the +-- License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +-- either express or implied. See the License for the specific language governing permissions +-- and limitations under the License. +-- +-- =============================================================================================== + +local log = require "log" + +local capabilities = require "st.capabilities" +local Driver = require "st.driver" + +local discovery = require "discovery" +local fields = require "fields" + +local jbl_discovery_helper = require "jbl.discovery_helper" +local jbl_device_manager = require "jbl.device_manager" +local jbl_capability_handler = require "jbl.capability_handler" + +local EventSource = require "lunchbox.sse.eventsource" + +local DEFAULT_MONITORING_INTERVAL = 300 +local CREDENTIAL_KEY_HEADER = "Authorization" + +local function handle_sse_event(driver, device, msg) + driver.device_manager.handle_sse_event(driver, device, msg.type, msg.data) +end + +local function create_sse(driver, device, credential) + log.info("create_sse : dni = " .. tostring(device.device_network_id)) + local conn_info = device:get_field(fields.CONN_INFO) + + if not driver.device_manager.is_valid_connection(driver, device, conn_info) then + log.error("create_sse : invalid connection") + return + end + + local sse_url = driver.device_manager.get_sse_url(driver, device, conn_info) + if not sse_url then + log.error("failed to get sse_url") + else + log.trace("Creating SSE EventSource for " .. device.device_network_id .. ", sse_url = " .. sse_url) + local eventsource = EventSource.new(sse_url, {[CREDENTIAL_KEY_HEADER] = credential}, nil) + + eventsource.onmessage = function(msg) + if msg then + handle_sse_event(driver, device, msg) + end + end + + eventsource.onerror = function(msg) + log.error("Eventsource error: dni = " .. tostring(device.device_network_id) .. ", Error=" .. tostring(msg)) + device:offline() + end + + eventsource.onopen = function(msg) + log.info("Eventsource open: dni = " .. tostring(device.device_network_id)) + device:online() + end + + local old_eventsource = device:get_field(fields.EVENT_SOURCE) + if old_eventsource then + log.info("Eventsource Close: dni = " .. tostring(device.device_network_id)) + old_eventsource:close() + end + device:set_field(fields.EVENT_SOURCE, eventsource) + end +end + +local function update_connection(driver, device, device_ip, device_info) + local device_dni = device.device_network_id + log.info("update connection, dni = " .. tostring(device_dni)) + + local conn_info = driver.discovery_helper.get_connection_info(driver, device_dni, device_ip, device_info) + + local credential = device:get_field(fields.CREDENTIAL) + + conn_info:add_header(CREDENTIAL_KEY_HEADER, credential) + + if driver.device_manager.is_valid_connection(driver, device, conn_info) then + device:set_field(fields.CONN_INFO, conn_info) + + create_sse(driver,device, credential) + end +end + + +local function find_new_connetion(driver, device) + log.info("find new connection for dni=" .. tostring(device.device_network_id)) + local ip_table = discovery.find_ip_table(driver) + local ip = ip_table[device.device_network_id] + if ip then + device:set_field(fields.DEVICE_IPV4, ip, {persist = true}) + local device_info = device:get_field(fields.DEVICE_INFO) + update_connection(driver, device, ip, device_info) + end +end + +local function check_and_update_connection(driver, device) + local conn_info = device:get_field(fields.CONN_INFO) + if not driver.device_manager.is_valid_connection(driver, device, conn_info) then + device:offline() + find_new_connetion(driver, device) + conn_info = device:get_field(fields.CONN_INFO) + end + + if driver.device_manager.is_valid_connection(driver, device, conn_info) then + device:online() + end +end + +local function create_monitoring_thread(driver, device, device_info) + local old_timer = device:get_field(fields.MONITORING_TIMER) + if old_timer ~= nil then + log.info("monitoring_timer: dni=" .. device.device_network_id .. ", remove old timer") + device.thread:cancel_timer(old_timer) + end + + local monitoring_interval = DEFAULT_MONITORING_INTERVAL + + log.info("create_monitoring_thread: dni=" .. device.device_network_id) + local new_timer = device.thread:call_on_schedule(monitoring_interval, function() + check_and_update_connection(driver, device) + driver.device_manager.device_monitor(driver, device, device_info) + end, "monitor_timer") + device:set_field(fields.MONITORING_TIMER, new_timer) +end + +local function refresh(driver, device, cmd) + log.info("refresh : dni = " .. tostring(device.device_network_id)) + check_and_update_connection(driver, device) + driver.device_manager.refresh(driver, device) +end + +local function device_removed(driver, device) + log.info("device_removed : dni = " .. tostring(device.device_network_id)) + local eventsource = device:get_field(fields.EVENT_SOURCE) + if eventsource then + log.info("Eventsource Close: dni = " .. tostring(device.device_network_id)) + eventsource:close() + end +end + +local function device_init(driver, device) + log.info("device_init : dni = " .. tostring(device.device_network_id)) + + if device:get_field(fields._INIT) then + log.info(string.format("device_init : already initialized. dni = %s", device.device_network_id)) + return + end + + local device_dni = device.device_network_id + + driver.controlled_devices[device_dni] = device + + local device_ip = device:get_field(fields.DEVICE_IPV4) + local device_info = device:get_field(fields.DEVICE_INFO) + local credential = device:get_field(fields.CREDENTIAL) + + if not credential then + log.error("failed to find credential.") + device:offline() + return + end + + log.trace("Creating device monitoring for " .. device.device_network_id) + create_monitoring_thread(driver, device, device_info) + + update_connection(driver, device, device_ip, device_info) + + refresh(driver, device, nil) + device:set_field(fields._INIT, true, { persist = false }) +end + +local lan_driver = Driver("jbl", + { + discovery = discovery.do_network_discovery, + lifecycle_handlers = {added = discovery.device_added, init = device_init, removed = device_removed}, + capability_handlers = { + [capabilities.refresh.ID] = { + [capabilities.refresh.commands.refresh.NAME] = refresh, + }, + [capabilities.audioMute.ID] = { + [capabilities.audioMute.commands.setMute.NAME] = jbl_capability_handler.set_mute_handler, + [capabilities.audioMute.commands.mute.NAME] = jbl_capability_handler.mute_handler, + [capabilities.audioMute.commands.unmute.NAME] = jbl_capability_handler.unmute_handler, + }, + [capabilities.audioVolume.ID] = { + [capabilities.audioVolume.commands.setVolume.NAME] = jbl_capability_handler.set_volume_handler, + }, + [capabilities.mediaTrackControl.ID] = { + [capabilities.mediaTrackControl.commands.nextTrack.NAME] = jbl_capability_handler.next_track_handler, + [capabilities.mediaTrackControl.commands.previousTrack.NAME] = jbl_capability_handler.previous_track_handler, + }, + [capabilities.mediaPlayback.ID] = { + [capabilities.mediaPlayback.commands.play.NAME] = jbl_capability_handler.playback_play_handler, + [capabilities.mediaPlayback.commands.pause.NAME] = jbl_capability_handler.playback_pause_handler, + }, + [capabilities.audioNotification.ID] = { + [capabilities.audioNotification.commands.playTrack.NAME] = jbl_capability_handler.audioNotification_handler, + [capabilities.audioNotification.commands.playTrackAndRestore.NAME] = jbl_capability_handler.audioNotification_handler, + [capabilities.audioNotification.commands.playTrackAndResume.NAME] = jbl_capability_handler.audioNotification_handler, + }, + }, + + discovery_helper = jbl_discovery_helper, + device_manager = jbl_device_manager, + controlled_devices = {}, + } +) + +log.info("Starting lan driver") +lan_driver:run() +log.warn("lan driver exiting") \ No newline at end of file diff --git a/drivers/SmartThings/jbl/src/jbl/api.lua b/drivers/SmartThings/jbl/src/jbl/api.lua new file mode 100644 index 0000000000..2487818ec5 --- /dev/null +++ b/drivers/SmartThings/jbl/src/jbl/api.lua @@ -0,0 +1,138 @@ +local log = require "log" +local json = require "st.json" +local RestClient = require "lunchbox.rest" +local utils = require "utils" +local cosock = require "cosock" + +local jbl_api = {} +jbl_api.__index = jbl_api + +local CREDENTIAL_TIME_OUT_SECONDS = 30 + +local SSL_CONFIG = { + mode = "client", + protocol = "any", + verify = "peer", + options = "all", + cafile="./selfSignedRoot.crt" +} + +local ADDITIONAL_HEADERS = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", +} + +function jbl_api.labeled_socket_builder(label) + local socket_builder = utils.labeled_socket_builder(label, SSL_CONFIG) + return socket_builder +end + +local function get_base_url(device_ip) + return "https://" .. device_ip .. ":4443" +end + +local function process_rest_response(response, err, partial) + if err ~= nil then + return response, err, nil + elseif response ~= nil then + local _, decoded_json = pcall(json.decode, response:get_body()) + return decoded_json, nil, response.status + else + return nil, "no response or error received", nil + end + end + +local function retry_fn(retry_attempts) + local count = 0 + return function() + count = count + 1 + return count < retry_attempts + end +end + +local function do_get(api_instance, path) + return process_rest_response(api_instance.client:get(path, api_instance.headers, retry_fn(5))) +end + +local function do_post(api_instance, path, payload) + return process_rest_response(api_instance.client:post(path, payload, api_instance.headers, retry_fn(5))) +end + + + +function jbl_api.new_device_manager(bridge_ip, bridge_info, socket_builder) + local base_url = get_base_url(bridge_ip) + + return setmetatable( + { + headers = ADDITIONAL_HEADERS, + client = RestClient.new(base_url, socket_builder), + base_url = base_url, + }, jbl_api + ) +end + +function jbl_api:add_header(key, value) + log.info("add_header : " .. key .. ", " .. value) + self.headers[key] = value +end + +function jbl_api.get_credential(bridge_ip, socket_builder) + local start_time = cosock.socket.gettime() + local timeout_time = start_time + CREDENTIAL_TIME_OUT_SECONDS + log.info("get_credential : start (" .. tostring(start_time) .. "), timeout (" .. tostring(timeout_time) .. ")") + + while true do + local response, error, status = process_rest_response(RestClient.one_shot_get(get_base_url(bridge_ip) .. "/authcode", ADDITIONAL_HEADERS, socket_builder)) + local now = cosock.socket.gettime() + + if now > timeout_time then + log.error("get_credential take too long time : now(" .. tostring(now) .. ") > timeout (" .. tostring(timeout_time) .. ")") + return nil + end + + if not error and status == 200 then + local token = response + return token + end + + cosock.socket.sleep(1) + end +end + +function jbl_api.get_info(device_ip, socket_builder) + return process_rest_response(RestClient.one_shot_get(get_base_url(device_ip) .. "/info", ADDITIONAL_HEADERS, socket_builder)) +end + +function jbl_api:get_status() + return do_get(self, "/status") +end + +function jbl_api:get_volume() + return do_get(self, "/volume") +end + +function jbl_api:post_volume(payload) + log.info("post volume : payload = " .. payload) + return do_post(self, "/volume", payload) +end + +function jbl_api:post_playback_uri(payload) + log.info("post playback_uri : payload = " .. payload) + return do_post(self, "/playbackUri", payload) +end + +function jbl_api:get_playback() + return do_get(self, "/playback") +end + +function jbl_api:post_playback(payload) + log.info("post playback : payload = " .. payload) + return do_post(self, "/playback", payload) +end + +function jbl_api:get_sse_url() + return self.base_url .. "/events" +end + +return jbl_api \ No newline at end of file diff --git a/drivers/SmartThings/jbl/src/jbl/capability_handler.lua b/drivers/SmartThings/jbl/src/jbl/capability_handler.lua new file mode 100644 index 0000000000..7595e86b86 --- /dev/null +++ b/drivers/SmartThings/jbl/src/jbl/capability_handler.lua @@ -0,0 +1,135 @@ +local log = require "log" +local json = require "st.json" +local fields = require "fields" + +local capability_handler = {} +capability_handler.__index = capability_handler + +local function smartthings_playback_capability_handler(driver, device, capability_status) + local st_status_to_jbl_playback_status_table = { + paused = "pause", + playing = "play", + } + + local conn_info = device:get_field(fields.CONN_INFO) + log.info(string.format("media-playback.set_playback_status_handler : dni = %s, status = %s", device.device_network_id, capability_status)) + + local jbl_playback_status = st_status_to_jbl_playback_status_table[capability_status] + + local _, err, status = conn_info:post_playback(string.format('{"playback": "%s"}', jbl_playback_status)) + if not err and status == 200 then + log.info(string.format("post_playback success, dni = %s", device.device_network_id)) + elseif status == 404 then + log.error("404 error. delete device. dni = " .. tostring(device.device_network_id)) + device:offline() + end +end + +function capability_handler.playback_play_handler(driver, device, args) + smartthings_playback_capability_handler(driver, device, "playing") +end + +function capability_handler.playback_pause_handler(driver, device, args) + smartthings_playback_capability_handler(driver, device, "paused") +end + +function capability_handler.next_track_handler(driver, device, args) + local conn_info = device:get_field(fields.CONN_INFO) + log.info(string.format("media_track_control.next_track_handler : dni = %s", device.device_network_id)) + + local _, err, status = conn_info:post_playback('{"playback": "next track"}') + if err or status == 404 then + log.error("media_track_control.next_track_handler : 404 error. delete device. dni = " .. tostring(device.device_network_id)) + device:offline() + end +end + +function capability_handler.previous_track_handler(driver, device, args) + local conn_info = device:get_field(fields.CONN_INFO) + log.info(string.format("media_track_control.previous_track_handler : dni = %s", device.device_network_id)) + + local _, err, status = conn_info:post_playback('{"playback": "previous track"}') + if err or status == 404 then + log.error("media_track_control.previous_track_handler : 404 error. delete device. dni = " .. tostring(device.device_network_id)) + device:offline() + end +end + +function capability_handler.set_volume_handler(driver, device, args) + local volume = args.args.volume + local conn_info = device:get_field(fields.CONN_INFO) + log.info(string.format("audio_volume.set_volume : dni = %s, volume = %d", device.device_network_id, volume)) + + + local _, err, status = conn_info:post_volume(string.format('{"volume": %d}', volume)) + if not err and status == 200 then + log.info(string.format("post_volume success, dni = %s", device.device_network_id)) + elseif status == 404 then + log.error("set_volume_handler : 404 error. delete device. dni = " .. tostring(device.device_network_id)) + device:offline() + end +end + +function capability_handler.audioNotification_handler(driver, device, args) + local uri = args.args.uri + local level = args.args.level + local conn_info = device:get_field(fields.CONN_INFO) + + log.info(string.format("%s, %s : level = %s", args.capability, args.command, device.device_network_id, level)) + log.info(string.format("URI: %s", uri)) + + local payload_table = { + ["uri"] = uri, + } + + if level then + payload_table["volume"] = tostring(level) + if args.command == "playTrackAndRestore" then + payload_table["setToMasterVolume"] = false + else + payload_table["setToMasterVolume"] = true + end + end + + if args.command == "playTrackAndResume" then + payload_table["resumeCurrentPlayback"] = true + end + + local payload = json.encode(payload_table) + + conn_info:post_playback_uri(payload) +end + +local function smartthings_audioMute_capability_handler(driver, device, mute_state) + local conn_info = device:get_field(fields.CONN_INFO) + local st_state_to_jbl_muted_state_table = { + muted = "mute", + unmuted = "unmute", + } + log.info(string.format("smartthings_audioMute_capability_handler dni = %s, status = %s", device.device_network_id, mute_state)) + + local jbl_mute_state = st_state_to_jbl_muted_state_table[mute_state] + + local _, err, status = conn_info:post_playback(string.format('{"mute": "%s"}', jbl_mute_state)) + if not err and status == 200 then + log.info(string.format("post_playback success, dni = %s", device.device_network_id)) + elseif status == 404 then + log.error("smartthings_audioMute_capability_handler : 404 error. delete device. dni = " .. tostring(device.device_network_id)) + device:offline() + end +end + +function capability_handler.set_mute_handler(driver, device, args) + local mute_state = args.args.state + smartthings_audioMute_capability_handler(driver, device, mute_state) +end + +function capability_handler.mute_handler(driver, device, args) + smartthings_audioMute_capability_handler(driver, device, "muted") +end + +function capability_handler.unmute_handler(driver, device, args) + smartthings_audioMute_capability_handler(driver, device, "unmuted") +end + +return capability_handler \ No newline at end of file diff --git a/drivers/SmartThings/jbl/src/jbl/device_manager.lua b/drivers/SmartThings/jbl/src/jbl/device_manager.lua new file mode 100644 index 0000000000..8542b05aab --- /dev/null +++ b/drivers/SmartThings/jbl/src/jbl/device_manager.lua @@ -0,0 +1,139 @@ +local log = require "log" +local json = require "st.json" +local fields = require "fields" + +local capabilities = require "st.capabilities" + +local device_manager = {} +device_manager.__index = device_manager + +local function is_new_audioTrackData(device, audioTrackData) + local latestTrackData = device:get_latest_state("main", capabilities.audioTrackData.ID, capabilities.audioTrackData.audioTrackData.NAME, "") + + if not latestTrackData or latestTrackData == "" then + return true + end + + if (audioTrackData.title == latestTrackData.title) + and (audioTrackData.artist == latestTrackData.artist) + and (audioTrackData.album == latestTrackData.album) + and (audioTrackData.albumArtUrl == latestTrackData.albumArtUrl) + and (audioTrackData.mediaSource == latestTrackData.mediaSource) then + return false + end + + return true +end + +local jbl_playback_state_to_smartthings_playback_status_table = { + paused = "paused", + playing = "playing", +} + +function device_manager.handle_status(driver, device, status) + if not status then + log.error("device_manager.handle_status : status is nil") + return + end + + local playback_status = jbl_playback_state_to_smartthings_playback_status_table[status.playback] + if playback_status and playback_status ~= device:get_latest_state("main", capabilities.mediaPlayback.ID, capabilities.mediaPlayback.playbackStatus.NAME, "") then + log.info("device_manager.handle_status : update playbackStatus = " .. tostring(playback_status) .. ", dni = " .. tostring(device.device_network_id)) + device:emit_event(capabilities.mediaPlayback.playbackStatus[playback_status]()) + end + + if status.volume and status.volume ~= device:get_latest_state("main", capabilities.audioVolume.ID, capabilities.audioVolume.volume.NAME, 0) then + log.info("device_manager.handle_status : update volume = " .. tostring(status.volume) .. ", dni = " .. tostring(device.device_network_id)) + device:emit_event(capabilities.audioVolume.volume(status.volume)) + end + + + if status.mute and status.mute ~= "" and status.mute ~= device:get_latest_state("main", capabilities.audioMute.ID, capabilities.audioMute.mute.NAME, "") then + log.info("device_manager.handle_status : update mute = " .. tostring(status.mute) .. ", dni = " .. tostring(device.device_network_id)) + device:emit_event(capabilities.audioMute.mute[status.mute]()) + end + + if status.track and status.track ~= '' then + local audioTrackData = {} + audioTrackData.title = status.track.title or "" + audioTrackData.artist = status.track.artist or "" + audioTrackData.album = status.track.album or "" + if status.track.mediaSource then + audioTrackData.mediaSource = status.track.mediaSource.name + else + audioTrackData.mediaSource = "" + end + + if is_new_audioTrackData(device, audioTrackData) then + log.info("device_manager.handle_status : update audioTrackData = " .. tostring(json.encode(audioTrackData)) .. ", dni = " .. tostring(device.device_network_id) .. tostring(json.encode(status))) + device:emit_event(capabilities.audioTrackData.audioTrackData(audioTrackData)) + device:emit_event(capabilities.audioTrackData.totalTime(status.track.totalTime or 0)) + end + end +end + + +function device_manager.update_status(driver, device) + local conn_info = device:get_field(fields.CONN_INFO) + + if not conn_info then + log.warn(" device_manager.update_status : failed to find conn_info, dni = " .. tostring(device.device_network_id)) + return + end + + local response, err, status = conn_info:get_status() + if err or status ~= 200 then + log.error("device_manager.update_status : failed to get status, dni = " .. tostring(device.device_network_id) .. ", err = " .. tostring(err) .. ", status = " .. tostring(status)) + if status == 404 then + log.info("device_manager.update_status : deleted, dni = " .. tostring(device.device_network_id)) + device:offline() + end + return + end + device_manager.handle_status(driver, device, response) +end + +local sse_event_handlers = { + ["message"] = device_manager.handle_status, +} + +function device_manager.handle_sse_event(driver, device, event_type, data) + local device_json = json.decode(data) or nil + + local event_handler = sse_event_handlers[event_type] + if event_handler then + event_handler(driver, device, device_json) + else + log.error(string.format("handle_sse_event : unknown event type. dni = %s, event_type = '%s'", device.device_network_id, event_type)) + end +end + +function device_manager.refresh(driver, device) + device_manager.update_status(driver, device) +end + + +function device_manager.is_valid_connection(driver, device, conn_info) + if not conn_info then + log.error(" device_manager.is_valid_connection : failed to find conn_info, dni = " .. tostring(device.device_network_id)) + return false + end + local _, err, status = conn_info:get_status() + if err or status ~= 200 then + log.error(" device_manager.is_valid_connection : failed to connect to device, dni = " .. tostring(device.device_network_id) .. ", err = " .. tostring(err) .. ", status = " .. tostring(status)) + return false + end + + return true +end + +function device_manager.device_monitor(driver, device, device_info) + --TODO: add device monitoring logic (ip change, online/offline, etc ..) + log.info("device_monitor = " .. tostring(device.device_network_id)) + device_manager.refresh(driver, device) +end + +function device_manager.get_sse_url(driver, device, conn_info) + return conn_info:get_sse_url() +end +return device_manager diff --git a/drivers/SmartThings/jbl/src/jbl/discovery_helper.lua b/drivers/SmartThings/jbl/src/jbl/discovery_helper.lua new file mode 100644 index 0000000000..1759e6b387 --- /dev/null +++ b/drivers/SmartThings/jbl/src/jbl/discovery_helper.lua @@ -0,0 +1,97 @@ +-- Copyright 2022 SmartThings +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +-- except in compliance with the License. You may obtain a copy of the License at: +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software distributed under the +-- License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +-- either express or implied. See the License for the specific language governing permissions +-- and limitations under the License. +-- +-- =============================================================================================== + + +local log = require "log" + +local discovery_helper = {} + +local SERVICE_TYPE = "_jbl._tcp" +local DOMAIN = "local" + +local jbl_api = require "jbl.api" +local discovery_mdns = require "discovery_mdns" + +function discovery_helper.get_dni(driver, ip, discovery_responses) + local text_list = discovery_mdns.find_text_list_in_mdns_response(driver, ip, discovery_responses) + for _, text in ipairs(text_list) do + for key, value in string.gmatch(text, "(%S+)=(%S+)") do + if key == "mac" then + log.info("discovery_helper.get_dni : use mac as dni, dni = " .. value) + return value + end + end + end + + log.error("discovery_helper.get_dni : failed to find dni") + return nil +end + +function discovery_helper.get_service_type_and_domain() + return SERVICE_TYPE, DOMAIN +end + +function discovery_helper.get_device_create_msg(driver, device_dni, device_ip) + local device_info = jbl_api.get_info(device_ip, jbl_api.labeled_socket_builder(device_dni)) + + if not device_info then + log.error("failed to create device create msg. device_info is nil. dni = " .. device_dni) + return nil + end + + local create_device_msg = { + type = "LAN", + device_network_id = device_dni, + label = device_info.label, + profile = "jbl", + manufacturer = device_info.manufacturerName, + model = device_info.modelName, + vendor_provided_label = device_info.label, + } + + return create_device_msg +end + +function discovery_helper.get_credential(driver, bridge_dni, bridge_ip) + local credential = jbl_api.get_credential(bridge_ip, jbl_api.labeled_socket_builder(bridge_dni)) + + if not credential then + log.error("credential is nil") + return nil + end + + return "Bearer " .. credential.token +end + +function discovery_helper.get_connection_info(driver, device_dni, device_ip, device_info) + local conn_info = jbl_api.new_device_manager(device_ip, device_info, jbl_api.labeled_socket_builder(device_dni)) + + if conn_info == nil then + log.error("conn_info is nil") + end + + return conn_info +end + +function discovery_helper.get_device_info(driver, device_dni, device_ip) + local device_info = jbl_api.get_info(device_ip, jbl_api.labeled_socket_builder(device_dni)) + + if device_info == nil then + log.error("device_info is nil") + end + + return device_info +end + +return discovery_helper diff --git a/drivers/SmartThings/jbl/src/lunchbox/init.lua b/drivers/SmartThings/jbl/src/lunchbox/init.lua new file mode 100644 index 0000000000..2745880f50 --- /dev/null +++ b/drivers/SmartThings/jbl/src/lunchbox/init.lua @@ -0,0 +1,4 @@ +local RestClient = require "lunchbox.rest" +local EventSource = require "lunchbox.sse.eventsource" + +return {RestClient = RestClient, EventSource = EventSource} diff --git a/drivers/SmartThings/jbl/src/lunchbox/rest.lua b/drivers/SmartThings/jbl/src/lunchbox/rest.lua new file mode 100644 index 0000000000..473e9abf6d --- /dev/null +++ b/drivers/SmartThings/jbl/src/lunchbox/rest.lua @@ -0,0 +1,372 @@ +local socket = require "cosock.socket" +local utils = require "utils" + +local lb_utils = require "lunchbox.util" +local Request = require "luncheon.request" +local Response = require "luncheon.response" + +local RestCallStates = { + SEND = "Send", + RECEIVE = "Receive", + RETRY = "Retry", + RECONNECT = "Reconnect", + COMPLETE = "Complete", +} + +local function connect(client) + local default_port = 80 + local use_ssl = false + + if client.base_url.scheme == "https" then + default_port = 443 + use_ssl = true + end + + local port = client.base_url.port or default_port + + local sock, err = client.socket_builder(client.base_url.host, port, use_ssl) + + if sock == nil then + client.socket = nil + return false, err + end + + client.socket = sock + return true +end + +local function reconnect(client) + if client.socket ~= nil then + client.socket:close() + client.socket = nil + end + return connect(client) +end + +local function send_request(client, request) + if client.socket == nil then + return nil, "no socket available" + end + local payload = request:serialize() + + local bytes, err, idx = nil, nil, 0 + + repeat bytes, err, idx = client.socket:send(payload, idx + 1, #payload) until (bytes == #payload) + or (err ~= nil) + + return bytes, err, idx +end + +local function parse_chunked_response(original_response, sock) + local ChunkedTransferStates = { + EXPECTING_CHUNK_LENGTH = "ExpectingChunkLength", + EXPECTING_BODY_CHUNK = "ExpectingBodyChunk", + } + + local full_response = Response.new(original_response.status, nil) + + for header in original_response.headers:iter() do full_response.headers:append_chunk(header) end + + local original_body, err = original_response:get_body() + if type(original_body) ~= "string" or err ~= nil then + return original_body, (err or "unexpected nil in error position") + end + local next_chunk_bytes = tonumber(original_body, 16) + local next_chunk_body = "" + local bytes_read = 0; + + local state = ChunkedTransferStates.EXPECTING_BODY_CHUNK + + repeat + local pat = nil + local next_recv, next_err, partial = nil, nil, nil + + if state == ChunkedTransferStates.EXPECTING_BODY_CHUNK then + pat = next_chunk_bytes + else + pat = "*l" + end + + next_recv, next_err, partial = sock:receive(pat) + + if next_err ~= nil then + if string.lower(next_err) == "closed" then + if partial ~= nil and #partial >= 1 then + full_response:append_body(partial) + next_chunk_bytes = 0 + end + else + return nil, ("unexpected error reading chunked transfer: " .. next_err) + end + end + + if next_recv ~= nil and #next_recv >= 1 then + if state == ChunkedTransferStates.EXPECTING_BODY_CHUNK then + bytes_read = bytes_read + #next_recv + next_chunk_body = next_chunk_body .. next_recv + + if bytes_read >= next_chunk_bytes then + full_response = full_response:append_body(next_chunk_body) + next_chunk_body = "" + bytes_read = 0 + + state = ChunkedTransferStates.EXPECTING_CHUNK_LENGTH + end + elseif state == ChunkedTransferStates.EXPECTING_CHUNK_LENGTH then + next_chunk_bytes = tonumber(next_recv, 16) + + state = ChunkedTransferStates.EXPECTING_BODY_CHUNK + end + end + until next_chunk_bytes == 0 + + local _ = sock:receive("*l") -- clear the trailing CRLF + + full_response._received_body = true + full_response._parsed_headers = true + + return full_response +end + +local function recv_additional_response(original_response, sock) + local full_response = Response.new(original_response.status, nil) + local headers = original_response:get_headers() + local content_length_str = headers:get_one("Content-Length") + local content_length = nil + local bytes_read = 0 + if content_length_str then + content_length = math.tointeger(content_length_str) + end + + local next_recv, next_err, partial + + repeat + next_recv, next_err, partial = sock:receive(content_length - bytes_read) + + if next_recv ~= nil and #next_recv >= 1 then + full_response:append_body(next_recv) + bytes_read = bytes_read + #next_recv + end + + if partial ~= nil and #partial >= 1 then + full_response:append_body(partial) + bytes_read = bytes_read + #partial + end + until next_err == "closed" or bytes_read >= content_length + + full_response._received_body = true + full_response._parsed_headers = true + + return full_response +end + +local function handle_response(sock) + -- called select right before passing in so we receive immediately + local initial_recv, initial_err, partial = Response.source(function() return sock:receive('*l') end) + + local full_response = nil + + if initial_recv ~= nil then + local headers = initial_recv:get_headers() + + if headers:get_one("Content-Length") then + full_response = recv_additional_response(initial_recv, sock) + elseif headers:get_one("Transfer-Encoding") == "chunked" then + full_response = parse_chunked_response(initial_recv, sock) + else + full_response = initial_recv + end + + return full_response + else + return nil, initial_err, partial + end +end + +local function execute_request(client, request, retry_fn) + if not client._active then + return nil, "Called `execute request` on a terminated REST Client" + end + + if client.socket == nil then + local success, err = connect(client) + if not success then return nil, err end + end + + local should_retry = retry_fn + + if type(should_retry) ~= "function" then + should_retry = function() return false end + end + + -- send output + local _bytes_sent, send_err, _idx = nil, nil, 0 + -- recv output + local response, recv_err, _partial = nil, nil, nil + -- return values + local ret, err = nil, nil + + local backoff = utils.backoff_builder(60, 1, 0.1) + local current_state = RestCallStates.SEND + + repeat + local retry = should_retry() + if current_state == RestCallStates.SEND then + backoff = utils.backoff_builder(60, 1, 0.1) + _bytes_sent, send_err, _idx = send_request(client, request) + + if not send_err then + current_state = RestCallStates.RECEIVE + elseif retry then + if string.lower(send_err) == "closed" or string.lower(send_err):match("broken pipe") then + current_state = RestCallStates.RECONNECT + else + current_state = RestCallStates.RETRY + end + else + ret = nil + err = send_err + current_state = RestCallStates.COMPLETE + end + elseif current_state == RestCallStates.RECEIVE then + response, recv_err, _partial = handle_response(client.socket) + + if not recv_err then + ret = response + err = nil + current_state = RestCallStates.COMPLETE + elseif retry then + if string.lower(recv_err) == "closed" or string.lower(recv_err):match("broken pipe") then + current_state = RestCallStates.RECONNECT + else + current_state = RestCallStates.RETRY + end + else + ret = nil + err = recv_err + current_state = RestCallStates.COMPLETE + end + elseif current_state == RestCallStates.RECONNECT then + local success, reconn_err = reconnect(client) + if success then + current_state = RestCallStates.RETRY + elseif not retry then + ret = nil + err = reconn_err + current_state = RestCallStates.COMPLETE + else + socket.sleep(backoff()) + end + elseif current_state == RestCallStates.RETRY then + bytes_sent, send_err, _idx = nil, nil, 0 + response, recv_err, partial = nil, nil, nil + current_state = RestCallStates.SEND + socket.sleep(backoff()) + end + until current_state == RestCallStates.COMPLETE + + return ret, err +end + +---@class RestClient +--- +---@field base_url table `net.url` URL table +---@field socket table `cosock` TCP socket +local RestClient = {} +RestClient.__index = RestClient + +function RestClient.one_shot_get(full_url, additional_headers, socket_builder) + local url_table = lb_utils.force_url_table(full_url) + local client = RestClient.new(url_table.scheme .. "://" .. url_table.authority, socket_builder) + local ret, err = client:get(url_table.path, additional_headers) + client:shutdown() + client = nil + return ret, err +end + +function RestClient.one_shot_post(full_url, body, additional_headers, socket_builder) + local url_table = lb_utils.force_url_table(full_url) + local client = RestClient.new(url_table.scheme .. "://" .. url_table.authority, socket_builder) + local ret, err = client:post(url_table.path, body, additional_headers) + client:shutdown() + client = nil + return ret, err +end + +function RestClient:close_socket() + if self.socket ~= nil and self._active then + self.socket:close() + self.socket = nil + end +end + +function RestClient:shutdown() + self:close_socket() + self._active = false +end + +function RestClient:update_base_url(new_url) + if self.socket ~= nil then + self.socket:close() + self.socket = nil + end + + self.base_url = lb_utils.force_url_table(new_url) +end + +function RestClient:get(path, additional_headers, retry_fn) + local request = Request.new("GET", path, nil):add_header( + "user-agent", "smartthings-lua-edge-driver" + ):add_header("host", string.format("%s", self.base_url.host)):add_header( + "connection", "keep-alive" + ) + + if additional_headers ~= nil and type(additional_headers) == "table" then + for k, v in pairs(additional_headers) do request = request:add_header(k, v) end + end + + return execute_request(self, request, retry_fn) +end + +function RestClient:post(path, body_string, additional_headers, retry_fn) + local request = Request.new("POST", path, nil):add_header( + "user-agent", "smartthings-lua-edge-driver" + ):add_header("host", string.format("%s", self.base_url.host)):add_header( + "connection", "keep-alive" + ) + + if additional_headers ~= nil and type(additional_headers) == "table" then + for k, v in pairs(additional_headers) do request = request:add_header(k, v) end + end + + request = request:append_body(body_string) + + return execute_request(self, request, retry_fn) +end + +function RestClient:put(path, body_string, additional_headers, retry_fn) + local request = Request.new("PUT", path, nil):add_header( + "user-agent", "smartthings-lua-edge-driver" + ):add_header("host", string.format("%s", self.base_url.host)):add_header( + "connection", "keep-alive" + ) + + if additional_headers ~= nil and type(additional_headers) == "table" then + for k, v in pairs(additional_headers) do request = request:add_header(k, v) end + end + + request = request:append_body(body_string) + + return execute_request(self, request, retry_fn) +end + +function RestClient.new(base_url, sock_builder) + base_url = lb_utils.force_url_table(base_url) + + if type(sock_builder) ~= "function" then sock_builder = utils.labeled_socket_builder() end + + return + setmetatable({base_url = base_url, socket_builder = sock_builder, socket = nil, _active = true}, RestClient) +end + +return RestClient diff --git a/drivers/SmartThings/jbl/src/lunchbox/sse/eventsource.lua b/drivers/SmartThings/jbl/src/lunchbox/sse/eventsource.lua new file mode 100644 index 0000000000..d016ff4908 --- /dev/null +++ b/drivers/SmartThings/jbl/src/lunchbox/sse/eventsource.lua @@ -0,0 +1,510 @@ +local cosock = require "cosock" +local socket = require "cosock.socket" +local ssl = require "cosock.ssl" + +local log = require "log" +local util = require "lunchbox.util" +local Request = require "luncheon.request" +local Response = require "luncheon.response" + +--- A pure Lua implementation of the EventSource interface. +--- The EventSource interface represents the client end of an HTTP(S) +--- connection that receives an event stream following the Server-Sent events +--- specification. +--- +--- MDN Documentation for EventSource: https://developer.mozilla.org/en-US/docs/Web/API/EventSource +--- HTML Spec: https://html.spec.whatwg.org/multipage/server-sent-events.html +--- +--- @class EventSource +--- @field public url table A `net.url` table representing the URL for the connection +--- @field public ready_state number Enumeration of the ready states outlined in the spec. +--- @field public onopen function in-line callback for on-open events +--- @field public onmessage function in-line callback for on-message events +--- @field public onerror function in-line callback for on-error events; error callbacks will fire +--- @field private _reconnect boolean flag that says whether or not the client should attempt to reconnect on close. +--- @field private _reconnect_time_millis number The amount of time to wait between reconnects, in millis. Can be sent by the server. +--- @field private _sock_builder function|nil optional. If this function exists, it will be called to create a new TCP socket on connection. +--- @field private _sock table the TCP socket for the connection +--- @field private _needs_more boolean flag to track whether or not we're still expecting mroe on this source before we dispatch +--- @field private _last_field string the last field the parsing path saw, in case it needs to append more to its value +--- @field private _extra_headers table a table of string:string key-value pairs that will be inserted in to the initial requests's headers. +--- @field private _parse_buffers table inner state, keeps track of the various event stream buffers in between dispatches. +--- @field private _listeners table event listeners attached using the add_event_listener API instead of the inline callbacks. +local EventSource = {} +EventSource.__index = EventSource + +--- The Ready States that an EventSource can be in. We use base 0 to match the specification. +EventSource.ReadyStates = util.read_only { + CONNECTING = 0, -- The connection has not yet been established + OPEN = 1, -- The connection is open + CLOSED = 2 -- The connection has closed +} + +--- The event types supported by this source, patterned after their values in JavaScript. +EventSource.EventTypes = util.read_only { + ON_OPEN = "open", + ON_MESSAGE = "message", + ON_ERROR = "error", +} + +--- Helper function that creates the initial Request to start the stream. +--- @function create_request +--- @local +--- @param url_table table a net.url table +--- @param extra_headers table a set of key/value pairs (strings) to capture any extra HTTP headers needed. +local function create_request(url_table, extra_headers) + local request = Request.new("GET", url_table.path, nil) + :add_header("user-agent", "smartthings-lua-edge-driver") + :add_header("host", string.format("%s", url_table.host)) + :add_header("connection", "keep-alive") + :add_header("accept", "text/event-stream") + + if type(extra_headers) == "table" then + for k, v in pairs(extra_headers) do + request = request:add_header(k, v) + end + end + + return request +end + +--- Helper function to send the request and kick off the stream. +--- @function send_stream_start_request +--- @local +--- @param payload string the entire string buffer to send +--- @param sock table the TCP socket to send it over +local function send_stream_start_request(payload, sock) + local bytes, err, idx = nil, nil, 0 + + repeat + bytes, err, idx = sock:send(payload, idx + 1, #payload) + until (bytes == #payload) or (err ~= nil) + + if err then + log.error("send error: " .. err) + end + + return bytes, err, idx +end + +--- Helper function to create an table representing an event from the source's parse buffers. +--- @function make_event +--- @local +--- @param source EventSource +local function make_event(source) + local event_type = nil + + if #source._parse_buffers["event"] > 0 then + event_type = source._parse_buffers["event"] + end + + return { + type = event_type or "message", + data = source._parse_buffers["data"], + origin = source.url.scheme .. "://" .. source.url.host, + lastEventId = source._parse_buffers["id"] + } +end + +--- SSE spec for dispatching an event: +--- https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage +--- @function dispatch_event +--- @local +--- @param source EventSource +local function dispatch_event(source) + local data_buffer = source._parse_buffers["data"] + local is_blank_line = data_buffer ~= nil and + (#data_buffer == 0) or + data_buffer == "\n" or + data_buffer == "\r" or + data_buffer == "\r\n" + if data_buffer ~= nil and not is_blank_line then + local event = util.read_only(make_event(source)) + + if type(source.onmessage) == "function" then + source.onmessage(event) + end + + for _, listener in ipairs(source._listeners[EventSource.EventTypes.ON_MESSAGE]) do + if type(listener) == "function" then + listener(event) + end + end + end + + source._parse_buffers["event"] = "" + source._parse_buffers["data"] = "" +end + +local valid_fields = util.read_only { + ["event"] = true, + ["data"] = true, + ["id"] = true, + ["retry"] = true +} + +-- An event stream "line" can end in more than one way; from the spec: +-- Lines must be separated by either +-- a U+000D CARRIAGE RETURN U+000A LINE FEED (CRLF) character pair, +-- a single U+000A LINE FEED (LF) character, +-- or a single U+000D CARRIAGE RETURN (CR) character. +-- +-- util.iter_string_lines won't suffice here because: +-- a.) it assumes \n, and +-- b.) it doesn't differentiate between a "line" that ends without a newline and one that does. +-- +-- h/t to github.com/FreeMasen for the suggestions on the efficient implementation of this +local function find_line_endings(chunk) + local r_idx, n_idx = string.find(chunk, "[\r\n]+") + if r_idx == n_idx then + -- 1 character + return r_idx, n_idx + end + local slice = string.sub(chunk, r_idx, n_idx) + if slice == "\r\n" then + return r_idx, n_idx + end + -- invalid multi character match, return first character only + return r_idx, r_idx +end + +local function event_lines(chunk) + local remaining = chunk + local line_end, rn_end + local remainder_sent = false + return function() + line_end, rn_end = find_line_endings(remaining) + if not line_end then + if remainder_sent or (not remaining) or #remaining == 0 then + return nil + else + remainder_sent = true + return remaining, false + end + end + local next_line = string.sub(remaining, 1, line_end - 1) + remaining = string.sub(remaining, rn_end + 1) + return next_line, true + end +end +--- SSE spec for interpreting an event stream: +--- https://html.spec.whatwg.org/multipage/server-sent-events.html#the-eventsource-interface +--- @function parse +--- @local +--- @param source EventSource +--- @param recv string the received payload from the last socket receive +local function sse_parse_chunk(source, recv) + for line, complete in event_lines(recv) do + if not source._needs_more and (#line == 0 or (not line:match("([%w%p]+)"))) then -- empty/blank lines indicate dispatch + dispatch_event(source) + elseif source._needs_more then + local append = line + if source._last_field == "data" and complete then append = append .. "\n" end + if complete then source._needs_more = false end + source._parse_buffers[source._last_field] = source._parse_buffers[source._last_field] .. append + else + if line:sub(1, 1) ~= ":" then -- ignore any complete lines that start w/ a colon + local matches = line:gmatch("(%w*)(:*)(.*)") -- colon after field is optional, in that case it's a field w/ no value + + for field, _colon, value in matches do + value = value:gsub("^[^%g]", "", 1) -- trim a single leading space character + + if valid_fields[field] then + source._last_field = field + if field == "retry" then + local new_time = tonumber(value, 10) + if type(new_time) == "number" then + source._reconnect_time_millis = new_time + end + elseif field == "data" then + local append = (value or "") + if complete then append = append .. "\n" end + source._parse_buffers[field] = source._parse_buffers[field] .. append + elseif field == "id" then + -- skip ID's if they contain the NULL character + if not string.find(value, '\0') then + source._parse_buffers[field] = value + end + else + source._parse_buffers[field] = value + end + end + source._needs_more = source._needs_more or (not complete) + end + end + end + end +end + +--- Helper function that captures the cyclic logic of the EventSource while in the CONNECTING state. +--- @function connecting_action +--- @local +--- @param source EventSource +local function connecting_action(source) + if not source._sock then + if type(source._sock_builder) == "function" then + source._sock = source._sock_builder() + else + source._sock, err = socket.tcp() + if err ~= nil then return nil, err end + + _, err = source._sock:settimeout(60) + if err ~= nil then return nil, err end + + _, err = source._sock:connect(source.url.host, source.url.port) + if err ~= nil then return nil, err end + + _, err = source._sock:setoption("keepalive", true) + if err ~= nil then return nil, err end + + if source.url.scheme == "https" then + source._sock, err = ssl.wrap(source._sock, { + mode = "client", + protocol = "any", + verify = "none", + options = "all" + }) + if err ~= nil then return nil, err end + + _, err = source._sock:dohandshake() + if err ~= nil then return nil, err end + end + end + end + + local request = create_request(source.url, source._extra_headers) + + local last_event_id = source._parse_buffers["id"] + + if last_event_id ~= nil and #last_event_id > 0 then + request = request:add_header("Last-Event-ID", last_event_id) + end + + local _, err, _ = send_stream_start_request(request:serialize(), source._sock) + + if err ~= nil then + return nil, err + end + + local response + response, err = Response.tcp_source(source._sock) + + if err ~= nil then + return nil, err + end + + if response.status ~= 200 then + return nil, "Server responded with status other than 200 OK", { response.status, response.status_msg } + end + + local headers, err = response:get_headers() + if err ~= nil then + return nil, err + end + local content_type = string.lower((headers:get_one('content-type') or "none")) + if not content_type:find("text/event-stream", 1, true) then + local err_msg = "Expected content type of text/event-stream in response headers, received: " .. content_type + return nil, err_msg + end + + source.ready_state = EventSource.ReadyStates.OPEN + + if type(source.onopen) == "function" then + source.onopen() + end + + for _, listener in ipairs(source._listeners[EventSource.EventTypes.ON_OPEN]) do + if type(listener) == "function" then + listener() + end + end +end +--- Helper function that captures the cyclic logic of the EventSource while in the OPEN state. +--- @function open_action +--- @local +--- @param source EventSource +local function open_action(source) + local recv, err, partial = source._sock:receive('*l') + + if err then + --- connection is fine but there was nothing + --- to be read from the other end so we just + --- early return. + if err == "timeout" or err == "wantread" then + return + else + --- real error, close the connection. + source._sock:close() + source._sock = nil + source.ready_state = EventSource.ReadyStates.CLOSED + return nil, err, partial + end + end + + -- the number of bytes to read per the chunked encoding spec + local recv_as_num = tonumber(recv, 16) + + if recv_as_num ~= nil then + recv, err, partial = source._sock:receive(recv_as_num) + if err then + if err == "timeout" or err == "wantread" then + return + else + --- real error, close the connection. + source._sock:close() + source._sock = nil + source.ready_state = EventSource.ReadyStates.CLOSED + return nil, err, partial + end + end + local _, err, partial = source._sock:receive('*l') -- clear the final line + + if err then + if err == "timeout" or err == "wantread" then + return + else + --- real error, close the connection. + source._sock:close() + source._sock = nil + source.ready_state = EventSource.ReadyStates.CLOSED + return nil, err, partial + end + end + sse_parse_chunk(source, recv) + else + local recv_dbg = recv or "" + if #recv_dbg == 0 then recv_dbg = "" end + recv_dbg = recv_dbg:gsub("\r\n", ""):gsub("\n", ""):gsub("\r", "") + log.error(string.format("Received %s while expecting a chunked encoding payload length (hex number)\n", recv_dbg)) + end +end + +--- Helper function that captures the cyclic logic of the EventSource while in the CLOSED state. +--- @function closed_action +--- @local +--- @param source EventSource +local function closed_action(source) + if source._sock ~= nil then + source._sock:close() + source._sock = nil + end + + if source._reconnect then + if type(source.onerror) == "function" then + source.onerror() + end + + for _, listener in ipairs(source._listeners[EventSource.EventTypes.ON_ERROR]) do + if type(listener) == "function" then + listener() + end + end + + local sleep_time_secs = source._reconnect_time_millis / 1000.0 + socket.sleep(sleep_time_secs) + + source.ready_state = EventSource.ReadyStates.CONNECTING + end +end + +local state_actions = { + [EventSource.ReadyStates.CONNECTING] = connecting_action, + [EventSource.ReadyStates.OPEN] = open_action, + [EventSource.ReadyStates.CLOSED] = closed_action +} + +--- Create a new EventSource. The only required parameter is the URL, which can +--- be a string or a net.url table. The string form will be converted to a net.url table. +--- +--- @param url string|table a string or a net.url table representing the complete URL (minimally a scheme/host/path, port optional) for the event stream. +--- @param extra_headers table|nil an optional table of key-value pairs (strings) to be added to the initial GET request +--- @param sock_builder function|nil an optional function to be used to create the TCP socket for the stream. If nil, a set of defaults will be used to create a new TCP socket. +--- @return EventSource a new EventSource +function EventSource.new(url, extra_headers, sock_builder) + local url_table = util.force_url_table(url) + + if not url_table.port then + if url_table.scheme == "http" then + url_table.port = 80 + elseif url_table.scheme == "https" then + url_table.port = 443 + end + end + + local sock = nil + + if type(sock_builder) == "function" then + sock = sock_builder() + end + + local source = setmetatable({ + url = url_table, + ready_state = EventSource.ReadyStates.CONNECTING, + onopen = nil, + onmessage = nil, + onerror = nil, + _needs_more = false, + _last_field = nil, + _reconnect = true, + _reconnect_time_millis = 1000, + _sock_builder = sock_builder, + _sock = sock, + _extra_headers = extra_headers, + _parse_buffers = { + ["data"] = "", + ["id"] = "", + ["event"] = "", + }, + _listeners = { + [EventSource.EventTypes.ON_OPEN] = {}, + [EventSource.EventTypes.ON_MESSAGE] = {}, + [EventSource.EventTypes.ON_ERROR] = {} + }, + }, EventSource) + + cosock.spawn(function() + local st_utils = require "st.utils" + while true do + if source.ready_state == EventSource.ReadyStates.CLOSED and + not source._reconnect + then + return + end + local _, action_err, partial = state_actions[source.ready_state](source) + if action_err ~= nil then + if action_err ~= "timeout" or action_err ~= "wantread" then + log.error("Event Source Coroutine State Machine error: " .. action_err) + if partial ~= nil and #partial > 0 then + log.error(st_utils.stringify_table(partial, "\tReceived Partial", true)) + end + source.ready_state = EventSource.ReadyStates.CLOSED + end + end + end + end) + + return source +end + +--- Close the event source, signalling that a reconnect is not desired +function EventSource:close() + self._reconnect = false + if self._sock ~= nil then + self._sock:close() + end + self._sock = nil + self.ready_state = EventSource.ReadyStates.CLOSED +end + +--- Add a callback to the event source +---@param listener_type string One of "message", "open", or "error" +---@param listener function the callback to be called in case of an event. Open and Error events have no payload. The message event will have a single argument, a table. +function EventSource:add_event_listener(listener_type, listener) + local list = self._listeners[listener_type] + + if list then + table.insert(list, listener) + end +end + +return EventSource diff --git a/drivers/SmartThings/jbl/src/lunchbox/util.lua b/drivers/SmartThings/jbl/src/lunchbox/util.lua new file mode 100644 index 0000000000..54a421047f --- /dev/null +++ b/drivers/SmartThings/jbl/src/lunchbox/util.lua @@ -0,0 +1,46 @@ +local net_url = require "net.url" + +local util = {} + +util.force_url_table = function(url) + if type(url) ~= "table" then url = net_url.parse(url) end + + if not url.port then + if url.scheme == "http" then + url.port = 80 + elseif url.scheme == "https" then + url.port = 443 + end + end + + return url +end + +util.read_only = function(tbl) + if type(tbl) == "table" then + local proxy = {} + local mt = { -- create metatable + __index = tbl, + __newindex = function(t, k, v) error("attempt to update a read-only table", 2) end, + } + setmetatable(proxy, mt) + return proxy + else + return tbl + end +end + +util.iter_string_lines = function(str) + if str:sub(-1) ~= "\n" then str = str .. "\n" end + + return str:gmatch("(.-)\n") +end + +util.copy_data = function(tbl) + local ret = {} + for k, v in pairs(tbl) do ret[k] = v end + + return ret +end + +return util diff --git a/drivers/SmartThings/jbl/src/selfSignedRoot.crt b/drivers/SmartThings/jbl/src/selfSignedRoot.crt new file mode 100644 index 0000000000..329477095e --- /dev/null +++ b/drivers/SmartThings/jbl/src/selfSignedRoot.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDmzCCAoMCFCWrw9rEB3QpBgEh9WNAlaCHkEyDMA0GCSqGSIb3DQEBCwUAMIGJ +MQswCQYDVQQGEwJVUzEUMBIGA1UECAwLQ29ubmVjdGljdXQxETAPBgNVBAcMCFN0 +YW1mb3JkMR0wGwYDVQQKDBRIQVJNQU4gSW50ZXJuYXRpb25hbDETMBEGA1UECwwK +VGVjaG5vbG9neTEdMBsGA1UEAwwUKi5kZXZpY2VzLmhhcm1hbi5jb20wHhcNMjMw +NzA1MDYwODI1WhcNNDMwNjMwMDYwODI1WjCBiTELMAkGA1UEBhMCVVMxFDASBgNV +BAgMC0Nvbm5lY3RpY3V0MREwDwYDVQQHDAhTdGFtZm9yZDEdMBsGA1UECgwUSEFS +TUFOIEludGVybmF0aW9uYWwxEzARBgNVBAsMClRlY2hub2xvZ3kxHTAbBgNVBAMM +FCouZGV2aWNlcy5oYXJtYW4uY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAtXpf1YCY0GHeyxlDpxS4RACruP6bDXAcQ7CmY68GruCkQBQktPU8J0CN +SjdfMoEnXPC/EVg89/0BoiTaGVrnIkuJPkmbBWkluKpINhQ1IoKbr33PnJoTIG/7 +er5NHWq55dvY+FQIrZ8H4wZwtlH26ceKif2/4mHJtH5nl71AqXoGcT8W8B3nPIrP +0nvuC/FtlIyb2TlvD+lya+T0Uocu5mabl2TOaNjisndm4yke8YFaMQAj7bQEv66W +7uvacllzezl/Je3cVAfL3xM0D4GhtD6CD9TKsZulRTqy6YpPKptrBrGotxOzX3dK +5UwveZ7/xdCclEaQlfNTzznDjXLa+QIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQBW +GrdlIpLkb3cafaIsDtrALG8bHZR1eKXNOz56pJwYeoQl80yTFj02ehCfupXEbr7A +OQ3hdDSuV/ovV+xmbT29XLC4gvfYLasmrPfoxTT2w7tfNOkxkKV/E4VvQUXL/+Vk +uSifYLdIM89QGjirocGQ2wKlOq3LUfQTp3IZ2C9a1sn21aKj1no3/xKiazrAT7JQ +pcVPvSyyXZzyVQW8XSVThiHH9EwRU6/3hxguYEnL+Wkg9wBmav/qmmLNEBQSQoSi +Fg6nk2z+7ohsTB74URXhjMxloUG0k/UQ4gvH/cDFTthJMT+A+aLqVg4xiEb6M9iz +b9g1LEECKNF5dCPb7ICb +-----END CERTIFICATE----- diff --git a/drivers/SmartThings/jbl/src/utils.lua b/drivers/SmartThings/jbl/src/utils.lua new file mode 100644 index 0000000000..c67816359c --- /dev/null +++ b/drivers/SmartThings/jbl/src/utils.lua @@ -0,0 +1,178 @@ +local log = require "log" +---@module 'utils' +local utils = {} + +function utils.str_starts_with(str, start) + return str:sub(1, #start) == start +end + +function utils.is_nan(number) + -- IEEE 754 dictates that NaN compares falsey to everything, including itself. + if number ~= number then + return true + end + + -- If someone passes in something that isn't a Number type, it'll pass the above check. + -- Philosophical question: Something that isn't a number can't technicaly have the value + -- of "nan" but "nan" stands for "not a number", so what do we do here? + if type(number) ~= "number" then + log.warn(string.format("utils.is_nan received value of type %s as argument, returning true", type(number))) + return true + end + + -- In the event that something goes wrong with the above two things, + -- we simply compare the tostring against a known NaN value. + return tostring(number) == tostring(0 / 0) +end + +-- build a exponential backoff time value generator +-- +-- max: the maximum wait interval (not including `rand factor`) +-- inc: the rate at which to exponentially back off +-- rand: a randomization range of (-rand, rand) to be added to each interval +function utils.backoff_builder(max, inc, rand) + local count = 0 + inc = inc or 1 + return function() + local randval = 0 + if rand then + randval = math.random() * rand * 2 - rand + end + + local base = inc * (2 ^ count - 1) + count = count + 1 + + -- ensure base backoff (not including random factor) is less than max + if max then base = math.min(base, max) end + + -- ensure total backoff is >= 0 + return math.max(base + randval, 0) + end +end + +function utils.labeled_socket_builder(label, ssl_config) + local log = require "log" + local socket = require "cosock.socket" + local ssl = require "cosock.ssl" + + label = (label or "") + if #label > 0 then + label = label .. " " + end + + if not ssl_config then + ssl_config = { mode = "client", protocol = "any", verify = "none", options = "all" } + end + + local function make_socket(host, port, wrap_ssl) + log.info( + string.format( + "%sCreating TCP socket for REST Connection", label + ) + ) + local _ = nil + local sock, err = socket.tcp() + + if err ~= nil or (not sock) then + return nil, (err or "unknown error creating TCP socket") + end + + log.info( + string.format( + "%sSetting TCP socket timeout for REST Connection", label + ) + ) + _, err = sock:settimeout(60) + if err ~= nil then + return nil, "settimeout error: " .. err + end + + log.info( + string.format( + "%sConnecting TCP socket for REST Connection", label + ) + ) + _, err = sock:connect(host, port) + if err ~= nil then + return nil, "Connect error: " .. err + end + + log.info( + string.format( + "%sSet Keepalive for TCP socket for REST Connection", label + ) + ) + _, err = sock:setoption("keepalive", true) + if err ~= nil then + return nil, "Setoption error: " .. err + end + + if wrap_ssl then + log.info( + string.format( + "%sCreating SSL wrapper for for REST Connection", label + ) + ) + sock, err = + ssl.wrap(sock, ssl_config) + if err ~= nil then + return nil, "SSL wrap error: " .. err + end + log.info( + string.format( + "%sPerforming SSL handshake for for REST Connection", label + ) + ) + _, err = sock:dohandshake() + if err ~= nil then + return nil, "Error with SSL handshake: " .. err + end + end + + log.info( + string.format( + "%sSuccessfully created TCP connection", label + ) + ) + return sock, err + end + return make_socket +end + +--- From https://gist.github.com/sapphyrus/fd9aeb871e3ce966cc4b0b969f62f539 +--- MIT licensed +function utils.deep_table_eq(tbl1, tbl2) + if tbl1 == tbl2 then + return true + elseif type(tbl1) == "table" and type(tbl2) == "table" then + for key1, value1 in pairs(tbl1) do + local value2 = tbl2[key1] + + if value2 == nil then + -- avoid the type call for missing keys in tbl2 by directly comparing with nil + return false + elseif value1 ~= value2 then + if type(value1) == "table" and type(value2) == "table" then + if not utils.deep_table_eq(value1, value2) then + return false + end + else + return false + end + end + end + + -- check for missing keys in tbl1 + for key2, _ in pairs(tbl2) do + if tbl1[key2] == nil then + return false + end + end + + return true + end + + return false +end + +return utils