Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix two compute_with bugs. #8152

Merged
merged 2 commits into from Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
224 changes: 150 additions & 74 deletions src/ScheduleFunctions.cpp
Expand Up @@ -1021,81 +1021,126 @@ class CollectBounds : public IRVisitor {
}
};

class SubstituteFusedBounds : public IRMutator {
public:
const map<string, Expr> &replacements;
explicit SubstituteFusedBounds(const map<string, Expr> &r)
: replacements(r) {
// Rename a loop var in a compute_with cluster to include '.fused.', to
// disambiguate its bounds from the original loop bounds. The '.fused.' token is
// injected somewhere that's not going to change the results of var_name_match,
// so that it's unchanged as a scheduling point.
string fused_name(const string &var) {
size_t last_dot = var.rfind('.');
internal_assert(last_dot != string::npos);
return var.substr(0, last_dot) + ".fused." + var.substr(last_dot + 1);
}

// The bounds of every loop exist in 'replacements' should be replaced. The
// loop is also renamed by adding '.fused' in the original name before the
// variable name.
Stmt substitute_fused_bounds(Stmt s, const map<string, Expr> &replacements) {
if (!s.defined() || replacements.empty()) {
return s;
}

private:
using IRMutator::visit;
class SubstituteFusedBounds : public IRMutator {
const map<string, Expr> &replacements;

Stmt visit(const For *op) override {
const auto *min_var = op->min.as<Variable>();
const auto *extent_var = op->extent.as<Variable>();
if (min_var && extent_var) {
Expr min_val, extent_val;
{
const auto &it = replacements.find(min_var->name);
if (it != replacements.end()) {
min_val = it->second;
using IRMutator::visit;

Stmt visit(const For *op) override {
const auto *min_var = op->min.as<Variable>();
const auto *extent_var = op->extent.as<Variable>();
if (min_var && extent_var) {
Expr min_val, extent_val;
{
const auto &it = replacements.find(min_var->name);
if (it != replacements.end()) {
min_val = it->second;
}
}
}
{
const auto &it = replacements.find(extent_var->name);
if (it != replacements.end()) {
extent_val = it->second;
{
const auto &it = replacements.find(extent_var->name);
if (it != replacements.end()) {
extent_val = it->second;
}
}
if (!min_val.defined() || !extent_val.defined()) {
return IRMutator::visit(op);
}

Stmt body = mutate(op->body);

string new_var = fused_name(op->name);

ForType for_type = op->for_type;
DeviceAPI device_api = op->device_api;
if (is_const_one(extent_val)) {
// This is the child loop of a fused group. The real loop of the
// fused group is the loop of the parent function of the fused
// group. This child loop is just a scheduling point, and should
// never be a device transition, so we rewrite it to be a simple
// serial loop of extent 1."
for_type = ForType::Serial;
device_api = DeviceAPI::None;
}

Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"),
Variable::make(Int(32), new_var + ".loop_extent"),
for_type, op->partition_policy, device_api, body);

// Add let stmts defining the bound of the renamed for-loop.
stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt);
stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt);
stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt);
// Replace any reference to the old loop name with the new one.
stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt);
return stmt;
} else {
return IRMutator::visit(op);
}
if (!min_val.defined() || !extent_val.defined()) {
}

public:
explicit SubstituteFusedBounds(const map<string, Expr> &r)
: replacements(r) {
}
} subs(replacements);

return subs.mutate(s);
}

// Add letstmts inside each parent loop that define the corresponding child loop
// vars as equal to it. Bounds inference might need a child loop var.
Stmt add_loop_var_aliases(Stmt s, const map<string, set<string>> &loop_var_aliases) {
if (!s.defined() || loop_var_aliases.empty()) {
return s;
}

class AddLoopVarAliases : public IRMutator {
const map<string, set<string>> &loop_var_aliases;

using IRMutator::visit;

Stmt visit(const For *op) override {
auto it = loop_var_aliases.find(op->name);
if (it == loop_var_aliases.end()) {
return IRMutator::visit(op);
}

Expr var = Variable::make(Int(32), op->name);
Stmt body = mutate(op->body);

size_t last_dot = op->name.rfind('.');
internal_assert(last_dot != string::npos);
string new_var = op->name.substr(0, last_dot) + ".fused." + op->name.substr(last_dot + 1);

ForType for_type = op->for_type;
DeviceAPI device_api = op->device_api;
if (is_const_one(extent_val)) {
// This is the child loop of a fused group. The real loop of the
// fused group is the loop of the parent function of the fused
// group. This child loop is just a scheduling point, and should
// never be a device transition, so we rewrite it to be a simple
// serial loop of extent 1."
for_type = ForType::Serial;
device_api = DeviceAPI::None;
for (const string &alias : it->second) {
body = LetStmt::make(alias, var, body);
}

Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"),
Variable::make(Int(32), new_var + ".loop_extent"),
for_type, op->partition_policy, device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type,
op->partition_policy, op->device_api, std::move(body));
}

// Add let stmts defining the bound of the renamed for-loop.
stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt);
stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt);
stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt);
// Replace any reference to the old loop name with the new one.
stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt);
return stmt;
} else {
return IRMutator::visit(op);
public:
explicit AddLoopVarAliases(const map<string, set<string>> &a)
: loop_var_aliases(a) {
}
}
};
} add_aliases(loop_var_aliases);

// The bounds of every loop exist in 'replacements' should be replaced. The
// loop is also renamed by adding '.fused' in the original name before the
// variable name.
Stmt substitute_fused_bounds(Stmt s, const map<string, Expr> &replacements) {
if (!s.defined() || replacements.empty()) {
return s;
} else {
return SubstituteFusedBounds(replacements).mutate(s);
}
return add_aliases.mutate(s);
}

// Shift the iteration domain of a loop nest by some factor.
Expand Down Expand Up @@ -1460,7 +1505,9 @@ class InjectFunctionRealization : public IRMutator {
}

Stmt build_produce_definition(const Function &f, const string &prefix, const Definition &def, bool is_update,
map<string, Expr> &replacements, vector<pair<string, Expr>> &add_lets) {
map<string, Expr> &replacements,
vector<pair<string, Expr>> &add_lets,
map<string, set<string>> &aliases) {
const vector<Dim> &dims = def.schedule().dims(); // From inner to outer
const LoopLevel &fuse_level = def.schedule().fuse_level().level;

Expand Down Expand Up @@ -1499,6 +1546,10 @@ class InjectFunctionRealization : public IRMutator {
replacements.emplace(var + ".loop_extent", make_const(Int(32), 1));
replacements.emplace(var + ".loop_min", val);
replacements.emplace(var + ".loop_max", val);

string var_fused = fused_name(var_orig);
aliases[var_fused].emplace(std::move(var_orig));
aliases[var_fused].emplace(std::move(var));
}
}

Expand Down Expand Up @@ -1550,18 +1601,17 @@ class InjectFunctionRealization : public IRMutator {

// Replace the bounds of the parent fused loop (i.e. the first one to be
// realized in the group) with union of the bounds of the fused group.
Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, const map<string, Expr> &bounds) {
string prefix = f.name() + ".s0";
const Definition &def = f.definition();
Stmt replace_parent_bound_with_union_bound(const string &func, int stage,
const Definition &def, Stmt produce,
const map<string, Expr> &bounds,
map<string, Expr> &replacements) {

if (!def.defined()) {
if (def.schedule().fused_pairs().empty()) {
return produce;
}

const vector<Dim> &dims = def.schedule().dims(); // From inner to outer

map<string, Expr> replacements;

vector<FusedPair> dependence = collect_all_dependence(def);

// Compute the union of the bounds of the fused loops.
Expand All @@ -1582,6 +1632,8 @@ class InjectFunctionRealization : public IRMutator {
// the parent, e.g. y.yi and yi.
int dim2_idx = (int)(dims_2.size() - (dims.size() - i));
internal_assert(dim2_idx < (int)dims_2.size());
string var_1 = func + ".s" + std::to_string(stage) +
"." + dims[i].var;

string var_2 = pair.func_2 + ".s" + std::to_string(pair.stage_2) +
"." + dims_2[dim2_idx].var;
Expand All @@ -1592,7 +1644,6 @@ class InjectFunctionRealization : public IRMutator {
Expr max_2 = bounds.find(var_2 + ".loop_max")->second;
Expr extent_2 = bounds.find(var_2 + ".loop_extent")->second;

string var_1 = prefix + "." + dims[i].var;
internal_assert(bounds.count(var_1 + ".loop_min"));
internal_assert(bounds.count(var_1 + ".loop_max"));
internal_assert(bounds.count(var_1 + ".loop_extent"));
Expand All @@ -1616,8 +1667,26 @@ class InjectFunctionRealization : public IRMutator {
}
}

// Now, replace the bounds of the parent fused loops with the union bounds.
// Now, replace the bounds of the parent fused loops with the union
// bounds.
for (const auto &spec : def.specializations()) {
produce = replace_parent_bound_with_union_bound(func, stage, spec.definition, produce, bounds, replacements);
}

return produce;
}

Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce,
const map<string, Expr> &bounds) {
map<string, Expr> replacements;

int stage = 0;
produce = replace_parent_bound_with_union_bound(f.name(), stage++, f.definition(), produce, bounds, replacements);
for (const Definition &def : f.updates()) {
produce = replace_parent_bound_with_union_bound(f.name(), stage++, def, produce, bounds, replacements);
}
produce = substitute_fused_bounds(produce, replacements);

return produce;
}

Expand Down Expand Up @@ -1748,22 +1817,23 @@ class InjectFunctionRealization : public IRMutator {
Stmt producer;
map<string, Expr> replacements;
vector<pair<string, Expr>> add_lets;
map<string, set<string>> aliases;

for (const auto &func_stage : stage_order) {
const auto &f = func_stage.first;

if (f.has_extern_definition() && (func_stage.second == 0)) {
const Stmt &produceDef = Internal::build_extern_produce(env, f, target);
producer = inject_stmt(producer, produceDef, LoopLevel::inlined().lock());
const Stmt &produce_def = Internal::build_extern_produce(env, f, target);
producer = inject_stmt(producer, produce_def, LoopLevel::inlined().lock());
Comment on lines -1756 to +1827
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were these changes done automatically? Why the churn?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a drive-by fix to a location variable that didn't match our naming conventions.

continue;
}

string def_prefix = f.name() + ".s" + std::to_string(func_stage.second) + ".";
const auto &def = (func_stage.second == 0) ? f.definition() : f.updates()[func_stage.second - 1];

const Stmt &produceDef = build_produce_definition(f, def_prefix, def, func_stage.second > 0,
replacements, add_lets);
producer = inject_stmt(producer, produceDef, def.schedule().fuse_level().level);
const Stmt &produce_def = build_produce_definition(f, def_prefix, def, func_stage.second > 0,
replacements, add_lets, aliases);
producer = inject_stmt(producer, produce_def, def.schedule().fuse_level().level);
}

internal_assert(producer.defined());
Expand Down Expand Up @@ -1799,8 +1869,14 @@ class InjectFunctionRealization : public IRMutator {

// Replace the bounds of parent fused loop with union of bounds of
// the fused loops.
Function group_parent = funcs.back();
producer = replace_parent_bound_with_union_bound(funcs.back(), producer, bounds);

// Define the old loop var names as equal to the corresponding parent
// fused loop var. Bounds inference might refer directly to the original
// loop vars.
producer = add_loop_var_aliases(producer, aliases);

// Add the producer nodes.
for (const auto &i : funcs) {
producer = ProducerConsumer::make_produce(i.name(), producer);
Expand Down