Skip to content

Commit

Permalink
[XLA] give Array an option to slice out of bounds.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625331947
  • Loading branch information
blakehechtman authored and jax authors committed Apr 27, 2024
1 parent 755f350 commit 687447c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 58 deletions.
51 changes: 27 additions & 24 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -193,7 +193,7 @@ xla::Array<Value> concatenate(const ArrayRef<xla::Array<Value>> arrays,
}
xla::Array<Value> res(dims);
int64_t offset = 0;
for (xla::Array<Value> const& arr : arrays) {
for (xla::Array<Value> const &arr : arrays) {
arr.Each([&](const absl::Span<const int64_t> idx, const Value v) {
SmallVector<int64_t> res_idx(toArrayRef(idx));
res_idx[axis] += offset;
Expand Down Expand Up @@ -249,7 +249,7 @@ bool incrementIndex(const MutableArrayRef<int64_t> idx,
}

bool sliceIsEmpty(const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
const absl::Span<const int64_t> limits) {
for (auto [s, l] : llvm::zip_equal(starts, limits)) {
CHECK_LE(s, l);
if (s == l) {
Expand Down Expand Up @@ -282,9 +282,19 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
return;
}
SmallVector<int64_t> idx(toArrayRef(starts));
auto in_bounds = [&] {
for (int64_t i = 0; i < idx.size(); ++i) {
if (idx[i] >= limits[i]) {
return false;
}
}
return true;
};
auto data_it = data.begin();
do {
arr(idx) = *data_it;
if (in_bounds()) {
arr(idx) = *data_it;
}
++data_it;
} while (incrementSliceIndex(idx, starts, limits));
CHECK(data_it == data.end());
Expand Down Expand Up @@ -1307,7 +1317,7 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
}
tpu::LoadOp load_op = cast<tpu::LoadOp>(op);
if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone)) {
VectorLayout::ImplicitDim::kNone)) {
return op.emitOpError("Invalid output layout for ") << load_op->getName();
}
FAILUREOR_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -1863,7 +1873,7 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_EQ_OP(op.getNumResults(), 1);
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
if (layouts_in[0] !=layouts_out[0]) {
if (layouts_in[0] != layouts_out[0]) {
return op.emitOpError("Expected same input and output layout");
}
OpBuilder builder(&op);
Expand Down Expand Up @@ -2622,8 +2632,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
const LayoutOffsets offsets_in = layout_in.offsets();
const LayoutOffsets offsets_out = layout_out.offsets();
if (layout_in.tiling() != layout_out.tiling()) {
return op.emitOpError(
"Not implemented: Changing tiling mid-broadcast");
return op.emitOpError("Not implemented: Changing tiling mid-broadcast");
}
auto tiling = layout_in.tiling();

Expand Down Expand Up @@ -2745,8 +2754,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
VectorType::get(ctx.target_shape, builder.getI32Type());
auto idx_const = builder.create<arith::ConstantOp>(
broadcast_op.getLoc(), idx_ty,
DenseElementsAttr::get(idx_ty,
builder.getI32IntegerAttr(offset)));
DenseElementsAttr::get(idx_ty, builder.getI32IntegerAttr(offset)));
int64_t sublanes_per_tile = layout_in.sublanesPerTile(ctx.target_shape);
DenseI32ArrayAttr sublane_pattern;
if (num_tiles != 1) {
Expand Down Expand Up @@ -3687,11 +3695,6 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
"Not implemented: Non-native or offset layout unsupported");
}
const int64_t transpose_unit_size = ctx.target_shape[1];
for (const int64_t s : src_ty.getShape().take_back(2)) {
if (s % transpose_unit_size != 0) {
return transpose_op->emitOpError("Not implemented: Padded transpose");
}
}
if (ctx.hardware_generation < 4 && layout_in.bitwidth() != 32) {
return transpose_op->emitOpError(
"Not implemented: TPUs before v4 only support 32-bit transposes");
Expand Down Expand Up @@ -3730,8 +3733,8 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
src_slice_ends.append(incremented_batch_idx.begin(),
incremented_batch_idx.end());
src_slice_ends.append({(src_row + 1) * vregs_per_tile, src_col_end});
xla::Array<Value> src_tile_vregs =
src_vregs.Slice(src_slice_starts, src_slice_ends);
xla::Array<Value> src_tile_vregs = src_vregs.Slice(
src_slice_starts, src_slice_ends, /*out_of_bounds_ok=*/true);
// Drop leading singleton (batch) dimensions to have a shape that conforms
// with the vreg array shape specified by layout_in, as expected by assemble
src_tile_vregs.Reshape(
Expand Down Expand Up @@ -3762,12 +3765,12 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<int64_t> batch_sizes =
dst_ty.getShape().take_front(num_batch_dims);
SmallVector<int64_t> batch_idx(num_batch_dims);
const int64_t tile_rows =
xla::CeilOfRatio(*(src_ty.getShape().end() - 2), transpose_unit_size);
const int64_t num_col_tiles =
xla::CeilOfRatio(*(src_ty.getShape().end() - 1), transpose_unit_size);
do {
const int64_t tile_rows =
*(src_ty.getShape().end() - 2) / transpose_unit_size;
for (int64_t src_row = 0; src_row < tile_rows; ++src_row) {
const int64_t num_col_tiles =
*(src_ty.getShape().end() - 1) / transpose_unit_size;
if (can_batch) {
const int64_t num_batch_tiles = num_col_tiles / 2;
for (int64_t src_col = 0; src_col < num_batch_tiles; ++src_col) {
Expand Down Expand Up @@ -4307,7 +4310,7 @@ FailureOr<TypedValue<VectorType>> relayout(
*(src_tiles.dimensions().end() - 2) == 1)) &&
dst.offsets()[1] == 0 && src.tiling() == std::array<int64_t, 2>{1, 128} &&
dst.tiling() == std::array<int64_t, 2>{8, 128}) {
xla::Array<Value> src_tiles_retiled(
xla::Array<Value> src_tiles_retiled(
dst.tileArrayShape(vty.getShape(), target_shape));
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
for (int dst_sl_idx = 0; dst_sl_idx < 8; ++dst_sl_idx) {
Expand Down Expand Up @@ -4466,8 +4469,8 @@ FailureOr<TypedValue<VectorType>> relayout(
v.getLoc(), bits_vreg_ty,
DenseElementsAttr::get(bits_vreg_ty, shift_bits));
dst_tiles.Each([&](absl::Span<const int64_t> /*idx*/, Value *tile) {
auto bit_tile =
builder.create<tpu::BitcastVregOp>(v.getLoc(), bits_vreg_ty, *tile);
auto bit_tile = builder.create<tpu::BitcastVregOp>(
v.getLoc(), bits_vreg_ty, *tile);
Operation *shift_tile;
if (subelem_diff > 0) {
shift_tile =
Expand All @@ -4479,7 +4482,7 @@ FailureOr<TypedValue<VectorType>> relayout(
}
*tile = builder
.create<tpu::BitcastVregOp>(v.getLoc(), tile->getType(),
shift_tile->getResult(0))
shift_tile->getResult(0))
.getResult();
return absl::OkStatus();
});
Expand Down
57 changes: 23 additions & 34 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Expand Up @@ -898,8 +898,9 @@ class VectorLayoutInferer {
if (some_layout->tiling()[0] == 1) {
offsets[0] = std::nullopt;
}
*some_layout = VectorLayout(some_layout->bitwidth(), offsets,
default_tiling_, some_layout->implicit_dim());
*some_layout =
VectorLayout(some_layout->bitwidth(), offsets, default_tiling_,
some_layout->implicit_dim());
}
auto &layout = *some_layout;
if (layout.implicit_dim() != ImplicitDim::kNone) {
Expand Down Expand Up @@ -1410,44 +1411,32 @@ class VectorLayoutInferer {

LogicalResult infer(vector::TransposeOp op) {
auto permutation = op.getPermutation();
TPU_CHECK_OP(permutation.size() > 1,
"Vector and scalar transpose should be a no-op and removed");

auto some_layout = getLayout(op.getVector());
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
auto src_ty = op.getSourceVectorType();
TPU_CHECK_OP(permutation.size() == src_ty.getRank(),
"Transpose permutation has incorrect rank");
if (layout.implicit_dim() == ImplicitDim::kNone) {
TPU_CHECK_OP((layout.offsets() == LayoutOffsets{0, 0}),
"Padded transposes unsupported");
auto xlu_width = target_shape_[1];
for (int64_t s : src_ty.getShape().take_back(2)) {
TPU_CHECK_OP(s % xlu_width == 0, "Padded transposes unsupported");
}
for (auto dim : permutation.drop_back(2)) {
TPU_CHECK_OP(
dim < src_ty.getRank() - 2,
"Unsupported transpose permutation - minor dims into major");
}
for (auto dim : permutation.take_back(2)) {
TPU_CHECK_OP(
dim >= src_ty.getRank() - 2,
"Unsupported transpose permutation - major dims into minor");
}
Layout required_layout = some_layout;
if (permutation.size() < 2) {
return failure();
}
// Require native tiling if we're going to use the XLU.
if (permutation[permutation.size() - 1] == permutation.size() - 2) {
auto native_tiling = nativeTiling(layout.bitwidth());
required_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
native_tiling, ImplicitDim::kNone);
}
setLayout(op, required_layout, required_layout);
return success();
}
op.emitOpError("Unsupported transpose");
return failure();
for (auto dim : permutation.drop_back(2)) {
TPU_CHECK_OP(dim < src_ty.getRank() - 2,
"Unsupported transpose permutation - minor dims into major");
}
for (auto dim : permutation.take_back(2)) {
TPU_CHECK_OP(dim >= src_ty.getRank() - 2,
"Unsupported transpose permutation - major dims into minor");
}
Layout required_layout = some_layout;
// Require native tiling if we're going to use the XLU.
if (permutation[permutation.size() - 1] == permutation.size() - 2) {
auto native_tiling = nativeTiling(layout.bitwidth());
required_layout = VectorLayout(layout.bitwidth(), LayoutOffsets{0, 0},
native_tiling, ImplicitDim::kNone);
}
setLayout(op, required_layout, required_layout);
return success();
}

LogicalResult inferExt(Operation *op) {
Expand Down

0 comments on commit 687447c

Please sign in to comment.