diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index c3605ce92f38..a9094ec807b6 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -193,7 +193,7 @@ xla::Array concatenate(const ArrayRef> arrays, } xla::Array res(dims); int64_t offset = 0; - for (xla::Array const& arr : arrays) { + for (xla::Array const &arr : arrays) { arr.Each([&](const absl::Span idx, const Value v) { SmallVector res_idx(toArrayRef(idx)); res_idx[axis] += offset; @@ -249,7 +249,7 @@ bool incrementIndex(const MutableArrayRef idx, } bool sliceIsEmpty(const absl::Span starts, - const absl::Span limits) { + const absl::Span limits) { for (auto [s, l] : llvm::zip_equal(starts, limits)) { CHECK_LE(s, l); if (s == l) { @@ -282,9 +282,19 @@ void updateSliceFromRange(xla::Array &arr, Range data, return; } SmallVector idx(toArrayRef(starts)); + auto in_bounds = [&] { + for (int64_t i = 0; i < idx.size(); ++i) { + if (idx[i] >= limits[i]) { + return false; + } + } + return true; + }; auto data_it = data.begin(); do { - arr(idx) = *data_it; + if (in_bounds()) { + arr(idx) = *data_it; + } ++data_it; } while (incrementSliceIndex(idx, starts, limits)); CHECK(data_it == data.end()); @@ -1307,7 +1317,7 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, } tpu::LoadOp load_op = cast(op); if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape, - VectorLayout::ImplicitDim::kNone)) { + VectorLayout::ImplicitDim::kNone)) { return op.emitOpError("Invalid output layout for ") << load_op->getName(); } FAILUREOR_ASSIGN_OR_RETURN( @@ -1863,7 +1873,7 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(op.getNumResults(), 1); TPU_ASSERT_EQ_OP(layouts_in.size(), 1); TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - if (layouts_in[0] !=layouts_out[0]) { + if (layouts_in[0] != layouts_out[0]) { return op.emitOpError("Expected same input and output layout"); } OpBuilder builder(&op); @@ -2622,8 +2632,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, const LayoutOffsets offsets_in = layout_in.offsets(); const LayoutOffsets offsets_out = layout_out.offsets(); if (layout_in.tiling() != layout_out.tiling()) { - return op.emitOpError( - "Not implemented: Changing tiling mid-broadcast"); + return op.emitOpError("Not implemented: Changing tiling mid-broadcast"); } auto tiling = layout_in.tiling(); @@ -2745,8 +2754,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, VectorType::get(ctx.target_shape, builder.getI32Type()); auto idx_const = builder.create( broadcast_op.getLoc(), idx_ty, - DenseElementsAttr::get(idx_ty, - builder.getI32IntegerAttr(offset))); + DenseElementsAttr::get(idx_ty, builder.getI32IntegerAttr(offset))); int64_t sublanes_per_tile = layout_in.sublanesPerTile(ctx.target_shape); DenseI32ArrayAttr sublane_pattern; if (num_tiles != 1) { @@ -3687,11 +3695,6 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, "Not implemented: Non-native or offset layout unsupported"); } const int64_t transpose_unit_size = ctx.target_shape[1]; - for (const int64_t s : src_ty.getShape().take_back(2)) { - if (s % transpose_unit_size != 0) { - return transpose_op->emitOpError("Not implemented: Padded transpose"); - } - } if (ctx.hardware_generation < 4 && layout_in.bitwidth() != 32) { return transpose_op->emitOpError( "Not implemented: TPUs before v4 only support 32-bit transposes"); @@ -3730,8 +3733,8 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, src_slice_ends.append(incremented_batch_idx.begin(), incremented_batch_idx.end()); src_slice_ends.append({(src_row + 1) * vregs_per_tile, src_col_end}); - xla::Array src_tile_vregs = - src_vregs.Slice(src_slice_starts, src_slice_ends); + xla::Array src_tile_vregs = src_vregs.Slice( + src_slice_starts, src_slice_ends, /*out_of_bounds_ok=*/true); // Drop leading singleton (batch) dimensions to have a shape that conforms // with the vreg array shape specified by layout_in, as expected by assemble src_tile_vregs.Reshape( @@ -3762,12 +3765,12 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, const ArrayRef batch_sizes = dst_ty.getShape().take_front(num_batch_dims); SmallVector batch_idx(num_batch_dims); + const int64_t tile_rows = + xla::CeilOfRatio(*(src_ty.getShape().end() - 2), transpose_unit_size); + const int64_t num_col_tiles = + xla::CeilOfRatio(*(src_ty.getShape().end() - 1), transpose_unit_size); do { - const int64_t tile_rows = - *(src_ty.getShape().end() - 2) / transpose_unit_size; for (int64_t src_row = 0; src_row < tile_rows; ++src_row) { - const int64_t num_col_tiles = - *(src_ty.getShape().end() - 1) / transpose_unit_size; if (can_batch) { const int64_t num_batch_tiles = num_col_tiles / 2; for (int64_t src_col = 0; src_col < num_batch_tiles; ++src_col) { @@ -4307,7 +4310,7 @@ FailureOr> relayout( *(src_tiles.dimensions().end() - 2) == 1)) && dst.offsets()[1] == 0 && src.tiling() == std::array{1, 128} && dst.tiling() == std::array{8, 128}) { - xla::Array src_tiles_retiled( + xla::Array src_tiles_retiled( dst.tileArrayShape(vty.getShape(), target_shape)); src_tiles_retiled.Each([&](absl::Span idx, Value *tile) { for (int dst_sl_idx = 0; dst_sl_idx < 8; ++dst_sl_idx) { @@ -4466,8 +4469,8 @@ FailureOr> relayout( v.getLoc(), bits_vreg_ty, DenseElementsAttr::get(bits_vreg_ty, shift_bits)); dst_tiles.Each([&](absl::Span /*idx*/, Value *tile) { - auto bit_tile = - builder.create(v.getLoc(), bits_vreg_ty, *tile); + auto bit_tile = builder.create( + v.getLoc(), bits_vreg_ty, *tile); Operation *shift_tile; if (subelem_diff > 0) { shift_tile = @@ -4479,7 +4482,7 @@ FailureOr> relayout( } *tile = builder .create(v.getLoc(), tile->getType(), - shift_tile->getResult(0)) + shift_tile->getResult(0)) .getResult(); return absl::OkStatus(); }); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 74b51a9e6589..67ebca2ebc4c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -898,8 +898,9 @@ class VectorLayoutInferer { if (some_layout->tiling()[0] == 1) { offsets[0] = std::nullopt; } - *some_layout = VectorLayout(some_layout->bitwidth(), offsets, - default_tiling_, some_layout->implicit_dim()); + *some_layout = + VectorLayout(some_layout->bitwidth(), offsets, default_tiling_, + some_layout->implicit_dim()); } auto &layout = *some_layout; if (layout.implicit_dim() != ImplicitDim::kNone) { @@ -1410,44 +1411,32 @@ class VectorLayoutInferer { LogicalResult infer(vector::TransposeOp op) { auto permutation = op.getPermutation(); + TPU_CHECK_OP(permutation.size() > 1, + "Vector and scalar transpose should be a no-op and removed"); + auto some_layout = getLayout(op.getVector()); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); auto &layout = *some_layout; auto src_ty = op.getSourceVectorType(); TPU_CHECK_OP(permutation.size() == src_ty.getRank(), "Transpose permutation has incorrect rank"); - if (layout.implicit_dim() == ImplicitDim::kNone) { - TPU_CHECK_OP((layout.offsets() == LayoutOffsets{0, 0}), - "Padded transposes unsupported"); - auto xlu_width = target_shape_[1]; - for (int64_t s : src_ty.getShape().take_back(2)) { - TPU_CHECK_OP(s % xlu_width == 0, "Padded transposes unsupported"); - } - for (auto dim : permutation.drop_back(2)) { - TPU_CHECK_OP( - dim < src_ty.getRank() - 2, - "Unsupported transpose permutation - minor dims into major"); - } - for (auto dim : permutation.take_back(2)) { - TPU_CHECK_OP( - dim >= src_ty.getRank() - 2, - "Unsupported transpose permutation - major dims into minor"); - } - Layout required_layout = some_layout; - if (permutation.size() < 2) { - return failure(); - } - // Require native tiling if we're going to use the XLU. - if (permutation[permutation.size() - 1] == permutation.size() - 2) { - auto native_tiling = nativeTiling(layout.bitwidth()); - required_layout = VectorLayout(layout.bitwidth(), layout.offsets(), - native_tiling, ImplicitDim::kNone); - } - setLayout(op, required_layout, required_layout); - return success(); - } - op.emitOpError("Unsupported transpose"); - return failure(); + for (auto dim : permutation.drop_back(2)) { + TPU_CHECK_OP(dim < src_ty.getRank() - 2, + "Unsupported transpose permutation - minor dims into major"); + } + for (auto dim : permutation.take_back(2)) { + TPU_CHECK_OP(dim >= src_ty.getRank() - 2, + "Unsupported transpose permutation - major dims into minor"); + } + Layout required_layout = some_layout; + // Require native tiling if we're going to use the XLU. + if (permutation[permutation.size() - 1] == permutation.size() - 2) { + auto native_tiling = nativeTiling(layout.bitwidth()); + required_layout = VectorLayout(layout.bitwidth(), LayoutOffsets{0, 0}, + native_tiling, ImplicitDim::kNone); + } + setLayout(op, required_layout, required_layout); + return success(); } LogicalResult inferExt(Operation *op) {