Skip to content

Commit

Permalink
Bugfixes (Dec 2023) (#515)
Browse files Browse the repository at this point in the history
* Delay overload selection when arguments are not known (delayed dispatch)

* Delay 'is None' for 'Optional[T]' until type is known

* Fix union overload selection

* Add static string slicing

* Fix itertools.accumulate

* Fix list comprehension optimization ( minitech:imports-in-list-comprehensions )

* Fix match or patterns

* Fix tests and faulty static tuple issue
  • Loading branch information
inumanag authored Dec 26, 2023
1 parent 32a624b commit 416cc5f
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 68 deletions.
2 changes: 1 addition & 1 deletion codon/parser/visitors/simplify/collections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) {
bool canOptimize = expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 &&
loops[0].conds.empty();
if (canOptimize) {
auto iter = transform(loops[0].gen);
auto iter = transform(clone(loops[0].gen));
IdExpr *id;
if (iter->getCall() && (id = iter->getCall()->expr->getId())) {
// Turn off this optimization for static items
Expand Down
2 changes: 1 addition & 1 deletion codon/parser/visitors/simplify/cond.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ StmtPtr SimplifyVisitor::transformPattern(const ExprPtr &var, ExprPtr pattern,
suite));
} else if (auto eb = pattern->getBinary()) {
// Or pattern
if (eb->op == "|") {
if (eb->op == "|" || eb->op == "||") {
return N<SuiteStmt>(transformPattern(clone(var), clone(eb->lexpr), clone(suite)),
transformPattern(clone(var), clone(eb->rexpr), suite));
}
Expand Down
3 changes: 2 additions & 1 deletion codon/parser/visitors/simplify/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// Case 2: function overload
if (auto c = ctx->find(stmt->name)) {
if (c->isFunc() && c->getModule() == ctx->getModule() &&
c->getBaseName() == ctx->getBaseName())
c->getBaseName() == ctx->getBaseName()) {
rootName = c->canonicalName;
}
}
}
if (rootName.empty())
Expand Down
32 changes: 22 additions & 10 deletions codon/parser/visitors/typecheck/access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ void TypecheckVisitor::visit(IdExpr *expr) {
auto val = ctx->find(expr->value);
if (!val) {
// Handle overloads
if (in(ctx->cache->overloads, expr->value))
if (in(ctx->cache->overloads, expr->value)) {
val = ctx->forceFind(getDispatch(expr->value)->ast->name);
}
seqassert(val, "cannot find '{}'", expr->value);
}
unify(expr->type, ctx->instantiate(val->type));
Expand Down Expand Up @@ -402,28 +403,39 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
}
}

if (methodArgs) {
FuncTypePtr bestMethod = nullptr;
bool goDispatch = methodArgs == nullptr;
if (!goDispatch) {
std::vector<FuncTypePtr> m;
// Use the provided arguments to select the best method
if (auto dot = expr->getDot()) {
// Case: method overloads (DotExpr)
auto methods =
ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false);
auto m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs);
bestMethod = m.empty() ? nullptr : m[0];
m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs);
} else if (auto id = expr->getId()) {
// Case: function overloads (IdExpr)
std::vector<types::FuncTypePtr> methods;
for (auto &m : ctx->cache->overloads[id->value])
if (!endswith(m.name, ":dispatch"))
methods.push_back(ctx->cache->functions[m.name].type);
std::reverse(methods.begin(), methods.end());
auto m = findMatchingMethods(nullptr, methods, *methodArgs);
bestMethod = m.empty() ? nullptr : m[0];
m = findMatchingMethods(nullptr, methods, *methodArgs);
}
if (bestMethod)
return bestMethod;
} else {

if (m.size() == 1) {
return m[0];
} else if (m.size() > 1) {
for (auto &a : *methodArgs) {
if (auto u = a.value->type->getUnbound()) {
goDispatch = true;
}
}
if (!goDispatch)
return m[0];
}
}

if (goDispatch) {
// If overload is ambiguous, route through a dispatch function
std::string name;
if (auto dot = expr->getDot()) {
Expand Down
89 changes: 60 additions & 29 deletions codon/parser/visitors/typecheck/op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,12 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
auto g = expr->lexpr->getType()->getClass();
for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass())
;
if (!g->generics[0].type->getClass()) {
if (!expr->isStatic())
expr->staticValue.type = StaticValue::INT;
unify(expr->type, ctx->getType("bool"));
return nullptr;
}
if (g->generics[0].type->is("NoneType"))
return transform(N<BoolExpr>(true));

Expand Down Expand Up @@ -729,19 +735,26 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) {
std::pair<bool, ExprPtr>
TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
const ExprPtr &expr, const ExprPtr &index) {
if (!tuple->getRecord())
return {false, nullptr};
if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) &&
!startswith(tuple->name, TYPE_PARTIAL)) {
if (tuple->is(TYPE_OPTIONAL)) {
if (auto newTuple = tuple->generics[0].type->getClass()) {
return transformStaticTupleIndex(
newTuple, transform(N<CallExpr>(N<IdExpr>(FN_UNWRAP), expr)), index);
} else {
return {true, nullptr};
bool isStaticString =
expr->isStatic() && expr->staticValue.type == StaticValue::STRING;

if (isStaticString && !expr->staticValue.evaluated) {
return {true, nullptr};
} else if (!isStaticString) {
if (!tuple->getRecord())
return {false, nullptr};
if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) &&
!startswith(tuple->name, TYPE_PARTIAL)) {
if (tuple->is(TYPE_OPTIONAL)) {
if (auto newTuple = tuple->generics[0].type->getClass()) {
return transformStaticTupleIndex(
newTuple, transform(N<CallExpr>(N<IdExpr>(FN_UNWRAP), expr)), index);
} else {
return {true, nullptr};
}
}
return {false, nullptr};
}
return {false, nullptr};
}

// Extract the static integer value from expression
Expand All @@ -760,15 +773,15 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
return false;
};

auto classFields = getClassFields(tuple.get());
auto sz = int64_t(tuple->getRecord()->args.size());
int64_t start = 0, stop = sz, step = 1;
auto sz = int64_t(isStaticString ? expr->staticValue.getString().size()
: tuple->getRecord()->args.size());
int64_t start = 0, stop = sz, step = 1, multiple = 0;
if (getInt(&start, index)) {
// Case: `tuple[int]`
auto i = translateIndex(start, stop);
if (i < 0 || i >= stop)
E(Error::TUPLE_RANGE_BOUNDS, index, stop - 1, i);
return {true, transform(N<DotExpr>(expr, classFields[i].name))};
start = i;
} else if (auto slice = CAST(index, SliceExpr)) {
// Case: `tuple[int:int:int]`
if (!getInt(&start, slice->start) || !getInt(&stop, slice->stop) ||
Expand All @@ -781,23 +794,41 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
if (slice->step && !slice->stop)
stop = step > 0 ? sz : -(sz + 1);
sliceAdjustIndices(sz, &start, &stop, step);
multiple = 1;
} else {
return {false, nullptr};
}

// Generate a sub-tuple
auto var = N<IdExpr>(ctx->cache->getTemporaryVar("tup"));
auto ass = N<AssignStmt>(var, expr);
std::vector<ExprPtr> te;
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) {
if (i < 0 || i >= sz)
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
te.push_back(N<DotExpr>(clone(var), classFields[i].name));
if (isStaticString) {
auto str = expr->staticValue.getString();
if (!multiple) {
return {true, transform(N<StringExpr>(str.substr(start, 1)))};
} else {
std::string newStr;
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step)
newStr += str[i];
return {true, transform(N<StringExpr>(newStr))};
}
} else {
auto classFields = getClassFields(tuple.get());
if (!multiple) {
return {true, transform(N<DotExpr>(expr, classFields[start].name))};
} else {
// Generate a sub-tuple
auto var = N<IdExpr>(ctx->cache->getTemporaryVar("tup"));
auto ass = N<AssignStmt>(var, expr);
std::vector<ExprPtr> te;
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) {
if (i < 0 || i >= sz)
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
te.push_back(N<DotExpr>(clone(var), classFields[i].name));
}
ExprPtr e = transform(
N<StmtExpr>(std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(N<IdExpr>(TYPE_TUPLE), "__new__"), te)));
return {true, e};
}
ExprPtr e = transform(
N<StmtExpr>(std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(N<IdExpr>(TYPE_TUPLE), "__new__"), te)));
return {true, e};
}

return {false, nullptr};
}

/// Follow Python indexing rules for static tuple indices.
Expand Down
24 changes: 18 additions & 6 deletions codon/parser/visitors/typecheck/typecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,20 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
!fn->ast->args[si].defaultValue) {
return -1;
}
reordered.push_back({nullptr, 0});
reordered.emplace_back(nullptr, 0);
} else {
seqassert(gi < fn->funcGenerics.size(), "bad fn");
if (!fn->funcGenerics[gi].type->isStaticType() &&
!args[slots[si][0]].value->isType())
return -1;
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]);
}
gi++;
} else if (si == s || si == k || slots[si].size() != 1) {
// Ignore *args, *kwargs and default arguments
reordered.push_back({nullptr, 0});
reordered.emplace_back(nullptr, 0);
} else {
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]);
}
}
return 0;
Expand Down Expand Up @@ -416,8 +416,20 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
!expectedClass->getUnion()) {
// Extract union types via __internal__.get_union
if (auto t = realize(expectedClass)) {
expr = transform(N<CallExpr>(N<IdExpr>("__internal__.get_union:0"), expr,
N<IdExpr>(t->realizedName())));
auto e = realize(expr->type);
if (!e)
return false;
bool ok = false;
for (auto &ut : e->getUnion()->getRealizationTypes()) {
if (ut->unify(t.get(), nullptr) >= 0) {
ok = true;
break;
}
}
if (ok) {
expr = transform(N<CallExpr>(N<IdExpr>("__internal__.get_union:0"), expr,
N<IdExpr>(t->realizedName())));
}
} else {
return false;
}
Expand Down
38 changes: 19 additions & 19 deletions stdlib/internal/python.codon
Original file line number Diff line number Diff line change
Expand Up @@ -439,25 +439,25 @@ def init_handles_dlopen(py_handle: cobj):
PyTuple_Type = dlsym(py_handle, "PyTuple_Type")
PySlice_Type = dlsym(py_handle, "PySlice_Type")
PyCapsule_Type = dlsym(py_handle, "PyCapsule_Type")
PyExc_BaseException = Ptr[cobj](dlsym(py_handle, "PyExc_BaseException"))[0]
PyExc_Exception = Ptr[cobj](dlsym(py_handle, "PyExc_Exception"))[0]
PyExc_NameError = Ptr[cobj](dlsym(py_handle, "PyExc_NameError"))[0]
PyExc_OSError = Ptr[cobj](dlsym(py_handle, "PyExc_OSError"))[0]
PyExc_IOError = Ptr[cobj](dlsym(py_handle, "PyExc_IOError"))[0]
PyExc_ValueError = Ptr[cobj](dlsym(py_handle, "PyExc_ValueError"))[0]
PyExc_LookupError = Ptr[cobj](dlsym(py_handle, "PyExc_LookupError"))[0]
PyExc_IndexError = Ptr[cobj](dlsym(py_handle, "PyExc_IndexError"))[0]
PyExc_KeyError = Ptr[cobj](dlsym(py_handle, "PyExc_KeyError"))[0]
PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError"))[0]
PyExc_ArithmeticError = Ptr[cobj](dlsym(py_handle, "PyExc_ArithmeticError"))[0]
PyExc_ZeroDivisionError = Ptr[cobj](dlsym(py_handle, "PyExc_ZeroDivisionError"))[0]
PyExc_OverflowError = Ptr[cobj](dlsym(py_handle, "PyExc_OverflowError"))[0]
PyExc_AttributeError = Ptr[cobj](dlsym(py_handle, "PyExc_AttributeError"))[0]
PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError"))[0]
PyExc_NotImplementedError = Ptr[cobj](dlsym(py_handle, "PyExc_NotImplementedError"))[0]
PyExc_StopIteration = Ptr[cobj](dlsym(py_handle, "PyExc_StopIteration"))[0]
PyExc_AssertionError = Ptr[cobj](dlsym(py_handle, "PyExc_AssertionError"))[0]
PyExc_SystemExit = Ptr[cobj](dlsym(py_handle, "PyExc_SystemExit"))[0]
PyExc_BaseException = Ptr[cobj](dlsym(py_handle, "PyExc_BaseException", cobj))[0]
PyExc_Exception = Ptr[cobj](dlsym(py_handle, "PyExc_Exception", cobj))[0]
PyExc_NameError = Ptr[cobj](dlsym(py_handle, "PyExc_NameError", cobj))[0]
PyExc_OSError = Ptr[cobj](dlsym(py_handle, "PyExc_OSError", cobj))[0]
PyExc_IOError = Ptr[cobj](dlsym(py_handle, "PyExc_IOError", cobj))[0]
PyExc_ValueError = Ptr[cobj](dlsym(py_handle, "PyExc_ValueError", cobj))[0]
PyExc_LookupError = Ptr[cobj](dlsym(py_handle, "PyExc_LookupError", cobj))[0]
PyExc_IndexError = Ptr[cobj](dlsym(py_handle, "PyExc_IndexError", cobj))[0]
PyExc_KeyError = Ptr[cobj](dlsym(py_handle, "PyExc_KeyError", cobj))[0]
PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError", cobj))[0]
PyExc_ArithmeticError = Ptr[cobj](dlsym(py_handle, "PyExc_ArithmeticError", cobj))[0]
PyExc_ZeroDivisionError = Ptr[cobj](dlsym(py_handle, "PyExc_ZeroDivisionError", cobj))[0]
PyExc_OverflowError = Ptr[cobj](dlsym(py_handle, "PyExc_OverflowError", cobj))[0]
PyExc_AttributeError = Ptr[cobj](dlsym(py_handle, "PyExc_AttributeError", cobj))[0]
PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError", cobj))[0]
PyExc_NotImplementedError = Ptr[cobj](dlsym(py_handle, "PyExc_NotImplementedError", cobj))[0]
PyExc_StopIteration = Ptr[cobj](dlsym(py_handle, "PyExc_StopIteration", cobj))[0]
PyExc_AssertionError = Ptr[cobj](dlsym(py_handle, "PyExc_AssertionError", cobj))[0]
PyExc_SystemExit = Ptr[cobj](dlsym(py_handle, "PyExc_SystemExit", cobj))[0]

def init_handles_static():
from C import Py_DecRef(cobj) as _Py_DecRef
Expand Down
2 changes: 1 addition & 1 deletion stdlib/itertools.codon
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def accumulate(iterable: Generator[T], func=lambda a, b: a + b, T: type):
Make an iterator that returns accumulated sums, or accumulated results
of other binary functions (specified via the optional func argument).
"""
total = None
total: Optional[T] = None
for element in iterable:
total = element if total is None else func(unwrap(total), element)
yield unwrap(total)
Expand Down
4 changes: 4 additions & 0 deletions test/parser/simplify_expr.codon
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ print [i for i in range(3) if i%2 == 0] #: [0, 2]
print [i + j for i in range(1) for j in range(1)] #: [0]
print {i for i in range(3)} #: {0, 1, 2}

#%% comprehension_opt_clone
import sys
z = [i for i in sys.argv]

#%% generator,barebones
z = 3
g = (e for e in range(20) if e % z == 1)
Expand Down
16 changes: 16 additions & 0 deletions test/parser/simplify_stmt.codon
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,22 @@ foo((1, 33)) #: list 4 33
foo((9, ('ACGT', 3))) #: complex ('ACGT', 3)
foo(range(10)) #: else

for op in 'MI=DXSN':
match op:
case 'M' | '=' | 'X':
print('case 1')
case 'I' or 'S':
print('case 2')
case _:
print('case 3')
#: case 1
#: case 2
#: case 1
#: case 3
#: case 1
#: case 2
#: case 3

#%% match_err_1,barebones
match [1, 2]:
case [1, ..., 2, ..., 3]: pass
Expand Down
23 changes: 23 additions & 0 deletions test/parser/types.codon
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,18 @@ foo("he", z) #: hewoo
X(s='lolo') #: lolo lolo Int[1]
X('abc') #: abc abc Int[2]


def foo2(x: Static[str]):
print(x, x.__is_static__)
s: Static[str] = "abcdefghijkl"
foo2(s) #: abcdefghijkl True
foo2(s[1]) #: b True
foo2(s[1:5]) #: bcde True
foo2(s[10:50]) #: kl True
foo2(s[1:30:3]) #: behk True
foo2(s[::-1]) #: lkjihgfedcba True


#%% static_getitem
print Int[staticlen("ee")].__class__.__name__ #: Int[2]

Expand Down Expand Up @@ -2063,3 +2075,14 @@ print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz'))
c = B()
print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz'))
#: True False True


#%% delayed_dispatch
import math
def fox(a, b, key=None): # key=None delays it!
return a if a <= b else b

a = 1.0
b = 2.0
c = fox(a, b)
print(math.log(c) / 2) #: 0

0 comments on commit 416cc5f

Please sign in to comment.