From b03f52d698bc6ec60cf8bab6c94104716d2648c3 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 22 Apr 2024 03:32:31 -0700 Subject: [PATCH] [Mosaic] Add support for remote DMAs and semaphores in megacore mode The change to tpu.td is not backwards compatible, but I made it so using the newly added Mosaic stability layer. It's been a good exercise and it seems to be working just fine. Co-authored-by: Sharad Vikram PiperOrigin-RevId: 626978604 --- jax/_src/pallas/mosaic/lowering.py | 8 ++- jax/_src/pallas/mosaic/primitives.py | 6 +- jaxlib/mosaic/dialect/tpu/tpu.td | 12 ++-- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 16 +++-- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 70 +++++++++++++++++-- 5 files changed, 91 insertions(+), 21 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 37890b757702..4f449bab64e8 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2230,12 +2230,14 @@ def _semaphore_signal_lowering_rule( args_tree, device_id_type: tpu_primitives.DeviceIdType, ): - sem_aval, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, indexers, value, device_id = tree_util.tree_unflatten(args_tree, args) + sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem, indexers, value, device_id, core_index = tree_util.tree_unflatten(args_tree, args) sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) - return tpu.SemaphoreSignalOp(sem, value, device_id=device_id).results + return tpu.SemaphoreSignalOp( + sem, value, device_id=device_id, core_id=core_index + ).results lowering_rules[tpu_primitives.semaphore_signal_p] = ( diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 8b4d8aae899e..18afd300b041 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -230,10 +230,11 @@ def semaphore_signal( *, device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, device_id_type: DeviceIdType = DeviceIdType.MESH, + core_index: int | jax.Array | None = None, ): ref, indexers = _get_ref_and_indexers(sem_or_view) inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, indexers, inc, device_id] + args = [ref, indexers, inc, device_id, core_index] flat_args, args_tree = tree_util.tree_flatten(args) semaphore_signal_p.bind( *flat_args, @@ -249,7 +250,7 @@ def _semaphore_signal_abstract_eval( device_id_type: DeviceIdType, ): del device_id_type - sem_aval, sem_indexers_avals, value_aval, device_id_avals = ( + sem_aval, sem_indexers_avals, value_aval, device_id_avals, core_index_aval = ( tree_util.tree_unflatten(args_tree, avals) ) check_sem_avals(sem_aval, sem_indexers_avals, "signal") @@ -274,6 +275,7 @@ def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, sem_indexers, value, device_ids, + _, ) = tree_util.tree_unflatten(tree, invars) out = pp.concat([ pp.text('semaphore_signal'), diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 461384b0273b..8b4353994dca 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -480,25 +480,27 @@ def TPU_GetBarrierSemaphoreOp : TPU_Op<"sem_barrier"> { let hasVerifier = 1; } -def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal"> { +def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { let arguments = (ins MemRefOf<[TPU_SemaphoreType]>:$semaphore, I32:$amount, - Optional:$device_id // For remote DMAs + Optional:$device_id, // For remote DMAs + Optional:$core_id // For megacore ); let assemblyFormat = [{ - $semaphore `,` $amount (`,` $device_id^)? attr-dict `:` type($semaphore) + $semaphore `,` $amount (`,` $device_id^)? (`,` $core_id^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; } -def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [SameVariadicOperandSize]> { +def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { let arguments = (ins AnyMemRef:$source, Optional>:$source_semaphore, // For remote DMAs AnyMemRef:$target, MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, - Optional:$device_id // For remote DMAs + Optional:$device_id, // For remote DMAs + Optional:$core_id // For megacore ); let hasVerifier = 1; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 2bd6c4c10697..9ed7c43f1767 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -317,8 +317,7 @@ LogicalResult GetBarrierSemaphoreOp::verify() { LogicalResult SemaphoreSignalOp::verify() { auto sem_type = getMemRefType(getSemaphore()); if (sem_type.getRank() != 0) { - emitOpError("Semaphore reference must be rank 0"); - return failure(); + return emitOpError("Semaphore reference must be rank 0"); } return success(); } @@ -328,14 +327,19 @@ LogicalResult EnqueueDMAOp::verify() { if (source_sem) { auto source_sem_type = getMemRefType(getSourceSemaphore()); if (source_sem_type.getRank() != 0) { - emitOpError("DMA source semaphore reference must be rank 0"); - return failure(); + return emitOpError("DMA source semaphore reference must be rank 0"); } } auto target_sem_type = getMemRefType(getTargetSemaphore()); if (target_sem_type.getRank() != 0) { - emitOpError("DMA target semaphore must be rank 0"); - return failure(); + return emitOpError("DMA target semaphore must be rank 0"); + } + if (getDeviceId() || getCoreId()) { + if (!getSourceSemaphore()) { + return emitOpError( + "DMA source semaphore must be specified when " + "device_id or core_id is specified"); + } } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 1c4d5f6c323b..ac2389d6c238 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -27,11 +27,13 @@ limitations under the License. #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" +#include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" +#include "mlir/include/mlir/Support/LogicalResult.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { -#define GEN_PASS_DECL_MOSAICSERDEPASS #define GEN_PASS_DEF_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" @@ -39,7 +41,7 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; -constexpr int kVersion = 1; +constexpr int kVersion = 2; StringRef mangle(StringRef name, std::string* storage) { storage->clear(); @@ -57,6 +59,55 @@ std::optional demangle(StringRef name) { return name.drop_front(kMangledDialect.size()); } +using rule_type = std::function; + +LogicalResult enqueue_dma_rule(Operation* op, int version) { + // Added AttrSizedOperandSegments and core_id in version 2. + if (version < 2) { + if (op->getNumOperands() == 3) { // Local DMA. + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 0, 1, 1, 0, 0})); + } else if (op->getNumOperands() == 5) { // Remote DMA. + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 1, 1, 0})); + } else { + return op->emitError("Unexpected operand count in tpu.enqueue_dma: ") + << op->getNumOperands(); + } + } + return success(); +} + +LogicalResult semaphore_signal_rule(Operation* op, int version) { + // Added AttrSizedOperandSegments and core_id in version 2. + if (version < 2) { + if (op->getNumOperands() == 2) { // Local signal. + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); + } else if (op->getNumOperands() == 3) { // Remote signal. + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); + } else { + return op->emitError("Unexpected operand count in tpu.semaphore_signal"); + } + } + return success(); +} + +const llvm::StringMap& upgrade_rules() { + static auto rules = new llvm::StringMap{ + {EnqueueDMAOp::getOperationName(), enqueue_dma_rule}, + {SemaphoreSignalOp::getOperationName(), semaphore_signal_rule}, + }; + return *rules; +} + struct MosaicSerdePass : public impl::MosaicSerdePassBase { using Base::Base; @@ -68,6 +119,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { signalPassFailure(); return; } + int version = kVersion; if (serialize) { module->setAttr( kVersionAttrName, @@ -81,16 +133,17 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { signalPassFailure(); return; } - if (version_attr.getValue() != kVersion) { + if (version_attr.getInt() > kVersion) { module->emitError("Unsupported Mosaic version: ") - << version_attr.getValue().getSExtValue(); + << version_attr.getInt(); signalPassFailure(); return; } + version = version_attr.getInt(); module->removeAttr(kVersionAttrName); } std::string name_storage; - auto result = module.walk([this, &name_storage](Operation* op) { + auto result = module.walk([this, &name_storage, version](Operation* op) { if (isa(op)) { // Don't mangle the ModuleOp itself. return WalkResult::advance(); } @@ -111,6 +164,13 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { op->emitError("Operation not in a serialized form"); return WalkResult::interrupt(); } + // Upgrade the op to the current version, if needed. + if (const auto rule = upgrade_rules().find(new_name->getStringRef()); + rule != upgrade_rules().end()) { + if (rule->second(op, version).failed()) { + return WalkResult::interrupt(); + } + } } auto new_op = Operation::create( op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),