From 416cc5fa59cb8cc5bb5af82d68b60a1aa88e9760 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ibrahim=20Numanagi=C4=87?= Date: Tue, 26 Dec 2023 15:35:03 +0100 Subject: [PATCH] Bugfixes (Dec 2023) (#515) * Delay overload selection when arguments are not known (delayed dispatch) * Delay 'is None' for 'Optional[T]' until type is known * Fix union overload selection * Add static string slicing * Fix itertools.accumulate * Fix list comprehension optimization ( minitech:imports-in-list-comprehensions ) * Fix match or patterns * Fix tests and faulty static tuple issue --- .../parser/visitors/simplify/collections.cpp | 2 +- codon/parser/visitors/simplify/cond.cpp | 2 +- codon/parser/visitors/simplify/function.cpp | 3 +- codon/parser/visitors/typecheck/access.cpp | 32 ++++--- codon/parser/visitors/typecheck/op.cpp | 89 +++++++++++++------ codon/parser/visitors/typecheck/typecheck.cpp | 24 +++-- stdlib/internal/python.codon | 38 ++++---- stdlib/itertools.codon | 2 +- test/parser/simplify_expr.codon | 4 + test/parser/simplify_stmt.codon | 16 ++++ test/parser/types.codon | 23 +++++ 11 files changed, 167 insertions(+), 68 deletions(-) diff --git a/codon/parser/visitors/simplify/collections.cpp b/codon/parser/visitors/simplify/collections.cpp index ee24134f..f23ef5fb 100644 --- a/codon/parser/visitors/simplify/collections.cpp +++ b/codon/parser/visitors/simplify/collections.cpp @@ -59,7 +59,7 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) { bool canOptimize = expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 && loops[0].conds.empty(); if (canOptimize) { - auto iter = transform(loops[0].gen); + auto iter = transform(clone(loops[0].gen)); IdExpr *id; if (iter->getCall() && (id = iter->getCall()->expr->getId())) { // Turn off this optimization for static items diff --git a/codon/parser/visitors/simplify/cond.cpp b/codon/parser/visitors/simplify/cond.cpp index 2870647e..29bbe3fd 100644 --- a/codon/parser/visitors/simplify/cond.cpp +++ b/codon/parser/visitors/simplify/cond.cpp @@ -151,7 +151,7 @@ StmtPtr SimplifyVisitor::transformPattern(const ExprPtr &var, ExprPtr pattern, suite)); } else if (auto eb = pattern->getBinary()) { // Or pattern - if (eb->op == "|") { + if (eb->op == "|" || eb->op == "||") { return N(transformPattern(clone(var), clone(eb->lexpr), clone(suite)), transformPattern(clone(var), clone(eb->rexpr), suite)); } diff --git a/codon/parser/visitors/simplify/function.cpp b/codon/parser/visitors/simplify/function.cpp index 10a6ae5e..4fc41be2 100644 --- a/codon/parser/visitors/simplify/function.cpp +++ b/codon/parser/visitors/simplify/function.cpp @@ -138,8 +138,9 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { // Case 2: function overload if (auto c = ctx->find(stmt->name)) { if (c->isFunc() && c->getModule() == ctx->getModule() && - c->getBaseName() == ctx->getBaseName()) + c->getBaseName() == ctx->getBaseName()) { rootName = c->canonicalName; + } } } if (rootName.empty()) diff --git a/codon/parser/visitors/typecheck/access.cpp b/codon/parser/visitors/typecheck/access.cpp index 174cf294..014992b0 100644 --- a/codon/parser/visitors/typecheck/access.cpp +++ b/codon/parser/visitors/typecheck/access.cpp @@ -29,8 +29,9 @@ void TypecheckVisitor::visit(IdExpr *expr) { auto val = ctx->find(expr->value); if (!val) { // Handle overloads - if (in(ctx->cache->overloads, expr->value)) + if (in(ctx->cache->overloads, expr->value)) { val = ctx->forceFind(getDispatch(expr->value)->ast->name); + } seqassert(val, "cannot find '{}'", expr->value); } unify(expr->type, ctx->instantiate(val->type)); @@ -402,15 +403,15 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr, } } - if (methodArgs) { - FuncTypePtr bestMethod = nullptr; + bool goDispatch = methodArgs == nullptr; + if (!goDispatch) { + std::vector m; // Use the provided arguments to select the best method if (auto dot = expr->getDot()) { // Case: method overloads (DotExpr) auto methods = ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false); - auto m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs); - bestMethod = m.empty() ? nullptr : m[0]; + m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs); } else if (auto id = expr->getId()) { // Case: function overloads (IdExpr) std::vector methods; @@ -418,12 +419,23 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr, if (!endswith(m.name, ":dispatch")) methods.push_back(ctx->cache->functions[m.name].type); std::reverse(methods.begin(), methods.end()); - auto m = findMatchingMethods(nullptr, methods, *methodArgs); - bestMethod = m.empty() ? nullptr : m[0]; + m = findMatchingMethods(nullptr, methods, *methodArgs); } - if (bestMethod) - return bestMethod; - } else { + + if (m.size() == 1) { + return m[0]; + } else if (m.size() > 1) { + for (auto &a : *methodArgs) { + if (auto u = a.value->type->getUnbound()) { + goDispatch = true; + } + } + if (!goDispatch) + return m[0]; + } + } + + if (goDispatch) { // If overload is ambiguous, route through a dispatch function std::string name; if (auto dot = expr->getDot()) { diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index 330e2e2e..cd3d07fc 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -580,6 +580,12 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) { auto g = expr->lexpr->getType()->getClass(); for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass()) ; + if (!g->generics[0].type->getClass()) { + if (!expr->isStatic()) + expr->staticValue.type = StaticValue::INT; + unify(expr->type, ctx->getType("bool")); + return nullptr; + } if (g->generics[0].type->is("NoneType")) return transform(N(true)); @@ -729,19 +735,26 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) { std::pair TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, const ExprPtr &expr, const ExprPtr &index) { - if (!tuple->getRecord()) - return {false, nullptr}; - if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) && - !startswith(tuple->name, TYPE_PARTIAL)) { - if (tuple->is(TYPE_OPTIONAL)) { - if (auto newTuple = tuple->generics[0].type->getClass()) { - return transformStaticTupleIndex( - newTuple, transform(N(N(FN_UNWRAP), expr)), index); - } else { - return {true, nullptr}; + bool isStaticString = + expr->isStatic() && expr->staticValue.type == StaticValue::STRING; + + if (isStaticString && !expr->staticValue.evaluated) { + return {true, nullptr}; + } else if (!isStaticString) { + if (!tuple->getRecord()) + return {false, nullptr}; + if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) && + !startswith(tuple->name, TYPE_PARTIAL)) { + if (tuple->is(TYPE_OPTIONAL)) { + if (auto newTuple = tuple->generics[0].type->getClass()) { + return transformStaticTupleIndex( + newTuple, transform(N(N(FN_UNWRAP), expr)), index); + } else { + return {true, nullptr}; + } } + return {false, nullptr}; } - return {false, nullptr}; } // Extract the static integer value from expression @@ -760,15 +773,15 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, return false; }; - auto classFields = getClassFields(tuple.get()); - auto sz = int64_t(tuple->getRecord()->args.size()); - int64_t start = 0, stop = sz, step = 1; + auto sz = int64_t(isStaticString ? expr->staticValue.getString().size() + : tuple->getRecord()->args.size()); + int64_t start = 0, stop = sz, step = 1, multiple = 0; if (getInt(&start, index)) { // Case: `tuple[int]` auto i = translateIndex(start, stop); if (i < 0 || i >= stop) E(Error::TUPLE_RANGE_BOUNDS, index, stop - 1, i); - return {true, transform(N(expr, classFields[i].name))}; + start = i; } else if (auto slice = CAST(index, SliceExpr)) { // Case: `tuple[int:int:int]` if (!getInt(&start, slice->start) || !getInt(&stop, slice->stop) || @@ -781,23 +794,41 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, if (slice->step && !slice->stop) stop = step > 0 ? sz : -(sz + 1); sliceAdjustIndices(sz, &start, &stop, step); + multiple = 1; + } else { + return {false, nullptr}; + } - // Generate a sub-tuple - auto var = N(ctx->cache->getTemporaryVar("tup")); - auto ass = N(var, expr); - std::vector te; - for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) { - if (i < 0 || i >= sz) - E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i); - te.push_back(N(clone(var), classFields[i].name)); + if (isStaticString) { + auto str = expr->staticValue.getString(); + if (!multiple) { + return {true, transform(N(str.substr(start, 1)))}; + } else { + std::string newStr; + for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) + newStr += str[i]; + return {true, transform(N(newStr))}; + } + } else { + auto classFields = getClassFields(tuple.get()); + if (!multiple) { + return {true, transform(N(expr, classFields[start].name))}; + } else { + // Generate a sub-tuple + auto var = N(ctx->cache->getTemporaryVar("tup")); + auto ass = N(var, expr); + std::vector te; + for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) { + if (i < 0 || i >= sz) + E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i); + te.push_back(N(clone(var), classFields[i].name)); + } + ExprPtr e = transform( + N(std::vector{ass}, + N(N(N(TYPE_TUPLE), "__new__"), te))); + return {true, e}; } - ExprPtr e = transform( - N(std::vector{ass}, - N(N(N(TYPE_TUPLE), "__new__"), te))); - return {true, e}; } - - return {false, nullptr}; } /// Follow Python indexing rules for static tuple indices. diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index c7eb9566..2969ffdc 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -264,20 +264,20 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn, !fn->ast->args[si].defaultValue) { return -1; } - reordered.push_back({nullptr, 0}); + reordered.emplace_back(nullptr, 0); } else { seqassert(gi < fn->funcGenerics.size(), "bad fn"); if (!fn->funcGenerics[gi].type->isStaticType() && !args[slots[si][0]].value->isType()) return -1; - reordered.push_back({args[slots[si][0]].value->type, slots[si][0]}); + reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]); } gi++; } else if (si == s || si == k || slots[si].size() != 1) { // Ignore *args, *kwargs and default arguments - reordered.push_back({nullptr, 0}); + reordered.emplace_back(nullptr, 0); } else { - reordered.push_back({args[slots[si][0]].value->type, slots[si][0]}); + reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]); } } return 0; @@ -416,8 +416,20 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType, !expectedClass->getUnion()) { // Extract union types via __internal__.get_union if (auto t = realize(expectedClass)) { - expr = transform(N(N("__internal__.get_union:0"), expr, - N(t->realizedName()))); + auto e = realize(expr->type); + if (!e) + return false; + bool ok = false; + for (auto &ut : e->getUnion()->getRealizationTypes()) { + if (ut->unify(t.get(), nullptr) >= 0) { + ok = true; + break; + } + } + if (ok) { + expr = transform(N(N("__internal__.get_union:0"), expr, + N(t->realizedName()))); + } } else { return false; } diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index 78dd83ca..db7d7238 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -439,25 +439,25 @@ def init_handles_dlopen(py_handle: cobj): PyTuple_Type = dlsym(py_handle, "PyTuple_Type") PySlice_Type = dlsym(py_handle, "PySlice_Type") PyCapsule_Type = dlsym(py_handle, "PyCapsule_Type") - PyExc_BaseException = Ptr[cobj](dlsym(py_handle, "PyExc_BaseException"))[0] - PyExc_Exception = Ptr[cobj](dlsym(py_handle, "PyExc_Exception"))[0] - PyExc_NameError = Ptr[cobj](dlsym(py_handle, "PyExc_NameError"))[0] - PyExc_OSError = Ptr[cobj](dlsym(py_handle, "PyExc_OSError"))[0] - PyExc_IOError = Ptr[cobj](dlsym(py_handle, "PyExc_IOError"))[0] - PyExc_ValueError = Ptr[cobj](dlsym(py_handle, "PyExc_ValueError"))[0] - PyExc_LookupError = Ptr[cobj](dlsym(py_handle, "PyExc_LookupError"))[0] - PyExc_IndexError = Ptr[cobj](dlsym(py_handle, "PyExc_IndexError"))[0] - PyExc_KeyError = Ptr[cobj](dlsym(py_handle, "PyExc_KeyError"))[0] - PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError"))[0] - PyExc_ArithmeticError = Ptr[cobj](dlsym(py_handle, "PyExc_ArithmeticError"))[0] - PyExc_ZeroDivisionError = Ptr[cobj](dlsym(py_handle, "PyExc_ZeroDivisionError"))[0] - PyExc_OverflowError = Ptr[cobj](dlsym(py_handle, "PyExc_OverflowError"))[0] - PyExc_AttributeError = Ptr[cobj](dlsym(py_handle, "PyExc_AttributeError"))[0] - PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError"))[0] - PyExc_NotImplementedError = Ptr[cobj](dlsym(py_handle, "PyExc_NotImplementedError"))[0] - PyExc_StopIteration = Ptr[cobj](dlsym(py_handle, "PyExc_StopIteration"))[0] - PyExc_AssertionError = Ptr[cobj](dlsym(py_handle, "PyExc_AssertionError"))[0] - PyExc_SystemExit = Ptr[cobj](dlsym(py_handle, "PyExc_SystemExit"))[0] + PyExc_BaseException = Ptr[cobj](dlsym(py_handle, "PyExc_BaseException", cobj))[0] + PyExc_Exception = Ptr[cobj](dlsym(py_handle, "PyExc_Exception", cobj))[0] + PyExc_NameError = Ptr[cobj](dlsym(py_handle, "PyExc_NameError", cobj))[0] + PyExc_OSError = Ptr[cobj](dlsym(py_handle, "PyExc_OSError", cobj))[0] + PyExc_IOError = Ptr[cobj](dlsym(py_handle, "PyExc_IOError", cobj))[0] + PyExc_ValueError = Ptr[cobj](dlsym(py_handle, "PyExc_ValueError", cobj))[0] + PyExc_LookupError = Ptr[cobj](dlsym(py_handle, "PyExc_LookupError", cobj))[0] + PyExc_IndexError = Ptr[cobj](dlsym(py_handle, "PyExc_IndexError", cobj))[0] + PyExc_KeyError = Ptr[cobj](dlsym(py_handle, "PyExc_KeyError", cobj))[0] + PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError", cobj))[0] + PyExc_ArithmeticError = Ptr[cobj](dlsym(py_handle, "PyExc_ArithmeticError", cobj))[0] + PyExc_ZeroDivisionError = Ptr[cobj](dlsym(py_handle, "PyExc_ZeroDivisionError", cobj))[0] + PyExc_OverflowError = Ptr[cobj](dlsym(py_handle, "PyExc_OverflowError", cobj))[0] + PyExc_AttributeError = Ptr[cobj](dlsym(py_handle, "PyExc_AttributeError", cobj))[0] + PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError", cobj))[0] + PyExc_NotImplementedError = Ptr[cobj](dlsym(py_handle, "PyExc_NotImplementedError", cobj))[0] + PyExc_StopIteration = Ptr[cobj](dlsym(py_handle, "PyExc_StopIteration", cobj))[0] + PyExc_AssertionError = Ptr[cobj](dlsym(py_handle, "PyExc_AssertionError", cobj))[0] + PyExc_SystemExit = Ptr[cobj](dlsym(py_handle, "PyExc_SystemExit", cobj))[0] def init_handles_static(): from C import Py_DecRef(cobj) as _Py_DecRef diff --git a/stdlib/itertools.codon b/stdlib/itertools.codon index 951799db..4947af18 100644 --- a/stdlib/itertools.codon +++ b/stdlib/itertools.codon @@ -60,7 +60,7 @@ def accumulate(iterable: Generator[T], func=lambda a, b: a + b, T: type): Make an iterator that returns accumulated sums, or accumulated results of other binary functions (specified via the optional func argument). """ - total = None + total: Optional[T] = None for element in iterable: total = element if total is None else func(unwrap(total), element) yield unwrap(total) diff --git a/test/parser/simplify_expr.codon b/test/parser/simplify_expr.codon index 1cf19b9e..872c056e 100644 --- a/test/parser/simplify_expr.codon +++ b/test/parser/simplify_expr.codon @@ -175,6 +175,10 @@ print [i for i in range(3) if i%2 == 0] #: [0, 2] print [i + j for i in range(1) for j in range(1)] #: [0] print {i for i in range(3)} #: {0, 1, 2} +#%% comprehension_opt_clone +import sys +z = [i for i in sys.argv] + #%% generator,barebones z = 3 g = (e for e in range(20) if e % z == 1) diff --git a/test/parser/simplify_stmt.codon b/test/parser/simplify_stmt.codon index 41486d2d..ea8ac944 100644 --- a/test/parser/simplify_stmt.codon +++ b/test/parser/simplify_stmt.codon @@ -297,6 +297,22 @@ foo((1, 33)) #: list 4 33 foo((9, ('ACGT', 3))) #: complex ('ACGT', 3) foo(range(10)) #: else +for op in 'MI=DXSN': + match op: + case 'M' | '=' | 'X': + print('case 1') + case 'I' or 'S': + print('case 2') + case _: + print('case 3') +#: case 1 +#: case 2 +#: case 1 +#: case 3 +#: case 1 +#: case 2 +#: case 3 + #%% match_err_1,barebones match [1, 2]: case [1, ..., 2, ..., 3]: pass diff --git a/test/parser/types.codon b/test/parser/types.codon index 1cb612bd..02829c3e 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -463,6 +463,18 @@ foo("he", z) #: hewoo X(s='lolo') #: lolo lolo Int[1] X('abc') #: abc abc Int[2] + +def foo2(x: Static[str]): + print(x, x.__is_static__) +s: Static[str] = "abcdefghijkl" +foo2(s) #: abcdefghijkl True +foo2(s[1]) #: b True +foo2(s[1:5]) #: bcde True +foo2(s[10:50]) #: kl True +foo2(s[1:30:3]) #: behk True +foo2(s[::-1]) #: lkjihgfedcba True + + #%% static_getitem print Int[staticlen("ee")].__class__.__name__ #: Int[2] @@ -2063,3 +2075,14 @@ print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz')) c = B() print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz')) #: True False True + + +#%% delayed_dispatch +import math +def fox(a, b, key=None): # key=None delays it! + return a if a <= b else b + +a = 1.0 +b = 2.0 +c = fox(a, b) +print(math.log(c) / 2) #: 0