From 96a283cf4ef39f04d3989f3d01cd73408ec5d872 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 12 Mar 2024 16:50:05 -0700 Subject: [PATCH 1/2] Fix two compute_with bugs. This PR fixes a bug in compute_with, and another bug I found while fixing it (we could really use a compute_with fuzzer). The first bug is that you can get into situations where the bounds of a producer func will refer directly to the loop variable of a consumer func, where the consumer is in a compute_with fused group. In main, that loop variable may not be defined because fused loop names have been rewritten to include the token ".fused.". This PR adds let stmts to define it just inside the fused loop body. The second bug is that not all parent loops in compute_with fused groups were having their bounds expanded to cover the region to be computed of all children, because the logic for deciding which loops to expand only considered the non-specialized pure definition. So e.g. compute_with applied to an update stage would fail to compute values of the child Func where they do not overlap with the parent Func. This PR visits all definitions of the parent Func of the fused group, instead of just the unspecialized pure definition of the parent Func. Fixes #8149 --- src/ScheduleFunctions.cpp | 224 ++++++++++++++++++++---------- test/correctness/compute_with.cpp | 87 +++++++++++- 2 files changed, 236 insertions(+), 75 deletions(-) diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index aa45841253b7..88f15ec3d36d 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -1021,81 +1021,126 @@ class CollectBounds : public IRVisitor { } }; -class SubstituteFusedBounds : public IRMutator { -public: - const map &replacements; - explicit SubstituteFusedBounds(const map &r) - : replacements(r) { +// Rename a loop var in a compute_with cluster to include '.fused.', to +// disambiguate its bounds from the original loop bounds. The '.fused.' token is +// injected somewhere that's not going to change the results of var_name_match, +// so that it's unchanged as a scheduling point. +string fused_name(const string &var) { + size_t last_dot = var.rfind('.'); + internal_assert(last_dot != string::npos); + return var.substr(0, last_dot) + ".fused." + var.substr(last_dot + 1); +} + +// The bounds of every loop exist in 'replacements' should be replaced. The +// loop is also renamed by adding '.fused' in the original name before the +// variable name. +Stmt substitute_fused_bounds(Stmt s, const map &replacements) { + if (!s.defined() || replacements.empty()) { + return s; } -private: - using IRMutator::visit; + class SubstituteFusedBounds : public IRMutator { + const map &replacements; - Stmt visit(const For *op) override { - const auto *min_var = op->min.as(); - const auto *extent_var = op->extent.as(); - if (min_var && extent_var) { - Expr min_val, extent_val; - { - const auto &it = replacements.find(min_var->name); - if (it != replacements.end()) { - min_val = it->second; + using IRMutator::visit; + + Stmt visit(const For *op) override { + const auto *min_var = op->min.as(); + const auto *extent_var = op->extent.as(); + if (min_var && extent_var) { + Expr min_val, extent_val; + { + const auto &it = replacements.find(min_var->name); + if (it != replacements.end()) { + min_val = it->second; + } } - } - { - const auto &it = replacements.find(extent_var->name); - if (it != replacements.end()) { - extent_val = it->second; + { + const auto &it = replacements.find(extent_var->name); + if (it != replacements.end()) { + extent_val = it->second; + } + } + if (!min_val.defined() || !extent_val.defined()) { + return IRMutator::visit(op); + } + + Stmt body = mutate(op->body); + + string new_var = fused_name(op->name); + + ForType for_type = op->for_type; + DeviceAPI device_api = op->device_api; + if (is_const_one(extent_val)) { + // This is the child loop of a fused group. The real loop of the + // fused group is the loop of the parent function of the fused + // group. This child loop is just a scheduling point, and should + // never be a device transition, so we rewrite it to be a simple + // serial loop of extent 1." + for_type = ForType::Serial; + device_api = DeviceAPI::None; } + + Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), + Variable::make(Int(32), new_var + ".loop_extent"), + for_type, op->partition_policy, device_api, body); + + // Add let stmts defining the bound of the renamed for-loop. + stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); + stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt); + stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt); + // Replace any reference to the old loop name with the new one. + stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); + return stmt; + } else { + return IRMutator::visit(op); } - if (!min_val.defined() || !extent_val.defined()) { + } + + public: + explicit SubstituteFusedBounds(const map &r) + : replacements(r) { + } + } subs(replacements); + + return subs.mutate(s); +} + +// Add letstmts inside each parent loop that define the corresponding child loop +// vars as equal to it. Bounds inference might need a child loop var. +Stmt add_loop_var_aliases(Stmt s, const map> &loop_var_aliases) { + if (!s.defined() || loop_var_aliases.empty()) { + return s; + } + + class AddLoopVarAliases : public IRMutator { + const map> &loop_var_aliases; + + using IRMutator::visit; + + Stmt visit(const For *op) override { + auto it = loop_var_aliases.find(op->name); + if (it == loop_var_aliases.end()) { return IRMutator::visit(op); } + Expr var = Variable::make(Int(32), op->name); Stmt body = mutate(op->body); - - size_t last_dot = op->name.rfind('.'); - internal_assert(last_dot != string::npos); - string new_var = op->name.substr(0, last_dot) + ".fused." + op->name.substr(last_dot + 1); - - ForType for_type = op->for_type; - DeviceAPI device_api = op->device_api; - if (is_const_one(extent_val)) { - // This is the child loop of a fused group. The real loop of the - // fused group is the loop of the parent function of the fused - // group. This child loop is just a scheduling point, and should - // never be a device transition, so we rewrite it to be a simple - // serial loop of extent 1." - for_type = ForType::Serial; - device_api = DeviceAPI::None; + for (const string &alias : it->second) { + body = LetStmt::make(alias, var, body); } - Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), - Variable::make(Int(32), new_var + ".loop_extent"), - for_type, op->partition_policy, device_api, body); + return For::make(op->name, op->min, op->extent, op->for_type, + op->partition_policy, op->device_api, std::move(body)); + } - // Add let stmts defining the bound of the renamed for-loop. - stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); - stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt); - stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt); - // Replace any reference to the old loop name with the new one. - stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); - return stmt; - } else { - return IRMutator::visit(op); + public: + explicit AddLoopVarAliases(const map> &a) + : loop_var_aliases(a) { } - } -}; + } add_aliases(loop_var_aliases); -// The bounds of every loop exist in 'replacements' should be replaced. The -// loop is also renamed by adding '.fused' in the original name before the -// variable name. -Stmt substitute_fused_bounds(Stmt s, const map &replacements) { - if (!s.defined() || replacements.empty()) { - return s; - } else { - return SubstituteFusedBounds(replacements).mutate(s); - } + return add_aliases.mutate(s); } // Shift the iteration domain of a loop nest by some factor. @@ -1460,7 +1505,9 @@ class InjectFunctionRealization : public IRMutator { } Stmt build_produce_definition(const Function &f, const string &prefix, const Definition &def, bool is_update, - map &replacements, vector> &add_lets) { + map &replacements, + vector> &add_lets, + map> &aliases) { const vector &dims = def.schedule().dims(); // From inner to outer const LoopLevel &fuse_level = def.schedule().fuse_level().level; @@ -1499,6 +1546,10 @@ class InjectFunctionRealization : public IRMutator { replacements.emplace(var + ".loop_extent", make_const(Int(32), 1)); replacements.emplace(var + ".loop_min", val); replacements.emplace(var + ".loop_max", val); + + string var_fused = fused_name(var_orig); + aliases[var_fused].emplace(std::move(var_orig)); + aliases[var_fused].emplace(std::move(var)); } } @@ -1550,18 +1601,17 @@ class InjectFunctionRealization : public IRMutator { // Replace the bounds of the parent fused loop (i.e. the first one to be // realized in the group) with union of the bounds of the fused group. - Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, const map &bounds) { - string prefix = f.name() + ".s0"; - const Definition &def = f.definition(); + Stmt replace_parent_bound_with_union_bound(string func, int stage, + const Definition &def, Stmt produce, + const map &bounds, + map &replacements) { - if (!def.defined()) { + if (def.schedule().fused_pairs().empty()) { return produce; } const vector &dims = def.schedule().dims(); // From inner to outer - map replacements; - vector dependence = collect_all_dependence(def); // Compute the union of the bounds of the fused loops. @@ -1582,6 +1632,8 @@ class InjectFunctionRealization : public IRMutator { // the parent, e.g. y.yi and yi. int dim2_idx = (int)(dims_2.size() - (dims.size() - i)); internal_assert(dim2_idx < (int)dims_2.size()); + string var_1 = func + ".s" + std::to_string(stage) + + "." + dims[i].var; string var_2 = pair.func_2 + ".s" + std::to_string(pair.stage_2) + "." + dims_2[dim2_idx].var; @@ -1592,7 +1644,6 @@ class InjectFunctionRealization : public IRMutator { Expr max_2 = bounds.find(var_2 + ".loop_max")->second; Expr extent_2 = bounds.find(var_2 + ".loop_extent")->second; - string var_1 = prefix + "." + dims[i].var; internal_assert(bounds.count(var_1 + ".loop_min")); internal_assert(bounds.count(var_1 + ".loop_max")); internal_assert(bounds.count(var_1 + ".loop_extent")); @@ -1616,8 +1667,26 @@ class InjectFunctionRealization : public IRMutator { } } - // Now, replace the bounds of the parent fused loops with the union bounds. + // Now, replace the bounds of the parent fused loops with the union + // bounds. + for (auto &spec : def.specializations()) { + produce = replace_parent_bound_with_union_bound(func, stage, spec.definition, produce, bounds, replacements); + } + + return produce; + } + + Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, + const map &bounds) { + map replacements; + + int stage = 0; + produce = replace_parent_bound_with_union_bound(f.name(), stage++, f.definition(), produce, bounds, replacements); + for (const Definition &def : f.updates()) { + produce = replace_parent_bound_with_union_bound(f.name(), stage++, def, produce, bounds, replacements); + } produce = substitute_fused_bounds(produce, replacements); + return produce; } @@ -1748,22 +1817,23 @@ class InjectFunctionRealization : public IRMutator { Stmt producer; map replacements; vector> add_lets; + map> aliases; for (const auto &func_stage : stage_order) { const auto &f = func_stage.first; if (f.has_extern_definition() && (func_stage.second == 0)) { - const Stmt &produceDef = Internal::build_extern_produce(env, f, target); - producer = inject_stmt(producer, produceDef, LoopLevel::inlined().lock()); + const Stmt &produce_def = Internal::build_extern_produce(env, f, target); + producer = inject_stmt(producer, produce_def, LoopLevel::inlined().lock()); continue; } string def_prefix = f.name() + ".s" + std::to_string(func_stage.second) + "."; const auto &def = (func_stage.second == 0) ? f.definition() : f.updates()[func_stage.second - 1]; - const Stmt &produceDef = build_produce_definition(f, def_prefix, def, func_stage.second > 0, - replacements, add_lets); - producer = inject_stmt(producer, produceDef, def.schedule().fuse_level().level); + const Stmt &produce_def = build_produce_definition(f, def_prefix, def, func_stage.second > 0, + replacements, add_lets, aliases); + producer = inject_stmt(producer, produce_def, def.schedule().fuse_level().level); } internal_assert(producer.defined()); @@ -1799,8 +1869,14 @@ class InjectFunctionRealization : public IRMutator { // Replace the bounds of parent fused loop with union of bounds of // the fused loops. + Function group_parent = funcs.back(); producer = replace_parent_bound_with_union_bound(funcs.back(), producer, bounds); + // Define the old loop var names as equal to the corresponding parent + // fused loop var. Bounds inference might refer directly to the original + // loop vars. + producer = add_loop_var_aliases(producer, aliases); + // Add the producer nodes. for (const auto &i : funcs) { producer = ProducerConsumer::make_produce(i.name(), producer); diff --git a/test/correctness/compute_with.cpp b/test/correctness/compute_with.cpp index 053570a2f5c0..0152642028eb 100644 --- a/test/correctness/compute_with.cpp +++ b/test/correctness/compute_with.cpp @@ -2204,6 +2204,89 @@ int two_compute_at_test() { return 0; } +// Test for the issue described in https://github.com/halide/Halide/issues/8149. +int child_var_dependent_bounds_test() { + Func f{"f"}, g{"g"}; + Var x{"x"}, y{"y"}; + RDom r(0, 10, "r"); + + Func f_inter{"f_inter"}, g_inter{"g_inter"}; + + f_inter(x, y) = x; + f_inter(x, y) += 1; + f(x) = x; + f(x) += f_inter(x, r); + + g_inter(x, y) = x; + g_inter(x, y) += 1; + g(x) = x; + g(x) += g_inter(x, r); + + f_inter.compute_at(f, r); + g_inter.compute_at(f, r); + g.update().compute_with(f.update(), r); + f.update().unscheduled(); + + Pipeline p({f, g}); + + p.compile_jit(); + Buffer f_buf(10), g_buf(10); + + f_buf.set_min(2); + p.realize({f_buf, g_buf}); + f_buf.set_min(0); + + for (int i = 0; i < 10; i++) { + int correct_f = 10 + 11 * (i + 2); + int correct_g = 10 + 11 * i; + if (f_buf(i) != correct_f) { + printf("f(%d) = %d instead of %d\n", i, f_buf(i), correct_f); + } + if (g_buf(i) != correct_g) { + printf("g(%d) = %d instead of %d\n", i, g_buf(i), correct_f); + } + } + + return 0; +} + +int overlapping_updates_test() { + Func f{"f"}, g{"g"}; + Var x{"x"}; + + f(x) = 0; + f(x) += x; + g(x) = 0; + g(x) += x; + + g.update().compute_with(f.update(), x); + f.update().unscheduled(); + + Pipeline p({f, g}); + + p.compile_jit(); + Buffer f_buf(10), g_buf(10); + + f_buf.set_min(2); + p.realize({f_buf, g_buf}); + f_buf.set_min(0); + + for (int i = 0; i < 10; i++) { + int correct_f = i + 2; + int correct_g = i; + if (f_buf(i) != correct_f) { + printf("f(%d) = %d instead of %d\n", i, f_buf(i), correct_f); + return 1; + } + if (g_buf(i) != correct_g) { + printf("g(%d) = %d instead of %d\n", i, g_buf(i), correct_f); + return 1; + } + } + + return 0; +} + } // namespace int main(int argc, char **argv) { @@ -2247,7 +2330,9 @@ int main(int argc, char **argv) { {"different arg number compute_at test", different_arg_num_compute_at_test}, {"store_at different levels test", store_at_different_levels_test}, {"rvar bounds test", rvar_bounds_test}, - {"two_compute_at test", two_compute_at_test}, + {"two compute at test", two_compute_at_test}, + {"overlapping updates test", overlapping_updates_test}, + {"child var dependent bounds test", child_var_dependent_bounds_test}, }; using Sharder = Halide::Internal::Test::Sharder; From 44b4d74645c0b199ce4cec7f00576a94a0c4ec86 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 12 Mar 2024 17:23:03 -0700 Subject: [PATCH 2/2] clang-tidy --- src/ScheduleFunctions.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 88f15ec3d36d..8fa2fd71a7a2 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -1601,7 +1601,7 @@ class InjectFunctionRealization : public IRMutator { // Replace the bounds of the parent fused loop (i.e. the first one to be // realized in the group) with union of the bounds of the fused group. - Stmt replace_parent_bound_with_union_bound(string func, int stage, + Stmt replace_parent_bound_with_union_bound(const string &func, int stage, const Definition &def, Stmt produce, const map &bounds, map &replacements) { @@ -1669,7 +1669,7 @@ class InjectFunctionRealization : public IRMutator { // Now, replace the bounds of the parent fused loops with the union // bounds. - for (auto &spec : def.specializations()) { + for (const auto &spec : def.specializations()) { produce = replace_parent_bound_with_union_bound(func, stage, spec.definition, produce, bounds, replacements); }