Skip to content

Commit

Permalink
Added decimal casting for the E notation
Browse files Browse the repository at this point in the history
  • Loading branch information
Nekit2217 committed Apr 9, 2024
1 parent 6306df0 commit 114e35e
Showing 1 changed file with 212 additions and 12 deletions.
224 changes: 212 additions & 12 deletions arrow-cast/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,105 @@ impl Parser for Date64Type {
}
}

fn parse_e_notation<T: DecimalType> (
s: &str,
mut digits: u16,
mut fractionals: i16,
mut result: T::Native,
precision: u16,
scale: i16
) -> Result<T::Native, ArrowError> {
if digits == 0 && fractionals == 0{
return Err(ArrowError::ParseError(format!("can't parse the string value {s} to decimal")));
}

let mut exp: i16 = 0;
let base = T::Native::usize_as(10);

let mut exp_start: bool = false;
// e has a plus sign
let mut pos_shift_direction: bool = true;

let mut bs = s.as_bytes().iter().skip((digits + 1) as usize);

while let Some(b) = bs.next() {
match b {
b'0'..=b'9' => {
result = result.mul_wrapping(base);
result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize));
if fractionals > 0 {
fractionals += 1;
}
digits += 1;
}
&b'e' | &b'E' => {
exp_start = true;
}
_ => {
return Err(ArrowError::ParseError(format!("can't parse the string value {s} to decimal")));
}
};

if exp_start {
pos_shift_direction = match bs.next() {
Some(&b'-') => false,
Some(&b'+') => true,
Some(b) => {
if !b.is_ascii_digit() {
return Err(ArrowError::ParseError(format!("can't parse the string value {s} to decimal")));
}

exp *= 10;
exp += (b - b'0') as i16;

true
},
None => return Err(ArrowError::ParseError(format!("can't parse the string value {s} to decimal")))
};

for b in bs.by_ref() {
if !b.is_ascii_digit() {
return Err(ArrowError::ParseError(format!("can't parse the string value {s} to decimal")));
}
exp *= 10;
exp += (b - b'0') as i16;
}
}
}

if !pos_shift_direction {
// exponent has a large negative sign
// 1.12345e-30 => 0.0{29}12345, scale = 5
if exp - (digits as i16 + scale) > 0 {
return Ok(T::Native::usize_as(0));
}
exp *= -1;
}

// comma offset
exp = fractionals - exp;
// We have zeros on the left, we need to count them
if !pos_shift_direction && exp > digits as i16 {
digits = exp as u16;
}
// Number of numbers to be removed or added
exp = scale - exp;

if (digits as i16 + exp) as u16 > precision {
return Err(ArrowError::ParseError(format!("parse decimal overflow ({s})")));
}


if exp < 0 {
result = result.div_wrapping(base.pow_wrapping((exp * -1) as _));
} else {
result = result.mul_wrapping(base.pow_wrapping(exp as _));
}

Ok(result)
}


/// Parse the string format decimal value to i128/i256 format and checking the precision and scale.
/// The result value can't be out of bounds.
pub fn parse_decimal<T: DecimalType>(
Expand All @@ -679,8 +778,8 @@ pub fn parse_decimal<T: DecimalType>(
scale: i8,
) -> Result<T::Native, ArrowError> {
let mut result = T::Native::usize_as(0);
let mut fractionals = 0;
let mut digits = 0;
let mut fractionals: i8 = 0;
let mut digits: u8 = 0;
let base = T::Native::usize_as(10);

let bs = s.as_bytes();
Expand All @@ -696,7 +795,10 @@ pub fn parse_decimal<T: DecimalType>(
)));
}

let mut is_e_notation = false;

let mut bs = bs.iter();

// Overflow checks are not required if 10^(precision - 1) <= T::MAX holds.
// Thus, if we validate the precision correctly, we can skip overflow checks.
while let Some(b) = bs.next() {
Expand All @@ -713,6 +815,23 @@ pub fn parse_decimal<T: DecimalType>(
b'.' => {
for b in bs.by_ref() {
if !b.is_ascii_digit() {
if *b == b'e' || *b == b'E' {
result = match parse_e_notation::<T>(
s,
digits.clone() as u16,
fractionals.clone() as i16,
result,
precision.clone() as u16,
scale.clone() as i16,
) {
Err(e) => return Err(e),
Ok(v) => v
};

is_e_notation = true;

break;
}
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
Expand All @@ -729,13 +848,34 @@ pub fn parse_decimal<T: DecimalType>(
result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize));
}

if is_e_notation {
break
}

// Fail on "."
if digits == 0 {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}
}
b'e' | b'E' => {
result = match parse_e_notation::<T>(
s,
digits.clone() as u16,
fractionals.clone() as i16,
result,
precision.clone() as u16,
scale.clone() as i16
) {
Err(e) => return Err(e),
Ok(v) => v
};

is_e_notation = true;

break;
}
_ => {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
Expand All @@ -744,15 +884,17 @@ pub fn parse_decimal<T: DecimalType>(
}
}

if fractionals < scale {
let exp = scale - fractionals;
if exp as u8 + digits > precision {
return Err(ArrowError::ParseError("parse decimal overflow".to_string()));
if !is_e_notation {
if fractionals < scale {
let exp = scale - fractionals;
if exp as u8 + digits > precision {
return Err(ArrowError::ParseError(format!("parse decimal overflow ({s})")));
}
let mul = base.pow_wrapping(exp as _);
result = result.mul_wrapping(mul);
} else if digits > precision {
return Err(ArrowError::ParseError(format!("parse decimal overflow ({s})")));
}
let mul = base.pow_wrapping(exp as _);
result = result.mul_wrapping(mul);
} else if digits > precision {
return Err(ArrowError::ParseError("parse decimal overflow".to_string()));
}

Ok(if negative {
Expand All @@ -762,6 +904,7 @@ pub fn parse_decimal<T: DecimalType>(
})
}


pub fn parse_interval_year_month(
value: &str,
) -> Result<<IntervalYearMonthType as ArrowPrimitiveType>::Native, ArrowError> {
Expand Down Expand Up @@ -2202,7 +2345,38 @@ mod tests {
let result_256 = parse_decimal::<Decimal256Type>(s, 20, 3);
assert_eq!(i256::from_i128(i), result_256.unwrap());
}
let can_not_parse_tests = ["123,123", ".", "123.123.123", "", "+", "-"];
let e_notation_tests = [
("1.23e3", "1230.0", 2),
("5.6714e+2", "567.14", 4),
("5.6714e-2", "0.056714", 4),
("5.6714e-2", "0.056714", 3),
("5.6741214125e2", "567.41214125", 4),
("8.91E4", "89100.0", 2),
("3.14E+5", "314000.0", 2),
("2.718e0", "2.718", 2),
("9.999999e-1", "0.9999999", 4),
("1.23e+3", "1230", 2),
("1.234559e+3", "1234.559", 2),
("1.00E-10", "0.0000000001", 11),
("1.23e-4", "0.000123", 2),
("9.876e7", "98760000.0", 2),
("5.432E+8", "543200000.0", 10),
("1.234567e9", "1234567000.0", 2),
("1.234567e2", "123.45670000", 2),
("4749.3e-5", "0.047493", 10),
("4749.3e+5", "474930000", 10),
("4749.3e-5", "0.047493", 1),
("4749.3e+5", "474930000", 1),
];
for (e, d, scale) in e_notation_tests {
let result_128_e = parse_decimal::<Decimal128Type>(e, 20, scale);
let result_128_d = parse_decimal::<Decimal128Type>(d, 20, scale);
assert_eq!(result_128_e.unwrap(), result_128_d.unwrap());
let result_256_e = parse_decimal::<Decimal256Type>(e, 20, scale);
let result_256_d = parse_decimal::<Decimal256Type>(d, 20, scale);
assert_eq!(result_256_e.unwrap(), result_256_d.unwrap());
}
let can_not_parse_tests = ["123,123", ".", "123.123.123", "", "+", "-", "e", "1.3e+e3"];
for s in can_not_parse_tests {
let result_128 = parse_decimal::<Decimal128Type>(s, 20, 3);
assert_eq!(
Expand All @@ -2215,7 +2389,7 @@ mod tests {
result_256.unwrap_err().to_string()
);
}
let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99", "9.999999999e7"];
for s in overflow_parse_tests {
let result_128 = parse_decimal::<Decimal128Type>(s, 10, 3);
let expected_128 = "Parser error: parse decimal overflow";
Expand Down Expand Up @@ -2262,6 +2436,16 @@ mod tests {
99999999999999999999999999999999999999i128,
38,
),
(
"0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001016744",
0i128,
15,
),
(
"1.016744e-320",
0i128,
15,
),
];
for (s, i, scale) in edge_tests_128 {
let result_128 = parse_decimal::<Decimal128Type>(s, 38, scale);
Expand Down Expand Up @@ -2292,6 +2476,14 @@ mod tests {
.unwrap(),
26,
),
(
"9.999999999999999999999999999999999999999999999999999999999999999999999999999e49",
i256::from_string(
"9999999999999999999999999999999999999999999999999999999999999999999999999999",
)
.unwrap(),
26,
),
(
"99999999999999999999999999999999999999999999999999",
i256::from_string(
Expand All @@ -2300,6 +2492,14 @@ mod tests {
.unwrap(),
26,
),
(
"9.9999999999999999999999999999999999999999999999999e+49",
i256::from_string(
"9999999999999999999999999999999999999999999999999900000000000000000000000000",
)
.unwrap(),
26,
),
];
for (s, i, scale) in edge_tests_256 {
let result = parse_decimal::<Decimal256Type>(s, 76, scale);
Expand Down

0 comments on commit 114e35e

Please sign in to comment.