Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(c/driver/postgresql): Support for writing DECIMAL types #1288

Merged
merged 35 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a24b046
Initial hacks
WillAyd Nov 3, 2023
dbddd8b
Merge remote-tracking branch 'upstream/main' into copy-decimal
WillAyd Nov 13, 2023
c5dfd05
feat(c/driver/postgresql): Support for writing DECIMAL128
WillAyd Nov 13, 2023
a091bcd
removed TODO
WillAyd Nov 13, 2023
61eb3cc
trailing decimals
WillAyd Nov 14, 2023
a76ab65
Merge remote-tracking branch 'upstream/main' into copy-decimal
WillAyd Nov 26, 2023
078de30
more decimal hacks
WillAyd Nov 28, 2023
bf8ed7b
working for positive decimal values
WillAyd Nov 29, 2023
75cbd58
Merge branch 'main' into copy-decimal
WillAyd Nov 29, 2023
94bf657
negative value support
WillAyd Nov 29, 2023
4b49999
skip other drivers
WillAyd Nov 29, 2023
c046632
No std::string_view
WillAyd Nov 29, 2023
3957b6d
cleanups
WillAyd Nov 29, 2023
c5d19bb
more generic ToString
WillAyd Nov 29, 2023
ba44774
don't hardcode precision and scale
WillAyd Nov 29, 2023
06b6349
Decimal256 Support
WillAyd Nov 30, 2023
bc19709
remove dead code
WillAyd Nov 30, 2023
10e6e09
Merge remote-tracking branch 'upstream/main' into copy-decimal
WillAyd Dec 14, 2023
e9967a7
less string
WillAyd Dec 14, 2023
6a0d3c9
Allocate up front
WillAyd Dec 16, 2023
df7ba3e
compiling with lifecycle issues
WillAyd Dec 18, 2023
ac733bf
lifecycle workarounds
WillAyd Dec 18, 2023
0cf303f
Try parametrized postgres-test suite
WillAyd Dec 18, 2023
59cdb22
fix test precision / scale arguments
WillAyd Dec 18, 2023
759b0f1
add nullability testing
WillAyd Dec 18, 2023
9472ff5
decimal256 test cases (but failing)
WillAyd Dec 18, 2023
0eba157
passing DECIMAL256 tests
WillAyd Dec 18, 2023
97c2d5c
lint
WillAyd Dec 18, 2023
443efed
endian agnosticism
WillAyd Dec 18, 2023
5a93f9e
fixups
WillAyd Dec 18, 2023
b629aca
msvc compat?
WillAyd Dec 22, 2023
f5100d0
fix COPY test
WillAyd Dec 22, 2023
7e7351d
Simple benchmark
WillAyd Dec 22, 2023
dc1b735
return int instead of void
WillAyd Dec 22, 2023
cc252cd
Merge remote-tracking branch 'upstream/main' into copy-decimal
WillAyd Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions c/driver/postgresql/postgres_copy_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,152 @@ class PostgresCopyIntervalFieldWriter : public PostgresCopyFieldWriter {
}
};

// Inspiration for this taken from get_str_from_var in the pg source
// src/backend/utils/adt/numeric.c
template<enum ArrowType T>
class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
public:
PostgresCopyNumericFieldWriter<T>(int32_t precision, int32_t scale) :
precision_{precision}, scale_{scale} {}

ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
struct ArrowDecimal decimal;
ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);

const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : kNumericNeg;

// Number of decimal digits per Postgres digit
constexpr int kDecDigits = 4;
std::vector<int16_t> pg_digits;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably too small of an optimization to matter, but in principle you should be able to put an upper bound on the number of digits needed to represent an Arrow decimal, and then just stack-allocate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that's true and actually how postgres does it internally.

https://github.com/postgres/postgres/blob/8680bae8463a0b213893ca6a1c5bb2c2530e823c/src/backend/utils/adt/numeric.c#L8026

If we wanted to stack allocate I guess would just expand that out to whatever is required to store up to 4 decimal words?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, decimal128 would be 38 digits and decimal256 would be 76

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK cool. Shouldn't be too hard to switch to that - just need to figure out how to handle once I get multi-word decimals supported

int16_t weight = -(scale_ / kDecDigits);
int16_t dscale = scale_;
bool seen_decimal = scale_ == 0;
bool truncating_trailing_zeros = true;

char decimal_string[max_decimal_digits_ + 1];
int digits_remaining = DecimalToString<bitwidth_>(&decimal, decimal_string);
do {
const int start_pos = digits_remaining < kDecDigits ?
0 : digits_remaining - kDecDigits;
const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
char substr[kDecDigits + 1];
std::memcpy(substr, decimal_string + start_pos, len);
substr[len] = '\0';
int16_t val = static_cast<int16_t>(std::atoi(substr));

if (val == 0) {
if (!seen_decimal && truncating_trailing_zeros) {
dscale -= kDecDigits;
}
} else {
pg_digits.insert(pg_digits.begin(), val);
if (!seen_decimal && truncating_trailing_zeros) {
if (val % 1000 == 0) {
dscale -= 3;
} else if (val % 100 == 0) {
dscale -= 2;
} else if (val % 10 == 0) {
dscale -= 1;
}
}
truncating_trailing_zeros = false;
}
digits_remaining -= kDecDigits;
if (digits_remaining <= 0) {
break;
}
weight++;

if (start_pos <= static_cast<int>(std::strlen(decimal_string)) - scale_) {
seen_decimal = true;
}
} while (true);

int16_t ndigits = pg_digits.size();
int32_t field_size_bytes = sizeof(ndigits)
+ sizeof(weight)
+ sizeof(sign)
+ sizeof(dscale)
+ ndigits * sizeof(int16_t);

NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));

const size_t pg_digit_bytes = sizeof(int16_t) * pg_digits.size();
NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, pg_digit_bytes));
for (auto pg_digit : pg_digits) {
WriteUnsafe<int16_t>(buffer, pg_digit);
}

return ADBC_STATUS_OK;
}

private:
// returns the length of the string
template <int32_t DEC_WIDTH>
int DecimalToString(struct ArrowDecimal* decimal, char* out) {
constexpr size_t nwords = (DEC_WIDTH == 128) ? 2 : 4;
uint8_t tmp[DEC_WIDTH / 8];
ArrowDecimalGetBytes(decimal, tmp);
uint64_t buf[DEC_WIDTH / 64];
std::memcpy(buf, tmp, sizeof(buf));
const int16_t sign = ArrowDecimalSign(decimal) > 0 ? kNumericPos : kNumericNeg;
const bool is_negative = sign == kNumericNeg ? true : false;
if (is_negative) {
buf[0] = ~buf[0] + 1;
for (size_t i = 1; i < nwords; i++) {
buf[i] = ~buf[i];
}
}

// Basic approach adopted from https://stackoverflow.com/a/8023862/621736
char s[max_decimal_digits_ + 1];
std::memset(s, '0', sizeof(s) - 1);
s[sizeof(s) - 1] = '\0';

for (size_t i = 0; i < DEC_WIDTH; i++) {
int carry;

carry = (buf[nwords - 1] >= 0x7FFFFFFFFFFFFFFF);
for (size_t j = nwords - 1; j > 0; j--) {
buf[j] = ((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j-1] >= 0x7FFFFFFFFFFFFFFF);
}
buf[0] = ((buf[0] << 1) & 0xFFFFFFFFFFFFFFFF);

for (int j = sizeof(s) - 2; j>= 0; j--) {
s[j] += s[j] - '0' + carry;
carry = (s[j] > '9');
if (carry) {
s[j] -= 10;
}
}
}

char* p = s;
while ((p[0] == '0') && (p < &s[sizeof(s) - 2])) {
p++;
}

const size_t ndigits = sizeof(s) - 1 - (p - s);
std::memcpy(out, p, ndigits);
out[ndigits] = '\0';

return ndigits;
}

static constexpr uint16_t kNumericPos = 0x0000;
static constexpr uint16_t kNumericNeg = 0x4000;
static constexpr int32_t bitwidth_ = (T == NANOARROW_TYPE_DECIMAL128) ? 128 : 256;
static constexpr size_t max_decimal_digits_ =
(T == NANOARROW_TYPE_DECIMAL128) ? 39 : 78;
const int32_t precision_;
const int32_t scale_;
};

template <enum ArrowTimeUnit TU>
class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
public:
Expand Down Expand Up @@ -1392,6 +1538,20 @@ static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema,
case NANOARROW_TYPE_DOUBLE:
*out = new PostgresCopyDoubleFieldWriter();
return NANOARROW_OK;
case NANOARROW_TYPE_DECIMAL128: {
const auto precision = schema_view.decimal_precision;
const auto scale = schema_view.decimal_scale;
*out = new PostgresCopyNumericFieldWriter<
NANOARROW_TYPE_DECIMAL128>(precision, scale);
return NANOARROW_OK;
}
case NANOARROW_TYPE_DECIMAL256: {
const auto precision = schema_view.decimal_precision;
const auto scale = schema_view.decimal_scale;
*out = new PostgresCopyNumericFieldWriter<
NANOARROW_TYPE_DECIMAL256>(precision, scale);
return NANOARROW_OK;
}
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
Expand Down
66 changes: 66 additions & 0 deletions c/driver/postgresql/postgres_copy_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,72 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric) {
EXPECT_EQ(std::string(item.data, item.size_bytes), "inf");
}

// This buffer is similar to the read variant above but removes special values
// nan, ±inf as they are not supported via the Arrow Decimal types
// COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (NULL), (-123.456),
// ('0.00001234'), (1.0000), (123.456), (1000000)) AS drvd(col))
// TO STDOUT WITH (FORMAT binary);
static uint8_t kTestPgCopyNumericWrite[] = {
0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01, 0x00,
0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x00, 0x40, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11,
0xd0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0xff, 0xfe, 0x00, 0x00, 0x00,
0x08, 0x04, 0xd2, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00,
0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11, 0xd0, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0a, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0xff, 0xff};

TEST(PostgresCopyUtilsTest, PostgresCopyWriteNumeric) {
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;
constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128;
constexpr int32_t size = 128;
constexpr int32_t precision = 38;
constexpr int32_t scale = 8;

struct ArrowDecimal decimal1;
struct ArrowDecimal decimal2;
struct ArrowDecimal decimal3;
struct ArrowDecimal decimal4;
struct ArrowDecimal decimal5;

ArrowDecimalInit(&decimal1, size, 19, 8);
ArrowDecimalSetInt(&decimal1, -12345600000);
ArrowDecimalInit(&decimal2, size, 19, 8);
ArrowDecimalSetInt(&decimal2, 1234);
ArrowDecimalInit(&decimal3, size, 19, 8);
ArrowDecimalSetInt(&decimal3, 100000000);
ArrowDecimalInit(&decimal4, size, 19, 8);
ArrowDecimalSetInt(&decimal4, 12345600000);
ArrowDecimalInit(&decimal5, size, 19, 8);
ArrowDecimalSetInt(&decimal5, 100000000000000);

const std::vector<std::optional<ArrowDecimal*>> values = {
std::nullopt, &decimal1, &decimal2, &decimal3, &decimal4, &decimal5};

ArrowSchemaInit(&schema.value);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These 4 lines are essentially what adbc_validation::MakeSchema does. Though that function is templated by type, I wasn't sure if there was a way to make the template type and optionally precision / scale. There certainly could be a more graceful way of handling this in C++ that I am unaware of

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been meaning to overhaul this for a while now. (Or possibly give in and depend on arrow-cpp...) The current approach really only works for primitive types.

ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0);
ASSERT_EQ(AdbcNsArrowSchemaSetTypeDecimal(schema.value.children[0],
type, precision, scale), 0);
ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0);
ASSERT_EQ(adbc_validation::MakeBatch<ArrowDecimal*>(&schema.value, &array.value,
&na_error, values), ADBC_STATUS_OK);

PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);

const struct ArrowBuffer buf = tester.WriteBuffer();
// The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
constexpr size_t buf_size = sizeof(kTestPgCopyNumericWrite) - 2;
ASSERT_EQ(buf.size_bytes, buf_size);
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyNumericWrite[i]) << " at position " << i;
}
}

// COPY (SELECT CAST(col AS TIMESTAMP) FROM ( VALUES ('1900-01-01 12:34:56'),
// ('2100-01-01 12:34:56'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT BINARY);
static uint8_t kTestPgCopyTimestamp[] = {
Expand Down
Loading
Loading