diff --git a/integration_tests/test_membership_01.py b/integration_tests/test_membership_01.py index 1fab47cda0..0f3b2b0d94 100644 --- a/integration_tests/test_membership_01.py +++ b/integration_tests/test_membership_01.py @@ -6,6 +6,9 @@ def test_int_dict(): i = 4 assert (i in a) + a = {} + assert (1 not in a) + def test_str_dict(): a: dict[str, str] = {'a':'1', 'b':'2', 'c':'3'} i: str @@ -14,6 +17,9 @@ def test_str_dict(): i = 'c' assert (i in a) + a = {} + assert ('a' not in a) + def test_int_set(): a: set[i32] = {1, 2, 3, 4} i: i32 @@ -22,6 +28,9 @@ def test_int_set(): i = 4 assert (i in a) + a = set() + assert (1 not in a) + def test_str_set(): a: set[str] = {'a', 'b', 'c', 'e', 'f'} i: str @@ -30,6 +39,9 @@ def test_str_set(): i = 'c' assert (i in a) + a = set() + assert ('a' not in a) + test_int_dict() test_str_dict() test_int_set() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index a577dfd30a..4aad34d197 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1726,9 +1726,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor> ptr_loads = ptr_loads_copy; llvm::Value *capacity = LLVM::CreateLoad(*builder, llvm_utils->dict_api->get_pointer_to_capacity(right)); - llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module); - - tmp = llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true); + get_builder0(); + llvm::AllocaInst *res = builder0.CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + llvm_utils->create_if_else(builder->CreateICmpEQ( + capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))), + [&]() { + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 0)), res); + }, [&]() { + llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module); + LLVM::CreateStore(*builder, llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true), res); + }); + tmp = LLVM::CreateLoad(*builder, res); } void visit_SetContains(const ASR::SetContains_t &x) { @@ -1748,9 +1756,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor> ptr_loads = ptr_loads_copy; llvm::Value *capacity = LLVM::CreateLoad(*builder, llvm_utils->set_api->get_pointer_to_capacity(right)); - llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module); - - tmp = llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true); + get_builder0(); + llvm::AllocaInst *res = builder0.CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + llvm_utils->create_if_else(builder->CreateICmpEQ( + capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))), + [&]() { + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 0)), res); + }, [&]() { + llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module); + LLVM::CreateStore(*builder, llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true), res); + }); + tmp = LLVM::CreateLoad(*builder, res); } void visit_DictLen(const ASR::DictLen_t& x) {