diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index bb26220a..9ab42a77 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -276,7 +276,7 @@ class AstToIR: assert False if lhs_type->is_integer_type() and rhs_type->is_integer_type(): - is_signed = lhs_type->kind == TypeKind::SignedInteger and rhs_type->kind == TypeKind::SignedInteger + is_signed = lhs_type->kind == TypeKind::SignedInteger or rhs_type->kind == TypeKind::SignedInteger if op == AstExpressionKind::Add: return LLVMBuildAdd(self->builder, lhs, rhs, "add") if op == AstExpressionKind::Subtract: diff --git a/src/codegen.c b/src/codegen.c index 96fccb99..470cf29b 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -392,11 +392,11 @@ static void codegen_instruction(const struct State *st, const CfInstruction *ins setdest(LLVMBuildFCmp(st->builder, LLVMRealOEQ, getop(0), getop(1), "num_eq")); break; case CF_NUM_LT: - if (is_integer_type(ins->operands[0]->type)) - // TODO: unsigned less than + if (ins->operands[0]->type->kind == TYPE_UNSIGNED_INTEGER && ins->operands[1]->type->kind == TYPE_UNSIGNED_INTEGER) + setdest(LLVMBuildICmp(st->builder, LLVMIntULT, getop(0), getop(1), "num_lt")); + else if (is_integer_type(ins->operands[0]->type) && is_integer_type(ins->operands[1]->type)) setdest(LLVMBuildICmp(st->builder, LLVMIntSLT, getop(0), getop(1), "num_lt")); else - // TODO: signed less than setdest(LLVMBuildFCmp(st->builder, LLVMRealOLT, getop(0), getop(1), "num_lt")); break; } diff --git a/tests/should_succeed/compare_bytes.jou b/tests/should_succeed/compare_bytes.jou new file mode 100644 index 00000000..9aad815a --- /dev/null +++ b/tests/should_succeed/compare_bytes.jou @@ -0,0 +1,20 @@ +import "stdlib/io.jou" + + +# TODO: It's not possible to test things like "signed < unsigned" yet with types of the same size. +# For 8-bit there's only unsigned (byte), for 16-bit only signed (short), 32-bit only signed (int), 64-bit only signed (long). +def main() -> int: + a = 0 as byte + b = 100 as byte + c = 200 as byte + + printf("%d %d %d\n", a < b, b < b, c < b) # Output: 1 0 0 + printf("%d %d %d\n", a > b, b > b, c > b) # Output: 0 0 1 + + printf("%d %d %d\n", a <= b, b <= b, c <= b) # Output: 1 1 0 + printf("%d %d %d\n", a >= b, b >= b, c >= b) # Output: 0 1 1 + + printf("%d %d %d\n", a == b, b == b, c == b) # Output: 0 1 0 + printf("%d %d %d\n", a != b, b != b, c != b) # Output: 1 0 1 + + return 0