Skip to content

Commit

Permalink
Add table64 lowering pass (#6595)
Browse files Browse the repository at this point in the history
Changes to wasm-validator.cpp here are mostly for consistency between
elem and data segment validation.
  • Loading branch information
sbc100 committed May 15, 2024
1 parent 2cc5e06 commit 2b60f8a
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 67 deletions.
1 change: 1 addition & 0 deletions src/passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ set(passes_SOURCES
ReorderGlobals.cpp
ReorderLocals.cpp
ReReloop.cpp
Table64Lowering.cpp
TrapMode.cpp
TypeGeneralizing.cpp
TypeRefining.cpp
Expand Down
75 changes: 39 additions & 36 deletions src/passes/Memory64Lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ struct Memory64Lowering : public WalkerPass<PostWalker<Memory64Lowering>> {
return;
}
auto& module = *getModule();
auto memory = module.getMemory(memoryName);
auto* memory = module.getMemory(memoryName);
if (memory->is64()) {
assert(ptr->type == Type::i64);
Builder builder(module);
ptr = builder.makeUnary(UnaryOp::WrapInt64, ptr);
ptr = Builder(module).makeUnary(UnaryOp::WrapInt64, ptr);
}
}

Expand All @@ -51,12 +50,11 @@ struct Memory64Lowering : public WalkerPass<PostWalker<Memory64Lowering>> {
return;
}
auto& module = *getModule();
auto memory = module.getMemory(memoryName);
auto* memory = module.getMemory(memoryName);
if (memory->is64()) {
assert(ptr->type == Type::i64);
ptr->type = Type::i32;
Builder builder(module);
ptr = builder.makeUnary(UnaryOp::ExtendUInt32, ptr);
ptr = Builder(module).makeUnary(UnaryOp::ExtendUInt32, ptr);
}
}

Expand All @@ -66,9 +64,9 @@ struct Memory64Lowering : public WalkerPass<PostWalker<Memory64Lowering>> {

void visitMemorySize(MemorySize* curr) {
auto& module = *getModule();
auto memory = module.getMemory(curr->memory);
auto* memory = module.getMemory(curr->memory);
if (memory->is64()) {
auto size = static_cast<Expression*>(curr);
auto* size = static_cast<Expression*>(curr);
extendAddress64(size, curr->memory);
curr->type = Type::i32;
replaceCurrent(size);
Expand All @@ -77,10 +75,10 @@ struct Memory64Lowering : public WalkerPass<PostWalker<Memory64Lowering>> {

void visitMemoryGrow(MemoryGrow* curr) {
auto& module = *getModule();
auto memory = module.getMemory(curr->memory);
auto* memory = module.getMemory(curr->memory);
if (memory->is64()) {
wrapAddress64(curr->delta, curr->memory);
auto size = static_cast<Expression*>(curr);
auto* size = static_cast<Expression*>(curr);
extendAddress64(size, curr->memory);
curr->type = Type::i32;
replaceCurrent(size);
Expand Down Expand Up @@ -129,34 +127,39 @@ struct Memory64Lowering : public WalkerPass<PostWalker<Memory64Lowering>> {
}

void visitDataSegment(DataSegment* segment) {
if (!segment->isPassive) {
if (auto* c = segment->offset->dynCast<Const>()) {
c->value = Literal(static_cast<uint32_t>(c->value.geti64()));
c->type = Type::i32;
} else if (auto* get = segment->offset->dynCast<GlobalGet>()) {
auto& module = *getModule();
auto* g = module.getGlobal(get->name);
if (g->imported() && g->base == MEMORY_BASE) {
ImportInfo info(module);
auto* memoryBase32 = info.getImportedGlobal(g->module, MEMORY_BASE32);
if (!memoryBase32) {
Builder builder(module);
memoryBase32 = builder
.makeGlobal(MEMORY_BASE32,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Immutable)
.release();
memoryBase32->module = g->module;
memoryBase32->base = MEMORY_BASE32;
module.addGlobal(memoryBase32);
}
// Use this alternative import when initializing the segment.
assert(memoryBase32);
get->type = Type::i32;
get->name = memoryBase32->name;
if (segment->isPassive) {
// passive segments don't have any offset to adjust
return;
}

if (auto* c = segment->offset->dynCast<Const>()) {
c->value = Literal(static_cast<uint32_t>(c->value.geti64()));
c->type = Type::i32;
} else if (auto* get = segment->offset->dynCast<GlobalGet>()) {
auto& module = *getModule();
auto* g = module.getGlobal(get->name);
if (g->imported() && g->base == MEMORY_BASE) {
ImportInfo info(module);
auto* memoryBase32 = info.getImportedGlobal(g->module, MEMORY_BASE32);
if (!memoryBase32) {
Builder builder(module);
memoryBase32 = builder
.makeGlobal(MEMORY_BASE32,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Immutable)
.release();
memoryBase32->module = g->module;
memoryBase32->base = MEMORY_BASE32;
module.addGlobal(memoryBase32);
}
// Use this alternative import when initializing the segment.
assert(memoryBase32);
get->type = Type::i32;
get->name = memoryBase32->name;
}
} else {
WASM_UNREACHABLE("unexpected elem offset");
}
}

Expand Down
145 changes: 145 additions & 0 deletions src/passes/Table64Lowering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//
// Lowers a module with a 64-bit table to one with a 32-bit table.
//
// This pass can be deleted once table64 is implemented in Wasm engines:
// https://github.com/WebAssembly/memory64/issues/51
//

#include "ir/bits.h"
#include "ir/import-utils.h"
#include "pass.h"
#include "wasm-builder.h"
#include "wasm.h"

namespace wasm {

static Name TABLE_BASE("__table_base");
static Name TABLE_BASE32("__table_base32");

struct Table64Lowering : public WalkerPass<PostWalker<Table64Lowering>> {

void wrapAddress64(Expression*& ptr, Name tableName) {
if (ptr->type == Type::unreachable) {
return;
}
auto& module = *getModule();
auto* table = module.getTable(tableName);
if (table->is64()) {
assert(ptr->type == Type::i64);
ptr = Builder(module).makeUnary(UnaryOp::WrapInt64, ptr);
}
}

void extendAddress64(Expression*& ptr, Name tableName) {
if (ptr->type == Type::unreachable) {
return;
}
auto& module = *getModule();
auto* table = module.getTable(tableName);
if (table->is64()) {
assert(ptr->type == Type::i64);
ptr->type = Type::i32;
ptr = Builder(module).makeUnary(UnaryOp::ExtendUInt32, ptr);
}
}

void visitTableSize(TableSize* curr) {
auto& module = *getModule();
auto* table = module.getTable(curr->table);
if (table->is64()) {
auto* size = static_cast<Expression*>(curr);
extendAddress64(size, curr->table);
replaceCurrent(size);
}
}

void visitTableGrow(TableGrow* curr) {
auto& module = *getModule();
auto* table = module.getTable(curr->table);
if (table->is64()) {
wrapAddress64(curr->delta, curr->table);
auto* size = static_cast<Expression*>(curr);
extendAddress64(size, curr->table);
replaceCurrent(size);
}
}

void visitTableFill(TableFill* curr) {
wrapAddress64(curr->dest, curr->table);
wrapAddress64(curr->size, curr->table);
}

void visitTableCopy(TableCopy* curr) {
wrapAddress64(curr->dest, curr->destTable);
wrapAddress64(curr->source, curr->sourceTable);
wrapAddress64(curr->size, curr->destTable);
}

void visitCallIndirect(CallIndirect* curr) {
wrapAddress64(curr->target, curr->table);
}

void visitTable(Table* table) {
// This is visited last.
if (table->is64()) {
table->indexType = Type::i32;
}
}

void visitElementSegment(ElementSegment* segment) {
if (segment->table.isNull()) {
// Passive segments don't have any offset to update.
return;
}

if (auto* c = segment->offset->dynCast<Const>()) {
c->value = Literal(static_cast<uint32_t>(c->value.geti64()));
c->type = Type::i32;
} else if (auto* get = segment->offset->dynCast<GlobalGet>()) {
auto& module = *getModule();
auto* g = module.getGlobal(get->name);
if (g->imported() && g->base == TABLE_BASE) {
ImportInfo info(module);
auto* memoryBase32 = info.getImportedGlobal(g->module, TABLE_BASE32);
if (!memoryBase32) {
Builder builder(module);
memoryBase32 = builder
.makeGlobal(TABLE_BASE32,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Immutable)
.release();
memoryBase32->module = g->module;
memoryBase32->base = TABLE_BASE32;
module.addGlobal(memoryBase32);
}
// Use this alternative import when initializing the segment.
assert(memoryBase32);
get->type = Type::i32;
get->name = memoryBase32->name;
}
} else {
WASM_UNREACHABLE("unexpected elem offset");
}
}
};

Pass* createTable64LoweringPass() { return new Table64Lowering(); }

} // namespace wasm
3 changes: 3 additions & 0 deletions src/passes/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ void PassRegistry::registerPasses() {
"lower loads and stores to a 64-bit memory to instead use a "
"32-bit one",
createMemory64LoweringPass);
registerPass("table64-lowering",
"lower 64-bit tables 32-bit ones",
createTable64LoweringPass);
registerPass("memory-packing",
"packs memory into separate segments, skipping zeros",
createMemoryPackingPass);
Expand Down
1 change: 1 addition & 0 deletions src/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ Pass* createStripEHPass();
Pass* createStubUnsupportedJSOpsPass();
Pass* createSSAifyPass();
Pass* createSSAifyNoMergePass();
Pass* createTable64LoweringPass();
Pass* createTranslateToExnrefPass();
Pass* createTrapModeClamp();
Pass* createTrapModeJS();
Expand Down
42 changes: 15 additions & 27 deletions src/wasm/wasm-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3754,25 +3754,14 @@ static void validateDataSegments(Module& module, ValidationInfo& info) {
"active segment must have a valid memory name")) {
continue;
}
if (memory->is64()) {
if (!info.shouldBeEqual(segment->offset->type,
Type(Type::i64),
segment->offset,
"segment offset should be i64")) {
continue;
}
} else {
if (!info.shouldBeEqual(segment->offset->type,
Type(Type::i32),
segment->offset,
"segment offset should be i32")) {
continue;
}
}
info.shouldBeEqual(segment->offset->type,
memory->indexType,
segment->offset,
"segment offset must match memory index type");
info.shouldBeTrue(
Properties::isValidConstantExpression(module, segment->offset),
segment->offset,
"memory segment offset should be constant");
"memory segment offset must be constant");
FunctionValidator(module, &info).validate(segment->offset);
}
}
Expand Down Expand Up @@ -3846,31 +3835,30 @@ static void validateTables(Module& module, ValidationInfo& info) {
"elem",
"Non-nullable reference types are not yet supported for tables");

if (segment->table.is()) {
bool isPassive = !segment->table.is();
if (isPassive) {
info.shouldBeTrue(
!segment->offset, "elem", "passive segment should not have an offset");
} else {
auto table = module.getTableOrNull(segment->table);
info.shouldBeTrue(table != nullptr,
"elem",
"element segment must have a valid table name");
info.shouldBeTrue(!!segment->offset,
"elem",
"table segment offset should have an offset");
info.shouldBeTrue(
!!segment->offset, "elem", "table segment offset must have an offset");
info.shouldBeEqual(segment->offset->type,
Type(Type::i32),
table->indexType,
segment->offset,
"element segment offset should be i32");
"element segment offset must match table index type");
info.shouldBeTrue(
Properties::isValidConstantExpression(module, segment->offset),
segment->offset,
"table segment offset should be constant");
"table segment offset must be constant");
info.shouldBeTrue(
Type::isSubType(segment->type, table->type),
"elem",
"element segment type must be a subtype of the table type");
validator.validate(segment->offset);
} else {
info.shouldBeTrue(!segment->offset,
"elem",
"non-table segment offset should have no offset");
}
for (auto* expr : segment->data) {
info.shouldBeTrue(Properties::isValidConstantExpression(module, expr),
Expand Down
2 changes: 0 additions & 2 deletions src/wasm/wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,8 +851,6 @@ void TableSize::finalize() {
void TableGrow::finalize() {
if (delta->type == Type::unreachable || value->type == Type::unreachable) {
type = Type::unreachable;
} else {
type = Type::i32;
}
}

Expand Down
2 changes: 2 additions & 0 deletions test/lit/help/wasm-opt.test
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,8 @@
;; CHECK-NEXT:
;; CHECK-NEXT: --symbolmap (alias for print-function-map)
;; CHECK-NEXT:
;; CHECK-NEXT: --table64-lowering lower 64-bit tables 32-bit ones
;; CHECK-NEXT:
;; CHECK-NEXT: --translate-to-exnref translate old Phase 3 EH
;; CHECK-NEXT: instructions to new ones with
;; CHECK-NEXT: exnref
Expand Down
2 changes: 2 additions & 0 deletions test/lit/help/wasm2js.test
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@
;; CHECK-NEXT:
;; CHECK-NEXT: --symbolmap (alias for print-function-map)
;; CHECK-NEXT:
;; CHECK-NEXT: --table64-lowering lower 64-bit tables 32-bit ones
;; CHECK-NEXT:
;; CHECK-NEXT: --translate-to-exnref translate old Phase 3 EH
;; CHECK-NEXT: instructions to new ones with
;; CHECK-NEXT: exnref
Expand Down

0 comments on commit 2b60f8a

Please sign in to comment.