From 478d5833cfbc7701bd369a0dd0a53c86b52cbd14 Mon Sep 17 00:00:00 2001 From: Techatrix Date: Fri, 5 Jul 2024 16:27:00 +0200 Subject: [PATCH] add a new parser for LSP headers --- src/lsp.zig | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/src/lsp.zig b/src/lsp.zig index fa4b13e..fc91189 100644 --- a/src/lsp.zig +++ b/src/lsp.zig @@ -820,6 +820,122 @@ pub const JsonRPCMessage = union(enum) { } }; +pub const BaseProtocolHeader = struct { + content_length: usize, + + pub const max_header_length: usize = 1024; + + pub const ParseError = error{ + EndOfStream, + /// The message is longer than `std.math.maxInt(usize)`. + OversizedMessage, + /// The header field is longer than `max_header_length`. The ": " doesn't count towards the header field length. + OversizedHeaderField, + MissingContentLength, + InvalidContentLength, + InvalidHeaderField, + }; + + pub inline fn parse(reader: anytype) (@TypeOf(reader).Error || ParseError)!BaseProtocolHeader { + return @errorCast(parseAny(reader.any())); + } + + /// Type erased version of `parse`. + pub fn parseAny(reader: std.io.AnyReader) (anyerror || ParseError)!BaseProtocolHeader { + var content_length: ?usize = null; + + outer: while (true) { + var maybe_colon_index: ?usize = null; + + var buffer: [max_header_length]u8 = undefined; + var buffer_index: usize = 0; + + while (true) { + const byte: u8 = try reader.readByte(); + switch (byte) { + '\n' => return error.InvalidHeaderField, + '\r' => { + if (try reader.readByte() != '\n') return error.InvalidHeaderField; + if (buffer_index == 0) break :outer; + break; + }, + ':' => { + // The ": " is not being added to the buffer here! + if (try reader.readByte() != ' ') return error.InvalidHeaderField; + maybe_colon_index = buffer_index; + }, + else => { + if (buffer_index == max_header_length) return error.OversizedHeaderField; + buffer[buffer_index] = byte; + buffer_index += 1; + }, + } + } + + const colon_index = maybe_colon_index orelse return error.InvalidHeaderField; + + const header = buffer[0..buffer_index]; + const header_name = header[0..colon_index]; + const header_value = header[colon_index..]; + + if (!std.ascii.eqlIgnoreCase(header_name, "content-length")) continue; + + content_length = std.fmt.parseUnsigned(usize, header_value, 10) catch |err| switch (err) { + error.Overflow => return error.OversizedMessage, + error.InvalidCharacter => return error.InvalidContentLength, + }; + } + + return .{ + .content_length = content_length orelse return error.MissingContentLength, + }; + } + + pub fn format( + header: BaseProtocolHeader, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + out_stream: anytype, + ) !void { + _ = options; + if (fmt.len != 0) std.fmt.invalidFmtError(fmt, header); + try std.fmt.format(out_stream, "Content-Length: {d}\r\n\r\n", .{header.content_length}); + } + + test parse { + try expectParseError("", error.EndOfStream); + try expectParseError("\n", error.InvalidHeaderField); + try expectParseError("\n\r", error.InvalidHeaderField); + try expectParseError("\r", error.EndOfStream); + try expectParseError("\r\n", error.MissingContentLength); + try expectParseError("\r\n\r\n", error.MissingContentLength); + + try expectParseError("content-length: 32\r\n", error.EndOfStream); + try expectParseError("content-length 32\r\n\r\n", error.InvalidHeaderField); + try expectParseError("content-length:32\r\n\r\n", error.InvalidHeaderField); + try expectParseError("contentLength: 32\r\n\r\n", error.MissingContentLength); + try expectParseError("content-length: abababababab\r\n\r\n", error.InvalidContentLength); + try expectParseError("content-length: 9999999999999999999999999999999999\r\n\r\n", error.OversizedMessage); + + try expectParse("content-length: 32\r\n\r\n", .{ .content_length = 32 }); + try expectParse("Content-Length: 32\r\n\r\n", .{ .content_length = 32 }); + + try expectParse("content-type: whatever\r\nContent-Length: 666\r\n\r\n", .{ .content_length = 666 }); + try expectParse("Content-Type: impostor\r\ncontent-length: 42\r\n\r\n", .{ .content_length = 42 }); + } + + fn expectParse(input: []const u8, expected_header: BaseProtocolHeader) !void { + var fbs = std.io.fixedBufferStream(input); + const actual_header = try parse(fbs.reader()); + try std.testing.expectEqual(expected_header.content_length, actual_header.content_length); + } + + fn expectParseError(input: []const u8, expected_error: ParseError) !void { + var fbs = std.io.fixedBufferStream(input); + try std.testing.expectError(expected_error, parse(fbs.reader())); + } +}; + pub const MethodWithParams = struct { method: []const u8, /// The `std.json.Value` can only be `.null`, `.array` or `.object`.