Skip to content

Commit

Permalink
Property setters (#501)
Browse files Browse the repository at this point in the history
* Fix __from_gpu_new__

* Fix GPU tests

* Update GPU debug codegen

* Add will-return attribute for GPU compilation

* Fix isinstance on unresolved types

* Fix union type instantiation and pendingRealizations placement

* Add float16, bfloat16 and float128 IR types

* Add float16, bfloat16 and float128 types

* Mark complex64 as no-python

* Fix float methods

* Add float tests

* Disable some float tests

* Fix bitset in reaching definitions analysis

* Fix static bool unification

* Add property setters

* Remove log

* Add Union hasattr support

* Fix union bugs; Move union logic to internal.codon; Add fn_can_call for any expression

* Fix isinstance(x, Union)

---------

Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>
  • Loading branch information
arshajii and inumanag committed Dec 4, 2023
1 parent 78a3d7d commit b4a3f89
Show file tree
Hide file tree
Showing 19 changed files with 237 additions and 193 deletions.
3 changes: 2 additions & 1 deletion codon/compiler/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code,
auto fo = fopen("_dump_typecheck.sexp", "w");
fmt::print(fo, "{}\n", typechecked->toString(0));
for (auto &f : cache->functions)
for (auto &r : f.second.realizations)
for (auto &r : f.second.realizations) {
fmt::print(fo, "{}\n", r.second->ast->toString(0));
}
fclose(fo);
}

Expand Down
4 changes: 2 additions & 2 deletions codon/parser/ast/stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ ForStmt::ForStmt(ExprPtr var, ExprPtr iter, StmtPtr suite, StmtPtr elseSuite,
ExprPtr decorator, std::vector<CallExpr::Arg> ompArgs)
: Stmt(), var(std::move(var)), iter(std::move(iter)), suite(std::move(suite)),
elseSuite(std::move(elseSuite)), decorator(std::move(decorator)),
ompArgs(std::move(ompArgs)), wrapped(false) {}
ompArgs(std::move(ompArgs)), wrapped(false), flat(false) {}
ForStmt::ForStmt(const ForStmt &stmt)
: Stmt(stmt), var(ast::clone(stmt.var)), iter(ast::clone(stmt.iter)),
suite(ast::clone(stmt.suite)), elseSuite(ast::clone(stmt.elseSuite)),
decorator(ast::clone(stmt.decorator)), ompArgs(ast::clone_nop(stmt.ompArgs)),
wrapped(stmt.wrapped) {}
wrapped(stmt.wrapped), flat(stmt.flat) {}
std::string ForStmt::toString(int indent) const {
std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " ";
std::string attr;
Expand Down
2 changes: 2 additions & 0 deletions codon/parser/ast/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ struct ForStmt : public Stmt {

/// Indicates if iter was wrapped with __iter__() call.
bool wrapped;
/// True if there are no break/continue within the loop
bool flat;

ForStmt(ExprPtr var, ExprPtr iter, StmtPtr suite, StmtPtr elseSuite = nullptr,
ExprPtr decorator = nullptr, std::vector<CallExpr::Arg> ompArgs = {});
Expand Down
3 changes: 3 additions & 0 deletions codon/parser/visitors/simplify/ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ struct SimplifyContext : public Context<SimplifyItem> {
/// List of variables "seen" before their assignment within a loop.
/// Used to dominate variables that are updated within a loop.
std::unordered_set<std::string> seenVars;
/// False if a loop has continue/break statement. Used for flattening static
/// loops.
bool flat = true;
};
std::vector<Loop> loops;

Expand Down
4 changes: 4 additions & 0 deletions codon/parser/visitors/simplify/loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace codon::ast {
void SimplifyVisitor::visit(ContinueStmt *stmt) {
if (!ctx->getBase()->getLoop())
E(Error::EXPECTED_LOOP, stmt, "continue");
ctx->getBase()->getLoop()->flat = false;
}

/// Ensure that `break` is in a loop.
Expand All @@ -28,6 +29,7 @@ void SimplifyVisitor::visit(ContinueStmt *stmt) {
void SimplifyVisitor::visit(BreakStmt *stmt) {
if (!ctx->getBase()->getLoop())
E(Error::EXPECTED_LOOP, stmt, "break");
ctx->getBase()->getLoop()->flat = false;
if (!ctx->getBase()->getLoop()->breakVar.empty()) {
resultStmt = N<SuiteStmt>(
transform(N<AssignStmt>(N<IdExpr>(ctx->getBase()->getLoop()->breakVar),
Expand Down Expand Up @@ -116,6 +118,8 @@ void SimplifyVisitor::visit(ForStmt *stmt) {
stmt->suite = transform(N<SuiteStmt>(stmts));
}

if (ctx->getBase()->getLoop()->flat)
stmt->flat = true;
// Complete while-else clause
if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) {
resultStmt = N<SuiteStmt>(assign, N<ForStmt>(*stmt),
Expand Down
10 changes: 7 additions & 3 deletions codon/parser/visitors/typecheck/access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,15 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr,

// Case: transform `union.m` to `__internal__.get_union_method(union, "m", ...)`
if (typ->getUnion()) {
if (!typ->canRealize())
return nullptr; // delay!
// bool isMember = false;
// for (auto &t: typ->getUnion()->getRealizationTypes())
// if (ctx->findMethod(t.get(), expr->member).empty())
return transform(N<CallExpr>(
N<IdExpr>("__internal__.get_union_method:0"),
N<IdExpr>("__internal__.union_member:0"),
std::vector<CallExpr::Arg>{{"union", expr->expr},
{"method", N<StringExpr>(expr->member)},
{"", N<EllipsisExpr>(EllipsisExpr::PARTIAL)}}));
{"member", N<StringExpr>(expr->member)}}));
}

// For debugging purposes:
Expand Down
25 changes: 21 additions & 4 deletions codon/parser/visitors/typecheck/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
return transform(N<BoolExpr>(typ->getRecord() != nullptr));
} else if (typExpr->isId("ByRef")) {
return transform(N<BoolExpr>(typ->getRecord() == nullptr));
} else if (typExpr->isId("Union")) {
return transform(N<BoolExpr>(typ->getUnion() != nullptr));
} else if (!typExpr->type->getUnion() && typ->getUnion()) {
auto unionTypes = typ->getUnion()->getRealizationTypes();
int tag = -1;
Expand Down Expand Up @@ -997,10 +999,6 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
if (!typ)
return {true, nullptr};

auto fn = expr->args[0].value->type->getFunc();
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());

auto inargs = unpackTupleTypes(expr->args[1].value);
auto kwargs = unpackTupleTypes(expr->args[2].value);
seqassert(inargs && kwargs, "bad call to fn_can_call");
Expand All @@ -1014,6 +1012,25 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
callArgs.push_back({a.first, std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a.second);
}

auto fn = expr->args[0].value->type->getFunc();
if (!fn) {
bool canCompile = true;
// Special case: not a function, just try compiling it!
auto ocache = *(ctx->cache);
auto octx = *ctx;
try {
transform(N<CallExpr>(clone(expr->args[0].value),
N<StarExpr>(clone(expr->args[1].value)),
N<KeywordStarExpr>(clone(expr->args[2].value))));
} catch (const exc::ParserException &e) {
// LOG("{}", e.what());
canCompile = false;
*ctx = octx;
*(ctx->cache) = ocache;
}
return {true, transform(N<BoolExpr>(canCompile))};
}
return {true, transform(N<BoolExpr>(canCall(fn, callArgs) >= 0))};
} else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) {
expr->staticValue.type = StaticValue::INT;
Expand Down
10 changes: 8 additions & 2 deletions codon/parser/visitors/typecheck/class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,20 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name,
if (startswith(typeName, TYPE_KWTUPLE))
stmt->getClass()->suite = N<SuiteStmt>(getItem, contains, getDef);

// Add repr for KwArgs:
// Add repr and call for partials:
// `def __repr__(self): return __magic__.repr_partial(self)`
auto repr = N<FunctionStmt>(
"__repr__", nullptr, std::vector<Param>{Param{"self"}},
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(N<IdExpr>("__magic__"), "repr_partial"), N<IdExpr>("self")))));
auto pcall = N<FunctionStmt>(
"__call__", nullptr,
std::vector<Param>{Param{"self"}, Param{"*args"}, Param{"**kwargs"}},
N<SuiteStmt>(
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("self"), N<StarExpr>(N<IdExpr>("args")),
N<KeywordStarExpr>(N<IdExpr>("kwargs"))))));
if (startswith(typeName, TYPE_PARTIAL))
stmt->getClass()->suite = repr;
stmt->getClass()->suite = N<SuiteStmt>(repr, pcall);

// Simplify in the standard library context and type check
stmt = SimplifyVisitor::apply(ctx->cache->imports[STDLIB_IMPORT].ctx, stmt,
Expand Down
16 changes: 16 additions & 0 deletions codon/parser/visitors/typecheck/error.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ using namespace types;
/// f = exc; ...; break # PyExc
/// raise```
void TypecheckVisitor::visit(TryStmt *stmt) {
// TODO: static can-compile check
// if (stmt->catches.size() == 1 && stmt->catches[0].var.empty() &&
// stmt->catches[0].exc->isId("std.internal.types.error.StaticCompileError")) {
// /// TODO: this is right now _very_ dangerous; inferred types here will remain!
// bool compiled = true;
// try {
// auto nctx = std::make_shared<TypeContext>(*ctx);
// TypecheckVisitor(nctx).transform(clone(stmt->suite));
// } catch (const exc::ParserException &exc) {
// compiled = false;
// }
// resultStmt = compiled ? transform(stmt->suite) :
// transform(stmt->catches[0].suite); LOG("testing!! {} {}", getSrcInfo(),
// compiled); return;
// }

ctx->blockLevel++;
transform(stmt->suite);
ctx->blockLevel--;
Expand Down
4 changes: 4 additions & 0 deletions codon/parser/visitors/typecheck/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ types::FuncTypePtr TypecheckVisitor::makeFunctionType(FunctionStmt *stmt) {
ctx->typecheckLevel++;
if (stmt->ret) {
unify(baseType->generics[1].type, transformType(stmt->ret)->getType());
if (stmt->ret->isId("Union")) {
baseType->generics[1].type->getUnion()->generics[0].type->getUnbound()->kind =
LinkType::Generic;
}
} else {
generics.push_back(unify(baseType->generics[1].type, ctx->getUnbound()));
}
Expand Down
142 changes: 23 additions & 119 deletions codon/parser/visitors/typecheck/infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
}

if (result->isDone()) {
// Special union case: if union cannot be inferred return type is Union[NoneType]
if (auto tr = ctx->getRealizationBase()->returnType) {
if (auto tu = tr->getUnion()) {
if (!tu->isSealed()) {
if (tu->pendingTypes[0]->getLink() &&
tu->pendingTypes[0]->getLink()->kind == LinkType::Unbound) {
tu->addType(ctx->forceFind("NoneType")->type);
tu->seal();
}
}
}
}
break;
} else if (changedNodes) {
continue;
Expand Down Expand Up @@ -353,17 +365,17 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)

if (!ret) {
realizations.erase(key);
ctx->realizationBases.pop_back();
ctx->popBlock();
ctx->typecheckLevel--;
getLogger().level--;
if (!startswith(ast->name, "._lambda")) {
// Lambda typecheck failures are "ignored" as they are treated as statements,
// not functions.
// TODO: generalize this further.
// LOG("{}", ast->suite->toString(2));
error("cannot typecheck the program");
}
ctx->realizationBases.pop_back();
ctx->popBlock();
ctx->typecheckLevel--;
getLogger().level--;
return nullptr; // inference must be delayed
}

Expand Down Expand Up @@ -836,127 +848,19 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
N<IdExpr>("__internal__.new_union:0"), N<IdExpr>(type->ast->args[0].name),
N<IdExpr>(unionType->realizedTypeName())));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.new_union:0")) {
// Special case: __internal__.new_union
// def __internal__.new_union(value, U[T0, ..., TN]):
// if isinstance(value, T0):
// return __internal__.union_make(0, value, U[T0, ..., TN])
// if isinstance(value, Union[T0]):
// return __internal__.union_make(
// 0, __internal__.get_union(value, T0), U[T0, ..., TN])
// ... <for all T0...TN> ...
// compile_error("invalid union constructor")
auto unionType = type->funcGenerics[0].type->getUnion();
auto unionTypes = unionType->getRealizationTypes();

auto objVar = ast->args[0].name;
auto suite = N<SuiteStmt>();
int tag = 0;
for (auto &t : unionTypes) {
suite->stmts.push_back(N<IfStmt>(
N<CallExpr>(N<IdExpr>("isinstance"), N<IdExpr>(objVar),
NT<IdExpr>(t->realizedName())),
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.union_make:0"),
N<IntExpr>(tag), N<IdExpr>(objVar),
N<IdExpr>(unionType->realizedTypeName())))));
// Check for Union[T]
suite->stmts.push_back(N<IfStmt>(
N<CallExpr>(
N<IdExpr>("isinstance"), N<IdExpr>(objVar),
NT<InstantiateExpr>(NT<IdExpr>("Union"),
std::vector<ExprPtr>{NT<IdExpr>(t->realizedName())})),
N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("__internal__.union_make:0"), N<IntExpr>(tag),
N<CallExpr>(N<IdExpr>("__internal__.get_union:0"),
N<IdExpr>(objVar), NT<IdExpr>(t->realizedName())),
N<IdExpr>(unionType->realizedTypeName())))));
tag++;
}
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<IdExpr>("compile_error"), N<StringExpr>("invalid union constructor"))));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union:0")) {
// Special case: __internal__.get_union
// def __internal__.new_union(union: Union[T0,...,TN], T):
// if __internal__.union_get_tag(union) == 0:
// return __internal__.union_get_data(union, T0)
// ... <for all T0...TN>
// raise TypeError("getter")
auto unionType = type->getArgTypes()[0]->getUnion();
auto unionTypes = unionType->getRealizationTypes();

auto targetType = type->funcGenerics[0].type;
auto selfVar = ast->args[0].name;
auto suite = N<SuiteStmt>();
int tag = 0;
for (auto t : unionTypes) {
if (t->realizedName() == targetType->realizedName()) {
suite->stmts.push_back(N<IfStmt>(
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"),
N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag)),
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"),
N<IdExpr>(selfVar),
NT<IdExpr>(t->realizedName())))));
}
tag++;
}
suite->stmts.push_back(
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union getter"))));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__._get_union_method:0")) {
// def __internal__._get_union_method(union: Union[T0,...,TN], method, *args, **kw):
// if __internal__.union_get_tag(union) == 0:
// return __internal__.union_get_data(union, T0).method(*args, **kw)
// ... <for all T0...TN>
// raise TypeError("call")
auto szt = type->funcGenerics[0].type->getStatic();
auto fnName = szt->evaluate().getString();
auto unionType = type->getArgTypes()[0]->getUnion();
auto unionTypes = unionType->getRealizationTypes();

auto selfVar = ast->args[0].name;
auto suite = N<SuiteStmt>();
int tag = 0;
for (auto &t : unionTypes) {
auto callee =
N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"),
N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())),
fnName);
auto args = N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1)));
auto kwargs = N<KeywordStarExpr>(N<IdExpr>(ast->args[3].name.substr(2)));
std::vector<CallExpr::Arg> callArgs;
ExprPtr check =
N<CallExpr>(N<IdExpr>("hasattr"), NT<IdExpr>(t->realizedName()),
N<StringExpr>(fnName), args->clone(), kwargs->clone());
suite->stmts.push_back(N<IfStmt>(
N<BinaryExpr>(
check, "&&",
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"),
N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag))),
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(callee, args, kwargs)))));
tag++;
}
suite->stmts.push_back(
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union call"))));
// suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));

auto ret = ctx->instantiate(ctx->getType("Union"));
unify(type->getRetType(), ret);
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union_first:0")) {
// def __internal__.get_union_first(union: Union[T0]):
} else if (startswith(ast->name, "__internal__.get_union_tag:0")) {
// def __internal__.get_union_tag(union: Union, tag: Static[int]):
// return __internal__.union_get_data(union, T0)
auto szt = type->funcGenerics[0].type->getStatic();
auto tag = szt->evaluate().getInt();
auto unionType = type->getArgTypes()[0]->getUnion();
auto unionTypes = unionType->getRealizationTypes();

if (tag < 0 || tag >= unionTypes.size())
E(Error::CUSTOM, getSrcInfo(), "bad union tag");
auto selfVar = ast->args[0].name;
auto suite = N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"), N<IdExpr>(selfVar),
NT<IdExpr>(unionTypes[0]->realizedName()))));
NT<IdExpr>(unionTypes[tag]->realizedName()))));
ast->suite = suite;
}
return ast;
Expand Down

0 comments on commit b4a3f89

Please sign in to comment.