Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes (Dec 2023) #515

Merged
merged 8 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codon/parser/visitors/simplify/collections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) {
bool canOptimize = expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 &&
loops[0].conds.empty();
if (canOptimize) {
auto iter = transform(loops[0].gen);
auto iter = transform(clone(loops[0].gen));
IdExpr *id;
if (iter->getCall() && (id = iter->getCall()->expr->getId())) {
// Turn off this optimization for static items
Expand Down
2 changes: 1 addition & 1 deletion codon/parser/visitors/simplify/cond.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ StmtPtr SimplifyVisitor::transformPattern(const ExprPtr &var, ExprPtr pattern,
suite));
} else if (auto eb = pattern->getBinary()) {
// Or pattern
if (eb->op == "|") {
if (eb->op == "|" || eb->op == "||") {
return N<SuiteStmt>(transformPattern(clone(var), clone(eb->lexpr), clone(suite)),
transformPattern(clone(var), clone(eb->rexpr), suite));
}
Expand Down
3 changes: 2 additions & 1 deletion codon/parser/visitors/simplify/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// Case 2: function overload
if (auto c = ctx->find(stmt->name)) {
if (c->isFunc() && c->getModule() == ctx->getModule() &&
c->getBaseName() == ctx->getBaseName())
c->getBaseName() == ctx->getBaseName()) {
rootName = c->canonicalName;
}
}
}
if (rootName.empty())
Expand Down
32 changes: 22 additions & 10 deletions codon/parser/visitors/typecheck/access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ void TypecheckVisitor::visit(IdExpr *expr) {
auto val = ctx->find(expr->value);
if (!val) {
// Handle overloads
if (in(ctx->cache->overloads, expr->value))
if (in(ctx->cache->overloads, expr->value)) {
val = ctx->forceFind(getDispatch(expr->value)->ast->name);
}
seqassert(val, "cannot find '{}'", expr->value);
}
unify(expr->type, ctx->instantiate(val->type));
Expand Down Expand Up @@ -402,28 +403,39 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
}
}

if (methodArgs) {
FuncTypePtr bestMethod = nullptr;
bool goDispatch = methodArgs == nullptr;
if (!goDispatch) {
std::vector<FuncTypePtr> m;
// Use the provided arguments to select the best method
if (auto dot = expr->getDot()) {
// Case: method overloads (DotExpr)
auto methods =
ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false);
auto m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs);
bestMethod = m.empty() ? nullptr : m[0];
m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs);
} else if (auto id = expr->getId()) {
// Case: function overloads (IdExpr)
std::vector<types::FuncTypePtr> methods;
for (auto &m : ctx->cache->overloads[id->value])
if (!endswith(m.name, ":dispatch"))
methods.push_back(ctx->cache->functions[m.name].type);
std::reverse(methods.begin(), methods.end());
auto m = findMatchingMethods(nullptr, methods, *methodArgs);
bestMethod = m.empty() ? nullptr : m[0];
m = findMatchingMethods(nullptr, methods, *methodArgs);
}
if (bestMethod)
return bestMethod;
} else {

if (m.size() == 1) {
return m[0];
} else if (m.size() > 1) {
for (auto &a : *methodArgs) {
if (auto u = a.value->type->getUnbound()) {
goDispatch = true;
}
}
if (!goDispatch)
return m[0];
}
}

if (goDispatch) {
// If overload is ambiguous, route through a dispatch function
std::string name;
if (auto dot = expr->getDot()) {
Expand Down
89 changes: 60 additions & 29 deletions codon/parser/visitors/typecheck/op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,12 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
auto g = expr->lexpr->getType()->getClass();
for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass())
;
if (!g->generics[0].type->getClass()) {
if (!expr->isStatic())
expr->staticValue.type = StaticValue::INT;
unify(expr->type, ctx->getType("bool"));
return nullptr;
}
if (g->generics[0].type->is("NoneType"))
return transform(N<BoolExpr>(true));

Expand Down Expand Up @@ -729,19 +735,26 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) {
std::pair<bool, ExprPtr>
TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
const ExprPtr &expr, const ExprPtr &index) {
if (!tuple->getRecord())
return {false, nullptr};
if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) &&
!startswith(tuple->name, TYPE_PARTIAL)) {
if (tuple->is(TYPE_OPTIONAL)) {
if (auto newTuple = tuple->generics[0].type->getClass()) {
return transformStaticTupleIndex(
newTuple, transform(N<CallExpr>(N<IdExpr>(FN_UNWRAP), expr)), index);
} else {
return {true, nullptr};
bool isStaticString =
expr->isStatic() && expr->staticValue.type == StaticValue::STRING;

if (isStaticString && !expr->staticValue.evaluated) {
return {true, nullptr};
} else if (!isStaticString) {
if (!tuple->getRecord())
return {false, nullptr};
if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) &&
!startswith(tuple->name, TYPE_PARTIAL)) {
if (tuple->is(TYPE_OPTIONAL)) {
if (auto newTuple = tuple->generics[0].type->getClass()) {
return transformStaticTupleIndex(
newTuple, transform(N<CallExpr>(N<IdExpr>(FN_UNWRAP), expr)), index);
} else {
return {true, nullptr};
}
}
return {false, nullptr};
}
return {false, nullptr};
}

// Extract the static integer value from expression
Expand All @@ -760,15 +773,15 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
return false;
};

auto classFields = getClassFields(tuple.get());
auto sz = int64_t(tuple->getRecord()->args.size());
int64_t start = 0, stop = sz, step = 1;
auto sz = int64_t(isStaticString ? expr->staticValue.getString().size()
: tuple->getRecord()->args.size());
int64_t start = 0, stop = sz, step = 1, multiple = 0;
if (getInt(&start, index)) {
// Case: `tuple[int]`
auto i = translateIndex(start, stop);
if (i < 0 || i >= stop)
E(Error::TUPLE_RANGE_BOUNDS, index, stop - 1, i);
return {true, transform(N<DotExpr>(expr, classFields[i].name))};
start = i;
} else if (auto slice = CAST(index, SliceExpr)) {
// Case: `tuple[int:int:int]`
if (!getInt(&start, slice->start) || !getInt(&stop, slice->stop) ||
Expand All @@ -781,23 +794,41 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
if (slice->step && !slice->stop)
stop = step > 0 ? sz : -(sz + 1);
sliceAdjustIndices(sz, &start, &stop, step);
multiple = 1;
} else {
return {false, nullptr};
}

// Generate a sub-tuple
auto var = N<IdExpr>(ctx->cache->getTemporaryVar("tup"));
auto ass = N<AssignStmt>(var, expr);
std::vector<ExprPtr> te;
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) {
if (i < 0 || i >= sz)
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
te.push_back(N<DotExpr>(clone(var), classFields[i].name));
if (isStaticString) {
auto str = expr->staticValue.getString();
if (!multiple) {
return {true, transform(N<StringExpr>(str.substr(start, 1)))};
} else {
std::string newStr;
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step)
newStr += str[i];
return {true, transform(N<StringExpr>(newStr))};
}
} else {
auto classFields = getClassFields(tuple.get());
if (!multiple) {
return {true, transform(N<DotExpr>(expr, classFields[start].name))};
} else {
// Generate a sub-tuple
auto var = N<IdExpr>(ctx->cache->getTemporaryVar("tup"));
auto ass = N<AssignStmt>(var, expr);
std::vector<ExprPtr> te;
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) {
if (i < 0 || i >= sz)
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
te.push_back(N<DotExpr>(clone(var), classFields[i].name));
}
ExprPtr e = transform(
N<StmtExpr>(std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(N<IdExpr>(TYPE_TUPLE), "__new__"), te)));
return {true, e};
}
ExprPtr e = transform(
N<StmtExpr>(std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(N<IdExpr>(TYPE_TUPLE), "__new__"), te)));
return {true, e};
}

return {false, nullptr};
}

/// Follow Python indexing rules for static tuple indices.
Expand Down
24 changes: 18 additions & 6 deletions codon/parser/visitors/typecheck/typecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,20 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
!fn->ast->args[si].defaultValue) {
return -1;
}
reordered.push_back({nullptr, 0});
reordered.emplace_back(nullptr, 0);
} else {
seqassert(gi < fn->funcGenerics.size(), "bad fn");
if (!fn->funcGenerics[gi].type->isStaticType() &&
!args[slots[si][0]].value->isType())
return -1;
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]);
}
gi++;
} else if (si == s || si == k || slots[si].size() != 1) {
// Ignore *args, *kwargs and default arguments
reordered.push_back({nullptr, 0});
reordered.emplace_back(nullptr, 0);
} else {
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]);
}
}
return 0;
Expand Down Expand Up @@ -416,8 +416,20 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
!expectedClass->getUnion()) {
// Extract union types via __internal__.get_union
if (auto t = realize(expectedClass)) {
expr = transform(N<CallExpr>(N<IdExpr>("__internal__.get_union:0"), expr,
N<IdExpr>(t->realizedName())));
auto e = realize(expr->type);
if (!e)
return false;
bool ok = false;
for (auto &ut : e->getUnion()->getRealizationTypes()) {
if (ut->unify(t.get(), nullptr) >= 0) {
ok = true;
break;
}
}
if (ok) {
expr = transform(N<CallExpr>(N<IdExpr>("__internal__.get_union:0"), expr,
N<IdExpr>(t->realizedName())));
}
} else {
return false;
}
Expand Down
38 changes: 19 additions & 19 deletions stdlib/internal/python.codon
Original file line number Diff line number Diff line change
Expand Up @@ -439,25 +439,25 @@ def init_handles_dlopen(py_handle: cobj):
PyTuple_Type = dlsym(py_handle, "PyTuple_Type")
PySlice_Type = dlsym(py_handle, "PySlice_Type")
PyCapsule_Type = dlsym(py_handle, "PyCapsule_Type")
PyExc_BaseException = Ptr[cobj](dlsym(py_handle, "PyExc_BaseException"))[0]
PyExc_Exception = Ptr[cobj](dlsym(py_handle, "PyExc_Exception"))[0]
PyExc_NameError = Ptr[cobj](dlsym(py_handle, "PyExc_NameError"))[0]
PyExc_OSError = Ptr[cobj](dlsym(py_handle, "PyExc_OSError"))[0]
PyExc_IOError = Ptr[cobj](dlsym(py_handle, "PyExc_IOError"))[0]
PyExc_ValueError = Ptr[cobj](dlsym(py_handle, "PyExc_ValueError"))[0]
PyExc_LookupError = Ptr[cobj](dlsym(py_handle, "PyExc_LookupError"))[0]
PyExc_IndexError = Ptr[cobj](dlsym(py_handle, "PyExc_IndexError"))[0]
PyExc_KeyError = Ptr[cobj](dlsym(py_handle, "PyExc_KeyError"))[0]
PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError"))[0]
PyExc_ArithmeticError = Ptr[cobj](dlsym(py_handle, "PyExc_ArithmeticError"))[0]
PyExc_ZeroDivisionError = Ptr[cobj](dlsym(py_handle, "PyExc_ZeroDivisionError"))[0]
PyExc_OverflowError = Ptr[cobj](dlsym(py_handle, "PyExc_OverflowError"))[0]
PyExc_AttributeError = Ptr[cobj](dlsym(py_handle, "PyExc_AttributeError"))[0]
PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError"))[0]
PyExc_NotImplementedError = Ptr[cobj](dlsym(py_handle, "PyExc_NotImplementedError"))[0]
PyExc_StopIteration = Ptr[cobj](dlsym(py_handle, "PyExc_StopIteration"))[0]
PyExc_AssertionError = Ptr[cobj](dlsym(py_handle, "PyExc_AssertionError"))[0]
PyExc_SystemExit = Ptr[cobj](dlsym(py_handle, "PyExc_SystemExit"))[0]
PyExc_BaseException = Ptr[cobj](dlsym(py_handle, "PyExc_BaseException", cobj))[0]
PyExc_Exception = Ptr[cobj](dlsym(py_handle, "PyExc_Exception", cobj))[0]
PyExc_NameError = Ptr[cobj](dlsym(py_handle, "PyExc_NameError", cobj))[0]
PyExc_OSError = Ptr[cobj](dlsym(py_handle, "PyExc_OSError", cobj))[0]
PyExc_IOError = Ptr[cobj](dlsym(py_handle, "PyExc_IOError", cobj))[0]
PyExc_ValueError = Ptr[cobj](dlsym(py_handle, "PyExc_ValueError", cobj))[0]
PyExc_LookupError = Ptr[cobj](dlsym(py_handle, "PyExc_LookupError", cobj))[0]
PyExc_IndexError = Ptr[cobj](dlsym(py_handle, "PyExc_IndexError", cobj))[0]
PyExc_KeyError = Ptr[cobj](dlsym(py_handle, "PyExc_KeyError", cobj))[0]
PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError", cobj))[0]
PyExc_ArithmeticError = Ptr[cobj](dlsym(py_handle, "PyExc_ArithmeticError", cobj))[0]
PyExc_ZeroDivisionError = Ptr[cobj](dlsym(py_handle, "PyExc_ZeroDivisionError", cobj))[0]
PyExc_OverflowError = Ptr[cobj](dlsym(py_handle, "PyExc_OverflowError", cobj))[0]
PyExc_AttributeError = Ptr[cobj](dlsym(py_handle, "PyExc_AttributeError", cobj))[0]
PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError", cobj))[0]
PyExc_NotImplementedError = Ptr[cobj](dlsym(py_handle, "PyExc_NotImplementedError", cobj))[0]
PyExc_StopIteration = Ptr[cobj](dlsym(py_handle, "PyExc_StopIteration", cobj))[0]
PyExc_AssertionError = Ptr[cobj](dlsym(py_handle, "PyExc_AssertionError", cobj))[0]
PyExc_SystemExit = Ptr[cobj](dlsym(py_handle, "PyExc_SystemExit", cobj))[0]

def init_handles_static():
from C import Py_DecRef(cobj) as _Py_DecRef
Expand Down
2 changes: 1 addition & 1 deletion stdlib/itertools.codon
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def accumulate(iterable: Generator[T], func=lambda a, b: a + b, T: type):
Make an iterator that returns accumulated sums, or accumulated results
of other binary functions (specified via the optional func argument).
"""
total = None
total: Optional[T] = None
for element in iterable:
total = element if total is None else func(unwrap(total), element)
yield unwrap(total)
Expand Down
4 changes: 4 additions & 0 deletions test/parser/simplify_expr.codon
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ print [i for i in range(3) if i%2 == 0] #: [0, 2]
print [i + j for i in range(1) for j in range(1)] #: [0]
print {i for i in range(3)} #: {0, 1, 2}

#%% comprehension_opt_clone
import sys
z = [i for i in sys.argv]

#%% generator,barebones
z = 3
g = (e for e in range(20) if e % z == 1)
Expand Down
16 changes: 16 additions & 0 deletions test/parser/simplify_stmt.codon
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,22 @@ foo((1, 33)) #: list 4 33
foo((9, ('ACGT', 3))) #: complex ('ACGT', 3)
foo(range(10)) #: else

for op in 'MI=DXSN':
match op:
case 'M' | '=' | 'X':
print('case 1')
case 'I' or 'S':
print('case 2')
case _:
print('case 3')
#: case 1
#: case 2
#: case 1
#: case 3
#: case 1
#: case 2
#: case 3

#%% match_err_1,barebones
match [1, 2]:
case [1, ..., 2, ..., 3]: pass
Expand Down
23 changes: 23 additions & 0 deletions test/parser/types.codon
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,18 @@ foo("he", z) #: hewoo
X(s='lolo') #: lolo lolo Int[1]
X('abc') #: abc abc Int[2]


def foo2(x: Static[str]):
print(x, x.__is_static__)
s: Static[str] = "abcdefghijkl"
foo2(s) #: abcdefghijkl True
foo2(s[1]) #: b True
foo2(s[1:5]) #: bcde True
foo2(s[10:50]) #: kl True
foo2(s[1:30:3]) #: behk True
foo2(s[::-1]) #: lkjihgfedcba True


#%% static_getitem
print Int[staticlen("ee")].__class__.__name__ #: Int[2]

Expand Down Expand Up @@ -2063,3 +2075,14 @@ print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz'))
c = B()
print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz'))
#: True False True


#%% delayed_dispatch
import math
def fox(a, b, key=None): # key=None delays it!
return a if a <= b else b

a = 1.0
b = 2.0
c = fox(a, b)
print(math.log(c) / 2) #: 0