Skip to content

Commit

Permalink
Merge branch 'develop' into feature/generic_slices
Browse files Browse the repository at this point in the history
  • Loading branch information
pakaelbling committed Feb 20, 2024
2 parents 262c6b5 + 7a787bf commit 9390a70
Show file tree
Hide file tree
Showing 42 changed files with 1,004 additions and 161 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:

- name: Cache Dependencies
id: cache-deps
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: llvm
key: manylinux-llvm
Expand Down Expand Up @@ -133,7 +133,7 @@ jobs:
- name: Cache Dependencies
id: cache-deps
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: llvm
key: ${{ runner.os }}-llvm
Expand Down
11 changes: 4 additions & 7 deletions codon/cir/llvm/llvisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,21 @@ LLVMVisitor::LLVMVisitor()
auto &registry = *llvm::PassRegistry::getPassRegistry();
llvm::initializeCore(registry);
llvm::initializeScalarOpts(registry);
llvm::initializeObjCARCOpts(registry);
llvm::initializeVectorization(registry);
llvm::initializeIPO(registry);
llvm::initializeAnalysis(registry);
llvm::initializeTransformUtils(registry);
llvm::initializeInstCombine(registry);
llvm::initializeAggressiveInstCombine(registry);
llvm::initializeInstrumentation(registry);
llvm::initializeTarget(registry);

llvm::initializeExpandLargeDivRemLegacyPassPass(registry);
llvm::initializeExpandLargeFpConvertLegacyPassPass(registry);
llvm::initializeExpandMemCmpPassPass(registry);
llvm::initializeScalarizeMaskedMemIntrinLegacyPassPass(registry);
llvm::initializeSelectOptimizePass(registry);
llvm::initializeCallBrPreparePass(registry);
llvm::initializeCodeGenPreparePass(registry);
llvm::initializeAtomicExpandPass(registry);
llvm::initializeRewriteSymbolsLegacyPassPass(registry);
llvm::initializeWinEHPreparePass(registry);
llvm::initializeDwarfEHPrepareLegacyPassPass(registry);
llvm::initializeSafeStackLegacyPassPass(registry);
Expand All @@ -110,8 +109,6 @@ LLVMVisitor::LLVMVisitor()
llvm::initializeExpandVectorPredicationPass(registry);
llvm::initializeWasmEHPreparePass(registry);
llvm::initializeWriteBitcodePassPass(registry);
llvm::initializeHardwareLoopsPass(registry);
llvm::initializeTypePromotionPass(registry);
llvm::initializeReplaceWithVeclibLegacyPass(registry);
llvm::initializeJMCInstrumenterPass(registry);
}
Expand Down Expand Up @@ -1855,7 +1852,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
// set up debug info
// for now we just set all to func's source location
auto *srcInfo = getSrcInfo(x);
for (auto &block : func->getBasicBlockList()) {
for (auto &block : *func) {
for (auto &inst : block) {
if (!inst.getDebugLoc()) {
inst.setDebugLoc(llvm::DebugLoc(llvm::DILocation::get(
Expand Down
6 changes: 1 addition & 5 deletions codon/cir/llvm/llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/Analysis/CaptureTracking.h"
Expand All @@ -26,7 +25,6 @@
#include "llvm/ExecutionEngine/JITLink/JITLink.h"
#include "llvm/ExecutionEngine/JITLink/JITLinkDylib.h"
#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h"
#include "llvm/ExecutionEngine/JITLink/MemoryFlags.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/MCJIT.h"
#include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h"
Expand Down Expand Up @@ -81,7 +79,6 @@
#include "llvm/LinkAllIR.h"
#include "llvm/LinkAllPasses.h"
#include "llvm/Linker/Linker.h"
#include "llvm/MC/SubtargetFeature.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/Allocator.h"
Expand All @@ -92,7 +89,6 @@
#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Host.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Memory.h"
#include "llvm/Support/Process.h"
Expand All @@ -105,11 +101,11 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetLoweringObjectFile.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/IPO/GlobalDCE.h"
#include "llvm/Transforms/IPO/Internalize.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/IPO/StripDeadPrototypes.h"
#include "llvm/Transforms/IPO/StripSymbols.h"
#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
Expand Down
2 changes: 1 addition & 1 deletion codon/cir/llvm/optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void applyDebugTransformations(llvm::Module *module, bool debug, bool jit) {
f.addFnAttr("no-frame-pointer-elim-non-leaf");
f.addFnAttr("no-jump-tables", "false");

for (auto &block : f.getBasicBlockList()) {
for (auto &block : f) {
for (auto &inst : block) {
if (auto *call = llvm::dyn_cast<llvm::CallInst>(&inst)) {
call->setTailCall(false);
Expand Down
24 changes: 11 additions & 13 deletions codon/cir/transform/parallel/openmp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,16 @@ struct Reduction {
case Kind::XOR:
result = *lhs ^ *arg;
break;
case Kind::MIN: {
auto *tup = util::makeTuple({lhs, arg});
auto *fn = M->getOrRealizeFunc("min", {tup->getType()}, {}, builtinModule);
seqassertn(fn, "min function not found");
result = util::call(fn, {tup});
break;
}
case Kind::MIN:
case Kind::MAX: {
// signature is (tuple of args, key, default)
auto name = (kind == Kind::MIN ? "min" : "max");
auto *tup = util::makeTuple({lhs, arg});
auto *fn = M->getOrRealizeFunc("max", {tup->getType()}, {}, builtinModule);
seqassertn(fn, "max function not found");
result = util::call(fn, {tup});
auto *none = (*M->getNoneType())();
auto *fn = M->getOrRealizeFunc(
name, {tup->getType(), none->getType(), none->getType()}, {}, builtinModule);
seqassertn(fn, "{} function not found", name);
result = util::call(fn, {tup, none, none});
break;
}
default:
Expand Down Expand Up @@ -432,6 +430,7 @@ struct ReductionIdentifier : public util::Operator {
auto *ptrType = cast<types::PointerType>(shared->getType());
seqassertn(ptrType, "expected shared var to be of pointer type");
auto *type = ptrType->getBase();
auto *noneType = M->getOptionalType(M->getNoneType());

// double-check the call
if (!util::isCallOf(v, Module::SETITEM_MAGIC_NAME,
Expand All @@ -454,7 +453,8 @@ struct ReductionIdentifier : public util::Operator {
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true))
continue;
} else {
if (!util::isCallOf(item, rf.name, {M->getTupleType({type, type})}, type,
if (!util::isCallOf(item, rf.name,
{M->getTupleType({type, type}), noneType, noneType}, type,
/*method=*/false))
continue;
}
Expand Down Expand Up @@ -1183,9 +1183,7 @@ struct GPULoopBodyStubReplacer : public util::Operator {

std::vector<Value *> newArgs;
for (auto *arg : *replacement) {
// std::cout << "A: " << *arg << std::endl;
if (getVarFromOutlinedArg(arg)->getId() == loopVar->getId()) {
// std::cout << "(loop var)" << std::endl;
newArgs.push_back(idx);
} else {
newArgs.push_back(util::tupleGet(args, next++));
Expand Down
12 changes: 10 additions & 2 deletions codon/cir/transform/pythonic/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ struct GeneratorSumTransformer : public util::Operator {
auto *M = v->getModule();
auto *newReturn = M->Nr<ReturnInstr>(M->Nr<VarValue>(accumulator));
see(newReturn);
v->replaceAll(util::series(v->getValue(), newReturn));
if (v->getValue()) {
v->replaceAll(util::series(v->getValue(), newReturn));
} else {
v->replaceAll(newReturn);
}
}

void handle(YieldInInstr *v) override { valid = false; }
Expand Down Expand Up @@ -97,7 +101,11 @@ struct GeneratorAnyAllTransformer : public util::Operator {
auto *M = v->getModule();
auto *newReturn = M->Nr<ReturnInstr>(M->getBool(!any));
see(newReturn);
v->replaceAll(util::series(v->getValue(), newReturn));
if (v->getValue()) {
v->replaceAll(util::series(v->getValue(), newReturn));
} else {
v->replaceAll(newReturn);
}
}

void handle(YieldInInstr *v) override { valid = false; }
Expand Down
6 changes: 4 additions & 2 deletions codon/compiler/debug_listener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,15 @@ llvm::Error DebugPlugin::notifyFailed(llvm::orc::MaterializationResponsibility &
return llvm::Error::success();
}

llvm::Error DebugPlugin::notifyRemovingResources(llvm::orc::ResourceKey key) {
llvm::Error DebugPlugin::notifyRemovingResources(llvm::orc::JITDylib &jd,
llvm::orc::ResourceKey key) {
std::lock_guard<std::mutex> lock(pluginMutex);
registeredObjs.erase(key);
return llvm::Error::success();
}

void DebugPlugin::notifyTransferringResources(llvm::orc::ResourceKey dstKey,
void DebugPlugin::notifyTransferringResources(llvm::orc::JITDylib &jd,
llvm::orc::ResourceKey dstKey,
llvm::orc::ResourceKey srcKey) {
std::lock_guard<std::mutex> lock(pluginMutex);
auto it = registeredObjs.find(srcKey);
Expand Down
6 changes: 4 additions & 2 deletions codon/compiler/debug_listener.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ class DebugPlugin : public llvm::orc::ObjectLinkingLayer::Plugin {
llvm::MemoryBufferRef inputObject) override;
llvm::Error notifyEmitted(llvm::orc::MaterializationResponsibility &mr) override;
llvm::Error notifyFailed(llvm::orc::MaterializationResponsibility &mr) override;
llvm::Error notifyRemovingResources(llvm::orc::ResourceKey key) override;
void notifyTransferringResources(llvm::orc::ResourceKey dstKey,
llvm::Error notifyRemovingResources(llvm::orc::JITDylib &jd,
llvm::orc::ResourceKey key) override;
void notifyTransferringResources(llvm::orc::JITDylib &jd,
llvm::orc::ResourceKey dstKey,
llvm::orc::ResourceKey srcKey) override;
void modifyPassConfig(llvm::orc::MaterializationResponsibility &mr,
llvm::jitlink::LinkGraph &,
Expand Down
7 changes: 3 additions & 4 deletions codon/compiler/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ llvm::Expected<std::unique_ptr<Engine>> Engine::create() {

auto sess = std::make_unique<llvm::orc::ExecutionSession>(std::move(*epc));

auto epciu =
llvm::orc::EPCIndirectionUtils::Create(sess->getExecutorProcessControl());
auto epciu = llvm::orc::EPCIndirectionUtils::Create(*sess);
if (!epciu)
return epciu.takeError();

(*epciu)->createLazyCallThroughManager(
*sess, llvm::pointerToJITTargetAddress(&handleLazyCallThroughError));
*sess, llvm::orc::ExecutorAddr::fromPtr(&handleLazyCallThroughError));

if (auto err = llvm::orc::setUpInProcessLCTMReentryViaEPCIU(**epciu))
return std::move(err);
Expand All @@ -87,7 +86,7 @@ llvm::Error Engine::addModule(llvm::orc::ThreadSafeModule module,
return optimizeLayer.add(rt, std::move(module));
}

llvm::Expected<llvm::JITEvaluatedSymbol> Engine::lookup(llvm::StringRef name) {
llvm::Expected<llvm::orc::ExecutorSymbolDef> Engine::lookup(llvm::StringRef name) {
return sess->lookup({&mainJD}, mangle(name.str()));
}

Expand Down
2 changes: 1 addition & 1 deletion codon/compiler/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Engine {
llvm::Error addModule(llvm::orc::ThreadSafeModule module,
llvm::orc::ResourceTrackerSP rt = nullptr);

llvm::Expected<llvm::JITEvaluatedSymbol> lookup(llvm::StringRef name);
llvm::Expected<llvm::orc::ExecutorSymbolDef> lookup(llvm::StringRef name);
};

} // namespace jit
Expand Down
6 changes: 3 additions & 3 deletions codon/compiler/jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ llvm::Error JIT::init() {
if (auto err = func.takeError())
return err;

auto *main = (MainFunc *)func->getAddress();
auto *main = func->getAddress().toPtr<MainFunc>();
(*main)(0, nullptr);
return llvm::Error::success();
}
Expand Down Expand Up @@ -174,7 +174,7 @@ llvm::Expected<void *> JIT::address(const ir::Func *input) {
if (auto err = func.takeError())
return std::move(err);

return (void *)func->getAddress();
return (void *)func->getAddress().getValue();
}

llvm::Expected<std::string> JIT::run(const ir::Func *input) {
Expand Down Expand Up @@ -292,7 +292,7 @@ JITResult JIT::executePython(const std::string &name,
auto *wrapper = it->second;
const std::string name = ir::LLVMVisitor::getNameForFunction(wrapper);
auto func = llvm::cantFail(engine->lookup(name));
wrap = (PyWrapperFunc *)func.getAddress();
wrap = func.getAddress().toPtr<PyWrapperFunc>();
} else {
static int idx = 0;
auto wrapname = "__codon_wrapped__" + name + "_" + std::to_string(idx++);
Expand Down
4 changes: 2 additions & 2 deletions codon/compiler/memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void BoehmGCJITLinkMemoryManager::allocate(const llvm::jitlink::JITLinkDylib *JD
auto &Seg = KV.second;

auto &SegAddr =
(AG.getMemDeallocPolicy() == llvm::jitlink::MemDeallocPolicy::Standard)
(AG.getMemLifetimePolicy() == llvm::orc::MemLifetimePolicy::Standard)
? NextStandardSegAddr
: NextFinalizeSegAddr;

Expand All @@ -189,7 +189,7 @@ void BoehmGCJITLinkMemoryManager::allocate(const llvm::jitlink::JITLinkDylib *JD
SegAddr += llvm::alignTo(Seg.ContentSize + Seg.ZeroFillSize, PageSize);

if (static_cast<int>(AG.getMemProt()) &
static_cast<int>(llvm::jitlink::MemProt::Write)) {
static_cast<int>(llvm::orc::MemProt::Write)) {
seq_gc_add_roots((void *)Seg.Addr.getValue(), (void *)SegAddr.getValue());
}
}
Expand Down
1 change: 1 addition & 0 deletions codon/parser/ast/stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ const std::string Attr::CVarArg = ".__vararg__";
const std::string Attr::Method = ".__method__";
const std::string Attr::Capture = ".__capture__";
const std::string Attr::HasSelf = ".__hasself__";
const std::string Attr::IsGenerator = ".__generator__";
const std::string Attr::Extend = "extend";
const std::string Attr::Tuple = "tuple";
const std::string Attr::Test = "std.internal.attributes.test";
Expand Down
1 change: 1 addition & 0 deletions codon/parser/ast/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ struct Attr {
const static std::string Method;
const static std::string Capture;
const static std::string HasSelf;
const static std::string IsGenerator;
// Class attributes
const static std::string Extend;
const static std::string Tuple;
Expand Down
1 change: 1 addition & 0 deletions codon/parser/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.h"

#include <cinttypes>
#include <climits>
#include <string>
#include <vector>

Expand Down
30 changes: 26 additions & 4 deletions codon/parser/visitors/simplify/access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void SimplifyVisitor::visit(IdExpr *expr) {
if (!checked) {
// Prepend access with __internal__.undef([var]__used__, "[var name]")
auto checkStmt = N<ExprStmt>(N<CallExpr>(
N<DotExpr>("__internal__", "undef"),
N<IdExpr>("__internal__.undef"),
N<IdExpr>(fmt::format("{}.__used__", val->canonicalName)),
N<StringExpr>(ctx->cache->reverseIdentifierLookup[val->canonicalName])));
if (!ctx->isConditionalExpr) {
Expand Down Expand Up @@ -230,10 +230,11 @@ SimplifyVisitor::getImport(const std::vector<std::string> &chain) {

// Find the longest prefix that corresponds to the existing import
// (e.g., `a.b.c.d` -> `a.b.c` if there is `import a.b.c`)
SimplifyContext::Item val = nullptr;
SimplifyContext::Item val = nullptr, importVal = nullptr;
for (auto i = chain.size(); i-- > 0;) {
val = ctx->find(join(chain, "/", 0, i + 1));
if (val && val->isImport()) {
importVal = val;
importName = val->importPath, importEnd = i + 1;
break;
}
Expand All @@ -254,6 +255,14 @@ SimplifyVisitor::getImport(const std::vector<std::string> &chain) {
return {importEnd, val};
} else {
val = fctx->find(join(chain, ".", importEnd, i + 1));
if (val && i + 1 != chain.size() && val->isImport()) {
importVal = val;
importName = val->importPath;
importEnd = i + 1;
fctx = ctx->cache->imports[importName].ctx;
i = chain.size();
continue;
}
if (val && (importName.empty() || val->isType() || !val->isConditional())) {
itemName = val->canonicalName, itemEnd = i + 1;
break;
Expand All @@ -264,10 +273,23 @@ SimplifyVisitor::getImport(const std::vector<std::string> &chain) {
if (ctx->getBase()->pyCaptures)
return {1, nullptr};
E(Error::IMPORT_NO_MODULE, getSrcInfo(), chain[importEnd]);
}
if (itemName.empty())
} else if (itemName.empty()) {
if (!ctx->isStdlibLoading && endswith(importName, "__init__.codon")) {
auto import = ctx->cache->imports[importName];
auto file =
getImportFile(ctx->cache->argv0, chain[importEnd], importName, false,
ctx->cache->module0, ctx->cache->pluginImportPaths);
if (file) {
auto s = SimplifyVisitor(import.ctx, preamble)
.transform(N<ImportStmt>(N<IdExpr>(chain[importEnd]), nullptr));
prependStmts->push_back(s);
return getImport(chain);
}
}

E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd],
ctx->cache->imports[importName].moduleName);
}
importEnd = itemEnd;
}
return {importEnd, val};
Expand Down

0 comments on commit 9390a70

Please sign in to comment.