Skip to content

Commit

Permalink
Fix type error in VectorizeLoops (#8055)
Browse files Browse the repository at this point in the history
  • Loading branch information
abadams authored and steven-johnson committed Feb 1, 2024
1 parent 2111594 commit 3577f88
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Interval bounds_of_lanes(const Expr &e) {
Interval ia = bounds_of_lanes(not_->a);
return {!ia.max, !ia.min};
} else if (const Ramp *r = e.as<Ramp>()) {
Expr last_lane_idx = make_const(r->base.type(), r->lanes - 1);
Expr last_lane_idx = make_const(r->base.type().element_of(), r->lanes - 1);
Interval ib = bounds_of_lanes(r->base);
const Broadcast *b = as_scalar_broadcast(r->stride);
Expr stride = b ? b->value : r->stride;
Expand Down Expand Up @@ -875,6 +875,7 @@ class VectorSubs : public IRMutator {
// generating a scalar condition that checks if
// the least-true lane is true.
Expr all_true = bounds_of_lanes(likely->args[0]).min;
internal_assert(all_true.type() == Bool());
// Wrap it in the same flavor of likely
all_true = Call::make(Bool(), likely->name,
{all_true}, Call::PureIntrinsic);
Expand Down
68 changes: 68 additions & 0 deletions test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,74 @@ int main(int argc, char **argv) {
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/8054
{
ImageParam input(Float(32), 2, "input");
const float r_sigma = 0.1;
const int s_sigma = 8;
Func bilateral_grid{"bilateral_grid"};

Var x("x"), y("y"), z("z"), c("c");

// Add a boundary condition
Func clamped = Halide::BoundaryConditions::repeat_edge(input);

// Construct the bilateral grid
RDom r(0, s_sigma, 0, s_sigma);
Expr val = clamped(x * s_sigma + r.x - s_sigma / 2, y * s_sigma + r.y - s_sigma / 2);
val = clamp(val, 0.0f, 1.0f);

Expr zi = cast<int>(val * (1.0f / r_sigma) + 0.5f);

Func histogram("histogram");
histogram(x, y, z, c) = 0.0f;
histogram(x, y, zi, c) += mux(c, {val, 1.0f});

// Blur the grid using a five-tap filter
Func blurx("blurx"), blury("blury"), blurz("blurz");
blurz(x, y, z, c) = (histogram(x, y, z - 2, c) +
histogram(x, y, z - 1, c) * 4 +
histogram(x, y, z, c) * 6 +
histogram(x, y, z + 1, c) * 4 +
histogram(x, y, z + 2, c));
blurx(x, y, z, c) = (blurz(x - 2, y, z, c) +
blurz(x - 1, y, z, c) * 4 +
blurz(x, y, z, c) * 6 +
blurz(x + 1, y, z, c) * 4 +
blurz(x + 2, y, z, c));
blury(x, y, z, c) = (blurx(x, y - 2, z, c) +
blurx(x, y - 1, z, c) * 4 +
blurx(x, y, z, c) * 6 +
blurx(x, y + 1, z, c) * 4 +
blurx(x, y + 2, z, c));

// Take trilinear samples to compute the output
val = clamp(input(x, y), 0.0f, 1.0f);
Expr zv = val * (1.0f / r_sigma);
zi = cast<int>(zv);
Expr zf = zv - zi;
Expr xf = cast<float>(x % s_sigma) / s_sigma;
Expr yf = cast<float>(y % s_sigma) / s_sigma;
Expr xi = x / s_sigma;
Expr yi = y / s_sigma;
Func interpolated("interpolated");
interpolated(x, y, c) =
lerp(lerp(lerp(blury(xi, yi, zi, c), blury(xi + 1, yi, zi, c), xf),
lerp(blury(xi, yi + 1, zi, c), blury(xi + 1, yi + 1, zi, c), xf), yf),
lerp(lerp(blury(xi, yi, zi + 1, c), blury(xi + 1, yi, zi + 1, c), xf),
lerp(blury(xi, yi + 1, zi + 1, c), blury(xi + 1, yi + 1, zi + 1, c), xf), yf),
zf);

// Normalize
bilateral_grid(x, y) = interpolated(x, y, 0) / interpolated(x, y, 1);
Pipeline p({bilateral_grid});

Var v6, zo, vzi;

blury.compute_root().split(x, x, v6, 6, TailStrategy::GuardWithIf).split(z, zo, vzi, 8, TailStrategy::GuardWithIf).reorder(y, x, c, vzi, zo, v6).vectorize(vzi).vectorize(v6);
p.compile_to_module({input}, "bilateral_grid", {Target("host")});
}

printf("Success!\n");
return 0;
}

0 comments on commit 3577f88

Please sign in to comment.