diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index aa45841253b7..8fa2fd71a7a2 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -1021,81 +1021,126 @@ class CollectBounds : public IRVisitor { } }; -class SubstituteFusedBounds : public IRMutator { -public: - const map &replacements; - explicit SubstituteFusedBounds(const map &r) - : replacements(r) { +// Rename a loop var in a compute_with cluster to include '.fused.', to +// disambiguate its bounds from the original loop bounds. The '.fused.' token is +// injected somewhere that's not going to change the results of var_name_match, +// so that it's unchanged as a scheduling point. +string fused_name(const string &var) { + size_t last_dot = var.rfind('.'); + internal_assert(last_dot != string::npos); + return var.substr(0, last_dot) + ".fused." + var.substr(last_dot + 1); +} + +// The bounds of every loop exist in 'replacements' should be replaced. The +// loop is also renamed by adding '.fused' in the original name before the +// variable name. +Stmt substitute_fused_bounds(Stmt s, const map &replacements) { + if (!s.defined() || replacements.empty()) { + return s; } -private: - using IRMutator::visit; + class SubstituteFusedBounds : public IRMutator { + const map &replacements; - Stmt visit(const For *op) override { - const auto *min_var = op->min.as(); - const auto *extent_var = op->extent.as(); - if (min_var && extent_var) { - Expr min_val, extent_val; - { - const auto &it = replacements.find(min_var->name); - if (it != replacements.end()) { - min_val = it->second; + using IRMutator::visit; + + Stmt visit(const For *op) override { + const auto *min_var = op->min.as(); + const auto *extent_var = op->extent.as(); + if (min_var && extent_var) { + Expr min_val, extent_val; + { + const auto &it = replacements.find(min_var->name); + if (it != replacements.end()) { + min_val = it->second; + } } - } - { - const auto &it = replacements.find(extent_var->name); - if (it != replacements.end()) { - extent_val = it->second; + { + const auto &it = replacements.find(extent_var->name); + if (it != replacements.end()) { + extent_val = it->second; + } + } + if (!min_val.defined() || !extent_val.defined()) { + return IRMutator::visit(op); + } + + Stmt body = mutate(op->body); + + string new_var = fused_name(op->name); + + ForType for_type = op->for_type; + DeviceAPI device_api = op->device_api; + if (is_const_one(extent_val)) { + // This is the child loop of a fused group. The real loop of the + // fused group is the loop of the parent function of the fused + // group. This child loop is just a scheduling point, and should + // never be a device transition, so we rewrite it to be a simple + // serial loop of extent 1." + for_type = ForType::Serial; + device_api = DeviceAPI::None; } + + Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), + Variable::make(Int(32), new_var + ".loop_extent"), + for_type, op->partition_policy, device_api, body); + + // Add let stmts defining the bound of the renamed for-loop. + stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); + stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt); + stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt); + // Replace any reference to the old loop name with the new one. + stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); + return stmt; + } else { + return IRMutator::visit(op); } - if (!min_val.defined() || !extent_val.defined()) { + } + + public: + explicit SubstituteFusedBounds(const map &r) + : replacements(r) { + } + } subs(replacements); + + return subs.mutate(s); +} + +// Add letstmts inside each parent loop that define the corresponding child loop +// vars as equal to it. Bounds inference might need a child loop var. +Stmt add_loop_var_aliases(Stmt s, const map> &loop_var_aliases) { + if (!s.defined() || loop_var_aliases.empty()) { + return s; + } + + class AddLoopVarAliases : public IRMutator { + const map> &loop_var_aliases; + + using IRMutator::visit; + + Stmt visit(const For *op) override { + auto it = loop_var_aliases.find(op->name); + if (it == loop_var_aliases.end()) { return IRMutator::visit(op); } + Expr var = Variable::make(Int(32), op->name); Stmt body = mutate(op->body); - - size_t last_dot = op->name.rfind('.'); - internal_assert(last_dot != string::npos); - string new_var = op->name.substr(0, last_dot) + ".fused." + op->name.substr(last_dot + 1); - - ForType for_type = op->for_type; - DeviceAPI device_api = op->device_api; - if (is_const_one(extent_val)) { - // This is the child loop of a fused group. The real loop of the - // fused group is the loop of the parent function of the fused - // group. This child loop is just a scheduling point, and should - // never be a device transition, so we rewrite it to be a simple - // serial loop of extent 1." - for_type = ForType::Serial; - device_api = DeviceAPI::None; + for (const string &alias : it->second) { + body = LetStmt::make(alias, var, body); } - Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), - Variable::make(Int(32), new_var + ".loop_extent"), - for_type, op->partition_policy, device_api, body); + return For::make(op->name, op->min, op->extent, op->for_type, + op->partition_policy, op->device_api, std::move(body)); + } - // Add let stmts defining the bound of the renamed for-loop. - stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); - stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt); - stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt); - // Replace any reference to the old loop name with the new one. - stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); - return stmt; - } else { - return IRMutator::visit(op); + public: + explicit AddLoopVarAliases(const map> &a) + : loop_var_aliases(a) { } - } -}; + } add_aliases(loop_var_aliases); -// The bounds of every loop exist in 'replacements' should be replaced. The -// loop is also renamed by adding '.fused' in the original name before the -// variable name. -Stmt substitute_fused_bounds(Stmt s, const map &replacements) { - if (!s.defined() || replacements.empty()) { - return s; - } else { - return SubstituteFusedBounds(replacements).mutate(s); - } + return add_aliases.mutate(s); } // Shift the iteration domain of a loop nest by some factor. @@ -1460,7 +1505,9 @@ class InjectFunctionRealization : public IRMutator { } Stmt build_produce_definition(const Function &f, const string &prefix, const Definition &def, bool is_update, - map &replacements, vector> &add_lets) { + map &replacements, + vector> &add_lets, + map> &aliases) { const vector &dims = def.schedule().dims(); // From inner to outer const LoopLevel &fuse_level = def.schedule().fuse_level().level; @@ -1499,6 +1546,10 @@ class InjectFunctionRealization : public IRMutator { replacements.emplace(var + ".loop_extent", make_const(Int(32), 1)); replacements.emplace(var + ".loop_min", val); replacements.emplace(var + ".loop_max", val); + + string var_fused = fused_name(var_orig); + aliases[var_fused].emplace(std::move(var_orig)); + aliases[var_fused].emplace(std::move(var)); } } @@ -1550,18 +1601,17 @@ class InjectFunctionRealization : public IRMutator { // Replace the bounds of the parent fused loop (i.e. the first one to be // realized in the group) with union of the bounds of the fused group. - Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, const map &bounds) { - string prefix = f.name() + ".s0"; - const Definition &def = f.definition(); + Stmt replace_parent_bound_with_union_bound(const string &func, int stage, + const Definition &def, Stmt produce, + const map &bounds, + map &replacements) { - if (!def.defined()) { + if (def.schedule().fused_pairs().empty()) { return produce; } const vector &dims = def.schedule().dims(); // From inner to outer - map replacements; - vector dependence = collect_all_dependence(def); // Compute the union of the bounds of the fused loops. @@ -1582,6 +1632,8 @@ class InjectFunctionRealization : public IRMutator { // the parent, e.g. y.yi and yi. int dim2_idx = (int)(dims_2.size() - (dims.size() - i)); internal_assert(dim2_idx < (int)dims_2.size()); + string var_1 = func + ".s" + std::to_string(stage) + + "." + dims[i].var; string var_2 = pair.func_2 + ".s" + std::to_string(pair.stage_2) + "." + dims_2[dim2_idx].var; @@ -1592,7 +1644,6 @@ class InjectFunctionRealization : public IRMutator { Expr max_2 = bounds.find(var_2 + ".loop_max")->second; Expr extent_2 = bounds.find(var_2 + ".loop_extent")->second; - string var_1 = prefix + "." + dims[i].var; internal_assert(bounds.count(var_1 + ".loop_min")); internal_assert(bounds.count(var_1 + ".loop_max")); internal_assert(bounds.count(var_1 + ".loop_extent")); @@ -1616,8 +1667,26 @@ class InjectFunctionRealization : public IRMutator { } } - // Now, replace the bounds of the parent fused loops with the union bounds. + // Now, replace the bounds of the parent fused loops with the union + // bounds. + for (const auto &spec : def.specializations()) { + produce = replace_parent_bound_with_union_bound(func, stage, spec.definition, produce, bounds, replacements); + } + + return produce; + } + + Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, + const map &bounds) { + map replacements; + + int stage = 0; + produce = replace_parent_bound_with_union_bound(f.name(), stage++, f.definition(), produce, bounds, replacements); + for (const Definition &def : f.updates()) { + produce = replace_parent_bound_with_union_bound(f.name(), stage++, def, produce, bounds, replacements); + } produce = substitute_fused_bounds(produce, replacements); + return produce; } @@ -1748,22 +1817,23 @@ class InjectFunctionRealization : public IRMutator { Stmt producer; map replacements; vector> add_lets; + map> aliases; for (const auto &func_stage : stage_order) { const auto &f = func_stage.first; if (f.has_extern_definition() && (func_stage.second == 0)) { - const Stmt &produceDef = Internal::build_extern_produce(env, f, target); - producer = inject_stmt(producer, produceDef, LoopLevel::inlined().lock()); + const Stmt &produce_def = Internal::build_extern_produce(env, f, target); + producer = inject_stmt(producer, produce_def, LoopLevel::inlined().lock()); continue; } string def_prefix = f.name() + ".s" + std::to_string(func_stage.second) + "."; const auto &def = (func_stage.second == 0) ? f.definition() : f.updates()[func_stage.second - 1]; - const Stmt &produceDef = build_produce_definition(f, def_prefix, def, func_stage.second > 0, - replacements, add_lets); - producer = inject_stmt(producer, produceDef, def.schedule().fuse_level().level); + const Stmt &produce_def = build_produce_definition(f, def_prefix, def, func_stage.second > 0, + replacements, add_lets, aliases); + producer = inject_stmt(producer, produce_def, def.schedule().fuse_level().level); } internal_assert(producer.defined()); @@ -1799,8 +1869,14 @@ class InjectFunctionRealization : public IRMutator { // Replace the bounds of parent fused loop with union of bounds of // the fused loops. + Function group_parent = funcs.back(); producer = replace_parent_bound_with_union_bound(funcs.back(), producer, bounds); + // Define the old loop var names as equal to the corresponding parent + // fused loop var. Bounds inference might refer directly to the original + // loop vars. + producer = add_loop_var_aliases(producer, aliases); + // Add the producer nodes. for (const auto &i : funcs) { producer = ProducerConsumer::make_produce(i.name(), producer); diff --git a/test/correctness/compute_with.cpp b/test/correctness/compute_with.cpp index 053570a2f5c0..0152642028eb 100644 --- a/test/correctness/compute_with.cpp +++ b/test/correctness/compute_with.cpp @@ -2204,6 +2204,89 @@ int two_compute_at_test() { return 0; } +// Test for the issue described in https://github.com/halide/Halide/issues/8149. +int child_var_dependent_bounds_test() { + Func f{"f"}, g{"g"}; + Var x{"x"}, y{"y"}; + RDom r(0, 10, "r"); + + Func f_inter{"f_inter"}, g_inter{"g_inter"}; + + f_inter(x, y) = x; + f_inter(x, y) += 1; + f(x) = x; + f(x) += f_inter(x, r); + + g_inter(x, y) = x; + g_inter(x, y) += 1; + g(x) = x; + g(x) += g_inter(x, r); + + f_inter.compute_at(f, r); + g_inter.compute_at(f, r); + g.update().compute_with(f.update(), r); + f.update().unscheduled(); + + Pipeline p({f, g}); + + p.compile_jit(); + Buffer f_buf(10), g_buf(10); + + f_buf.set_min(2); + p.realize({f_buf, g_buf}); + f_buf.set_min(0); + + for (int i = 0; i < 10; i++) { + int correct_f = 10 + 11 * (i + 2); + int correct_g = 10 + 11 * i; + if (f_buf(i) != correct_f) { + printf("f(%d) = %d instead of %d\n", i, f_buf(i), correct_f); + } + if (g_buf(i) != correct_g) { + printf("g(%d) = %d instead of %d\n", i, g_buf(i), correct_f); + } + } + + return 0; +} + +int overlapping_updates_test() { + Func f{"f"}, g{"g"}; + Var x{"x"}; + + f(x) = 0; + f(x) += x; + g(x) = 0; + g(x) += x; + + g.update().compute_with(f.update(), x); + f.update().unscheduled(); + + Pipeline p({f, g}); + + p.compile_jit(); + Buffer f_buf(10), g_buf(10); + + f_buf.set_min(2); + p.realize({f_buf, g_buf}); + f_buf.set_min(0); + + for (int i = 0; i < 10; i++) { + int correct_f = i + 2; + int correct_g = i; + if (f_buf(i) != correct_f) { + printf("f(%d) = %d instead of %d\n", i, f_buf(i), correct_f); + return 1; + } + if (g_buf(i) != correct_g) { + printf("g(%d) = %d instead of %d\n", i, g_buf(i), correct_f); + return 1; + } + } + + return 0; +} + } // namespace int main(int argc, char **argv) { @@ -2247,7 +2330,9 @@ int main(int argc, char **argv) { {"different arg number compute_at test", different_arg_num_compute_at_test}, {"store_at different levels test", store_at_different_levels_test}, {"rvar bounds test", rvar_bounds_test}, - {"two_compute_at test", two_compute_at_test}, + {"two compute at test", two_compute_at_test}, + {"overlapping updates test", overlapping_updates_test}, + {"child var dependent bounds test", child_var_dependent_bounds_test}, }; using Sharder = Halide::Internal::Test::Sharder;