From dcdd9d5b06a163422ba0a6370b0d5f4fb354ec6d Mon Sep 17 00:00:00 2001 From: Jimmy Lu Date: Mon, 18 Nov 2024 14:22:40 -0800 Subject: [PATCH] fix: Unaligned memory access in ByteStream and PrestoSerializer Summary: We have a few places in the serialization code that do unaligned memory access. This is causing crashes in `ByteOutputStream::append` and we clean them up in this change. Differential Revision: D66125705 --- velox/common/memory/ByteStream.h | 51 ++++--- velox/common/memory/tests/ByteStreamTest.cpp | 12 ++ velox/serializers/PrestoSerializer.cpp | 153 +++++++------------ 3 files changed, 99 insertions(+), 117 deletions(-) diff --git a/velox/common/memory/ByteStream.h b/velox/common/memory/ByteStream.h index 1cf866bad890e..c8435299c118e 100644 --- a/velox/common/memory/ByteStream.h +++ b/velox/common/memory/ByteStream.h @@ -15,12 +15,13 @@ */ #pragma once -#include #include "velox/common/base/Scratch.h" #include "velox/common/memory/StreamArena.h" #include "velox/type/Type.h" +#include #include + #include namespace facebook::velox { @@ -122,20 +123,24 @@ class ByteInputStream { template T read() { + static_assert(std::is_trivially_copyable_v); if (current_->position + sizeof(T) <= current_->size) { + auto* source = current_->buffer + current_->position; current_->position += sizeof(T); - return *reinterpret_cast( - current_->buffer + current_->position - sizeof(T)); + return folly::loadUnaligned(source); } // The number straddles two buffers. We read byte by byte and make a // little-endian uint64_t. The bytes can be cast to any integer or floating // point type since the wire format has the machine byte order. static_assert(sizeof(T) <= sizeof(uint64_t)); - uint64_t value = 0; + union { + uint64_t bits; + T typed; + } value{}; for (int32_t i = 0; i < sizeof(T); ++i) { - value |= static_cast(readByte()) << (i * 8); + value.bits |= static_cast(readByte()) << (i * 8); } - return *reinterpret_cast(&value); + return value.typed; } template @@ -288,21 +293,15 @@ class ByteOutputStream { template void append(folly::Range values) { + static_assert(std::is_trivially_copyable_v); if (current_->position + sizeof(T) * values.size() > current_->size) { appendStringView(std::string_view( reinterpret_cast(&values[0]), values.size() * sizeof(T))); return; } - - auto* target = reinterpret_cast(current_->buffer + current_->position); - const auto* end = target + values.size(); - auto* valuePtr = &values[0]; - while (target != end) { - *target = *valuePtr; - ++target; - ++valuePtr; - } + auto* target = current_->buffer + current_->position; + memcpy(target, values.data(), values.size() * sizeof(T)); current_->position += sizeof(T) * values.size(); } @@ -317,10 +316,11 @@ class ByteOutputStream { // There must be 8 bytes writable. If available is 56, there are 7, so >. if (available > 56) { const auto offset = position & 7; - uint64_t* buffer = - reinterpret_cast(current_->buffer + (position >> 3)); const auto mask = bits::lowMask(offset); - *buffer = (*buffer & mask) | (bits[0] << offset); + auto* buffer = current_->buffer + (position >> 3); + auto value = folly::loadUnaligned(buffer); + value = (value & mask) | (bits[0] << offset); + folly::storeUnaligned(buffer, value); current_->position += end; return; } @@ -362,7 +362,7 @@ class ByteOutputStream { // Returns a range of 'size' items of T. If there is no contiguous space in // 'this', uses 'scratch' to make a temp block that is appended to 'this' in template - T* getAppendWindow(int32_t size, ScratchPtr& scratchPtr) { + uint8_t* getAppendWindow(int32_t size, ScratchPtr& scratchPtr) { const int32_t bytes = sizeof(T) * size; if (!current_) { extend(bytes); @@ -370,15 +370,14 @@ class ByteOutputStream { auto available = current_->size - current_->position; if (available >= bytes) { current_->position += bytes; - return reinterpret_cast( - current_->buffer + current_->position - bytes); + return current_->buffer + current_->position - bytes; } // If the tail is not large enough, make temp of the right size // in scratch. Extend the stream so that there is guaranteed space to copy // the scratch to the stream. This copy takes place in destruction of // AppendWindow and must not allocate so that it is noexcept. ensureSpace(bytes); - return scratchPtr.get(size); + return reinterpret_cast(scratchPtr.get(size)); } void extend(int32_t bytes); @@ -425,6 +424,12 @@ class ByteOutputStream { friend class AppendWindow; }; +template <> +inline void ByteOutputStream::append( + folly::Range*> /*values*/) { + VELOX_FAIL("Cannot serialize OPAQUE data"); +} + /// A scoped wrapper that provides 'size' T's of writable space in 'stream'. /// Normally gives an address into 'stream's buffer but can use 'scratch' to /// make a contiguous piece if stream does not have a suitable run. @@ -448,7 +453,7 @@ class AppendWindow { } } - T* get(int32_t size) { + uint8_t* get(int32_t size) { return stream_.getAppendWindow(size, scratchPtr_); } diff --git a/velox/common/memory/tests/ByteStreamTest.cpp b/velox/common/memory/tests/ByteStreamTest.cpp index d97bc2673ec9d..0916b412934cf 100644 --- a/velox/common/memory/tests/ByteStreamTest.cpp +++ b/velox/common/memory/tests/ByteStreamTest.cpp @@ -357,6 +357,18 @@ TEST_F(ByteStreamTest, reuse) { } } +TEST_F(ByteStreamTest, unalignedWrite) { + constexpr int kSize = 1 + sizeof(int128_t); + auto arena = newArena(); + ByteOutputStream stream(arena.get()); + stream.startWrite(kSize); + stream.appendStringView(std::string_view("x")); + int128_t data{}; + // This only crashes in opt mode. + stream.append(folly::Range(&data, 1)); + ASSERT_EQ(stream.size(), kSize); +} + class InputByteStreamTest : public ByteStreamTest, public testing::WithParamInterface { protected: diff --git a/velox/serializers/PrestoSerializer.cpp b/velox/serializers/PrestoSerializer.cpp index 80301dde0a7f9..a26aac828f48f 100644 --- a/velox/serializers/PrestoSerializer.cpp +++ b/velox/serializers/PrestoSerializer.cpp @@ -1483,7 +1483,7 @@ class VectorStream { void initializeHeader(std::string_view name, StreamArena& streamArena) { streamArena.newTinyRange(50, nullptr, &header_); header_.size = name.size() + sizeof(int32_t); - *reinterpret_cast(header_.buffer) = name.size(); + folly::storeUnaligned(header_.buffer, name.size()); ::memcpy(header_.buffer + sizeof(int32_t), &name[0], name.size()); } @@ -2396,46 +2396,34 @@ int32_t rowsToRanges( return fill; } -template +template void copyWords( - T* destination, + uint8_t* destination, const int32_t* indices, int32_t numIndices, const T* values, - bool isLongDecimal = false) { - if (std::is_same_v && isLongDecimal) { - for (auto i = 0; i < numIndices; ++i) { - reinterpret_cast(destination)[i] = toJavaDecimalValue( - reinterpret_cast(values)[indices[i]]); - } - return; - } + Conv&& conv = {}) { for (auto i = 0; i < numIndices; ++i) { - destination[i] = values[indices[i]]; + folly::storeUnaligned( + destination + i * sizeof(T), conv(values[indices[i]])); } } -template +template void copyWordsWithRows( - T* destination, + uint8_t* destination, const int32_t* rows, const int32_t* indices, int32_t numIndices, const T* values, - bool isLongDecimal = false) { + Conv&& conv = {}) { if (!indices) { - copyWords(destination, rows, numIndices, values, isLongDecimal); - return; - } - if (std::is_same_v && isLongDecimal) { - for (auto i = 0; i < numIndices; ++i) { - reinterpret_cast(destination)[i] = toJavaDecimalValue( - reinterpret_cast(values)[rows[indices[i]]]); - } + copyWords(destination, rows, numIndices, values, std::forward(conv)); return; } for (auto i = 0; i < numIndices; ++i) { - destination[i] = values[rows[indices[i]]]; + folly::storeUnaligned( + destination + i * sizeof(T), conv(values[rows[indices[i]]])); } } @@ -2469,7 +2457,7 @@ void appendNonNull( if constexpr (sizeof(T) == 8) { AppendWindow window(out, scratch); - int64_t* output = window.get(numNonNull); + auto* output = window.get(numNonNull); copyWordsWithRows( output, rows.data(), @@ -2478,7 +2466,7 @@ void appendNonNull( reinterpret_cast(values)); } else if constexpr (sizeof(T) == 4) { AppendWindow window(out, scratch); - int32_t* output = window.get(numNonNull); + auto* output = window.get(numNonNull); copyWordsWithRows( output, rows.data(), @@ -2487,14 +2475,19 @@ void appendNonNull( reinterpret_cast(values)); } else { AppendWindow window(out, scratch); - T* output = window.get(numNonNull); - copyWordsWithRows( - output, - rows.data(), - nonNullIndices, - numNonNull, - values, - stream->isLongDecimal()); + auto* output = window.get(numNonNull); + if (stream->isLongDecimal()) { + copyWordsWithRows( + output, + rows.data(), + nonNullIndices, + numNonNull, + values, + toJavaDecimalValue); + } else { + copyWordsWithRows( + output, rows.data(), nonNullIndices, numNonNull, values); + } } } @@ -2563,56 +2556,34 @@ void serializeFlatVector( auto* flatVector = vector->asUnchecked>(); auto* rawValues = flatVector->rawValues(); if (!flatVector->mayHaveNulls()) { - if (std::is_same_v) { - appendTimestamps( - nullptr, - rows, - reinterpret_cast(rawValues), - stream, - scratch); - return; - } - - if (std::is_same_v) { - appendStrings( - nullptr, - rows, - reinterpret_cast(rawValues), - stream, - scratch); - return; + if constexpr (std::is_same_v) { + appendTimestamps(nullptr, rows, rawValues, stream, scratch); + } else if constexpr (std::is_same_v) { + appendStrings(nullptr, rows, rawValues, stream, scratch); + } else { + stream->appendNonNull(rows.size()); + AppendWindow window(stream->values(), scratch); + auto* output = window.get(rows.size()); + if (stream->isLongDecimal()) { + copyWords( + output, rows.data(), rows.size(), rawValues, toJavaDecimalValue); + } else { + copyWords(output, rows.data(), rows.size(), rawValues); + } } - - stream->appendNonNull(rows.size()); - AppendWindow window(stream->values(), scratch); - T* output = window.get(rows.size()); - copyWords( - output, rows.data(), rows.size(), rawValues, stream->isLongDecimal()); return; } ScratchPtr nullsHolder(scratch); uint64_t* nulls = nullsHolder.get(bits::nwords(rows.size())); simd::gatherBits(vector->rawNulls(), rows, nulls); - if (std::is_same_v) { - appendTimestamps( - nulls, - rows, - reinterpret_cast(rawValues), - stream, - scratch); - return; - } - if (std::is_same_v) { - appendStrings( - nulls, - rows, - reinterpret_cast(rawValues), - stream, - scratch); - return; + if constexpr (std::is_same_v) { + appendTimestamps(nulls, rows, rawValues, stream, scratch); + } else if constexpr (std::is_same_v) { + appendStrings(nulls, rows, rawValues, stream, scratch); + } else { + appendNonNull(stream, nulls, rows, rawValues, scratch); } - appendNonNull(stream, nulls, rows, rawValues, scratch); } uint64_t bitsToBytesMap[256]; @@ -2628,15 +2599,14 @@ void serializeFlatVector( VectorStream* stream, Scratch& scratch) { auto* flatVector = vector->as>(); - auto* rawValues = flatVector->rawValues(); + auto* rawValues = flatVector->rawValues(); ScratchPtr bitsHolder(scratch); uint64_t* valueBits; int32_t numValueBits; if (!flatVector->mayHaveNulls()) { stream->appendNonNull(rows.size()); valueBits = bitsHolder.get(bits::nwords(rows.size())); - simd::gatherBits( - reinterpret_cast(rawValues), rows, valueBits); + simd::gatherBits(rawValues, rows, valueBits); numValueBits = rows.size(); } else { uint64_t* nulls = bitsHolder.get(bits::nwords(rows.size())); @@ -2651,7 +2621,7 @@ void serializeFlatVector( folly::Range(nonNulls, numValueBits), nonNulls); simd::gatherBits( - reinterpret_cast(rawValues), + rawValues, folly::Range(nonNulls, numValueBits), valueBits); } @@ -2665,10 +2635,11 @@ void serializeFlatVector( const auto numBytes = bits::nbytes(numValueBits); for (auto i = 0; i < numBytes; ++i) { uint64_t word = bitsToBytes(reinterpret_cast(valueBits)[i]); + auto* target = output + i * 8; if (i < numBytes - 1) { - reinterpret_cast(output)[i] = word; + folly::storeUnaligned(target, word); } else { - memcpy(output + i * 8, &word, numValueBits - i * 8); + memcpy(target, &word, numValueBits - i * 8); } } } @@ -3080,10 +3051,8 @@ void estimateFlattenedConstantSerializedSize( int32_t elementSize = sizeof(T); if (constantVector->isNullAt(0)) { elementSize = 1; - } else if (std::is_same_v) { - const auto value = constantVector->valueAt(0); - const auto* string = reinterpret_cast(&value); - elementSize = string->size(); + } else if constexpr (std::is_same_v) { + elementSize = constantVector->valueAt(0).size(); } for (int32_t i = 0; i < ranges.size(); ++i) { *sizes[i] += elementSize * ranges[i].size; @@ -3341,10 +3310,8 @@ void estimateFlattenedConstantSerializedSize( folly::Range(&singleRow, 1), &sizePtr, scratch); - } else if (std::is_same_v) { - const auto value = constantVector->valueAt(0); - const auto string = reinterpret_cast(&value); - elementSize = string->size(); + } else if constexpr (std::is_same_v) { + elementSize = constantVector->valueAt(0).size(); } for (int32_t i = 0; i < rows.size(); ++i) { *sizes[i] += elementSize; @@ -3691,10 +3658,8 @@ void estimateConstantSerializedSize( newRanges, &elementSizePtr, scratch); - } else if (std::is_same_v) { - auto value = constantVector->valueAt(0); - auto string = reinterpret_cast(&value); - elementSize = string->size(); + } else if constexpr (std::is_same_v) { + elementSize = constantVector->valueAt(0).size(); } else { elementSize = sizeof(T); }