Skip to content

Commit

Permalink
Key Update Crash on Allocation Failure
Browse files Browse the repository at this point in the history
  • Loading branch information
nibanks committed Aug 8, 2024
1 parent a6f38aa commit a74fea6
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/platform/crypt.c
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ QuicPacketKeyUpdate(
_Out_ QUIC_PACKET_KEY** NewKey
)
{
if (OldKey->Type != QUIC_PACKET_KEY_1_RTT) {
if (OldKey == NULL || OldKey->Type != QUIC_PACKET_KEY_1_RTT) {
return QUIC_STATUS_INVALID_STATE;
}

Expand Down
5 changes: 5 additions & 0 deletions src/test/MsQuicTests.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,11 @@ QuicDrillTestServerVNPacket(
_In_ int Family
);

void
QuicDrillTestKeyUpdateDuringHandshake(
_In_ int Family
);

//
// Datagram tests
//
Expand Down
9 changes: 9 additions & 0 deletions src/test/bin/quic_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,15 @@ TEST_P(WithDrillInitialPacketTokenArgs, QuicDrillTestServerVNPacket) {
}
}

TEST_P(WithDrillInitialPacketTokenArgs, QuicDrillTestKeyUpdateDuringHandshake) {
TestLoggerT<ParamType> Logger("QuicDrillTestKeyUpdateDuringHandshake", GetParam());
if (TestingKernelMode) {
//ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_DRILL_VN_PACKET_TOKEN, GetParam().Family));
} else {
QuicDrillTestKeyUpdateDuringHandshake(GetParam().Family);
}
}

TEST_P(WithDatagramNegotiationArgs, DatagramNegotiation) {
TestLoggerT<ParamType> Logger("QuicTestDatagramNegotiation", GetParam());
if (TestingKernelMode) {
Expand Down
141 changes: 126 additions & 15 deletions src/test/lib/DrillDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
--*/

#include "precomp.h"
#include <quic_crypt.h>
#include <msquichelper.h>

#ifdef QUIC_CLOG
#include "DrillDescriptor.cpp.clog.h"
#endif
Expand Down Expand Up @@ -152,7 +155,7 @@ DrillVNPacketDescriptor::write(
return PacketBuffer;
}

DrillInitialPacketDescriptor::DrillInitialPacketDescriptor()
DrillInitialPacketDescriptor::DrillInitialPacketDescriptor(uint8_t SrcCidLength)
{
Type = Initial;
Header.FixedBit = 1;
Expand All @@ -161,18 +164,20 @@ DrillInitialPacketDescriptor::DrillInitialPacketDescriptor()
const uint8_t CidValMax = 8;
for (uint8_t CidVal = 0; CidVal <= CidValMax; CidVal++) {
DestCid.push_back(CidVal);
SourceCid.push_back(CidValMax - CidVal);
}

for (uint8_t CidVal = 0; CidVal < SrcCidLength; CidVal++) {
SourceCid.push_back(SrcCidLength - CidVal);
}
}

DrillBuffer
DrillInitialPacketDescriptor::write(
bool EncryptPayload
) const
{
DrillBuffer PacketBuffer = DrillPacketDescriptor::write();

size_t CalculatedPacketLength = PacketBuffer.size();

DrillBuffer EncodedTokenLength;
if (TokenLen != nullptr) {
EncodedTokenLength = QuicDrillEncodeQuicVarInt(*TokenLen);
Expand All @@ -181,24 +186,20 @@ DrillInitialPacketDescriptor::write(
}
PacketBuffer.insert(PacketBuffer.end(), EncodedTokenLength.begin(), EncodedTokenLength.end());

CalculatedPacketLength += EncodedTokenLength.size();

if (Token.size()) {
PacketBuffer.insert(PacketBuffer.end(), Token.begin(), Token.end());
CalculatedPacketLength += Token.size();
}

//
// Note: this ignores the bits in the Header that specify how many bytes
// are used. The caller must ensure these are in-sync.
// Packet number buffer.
//
DrillBuffer PacketNumberBuffer;
if (PacketNumber < 0x100) {
if (Header.PacketNumLen == 0) {
PacketNumberBuffer.push_back((uint8_t) PacketNumber);
} else if (PacketNumber < 0x10000) {
} else if (Header.PacketNumLen == 1) {
PacketNumberBuffer.push_back((uint8_t) (PacketNumber >> 8));
PacketNumberBuffer.push_back((uint8_t) PacketNumber);
} else if (PacketNumber < 0x1000000) {
} else if (Header.PacketNumLen == 2) {
PacketNumberBuffer.push_back((uint8_t) (PacketNumber >> 16));
PacketNumberBuffer.push_back((uint8_t) (PacketNumber >> 8));
PacketNumberBuffer.push_back((uint8_t) PacketNumber);
Expand All @@ -209,16 +210,17 @@ DrillInitialPacketDescriptor::write(
PacketNumberBuffer.push_back((uint8_t) PacketNumber);
}

CalculatedPacketLength += PacketNumberBuffer.size();
CalculatedPacketLength += Payload.size();

//
// Write packet length.
//
DrillBuffer PacketLengthBuffer;
if (PacketLength != nullptr) {
PacketLengthBuffer = QuicDrillEncodeQuicVarInt(*PacketLength);
} else {
size_t CalculatedPacketLength = PacketNumberBuffer.size() + Payload.size();
if (EncryptPayload) {
CalculatedPacketLength += CXPLAT_ENCRYPTION_OVERHEAD;
}
PacketLengthBuffer = QuicDrillEncodeQuicVarInt(CalculatedPacketLength);
}
PacketBuffer.insert(PacketBuffer.end(), PacketLengthBuffer.begin(), PacketLengthBuffer.end());
Expand All @@ -228,12 +230,121 @@ DrillInitialPacketDescriptor::write(
//
PacketBuffer.insert(PacketBuffer.end(), PacketNumberBuffer.begin(), PacketNumberBuffer.end());

auto HeaderLength = (uint16_t)PacketBuffer.size();

//
// Write payload.
//
if (Payload.size() > 0) {
PacketBuffer.insert(PacketBuffer.end(), Payload.begin(), Payload.end());
}

if (EncryptPayload) {
for (uint8_t i = 0; i < CXPLAT_ENCRYPTION_OVERHEAD; ++i) {
PacketBuffer.push_back(0);
}
encrypt(PacketBuffer, HeaderLength, (uint8_t)PacketNumberBuffer.size());
}

return PacketBuffer;
}

struct StrBuffer {
uint8_t* Data;
uint16_t Length;

StrBuffer(const char* HexBytes)
{
Length = (uint16_t)(strlen(HexBytes) / 2);
Data = new uint8_t[Length];

for (uint16_t i = 0; i < Length; ++i) {
Data[i] =
(DecodeHexChar(HexBytes[i * 2]) << 4) |
DecodeHexChar(HexBytes[i * 2 + 1]);
}
}

~StrBuffer() { delete [] Data; }
};

void
DrillInitialPacketDescriptor::encrypt(
DrillBuffer& PacketBuffer,
uint16_t HeaderLength,
uint8_t PacketNumberLength
) const
{
const QUIC_HKDF_LABELS HkdfLabels = { "quic key", "quic iv", "quic hp", "quic ku" };
const StrBuffer InitialSalt("38762cf7f55934b34d179ae6a4c80cadccbb7f0a");

QUIC_PACKET_KEY* WriteKey;
QuicPacketKeyCreateInitial(
FALSE,
&HkdfLabels,
InitialSalt.Data,
(uint8_t)DestCid.size(),
DestCid.data(),
nullptr,
&WriteKey);

uint8_t Iv[CXPLAT_IV_LENGTH];
uint64_t FullPacketNumber = PacketNumber;
QuicCryptoCombineIvAndPacketNumber(
WriteKey->Iv, (uint8_t*)&FullPacketNumber, Iv);

CxPlatEncrypt(
WriteKey->PacketKey,
Iv,
HeaderLength,
PacketBuffer.data(),
(uint16_t)PacketBuffer.size() - HeaderLength,
PacketBuffer.data() + HeaderLength);

uint8_t HpMask[16];
CxPlatHpComputeMask(
WriteKey->HeaderKey,
1,
PacketBuffer.data() + HeaderLength,
HpMask);

uint16_t PacketNumberOffset = HeaderLength - PacketNumberLength;
PacketBuffer[0] ^= HpMask[0] & 0x0F;
for (uint8_t i = 0; i < PacketNumberLength; ++i) {
PacketBuffer[PacketNumberOffset + i] ^= HpMask[i + 1];
}

QuicPacketKeyFree(WriteKey);
}

union QuicShortHeader {
uint8_t HeaderByte;
struct {
uint8_t PacketNumLen : 2;
uint8_t KeyPhase : 1;
uint8_t Reserved : 2;
uint8_t SpinBit : 1;
uint8_t FixedBit : 1;
uint8_t LongHeader : 1;
};
};

DrillBuffer
Drill1RttPacketDescriptor::write(
) const
{
DrillBuffer PacketBuffer;
QuicShortHeader Header = { 0 };
Header.PacketNumLen = 3;
Header.KeyPhase = KeyPhase;

PacketBuffer.push_back(Header.HeaderByte);
PacketBuffer.insert(PacketBuffer.end(), DestCid.begin(), DestCid.end());
PacketBuffer.push_back((uint8_t) (PacketNumber >> 24));// TODO - different packet number sizes
PacketBuffer.push_back((uint8_t) (PacketNumber >> 16));
PacketBuffer.push_back((uint8_t) (PacketNumber >> 8));
PacketBuffer.push_back((uint8_t) PacketNumber);
PacketBuffer.insert(PacketBuffer.end(), Payload.begin(), Payload.end());

return PacketBuffer;
}
25 changes: 23 additions & 2 deletions src/test/lib/DrillDescriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,34 @@ struct DrillInitialPacketDescriptor : DrillPacketDescriptor {

DrillBuffer Payload;

DrillInitialPacketDescriptor(uint8_t SrcCidLength = 9);

DrillInitialPacketDescriptor();
//
// Write this descriptor to a byte array to send on the wire.
//
virtual DrillBuffer write(bool EncryptPayload = false) const;

private:

void encrypt(DrillBuffer& PacketBuffer, uint16_t HeaderLength, uint8_t PacketNumberLength) const;
};

struct Drill1RttPacketDescriptor {

DrillBuffer DestCid;

uint8_t KeyPhase {0};

uint32_t PacketNumber {0};

DrillBuffer Payload;

Drill1RttPacketDescriptor() {}

//
// Write this descriptor to a byte array to send on the wire.
//
virtual DrillBuffer write() const;
DrillBuffer write() const;
};

enum DrillVarIntSize {
Expand Down
49 changes: 49 additions & 0 deletions src/test/lib/QuicDrill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,52 @@ QuicDrillTestServerVNPacket(

CxPlatSleep(500);
}

void
QuicDrillTestKeyUpdateDuringHandshake(
_In_ int Family
)
{
MsQuicRegistration Registration(true);
TEST_QUIC_SUCCEEDED(Registration.GetInitStatus());

if (QuitTestIsFeatureSupported(CXPLAT_DATAPATH_FEATURE_RAW)) {
return;
}

QUIC_ADDRESS_FAMILY QuicAddrFamily = (Family == 4) ? QUIC_ADDRESS_FAMILY_INET : QUIC_ADDRESS_FAMILY_INET6;
QuicAddr ServerLocalAddr(QuicAddrFamily);

MsQuicAutoAcceptListener Listener(Registration, MsQuicConnection::NoOpCallback);
TEST_QUIC_SUCCEEDED(Listener.Start("MsQuicTest", &ServerLocalAddr.SockAddr));
TEST_QUIC_SUCCEEDED(Listener.GetInitStatus());
TEST_QUIC_SUCCEEDED(Listener.GetLocalAddr(ServerLocalAddr));

DrillSender Sender;
TEST_QUIC_SUCCEEDED(
Sender.Initialize(
QUIC_TEST_LOOPBACK_FOR_AF(QuicAddrFamily),
QuicAddrFamily,
(QuicAddrFamily == QUIC_ADDRESS_FAMILY_INET) ?
ServerLocalAddr.SockAddr.Ipv4.sin_port :
ServerLocalAddr.SockAddr.Ipv6.sin6_port));

DrillInitialPacketDescriptor InitialPacketBuffer(0);
InitialPacketBuffer.Header.PacketNumLen = 3;
InitialPacketBuffer.Payload.push_back(1); // Ping frame
for (uint16_t i = 0; i < 1199; ++i) { InitialPacketBuffer.Payload.push_back(0); } // Padding frames

Drill1RttPacketDescriptor OneRttPacketBuffer;
OneRttPacketBuffer.DestCid.insert(
OneRttPacketBuffer.DestCid.end(),
InitialPacketBuffer.DestCid.begin(),
InitialPacketBuffer.DestCid.end());
OneRttPacketBuffer.KeyPhase = 1;
OneRttPacketBuffer.Payload.push_back(1); // Ping frame
for (uint16_t i = 0; i < 80; ++i) { OneRttPacketBuffer.Payload.push_back(0); } // Padding frames

TEST_QUIC_SUCCEEDED(Sender.Send(InitialPacketBuffer.write(true)));
TEST_QUIC_SUCCEEDED(Sender.Send(OneRttPacketBuffer.write()));

CxPlatSleep(500);
}

0 comments on commit a74fea6

Please sign in to comment.