Skip to content

Commit

Permalink
Update JBL lunchbox, to utilize new luncheon library (#1252)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjswedes authored Mar 5, 2024
1 parent e8a72f4 commit c3ad669
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 85 deletions.
114 changes: 48 additions & 66 deletions drivers/SmartThings/jbl/src/lunchbox/rest.lua
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
---@class ChunkedResponse : Response
---@field package _received_body boolean
---@field package _parsed_headers boolean
---@field public new fun(status_code: number, socket: table?): ChunkedResponse
---@field public fill_body fun(self: ChunkedResponse): string?
---@field public append_body fun(self: ChunkedResponse, next_chunk_body: string): ChunkedResponse

local socket = require "cosock.socket"
local utils = require "utils"

local utils = require "utils"
local lb_utils = require "lunchbox.util"
local Request = require "luncheon.request"
local Response = require "luncheon.response"
local Response = require "luncheon.response" --[[@as ChunkedResponse]]

local api_version = require("version").api

local RestCallStates = {
SEND = "Send",
Expand All @@ -14,16 +23,14 @@ local RestCallStates = {
}

local function connect(client)
local default_port = 80
local port = 80
local use_ssl = false

if client.base_url.scheme == "https" then
default_port = 443
port = 443
use_ssl = true
end

local port = client.base_url.port or default_port

local sock, err = client.socket_builder(client.base_url.host, port, use_ssl)

if sock == nil then
Expand All @@ -43,16 +50,22 @@ local function reconnect(client)
return connect(client)
end

---comment
---@param client RestClient
---@param request Request
---@return integer? bytes_sent
---@return string? err_msg
---@return integer idx
local function send_request(client, request)
if client.socket == nil then
return nil, "no socket available"
return nil, "no socket available", 0
end
local payload = request:serialize()

local bytes, err, idx = nil, nil, 0

repeat bytes, err, idx = client.socket:send(payload, idx + 1, #payload) until (bytes == #payload)
or (err ~= nil)
or (err ~= nil)

return bytes, err, idx
end
Expand All @@ -63,7 +76,7 @@ local function parse_chunked_response(original_response, sock)
EXPECTING_BODY_CHUNK = "ExpectingBodyChunk",
}

local full_response = Response.new(original_response.status, nil)
local full_response = Response.new(original_response.status, nil) --[[@as ChunkedResponse]]

for header in original_response.headers:iter() do full_response.headers:append_chunk(header) end

Expand Down Expand Up @@ -130,39 +143,12 @@ local function parse_chunked_response(original_response, sock)
return full_response
end

local function recv_additional_response(original_response, sock)
local full_response = Response.new(original_response.status, nil)
local headers = original_response:get_headers()
local content_length_str = headers:get_one("Content-Length")
local content_length = nil
local bytes_read = 0
if content_length_str then
content_length = math.tointeger(content_length_str)
end

local next_recv, next_err, partial

repeat
next_recv, next_err, partial = sock:receive(content_length - bytes_read)

if next_recv ~= nil and #next_recv >= 1 then
full_response:append_body(next_recv)
bytes_read = bytes_read + #next_recv
end

if partial ~= nil and #partial >= 1 then
full_response:append_body(partial)
bytes_read = bytes_read + #partial
end
until next_err == "closed" or bytes_read >= content_length

full_response._received_body = true
full_response._parsed_headers = true

return full_response
end

local function handle_response(sock)
if api_version >= 9 then
local response, err = Response.tcp_source(sock)
if err or (not response) then return response, (err or "unknown error") end
return response, response:fill_body()
end
-- called select right before passing in so we receive immediately
local initial_recv, initial_err, partial = Response.source(function() return sock:receive('*l') end)

Expand All @@ -171,9 +157,7 @@ local function handle_response(sock)
if initial_recv ~= nil then
local headers = initial_recv:get_headers()

if headers:get_one("Content-Length") then
full_response = recv_additional_response(initial_recv, sock)
elseif headers:get_one("Transfer-Encoding") == "chunked" then
if headers and headers:get_one("Transfer-Encoding") == "chunked" then
local response, err = parse_chunked_response(initial_recv, sock)
if err ~= nil then
return nil, err
Expand All @@ -191,12 +175,12 @@ end

local function execute_request(client, request, retry_fn)
if not client._active then
return nil, "Called `execute request` on a terminated REST Client"
return nil, "Called `execute request` on a terminated REST Client", nil
end

if client.socket == nil then
local success, err = connect(client)
if not success then return nil, err end
if not success then return nil, err, nil end
end

local should_retry = retry_fn
Expand All @@ -208,7 +192,7 @@ local function execute_request(client, request, retry_fn)
-- send output
local _bytes_sent, send_err, _idx = nil, nil, 0
-- recv output
local response, recv_err, _partial = nil, nil, nil
local response, recv_err, partial = nil, nil, nil
-- return values
local ret, err = nil, nil

Expand All @@ -235,7 +219,7 @@ local function execute_request(client, request, retry_fn)
current_state = RestCallStates.COMPLETE
end
elseif current_state == RestCallStates.RECEIVE then
response, recv_err, _partial = handle_response(client.socket)
response, recv_err, partial = handle_response(client.socket)

if not recv_err then
ret = response
Expand Down Expand Up @@ -271,7 +255,7 @@ local function execute_request(client, request, retry_fn)
end
until current_state == RestCallStates.COMPLETE

return ret, err
return ret, err, partial
end

---@class RestClient
Expand All @@ -283,19 +267,17 @@ RestClient.__index = RestClient

function RestClient.one_shot_get(full_url, additional_headers, socket_builder)
local url_table = lb_utils.force_url_table(full_url)
local client = RestClient.new(url_table.scheme .. "://" .. url_table.authority, socket_builder)
local client = RestClient.new(url_table.scheme .. "://" .. url_table.host, socket_builder)
local ret, err = client:get(url_table.path, additional_headers)
client:shutdown()
client = nil
return ret, err
end

function RestClient.one_shot_post(full_url, body, additional_headers, socket_builder)
local url_table = lb_utils.force_url_table(full_url)
local client = RestClient.new(url_table.scheme .. "://" .. url_table.authority, socket_builder)
local client = RestClient.new(url_table.scheme .. "://" .. url_table.host, socket_builder)
local ret, err = client:post(url_table.path, body, additional_headers)
client:shutdown()
client = nil
return ret, err
end

Expand All @@ -322,10 +304,10 @@ end

function RestClient:get(path, additional_headers, retry_fn)
local request = Request.new("GET", path, nil):add_header(
"user-agent", "smartthings-lua-edge-driver"
):add_header("host", string.format("%s", self.base_url.host)):add_header(
"connection", "keep-alive"
)
"user-agent", "smartthings-lua-edge-driver"
):add_header("host", string.format("%s", self.base_url.host)):add_header(
"connection", "keep-alive"
)

if additional_headers ~= nil and type(additional_headers) == "table" then
for k, v in pairs(additional_headers) do request = request:add_header(k, v) end
Expand All @@ -336,10 +318,10 @@ end

function RestClient:post(path, body_string, additional_headers, retry_fn)
local request = Request.new("POST", path, nil):add_header(
"user-agent", "smartthings-lua-edge-driver"
):add_header("host", string.format("%s", self.base_url.host)):add_header(
"connection", "keep-alive"
)
"user-agent", "smartthings-lua-edge-driver"
):add_header("host", string.format("%s", self.base_url.host)):add_header(
"connection", "keep-alive"
)

if additional_headers ~= nil and type(additional_headers) == "table" then
for k, v in pairs(additional_headers) do request = request:add_header(k, v) end
Expand All @@ -352,10 +334,10 @@ end

function RestClient:put(path, body_string, additional_headers, retry_fn)
local request = Request.new("PUT", path, nil):add_header(
"user-agent", "smartthings-lua-edge-driver"
):add_header("host", string.format("%s", self.base_url.host)):add_header(
"connection", "keep-alive"
)
"user-agent", "smartthings-lua-edge-driver"
):add_header("host", string.format("%s", self.base_url.host)):add_header(
"connection", "keep-alive"
)

if additional_headers ~= nil and type(additional_headers) == "table" then
for k, v in pairs(additional_headers) do request = request:add_header(k, v) end
Expand All @@ -372,7 +354,7 @@ function RestClient.new(base_url, sock_builder)
if type(sock_builder) ~= "function" then sock_builder = utils.labeled_socket_builder() end

return
setmetatable({base_url = base_url, socket_builder = sock_builder, socket = nil, _active = true}, RestClient)
setmetatable({base_url = base_url, socket_builder = sock_builder, socket = nil, _active = true}, RestClient)
end

return RestClient
39 changes: 20 additions & 19 deletions drivers/SmartThings/jbl/src/lunchbox/sse/eventsource.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ local Response = require "luncheon.response"
--- @field public onopen function in-line callback for on-open events
--- @field public onmessage function in-line callback for on-message events
--- @field public onerror function in-line callback for on-error events; error callbacks will fire
--- @field private _reconnect boolean flag that says whether or not the client should attempt to reconnect on close.
--- @field private _reconnect_time_millis number The amount of time to wait between reconnects, in millis. Can be sent by the server.
--- @field private _sock_builder function|nil optional. If this function exists, it will be called to create a new TCP socket on connection.
--- @field private _sock table the TCP socket for the connection
--- @field private _needs_more boolean flag to track whether or not we're still expecting mroe on this source before we dispatch
--- @field private _last_field string the last field the parsing path saw, in case it needs to append more to its value
--- @field private _extra_headers table a table of string:string key-value pairs that will be inserted in to the initial requests's headers.
--- @field private _parse_buffers table inner state, keeps track of the various event stream buffers in between dispatches.
--- @field private _listeners table event listeners attached using the add_event_listener API instead of the inline callbacks.
--- @field package _reconnect boolean flag that says whether or not the client should attempt to reconnect on close.
--- @field package _reconnect_time_millis number The amount of time to wait between reconnects, in millis. Can be sent by the server.
--- @field package _sock_builder function|nil optional. If this function exists, it will be called to create a new TCP socket on connection.
--- @field package _sock table? the TCP socket for the connection
--- @field package _needs_more boolean flag to track whether or not we're still expecting mroe on this source before we dispatch
--- @field package _last_field string the last field the parsing path saw, in case it needs to append more to its value
--- @field package _extra_headers table a table of string:string key-value pairs that will be inserted in to the initial requests's headers.
--- @field package _parse_buffers table inner state, keeps track of the various event stream buffers in between dispatches.
--- @field package _listeners table event listeners attached using the add_event_listener API instead of the inline callbacks.
local EventSource = {}
EventSource.__index = EventSource

Expand Down Expand Up @@ -81,7 +81,7 @@ local function send_stream_start_request(payload, sock)
until (bytes == #payload) or (err ~= nil)

if err then
log.error("send error: " .. err)
log.error_with({ hub_logs = true }, "send error: " .. err)
end

return bytes, err, idx
Expand Down Expand Up @@ -156,8 +156,8 @@ local valid_fields = util.read_only {
-- h/t to github.com/FreeMasen for the suggestions on the efficient implementation of this
local function find_line_endings(chunk)
local r_idx, n_idx = string.find(chunk, "[\r\n]+")
if r_idx == n_idx then
-- 1 character
if r_idx == nil or r_idx == n_idx then
-- 1 character or no match
return r_idx, n_idx
end
local slice = string.sub(chunk, r_idx, n_idx)
Expand Down Expand Up @@ -289,8 +289,8 @@ local function connecting_action(source)
local response
response, err = Response.tcp_source(source._sock)

if err ~= nil then
return nil, err
if not response or err ~= nil then
return nil, err or "nil response from Response.tcp_source"
end

if response.status ~= 200 then
Expand All @@ -301,7 +301,7 @@ local function connecting_action(source)
if err ~= nil then
return nil, err
end
local content_type = string.lower((headers:get_one('content-type') or "none"))
local content_type = string.lower((headers and headers:get_one('content-type') or "none"))
if not content_type:find("text/event-stream", 1, true) then
local err_msg = "Expected content type of text/event-stream in response headers, received: " .. content_type
return nil, err_msg
Expand Down Expand Up @@ -375,7 +375,8 @@ local function open_action(source)
local recv_dbg = recv or "<NIL>"
if #recv_dbg == 0 then recv_dbg = "<EMPTY>" end
recv_dbg = recv_dbg:gsub("\r\n", "<CRLF>"):gsub("\n", "<LF>"):gsub("\r", "<CR>")
log.error(string.format("Received %s while expecting a chunked encoding payload length (hex number)\n", recv_dbg))
log.error_with({ hub_logs = true },
string.format("Received %s while expecting a chunked encoding payload length (hex number)\n", recv_dbg))
end
end

Expand Down Expand Up @@ -466,16 +467,16 @@ function EventSource.new(url, extra_headers, sock_builder)
local st_utils = require "st.utils"
while true do
if source.ready_state == EventSource.ReadyStates.CLOSED and
not source._reconnect
not source._reconnect
then
return
end
local _, action_err, partial = state_actions[source.ready_state](source)
if action_err ~= nil then
if action_err ~= "timeout" or action_err ~= "wantread" then
log.error("Event Source Coroutine State Machine error: " .. action_err)
log.error_with({ hub_logs = true }, "Event Source Coroutine State Machine error: " .. action_err)
if partial ~= nil and #partial > 0 then
log.error(st_utils.stringify_table(partial, "\tReceived Partial", true))
log.error_with({ hub_logs = true }, st_utils.stringify_table(partial, "\tReceived Partial", true))
end
source.ready_state = EventSource.ReadyStates.CLOSED
end
Expand Down

0 comments on commit c3ad669

Please sign in to comment.