Skip to content

Commit

Permalink
Add folding for shlo div op and start a pass that canonicalizes and f…
Browse files Browse the repository at this point in the history
…olds shlo within the odml converter.

PiperOrigin-RevId: 629851095
  • Loading branch information
LukeBoyer authored and tensorflower-gardener committed May 1, 2024
1 parent d89ceef commit 0ad69af
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 11 deletions.
42 changes: 31 additions & 11 deletions tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD
Expand Up @@ -17,13 +17,14 @@ package_group(

tf_cc_binary(
name = "odml-converter",
testonly = True,
srcs = ["odml_converter_main.cc"],
visibility = [
"//tensorflow/compiler/mlir/lite/stablehlo/odml_converter:__subpackages__",
"//third_party/odml/infra:__subpackages__",
], # Prototype phase.
deps = [
":all_passes",
":shlo_simplify",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass",
Expand All @@ -37,6 +38,35 @@ tf_cc_binary(
],
)

cc_library(
name = "shlo_simplify",
srcs = [
"transforms/shlo_simplify.cc",
],
hdrs = ["passes.h"],
deps = [
":folders",
":passes_inc_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)

cc_library(
name = "folders",
srcs = ["folders.cc"],
hdrs = ["folders.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@stablehlo//:stablehlo_ops",
],
)

gentbl_cc_library(
name = "passes_inc_gen",
tbl_outs = [
Expand All @@ -52,13 +82,3 @@ gentbl_cc_library(
td_file = "passes.td",
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
)

cc_library(
name = "all_passes",
hdrs = ["passes.h"],
deps = [":passes_inc_gen"],
)

exports_files([
"run_lit.sh",
])
116 changes: 116 additions & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc
@@ -0,0 +1,116 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <optional>
#include <vector>

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo

namespace mlir::odml {

namespace {

// Helper class for parsing operands to a foldable operation.
class FoldAdaptor {
public:
// Returns std::nullopt if the operation cannot be folded.
static std::optional<FoldAdaptor> Create(Operation* operation) {
auto foldable_opr = [](Value val) -> bool {
return !llvm::isa<BlockArgument>(val) &&
llvm::isa<stablehlo::ConstantOp>(val.getDefiningOp());
};
if (!llvm::all_of(operation->getOperands(), foldable_opr)) {
return std::nullopt;
}
return FoldAdaptor(operation);
}

// Gets a list of ElementsAttr behind each constant operand.
llvm::SmallVector<ElementsAttr> OperandData() {
llvm::SmallVector<ElementsAttr> res;
res.reserve(operation_->getNumOperands());
for (auto opr : operation_->getOperands()) {
auto op = llvm::dyn_cast<stablehlo::ConstantOp>(opr.getDefiningOp());
res.push_back(op.getValue());
}
return res;
}

// Gets a pointer to the operation to be folded.
Operation* Op() { return operation_; }

private:
explicit FoldAdaptor(Operation* operation) : operation_(operation) {}
Operation* const operation_;
};

// Tries to fold stablehlo::DivOp. Datatype must be floating point. Currently
// only supports splat values for the left hand side.
static LogicalResult FoldDivOp(stablehlo::DivOp op, PatternRewriter& rewriter) {
auto adaptor = FoldAdaptor::Create(op);
if (!adaptor.has_value()) return failure();
if (!op.getType().getElementType().isa<FloatType>()) {
return failure();
}
auto const_oprs = adaptor.value().OperandData();
if (const_oprs[1].isSplat()) {
return failure();
}

std::vector<APFloat> res;
res.reserve(const_oprs[1].getNumElements());

if (const_oprs[0].isSplat()) {
const APFloat lhs = const_oprs[0].getSplatValue<APFloat>();
for (const auto rhs : const_oprs[1].getValues<APFloat>()) {
if (rhs.isZero()) {
return failure();
}
res.push_back(lhs / rhs);
}
} else {
for (const auto [lhs, rhs] :
llvm::zip(const_oprs[0].getValues<APFloat>(),
const_oprs[1].getValues<APFloat>())) {
if (rhs.isZero()) {
return failure();
}
res.push_back(lhs / rhs);
}
}

auto res_attr = DenseElementsAttr::get(
const_oprs[0].getType().cast<RankedTensorType>(), res);
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(adaptor.value().Op(),
res_attr);
return success();
}
} // namespace

void PopulateFolderPatterns(RewritePatternSet& patternSet) {
patternSet.add(FoldDivOp, 10);
}

} // namespace mlir::odml
26 changes: 26 additions & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h
@@ -0,0 +1,26 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_

namespace mlir::odml {

// Populates the pattern set with all folding patterns. These patterns
// are intended to have precedence over any other patterns added to the set.
void PopulateFolderPatterns(RewritePatternSet &patternSet);

} // namespace mlir::odml

#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_
Expand Up @@ -16,8 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_

#include <memory>

#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project

namespace mlir::odml {

std::unique_ptr<OperationPass<ModuleOp>> CreateSHLOSimplifyPass();

#define GEN_PASS_REGISTRATION
#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h.inc"

Expand Down
11 changes: 11 additions & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.td
Expand Up @@ -15,3 +15,14 @@ limitations under the License.

include "mlir/Pass/PassBase.td"

def SHLOSimplifyPass: Pass<"shlo-simplify", "ModuleOp"> {
let summary = "Apply internal canonicalizations and foldings.";
let description = [{
TODO
}];

let options = [

];
let constructor = "CreateSHLOSimplifyPass()";
}
Expand Up @@ -21,5 +21,6 @@ filegroup(
data = [
"//tensorflow/compiler/mlir/lite/stablehlo/odml_converter:odml-converter",
"@llvm-project//llvm:FileCheck",
"@llvm-project//mlir:run_lit.sh",
],
)
@@ -0,0 +1,37 @@
// RUN: odml-converter --shlo-simplify %s -split-input-file | FileCheck %s

func.func @foldDiv() -> tensor<2xf32> {
%0 = stablehlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
%1 = stablehlo.constant dense<[4.0, 6.0]> : tensor<2xf32>
%2 = stablehlo.divide %0, %1 : tensor<2xf32>
return %2 : tensor<2xf32>
}

// CHECK-LABEL: foldDiv
// CHECK: stablehlo.constant dense<5.000000e-01> : tensor<2xf32>

// -----

func.func @foldDivLHSSplat() -> tensor<2xf32> {
%0 = stablehlo.constant dense<2.0> : tensor<2xf32>
%1 = stablehlo.constant dense<[4.0, 6.0]> : tensor<2xf32>
%2 = stablehlo.divide %0, %1 : tensor<2xf32>
return %2 : tensor<2xf32>
}

// CHECK-LABEL: foldDivLHSSplat
// CHECK: stablehlo.constant dense<[5.000000e-01, 0.333333343]> : tensor<2xf32>

// -----

func.func @foldDivF64() -> tensor<2xf64> {
%0 = stablehlo.constant dense<[2.0, 3.0]> : tensor<2xf64>
%1 = stablehlo.constant dense<[4.0, 6.0]> : tensor<2xf64>
%2 = stablehlo.divide %0, %1 : tensor<2xf64>
return %2 : tensor<2xf64>
}

// CHECK-LABEL: foldDivF64
// CHECK: stablehlo.constant dense<5.000000e-01> : tensor<2xf64>


@@ -0,0 +1,54 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>

#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h"

namespace mlir {
namespace odml {
namespace {

#define GEN_PASS_DEF_SHLOSIMPLIFYPASS
#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h.inc"

// Performs misc odml "cleanup" on shlo dialect. This is a functional standin
// for canonicalization and folding which is not offered directly by the
// shlo implementation.
class SHLOSimplifyPass : public impl::SHLOSimplifyPassBase<SHLOSimplifyPass> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SHLOSimplifyPass)

void runOnOperation() override {
ModuleOp module = getOperation();
RewritePatternSet patterns(&getContext());
PopulateFolderPatterns(patterns);
(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
}
};

} // namespace

std::unique_ptr<OperationPass<ModuleOp>> CreateSHLOSimplifyPass() {
return std::make_unique<SHLOSimplifyPass>();
}

} // namespace odml
} // namespace mlir

0 comments on commit 0ad69af

Please sign in to comment.