Skip to content

Commit

Permalink
[Mosaic] Add support for remote DMAs and semaphores in megacore mode
Browse files Browse the repository at this point in the history
The change to tpu.td is not backwards compatible, but I made it so using the
newly added Mosaic stability layer. It's been a good exercise and it seems to
be working just fine.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 626978604
  • Loading branch information
2 people authored and jax authors committed May 2, 2024
1 parent b40a310 commit b03f52d
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 21 deletions.
8 changes: 5 additions & 3 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -2230,12 +2230,14 @@ def _semaphore_signal_lowering_rule(
args_tree,
device_id_type: tpu_primitives.DeviceIdType,
):
sem_aval, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
sem, indexers, value, device_id = tree_util.tree_unflatten(args_tree, args)
sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
sem, indexers, value, device_id, core_index = tree_util.tree_unflatten(args_tree, args)
sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers)
if device_id is not None:
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
return tpu.SemaphoreSignalOp(sem, value, device_id=device_id).results
return tpu.SemaphoreSignalOp(
sem, value, device_id=device_id, core_id=core_index
).results


lowering_rules[tpu_primitives.semaphore_signal_p] = (
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/pallas/mosaic/primitives.py
Expand Up @@ -230,10 +230,11 @@ def semaphore_signal(
*,
device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None,
device_id_type: DeviceIdType = DeviceIdType.MESH,
core_index: int | jax.Array | None = None,
):
ref, indexers = _get_ref_and_indexers(sem_or_view)
inc = jnp.asarray(inc, dtype=jnp.int32)
args = [ref, indexers, inc, device_id]
args = [ref, indexers, inc, device_id, core_index]
flat_args, args_tree = tree_util.tree_flatten(args)
semaphore_signal_p.bind(
*flat_args,
Expand All @@ -249,7 +250,7 @@ def _semaphore_signal_abstract_eval(
device_id_type: DeviceIdType,
):
del device_id_type
sem_aval, sem_indexers_avals, value_aval, device_id_avals = (
sem_aval, sem_indexers_avals, value_aval, device_id_avals, core_index_aval = (
tree_util.tree_unflatten(args_tree, avals)
)
check_sem_avals(sem_aval, sem_indexers_avals, "signal")
Expand All @@ -274,6 +275,7 @@ def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn,
sem_indexers,
value,
device_ids,
_,
) = tree_util.tree_unflatten(tree, invars)
out = pp.concat([
pp.text('semaphore_signal'),
Expand Down
12 changes: 7 additions & 5 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Expand Up @@ -480,25 +480,27 @@ def TPU_GetBarrierSemaphoreOp : TPU_Op<"sem_barrier"> {
let hasVerifier = 1;
}

def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal"> {
def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> {
let arguments = (ins
MemRefOf<[TPU_SemaphoreType]>:$semaphore,
I32:$amount,
Optional<I32>:$device_id // For remote DMAs
Optional<I32>:$device_id, // For remote DMAs
Optional<I32>:$core_id // For megacore
);
let assemblyFormat = [{
$semaphore `,` $amount (`,` $device_id^)? attr-dict `:` type($semaphore)
$semaphore `,` $amount (`,` $device_id^)? (`,` $core_id^)? attr-dict `:` type($semaphore)
}];
let hasVerifier = 1;
}

def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [SameVariadicOperandSize]> {
def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> {
let arguments = (ins
AnyMemRef:$source,
Optional<MemRefOf<[TPU_DMASemaphoreType]>>:$source_semaphore, // For remote DMAs
AnyMemRef:$target,
MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore,
Optional<I32>:$device_id // For remote DMAs
Optional<I32>:$device_id, // For remote DMAs
Optional<I32>:$core_id // For megacore
);
let hasVerifier = 1;
}
Expand Down
16 changes: 10 additions & 6 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Expand Up @@ -317,8 +317,7 @@ LogicalResult GetBarrierSemaphoreOp::verify() {
LogicalResult SemaphoreSignalOp::verify() {
auto sem_type = getMemRefType(getSemaphore());
if (sem_type.getRank() != 0) {
emitOpError("Semaphore reference must be rank 0");
return failure();
return emitOpError("Semaphore reference must be rank 0");
}
return success();
}
Expand All @@ -328,14 +327,19 @@ LogicalResult EnqueueDMAOp::verify() {
if (source_sem) {
auto source_sem_type = getMemRefType(getSourceSemaphore());
if (source_sem_type.getRank() != 0) {
emitOpError("DMA source semaphore reference must be rank 0");
return failure();
return emitOpError("DMA source semaphore reference must be rank 0");
}
}
auto target_sem_type = getMemRefType(getTargetSemaphore());
if (target_sem_type.getRank() != 0) {
emitOpError("DMA target semaphore must be rank 0");
return failure();
return emitOpError("DMA target semaphore must be rank 0");
}
if (getDeviceId() || getCoreId()) {
if (!getSourceSemaphore()) {
return emitOpError(
"DMA source semaphore must be specified when "
"device_id or core_id is specified");
}
}
return success();
}
Expand Down
70 changes: 65 additions & 5 deletions jaxlib/mosaic/dialect/tpu/transforms/serde.cc
Expand Up @@ -27,19 +27,21 @@ limitations under the License.
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "mlir/include/mlir/IR/OpDefinition.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "mlir/include/mlir/Support/LogicalResult.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"

namespace mlir::tpu {

#define GEN_PASS_DECL_MOSAICSERDEPASS
#define GEN_PASS_DEF_MOSAICSERDEPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

namespace {

constexpr std::string_view kMangledDialect = "stable_mosaic.";
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
constexpr int kVersion = 1;
constexpr int kVersion = 2;

StringRef mangle(StringRef name, std::string* storage) {
storage->clear();
Expand All @@ -57,6 +59,55 @@ std::optional<StringRef> demangle(StringRef name) {
return name.drop_front(kMangledDialect.size());
}

using rule_type = std::function<LogicalResult(Operation*, int)>;

LogicalResult enqueue_dma_rule(Operation* op, int version) {
// Added AttrSizedOperandSegments and core_id in version 2.
if (version < 2) {
if (op->getNumOperands() == 3) { // Local DMA.
op->setAttr(
OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 0, 1, 1, 0, 0}));
} else if (op->getNumOperands() == 5) { // Remote DMA.
op->setAttr(
OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 1, 1, 0}));
} else {
return op->emitError("Unexpected operand count in tpu.enqueue_dma: ")
<< op->getNumOperands();
}
}
return success();
}

LogicalResult semaphore_signal_rule(Operation* op, int version) {
// Added AttrSizedOperandSegments and core_id in version 2.
if (version < 2) {
if (op->getNumOperands() == 2) { // Local signal.
op->setAttr(OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0}));
} else if (op->getNumOperands() == 3) { // Remote signal.
op->setAttr(OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0}));
} else {
return op->emitError("Unexpected operand count in tpu.semaphore_signal");
}
}
return success();
}

const llvm::StringMap<rule_type>& upgrade_rules() {
static auto rules = new llvm::StringMap<rule_type>{
{EnqueueDMAOp::getOperationName(), enqueue_dma_rule},
{SemaphoreSignalOp::getOperationName(), semaphore_signal_rule},
};
return *rules;
}

struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
using Base::Base;

Expand All @@ -68,6 +119,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
signalPassFailure();
return;
}
int version = kVersion;
if (serialize) {
module->setAttr(
kVersionAttrName,
Expand All @@ -81,16 +133,17 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
signalPassFailure();
return;
}
if (version_attr.getValue() != kVersion) {
if (version_attr.getInt() > kVersion) {
module->emitError("Unsupported Mosaic version: ")
<< version_attr.getValue().getSExtValue();
<< version_attr.getInt();
signalPassFailure();
return;
}
version = version_attr.getInt();
module->removeAttr(kVersionAttrName);
}
std::string name_storage;
auto result = module.walk([this, &name_storage](Operation* op) {
auto result = module.walk([this, &name_storage, version](Operation* op) {
if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
return WalkResult::advance();
}
Expand All @@ -111,6 +164,13 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
op->emitError("Operation not in a serialized form");
return WalkResult::interrupt();
}
// Upgrade the op to the current version, if needed.
if (const auto rule = upgrade_rules().find(new_name->getStringRef());
rule != upgrade_rules().end()) {
if (rule->second(op, version).failed()) {
return WalkResult::interrupt();
}
}
}
auto new_op = Operation::create(
op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),
Expand Down

0 comments on commit b03f52d

Please sign in to comment.