Skip to content

fetch, tls, and http fixes #24740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions lib/std/Io/Reader.zig
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ pub const VTable = struct {
///
/// Returns the number of bytes written, which will be at minimum `0` and
/// at most `limit`. The number returned, including zero, does not indicate
/// end of stream. `limit` is guaranteed to be at least as large as the
/// buffer capacity of `w`, a value whose minimum size is determined by the
/// stream implementation.
/// end of stream.
///
/// The reader's internal logical seek position moves forward in accordance
/// with the number of bytes returned from this function.
Expand Down
68 changes: 45 additions & 23 deletions lib/std/crypto/tls/Client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ pub const ReadError = error{
TlsUnexpectedMessage,
TlsIllegalParameter,
TlsSequenceOverflow,
/// The buffer provided to the read function was not at least
/// `min_buffer_len`.
OutputBufferUndersize,
};

pub const SslKeyLog = struct {
Expand Down Expand Up @@ -372,7 +369,8 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
};
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch
return error.TlsBadRecordMac;
cleartext_fragment_end += std.mem.trimEnd(u8, cleartext, "\x00").len;
// TODO use scalar, non-slice version
cleartext_fragment_end += mem.trimEnd(u8, cleartext, "\x00").len;
},
}
read_seq += 1;
Expand All @@ -395,9 +393,9 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_fragment_buf[0..message_len];
const ad = std.mem.toBytes(big(read_seq)) ++
const ad = mem.toBytes(big(read_seq)) ++
record_header[0 .. 1 + 2] ++
std.mem.toBytes(big(message_len));
mem.toBytes(big(message_len));
const record_iv = record_decoder.array(P.record_iv_length).*;
const masked_read_seq = read_seq &
comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
Expand Down Expand Up @@ -738,7 +736,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
&.{ "server finished", &p.transcript_hash.finalResult() },
P.verify_data_length,
),
.app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block),
.app_cipher = mem.bytesToValue(P.Tls_1_2, &key_block),
} };
const pv = &p.version.tls_1_2;
const nonce: [P.AEAD.nonce_length]u8 = nonce: {
Expand All @@ -756,7 +754,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
client_verify_cleartext.len ..][0..client_verify_cleartext.len],
client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length],
&client_verify_cleartext,
std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
nonce,
pv.app_cipher.client_write_key,
);
Expand Down Expand Up @@ -873,7 +871,10 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
.input = input,
.reader = .{
.buffer = options.read_buffer,
.vtable = &.{ .stream = stream },
.vtable = &.{
.stream = stream,
.readVec = readVec,
},
.seek = 0,
.end = 0,
},
Expand Down Expand Up @@ -1017,7 +1018,7 @@ fn prepareCiphertextRecord(
const nonce = nonce: {
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ std.mem.toBytes(big(c.write_seq));
const operand: V = pad ++ mem.toBytes(big(c.write_seq));
break :nonce @as(V, pv.client_iv) ^ operand;
};
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
Expand Down Expand Up @@ -1048,7 +1049,7 @@ fn prepareCiphertextRecord(
record_header.* = .{@intFromEnum(inner_content_type)} ++
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int(u16, P.record_iv_length + message_len + P.mac_length);
const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
const ad = mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length];
ciphertext_end += P.record_iv_length;
const nonce: [P.AEAD.nonce_length]u8 = nonce: {
Expand Down Expand Up @@ -1076,7 +1077,22 @@ pub fn eof(c: Client) bool {
}

fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
// This function writes exclusively to the buffer.
_ = w;
_ = limit;
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
return readIndirect(c);
}

fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
// This function writes exclusively to the buffer.
_ = data;
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
return readIndirect(c);
}

fn readIndirect(c: *Client) Reader.Error!usize {
const r = &c.reader;
if (c.eof()) return error.EndOfStream;
const input = c.input;
// If at least one full encrypted record is not buffered, read once.
Expand Down Expand Up @@ -1108,8 +1124,13 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
if (record_end > input.buffered().len) return 0;
}

var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
if (r.seek == r.end) {
r.seek = 0;
r.end = 0;
}
const cleartext_buffer = r.buffer[r.end..];

const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
inline else => |*p| switch (c.tls_version) {
.tls_1_3 => {
const pv = &p.tls_1_3;
Expand All @@ -1121,23 +1142,24 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
const nonce = nonce: {
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
const operand: V = pad ++ mem.toBytes(big(c.read_seq));
break :nonce @as(V, pv.server_iv) ^ operand;
};
const cleartext = cleartext_stack_buffer[0..ciphertext.len];
const cleartext = cleartext_buffer[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
return failRead(c, error.TlsBadRecordMac);
// TODO use scalar, non-slice version
const msg = mem.trimRight(u8, cleartext, "\x00");
break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
break :cleartext .{ msg.len - 1, @enumFromInt(msg[msg.len - 1]) };
},
.tls_1_2 => {
const pv = &p.tls_1_2;
const P = @TypeOf(p.*);
const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked
const ad = std.mem.toBytes(big(c.read_seq)) ++
const ad = mem.toBytes(big(c.read_seq)) ++
ad_header[0 .. 1 + 2] ++
std.mem.toBytes(big(message_len));
mem.toBytes(big(message_len));
const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked
const masked_read_seq = c.read_seq &
comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
Expand All @@ -1149,14 +1171,15 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
};
const ciphertext = input.take(message_len) catch unreachable; // already peeked
const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
const cleartext = cleartext_stack_buffer[0..ciphertext.len];
const cleartext = cleartext_buffer[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
return failRead(c, error.TlsBadRecordMac);
break :cleartext .{ cleartext, ct };
break :cleartext .{ cleartext.len, ct };
},
else => unreachable,
},
};
const cleartext = cleartext_buffer[0..cleartext_len];
c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
switch (inner_ct) {
.alert => {
Expand Down Expand Up @@ -1245,9 +1268,8 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
return 0;
},
.application_data => {
if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
try w.writeAll(cleartext);
return cleartext.len;
r.end += cleartext.len;
return 0;
},
else => return failRead(c, error.TlsUnexpectedMessage),
}
Expand Down
58 changes: 35 additions & 23 deletions lib/std/http.zig
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ pub const ContentEncoding = enum {
});
return map.get(s);
}

pub fn minBufferCapacity(ce: ContentEncoding) usize {
return switch (ce) {
.zstd => std.compress.zstd.default_window_len,
.gzip, .deflate => std.compress.flate.max_window_len,
Comment on lines +298 to +299
Copy link
Collaborator

@squeek502 squeek502 Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment mostly borne out of my own confusion: this concept of "buffer capacity" for std.compress seems to be defined in different ways many different places, and some of them appear to disagree with eachother.

For zstd, these doc comments would lead me to believe that default_window_len is not enough:

/// The output buffer is asserted to have capacity for `window_len` plus
/// `zstd.block_size_max`.

/// When connecting `reader` to a `Writer`, `buffer` should be empty, and
/// `Writer.buffer` capacity has requirements based on `Options.window_len`.
///
/// Otherwise, `buffer` has those requirements.

For deflate, these constants are just generally confusing to me:

/// When decompressing, the output buffer is used as the history window, so
/// less than this may result in failure to decompress streams that were
/// compressed with a larger window.
pub const max_window_len = history_len * 2;
pub const history_len = 32768;

Seems like some more definitive recommended_buffer_size or min_buffer_size constant should be defined per-compression-algorithm and used consistently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I've been figuring this stuff out as I go along and only just starting to get a handle of it. Once I'm confident in the flow chart I'll do another pass over the doc comments, constants, and make sure everything is clear and agrees with each other.

even in this PR you can see I changed the rules for Reader.stream

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to do all this stuff in a branch and only inflict it on master once I had a fully done, coherent change, but it's just too massive. everybody has to come along for the ride and suffer with me, sorry

.compress, .identity => 0,
};
}
};

pub const Connection = enum {
Expand Down Expand Up @@ -412,7 +420,7 @@ pub const Reader = struct {
/// * `interfaceDecompressing`
pub fn bodyReader(
reader: *Reader,
buffer: []u8,
transfer_buffer: []u8,
transfer_encoding: TransferEncoding,
content_length: ?u64,
) *std.Io.Reader {
Expand All @@ -421,7 +429,7 @@ pub const Reader = struct {
.chunked => {
reader.state = .{ .body_remaining_chunk_len = .head };
reader.interface = .{
.buffer = buffer,
.buffer = transfer_buffer,
.seek = 0,
.end = 0,
.vtable = &.{
Expand All @@ -435,7 +443,7 @@ pub const Reader = struct {
if (content_length) |len| {
reader.state = .{ .body_remaining_content_length = len };
reader.interface = .{
.buffer = buffer,
.buffer = transfer_buffer,
.seek = 0,
.end = 0,
.vtable = &.{
Expand All @@ -460,11 +468,12 @@ pub const Reader = struct {
/// * `interface`
pub fn bodyReaderDecompressing(
reader: *Reader,
transfer_buffer: []u8,
transfer_encoding: TransferEncoding,
content_length: ?u64,
content_encoding: ContentEncoding,
decompressor: *Decompressor,
decompression_buffer: []u8,
decompress: *Decompress,
decompress_buffer: []u8,
) *std.Io.Reader {
if (transfer_encoding == .none and content_length == null) {
assert(reader.state == .received_head);
Expand All @@ -474,22 +483,22 @@ pub const Reader = struct {
return reader.in;
},
.deflate => {
decompressor.* = .{ .flate = .init(reader.in, .zlib, decompression_buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(reader.in, .zlib, decompress_buffer) };
return &decompress.flate.reader;
},
.gzip => {
decompressor.* = .{ .flate = .init(reader.in, .gzip, decompression_buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(reader.in, .gzip, decompress_buffer) };
return &decompress.flate.reader;
},
.zstd => {
decompressor.* = .{ .zstd = .init(reader.in, decompression_buffer, .{ .verify_checksum = false }) };
return &decompressor.zstd.reader;
decompress.* = .{ .zstd = .init(reader.in, decompress_buffer, .{ .verify_checksum = false }) };
return &decompress.zstd.reader;
},
.compress => unreachable,
}
}
const transfer_reader = bodyReader(reader, &.{}, transfer_encoding, content_length);
return decompressor.init(transfer_reader, decompression_buffer, content_encoding);
const transfer_reader = bodyReader(reader, transfer_buffer, transfer_encoding, content_length);
return decompress.init(transfer_reader, decompress_buffer, content_encoding);
}

fn contentLengthStream(
Expand Down Expand Up @@ -691,33 +700,33 @@ pub const Reader = struct {
}
};

pub const Decompressor = union(enum) {
pub const Decompress = union(enum) {
flate: std.compress.flate.Decompress,
zstd: std.compress.zstd.Decompress,
none: *std.Io.Reader,

pub fn init(
decompressor: *Decompressor,
decompress: *Decompress,
transfer_reader: *std.Io.Reader,
buffer: []u8,
content_encoding: ContentEncoding,
) *std.Io.Reader {
switch (content_encoding) {
.identity => {
decompressor.* = .{ .none = transfer_reader };
decompress.* = .{ .none = transfer_reader };
return transfer_reader;
},
.deflate => {
decompressor.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
return &decompress.flate.reader;
},
.gzip => {
decompressor.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
return &decompress.flate.reader;
},
.zstd => {
decompressor.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
return &decompressor.zstd.reader;
decompress.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
return &decompress.zstd.reader;
},
.compress => unreachable,
}
Expand Down Expand Up @@ -794,7 +803,7 @@ pub const BodyWriter = struct {
}

/// When using content-length, asserts that the amount of data sent matches
/// the value sent in the header, then flushes.
/// the value sent in the header, then flushes `http_protocol_output`.
///
/// When using transfer-encoding: chunked, writes the end-of-stream message
/// with empty trailers, then flushes the stream to the system. Asserts any
Expand All @@ -818,10 +827,13 @@ pub const BodyWriter = struct {
///
/// Respects the value of `isEliding` to omit all data after the headers.
///
/// Does not flush `http_protocol_output`, but does flush `writer`.
///
/// See also:
/// * `end`
/// * `endChunked`
pub fn endUnflushed(w: *BodyWriter) Error!void {
try w.writer.flush();
switch (w.state) {
.end => unreachable,
.content_length => |len| {
Expand Down
Loading
Loading