diff --git a/http/server.lua b/http/server.lua index ea014f2..6e33572 100644 --- a/http/server.lua +++ b/http/server.lua @@ -16,13 +16,25 @@ local errno = require 'errno' local DETACHED = 101 local function errorf(fmt, ...) - error(string.format(fmt, ...)) + error(string.format(fmt, ...), 3) end local function sprintf(fmt, ...) return string.format(fmt, ...) end +local function is_callable(obj) + local t_obj = type(obj) + if t_obj == 'function' then + return true + end + if t_obj == 'table' then + local mt = getmetatable(obj) + return (type(mt) == 'table' and type(mt.__call) == 'function') + end + return false +end + local function uri_escape(str) local res = {} if type(str) == 'table' then @@ -627,220 +639,6 @@ local function normalize_headers(hdrs) return res end -local function parse_request(req) - local p = lib._parse_request(req) - if p.error then - return p - end - p.path = uri_unescape(p.path) - if p.path:sub(1, 1) ~= "/" or p.path:find("./", nil, true) ~= nil then - p.error = "invalid uri" - return p - end - return p -end - -local function process_client(self, s, peer) - while true do - local hdrs = '' - - local is_eof = false - while true do - local chunk = s:read{ - delimiter = { "\n\n", "\r\n\r\n" } - } - - if chunk == '' then - is_eof = true - break -- eof - elseif chunk == nil then - log.error('failed to read request: %s', errno.strerror()) - return - end - - hdrs = hdrs .. chunk - - if string.endswith(hdrs, "\n\n") or string.endswith(hdrs, "\r\n\r\n") then - break - end - end - - if is_eof then - break - end - - log.debug("request:\n%s", hdrs) - local p = parse_request(hdrs) - if p.error ~= nil then - log.error('failed to parse request: %s', p.error) - s:write(sprintf("HTTP/1.0 400 Bad request\r\n\r\n%s", p.error)) - break - end - p.httpd = self - p.s = s - p.peer = peer - setmetatable(p, request_mt) - - if p.headers['expect'] == '100-continue' then - s:write('HTTP/1.0 100 Continue\r\n\r\n') - end - - local logreq = self.options.log_requests and log.info or log.debug - logreq("%s %s%s", p.method, p.path, - p.query ~= "" and "?"..p.query or "") - - local res, reason = pcall(self.options.handler, self, p) - p:read() -- skip remaining bytes of request body - local status, hdrs, body - - if not res then - status = 500 - hdrs = {} - local trace = debug.traceback() - local logerror = self.options.log_errors and log.error or log.debug - logerror('unhandled error: %s\n%s\nrequest:\n%s', - tostring(reason), trace, tostring(p)) - if self.options.display_errors then - body = - "Unhandled error: " .. tostring(reason) .. "\n" - .. trace .. "\n\n" - .. "\n\nRequest:\n" - .. tostring(p) - else - body = "Internal Error" - end - elseif type(reason) == 'table' then - if reason.status == nil then - status = 200 - elseif type(reason.status) == 'number' then - status = reason.status - else - error('response.status must be a number') - end - if reason.headers == nil then - hdrs = {} - elseif type(reason.headers) == 'table' then - hdrs = normalize_headers(reason.headers) - else - error('response.headers must be a table') - end - body = reason.body - elseif reason == nil then - status = 200 - hdrs = {} - elseif type(reason) == 'number' then - if reason == DETACHED then - break - end - else - error('invalid response') - end - - local gen, param, state - if type(body) == 'string' then - -- Plain string - hdrs['content-length'] = #body - elseif type(body) == 'function' then - -- Generating function - gen = body - hdrs['transfer-encoding'] = 'chunked' - elseif type(body) == 'table' and body.gen then - -- Iterator - gen, param, state = body.gen, body.param, body.state - hdrs['transfer-encoding'] = 'chunked' - elseif body == nil then - -- Empty body - hdrs['content-length'] = 0 - else - body = tostring(body) - hdrs['content-length'] = #body - end - - if hdrs.server == nil then - hdrs.server = sprintf('Tarantool http (tarantool v%s)', _TARANTOOL) - end - - if p.proto[1] ~= 1 then - hdrs.connection = 'close' - elseif p.broken then - hdrs.connection = 'close' - elseif rawget(p, 'body') == nil then - hdrs.connection = 'close' - elseif p.proto[2] == 1 then - if p.headers.connection == nil then - hdrs.connection = 'keep-alive' - elseif string.lower(p.headers.connection) ~= 'keep-alive' then - hdrs.connection = 'close' - else - hdrs.connection = 'keep-alive' - end - elseif p.proto[2] == 0 then - if p.headers.connection == nil then - hdrs.connection = 'close' - elseif string.lower(p.headers.connection) == 'keep-alive' then - hdrs.connection = 'keep-alive' - else - hdrs.connection = 'close' - end - end - - local response = { - "HTTP/1.1 "; - status; - " "; - reason_by_code(status); - "\r\n"; - }; - for k, v in pairs(hdrs) do - if type(v) == 'table' then - for i, sv in pairs(v) do - table.insert(response, sprintf("%s: %s\r\n", ucfirst(k), sv)) - end - else - table.insert(response, sprintf("%s: %s\r\n", ucfirst(k), v)) - end - end - table.insert(response, "\r\n") - - if type(body) == 'string' then - table.insert(response, body) - response = table.concat(response) - if not s:write(response) then - break - end - elseif gen then - response = table.concat(response) - if not s:write(response) then - break - end - response = nil - -- Transfer-Encoding: chunked - for _, part in gen, param, state do - part = tostring(part) - if not s:write(sprintf("%x\r\n%s\r\n", #part, part)) then - break - end - end - if not s:write("0\r\n\r\n") then - break - end - else - response = table.concat(response) - if not s:write(response) then - break - end - end - - if p.proto[1] ~= 1 then - break - end - - if hdrs.connection ~= 'keep-alive' then - break - end - end -end - local function httpd_stop(self) if type(self) ~= 'table' then error("httpd: usage: httpd:stop()") @@ -916,7 +714,7 @@ local function match_route(self, method, route) end local function set_helper(self, name, sub) - if sub == nil or type(sub) == 'function' then + if sub == nil or is_callable(sub) then self.helpers[ name ] = sub return self end @@ -924,7 +722,7 @@ local function set_helper(self, name, sub) end local function set_hook(self, name, sub) - if sub == nil or type(sub) == 'function' then + if sub == nil or is_callable(sub) then self.hooks[ name ] = sub return self end @@ -968,7 +766,7 @@ local function ctx_action(tx) local action = tx.endpoint.action if tx.httpd.options.cache_controllers then if tx.httpd.cache[ ctx ] ~= nil then - if type(tx.httpd.cache[ ctx ][ action ]) ~= 'function' then + if is_callable(tx.httpd.cache[ ctx ][ action ]) then errorf("Controller '%s' doesn't contain function '%s'", ctx, action) end @@ -997,7 +795,7 @@ local function ctx_action(tx) errorf("require '%s' didn't return table", ctx) end - if type(mod[ action ]) ~= 'function' then + if is_callable(mod[ action ]) then errorf("Controller '%s' doesn't contain function '%s'", ctx, action) end @@ -1039,8 +837,8 @@ local function add_route(self, opts, sub) sub = ctx_action - elseif type(sub) ~= 'function' then - errorf("wrong argument: expected function, but received %s", + elseif not is_callable(sub) then + errorf("wrong argument: expected callable, but received %s", type(sub)) end @@ -1132,19 +930,282 @@ local function url_for_httpd(httpd, name, args, query) end end +local function httpd_parse_request(request_raw) + local request_parsed = lib._parse_request(request_raw) + if request_parsed.error then + return nil, request_parsed.error + end + request_parsed.path = uri_unescape(request_parsed.path) + if request_parsed.path:sub(1, 1) ~= "/" or + request_parsed.path:find("./", nil, true) ~= nil then + return nil, "invalid uri" + end + return request_parsed +end + +local function httpd_http11_parse_request(session, request_raw) + local request_parsed, err = httpd_parse_request(request_raw) + if not request_parsed then + return nil, err + end + request_parsed.httpd = session.server + request_parsed.s = session.socket + request_parsed.peer = session.peer + setmetatable(request_parsed, request_mt) + + return request_parsed +end + +local function httpd_http11_handler(session) + local hdrs = '' + + while true do + local chunk = session:read{ delimiter = { "\n\n", "\r\n\r\n" } } + + if chunk == '' then + return false + elseif chunk == nil then + log.error('failed to read request: %s', errno.strerror()) + return false + end + + hdrs = hdrs .. chunk + + if hdrs:endswith("\n\n") or hdrs:endswith("\r\n\r\n") then + break + end + end + + log.debug("request:\n%s", hdrs) + local p, err = httpd_http11_parse_request(session, hdrs) + if not p then + log.error('failed to parse request: %s', err) + session:write(sprintf("HTTP/1.1 400 Bad request\r\n\r\n%s", err)) + return + end + + if p.headers['upgrade'] then + local proto_name = p.headers['upgrade']:lower() + local proto = session.server.upgrades[proto_name] + if not proto then + if not session.server.options.ignore_unknown_upgrade then + session:write('HTTP/1.1 400 Bad Request\r\n\r\n') + return false + end + else + local ok, upgrade_ok = pcall(proto.upgrade, session, p) + if not ok then + log.error("Failed to upgrade to '%s': %s", p.headers['upgrade'], + upgrade_ok) + session:write('HTTP/1.1 500 Internal Error\r\n\r\n') + return false + elseif not upgrade_ok then + -- TODO: should we close connection, or we should retry again + return false + end + + session.ctx.proto = proto.name + session.ctx.handler = proto.handler + return true + end + end + + if p.headers['expect'] == '100-continue' then + session:write('HTTP/1.1 100 Continue\r\n\r\n') + elseif p.headers['expect'] then + session:write('HTTP/1.1 417 Expectation Failed\r\n\r\n') + return false + end + + local logreq = session.server.options.log_requests and log.info or log.debug + logreq("%s %s%s", p.method, p.path, p.query ~= "" and "?"..p.query or "") + + local res, reason = pcall(session.server.options.handler, session.server, p) + p:read() -- skip remaining bytes of request body + local status, hdrs, body + + if not res then + status = 500 + hdrs = {} + local trace = debug.traceback() + local logerror = log.error + if session.server.options.log_errors then + logerror = log.debug + end + logerror('unhandled error: %s\n%s\nrequest:\n%s', + tostring(reason), trace, tostring(p)) + if session.server.options.display_errors then + body = "Unhandled error: %s\n%s \n\n\n\nRequest:\n%s" + body = body:format(tostring(reason), trace, tostring(p)) + else + body = "Internal Error" + end + elseif type(reason) == 'table' then + if reason.status == nil then + status = 200 + elseif type(reason.status) == 'number' then + status = reason.status + else + error('response.status must be a number') + end + if reason.headers == nil then + hdrs = {} + elseif type(reason.headers) == 'table' then + hdrs = normalize_headers(reason.headers) + else + error('response.headers must be a table') + end + body = reason.body + elseif reason == nil then + status = 200 + hdrs = {} + elseif type(reason) == 'number' then + if reason == DETACHED then + return false + end + else + error('invalid response') + end + + local gen, param, state + if type(body) == 'string' then + -- Plain string + hdrs['content-length'] = #body + elseif is_callable(body) then + -- Generating function + gen = body + hdrs['transfer-encoding'] = 'chunked' + elseif type(body) == 'table' and body.gen then + -- Iterator + gen, param, state = body.gen, body.param, body.state + hdrs['transfer-encoding'] = 'chunked' + elseif body == nil then + -- Empty body + hdrs['content-length'] = 0 + else + body = tostring(body) + hdrs['content-length'] = #body + end + + if hdrs.server == nil then + hdrs.server = sprintf('Tarantool http (tarantool v%s)', _TARANTOOL) + end + + if p.proto[1] ~= 1 then + hdrs.connection = 'close' + elseif p.broken then + hdrs.connection = 'close' + elseif rawget(p, 'body') == nil then + hdrs.connection = 'close' + elseif p.proto[2] == 1 then + if p.headers.connection == nil then + hdrs.connection = 'keep-alive' + elseif string.lower(p.headers.connection) ~= 'keep-alive' then + hdrs.connection = 'close' + else + hdrs.connection = 'keep-alive' + end + elseif p.proto[2] == 0 then + if p.headers.connection == nil then + hdrs.connection = 'close' + elseif string.lower(p.headers.connection) == 'keep-alive' then + hdrs.connection = 'keep-alive' + else + hdrs.connection = 'close' + end + end + + local response = { + "HTTP/1.1 "; + status; + " "; + reason_by_code(status); + "\r\n"; + }; + for k, v in pairs(hdrs) do + if type(v) == 'table' then + for i, sv in pairs(v) do + table.insert(response, sprintf("%s: %s\r\n", ucfirst(k), sv)) + end + else + table.insert(response, sprintf("%s: %s\r\n", ucfirst(k), v)) + end + end + table.insert(response, "\r\n") + + if type(body) == 'string' then + table.insert(response, body) + response = table.concat(response) + if not session:write(response) then + return false + end + elseif gen then + response = table.concat(response) + if not session:write(response) then + return false + end + response = nil + -- Transfer-Encoding: chunked + for _, part in gen, param, state do + part = tostring(part) + if not session:write(sprintf("%x\r\n%s\r\n", #part, part)) then + return false + end + end + if not session:write("0\r\n\r\n") then + return false + end + else + response = table.concat(response) + if not session:write(response) then + return false + end + end + + if p.proto[1] ~= 1 then + return false + end + + return (hdrs.connection == 'keep-alive') +end + +local session_methods = { + read = function(self, ...) return self.socket:read(...) end, + write = function(self, ...) return self.socket:write(...) end, +} + +local session_mt = { + __index = session_methods +} + +-- by default we're creating session with HTTP/1.1 support +local function session_new(self, socket, peer) + return setmetatable({ + server = self, + socket = socket, + peer = peer, + ctx = { proto = 'HTTP/1.1', handler = httpd_http11_handler }, + }, session_mt) +end + +local function httpd_tcp_handler(self, sckt, peer) + local session = session_new(self, sckt, peer) + + local rv = true + repeat + rv = session.ctx.handler(session) + until not rv +end + local function httpd_start(self) if type(self) ~= 'table' then error("httpd: usage: httpd:start()") end - local server = socket.tcp_server(self.host, self.port, - { name = 'http', - handler = function(...) - local res = process_client(self, ...) - end}) - if server == nil then - error(sprintf("Can't create tcp_server: %s", errno.strerror())) - end + local server = assertf(socket.tcp_server(self.host, self.port, { + name = 'http', + handler = function(...) httpd_tcp_handler(self, ...) end + }), "Can't create tcp_server: %s", errno.strerror()) rawset(self, 'is_run', true) rawset(self, 'tcp_server', server) @@ -1153,62 +1214,89 @@ local function httpd_start(self) return self end -local exports = { - DETACHED = DETACHED, - - new = function(host, port, options) - if options == nil then - options = {} +local function httpd_register_extension(self, ext_type, opts) + if ext_type:lower() == 'upgrade' then + if not (type(opts) == 'table') then + errorf("Upgrade extension argument should be table") end - if type(options) ~= 'table' then - errorf("options must be table not '%s'", type(options)) + if not (type(opts.name) == 'string') then + errorf("Upgrade extension name should be %s", 'options.name', 'string') + end + if not is_callable(opts.upgrade) then + errorf("Upgrade extension callback should be callable") + end + if not is_callable(opts.handler) then + errorf("Upgrade extension handler should be callable") end - local default = { - max_header_size = 4096, - header_timeout = 100, - handler = handler, - app_dir = '.', - charset = 'utf-8', - cache_templates = true, - cache_controllers = true, - cache_static = true, - log_requests = true, - log_errors = true, - display_errors = true, - } - - local self = { - host = host, - port = port, - is_run = false, - stop = httpd_stop, - start = httpd_start, - options = extend(default, options, true), - - routes = { }, - iroutes = { }, - helpers = { - url_for = url_for_helper, - }, - hooks = { }, - - -- methods - route = add_route, - match = match_route, - helper = set_helper, - hook = set_hook, - url_for = url_for_httpd, - - -- caches - cache = { - tpl = {}, - ctx = {}, - static = {}, - }, - } - return self + self.upgrades[opts.name:lower()] = table.copy(opts) + else + errorf('Unknown extension type: %s', ext_type) end +end + +local httpd_methods = { + stop = httpd_stop, + start = httpd_start, + route = add_route, + match = match_route, + helper = set_helper, + hook = set_hook, + url_for = url_for_httpd, + register_extension = httpd_register_extension, } -return exports +local httpd_mt = { + __index = httpd_methods +} + +local httpd_options_default = { + max_header_size = 4096, + header_timeout = 100, + handler = handler, + app_dir = '.', + charset = 'utf-8', + cache_templates = true, + cache_controllers = true, + cache_static = true, + log_requests = true, + log_errors = true, + display_errors = true, + ignore_unknown_upgrade = true, +} + +local function httpd_new(host, port, options) + options = options or {} + if type(options) ~= 'table' then + errorf("options must be table, not '%s'", opts_tp) + end + + -- populate options table with default values + options = extend(table.copy(httpd_options_default), options, true) + + local self = setmetatable({ + host = host, + port = port, + is_run = false, + options = options, + + routes = { }, + iroutes = { }, + helpers = { url_for = url_for_helper, }, + hooks = { }, + upgrades = { }, + + -- caches + cache = { tpl = {}, ctx = {}, static = {}, }, + }, httpd_mt) + + return self +end + +return { + DETACHED = DETACHED, + new = httpd_new, + parse_headers = httpd_parse_request, + uri_escape = uri_escape, + uri_unescape = uri_unescape, +} diff --git a/test/http.test.lua b/test/http.test.lua index ff78338..58fef74 100755 --- a/test/http.test.lua +++ b/test/http.test.lua @@ -9,8 +9,10 @@ local json = require('json') local yaml = require 'yaml' local urilib = require('uri') -local test = tap.test("http") -test:plan(7) +local socket = require('socket') + +local test = tap.test("http"); test:plan(8) + test:test("split_uri", function(test) test:plan(65) local function check(uri, rhs) @@ -388,4 +390,101 @@ test:test("server requests", function(test) httpd:stop() end) +test:test("upgrade", function(test) + test:plan(4) + + local log = require('log') + + local httpd = cfgserv() + httpd:start() + httpd:register_extension('upgrade', { + name = 'exist-error-upgrade', + upgrade = function() error('error') end, + handler = function() end, + }) + httpd:register_extension('upgrade', { + name = 'exist-fails-upgrade', + upgrade = function(session) + session:write('HTTP/1.1 426 Upgrade Required\r\n\r\n') + return false + end, + handler = function() end, + }) + + local switching_header = 'HTTP/1.1 101 Switching Protocols\r\n' .. + 'Upgrade: exist\r\n' .. + 'Connection: Upgrade\r\n\r\n' + httpd:register_extension('upgrade', { + name = 'exists', + upgrade = function(session) + session:write(switching_header) + return true + end, + handler = function(session) + while true do + local in_data = session:read(24) + if in_data == '' then + return false + end + session:write(in_data) + end + end, + }) + + test:test("upgrade failed, no protocol", function(test) + test:plan(1) + local r = http_client.get('http://127.0.0.1:12345/abc', { + headers = { upgrade = 'non-existent' } + }) + test:is(r.status, 400, 'Error code is 400') + end) + + test:test("upgrade failed, error while upgrade", function(test) + test:plan(1) + local r = http_client.get('http://127.0.0.1:12345/abc', { + headers = { upgrade = 'exist-error-upgrade' } + }) + test:is(r.status, 500, 'Error code is 500') + end) + + test:test("upgrade failed, upgrade return false", function(test) + test:plan(1) + local r = http_client.get('http://127.0.0.1:12345/abc', { + headers = { upgrade = 'exist-fails-upgrade' } + }) + test:is(r.status, 426, 'Error code is 426') + end) + + local ws_get_r = "GET /abc HTTP/1.1\r\nUpgrade:exists\r\n\r\n" + + test:test("upgrade success, simple tcp echo", function(test) + test:plan(3) + local sck = socket.tcp_connect('127.0.0.1', 12345) + sck:write(ws_get_r) + local data = '' + while true do + local tdata = sck:read({ delimiter = { '\n\n', '\r\n\r\n' } }) + if not tdata or tdata == '' or + tdata:endswith('\r\n\r\n') or tdata:endswith('\n\n') then + if tdata then + data = data .. tdata + end + break + end + end + + if not data:endswith('\r\n\r\n') then + test:fail('automatic fail') + else + test:is(#data, #switching_header, 'right http upgrade len') + end + + local msg = ('x'):rep(24) + sck:write(msg) + local res = sck:read(#msg) + test:is(#res, #msg, 'echo is ok') + test:is(res, msg, 'echo is ok') + end) +end) + os.exit(test:check() == true and 0 or 1)