From bac6ae58dd179694c396c3b108a60d582fb5e44e Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Tue, 17 Jan 2023 10:21:59 -0500 Subject: [PATCH] Generator argument optimization (and more) (#175) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix ABI incompatibilities * Fix codon-jit on macOS * Fix scoping bugs * Fix .codon detection * Handle static arguments in magic methods; Update simd; Fix misc. bugs * Avoid partial calls with generators * clang-format * Add generator-argument optimization * Fix typo * Fix omp test * Make sure sum() does not call __iadd__ * Clarify difference in docs * Fix any/all generator pass * Fix InstantiateExpr simplification; Support .py as module extension * clang-format * Bump version Co-authored-by: Ibrahim Numanagić --- CMakeLists.txt | 6 +- codon/cir/func.cpp | 2 +- codon/cir/llvm/llvisitor.cpp | 6 +- codon/cir/llvm/llvisitor.h | 4 +- codon/cir/llvm/optimize.h | 2 +- codon/cir/module.cpp | 2 +- codon/cir/transform/manager.cpp | 2 + codon/cir/transform/pythonic/generator.cpp | 235 ++++++++++++++++++ codon/cir/transform/pythonic/generator.h | 25 ++ codon/cir/types/types.cpp | 2 +- codon/cir/types/types.h | 2 +- codon/compiler/compiler.h | 6 +- codon/compiler/engine.cpp | 2 +- codon/compiler/engine.h | 2 +- codon/compiler/jit.h | 6 +- codon/dsl/dsl.h | 2 +- codon/dsl/plugins.h | 2 +- codon/parser/cache.h | 2 +- codon/parser/common.cpp | 19 +- codon/parser/visitors/simplify/access.cpp | 4 + .../parser/visitors/simplify/collections.cpp | 65 +++-- codon/parser/visitors/simplify/function.cpp | 4 +- codon/parser/visitors/simplify/op.cpp | 10 +- codon/parser/visitors/translate/translate.cpp | 4 +- codon/parser/visitors/translate/translate.h | 2 +- .../parser/visitors/translate/translate_ctx.h | 4 +- codon/parser/visitors/typecheck/call.cpp | 16 +- codon/parser/visitors/typecheck/function.cpp | 6 +- codon/parser/visitors/typecheck/infer.cpp | 2 +- codon/parser/visitors/typecheck/op.cpp | 8 +- codon/parser/visitors/typecheck/typecheck.cpp | 13 + codon/parser/visitors/typecheck/typecheck.h | 3 + docs/intro/differences.md | 4 +- extra/python/codon/decorator.py | 14 +- stdlib/experimental/simd.codon | 6 +- stdlib/internal/builtin.codon | 29 ++- stdlib/internal/types/collections/list.codon | 7 - test/core/bltin.codon | 65 +++++ test/main.cpp | 6 +- test/parser/simplify_expr.codon | 10 +- test/transform/omp.codon | 2 +- 41 files changed, 514 insertions(+), 99 deletions(-) create mode 100644 codon/cir/transform/pythonic/generator.cpp create mode 100644 codon/cir/transform/pythonic/generator.h diff --git a/CMakeLists.txt b/CMakeLists.txt index bd4f4744..da3f0243 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,10 @@ cmake_minimum_required(VERSION 3.14) project( Codon - VERSION "0.15.3" + VERSION "0.15.4" HOMEPAGE_URL "https://github.com/exaloop/codon" DESCRIPTION "high-performance, extensible Python compiler") -set(CODON_JIT_PYTHON_VERSION "0.1.1") +set(CODON_JIT_PYTHON_VERSION "0.1.2") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in" "${PROJECT_SOURCE_DIR}/codon/config/config.h") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in" @@ -197,6 +197,7 @@ set(CODON_HPPFILES codon/cir/transform/parallel/schedule.h codon/cir/transform/pass.h codon/cir/transform/pythonic/dict.h + codon/cir/transform/pythonic/generator.h codon/cir/transform/pythonic/io.h codon/cir/transform/pythonic/list.h codon/cir/transform/pythonic/str.h @@ -304,6 +305,7 @@ set(CODON_CPPFILES codon/cir/transform/parallel/schedule.cpp codon/cir/transform/pass.cpp codon/cir/transform/pythonic/dict.cpp + codon/cir/transform/pythonic/generator.cpp codon/cir/transform/pythonic/io.cpp codon/cir/transform/pythonic/list.cpp codon/cir/transform/pythonic/str.cpp diff --git a/codon/cir/func.cpp b/codon/cir/func.cpp index 58b6a9de..c3689d5e 100644 --- a/codon/cir/func.cpp +++ b/codon/cir/func.cpp @@ -4,12 +4,12 @@ #include -#include "codon/parser/common.h" #include "codon/cir/module.h" #include "codon/cir/util/iterators.h" #include "codon/cir/util/operator.h" #include "codon/cir/util/visitor.h" #include "codon/cir/var.h" +#include "codon/parser/common.h" namespace codon { namespace ir { diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index 00c01348..f4050602 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -10,13 +10,13 @@ #include #include +#include "codon/cir/dsl/codegen.h" +#include "codon/cir/llvm/optimize.h" +#include "codon/cir/util/irtools.h" #include "codon/compiler/debug_listener.h" #include "codon/compiler/memory_manager.h" #include "codon/parser/common.h" #include "codon/runtime/lib.h" -#include "codon/cir/dsl/codegen.h" -#include "codon/cir/llvm/optimize.h" -#include "codon/cir/util/irtools.h" #include "codon/util/common.h" namespace codon { diff --git a/codon/cir/llvm/llvisitor.h b/codon/cir/llvm/llvisitor.h index c5e39da8..499bf98a 100644 --- a/codon/cir/llvm/llvisitor.h +++ b/codon/cir/llvm/llvisitor.h @@ -2,9 +2,9 @@ #pragma once -#include "codon/dsl/plugins.h" -#include "codon/cir/llvm/llvm.h" #include "codon/cir/cir.h" +#include "codon/cir/llvm/llvm.h" +#include "codon/dsl/plugins.h" #include "codon/util/common.h" #include diff --git a/codon/cir/llvm/optimize.h b/codon/cir/llvm/optimize.h index 5029aff3..3ee9c77c 100644 --- a/codon/cir/llvm/optimize.h +++ b/codon/cir/llvm/optimize.h @@ -4,8 +4,8 @@ #include -#include "codon/dsl/plugins.h" #include "codon/cir/llvm/llvm.h" +#include "codon/dsl/plugins.h" namespace codon { namespace ir { diff --git a/codon/cir/module.cpp b/codon/cir/module.cpp index 7997c78b..3588a455 100644 --- a/codon/cir/module.cpp +++ b/codon/cir/module.cpp @@ -5,8 +5,8 @@ #include #include -#include "codon/parser/cache.h" #include "codon/cir/func.h" +#include "codon/parser/cache.h" namespace codon { namespace ir { diff --git a/codon/cir/transform/manager.cpp b/codon/cir/transform/manager.cpp index 0cafa3d1..b7ebab46 100644 --- a/codon/cir/transform/manager.cpp +++ b/codon/cir/transform/manager.cpp @@ -18,6 +18,7 @@ #include "codon/cir/transform/parallel/openmp.h" #include "codon/cir/transform/pass.h" #include "codon/cir/transform/pythonic/dict.h" +#include "codon/cir/transform/pythonic/generator.h" #include "codon/cir/transform/pythonic/io.h" #include "codon/cir/transform/pythonic/list.h" #include "codon/cir/transform/pythonic/str.h" @@ -162,6 +163,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) { registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); + registerPass(std::make_unique()); registerPass(std::make_unique()); // lowering diff --git a/codon/cir/transform/pythonic/generator.cpp b/codon/cir/transform/pythonic/generator.cpp new file mode 100644 index 00000000..b1be1a8d --- /dev/null +++ b/codon/cir/transform/pythonic/generator.cpp @@ -0,0 +1,235 @@ +// Copyright (C) 2022-2023 Exaloop Inc. + +#include "generator.h" + +#include + +#include "codon/cir/util/cloning.h" +#include "codon/cir/util/irtools.h" +#include "codon/cir/util/matching.h" + +namespace codon { +namespace ir { +namespace transform { +namespace pythonic { +namespace { +bool isSum(Func *f) { + return f && f->getName().rfind("std.internal.builtin.sum:", 0) == 0; +} + +bool isAny(Func *f) { + return f && f->getName().rfind("std.internal.builtin.any:", 0) == 0; +} + +bool isAll(Func *f) { + return f && f->getName().rfind("std.internal.builtin.all:", 0) == 0; +} + +// Replaces yields with updates to the accumulator variable. +struct GeneratorSumTransformer : public util::Operator { + Var *accumulator; + bool valid; + + explicit GeneratorSumTransformer(Var *accumulator) + : util::Operator(), accumulator(accumulator), valid(true) {} + + void handle(YieldInstr *v) override { + auto *M = v->getModule(); + auto *val = v->getValue(); + if (!val) { + valid = false; + return; + } + + Value *rhs = val; + if (val->getType()->is(M->getBoolType())) { + rhs = M->Nr(rhs, M->getInt(1), M->getInt(0)); + } + + Value *add = *M->Nr(accumulator) + *rhs; + if (!add || !add->getType()->is(accumulator->getType())) { + valid = false; + return; + } + + auto *assign = M->Nr(accumulator, add); + v->replaceAll(assign); + } + + void handle(ReturnInstr *v) override { + auto *M = v->getModule(); + auto *newReturn = M->Nr(M->Nr(accumulator)); + see(newReturn); + v->replaceAll(util::series(v->getValue(), newReturn)); + } + + void handle(YieldInInstr *v) override { valid = false; } +}; + +// Replaces yields with conditional returns of the any/all answer. +struct GeneratorAnyAllTransformer : public util::Operator { + bool any; // true=any, false=all + bool valid; + + explicit GeneratorAnyAllTransformer(bool any) + : util::Operator(), any(any), valid(true) {} + + void handle(YieldInstr *v) override { + auto *M = v->getModule(); + auto *val = v->getValue(); + auto *valBool = val ? (*M->getBoolType())(*val) : nullptr; + if (!valBool) { + valid = false; + return; + } else if (!any) { + valBool = M->Nr(valBool, M->getBool(false), M->getBool(true)); + } + + auto *newReturn = M->Nr(M->getBool(any)); + see(newReturn); + auto *rep = M->Nr(valBool, util::series(newReturn)); + v->replaceAll(rep); + } + + void handle(ReturnInstr *v) override { + if (saw(v)) + return; + auto *M = v->getModule(); + auto *newReturn = M->Nr(M->getBool(!any)); + see(newReturn); + v->replaceAll(util::series(v->getValue(), newReturn)); + } + + void handle(YieldInInstr *v) override { valid = false; } +}; + +Func *genToSum(BodiedFunc *gen, types::Type *startType, types::Type *outType) { + if (!gen || !gen->isGenerator()) + return nullptr; + + auto *M = gen->getModule(); + auto *fn = M->Nr("__sum_wrapper"); + auto *genType = cast(gen->getType()); + if (!genType) + return nullptr; + + std::vector argTypes(genType->begin(), genType->end()); + argTypes.push_back(startType); + + std::vector names; + for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) { + names.push_back((*it)->getName()); + } + names.push_back("start"); + + auto *fnType = M->getFuncType(outType, argTypes); + fn->realize(fnType, names); + + std::unordered_map argRemap; + for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin(); + it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) { + argRemap.emplace((*it1)->getId(), *it2); + } + + util::CloneVisitor cv(M); + auto *body = cast(cv.clone(gen->getBody(), fn, argRemap)); + fn->setBody(body); + + Value *init = M->Nr(fn->arg_back()); + if (startType->is(M->getIntType()) && outType->is(M->getFloatType())) + init = (*M->getFloatType())(*init); + + if (!init || !init->getType()->is(outType)) + return nullptr; + + auto *accumulator = util::makeVar(init, body, fn, /*prepend=*/true)->getVar(); + GeneratorSumTransformer xgen(accumulator); + fn->accept(xgen); + body->push_back(M->Nr(M->Nr(accumulator))); + + if (!xgen.valid) + return nullptr; + + return fn; +} + +Func *genToAnyAll(BodiedFunc *gen, bool any) { + if (!gen || !gen->isGenerator()) + return nullptr; + + auto *M = gen->getModule(); + auto *fn = M->Nr(any ? "__any_wrapper" : "__all_wrapper"); + auto *genType = cast(gen->getType()); + + std::vector argTypes(genType->begin(), genType->end()); + std::vector names; + for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) { + names.push_back((*it)->getName()); + } + + auto *fnType = M->getFuncType(M->getBoolType(), argTypes); + fn->realize(fnType, names); + + std::unordered_map argRemap; + for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin(); + it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) { + argRemap.emplace((*it1)->getId(), *it2); + } + + util::CloneVisitor cv(M); + auto *body = cast(cv.clone(gen->getBody(), fn, argRemap)); + fn->setBody(body); + + GeneratorAnyAllTransformer xgen(any); + fn->accept(xgen); + body->push_back(M->Nr(M->getBool(!any))); + + if (!xgen.valid) + return nullptr; + + return fn; +} +} // namespace + +const std::string GeneratorArgumentOptimization::KEY = + "core-pythonic-generator-argument-opt"; + +void GeneratorArgumentOptimization::handle(CallInstr *v) { + auto *M = v->getModule(); + auto *func = util::getFunc(v->getCallee()); + + if (isSum(func) && v->numArgs() == 2) { + auto *call = cast(v->front()); + if (!call) + return; + + auto *gen = util::getFunc(call->getCallee()); + auto *start = v->back(); + + if (auto *fn = genToSum(cast(gen), start->getType(), v->getType())) { + std::vector args(call->begin(), call->end()); + args.push_back(start); + v->replaceAll(util::call(fn, args)); + } + } else { + bool any = isAny(func), all = isAll(func); + if (!(any || all) || v->numArgs() != 1 || !v->getType()->is(M->getBoolType())) + return; + + auto *call = cast(v->front()); + if (!call) + return; + + auto *gen = util::getFunc(call->getCallee()); + + if (auto *fn = genToAnyAll(cast(gen), any)) { + std::vector args(call->begin(), call->end()); + v->replaceAll(util::call(fn, args)); + } + } +} + +} // namespace pythonic +} // namespace transform +} // namespace ir +} // namespace codon diff --git a/codon/cir/transform/pythonic/generator.h b/codon/cir/transform/pythonic/generator.h new file mode 100644 index 00000000..663f8868 --- /dev/null +++ b/codon/cir/transform/pythonic/generator.h @@ -0,0 +1,25 @@ +// Copyright (C) 2022-2023 Exaloop Inc. + +#pragma once + +#include "codon/cir/transform/pass.h" + +namespace codon { +namespace ir { +namespace transform { +namespace pythonic { + +/// Pass to optimize passing a generator to some built-in functions +/// like sum(), any() or all(), which will be converted to regular +/// for-loops. +class GeneratorArgumentOptimization : public OperatorPass { +public: + static const std::string KEY; + std::string getKey() const override { return KEY; } + void handle(CallInstr *v) override; +}; + +} // namespace pythonic +} // namespace transform +} // namespace ir +} // namespace codon diff --git a/codon/cir/types/types.cpp b/codon/cir/types/types.cpp index 5c53dee4..078230b0 100644 --- a/codon/cir/types/types.cpp +++ b/codon/cir/types/types.cpp @@ -6,12 +6,12 @@ #include #include -#include "codon/parser/cache.h" #include "codon/cir/module.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/iterators.h" #include "codon/cir/util/visitor.h" #include "codon/cir/value.h" +#include "codon/parser/cache.h" #include namespace codon { diff --git a/codon/cir/types/types.h b/codon/cir/types/types.h index 17ae8926..fff29aa8 100644 --- a/codon/cir/types/types.h +++ b/codon/cir/types/types.h @@ -8,10 +8,10 @@ #include #include -#include "codon/parser/ast.h" #include "codon/cir/base.h" #include "codon/cir/util/packs.h" #include "codon/cir/util/visitor.h" +#include "codon/parser/ast.h" #include #include diff --git a/codon/compiler/compiler.h b/codon/compiler/compiler.h index 5d5712e6..cc13bfa7 100644 --- a/codon/compiler/compiler.h +++ b/codon/compiler/compiler.h @@ -7,12 +7,12 @@ #include #include -#include "codon/compiler/error.h" -#include "codon/dsl/plugins.h" -#include "codon/parser/cache.h" #include "codon/cir/llvm/llvisitor.h" #include "codon/cir/module.h" #include "codon/cir/transform/manager.h" +#include "codon/compiler/error.h" +#include "codon/dsl/plugins.h" +#include "codon/parser/cache.h" namespace codon { diff --git a/codon/compiler/engine.cpp b/codon/compiler/engine.cpp index b7741d69..d3962c98 100644 --- a/codon/compiler/engine.cpp +++ b/codon/compiler/engine.cpp @@ -2,8 +2,8 @@ #include "engine.h" -#include "codon/compiler/memory_manager.h" #include "codon/cir/llvm/optimize.h" +#include "codon/compiler/memory_manager.h" namespace codon { namespace jit { diff --git a/codon/compiler/engine.h b/codon/compiler/engine.h index 32290838..0edf2356 100644 --- a/codon/compiler/engine.h +++ b/codon/compiler/engine.h @@ -5,8 +5,8 @@ #include #include -#include "codon/compiler/debug_listener.h" #include "codon/cir/llvm/llvm.h" +#include "codon/compiler/debug_listener.h" namespace codon { namespace jit { diff --git a/codon/compiler/jit.h b/codon/compiler/jit.h index e873b465..8f9ff226 100644 --- a/codon/compiler/jit.h +++ b/codon/compiler/jit.h @@ -7,14 +7,14 @@ #include #include +#include "codon/cir/llvm/llvisitor.h" +#include "codon/cir/transform/manager.h" +#include "codon/cir/var.h" #include "codon/compiler/compiler.h" #include "codon/compiler/engine.h" #include "codon/compiler/error.h" #include "codon/parser/cache.h" #include "codon/runtime/lib.h" -#include "codon/cir/llvm/llvisitor.h" -#include "codon/cir/transform/manager.h" -#include "codon/cir/var.h" #include "codon/compiler/jit_extern.h" diff --git a/codon/dsl/dsl.h b/codon/dsl/dsl.h index 7e095d5b..8bdeb8ca 100644 --- a/codon/dsl/dsl.h +++ b/codon/dsl/dsl.h @@ -2,10 +2,10 @@ #pragma once -#include "codon/parser/cache.h" #include "codon/cir/cir.h" #include "codon/cir/transform/manager.h" #include "codon/cir/transform/pass.h" +#include "codon/parser/cache.h" #include "llvm/Passes/PassBuilder.h" #include #include diff --git a/codon/dsl/plugins.h b/codon/dsl/plugins.h index c3015f8d..005c4f5f 100644 --- a/codon/dsl/plugins.h +++ b/codon/dsl/plugins.h @@ -7,9 +7,9 @@ #include #include +#include "codon/cir/util/iterators.h" #include "codon/compiler/error.h" #include "codon/dsl/dsl.h" -#include "codon/cir/util/iterators.h" #include "llvm/Support/DynamicLibrary.h" namespace codon { diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 1fe4ea16..d230bac1 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -8,10 +8,10 @@ #include #include +#include "codon/cir/cir.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/ctx.h" -#include "codon/cir/cir.h" #define FILE_GENERATED "" #define MODULE_MAIN "__main__" diff --git a/codon/parser/common.cpp b/codon/parser/common.cpp index cb8bb419..88a250a0 100644 --- a/codon/parser/common.cpp +++ b/codon/parser/common.cpp @@ -207,9 +207,12 @@ std::string library_path() { namespace { -void addPath(std::vector &paths, const std::string &path) { - if (llvm::sys::fs::exists(path)) +bool addPath(std::vector &paths, const std::string &path) { + if (llvm::sys::fs::exists(path)) { paths.push_back(getAbsolutePath(path)); + return true; + } + return false; } std::vector getStdLibPaths(const std::string &argv0, @@ -244,7 +247,9 @@ ImportFile getRoot(const std::string argv0, const std::vector &plug } if (!isStdLib && startswith(s, module0Root)) root = module0Root; - const std::string ext = ".codon"; + std::string ext = ".codon"; + if (!((root.empty() || startswith(s, root)) && endswith(s, ext))) + ext = ".py"; seqassertn((root.empty() || startswith(s, root)) && endswith(s, ext), "bad path substitution: {}, {}", s, root); auto module = s.substr(root.size() + 1, s.size() - root.size() - ext.size() - 1); @@ -280,6 +285,14 @@ std::shared_ptr getImportFile(const std::string &argv0, path = llvm::SmallString<128>(parentRelativeTo); llvm::sys::path::append(path, what, "__init__.codon"); addPath(paths, std::string(path)); + + path = llvm::SmallString<128>(parentRelativeTo); + llvm::sys::path::append(path, what); + llvm::sys::path::replace_extension(path, "py"); + addPath(paths, std::string(path)); + path = llvm::SmallString<128>(parentRelativeTo); + llvm::sys::path::append(path, what, "__init__.py"); + addPath(paths, std::string(path)); } } for (auto &p : getStdLibPaths(argv0, plugins)) { diff --git a/codon/parser/visitors/simplify/access.cpp b/codon/parser/visitors/simplify/access.cpp index 4e3459cb..5013c4a8 100644 --- a/codon/parser/visitors/simplify/access.cpp +++ b/codon/parser/visitors/simplify/access.cpp @@ -14,6 +14,10 @@ using namespace codon::error; namespace codon::ast { void SimplifyVisitor::visit(IdExpr *expr) { + if (startswith(expr->value, TYPE_TUPLE)) { + expr->markType(); + return; + } auto val = ctx->findDominatingBinding(expr->value); if (!val) E(Error::ID_NOT_FOUND, expr, expr->value); diff --git a/codon/parser/visitors/simplify/collections.cpp b/codon/parser/visitors/simplify/collections.cpp index 9b2039b8..ee24134f 100644 --- a/codon/parser/visitors/simplify/collections.cpp +++ b/codon/parser/visitors/simplify/collections.cpp @@ -53,15 +53,19 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) { auto loops = clone_nop(expr->loops); // Clone as loops will be modified - std::string optimizeVar; - if (expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 && - loops[0].conds.empty()) { - // List comprehension optimization: - // Use `iter.__len__()` when creating list if there is a single for loop - // without any if conditions in the comprehension - optimizeVar = ctx->cache->getTemporaryVar("i"); - stmts.push_back(transform(N(N(optimizeVar), loops[0].gen))); - loops[0].gen = N(optimizeVar); + // List comprehension optimization: + // Use `iter.__len__()` when creating list if there is a single for loop + // without any if conditions in the comprehension + bool canOptimize = expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 && + loops[0].conds.empty(); + if (canOptimize) { + auto iter = transform(loops[0].gen); + IdExpr *id; + if (iter->getCall() && (id = iter->getCall()->expr->getId())) { + // Turn off this optimization for static items + canOptimize &= !startswith(id->value, "std.internal.types.range.staticrange"); + canOptimize &= !startswith(id->value, "statictuple"); + } } SuiteStmt *prev = nullptr; @@ -72,16 +76,32 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) { if (expr->kind == GeneratorExpr::ListGenerator) { // List comprehensions std::vector args; - if (!optimizeVar.empty()) { - // Use special List.__init__(bool, [optimizeVar]) constructor - args = {N(true), N(optimizeVar)}; - } - stmts.push_back( - transform(N(clone(var), N(N("List"), args)))); prev->stmts.push_back( N(N(N(clone(var), "append"), clone(expr->expr)))); - stmts.push_back(transform(suite)); - resultExpr = N(stmts, transform(var)); + auto noOptStmt = + N(N(clone(var), N(N("List"))), suite); + if (canOptimize) { + seqassert(suite->getSuite() && !suite->getSuite()->stmts.empty() && + CAST(suite->getSuite()->stmts[0], ForStmt), + "bad comprehension transformation"); + auto optimizeVar = ctx->cache->getTemporaryVar("i"); + auto optSuite = clone(suite); + CAST(optSuite->getSuite()->stmts[0], ForStmt)->iter = N(optimizeVar); + + auto optStmt = N( + N(N(optimizeVar), clone(expr->loops[0].gen)), + N( + clone(var), + N(N("List"), + N(N(N(optimizeVar), "__len__")))), + optSuite); + resultExpr = transform( + N(N(N("hasattr"), clone(expr->loops[0].gen), + N("__len__")), + N(optStmt, clone(var)), N(noOptStmt, var))); + } else { + resultExpr = transform(N(noOptStmt, var)); + } } else if (expr->kind == GeneratorExpr::SetGenerator) { // Set comprehensions stmts.push_back( @@ -94,7 +114,16 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) { // Generators: converted to lambda functions that yield the target expression prev->stmts.push_back(N(clone(expr->expr))); stmts.push_back(suite); - resultExpr = N(N(N(makeAnonFn(stmts)), "__iter__")); + + auto anon = makeAnonFn(stmts); + if (auto call = anon->getCall()) { + seqassert(!call->args.empty() && call->args.back().value->getEllipsis(), + "bad lambda: {}", *call); + call->args.pop_back(); + } else { + anon = N(anon); + } + resultExpr = anon; } std::swap(avoidDomination, ctx->avoidDomination); } diff --git a/codon/parser/visitors/simplify/function.cpp b/codon/parser/visitors/simplify/function.cpp index 0e78276d..1286bf74 100644 --- a/codon/parser/visitors/simplify/function.cpp +++ b/codon/parser/visitors/simplify/function.cpp @@ -322,8 +322,8 @@ ExprPtr SimplifyVisitor::makeAnonFn(std::vector suite, prependStmts->push_back(fs->stmts[0]); for (StmtPtr s = fs->stmts[1]; s;) { if (auto suite = s->getSuite()) { - // Suites can only occur when __internal__.undef is inserted for a partial call - // argument. Extract __internal__.undef checks and prepend them + // Suites can only occur when captures are inserted for a partial call + // argument. seqassert(suite->stmts.size() == 2, "invalid function transform"); prependStmts->push_back(suite->stmts[0]); s = suite->stmts[1]; diff --git a/codon/parser/visitors/simplify/op.cpp b/codon/parser/visitors/simplify/op.cpp index 911dafa3..b5d9e503 100644 --- a/codon/parser/visitors/simplify/op.cpp +++ b/codon/parser/visitors/simplify/op.cpp @@ -95,8 +95,12 @@ void SimplifyVisitor::visit(IndexExpr *expr) { } } -/// Ignore it. Already transformed. Sometimes called again -/// during class extension. -void SimplifyVisitor::visit(InstantiateExpr *expr) {} +/// Already transformed. Sometimes needed again +/// for identifier analysis. +void SimplifyVisitor::visit(InstantiateExpr *expr) { + transformType(expr->typeExpr); + for (auto &tp : expr->typeParams) + transform(tp, true); +} } // namespace codon::ast diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 426159e2..c15cc235 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -7,11 +7,11 @@ #include #include +#include "codon/cir/transform/parallel/schedule.h" +#include "codon/cir/util/cloning.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/translate/translate_ctx.h" -#include "codon/cir/transform/parallel/schedule.h" -#include "codon/cir/util/cloning.h" using codon::ir::cast; using codon::ir::transform::parallel::OMPSched; diff --git a/codon/parser/visitors/translate/translate.h b/codon/parser/visitors/translate/translate.h index 648a01ca..1750543f 100644 --- a/codon/parser/visitors/translate/translate.h +++ b/codon/parser/visitors/translate/translate.h @@ -8,12 +8,12 @@ #include #include +#include "codon/cir/cir.h" #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/visitors/translate/translate_ctx.h" #include "codon/parser/visitors/visitor.h" -#include "codon/cir/cir.h" namespace codon::ast { diff --git a/codon/parser/visitors/translate/translate_ctx.h b/codon/parser/visitors/translate/translate_ctx.h index 83ecdc0e..7d6d11d8 100644 --- a/codon/parser/visitors/translate/translate_ctx.h +++ b/codon/parser/visitors/translate/translate_ctx.h @@ -7,11 +7,11 @@ #include #include +#include "codon/cir/cir.h" +#include "codon/cir/types/types.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" #include "codon/parser/ctx.h" -#include "codon/cir/cir.h" -#include "codon/cir/types/types.h" namespace codon::ast { diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index bfd6305e..0cb7687b 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -909,22 +909,30 @@ void TypecheckVisitor::addFunctionGenerics(const FuncType *t) { for (auto parent = t->funcParent; parent;) { if (auto f = parent->getFunc()) { // Add parent function generics - for (auto &g : f->funcGenerics) + for (auto &g : f->funcGenerics) { + // LOG(" -> {} := {}", g.name, g.type->debugString(true)); ctx->add(TypecheckItem::Type, g.name, g.type); + } parent = f->funcParent; } else { // Add parent class generics seqassert(parent->getClass(), "not a class: {}", parent); - for (auto &g : parent->getClass()->generics) + for (auto &g : parent->getClass()->generics) { + // LOG(" => {} := {}", g.name, g.type->debugString(true)); ctx->add(TypecheckItem::Type, g.name, g.type); - for (auto &g : parent->getClass()->hiddenGenerics) + } + for (auto &g : parent->getClass()->hiddenGenerics) { + // LOG(" :> {} := {}", g.name, g.type->debugString(true)); ctx->add(TypecheckItem::Type, g.name, g.type); + } break; } } // Add function generics - for (auto &g : t->funcGenerics) + for (auto &g : t->funcGenerics) { + // LOG(" >> {} := {}", g.name, g.type->debugString(true)); ctx->add(TypecheckItem::Type, g.name, g.type); + } } /// Generate a partial type `Partial.N` for a given function. diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index a80546c4..314e46cc 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -155,7 +155,8 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { // Generalize generics and remove them from the context for (const auto &g : generics) { for (auto &u : g->getUnbounds()) - u->getUnbound()->kind = LinkType::Generic; + if (u->getUnbound()) + u->getUnbound()->kind = LinkType::Generic; } // Construct the type @@ -163,8 +164,9 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { baseType, ctx->cache->functions[stmt->name].ast.get(), explicits); funcTyp->setSrcInfo(getSrcInfo()); - if (isClassMember && stmt->attributes.has(Attr::Method)) + if (isClassMember && stmt->attributes.has(Attr::Method)) { funcTyp->funcParent = ctx->find(stmt->attributes.parentClass)->type; + } funcTyp = std::static_pointer_cast(funcTyp->generalize(ctx->typecheckLevel)); diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 7209b277..90b528ea 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -7,11 +7,11 @@ #include #include +#include "codon/cir/types/types.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/typecheck/typecheck.h" -#include "codon/cir/types/types.h" using fmt::format; using namespace codon::error; diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index 65342602..b6551acd 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -636,7 +636,7 @@ ExprPtr TypecheckVisitor::transformBinaryInplaceMagic(BinaryExpr *expr, bool isA // In-place operations: check if `lhs.__iop__(lhs, rhs)` exists if (!method && expr->inPlace) { - method = findBestMethod(lt, format("__i{}__", magic), {lt, rt}); + method = findBestMethod(lt, format("__i{}__", magic), {expr->lexpr, expr->rexpr}); } if (method) @@ -667,11 +667,11 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) { } // Normal operations: check if `lhs.__magic__(lhs, rhs)` exists - auto method = findBestMethod(lt, format("__{}__", magic), {lt, rt}); + auto method = findBestMethod(lt, format("__{}__", magic), {expr->lexpr, expr->rexpr}); // Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists - if (!method && - (method = findBestMethod(rt, format("__{}__", rightMagic), {rt, lt}))) { + if (!method && (method = findBestMethod(rt, format("__{}__", rightMagic), + {expr->rexpr, expr->lexpr}))) { swap(expr->lexpr, expr->rexpr); } diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 537b2073..4d0d4ec4 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -187,6 +187,19 @@ TypecheckVisitor::findBestMethod(const ClassTypePtr &typ, const std::string &mem return m.empty() ? nullptr : m[0]; } +/// Select the best method indicated of an object that matches the given argument +/// types. See @c findMatchingMethods for details. +types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ, + const std::string &member, + const std::vector &args) { + std::vector callArgs; + for (auto &a : args) + callArgs.push_back({"", a}); + auto methods = ctx->findMethod(typ->name, member, false); + auto m = findMatchingMethods(typ, methods, callArgs); + return m.empty() ? nullptr : m[0]; +} + /// Select the best method among the provided methods given the list of arguments. /// See @c reorderNamedArgs for details. std::vector diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index 75332c18..4fb73a77 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -207,6 +207,9 @@ class TypecheckVisitor : public CallbackASTVisitor { types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ, const std::string &member, const std::vector &args); + types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ, + const std::string &member, + const std::vector &args); std::vector findMatchingMethods(const types::ClassTypePtr &typ, const std::vector &methods, diff --git a/docs/intro/differences.md b/docs/intro/differences.md index b363748c..54fb005e 100644 --- a/docs/intro/differences.md +++ b/docs/intro/differences.md @@ -14,8 +14,8 @@ in mind. - **Strings:** Codon currently uses ASCII strings unlike Python's unicode strings. -- **Dictionaries:** Codon's dictionary type is not sorted - internally, unlike Python's. +- **Dictionaries:** Codon's dictionary type does not preserve + insertion order, unlike Python's as of 3.6. # Type checking diff --git a/extra/python/codon/decorator.py b/extra/python/codon/decorator.py index 77370264..9d6e0c01 100644 --- a/extra/python/codon/decorator.py +++ b/extra/python/codon/decorator.py @@ -17,13 +17,21 @@ from .codon_jit import JITWrapper, JITError, codon_library if "CODON_PATH" not in os.environ: + codon_path = [] codon_lib_path = codon_library() - if not codon_lib_path: + if codon_lib_path: + codon_path.append(Path(codon_lib_path).parent / "stdlib") + codon_path.append( + Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib" + ) + for path in codon_path: + if path.exists(): + os.environ["CODON_PATH"] = str(path.resolve()) + break + else: raise RuntimeError( "Cannot locate Codon. Please install Codon or set CODON_PATH." ) - codon_path = (Path(codon_lib_path).parent / "stdlib").resolve() - os.environ["CODON_PATH"] = str(codon_path) pod_conversions = { type(None): "pyobj", diff --git a/stdlib/experimental/simd.codon b/stdlib/experimental/simd.codon index 8eb6fa5c..4416d762 100644 --- a/stdlib/experimental/simd.codon +++ b/stdlib/experimental/simd.codon @@ -1,6 +1,6 @@ # Copyright (C) 2022-2023 Exaloop Inc. -@tuple +@tuple(container=False) # disallow default __getitem__ class Vec[T, N: Static[int]]: ZERO_16x8i = Vec[u8,16](u8(0)) FF_16x8i = Vec[u8,16](u8(0xff)) @@ -307,6 +307,10 @@ class Vec[T, N: Static[int]]: else: return "?" + def scatter(self: Vec[T, N]) -> List[T]: + return [self[i] for i in staticrange(N)] + + u8x16 = Vec[u8, 16] u8x32 = Vec[u8, 32] f32x8 = Vec[f32, 8] diff --git a/stdlib/internal/builtin.codon b/stdlib/internal/builtin.codon index 26109be9..be18e3cf 100644 --- a/stdlib/internal/builtin.codon +++ b/stdlib/internal/builtin.codon @@ -248,19 +248,26 @@ def round(x, n=0): nx = float.__pow__(10.0, n) return float.__round__(x * nx) / nx -def sum(xi): +def _sum_start(x, start): + if isinstance(x.__iter__(), Generator[float]) and isinstance(start, int): + return float(start) + else: + return start + +def sum(x, start=0): """ - Return the sum of the items added together from xi + Return the sum of the items added together from x """ - x = iter(xi) - if not x.done(): - s = x.next() - while not x.done(): - s += x.next() - x.destroy() - return s - else: - x.destroy() + s = _sum_start(x, start) + + for a in x: + # don't use += to avoid calling iadd + if isinstance(a, bool): + s = s + (1 if a else 0) + else: + s = s + a + + return s def repr(x): """Return the string representation of x""" diff --git a/stdlib/internal/types/collections/list.codon b/stdlib/internal/types/collections/list.codon index 97be7656..fb96844c 100644 --- a/stdlib/internal/types/collections/list.codon +++ b/stdlib/internal/types/collections/list.codon @@ -24,13 +24,6 @@ class List: self.arr = Array[T](capacity) self.len = 0 - def __init__(self, dummy: bool, other): - """Dummy __init__ used for list comprehension optimization""" - if hasattr(other, "__len__"): - self.__init__(other.__len__()) - else: - self.__init__() - def __init__(self, arr: Array[T], len: int): self.arr = arr self.len = len diff --git a/test/core/bltin.codon b/test/core/bltin.codon index 71dea1e8..ec334bf4 100644 --- a/test/core/bltin.codon +++ b/test/core/bltin.codon @@ -85,6 +85,70 @@ def test_map_filter(): assert list(filter(lambda i: i%2 == 0, map(lambda i: i*i, range(10)))) == [0, 4, 16, 36, 64] +@test +def test_gen_builtins(): + assert sum([1, 2, 3]) == 6 + assert sum([1, 2, 3], 0.5) == 6.5 + assert sum([True, False, True, False, True], 0.5) == 3.5 + assert sum(List[float]()) == 0 + assert sum(i/2 for i in range(10)) == 22.5 + + def g1(): + yield 1.5 + yield 2.5 + return + yield 3.5 + + assert sum(g1(), 10) == 14.0 + + def g2(): + yield True + yield False + yield True + + assert sum(g2()) == 2 + + class A: + iadd_count = 0 + n: int + + def __init__(self, n): + self.n = n + + def __add__(self, other): + return A(self.n + other.n) + + def __iadd__(self, other): + A.iadd_count += 1 + self.n += other.n + return self + + assert sum((A(i) for i in range(5)), A(100)).n == 110 + assert A.iadd_count == 0 + + def g3(a, b): + for i in range(10): + yield a + yield b + + assert all([True, True]) + assert all(i for i in range(0)) + assert not all([True, False]) + assert all(List[str]()) + assert all(g3(True, True)) + assert not all(g3(True, False)) + assert not all(g3(False, True)) + assert not all(g3(False, False)) + + assert any([True, True]) + assert not any(i for i in range(0)) + assert not any([False, False]) + assert not any(List[bool]()) + assert any(g3(True, True)) + assert any(g3(True, False)) + assert any(g3(False, True)) + assert not any(g3(False, False)) + @test def test_int_format(): n = 0 @@ -269,6 +333,7 @@ def test_files(open_fn): test_min_max() test_map_filter() +test_gen_builtins() test_int_format() test_reversed() test_divmod() diff --git a/test/main.cpp b/test/main.cpp index 15259705..fbc82a9d 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -15,15 +15,15 @@ #include #include -#include "codon/compiler/compiler.h" -#include "codon/compiler/error.h" -#include "codon/parser/common.h" #include "codon/cir/analyze/dataflow/capture.h" #include "codon/cir/analyze/dataflow/reaching.h" #include "codon/cir/util/inlining.h" #include "codon/cir/util/irtools.h" #include "codon/cir/util/operator.h" #include "codon/cir/util/outlining.h" +#include "codon/compiler/compiler.h" +#include "codon/compiler/error.h" +#include "codon/parser/common.h" #include "codon/util/common.h" #include "gtest/gtest.h" diff --git a/test/parser/simplify_expr.codon b/test/parser/simplify_expr.codon index 274d0ac0..7723dd2b 100644 --- a/test/parser/simplify_expr.codon +++ b/test/parser/simplify_expr.codon @@ -157,12 +157,10 @@ print d #: {0: 9} #%% comprehension_opt,barebones @extend class List: - def __init__(self, dummy: bool, other): - if hasattr(other, '__len__'): - print 'optimize', other.__len__() - self.__init__(other.__len__()) - else: - self.__init__() + def __init__(self, cap: int): + print 'optimize', cap + self.arr = Array[T](cap) + self.len = 0 def foo(): yield 0 yield 1 diff --git a/test/transform/omp.codon b/test/transform/omp.codon index 6626594d..c964170e 100644 --- a/test/transform/omp.codon +++ b/test/transform/omp.codon @@ -457,7 +457,7 @@ def test_omp_reductions(): @par for i in L[1:1001]: c += f32(i) - assert c == sum(f32(i) for i in range(1001)) + assert c == sum((f32(i) for i in range(1001)), f32(0)) c = f32(1.) @par