Skip to content

Commit

Permalink
fix: Conversion operations having poison results (#131)
Browse files Browse the repository at this point in the history
Closes #103
  • Loading branch information
doug-q authored Oct 21, 2024
1 parent 19fdb17 commit 22d1d0f
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 61 deletions.
193 changes: 136 additions & 57 deletions src/extension/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, ensure, Result};

use hugr::{
extension::{
prelude::{sum_with_error, ConstError, BOOL_T},
simple_op::MakeExtensionOp,
},
ops::{constant::Value, custom::ExtensionOp},
ops::{constant::Value, custom::ExtensionOp, DataflowOpTrait as _},
std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES},
types::{TypeArg, TypeEnum},
types::{TypeArg, TypeEnum, TypeRow},
HugrView,
};

use inkwell::{values::BasicValue, FloatPredicate, IntPredicate};
use inkwell::{types::IntType, values::BasicValue, FloatPredicate, IntPredicate};

use crate::{
custom::{CodegenExtension, CodegenExtsBuilder},
Expand All @@ -21,6 +21,7 @@ use crate::{
EmitOpArgs,
},
sum::LLVMSumValue,
types::HugrType,
};

fn build_trunc_op<'c, H: HugrView>(
Expand All @@ -29,53 +30,58 @@ fn build_trunc_op<'c, H: HugrView>(
log_width: u64,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
// Note: This logic is copied from `llvm_type` in the IntTypes
// extension. We need to have a common source of truth for this.
let (width, (int_min_value_s, int_max_value_s), int_max_value_u) = match log_width {
0..=3 => (8, (i8::MIN as i64, i8::MAX as i64), u8::MAX as u64),
4 => (16, (i16::MIN as i64, i16::MAX as i64), u16::MAX as u64),
5 => (32, (i32::MIN as i64, i32::MAX as i64), u32::MAX as u64),
6 => (64, (i64::MIN, i64::MAX), u64::MAX),
m => return Err(anyhow!("ConversionEmitter: unsupported log_width: {}", m)),
};

let hugr_int_ty = INT_TYPES[log_width as usize].clone();
let int_ty = context
.typing_session()
.llvm_type(&hugr_int_ty)?
.into_int_type();
let hugr_sum_ty = sum_with_error(vec![hugr_int_ty.clone()]);
// TODO: it would be nice to get this info out of `ops.node()`, this would
// require adding appropriate methods to `ConvertOpDef`. In the meantime, we
// assert that the output types are as we expect.
debug_assert_eq!(
TypeRow::from(vec![HugrType::from(hugr_sum_ty.clone())]),
args.node().signature().output
);

let Some(int_ty) = IntType::try_from(context.llvm_type(&hugr_int_ty)?).ok() else {
bail!("Expected `arithmetic.int` to lower to an llvm integer")
};

let hugr_sum_ty = sum_with_error(vec![hugr_int_ty]);
let sum_ty = context.typing_session().llvm_sum_type(hugr_sum_ty)?;
let sum_ty = context.llvm_sum_type(hugr_sum_ty)?;

let (width, int_min_value_s, int_max_value_s, int_max_value_u) = {
ensure!(
log_width <= 6,
"Expected log_width of output to be <= 6, found: {log_width}"
);
let width = 1 << log_width;
(
width,
i64::MIN >> (64 - width),
i64::MAX >> (64 - width),
u64::MAX >> (64 - width),
)
};

emit_custom_unary_op(context, args, |ctx, arg, _| {
// We have to check if the conversion will work, so we
// make the maximum int and convert to a float, then compare
// with the function input.
let flt_max = if signed {
ctx.iw_context()
.f64_type()
.const_float(int_max_value_s as f64)
let flt_max = ctx.iw_context().f64_type().const_float(if signed {
int_max_value_s as f64
} else {
ctx.iw_context()
.f64_type()
.const_float(int_max_value_u as f64)
};
int_max_value_u as f64
});

let within_upper_bound = ctx.builder().build_float_compare(
FloatPredicate::OLE,
FloatPredicate::OLT,
arg.into_float_value(),
flt_max,
"within_upper_bound",
)?;

let flt_min = if signed {
ctx.iw_context()
.f64_type()
.const_float(int_min_value_s as f64)
let flt_min = ctx.iw_context().f64_type().const_float(if signed {
int_min_value_s as f64
} else {
ctx.iw_context().f64_type().const_float(0.0)
};
0.0
});

let within_lower_bound = ctx.builder().build_float_compare(
FloatPredicate::OLE,
Expand Down Expand Up @@ -401,7 +407,7 @@ mod test {
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
}

fn roundtrip_hugr(val: u64) -> Hugr {
fn roundtrip_hugr(val: u64, signed: bool) -> Hugr {
let int64 = INT_TYPES[6].clone();
SimpleHugrConfig::new()
.with_outs(USIZE_T)
Expand All @@ -412,14 +418,23 @@ mod test {
.add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [k])
.unwrap()
.outputs_arr();
let [flt] = builder
.add_dataflow_op(ConvertOpDef::convert_u.with_log_width(6), [int])
.unwrap()
.outputs_arr();
let [int_or_err] = builder
.add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(6), [flt])
.unwrap()
.outputs_arr();
let [flt] = {
let op = if signed {
ConvertOpDef::convert_s.with_log_width(6)
} else {
ConvertOpDef::convert_u.with_log_width(6)
};
builder.add_dataflow_op(op, [int]).unwrap().outputs_arr()
};

let [int_or_err] = {
let op = if signed {
ConvertOpDef::trunc_s.with_log_width(6)
} else {
ConvertOpDef::trunc_u.with_log_width(6)
};
builder.add_dataflow_op(op, [flt]).unwrap().outputs_arr()
};
let sum_ty = sum_with_error(int64.clone());
let variants = (0..sum_ty.num_variants())
.map(|i| sum_ty.get_variant(i).unwrap().clone().try_into().unwrap());
Expand Down Expand Up @@ -467,25 +482,89 @@ mod test {
#[case(4294967295)]
#[case(42)]
#[case(18_000_000_000_000_000_000)]
fn roundtrip(mut exec_ctx: TestContext, #[case] val: u64) {
fn roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
add_extensions(&mut exec_ctx);
let hugr = roundtrip_hugr(val);
let hugr = roundtrip_hugr(val, false);
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
}

// N.B.: There's some strange behaviour at the upper end of the ints - the
// first case gets converted to something that's off by 1,000, but the second
// (which is (2 ^ 64) - 1) gets converted to (2 ^ 32) - off by 9 million!
// The fact that the first case works as expected tells me this isn't to do
// with int widths - maybe a floating point expert could explain that this
// is standard behaviour...
#[rstest]
#[case(18_446_744_073_709_550_000, 18_446_744_073_709_549_568)]
#[case(18_446_744_073_709_551_615, 9_223_372_036_854_775_808)] // 2 ^ 63
fn approx_roundtrip(mut exec_ctx: TestContext, #[case] val: u64, #[case] expected: u64) {
// Exact roundtrip conversion is defined on values up to 2**53 for f64.
#[case(0)]
#[case(3)]
#[case(255)]
#[case(4294967295)]
#[case(42)]
#[case(-9_000_000_000_000_000_000)]
fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
add_extensions(&mut exec_ctx);
let hugr = roundtrip_hugr(val as u64, true);
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main") as i64);
}

// For unisgined ints larger than (1 << 54) - 1, f64s do not have enough
// precision to exactly roundtrip the int.
// The exact behaviour of the round-trip is is platform-dependent.
#[rstest]
#[case(u64::MAX)]
#[case(u64::MAX - 1)]
#[case(u64::MAX - (1 << 1))]
#[case(u64::MAX - (1 << 2))]
#[case(u64::MAX - (1 << 3))]
#[case(u64::MAX - (1 << 4))]
#[case(u64::MAX - (1 << 5))]
#[case(u64::MAX - (1 << 6))]
#[case(u64::MAX - (1 << 7))]
#[case(u64::MAX - (1 << 8))]
#[case(u64::MAX - (1 << 9))]
#[case(u64::MAX - (1 << 10))]
#[case(u64::MAX - (1 << 11))]
fn approx_roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
add_extensions(&mut exec_ctx);

let hugr = roundtrip_hugr(val, false);
let result = exec_ctx.exec_hugr_u64(hugr, "main");
let (v_r_max, v_r_min) = (val.max(result), val.min(result));
// If val is too large the `trunc_u` op in `hugr` will return None.
// In this case the hugr returns the magic number `999`.
assert!(result == 999 || (v_r_max - v_r_min) < 1 << 10);
}

#[rstest]
#[case(i64::MAX)]
#[case(i64::MAX - 1)]
#[case(i64::MAX - (1 << 1))]
#[case(i64::MAX - (1 << 2))]
#[case(i64::MAX - (1 << 3))]
#[case(i64::MAX - (1 << 4))]
#[case(i64::MAX - (1 << 5))]
#[case(i64::MAX - (1 << 6))]
#[case(i64::MAX - (1 << 7))]
#[case(i64::MAX - (1 << 8))]
#[case(i64::MAX - (1 << 9))]
#[case(i64::MAX - (1 << 10))]
#[case(i64::MAX - (1 << 11))]
#[case(i64::MIN)]
#[case(i64::MIN + 1)]
#[case(i64::MIN + (1 << 1))]
#[case(i64::MIN + (1 << 2))]
#[case(i64::MIN + (1 << 3))]
#[case(i64::MIN + (1 << 4))]
#[case(i64::MIN + (1 << 5))]
#[case(i64::MIN + (1 << 6))]
#[case(i64::MIN + (1 << 7))]
#[case(i64::MIN + (1 << 8))]
#[case(i64::MIN + (1 << 9))]
#[case(i64::MIN + (1 << 10))]
#[case(i64::MIN + (1 << 11))]
fn approx_roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
add_extensions(&mut exec_ctx);
let hugr = roundtrip_hugr(val);
assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));

let hugr = roundtrip_hugr(val as u64, true);
let result = exec_ctx.exec_hugr_u64(hugr, "main") as i64;
// If val.abs() is too large the `trunc_s` op in `hugr` will return None.
// In this case the hugr returns the magic number `999`.
assert!(result == 999 || (val - result).abs() < 1 << 10);
}

#[rstest]
Expand Down
2 changes: 1 addition & 1 deletion src/extension/snapshots/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%within_upper_bound = fcmp ole double %0, 0x41DFFFFFFFC00000
%within_upper_bound = fcmp olt double %0, 0x41DFFFFFFFC00000
%within_lower_bound = fcmp ole double 0xC1E0000000000000, %0
%success = and i1 %within_upper_bound, %within_lower_bound
%trunc_result = fptosi double %0 to i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ alloca_block:
entry_block: ; preds = %alloca_block
store double %0, double* %"2_0", align 8
%"2_01" = load double, double* %"2_0", align 8
%within_upper_bound = fcmp ole double %"2_01", 0x41DFFFFFFFC00000
%within_upper_bound = fcmp olt double %"2_01", 0x41DFFFFFFFC00000
%within_lower_bound = fcmp ole double 0xC1E0000000000000, %"2_01"
%success = and i1 %within_upper_bound, %within_lower_bound
%trunc_result = fptosi double %"2_01" to i32
Expand Down
2 changes: 1 addition & 1 deletion src/extension/snapshots/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%within_upper_bound = fcmp ole double %0, 0x43F0000000000000
%within_upper_bound = fcmp olt double %0, 0x43F0000000000000
%within_lower_bound = fcmp ole double 0.000000e+00, %0
%success = and i1 %within_upper_bound, %within_lower_bound
%trunc_result = fptoui double %0 to i64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ alloca_block:
entry_block: ; preds = %alloca_block
store double %0, double* %"2_0", align 8
%"2_01" = load double, double* %"2_0", align 8
%within_upper_bound = fcmp ole double %"2_01", 0x43F0000000000000
%within_upper_bound = fcmp olt double %"2_01", 0x43F0000000000000
%within_lower_bound = fcmp ole double 0.000000e+00, %"2_01"
%success = and i1 %within_upper_bound, %within_lower_bound
%trunc_result = fptoui double %"2_01" to i64
Expand Down

0 comments on commit 22d1d0f

Please sign in to comment.