diff --git a/.bazelrc b/.bazelrc index ebf8b415286c4f..445555f50ad090 100644 --- a/.bazelrc +++ b/.bazelrc @@ -663,23 +663,3 @@ build:ubsan --linkopt -lubsan # Disable TFRT integration for now unless --config=tfrt is specified. build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/common,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils build:tfrt --deleted_packages= - -# Experimental configuration for building XLA GPU lowering to TFRT. -build:experimental_enable_xlir --config=tfrt -build:experimental_enable_xlir --@tf_runtime//:enable_gpu -build:experimental_enable_xlir --@rules_cuda//cuda:cuda_runtime=//tensorflow/compiler/xla/service/gpu:cuda_runtime_for_xlir -build:experimental_enable_xlir --nocheck_visibility -build:experimental_enable_xlir --incompatible_strict_action_env -build:experimental_enable_xlir --config=monolithic -build:experimental_enable_xlir --//tensorflow/compiler/xla/service/gpu:enable_xlir - -# bazel test --config=experimental_enable_bef_thunk \ -# //tensorflow/compiler/xla/service/gpu:bef_thunk_tests -build:experimental_enable_bef_thunk --config=experimental_enable_xlir -test:experimental_enable_bef_thunk --test_env=XLA_FLAGS=--xla_gpu_bef_thunk - -# bazel test --config=experimental_enable_bef_executable \ -# //tensorflow/compiler/xla/service/gpu:bef_executable_tests -build:experimental_enable_bef_executable --config=experimental_enable_xlir -test:experimental_enable_bef_executable --test_env=XLA_FLAGS=--xla_gpu_bef_executable - diff --git a/.github/workflows/trusted-partners.yml b/.github/workflows/trusted-partners.yml index ef944aa029dccb..abf62dd2b8a2b1 100644 --- a/.github/workflows/trusted-partners.yml +++ b/.github/workflows/trusted-partners.yml @@ -41,10 +41,13 @@ jobs: const domain = await script.get_email_domain({github, username}); switch(domain) { case "intel.com": - console.log(await script.filter({github, context})); + console.log(await script.filter({github, context, domain})); break; case "apple.com": - console.log(await script.filter({github, context})); + console.log(await script.filter({github, context, domain})); + break; + case "nvidia.com": + console.log(await script.filter({github, context, domain})); break; case "google.com": console.log("Googler. No action necessary"); diff --git a/.github/workflows/trusted_partners.js b/.github/workflows/trusted_partners.js index ed622ed27f779a..09a2749e9d28e8 100644 --- a/.github/workflows/trusted_partners.js +++ b/.github/workflows/trusted_partners.js @@ -49,7 +49,7 @@ const get_email_domain = async ({github, username}) => { context has the commit message details in the payload @return {string} Returns the message with labels attached and assignees added */ -const filter_action = async ({github, context}) => { +const filter_action = async ({github, context, domain}) => { const labels = ['kokoro:force-run', 'ready to pull']; let assignees = []; @@ -58,11 +58,26 @@ const filter_action = async ({github, context}) => { if (title && title.toLowerCase().includes("onednn")) assignees = onednn_assignees; const intel_windows_assignees = ['nitins17', 'learning-to-play']; - if (title && title.toLowerCase().includes("intel") && title.toLowerCase().includes("windows")) + if (title && title.toLowerCase().includes('intel') && + title.toLowerCase().includes('windows') && domain.includes('intel.com')) assignees = intel_windows_assignees; const apple_silicon_assignees = ['penpornk', 'nitins17']; - if (title && title.toLowerCase().includes("apple") && title.toLowerCase().includes("silicon")) + if (title && title.toLowerCase().includes('apple') && + title.toLowerCase().includes('silicon') && domain.includes('apple.com')) assignees = apple_silicon_assignees; + if (title && title.toLowerCase().includes('nvidia') && + domain.includes('nvidia.com')) { + if (title.toLowerCase().includes('jax')) { + assignees.push('hawkinsp', 'yashk2810', 'skye'); + } + if (title.toLowerCase().includes('xla') || + title.toLowerCase().includes('gpu')) { + assignees.push('cheshire', 'gcforster', 'reedwm', 'chsigg'); + } + if (title.toLowerCase().includes('tf')) { + assignees.push('rohan100jain', 'bfontain', 'penpornk'); + } + } const resp_label = await github.rest.issues.addLabels({ issue_number: context.issue.number, diff --git a/RELEASE.md b/RELEASE.md index f8850ecbf35420..c67c3b99b64ed0 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -6,6 +6,10 @@ * * +* Causal attention in `keras.layers.Attention` and + `keras.layers.AdditiveAttention` is now specified in the `call()` method + via the `use_causal_mask` argument (rather than in the constructor), + for consistency with other layers. * Some files in `tensorflow/python/training` have been moved to `tensorflow/python/tracking` and `tensorflow/python/checkpoint`. Please update your imports accordingly, the old files will be removed in Release diff --git a/tensorflow/BUILD b/tensorflow/BUILD index de9a6db92cd2f6..a56cbbe1a3602c 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -897,6 +897,7 @@ config_setting( package_group( name = "internal", packages = [ + "//devtools/python/indexer/...", "//learning/brain/keras/...", "//learning/brain/mlir/...", "//learning/brain/tfrt/...", diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index a24bb003e15ab1..e731fd3420bee5 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -29,6 +29,26 @@ cc_library( ], ) +cc_library( + name = "io_ops", + srcs = [ + "io_ops.cc", + ], + hdrs = [ + "io_ops.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:tracing_utils", + "//tensorflow/core:framework", + "//tensorflow/core/platform:errors", + ], +) + cc_library( name = "math_ops", srcs = [ @@ -99,6 +119,7 @@ cc_library( name = "ops", hdrs = [ "array_ops.h", + "io_ops.h", "math_ops.h", "nn_ops.h", "resource_variable_ops.h", @@ -108,6 +129,7 @@ cc_library( ], deps = [ ":array_ops", + ":io_ops", ":math_ops", ":nn_ops", ":resource_variable_ops", @@ -122,6 +144,7 @@ filegroup( name = "pywrap_required_hdrs", srcs = [ "array_ops.h", + "io_ops.h", "math_ops.h", "nn_ops.h", "resource_variable_ops.h", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 12cbe6f7dec314..ef69c24fc4a9fb 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -488,6 +488,14 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } + if (!op_filter_.allow_where_op && node.type_string() == "Where") { + absl::string_view uncompilable_reason = "Where op"; + MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, + encapsulating_function, uncompilable_nodes); + LogNotCompilable(node, uncompilable_reason); + return false; + } + if (!op_filter_.allow_unique_op && node.type_string() == "Unique") { absl::string_view uncompilable_reason = "Unique op"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 8435e8eea1af9d..2b31e575779bce 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -136,6 +136,9 @@ class RecursiveCompilabilityChecker { // Whether to allow the compilation of CollectiveReduceV2Op. bool allow_collective_reduce_v2 = true; + // Whether to allow the compilation of WhereOp. + bool allow_where_op = true; + // Whether to allow the compilation of UniqueOp. Compilation of the UniqueOp // generates output with bounded dynamic shape that may cause failures with // auto clustering. diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index c0cf4b4a579d54..be2ec0efc2d720 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -114,6 +114,12 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { " BN: TF FusedBatchNorm* operations." " FUSIBLE: All TF operations that XLA can fuse (All the above). " "You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."), + Flag("tf_xla_cluster_exclude_ops", + &mark_for_compilation_flags->tf_xla_cluster_exclude_ops, + "(experimental) " + "Exclude the operations from auto-clustering. " + "If multiple, separate them with commas." + " Where, Some_other_ops"), Flag("tf_xla_clustering_debug", &mark_for_compilation_flags->tf_xla_clustering_debug, "Dump graphs during XLA compilation."), diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 13c56b08aaed3b..1cbfdb9caf5809 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -61,6 +61,9 @@ struct MarkForCompilationPassFlags { // If non-empty, limit XLA clustering to the following TF operations. string tf_xla_ops_to_cluster; + // If non-empty, remove following operations from XLA clustering excludelist. + string tf_xla_cluster_exclude_ops; + // Dump graphs during XLA compilation. bool tf_xla_clustering_debug; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 72c764dd7d5125..9a849a4d8500a1 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1189,6 +1189,24 @@ StatusOr IsIdentityDrivingConstsInLoop(Node* node) { return true; } +absl::flat_hash_set GetOrCreateClusterExcludeList() { + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + absl::flat_hash_set excludelist; + for (auto s : absl::StrSplit(flags->tf_xla_cluster_exclude_ops, ',')) { + if (!s.empty()) { + excludelist.insert(string(s)); + } + } + if (VLOG_IS_ON(2) && !excludelist.empty()) { + std::vector vexcludelist(excludelist.begin(), excludelist.end()); + absl::c_sort(vexcludelist); + VLOG(2) << "XLA clustering will exclude following TF operations from auto " + "clustering: " + << absl::StrJoin(vexcludelist, " "); + } + return excludelist; +} + absl::flat_hash_set GetOrCreateAllowlist() { absl::flat_hash_map>* allowlist_table = tensorflow::GetAllowlistTable(); @@ -1289,12 +1307,25 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } + auto cluster_exclude_op_list = GetOrCreateClusterExcludeList(); RecursiveCompilabilityChecker::OperationFilter filter = CreateOperationFilter(*registration); filter.require_always_compilable = true; filter.allow_string_consts = false; filter.allow_collective_reduce_v2 = false; filter.allow_unique_op = false; + filter.allow_where_op = true; + + for (const auto& s : cluster_exclude_op_list) { + if (s == "Where") { + filter.allow_where_op = false; + } else { + return errors::InvalidArgument( + "The operation '", s, + "' passed to --tf_xla_cluster_exclude_ops is not supported by " + "XLA."); + } + } RecursiveCompilabilityChecker checker( filter, DeviceType{registration->compilation_device_name}); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index aeecd9a3947d20..9c897fe1c61cba 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -196,6 +196,24 @@ TEST(XlaCompilationTest, StringUnsupported) { EXPECT_TRUE(clusters.empty()); } +TEST(XlaCompilationTest, WhereUnsupported) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_INT32) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Where", a, builder.opts().WithName("B")); + ops::BinaryOp("Gather", b, a, builder.opts().WithName("C")); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + EXPECT_TRUE(!clusters.empty()); +} + TEST(XlaCompilationTest, HalfSupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); { diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 279fc93612e38f..aa65703fa1c8d8 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -1365,6 +1365,7 @@ cc_library( "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], @@ -2193,6 +2194,7 @@ cc_library( deps = [ ":bufferize_pass", ":gml_st_tiling", + ":gml_st_vectorization", ":legalize_mhlo_to_gml", ":legalize_to_linalg", "@llvm-project//mlir:BufferizationTransforms", @@ -2394,9 +2396,10 @@ gentbl_cc_library( td_file = "include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td", td_srcs = [ "include/mlir-hlo/Dialect/gml_st/IR/gml_st_extension_ops.td", - "include/mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td", "include/mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td", "include/mlir-hlo/Dialect/gml_st/IR/gml_st_set_ops.td", + "include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.td", + "include/mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td", ], deps = [":gml_st_ops_td_files"], ) @@ -2411,6 +2414,7 @@ cc_library( ], includes = ["include"], deps = [ + ":compose_set_interface", ":fusion_interface", ":gml_st_ops_inc_gen", "@llvm-project//llvm:Support", @@ -2474,14 +2478,49 @@ cc_library( deps = [ ":fusion_interface", ":gml_st", - ":hlo", - ":map_mhlo_to_scalar_op", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithmeticDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorUtils", + ], +) + +gentbl_cc_library( + name = "compose_set_interface_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.td", + deps = ["@llvm-project//mlir:OpBaseTdFiles"], +) + +cc_library( + name = "compose_set_interface", + srcs = [ + "include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.cc.inc", + "include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h.inc", + "lib/Dialect/gml_st/transforms/compose_set_interface.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h", + ], + includes = ["include"], + deps = [ + ":compose_set_interface_inc_gen", + "@llvm-project//mlir:IR", ], ) @@ -2584,6 +2623,7 @@ cc_library( ], includes = ["include"], deps = [ + ":compose_set_interface", ":gml_st", ":gml_st_passes_inc_gen", "@llvm-project//llvm:Support", @@ -2595,6 +2635,28 @@ cc_library( ], ) +cc_library( + name = "gml_st_vectorization", + srcs = [ + "include/mlir-hlo/Dialect/gml_st/transforms/pass_detail.h", + "lib/Dialect/gml_st/transforms/vectorization.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/gml_st/transforms/passes.h", + ], + includes = ["include"], + deps = [ + ":gml_st", + ":gml_st_passes_inc_gen", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:VectorDialect", + ], +) + cc_library( name = "legalize_mhlo_to_gml", srcs = [ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo-c/Attributes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo-c/Attributes.h index e8499070f76338..388316738ab320 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo-c/Attributes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo-c/Attributes.h @@ -288,6 +288,23 @@ MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirStringRef mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr); +// +// RngDistributionAttr. +// +// Creates a new RngDistribution attribute with the given 'distribution' string +// parameter. +MLIR_CAPI_EXPORTED MlirAttribute +mlirMhloRngDistributionAttrGet(MlirContext ctx, MlirStringRef distribution); + +// Returns true if the given attribute is a RngDistribution attribute. +MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsARngDistributionAttr( + MlirAttribute attr); + +// Returns the rng-distribution string associated with RngDistribution +// attribute. +MLIR_CAPI_EXPORTED MlirStringRef +mlirMhloRngDistributionAttrGetRngDistribution(MlirAttribute attr); + // // RngAlgorithmAttr. // diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_extension_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_extension_ops.td index dd5831a9d2a5ed..025cd7f52f74ea 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_extension_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_extension_ops.td @@ -20,20 +20,84 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops_base.td" include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td" -def GMLST_DynamicBroadcastInDimOp : GMLST_Op<"dynamic_broadcast_in_dim", [ - NoSideEffect, DeclareOpInterfaceMethods]> { +def GMLST_DynamicBroadcastInDimOp : GMLST_Op<"dynamic_broadcast_in_dim", + [NoSideEffect, + TypesMatchWith<"result and init types match", "init", "result", "$_self">, + DeclareOpInterfaceMethods]> { let summary = [{Destination-style twin for `mhlo.dynamic_broadcast_in_dim`}]; let arguments = (ins - AnyTensor:$init, AnyTensor:$operand, + AnyTensor:$init, I64ElementsAttr:$broadcast_dimensions, OptionalAttr:$known_expanding_dimensions, OptionalAttr:$known_nonexpanding_dimensions ); let results = (outs AnyTensor:$result); let assemblyFormat = [{ - $init `,` $operand `,` custom($broadcast_dimensions) - attr-dict `:` type($init) `,` type($operand) `->` type($result) + `ins` `(` $operand `:` type($operand) `)` + `outs` `(` $init `:` type($init) `)` + attr-dict + }]; + let hasVerifier = 0; +} + +def GMLST_GatherOp : GMLST_Op<"gather", [ + NoSideEffect, + TypesMatchWith<"result and init types match", "init", "result", "$_self"> +]> { + let summary = "Destination-style twin for `mhlo.gather`"; + let description = [{ + Does not currently support the full interface of mhlo.gather. In particular: + - index_vector_dim is start_indices.shape.rank - 1 + - slice_sizes is [1,1,...] + - offset_dims is [] + - collapsed_slice_dims is range(operand.shape.rank) + - start_index_map is range(slice_sizes.shape[index_vector_dim]) + }]; + let arguments = (ins + AnyTensor:$operand, + I64Tensor:$start_indices, + AnyTensor:$init + ); + let results = (outs AnyTensor:$result); + let assemblyFormat = [{ + `ins` `(` $operand `:` type($operand) `,` + $start_indices `:` type($start_indices) `)` + `outs` `(` $init `:` type($init) `)` + attr-dict + }]; + let hasVerifier = 0; +} + +def GMLST_ScatterOp : GMLST_Op<"scatter", [ + NoSideEffect, + TypesMatchWith<"result and init types match", "init", "result", "$_self">, + TypesMatchWith<"result and operand types match", "operand", "result", + "$_self">]> { + let summary = "Destination-style twin for `mhlo.scatter`"; + let description = [{ + Caveats: + - the variadic case is not supported. + - update_computation is sum. + - Only point updates are supported + - update_window_dims is [] + - inserted_window_dims is range(operand.shape.rank) + - scatter_dims_to_operand_dims is range(scatter_indices.shape.rank) + - index_vector_dim is scatter_indices.shape.rank-1 + }]; + let arguments = (ins + AnyTensor:$operand, + I64Tensor:$scatter_indices, + AnyTensor:$updates, + AnyTensor:$init + ); + let results = (outs AnyTensor:$result); + let assemblyFormat = [{ + `ins` `(` $operand `:` type($operand) `,` + $scatter_indices `:` type($scatter_indices) `,` + $updates `:` type($updates) `)` + `outs` `(` $init `:` type($init) `)` + attr-dict }]; let hasVerifier = 0; } diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h index 9233d65cbe0427..dcbfc75826902c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef MLIR_HLO_DIALECT_GML_ST_IR_GML_ST_OPS_H #define MLIR_HLO_DIALECT_GML_ST_IR_GML_ST_OPS_H +#include "mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h" #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td index 18bca6a68c81e2..a9f1fc67e7ea4a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td @@ -34,34 +34,37 @@ def GMLST_MaterializeOp : GMLST_Op<"materialize", let hasVerifier = 0; } -def GMLST_DimOp : GMLST_Op<"dim", []> { - let summary = [{Returns the size of a tile in a given dimension}]; +def GMLST_OffsetOp : GMLST_Op<"offset", []> { + let summary = [{Returns the offset of a tile in a given dimension}]; let arguments = (ins GMLST_TileType:$tile, Index:$dim); let results = (outs Index:$result); let assemblyFormat = [{ $tile `[` $dim `]` attr-dict `:` qualified(type($tile)) }]; let hasVerifier = 0; + let hasFolder = 1; } -def GMLST_OffsetOp : GMLST_Op<"offset", []> { - let summary = [{Returns the offset of a tile in a given dimension}]; +def GMLST_SizeOp : GMLST_Op<"size", []> { + let summary = [{Returns the size of a tile in a given dimension}]; let arguments = (ins GMLST_TileType:$tile, Index:$dim); let results = (outs Index:$result); let assemblyFormat = [{ $tile `[` $dim `]` attr-dict `:` qualified(type($tile)) }]; let hasVerifier = 0; + let hasFolder = 1; } -def GMLST_SizeOp : GMLST_Op<"size", []> { - let summary = [{Returns the size of a tile in a given dimension}]; +def GMLST_StrideOp : GMLST_Op<"stride", []> { + let summary = [{Returns the stride of a tile in a given dimension}]; let arguments = (ins GMLST_TileType:$tile, Index:$dim); let results = (outs Index:$result); let assemblyFormat = [{ $tile `[` $dim `]` attr-dict `:` qualified(type($tile)) }]; let hasVerifier = 0; + let hasFolder = 1; } class GMLST_LoopLikeOp traits = []> @@ -265,10 +268,9 @@ def GMLST_SetYieldOp : GMLST_Op<"set_yield", [NoSideEffect, ReturnLike, Variadic:$dsts, Variadic:$sets); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let assemblyFormat = [{ - attr-dict ($srcs^ `into` $dsts `[` $sets `]` - `:` type($srcs) `into` type($dsts) `[` type($sets) `]`)? - }]; + + + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_set_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_set_ops.td index 0d4eb54b721dc7..a13bef6f275423 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_set_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_set_ops.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops_base.td" +include "mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.td" // Base class of all subset types. class GMLST_Set : TypeDef { } @@ -37,7 +38,8 @@ def GMLST_TileType : GMLST_Set<"Tile"> { }]; } -def GMLST_PointType : GMLST_Set<"Point"> { +def GMLST_PointType : GMLST_Set<"Point">, + BuildableType<"$_builder.getType<::mlir::gml_st::PointType>()"> { let mnemonic = "point"; let summary = "Type that represents a point of an index space."; let assemblyFormat = ""; @@ -57,10 +59,16 @@ def GMLST_SpaceOp : GMLST_Op<"space", [NoSideEffect, custom($dynamic_sizes, $static_sizes) attr-dict `:` qualified(type($result)) }]; + let extraClassDeclaration = [{ + unsigned getNumDynamicEntriesUpToIdx(unsigned idx); + mlir::Value getDynamicSize(unsigned idx); + }]; let hasVerifier = 1; } -def GMLST_PointOp : GMLST_Op<"point", [NoSideEffect]> { +def GMLST_PointOp : GMLST_Op<"point", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { let arguments = (ins AnySet:$superset, Variadic:$dynamic_indices, I64ArrayAttr:$static_indices); @@ -76,9 +84,12 @@ def GMLST_PointOp : GMLST_Op<"point", [NoSideEffect]> { let hasVerifier = 1; } -def GMLST_TileOp : GMLST_Op<"tile", [NoSideEffect, AttrSizedOperandSegments, +def GMLST_TileOp : GMLST_Op<"tile", [ + NoSideEffect, + AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins GMLST_TileType:$superset, Variadic:$offsets, Variadic:$sizes, @@ -130,4 +141,24 @@ def GMLST_CollapseTileOp : GMLST_Op<"collapse_tile", [NoSideEffect, let hasVerifier = 0; } +def GMLST_TransposeTileOp : GMLST_Op<"transpose_tile", [ + NoSideEffect, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Transposes a tile."; + let description = [{ + Transposes the argument tile by applying the permutation to the dimensions, + offsets and strides of the operand tile. + }]; + let arguments = (ins + GMLST_TileType:$superset, + DenseI64ArrayAttr:$permutation); + let results = (outs GMLST_TileType:$result); + let assemblyFormat = [{ + $superset `,` $permutation attr-dict `:` + qualified(type($superset)) `to` qualified(type($result)) + }]; + let hasVerifier = 1; +} + #endif // GML_ST_SET_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt index 7313b0f0cf1838..84cae84f7ee35c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt @@ -26,3 +26,8 @@ set(LLVM_TARGET_DEFINITIONS fusion_interface.td) mlir_tablegen(fusion_interface.h.inc -gen-op-interface-decls) mlir_tablegen(fusion_interface.cc.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRGmlStFusionInterfaceIncGen) + +set(LLVM_TARGET_DEFINITIONS compose_set_interface.td) +mlir_tablegen(compose_set_interface.h.inc -gen-op-interface-decls) +mlir_tablegen(compose_set_interface.cc.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRGmlStComposeSetInterfaceIncGen) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h new file mode 100644 index 00000000000000..091a5825cbc8b4 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h @@ -0,0 +1,25 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_COMPOSE_SET_INTERFACE_H +#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_COMPOSE_SET_INTERFACE_H + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h.inc" + +#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_COMPOSE_SET_INTERFACE_H diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.td new file mode 100644 index 00000000000000..3694ac89f51c94 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.td @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef COMPOSE_SET_INTERFACE +#define COMPOSE_SET_INTERFACE + +include "mlir/IR/OpBase.td" + +def ComposeSetInterface : OpInterface<"ComposeSetInterface"> { + let description = [{ + This interface should be implemented by all set operations that can be + composed with their superset operand. + }]; + let cppNamespace = "::mlir::gml_st"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Returns a composition of this set with its superset operand.", + /*retTy=*/"::mlir::Value", + /*methodName=*/"compose", + /*args=*/(ins "OpBuilder&":$builder)>, + ]; +} + +#endif // COMPOSE_SET_INTERFACE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td index e5292537b2d3ce..0199f65d8bc61a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td @@ -18,7 +18,7 @@ limitations under the License. include "mlir/IR/OpBase.td" -def FusionIterface : OpInterface<"FusionIterface"> { +def FusionInterface : OpInterface<"FusionInterface"> { let description = [{ TBD }]; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/pass_detail.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/pass_detail.h index 0ba57051eff5ff..20d1d42b59b4da 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/pass_detail.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/pass_detail.h @@ -27,6 +27,9 @@ class FuncOp; namespace scf { class SCFDialect; } // namespace scf +namespace vector { +class VectorDialect; +} // namespace vector } // namespace mlir #define GEN_PASS_CLASSES diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h index 26a4506e37d682..1de73507bff429 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h @@ -45,6 +45,10 @@ std::unique_ptr> createGmlStToScfPass(); // its body. std::unique_ptr> CreateTiledLoopBufferizePass(); +/// Pass to vectorize linalg.generic ops tiled to gml_st.parallel and gml_st.for +/// loops. +std::unique_ptr> createVectorizeGmlStLoopsPass(); + #define GEN_PASS_REGISTRATION #include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td index f8def99b8a3701..745ba18970fcf7 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td @@ -45,7 +45,17 @@ def GmlStToScf : Pass<"gml-st-to-scf", "mlir::func::FuncOp"> { let dependentDialects = ["::mlir::scf::SCFDialect"]; } -def TiledLoopBufferizePass : Pass<"gml-tiled-loop-bufferize", "mlir::func::FuncOp"> { +def TiledLoopBufferizePass : + Pass<"gml-tiled-loop-bufferize", "mlir::func::FuncOp"> { let summary = "Pass to bufferize linalg.tiled_loop with the ops inside it."; let constructor = "::mlir::gml_st::CreateTiledLoopBufferizePass()"; } + +def VectorizeGmlStLoopsPass : + Pass<"vectorize-gml-st-loops", "mlir::func::FuncOp"> { + let summary = + "Pass to vectorize linalg.generic ops tiled to gml_st.parallel and " # + "gml_st.for loops."; + let constructor = "::mlir::gml_st::createVectorizeGmlStLoopsPass()"; + let dependentDialects = ["::mlir::vector::VectorDialect"]; +} diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td index ea576447cbb282..22e9eda5219d30 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td @@ -152,7 +152,7 @@ def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperan See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. }]; } -def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer> { +def LHLO_CosineOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer> { let summary = "Cos operator"; let description = [{ Returns `Cos(operand)` element-wise. @@ -320,7 +320,7 @@ def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign"> { https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. }]; } -def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer> { +def LHLO_SineOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer> { let summary = "Sin operator"; let description = [{ Returns `Sin(operand)` element-wise. @@ -851,7 +851,7 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []> { ); } -def LHLO_ConvOp : LHLO_Op<"convolution", []> { +def LHLO_ConvolutionOp : LHLO_Op<"convolution", []> { let summary = "Convolution operator"; let description = [{ Computes a convolution of the kind used in neural networks. diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h index 4268e59a532a16..5653aa9e55fb61 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h @@ -51,10 +51,10 @@ MAP_HLO_TO_LHLO(ConstantOp); MAP_HLO_TO_LHLO(CompareOp); MAP_HLO_TO_LHLO(ComplexOp); MAP_HLO_TO_LHLO(ConcatenateOp); -MAP_HLO_TO_LHLO(ConvOp); +MAP_HLO_TO_LHLO(ConvolutionOp); MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(CopyOp); -MAP_HLO_TO_LHLO(CosOp); +MAP_HLO_TO_LHLO(CosineOp); MAP_HLO_TO_LHLO(CustomCallOp); MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DotOp); @@ -92,7 +92,7 @@ MAP_HLO_TO_LHLO(ShiftLeftOp); MAP_HLO_TO_LHLO(ShiftRightArithmeticOp); MAP_HLO_TO_LHLO(ShiftRightLogicalOp); MAP_HLO_TO_LHLO(SignOp); -MAP_HLO_TO_LHLO(SinOp); +MAP_HLO_TO_LHLO(SineOp); MAP_HLO_TO_LHLO(SliceOp); MAP_HLO_TO_LHLO(SqrtOp); MAP_HLO_TO_LHLO(SubOp); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h index 98f35554e6910f..2a46aa33826daa 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h @@ -49,10 +49,10 @@ MAP_LHLO_TO_HLO(ConstantOp); MAP_LHLO_TO_HLO(CompareOp); MAP_LHLO_TO_HLO(ComplexOp); MAP_LHLO_TO_HLO(ConcatenateOp); -MAP_LHLO_TO_HLO(ConvOp); +MAP_LHLO_TO_HLO(ConvolutionOp); MAP_LHLO_TO_HLO(ConvertOp); MAP_LHLO_TO_HLO(CopyOp); -MAP_LHLO_TO_HLO(CosOp); +MAP_LHLO_TO_HLO(CosineOp); MAP_LHLO_TO_HLO(CustomCallOp); MAP_LHLO_TO_HLO(DivOp); MAP_LHLO_TO_HLO(DotOp); @@ -89,7 +89,7 @@ MAP_LHLO_TO_HLO(ShiftLeftOp); MAP_LHLO_TO_HLO(ShiftRightArithmeticOp); MAP_LHLO_TO_HLO(ShiftRightLogicalOp); MAP_LHLO_TO_HLO(SignOp); -MAP_LHLO_TO_HLO(SinOp); +MAP_LHLO_TO_HLO(SineOp); MAP_LHLO_TO_HLO(SliceOp); MAP_LHLO_TO_HLO(SqrtOp); MAP_LHLO_TO_HLO(SubOp); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.td index 62c48442bff232..091d246baace71 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -125,27 +125,12 @@ def LHLOGPU_ConvForwardFusedSideInputOp : // LMHLO ops representing other library functions. //===----------------------------------------------------------------------===// -// output = alpha * (lhs * rhs) -// Verify: beta = 0.0 +// c = alpha * (a @ b) + beta * c def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { let arguments = (ins - Arg:$lhs, - Arg:$rhs, - Arg:$output, - DotDimensionNumbers:$dot_dimension_numbers, - HLO_PrecisionConfigAttr:$precision_config, - F64Attr:$alpha_real, - F64Attr:$alpha_imag, - OptionalAttr:$algorithm); -} - -// output = alpha(lhs * rhs) + beta * bias -def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { - let arguments = (ins - Arg:$lhs, - Arg:$rhs, - Arg:$bias, - Arg:$output, + Arg:$a, + Arg:$b, + Arg:$c, DotDimensionNumbers:$dot_dimension_numbers, HLO_PrecisionConfigAttr:$precision_config, F64Attr:$alpha_real, diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index ea99ef314fcb04..38ea32277b5623 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -219,7 +219,8 @@ def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros", https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. }]; } -def HLO_CosOp: HLO_UnaryElementwiseOp<"cosine", + +def HLO_CosineOp: HLO_UnaryElementwiseOp<"cosine", [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { let summary = "Cos operator"; let description = [{ @@ -228,7 +229,10 @@ def HLO_CosOp: HLO_UnaryElementwiseOp<"cosine", See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. }]; + + let hasCustomHLOConverter = 1; } + def HLO_ExpOp: HLO_UnaryElementwiseOp<"exponential", [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { let summary = "Exponential operator"; @@ -416,7 +420,8 @@ def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", }]; let hasFolder = 1; } -def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", + +def HLO_SineOp: HLO_UnaryElementwiseOp<"sine", [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { let summary = "Sin operator"; let description = [{ @@ -425,7 +430,9 @@ def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. }]; + let hasCustomHLOConverter = 1; } + def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { let summary = "Square-root operator"; @@ -1196,7 +1203,7 @@ def HLO_DynamicSliceOp: HLO_Op<"dynamic_slice", let hasVerifier = 1; } -def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", +def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic_update_slice", [NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>, AllShapesMatch<["operand", "result"]>]> { let summary = "Dynamic Update Slice operator"; @@ -1561,7 +1568,7 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", let hasVerifier = 1; } -def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]> { +def HLO_ConvolutionOp : HLO_Op<"convolution", [NoSideEffect]> { let summary = "Convolution operator"; let description = [{ Computes a convolution of the kind used in neural networks. @@ -2134,7 +2141,7 @@ def HLO_TraceOp: HLO_Op<"trace", []> { def HLO_TransposeOp: HLO_ShapedInterfaceOp<"transpose", [NoSideEffect, SameOperandsAndResultElementType, - InferTensorTypeWithReify]> { + DeclareOpInterfaceMethods]> { let summary = "Transpose operator"; let description = [{ Permutes the dimensions of `operand` according to the given `permutation`. @@ -2310,53 +2317,35 @@ def HLO_OptimizationBarrierOp : HLO_Op<"optimization_barrier", // MHLO RNG Operators. //===----------------------------------------------------------------------===// -def HLO_RngUniformOp : HLO_Op<"rng_uniform", [InferTensorTypeWithReify, AllElementTypesMatch<["a", "b", "result"]>]> { +def HLO_RngOp : HLO_Op<"rng", [InferTensorTypeWithReify, AllElementTypesMatch<["a", "b", "result"]>]> { let summary = "RNG with uniform distribution."; let description = [{ Constructs an output of a given shape with random numbers generated - following the uniform distribution over the interval `[a,b)`. The parameters - and output element type have to be a boolean type, an integral type or a - floating point types, and the types have to be consistent. + following the given `rng_distribution` with two parameters: + `UNIFORM`: the uniform distribution over the interval `[a,b)`. The parameters + and output element type have to be a boolean type, an integral type or a + floating point types, and the types have to be consistent. - See https://www.tensorflow.org/xla/operation_semantics#rnguniform. + See https://www.tensorflow.org/xla/operation_semantics#rnguniform. + + `NORMAL`: the normal distribution with parameters `mu` (=`a`) and + `sigma` (=`b`). The parameters and output shape have to have a + floating point elemental type. The parameters furthermore have + to be scalar valued. + + See https://www.tensorflow.org/xla/operation_semantics#rngnormal. }]; let arguments = (ins 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$a, 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$b, - HLO_DimensionTensor:$shape + HLO_DimensionTensor:$shape, + HLO_RngDistributionAttr:$rng_distribution ); let results = (outs HLO_PredIntOrFpTensor:$result); let hasCustomHLOConverter = 1; - - let extraClassDeclaration = [{ - // Returns whether the return types are compatible. - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { - return succeeded(::mlir::verifyCompatibleShapes(l, r)); - } - }]; -} - -def HLO_RngNormalOp : HLO_Op<"rng_normal", [InferTensorTypeWithReify, AllElementTypesMatch<["mu", "sigma", "result"]>]> { - let summary = "RNG with normal distribution."; - let description = [{ - Constructs an output of a given shape with random numbers generated - following the normal distribution with parameters `mu` and `sigma`. The - parameters and output shape have to have a floating point elemental type. - The parameters furthermore have to be scalar valued. - - See https://www.tensorflow.org/xla/operation_semantics#rngnormal. - }]; - let arguments = (ins - 0DTensorOf<[HLO_Float]>:$mu, - 0DTensorOf<[HLO_Float]>:$sigma, - HLO_DimensionTensor:$shape - ); - - let results = (outs HLO_FpTensor:$result); - - let hasCustomHLOConverter = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ // Returns whether the return types are compatible. diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td index 483008361033c7..7be32da34cb7fb 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td @@ -190,6 +190,23 @@ def HLO_FusionKind : I32EnumAttr<"FusionKind", "fusion kind", [ let cppNamespace = "::mlir::mhlo"; } +def HLO_RNG_DISTRIBUTION_UNIFORM : I32EnumAttrCase<"UNIFORM", 1>; +def HLO_RNG_DISTRIBUTION_NORMAL : I32EnumAttrCase<"NORMAL", 2>; + +def HLO_RNG_DISTRIBUTION : I32EnumAttr<"RngDistribution", + "XLA PRNG distribution to be used.", + [ + HLO_RNG_DISTRIBUTION_UNIFORM, + HLO_RNG_DISTRIBUTION_NORMAL + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def HLO_RngDistributionAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def HLO_FusionKindAttr : EnumAttr; def HLO_RNG_ALGORITHM_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h index 6f05c914205e0a..152aa209b80556 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h @@ -37,6 +37,9 @@ class MemRefDialect; namespace tensor { class TensorDialect; } // namespace tensor +namespace shape { +class ShapeDialect; +} // namespace shape namespace mhlo { class MhloDialect; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h index 687b80beef1aca..cd040a8e56989a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h @@ -71,7 +71,7 @@ struct MhloToScalarOp { using UOp = ::mlir::math::CountLeadingZerosOp; }; template <> -struct MhloToScalarOp { +struct MhloToScalarOp { using FOp = ::mlir::math::CosOp; using COp = ::mlir::complex::CosOp; }; @@ -150,7 +150,7 @@ struct MhloToScalarOp { using COp = ::mlir::complex::SqrtOp; }; template <> -struct MhloToScalarOp { +struct MhloToScalarOp { using FOp = ::mlir::math::SinOp; using COp = ::mlir::complex::SinOp; }; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index 156dacbabdd86c..adf4a503d5a307 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -216,6 +216,8 @@ def GroupReductionDimensionsPass def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "func::FuncOp"> { let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; let constructor = "createTestUnfuseBatchNormPass()"; + + let dependentDialects = ["arith::ArithmeticDialect", "shape::ShapeDialect", "tensor::TensorDialect"]; } def ExpandHloTuplesPass : Pass<"expand-hlo-tuples", "ModuleOp"> { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index c99506125c4eef..40b4073e1f4c8b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -101,9 +101,22 @@ void populateDynamicShapeFusionPatterns(MLIRContext *context, RewritePatternSet *patterns); // Populate a collection of conversion patterns for un-fusing -// batch_norm_inference and batch_norm_training into constituent HLO ops. -void populateUnfuseBatchNormPatterns(MLIRContext *context, - RewritePatternSet *patterns); +// batch_norm_inference into constituent HLO ops. +void populateUnfuseBatchNormInferencePattern(MLIRContext *context, + RewritePatternSet *patterns); + +// Populate a collection of conversion patterns for un-fusing +// batch_norm_training into constituent HLO ops. +void populateUnfuseBatchNormTrainingPattern(MLIRContext *context, + RewritePatternSet *patterns); + +// Populate a collection of conversion patterns for un-fusing +// // batch_norm_inference and batch_norm_training into constituent HLO ops. +inline void populateUnfuseBatchNormPatterns(MLIRContext *context, + RewritePatternSet *patterns) { + populateUnfuseBatchNormInferencePattern(context, patterns); + populateUnfuseBatchNormTrainingPattern(context, patterns); +} // Populates patterns that translate the trigonometric operations from the // standard dialect to approximations that do not use intrinsics. diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/gml_st_pipeline.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/gml_st_pipeline.h index ad820552b4b2e0..fa6f61b8849625 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/gml_st_pipeline.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Transforms/gml_st_pipeline.h @@ -24,6 +24,9 @@ struct GmlStPipelineOptions : public mlir::PassPipelineOptions { ListOption tileSizes{ *this, "tile-sizes", llvm::cl::desc("Tiling sizes for the tiling pass")}; + Option fuse{*this, "fuse", + llvm::cl::desc("Fuse into GmlSt loop nests."), + llvm::cl::init(false)}; Option lowerToLoops{ *this, "lower-to-loops", llvm::cl::desc("Enable bufferization and lowering to SCF dialect for " diff --git a/tensorflow/compiler/mlir/hlo/lib/CAPI/Attributes.cc b/tensorflow/compiler/mlir/hlo/lib/CAPI/Attributes.cc index 74834d60c23997..61fb4ea14aba75 100644 --- a/tensorflow/compiler/mlir/hlo/lib/CAPI/Attributes.cc +++ b/tensorflow/compiler/mlir/hlo/lib/CAPI/Attributes.cc @@ -519,6 +519,29 @@ MlirStringRef mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr) { unwrap(attr).cast().getValue())); } +// +// RngDistributionAttr. +// + +MlirAttribute mlirMhloRngDistributionAttrGet(MlirContext ctx, + MlirStringRef distribution) { + llvm::Optional rngDistribution = + mlir::mhlo::symbolizeRngDistribution(unwrap(distribution)); + if (!rngDistribution) llvm_unreachable("Invalid rng-distribution specified."); + return wrap(mlir::mhlo::RngDistributionAttr::get(unwrap(ctx), + rngDistribution.getValue())); +} + +bool mlirMhloAttributeIsARngDistributionAttr(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirStringRef mlirMhloRngDistributionAttrGetRngDistribution( + MlirAttribute attr) { + return wrap(mlir::mhlo::stringifyRngDistribution( + unwrap(attr).cast().getValue())); +} + // // RngAlgorithmAttr. // diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/CMakeLists.txt index 816b7ea2bf5220..d1b2a3231c7207 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_dialect_library(GmlStDialect DEPENDS MLIRgml_st_opsIncGen + MLIRGmlStComposeSetInterfaceIncGen MLIRGmlStFusionInterfaceIncGen LINK_LIBS PUBLIC diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc index 62df69fcb33798..548d98afd97c3c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc @@ -15,6 +15,9 @@ limitations under the License. #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" +#include + +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -1253,6 +1256,19 @@ LogicalResult SpaceOp::verify() { dynamic_sizes(), ShapedType::isDynamic); } +unsigned SpaceOp::getNumDynamicEntriesUpToIdx(unsigned idx) { + return std::count_if(static_sizes().begin(), static_sizes().begin() + idx, + [&](const mlir::Attribute size) { + return mlir::ShapedType::isDynamic( + size.cast().getInt()); + }); +} + +mlir::Value SpaceOp::getDynamicSize(unsigned idx) { + auto numDynamic = getNumDynamicEntriesUpToIdx(idx); + return dynamic_sizes()[numDynamic]; +} + //===----------------------------------------------------------------------===// // PointOp //===----------------------------------------------------------------------===// @@ -1276,10 +1292,8 @@ LogicalResult PointOp::verify() { // Check whether the known indices are in-bounds of known dimension sizes. for (auto dimAndIndex : llvm::zip(tileShape, static_indices())) { auto dimSize = std::get<0>(dimAndIndex); - auto index = std::get<1>(dimAndIndex) - .dyn_cast() - .getValue() - .getSExtValue(); + auto index = + std::get<1>(dimAndIndex).dyn_cast().getInt(); if (index == ShapedType::kDynamicStrideOrOffset) continue; if (index < 0) { return emitOpError("expected index = ") @@ -1293,6 +1307,7 @@ LogicalResult PointOp::verify() { } return success(); } + // //===----------------------------------------------------------------------===// // TileOp @@ -1381,6 +1396,169 @@ LogicalResult TileOp::verify() { return success(); } +namespace { + +OpFoldResult multiplyOperandsOrIntegers(OpBuilder &builder, Location loc, + OpFoldResult lhs, OpFoldResult rhs) { + // Both operands are static. + if (lhs.is() && rhs.is()) { + return builder.getI64IntegerAttr( + lhs.get().cast().getInt() * + rhs.get().cast().getInt()); + } + + // Exploit commutativity and move static operand to the left (if any). + if (rhs.is()) std::swap(lhs, rhs); + + // Create constant if needed. + if (lhs.is()) { + int64_t lhsInt = lhs.get().cast().getInt(); + + // Exploit static operand if possible. + if (lhsInt == 0) return lhs; + if (lhsInt == 1) return rhs; + + lhs = builder.create(loc, lhsInt).getResult(); + } + + // Multiply. + return builder.create(loc, lhs.get(), rhs.get()) + .getResult(); +} + +OpFoldResult addOperandsOrIntegers(OpBuilder &builder, Location loc, + OpFoldResult lhs, OpFoldResult rhs) { + // Both operands are static. + if (lhs.is() && rhs.is()) { + return builder.getI64IntegerAttr( + lhs.get().cast().getInt() + + rhs.get().cast().getInt()); + } + + // Exploit commutativity and move static operand to the left (if any). + if (rhs.is()) std::swap(lhs, rhs); + + // Create constant if needed. + if (lhs.is()) { + int64_t lhsInt = lhs.get().cast().getInt(); + + // Exploit static operand if possible. + if (lhsInt == 0) return rhs; + + lhs = builder.create(loc, lhsInt).getResult(); + } + + // Add. + return builder.create(loc, lhs.get(), rhs.get()) + .getResult(); +} + +// Compose offsets with newOffset = supersetOffset + supersetStride * offset. +SmallVector composeOffsets( + const llvm::SmallVectorImpl &supersetOffsets, + const llvm::SmallVectorImpl &supersetStrides, + const llvm::SmallVectorImpl &offsets, Location loc, + OpBuilder &builder) { + SmallVector composedOffsets; + for (auto it : llvm::zip(supersetOffsets, supersetStrides, offsets)) { + composedOffsets.push_back(addOperandsOrIntegers( + builder, loc, std::get<0>(it), + multiplyOperandsOrIntegers(builder, loc, std::get<1>(it), + std::get<2>(it)))); + } + return composedOffsets; +} + +// Compose strides with newStride = supersetStride * stride. +SmallVector composeStrides( + OpBuilder &builder, Location loc, + const llvm::SmallVectorImpl &supersetStrides, + const llvm::SmallVectorImpl &strides) { + SmallVector composedStrides; + for (auto it : llvm::zip(supersetStrides, strides)) { + composedStrides.push_back(multiplyOperandsOrIntegers( + builder, loc, std::get<0>(it), std::get<1>(it))); + } + return composedStrides; +} + +} // namespace + +Value TileOp::compose(OpBuilder &builder) { + auto supersetOp = llvm::dyn_cast_or_null(superset().getDefiningOp()); + if (!supersetOp) return {}; + + // Compose offsets with newOffset = supersetOffset + supersetStride * + // offset. + auto loc = getLoc(); + auto composedOffsets = decomposeMixedStridesOrOffsets( + builder, + composeOffsets(supersetOp.getMixedOffsets(), supersetOp.getMixedStrides(), + getMixedOffsets(), loc, builder)); + + // Compose strides with newStride = supersetStride * stride. + auto composedStrides = decomposeMixedStridesOrOffsets( + builder, composeStrides(builder, loc, supersetOp.getMixedStrides(), + getMixedStrides())); + + // Build the composed tile op. + return builder.create(loc, supersetOp.superset(), + composedOffsets.second, sizes(), + composedStrides.second, composedOffsets.first, + static_sizes(), composedStrides.first); +} + +//===----------------------------------------------------------------------===// +// PointOp +//===----------------------------------------------------------------------===// + +namespace { + +// TODO(frgossen): Move this upstream to the ViewLikeInterface +SmallVector getMixedImpl(ArrayAttr staticValues, + ValueRange dynamicValues, + const int64_t dynamicValuePlaceholder) { + int64_t idxDynamic = 0; + SmallVector result; + for (const auto &staticAttr : staticValues) { + int64_t staticInt = staticAttr.cast().getInt(); + if (staticInt == dynamicValuePlaceholder) { + result.push_back(dynamicValues[idxDynamic++]); + } else { + result.push_back(staticAttr); + } + } + return result; +} + +// TODO(frgossen): Move this upstream to the ViewLikeInterface +SmallVector getMixedStridesOrOffsets(ArrayAttr staticValues, + ValueRange dynamicValues) { + return getMixedImpl(staticValues, dynamicValues, + ShapedType::kDynamicStrideOrOffset); +} + +} // namespace + +Value PointOp::compose(OpBuilder &builder) { + auto supersetOp = llvm::dyn_cast_or_null(superset().getDefiningOp()); + if (!supersetOp) return {}; + + // Compose offsets with newOffset = supersetOffset + supersetStride * + // offset. + auto loc = getLoc(); + auto composedOffsets = decomposeMixedStridesOrOffsets( + builder, + composeOffsets( + supersetOp.getMixedOffsets(), supersetOp.getMixedStrides(), + getMixedStridesOrOffsets(static_indices(), dynamic_indices()), loc, + builder)); + + // Build the composed point op. + return builder.create(loc, supersetOp.superset(), + composedOffsets.second, composedOffsets.first); +} + //===----------------------------------------------------------------------===// // CollapseTileOp //===----------------------------------------------------------------------===// @@ -1406,12 +1584,191 @@ LogicalResult CollapseTileOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// TransposeTileOp +//===----------------------------------------------------------------------===// + +LogicalResult TransposeTileOp::inferReturnTypes( + MLIRContext *ctx, Optional /*loc*/, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // Get argument tile type. + TransposeTileOp::Adaptor adaptor(operands, attributes, regions); + auto argTy = adaptor.superset().getType().dyn_cast(); + if (!argTy) return failure(); + auto argShape = argTy.getShape(); + + // Derive result shape. + SmallVector shape = llvm::to_vector(llvm::map_range( + adaptor.permutation(), [&](const auto &d) { return argShape[d]; })); + + auto resultTy = TileType::get(ctx, shape); + inferredReturnTypes.push_back(resultTy); + return success(); +} + +Value TransposeTileOp::compose(OpBuilder &builder) { + // We can compose with a TileOp operand which has a SpaceOp operand, or + // compose with a SpaceOp operand. transpose_tile(tile(space, offsets, sizes, + // strides)) is replaced by tile(transpose(space), transpose(offsets), + // transpose(sizes), transpose(strides)). transpose_tile(space) is replaced by + // transpose(space). + Operation *definingOp = superset().getDefiningOp(); + auto spaceOp = llvm::dyn_cast_or_null(definingOp); + auto tileOp = llvm::dyn_cast_or_null(definingOp); + if (tileOp) { + spaceOp = + llvm::dyn_cast_or_null(tileOp.superset().getDefiningOp()); + } + if (!spaceOp) return {}; + + auto loc = getLoc(); + ArrayRef perm = permutation(); + int64_t rank = perm.size(); + + // Create a new space op that has the permutation applied. + SmallVector dynamicDims; + SmallVector staticDims; + SmallVector shape; + auto originalShape = spaceOp.getType().getShape(); + dynamicDims.reserve(spaceOp.dynamic_sizes().size()); + staticDims.reserve(rank); + shape.reserve(rank); + for (int64_t dim : perm) { + shape.push_back(originalShape[dim]); + staticDims.push_back(spaceOp.static_sizes()[dim]); + if (ShapedType::isDynamic(staticDims.back().cast().getInt())) { + dynamicDims.push_back(spaceOp.getDynamicSize(dim)); + } + } + auto spaceTy = builder.getType(shape); + Value newSpace = builder.create(loc, spaceTy, dynamicDims, + builder.getArrayAttr(staticDims)); + if (!tileOp) return newSpace; + + // Otherwise we need to apply the permutation to the 'tileOp' operand. + SmallVector inputTileOffsets, inputTileSizes, inputTileStrides; + SmallVector inputStaticOffsets, inputStaticSizes, inputStaticStrides; + inputStaticOffsets.reserve(rank); + inputStaticSizes.reserve(rank); + inputStaticStrides.reserve(rank); + inputTileOffsets.reserve(tileOp.offsets().size()); + inputTileSizes.reserve(tileOp.sizes().size()); + inputTileStrides.reserve(tileOp.strides().size()); + for (int64_t dim : perm) { + if (tileOp.isDynamicOffset(dim)) { + inputTileOffsets.push_back(tileOp.getDynamicOffset(dim)); + inputStaticOffsets.push_back(ShapedType::kDynamicStrideOrOffset); + } else { + inputStaticOffsets.push_back(tileOp.getStaticOffset(dim)); + } + if (tileOp.isDynamicSize(dim)) { + inputTileSizes.push_back(tileOp.getDynamicSize(dim)); + inputStaticSizes.push_back(ShapedType::kDynamicSize); + } else { + inputStaticSizes.push_back(tileOp.getStaticSize(dim)); + } + if (tileOp.isDynamicStride(dim)) { + inputTileStrides.push_back(tileOp.getDynamicStride(dim)); + inputStaticStrides.push_back(ShapedType::kDynamicStrideOrOffset); + } else { + inputStaticStrides.push_back(tileOp.getStaticStride(dim)); + } + } + + return builder.create(loc, getType(), newSpace, inputTileOffsets, + inputTileSizes, inputTileStrides, + builder.getI64ArrayAttr(inputStaticOffsets), + builder.getI64ArrayAttr(inputStaticSizes), + builder.getI64ArrayAttr(inputStaticStrides)); +} + +LogicalResult TransposeTileOp::verify() { + TileType type = getType(); + int64_t rank = type.getShape().size(); + // 'permutation' should have 'rank' elements. + if (permutation().size() != rank) { + return emitOpError("expected permutation attribute size = ") + << permutation().size() << " to match rank = " << rank; + } + // Verify that 'permutation' is in fact a permutation. + // Store where a certain number occurred. + SmallVector position(rank, -1); + for (const auto &it : llvm::enumerate(permutation())) { + int64_t dim = it.value(); + if (dim < 0 || dim >= rank) { + return emitOpError("permutation[") + << it.index() << "] = " << dim << " is outside of range [0, " + << rank - 1 << "]"; + } + if (position[dim] >= 0) { + return emitOpError( + "expected permutation attribute to contain no duplicate " + "values, but got ") + << dim << " at positions " << position[dim] << " and " + << it.index(); + } + position[dim] = it.index(); + } + return success(); +} + //===----------------------------------------------------------------------===// // SetYieldOp //===----------------------------------------------------------------------===// LogicalResult SetYieldOp::verify() { return success(); } +void SetYieldOp::print(OpAsmPrinter &p) { + p.printOptionalAttrDict(getOperation()->getAttrs()); + + for (auto zip : llvm::zip(srcs(), dsts(), sets())) { + Value src, dst, set; + std::tie(src, dst, set) = zip; + p << ' ' << src << " into " << dst << '[' << set << "] : " << src.getType() + << " into " << dst.getType() << '[' << set.getType() << ']'; + } +} + +ParseResult SetYieldOp::parse(OpAsmParser &parser, OperationState &result) { + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + + SmallVector srcs, dsts, sets; + SmallVector srcTypes, dstTypes, setTypes; + + auto parseElt = [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand src; + auto parseResult = parser.parseOptionalOperand(src, false); + + if (!parseResult.hasValue()) return success(); + srcs.push_back(src); + + if (parser.parseKeyword("into") || + parser.parseOperand(dsts.emplace_back()) || parser.parseLSquare() || + parser.parseOperand(sets.emplace_back()) || parser.parseRSquare()) + return failure(); + + if (parser.parseColon() || parser.parseType(srcTypes.emplace_back()) || + parser.parseKeyword("into") || + parser.parseType(dstTypes.emplace_back()) || parser.parseLSquare() || + parser.parseType(setTypes.emplace_back()) || parser.parseRSquare()) + return failure(); + return success(); + }; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseElt)) + return failure(); + + if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(dsts, dstTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(sets, setTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + + return success(); +} + //===----------------------------------------------------------------------===// // DynamicBroadcastInDimOp //===----------------------------------------------------------------------===// @@ -1501,10 +1858,67 @@ Value DynamicBroadcastInDimOp::fuse(Location loc, Value set, auto tiledResultTy = RankedTensorType::get(tileTy.getShape(), resultTy.getElementType()); return builder.create( - loc, tiledResultTy, tiledInit, tiledOperand, broadcast_dimensions(), + loc, tiledResultTy, tiledOperand, tiledInit, broadcast_dimensions(), known_expanding_dimensionsAttr(), known_nonexpanding_dimensionsAttr()); } +//===----------------------------------------------------------------------===// +// OffsetOp +//===----------------------------------------------------------------------===// + +OpFoldResult OffsetOp::fold(ArrayRef operands) { + auto idxAttr = operands[1].dyn_cast_or_null(); + if (!idxAttr) return {}; + + if (auto tileOp = tile().getDefiningOp()) { + auto idx = idxAttr.getInt(); + if (tileOp.isDynamicOffset(idx)) return tileOp.getDynamicOffset(idx); + + Builder b(idxAttr.getContext()); + return b.getIndexAttr(tileOp.getStaticOffset(idx)); + } + // TODO(unknown): Handle space op, as well. + return {}; +} + +//===----------------------------------------------------------------------===// +// SizeOp +//===----------------------------------------------------------------------===// + +OpFoldResult SizeOp::fold(ArrayRef operands) { + auto idxAttr = operands[1].dyn_cast_or_null(); + if (!idxAttr) return {}; + + if (auto tileOp = tile().getDefiningOp()) { + auto idx = idxAttr.getInt(); + if (tileOp.isDynamicSize(idx)) return tileOp.getDynamicSize(idx); + + Builder b(idxAttr.getContext()); + return b.getIndexAttr(tileOp.getStaticSize(idx)); + } + // TODO(unknown): Handle space op, as well. + return {}; +} + +//===----------------------------------------------------------------------===// +// StrideOp +//===----------------------------------------------------------------------===// + +OpFoldResult StrideOp::fold(ArrayRef operands) { + auto idxAttr = operands[1].dyn_cast_or_null(); + if (!idxAttr) return {}; + + if (auto tileOp = tile().getDefiningOp()) { + auto idx = idxAttr.getInt(); + if (tileOp.isDynamicStride(idx)) return tileOp.getDynamicStride(idx); + + Builder b(idxAttr.getContext()); + return b.getIndexAttr(tileOp.getStaticStride(idx)); + } + // TODO(unknown): Handle space op, as well. + return {}; +} + } // namespace gml_st } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt index 1aed638b2895f2..c1b75db23aa678 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt @@ -42,6 +42,16 @@ add_mlir_library(GmlStFusionInterfaceImpl MLIRSupport ) +add_mlir_library(GmlStComposeSetInterface + compose_set_interface.cc + + LINK_LIBS PUBLIC + MLIRIR + + DEPENDS + MLIRGmlStComposeSetInterfaceIncGen +) + add_mlir_library(GmlStBufferizableOpInterface bufferizable_op_interface_impl.cc @@ -59,6 +69,7 @@ add_mlir_library(GmlStPasses gml_st_to_scf.cc legalize_mhlo_to_gml.cc tiling.cc + vectorization.cc DEPENDS MLIRGmlStPassIncGen @@ -67,11 +78,16 @@ add_mlir_library(GmlStPasses Core LINK_LIBS PUBLIC + GmlStComposeSetInterface GmlStFusionInterface GmlStFusionInterfaceImpl + MLIRFuncDialect MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms MLIRPass MLIRSupport + MLIRVectorDialect ) add_mlir_library(GmlStTransforms diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_set_interface.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_set_interface.cc new file mode 100644 index 00000000000000..37962013c21bb9 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_set_interface.cc @@ -0,0 +1,18 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h" + +#include "mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_set_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_set_ops.cc index 468f26d6b50632..2235e0bfe0d3e6 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_set_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_set_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" +#include "mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h" #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h" #include "mlir-hlo/Dialect/gml_st/transforms/passes.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -29,117 +30,17 @@ namespace mlir { namespace gml_st { namespace { -OpFoldResult multiplyOperandsOrIntegers(PatternRewriter& rewriter, Location loc, - OpFoldResult lhs, OpFoldResult rhs) { - // Both operands are static. - if (lhs.is() && rhs.is()) { - return rewriter.getI64IntegerAttr( - lhs.get().cast().getInt() * - rhs.get().cast().getInt()); - } - - // Exploit commutativity and move static operand to the left (if any). - if (rhs.is()) std::swap(lhs, rhs); - - // Create constant if needed. - if (lhs.is()) { - int64_t lhsInt = lhs.get().cast().getInt(); - - // Exploit static operand if possible. - if (lhsInt == 0) return lhs; - if (lhsInt == 1) return rhs; - - lhs = rewriter.create(loc, lhsInt).getResult(); - } - - // Multiply. - return rewriter.create(loc, lhs.get(), rhs.get()) - .getResult(); -} - -OpFoldResult addOperandsOrIntegers(PatternRewriter& rewriter, Location loc, - OpFoldResult lhs, OpFoldResult rhs) { - // Both operands are static. - if (lhs.is() && rhs.is()) { - return rewriter.getI64IntegerAttr( - lhs.get().cast().getInt() + - rhs.get().cast().getInt()); - } - - // Exploit commutativity and move static operand to the left (if any). - if (rhs.is()) std::swap(lhs, rhs); - - // Create constant if needed. - if (lhs.is()) { - int64_t lhsInt = lhs.get().cast().getInt(); - - // Exploit static operand if possible. - if (lhsInt == 0) return rhs; - - lhs = rewriter.create(loc, lhsInt).getResult(); - } - - // Add. - return rewriter.create(loc, lhs.get(), rhs.get()) - .getResult(); -} - -// Compose offsets with newOffset = supersetOffset + supersetStride * offset. -SmallVector composeOffsets( - const llvm::SmallVectorImpl& supersetOffsets, - const llvm::SmallVectorImpl& supersetStrides, - const llvm::SmallVectorImpl& offsets, Location loc, - PatternRewriter& rewriter) { - SmallVector composedOffsets; - for (auto it : llvm::zip(supersetOffsets, supersetStrides, offsets)) { - composedOffsets.push_back(addOperandsOrIntegers( - rewriter, loc, std::get<0>(it), - multiplyOperandsOrIntegers(rewriter, loc, std::get<1>(it), - std::get<2>(it)))); - } - return composedOffsets; -} - -// Compose strides with newStride = supersetStride * stride. -SmallVector composeStrides( - PatternRewriter& rewriter, Location loc, - const llvm::SmallVectorImpl& supersetStrides, - const llvm::SmallVectorImpl& strides) { - SmallVector composedStrides; - for (auto it : llvm::zip(supersetStrides, strides)) { - composedStrides.push_back(multiplyOperandsOrIntegers( - rewriter, loc, std::get<0>(it), std::get<1>(it))); - } - return composedStrides; -} +struct ComposeSetPattern + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern< + ComposeSetInterface>::OpInterfaceRewritePattern; -struct ComposeTilesPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TileOp op, + LogicalResult matchAndRewrite(ComposeSetInterface iface, PatternRewriter& rewriter) const override { - auto supersetOp = - llvm::dyn_cast_or_null(op.superset().getDefiningOp()); - if (!supersetOp) return failure(); - - // Compose offsets with newOffset = supersetOffset + supersetStride * - // offset. - auto loc = op.getLoc(); - auto composedOffsets = decomposeMixedStridesOrOffsets( - rewriter, composeOffsets(supersetOp.getMixedOffsets(), - supersetOp.getMixedStrides(), - op.getMixedOffsets(), loc, rewriter)); - - // Compose strides with newStride = supersetStride * stride. - auto composedStrides = decomposeMixedStridesOrOffsets( - rewriter, composeStrides(rewriter, loc, supersetOp.getMixedStrides(), - op.getMixedStrides())); - - // Build the composed tile op. - rewriter.replaceOpWithNewOp( - op, supersetOp.superset(), composedOffsets.second, op.sizes(), - composedStrides.second, composedOffsets.first, op.static_sizes(), - composedStrides.first); + Value composed = iface.compose(rewriter); + if (!composed) return failure(); + + rewriter.replaceOp(iface.getOperation(), composed); return success(); } }; @@ -151,10 +52,17 @@ class ComposeSetOpsPass : public ComposeSetOpsPassBase { void runOnOperation() final { MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); - patterns.insert(ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + patterns.insert(ctx); + + // Apply patterns from the top down. This makes sure that we have already + // composed the operand of a tiling op. + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc index 27d114a67e5a54..c556821e599945 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion.cc @@ -36,7 +36,7 @@ struct FusionPattern : public OpRewritePattern { Operation* def = op.source().getDefiningOp(); if (!def) return failure(); - auto iface = llvm::dyn_cast(def); + auto iface = llvm::dyn_cast(def); if (!iface) return failure(); Value fused = iface.fuse(op.getLoc(), op.set(), rewriter); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface_impl.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface_impl.cc index 4dc0d787dc8825..c99163b53135a0 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface_impl.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/fusion_interface_impl.cc @@ -15,96 +15,159 @@ limitations under the License. #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface_impl.h" +#include + +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" #include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/IR/BuiltinAttributes.h" namespace mlir { namespace gml_st { namespace { -bool isElementwise(linalg::GenericOp genericOp) { - if (!genericOp.hasTensorSemantics()) return false; - if (genericOp.outputs().size() != 1) return false; - if (!llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { - return mlir::isParallelIterator(attr); +enum class LinalgGenericFusionKind { + FuseAsElementwise, + FuseAsTranspose, + None, +}; + +LinalgGenericFusionKind getLinalgGenericFusionKind( + linalg::GenericOp genericOp) { + // Only consider all-parallel `linalg.generic` ops with a unique result and + // tensor semantics for fusion. + if (!genericOp.hasTensorSemantics() || genericOp.outputs().size() != 1 || + llvm::any_of(genericOp.iterator_types(), [](Attribute attr) { + return !mlir::isParallelIterator(attr); + })) { + return LinalgGenericFusionKind::None; + } + + // Fuse as element-wise if all maps are identity maps. + if (llvm::all_of(genericOp.indexing_maps(), [](Attribute attr) { + return attr.cast().getAffineMap().isIdentity(); })) { - return false; + return LinalgGenericFusionKind::FuseAsElementwise; } - if (!llvm::all_of(genericOp.indexing_maps(), [](Attribute attr) { - return attr.cast().isIdentity(); + + // Fuse as transpose if all maps are permutation maps. + if (llvm::all_of(genericOp.indexing_maps(), [](Attribute attr) { + return attr.cast().getAffineMap().isPermutation(); })) { - return false; + return LinalgGenericFusionKind::FuseAsTranspose; } - return true; + + return LinalgGenericFusionKind::None; } -struct LingalgGenericFusionInterface - : public FusionIterface::ExternalModel { - Value fuse(Operation* op, Location loc, Value subset, - OpBuilder& builder) const { - auto genericOp = llvm::cast(op); +Value fuseAsElementwise(linalg::GenericOp genericOp, Location loc, Value subset, + OpBuilder& builder) { + assert(getLinalgGenericFusionKind(genericOp) == + LinalgGenericFusionKind::FuseAsElementwise && + "expect element-wise linalg.generic op"); + linalg::LinalgOp linalgOp = genericOp; + return llvm::TypeSwitch(subset.getType()) + .Case([&](TileType tileTy) -> Value { + // Create tiled op. + Value output = genericOp.outputs().front(); + auto outputTy = output.getType().cast(); + auto subResultTy = + RankedTensorType::get(tileTy.getShape(), outputTy.getElementType()); + SmallVector subOperands; + subOperands.reserve(genericOp.getNumInputs()); + for (auto input : genericOp.inputs()) { + subOperands.push_back( + builder.create(loc, input, subset)); + } + subOperands.push_back( + builder.create(loc, output, subset)); + Operation* tiledOp = + linalgOp.clone(builder, loc, subResultTy, subOperands); + + return tiledOp->getResults().front(); + }) + .Case([&](PointType) -> Value { + // Create scalar computation. + BlockAndValueMapping bvm; + Block* block = genericOp.getBody(); + for (auto it : llvm::zip(block->getArguments(), linalgOp.inputs())) { + bvm.map(std::get<0>(it), + builder.create(loc, std::get<1>(it), subset)); + } + for (auto& it : block->without_terminator()) builder.clone(it, bvm); + + auto innerResults = block->getTerminator()->getOperands(); + assert(innerResults.size() == 1 && "expect unique inner result"); + return bvm.lookup(innerResults.front()); + }) + .Default([](Type) -> Value { return {}; }); +} - // Supports only tile subsets. - auto tileTy = subset.getType().dyn_cast(); - if (!tileTy.isa()) return {}; - - // Supports only element-wise `linalg.generic` ops. - if (!isElementwise(genericOp)) return {}; - - // Create tiled op. - Value output = genericOp.outputs().front(); - auto outputTy = output.getType().cast(); - auto subResultTy = - RankedTensorType::get(tileTy.getShape(), outputTy.getElementType()); - SmallVector subOperands; - subOperands.reserve(genericOp.getNumInputs()); - for (auto input : genericOp.inputs()) { +Value fuseAsTranspose(linalg::GenericOp genericOp, Location loc, Value subset, + OpBuilder& builder) { + assert(getLinalgGenericFusionKind(genericOp) == + LinalgGenericFusionKind::FuseAsTranspose && + "expect transposing linalg.generic op"); + + auto tileTy = subset.getType().dyn_cast(); + if (!tileTy) return {}; + + // Create tiled op. + Value output = genericOp.outputs().front(); + auto outputTy = output.getType().cast(); + auto subResultTy = + RankedTensorType::get(tileTy.getShape(), outputTy.getElementType()); + SmallVector subOperands; + subOperands.reserve(genericOp.getNumInputs()); + for (const auto& inputAndMap : + llvm::zip(genericOp.inputs(), genericOp.getIndexingMaps())) { + Value input; + AffineMap map; + std::tie(input, map) = inputAndMap; + if (map.isIdentity()) { subOperands.push_back(builder.create(loc, input, subset)); + continue; } - subOperands.push_back(builder.create(loc, output, subset)); - linalg::LinalgOp linalgOp = genericOp; - Operation* tiledOp = linalgOp.clone(builder, loc, subResultTy, subOperands); - return tiledOp->getResults().front(); + assert(map.isPermutation()); + SmallVector permutation; + permutation.reserve(map.getNumResults()); + for (unsigned int r = 0, e = map.getNumResults(); r < e; ++r) { + permutation.push_back(map.getPermutedPosition(r)); + } + auto transposedTile = builder.create( + loc, subset, DenseI64ArrayAttr::get(builder.getContext(), permutation)); + subOperands.push_back( + builder.create(loc, input, transposedTile)); } -}; + // Materialize the tiled output. + subOperands.push_back(builder.create(loc, output, subset)); + linalg::LinalgOp linalgOp = genericOp; + Operation* tiledOp = linalgOp.clone(builder, loc, subResultTy, subOperands); + return tiledOp->getResults().front(); +} -template -struct ElementwiseFusionInterface - : public FusionIterface::ExternalModel, - OpTy> { +struct LinalgGenericFusionInterface + : public FusionInterface::ExternalModel { Value fuse(Operation* op, Location loc, Value subset, OpBuilder& builder) const { - // Supports tile and point subsets. - Type subsetTy = subset.getType(); - if (!subsetTy.isa()) return {}; - - // Expect ranked element-wise op. - auto cwiseOp = llvm::cast(op); - auto rankedTy = cwiseOp.getType().template dyn_cast(); - if (!rankedTy) return {}; - - // Materialize subsets for all arguments. - auto subsetArgs = llvm::to_vector( - llvm::map_range(cwiseOp->getOperands(), [&](const auto& arg) -> Value { - return builder.create(loc, arg, subset); - })); - - // Materialize elementwise op for subset. - return llvm::TypeSwitch(subsetTy) - .Case([&](TileType) -> Value { - return builder.create(loc, subsetArgs); - }) - .Case([&](PointType) -> Value { - return mhlo::MhloOpToStdScalarOp::mapOp( - cwiseOp, rankedTy.getElementType(), subsetArgs, &builder); - }) - .Default([](Type) -> Value { return {}; }); + auto genericOp = llvm::cast(op); + auto kind = getLinalgGenericFusionKind(genericOp); + + if (kind == LinalgGenericFusionKind::FuseAsElementwise) { + return fuseAsElementwise(genericOp, loc, subset, builder); + } + + if (kind == LinalgGenericFusionKind::FuseAsTranspose) { + return fuseAsTranspose(genericOp, loc, subset, builder); + } + + return {}; } }; @@ -113,18 +176,7 @@ struct ElementwiseFusionInterface void registerFusionInterfaceExternalModels(DialectRegistry& registry) { registry.insert(); registry.addExtension(+[](MLIRContext* ctx, linalg::LinalgDialect*) { - linalg::GenericOp::attachInterface(*ctx); - }); - - // TODO(frgossen): Update tests and remove these in favor of - // `linalg.generic`-based fusions. - registry.insert(); - registry.addExtension(+[](MLIRContext* ctx, mhlo::MhloDialect*) { - mhlo::AddOp::attachInterface>(*ctx); - mhlo::SubOp::attachInterface>(*ctx); - mhlo::CosOp::attachInterface>(*ctx); - mhlo::TanhOp::attachInterface>( - *ctx); + linalg::GenericOp::attachInterface(*ctx); }); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/legalize_mhlo_to_gml.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/legalize_mhlo_to_gml.cc index 5fba0f6591958b..ff6410d55a59a3 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/legalize_mhlo_to_gml.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/legalize_mhlo_to_gml.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h" @@ -56,13 +57,69 @@ struct DynamicBroadcastInDimOpPattern loc, dynamicDims, staticShapeInfo, resultTy.getElementType()); rewriter.replaceOpWithNewOp( - op, resultTy, initTensor, op.operand(), op.broadcast_dimensions(), + op, resultTy, op.operand(), initTensor, op.broadcast_dimensions(), op.known_expanding_dimensionsAttr(), op.known_nonexpanding_dimensionsAttr()); return success(); } }; +// Rewrites simple gather patterns (as checked below). +struct GatherPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::GatherOp op, + PatternRewriter& rewriter) const override { + auto startIndicesType = + op.start_indices().getType().dyn_cast(); + auto operandType = op.operand().getType().dyn_cast(); + + if (!startIndicesType || !operandType) return failure(); + + // index_vector_dim must be the last dimension of start_indices. + int indexVectorDim = op.dimension_numbers().getIndexVectorDim(); + if (startIndicesType.getRank() - 1 != indexVectorDim) return failure(); + + // All slice_sizes must be 1. + if (!llvm::all_of(op.slice_sizes(), [](auto size) { return size == 1; })) + return failure(); + + // offset_dims must be [] + if (!op.dimension_numbers().getOffsetDims().empty()) return failure(); + + // collapsed_slice_dims[] must be range(operand.rank) + auto collapsedSliceDims = op.dimension_numbers().getCollapsedSliceDims(); + if (!isIotaArray(collapsedSliceDims, operandType.getRank())) + return failure(); + + // start_index_map[] must be range(start_indices.shape[index_vector_dim]) + auto startIndexMap = op.dimension_numbers().getStartIndexMap(); + if (!isIotaArray(startIndexMap, + startIndicesType.getShape()[indexVectorDim])) + return failure(); + + // The shape of the result must be statically known. + if (op.getType().getNumDynamicDims() > 0) return failure(); + + auto loc = op.getLoc(); + auto initTensor = rewriter.create( + loc, mlir::ValueRange{}, op.getType().getShape(), + op.getType().getElementType()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.operand(), op.start_indices(), initTensor); + return success(); + } + + private: + static bool isIotaArray(llvm::ArrayRef array, int expectedSize) { + if (array.size() != expectedSize) return false; + for (int i = 0, e = array.size(); i < e; ++i) { + if (i != array[i]) return false; + } + return true; + } +}; + class LegalizeMHLOToGMLPass : public LegalizeMHLOToGMLPassBase { void getDependentDialects(DialectRegistry& registry) const final { @@ -74,7 +131,7 @@ class LegalizeMHLOToGMLPass RewritePatternSet patterns(ctx); // List of patterns. - patterns.insert(ctx); + patterns.insert(ctx); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc index a13ada60c53f51..a62fe08a1e288e 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/transforms.cc @@ -441,7 +441,7 @@ FailureOr tileToPoints(RewriterBase &b, loc, output.getType(), lowerBounds, upperBounds, steps, [&](OpBuilder &b, Location nestedLoc, ValueRange ivs) { Value point = b.create( - nestedLoc, b.getType(), space, ivs, + nestedLoc, space, ivs, b.getI64ArrayAttr(SmallVector( outputType.getRank(), ShapedType::kDynamicStrideOrOffset))); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/vectorization.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/vectorization.cc new file mode 100644 index 00000000000000..50d6e3084cef37 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/vectorization.cc @@ -0,0 +1,59 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" +#include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h" +#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +namespace gml_st { + +struct VectorizeGmlStLoopsPass + : public VectorizeGmlStLoopsPassBase { + void runOnOperation() override { + auto funcOp = getOperation(); + // Vectorize linalg.generic operations inside gml_st.for and gml_st.parallel + // loops. + OpPassManager dynamicPM("func.func"); + linalg::CodegenStrategy strategy; + strategy.vectorize(linalg::GenericOp::getOperationName(), + [](mlir::Operation *op) { + auto generic = mlir::dyn_cast(op); + if (!generic) return failure(); + if (op->getParentOfType() || + op->getParentOfType()) { + return success(); + } + return failure(); + }); + strategy.configurePassPipeline(dynamicPM, funcOp.getContext()); + if (failed(runPipeline(dynamicPM, funcOp))) return signalPassFailure(); + } +}; + +std::unique_ptr> createVectorizeGmlStLoopsPass() { + return std::make_unique(); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 1fac89bf1c686f..e2d11d57df1983 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -659,7 +659,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CopyOp) -INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DivOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DomainOp) @@ -687,7 +687,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftLeftOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightArithmeticOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightLogicalOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SignOp) -INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SineOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SubOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhOp) @@ -1803,6 +1803,13 @@ OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef operands) { auto operandShape = this->operand().getType().cast(); auto updateShape = this->update().getType().cast(); + // If any of the dimensions are length-0, the update does nothing. + for (auto dim : updateShape.getShape()) { + if (dim == 0) { + return this->operand(); + } + } + if (operandShape != updateShape || !operandShape.hasStaticShape()) { return {}; } @@ -1854,7 +1861,7 @@ LogicalResult CollectivePermuteOp::verify() { } //===----------------------------------------------------------------------===// -// ConvOp +// ConvolutionOp //===----------------------------------------------------------------------===// namespace { @@ -1867,7 +1874,7 @@ namespace { // Note that the spatial + non-spatial dimensions may not cover all the // dimensions in the range [0,num) because of the presence of 'unknown' // dimensions (ref. cl/415132294). -LogicalResult isSpatialDimensionsValid(ConvOp op) { +LogicalResult isSpatialDimensionsValid(ConvolutionOp op) { auto inputSpatialDimensions = op.dimension_numbers().getInputSpatialDimensions(); auto kernelSpatialDimensions = @@ -1955,7 +1962,7 @@ LogicalResult isSpatialDimensionsValid(ConvOp op) { // b % bgc == 0 // f % fgc == 0 and i = f / fgc // o (or f') % bgc == 0 and o (or f') % fgc == 0 -LogicalResult verifyConvolutionAttributes(ConvOp op) { +LogicalResult verifyConvolutionAttributes(ConvolutionOp op) { // P1. if (failed(isSpatialDimensionsValid(op))) return failure(); @@ -2037,12 +2044,12 @@ LogicalResult verifyConvolutionAttributes(ConvOp op) { return success(); } -// Infer the return-shape of ConvOp. +// Infer the return-shape of ConvolutionOp. // Precondition: -// 1. Input args to ConvOp 'op' are RankedTypes. +// 1. Input args to ConvolutionOp 'op' are RankedTypes. // 2. rank-of(input-type) == rank-of(output-type) -SmallVector inferConvOpReturnShape( - ConvOp op, const ArrayRef window) { +SmallVector inferConvolutionOpReturnShape( + ConvolutionOp op, const ArrayRef window) { // We keep the 'unknown' dimensions (cl/415132294) as it is in the // output-shape. To do that we initilize the output dimensions with the shape // of the return-type and updates only the spatial + non-spatial dimensions. @@ -2090,7 +2097,7 @@ SmallVector inferConvOpReturnShape( * P4. Verify the return shape. * TODO(b/232574102): Verify the element-type of return-value. */ -LogicalResult ConvOp::verify() { +LogicalResult ConvolutionOp::verify() { auto lhsType = lhs().getType().dyn_cast(); auto rhsType = rhs().getType().dyn_cast(); @@ -2142,7 +2149,7 @@ LogicalResult ConvOp::verify() { << numDims << "), but got " << actualReturnRankedType.getRank() << "."; - auto expectedReturnShape = inferConvOpReturnShape(*this, *windowOrErr); + auto expectedReturnShape = inferConvolutionOpReturnShape(*this, *windowOrErr); auto expectedReturnType = RankedTensorType::get(expectedReturnShape, actualReturnElementType); if (failed(verifyCompatibleShape(expectedReturnType, actualReturnRankedType))) @@ -4945,31 +4952,23 @@ LogicalResult RngBitGeneratorOp::verify() { } //===----------------------------------------------------------------------===// -// RngNormalOp +// RngOp //===----------------------------------------------------------------------===// -LogicalResult RngNormalOp::inferReturnTypeComponents( - MLIRContext* context, Optional location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl& inferredReturnShapes) { - return rngInferReturnTypeComponents(context, location, operands, attributes, - regions, inferredReturnShapes); -} - -LogicalResult RngNormalOp::reifyReturnTypeShapes( - OpBuilder& builder, ValueRange operands, - SmallVectorImpl& reifiedReturnShapes) { - RngNormalOp::Adaptor adaptor(operands); - reifiedReturnShapes.push_back( - castToIndexTensor(builder, getLoc(), adaptor.shape())); - return success(); +LogicalResult RngOp::verify() { + auto dist = rng_distribution(); + if (dist == RngDistribution::UNIFORM) { + return success(); + } + auto muTy = a().getType().cast().getElementType(); + auto sigmaTy = b().getType().cast().getElementType(); + if (muTy.isa() && sigmaTy.isa()) { + return success(); + } + return emitOpError() << "mu and sigma must be floats"; } -//===----------------------------------------------------------------------===// -// RngUniformOp -//===----------------------------------------------------------------------===// - -LogicalResult RngUniformOp::inferReturnTypeComponents( +LogicalResult RngOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { @@ -4977,10 +4976,10 @@ LogicalResult RngUniformOp::inferReturnTypeComponents( regions, inferredReturnShapes); } -LogicalResult RngUniformOp::reifyReturnTypeShapes( +LogicalResult RngOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { - RngUniformOp::Adaptor adaptor(operands); + RngOp::Adaptor adaptor(operands); reifiedReturnShapes.push_back( castToIndexTensor(builder, getLoc(), adaptor.shape())); return success(); @@ -6697,10 +6696,10 @@ LogicalResult TransposeOp::reifyReturnTypeShapes( // Method for InferTypeOpInterface: infer the return type from the operand type // and the permutation. -LogicalResult TransposeOp::inferReturnTypeComponents( - MLIRContext* context, Optional loc, ValueShapeRange operands, +LogicalResult TransposeOp::inferReturnTypes( + MLIRContext* /*context*/, Optional loc, ValueRange operands, DictionaryAttr attributes, RegionRange, - SmallVectorImpl& inferredReturnTypes) { + SmallVectorImpl& inferredReturnTypes) { auto type = operands[0].getType(); auto rankedTy = type.dyn_cast(); if (!rankedTy) { @@ -6733,7 +6732,8 @@ LogicalResult TransposeOp::inferReturnTypeComponents( for (int64_t dim : permutation.getValues()) { resultShape.push_back(inputShape[dim]); } - inferredReturnTypes.emplace_back(resultShape, rankedTy.getElementType()); + inferredReturnTypes.emplace_back(RankedTensorType::get( + resultShape, rankedTy.getElementType(), rankedTy.getEncoding())); return success(); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 815a027557e52b..eb63dcdedc7b88 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -605,7 +605,7 @@ Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc, // Materialize reflection. Value reflectionDenom = rewriter.create( loc, - rewriter.create( + rewriter.create( loc, rewriter.create( loc, getConstantLike(rewriter, loc, M_PI, x), absFrac))); Value lgammaReflection = rewriter.create( @@ -779,8 +779,8 @@ Value materializeDigamma(ConversionPatternRewriter &rewriter, Location loc, // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x) Value pi = getConstantLike(rewriter, loc, M_PI, x); Value piMulReducedX = rewriter.create(loc, pi, reducedX); - Value cos = rewriter.create(loc, piMulReducedX); - Value sin = rewriter.create(loc, piMulReducedX); + Value cos = rewriter.create(loc, piMulReducedX); + Value sin = rewriter.create(loc, piMulReducedX); Value reflection = rewriter.create( loc, digamma, rewriter.create( @@ -1207,8 +1207,8 @@ Value materializeTan(ConversionPatternRewriter &rewriter, Location loc, ValueRange operands) { TanOp::Adaptor transformed(operands); return rewriter.create( - loc, rewriter.create(loc, transformed.operand()), - rewriter.create(loc, transformed.operand())); + loc, rewriter.create(loc, transformed.operand()), + rewriter.create(loc, transformed.operand())); } struct ConvertTanOp : public OpConversionPattern { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index cd4786e34308e4..755ce99b07d1aa 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -524,10 +524,10 @@ void populateHloToLhloConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, - HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, - HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -555,7 +555,7 @@ void populateHloToLhloConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, - HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 3671a30a496404..f21b8ece131de4 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -303,17 +303,21 @@ static bool hasCanonicalDimensionNumbers( } //===----------------------------------------------------------------------===// -// mhlo.RngUniformOp conversion patterns. +// mhlo.RngOp conversion patterns. //===----------------------------------------------------------------------===// -// Pass to lower from rng_uniform to stateless uniform pseudo RNG with LCG +// Pass to lower from rng to stateless pseudo RNG with LCG // algorithm -struct RngUniformConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct RngUniformConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - mhlo::RngUniformOp op, OpAdaptor adaptor, + mhlo::RngOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { + // We only handle uniform distributions + if (op.rng_distribution() != ::mlir::mhlo::RngDistribution::UNIFORM) { + return failure(); + } // TODO(raikonenfnu): Handle other element types as well. auto minTy = adaptor.getOperands()[0].getType().dyn_cast(); auto maxTy = adaptor.getOperands()[0].getType().dyn_cast(); @@ -2093,11 +2097,12 @@ Value applyConvolutionPadding(Location loc, Value input, /// Converts mhlo.conv operation to linalg named op. This only covers normal /// convolution cases. The op must have canonical dimension numbers. Depthwise /// convolution and pointwise convolution are not handled in the conversion. -struct NormalConvOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct NormalConvolutionOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - mhlo::ConvOp op, OpAdaptor adaptor, + mhlo::ConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure(); if (op.feature_group_count() != 1u) return failure(); @@ -2170,11 +2175,12 @@ struct NormalConvOpConversion : public OpConversionPattern { /// Converts mhlo.convolution operation to /// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or /// depthwise_conv_2d_input_nhwc_filter_hwc op. -struct DepthwiseConvOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct DepthwiseConvolutionOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - mhlo::ConvOp op, OpAdaptor adaptor, + mhlo::ConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { if (op.batch_group_count() != 1) return failure(); // Fall into the normal convolution cases. @@ -3294,7 +3300,7 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -3321,7 +3327,7 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -3334,8 +3340,8 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, DynamicSliceConverter, DynamicUpdateSliceConverter, TransposeConverter, - NormalConvOpConversion, - DepthwiseConvOpConversion, + NormalConvolutionOpConversion, + DepthwiseConvolutionOpConversion, GatherConversion, PadOpConversion, PadOpNegativePaddingConversion, diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td index 05de35db0abfd3..210d0e616f7637 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td @@ -80,36 +80,36 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), // sin(a) * cosh(b) + icos(a) * sinh(b) // sinh(b) = (e^x - e^-x) / 2 // cosh(b) = (e^x + e^-x) / 2 -def : Pat<(HLO_SinOp HLO_ComplexTensor:$val), +def : Pat<(HLO_SineOp HLO_ComplexTensor:$val), (HLO_ComplexOp (HLO_DivOp (HLO_MulOp - (HLO_SinOp (HLO_RealOp:$real $val)), + (HLO_SineOp (HLO_RealOp:$real $val)), (HLO_AddOp (HLO_ExpOp:$exp (HLO_ImagOp:$imag $val)), (HLO_ExpOp:$nexp (HLO_NegOp $imag)))), (HLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))), (HLO_DivOp (HLO_MulOp - (HLO_CosOp $real), + (HLO_CosineOp $real), (HLO_SubOp $exp, $nexp)), $two))>; // Can deconstruct cos(a + ib) as follows: // cos(a) * cosh(b) - isin(a) * sinh(b) // sinh(b) = (e^x - e^-x) / 2 // cosh(b) = (e^x + e^-x) / 2 -def : Pat<(HLO_CosOp HLO_ComplexTensor:$val), +def : Pat<(HLO_CosineOp HLO_ComplexTensor:$val), (HLO_ComplexOp (HLO_DivOp (HLO_MulOp - (HLO_CosOp (HLO_RealOp:$real $val)), + (HLO_CosineOp (HLO_RealOp:$real $val)), (HLO_AddOp (HLO_ExpOp:$exp (HLO_ImagOp:$imag $val)), (HLO_ExpOp:$nexp (HLO_NegOp $imag)))), (HLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))), (HLO_DivOp (HLO_MulOp - (HLO_SinOp $real), + (HLO_SineOp $real), (HLO_SubOp $nexp, $exp)), $two))>; // Exponential can be lowered to an exponential on the real component and a @@ -123,9 +123,9 @@ class HLO_ComparisonDirectionValue : def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), (HLO_ComplexOp (HLO_MulOp - (HLO_CosOp (HLO_ImagOp:$imag $val)), + (HLO_CosineOp (HLO_ImagOp:$imag $val)), (HLO_ExpOp:$exp (HLO_RealOp:$real $val))), - (HLO_MulOp (HLO_SinOp $imag), $exp))>; + (HLO_MulOp (HLO_SineOp $imag), $exp))>; foreach pair = [[HLO_ComparisonDirectionValue<"NE">, HLO_OrOp], [HLO_ComparisonDirectionValue<"EQ">, HLO_AndOp]] in { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sparse_rewriting.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sparse_rewriting.cc index f18dc428b0c11a..4de3adedd3e9a1 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sparse_rewriting.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sparse_rewriting.cc @@ -36,13 +36,13 @@ namespace { // TODO(b/231360416): replace this list with "supports sparsity" trait? static bool canFuseWithSparseConvert(Operation *op) { return isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || + isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || isa(op) || + isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || - isa(op); + isa(op) || isa(op); } /// Fuses a sparse tensor type from a conversion into a mhlo operation diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index 28d1e31a5c0ecd..11f49bad3611d4 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -17,10 +17,12 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" @@ -37,8 +39,7 @@ namespace { Value broadcastToFeatureDim(Location loc, RankedTensorType resultType, Value value1d, Value shapeValue, int64_t featureDim, PatternRewriter& rewriter) { // NOLINT - Builder b(rewriter.getContext()); - auto dimsType = RankedTensorType::get({1}, b.getIntegerType(64)); + auto dimsType = RankedTensorType::get({1}, rewriter.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dimsType, {featureDim}); if (shapeValue) { return rewriter.createOrFold( @@ -49,25 +50,20 @@ Value broadcastToFeatureDim(Location loc, RankedTensorType resultType, dims); } -// Calculate the shape value of operand, assuming it is a dynamic shape with -// static rank. -Value calculateShapeValue(Location loc, Value operand, - PatternRewriter& rewriter) { // NOLINT +// Get the shape of operand, assuming it is a dynamic shape with static rank. +Value getShapeValue(Location loc, Value operand, + PatternRewriter &rewriter) { // NOLINT RankedTensorType resultType = operand.getType().dyn_cast(); - llvm::SmallVector shapeValues; - int64_t rank = resultType.getRank(); - shapeValues.reserve(rank); - for (int64_t i = 0; i < rank; ++i) { - shapeValues.push_back( - rewriter.create(loc, operand, i)); - } - return rewriter.create(loc, shapeValues); + return rewriter.create( + loc, + RankedTensorType::get({resultType.getRank()}, rewriter.getIndexType()), + operand); } -Value materializeEpsilon(Operation* op, FloatAttr epsilonAttr, FloatType fpType, +Value materializeEpsilon(Operation *op, FloatAttr epsilonAttr, FloatType fpType, Value broadcastTo, RankedTensorType broadcastToType, - PatternRewriter& rewriter) { // NOLINT - Builder b(rewriter.getContext()); + PatternRewriter &rewriter) { // NOLINT + ImplicitLocOpBuilder b(op->getLoc(), rewriter); if (epsilonAttr.getType() != fpType) { // Need to convert. bool losesInfo; @@ -89,18 +85,17 @@ Value materializeEpsilon(Operation* op, FloatAttr epsilonAttr, FloatType fpType, auto scalarType = RankedTensorType::get({}, fpType); auto epsilonTensorAttr = DenseElementsAttr::get(scalarType, {epsilonAttr.cast()}); - Value epsilon = - rewriter.create(op->getLoc(), epsilonTensorAttr); + Value epsilon = b.create(epsilonTensorAttr); auto dimsType = RankedTensorType::get({0}, b.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dimsType, SmallVector{}); if (broadcastToType.hasStaticShape()) { - return rewriter.create( - op->getLoc(), broadcastToType, epsilon, /*broadcast_dims=*/dims); + return b.create(broadcastToType, epsilon, + /*broadcast_dims=*/dims); } - Value shapeValue = calculateShapeValue(op->getLoc(), broadcastTo, rewriter); - return rewriter.createOrFold( - op->getLoc(), broadcastToType, epsilon, shapeValue, - /*broadcast_dims=*/dims); + Value shapeValue = getShapeValue(op->getLoc(), broadcastTo, rewriter); + return b.createOrFold(broadcastToType, epsilon, + shapeValue, + /*broadcast_dims=*/dims); } class UnfuseBatchNormInferencePattern @@ -139,7 +134,7 @@ class UnfuseBatchNormInferencePattern // Broadcast all terms. Value shapeValue; if (!inputType.hasStaticShape()) { - shapeValue = calculateShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter); + shapeValue = getShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter); } auto broadcastScale = broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.scale(), @@ -200,41 +195,31 @@ Value createReduce(Location loc, Value operand, Value zero, } // Calculate total reduce size, assuming it is a dynamic shape with static rank. -// Reduce from operand to operand[feature_index] -Value calculateReduceSize(Operation* op, Value operand, - RankedTensorType operandType, +// Reduce from operand to operand[feature_index]/scale +Value calculateReduceSize(Operation *op, Value operand, + RankedTensorType operandType, Value scale, RankedTensorType scaleType, int64_t featureIndex, - PatternRewriter& rewriter) { - Location loc = op->getLoc(); + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Type indexType = b.getIndexType(); if (!operandType.hasStaticShape()) { // the "operand" has dynamic shape with static rank - llvm::SmallVector reduceValues; - for (int64_t i = 0, e = operandType.getRank(); i < e; i++) { - if (i != featureIndex) { - reduceValues.push_back(rewriter.create(loc, operand, i)); - } - } - assert(!reduceValues.empty()); - Value reduceSize = reduceValues[0]; - for (size_t i = 1, e = reduceValues.size(); i < e; i++) { - reduceSize = - rewriter.create(loc, reduceSize, reduceValues[i]); - } - reduceSize = rewriter.create(loc, rewriter.getI64Type(), - reduceSize); - reduceSize = rewriter.create(loc, reduceSize); - reduceSize = rewriter.create( - loc, RankedTensorType::get({1}, operandType.getElementType()), - reduceSize); - reduceSize = rewriter.create( - loc, RankedTensorType::get({}, operandType.getElementType()), - reduceSize); - Value featureSize = - rewriter.create(loc, operand, featureIndex); - featureSize = rewriter.create(loc, featureSize); - - return rewriter.createOrFold( - loc, scaleType, reduceSize, featureSize, rewriter.getI64TensorAttr({})); + Value operandShape = getShapeValue(op->getLoc(), operand, rewriter); + Value scaleShape = getShapeValue(op->getLoc(), scale, rewriter); + Value operandTotalSize = + b.create(indexType, operandShape); + Value scaleTotalSize = + b.create(indexType, scaleShape); + Value reduceSize = + b.create(indexType, operandTotalSize, scaleTotalSize); + reduceSize = b.create(b.getI64Type(), reduceSize); + reduceSize = b.create(reduceSize); + reduceSize = b.create( + RankedTensorType::get({1}, operandType.getElementType()), reduceSize); + reduceSize = b.create( + RankedTensorType::get({}, operandType.getElementType()), reduceSize); + return b.createOrFold( + scaleType, reduceSize, scaleShape, b.getI64TensorAttr({})); } // the "operand" has static shape @@ -252,8 +237,8 @@ Value calculateReduceSize(Operation* op, Value operand, if (losesInfo) { op->emitWarning("Conversion of reduce_dims_size loses precision"); } - Value reduceSize = rewriter.create( - loc, DenseFPElementsAttr::get(scaleType, floatValue)); + Value reduceSize = b.create( + DenseFPElementsAttr::get(scaleType, floatValue)); return reduceSize; } @@ -298,7 +283,7 @@ class UnfuseBatchNormTrainingPattern // reduce size constant Value reduceSize = calculateReduceSize(bnOp.getOperation(), bnOp.operand(), operandType, - scaleType, featureIndex, rewriter); + bnOp.scale(), scaleType, featureIndex, rewriter); if (!reduceSize) { return failure(); } @@ -330,7 +315,7 @@ class UnfuseBatchNormTrainingPattern Value shapeValue; if (!operandType.hasStaticShape()) { - shapeValue = calculateShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter); + shapeValue = getShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter); } // X - E[X] Value meanBroadcast = broadcastToFeatureDim( @@ -371,9 +356,13 @@ class UnfuseBatchNormTrainingPattern // In combination with marking such ops as illegal, this allows backends that // do not have special support for fused batchnorm to use simpler arithmetic // primitives. -void populateUnfuseBatchNormPatterns(MLIRContext* context, - RewritePatternSet* patterns) { +void populateUnfuseBatchNormInferencePattern(MLIRContext *context, + RewritePatternSet *patterns) { patterns->add(context); +} + +void populateUnfuseBatchNormTrainingPattern(MLIRContext *context, + RewritePatternSet *patterns) { patterns->add(context); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index b3931a9f514af4..59646d5b1011ba 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -18,8 +18,10 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -34,12 +36,10 @@ namespace { struct TestUnfuseBatchNormPass : public TestUnfuseBatchNormPassBase { - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateUnfuseBatchNormPatterns(&getContext(), &patterns); + populateUnfuseBatchNormInferencePattern(&getContext(), &patterns); + populateUnfuseBatchNormTrainingPattern(&getContext(), &patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Transforms/gml_st_pipeline.cc b/tensorflow/compiler/mlir/hlo/lib/Transforms/gml_st_pipeline.cc index b83c895aa4bfd3..a915aebd90aa52 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Transforms/gml_st_pipeline.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Transforms/gml_st_pipeline.cc @@ -36,6 +36,10 @@ void createGmlStPipeline(mlir::OpPassManager& pm, // Perform tiling, fusion, vectorization and other transformations. pm.addNestedPass(gml_st::createTilingPass(options.tileSizes)); + if (options.fuse) { + pm.addNestedPass(gml_st::createFusionPass()); + } + pm.addNestedPass(gml_st::createComposeSetOpsPass()); if (!options.lowerToLoops) return; @@ -48,6 +52,7 @@ void createGmlStPipeline(mlir::OpPassManager& pm, // Convert Linalg + GmlSt to SCF loops. pm.addNestedPass(createConvertLinalgToLoopsPass()); + pm.addNestedPass(gml_st::createVectorizeGmlStLoopsPass()); pm.addNestedPass(gml_st::createGmlStToScfPass()); } diff --git a/tensorflow/compiler/mlir/hlo/python/MlirHloModule.cc b/tensorflow/compiler/mlir/hlo/python/MlirHloModule.cc index ef79e7c85280f2..13e514f3c75882 100644 --- a/tensorflow/compiler/mlir/hlo/python/MlirHloModule.cc +++ b/tensorflow/compiler/mlir/hlo/python/MlirHloModule.cc @@ -425,6 +425,24 @@ PYBIND11_MODULE(_mlirHlo, m) { return toPyString(mlirMhloFusionKindAttrGetFusionKind(self)); }); + mlir::python::adaptors::mlir_attribute_subclass( + m, "RngDistributionAttr", mlirMhloAttributeIsARngDistributionAttr) + .def_classmethod( + "get", + [](py::object cls, const std::string &distribution, MlirContext ctx) { + return cls(mlirMhloRngDistributionAttrGet( + ctx, mlirStringRefCreate(distribution.c_str(), + distribution.size()))); + }, + py::arg("cls"), py::arg("rng_distribution"), + py::arg("context") = py::none(), + "Creates a RngDistribution attribute with the given rng " + "distribution.") + .def_property_readonly("rng_distribution", [](MlirAttribute self) { + auto distribution = mlirMhloRngDistributionAttrGetRngDistribution(self); + return py::str(distribution.data, distribution.length); + }); + mlir::python::adaptors::mlir_attribute_subclass( m, "RngAlgorithmAttr", mlirMhloAttributeIsARngAlgorithmAttr) .def_classmethod( diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/bufferization.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/bufferization.mlir index ab652712a64383..3c765718751a9b 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/bufferization.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/bufferization.mlir @@ -1,5 +1,5 @@ // RUN: mlir-hlo-opt %s -test-gml-st-bufferization -canonicalize -cse \ -// RUN: -split-input-file | FileCheck %s --dump-input=always +// RUN: -split-input-file | FileCheck %s func.func @set_space(%input: tensor) -> tensor { %c0 = arith.constant 0 : index diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/compose_set_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/compose_set_ops.mlir index 395ce4975c43e3..2c43bee8169827 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/compose_set_ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/compose_set_ops.mlir @@ -29,9 +29,9 @@ func.func @tile_of_tile(%arg : tensor, %i : index, %j : index, // CHECK-SAME: %[[ARG:.*]]: tensor<4096x2048xf32> func.func @tile_of_tile_of_tile_all_constant(%arg : tensor<4096x2048xf32>) -> tensor<128x64xf32> { - // CHECK-DAG: %[[SPACE:.*]] = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048> - // CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [18, 64] [128, 64] [4, 0] : !gml_st.tile<4096x2048> to !gml_st.tile<128x64> - // CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] : tensor<4096x2048xf32>[!gml_st.tile<128x64>] + // CHECK: %[[SPACE:.*]] = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048> + // CHECK: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [18, 64] [128, 64] [4, 0] : !gml_st.tile<4096x2048> to !gml_st.tile<128x64> + // CHECK: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] : tensor<4096x2048xf32>[!gml_st.tile<128x64>] // CHECK: return %[[RES]] %s = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048> %t = gml_st.tile %s [0, 32] [2048, 256] [1, 2] @@ -57,10 +57,10 @@ func.func @tile_chain_w_zeroes_and_ones(%arg : tensor<8192x4096x2048xf32>, // CHECK-DAG: %[[C32:.*]] = arith.constant 32 // CHECK-DAG: %[[SPACE:.*]] = gml_st.space [8192, 4096, 2048] : !gml_st.tile<8192x4096x2048> // CHECK-DAG: %[[TWO_K:.*]] = arith.muli %[[K]], %[[C2]] - // CHECK-DAG: %[[SIXTEEN_J:.*]] = arith.addi %[[J]], %[[C16]] + // CHECK-DAG: %[[SIXTEEN_PLUS_J:.*]] = arith.addi %[[J]], %[[C16]] // CHECK-DAG: %[[TWO_K_PLUS_32:.*]] = arith.addi %[[TWO_K]], %[[C32]] // CHECK-DAG: %[[C_TIMES_C2:.*]] = arith.muli %[[C]], %[[C2]] - // CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [0, %[[SIXTEEN_J]], %[[TWO_K_PLUS_32]]] [%[[M]], %[[N]], %[[O]]] [0, %[[B]], %[[C_TIMES_C2]]] : !gml_st.tile<8192x4096x2048> to !gml_st.tile + // CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [0, %[[SIXTEEN_PLUS_J]], %[[TWO_K_PLUS_32]]] [%[[M]], %[[N]], %[[O]]] [0, %[[B]], %[[C_TIMES_C2]]] : !gml_st.tile<8192x4096x2048> to !gml_st.tile // CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] : tensor<8192x4096x2048xf32>[!gml_st.tile] // CHECK: return %[[RES]] %space = gml_st.space [8192, 4096, 2048] : !gml_st.tile<8192x4096x2048> @@ -80,8 +80,8 @@ func.func @tile_chain_w_zeroes_and_ones(%arg : tensor<8192x4096x2048xf32>, func.func @tile_of_tile_arith_shortcuts_add(%arg : tensor<32x32x32xf32>, %i : index, %j : index) -> tensor<8x8x8xf32> { // CHECK-DAG: %[[SPACE:.*]] = gml_st.space - // CHECK-DAG: %[[IJ:.*]] = arith.addi %[[I]], %[[J]] - // CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[J]], %[[I]], %[[IJ]]] [8, 8, 8] [1, 1, 1] + // CHECK-DAG: %[[I_PLUS_J:.*]] = arith.addi %[[I]], %[[J]] + // CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[J]], %[[I]], %[[I_PLUS_J]]] [8, 8, 8] [1, 1, 1] // CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] // CHECK: return %[[RES]] %space = gml_st.space [32, 32, 32] : !gml_st.tile<32x32x32> @@ -116,3 +116,132 @@ func.func @tile_of_tile_arith_shortcuts_mul(%arg : tensor<32x32x32x32x32xf32>, : tensor<32x32x32x32x32xf32>[!gml_st.tile<8x8x8x8x8>] func.return %result : tensor<8x8x8x8x8xf32> } + +// ----- + +// CHECK-LABEL: @point_of_tile +// CHECK-SAME: %[[ARG:.*]]: tensor, %[[I:.*]]: index, %[[J:.*]]: index, %[[K:.*]]: index, %[[M:.*]]: index, %[[A:.*]]: index +func.func @point_of_tile(%arg : tensor, %i : index, %j : index, + %k : index, %m : index, %a : index) -> f32 { + // CHECK-DAG: %[[SPACE:.*]] = gml_st.space [1024, %[[M]]] : !gml_st.tile<1024x?> + // CHECK-DAG: %[[AK:.*]] = arith.muli %[[A]], %[[K]] + // CHECK-DAG: %[[J_PLUS_AK:.*]] = arith.addi %[[J]], %[[AK]] + // CHECK-DAG: %[[POINT:.*]] = gml_st.point %[[SPACE]] [%[[I]], %[[J_PLUS_AK]]] : !gml_st.tile<1024x?> to !gml_st.point + // CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor[!gml_st.point] + // CHECK: return %[[RES]] + %space = gml_st.space [1024, %m] : !gml_st.tile<1024x?> + %tile = gml_st.tile %space [%i, %j] [4, 128] [2, %a] + : !gml_st.tile<1024x?> to !gml_st.tile<4x128> + %point_of_tile = gml_st.point %tile [0, %k] + : !gml_st.tile<4x128> to !gml_st.point + %result = gml_st.materialize %arg[%point_of_tile] + : tensor[!gml_st.point] + func.return %result : f32 +} + +// ----- + +// CHECK-LABEL: @point_of_tile_of_tile_all_constant +// CHECK-SAME: %[[ARG:.*]]: tensor<4096x2048xf32> +func.func @point_of_tile_of_tile_all_constant(%arg : tensor<4096x2048xf32>) + -> f32 { + // CHECK: %[[SPACE:.*]] = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048> + // CHECK: %[[POINT:.*]] = gml_st.point %[[SPACE]] [18, 64] : !gml_st.tile<4096x2048> to !gml_st.point + // CHECK: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor<4096x2048xf32>[!gml_st.point] + // CHECK: return %[[RES]] + %s = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048> + %t = gml_st.tile %s [0, 32] [2048, 256] [1, 2] + : !gml_st.tile<4096x2048> to !gml_st.tile<2048x256> + %tt = gml_st.tile %t [2, 16] [256, 128] [4, 0] + : !gml_st.tile<2048x256> to !gml_st.tile<256x128> + %ptt = gml_st.point %tt [4, 8] : !gml_st.tile<256x128> to !gml_st.point + %res = gml_st.materialize %arg[%ptt] + : tensor<4096x2048xf32>[!gml_st.point] + func.return %res : f32 +} + +// ----- + +// CHECK-LABEL: @point_chain_w_zeroes_and_ones +// CHECK-SAME: %[[ARG:.*]]: tensor<8192x4096x2048xf32>, %[[I:.*]]: index, %[[J:.*]]: index, %[[K:.*]]: index +func.func @point_chain_w_zeroes_and_ones(%arg : tensor<8192x4096x2048xf32>, + %i : index, %j : index, %k : index) -> f32 { + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 + // CHECK-DAG: %[[C16:.*]] = arith.constant 16 + // CHECK-DAG: %[[C32:.*]] = arith.constant 32 + // CHECK-DAG: %[[SPACE:.*]] = gml_st.space [8192, 4096, 2048] : !gml_st.tile<8192x4096x2048> + // CHECK-DAG: %[[TWO_K:.*]] = arith.muli %[[K]], %[[C2]] + // CHECK-DAG: %[[SIXTEEN_PLUS_J:.*]] = arith.addi %[[J]], %[[C16]] + // CHECK-DAG: %[[TWO_K_PLUS_32:.*]] = arith.addi %[[TWO_K]], %[[C32]] + // CHECK-DAG: %[[POINT:.*]] = gml_st.point %[[SPACE]] [0, %[[SIXTEEN_PLUS_J]], %[[TWO_K_PLUS_32]]] : !gml_st.tile<8192x4096x2048> to !gml_st.point + // CHECK-DAG: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor<8192x4096x2048xf32>[!gml_st.point] + // CHECK: return %[[RES]] + %space = gml_st.space [8192, 4096, 2048] : !gml_st.tile<8192x4096x2048> + %tile = gml_st.tile %space [0, 16, 32] [2048, 1024, 512] [0, 1, 2] + : !gml_st.tile<8192x4096x2048> to !gml_st.tile<2048x1024x512> + %point_of_tile = gml_st.point %tile [%i, %j, %k] + : !gml_st.tile<2048x1024x512> to !gml_st.point + %result = gml_st.materialize %arg[%point_of_tile] + : tensor<8192x4096x2048xf32>[!gml_st.point] + func.return %result : f32 +} + +// ----- + +// CHECK-LABEL: @point_of_transpose_tile_of_tile_all_constant +// CHECK-SAME: %[[ARG:.*]]: tensor<2048x4096xf32> +func.func @point_of_transpose_tile_of_tile_all_constant(%arg : tensor<2048x4096xf32>) + -> f32 { + // CHECK: %[[SPACE:.*]] = gml_st.space [2048, 4096] : !gml_st.tile<2048x4096> + // CHECK: %[[POINT:.*]] = gml_st.point %[[SPACE]] [40, 8] : !gml_st.tile<2048x4096> to !gml_st.point + // CHECK: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor<2048x4096xf32>[!gml_st.point] + // CHECK: return %[[RES]] + %s = gml_st.space [4096, 2048] : !gml_st.tile<4096x2048> + %t = gml_st.tile %s [0, 32] [128, 256] [1, 2] + : !gml_st.tile<4096x2048> to !gml_st.tile<128x256> + %tt = gml_st.transpose_tile %t, [1, 0] + : !gml_st.tile<128x256> to !gml_st.tile<256x128> + %ptt = gml_st.point %tt [4, 8] : !gml_st.tile<256x128> to !gml_st.point + %res = gml_st.materialize %arg[%ptt] + : tensor<2048x4096xf32>[!gml_st.point] + func.return %res : f32 +} + +// ----- + +// CHECK-LABEL: @transpose_tile_of_transpose_tile_of_tile +// CHECK-SAME: %[[ARG:.*]]: tensor<10x?x5xf32>, %[[SIZE:.*]]: index +func.func @transpose_tile_of_transpose_tile_of_tile( + %arg : tensor<10x?x5xf32>, %size: index) -> tensor<4x?x5xf32> { +// CHECK: %[[SPACE:.*]] = gml_st.space [10, %[[SIZE]], 5] : !gml_st.tile<10x?x5> +// CHECK: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [3, 0, 0] [4, %[[SIZE]], 5] [2, %[[SIZE]], 1] : !gml_st.tile<10x?x5> to !gml_st.tile<4x?x5> +// CHECK: %[[RES:.*]] = gml_st.materialize %arg0[%[[TILE]]] : tensor<10x?x5xf32>[!gml_st.tile<4x?x5>] +// CHECK: return %[[RES]] : tensor<4x?x5xf32> + %s = gml_st.space [%size, 5, 10] : !gml_st.tile + %t = gml_st.tile %s [0, 0, 3] [%size, 5, 4] [%size, 1, 2] + : !gml_st.tile to !gml_st.tile + %tt = gml_st.transpose_tile %t, [1, 0, 2] + : !gml_st.tile to !gml_st.tile<5x?x4> + %tt2 = gml_st.transpose_tile %tt, [2, 1, 0] + : !gml_st.tile<5x?x4> to !gml_st.tile<4x?x5> + %res = gml_st.materialize %arg[%tt2] + : tensor<10x?x5xf32>[!gml_st.tile<4x?x5>] + func.return %res : tensor<4x?x5xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_tile_of_space +// CHECK-SAME: %[[ARG:.*]]: tensor<5x10x?xf32>, %[[SIZE:.*]]: index +func.func @transpose_tile_of_space( + %arg : tensor<5x10x?xf32>, %size: index) -> tensor<5x10x?xf32> { +// CHECK: %[[SPACE:.*]] = gml_st.space [5, 10, %[[SIZE]]] : !gml_st.tile<5x10x?> +// CHECK: %[[RES:.*]] = gml_st.materialize %arg0[%[[SPACE]]] : tensor<5x10x?xf32>[!gml_st.tile<5x10x?>] +// CHECK: return %[[RES]] : tensor<5x10x?xf32> + %s = gml_st.space [%size, 5, 10] : !gml_st.tile + %tt = gml_st.transpose_tile %s, [1, 2, 0] + : !gml_st.tile to !gml_st.tile<5x10x?> + %res = gml_st.materialize %arg[%tt] + : tensor<5x10x?xf32>[!gml_st.tile<5x10x?>] + func.return %res : tensor<5x10x?xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir index 8b8b1f928135a0..4556cb9f3a2f4f 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/fusion.mlir @@ -48,7 +48,10 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, // Check tiled broadcast. // CHECK-DAG: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[RES_TILE]]] : tensor[!gml_st.tile<3x4x5>] // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]][%[[ARG_TILE]]] : tensor[!gml_st.tile] - // CHECK-DAG: %[[RES:.*]] = gml_st.dynamic_broadcast_in_dim %[[INIT_SUB]], %[[ARG_SUB]], [0, 2] : tensor<3x4x5xf32>, tensor -> tensor<3x4x5xf32> + // CHECK-NEXT: %[[RES:.*]] = gml_st.dynamic_broadcast_in_dim + // CHECK-SAME ins(%[[ARG_SUB]] : tensor) + // CHECK-SAME outs(%[[INIT_SUB]] : tensor<3x4x5xf32>) + // CHECK-SAME {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} // CHECK: return %[[RES]] : tensor<3x4x5xf32> %c0 = arith.constant 0 : index @@ -60,8 +63,8 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, %d1 = tensor.extract %shape[%c1] : tensor<3xindex> %d2 = tensor.extract %shape[%c2] : tensor<3xindex> %dst = linalg.init_tensor [%d0, %d1, %d2] : tensor - %bcast = gml_st.dynamic_broadcast_in_dim %dst, %arg, [0, 2] - : tensor, tensor -> tensor + %bcast = gml_st.dynamic_broadcast_in_dim ins(%arg: tensor) + outs(%dst: tensor) { broadcast_dimensions = dense<[0, 2]> : tensor<2xi64> } // Materialze a tile. %space = gml_st.space [%d0, %d1, %d2] : !gml_st.tile @@ -75,122 +78,268 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, // ----- -// CHECK-LABEL: @add -// CHECK-SAME: %[[LHS:.*]]: tensor<32x32xf32>, %[[RHS:.*]]: tensor<32x32xf32>, %[[TILE:.*]]: !gml_st.tile +// CHECK: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +#id_map = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: @add +// CHECK-SAME: %[[LHS:.*]]: tensor<32x32xf32>, %[[RHS:.*]]: tensor<32x32xf32>, %[[TILE:.*]]: !gml_st.tile) func.func @add(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>, %tile: !gml_st.tile) -> tensor { - // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[TILE]]] : tensor<32x32xf32>[!gml_st.tile] - // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[TILE]]] : tensor<32x32xf32>[!gml_st.tile] - // CHECK-DAG: %[[RES:.*]] = mhlo.add %[[LHS_SUB]], %[[RHS_SUB]] : tensor - // CHECK: return %[[RES]] - %0 = mhlo.add %lhs, %rhs : tensor<32x32xf32> - %1 = gml_st.materialize %0[%tile] : tensor<32x32xf32>[!gml_st.tile] - func.return %1 : tensor + // CHECK-DAG: %[[INIT:.*]] = linalg.init_tensor [32, 32] + // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[TILE]]] + // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[TILE]]] + // CHECK-DAG: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] + // CHECK: %[[RES:.*]] = linalg.generic + // CHECK-SAME: indexing_maps = [#[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]], + // CHECK-SAME: iterator_types = ["parallel", "parallel"] + // CHECK-SAME: ins(%[[LHS_SUB]], %[[RHS_SUB]] : tensor, tensor) + // CHECK-SAME: outs(%[[INIT_SUB]] : tensor) + // CHECK: ^bb0(%[[LHS_SCALAR:.*]]: f32, %[[RHS_SCALAR:.*]]: f32, %{{.*}}: f32): + // CHECK-DAG: %[[RES_SCALAR:.*]] = arith.addf %[[LHS_SCALAR]], %[[RHS_SCALAR]] + // CHECK: linalg.yield %[[RES_SCALAR]] + // CHECK: return %[[RES]] + %init = linalg.init_tensor [32, 32] : tensor<32x32xf32> + %linalg = linalg.generic { + indexing_maps = [#id_map, #id_map, #id_map], + iterator_types = ["parallel", "parallel"]} + ins(%lhs, %rhs : tensor<32x32xf32>, tensor<32x32xf32>) + outs(%init : tensor<32x32xf32>) { + ^bb0(%lhs_scalar: f32, %rhs_scalar: f32, %_: f32): + %add = arith.addf %lhs_scalar, %rhs_scalar : f32 + linalg.yield %add : f32 + } -> tensor<32x32xf32> + %result = gml_st.materialize %linalg[%tile] : tensor<32x32xf32>[!gml_st.tile] + return %result : tensor +} + +// ----- + +#id_map = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @add_two_users +// CHECK-SAME: %[[LHS:.*]]: tensor<32x32xf32>, %[[RHS:.*]]: tensor<32x32xf32>, %[[TILE:.*]]: !gml_st.tile, %[[D0:.*]]: index, %[[D1:.*]]: index +func.func @add_two_users(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>, + %tile: !gml_st.tile, %d0: index, %d1: index) -> tensor { + // CHECK: %[[INIT:.*]] = linalg.init_tensor [32, 32] + // CHECK: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[TILE]]] + // CHECK: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[TILE]]] + // CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] + // CHECK: %[[GENERIC0:.*]] = linalg.generic + // CHECK-SAME: ins(%[[LHS_SUB]], %[[RHS_SUB]] : tensor, tensor) + // CHECK-SAME: outs(%[[INIT_SUB]] : tensor) + // CHECK: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[TILE]]] + // CHECK: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[TILE]]] + // CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] + // CHECK: %[[GENERIC1:.*]] = linalg.generic + // CHECK-SAME: ins(%[[LHS_SUB]], %[[RHS_SUB]] : tensor, tensor) + // CHECK-SAME: outs(%[[INIT_SUB]] : tensor) + // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]]] + // CHECK: %[[RES:.*]] = linalg.generic + // CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : tensor, tensor) + // CHECK-SAME: outs(%[[INIT]] : tensor) + // CHECK: return %[[RES]] + %init0 = linalg.init_tensor [32, 32] : tensor<32x32xf32> + %linalg0 = linalg.generic { + indexing_maps = [#id_map, #id_map, #id_map], + iterator_types = ["parallel", "parallel"]} + ins(%lhs, %rhs : tensor<32x32xf32>, tensor<32x32xf32>) + outs(%init0 : tensor<32x32xf32>) { + ^bb0(%lhs_scalar: f32, %rhs_scalar: f32, %_: f32): + %add = arith.addf %lhs_scalar, %rhs_scalar : f32 + linalg.yield %add : f32 + } -> tensor<32x32xf32> + %user0 = gml_st.materialize %linalg0[%tile] : tensor<32x32xf32>[!gml_st.tile] + %user1 = gml_st.materialize %linalg0[%tile] : tensor<32x32xf32>[!gml_st.tile] + %init1 = linalg.init_tensor [%d0, %d1] : tensor + %linalg1 = linalg.generic { + indexing_maps = [#id_map, #id_map, #id_map], + iterator_types = ["parallel", "parallel"]} + ins(%user0, %user1 : tensor, tensor) + outs(%init1 : tensor) { + ^bb0(%lhs_scalar: f32, %rhs_scalar: f32, %_: f32): + %add = arith.addf %lhs_scalar, %rhs_scalar : f32 + linalg.yield %add : f32 + } -> tensor + func.return %linalg1 : tensor } // ----- -// CHECK-LABEL: @cos -// CHECK-SAME: %[[ARG:.*]]: tensor<32x32xf32>, %[[TILE:.*]]: !gml_st.tile +// CHECK: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +#id_map = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: @cos +// CHECK-SAME: %[[ARG:.*]]: tensor<32x32xf32>, %[[TILE:.*]]: !gml_st.tile func.func @cos(%arg: tensor<32x32xf32>, %tile: !gml_st.tile) -> tensor { - // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] : tensor<32x32xf32>[!gml_st.tile] - // CHECK-DAG: %[[RES:.*]] = mhlo.cosine %[[ARG_SUB]] : tensor - // CHECK: return %[[RES]] - %0 = mhlo.cosine %arg : tensor<32x32xf32> - %1 = gml_st.materialize %0[%tile] : tensor<32x32xf32>[!gml_st.tile] - return %1 : tensor + // CHECK-DAG: %[[INIT:.*]] = linalg.init_tensor [32, 32] + // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] + // CHECK-DAG: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] + // CHECK: %[[RES:.*]] = linalg.generic + // CHECK-SAME: indexing_maps = [#[[ID_MAP]], #[[ID_MAP]]], + // CHECK-SAME: iterator_types = ["parallel", "parallel"] + // CHECK-SAME: ins(%[[ARG_SUB]] : tensor) + // CHECK-SAME: outs(%[[INIT_SUB]] : tensor) + // CHECK: ^bb0(%[[ARG_SCALAR:.*]]: f32, %{{.*}}: f32): + // CHECK-DAG: %[[RES_SCALAR:.*]] = math.cos %[[ARG_SCALAR]] + // CHECK: linalg.yield %[[RES_SCALAR]] + // CHECK: return %[[RES]] + %init = linalg.init_tensor [32, 32] : tensor<32x32xf32> + %linalg = linalg.generic { + indexing_maps = [#id_map, #id_map], + iterator_types = ["parallel", "parallel"]} + ins(%arg : tensor<32x32xf32>) + outs(%init : tensor<32x32xf32>) { + ^bb0(%arg_scalar: f32, %_: f32): + %cos = math.cos %arg_scalar : f32 + linalg.yield %cos : f32 + } -> tensor<32x32xf32> + %result = gml_st.materialize %linalg[%tile] : tensor<32x32xf32>[!gml_st.tile] + return %result : tensor } // ----- +#id_map = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: @add_point // CHECK-SAME: %[[LHS:.*]]: tensor<32x32xf32>, %[[RHS:.*]]: tensor<32x32xf32>, %[[POINT:.*]]: !gml_st.point func.func @add_point(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>, %point: !gml_st.point) -> f32 { - // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[POINT]]] : tensor<32x32xf32>[!gml_st.point] - // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[POINT]]] : tensor<32x32xf32>[!gml_st.point] + // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[POINT]]] + // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[POINT]]] // CHECK-DAG: %[[RES:.*]] = arith.addf %[[LHS_SUB]], %[[RHS_SUB]] // CHECK: return %[[RES]] - %0 = mhlo.add %lhs, %rhs : tensor<32x32xf32> - %1 = gml_st.materialize %0[%point] : tensor<32x32xf32>[!gml_st.point] - func.return %1 : f32 + %init = linalg.init_tensor [32, 32] : tensor<32x32xf32> + %linalg = linalg.generic { + indexing_maps = [#id_map, #id_map, #id_map], + iterator_types = ["parallel", "parallel"]} + ins(%lhs, %rhs : tensor<32x32xf32>, tensor<32x32xf32>) + outs(%init : tensor<32x32xf32>) { + ^bb0(%lhs_scalar: f32, %rhs_scalar: f32, %_: f32): + %add = arith.addf %lhs_scalar, %rhs_scalar : f32 + linalg.yield %add : f32 + } -> tensor<32x32xf32> + %result = gml_st.materialize %linalg[%point] : tensor<32x32xf32>[!gml_st.point] + return %result : f32 } // ----- +#id_map = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: @cos_point // CHECK-SAME: %[[ARG:.*]]: tensor<32x32xf32>, %[[POINT:.*]]: !gml_st.point func.func @cos_point(%arg: tensor<32x32xf32>, %point: !gml_st.point) -> f32 { - // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] : tensor<32x32xf32>[!gml_st.point] + // CHECK-DAG: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]][%[[POINT]]] // CHECK-DAG: %[[RES:.*]] = math.cos %[[ARG_SUB]] // CHECK: return %[[RES]] - %0 = mhlo.cosine %arg : tensor<32x32xf32> - %1 = gml_st.materialize %0[%point] : tensor<32x32xf32>[!gml_st.point] - return %1 : f32 + %init = linalg.init_tensor [32, 32] : tensor<32x32xf32> + %linalg = linalg.generic { + indexing_maps = [#id_map, #id_map], + iterator_types = ["parallel", "parallel"]} + ins(%arg : tensor<32x32xf32>) + outs(%init : tensor<32x32xf32>) { + ^bb0(%arg_scalar: f32, %_: f32): + %cos = math.cos %arg_scalar : f32 + linalg.yield %cos : f32 + } -> tensor<32x32xf32> + %result = gml_st.materialize %linalg[%point] : tensor<32x32xf32>[!gml_st.point] + return %result : f32 } // ----- -#cwise_trait = { - indexing_maps = [ - affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)> - ], - iterator_types = ["parallel"] -} +// CHECK: #[[ID_MAP:.*]] = affine_map<(d0) -> (d0)> +#id_map = affine_map<(d0) -> (d0)> -// CHECK-LABEL: @fuse_into_ploop -// CHECK-SAME: %[[LHS:.*]]: tensor<8xf32>, %[[RHS:.*]]: tensor<8xf32>, %[[OUT:.*]]: tensor<8xf32> -func.func @fuse_into_ploop(%lhs : tensor<8xf32>, %rhs : tensor<8xf32>, - %out: tensor<8xf32>) -> tensor<8xf32> { - // CHECK-DAG: %[[C8:.*]] = arith.constant 8 - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[SPACE:.*]] = gml_st.space [8] : !gml_st.tile<8> - // CHECK: %[[RES:.*]] = gml_st.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[C8]]) step (%[[C4]]) { - // CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[I]]] [4] [1] : !gml_st.tile<8> to !gml_st.tile<4> - // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[TILE]]] : tensor<8xf32>[!gml_st.tile<4>] - // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[TILE]]] : tensor<8xf32>[!gml_st.tile<4>] - // CHECK-DAG: %[[OUT_SUB:.*]] = gml_st.materialize %[[OUT]][%[[TILE]]] : tensor<8xf32>[!gml_st.tile<4>] - // CHECK-DAG: %[[TANH_SUB:.*]] = mhlo.tanh %[[LHS_SUB]] - // CHECK-DAG: %[[COS_SUB:.*]] = mhlo.cosine %[[RHS_SUB]] - // CHECK: %[[RES_SUB:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[TANH_SUB]], %[[COS_SUB]] : tensor<4xf32>, tensor<4xf32>) outs(%[[OUT_SUB]] : tensor<4xf32>) - // CHECK: ^bb0(%[[TANH_SCALAR:.*]]: f32, %[[COS_SCALAR:.*]]: f32, %{{.*}}: f32): - // CHECK-DAG: %[[RES_SCALAR:.*]] = arith.addf %[[TANH_SCALAR]], %[[COS_SCALAR]] : f32 - // CHECK: linalg.yield %[[RES_SCALAR]] - // CHECK: gml_st.set_yield %[[RES_SUB]] into %[[OUT]][%[[TILE]] - // CHECK-SAME: : tensor<4xf32> into tensor<8xf32>[!gml_st.tile<4>] - // CHECK: return %[[RES]] - - %tanh = mhlo.tanh %lhs : tensor<8xf32> - %cos = mhlo.cosine %rhs : tensor<8xf32> - - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index +// CHECK: @fuse_into_ploop +// CHECK-SAME: %[[LHS:.*]]: tensor<8xf32>, %[[RHS:.*]]: tensor<8xf32> +func.func @fuse_into_ploop(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>) + -> tensor<8xf32> { + // CHECK-DAG: %[[C8:.*]] = arith.constant 8 + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[INIT:.*]] = linalg.init_tensor [8] + // CHECK-DAG: %[[SPACE:.*]] = gml_st.space [8] + // CHECK: %[[RESULT:.*]] = gml_st.parallel (%[[IV:.*]]) = (%[[C0]]) to (%[[C8]]) step (%[[C4]]) + // CHECK-DAG: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[IV]]] [4] [1] + // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[TILE]]] + // CHECK-DAG: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] + // CHECK: %[[TANH_SUB:.*]] = linalg.generic + // CHECK-SAME: indexing_maps = [#[[ID_MAP]], #[[ID_MAP]]] + // CHECK-SAME: iterator_types = ["parallel"] + // CHECK-SAME: ins(%[[LHS_SUB]] : tensor<4xf32>) + // CHECK-SAME: outs(%[[INIT_SUB]] : tensor<4xf32>) + // CHECK: ^bb0(%[[LHS_SCALAR:.*]]: f32, %{{.*}}: f32): + // CHECK-DAG: %[[TANH_SCALAR:.*]] = math.tanh %[[LHS_SCALAR]] + // CHECK: linalg.yield %[[TANH_SCALAR]] + // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[TILE]]] + // CHECK-DAG: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] + // CHECK: %[[COS_SUB:.*]] = linalg.generic + // CHECK-SAME: indexing_maps = [#[[ID_MAP]], #[[ID_MAP]]] + // CHECK-SAME: iterator_types = ["parallel"] + // CHECK-SAME: ins(%[[RHS_SUB]] : tensor<4xf32>) + // CHECK-SAME: outs(%[[INIT_SUB]] : tensor<4xf32>) + // CHECK: ^bb0(%[[RHS_SCALAR:.*]]: f32, %{{.*}}: f32): + // CHECK-DAG: %[[COS_SCALAR:.*]] = math.cos %[[RHS_SCALAR]] + // CHECK: linalg.yield %[[COS_SCALAR]] + // CHECK-DAG: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] + // CHECK: %[[RESULT_SUB:.*]] = linalg.generic + // CHECK-SAME: indexing_maps = [#[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]] + // CHECK-SAME: iterator_types = ["parallel"] + // CHECK-SAME: ins(%[[TANH_SUB]], %[[COS_SUB]] : tensor<4xf32>, tensor<4xf32>) + // CHECK-SAME: outs(%[[INIT_SUB]] : tensor<4xf32>) + // CHECK: ^bb0(%[[TANH_SCALAR:.*]]: f32, %[[COS_SCALAR:.*]]: f32, %{{.*}}: f32): + // CHECK-DAG: %[[RESULT_SCALAR:.*]] = arith.addf %[[TANH_SCALAR]], %[[COS_SCALAR]] + // CHECK: linalg.yield %[[RESULT_SCALAR]] + // CHECK: gml_st.set_yield %[[RESULT_SUB]] into %[[INIT]][%[[TILE]]] + // CHECK: return %[[RESULT]] %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %init = linalg.init_tensor [8] : tensor<8xf32> + %tanh = linalg.generic { + indexing_maps = [#id_map, #id_map], + iterator_types = ["parallel"]} + ins(%lhs : tensor<8xf32>) + outs(%init : tensor<8xf32>) { + ^bb0(%lhs_scalar: f32, %_: f32): + %tanh_scalar = math.tanh %lhs_scalar : f32 + linalg.yield %tanh_scalar : f32 + } -> tensor<8xf32> + %cos = linalg.generic { + indexing_maps = [#id_map, #id_map], + iterator_types = ["parallel"]} + ins(%rhs : tensor<8xf32>) + outs(%init : tensor<8xf32>) { + ^bb0(%rhs_scalar: f32, %_: f32): + %cos_scalar = math.cos %rhs_scalar : f32 + linalg.yield %cos_scalar : f32 + } -> tensor<8xf32> %space = gml_st.space [8] : !gml_st.tile<8> - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) { - %tile = gml_st.tile %space [%i] [4] [1] : !gml_st.tile<8> to !gml_st.tile<4> + %result = gml_st.parallel (%iv) = (%c0) to (%c8) step (%c4) { + %tile = gml_st.tile %space [%iv] [4] [1] + : !gml_st.tile<8> to !gml_st.tile<4> %tanh_sub = gml_st.materialize %tanh[%tile] : tensor<8xf32>[!gml_st.tile<4>] %cos_sub = gml_st.materialize %cos[%tile] : tensor<8xf32>[!gml_st.tile<4>] - %out_sub = gml_st.materialize %out[%tile] + %init_sub = gml_st.materialize %init[%tile] : tensor<8xf32>[!gml_st.tile<4>] - - %result_sub = linalg.generic #cwise_trait + %result_sub = linalg.generic { + indexing_maps = [#id_map, #id_map, #id_map], + iterator_types = ["parallel"]} ins(%tanh_sub, %cos_sub : tensor<4xf32>, tensor<4xf32>) - outs(%out_sub : tensor<4xf32>) { - ^bb(%l: f32, %r: f32, %o: f32) : - %s = arith.addf %l, %r : f32 - linalg.yield %s : f32 + outs(%init_sub : tensor<4xf32>) { + ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): + %tanh0 = arith.addf %arg4, %arg5 : f32 + linalg.yield %tanh0 : f32 } -> tensor<4xf32> - - gml_st.set_yield %result_sub into %out[%tile] : tensor<4xf32> into tensor<8xf32>[!gml_st.tile<4>] + gml_st.set_yield %result_sub into %init[%tile] + : tensor<4xf32> into tensor<8xf32>[!gml_st.tile<4>] } : tensor<8xf32> - func.return %sum : tensor<8xf32> + return %result : tensor<8xf32> } // ----- @@ -236,3 +385,33 @@ func.func @fuse_cwise_linalg_generic(%lhs: tensor, %4 = gml_st.materialize %3[%tile] : tensor[!gml_st.tile] return %4 : tensor } + +// ----- + +#id_map = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: @fuse_cwise_linalg_generic_at_point +// CHECK-SAME: %[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %[[POINT:.*]]: !gml_st.point +func.func @fuse_cwise_linalg_generic_at_point(%lhs: tensor, + %rhs: tensor, %point: !gml_st.point) -> f32 { + // CHECK-DAG: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[POINT]]] + // CHECK-DAG: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[POINT]]] + // CHECK-DAG: %[[RES:.*]] = arith.addf %[[LHS_SUB]], %[[RHS_SUB]] + // CHECK: return %[[RES]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.dim %lhs, %c0 : tensor + %1 = tensor.dim %lhs, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.generic { + indexing_maps = [#id_map, #id_map, #id_map], + iterator_types = ["parallel", "parallel"]} + ins(%lhs, %rhs : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %5 = arith.addf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + %4 = gml_st.materialize %3[%point] : tensor[!gml_st.point] + return %4 : f32 +} diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/invalid.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/invalid.mlir index b0a08c1f11aeb7..008c1d82a2309f 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/invalid.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/invalid.mlir @@ -265,6 +265,33 @@ func.func @tile_op_offset_out_of_bounds_considering_size_and_stride(%i: index) { // ----- +func.func @transpose_tile_op_permutation_out_of_bounds() { + %0 = gml_st.space [64, 32] : !gml_st.tile<64x32> + // expected-error@+1 {{'gml_st.transpose_tile' op permutation[1] = 2 is outside of range [0, 1]}} + %1 = gml_st.transpose_tile %0, [0, 2] : !gml_st.tile<64x32> to !gml_st.tile<64x32> + func.return +} + +// ----- + +func.func @transpose_tile_op_permutation_wrong_size() { + %0 = gml_st.space [64, 32] : !gml_st.tile<64x32> + // expected-error@+1 {{'gml_st.transpose_tile' op expected permutation attribute size = 1 to match rank = 2}} + %1 = gml_st.transpose_tile %0, [0] : !gml_st.tile<64x32> to !gml_st.tile<64x32> + func.return +} + +// ----- + +func.func @transpose_tile_op_permutation_duplicate_value() { + %0 = gml_st.space [64, 32] : !gml_st.tile<64x32> + // expected-error@+1 {{'gml_st.transpose_tile' op expected permutation attribute to contain no duplicate values, but got 0 at positions 0 and 1}} + %1 = gml_st.transpose_tile %0, [0, 0] : !gml_st.tile<64x32> to !gml_st.tile<64x32> + func.return +} + +// ----- + func.func @for_loop_wrong_yield_target( %arg: tensor<8xf32>, %output: tensor) -> tensor { %c0 = arith.constant 0 : index diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/legalize_mhlo_to_gml.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/legalize_mhlo_to_gml.mlir index 91a66988f0c3d9..830a0194f57490 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/legalize_mhlo_to_gml.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/legalize_mhlo_to_gml.mlir @@ -1,7 +1,7 @@ // RUN: mlir-hlo-opt %s --legalize-mhlo-to-gml | FileCheck %s -// CHECK: @dynamic_broadcast_in_dim -// CHECK-SAME: %[[ARG:.*]]: tensor, %[[SHAPE:.*]]: tensor<3xindex> +// CHECK-LABEL: @dynamic_broadcast_in_dim +// CHECK-SAME: %[[ARG:.*]]: tensor, %[[SHAPE:.*]]: tensor<3xindex> func.func @dynamic_broadcast_in_dim(%arg : tensor, %shape : tensor<3xindex>) -> tensor { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 @@ -10,10 +10,53 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, %shape : tensor<3xin // CHECK-DAG: %[[SHAPE_D1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] // CHECK-DAG: %[[SHAPE_D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] // CHECK-DAG: %[[INIT:.*]] = linalg.init_tensor [%[[SHAPE_D0]], %[[SHAPE_D1]], %[[SHAPE_D2]]] : tensor - // CHECK-DAG: %[[BCAST:.*]] = gml_st.dynamic_broadcast_in_dim %[[INIT]], %[[ARG]], [0, 2] : tensor, tensor -> tensor + // CHECK-NEXT: %[[BCAST:.*]] = gml_st.dynamic_broadcast_in_dim + // CHECK-SAME: ins(%[[ARG]] : tensor) + // CHECK-SAME: outs(%[[INIT]] : tensor) + // CHECK-SAME: {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} // CHECK: return %[[BCAST]] - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) - { broadcast_dimensions = dense<[0, 2]> : tensor<2xi64> } + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) + { broadcast_dimensions = dense<[0, 2]> : tensor<2xi64> } : (tensor, tensor<3xindex>) -> tensor func.return %0 : tensor -} +} + +func.func @simple_gather(%operand : tensor<3x3xf32>, + %indices: tensor<3x2xi64>) -> tensor<3xf32> { + %0 = "mhlo.gather"(%operand, %indices) { + dimension_numbers = #mhlo.gather< + collapsed_slice_dims = [0, 1], + index_vector_dim = 1, + offset_dims = [], + start_index_map = [0, 1] + >, + indices_are_sorted = false, + slice_sizes = dense<[1, 1]> : tensor<2xi64> + } : (tensor<3x3xf32>, tensor<3x2xi64>) -> tensor<3xf32> + func.return %0 : tensor<3xf32> +} + +// CHECK-LABEL: @simple_gather +// CHECK: %[[INIT:.*]] = linalg.init_tensor [3] : tensor<3xf32> +// CHECK: %[[GATHER:.*]] = gml_st.gather +// CHECK-SAME: ins(%arg0 : tensor<3x3xf32>, %arg1 : tensor<3x2xi64>) +// CHECK-SAME: outs(%[[INIT]] : tensor<3xf32>) +// CHECK: return %[[GATHER]] + +func.func @unsupported_gather(%operand : tensor<3x3xf32>, + %indices: tensor<3x2xi64>) -> tensor<3xf32> { + %0 = "mhlo.gather"(%operand, %indices) { + dimension_numbers = #mhlo.gather< + collapsed_slice_dims = [0, 1], + index_vector_dim = 1, + offset_dims = [], + start_index_map = [1, 0] + >, + indices_are_sorted = false, + slice_sizes = dense<[1, 1]> : tensor<2xi64> + } : (tensor<3x3xf32>, tensor<3x2xi64>) -> tensor<3xf32> + func.return %0 : tensor<3xf32> +} + +// CHECK-LABEL: @unsupported_gather +// CHECK: mhlo.gather \ No newline at end of file diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/ops.mlir index 58319e91f6c9e8..3261422360c269 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/ops.mlir @@ -13,6 +13,8 @@ func.func @types() { %2 = gml_st.point %0 [42] : !gml_st.tile<64> to !gml_st.point // CHECK: %{{.*}} = gml_st.tile %[[ARG2]] [0, 0] [42, 16] [1, 1] : !gml_st.tile<64x32> to !gml_st.tile<42x16> %3 = gml_st.tile %1 [0, 0] [42, 16] [1, 1] : !gml_st.tile<64x32> to !gml_st.tile<42x16> + // CHECK: %{{.*}} = gml_st.transpose_tile %[[ARG2]], [1, 0] : !gml_st.tile<64x32> to !gml_st.tile<32x64> + %4 = gml_st.transpose_tile %1, [1, 0] : !gml_st.tile<64x32> to !gml_st.tile<32x64> func.return } @@ -309,3 +311,41 @@ func.func @for_loop(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>, func.return %sum : tensor<8xf32> } // CHECK-LABEL: func @for_loop + +func.func @dynamic_broadcast_in_dim(%arg: tensor, + %dst: tensor) { + %bcast = gml_st.dynamic_broadcast_in_dim + ins(%arg: tensor) + outs(%dst: tensor) { + broadcast_dimensions = dense<[0, 2]> : tensor<2xi64> + } + func.return +} +// CHECK-LABEL: func @dynamic_broadcast_in_dim + +// ----- + +func.func @gather(%arg: tensor<100xf32>, + %indices: tensor<42x1xi64>, + %dst: tensor<42xf32>) -> tensor<42xf32> { + %gather = gml_st.gather + ins(%arg: tensor<100xf32>, %indices: tensor<42x1xi64>) + outs(%dst: tensor<42xf32>) + func.return %gather : tensor<42xf32> +} +// CHECK-LABEL: func @gather + +// ----- + +func.func @scatter(%arg: tensor<3x3xf32>, + %indices: tensor<2x2xi64>, + %updates: tensor<3xf32>, + %dst: tensor<3x3xf32>) -> tensor<3x3xf32> { + %scatter = gml_st.scatter + ins(%arg: tensor<3x3xf32>, + %indices: tensor<2x2xi64>, + %updates: tensor<3xf32>) + outs(%dst: tensor<3x3xf32>) + func.return %scatter : tensor<3x3xf32> +} +// CHECK-LABEL: func @scatter diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/vectorize.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/vectorize.mlir new file mode 100644 index 00000000000000..c798d4cc18d349 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/vectorize.mlir @@ -0,0 +1,74 @@ +// Test vectorization of gml_st.parallel and gml_st.for loops. +// RUN: mlir-hlo-opt %s --vectorize-gml-st-loops | \ +// RUN: FileCheck %s + +#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @parallel_with_tiles( +func.func @parallel_with_tiles( + %arg0: memref, %arg1: memref, %arg2: memref) + -> memref { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.dim %arg0, %c1 : memref + gml_st.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c4, %c1) { + %6 = memref.subview %arg2[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + %7 = memref.subview %arg1[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + %8 = memref.subview %arg0[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + linalg.generic {indexing_maps = [#map1, #map1, #map1], + iterator_types = ["parallel", "parallel"]} + ins(%8, %7 : memref<4x1xf32, #map0>, memref<4x1xf32, #map0>) + outs(%6 : memref<4x1xf32, #map0>) { + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): + %9 = arith.addf %arg5, %arg6 : f32 + linalg.yield %9 : f32 + } + gml_st.set_yield + } + func.return %arg2 : memref +} +// CHECK-NOT: linalg.generic +// CHECK: %[[LHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] +// CHECK: %[[RHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] +// CHECK: %[[ADD:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4x1xf32> +// CHECK: vector.transfer_write %[[ADD]], {{%.*}}[%c0, %c0] + +// CHECK-LABEL: @for_with_tiles( +func.func @for_with_tiles( + %arg0: memref, %arg1: memref, %arg2: memref) + -> memref { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.dim %arg0, %c1 : memref + gml_st.for (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c4, %c1) { + %6 = memref.subview %arg2[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + %7 = memref.subview %arg1[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + %8 = memref.subview %arg0[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + linalg.generic {indexing_maps = [#map1, #map1, #map1], + iterator_types = ["parallel", "parallel"]} + ins(%8, %7 : memref<4x1xf32, #map0>, memref<4x1xf32, #map0>) + outs(%6 : memref<4x1xf32, #map0>) { + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): + %9 = arith.addf %arg5, %arg6 : f32 + linalg.yield %9 : f32 + } + gml_st.set_yield + } + func.return %arg2 : memref +} +// CHECK-NOT: linalg.generic +// CHECK: %[[LHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] +// CHECK: %[[RHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] +// CHECK: %[[ADD:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4x1xf32> +// CHECK: vector.transfer_write %[[ADD]], {{%.*}}[%c0, %c0] diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir index a20b2b52713845..3c9e7d1713f182 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/lhlo_gpu/lhlo_gpu_ops.mlir @@ -197,6 +197,7 @@ func.func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5 >, alpha_real = 0.5, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 20, rhs_stride = 20, @@ -205,28 +206,6 @@ func.func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5 func.return } - -// CHECK-LABEL: func @gemm_bias -func.func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, - %bias: memref<5x5xf32>, %output:memref<5x5xf32>) { - "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [1,1], - rhs_batching_dimensions = [1,1], - lhs_contracting_dimensions = [1,1], - rhs_contracting_dimensions = [1,1] - >, - alpha_real = 0.5, - alpha_imag = 0.0, - beta = 1.0, - batch_size = 1, - lhs_stride = 20, - rhs_stride = 20, - algorithm = 0 - } : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> () - func.return -} - // CHECK-LABEL: func @cholesky func.func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { %scratch = memref.alloc() : memref<32xi8> diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir index e4b6db09db3d22..3ce46faa043256 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir @@ -370,21 +370,29 @@ func.func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } +// CHECK-LABEL: dynamic_update_slice_fold_length_0 +func.func @dynamic_update_slice_fold_length_0(%arg0: tensor<3x4xi64>, %arg1: tensor<3x0xi64>) -> tensor<3x4xi64> { + // CHECK: return %arg0 + %0 = mhlo.constant dense<0> : tensor + %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor<3x4xi64>, tensor<3x0xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %1 : tensor<3x4xi64> +} + // CHECK-LABEL: dynamic_update_slice_identity_update func.func @dynamic_update_slice_identity_update(%arg0: tensor<3x4xi64>, %arg1: tensor<3x4xi64>) -> tensor<3x4xi64> { // CHECK: return %arg1 %0 = mhlo.constant dense<0> : tensor - %1 = "mhlo.dynamic-update-slice"(%arg0, %arg1, %0, %0) : (tensor<3x4xi64>, tensor<3x4xi64>, tensor, tensor) -> tensor<3x4xi64> + %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor<3x4xi64>, tensor<3x4xi64>, tensor, tensor) -> tensor<3x4xi64> func.return %1 : tensor<3x4xi64> } // CHECK-LABEL: dynamic_update_slice_fold_fail_dynamic_shapes func.func @dynamic_update_slice_fold_fail_dynamic_shapes(%arg0: tensor, %arg1: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor - %1 = "mhlo.dynamic-update-slice"(%arg0, %arg1, %0, %0) : (tensor, tensor, tensor, tensor) -> tensor + %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor, tensor, tensor, tensor) -> tensor func.return %1 : tensor // CHECK: %[[CST:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[VAL:.*]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, %[[CST]], %[[CST]]) : (tensor, tensor, tensor, tensor) -> tensor + // CHECK: %[[VAL:.*]] = "mhlo.dynamic_update_slice"(%arg0, %arg1, %[[CST]], %[[CST]]) : (tensor, tensor, tensor, tensor) -> tensor // CHECK: return %[[VAL]] : tensor } diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir index 3abd1cfac24428..485c429f34be10 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir @@ -2990,7 +2990,7 @@ func.func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor, // ----- func.func @dynamic_update_slice(%target: tensor<3x3xi32>, %update: tensor<2x2xi32>, %c0: tensor) -> tensor<3x3xi32> { - %0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0) + %0 = "mhlo.dynamic_update_slice"(%target, %update, %c0, %c0) : (tensor<3x3xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<3x3xi32> func.return %0 : tensor<3x3xi32> } @@ -3017,7 +3017,7 @@ func.func @dynamic_update_slice(%target: tensor<3x3xi32>, %update: tensor<2x2xi3 // ----- func.func @dynamic_update_slice_unsigned(%target: tensor<3x3xui32>, %update: tensor<2x2xui32>, %c0: tensor) -> tensor<3x3xui32> { - %0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0) + %0 = "mhlo.dynamic_update_slice"(%target, %update, %c0, %c0) : (tensor<3x3xui32>, tensor<2x2xui32>, tensor, tensor) -> tensor<3x3xui32> func.return %0 : tensor<3x3xui32> } @@ -3049,7 +3049,7 @@ func.func @dynamic_update_slice_unsigned(%target: tensor<3x3xui32>, %update: ten func.func @dynamic_update_slice_float(%target: tensor<3x3xf32>, %update: tensor<2x2xf32>, %c0: tensor) -> tensor<3x3xf32> { - %0 = "mhlo.dynamic-update-slice"(%target, %update, %c0, %c0) + %0 = "mhlo.dynamic_update_slice"(%target, %update, %c0, %c0) : (tensor<3x3xf32>, tensor<2x2xf32>, tensor, tensor) -> tensor<3x3xf32> func.return %0 : tensor<3x3xf32> } @@ -4450,7 +4450,7 @@ func.func @torch_index_select(%arg0: tensor<5x1x5xi32>, func.func @rng_uniform_1d(%min: tensor, %max: tensor) -> tensor<10xf32> { %shape = arith.constant dense<[10]> : tensor<1xi32> - %0 = "mhlo.rng_uniform"(%min, %max, %shape) : (tensor, tensor, tensor<1xi32>) -> tensor<10xf32> + %0 = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<1xi32>) -> tensor<10xf32> func.return %0 : tensor<10xf32> } // CHECK-LABEL: func @rng_uniform_1d @@ -4477,7 +4477,7 @@ func.func @rng_uniform_1d(%min: tensor, %max: tensor) -> tensor<10xf32 func.func @rng_uniform_2d(%min: tensor, %max: tensor) -> tensor<3x3xf32> { %shape = arith.constant dense<[3, 3]> : tensor<2xi32> - %0 = "mhlo.rng_uniform"(%min, %max, %shape) : (tensor, tensor, tensor<2xi32>) -> tensor<3x3xf32> + %0 = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi32>) -> tensor<3x3xf32> func.return %0 : tensor<3x3xf32> } // CHECK-LABEL: func @rng_uniform_2d @@ -4509,7 +4509,7 @@ func.func @rng_uniform_2d(%min: tensor, %max: tensor) -> tensor<3x3xf3 func.func @rng_uniform_3d(%min: tensor, %max: tensor) -> tensor<2x2x2xf32> { %shape = arith.constant dense<[2, 2, 2]> : tensor<3xi32> - %0 = "mhlo.rng_uniform"(%min, %max, %shape) : (tensor, tensor, tensor<3xi32>) -> tensor<2x2x2xf32> + %0 = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi32>) -> tensor<2x2x2xf32> func.return %0 : tensor<2x2x2xf32> } // CHECK-LABEL: func @rng_uniform_3d @@ -4545,7 +4545,7 @@ func.func @rng_uniform_3d(%min: tensor, %max: tensor) -> tensor<2x2x2x func.func @rng_uniform_dynamic_1d(%min: tensor, %max: tensor, %shape: tensor<1xi32>) -> tensor { - %0 = "mhlo.rng_uniform"(%min, %max, %shape) : (tensor, tensor, tensor<1xi32>) -> tensor + %0 = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<1xi32>) -> tensor func.return %0 : tensor } // CHECK-LABEL: func @rng_uniform_dynamic_1d diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir index 0bed861a239f24..079f440858364f 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir @@ -150,21 +150,10 @@ func.func @concat(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xindex // ----- -// CHECK-LABEL: func @transpose -func.func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xindex> { - %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> - %1 = "mhlo_test.get_return_type_components"(%0) - : (tensor<2x1x4x3xi32>) -> tensor<2x1x4x3xindex> -// CHECK: %1 = "mhlo_test.return_type_components"(%0) {dims0 = [2, 1, 4, 3], element_type0 = i32} : (tensor<2x1x4x3xi32>) -> tensor<2x1x4x3xindex> - func.return %1 : tensor<2x1x4x3xindex> -} - -// ----- - // CHECK-LABEL: @rng_normal func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<7xindex> { %0 = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - %1 = "mhlo.rng_normal"(%arg0, %arg1, %0) : (tensor, tensor, tensor<1xi64>) -> tensor<7xf32> + %1 = "mhlo.rng"(%arg0, %arg1, %0) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<1xi64>) -> tensor<7xf32> %2 = "mhlo_test.get_return_type_components"(%1) : (tensor<7xf32>) -> tensor<7xindex> // CHECK: %2 = "mhlo_test.return_type_components"(%1) {dims0 = [7], element_type0 = f32} : (tensor<7xf32>) -> tensor<7xindex> @@ -176,7 +165,7 @@ func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<7xindex> // CHECK-LABEL: func @rng_uniform func.func @rng_uniform(%a: tensor, %b: tensor) -> tensor<2x3x5xindex> { %0 = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %1 = "mhlo.rng_uniform"(%a, %b, %0) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %1 = "mhlo.rng"(%a, %b, %0) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> %2 = "mhlo_test.get_return_type_components"(%1) : (tensor<2x3x5xf32>) -> tensor<2x3x5xindex> // CHECK: %2 = "mhlo_test.return_type_components"(%1) {dims0 = [2, 3, 5], element_type0 = f32} : (tensor<2x3x5xf32>) -> tensor<2x3x5xindex> diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir index 6363ad08ef7c35..46abd9919563b9 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/ops.mlir @@ -1199,7 +1199,7 @@ func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<1 // CHECK-LABEL: func @rng_normal func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<2x3x5xf32> { %cst = "mhlo.constant"() {value = dense<[2, 3, 5]> : tensor<3xi64>} : () -> tensor<3xi64> - %0 = "mhlo.rng_normal"(%arg0, %arg1, %cst) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1207,7 +1207,7 @@ func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<2x3x5xf3 // CHECK-LABEL: func @rng_normal_no_constant func.func @rng_normal_no_constant(%a: tensor, %b: tensor, %shape: tensor<3xi64>) -> tensor { - %0 = "mhlo.rng_normal"(%a, %b, %shape) : (tensor, tensor, tensor<3xi64>) -> tensor + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<3xi64>) -> tensor func.return %0 : tensor } @@ -1215,7 +1215,7 @@ func.func @rng_normal_no_constant(%a: tensor, %b: tensor, %shape: tens // CHECK-LABEL: func @rng_normal_dynamic_dim func.func @rng_normal_dynamic_dim(%a: tensor, %b: tensor, %shape: tensor) -> tensor<*xf32> { - %0 = "mhlo.rng_normal"(%a, %b, %shape) : (tensor, tensor, tensor) -> tensor<*xf32> + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1224,7 +1224,7 @@ func.func @rng_normal_dynamic_dim(%a: tensor, %b: tensor, %shape: tens func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { %cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> // expected-error @+1 {{inferred type(s) 'tensor<7xf32>' are incompatible with return type(s) of operation 'tensor<12xf32>'}} - %0 = "mhlo.rng_normal"(%arg0, %arg1, %cst) : (tensor, tensor, tensor<1xi64>) -> tensor<12xf32> + %0 = "mhlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<1xi64>) -> tensor<12xf32> func.return } @@ -1232,18 +1232,17 @@ func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#0 must be 0D tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} - %0 = "mhlo.rng_normal"(%mu, %sigma, %shape) : (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } // ----- - func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #1 must be 0D tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} - %0 = "mhlo.rng_normal"(%mu, %sigma, %shape) : (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> + // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1252,7 +1251,7 @@ func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32> func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64> // expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}} - %0 = "mhlo.rng_normal"(%mu, %sigma, %shape) : (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1260,8 +1259,8 @@ func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { %cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - // expected-error @+1 {{operand #0 must be 0D tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} - %0 = "mhlo.rng_normal"(%arg0, %arg1, %cst) : (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> + // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + %0 = "mhlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #mhlo.rng_distribution}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> func.return } @@ -1270,7 +1269,7 @@ func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %0 = "mhlo.rng_uniform"(%a, %b, %shape) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1278,7 +1277,7 @@ func.func @rng_uniform(%a: tensor, %b: tensor) -> tensor<2x3x5xf32> { // CHECK-LABEL: func @rng_uniform_no_constant func.func @rng_uniform_no_constant(%a: tensor, %b: tensor, %shape: tensor<3xi64>) -> tensor { - %0 = "mhlo.rng_uniform"(%a, %b, %shape) : (tensor, tensor, tensor<3xi64>) -> tensor + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<3xi64>) -> tensor func.return %0 : tensor } @@ -1286,7 +1285,7 @@ func.func @rng_uniform_no_constant(%a: tensor, %b: tensor, %shape: ten // CHECK-LABEL: func @rng_uniform_dynamic_dim func.func @rng_uniform_dynamic_dim(%a: tensor, %b: tensor, %shape: tensor) -> tensor<*xf32> { - %0 = "mhlo.rng_uniform"(%a, %b, %shape) : (tensor, tensor, tensor) -> tensor<*xf32> + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1294,7 +1293,7 @@ func.func @rng_uniform_dynamic_dim(%a: tensor, %b: tensor, %shape: ten func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %arg2: tensor<7xi64>) { // expected-error @+1 {{inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor'}} - %0 = "mhlo.rng_uniform"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<7xi64>) -> tensor + %0 = "mhlo.rng"(%arg0, %arg1, %arg2) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<7xi64>) -> tensor func.return } @@ -1303,7 +1302,7 @@ func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %ar func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} - %0 = "mhlo.rng_uniform"(%a, %b, %shape) : (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1313,7 +1312,7 @@ func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> ten func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} - %0 = "mhlo.rng_uniform"(%a, %b, %shape) : (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1322,7 +1321,7 @@ func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> ten func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64> // expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}} - %0 = "mhlo.rng_uniform"(%a, %b, %shape) : (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1331,7 +1330,7 @@ func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> t func.func @rng_uniform_invalid_type(%a: tensor>, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} - %0 = "mhlo.rng_uniform"(%a, %b, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1631,7 +1630,7 @@ func.func @dynamic_slice_slice_size_too_large(%arg0: tensor<3x4xi32>, %arg1: ten // CHECK-LABEL: @dynamic_update_slice func.func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { - %0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> func.return %0 : tensor<3x4xi64> } @@ -1639,7 +1638,7 @@ func.func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, func.func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { // expected-error@+1 {{operand #2 must be 0D tensor of 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} - %0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> func.return %0 : tensor<3x4xi64> } @@ -1647,7 +1646,7 @@ func.func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: func.func @dynamic_update_slice_mismatched_start(%input: tensor<11x3x4xi32>, %update: tensor<1x3x4xi32>, %start1: tensor, %start2: tensor, %start3: tensor) -> tensor<11x3x4xi32> { // expected-error@+1 {{start indices must have same element type (encountered mismatch: 'i32' vs 'i64')}} - %0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2, %start3) : (tensor<11x3x4xi32>, tensor<1x3x4xi32>, tensor, tensor, tensor) -> tensor<11x3x4xi32> + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2, %start3) : (tensor<11x3x4xi32>, tensor<1x3x4xi32>, tensor, tensor, tensor) -> tensor<11x3x4xi32> func.return %0 : tensor<11x3x4xi32> } diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_rewriting.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_rewriting.mlir index e431551ed545ef..54b27940389c03 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_rewriting.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/sparse_rewriting.mlir @@ -68,3 +68,13 @@ func.func @rewrite_convert_nop(%arg0: tensor<10x10xf64, #CSR>) -> tensor<10x10xf %2 = sparse_tensor.convert %1 : tensor<10x10xf64, #CSR> to tensor<10x10xf64, #CSR> return %2 : tensor<10x10xf64, #CSR> } + +// CHECK-LABEL: func @rewrite_transpose( +// CHECK-SAME: %[[ARG0:.*]]: tensor<100x200xf64, #{{.*}}>) -> tensor<200x100xf64, #{{.*}}> { +// CHECK: %[[VAL:.*]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<100x200xf64, #{{.*}}>) -> tensor<200x100xf64, #{{.*}}> +// CHECK-NEXT: return %[[VAL:.*]] : tensor<200x100xf64, #{{.*}}> +func.func @rewrite_transpose(%arg0: tensor<100x200xf64, #CSR>) -> tensor<200x100xf64, #CSR> { + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<100x200xf64, #CSR>) -> tensor<200x100xf64> + %1 = sparse_tensor.convert %0 : tensor<200x100xf64> to tensor<200x100xf64, #CSR> + return %1 : tensor<200x100xf64, #CSR> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/unfuse_batch_norm.mlir index 781d4125fe234f..de3ccf244fa22b 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/unfuse_batch_norm.mlir @@ -208,25 +208,16 @@ func.func @batchNormInference_dynamic_shape( %x: tensor, %scale: tensor, %offset: tensor, %mean: tensor, %variance: tensor) -> tensor { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[VARIANCE]], %[[C0]] : tensor - // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[VAR_SHAPE:.+]] = shape.shape_of %[[VARIANCE]] : tensor -> tensor<1xindex> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[VAR_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor - // CHECK-DAG: %[[INPUT_DIM_0:.+]] = tensor.dim %[[X]], %[[C0]] : tensor - // CHECK-DAG: %[[INPUT_DIM_1:.+]] = tensor.dim %[[X]], %[[C1]] : tensor - // CHECK-DAG: %[[INPUT_DIM_2:.+]] = tensor.dim %[[X]], %[[C2]] : tensor - // CHECK-DAG: %[[INPUT_DIM_3:.+]] = tensor.dim %[[X]], %[[C3]] : tensor - // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor -> tensor<4xindex> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor @@ -239,11 +230,45 @@ func.func @batchNormInference_dynamic_shape( } // ----- -// TODO(qingyunqu): complete this testcase // CHECK-LABEL: @batchNormTraining_dynamic_shape +// Validate that dynamic shapes are handled properly. +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] func.func @batchNormTraining_dynamic_shape( %x: tensor, %scale: tensor, %offset: tensor) -> (tensor, tensor, tensor) { + // CHECK-DAG: %[[ZERO:.+]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor + // CHECK-DAG: %[[SCALE_SHAPE:.+]] = shape.shape_of %[[SCALE]] : tensor -> tensor<1xindex> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[SCALE_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor -> tensor<4xindex> + // CHECK-DAG: %[[X_SIZE:.+]] = shape.num_elements %[[X_SHAPE]] : tensor<4xindex> -> index + // CHECK-DAG: %[[SCALE_SIZE:.+]] = shape.num_elements %[[SCALE_SHAPE]] : tensor<1xindex> -> index + // CHECK-DAG: %[[REDUCE_SIZE:.+]] = shape.div %[[X_SIZE]], %[[SCALE_SIZE]] : index, index -> index + // CHECK-DAG: %[[INDEX_CAST:.+]] = arith.index_cast %[[REDUCE_SIZE]] : index to i64 + // CHECK-DAG: %[[REDUCE_SIZE_TENSOR:.+]] = tensor.from_elements %[[INDEX_CAST]] : tensor<1xi64> + // CHECK-DAG: %[[REDUCE_SIZE_TENSOR_FP:.+]] = mhlo.convert(%[[REDUCE_SIZE_TENSOR]]) : (tensor<1xi64>) -> tensor<1xf32> + // CHECK-DAG: %[[REDUCE_SIZE_RESHAPE:.+]] = "mhlo.reshape"(%[[REDUCE_SIZE_TENSOR_FP]]) : (tensor<1xf32>) -> tensor + // CHECK-DAG: %[[REDUCE_SIZE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[REDUCE_SIZE_RESHAPE]], %[[SCALE_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[X_SUM:.+]] = mhlo.reduce(%[[X]] init: %[[ZERO]]) applies mhlo.add across dimensions = [0, 1, 3] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[X2:.+]] = mhlo.multiply %[[X]], %[[X]] : tensor + // CHECK-DAG: %[[X2_SUM:.+]] = mhlo.reduce(%[[X2]] init: %[[ZERO]]) applies mhlo.add across dimensions = [0, 1, 3] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[EX:.+]] = mhlo.divide %[[X_SUM]], %[[REDUCE_SIZE_BCAST]] : tensor + // CHECK-DAG: %[[EX2:.+]] = mhlo.divide %[[X2_SUM]], %[[REDUCE_SIZE_BCAST]] : tensor + // CHECK-DAG: %[[EX_2:.+]] = mhlo.multiply %[[EX]], %[[EX]] : tensor + // CHECK-DAG: %[[VARX:.+]] = mhlo.subtract %[[EX2]], %[[EX_2]] : tensor + // CHECK-DAG: %[[VARX_EPS:.+]] = mhlo.add %[[VARX]], %[[EPS_BCAST]] : tensor + // CHECK-DAG: %[[STDX:.+]] = mhlo.sqrt %[[VARX_EPS]] : tensor + // CHECK-DAG: %[[EX_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EX]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[X_SUB_EX:.+]] = mhlo.subtract %[[X]], %[[EX_BCAST]] : tensor + // CHECK-DAG: %[[STDX_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDX]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[X_CENTOR:.+]] = mhlo.divide %[[X_SUB_EX]], %[[STDX_BCAST]] : tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTOR]], %[[SCALE_BCAST]] : tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[X_SHAPE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_SCALED]], %[[OFFSET_BCAST]] : tensor + // CHECK-DAG: return %[[RESULT]], %[[EX]], %[[VARX]] : tensor, tensor, tensor %0:3 = "mhlo.batch_norm_training"(%x, %scale, %offset) {epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_conv_op.mlir similarity index 99% rename from tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir rename to tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_conv_op.mlir index c4a2eaf1c9a9cd..1b5d1dcf9fee3d 100644 --- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/conv_op_verifier.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_conv_op.mlir @@ -691,7 +691,7 @@ func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, // ----- -// The following tests checks the inferred output-type of ConvOp. We +// The following tests checks the inferred output-type of ConvolutionOp. We // deliberately put an invalid output-type in these tests so that the // inffered-type can be highlighted in the error message. diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/reduce_op_verifier.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir similarity index 100% rename from tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/reduce_op_verifier.mlir rename to tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/reduce_window_op_verifier.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir similarity index 100% rename from tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/reduce_window_op_verifier.mlir rename to tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/scatter_op_verifier.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir similarity index 100% rename from tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/scatter_op_verifier.mlir rename to tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/select_and_scatter_op_verifier.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir similarity index 100% rename from tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/select_and_scatter_op_verifier.mlir rename to tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir diff --git a/tensorflow/compiler/mlir/hlo/tests/gml_st_pipeline.mlir b/tensorflow/compiler/mlir/hlo/tests/gml_st_pipeline.mlir index cebc57e8cf7596..8ca9ebfbd27d51 100644 --- a/tensorflow/compiler/mlir/hlo/tests/gml_st_pipeline.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/gml_st_pipeline.mlir @@ -1,34 +1,85 @@ // RUN: mlir-hlo-opt --split-input-file %s \ -// RUN: --gml-st-pipeline="tile-sizes=256" \ -// RUN: | FileCheck --dump-input=always %s +// RUN: --gml-st-pipeline="tile-sizes=64,4 fuse" \ +// RUN: | FileCheck %s + +// TODO(akuegel): Also run with the option lower-to-loops. This fails currently +// due to not having a bufferization for gml_st.dynamic_broadcast_in_dim. // CHECK-LABEL: func @log( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2048xf32>) -func.func @log(%arg0: tensor<2048xf32>) -> tensor<2048xf32> { - %0 = mhlo.log %arg0 : tensor<2048xf32> - return %0 : tensor<2048xf32> +// CHECK-SAME: %[[ARG0:.*]]: tensor<512x4xf32>) +func.func @log(%arg0: tensor<512x4xf32>) -> tensor<512x4xf32> { + %0 = mhlo.log %arg0 : tensor<512x4xf32> + return %0 : tensor<512x4xf32> } -// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2048:.*]] = arith.constant 2048 : index -// CHECK: %[[INIT:.*]] = linalg.init_tensor [2048] : tensor<2048xf32> -// CHECK: %[[SPACE:.*]] = gml_st.space [2048] : !gml_st.tile<2048> -// CHECK: %[[RESULT:.*]] = gml_st.parallel (%[[IV:.*]]) = (%[[C0]]) -// CHECK: to (%[[C2048]]) step (%[[C256]]) -// CHECK: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[IV]]] [256] [1] : -// CHECK: !gml_st.tile<2048> to !gml_st.tile<256> +// CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[INIT:.*]] = linalg.init_tensor [512, 4] : tensor<512x4xf32> +// CHECK: %[[SPACE:.*]] = gml_st.space [512, 4] : !gml_st.tile<512x4> +// CHECK: %[[RESULT:.*]] = gml_st.parallel (%[[IV:.*]], %[[IV2:.*]]) = +// CHECK: (%[[C0]], %[[C0]]) to (%[[C512]], %[[C4]]) +// CHECK: step (%[[C64]], %[[C4]]) +// CHECK: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[IV]], %[[IV2]]] +// CHECK: [64, 4] [1, 1] : +// CHECK: !gml_st.tile<512x4> to !gml_st.tile<64x4> // CHECK: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] : -// CHECK: tensor<2048xf32>[!gml_st.tile<256>] +// CHECK: tensor<512x4xf32>[!gml_st.tile<64x4>] // CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] : -// CHECK: tensor<2048xf32>[!gml_st.tile<256>] +// CHECK: tensor<512x4xf32>[!gml_st.tile<64x4>] // CHECK: %[[LINALG_OP:.*]] = linalg.generic -// CHECK: ins(%[[ARG_SUB]] : tensor<256xf32>) -// CHECK: outs(%[[INIT_SUB:.*]] : tensor<256xf32>) +// CHECK: ins(%[[ARG_SUB]] : tensor<64x4xf32>) +// CHECK: outs(%[[INIT_SUB:.*]] : tensor<64x4xf32>) // CHECK: %[[LOG:.*]] = math.log %{{.*}} : f32 // CHECK: linalg.yield %[[LOG]] : f32 // CHECK: gml_st.set_yield %[[LINALG_OP]] into %[[INIT]][%[[TILE]]] : -// CHECK: tensor<256xf32> into tensor<2048xf32>[!gml_st.tile<256>] -// CHECK: return %[[RESULT]] : tensor<2048xf32> +// CHECK: tensor<64x4xf32> into tensor<512x4xf32>[!gml_st.tile<64x4>] +// CHECK: return %[[RESULT]] : tensor<512x4xf32> + +// ----- + +// CHECK-LABEL: func @transposed_log( +// CHECK-SAME: %[[ARG0:.*]]: tensor<20x64xf32>) +func.func @transposed_log(%arg0: tensor<20x64xf32>) -> tensor<64x20xf32> { + %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : + (tensor<20x64xf32>) -> tensor<64x20xf32> + %1 = mhlo.log %0 : tensor<64x20xf32> + return %1 : tensor<64x20xf32> +} +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[INIT:.*]] = linalg.init_tensor [64, 20] : tensor<64x20xf32> +// CHECK: %[[INIT2:.*]] = linalg.init_tensor [64, 20] : tensor<64x20xf32> +// CHECK: %[[SPACE:.*]] = gml_st.space [64, 20] : !gml_st.tile<64x20> +// CHECK: %[[RESULT:.*]] = gml_st.parallel (%[[IV:.*]], %[[IV2:.*]]) = +// CHECK: (%[[C0]], %[[C0]]) to (%[[C64]], %[[C20]]) +// CHECK: step (%[[C64]], %[[C4]]) +// CHECK: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[IV]], %[[IV2]]] +// CHECK: [64, 4] [1, 1] : +// CHECK: !gml_st.tile<64x20> to !gml_st.tile<64x4> +// CHECK: %[[SPACE2:.*]] = gml_st.space [20, 64] : !gml_st.tile<20x64> +// CHECK: %[[TILE2:.*]] = gml_st.tile %[[SPACE2]] [%[[IV2]], %[[IV]]] +// CHECK: [4, 64] [1, 1] : +// CHECK: !gml_st.tile<20x64> to !gml_st.tile<4x64> +// CHECK: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG0]][%[[TILE2]]] : +// CHECK: tensor<20x64xf32>[!gml_st.tile<4x64>] +// CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] : +// CHECK: tensor<64x20xf32>[!gml_st.tile<64x4>] +// CHECK: %[[LINALG_OP:.*]] = linalg.generic +// CHECK: ins(%[[ARG_SUB]] : tensor<4x64xf32>) +// CHECK: outs(%[[INIT_SUB:.*]] : tensor<64x4xf32>) +// CHECK: %[[TRANSPOSE_SUB:.*]] = gml_st.materialize %[[INIT2]][%[[TILE]]] +// CHECK: : tensor<64x20xf32>[!gml_st.tile<64x4>] +// CHECK: %[[LOG_RES:.*]] = linalg.generic +// CHECK: ins(%[[LINALG_OP]] : tensor<64x4xf32>) +// CHECK: outs(%[[TRANSPOSE_SUB:.*]] : tensor<64x4xf32>) +// CHECK: %[[LOG:.*]] = math.log %{{.*}} : f32 +// CHECK: linalg.yield %[[LOG]] : f32 +// CHECK: gml_st.set_yield %[[LOG_RES]] into %[[INIT2]][%[[TILE]]] : +// CHECK: tensor<64x4xf32> into tensor<64x20xf32>[!gml_st.tile<64x4>] +// CHECK: return %[[RESULT]] : tensor<64x20xf32> // ----- diff --git a/tensorflow/compiler/mlir/hlo/tests/python/attributes.py b/tensorflow/compiler/mlir/hlo/tests/python/attributes.py index 1dbfbcab1efd97..fc73bbd979c7b2 100644 --- a/tensorflow/compiler/mlir/hlo/tests/python/attributes.py +++ b/tensorflow/compiler/mlir/hlo/tests/python/attributes.py @@ -184,6 +184,16 @@ def test_fusion_kind(): assert attr.fusion_kind == "kLoop" +@run +def test_rng_distribution(): + """Check that RngDistribution attribute is available and usable.""" + + attr = RngDistributionAttr.get("UNIFORM") + assert attr is not None + assert str(attr) == ("#mhlo.rng_distribution") + assert attr.rng_distribution == "UNIFORM" + + @run def test_rng_algorithm(): """Check that RngAlgorithm attribute is available and usable.""" diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index a0c161e8671a35..712734c2083aa4 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -31,6 +31,7 @@ package_group( # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server/...", "//research/language_modeling/sentence_explorer/ondevice/...", + "//learning/brain/research/babelfish/inference/speech_tflite/mlir/...", ], ) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index f9c181474ef8fa..372dfc9311681e 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -4838,10 +4838,9 @@ def TFL_UnsortedSegmentProdOp: TFL_Op<"unsorted_segment_prod", [ let arguments = (ins TFL_TensorOf<[F32, I32]>:$input, TFL_I32Tensor:$segment_ids, - I32Attr:$num_segments + TFL_I32Tensor:$num_segments ); let results = (outs TFL_TensorOf<[F32, I32]>:$output); - let hasOptions = 1; } def TFL_YieldOp : Op, %arg1: tensor<8xi32>) -> %0 = "tf.UnsortedSegmentProd"(%arg0, %arg1, %num_segments) : (tensor<8xf32>, tensor<8xi32>, tensor) -> tensor<8xf32> func.return %0 : tensor<8xf32> // CHECK-LABEL: unsorted_segment_prod - // CHECK: [[BCT:%.*]] = "tfl.unsorted_segment_prod"(%arg0, %arg1) {num_segments = 8 : i32} : (tensor<8xf32>, tensor<8xi32>) -> tensor<8xf32> + // CHECK: [[BCT:%.*]] = "tfl.unsorted_segment_prod"(%arg0, %arg1, %cst) : (tensor<8xf32>, tensor<8xi32>, tensor) -> tensor<8xf32> // CHECK: return [[BCT]] : tensor<8xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unsorted_segment_prod.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unsorted_segment_prod.mlir index bef3580671b113..68eecd54ae3fd4 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unsorted_segment_prod.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unsorted_segment_prod.mlir @@ -1,7 +1,7 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s -func.func @main(tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> { -^bb0(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>): +func.func @main(tensor<8xi32>, tensor<8xi32>, tensor) -> tensor<8xi32> { +^bb0(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>, %arg2: tensor): // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { @@ -27,23 +27,27 @@ func.func @main(tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 8 ], +// CHECK-NEXT: shape: [ ], // CHECK-NEXT: type: INT32, // CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "arg2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 8 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: buffer: 4, // CHECK-NEXT: name: "tfl.unsorted_segment_prod", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: } ], -// CHECK-NEXT: inputs: [ 0, 1 ], -// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: inputs: [ 0, 1, 2 ], +// CHECK-NEXT: outputs: [ 3 ], // CHECK-NEXT: operators: [ { -// CHECK-NEXT: inputs: [ 0, 1 ], -// CHECK-NEXT: outputs: [ 2 ] -// CHECK-NEXT: builtin_options_type: UnsortedSegmentProdOptions, -// CHECK-NEXT: builtin_options: { -// CHECK-NEXT: num_segments: 8 -// CHECK-NEXT: } +// CHECK-NEXT: inputs: [ 0, 1, 2 ], +// CHECK-NEXT: outputs: [ 3 ] // CHECK-NEXT: } ] // CHECK-NEXT: name: "main" // CHECK-NEXT: } ], @@ -57,14 +61,16 @@ func.func @main(tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> { // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { // CHECK-NEXT: data: [ 50, 46, 49, 48, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: } ], // CHECK-NEXT: metadata: [ { // CHECK-NEXT: name: "min_runtime_version", -// CHECK-NEXT: buffer: 4 +// CHECK-NEXT: buffer: 5 // CHECK-NEXT: } ] // CHECK-NEXT: signature_defs: [ ] // CHECK-NEXT: } - %0 = "tfl.unsorted_segment_prod"(%arg0, %arg1) {num_segments = 8 : i32} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> + %0 = "tfl.unsorted_segment_prod"(%arg0, %arg1, %arg2) : (tensor<8xi32>, tensor<8xi32>, tensor) -> tensor<8xi32> func.return %0 : tensor<8xi32> } diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index 8e6da848fdd4cb..e1e4036a43e938 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -152,3 +152,32 @@ func.func @RemoveLeadingQdq(%arg0: tensor<4xf32>, %arg1: tensor) -> (tensor // QDQ-NEXT: %[[split:.*]]:4 = "tfl.split"(%arg1, %arg0) {num_splits = 4 : i32} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) // QDQ-NEXT: return %[[split]]#0 : tensor<2xf32> } + +// CHECK-LABEL: FoldTranspose +func.func @FoldTranspose(%arg0: tensor<1x10x20x3xf32>) -> tensor<1x20x40x16xf32> { + %cst = arith.constant dense<[1, 20, 40, 16]> : tensor<4xi32> + %cst_0 = arith.constant dense<[2, 0, 1, 3]> : tensor<4xi32> + %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform> + %1 = "tfl.pseudo_qconst"() {qtype = tensor<3x3x16x3x!quant.uniform:f32, 0.047244094488188976>>, value = dense<"0x0303040002010303FFFFFD0304020401FF0000FEFF0003FF01FD0203FF0202FEFE0003010201FD04FE0402030303000202FD0100FDFE0402FEFEFE01020101FD0204FEFDFC03FFFE0101FDFE02040002FDFFFE03FFFE0201FEFDFF00FFFDFEFD030201FD01FC01FF010003FF0401FCFD0101FC0000FE03FEFE010102000002FE02030100FE00FEFDFD0003FD000303000103FE01FF02000002FF0101FDFDFF02FFFF00000203FF0003030302FDFF03FFFF030001020102FD04FE0104FE030401030102FEFCFEFD03FD03FD000102FE02020001020000FE030202030103FFFC01FC000302000304FCFF03FD04FC00010400010100030303FC02FCFEFE01000303000100010003FE000303010301010102FEFC01FD020301FFFDFFFCFDFEFCFE030001FDFCFE000202FE020300FD00FD02FF0001FF0002FF01FD010102FDFE04FCFE0000FD01000101FF0402FF020103FC020301FF03010204FDFFFE0202FF0302FF02FFFF01FF01FF04FD0002FF00FC00FC0101010404FE03040300000301FD0001FE04FF040103FF01FD0301FF0002040403FF03FE04FDFD0103FCFE01FDFCFF03FC010200FDFE020200FF00FFFC03FE"> : tensor<3x3x16x3xi8>} : () -> tensor<3x3x16x3x!quant.uniform:f32, 0.047244094488188976>> + %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x10x20x3x!quant.uniform>} : (tensor<1x10x20x3xf32>) -> tensor<1x10x20x3x!quant.uniform> + %3 = "tfl.transpose"(%1, %cst_0) : (tensor<3x3x16x3x!quant.uniform:f32, 0.047244094488188976>>, tensor<4xi32>) -> tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>> + %4 = "tfl.transpose_conv"(%cst, %3, %2, %0) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>>, tensor<1x10x20x3x!quant.uniform>, tensor<16x!quant.uniform>) -> tensor<1x20x40x16x!quant.uniform> + %5 = "tfl.dequantize"(%4) : (tensor<1x20x40x16x!quant.uniform>) -> tensor<1x20x40x16xf32> + return %5 : tensor<1x20x40x16xf32> + + // CHECK-NOT: "tfl.transpose" + // CHECK: "tfl.pseudo_qconst"() {qtype = tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>>, value = dense<"0x03030402FD010302010103FE0301020001010001FD02030101FE0400020100FDFEFD01FC01FF02FEFCFE000303FCFE00FF0301FF04010303FF0402FE01FF01000002FD03FD03FC020202FE0204FD03FF01FFFD03FEFE010003FFFF010103FD00FCFEFE020300FFFE02FD03010402040201010401FCFDFDFF0102FE010003FD00FD02FF03FF000201FF00FD0204FD010102FFFF02020003000102FF0002FF0204040300FEFFFEFDFCFC000000000201020000010001FF00FFFF01FF03FE0003FF03FFFEFE03FE03FF0000FE0303FE0002FF01FF01FF04FDFD01FD020101FDFE0101030303020203030301FD010104FD000103FC03FF02FE020402000002FDFF0103FF03010102FDFE02FF00FE01FD02FEFE0002FD02FE0203FFFFFC01FC0102FE04FCFEFC00FCFCFF03000301FFFE03030100030001000302FC01FD0000FD010101FC01020201FDFFFE02FE00FE0201020003040203010100010404FE00FDFE04FE0401FEFDFDFD00FD04FEFCFF03FFFDFF01FF04030403020200020303FF00FF03FD000104FEFD04FCFCFDFE02FF02000003FF00FF030002FDFEFD030300030401000104FCFE030103FC01FD00FC03FE"> : tensor<16x3x3x3xi8>} : () -> tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>> + // CHECK-NEXT: "tfl.transpose_conv" +} + +// CHECK-LABEL: FoldReshape +func.func @FoldReshape(%arg0: tensor<4xi32>, %arg1: tensor<1x48x80x16x!quant.uniform>, %arg2: tensor<1x!quant.uniform>) -> tensor<1x96x160x1x!quant.uniform> { + %cst = arith.constant dense<[1, 2, 2, 16]> : tensor<4xi32> + %0 = "tfl.pseudo_qconst"() {qtype = tensor<2x2x1x16x!quant.uniform:f32, 0.022395913056501255>>, value = dense<[[[[12, -60, -51, -59, -62, 33, 53, 17, -31, 50, 27, 7, -19, -34, -14, -26]], [[47, -84, -32, -36, -102, -8, -8, 35, -33, 59, 95, 40, -25, -30, -55, 25]]], [[[4, -41, -61, 12, -23, 48, 40, 15, -39, 52, 81, -62, -24, 17, -7, -52]], [[40, -70, -45, 32, -43, 2, -30, 34, -35, 58, 77, -28, -30, 37, -47, -5]]]]> : tensor<2x2x1x16xi8>} : () -> tensor<2x2x1x16x!quant.uniform:f32, 0.022395913056501255>> + %1 = "tfl.reshape"(%0, %cst) : (tensor<2x2x1x16x!quant.uniform:f32, 0.022395913056501255>>, tensor<4xi32>) -> tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>> + %2 = "tfl.transpose_conv"(%arg0, %1, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>>, tensor<1x48x80x16x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x96x160x1x!quant.uniform> + return %2 : tensor<1x96x160x1x!quant.uniform> + // CHECK-NOT: "tfl.reshape" + // CHECK{LITERAL}: "tfl.pseudo_qconst"() {qtype = tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>>, value = dense<[[[[12, -60, -51, -59, -62, 33, 53, 17, -31, 50, 27, 7, -19, -34, -14, -26], [47, -84, -32, -36, -102, -8, -8, 35, -33, 59, 95, 40, -25, -30, -55, 25]], [[4, -41, -61, 12, -23, 48, 40, 15, -39, 52, 81, -62, -24, 17, -7, -52], [40, -70, -45, 32, -43, 2, -30, 34, -35, 58, 77, -28, -30, 37, -47, -5]]]]> : tensor<1x2x2x16xi8>} : () -> tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>> + // CHECK-NEXT: "tfl.transpose_conv" +} diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 5aa35097e41f5f..30deb87da10aeb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -221,10 +221,9 @@ def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>; def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids), (TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>; def LegalizeUnsortedSegmentProd : - Pat<(TF_UnsortedSegmentProdOp $data, $segment_ids, - (Arith_ConstantOp ElementsAttr:$num_segments)), + Pat<(TF_UnsortedSegmentProdOp $data, $segment_ids, $num_segments), (TFL_UnsortedSegmentProdOp $data, (CreateTFCastToInt32Op $segment_ids), - ExtractSingleElementAsInt32:$num_segments)>; + (CreateTFCastToInt32Op $num_segments))>; def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>; def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 5d9b184cbd2d18..ad145c1f04214f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -188,6 +189,145 @@ struct RemoveVolatileOps : public OpRewritePattern { } }; +// Fold the constant quantized Transpose ops. +struct FoldTransposeOp : public OpRewritePattern { + explicit FoldTransposeOp(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + // Computes the permutation of a constant `input_tensor` according to `perm`. + // The function recursively traverses the dimensions of the output tensor in + // a row-major order and writes the value in the output tensor into + // `new_values`. + void ComputePermutation(ElementsAttr input_tensor, ArrayRef perm, + ArrayRef output_shape, int num_dimensions, + int output_axis, std::vector* input_indices, + std::vector* new_values) const { + // Refer to the implementation of `Transpose` function in + // tensorflow/lite/kernels/internal/reference/reference_ops.h + assert(output_axis < num_dimensions); + const int input_axis = perm[output_axis]; + for (int i = 0; i < output_shape[output_axis]; ++i) { + // Update the input indices on `input_axis`. + assert(input_axis < input_indices->size()); + input_indices->operator[](input_axis) = static_cast(i); + // Write the value from `input_tensor` if it is the last axis or + // recurse into the next axis. + const bool is_last_axis = output_axis == num_dimensions - 1; + if (is_last_axis) { + new_values->push_back( + input_tensor.getValues()[*input_indices]); + } else { + ComputePermutation(input_tensor, perm, output_shape, num_dimensions, + output_axis + 1, input_indices, new_values); + } + } + } + + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter& rewriter) const override { + Operation* def_op = op.input().getDefiningOp(); + auto qconst_op = llvm::dyn_cast_or_null(def_op); + if (qconst_op == nullptr) return failure(); + + DenseIntElementsAttr perm_tensor; + if (!matchPattern(op.perm(), m_Constant(&perm_tensor))) return failure(); + + if (!(getElementTypeOrSelf(op.output().getType())) + .isa()) + return failure(); + + ElementsAttr input_tensor = qconst_op.value(); + + assert(perm_tensor.getType().getRank() == 1); + const int num_dimensions = input_tensor.getType().getRank(); + assert(perm_tensor.getType().getNumElements() == num_dimensions); + + ArrayRef input_shape = input_tensor.getType().getShape(); + auto output_type = op.output().getType().cast(); + + SmallVector perm; + SmallVector output_shape; + for (int i = 0; i < num_dimensions; ++i) { + perm.push_back(perm_tensor.getValues()[i].getInt()); + output_shape.push_back(input_shape[perm[i]]); + + // Check that the derived output shape matches the static shape. + assert(!output_type.hasStaticShape() || + output_type.getShape()[i] == output_shape[i]); + } + + std::vector new_values; + new_values.reserve(input_tensor.getType().getNumElements()); + std::vector input_indices(num_dimensions); + ComputePermutation(input_tensor, perm, output_shape, num_dimensions, + /*output_axis=*/0, &input_indices, &new_values); + auto result_type = + RankedTensorType::get(output_shape, output_type.getElementType()); + auto values_type = RankedTensorType::get( + output_shape, output_type.getElementType() + .cast() + .getStorageType()); + rewriter.replaceOpWithNewOp( + op, TypeAttr::get(result_type), + DenseIntElementsAttr::get(values_type, new_values)); + return success(); + } +}; + +// Fold constant quantized Reshape ops. +struct FoldReshapeOp : public OpRewritePattern { + // Does not take ownership of context, which must refer to a valid value that + // outlives this object. + explicit FoldReshapeOp(MLIRContext* context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(ReshapeOp op, + PatternRewriter& rewriter) const override { + Operation* def_op = op.input().getDefiningOp(); + auto qconst_op = llvm::dyn_cast_or_null(def_op); + if (qconst_op == nullptr) return failure(); + + auto dense_elements = + qconst_op.value().dyn_cast_or_null(); + if (dense_elements == nullptr) return failure(); + + // Handle per tensor cases only. + if (!(getElementTypeOrSelf(op.getType())) + .isa()) { + return failure(); + } + + // Remove identity reshape with both static result and input shape. + auto result_type = op.getType().cast(); + auto input_type = op.input().getType().cast(); + + // Constant folding + // If the result type isn't static, tries to derive the result type from + // the #2 operand. + if (!result_type.hasStaticShape()) { + DenseIntElementsAttr shape_elements; + if (!matchPattern(op.shape(), m_Constant(&shape_elements))) + return failure(); + + SmallVector shape_data; + for (const APInt& it : shape_elements.getValues()) { + shape_data.push_back(it.getSExtValue()); + } + result_type = + RankedTensorType::get(shape_data, input_type.getElementType()); + } + auto values_type = RankedTensorType::get( + result_type.getShape(), result_type.getElementType() + .cast() + .getStorageType()); + + DenseElementsAttr reshaped_elements = dense_elements.reshape(values_type); + rewriter.replaceOpWithNewOp(op, TypeAttr::get(result_type), + reshaped_elements); + return success(); + } +}; + // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling // output. template @@ -251,7 +391,8 @@ void PostQuantizePass::runOnOperation() { RewritePatternSet phase_2_patterns(&getContext()); TFL::populateWithGenerated(phase_2_patterns); phase_2_patterns.add, - RemoveVolatileOps>(ctx); + RemoveVolatileOps, + FoldTransposeOp, FoldReshapeOp>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns)); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 30905b05152833..92fc65b6040b51 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -66,10 +66,11 @@ td_library( srcs = [ "passes/lift_quantizable_spots_as_functions.td", "passes/lift_quantizable_spots_as_functions_drq.td", + "passes/optimize.td", "passes/prepare_lifting.td", "passes/prepare_quantize.td", "passes/quantize_composite_functions.td", - "passes/tf_quant_ops.td", + "passes/replace_cast_hacks_with_tf_xla_ops.td", "passes/utils.td", ], compatible_with = get_compatible_with_cloud(), @@ -172,6 +173,20 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "optimize_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "passes/optimize.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/optimize.td", + deps = [":quant_td_files"], +) + cc_library( name = "tf_quant_ops", srcs = [ @@ -202,6 +217,20 @@ cc_library( ], ) +gentbl_cc_library( + name = "replace_cast_hacks_with_tf_xla_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "passes/replace_cast_hacks_with_tf_xla_ops.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/replace_cast_hacks_with_tf_xla_ops.td", + deps = [":quant_td_files"], +) + cc_library( name = "passes", srcs = [ @@ -216,6 +245,8 @@ cc_library( "passes/lift_quantizable_spots_as_functions.inc", "passes/lift_quantizable_spots_as_functions_drq.cc", "passes/lift_quantizable_spots_as_functions_drq.inc", + "passes/optimize.cc", + "passes/optimize.inc", "passes/post_quantize.cc", "passes/prepare_lifting.cc", "passes/prepare_lifting.inc", @@ -226,6 +257,8 @@ cc_library( "passes/quantize_composite_functions.cc", "passes/quantize_composite_functions.inc", "passes/quantized_function_library.h", + "passes/replace_cast_hacks_with_tf_xla_ops.cc", + "passes/replace_cast_hacks_with_tf_xla_ops.inc", ], hdrs = [ "passes/passes.h", @@ -246,12 +279,14 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/ir/importexport:mangling", "//tensorflow/core/platform:env", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:path", + "//tensorflow/lite/kernels:padding", "//tensorflow/lite/kernels/internal:quantization_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.cc new file mode 100644 index 00000000000000..509d5f46ab997f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.cc @@ -0,0 +1,66 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::quant { +namespace { + +// Applies optimization after quantization. +class OptimizePass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "quant-optimize"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Applies optimization after quantization"; + } + + void runOnOperation() override; +}; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.inc" + +void OptimizePass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateWithGenerated(patterns); + auto func = getOperation(); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // namespace + +std::unique_ptr> CreateOptimizePass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td new file mode 100644 index 00000000000000..29b5bb149f8078 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td @@ -0,0 +1,63 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" + +// Remove redundant `CastOp` to int8 if the input is properly clipped. +def RemoveRedundantCastOps : Pat< + (TF_CastOp:$root_cast + (TF_CastOp:$i8_cast + (TF_ClipByValueOp:$clip $input, $min_value, $max_value), + ConstBoolAttrFalse:$truncate2), + ConstBoolAttrFalse:$truncate1), + (CreateOpWithOutputType<"TF::CastOp"> + (GetValueType $root_cast), $clip, ConstBoolAttrFalse), + [(TensorOf<[I8]> $i8_cast), + (TensorOf<[I32]> $clip), + (IsIntSplatValueEqual<"int32_t", "-128"> $min_value), + (IsIntSplatValueEqual<"int32_t", "127"> $max_value)]>; + +// This pattern optimizes: +// (x + cst1) + cst2 -> x + cst +// (x - cst1) - cst2 -> x - cst +// Where: cst = cst1 + cst2 +foreach BinaryOp = [TF_AddV2Op, TF_SubOp] in { + def OptimizeConsecutive#BinaryOp : Pat< + (BinaryOp + (BinaryOp $x, (TF_ConstOp:$cst1 $cst1_value)), + (TF_ConstOp:$cst2 $cst2_value)), + (BinaryOp + $x, (TF_AddV2Op $cst1, $cst2))>; +} + +// This pattern optimizes: +// (x + cst1) - cst2 -> x - cst +// (x - cst1) + cst2 -> x + cst +// Where: cst = cst2 - cst1 +foreach BinaryOpPair = [[TF_AddV2Op, TF_SubOp], + [TF_SubOp, TF_AddV2Op]] in { + def OptimizeConsecutive#BinaryOpPair[0]#BinaryOpPair[1] : Pat< + (BinaryOpPair[0] + (BinaryOpPair[1] $x, (TF_ConstOp:$cst1 $cst1_value)), + (TF_ConstOp:$cst2 $cst2_value)), + (BinaryOpPair[0] + $x, (TF_SubOp $cst1, $cst2))>; +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h index 7b40dc2dda39a0..af2780be30603c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h @@ -98,6 +98,15 @@ std::unique_ptr> CreatePostQuantizePass(); std::unique_ptr> CreateConvertTFQuantOpsToMHLOPass(); +// Applies optimization patterns after quantization. +std::unique_ptr> CreateOptimizePass(); + +// Creates an instance of the ReplaceCastHacksWithTFXLAOpsPass, which will +// replace mixed-type convolution and matmul cast hacks by XLA Conv2DOp and +// MatmulOp. +std::unique_ptr> +CreateReplaceCastHacksWithTFXLAOpsPass(); + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc new file mode 100644 index 00000000000000..882295de3a670f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc @@ -0,0 +1,263 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/lite/kernels/padding.h" + +namespace mlir::quant { +namespace { + +// Replaces mixed-type Conv and Matmul cast hacks with TF XLA ops. +// TODO(b/228403741): Support conversion for dynamic-shaped TF ops. +class ReplaceCastHacksWithTFXLAOpsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceCastHacksWithTFXLAOpsPass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "quant-replace-cast-hacks-with-tf-xla-ops"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Replace mixed-type Conv and Matmul cast hacks with TF XLA ops."; + } + + void runOnOperation() override; +}; + +// Generates params for the XLA Convolution op. +void PrepareXlaConvParams(OpBuilder &builder, Location loc, ArrayAttr strides, + ArrayAttr dilations, int feature_group_cnt, + Value &window_strides, Value &lhs_dilation, + Value &rhs_dilation, Value &feature_group_count) { + const int stride_h = strides[1].cast().getInt(); + const int stride_w = strides[2].cast().getInt(); + window_strides = + Create1DConstValue(builder, loc, {stride_h, stride_w}); + + const int dilation_h = dilations[1].cast().getInt(); + const int dilation_w = dilations[2].cast().getInt(); + lhs_dilation = Create1DConstValue(builder, loc, {1, 1}); + rhs_dilation = + Create1DConstValue(builder, loc, {dilation_h, dilation_w}); + + feature_group_count = + CreateScalarConstValue(builder, loc, feature_group_cnt); +} + +// For non-zero padding (input_zp != 0), adds Pad op before convolution. +Value CalculatePaddingAndPadIfNeeded( + OpBuilder &builder, Location loc, Value input, Value filter, + int8_t input_zp_value, ArrayAttr strides, ArrayAttr dilations, + StringAttr conv_padding, ArrayAttr explicit_paddings, Value &padding) { + ShapedType input_shape = input.getType().template cast(); + ShapedType filter_shape = filter.getType().template cast(); + + int padding_h_before, padding_h_after, padding_w_before, padding_w_after; + if (conv_padding.strref().equals("EXPLICIT")) { + if (explicit_paddings.size() != 8) { + mlir::emitError(loc, + "explicit_paddings are expected to be 8-element arrays"); + return {}; + } + padding_h_before = explicit_paddings[2].cast().getInt(); + padding_h_after = explicit_paddings[3].cast().getInt(); + padding_w_before = explicit_paddings[4].cast().getInt(); + padding_w_after = explicit_paddings[5].cast().getInt(); + } else { + TfLitePadding tflite_padding = conv_padding.strref().equals("VALID") + ? kTfLitePaddingValid + : kTfLitePaddingSame; + int output_height, output_width; + const int stride_h = strides[1].cast().getInt(); + const int stride_w = strides[2].cast().getInt(); + const int dilation_h = dilations[1].cast().getInt(); + const int dilation_w = dilations[2].cast().getInt(); + TfLitePaddingValues padding_values = tflite::ComputePaddingHeightWidth( + stride_h, stride_w, dilation_h, dilation_w, + /*in_height=*/input_shape.getDimSize(1), + /*in_width=*/input_shape.getDimSize(2), + /*filter_height=*/filter_shape.getDimSize(0), + /*filter_width=*/filter_shape.getDimSize(1), tflite_padding, + &output_height, &output_width); + padding_h_before = padding_values.height; + padding_h_after = padding_values.height + padding_values.height_offset; + padding_w_before = padding_values.width; + padding_w_after = padding_values.width + padding_values.width_offset; + } + + if (conv_padding.strref().equals("VALID") || input_zp_value == 0 || + (padding_h_before == 0 && padding_h_after == 0 && padding_w_before == 0 && + padding_w_after == 0)) { + padding = CreateConstValue( + builder, loc, {2, 2}, + {padding_h_before, padding_h_after, padding_w_before, padding_w_after}); + return input; + } + padding = CreateConstValue(builder, loc, {2, 2}, {0, 0, 0, 0}); + + Value temp_padding = + CreateConstValue(builder, loc, {4, 2}, + {0, 0, padding_h_before, padding_h_after, + padding_w_before, padding_w_after, 0, 0}); + SmallVector output_shape(input_shape.getShape().begin(), + input_shape.getShape().end()); + output_shape[1] += padding_h_before + padding_h_after; + output_shape[2] += padding_w_before + padding_w_after; + return builder.create( + loc, RankedTensorType::get(output_shape, builder.getI8Type()), input, + temp_padding, + CreateScalarConstValue(builder, loc, input_zp_value)); +} + +// Calculates zero-point offset by reducing weights and multiply it with zp. +Value CalculateZeroPointOffset( + OpBuilder &builder, Location loc, Value filter, int8_t input_zp, + int output_dim, const SmallVector &weight_non_output_indices) { + Value reduction_indices_value = + Create1DConstValue(builder, loc, weight_non_output_indices); + Value zp = CreateScalarConstValue(builder, loc, input_zp); + + auto zp_mul_output_type = + RankedTensorType::get({output_dim}, builder.getIntegerType(32)); + auto reduced = builder.create( + loc, zp_mul_output_type, filter, reduction_indices_value, + /*keep_dims=*/builder.getBoolAttr(false)); + return builder.create(loc, zp, reduced).getResult(); +} + +// Helper function to create a XlaConvV2Op for Conv2DOp and DepthwiseConv2DOp. +Value CreateXLAConvOp( + OpBuilder &builder, Location loc, Value input, Value filter, Value input_zp, + Value conv_output, ArrayAttr strides, ArrayAttr dilations, + StringAttr conv_padding, ArrayAttr explicit_paddings, int feature_group_cnt, + const SmallVector &filter_non_output_indices, + const xla::ConvolutionDimensionNumbers &dimension_numbers) { + int32_t input_zp_value; + if (!GetSplatValue(input_zp, input_zp_value)) { + mlir::emitError( + loc, "zero point is expected to be a constant with a single value"); + return {}; + } + if (strides.size() != 4 || dilations.size() != 4) { + mlir::emitError( + loc, "strides and dilations are expected to be 4-element arrays"); + return {}; + } + ShapedType input_shape = input.getType().template cast(); + ShapedType filter_shape = filter.getType().template cast(); + if (!input_shape.hasRank() || input_shape.getRank() != 4 || + !filter_shape.hasRank() || filter_shape.getRank() != 4) { + mlir::emitError(loc, "input and filter are expected to be 4D tensors"); + return {}; + } + + Value padding, window_strides, lhs_dilation, rhs_dilation, + feature_group_count; + PrepareXlaConvParams(builder, loc, strides, dilations, feature_group_cnt, + /*window_strides=*/window_strides, + /*lhs_dilation=*/lhs_dilation, + /*rhs_dilation=*/rhs_dilation, + /*feature_group_count=*/feature_group_count); + + input = CalculatePaddingAndPadIfNeeded( + builder, loc, input, filter, input_zp_value, strides, dilations, + conv_padding, explicit_paddings, padding); + auto filter_type = filter.getType().dyn_cast(); + Value filter_i8 = builder.create( + loc, filter_type.clone(builder.getIntegerType(8)), filter); + Value xla_conv_output = + builder + .create( + loc, /*output=*/conv_output.getType(), + /*lhs=*/input, + /*rhs=*/filter_i8, window_strides, padding, lhs_dilation, + rhs_dilation, feature_group_count, + builder.getStringAttr(dimension_numbers.SerializeAsString()), + /*precision_config=*/builder.getStringAttr("")) + .output(); + if (input_zp_value == 0) return xla_conv_output; + + Value zp_offset = CalculateZeroPointOffset( + builder, loc, /*filter=*/filter, /*input_zp=*/input_zp_value, + /*output_dim=*/filter_shape.getDimSize(3), + /*weight_non_output_indices=*/filter_non_output_indices); + return builder.create(loc, xla_conv_output, zp_offset); +} + +// Creates a XlaConvV2Op from TF Conv2DOp and returns its output. +Value CreateXLAConvOpFromTFConv2DOp(OpBuilder &builder, Location loc, + Value input, Value filter, Value input_zp, + Value conv_output, ArrayAttr strides, + ArrayAttr dilations, + StringAttr conv_padding, + ArrayAttr explicit_paddings) { + const int feature_group_cnt = 1; + SmallVector filter_non_output_indices = {0, 1, 2}; + xla::ConvolutionDimensionNumbers dnums; + // Input: [N, H, W, C]. + dnums.set_input_batch_dimension(0); + dnums.set_input_feature_dimension(3); + dnums.add_input_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + // Kernel: [K, K, I, O]. + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + // Output: [N, H, W, C]. + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(3); + dnums.add_output_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(2); + return CreateXLAConvOp(builder, loc, input, filter, input_zp, conv_output, + strides, dilations, conv_padding, explicit_paddings, + feature_group_cnt, filter_non_output_indices, dnums); +} + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.inc" + +void ReplaceCastHacksWithTFXLAOpsPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // namespace + +std::unique_ptr> +CreateReplaceCastHacksWithTFXLAOpsPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td new file mode 100644 index 00000000000000..7471c263a102dc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td @@ -0,0 +1,39 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" + +def CreateXLAConvOpFromTFConv2DOp : NativeCodeCall< + "CreateXLAConvOpFromTFConv2DOp($_builder, $_loc, $0...)">; + +def ConvertTFConv2DToXLAConvOp : Pat< + (TF_Conv2DOp:$conv + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_ConstOp:$filter $filter_value), + $strides, $use_cudnn, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), + (CreateXLAConvOpFromTFConv2DOp + $input, $filter, $input_zp, $conv, $strides, + $dilations, $padding, $explicit_padding), + [(IsInt8ElementType $input), + (IsInt32ElementType $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"1, 2, 3"> $input)], + (addBenefit 10)>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc index 76be9b6dfd21ee..288d327be6e586 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc @@ -36,6 +36,23 @@ bool HasQuantizedTensors(Operation* op) { return false; } +bool HasStaticShape(Value value) { + auto shaped_type = value.getType().dyn_cast(); + if (!shaped_type) return false; + + return shaped_type.hasStaticShape(); +} + +bool HasStaticShapeAtDims(Value value, llvm::ArrayRef dims) { + auto shaped_type = value.getType().dyn_cast(); + if (!shaped_type) return false; + + for (auto dim : dims) { + if (shaped_type.isDynamicDim(dim)) return false; + } + return true; +} + Type CloneTypeWithNewElementType(Type old_type, Type element_type) { if (!old_type.isa()) return {}; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h index 068bca754077fb..c80fcb9559c6e6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project @@ -30,6 +31,12 @@ namespace quant { // Returns true if the op has any quantized tensors as input or output. bool HasQuantizedTensors(Operation *op); +// Returns true if the value has static shape. +bool HasStaticShape(Value value); + +// Returns true if the value has static shape at given dims. +bool HasStaticShapeAtDims(Value value, llvm::ArrayRef dims); + enum class QuantizationMethod { kQuantizationAwareTraining, kPostTrainingQuantization, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td index e909a7245dbecb..441501fb8c542a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td @@ -39,6 +39,21 @@ def HasOneUse : Constraint>; // Gets the type of a value. def GetValueType : NativeCodeCall<"$0.getType()">; +// Checks if the value has the type of int8. +def IsInt8ElementType : Constraint< + CPred<"getElementTypeOrSelf($0).isInteger(8)">>; + +// Checks if the value has the type of int32. +def IsInt32ElementType : Constraint< + CPred<"getElementTypeOrSelf($0).isInteger(32)">>; + +// Checks if the value has static shape. +def HasStaticShapeConstraint : Constraint>; + +// Checks if the value has static shape at given dims. +class HasStaticShapeAtDimsConstraint : Constraint< + CPred<"HasStaticShapeAtDims($0, {"# dims #"})">>; + // The rewrite rule cannot replace a value with itself, so we work around // by cloning the root op to replicate that value. The old op will get folded. def CloningOpResult : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index e08d0af200e583..1819bf12d0891e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -150,6 +150,7 @@ absl::StatusOr QuantizeQATModel(absl::string_view saved_model_path, mlir::quant::QuantizationMethod::kQuantizationAwareTraining)); pm.addPass(mlir::createSymbolDCEPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::quant::CreateOptimizePass()); pm.addPass(mlir::quant::CreateInsertMainFunctionPass()); pm.addNestedPass( @@ -292,6 +293,8 @@ absl::StatusOr QuantizePTQModelPostCalibration( mlir::quant::QuantizationMethod::kPostTrainingQuantization)); pm.addPass(mlir::createSymbolDCEPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::quant::CreateOptimizePass()); + pm.addPass(mlir::quant::CreateInsertMainFunctionPass()); pm.addNestedPass( mlir::CreateFunctionalToExecutorDialectConversionPass()); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/optimize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/optimize.mlir new file mode 100644 index 00000000000000..6da979db7cca8d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/optimize.mlir @@ -0,0 +1,124 @@ +// RUN: tf-quant-opt %s -quant-optimize -allow-unregistered-dialect | FileCheck %s + +func.func @remove_redundant_cast(%arg0: tensor<1x100x100x1xf32>) -> (tensor<1x96x96x1xf32>) { + %cst = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<0.0235294122> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<0.00708661414> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_2 = "tf.Const"() {value = dense<1.799000e+03> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_3 = "tf.Const"() {value = dense<[[[[1.400000e+01]], [[-2.800000e+01]], [[4.200000e+01]]], [[[-5.600000e+01]], [[7.100000e+01]], [[-8.500000e+01]]], [[[9.900000e+01]], [[-1.130000e+02]], [[1.270000e+02]]]]> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32> + %cst_4 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<0.00118110236> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_6 = "tf.Const"() {value = dense<1.079500e+04> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_7 = "tf.Const"() {value = dense<0.00392156886> : tensor} : () -> tensor + %cst_8 = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor + %cst_9 = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_7) : (tensor<1x100x100x1xf32>, tensor) -> tensor<1x100x100x1xf32> + %1 = "tf.Round"(%0) : (tensor<1x100x100x1xf32>) -> tensor<1x100x100x1xf32> + %2 = "tf.Cast"(%1) : (tensor<1x100x100x1xf32>) -> tensor<1x100x100x1xi32> + %3 = "tf.AddV2"(%2, %cst) : (tensor<1x100x100x1xi32>, tensor) -> tensor<1x100x100x1xi32> + + %4 = "tf.ClipByValue"(%3, %cst, %cst_9) : (tensor<1x100x100x1xi32>, tensor, tensor) -> tensor<1x100x100x1xi32> + %5 = "tf.Cast"(%4) {Truncate = false} : (tensor<1x100x100x1xi32>) -> tensor<1x100x100x1xi8> + %6 = "tf.Cast"(%5) {Truncate = false} : (tensor<1x100x100x1xi8>) -> tensor<1x100x100x1xf32> + + %7 = "tf.Sub"(%6, %cst_4) : (tensor<1x100x100x1xf32>, tensor) -> tensor<1x100x100x1xf32> + %8 = "tf.Conv2D"(%7, %cst_3) {dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x100x100x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x98x98x1xf32> + %9 = "tf.AddV2"(%8, %cst_6) : (tensor<1x98x98x1xf32>, tensor<1xf32>) -> tensor<1x98x98x1xf32> + %10 = "tf.Mul"(%9, %cst_5) : (tensor<1x98x98x1xf32>, tensor<1xf32>) -> tensor<1x98x98x1xf32> + %11 = "tf.AddV2"(%10, %cst_8) : (tensor<1x98x98x1xf32>, tensor) -> tensor<1x98x98x1xf32> + %12 = "tf.Floor"(%11) : (tensor<1x98x98x1xf32>) -> tensor<1x98x98x1xf32> + %13 = "tf.Cast"(%12) {Truncate = false} : (tensor<1x98x98x1xf32>) -> tensor<1x98x98x1xi32> + %14 = "tf.AddV2"(%13, %cst) : (tensor<1x98x98x1xi32>, tensor) -> tensor<1x98x98x1xi32> + + %15 = "tf.ClipByValue"(%14, %cst, %cst_9) : (tensor<1x98x98x1xi32>, tensor, tensor) -> tensor<1x98x98x1xi32> + %16 = "tf.Cast"(%15) {Truncate = false} : (tensor<1x98x98x1xi32>) -> tensor<1x98x98x1xi8> + %17 = "tf.Cast"(%16) {Truncate = false} : (tensor<1x98x98x1xi8>) -> tensor<1x98x98x1xf32> + + %18 = "tf.Sub"(%17, %cst_4) : (tensor<1x98x98x1xf32>, tensor) -> tensor<1x98x98x1xf32> + %19 = "tf.Conv2D"(%18, %cst_3) {dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x98x98x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x96x96x1xf32> + %20 = "tf.AddV2"(%19, %cst_2) : (tensor<1x96x96x1xf32>, tensor<1xf32>) -> tensor<1x96x96x1xf32> + %21 = "tf.Mul"(%20, %cst_1) : (tensor<1x96x96x1xf32>, tensor<1xf32>) -> tensor<1x96x96x1xf32> + %22 = "tf.AddV2"(%21, %cst_8) : (tensor<1x96x96x1xf32>, tensor) -> tensor<1x96x96x1xf32> + %23 = "tf.Floor"(%22) : (tensor<1x96x96x1xf32>) -> tensor<1x96x96x1xf32> + %24 = "tf.Cast"(%23) {Truncate = false} : (tensor<1x96x96x1xf32>) -> tensor<1x96x96x1xi32> + %25 = "tf.AddV2"(%24, %cst) : (tensor<1x96x96x1xi32>, tensor) -> tensor<1x96x96x1xi32> + + %26 = "tf.ClipByValue"(%25, %cst, %cst_9) : (tensor<1x96x96x1xi32>, tensor, tensor) -> tensor<1x96x96x1xi32> + %27 = "tf.Cast"(%26) {Truncate = false} : (tensor<1x96x96x1xi32>) -> tensor<1x96x96x1xi8> + %28 = "tf.Cast"(%27) : (tensor<1x96x96x1xi8>) -> tensor<1x96x96x1xi32> + + %29 = "tf.Sub"(%28, %cst) : (tensor<1x96x96x1xi32>, tensor) -> tensor<1x96x96x1xi32> + %30 = "tf.Cast"(%29) : (tensor<1x96x96x1xi32>) -> tensor<1x96x96x1xf32> + %31 = "tf.Mul"(%30, %cst_0) : (tensor<1x96x96x1xf32>, tensor) -> tensor<1x96x96x1xf32> + return %31 : tensor<1x96x96x1xf32> + +// CHECK-LABEL: func.func @remove_redundant_cast + +// CHECK: %[[CLIPBYVALUE_0:.*]] = "tf.ClipByValue" +// CHECK-SAME: (tensor<1x100x100x1xi32>, tensor, tensor) -> tensor<1x100x100x1xi32> +// CHECK: %[[CAST_1:.*]] = "tf.Cast"(%[[CLIPBYVALUE_0]]) {Truncate = false} : (tensor<1x100x100x1xi32>) -> tensor<1x100x100x1xf32> + +// CHECK: %[[CLIPBYVALUE_1:.*]] = "tf.ClipByValue" +// CHECK-SAME: (tensor<1x98x98x1xi32>, tensor, tensor) -> tensor<1x98x98x1xi32> +// CHECK: %[[CAST_3:.*]] = "tf.Cast"(%[[CLIPBYVALUE_1]]) {Truncate = false} : (tensor<1x98x98x1xi32>) -> tensor<1x98x98x1xf32> + +// CHECK: %[[CLIPBYVALUE_2:.*]] = "tf.ClipByValue" +// CHECK-SAME: (tensor<1x96x96x1xi32>, tensor, tensor) -> tensor<1x96x96x1xi32> +// CHECK: %[[SUB_2:.*]] = "tf.Sub"(%[[CLIPBYVALUE_2]], {{.*}}) : (tensor<1x96x96x1xi32>, tensor) -> tensor<1x96x96x1xi32> +} + +func.func @consecutive_add_add(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-18> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-12> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst) {T = i32, device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_1) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @consecutive_add_add + +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-30> : tensor} : () -> tensor +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: return %[[ADD]] : tensor +} + +func.func @consecutive_add_sub(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-18> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-12> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst) {T = i32, device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Sub"(%0, %cst_1) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @consecutive_add_sub + +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-6> : tensor} : () -> tensor +// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: return %[[SUB]] : tensor +} + +func.func @consecutive_sub_add(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-18> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-12> : tensor} : () -> tensor + %0 = "tf.Sub"(%arg0, %cst) {T = i32, device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_1) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @consecutive_sub_add + +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-6> : tensor} : () -> tensor +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: return %[[ADD]] : tensor +} + +func.func @consecutive_sub_sub(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-18> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-12> : tensor} : () -> tensor + %0 = "tf.Sub"(%arg0, %cst) {T = i32, device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Sub"(%0, %cst_1) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @consecutive_sub_sub + +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-30> : tensor} : () -> tensor +// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: return %[[SUB]] : tensor +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir index 282d172373499d..fbc3639cd8ff47 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir @@ -120,3 +120,42 @@ module { // CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D" // CHECK-SAME: {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} } + +// ----- + +module { + func.func @conv_with_maxpool(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<[[[[1.600000e-01, 1.000000e-01], [5.100000e-01, 5.400000e-01], [-5.000000e-01, 4.100000e-01]], [[-3.500000e-01, 5.000000e-02], [-0.00999999977, 1.600000e-01], [-4.800000e-01, -2.400000e-01]]], [[[-3.500000e-01, -2.100000e-01], [-1.400000e-01, -2.000000e-02], [4.800000e-01, 3.500000e-01]], [[-1.900000e-01, 3.200000e-01], [0.00999999977, -7.000000e-02], [2.000000e-01, -4.000000e-02]]]]> : tensor<2x2x3x2xf32>} : () -> tensor<2x2x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quant.qcast"(%cst) : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2x!quant.uniform:f32:3, {4.000000e-03,5.000000e-03}>> + %1 = "quant.dcast"(%0) : (tensor<2x2x3x2x!quant.uniform:f32:3, {4.000000e-03,5.000000e-03}>>) -> tensor<*xf32> + %2 = "quant.qcast"(%arg0) : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3x!quant.uniform> + %3 = "quant.dcast"(%2) : (tensor<1x2x2x3x!quant.uniform>) -> tensor<*xf32> + %4 = "tf.PartitionedCall"(%3, %1, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} : (tensor<*xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "quant.qcast"(%4) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %6 = "quant.dcast"(%5) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %7 = "tf.AvgPool"(%6) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<*xf32>) -> tensor<*xf32> + func.return %7 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + +// CHECK-LABEL: func @conv_with_maxpool +// CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%arg0 +// CHECK-SAME: f = @quantize_i8 +// CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]] +// CHECK-SAME: f = @quantized_conv2d_with_bias_and_relu6_fn_0 +// CHECK-SAME: (tensor<1x2x2x3xi8>, tensor<2x2x3x2xi8>, tensor<2xi32>, tensor, tensor, tensor<2xf32>, tensor<2xi32>, tensor<2xf32>, tensor<2xi32>, tensor, tensor) -> tensor<*xi8> +// CHECK: %[[cast_1:.*]] = "tf.Cast"(%[[conv_quant]]) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32> +// CHECK: %[[avgpool:.*]] = "tf.AvgPool"(%[[cast_1]]) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[cast_2:.*]] = "tf.Cast"(%[[round]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> +// CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[cast_2]] +// CHECK-SAME: f = @dequantize_i8 +// CHECK: return %[[dequantize]] + +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir new file mode 100644 index 00000000000000..eb73941c0acfa6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir @@ -0,0 +1,90 @@ +// Copyright 2022 The TensorFlow Runtime Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: tf-quant-opt %s -split-input-file -inline -quant-replace-cast-hacks-with-tf-xla-ops | FileCheck %s + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1177 : i32}} { + func.func @conv_with_bias_and_relu(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x2x2xf32> { + %cst = "tf.Const"() {value = dense<[162, 160]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_0 = "tf.Const"() {value = dense<[[[[-85, 72], [23, -103], [-29, -96]], [[-128, -83], [81, -57], [67, 119]], [[44, 10], [-90, -107], [77, 122]]], [[[18, 61], [127, -20], [-107, 119]], [[12, -66], [-98, 15], [124, 9]], [[68, 119], [20, -52], [48, 123]]]]> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> + %cst_1 = "tf.Const"() {value = dense<0.587548196> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<18.1044273> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.0748551115> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<0.0439809859> : tensor} : () -> tensor + %cst_7 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %0 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_2) {config = "", config_proto = "", executor_type = "", f = @quantize_i8} : (tensor<1x3x4x3xf32>, tensor, tensor) -> tensor<1x3x4x3xi8> + %1 = "tf.PartitionedCall"(%0, %cst_0, %cst, %cst_1, %cst_2, %cst_4, %cst_5, %cst_6, %cst_7, %cst_3, %cst_2) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu_fn_0} : (tensor<1x3x4x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor<1x3x2x2xi8> + %2 = "tf.PartitionedCall"(%1, %cst_3, %cst_2) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8} : (tensor<1x3x2x2xi8>, tensor, tensor) -> tensor<1x3x2x2xf32> + return %2 : tensor<1x3x2x2xf32> + } + func.func private @quantize_i8(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<1x3x4x3xi8> { + %0 = "tf.Div"(%arg0, %arg1) : (tensor<1x3x4x3xf32>, tensor) -> tensor<1x3x4x3xf32> + %1 = "tf.Round"(%0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %2 = "tf.Cast"(%1) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xi32> + %3 = "tf.AddV2"(%2, %arg2) : (tensor<1x3x4x3xi32>, tensor) -> tensor<1x3x4x3xi32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<1x3x4x3xi32>) -> tensor<1x3x4x3xi8> + return %4 : tensor<1x3x4x3xi8> + } + func.func private @dequantize_i8(%arg0: tensor<1x3x2x2xi8>, %arg1: tensor, %arg2: tensor) -> tensor<1x3x2x2xf32> { + %0 = "tf.Cast"(%arg0) : (tensor<1x3x2x2xi8>) -> tensor<1x3x2x2xi32> + %1 = "tf.Sub"(%0, %arg2) : (tensor<1x3x2x2xi32>, tensor) -> tensor<1x3x2x2xi32> + %2 = "tf.Cast"(%1) : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32> + %3 = "tf.Mul"(%2, %arg1) : (tensor<1x3x2x2xf32>, tensor) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> + } + func.func private @quantized_conv2d_with_bias_and_relu_fn_0(%arg0: tensor<1x3x4x3xi8>, %arg1: tensor<2x3x3x2xi8>, %arg2: tensor<2xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor) -> tensor<1x3x2x2xi8> { + %cst = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x3x4x3xi8>) -> tensor<1x3x4x3xi32> + %1 = "tf.Sub"(%0, %arg4) : (tensor<1x3x4x3xi32>, tensor) -> tensor<1x3x4x3xi32> + %2 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2x3x3x2xi8>) -> tensor<2x3x3x2xi32> + %3 = "tf.Sub"(%2, %arg6) : (tensor<2x3x3x2xi32>, tensor) -> tensor<2x3x3x2xi32> + %4 = "tf.Conv2D"(%1, %3) {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xi32>, tensor<2x3x3x2xi32>) -> tensor<1x3x2x2xi32> + %5 = "tf.AddV2"(%4, %arg2) : (tensor<1x3x2x2xi32>, tensor<2xi32>) -> tensor<1x3x2x2xi32> + %6 = "tf.Mul"(%arg3, %arg5) : (tensor, tensor) -> tensor + %7 = "tf.Div"(%6, %arg9) : (tensor, tensor) -> tensor + %8 = "tf.Cast"(%5) {Truncate = false} : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32> + %9 = "tf.Mul"(%7, %8) : (tensor, tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %10 = "tf.Round"(%9) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %11 = "tf.Cast"(%10) {Truncate = false} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xi32> + %12 = "tf.AddV2"(%11, %arg10) : (tensor<1x3x2x2xi32>, tensor) -> tensor<1x3x2x2xi32> + %13 = "tf.Maximum"(%cst_0, %arg10) : (tensor, tensor) -> tensor + %14 = "tf.ClipByValue"(%12, %13, %cst) : (tensor<1x3x2x2xi32>, tensor, tensor) -> tensor<1x3x2x2xi32> + %15 = "tf.Cast"(%14) {Truncate = false} : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xi8> + return %15 : tensor<1x3x2x2xi8> + } + +// CHECK-LABEL: func @conv_with_bias_and_relu +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() {value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> +// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() +// CHECK-SAME{LITERAL}: {value = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> +// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor +// CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> +// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() {value = dense<[-24320, -25984]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() {value = dense<[162, 160]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[PADV2_0:.*]] = "tf.PadV2"({{.*}}, %[[CONST_4]], %[[CONST_5]]) : (tensor<1x3x4x3xi8>, tensor<4x2xi32>, tensor) -> tensor<1x4x5x3xi8> +// CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]], %[[CONST_6]], %[[CONST_0]], %[[CONST_3]], %[[CONST_1]], %[[CONST_1]], %[[CONST_2]]) +// CHECK-SAME: (tensor<1x4x5x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor<1x3x2x2xi32> +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST_7]]) : (tensor<1x3x2x2xi32>, tensor<2xi32>) -> tensor<1x3x2x2xi32> +// CHECK: %[[ADDV2_1:.*]] = "tf.AddV2"(%[[SUB_0]], %[[CONST_8]]) : (tensor<1x3x2x2xi32>, tensor<2xi32>) -> tensor<1x3x2x2xi32> +} + +// ----- diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/quant_spec.cc index 2d703c6e09891c..95bad8ee78d5fa 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/quant_spec.cc @@ -27,6 +27,7 @@ std::unique_ptr GetTfQuantScaleSpec(Operation* op) { if (llvm::isa< // clang-format off // go/keep-sorted start + TF::AvgPoolOp, TF::ConcatV2Op, TF::IdentityOp, TF::MaxPoolOp, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index a6d02d161a8f9f..a55c9e725f9a75 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -2274,7 +2274,7 @@ Mutually reduces multiple tensors of identical type and shape. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CollectiveReduceV2Op : TF_Op<"CollectiveReduceV2", []> { +def TF_CollectiveReduceV2Op : TF_Op<"CollectiveReduceV2", [TF_CollectiveReduceOrderingEffect]> { let summary = [{ Mutually reduces multiple tensors of identical type and shape. }]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index e0776ac946ecbb..b4d0a4d4e4c7a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -170,6 +170,7 @@ def TF_RecvResource : TF_ResourceBase<"Recv">; def TF_TPUExecuteResource : TF_ResourceBase<"TPUExecute">; def TF_RandomGeneratorResource : TF_ResourceBase<"RandomGenerator">; def TF_XlaHostComputeResource : TF_ResourceBase<"XlaHostCompute">; +def TF_CollectiveReduceOrderingResource : TF_ResourceBase<"CollectiveReduceOrdering">; // Fake resource, see `TF_MustExecute` below. def TF_MustExecuteResource : TF_ResourceBase<"MustExecute">; @@ -234,6 +235,9 @@ def TF_TPUExecuteSideEffect : MemoryEffects<[MemWrite]>; def TF_RandomGeneratorSideEffect : MemoryEffects<[MemWrite]>; +// Special effect for keeping `CollectiveReduce` ops in order. +def TF_CollectiveReduceOrderingEffect : MemoryEffects<[MemWrite]>; + // Trait for enforcing that a side-effecting op is executed, even if it would be // considered dead by MLIR (see b/195782952). // The trait is implemented as a write effect for a fake resource which is diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index 0b71989b784741..9299dad8474fe2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -94,6 +94,11 @@ struct MustExecute : public ::mlir::SideEffects::Resource::Base { StringRef getName() final { return "MustExecute"; } }; +struct CollectiveReduceOrdering + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "CollectiveReduceOrdering"; } +}; + // Returns true iff resource type with given ID is only self-dependent, i.e., // there are no dependencies to other resource types (including unknown resource // type). diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index aeb6a09b4a5d37..b116d49474b1a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1485,7 +1485,7 @@ func.func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> { // CHECK-SAME: %[[VAL_1:.*]]: tensor<256xf32>) -> tensor<1xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant dense<[256, 1]> : tensor<2xi64> // CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> -// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV2"(%[[VAL_0]], %[[VAL_3]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_3]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_5:.*]] = arith.constant dense<1> : tensor<1xi64> // CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> // CHECK: return %[[VAL_6]] : tensor<1xf32> @@ -1502,7 +1502,7 @@ func.func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) - // CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> // CHECK: %[[VAL_4:.*]] = arith.constant dense<[256, 1]> : tensor<2xi64> // CHECK: %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32> -// CHECK: %[[VAL_6:.*]] = "tf.BatchMatMulV2"(%[[VAL_3]], %[[VAL_5]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_5]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_7:.*]] = arith.constant dense<> : tensor<0xi64> // CHECK: %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_6]], %[[VAL_7]]) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor // CHECK: return %[[VAL_8]] : tensor @@ -1515,7 +1515,7 @@ func.func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> // CHECK-LABEL: func @convert_dot_2d_2d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x256xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> { -// CHECK: %[[VAL_2:.*]] = "tf.BatchMatMulV2"(%[[VAL_0]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_2:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> // CHECK: return %[[VAL_2]] : tensor<1x1xf32> // CHECK: } func.func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> { @@ -1558,7 +1558,7 @@ func.func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x // CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_6]]) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32> // CHECK: %[[VAL_8:.*]] = arith.constant dense<[3, 12, 4]> : tensor<3xi64> // CHECK: %[[VAL_9:.*]] = "tf.Reshape"(%[[VAL_5]], %[[VAL_8]]) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32> -// CHECK: %[[VAL_10:.*]] = "tf.BatchMatMulV2"(%[[VAL_7]], %[[VAL_9]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> +// CHECK: %[[VAL_10:.*]] = "tf.BatchMatMulV3"(%[[VAL_7]], %[[VAL_9]]) {adj_x = false, adj_y = false} : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> // CHECK: %[[VAL_11:.*]] = arith.constant dense<[3, 5, 1, 4]> : tensor<4xi64> // CHECK: %[[VAL_12:.*]] = "tf.Reshape"(%[[VAL_10]], %[[VAL_11]]) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32> // CHECK: return %[[VAL_12]] : tensor<3x5x1x4xf32> @@ -1581,7 +1581,7 @@ func.func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4 // CHECK-SAME: %[[VAL_1:.*]]: tensor<1024x1024xf32>) -> tensor<1x1x1024xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant dense<[1, 1024]> : tensor<2xi64> // CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<1x1024xf32> -// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV2"(%[[VAL_3]], %[[VAL_1]]) {adj_x = false, adj_y = false} : {{.*}} -> tensor<1x1024xf32> +// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) {adj_x = false, adj_y = false} : {{.*}} -> tensor<1x1024xf32> // CHECK: %[[VAL_5:.*]] = arith.constant dense<[1, 1, 1024]> : tensor<3xi64> // CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : {{.*}} -> tensor<1x1x1024xf32> // CHECK: return %[[VAL_6]] : tensor<1x1x1024xf32> @@ -1599,6 +1599,149 @@ func.func @convert_dot_general_repeated(%arg0: tensor<1x1x1024xf32>, %arg1: tens func.return %0 : tensor<1x1x1024xf32> } +// CHECK-LABEL: func @convert_dot_general_int8( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256x8xi8>) -> tensor<8xi32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<[1, 256]> : tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xi8>, tensor<2xi64>) -> tensor<1x256xi8> +// CHECK: %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<1x256xi8>, tensor<256x8xi8>) -> tensor<1x8xi32> +// CHECK: %[[VAL_5:.*]] = arith.constant dense<8> : tensor<1xi64> +// CHECK: %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x8xi32>, tensor<1xi64>) -> tensor<8xi32> +// CHECK: return %[[VAL_6]] : tensor<8xi32> +// CHECK: } +func.func @convert_dot_general_int8(%arg0: tensor<256xi8>, %arg1: tensor<256x8xi8>) -> tensor<8xi32> { + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_contracting_dimensions = [0], + rhs_contracting_dimensions = [0]>, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">] + } : (tensor<256xi8>, tensor<256x8xi8>) -> tensor<8xi32> + func.return %0 : tensor<8xi32> +} + +// CHECK-LABEL: func.func @convert_conv1d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK-DAG: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16> +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16> +// CHECK-DAG: %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16> +// CHECK: %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16> +// CHECK: %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK: %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64> +// CHECK: %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16> +// CHECK: return %[[VAL_14]] : tensor<16x32x256xbf16> +// CHECK: } +func.func @convert_conv1d(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<1xi64>, + padding = dense<0> : tensor<1x2xi64>, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">], + rhs_dilation = dense<1> : tensor<1xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> + func.return %0 : tensor<16x32x256xbf16> +} + +// CHECK-LABEL: func.func @convert_conv1d_non_canonical_dimension_numbers( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x256xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[32, 16, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<32x16x256xbf16>, tensor<4xi64>) -> tensor<32x16x256x1xbf16> +// CHECK-DAG: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[1, 0, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<32x16x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16> +// CHECK: %[[VAL_6:.*]] = arith.constant dense<[256, 1, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<256x1x256xbf16>, tensor<4xi64>) -> tensor<256x1x256x1xbf16> +// CHECK-DAG: %[[VAL_8:.*]] = "tf.Const"() {value = dense<[1, 3, 2, 0]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<256x1x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16> +// CHECK: %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16> +// CHECK-DAG: %[[VAL_11:.*]] = "tf.Const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<256x16x32x1xbf16> +// CHECK: %[[VAL_13:.*]] = arith.constant dense<[256, 16, 32]> : tensor<3xi64> +// CHECK: %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<256x16x32x1xbf16>, tensor<3xi64>) -> tensor<256x16x32xbf16> +// CHECK: return %[[VAL_14]] : tensor<256x16x32xbf16> +// CHECK: } +func.func @convert_conv1d_non_canonical_dimension_numbers(%arg0: tensor<32x16x256xbf16>, %arg1: tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[0, b, f]x[o, 0, i]->[f, b, 0]>, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<1xi64>, + padding = dense<0> : tensor<1x2xi64>, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">], + rhs_dilation = dense<1> : tensor<1xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor<32x16x256xbf16>, tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> + func.return %0 : tensor<256x16x32xbf16> +} + +// CHECK-LABEL: func.func @no_convert_conv1d_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x?x256xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x?x256xbf16> { +// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], window = {stride = [1], pad = {{\[\[}}0, 0]], lhs_dilate = [1], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : (tensor<16x?x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x?x256xbf16> +// CHECK: return %[[VAL_2]] : tensor<16x?x256xbf16> +// CHECK: } +func.func @no_convert_conv1d_dynamic(%arg0: tensor<16x?x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x?x256xbf16> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<1xi64>, + padding = dense<0> : tensor<1x2xi64>, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">], + rhs_dilation = dense<1> : tensor<1xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor<16x?x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x?x256xbf16> + func.return %0 : tensor<16x?x256xbf16> +} + +// CHECK-LABEL: func.func @no_convert_conv1d_feature_group_gt_1( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> { +// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], window = {stride = [1], pad = {{\[\[}}0, 0]], lhs_dilate = [1], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : (tensor<16x32x256xbf16>, tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> +// CHECK: return %[[VAL_2]] : tensor<16x32x128xbf16> +// CHECK: } +func.func @no_convert_conv1d_feature_group_gt_1(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, + feature_group_count = 2 : i64, + lhs_dilation = dense<1> : tensor<1xi64>, + padding = dense<0> : tensor<1x2xi64>, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">], + rhs_dilation = dense<1> : tensor<1xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor<16x32x256xbf16>, tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> + func.return %0 : tensor<16x32x128xbf16> +} + +// CHECK-LABEL: func.func @no_convert_conv1d_missing_windows_strides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { +// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], window = {pad = {{\[\[}}0, 0]], lhs_dilate = [1], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> +// CHECK: return %[[VAL_2]] : tensor<16x32x256xbf16> +// CHECK: } +func.func @no_convert_conv1d_missing_windows_strides(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<1xi64>, + padding = dense<0> : tensor<1x2xi64>, + precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">], + rhs_dilation = dense<1> : tensor<1xi64> + } : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> + func.return %0 : tensor<16x32x256xbf16> +} + // CHECK-LABEL: func @convert_conv2d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { @@ -2972,7 +3115,7 @@ func.func @if(%arg0: tensor) -> (tensor) { // CHECK: return %[[RESULT]] : tensor<28x1x100xf32> // CHECK: } func.func @convert_dynamic_update_slice(%arg0: tensor<28x1x100xf32>, %arg1: tensor<1x1x100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<28x1x100xf32> { - %0 = "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor, tensor, tensor) -> tensor<28x1x100xf32> + %0 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor, tensor, tensor) -> tensor<28x1x100xf32> func.return %0 : tensor<28x1x100xf32> } @@ -3254,7 +3397,7 @@ func.func @const_quant() -> tensor<512x1x!quant.uniform, %arg1: tensor<256x!quant.uniform>) -> tensor<1xf32> { %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]} : (tensor<1x256xf32>, tensor<256x!quant.uniform>) -> tensor<1xf32> func.return %0 : tensor<1xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index cd5589dba39405..ef395810a59211 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1811,6 +1811,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr %cst_2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> %cst_3 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %0 = tf_executor.graph { + // CHECK: "tf.XlaConv"(%arg0, %arg1, %cst, %cst_0, %cst_1, %cst_2, %cst_3) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_CPU:0", dimension_numbers = "\18\012\01\02@\01P\01Z\01\02b\01\02", precision_config = "\0A\02\01\01"} : (tensor<*xf32>, tensor<*xf32>, tensor<1xi32>, tensor<1x2xi32>, tensor<1xi32>, tensor<1xi32>, tensor) -> tensor<*xf32> %outputs, %control = tf_executor.island wraps "tf.XlaConv"(%arg0, %arg1, %cst, %cst_0, %cst_1, %cst_2, %cst_3) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_CPU:0", dimension_numbers = "\18\012\01\02@\01P\01Z\01\02b\01\02", precision_config = "\0A\02\01\01"} : (tensor<*xf32>, tensor<*xf32>, tensor<1xi32>, tensor<1x2xi32>, tensor<1xi32>, tensor<1xi32>, tensor) -> tensor<*xf32> tf_executor.fetch %outputs : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 525449e1717093..2937c6b37aecc5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -98,12 +98,304 @@ LogicalResult GetConstantSplatValue(Value value, SplatValueType &splat_value) { return success(); } -class ConvertConvOp : public OpConversionPattern { +struct PermutationAndShape { + DenseIntElementsAttr permutation; + ShapedType shape; +}; + +// Returns a DenseIntElementsAttr for a permutation and the shape after +// applying the permutation to a given shape through a transpose. +PermutationAndShape GetPermutationAndTransposedShape( + llvm::ArrayRef permutation_array, ShapedType input_type, + ConversionPatternRewriter &rewriter) { + assert(permutation_array.size() == input_type.getRank()); + llvm::SmallVector transposed_shape(permutation_array.size()); + for (int64_t i = 0; i < permutation_array.size(); ++i) { + transposed_shape[i] = input_type.getDimSize(permutation_array[i]); + } + auto transposed_type = + RankedTensorType::get(transposed_shape, input_type.getElementType()); + DenseIntElementsAttr permutation = DenseIntElementsAttr::get( + RankedTensorType::get(permutation_array.size(), rewriter.getI64Type()), + permutation_array); + return {permutation, transposed_type}; +} + +// Returns the inverse permutation array for a permutation array. +llvm::SmallVector GetInversePermutationArray( + llvm::ArrayRef permutation_array) { + llvm::SmallVector inverse_permutation_array( + permutation_array.size()); + const auto permutation_array_size = permutation_array.size(); + for (int64_t i = 0; i < permutation_array_size; ++i) { + inverse_permutation_array[permutation_array[i]] = i; + } + return inverse_permutation_array; +} + +// Returns the DenseIntElementsAttr for an inverse permutation given a +// permutation_array. +DenseIntElementsAttr GetInversePermutation( + llvm::ArrayRef permutation_array, + ConversionPatternRewriter &rewriter) { + SmallVector inverse_permutation_array = + GetInversePermutationArray(permutation_array); + return DenseIntElementsAttr::get( + RankedTensorType::get(inverse_permutation_array.size(), + rewriter.getI64Type()), + inverse_permutation_array); +} + +// Returns a DenseIntElementsAttr for an inverse permutation and the shape after +// applying the inverse permutation to a given shape through a transpose. +PermutationAndShape GetInversePermutationAndShape( + llvm::ArrayRef permutation_array, ShapedType input_type, + ConversionPatternRewriter &rewriter) { + SmallVector inverse_permutation_array = + GetInversePermutationArray(permutation_array); + return GetPermutationAndTransposedShape(inverse_permutation_array, input_type, + rewriter); +} + +// Common functionality for ConvertConvOp classes. +template +struct ConvertNdConvOp { + bool IsSupportedConvOp(mhlo::ConvolutionOp conv_op) const { + if (!conv_op.lhs().getType().cast().hasStaticShape() || + !conv_op.rhs().getType().cast().hasStaticShape() || + !conv_op.getType().cast().hasStaticShape()) + return false; + + // All ones in "lhs_dilation" means this "mhlo.conv" op should be + // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp". + if (conv_op.lhs_dilation().hasValue()) { + auto lhs_dilation = conv_op.lhs_dilation().getValue(); + if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue() != 1) + return false; + } + + if (!conv_op.window_strides().hasValue() || conv_op.window_strides() + .getValue() + .getType() + .cast() + .getRank() != 1) + return false; + + auto num_spatial_dims = + conv_op.dimension_numbers().getInputSpatialDimensions().size(); + // TODO(b/158636600): Currently we don't support 3D Convolution. + if (num_spatial_dims != SupportedSpatialDims) return false; + + return true; + } +}; + +// Convert a 1-D convolution into a 2-D convolution (which TF supports) so that +// it can be rewritten by the pattern `Convert2DConvOp`. +class Convert1DConvOp : public OpConversionPattern, + ConvertNdConvOp<1> { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ConvolutionOp conv_op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // + // Check that input is a supported 1d convolution. + // + + if (!IsSupportedConvOp(conv_op) || conv_op->getNumResults() != 1) + return rewriter.notifyMatchFailure(conv_op, "unsupported conv op."); + + const mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers(); + + // Group convolution is not supported yet. + const int64_t input_feature_dimension = dnums.getInputFeatureDimension(); + const int64_t input_channels = + conv_op.lhs().getType().cast().getDimSize( + input_feature_dimension); + const int64_t feature_group_count = conv_op.feature_group_count(); + if (feature_group_count != 1 && feature_group_count != input_channels) + return rewriter.notifyMatchFailure(conv_op, + "Group convolution is not supported,"); + + // + // Transpose and reshape the input and kernel + // + + // Reshape input image to add a new spatial dimension. + auto image_type = conv_op.lhs().getType().cast(); + SmallVector image_2d_shape(image_type.getShape().begin(), + image_type.getShape().end()); + image_2d_shape.push_back(1); + auto image_2d_type = + RankedTensorType::get(image_2d_shape, image_type.getElementType()); + auto image_2d_op = rewriter.create( + conv_op.getLoc(), image_2d_type, conv_op.lhs()); + + // Transpose image to get it into NWHC form (where H is the added dim). + SmallVector image_permutation = { + dnums.getInputBatchDimension(), dnums.getInputSpatialDimensions()[0], + 3, // The trailing dim that we added. + dnums.getInputFeatureDimension()}; + auto image_permutation_and_shape = GetPermutationAndTransposedShape( + image_permutation, image_2d_type, rewriter); + auto transposed_image_2d_op = rewriter.create( + conv_op.getLoc(), image_permutation_and_shape.shape, + image_2d_op->getResult(0), image_permutation_and_shape.permutation); + + // Reshape kernel to add a new spatial dimension. + auto kernel_type = conv_op.rhs().getType().cast(); + SmallVector kernel_2d_shape; + for (int64_t dim : kernel_type.getShape()) { + kernel_2d_shape.push_back(dim); + } + kernel_2d_shape.push_back(1); + auto kernel_2d_type = + RankedTensorType::get(kernel_2d_shape, kernel_type.getElementType()); + auto kernel_2d_op = rewriter.create( + conv_op.getLoc(), kernel_2d_type, conv_op.rhs()); + + // Transpose kernel to get it into WHIO form (where H is the added dim). + SmallVector kernel_permutation = { + dnums.getKernelSpatialDimensions()[0], + 3, // The trailing dim that we added. + dnums.getKernelInputFeatureDimension(), + dnums.getKernelOutputFeatureDimension()}; + auto kernel_permutation_and_shape = GetPermutationAndTransposedShape( + kernel_permutation, kernel_2d_type, rewriter); + auto transposed_kernel_2d_op = rewriter.create( + conv_op.getLoc(), kernel_permutation_and_shape.shape, + kernel_2d_op->getResult(0), kernel_permutation_and_shape.permutation); + + // + // Create 2d equivalents for 1d convolution attributes. + // + + // Window Strides + SmallVector window_strides_2d_array; + for (const auto v : conv_op.window_strides()->getValues()) { + window_strides_2d_array.emplace_back(v); + } + window_strides_2d_array.push_back(1); + auto window_strides_2d = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), + window_strides_2d_array); + + // Padding + SmallVector padding_2d_array; + for (const auto v : conv_op.padding().getValue().getValues()) { + padding_2d_array.emplace_back(v); + } + // The newly added spatial dimension requires zero left and right padding. + padding_2d_array.push_back(0); + padding_2d_array.push_back(0); + auto padding_2d = DenseIntElementsAttr::get( + RankedTensorType::get({2, 2}, rewriter.getI64Type()), padding_2d_array); + + // LHS dilation + SmallVector lhs_dilation_array_2d; + for (const auto v : + conv_op.lhs_dilation().getValue().getValues()) { + lhs_dilation_array_2d.emplace_back(v); + } + lhs_dilation_array_2d.push_back(1); + auto lhs_dilation_2d = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), + lhs_dilation_array_2d); + + // RHS dilation + SmallVector rhs_dilation_array_2d; + for (const auto v : + conv_op.rhs_dilation().getValue().getValues()) { + rhs_dilation_array_2d.emplace_back(v); + } + rhs_dilation_array_2d.push_back(1); + auto rhs_dilation_2d = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), + rhs_dilation_array_2d); + + // Window reversal is unsupported. + if (conv_op.window_reversal().hasValue() && + conv_op.window_reversal()->getValues()[0] == true) + return failure(); + auto window_reversal_2d = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), + SmallVector({0, 0})); + + // Precision config + if (!conv_op.precision_config().hasValue()) return failure(); + + // Dimension numbers reflect the form of the 2d conv op NWHC * WHIO -> NWHC + auto dnums_2d = + mhlo::ConvDimensionNumbersAttr::get(rewriter.getContext(), + /*inputBatchDimension=*/0, + /*inputFeatureDimension=*/3, + /*inputSpatialDimensions=*/{1, 2}, + /*kernelInputDimension=*/2, + /*kernelOutputDimension=*/3, + /*kernelSpatialDimensions=*/{0, 1}, + /*outputBatchDimension=*/0, + /*outputFeatureDimension=*/3, + /*outputSpatialDimensions=*/{1, 2}); + // + // Generate a 2-D convolution + // + + // Determine the 2-D convolution output shape. + auto output_type = conv_op->getResult(0).getType().cast(); + SmallVector output_2d_shape; + for (int64_t dim : output_type.getShape()) { + output_2d_shape.push_back(dim); + } + output_2d_shape.push_back(1); + auto output_2d_type = + RankedTensorType::get(output_2d_shape, output_type.getElementType()); + SmallVector output_permutation = { + dnums.getOutputBatchDimension(), dnums.getOutputSpatialDimensions()[0], + 3, // The trailing dim that we added. + dnums.getOutputFeatureDimension()}; + auto transposed_output_2d_shape = + GetPermutationAndTransposedShape(output_permutation, output_2d_type, + rewriter) + .shape; + + auto conv2d_op = rewriter.create( + conv_op.getLoc(), transposed_output_2d_shape, + transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(), + window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d, + window_reversal_2d, dnums_2d, conv_op.feature_group_count(), + conv_op.batch_group_count(), *conv_op.precision_config()); + + OpResult conv2d_output = conv2d_op->getResult(0); + auto conv2d_output_type = conv2d_output.getType().cast(); + + // + // Transpose and reshape the output + // + + // Since output is in NWHC form we need to undo the permutation we have + // affectively applied. + auto output_permutation_and_shape = GetInversePermutationAndShape( + output_permutation, conv2d_output_type, rewriter); + auto transposed_output_2d_op = rewriter.create( + conv_op.getLoc(), output_permutation_and_shape.shape, conv2d_output, + output_permutation_and_shape.permutation); + + // Drop the trailing spatial dimension from the output. + rewriter.replaceOpWithNewOp( + conv_op, output_type, transposed_output_2d_op.getResult()); + return success(); + } +}; + +class Convert2DConvOp : public OpConversionPattern, + ConvertNdConvOp<2> { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - mhlo::ConvOp conv_op, OpAdaptor adaptor, + mhlo::ConvolutionOp conv_op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { if (!IsSupportedConvOp(conv_op)) { return failure(); @@ -177,7 +469,7 @@ class ConvertConvOp : public OpConversionPattern { }; private: - bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims, + bool IsSamePadding(mhlo::ConvolutionOp conv_op, int num_spatial_dims, ArrayRef strides, ArrayRef dilation, ArrayRef padding_array) const { mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers(); @@ -310,7 +602,7 @@ class ConvertConvOp : public OpConversionPattern { start_attr, size_attr); } - void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef strides, + void CreateConvOp(mhlo::ConvolutionOp conv_op, ArrayRef strides, StringRef padding, ArrayRef explicit_padding, ArrayRef dilation, bool is_depthwise_conv, int input_channels, int num_spatial_dims, @@ -404,43 +696,15 @@ class ConvertConvOp : public OpConversionPattern { } rewriter.replaceOp(conv_op, {output}); } - - bool IsSupportedConvOp(mhlo::ConvOp conv_op) const { - if (!conv_op.lhs().getType().cast().hasStaticShape() || - !conv_op.rhs().getType().cast().hasStaticShape() || - !conv_op.getType().cast().hasStaticShape()) - return false; - - // All ones in "lhs_dilation" means this "mhlo.conv" op should be - // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp". - if (conv_op.lhs_dilation().hasValue()) { - auto lhs_dilation = conv_op.lhs_dilation().getValue(); - if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue() != 1) - return false; - } - - if (!conv_op.window_strides().hasValue() || conv_op.window_strides() - .getValue() - .getType() - .cast() - .getRank() != 1) - return false; - - auto num_spatial_dims = - conv_op.dimension_numbers().getInputSpatialDimensions().size(); - // TODO(b/158636600): Currently we don't support 3D Convolution. - if (num_spatial_dims != 2) return false; - - return true; - } }; -class ConvertNonTrivialConvOp : public OpConversionPattern { +class ConvertNonTrivialConvOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - mhlo::ConvOp conv_op, OpAdaptor adaptor, + mhlo::ConvolutionOp conv_op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { if (IsSupportedConvOp(conv_op, rewriter).failed()) { return rewriter.notifyMatchFailure( @@ -525,7 +789,7 @@ class ConvertNonTrivialConvOp : public OpConversionPattern { }; private: - bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims, + bool IsSamePadding(mhlo::ConvolutionOp conv_op, int num_spatial_dims, ArrayRef strides) const { for (auto i : llvm::seq(0, num_spatial_dims)) { int dim = i + 1; @@ -541,7 +805,7 @@ class ConvertNonTrivialConvOp : public OpConversionPattern { return true; } - LogicalResult IsSupportedConvOp(mhlo::ConvOp conv_op, + LogicalResult IsSupportedConvOp(mhlo::ConvolutionOp conv_op, ConversionPatternRewriter &rewriter) const { if (!conv_op.lhs().getType().cast().hasStaticShape() || !conv_op.rhs().getType().cast().hasStaticShape() || @@ -625,7 +889,7 @@ class ConvertNonTrivialConvOp : public OpConversionPattern { return success(); } - void CreateResizeBilinearOp(mhlo::ConvOp conv_op, + void CreateResizeBilinearOp(mhlo::ConvolutionOp conv_op, llvm::ArrayRef output_sizes, bool align_corners, ConversionPatternRewriter &rewriter) const { @@ -645,7 +909,7 @@ class ConvertNonTrivialConvOp : public OpConversionPattern { rewriter.replaceOp(conv_op, {output}); } - LogicalResult MatchResizeOp(mhlo::ConvOp conv_op, bool &align_corners, + LogicalResult MatchResizeOp(mhlo::ConvolutionOp conv_op, bool &align_corners, llvm::SmallVector &output_sizes, ConversionPatternRewriter &rewriter) const { mhlo::ConvDimensionNumbersAttr dnums = conv_op.dimension_numbers(); @@ -1488,7 +1752,7 @@ Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs, lhs_dot_dimensions_info.FlattenedOutDimensionSize()}, llvm::ArrayRef{ rhs_dot_dimensions_info.FlattenedOutDimensionSize()}); - auto matmul = rewriter.create( + auto matmul = rewriter.create( loc, RankedTensorType::get(matmul_shape, result_type.getElementType()), lhs_flattend.getResult(), rhs_flattend.getResult()); auto reshaped = @@ -2492,42 +2756,6 @@ bool IsIotaAttr(ArrayRef arr, int64_t size) { return true; } -DenseIntElementsAttr GetInversePermutation( - llvm::ArrayRef permutation_array, - ConversionPatternRewriter &rewriter) { - llvm::SmallVector inverse_permutation_array( - permutation_array.size()); - const auto permutation_array_size = permutation_array.size(); - for (int64_t i = 0; i < permutation_array_size; ++i) { - inverse_permutation_array[permutation_array[i]] = i; - } - return DenseIntElementsAttr::get( - RankedTensorType::get(inverse_permutation_array.size(), - rewriter.getI64Type()), - inverse_permutation_array); -} - -struct PermutationAndShape { - DenseIntElementsAttr permutation; - ShapedType shape; -}; - -PermutationAndShape GetPermutationAndTransposedShape( - llvm::ArrayRef permutation_array, ShapedType input_type, - ConversionPatternRewriter &rewriter) { - assert(permutation_array.size() == input_type.getRank()); - llvm::SmallVector transposed_shape(permutation_array.size()); - for (int64_t i = 0; i < permutation_array.size(); ++i) { - transposed_shape[i] = input_type.getDimSize(permutation_array[i]); - } - auto transposed_type = - RankedTensorType::get(transposed_shape, input_type.getElementType()); - DenseIntElementsAttr permutation = DenseIntElementsAttr::get( - RankedTensorType::get(permutation_array.size(), rewriter.getI64Type()), - permutation_array); - return {permutation, transposed_type}; -} - // Convert updates into canonical form as expected by tf.scatter ops. // // tf.scatter expects `update_window_dims` to be the trailing dimensions. @@ -3100,9 +3328,10 @@ void LegalizeHloToTf::runOnOperation() { void PopulateLegalizeHloToTfPatterns(RewritePatternSet *patterns, MLIRContext *context) { patterns->add< - ConvertAvgPoolOp, ConvertConvOp, ConvertNonTrivialConvOp, - ConvertDynamicSliceOp, ConvertDynamicUpdateSliceOp, ConvertGatherOp, - ConvertIfOp, ConvertMaxPoolOp, ConvertScatterAddOp, ConvertScatterMaxOp, + ConvertAvgPoolOp, Convert2DConvOp, Convert1DConvOp, + ConvertNonTrivialConvOp, ConvertDynamicSliceOp, + ConvertDynamicUpdateSliceOp, ConvertGatherOp, ConvertIfOp, + ConvertMaxPoolOp, ConvertScatterAddOp, ConvertScatterMaxOp, ConvertScatterMinOp, ConvertScatterSubOp, ConvertScatterUpdateOp, ConvertSliceOp, ConvertReduceOpToTfArgmax, ConvertReduceOpToTfArgmin, ConvertReduceOpToTfMax, ConvertReduceOpToTfMin, ConvertReduceOpToTfAll, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index fe73a72859bd79..a114bbbd1e6baa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -116,7 +116,7 @@ def : Pat<(HLO_ConvertOp HLO_Tensor:$operand), foreach Mapping = [[HLO_AbsOp, TF_AbsOp], [HLO_BitcastConvertOp, TF_BitcastOp], [HLO_CeilOp, TF_CeilOp], - [HLO_CosOp, TF_CosOp], + [HLO_CosineOp, TF_CosOp], [HLO_ExpOp, TF_ExpOp], [HLO_Expm1Op, TF_Expm1Op], [HLO_FloorOp, TF_FloorOp], @@ -128,7 +128,7 @@ foreach Mapping = [[HLO_AbsOp, TF_AbsOp], [HLO_NegOp, TF_NegOp], [HLO_RealOp, TF_RealOp], [HLO_RsqrtOp, TF_RsqrtOp], - [HLO_SinOp, TF_SinOp], + [HLO_SineOp, TF_SinOp], [HLO_SignOp, TF_SignOp], [HLO_SqrtOp, TF_SqrtOp], [HLO_TanhOp, TF_TanhOp]] in diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index 5ac18f0e4b0f0c..a8e2a9881591fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -186,10 +186,10 @@ void PrintPassPipeline(const mlir::PassManager& pass_manager, llvm::interleaveComma( pass_manager.getPasses(), passOS, [&](mlir::Pass& pass) { pass.printAsTextualPipeline(passOS); }); - os << "// configuration: -pass-pipeline='" << passOS.str() << "'"; - if (op->getContext()->isMultithreadingEnabled()) - os << " -mlir-disable-threading"; - os << " -verify-each"; + os << "{-# external_resources: { mlir_reproducer: { pipeline: \"" + << passOS.str() << "\", "; + os << "disable_threading: true, "; + os << "verify_each: true } } #-}"; os << "\n\n"; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index 4e11d83de7b06c..d8cefd434f7cca 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -104,8 +104,8 @@ TEST(DumpCrashReproducerTest, Valid) { std::string expected_txt_module; { llvm::raw_string_ostream os(expected_txt_module); - os << "// configuration: -pass-pipeline='' -mlir-disable-threading " - "-verify-each\n\n"; + os << "{-# external_resources: { mlir_reproducer: { pipeline: \"\", " + "disable_threading: true, verify_each: true } } #-}\n\n"; module_ref->getOperation()->print(os, mlir::OpPrintingFlags().useLocalScope()); os.flush(); diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc index 6703609d32b0e0..683ae39ead8b34 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc @@ -748,6 +748,44 @@ class StridedSliceOpClusteringPolicy } }; +// -------------------------------------------------------------------------- // +// Gather Operations. +// -------------------------------------------------------------------------- // + +class GatherOpClusteringPolicy : public DefaultClusteringPolicy { + public: + GatherOpClusteringPolicy() + : DefaultClusteringPolicy(IsGatherOp(), ValueConstraint::kRank) {} + + private: + std::function IsGatherOp() { + return [](Operation* op) { + return mlir::isa(op); + }; + } +}; + +// -------------------------------------------------------------------------- // +// Scatter Operations. +// -------------------------------------------------------------------------- // + +class ScatterOpClusteringPolicy : public DefaultClusteringPolicy { + public: + ScatterOpClusteringPolicy() + : DefaultClusteringPolicy(IsScatterOp(), ValueConstraint::kRank) {} + + private: + std::function IsScatterOp() { + return [](Operation* op) { + return mlir::isa< + mlir::TF::ScatterNdOp, mlir::TF::TensorScatterAddOp, + mlir::TF::TensorScatterMaxOp, mlir::TF::TensorScatterMinOp, + mlir::TF::TensorScatterSubOp, mlir::TF::TensorScatterUpdateOp>(op); + }; + } +}; + } // namespace void populateTfJitRtClusteringPolicies(ClusteringPolicySet& policies, @@ -780,6 +818,11 @@ void populateTfJitRtClusteringPolicies(ClusteringPolicySet& policies, SqueezeOpClusteringPolicy>(); } + if (is_enabled(JitRtClusteringTier::kGatherScatter)) { + policies.Add(); + } + if (is_enabled(JitRtClusteringTier::kAll)) { policies.Add, %rhs: memref<4x4xf32>, // CHECK-SAME: algorithm = 13 : i64 // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 // CHECK-SAME: alpha_real = 1.000000e+00 : f64 + // CHECK-SAME: beta = 0.000000e+00 : f64 // CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64> // CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64> // CHECK-SAME: uid = 0 : i64 @@ -36,6 +37,7 @@ func.func @compute(%lhs: memref<4x4xf32>, %rhs: memref<4x4xf32>, alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, batch_size = 1 : i64, + beta = 0.000000e+00 : f64, dot_dimension_numbers = #mhlo.dot, lhs_stride = 16 : i64, @@ -50,50 +52,3 @@ func.func @compute(%lhs: memref<4x4xf32>, %rhs: memref<4x4xf32>, // CHECK: func private @gemm(memref<4x4xf32>, memref<4x4xf32>, // CHECK-SAME: memref<4x4xf32>) // CHECK-SAME: attributes {rt.direct_custom_call = "xla.gpu.gemm"} - -// ----- - -// CHECK: @compute( -// CHECK: %[[LHS:[a-z0-9]+]]: memref<4x4xf32>, -// CHECK: %[[RHS:[a-z0-9]+]]: memref<4x4xf32>, -// CHECK: %[[OUT:[a-z0-9]+]]: memref<4x4xf32>, -// CHECK: %[[BIAS:[a-z0-9]+]]: memref<4x4xf32> -// CHECK: ) -func.func @compute(%lhs: memref<4x4xf32>, %rhs: memref<4x4xf32>, - %out: memref<4x4xf32>, %bias: memref<4x4xf32>) { - - // CHECK: call @gemm(%[[LHS]], %[[RHS]], %[[OUT]], %[[BIAS]]) - // CHECK-SAME: algorithm = 13 : i64 - // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 - // CHECK-SAME: alpha_real = 1.000000e+00 : f64 - // CHECK-SAME: beta = 1.000000e+00 : f64 - // CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64> - // CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64> - // CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64> - // CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64> - // CHECK-SAME: uid = 0 : i64 - // CHECK-SAME: (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>, - // CHECK-SAME: memref<4x4xf32>) -> () - "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %out, %bias) - { - algorithm = 13 : i64, - alpha_imag = 0.000000e+00 : f64, - alpha_real = 1.000000e+00 : f64, - batch_size = 1 : i64, - beta = 1.000000e+00 : f64, - dot_dimension_numbers = #mhlo.dot, - lhs_stride = 16 : i64, - rhs_stride = 16 : i64 - } - : (memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () - - // CHECK-NEXT: return - func.return -} - -// CHECK: func private @gemm(memref<4x4xf32>, memref<4x4xf32>, -// CHECK-SAME: memref<4x4xf32>, memref<4x4xf32>) -// CHECK-SAME: attributes {rt.direct_custom_call = "xla.gpu.gemm.bias"} diff --git a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/async_conversion.mlir b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/async_conversion.mlir index b950c466c85aa7..d079fadb465808 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/async_conversion.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/async_conversion.mlir @@ -38,6 +38,7 @@ func.func @async(%memref: memref<4x4xf32>) { >, alpha_real = 0.5, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 16, rhs_stride = 16 @@ -73,6 +74,7 @@ func.func @async(%memref: memref<4x4xf32>) { >, alpha_real = 0.5, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 16, rhs_stride = 16 @@ -98,6 +100,7 @@ func.func @async(%memref: memref<4x4xf32>) { >, alpha_real = 0.5, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 16, rhs_stride = 16 diff --git a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/blas.mlir b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/blas.mlir index 630b6330a18604..691ced396ecbfd 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/blas.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/blas.mlir @@ -38,6 +38,7 @@ func.func @gemm(%lhs: memref<3x4xf32>, %rhs: memref<4x5xf32>, %output:memref<3x5 >, alpha_real = 0.5, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 12, rhs_stride = 20 @@ -88,6 +89,7 @@ func.func @gemm_batch(%lhs: memref<42x3x4xf32>, %rhs: memref<4x5xf32>, %output:m >, alpha_real = 0.5, alpha_imag = 0.0, + beta = 0.0, batch_size = 42, lhs_stride = 12, rhs_stride = 0 @@ -98,58 +100,6 @@ func.func @gemm_batch(%lhs: memref<42x3x4xf32>, %rhs: memref<4x5xf32>, %output:m "lmhlo.terminator"() : () -> () } -// CHECK: func @gemm_bias( -// CHECK-SAME: %arg0: !tfrt.chain, -// CHECK-SAME: %arg1: !tfrt_gpu.stream, -// CHECK-SAME: %arg2: !tfrt_gpu.buffer, -// CHECK-SAME: %arg3: !tfrt_gpu.buffer, -// CHECK-SAME: %arg4: !tfrt_gpu.buffer, -// CHECK-SAME: %arg5: !tfrt_gpu.buffer -// CHECK-SAME: ) -> !tfrt.chain -func.func @gemm_bias(%lhs: memref<3x4xf32>, %rhs: memref<4x5xf32>, - %bias: memref<3x5xf32>, %output:memref<3x5xf32>) { - // CHECK-NOT: cast - // CHECK-NOT: async.execute - - // CHECK: [[CHAIN0:%[0-9]+]] = tfrt_gpu.mem.copy %arg5, %arg4, %arg1, %arg0 - // CHECK-SAME: : !tfrt_gpu.buffer, !tfrt_gpu.buffer - - // CHECK-DAG: [[M:%[0-9]+]] = tfrt.constant.i32 3 - // CHECK-DAG: [[N:%[0-9]+]] = tfrt.constant.i32 5 - // CHECK-DAG: [[K:%[0-9]+]] = tfrt.constant.i32 4 - // CHECK-DAG: [[ALPHA:%[0-9]+]] = tfrt.constant.f32 5.000000e-01 - // CHECK-DAG: [[BETA:%[0-9]+]] = tfrt.constant.f32 1.000000e+00 - // CHECK: [[ALGO:%[0-9]+]] = tfrt_gpu.blas.gemm.algo CUBLAS_GEMM_DEFAULT - // CHECK: [[CONTEXT:%[0-9]+]] = tfrt_gpu.stream.get_context %arg1 - // CHECK: [[HANDLE:%[0-9]+]] = tfrt.once @tfrt_gpu.blas.create{{.*}}([[CONTEXT]]) - - // CHECK: [[CHAIN1:%[0-9]+]] = tfrt_gpu.blas.gemm [[HANDLE]], %arg1 - // CHECK-SAME: CUBLAS_OP_N, CUBLAS_OP_N, [[N]], [[M]], [[K]], [[ALPHA]], - // CHECK-SAME: %arg3, CUDA_R_32F, [[N]], - // CHECK-SAME: %arg2, CUDA_R_32F, [[K]], [[BETA]], - // CHECK-SAME: %arg5, CUDA_R_32F, [[N]], - // CHECK-SAME: CUBLAS_COMPUTE_32F, [[ALGO]], [[CHAIN0]] - - "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [], - rhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [0] - >, - alpha_real = 0.5, - alpha_imag = 0.0, - beta = 1.0, - batch_size = 1, - lhs_stride = 12, - rhs_stride = 20 - } : (memref<3x4xf32>, memref<4x5xf32>, memref<3x5xf32>, memref<3x5xf32>) -> () - - // CHECK-NOT: cast - // CHECK: tfrt.return [[CHAIN1]] : !tfrt.chain - "lmhlo.terminator"() : () -> () -} - // CHECK: func @triangular_solve( // CHECK-SAME: %arg0: !tfrt.chain, // CHECK-SAME: %arg1: !tfrt_gpu.stream, diff --git a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/common.mlir b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/common.mlir index f053ae8bcd2bff..4875fe1087b4ac 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/common.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt/common.mlir @@ -27,6 +27,7 @@ func.func @view(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<100 >, alpha_real = 0.5, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 20, rhs_stride = 20 @@ -61,6 +62,7 @@ func.func @reinterpret_cast(%lhs: memref<5x4xf32, affine_map<(d0, d1) -> (d0 + d >, alpha_real = 0.5, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 20, rhs_stride = 20 @@ -91,6 +93,7 @@ func.func @two_ops(%memref: memref<4x4xf32>) { >, alpha_real = 3.14159274, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 16, rhs_stride = 16 @@ -107,6 +110,7 @@ func.func @two_ops(%memref: memref<4x4xf32>) { >, alpha_real = 2.71828175, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 16, rhs_stride = 16 @@ -136,6 +140,7 @@ func.func @return(%memref: memref<4x4xf32>) -> memref<4x4xf32> { >, alpha_real = 1.0, alpha_imag = 0.0, + beta = 0.0, batch_size = 1, lhs_stride = 16, rhs_stride = 16 diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gemm_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gemm_pattern.cc index 841bbe7ca53c1a..16bc1f05aed571 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gemm_pattern.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gemm_pattern.cc @@ -89,9 +89,6 @@ void MakeBlasGemmCompatible(int64_t& m, int64_t& n, MatrixDescriptor& lhs, } } -Value GetBias(lmhlo_gpu::GEMMOpAdaptor op) { return nullptr; } -Value GetBias(lmhlo_gpu::GEMM_BiasOpAdaptor op) { return op.getBias(); } - // Match GEMM auto-tuning, see ComputationTypeFromPrimitive() Type MlirComputationType(Type element_type, ConversionPatternRewriter& rewriter) { @@ -107,8 +104,8 @@ Type MlirComputationType(Type element_type, } // Gets the platform specific Gemm algorithm value. -template -tfrt::gpu::wrapper::BlasGemmAlgo GetBlasGemmAlgoOrDefault(GemmOp op) { +tfrt::gpu::wrapper::BlasGemmAlgo GetBlasGemmAlgoOrDefault( + lmhlo_gpu::GEMMOp op) { if (!op.getAlgorithm().hasValue()) return kBlasGemmDefaultAlgo; return {static_cast(op.getAlgorithm().getValue()), kGpuTargetPlatform}; } @@ -121,18 +118,14 @@ tfrt::gpu::wrapper::BlasOperation MatrixTransposeToBlasOperation( // Create all the Ops necessary for the GEMM operation, including the GEMM // operation itself. -template -Value CreateTfrtOps(GemmOp op, typename GemmOp::Adaptor adaptor, Value chain, - Value stream, mlir::Type input_type, mlir::Type output_type, - int64_t batch_size, int64_t m, int64_t n, int64_t k, - const MatrixDescriptor& lhs, const MatrixDescriptor& rhs, - const MatrixDescriptor& output, xla::complex128 alpha, - double beta, ConversionPatternRewriter& rewriter) { +Value CreateTfrtOps(lmhlo_gpu::GEMMOp op, lmhlo_gpu::GEMMOp::Adaptor adaptor, + Value chain, Value stream, mlir::Type input_type, + mlir::Type output_type, int64_t batch_size, int64_t m, + int64_t n, int64_t k, const MatrixDescriptor& lhs, + const MatrixDescriptor& rhs, const MatrixDescriptor& output, + xla::complex128 alpha, double beta, + ConversionPatternRewriter& rewriter) { auto loc = op.getLoc(); - if (auto bias = GetBias(adaptor)) { - chain = rewriter.create(loc, adaptor.getOutput(), - bias, stream, chain); - } const Type mlir_compute_type = MlirComputationType(output_type, rewriter); @@ -198,16 +191,15 @@ Value CreateTfrtOps(GemmOp op, typename GemmOp::Adaptor adaptor, Value chain, .getResult(); } -template -FailureOr GemmOpConversionRewrite(GemmOp op, - typename GemmOp::Adaptor adaptor, +FailureOr GemmOpConversionRewrite(lmhlo_gpu::GEMMOp op, + lmhlo_gpu::GEMMOp::Adaptor adaptor, Value chain, Value stream, ConversionPatternRewriter& rewriter) { auto get_element_type = [](Value value) { return value.getType().cast().getElementType(); }; - if (get_element_type(op.getLhs()) != get_element_type(op.getRhs())) { + if (get_element_type(op.getA()) != get_element_type(op.getB())) { return rewriter.notifyMatchFailure(op, "Input element type mismatch."); } @@ -221,29 +213,26 @@ FailureOr GemmOpConversionRewrite(GemmOp op, int64_t m = config->output_layout.num_rows; int64_t n = config->output_layout.num_cols; int64_t k = config->lhs_layout.num_cols; - MatrixDescriptor lhs = GetMatrixDesc(config->lhs_layout, adaptor.getLhs()); - MatrixDescriptor rhs = GetMatrixDesc(config->rhs_layout, adaptor.getRhs()); + MatrixDescriptor lhs = GetMatrixDesc(config->lhs_layout, adaptor.getA()); + MatrixDescriptor rhs = GetMatrixDesc(config->rhs_layout, adaptor.getB()); MatrixDescriptor output = - GetMatrixDesc(config->output_layout, adaptor.getOutput()); + GetMatrixDesc(config->output_layout, adaptor.getC()); int64_t batch_size = config->output_layout.batch_size; MakeBlasGemmCompatible(m, n, lhs, rhs, output); - return CreateTfrtOps(op, adaptor, chain, stream, - get_element_type(op.getLhs()), - get_element_type(op.getOutput()), batch_size, m, n, k, - lhs, rhs, output, config->alpha, config->beta, rewriter); + return CreateTfrtOps(op, adaptor, chain, stream, get_element_type(op.getA()), + get_element_type(op.getC()), batch_size, m, n, k, lhs, + rhs, output, config->alpha, config->beta, rewriter); } -template struct GemmRewritePattern - : tfrt::gpu::StreamifyOpConversionPattern { - using typename tfrt::gpu::StreamifyOpConversionPattern::OpAdaptor; + : tfrt::gpu::StreamifyOpConversionPattern { using tfrt::gpu::StreamifyOpConversionPattern< - GemmOpType>::StreamifyOpConversionPattern; + lmhlo_gpu::GEMMOp>::StreamifyOpConversionPattern; FailureOr matchAndRewriteOp( - GemmOpType op, OpAdaptor adaptor, Value chain, Value stream, - ConversionPatternRewriter& rewriter) const override { + lmhlo_gpu::GEMMOp op, lmhlo_gpu::GEMMOp::Adaptor adaptor, Value chain, + Value stream, ConversionPatternRewriter& rewriter) const override { auto result = GemmOpConversionRewrite(op, adaptor, chain, stream, rewriter); if (succeeded(result)) rewriter.eraseOp(op); return result; @@ -254,9 +243,7 @@ struct GemmRewritePattern void populateGemmConversionPattern(RewritePatternSet& patterns, TypeConverter& converter) { - patterns.add, - GemmRewritePattern>( - converter, patterns.getContext()); + patterns.add(converter, patterns.getContext()); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc index a2e8c5483f2de0..4cfa1defa16e36 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc @@ -106,7 +106,6 @@ using mlir::lmhlo_gpu::ConvForwardFusedOp; using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; using mlir::lmhlo_gpu::ConvForwardOp; using mlir::lmhlo_gpu::ConvolutionBackendConfigAttr; -using mlir::lmhlo_gpu::GEMM_BiasOp; using mlir::lmhlo_gpu::GEMMOp; using mlir::memref::AllocaOp; using mlir::memref::GetGlobalOp; @@ -360,24 +359,15 @@ class GemmUidGenerator { std::atomic cnt_; }; -template -class GemmLowering : public OpRewritePattern { +class GemmOpLowering : public OpRewritePattern { private: static StringRef CustomCallTarget(GEMMOp) { return "xla.gpu.gemm"; } - static StringRef CustomCallTarget(GEMM_BiasOp) { return "xla.gpu.gemm.bias"; } - - static void SetOptionalAttrs(ImplicitLocOpBuilder& b, GEMMOp op, - CallOp call) {} - static void SetOptionalAttrs(ImplicitLocOpBuilder& b, GEMM_BiasOp op, - CallOp call) { - call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); - } public: - GemmLowering(MLIRContext* ctx, GemmUidGenerator& uid) - : OpRewritePattern(ctx), uid_(uid) {} + GemmOpLowering(MLIRContext* ctx, GemmUidGenerator& uid) + : OpRewritePattern(ctx), uid_(uid) {} - LogicalResult matchAndRewrite(Gemm op, + LogicalResult matchAndRewrite(GEMMOp op, PatternRewriter& rewriter) const override { MLIRContext* ctx = this->getContext(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -408,12 +398,13 @@ class GemmLowering : public OpRewritePattern { call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); // Copy backend specific attributes. - call->setAttr(b.getStringAttr("algorithm"), op.getAlgorithmAttr()); + auto algorithm_attr = op.getAlgorithm() + ? op.getAlgorithmAttr() + : b.getI64IntegerAttr(se::blas::kDefaultGemmAlgo); + call->setAttr(b.getStringAttr("algorithm"), algorithm_attr); call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr()); call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr()); - - // Set optional arguments that are defined only for some Gemm ops. - SetOptionalAttrs(b, op, call); + call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); // TODO(ezhulenev): Once cutom calls support passing structured attributes // we should be able to pass `mhlo.dot` attribute directly. @@ -438,16 +429,6 @@ class GemmLowering : public OpRewritePattern { GemmUidGenerator& uid_; }; -class GemmOpLowering : public GemmLowering { - public: - using GemmLowering::GemmLowering; -}; - -class GemmBiasOpLowering : public GemmLowering { - public: - using GemmLowering::GemmLowering; -}; - // -------------------------------------------------------------------------- // template @@ -1517,7 +1498,7 @@ void ConvertLmhloGpuToJitRtPass::runOnOperation() { // Each unique Gemm operation in the module will get assigned a uid. GemmUidGenerator gemm_uid; - patterns.insert(ctx, gemm_uid); + patterns.insert(ctx, gemm_uid); // Assign shared unique id to each unique pair of async start-done operations, // all other collective operations will get assigned uid. diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_tfrt_gpu.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_tfrt_gpu.cc index 6cf7b145fa5f14..dfbb9614d17a4a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_tfrt_gpu.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_tfrt_gpu.cc @@ -32,18 +32,18 @@ namespace tensorflow { void populateLmhloToTfrtGpuPasses(mlir::OpPassManager &pm) { pm.addPass(tensorflow::createConvertLmhloToGpuBranchPass()); - pm.addPass(tfrt::gpu::CreateStreamifyOpsPass< - lmhlo::AllGatherOp, lmhlo::AllReduceOp, lmhlo::ReduceScatterOp, - lmhlo::AllToAllOp, lmhlo::CollectivePermuteOp, lmhlo::CustomCallOp, - lmhlo::TriangularSolveOp, lmhlo::ReplicaIdOp, lmhlo::PartitionIdOp, - lmhlo::InfeedOp, lmhlo::OutfeedOp, lmhlo::FftOp, - lmhlo_gpu::ConvForwardOp, lmhlo_gpu::ConvBackwardInputOp, - lmhlo_gpu::ConvBackwardFilterOp, lmhlo_gpu::ConvForwardFusedOp, - lmhlo_gpu::ConvForwardFusedSideInputOp, lmhlo_gpu::GEMMOp, - lmhlo_gpu::GEMM_BiasOp, lmhlo_gpu::CholeskyOp, - lmhlo_gpu::AllReduceStartOp, lmhlo_gpu::AllReduceDoneOp, - mlir::func::CallOp, mlir::memref::LoadOp, tfrt::compiler::CallOp, - tfrt::compiler::WhileOp>()); + pm.addPass( + tfrt::gpu::CreateStreamifyOpsPass< + lmhlo::AllGatherOp, lmhlo::AllReduceOp, lmhlo::ReduceScatterOp, + lmhlo::AllToAllOp, lmhlo::CollectivePermuteOp, lmhlo::CustomCallOp, + lmhlo::TriangularSolveOp, lmhlo::ReplicaIdOp, lmhlo::PartitionIdOp, + lmhlo::InfeedOp, lmhlo::OutfeedOp, lmhlo::FftOp, + lmhlo_gpu::ConvForwardOp, lmhlo_gpu::ConvBackwardInputOp, + lmhlo_gpu::ConvBackwardFilterOp, lmhlo_gpu::ConvForwardFusedOp, + lmhlo_gpu::ConvForwardFusedSideInputOp, lmhlo_gpu::GEMMOp, + lmhlo_gpu::CholeskyOp, lmhlo_gpu::AllReduceStartOp, + lmhlo_gpu::AllReduceDoneOp, mlir::func::CallOp, mlir::memref::LoadOp, + tfrt::compiler::CallOp, tfrt::compiler::WhileOp>()); pm.addPass(tensorflow::createConvertLmhloToGpuPass()); pm.addPass(mlir::createGpuAsyncRegionPass()); tfrt::gpu::PopulateGpuToTfrtGpuPasses(pm); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index 11717f2d70e1c8..90dd9ca8d3a864 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -102,6 +102,7 @@ Optional TFAllocOp::buildClone(OpBuilder &builder, Value alloc) { //===----------------------------------------------------------------------===// // JITExecuteOp //===----------------------------------------------------------------------===// + Optional JITExecuteOp::buildDealloc(OpBuilder &builder, Value alloc) { auto funcop = alloc.getParentRegion()->getParentOfType(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index 55f9f6cd298663..517c6f68ed522f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -304,10 +304,10 @@ def TFFramework_JITCompileFromStrOp : TFFramework_Op<"jit_compile_from_str", //===----------------------------------------------------------------------===// def TFFramework_JITExecuteOp : TFFramework_Op<"jit_execute", [ - AttrSizedOperandSegments, + AttrSizedOperandSegments, MemoryEffects<[MemAlloc]>, DeclareOpInterfaceMethods]> { + ["buildDealloc", "buildClone"]>]> { let summary = "Executes a JIT-compiled function through the TF framework"; let description = [{ The op takes an optional TF context, so that it can be added at a later @@ -325,17 +325,12 @@ def TFFramework_JITExecuteOp : TFFramework_Op<"jit_execute", [ Variadic>:$operands ); let results = (outs - Variadic>:$results - ); - - let builders = [ - OpBuilder<(ins "Optional":$ctx, "Value":$callable, - "ValueRange":$operands)> - ]; + Res , "", + [MemAlloc]>:$result); let assemblyFormat = [{ (`ctx` `(` $ctx^ `)`)? $callable `(` $operands `)` attr-dict - `:` type($operands) `->` type($results) + `:` type($operands) `->` type($result) }]; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_deallocation.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_deallocation.mlir new file mode 100644 index 00000000000000..0f0f92f90a2b93 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/buffer_deallocation.mlir @@ -0,0 +1,31 @@ +// RUN: kernel-gen-opt %s --buffer-deallocation | FileCheck %s + +// CHECK-LABEL: @jit_execute_allocation +// CHECK-SAME: %[[CTX:.*]]: !tf_framework.op_kernel_context, %[[ARG:.*]]: memref<*xf32>, %[[PRED:.*]]: i1, %[[CALLABLE:.*]]: !tf_framework.jit_callable, %[[SIZE:.*]]: index, %[[SHAPE:.*]]: memref +func.func @jit_execute_allocation(%ctx: !tf_framework.op_kernel_context, + %arg: memref<*xf32>, %pred: i1, %callable: !tf_framework.jit_callable, + %size: index, %shape: memref) -> memref<*xf32> { + // CHECK: %[[RES:.*]] = scf.if %[[PRED]] + // CHECK: %[[JIT_EXECUTE:.*]] = tf_framework.jit_execute ctx(%[[CTX]]) %[[CALLABLE]](%[[ARG]]) + // CHECK: %[[INNER_RES:.*]] = bufferization.clone %[[JIT_EXECUTE]] + // CHECK: tf_framework.dealloc(%[[CTX]], %[[JIT_EXECUTE]]) + // CHECK: scf.yield %[[INNER_RES]] + // CHECK: else + // CHECK: %[[ALLOC:.*]] = tf_framework.alloc(%[[CTX]], %[[SIZE]]) + // CHECK: %[[RESHAPE:.*]] = memref.reshape %[[ALLOC]](%[[SHAPE]]) + // CHECK: %[[INNER_RES:.*]] = bufferization.clone %[[RESHAPE]] + // CHECK: tf_framework.dealloc(%[[CTX]], %[[ALLOC]]) + // CHECK: scf.yield %[[INNER_RES]] + // CHECK: return %[[RES]] + %res = scf.if %pred -> (memref<*xf32>) { + %inner_res = tf_framework.jit_execute ctx(%ctx) %callable(%arg) + : memref<*xf32> -> memref<*xf32> + scf.yield %inner_res : memref<*xf32> + } else { + %alloc = tf_framework.alloc(%ctx, %size) : memref + %inner_res = memref.reshape %alloc(%shape) + : (memref, memref) -> memref<*xf32> + scf.yield %inner_res : memref<*xf32> + } + return %res : memref<*xf32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir index 5935df60d3d248..f85658e6cea79d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir @@ -3,9 +3,9 @@ // RUN: max-supported-rank=32 enable-ftz=false cpu-codegen=false" | \ // RUN: FileCheck %s -// CHECK-LABEL: @unary_tanh_rint +// CHECK-LABEL: @unary_tanh // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func.func @unary_tanh_rint(%arg : tensor<*xf32>) -> tensor<*xf32> { +func.func @unary_tanh(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[CALLABLE:.*]] = tf_framework.jit_compile_from_str // CHECK-SAME: " // CHECK-SAME: module { @@ -13,8 +13,7 @@ func.func @unary_tanh_rint(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SAME: attributes {tf_entry} // CHECK-SAME: { // CHECK-SAME: %0 = \22tf.Tanh\22(%arg0) - // CHECK-SAME: %1 = \22tf.Rint\22(%0) - // CHECK-SAME: return %1 + // CHECK-SAME: return %0 // CHECK-SAME: } // CHECK-SAME: } // CHECK-SAME: " @@ -28,8 +27,7 @@ func.func @unary_tanh_rint(%arg : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[RES:.*]] = tf_framework.jit_execute %[[CALLABLE]](%[[ARG]]) // CHECK: return %[[RES]] %0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32> - %1 = "tf.Rint"(%0) : (tensor<*xf32>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> + func.return %0 : tensor<*xf32> } // ----- diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc index d54a92fcda93d0..5f321907b0baf7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc @@ -43,13 +43,9 @@ struct BufferizeJITExecuteOp LogicalResult matchAndRewrite( tf_framework::JITExecuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector result_types; - if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), - result_types))) { - return failure(); - } + Type result_ty = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( - op, result_types, adaptor.getOperands(), op->getAttrs()); + op, result_ty, adaptor.getOperands(), op->getAttrs()); return success(); } }; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc index 228860e129ce1c..f83072f4df1d1b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc @@ -148,7 +148,7 @@ struct JITExecuteOpConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { llvm::Optional ctx = FindOpKernelContext(op); if (!ctx) return failure(); - rewriter.replaceOpWithNewOp(op, op.getResultTypes(), *ctx, + rewriter.replaceOpWithNewOp(op, op.result().getType(), *ctx, op.callable(), op.operands()); return success(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index c01d8fa2b23de2..47ccc08253e1b7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -322,11 +322,10 @@ class JITExecuteOpConverter : public ConvertToLLVMCallOpPattern { LogicalResult matchAndRewrite( JITExecuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // The TF context must be known for a successful lowering. Also, we support - // only one result. - if (adaptor.ctx() == nullptr || op.operands().empty() || - op.getNumResults() != 1) + // The TF context must be known for a successful lowering. + if (adaptor.ctx() == nullptr || op.operands().empty()) { return failure(); + } // Allocate result on stack. auto loc = op.getLoc(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc index 8436b639106393..542f25731203a6 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc @@ -52,104 +52,130 @@ namespace { constexpr int64_t i32BitLimit = 4294967296; using shape::ShapeOfOp; -bool IsTFOperation(Operation *op) { - return op != nullptr && - op->getDialect() == - op->getContext()->getLoadedDialect(); +bool IsSingleResultTFOperation(Operation *op) { + assert(op != nullptr && "expect op"); + if (op->getDialect() != + op->getContext()->getLoadedDialect()) + return false; + if (op->getNumResults() != 1) return false; + return true; } bool IsUnaryTFOperation(Operation *op) { - return IsTFOperation(op) && op->getNumOperands() == 1; + return IsSingleResultTFOperation(op) && op->getNumOperands() == 1; } -struct ModuleParameters { - llvm::ArrayRef tile_sizes; - llvm::ArrayRef unroll_factors; - int64_t max_supported_rank; - bool index_64bit; - bool cpu_codegen; -}; - struct TFToJITInvocationsPattern : public RewritePattern { explicit TFToJITInvocationsPattern(MLIRContext *ctx) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - // Apply to all TF ops except those that are already in a JIT-compiled - // region. - if (!IsTFOperation(op) || op->getParentOfType()) + // Apply to all single result TF ops except those that are already in a + // JIT-compiled region. + if (!IsSingleResultTFOperation(op) || + op->getParentOfType()) return failure(); - // Find last TF op. - while (IsTFOperation(op->getNextNode())) op = op->getNextNode(); - - // Find JIT compile region operands and results. - SmallVector cluster; - llvm::SmallPtrSet operand_set, result_set; - Operation *it = op; - while (IsTFOperation(it)) { - // Find results that escape the JIT compile region. - for (auto &use : it->getUses()) { - if (!llvm::is_contained(cluster, use.getOwner())) - result_set.insert(use.get()); - } - - // Update JIT region operands and results. - for (Value v : it->getResults()) operand_set.erase(v); - for (Value v : it->getOperands()) operand_set.insert(v); - - cluster.push_back(it); - it = it->getPrevNode(); - } - - // Introduce order to the operands and results. - auto operands = llvm::to_vector<16>(operand_set); - auto results = llvm::to_vector<16>(result_set); - auto operand_types = llvm::to_vector<16>( - llvm::map_range(operands, [](Value v) { return v.getType(); })); - auto result_types = llvm::to_vector<16>( - llvm::map_range(results, [](Value v) { return v.getType(); })); + Location loc = op->getLoc(); + Value op_result = op->getResults().front(); // Create the JIT compile op. - auto loc = op->getLoc(); auto jit_compile_op = rewriter.create( - loc, rewriter.getType(), llvm::None); + loc, rewriter.getType(), + /*ctx=*/llvm::None); - // Move the TF operations into the new op's body. - BlockAndValueMapping bvm; + // Move the TF operation into the body. { OpBuilder::InsertionGuard guard(rewriter); - Block *block = - rewriter.createBlock(&jit_compile_op.body(), {}, operand_types, - SmallVector(operands.size(), loc)); - for (auto it : llvm::zip(operands, block->getArguments())) + llvm::SmallVector locs(op->getNumOperands(), loc); + Block *block = rewriter.createBlock(&jit_compile_op.body(), {}, + op->getOperandTypes(), locs); + + // Map operands. + BlockAndValueMapping bvm; + for (auto it : llvm::zip(op->getOperands(), block->getArguments())) bvm.map(std::get<0>(it), std::get<1>(it)); + rewriter.setInsertionPointToStart(block); - for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm); - auto mapped_results = llvm::to_vector<16>( - llvm::map_range(results, [&](Value v) { return bvm.lookup(v); })); - rewriter.create(loc, TypeRange{}, - mapped_results); + rewriter.clone(*op, bvm); + rewriter.create(loc, + bvm.lookup(op_result)); } // Create JIT execute op. - auto jit_execute_op = rewriter.create( - loc, result_types, Value(), jit_compile_op.result(), operands); - - // Replace old TF ops with the new results. - for (auto it : llvm::zip(results, jit_execute_op.results())) - bvm.map(std::get<0>(it), std::get<1>(it)); - for (Operation *it : cluster) { - if (it->getUses().empty()) { - rewriter.eraseOp(it); - continue; - } - auto replacements = llvm::to_vector<16>(llvm::map_range( - it->getResults(), [&](Value v) { return bvm.lookup(v); })); - rewriter.replaceOp(it, replacements); + rewriter.replaceOpWithNewOp( + op, op_result.getType(), /*ctx=*/Value(), jit_compile_op.result(), + op->getOperands()); + return success(); + } +}; + +struct TFToI64JITInvocationForLargeTensorsPattern : public RewritePattern { + explicit TFToI64JITInvocationForLargeTensorsPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!IsUnaryTFOperation(op) || + !llvm::isa(op->getParentOp())) { + return failure(); } + auto results = llvm::to_vector<16>(op->getResults()); + auto operand_types = llvm::to_vector<16>(llvm::map_range( + op->getOperands(), [](Value v) { return v.getType(); })); + auto result_types = llvm::to_vector<16>( + llvm::map_range(results, [](Value v) { return v.getType(); })); + + // Create the JIT compile op. + auto loc = op->getLoc(); + Value shape_size_limit = + rewriter.create(loc, i32BitLimit); + auto arg = op->getOperands().front(); + auto shape = rewriter.create(loc, arg); + auto num_elems = rewriter.create(loc, shape); + Value coniditon_check_main = rewriter.create( + loc, arith::CmpIPredicate::sgt, num_elems, shape_size_limit); + + Value conditional_path = + rewriter + .create( + loc, op->getResultTypes(), coniditon_check_main, + [&](OpBuilder &b, Location l) { + auto jit_compile_op = + rewriter.create( + loc, + rewriter.getType(), + llvm::None); + BlockAndValueMapping bvm; + { + OpBuilder::InsertionGuard guard(rewriter); + Block *block = rewriter.createBlock( + &jit_compile_op.body(), {}, operand_types, + SmallVector(operand_types.size(), loc)); + for (auto it : + llvm::zip(op->getOperands(), block->getArguments())) + bvm.map(std::get<0>(it), std::get<1>(it)); + rewriter.setInsertionPointToStart(block); + rewriter.clone(*op, bvm); + auto new_op = rewriter.clone(*op, bvm); + rewriter.create( + loc, TypeRange{}, new_op->getResults()); + } + auto jit_execute_op = + rewriter.create( + loc, result_types, Value(), jit_compile_op.result(), + op->getOperands()); + b.create(l, jit_execute_op.result()); + }, + [&](OpBuilder &b, Location l) { + auto new_op = rewriter.clone(*op); + b.create(l, new_op->getResult(0)); + }) + .getResult(0); + + rewriter.replaceOp(op, conditional_path); return success(); } }; @@ -259,74 +285,6 @@ struct TFToJITInvocationPass } }; -struct TFToI64JITInvocationForLargeTensorsPattern : public RewritePattern { - explicit TFToI64JITInvocationForLargeTensorsPattern(MLIRContext *ctx) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (!IsUnaryTFOperation(op) || - !llvm::isa(op->getParentOp())) { - return failure(); - } - - auto results = llvm::to_vector<16>(op->getResults()); - auto operand_types = llvm::to_vector<16>(llvm::map_range( - op->getOperands(), [](Value v) { return v.getType(); })); - auto result_types = llvm::to_vector<16>( - llvm::map_range(results, [](Value v) { return v.getType(); })); - - // Create the JIT compile op. - auto loc = op->getLoc(); - Value shape_size_limit = - rewriter.create(loc, i32BitLimit); - auto arg = op->getOperands().front(); - auto shape = rewriter.create(loc, arg); - auto num_elems = rewriter.create(loc, shape); - Value coniditon_check_main = rewriter.create( - loc, arith::CmpIPredicate::sgt, num_elems, shape_size_limit); - - Value conditional_path = - rewriter - .create( - loc, op->getResultTypes(), coniditon_check_main, - [&](OpBuilder &b, Location l) { - auto jit_compile_op = - rewriter.create( - loc, - rewriter.getType(), - llvm::None); - BlockAndValueMapping bvm; - { - OpBuilder::InsertionGuard guard(rewriter); - Block *block = rewriter.createBlock( - &jit_compile_op.body(), {}, operand_types, - SmallVector(operand_types.size(), loc)); - for (auto it : - llvm::zip(op->getOperands(), block->getArguments())) - bvm.map(std::get<0>(it), std::get<1>(it)); - rewriter.setInsertionPointToStart(block); - rewriter.clone(*op, bvm); - auto new_op = rewriter.clone(*op, bvm); - rewriter.create( - loc, TypeRange{}, new_op->getResults()); - } - auto jit_execute_op = - rewriter.create( - loc, result_types, Value(), jit_compile_op.result(), - op->getOperands()); - b.create(l, jit_execute_op.results()); - }, - [&](OpBuilder &b, Location l) { - auto new_op = rewriter.clone(*op); - b.create(l, new_op->getResult(0)); - }) - .getResult(0); - - rewriter.replaceOp(op, conditional_path); - return success(); - } -}; } // namespace void PopulateTFToJITInvocationPatterns( diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index ef36785515da8d..57853898123d2e 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -1059,14 +1059,16 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( switch (instruction->random_distribution()) { case xla::RNG_UNIFORM: return func_builder - ->create(loc, result_type, operands[0], - operands[1], shape) + ->create( + loc, result_type, operands[0], operands[1], shape, + ::mlir::mhlo::RngDistribution::UNIFORM) .getOperation(); case xla::RNG_NORMAL: return func_builder - ->create(loc, result_type, operands[0], - operands[1], shape) + ->create(loc, result_type, operands[0], + operands[1], shape, + ::mlir::mhlo::RngDistribution::NORMAL) .getOperation(); default: @@ -1272,7 +1274,8 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertPrecisionConfig(&instruction->precision_config(), builder_))); return func_builder - ->create(loc, result_type, operands, attributes) + ->create(loc, result_type, operands, + attributes) .getOperation(); } @@ -1418,7 +1421,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kCeil, CeilOp); NO_ATTRIBUTE_CASE(kClamp, ClampOp); NO_ATTRIBUTE_CASE(kComplex, ComplexOp); - NO_ATTRIBUTE_CASE(kCos, CosOp); + NO_ATTRIBUTE_CASE(kCos, CosineOp); NO_ATTRIBUTE_CASE(kDivide, DivOp); NO_ATTRIBUTE_CASE(kExp, ExpOp); NO_ATTRIBUTE_CASE(kExpm1, Expm1Op); @@ -1452,7 +1455,7 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( NO_ATTRIBUTE_CASE(kShiftRightArithmetic, ShiftRightArithmeticOp); NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp); NO_ATTRIBUTE_CASE(kSign, SignOp); - NO_ATTRIBUTE_CASE(kSin, SinOp); + NO_ATTRIBUTE_CASE(kSin, SineOp); NO_ATTRIBUTE_CASE(kSqrt, SqrtOp); NO_ATTRIBUTE_CASE(kSubtract, SubOp); NO_ATTRIBUTE_CASE(kTanh, TanhOp); diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 17521a0dffcee6..f7852b028fa32c 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -351,7 +351,7 @@ StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) { return xla::HloOpcode::kClamp; } else if (isa(op)) { return xla::HloOpcode::kConcatenate; - } else if (isa(op)) { + } else if (isa(op)) { return xla::HloOpcode::kConvolution; } else if (isa(op)) { return xla::HloOpcode::kSort; @@ -371,7 +371,7 @@ StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) { return xla::HloOpcode::kCeil; } else if (isa(op)) { return xla::HloOpcode::kClz; - } else if (isa(op)) { + } else if (isa(op)) { return xla::HloOpcode::kCos; } else if (isa(op)) { return xla::HloOpcode::kExp; @@ -404,7 +404,7 @@ StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) { return xla::HloOpcode::kRsqrt; } else if (isa(op)) { return xla::HloOpcode::kSign; - } else if (isa(op)) { + } else if (isa(op)) { return xla::HloOpcode::kSin; } else if (isa(op)) { return xla::HloOpcode::kSqrt; diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index aba67f944e7e8c..568ea7f76fa2c5 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" @@ -107,7 +108,7 @@ StatusOr MlirHloBuilder::ConvGeneralDilatedInternal( mlir::ArrayAttr config_attr; if (precision_config) config_attr = ConvertPrecisionConfig(precision_config, &builder_); - auto op = builder_.create( + auto op = builder_.create( loc_, ty, GetValue(lhs), GetValue(rhs), GetI64ElementsAttr(window_strides, &builder_), ConvertPadding(padding, &builder_), @@ -402,24 +403,30 @@ StatusOr MlirHloBuilder::SetDimensionSizeInternal(const Shape& shape, StatusOr MlirHloBuilder::RngOpInternal( RandomDistribution distribution, absl::Span parameters, const Shape& shape) { - // TODO(hinsu): Introduce RngOp in the HLO dialect in MLIR and then RngUniform - // and RngNormal can be mapped to the new op. - std::string op_name; + mlir::mhlo::RngDistributionAttr attr; if (distribution == xla::RandomDistribution::RNG_UNIFORM) { - op_name = "mhlo.rng_uniform"; + attr = mlir::mhlo::RngDistributionAttr::get( + builder_.getContext(), mlir::mhlo::RngDistribution::UNIFORM); } else { TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL) << "Unexpected distribution: " << distribution; - op_name = "mhlo.rng_normal"; + attr = mlir::mhlo::RngDistributionAttr::get( + builder_.getContext(), mlir::mhlo::RngDistribution::NORMAL); } + llvm::SmallVector attributes = { + builder_.getNamedAttr("rng_distribution", attr)}; if (shape.is_dynamic()) return Unimplemented("RngOp with dynamic dims not supported"); - llvm::SmallVector operands; - operands.append(parameters.begin(), parameters.end()); - operands.push_back( - ConstantLiteral(LiteralUtil::CreateR1(shape.dimensions()))); - return CreateOp(op_name, shape, operands); + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + + auto op = builder_.create( + loc_, ty, GetValue(parameters[0]), GetValue(parameters[1]), + GetValue( + ConstantLiteral(LiteralUtil::CreateR1(shape.dimensions()))), + attr); + return MakeXlaOp(op.getResult()); } StatusOr MlirHloBuilder::RngBitGeneratorInternal( diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index beea18abb29719..9bcdd697ff3ff8 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -811,6 +811,17 @@ LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(CosineOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + auto result = op.getResult(); + xla::XlaOp arg; + if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op))) + return mlir::failure(); + auto xla_result = xla::Cos(Unwrap(arg)); + value_map[result] = xla_result; + return mlir::success(); +} + LogicalResult ExportXlaOp(DotOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::XlaOp lhs, rhs; @@ -1031,7 +1042,7 @@ LogicalResult ExportXlaOp(ConstantOp op, OpLoweringContext ctx) { return failure(); } -LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) { +LogicalResult ExportXlaOp(mlir::mhlo::ConvolutionOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::XlaOp lhs, rhs; if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure(); @@ -1425,24 +1436,20 @@ LogicalResult ExportXlaOp(BatchNormTrainingOp op, OpLoweringContext ctx) { return mlir::success(); } -LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) { - auto& value_map = *ctx.values; - xla::XlaOp mu, sigma; - if (failed(GetXlaOp(op.mu(), value_map, &mu, op))) return failure(); - if (failed(GetXlaOp(op.sigma(), value_map, &sigma, op))) return failure(); - - value_map[op] = xla::RngNormal(mu, sigma, xla::TypeToShape(op.getType())); - return success(); -} - -LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) { +LogicalResult ExportXlaOp(RngOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::XlaOp a, b; if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure(); if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure(); - value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType())); - return success(); + if (op.rng_distribution() == RngDistribution::UNIFORM) { + value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType())); + return success(); + } else if (op.rng_distribution() == RngDistribution::NORMAL) { + value_map[op] = xla::RngNormal(a, b, xla::TypeToShape(op.getType())); + return success(); + } + return failure(); } LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) { @@ -1528,6 +1535,17 @@ LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) { return success(); } +mlir::LogicalResult ExportXlaOp(mlir::mhlo::SineOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + auto result = op.getResult(); + xla::XlaOp arg; + if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op))) + return mlir::failure(); + auto xla_result = xla::Sin(Unwrap(arg)); + value_map[result] = xla_result; + return mlir::success(); +} + LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) { return failure(); } diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt index 4278c1f78a0f9a..b29961c83ab341 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -131,6 +131,7 @@ HloModule Gemm // CHECK-SAME: algorithm = 7 : i64 // CHECK-SAME: alpha_imag = 0.000000e+00 : f64 // CHECK-SAME: alpha_real = 1.000000e+00 : f64 +// CHECK-SAME: beta = 0.000000e+00 : f64 // CHECK-NOT: lhs_batching_dimensions // CHECK-NOT: rhs_batching_dimensions // CHECK-SAME: lhs_contracting_dimensions = [1] @@ -147,55 +148,6 @@ ENTRY main { // ----- -HloModule GemmBias - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.gemm_bias" -// CHECK-SAME: algorithm = 0 : i64 -// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 -// CHECK-SAME: alpha_real = 1.000000e+00 : f64 -// CHECK-SAME: beta = 1.000000e+00 : f64 -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [1] -// CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK-SAME: precision_config = [#mhlo<"precision HIGH">, #mhlo<"precision HIGHEST">] -// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>) -ENTRY main { - %A = f32[1,1]{1,0} parameter(0) - %B = f32[1,4]{1,0} parameter(1) - %C = f32[1,4]{1,0} parameter(2) - ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C), - custom_call_target="__cublas$gemm", - backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"HIGH\",\"HIGHEST\"]},\"selected_algorithm\":\"0\"}" -} - -// ----- - -HloModule GemmBias - -// CHECK-LABEL: func @main -// CHECK: "lmhlo_gpu.gemm_bias" -// CHECK-SAME: algorithm = 0 : i64 -// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 -// CHECK-SAME: alpha_real = 1.000000e+00 : f64 -// CHECK-SAME: beta = 1.000000e+00 : f64 -// CHECK-NOT: lhs_batching_dimensions -// CHECK-NOT: rhs_batching_dimensions -// CHECK-SAME: lhs_contracting_dimensions = [1] -// CHECK-SAME: rhs_contracting_dimensions = [0] -// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>) -ENTRY main { - %A = f32[1,1]{1,0} parameter(0) - %B = f32[1,4]{1,0} parameter(1) - %C = f32[1,4]{1,0} parameter(2) - ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C), - custom_call_target="__cublas$gemm", - backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"selected_algorithm\":\"0\"}" -} - -// ----- - HloModule AllReduce // Test all-reduce diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir index 1978fa646cc828..47766f93c38482 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -852,8 +852,8 @@ func.func @main(%arg: tensor<3x4xf32>, %start1: tensor, %start2: tensor, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> +// CHECK: "mhlo.dynamic_update_slice"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> func.func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor, %arg3: tensor) -> tensor<4x4xf32> { - %0 = "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> + %0 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> func.return %0 : tensor<4x4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir index 6869e052355a5a..c8e754385dc828 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir @@ -4040,7 +4040,7 @@ func.func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> // CHECK: return %[[F32]] func.return %0 : tensor<12x?x64xf32> @@ -4053,7 +4053,7 @@ func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng_normal"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*NORMAL.*}} -> tensor<12x?x64xf32> %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> // CHECK: return %[[F32]] func.return %0 : tensor<12x?x64xf32> @@ -5130,7 +5130,7 @@ func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { // CHECK-DAG: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> // CHECK-DAG: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor // CHECK-DAG: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor - // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]]) + // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) {rng_distribution = #mhlo.rng_distribution} // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ({ // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): // CHECK: "mhlo.compare"([[ARG1]], [[ARG2]]) {compare_type = #mhlo<"comparison_type TOTALORDER">, comparison_direction = #mhlo<"comparison_direction LT">} @@ -5144,9 +5144,9 @@ func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { // CHECK-LABEL: @random_shuffle_1D_10240 func.func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { - // CHECK: mhlo.rng_uniform + // CHECK: mhlo.rng{{.*UNIFORM.*}} // CHECK: mhlo.sort - // CHECK: mhlo.rng_uniform + // CHECK: mhlo.rng{{.*UNIFORM.*}} // CHECK: mhlo.sort %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) func.return %0: tensor<10240xf32> @@ -5162,7 +5162,7 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK-DAG: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> // CHECK-DAG: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor // CHECK-DAG: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor - // CHECK: [[SWAPS:%.*]] = "mhlo.rng_uniform"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) + // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) {rng_distribution = #mhlo.rng_distribution} // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor @@ -5175,8 +5175,8 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG]]) {slice_sizes = dense<1> : tensor} : (tensor<4xi32>, tensor) -> tensor<1xi32> // CHECK: [[SWP:%.*]] = "mhlo.reshape"([[SWP_IDX]]) : (tensor<1xi32>) -> tensor // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) {slice_sizes = dense<1> : tensor} - // CHECK: [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - // CHECK: [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> + // CHECK: [[INDICES1:%.*]] = "mhlo.dynamic_update_slice"([[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> + // CHECK: [[INDICES2:%.*]] = "mhlo.dynamic_update_slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[ITER_ARG]], [[ONE]] // CHECK: "mhlo.return"([[NEW_IV]], [[ITER_ARG1]], [[INDICES2]]) @@ -5654,7 +5654,7 @@ func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %a // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) - // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) + // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic_update_slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> // CHECK: return [[UPDATE]] @@ -5675,9 +5675,9 @@ func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf3 // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]]) // CHECK-DAG: [[RESHAPE3:%.+]] = "mhlo.reshape"([[SLICE3]]) - // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) - // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) - // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic_update_slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic_update_slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic_update_slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> @@ -5692,7 +5692,7 @@ func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> // CHECK: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> + // CHECK: [[DUS:%.+]] = "mhlo.dynamic_update_slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> // CHECK: return [[DUS]] %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> func.return %0 : tensor<4x16xf32> @@ -5704,7 +5704,7 @@ func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf func.func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK: [[DUS:%.+]] = "mhlo.dynamic_update_slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> // CHECK: return [[DUS]] %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir index 785392ad1cafe1..4d611a592e3ccc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir @@ -48,7 +48,7 @@ func.func @random_uniform_simple(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> // CHECK: return %[[F32]] func.return %0 : tensor<12x?x64xf32> @@ -62,7 +62,7 @@ func.func @random_uniform_with_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12x64 // CHECK-NEXT : %1 = mhlo.constant dense<0.000000e+00> : tensor // CHECK-NEXT : %2 = mhlo.constant dense<1.000000e+00> : tensor // CHECK-NEXT : %3 = mhlo.convert(%0) : (tensor<4xi32>) -> tensor<4xi64> - // CHECK-NEXT : %4 = "mhlo.rng_uniform"(%1, %2, %3) : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> + // CHECK-NEXT : %4 = "mhlo.rng"(%1, %2, %3) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> %cst = "tf.Const"() {value = dense<[32, 12, 12, 64]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = "tf.RandomUniform"(%cst) {seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<4xi32>) -> tensor<32x12x12x64xf32> // CHECK: return %4 : tensor<32x12x12x64xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 31510194a0556d..063992bd295e7a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -175,7 +175,7 @@ func.func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> // CHECK: %[[DIM1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor - // CHECK: "mhlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]]) + // CHECK: "mhlo.dynamic_update_slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]]) %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32> func.return %0: tensor<3x4xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 21bd4f23ce1f6e..3d24a46211e772 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -4163,7 +4163,7 @@ func.func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> // CHECK: return %[[F32]] func.return %0 : tensor<12x?x64xf32> @@ -4176,7 +4176,7 @@ func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK: %[[CONV:.*]] = mhlo.convert(%arg0) : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng_normal"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*NORMAL.*}} -> tensor<12x?x64xf32> %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> // CHECK: return %[[F32]] func.return %0 : tensor<12x?x64xf32> @@ -5297,7 +5297,7 @@ func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { // CHECK-DAG: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> // CHECK-DAG: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor // CHECK-DAG: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor - // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]]) + // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) {rng_distribution = #mhlo.rng_distribution} // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ({ // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): // CHECK: "mhlo.compare"([[ARG1]], [[ARG2]]) {compare_type = #mhlo<"comparison_type TOTALORDER">, comparison_direction = #mhlo<"comparison_direction LT">} @@ -5311,9 +5311,9 @@ func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { // CHECK-LABEL: @random_shuffle_1D_10240 func.func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { - // CHECK: mhlo.rng_uniform + // CHECK: mhlo.rng{{.*UNIFORM.*}} // CHECK: mhlo.sort - // CHECK: mhlo.rng_uniform + // CHECK: mhlo.rng{{.*UNIFORM.*}} // CHECK: mhlo.sort %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) func.return %0: tensor<10240xf32> @@ -5329,7 +5329,7 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK-DAG: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> // CHECK-DAG: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor // CHECK-DAG: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor - // CHECK: [[SWAPS:%.*]] = "mhlo.rng_uniform"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) + // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) {rng_distribution = #mhlo.rng_distribution} // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor @@ -5342,8 +5342,8 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG0]]) {slice_sizes = dense<1> : tensor} : (tensor<4xi32>, tensor) -> tensor<1xi32> // CHECK: [[SWP:%.*]] = "mhlo.reshape"([[SWP_IDX]]) : (tensor<1xi32>) -> tensor // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) {slice_sizes = dense<1> : tensor} - // CHECK: [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG0]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - // CHECK: [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> + // CHECK: [[INDICES1:%.*]] = "mhlo.dynamic_update_slice"([[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG0]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> + // CHECK: [[INDICES2:%.*]] = "mhlo.dynamic_update_slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[ITER_ARG0]], [[ONE]] // CHECK: "mhlo.return"([[NEW_IV]], [[ITER_ARG1]], [[INDICES2]]) @@ -5822,7 +5822,7 @@ func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %a // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) - // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) + // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic_update_slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> // CHECK: return [[UPDATE]] @@ -5843,9 +5843,9 @@ func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf3 // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]]) // CHECK-DAG: [[RESHAPE3:%.+]] = "mhlo.reshape"([[SLICE3]]) - // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) - // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) - // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic_update_slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic_update_slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic_update_slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> @@ -5860,7 +5860,7 @@ func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> // CHECK: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> + // CHECK: [[DUS:%.+]] = "mhlo.dynamic_update_slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> // CHECK: return [[DUS]] %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> func.return %0 : tensor<4x16xf32> @@ -5872,7 +5872,7 @@ func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf func.func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK: [[DUS:%.+]] = "mhlo.dynamic_update_slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> // CHECK: return [[DUS]] %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 86f5271d8d18a7..793199ca98ec1f 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1022,7 +1022,7 @@ func.func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { // CHECK: HloModule func.func @main(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %0 = "mhlo.rng_normal"(%mu, %sigma, %shape) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1038,7 +1038,7 @@ func.func @main() -> tensor<2x3x5xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = mhlo.constant dense<1.000000e+00> : tensor %2 = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - %3 = "mhlo.rng_uniform"(%0, %1, %2) : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %3 : tensor<2x3x5xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 8bf5b9b391f8fc..f037249cb22483 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -487,7 +487,7 @@ add { %Arg_2.3 = s32[] parameter(2) %Arg_3.4 = s32[] parameter(3) - // CHECK-NEXT: "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4) } @@ -497,7 +497,7 @@ add { %Arg_1.2 = f32[2] parameter(1) %Arg_2.3 = s32[] parameter(2) - // CHECK-NEXT: "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3) } @@ -838,7 +838,7 @@ add { %Arg_0.1 = f32[] parameter(0) %Arg_1.2 = f32[] parameter(1) // CHECK: [[CST:%.*]] = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // CHECK: "mhlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]]) + // CHECK: "mhlo.rng"([[ARG0]], [[ARG1]], [[CST]]) {rng_distribution = #mhlo.rng_distribution} ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_normal } @@ -848,7 +848,7 @@ add { %Arg_0.1 = f32[] parameter(0) %Arg_1.2 = f32[] parameter(1) // CHECK: [[CST:%.*]] = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // CHECK: "mhlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]]) + // CHECK: "mhlo.rng"([[ARG0]], [[ARG1]], [[CST]]) {rng_distribution = #mhlo.rng_distribution} ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_uniform } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index eb5eb00ec00ffa..83537c433d895e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -478,9 +478,9 @@ static Value ApplyReduction(Location loc, Value input, // Creates a mhlo.rng_uniform op with `builder` to generate `num_elements` // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`). -static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements, - int lower_limit, int upper_limit, - OpBuilder *builder) { +static mhlo::RngOp CreateRngUniform32(Location loc, int num_elements, + int lower_limit, int upper_limit, + OpBuilder *builder) { auto shape_tensor = builder->create( loc, GetI64ElementsAttr({num_elements}, builder)); @@ -489,7 +489,8 @@ static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements, auto upper = builder->create( loc, builder->getI32IntegerAttr(upper_limit)); - return builder->create(loc, lower, upper, shape_tensor); + return builder->create(loc, lower, upper, shape_tensor, + ::mlir::mhlo::RngDistribution::UNIFORM); } using WhileBodyFnType = llvm::function_ref { NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, paddings_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::makeArrayRef(attrs)); + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + llvm::makeArrayRef(attrs)); return success(); } }; @@ -5040,7 +5041,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { // activation gradients // = gradients (with padding and dilation) mirrored_weights - Value result = rewriter.create( + Value result = rewriter.create( op.getLoc(), op.getType(), op.out_backprop(), filter, /*window_strides=*/ GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, @@ -5246,7 +5247,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { const int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format); - Value result = rewriter.create( + Value result = rewriter.create( op.getLoc(), op.getType(), op.input(), op.out_backprop(), /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), /*padding=*/paddings_attr, /*lhs_dilation=*/ diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 64ecaae1602938..2f7f7bbbf1f7d6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -646,7 +646,7 @@ foreach Mapping = [ [TF_CoshOp, CHLO_CoshOp], [TF_ComplexAbsOp, HLO_AbsOp], [TF_ConjOp, CHLO_ConjOp], - [TF_CosOp, HLO_CosOp], + [TF_CosOp, HLO_CosineOp], [TF_DigammaOp, CHLO_DigammaOp], [TF_ExpOp, HLO_ExpOp], [TF_Expm1Op, HLO_Expm1Op], @@ -666,7 +666,7 @@ foreach Mapping = [ [TF_RsqrtOp, HLO_RsqrtOp], [TF_SigmoidOp, HLO_LogisticOp], [TF_SinhOp, CHLO_SinhOp], - [TF_SinOp, HLO_SinOp], + [TF_SinOp, HLO_SineOp], [TF_SqrtOp, HLO_SqrtOp], [TF_TanhOp, HLO_TanhOp], [TF_TanOp, CHLO_TanOp] @@ -716,19 +716,30 @@ def : Pat<(TF_BitcastOp:$res HLO_Tensor:$arg), //===----------------------------------------------------------------------===// // Random ops. //===----------------------------------------------------------------------===// - -foreach srcDstOpPair = [[TF_RandomUniformOp, HLO_RngUniformOp], - [TF_RandomStandardNormalOp, HLO_RngNormalOp]] in { // TODO(b/148269299): handle random number generator seeds/states correctly. -def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2), - (srcDstOpPair[1] + +class HLO_RngDistributionValue : + ConstantAttr; + +def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), + (HLO_RngOp (HLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)), (HLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)), - (CastValueToI64 $old, $shape)), + (CastValueToI64 $old, $shape), + HLO_RngDistributionValue<"UNIFORM">), + [(IsShapedTensor $shape)]>; + +def : Pat<(TF_RandomStandardNormalOp:$old $shape, $seed, $seed2), + (HLO_RngOp + (HLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)), + (HLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)), + (CastValueToI64 $old, $shape), + HLO_RngDistributionValue<"NORMAL">), [(IsShapedTensor $shape)]>; -} //===----------------------------------------------------------------------===// // Sigmoid grad op. diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 53d8ab75d40ea2..fd0ea14d01fe10 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -788,44 +788,41 @@ StatusOr LhloDialectEmitter::EmitGemm( auto const config, custom_call->backend_config()); - auto set_common_attributes = [&](auto op) -> Operation* { - auto arrayref = [](absl::Span array) { - return llvm::ArrayRef{array.data(), array.size()}; - }; - auto hlo_dims = config.dot_dimension_numbers(); - auto mlir_dims = mhlo::DotDimensionNumbersAttr::get( - builder_.getContext(), arrayref(hlo_dims.lhs_batch_dimensions()), - arrayref(hlo_dims.rhs_batch_dimensions()), - arrayref(hlo_dims.lhs_contracting_dimensions()), - arrayref(hlo_dims.rhs_contracting_dimensions())); - op.setDotDimensionNumbersAttr(mlir_dims); - op.setAlphaRealAttr(builder_.getF64FloatAttr(config.alpha_real())); - op.setAlphaImagAttr(builder_.getF64FloatAttr(config.alpha_imag())); - if (config.algorithm_case() == - xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { - op.setAlgorithmAttr( - builder_.getI64IntegerAttr(config.selected_algorithm())); - } - op.setPrecisionConfigAttr( - xla::ConvertPrecisionConfig(&config.precision_config(), &builder_)); - return op.getOperation(); - }; - if (custom_call->operand_count() == 2) { - TF_ASSIGN_OR_RETURN(auto gemm, - CreateOpWithoutAttrs(custom_call)); - return set_common_attributes(gemm); + TF_RET_CHECK(config.beta() == 0.); + } else if (custom_call->operand_count() != 3) { + return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands"); } - if (custom_call->operand_count() == 3) { - TF_ASSIGN_OR_RETURN( - auto gemm_bias, - CreateOpWithoutAttrs(custom_call)); - gemm_bias.setBetaAttr(builder_.getF64FloatAttr(config.beta())); - return set_common_attributes(gemm_bias); - } + // GEMM may have two or three operands. However, in the three operand case, + // the third operand is updated in-place, so we treat that as an output here. + TF_ASSIGN_OR_RETURN( + lmhlo_gpu::GEMMOp op, + CreateOpWithoutAttrs(custom_call, + /*num_operands=*/2)); - return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands"); + auto arrayref = [](absl::Span array) { + return llvm::ArrayRef{array.data(), array.size()}; + }; + + auto hlo_dims = config.dot_dimension_numbers(); + auto mlir_dims = mhlo::DotDimensionNumbersAttr::get( + builder_.getContext(), arrayref(hlo_dims.lhs_batch_dimensions()), + arrayref(hlo_dims.rhs_batch_dimensions()), + arrayref(hlo_dims.lhs_contracting_dimensions()), + arrayref(hlo_dims.rhs_contracting_dimensions())); + op.setDotDimensionNumbersAttr(mlir_dims); + op.setAlphaRealAttr(builder_.getF64FloatAttr(config.alpha_real())); + op.setAlphaImagAttr(builder_.getF64FloatAttr(config.alpha_imag())); + op.setBetaAttr(builder_.getF64FloatAttr(config.beta())); + if (config.algorithm_case() == + xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { + op.setAlgorithmAttr( + builder_.getI64IntegerAttr(config.selected_algorithm())); + } + op.setPrecisionConfigAttr( + xla::ConvertPrecisionConfig(&config.precision_config(), &builder_)); + return op.getOperation(); } static StatusOr GetLHLOActivation( diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 4cd2c92c016a5a..e05518ceac5f01 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -120,13 +120,13 @@ def fn(): w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0]])) w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]])) - w2 = w1.write(2, convert([[8.0, 9.0], [204.0, 205.0]])) + w2 = w1.write(2, convert([[8.0, 9.0], [124.0, 125.0]])) return w2.concat() self.assertAllEqual( convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0], - [8.0, 9.0], [204.0, 205.0]]), + [8.0, 9.0], [124.0, 125.0]]), self.evaluate(xla.compile(fn)[0])) @test_util.disable_control_flow_v2("b/122315751 (concat)") diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 39e80cd1f5b452..228052964d3677 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -206,6 +206,7 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core/platform:logging", + "//tensorflow/core/profiler/lib:annotated_traceme", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -417,6 +418,7 @@ tf_cuda_library( "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", "//tensorflow/core/platform:status", + "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core:stream_executor_headers_lib", ] + if_tensorrt([":tensorrt_lib"]), ) diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.cc b/tensorflow/compiler/tf2tensorrt/common/utils.cc index 85546f22022c33..251ac0ce1a4ba2 100644 --- a/tensorflow/compiler/tf2tensorrt/common/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/common/utils.cc @@ -20,7 +20,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/profiler/lib/traceme.h" #include "third_party/tensorrt/NvInferPlugin.h" + #endif namespace tensorflow { @@ -58,6 +60,8 @@ namespace tensorrt { Status GetTrtBindingIndex(const char* tensor_name, int profile_index, const nvinfer1::ICudaEngine* cuda_engine, int* binding_index) { + tensorflow::profiler::TraceMe activity( + "GetTrtBindingIndex", tensorflow::profiler::TraceMeLevel::kInfo); // If the engine has been built for K profiles, the first getNbBindings() / K // bindings are used by profile number 0, the following getNbBindings() / K // bindings are used by profile number 1 etc. diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 457d373d93223f..453209a4484e68 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -3951,13 +3951,26 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertRange) { return dtype == DT_INT32 ? static_cast(value) : value; }; + // A function that builds the next lexicographically greater configuration + // for the current one. The configuration is described as a (0,1)-vector + // config, where config[i] is 0 or 1 when the i-th parameter is passed as + // a weight or tensor, respectively. The function returns TRUE if such + // a configuration is built, or FALSE otherwise. + auto nextTensorWeigtConfiguration = [this](std::vector& config) { + for (int i = config.size(); i-- > 0;) { + if (config[i] = 1 - config[i]) return true; + } + return false; + }; + auto set_parameters = [this](const std::array& name, const std::array, 3>& value, const std::array& type, - bool all_tensors = false, int shape_idx = -1) { + const std::vector& config, + int shape_idx = -1) { Reset(); for (int i = 0; i < 3; i++) { - if (all_tensors) { + if (config[i]) { std::vector partial_shape_dims = {}; // The correct partial shape will be provided // (a) for all parameters, when shape_idx > 3 @@ -3992,109 +4005,135 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertRange) { ops::Placeholder(s.WithOpName(param_name[1]), param_type[1]), ops::Placeholder(s.WithOpName(param_name[2]), param_type[2])); - const NodeDef& node_def = range.operation.node()->def(); + const NodeDef& ndef = range.operation.node()->def(); const std::vector param_types{DT_FLOAT, DT_HALF, DT_INT32}; // ConverterRange is not implemented for Implicite batch mode. + std::vector config(3, 0); if (trt_mode_ == TrtTestMode::kImplicitBatch) { - for (bool all_tensors : {false, true}) { - set_parameters(param_name, param_value, param_type, all_tensors); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + do { + set_parameters(param_name, param_value, param_type, config); + RunValidationAndConversion(ndef, error::UNIMPLEMENTED, "Conversion for Range is not implemented in " "implicit batch mode"); - } + } while (nextTensorWeigtConfiguration(config)); + return; } - const std::string expected_msg = convert_range_expected_msg(node_def); - { - // We expect that all three (start, limit and delta) are passed as weights - // OR tensors and we reject parameters, if it's not true. - Reset(); - // Passing (start, limit) as weights - for (int i = 0; i < 2; i++) { - AddTestWeights(param_name[i], {1}, param_value[i], param_type[i]); - } - // ... and delta as a tensor - AddTestTensor(param_name[2], {1}, param_type[2], param_value[2]); - - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - expected_msg + "passed as weights OR tensors"); - } + const std::string expect_msg = convert_range_expected_msg(ndef); + bool all_weights = true; + do { + for (auto limit_type : param_types) { + param_type[1] = limit_type; + for (auto delta_type : param_types) { + param_type[2] = delta_type; - nvinfer1::DataType trt_type; - TF_ASSERT_OK(TfTypeToTrtType(tf_type_, &trt_type)); - const std::string expected = DebugString(trt_type); + const auto all_integers = start_type == DT_INT32 && + limit_type == DT_INT32 && + delta_type == DT_INT32; + + if (all_weights || all_integers && !config[2]) { + // Reject invalid parameters if delta = 0 and it's passed as a weight. + param_value[2] = {0}; + set_parameters(param_name, param_value, param_type, config); + RunValidationAndConversion( + ndef, error::INVALID_ARGUMENT, + "The delta parameter of Range operation cannot be equal to 0"); + + if (!all_weights && !config[2]) { + param_value[2] = {-1}; + set_parameters(param_name, param_value, param_type, config); + const string err = StrCat( + "The delta parameter of Range operation " + "cannot be negative, when one of (start, limit) is passed as " + "a tensor, but got ", + param_value[2][0]); + RunValidationAndConversion(ndef, error::INVALID_ARGUMENT, err); + } + } - // Reject invalid parameters if delta = 0 (for weights only). - for (auto limit_type : param_types) { - param_type[1] = limit_type; - for (auto delta_type : param_types) { - param_type[2] = delta_type; - param_value[2] = {0}; + if (all_weights) { + // Reject invalid parameters preventing the limit from + // being reached for fixed values of start and delta. + for (int j = 0; j <= 1; j++) { + param_value[j] = {get_casted_value(start, tf_type_)}; + param_value[1 - j] = {get_casted_value(limit, limit_type)}; + param_value[2] = {(2 * j - 1) * + get_casted_value(delta, delta_type)}; + set_parameters(param_name, param_value, param_type, config); + const auto error = convert_range_error_msg( + param_value[0][0], param_value[1][0], param_value[2][0]); + RunValidationAndConversion(ndef, error::INVALID_ARGUMENT, error); + } + } - set_parameters(param_name, param_value, param_type); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "The delta parameter of Range operation cannot be equal to 0"); - - // Reject invalid parameters preventing the limit from - // being reached for fixed values of start and delta. - for (int j = 0; j <= 1; j++) { - param_value[j] = {get_casted_value(start, tf_type_)}; - param_value[1 - j] = {get_casted_value(limit, limit_type)}; - param_value[2] = {(2 * j - 1) * get_casted_value(delta, delta_type)}; - set_parameters(param_name, param_value, param_type); - const auto error = convert_range_error_msg( - param_value[0][0], param_value[1][0], param_value[2][0]); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, error); + param_value[0] = {start}; + param_value[2] = {delta}; + if (all_integers) { + if (trt_mode_ == TrtTestMode::kDynamicShape) { + // Wrong dimension for the parameter passed as a tensor. + for (int j = 0; j < 3; j++) { + if (!config[j]) continue; + + const string err = + StrCat("Dimension for '", param_name[j], + "' of Range operator should be equal to 1"); + set_parameters(param_name, param_value, param_type, config, j); + RunValidationAndConversion(ndef, error::INVALID_ARGUMENT, err); + } + } + } else { + if (!all_weights) { + // The following test should fail, when + // (a) at least one parameter is passed as a tensor; + // (b) at least one parameter is not of type DT_INT32. + set_parameters(param_name, param_value, param_type, config); + RunValidationAndConversion(ndef, error::UNIMPLEMENTED, expect_msg); + } + } } + } + // All other configs will be set so that at least one parameter + // will be passed as a tensor + all_weights = false; + } while (nextTensorWeigtConfiguration(config)); - param_value[0] = {start}; - // When passed as tensors, all parameters should be of DT_INT32 type. - if (start_type == DT_INT32 && limit_type == DT_INT32 && - delta_type == DT_INT32) { - if (trt_mode_ == TrtTestMode::kDynamicShape) { - // Wrong dimension for one of parameters. - for (int j = 0; j < 3; j++) { - const string err = - StrCat("Dimension for '", param_name[j], - "' of Range operator should be equal to 1"); - set_parameters(param_name, param_value, param_type, true, j); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, err); + nvinfer1::DataType trt_type; + TF_ASSERT_OK(TfTypeToTrtType(DT_BOOL, &trt_type)); + const std::string error_msg = + "Unsupported data type " + DebugString(trt_type) + " used for '"; + do { + for (auto limit_type : param_types) { + param_type[1] = limit_type; + for (auto delta_type : param_types) { + param_type[2] = delta_type; + + for (int i = 0; i < 3; i++) { + if (!config[i]) { + const auto saved_type = param_type[i]; + param_type[i] = DT_BOOL; + set_parameters(param_name, param_value, param_type, config); + param_type[i] = saved_type; + RunValidationAndConversion(ndef, error::INVALID_ARGUMENT, + error_msg + param_name[i] + "'"); } } - } else { - // When at least one parameter is set as non-integer tensors, - // the following test should fail. - set_parameters(param_name, param_value, param_type, true); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - expected_msg + "tensors"); } } - } + } while (nextTensorWeigtConfiguration(config)); // The tests that pass all checks in ConvertRange::Validate(). const Status status = Status::OK(); const std::vector int_type{DT_INT32}; - for (bool all_tensors : {false, true}) { - // For now when (start, limit, delta) are passed as tensors - // these tensors should be of DT_INT32 type. - int partial_shape_idx = -1; - if (all_tensors) { - if (start_type != DT_INT32) { - continue; - } - if (trt_mode_ == TrtTestMode::kDynamicShape) { - // The correct partial shape will be provided for all parameters - partial_shape_idx = 3; - } - } - - // For now only parameters of DT_INT32 type could be used when - // they are pased as tensors. - const auto& types = all_tensors ? int_type : param_types; - const auto jEnd = all_tensors ? 0 : 1; + int partial_shape_idx = -1; + all_weights = true; + do { + // For now when at least one of (start, limit, delta) is passed as a tensor + // (a) all these parameters should be of DT_INT32 type; + // (b) only positive delta could be used. + const auto& types = all_weights ? param_types : int_type; + const auto jEnd = all_weights ? 1 : 0; for (auto limit_type : types) { param_type[1] = limit_type; for (auto delta_type : types) { @@ -4120,15 +4159,24 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertRange) { value += delta_curr; } - set_parameters(param_name, param_value, param_type, all_tensors, + set_parameters(param_name, param_value, param_type, config, partial_shape_idx); const std::vector output_dims = {num_values}; - TestOpConverter("my_range", node_def, output_dims, status, status, + TestOpConverter("my_range", ndef, output_dims, status, status, ElementsAreArray(expected_output)); } } } - } + + if (all_weights) { + if (start_type != DT_INT32) break; + if (trt_mode_ == TrtTestMode::kDynamicShape) partial_shape_idx = 3; + + // All other configs will be set so that at least one parameter + // will be passed as a tensor + all_weights = false; + } + } while (nextTensorWeigtConfiguration(config)); } TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertLikeOps) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc index 8851c935e337af..4c848e8a87be81 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc @@ -132,58 +132,77 @@ class ConvertRange : public ConvertFillBase { const auto& inputs = params.inputs; const auto& node_def = params.node_def; - if (!all_same_types(inputs)) { - return errors::InvalidArgument(convert_range_expected_msg(node_def), - "passed as weights OR tensors"); + float param[3]; + all_weights_ = all_integers_ = true; + for (int i = 0; i < 3; i++) { + const auto& input = inputs.at(i); + all_integers_ &= input.TrtDType() == nvinfer1::DataType::kINT32; + if (input.is_weights()) { + switch (input.TrtDType()) { + case nvinfer1::DataType::kFLOAT: + param[i] = get_input_param(input); + break; + case nvinfer1::DataType::kHALF: + param[i] = get_input_param(input); + break; + case nvinfer1::DataType::kINT32: + param[i] = get_input_param(input); + break; + default: + return errors::InvalidArgument( + "Unsupported data type ", DebugString(input.TrtDType()), + " used for '", InputSpec()[i].name, "'"); + } + } else { + all_weights_ = false; + } } - if (!all_weights_) { - if (!all_integers(inputs)) { - return errors::Unimplemented(convert_range_expected_msg(node_def), - "tensors"); + if (!(all_weights_ || all_integers_)) { + // As of 06/03/2022, when at least one of the (start, limit, delta) + // is passed as a tensor, they must all be of type kINT32 + return errors::Unimplemented(convert_range_expected_msg(node_def)); + } + + if (inputs.at(2).is_weights()) { + if ((delta_ = param[2]) == 0) { + return errors::InvalidArgument("The delta parameter of ", node_def.op(), + " operation cannot be equal to 0"); } - for (int i = 0; i < 3; i++) { - const auto& dims = inputs.at(i).GetTrtDims(); - if (dims.nbDims != 1 || dims.d[0] != 1) { - return errors::InvalidArgument("Dimension for '", InputSpec()[i].name, - "' of ", node_def.op(), " operator ", - "should be equal to 1"); - } + if (!all_weights_ && delta_ < 0) { + return errors::InvalidArgument( + "The delta parameter of Range operation " + "cannot be negative, when one of (start, limit) is passed as " + "a tensor, but got ", + delta_); } - return Status::OK(); } - float param[3]; for (int i = 0; i < 3; i++) { const auto& input = inputs.at(i); - switch (input.TrtDType()) { - case nvinfer1::DataType::kFLOAT: - param[i] = get_input_param(input); - break; - case nvinfer1::DataType::kHALF: - param[i] = get_input_param(input); - break; - default: // nvinfer1::DataType::kINT32: - param[i] = get_input_param(input); + const auto& dims = input.GetTrtDims(); + if (dims.nbDims != 1 || dims.d[0] != 1) { + return errors::InvalidArgument("Dimension for '", InputSpec()[i].name, + "' of ", node_def.op(), " operator ", + "should be equal to 1"); } } - if ((delta_ = param[2]) == 0) { - return errors::InvalidArgument("The delta parameter of ", node_def.op(), - " operation cannot be equal to 0"); - } + if (all_weights_) { + const auto num_intervals_float = + (param[1] - (start_ = param[0])) / delta_; + if (num_intervals_float < 0) { + const auto error = convert_range_error_msg(start_, param[1], delta_); + return errors::InvalidArgument(error); + } - const auto num_intervals_float = (param[1] - (start_ = param[0])) / delta_; - if (num_intervals_float < 0) { - const auto error = convert_range_error_msg(start_, param[1], delta_); - return errors::InvalidArgument(error); + num_values_ = static_cast(num_intervals_float); + if (start_ + delta_ * num_values_ != param[1]) { + num_values_++; + } } - num_values_ = static_cast(num_intervals_float); - if (start_ + delta_ * num_values_ != param[1]) { - num_values_++; - } return Status::OK(); } @@ -192,7 +211,6 @@ class ConvertRange : public ConvertFillBase { const auto& inputs = params.inputs; const TRT_TensorOrWeights& input = inputs.at(0); TRT_TensorOrWeights value_input; - nvinfer1::Dims trt_dims{1}; auto builder = TRTNetworkBuilder::Create(params.converter->network(), params.weight_store); @@ -201,14 +219,19 @@ class ConvertRange : public ConvertFillBase { ITensorProxyPtr beta_tensor = nullptr; ITensorProxyPtr scalar_tensor = nullptr; if (!all_weights_) { + ITensorProxyPtr tensors[3]; + for (int i = 0; i < 3; i++) { + TF_RETURN_IF_ERROR( + builder->get_tensor4TensorOrWeights(inputs.at(i), tensors + i)); + } + StatusOr num = - builder->Sub(/*limit*/ inputs.at(1).tensor()->trt_tensor(), - /*start*/ inputs.at(0).tensor()->trt_tensor()); + builder->Sub(/*limit*/ tensors[1]->trt_tensor(), + /*start*/ tensors[0]->trt_tensor()); TRT_ENSURE_PTR_OK(num); - beta_tensor = params.inputs.at(2).tensor(); StatusOr ceil_div = builder->FloorDiv( - (*num)->getOutput(0), beta_tensor->trt_tensor() /*delta*/); + (*num)->getOutput(0), (beta_tensor = tensors[2])->trt_tensor()); TRT_ENSURE_PTR_OK(ceil_div); dims_input_tensor = (*ceil_div)->getOutput(0); dims_input_tensor->setType(nvinfer1::DataType::kINT32); @@ -241,7 +264,7 @@ class ConvertRange : public ConvertFillBase { trt_dims, scalar_tensor, beta_tensor, delta_); ITensorProxyPtr output_tensor = (*layer)->getOutput(0); - if (all_integers(inputs)) { + if (all_integers_) { output_tensor->setType(nvinfer1::DataType::kINT32); } @@ -255,31 +278,11 @@ class ConvertRange : public ConvertFillBase { return static_cast(*input.weights().GetPointer()); } - bool all_integers(const std::vector& inputs) const { - for (int i = 0; i < 3; i++) { - if (inputs.at(i).TrtDType() != nvinfer1::DataType::kINT32) { - return false; - } - } - return true; - } - - bool all_same_types(const std::vector& inputs) { - auto i = inputs.size(); - const bool is_weight = inputs.at(--i).is_weights(); - while (i--) { - if (inputs.at(i).is_weights() != is_weight) { - return all_weights_ = false; - } - } - all_weights_ = is_weight; - return true; - } - float start_; float delta_; int num_values_; bool all_weights_; + bool all_integers_; }; std::string convert_range_error_msg(float start, float limit, float delta) { @@ -291,8 +294,9 @@ std::string convert_range_error_msg(float start, float limit, float delta) { } std::string convert_range_expected_msg(const NodeDef& node_def) { - return "All parameters (start, limit, delta) of " + node_def.op() + - " operation in " + node_def.name() + " are expected to be "; + return "When at least one of parameters (start, limit, delta) of " + + node_def.op() + " operation in " + node_def.name() + + " is passed as a tensor, they must all be of type kINT32"; } REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction(), "Fill"); diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h b/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h index 37015593e77a77..04fb1845affd08 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h @@ -341,6 +341,18 @@ class TRTNetworkBuilder { return const_layer; } + Status get_tensor4TensorOrWeights(const TRT_TensorOrWeights& input, + ITensorProxyPtr* pTensor) { + if (input.is_weights()) { + StatusOr const_layer = WeightsToConstant( + input.weights().GetTrtWeights(), input.GetTrtDims()); + if (!const_layer.status().ok()) return const_layer.status(); + *pTensor = (*const_layer)->getOutput(0); + } else { + *pTensor = input.tensor(); + } + return Status::OK(); + } // Creates a nvinfer1::Weights object containing a single scalar. template ::value>::type* = nullptr> diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 0826d932eb26e2..3765f35c14427e 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -97,8 +97,11 @@ class ContextDeviceMemory { "Out of GPU memory for execution context"); } } - execution_context_->setDeviceMemory(device_memory_); - + { + tensorflow::profiler::TraceMe activity( + "setDeviceMemory", tensorflow::profiler::TraceMeLevel::kInfo); + execution_context_->setDeviceMemory(device_memory_); + } return Status::OK(); } @@ -967,6 +970,9 @@ Status TRTEngineOp::ExecuteTrtEngine( ContextDeviceMemory context_device_memory; if (!has_device_memory) { + tensorflow::profiler::TraceMe activity( + "TRTEngineOp::AllocateDeviceMemory", + tensorflow::profiler::TraceMeLevel::kInfo); // Allocate device memory for the TensorRT engine execution. The device // memory will be released when context_device_memory goes out of scope. TF_RETURN_IF_ERROR(context_device_memory.AllocateDeviceMemory( @@ -979,6 +985,9 @@ Status TRTEngineOp::ExecuteTrtEngine( Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx, TRTEngineCacheResource** cache_res) { + tensorflow::profiler::TraceMe activity( + "TRTEngineOp::GetEngineCachResource", + tensorflow::profiler::TraceMeLevel::kInfo); // Canonicalize the op name by removing the scopes if any. This is mainly // because in TFv2, the function graph can be instantiated in various ways and // it'll insert scope names to the name of the TRTEngineOps, which will result @@ -1050,7 +1059,8 @@ StatusOr> TRTEngineOp::GetEngine( const std::vector& input_concrete_shapes, OpKernelContext* ctx, TRTEngineCacheResource* cache_res) { static EngineContext empty_context; - + tensorflow::profiler::TraceMe activity( + "TRTEngineOp::GetEngine", tensorflow::profiler::TraceMeLevel::kInfo); mutex_lock lock(engine_mutex_); // Using first input to get batch size is reliable - VerifyInputShapes() // guarantees that the first input is not a scalar. As such we can always use diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc index fdb8004b027f36..fcfe1697522fa3 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -211,6 +211,17 @@ class SerializeTRTResource : public OpKernel { int num_serialized_engines = 0; if (save_gpu_specific_engines_) { + // If user requests TRT engines export, recursively create + // requisite directories. + const char* export_trt_engines_env = + getenv("TF_TRT_EXPORT_TRT_ENGINES_PATH"); + if (export_trt_engines_env) { + VLOG(1) << "Exporting TRT engines to directory: " + << export_trt_engines_env; + OP_REQUIRES_OK( + ctx, ctx->env()->RecursivelyCreateDir(export_trt_engines_env)); + } + for (const auto& pair : resource->cache_) { // Ignore engines that failed to build. const std::unique_ptr& engine = pair.second; @@ -228,6 +239,28 @@ class SerializeTRTResource : public OpKernel { engine_instance.set_serialized_engine(engine_data->data(), engine_data->size()); + if (export_trt_engines_env) { + const std::string engine_filename = + std::string(export_trt_engines_env) + "/" + resource_name; + std::unique_ptr engine_file; + OP_REQUIRES_OK( + ctx, ctx->env()->NewWritableFile(engine_filename, &engine_file)); + OP_REQUIRES_OK(ctx, engine_file->Append(StringPiece( + static_cast(engine_data->data()), + engine_data->size()))); + + const std::string dims_filename = + std::string(export_trt_engines_env) + "/dims-" + resource_name; + std::unique_ptr dims_file; + OP_REQUIRES_OK( + ctx, ctx->env()->NewWritableFile(dims_filename, &dims_file)); + + for (const TensorShape& shape : engine_input_shapes) { + OP_REQUIRES_OK(ctx, + dims_file->Append(StringPiece(shape.DebugString()))); + } + } + OP_REQUIRES_OK( ctx, writer->WriteRecord(engine_instance.SerializeAsString())); ++num_serialized_engines; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc index ac35939fea832c..b362331644e263 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/profiler/lib/traceme.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" @@ -60,6 +61,8 @@ Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine, const nvinfer1::IExecutionContext* execution_context, int binding_index, bool use_implicit_batch, int batch_size, TensorShape& shape) { + tensorflow::profiler::TraceMe activity( + "getBindingDimensions", tensorflow::profiler::TraceMeLevel::kInfo); nvinfer1::Dims dims = use_implicit_batch ? cuda_engine->getBindingDimensions(binding_index) @@ -79,6 +82,8 @@ Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine, Status SetupBindings(nvinfer1::ICudaEngine* cuda_engine, const Tensor& tensor, std::vector& buffers, int binding_index) { + tensorflow::profiler::TraceMe activity( + "SetBindingPointers", tensorflow::profiler::TraceMeLevel::kInfo); const auto dtype = cuda_engine->getBindingDataType(binding_index); VLOG(2) << "<<<<<<<<< SetupBindings with dtype = " << (int)dtype; switch (dtype) { @@ -114,6 +119,8 @@ Status SetTrtEngineInputs(nvinfer1::ICudaEngine* cuda_engine, int num_batch, const TrtShapeOptimizationProfile& profiles, OpKernelContext* ctx, const DataVec* input_vec) { + tensorflow::profiler::TraceMe activity( + "SetTrtEngineInputs", tensorflow::profiler::TraceMeLevel::kInfo); int n_inputs = ctx ? ctx->num_inputs() : (input_vec ? input_vec->size() : 0); // Setup engine inputs. for (int i = 0; i < n_inputs; i++) { @@ -150,6 +157,9 @@ Status SetTrtEngineInputs(nvinfer1::ICudaEngine* cuda_engine, i, binding_index, cuda_engine, execution_context)); if (cuda_engine->isExecutionBinding(binding_index)) { + tensorflow::profiler::TraceMe activity( + "SetTrtEngineInputs::setBindingDimensions", + tensorflow::profiler::TraceMeLevel::kInfo); nvinfer1::Dims trt_dims; auto adap = DimsAdapter::Create(input_shape); TRT_ENSURE_OK(adap); @@ -187,6 +197,8 @@ Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine, int trt_profile_idx, std::vector& buffers, bool use_implicit_batch, int batch_size, OpKernelContext* ctx, DataVec* outputs) { + tensorflow::profiler::TraceMe activity( + "SetTrtEngineOutputs", tensorflow::profiler::TraceMeLevel::kInfo); // Either one of ctx or outpus should be specified int n_outputs = ctx ? ctx->num_outputs() : (outputs ? outputs->size() : 0); for (int i = 0; i < n_outputs; i++) { @@ -205,6 +217,8 @@ Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine, // Allocate output tensor of TRTEngineOp. Tensor* output_tensor = nullptr; if (ctx) { + tensorflow::profiler::TraceMe activity( + "AllocateOutput", tensorflow::profiler::TraceMeLevel::kInfo); TF_RETURN_IF_ERROR(ctx->allocate_output(i, output_shape, &output_tensor)); } else { // This path is used for unit tests. The tensor is already allocated. @@ -231,6 +245,8 @@ Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine, Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context, std::vector& buffers, cudaStream_t stream, bool use_implicit_batch, int batch_size) { + tensorflow::profiler::TraceMe activity( + "TrtEnqueue", tensorflow::profiler::TraceMeLevel::kInfo); bool ret = false; if (use_implicit_batch) { ret = execution_context->enqueue(batch_size, &buffers[0], stream, nullptr); diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 0ba26e67c7b8a2..5c4a6c1fdd8fed 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -185,7 +185,6 @@ struct EngineContext { // latency. Since its value remains constant, we can cache it. size_t device_memory_size_; }; - // Contains the context required to build the calibration data. class CalibrationContext { public: diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc index 5f96b5f55be777..21f6be4a964561 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/profiler/lib/traceme.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -139,6 +140,9 @@ void TrtShapeOptimizationProfile::OptimalStrategy( // Collects the values of tensors that are ShapeTensorCompatible to. The values // are stored in the actual_shape_values_ member variable. Status TrtShapeOptimizationProfile::CollectShapeValues(OpKernelContext* ctx) { + tensorflow::profiler::TraceMe activity( + "TrtShapeOptimizationProfile::CollectShapeValues", + tensorflow::profiler::TraceMeLevel::kInfo); const cudaStream_t* stream = CHECK_NOTNULL( reinterpret_cast(ctx->op_device_context() ->stream() @@ -466,6 +470,9 @@ void TrtShapeOptimizationProfile::SetShapeTensorMask( int TrtShapeOptimizationProfile::GetProfileNumber( const std::vector& shapes) { + tensorflow::profiler::TraceMe activity( + "TrtShapeOptimizationProfile::GetProfileNumber", + tensorflow::profiler::TraceMeLevel::kInfo); if (!need_profiles_) return 0; // TODO(tfeher): Return the best profile not just the first compatible. for (int i = 0; i < profiles_.size(); i++) { @@ -509,6 +516,9 @@ Status TrtShapeOptimizationProfile::CreateExecutionContexts( Status TrtShapeOptimizationProfile::SetInputShapeBinding( int input_index, int binding_index, nvinfer1::ICudaEngine* cuda_engine, nvinfer1::IExecutionContext* exec_context) const { + tensorflow::profiler::TraceMe activity( + "TrtShapeOptimizationProfile::SetInputShapeBinding", + tensorflow::profiler::TraceMeLevel::kInfo); if (cuda_engine->isShapeBinding(binding_index)) { // Input shape binding data has to be in host memory. That is the reason // we can't use input_tensor.flat().data(). which contains the same diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index de98909212547d..f1626ec86dc565 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -2576,8 +2576,8 @@ The arguments of scatter should follow these constraints: `inserted_window_dims.size`. - `scatter_dims_to_operand_dims.size` must be equal to - `scatter_indices`[`index_vector_dim`], and its values must be in the range - `[0, operand.rank)`. + `scatter_indices.shape.dims`[`index_vector_dim`], and its values must be in + the range `[0, operand.rank)`. For a given index `U` in each `updates` array, the corresponding index `I` in the corresponding `operands` array into which this update has to be applied is diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 0b7e6f30e57d69..b20fa4f07c501c 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -249,6 +249,7 @@ cc_library( "//tensorflow/stream_executor:stream", "//tensorflow/stream_executor/host:host_platform_id", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h index 60ce3b97a0a2cb..b23f0d8ad0b838 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h @@ -156,8 +156,8 @@ typedef struct { PJRT_Device** addressable_devices; // out size_t num_addressable_devices; // out } PJRT_Client_AddressableDevices_Args; -const size_t PJRT_Client_AddressableDevices_Args_STRUCT_SIZE = - PJRT_STRUCT_SIZE(PJRT_Client_AddressableDevices_Args, addressable_devices); +const size_t PJRT_Client_AddressableDevices_Args_STRUCT_SIZE = PJRT_STRUCT_SIZE( + PJRT_Client_AddressableDevices_Args, num_addressable_devices); // Returns a list of devices that are addressable from the client. // Addressable devices are those that the client can issue commands to. @@ -332,48 +332,46 @@ typedef PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args); // -------------------------------- API access --------------------------------- -#define PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type +#define _PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type // Please modify PJRT_Api_STRUCT_SIZE if the last field of PJRT_Api is changed. typedef struct { size_t struct_size; void* priv; - PJRT_API_STRUCT_FIELD(PJRT_Error_Destroy); - PJRT_API_STRUCT_FIELD(PJRT_Error_Message); - - PJRT_API_STRUCT_FIELD(PJRT_Client_Create); - PJRT_API_STRUCT_FIELD(PJRT_Client_Destroy); - PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformName); - PJRT_API_STRUCT_FIELD(PJRT_Client_ProcessIndex); - PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformVersion); - PJRT_API_STRUCT_FIELD(PJRT_Client_Devices); - PJRT_API_STRUCT_FIELD(PJRT_Client_AddressableDevices); - - PJRT_API_STRUCT_FIELD(PJRT_Device_Id); - PJRT_API_STRUCT_FIELD(PJRT_Device_ProcessIndex); - PJRT_API_STRUCT_FIELD(PJRT_Device_IsAddressable); - - PJRT_API_STRUCT_FIELD(PJRT_Executable_Destroy); - PJRT_API_STRUCT_FIELD(PJRT_Executable_Name); - PJRT_API_STRUCT_FIELD(PJRT_Executable_AddressableDevices); - PJRT_API_STRUCT_FIELD(PJRT_Executable_Delete); - PJRT_API_STRUCT_FIELD(PJRT_Executable_IsDeleted); - - PJRT_API_STRUCT_FIELD(PJRT_Buffer_Delete); - PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsDeleted); - PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsOnCpu); + _PJRT_API_STRUCT_FIELD(PJRT_Error_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_Error_Message); + + _PJRT_API_STRUCT_FIELD(PJRT_Client_Create); + _PJRT_API_STRUCT_FIELD(PJRT_Client_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformName); + _PJRT_API_STRUCT_FIELD(PJRT_Client_ProcessIndex); + _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformVersion); + _PJRT_API_STRUCT_FIELD(PJRT_Client_Devices); + _PJRT_API_STRUCT_FIELD(PJRT_Client_AddressableDevices); + + _PJRT_API_STRUCT_FIELD(PJRT_Device_Id); + _PJRT_API_STRUCT_FIELD(PJRT_Device_ProcessIndex); + _PJRT_API_STRUCT_FIELD(PJRT_Device_IsAddressable); + + _PJRT_API_STRUCT_FIELD(PJRT_Executable_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_Name); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_AddressableDevices); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_Delete); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_IsDeleted); + + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Delete); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsDeleted); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsOnCpu); } PJRT_Api; const size_t PJRT_Api_STRUCT_SIZE = PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Buffer_IsOnCpu); -#undef PJRT_API_STRUCT_FIELD +#undef _PJRT_API_STRUCT_FIELD #ifdef __cplusplus } #endif -#undef PJRT_API_STRUCT_FIELD - #endif // TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_H_ diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.cc b/tensorflow/compiler/xla/pjrt/gpu_device.cc index 47e5946ed65e73..ea1dd44fac1d49 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_device.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/gpu_device.h" #include +#include #include #include +#include #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" @@ -186,9 +188,8 @@ void EnablePeerAccess(absl::Span executors) { StatusOr>> BuildLocalDeviceStates( LocalClient* xla_client, bool asynchronous) { std::vector> addressable_devices; - for (int i = 0; i < xla_client->device_count(); ++i) { - se::StreamExecutor* executor = - xla_client->backend().stream_executor(i).ValueOrDie(); + for (se::StreamExecutor* executor : + xla_client->backend().stream_executors()) { addressable_devices.push_back(std::make_unique( executor, xla_client, LocalDeviceState::kComputeSynchronized, /*max_inflight_computations=*/32, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 3be5f477d1f2fe..30f51abcf0be1a 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -443,11 +443,11 @@ class PjRtClient { // the client can issue commands to. virtual int addressable_device_count() const = 0; - // Return all devices in the entire computation, including addressable and + // Return all devices known to the client, including addressable and // non-addressable devices. virtual absl::Span devices() const = 0; - // Return only addressable devices. + // Return only addressable devices. The devices are in no particular order. virtual absl::Span addressable_devices() const = 0; // Lookup any PjRtDevice for a given PjRtDevice::id(). diff --git a/tensorflow/compiler/xla/pjrt/pjrt_future.h b/tensorflow/compiler/xla/pjrt/pjrt_future.h index e8ba1f45b63d8f..a7cabb9e7a7c41 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_future.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_future.h @@ -192,7 +192,7 @@ class PjRtFuture { // `IsReady()` was called. `IsReady()` will return immediately if a call to // `Await()` has already returned, or any callback passed to `OnReady` has // already been triggered. Otherwise IsReady() may block for the duration of a - // network message on some backends." + // network message on some backends. bool IsReady() { return promise_ref_.IsAvailable(); } // `IsKnownReady()` is guaranteed to return immediately. `IsKnownReady()` will // always return true if a call to `Await()` has already returned, or any diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 673eb0864115ce..77976d064d4445 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -74,6 +74,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/casts.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" @@ -239,18 +240,18 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient( << "Duplicate device id: " << device->id(); if (device->IsAddressable()) { - int idx = device->local_hardware_id(); - if (idx >= addressable_devices_.size()) { - addressable_devices_.resize(idx + 1); - } - CHECK(addressable_devices_[idx] == nullptr) << idx; - addressable_devices_[idx] = device.get(); + addressable_devices_.push_back(device.get()); } device->SetClient(this); } - for (int idx = 0; idx < addressable_devices_.size(); ++idx) { - CHECK(addressable_devices_[idx] != nullptr) << idx; - } + // TODO(phawkins): we don't really promise anything about the order of + // these devices, but users may be depending on the current order. Sort into + // device ordinal order, which is the historical order these values have + // appeared. + absl::c_sort(addressable_devices_, + [](const PjRtDevice* a, const PjRtDevice* b) { + return a->local_hardware_id() < b->local_hardware_id(); + }); } StatusOr PjRtStreamExecutorClient::GetDefaultDeviceAssignment( diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index 1d76b35f478448..36e0b31417a6cc 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -265,7 +265,7 @@ class PjRtStreamExecutorClient : public PjRtClient { LocalDeviceState& device_state(int device_ordinal) const { return *tensorflow::down_cast( - addressable_devices_.at(device_ordinal)) + LookupAddressableDevice(device_ordinal).ValueOrDie()) ->local_device_state(); } LocalClient* client() const { return client_; } diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc index bcdb41ff91c8c0..51fed0481e44db 100644 --- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc @@ -586,11 +586,10 @@ StatusOr> TfrtCpuClient::BufferFromHostLiteral( // It is OK to capture `buffer` pointer because the `output_buffer` can't be // deleted until all the usage holds have gone away. tfrt::EnqueueWork( - GetHostContext(), - [literal, av = avs[0].CopyRef(), - movable_device_buffer{device_buffer.ToClosure()}, shape]() mutable { + GetHostContext(), [literal, av = avs[0].CopyRef(), + db = std::move(device_buffer), shape]() mutable { tensorflow::profiler::TraceMe traceme("H2D Dispatch"); - TfrtCpuBuffer::ScopedHold device_buffer(movable_device_buffer); + TfrtCpuBuffer::ScopedHold device_buffer = std::move(db); const std::shared_ptr& b = device_buffer->Buffers()[0]; CHECK_EQ(literal.size_bytes(), b->size()); @@ -607,11 +606,10 @@ StatusOr> TfrtCpuClient::BufferFromHostLiteral( // It is OK to capture `buffer` pointer because the `output_buffer` can't // be deleted until all the usage holds have gone away. tfrt::EnqueueWork( - GetHostContext(), - [i, literal, av = avs[i].CopyRef(), shape, - movable_device_buffer{device_buffer.ToClosure()}]() mutable { + GetHostContext(), [i, literal, av = avs[i].CopyRef(), shape, + db = std::move(device_buffer)]() mutable { tensorflow::profiler::TraceMe traceme("H2D Dispatch"); - TfrtCpuBuffer::ScopedHold device_buffer(movable_device_buffer); + TfrtCpuBuffer::ScopedHold device_buffer = std::move(db); auto slice = LiteralSlice(literal, {i}); const std::shared_ptr& b = device_buffer->Buffers()[i]; @@ -641,6 +639,22 @@ TfrtCpuBuffer::ScopedHold::ScopedHold(ScopedHold&& other) other.SetState(kMoved); } +TfrtCpuBuffer::ScopedHold& TfrtCpuBuffer::ScopedHold::operator=( + ScopedHold&& other) { + if (ok()) { + parent_->DropHold(type_, buffer().get()); + } + parent_ = other.parent_; + type_ = other.type_; + state_ = other.state_; + status_ = std::move(other.status_); + buffer_ = std::move(other.buffer_); + // Preserve the invariant that status is invalid if buffer == nullptr. + other.SetState(kMoved); + + return *this; +} + void TfrtCpuBuffer::ScopedHold::Acquire( StatusOr>&& buffer_or) { CHECK(!ok()); @@ -656,14 +670,6 @@ void TfrtCpuBuffer::ScopedHold::Acquire( CHECK(!ok() || buffer_ != nullptr); } -TfrtCpuBuffer::ScopedHold::ForClosure TfrtCpuBuffer::ScopedHold::ToClosure() { - CHECK(ok()); - ForClosure for_closure(parent_, type_, state_, std::move(status_), - std::move(buffer_)); - SetState(kReleased); - return for_closure; -} - void TfrtCpuBuffer::ScopedHold::ConvertUsageHold( absl::Span> events) { CHECK(ok()); @@ -1045,11 +1051,11 @@ PjRtFuture TfrtCpuBuffer::ToLiteral(MutableLiteralBase* literal) { // parallel. EnqueueWorkWhenReady( host_ctx, device_buffer_wait_avs, - [this, movable_device_buffer{device_buffer.ToClosure()}, + [this, db = std::move(device_buffer), device_buffer_wait_avs = std::move(device_buffer_wait_avs_copy), - literal, ready_event = ready_event.CopyRef()] { + literal, ready_event = ready_event.CopyRef()]() mutable { tensorflow::profiler::TraceMe traceme("D2H Dispatch"); - TfrtCpuBuffer::ScopedHold device_buffer(movable_device_buffer); + TfrtCpuBuffer::ScopedHold device_buffer = std::move(db); // Errors in src buffer are surfaced to user. for (const auto& av : device_buffer_wait_avs) { if (auto* error = av->GetErrorIfPresent()) { diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h index 56f92ccf5ac7f2..015005dd9b4b8e 100644 --- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -339,6 +339,7 @@ class TfrtCpuBuffer final : public PjRtBuffer { ~ScopedHold(); ScopedHold(ScopedHold&& other); + ScopedHold& operator=(ScopedHold&& other); ScopedHold(const ScopedHold&) = delete; ScopedHold& operator=(const ScopedHold&) = delete; @@ -390,22 +391,8 @@ class TfrtCpuBuffer final : public PjRtBuffer { friend class TfrtCpuClient; friend class TfrtCpuBuffer; - // Helper struct that makes it possible to move a ScopedHold through a - // closure. - using ForClosure = std::tuple>; - ScopedHold(TfrtCpuBuffer* parent, Type type) : parent_(parent), type_(type), state_(kUninitialized) {} - explicit ScopedHold(const ForClosure& closure_helper) - : parent_(std::get<0>(closure_helper)), - type_(std::get<1>(closure_helper)), - state_(std::get<2>(closure_helper)), - status_(std::get<3>(closure_helper)), - buffer_(std::get<4>(closure_helper)) { - // Check the buffer is not in an error state. - CHECK(status_.ok() && buffer_ != nullptr); - } // Sets buffer state. void SetState(State state) { state_ = state; } @@ -413,14 +400,9 @@ class TfrtCpuBuffer final : public PjRtBuffer { // Sets buffer_ and status_. Called by parent_ to initialize the hold. void Acquire( StatusOr>&& buffer_or); - // Releases the contents of *this, so *this can subsequently be - // deleted without releasing the parent's hold. Should be passed to the - // appropriate constructor of another ScopedHold, e.g., when a hold must be - // passed through a closure that is incompatible with std::move. - ForClosure ToClosure(); - - TfrtCpuBuffer* const parent_; - const Type type_; + + TfrtCpuBuffer* parent_; + Type type_; // There is an invariant that if ok() then buffer_ != nullptr. State state_; diff --git a/tensorflow/compiler/xla/python/callback.cc b/tensorflow/compiler/xla/python/callback.cc index d0af6ca74c9cde..29ffe7ec4f88e6 100644 --- a/tensorflow/compiler/xla/python/callback.cc +++ b/tensorflow/compiler/xla/python/callback.cc @@ -15,19 +15,23 @@ limitations under the License. #include "tensorflow/compiler/xla/python/callback.h" +#include +#include +#include #include #include #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/python/exceptions.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "tensorflow/core/profiler/lib/traceme.h" namespace py = pybind11; namespace xla { -void CpuCallback::Call(void* result, void** arg_ptrs, - XlaCustomCallStatus* status) { +void CpuCallback::PrepareAndCall(void* result, void** arg_ptrs, + XlaCustomCallStatus* status) { absl::Span inputs(arg_ptrs, args_.size()); absl::Span outputs(reinterpret_cast(result), results_.size()); @@ -43,26 +47,59 @@ void CpuCallback::Call(void* result, void** arg_ptrs, args[i].attr("flags").attr("writeable") = Py_False; } } - py::object result_tuple; + std::optional maybe_result_tuple = Call(args, status); + if (!maybe_result_tuple) { + // Python function errored so we return early. + return; + } + py::tuple result_tuple = maybe_result_tuple.value(); + for (size_t i = 0; i < results_.size(); ++i) { + py::object output = py::reinterpret_borrow( + PyTuple_GetItem(result_tuple.ptr(), i)); + py::array array = py::cast(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == results_[i].expected_strides) { + std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); + } else { + xla::StatusOr> plan = + transpose_cache_.GetOrCreate( + xla::primitive_util::ByteWidth(results_[i].type), dims, + results_[i].reversed_layout, + /*input_layout=*/xla::TransposePlan::Striding{strides}); + if (!plan.ok()) { + throw xla::XlaRuntimeError(plan.status().ToString()); + } + plan.ValueOrDie()->Execute(array.data(), outputs[i]); + } + } +} + +std::optional CpuCallback::Call(py::tuple args, + XlaCustomCallStatus* status) { + py::object result_object; try { - result_tuple = callable_(*py::reinterpret_borrow(args)); + result_object = callable_(*py::reinterpret_borrow(args)); } catch (py::error_already_set& e) { PyErr_Clear(); std::string error_message = e.what(); XlaCustomCallStatusSetFailure(status, error_message.c_str(), error_message.length()); - return; + return std::nullopt; } - if (!PyTuple_Check(result_tuple.ptr())) { + if (!PyTuple_Check(result_object.ptr())) { throw xla::XlaRuntimeError( absl::StrFormat("CPU callback expected a tuple result, got %s", - static_cast(py::repr(result_tuple)))); + static_cast(py::repr(result_object)))); } - if (PyTuple_Size(result_tuple.ptr()) != results_.size()) { + if (PyTuple_Size(result_object.ptr()) != results_.size()) { throw xla::XlaRuntimeError( absl::StrFormat("CPU callback expected a tuple with %d results, got %d", - results_.size(), PyTuple_Size(result_tuple.ptr()))); + results_.size(), PyTuple_Size(result_object.ptr()))); } + py::tuple result_tuple = py::cast(result_object); for (size_t i = 0; i < results_.size(); ++i) { py::object output = py::reinterpret_borrow( PyTuple_GetItem(result_tuple.ptr(), i)); @@ -86,29 +123,15 @@ void CpuCallback::Call(void* result, void** arg_ptrs, i, absl::StrJoin(results_[i].expected_dims, ","), absl::StrJoin(dims, ","))); } - absl::Span strides( - reinterpret_cast(array.strides()), array.ndim()); - if (strides == results_[i].expected_strides) { - std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); - } else { - xla::StatusOr> plan = - transpose_cache_.GetOrCreate( - xla::primitive_util::ByteWidth(results_[i].type), dims, - results_[i].reversed_layout, - /*input_layout=*/xla::TransposePlan::Striding{strides}); - if (!plan.ok()) { - throw xla::XlaRuntimeError(plan.status().ToString()); - } - plan.ValueOrDie()->Execute(array.data(), outputs[i]); - } } + return result_tuple; } void XlaPythonCpuCallback(void* output, void** inputs, XlaCustomCallStatus* status) { CpuCallback* callback = absl::bit_cast(*static_cast(inputs[0])); - callback->Call(output, inputs + 1, status); + callback->PrepareAndCall(output, inputs + 1, status); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/callback.h b/tensorflow/compiler/xla/python/callback.h index fc2b3d5ff52d10..32078d6b244983 100644 --- a/tensorflow/compiler/xla/python/callback.h +++ b/tensorflow/compiler/xla/python/callback.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CALLBACK_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_CALLBACK_H_ +#include #include #include "pybind11/pybind11.h" @@ -61,7 +62,12 @@ class CpuCallback { const std::vector& results() const { return results_; } size_t num_results() const { return results_.size(); } - void Call(void* result, void** arg_ptrs, XlaCustomCallStatus* status); + xla::TransposePlanCache& transpose_cache() { return transpose_cache_; } + + void PrepareAndCall(void* result, void** arg_ptrs, + XlaCustomCallStatus* status); + std::optional Call(pybind11::tuple args, + XlaCustomCallStatus* status); private: pybind11::function callable_; diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index 0fbc867c6c40f4..11965dad1adbd0 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -577,7 +577,7 @@ StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( host_callback->callback = [callback = std::move(callback)](void** outputs, void** inputs) { - callback->Call(outputs, inputs, /*status=*/nullptr); + callback->PrepareAndCall(outputs, inputs, /*status=*/nullptr); }; py::capsule callback_capsule( diff --git a/tensorflow/compiler/xla/python/py_client_gpu.cc b/tensorflow/compiler/xla/python/py_client_gpu.cc index 88db733a8a64c9..4cb38cf44946e4 100644 --- a/tensorflow/compiler/xla/python/py_client_gpu.cc +++ b/tensorflow/compiler/xla/python/py_client_gpu.cc @@ -21,6 +21,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #endif +#include "pybind11/pybind11.h" #include "tensorflow/compiler/xla/python/callback.h" #include "tensorflow/compiler/xla/python/exceptions.h" @@ -38,6 +39,8 @@ limitations under the License. #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice #endif +namespace py = pybind11; + namespace xla { void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers, @@ -53,44 +56,78 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers, CpuCallback* callback = absl::bit_cast(static_cast(descriptor)); size_t arity = callback->num_args(); - size_t num_results = callback->num_results(); - std::vector host_input_buffers; - std::vector host_output_buffers; + std::vector host_input_buffers(arity); // Copy input GPU buffers to host for (size_t i = 0; i < arity; ++i) { CpuCallback::Arg arg = callback->args()[i]; + if (arg.type == TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } void* buf = new char[arg.size_in_bytes]; - host_input_buffers.push_back(buf); - gpuMemcpyAsync(host_input_buffers[i], buffers[i], arg.size_in_bytes, - gpuMemcpyDeviceToHost, stream); - } - // TODO(sharadmv): we allocate space for host buffers but the callback will - // return NumPy arrays which wrap host buffers. We could reuse those instead. - // Allocate space for output buffers on host - for (size_t i = 0; i < num_results; ++i) { - CpuCallback::Result result = callback->results()[i]; - void* buf = new char[result.size_in_bytes]; - host_output_buffers.push_back(buf); + host_input_buffers[i] = buf; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + gpuMemcpyAsync(buf, buffers[i], arg.size_in_bytes, gpuMemcpyDeviceToHost, + stream); } gpuStreamSynchronize(stream); - void* host_output_buffer = host_output_buffers.data(); - callback->Call(host_output_buffer, host_input_buffers.data(), status); - // Copy host output buffers back to device - for (size_t i = 0; i < num_results; ++i) { + py::gil_scoped_acquire gil; + py::tuple host_input_arrays(arity); + for (size_t i = 0; i < arity; ++i) { + CpuCallback::Arg arg = callback->args()[i]; + if (arg.type == TOKEN) { + host_input_arrays[i] = py::none(); + continue; + } + py::capsule base(host_input_buffers[i], + [](void* ptr) { delete[] static_cast(ptr); }); + host_input_arrays[i] = + py::array(arg.dtype, arg.dims, arg.strides, + const_cast(host_input_buffers[i]), /*base=*/base); + host_input_arrays[i].attr("flags").attr("writeable") = Py_False; + } + std::optional maybe_result_tuple = + callback->Call(host_input_arrays, status); + if (!maybe_result_tuple) { + return; + } + py::tuple result_tuple = maybe_result_tuple.value(); + std::vector temp_buffers; + for (size_t i = 0; i < callback->results().size(); ++i) { CpuCallback::Result result = callback->results()[i]; - gpuMemcpyAsync(buffers[arity + i], host_output_buffers[i], - result.size_in_bytes, gpuMemcpyHostToDevice, stream); + if (result.type == TOKEN) { + continue; + } + py::object output = py::reinterpret_borrow( + PyTuple_GetItem(result_tuple.ptr(), i)); + py::array array = py::cast(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == result.expected_strides) { + gpuMemcpyAsync(buffers[arity + i], array.data(), result.size_in_bytes, + gpuMemcpyHostToDevice, stream); + } else { + void* temp = new char[result.size_in_bytes]; + temp_buffers.push_back(temp); + xla::StatusOr> plan = + callback->transpose_cache().GetOrCreate( + xla::primitive_util::ByteWidth(result.type), dims, + result.reversed_layout, + /*input_layout=*/xla::TransposePlan::Striding{strides}); + if (!plan.ok()) { + throw xla::XlaRuntimeError(plan.status().ToString()); + } + plan.ValueOrDie()->Execute(array.data(), temp); + gpuMemcpyAsync(buffers[arity + i], temp, result.size_in_bytes, + gpuMemcpyHostToDevice, stream); + } } - // We need to synchronize here to ensure that host buffers are alive while - // the async copy is happening. + py::gil_scoped_release release; gpuStreamSynchronize(stream); - // Free host output buffers - for (size_t i = 0; i < num_results; ++i) { - delete[] static_cast(host_output_buffers[i]); - } - // Free host input buffers - for (size_t i = 0; i < arity; ++i) { - delete[] static_cast(host_input_buffers[i]); + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); } } diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index bca963d5c8402e..139f130d89189d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -46,7 +46,7 @@ _version = 77 # Version number for MLIR:Python components. -mlir_api_version = 22 +mlir_api_version = 28 xla_platform_names = { 'cpu': 'Host', @@ -65,7 +65,8 @@ def make_cpu_client(*, use_tfrt: bool = True) -> ...: return _xla.get_cpu_client(asynchronous=True) -def make_gpu_client(distributed_client=None, node_id=0, platform_name=None): +def make_gpu_client(distributed_client=None, node_id=0, platform_name=None, + allowed_devices=None): """Returns a GPU client. BFC allocator is used by default.""" allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION') @@ -92,7 +93,8 @@ def make_gpu_client(distributed_client=None, node_id=0, platform_name=None): allocator_config=config, distributed_client=distributed_client, node_id=node_id, - platform_name=platform_name) + platform_name=platform_name, + allowed_devices=allowed_devices) def make_tpu_client(): diff --git a/tensorflow/compiler/xla/python/xla_client.pyi b/tensorflow/compiler/xla/python/xla_client.pyi index 67a7f4ab8d929a..2780b70865f4e7 100644 --- a/tensorflow/compiler/xla/python/xla_client.pyi +++ b/tensorflow/compiler/xla/python/xla_client.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import numpy @@ -71,7 +71,9 @@ def make_cpu_client(*, use_tfrt: bool = ...) -> Client: def make_gpu_client( distributed_client: Optional[DistributedRuntimeClient] = ..., - node_id: int = ...) -> Client: + node_id: int = ..., + platform_name: Optional[str] = ..., + allowed_devices: Optional[Set[int]] = ...) -> Client: ... diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index e6e27ab4ef00ca..8a28e04ffc1798 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -4868,6 +4868,58 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( return ReplaceInstruction(dynamic_slice, new_broadcast); } + HloInstruction *reshape, *reshape_operand; + if (Match(operand, m::Reshape(&reshape, m::Op(&reshape_operand))) && + reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions().has_value() && + !options_.is_layout_sensitive()) { + int64_t slice_dim = 0; + HloInstruction* zero = MakeScalarLike(dynamic_slice->mutable_operand(1), 0); + std::vector starts; + starts.reserve(reshape_operand->shape().rank()); + std::vector slice_sizes; + slice_sizes.reserve(reshape_operand->shape().rank()); + for (int64_t dim = 0; dim < reshape_operand->shape().rank(); ++dim) { + if (reshape_operand->shape().dimensions(dim) == 1) { + starts.push_back(zero); + slice_sizes.push_back(1); + continue; + } + while (dynamic_slice->operand(0)->shape().dimensions(slice_dim) == 1) { + ++slice_dim; + } + starts.push_back(dynamic_slice->mutable_operand(1 + slice_dim)); + slice_sizes.push_back(dynamic_slice->slice_sizes(slice_dim)); + ++slice_dim; + } + HloInstruction* new_dynamic_slice = + dynamic_slice->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(dynamic_slice->shape().element_type(), + slice_sizes), + reshape_operand, starts, slice_sizes)); + return ReplaceWithNewInstruction( + dynamic_slice, HloInstruction::CreateReshape(dynamic_slice->shape(), + new_dynamic_slice)); + } + + HloInstruction *transpose, *transpose_operand; + if (Match(operand, m::Transpose(&transpose, m::Op(&transpose_operand))) && + !options_.is_layout_sensitive()) { + auto output_to_input = InversePermutation(transpose->dimensions()); + HloInstruction* new_slice = + dynamic_slice->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::PermuteDimensions(output_to_input, + dynamic_slice->shape()), + transpose_operand, + Permute(absl::MakeSpan(dynamic_slice->operands().begin() + 1, + dynamic_slice->operands().end()), + output_to_input), + Permute(dynamic_slice->dynamic_slice_sizes(), output_to_input))); + return ReplaceWithNewInstruction( + dynamic_slice, + HloInstruction::CreateTranspose(dynamic_slice->shape(), new_slice, + transpose->dimensions())); + } + // Convert a dynamic slice into a slice if all offsets are constant and the // operand is not constant. if (operand->opcode() != HloOpcode::kConstant && @@ -6088,9 +6140,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return ReplaceInstruction(transpose, operand); } - if (options_.is_layout_sensitive() && - options_.replace_transpose_with_bitcast() && - TransposeIsBitcast(transpose)) { + if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return OkStatus(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index f85bb7730c8279..6675bbf5e21320 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -172,16 +172,6 @@ class AlgebraicSimplifierOptions { bool enable_sink_broadcast() const { return enable_sink_broadcast_; } - // TODO(b/228984373): Remove this option once BitcastDecomposer lands and - // sticks. - void set_replace_transpose_with_bitcast(bool replace_transpose_with_bitcast) { - replace_transpose_with_bitcast_ = replace_transpose_with_bitcast; - } - - bool replace_transpose_with_bitcast() const { - return replace_transpose_with_bitcast_; - } - // If true, min(x, NaN) = NaN. If false, min(x, NaN) = x. // // TODO(b/209827141): Remove this and make minmax_propagate_nan uncondtionally @@ -215,7 +205,6 @@ class AlgebraicSimplifierOptions { bool enable_reduce_of_reshape_{true}; bool enable_negative_padding_replacement_{true}; bool enable_sink_broadcast_{true}; - bool replace_transpose_with_bitcast_{true}; int64_t very_small_gather_size_{4}; bool minmax_propagate_nan_{true}; Metadata metadata_; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index b7023e4396ee97..af647bf2945490 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -3102,13 +3102,6 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); - // Don't replace transposes with bitcasts. - options.set_replace_transpose_with_bitcast(false); - AlgebraicSimplifier simplifier_no_replace(options); - ASSERT_FALSE(simplifier_no_replace.Run(m.get()).ValueOrDie()); - - // Replace transposes with bitcasts if possible. - options.set_replace_transpose_with_bitcast(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -5275,6 +5268,60 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfNonCanonicalBatchDotCantSimplify) { EXPECT_FALSE(changed); } +TEST_F(AlgebraicSimplifierTest, DynamicSliceOfTranspose) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[12,10,8] parameter(0) + i0 = s32[] parameter(1) + i1 = s32[] parameter(2) + i2 = s32[] parameter(3) + transpose = f32[12,8,10] transpose(param), dimensions={0,2,1} + ROOT slice = f32[2,3,5] dynamic-slice(transpose, i0, i1, i2), + dynamic_slice_sizes={2,3,5} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Transpose( + m::DynamicSlice(m::Parameter(0), m::Parameter(1), + m::Parameter(3), m::Parameter(2))))); +} + +TEST_F(AlgebraicSimplifierTest, DynamicSliceOfTrivialReshape) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[12,10,1,8] parameter(0) + i0 = s32[] parameter(1) + i1 = s32[] parameter(2) + i2 = s32[] parameter(3) + z = s32[] constant(0) + reshape = f32[1,12,10,8] reshape(param) + ROOT slice = f32[1,2,3,5] dynamic-slice(reshape, z, i0, i1, i2), + dynamic_slice_sizes={1,2,3,5} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(false); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::DynamicSlice( + m::Parameter(0), m::Parameter(1), m::Parameter(2), + m::Constant(), m::Parameter(3))))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { const char* hlo_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 4c22a71e40cade..bb44eee1617c7b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include "absl/base/dynamic_annotations.h" #include "absl/container/flat_hash_map.h" @@ -459,7 +460,6 @@ class CpuCollectivePermuteRendezvous for (int p_idx = 0; p_idx < participants_.size(); p_idx++) { replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx; } - for (auto& p : participants_) { for (int dest_replica : p.replica_ids_to_copy_to) { auto& dest_p = participants_[xla::FindOrDie( @@ -835,17 +835,20 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute( absl::string_view source_target_pairs_serialized( static_cast(source_target_pairs), source_target_pairs_size); auto pairs = absl::StrSplit(source_target_pairs_serialized, ','); - int32_t replica_id = + const xla::DeviceAssignment::LogicalID logical_id = run_options->device_assignment() - ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal)) + ->LogicalIdForDevice(xla::GlobalDeviceId(device_ordinal)) .ValueOrDie(); + int32_t logical_device_id = + channel_id_present ? logical_id.computation_id : logical_id.replica_id; + std::vector copy_to; for (auto& p : pairs) { std::vector mapping = absl::StrSplit(p, '='); CHECK_EQ(mapping.size(), 2); int from = std::stoi(mapping[0]); int to = std::stoi(mapping[1]); - if (from == replica_id) { + if (from == logical_device_id) { copy_to.push_back(to); } } @@ -855,7 +858,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute( CollectivePermuteParticipantData participant(rendezvous_key, device_ordinal, run_options->stream()); - participant.replica_id = replica_id; + participant.replica_id = logical_device_id; participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size); participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size); participant.replica_ids_to_copy_to = copy_to; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 956928d91f3387..9fae96d89e77ee 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1084,7 +1084,6 @@ cc_library( ":ir_emission_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo_gpu", "//tensorflow/compiler/xla:shape_util", @@ -2690,6 +2689,7 @@ cc_library( "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/core:lib_proto_parsing", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2725,160 +2725,6 @@ test_suite( ]), ) -# These tests are intended to be run with --test_env=XLA_FLAGS=--xla_gpu_bef_executable -# See tap/tensorflow.xla_gpu_tfrt_executable. -test_suite( - name = "bef_executable_tests", - tests = [ - "//tensorflow/compiler/tests:fft_test_gpu", - "//tensorflow/compiler/xla/service/gpu:cudnn_fused_conv_rewriter_test", - "//tensorflow/compiler/xla/service/gpu:custom_call_test", - "//tensorflow/compiler/xla/service/gpu/tests:add_preds.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:all_reduce.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:bef_executable_test_gpu", - "//tensorflow/compiler/xla/service/gpu/tests:concat.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:constant.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:copy.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:copy_nested.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:dynamic_update_slice_inplace.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:element_wise_row_vectorization.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:element_wise_row_vectorization_test", - "//tensorflow/compiler/xla/service/gpu/tests:fused_scatter.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:fused_slice.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:fused_slice_different_operands.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:fusion.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:gemm_broadcast_folding_rewrite_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_alignment_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_atomic_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_copy_alone_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_copy_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_ftz_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_fusion_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_index_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_infeed_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_input_fusible_slice_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_kernel_tiling_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_ldg_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_noalias_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_reduce_scatter_creator_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_spmd_e2e_compile_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_too_many_blocks_test", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_unrolling_test", - "//tensorflow/compiler/xla/service/gpu/tests:kernel_launch_test", - "//tensorflow/compiler/xla/service/gpu/tests:launch_dimensions.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:mlir_fft_test", - "//tensorflow/compiler/xla/service/gpu/tests:mlir_gemm_test", - "//tensorflow/compiler/xla/service/gpu/tests:mlir_gpu_compile_test", - "//tensorflow/compiler/xla/service/gpu/tests:pad_to_static.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:parallel_reduction_test", - "//tensorflow/compiler/xla/service/gpu/tests:pred_arithmetic_test", - "//tensorflow/compiler/xla/service/gpu/tests:reduce_unnested.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:reduction_degenerate_dim_remover_test", - "//tensorflow/compiler/xla/service/gpu/tests:reduction_dimension_grouper_test", - "//tensorflow/compiler/xla/service/gpu/tests:reduction_layout_normalizer_test", - "//tensorflow/compiler/xla/service/gpu/tests:reduction_vectorization_sm_all.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:reduction_vectorization_test", - "//tensorflow/compiler/xla/service/gpu/tests:rng_get_and_update_state.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:scatter.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:select_and_scatter.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:slice_to_dynamic.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:sorting.hlo.test", - "//tensorflow/compiler/xla/service/gpu/tests:sorting_test", - "//tensorflow/compiler/xla/service/gpu/tests:tree_reduction_rewriter_test", - "//tensorflow/compiler/xla/tests:array_elementwise_ops_test_gpu", - "//tensorflow/compiler/xla/tests:axpy_simple_test_gpu", - "//tensorflow/compiler/xla/tests:bad_rng_shape_validation_test_gpu", - "//tensorflow/compiler/xla/tests:batch_normalization_test_gpu", - "//tensorflow/compiler/xla/tests:bfloat16_test_gpu", - "//tensorflow/compiler/xla/tests:binop_scaling_test_gpu", - "//tensorflow/compiler/xla/tests:bitcast_convert_test_gpu", - "//tensorflow/compiler/xla/tests:broadcast_simple_test_gpu", - "//tensorflow/compiler/xla/tests:broadcast_test_gpu", - "//tensorflow/compiler/xla/tests:call_test_gpu", - "//tensorflow/compiler/xla/tests:check_execution_arity_test_gpu", - "//tensorflow/compiler/xla/tests:cholesky_test_gpu", - "//tensorflow/compiler/xla/tests:client_test_gpu", - "//tensorflow/compiler/xla/tests:compilation_cache_test_gpu", - "//tensorflow/compiler/xla/tests:compute_constant_test_gpu", - "//tensorflow/compiler/xla/tests:concat_test_gpu", - "//tensorflow/compiler/xla/tests:constant_reduction_function_test_gpu", - "//tensorflow/compiler/xla/tests:constants_test_gpu", - "//tensorflow/compiler/xla/tests:convert_test_gpu", - "//tensorflow/compiler/xla/tests:convolution_test_1d_autotune_disabled_gpu", - "//tensorflow/compiler/xla/tests:convolution_test_autotune_disabled_gpu", - "//tensorflow/compiler/xla/tests:convolution_test_cudnn_frontend_disabled_gpu", - "//tensorflow/compiler/xla/tests:copy_test_gpu", - "//tensorflow/compiler/xla/tests:cpu_gpu_fusion_test_gpu", - "//tensorflow/compiler/xla/tests:deallocation_test_gpu", - "//tensorflow/compiler/xla/tests:deconstruct_tuple_test_gpu", - "//tensorflow/compiler/xla/tests:deep_graph_test_gpu", - "//tensorflow/compiler/xla/tests:dot_operation_single_threaded_runtime_test_gpu", - "//tensorflow/compiler/xla/tests:dot_operation_test_autotune_disabled_gpu", - "//tensorflow/compiler/xla/tests:dot_operation_test_gpu", - "//tensorflow/compiler/xla/tests:dynamic_ops_test_gpu", - "//tensorflow/compiler/xla/tests:execution_profile_test_gpu", - "//tensorflow/compiler/xla/tests:execution_profile_test_with_xla_hlo_profile_gpu", - "//tensorflow/compiler/xla/tests:exhaustive_binary_16_bit_test_gpu", - "//tensorflow/compiler/xla/tests:exhaustive_binary_test_f32_f64_gpu", - "//tensorflow/compiler/xla/tests:exhaustive_unary_test_complex_gpu", - "//tensorflow/compiler/xla/tests:exhaustive_unary_test_f32_or_smaller_gpu", - "//tensorflow/compiler/xla/tests:exhaustive_unary_test_f64_gpu", - "//tensorflow/compiler/xla/tests:floor_ceil_test_gpu", - "//tensorflow/compiler/xla/tests:fmax_fmin_test_gpu", - "//tensorflow/compiler/xla/tests:gather_operation_test_gpu", - "//tensorflow/compiler/xla/tests:get_dimension_size_test_gpu", - "//tensorflow/compiler/xla/tests:half_test_gpu", - "//tensorflow/compiler/xla/tests:iota_test_gpu", - "//tensorflow/compiler/xla/tests:local_client_allocation_test_gpu", - "//tensorflow/compiler/xla/tests:local_client_execute_test_gpu", - "//tensorflow/compiler/xla/tests:log_test_gpu", - "//tensorflow/compiler/xla/tests:map_test_gpu", - "//tensorflow/compiler/xla/tests:matrix_ops_simple_test_gpu", - "//tensorflow/compiler/xla/tests:multidimensional_slice_test_gpu", - "//tensorflow/compiler/xla/tests:multioutput_fusion_test_gpu", - "//tensorflow/compiler/xla/tests:pad_test_gpu", - "//tensorflow/compiler/xla/tests:params_test_gpu", - "//tensorflow/compiler/xla/tests:pred_test_gpu", - "//tensorflow/compiler/xla/tests:prng_test_gpu", - "//tensorflow/compiler/xla/tests:ptxas_bug_120501638_gpu", - "//tensorflow/compiler/xla/tests:query_inferred_shape_test_gpu", - "//tensorflow/compiler/xla/tests:reduce_hlo_test_gpu", - "//tensorflow/compiler/xla/tests:reduce_precision_test_gpu", - "//tensorflow/compiler/xla/tests:replay_test_gpu", - "//tensorflow/compiler/xla/tests:reshape_motion_test_gpu", - "//tensorflow/compiler/xla/tests:reshape_test_gpu", - "//tensorflow/compiler/xla/tests:reverse_test_gpu", - "//tensorflow/compiler/xla/tests:round_trip_packed_literal_test_gpu", - "//tensorflow/compiler/xla/tests:round_trip_transfer_test_gpu", - "//tensorflow/compiler/xla/tests:sample_text_test_gpu", - "//tensorflow/compiler/xla/tests:scalar_computations_test_gpu", - "//tensorflow/compiler/xla/tests:scatter_test_gpu", - "//tensorflow/compiler/xla/tests:select_test_gpu", - "//tensorflow/compiler/xla/tests:slice_test_gpu", - "//tensorflow/compiler/xla/tests:transfer_manager_test_gpu", - "//tensorflow/compiler/xla/tests:transpose_test_gpu", - "//tensorflow/compiler/xla/tests:triangular_solve_test_gpu", - "//tensorflow/compiler/xla/tests:tuple_test_gpu", - "//tensorflow/compiler/xla/tests:unary_op_test_gpu", - "//tensorflow/compiler/xla/tests:value_inference_test_gpu", - "//tensorflow/compiler/xla/tests:vector_ops_reduce_test_gpu", - "//tensorflow/compiler/xla/tests:vector_ops_simple_test_gpu", - "//tensorflow/compiler/xla/tests:xla_hlo_profile_test_gpu", - ] + if_google([ - # Currently fails in OSS. - "//tensorflow/compiler/xla/tests:conv_depthwise_backprop_filter_test_gpu", - "//tensorflow/compiler/xla/tests:conv_depthwise_test_gpu", - "//tensorflow/compiler/xla/tests:convolution_dimension_numbers_test_gpu", - "//tensorflow/compiler/xla/tests:convolution_test_1d_gpu_alternative_layout_gpu", - "//tensorflow/compiler/xla/tests:convolution_test_1d_no_vmodule_gpu", - "//tensorflow/compiler/xla/tests:convolution_test_gpu", - "//tensorflow/compiler/xla/tests:convolution_test_gpu_alternative_layout_gpu", - "//tensorflow/compiler/xla/tests:convolution_variants_test_gpu", - "//tensorflow/compiler/xla/tests:grouped_convolution_test_gpu", - "//tensorflow/python/kernel_tests/signal:fft_ops_test_xla_gpu", - ]), -) - # These tests are intended to be run with --test_env=XLA_FLAGS=--xla_gpu_jitrt_executable # See tap/tensorflow.xla_gpu_jitrt. test_suite( @@ -2975,6 +2821,7 @@ test_suite( "//tensorflow/compiler/xla/tests:deconstruct_tuple_test_gpu", "//tensorflow/compiler/xla/tests:deep_graph_test_gpu", "//tensorflow/compiler/xla/tests:dot_operation_single_threaded_runtime_test_gpu", + "//tensorflow/compiler/xla/tests:dot_operation_test_autotune_disabled_gpu", "//tensorflow/compiler/xla/tests:dot_operation_test_gpu", "//tensorflow/compiler/xla/tests:dynamic_ops_test_gpu", "//tensorflow/compiler/xla/tests:execution_profile_test_gpu", diff --git a/tensorflow/compiler/xla/service/gpu/bef_thunk.cc b/tensorflow/compiler/xla/service/gpu/bef_thunk.cc index 255fcb31728165..ea8781d098c867 100644 --- a/tensorflow/compiler/xla/service/gpu/bef_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/bef_thunk.cc @@ -161,7 +161,7 @@ ConvertToBef(mlir::ModuleOp module, tfrt::HostContext* host) { } static StatusOr GetThunkKind(mlir::Operation* op) { - if (mlir::isa(op)) { + if (mlir::isa(op)) { return Thunk::Kind::kGemm; } if (mlir::isa(op)) { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.cc index 917bb4ffb5a4d4..0231a46a731a2f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" @@ -53,13 +54,22 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { bcast->operand(0)->shape().dimensions_size()); int num_batch_dims = dim_nums->lhs_batch_dimensions_size(); + const tensorflow::protobuf::RepeatedField &batch_dimensions = + (bcast_operand_index == 1) ? dim_nums->rhs_batch_dimensions() + : dim_nums->lhs_batch_dimensions(); // This optimization is only valid if the set of broadcasted dimensions // is exactly the set of batch dimensions. First, check that all newly - // broadcast dimensions have been inserted on the left. + // broadcast dimensions have been inserted on the left i.e. all new + // dimensions must be in [0, num_bcast_dims) or equivalently all original + // dimensions are >= num_bcast_dims. for (int64_t bcast_dim : bcast->dimensions()) { if (bcast_dim < num_bcast_dims) { return OkStatus(); } + // bcast_dim should not be in batch_dimensions. + if (absl::c_linear_search(batch_dimensions, bcast_dim)) { + return OkStatus(); + } } // Then check that all batch dimensions are being broadcast, and that diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 806c488184470b..c2ec7a19e1fd9c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -253,6 +253,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { MaybeConstantFoldBias(bias), }); TF_RETURN_IF_ERROR(gemm_call->set_backend_config(config)); + // Force bias input to alias with output, as GEMM operates in-place. + xla::Cast(gemm_call.get()) + ->set_output_to_operand_aliasing({{{}, {2, {}}}}); TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get())); TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(gemm_call))); return OkStatus(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 4930fa7da3cd37..eecdc35ba1e43d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1110,71 +1110,25 @@ Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { } Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) { - auto make_bef_thunk = - [&](auto op, std::optional bias = - std::nullopt) -> StatusOr> { - TF_ASSIGN_OR_RETURN(auto lhs, GetAllocationSlice(op.getLhs())); - TF_ASSIGN_OR_RETURN(auto rhs, GetAllocationSlice(op.getRhs())); - TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(op.getOutput())); - std::vector buffers = {lhs, rhs}; - if (bias.has_value()) { - buffers.push_back(bias.value()); - } - buffers.push_back(output); - return CreateBefThunk(GetThunkInfo(op), op, std::move(buffers)); - }; + TF_ASSIGN_OR_RETURN(auto thunk, [&]() -> StatusOr> { + auto gemm = mlir::dyn_cast(op); + TF_RET_CHECK(gemm != nullptr); - auto make_gemm_thunk = - [&](auto op, std::optional gemm_bias_beta = - std::nullopt) -> StatusOr> { - TF_ASSIGN_OR_RETURN(auto lhs, GetAllocationSlice(op.getLhs())); - TF_ASSIGN_OR_RETURN(auto rhs, GetAllocationSlice(op.getRhs())); - TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(op.getOutput())); + TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(gemm.getA())); + TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(gemm.getB())); + TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(gemm.getC())); + + if (IsBefThunkEnabled(hlo_module_config_)) { + return CreateBefThunk(GetThunkInfo(op), op, {a, b, c}); + } bool use_cublaslt = hlo_module_config_.debug_options().xla_gpu_enable_cublaslt(); - TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(op, use_cublaslt)); + TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm, use_cublaslt)); return std::unique_ptr( - new GemmThunk(GetThunkInfo(op), std::move(config), lhs, rhs, output)); - }; - - TF_ASSIGN_OR_RETURN(auto thunk, [&]() -> StatusOr> { - if (auto gemm = mlir::dyn_cast(op)) { - if (IsBefThunkEnabled(hlo_module_config_)) return make_bef_thunk(gemm); - return make_gemm_thunk(gemm); - } - - if (auto gemm = mlir::dyn_cast(op)) { - double gemm_bias_beta = gemm.getBeta().convertToDouble(); - TF_ASSIGN_OR_RETURN(auto bias, GetAllocationSlice(gemm.getBias())); - TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(gemm.getOutput())); - - if (IsBefThunkEnabled(hlo_module_config_)) - return make_bef_thunk(gemm, bias); - - // The bias is passed inside the output buffer. If those buffers are - // shared we can just use it, otherwise copy the bias values into the - // output buffer first. - if (bias == output) { - return make_gemm_thunk(gemm, gemm_bias_beta); - } - - ThunkSequence thunks; - thunks.push_back(std::make_unique( - Thunk::ThunkInfo(), - /*source_buffer=*/bias, - /*destination_buffer=*/output, - /*mem_size=*/ - ShapeUtil::ByteSizeOf(GetShape(gemm.getOutput())))); - TF_ASSIGN_OR_RETURN(auto thunk, make_gemm_thunk(gemm, gemm_bias_beta)); - thunks.push_back(std::move(thunk)); - return std::unique_ptr( - new SequentialThunk(GetThunkInfo(op), std::move(thunks))); - } - - return tensorflow::errors::Internal("Unexpected op."); + new GemmThunk(GetThunkInfo(op), std::move(config), a, b, c)); }()); AddThunkToThunkSequence(std::move(thunk)); @@ -5668,7 +5622,7 @@ Status IrEmitterUnnested::EmitOp(mlir::Operation* op) { return EmitCustomCallThunk(op); } - if (mlir::isa(op)) { + if (mlir::isa(op)) { return EmitGemmThunk(op); } diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc index 5015aac3985331..97d78bfd10040c 100644 --- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc +++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc @@ -259,13 +259,12 @@ static Shape ToShape(const jitrt::StridedMemrefView& memref) { static StatusOr GetGemmConfig( const DebugOptions* debug_options, const jitrt::StridedMemrefView& lhs, const jitrt::StridedMemrefView& rhs, const jitrt::StridedMemrefView& out, - int64_t algorithm, double alpha_real, double alpha_imag, + int64_t algorithm, double alpha_real, double alpha_imag, double beta, ArrayRef lhs_batch, ArrayRef lhs_contract, - ArrayRef rhs_batch, ArrayRef rhs_contract, - llvm::Optional beta = llvm::None) { + ArrayRef rhs_batch, ArrayRef rhs_contract) { return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs), rhs_batch, rhs_contract, ToShape(out), alpha_real, - alpha_imag, beta.getValueOr(0.0), algorithm, + alpha_imag, beta, algorithm, se::blas::kDefaultComputePrecision, debug_options->xla_gpu_enable_cublaslt()); } @@ -426,7 +425,7 @@ struct Gemm { const DebugOptions* debug_options, JitRtGemmConfigCache* configs, jitrt::StridedMemrefView lhs, jitrt::StridedMemrefView rhs, jitrt::StridedMemrefView out, int64_t algorithm, double alpha_real, - double alpha_imag, ArrayRef lhs_batch, + double alpha_imag, double beta, ArrayRef lhs_batch, ArrayRef lhs_contract, ArrayRef rhs_batch, ArrayRef rhs_contract, int64_t uid) const; @@ -439,7 +438,7 @@ LogicalResult Gemm::operator()( const DebugOptions* debug_options, JitRtGemmConfigCache* configs, jitrt::StridedMemrefView lhs, jitrt::StridedMemrefView rhs, jitrt::StridedMemrefView out, int64_t algorithm, double alpha_real, - double alpha_imag, ArrayRef lhs_batch, + double alpha_imag, double beta, ArrayRef lhs_batch, ArrayRef lhs_contract, ArrayRef rhs_batch, ArrayRef rhs_contract, int64_t uid) const { se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs); @@ -453,8 +452,8 @@ LogicalResult Gemm::operator()( const GemmConfig* config = configs->Get(uid); if (config == nullptr) { auto cfg = GetGemmConfig(debug_options, lhs, rhs, out, algorithm, - alpha_real, alpha_imag, lhs_batch, lhs_contract, - rhs_batch, rhs_contract); + alpha_real, alpha_imag, beta, lhs_batch, + lhs_contract, rhs_batch, rhs_contract); if (!cfg.ok()) return failure(); config = configs->Set(uid, std::move(*cfg)); } @@ -476,7 +475,7 @@ LogicalResult Gemm::operator()( run_options->device_ordinal(), run_options->allocator()); return RunBlasLtMatmul(matmul_plan, {alpha_real, alpha_imag}, lhs_data, - rhs_data, /*beta=*/0., output_data, stream, + rhs_data, beta, output_data, stream, scratch_allocator); } return RunGemm(*config, lhs_data, rhs_data, output_data, stream); @@ -499,114 +498,13 @@ static bool Gemm(runtime::KernelContext* ctx, void** args, void** attrs) { .Attr("algorithm") .Attr("alpha_real") .Attr("alpha_imag") - .Attr>("lhs_batching_dimensions") - .Attr>("lhs_contracting_dimensions") - .Attr>("rhs_batching_dimensions") - .Attr>("rhs_contracting_dimensions") - .Attr("uid") - .To(Gemm::Handler()) - .release(); - - return succeeded(Executable::Call(ctx, *handler, args, attrs)); -} - -// -------------------------------------------------------------------------- // - -namespace { -struct GemmBias { - LLVM_ATTRIBUTE_ALWAYS_INLINE - LogicalResult operator()( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, JitRtGemmConfigCache* configs, - jitrt::StridedMemrefView lhs, jitrt::StridedMemrefView rhs, - jitrt::StridedMemrefView bias, jitrt::StridedMemrefView out, - int64_t algorithm, double alpha_real, double alpha_imag, double beta, - ArrayRef lhs_batch, ArrayRef lhs_contract, - ArrayRef rhs_batch, ArrayRef rhs_contract, - int64_t uid) const; - static GemmBias Handler() { return GemmBias(); } -}; -} // namespace - -LogicalResult GemmBias::operator()( - const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options, JitRtGemmConfigCache* configs, - jitrt::StridedMemrefView lhs, jitrt::StridedMemrefView rhs, - jitrt::StridedMemrefView bias, jitrt::StridedMemrefView out, - int64_t algorithm, double alpha_real, double alpha_imag, double beta, - ArrayRef lhs_batch, ArrayRef lhs_contract, - ArrayRef rhs_batch, ArrayRef rhs_contract, - int64_t uid) const { - se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs); - se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs); - se::DeviceMemoryBase bias_data = GetDeviceAddress(bias); - se::DeviceMemoryBase output_data = GetDeviceAddress(out); - - VLOG(3) << "Running GEMM + Bias [beta=" << beta << "]"; - se::Stream* stream = run_options->stream(); - - // Find the gemm config for this instance of operation based on uid. - const GemmConfig* config = configs->Get(uid); - if (config == nullptr) { - auto cfg = GetGemmConfig(debug_options, lhs, rhs, out, algorithm, - alpha_real, alpha_imag, lhs_batch, lhs_contract, - rhs_batch, rhs_contract, beta); - if (!cfg.ok()) return failure(); - config = configs->Set(uid, std::move(*cfg)); - } - - // Copy bias to the output buffer of they are different. - if (out.data != bias.data) - stream->ThenMemcpy(&output_data, bias_data, bias_data.size()); - - Status executed = [&] { - if (config->use_cublaslt) { - TF_ASSIGN_OR_RETURN(MatmulPlanParams matmul_plan_params, - GetBlasLtMatmulPlanParams(*config)); - - // TODO(cjfj): Cache the plan. - se::cuda::BlasLt::MatmulPlan matmul_plan; - TF_RETURN_IF_ERROR(matmul_plan.init(matmul_plan_params.params)); - - if (matmul_plan_params.must_swap_operands) { - std::swap(lhs_data, rhs_data); - } - - se::OwningScratchAllocator<> scratch_allocator( - run_options->device_ordinal(), run_options->allocator()); - - return RunBlasLtMatmul(matmul_plan, {alpha_real, alpha_imag}, lhs_data, - rhs_data, beta, output_data, stream, - scratch_allocator); - } - return RunGemm(*config, lhs_data, rhs_data, output_data, stream); - }(); - - if (!executed.ok()) return failure(); - - return success(); -} - -static bool GemmBias(runtime::KernelContext* ctx, void** args, void** attrs) { - static auto* handler = - CustomCall::Bind("xla.gpu.gemm.bias") - .UserData() - .UserData() - .UserData() - .Arg() // lhs - .Arg() // rhs - .Arg() // bias - .Arg() // out - .Attr("algorithm") - .Attr("alpha_real") - .Attr("alpha_imag") .Attr("beta") .Attr>("lhs_batching_dimensions") .Attr>("lhs_contracting_dimensions") .Attr>("rhs_batching_dimensions") .Attr>("rhs_contracting_dimensions") .Attr("uid") - .To(GemmBias::Handler()) + .To(Gemm::Handler()) .release(); return succeeded(Executable::Call(ctx, *handler, args, attrs)); @@ -2079,7 +1977,6 @@ DirectCustomCallLibrary JitRtGpuCustomCalls() { lib.Insert("xla.gpu.collective_permute", &xla::gpu::CollectivePermute); lib.Insert("xla.gpu.func.launch", &xla::gpu::LaunchFunc); lib.Insert("xla.gpu.gemm", &xla::gpu::Gemm); - lib.Insert("xla.gpu.gemm.bias", &xla::gpu::GemmBias); auto conv = [](StringRef name) { return ("xla.gpu.conv." + name).str(); }; lib.Insert(conv("forward"), &ConvFn); diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc index d6f677f5c57cce..04cdd909915f06 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/types/span.h" -#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -362,42 +361,33 @@ bool IsBlasPlansCompatibleType(PrimitiveType type) { se::blas::kDefaultComputePrecision, use_cublaslt); } -/*static*/ StatusOr GemmConfig::For(mlir::Operation* op, +/*static*/ StatusOr GemmConfig::For(mlir::lmhlo_gpu::GEMMOp op, bool use_cublaslt) { - auto get_config = [&](auto op, llvm::APFloat beta) { - mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); - - std::optional algorithm; - if (op.getAlgorithm()) algorithm = *op.getAlgorithm(); - - int64_t compute_precision = 0; // Default - if (op.getPrecisionConfig().hasValue()) { - auto precision_config = op.getPrecisionConfig(); - for (auto attr : precision_config.getValue()) { - int64_t value = static_cast( - attr.template cast().getValue()); - if (value > compute_precision) { - compute_precision = value; - } + mlir::mhlo::DotDimensionNumbersAttr dot_dims = op.getDotDimensionNumbers(); + + std::optional algorithm; + if (op.getAlgorithm()) algorithm = *op.getAlgorithm(); + + int64_t compute_precision = 0; // Default + if (op.getPrecisionConfig().hasValue()) { + auto precision_config = op.getPrecisionConfig(); + for (auto attr : precision_config.getValue()) { + int64_t value = static_cast( + attr.template cast().getValue()); + if (value > compute_precision) { + compute_precision = value; } } + } - return GemmConfig::For( - GetShape(op.getLhs()), dot_dims.getLhsBatchingDimensions(), - dot_dims.getLhsContractingDimensions(), GetShape(op.getRhs()), - dot_dims.getRhsBatchingDimensions(), - dot_dims.getRhsContractingDimensions(), GetShape(op.getOutput()), - op.getAlphaReal().convertToDouble(), - op.getAlphaImag().convertToDouble(), beta.convertToDouble(), algorithm, - compute_precision, use_cublaslt); - }; - - if (auto gemm = mlir::dyn_cast(op)) - return get_config(gemm, llvm::APFloat(0.)); - - auto gemm = mlir::dyn_cast(op); - TF_RET_CHECK(gemm != nullptr); - return get_config(gemm, gemm.getBeta()); + return GemmConfig::For( + GetShape(op.getA()), dot_dims.getLhsBatchingDimensions(), + dot_dims.getLhsContractingDimensions(), GetShape(op.getB()), + dot_dims.getRhsBatchingDimensions(), + dot_dims.getRhsContractingDimensions(), GetShape(op.getC()), + op.getAlphaReal().convertToDouble(), op.getAlphaImag().convertToDouble(), + op.getBeta().convertToDouble(), algorithm, compute_precision, + use_cublaslt); } namespace { @@ -416,6 +406,26 @@ bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, return swap_operands; } +StatusOr GetBlasComputationType( + PrimitiveType dtype) { + switch (dtype) { + case F16: // fall-through + case BF16: + // Accumulate in f32 precision. + return se::blas::ComputationType::kF32; + case F32: // fall-through + case C64: + return se::blas::ComputationType::kTF32AsF32; + case F64: // fall-through + case C128: + return se::blas::ComputationType::kF64; + case S32: + return se::blas::ComputationType::kI32; + default: + return InternalError("unsupported type"); + } +} + se::blas::Transpose AsBlasTranspose(MatrixLayout::Order order) { // BLAS is column-major by default. return (order == MatrixLayout::Order::kColumnMajor) @@ -433,29 +443,6 @@ se::blas::MatrixDescriptor GetMatrixDesc(const MatrixLayout& layout, }; } -// Converts from an XLA PrimitiveType to a blas::ComputationType, which is -// used to specify the precision with which matmul computations should be -// performed, separately from the precision of the inputs and result. -std::optional ComputationTypeFromPrimitive( - PrimitiveType type) { - switch (type) { - case F16: // Use F32 computation for higher precision. - case BF16: - case F32: - return se::blas::ComputationType::kF32; - case F64: - return se::blas::ComputationType::kF64; - case C64: - return se::blas::ComputationType::kComplexF32; - case C128: - return se::blas::ComputationType::kComplexF64; - case S32: - return se::blas::ComputationType::kI32; - default: - return std::nullopt; - } -} - template Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k, const se::blas::MatrixDescriptor& lhs, @@ -466,8 +453,8 @@ Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k, se::blas::ProfileResult* profile_result) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); PrimitiveType output_type = primitive_util::NativeToPrimitiveType(); - se::blas::ComputationType computation_type = - *ComputationTypeFromPrimitive(output_type); + TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type, + GetBlasComputationType(output_type)); se::DeviceMemory output_data(output.data); if (batch_size != 1) { @@ -616,25 +603,6 @@ StatusOr AsBlasDataType(PrimitiveType dtype) { } } -StatusOr AsBlasComputationType(PrimitiveType dtype) { - switch (dtype) { - case F16: - return se::blas::ComputationType::kF16; - case BF16: - return se::blas::ComputationType::kBF16AsF32; - case F32: - return se::blas::ComputationType::kF32; - case F64: - return se::blas::ComputationType::kF64; - case C64: - return se::blas::ComputationType::kComplexF32; - case C128: - return se::blas::ComputationType::kComplexF64; - default: - return InternalError("unsupported type"); - } -} - template Status DoGemmLt(const se::cuda::BlasLt::MatmulPlan& plan, Input alpha, se::DeviceMemoryBase lhs_buffer, @@ -696,7 +664,7 @@ StatusOr GetBlasLtMatmulPlanParams(const GemmConfig& config) { TF_ASSIGN_OR_RETURN(se::blas::DataType dtype, AsBlasDataType(output_layout.dtype)); TF_ASSIGN_OR_RETURN(se::blas::ComputationType computation_type, - AsBlasComputationType(output_layout.dtype)); + GetBlasComputationType(output_layout.dtype)); se::cuda::BlasLt::MatmulPlanParams params{ /*ab_type=*/dtype, diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.h b/tensorflow/compiler/xla/service/gpu/matmul_utils.h index b0e8397bbeb10c..f79fad17f07f2e 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.h +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" -#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/statusor.h" @@ -86,7 +86,8 @@ StatusOr CanFoldTransposeOperandIntoDot(const HloInstruction& dot, struct GemmConfig { static StatusOr For(const HloInstruction* gemm); - static StatusOr For(mlir::Operation* op, bool use_cublaslt); + static StatusOr For(mlir::lmhlo_gpu::GEMMOp op, + bool use_cublaslt); static StatusOr For( const Shape& lhs_shape, absl::Span lhs_batch_dims, diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index fecc4b5d5ad0d4..df40eebe558eab 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -248,6 +248,8 @@ tf_cc_test( ":gpu_codegen_test", "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service/gpu:gemm_broadcast_folding_rewriter", + "//tensorflow/compiler/xla/service/gpu:gemm_rewriter", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc index 56d65211ae2489..00cf7aa306c961 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/core/platform/test.h" @@ -71,6 +73,54 @@ ENTRY AddDotsFunc { ; CHECK-NEXT: ROOT %cublas-batch-gemm.1 = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[\"0\"]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"selected_algorithm\":\"{{-?[0-9]+}}\"}" )"); } + +TEST_F(GemmBroadcastFoldingRewriteTest, LHSBatchDimNonZero) { + const char* hlo_text = R"( +HloModule LHSBatchDimNonZero + +ENTRY %LHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] { + %Arg_1 = f32[4,3]{1,0} parameter(0) + %Arg_2 = f32[4,7,3]{2,1,0} parameter(1) + %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2} + ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[7,4,3]{2,1,0} %broadcast.22, f32[4,7,3]{2,1,0} %Arg_2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2} +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + // Use GemmRewriter to generate cublasGemm call. + GemmRewriter gemm_rewriter; + TF_ASSERT_OK_AND_ASSIGN(bool changed, + this->RunHloPass(&gemm_rewriter, module.get())); + EXPECT_TRUE(changed); + GemmBroadcastFoldingRewriter pass; + TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(GemmBroadcastFoldingRewriteTest, RHSBatchDimNonZero) { + const char* hlo_text = R"( +HloModule RHSBatchDimNonZero + +ENTRY %RHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] { + %Arg_1 = f32[4,3]{1,0} parameter(0) + %Arg_2 = f32[4,7,3]{2,1,0} parameter(1) + %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2} + ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[4,7,3]{2,1,0} %Arg_2, f32[7,4,3]{2,1,0} %broadcast.22), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2} +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter gemm_rewriter; + TF_ASSERT_OK_AND_ASSIGN(bool changed, + this->RunHloPass(&gemm_rewriter, module.get())); + EXPECT_TRUE(changed); + GemmBroadcastFoldingRewriter pass; + TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 0d60526ae9313f..d10a82a4b4554c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -493,7 +493,8 @@ ENTRY AddDotsFunc { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,2]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[P2]]), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":3,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"selected_algorithm\":\"{{-?[0-9]+}}\"}" +; CHECK-NEXT: [[P2_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[P2]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[P2_COPY]]), custom_call_target="__cublas$gemm", output_to_operand_aliasing={{{{}: \(2, {}\)}}}, backend_config="{\"alpha_real\":3,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"selected_algorithm\":\"{{-?[0-9]+}}\"}" )"); } @@ -782,7 +783,8 @@ ENTRY BF16GemmWithBias { ; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[8,8]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[8,8]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), custom_call_target="__cublas$gemm", backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"selected_algorithm\":\"{{-?[0-9]+}}\"}" +; CHECK-NEXT: [[P2_COPY:%[^ ]+]] = bf16[8,8]{1,0} copy([[P2]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2_COPY]]), custom_call_target="__cublas$gemm", output_to_operand_aliasing={{{{}: \(2, {}\)}}}, backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"selected_algorithm\":\"{{-?[0-9]+}}\"}" )"); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc b/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc index 8dcac0f8de5926..1d91c38a799a44 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/mlir_gemm_test.cc @@ -36,7 +36,7 @@ class GemmTest : public MlirGpuTestBase { %arg2: memref<2x2xf32> {lmhlo.output_index = dense<[0]> : tensor<1xindex>}) attributes { result_xla_shape = "(f32[4]) " } { - "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot)", + "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot)", matmul_options, R"(} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.terminator"() : () -> () diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 757d6e2d6f27ab..ec39a9b634fa98 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -404,10 +404,11 @@ enum class EvalErrorDetail : uint32_t { kDynamicValueDependence = 0, }; -Status MakeEvalErrorDueToParamOrInfeed() { +Status MakeEvalErrorDueToParamOrInfeed(const HloInstruction& eval_instruction) { Status error = tensorflow::errors::FailedPrecondition( - "Failed to evaluate instruction since it depends on infeed or " - "parameters to its parent computation."); + "Failed to evaluate instruction (", eval_instruction.name(), + ") since it depends on infeed or parameters to its parent computation (", + eval_instruction.parent()->name(), ")."); std::string error_payload; error_payload.resize(sizeof(EvalErrorDetail)); absl::little_endian::Store32( @@ -821,7 +822,7 @@ StatusOr HloEvaluator::Evaluate( } } if (!result.IsKnown()) { - return MakeEvalErrorDueToParamOrInfeed(); + return MakeEvalErrorDueToParamOrInfeed(*computation.root_instruction()); } return result.Clone(); } @@ -839,7 +840,7 @@ StatusOr HloEvaluator::Evaluate( recursively_evaluate_nonconstant_operands)); const Literal& result = GetEvaluatedLiteralFor(instruction); if (!result.IsKnown()) { - return MakeEvalErrorDueToParamOrInfeed(); + return MakeEvalErrorDueToParamOrInfeed(*instruction); } return result.Clone(); } @@ -3840,7 +3841,7 @@ Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { } Status HloEvaluator::Preprocess(HloInstruction* hlo) { - VLOG(2) << "About to visit HLO: " << hlo->ToString(); + VLOG(3) << "About to visit HLO: " << hlo->ToString(); if (!enable_partial_evaluation_) { for (HloInstruction* operand : hlo->mutable_operands()) { if (!IsAlreadyEvaluated(operand) || @@ -3855,7 +3856,7 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) { } Status HloEvaluator::Postprocess(HloInstruction* hlo) { - VLOG(2) << "Finished visiting " << hlo->ToString() + VLOG(3) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); // Out of convenience the literal may have been produced with a different // layout. Relayout as indicated by the HLO instruction. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index cc7194dfa9a57a..d95716e399ecd7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -931,8 +931,9 @@ std::string HloDotDumper::GetInstructionNodeInlinedOperands( } } else if (operand->opcode() == HloOpcode::kGetTupleElement) { operand_str = - StrFormat("tuple-element %d of %s", operand->tuple_index(), - operand->operand(0)->name()); + StrFormat("tuple-element %d of %s %s", operand->tuple_index(), + operand->operand(0)->name(), + ShapeUtil::HumanStringWithLayout(operand->shape())); } else { operand_str = operand->name(); } @@ -954,9 +955,10 @@ std::string HloDotDumper::GetInstructionNodeInlinedOperands( instr->parent()->FusionInstruction()->operand( instr->parameter_number()); if (param_input->opcode() == HloOpcode::kGetTupleElement) { - lines.push_back(StrFormat("tuple-element %d of %s", - param_input->tuple_index(), - param_input->operand(0)->name())); + lines.push_back( + StrFormat("tuple-element %d of %s %s", param_input->tuple_index(), + param_input->operand(0)->name(), + ShapeUtil::HumanStringWithLayout(param_input->shape()))); } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 78b573ef284e1a..f9a0e11640a56f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -4870,6 +4870,13 @@ void HloInstruction::set_async_thread_name( Cast(this)->set_async_thread_name(async_thread_name); } +void HloInstruction::set_called_computations_thread_name( + const std::optional& async_thread_name, + bool skip_async_thread_name_overwrite) { + Cast(this)->RecursivelySetComputationsThreadName( + async_thread_name, skip_async_thread_name_overwrite); +} + bool HloInstruction::is_cross_program_prefetch() const { return Cast(this)->is_cross_program_prefetch(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 1d1110f434ec45..fd9b4da3121692 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -2139,6 +2139,12 @@ class HloInstruction { void set_async_thread_name( const std::optional& async_thread_name); + // Delegates to + // HloCallableInstruction::RecursivelySetComputationsThreadName(). + void set_called_computations_thread_name( + const std::optional& async_thread_name, + bool skip_async_thread_name_overwrite); + // Delegates to HloCopyStartInstruction::is_cross_program_prefetch(). bool is_cross_program_prefetch() const; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 7ca3a7d64ce70d..4ea40d204c40a0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1798,8 +1798,13 @@ TEST_F(HloInstructionTest, CanonicalStringificationFusion) { auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); + constexpr char kParallelThreadName[] = "parallel_thread"; + computation->SetThreadName(kParallelThreadName); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); + fusion->set_called_computations_thread_name( + kParallelThreadName, + /*skip_async_thread_name_overwrite*/ false); const std::string expected_fusion = R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls= @@ -1808,7 +1813,7 @@ TEST_F(HloInstructionTest, CanonicalStringificationFusion) { tmp_1 = f32[20,10]{1,0} parameter(1) tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; +}, thread_name="parallel_thread")"; EXPECT_EQ(fusion->ToString(options), expected_fusion); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 19825a1d960ba5..051baf188ead7a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -80,6 +80,28 @@ std::string PrecisionConfigToString(const PrecisionConfig& precision_config) { }), "}"); } + +void SetThreadName(HloComputation* called_computation, + const std::optional& thread_name, + bool skip_async_thread_name_overwrite) { + called_computation->SetThreadName(thread_name); + for (HloInstruction* instr : called_computation->instructions()) { + if (instr->IsAsynchronous()) { + if (!skip_async_thread_name_overwrite) { + // Set async instruction thread name and also recursively set async + // computations. + instr->set_async_thread_name(thread_name); + } + continue; + } + for (HloComputation* nested_called_computation : + instr->called_computations()) { + SetThreadName(nested_called_computation, thread_name, + skip_async_thread_name_overwrite); + } + } +} + } // namespace HloBatchNormInstruction::HloBatchNormInstruction( @@ -328,21 +350,8 @@ void HloAsyncInstruction::set_async_group_id( void HloAsyncInstruction::set_async_thread_name( const std::optional& async_thread_name) { async_thread_name_ = async_thread_name; - // Recursively sets all called computation to have same thread name. - std::function)> - set_computation_thread_name = - [&](HloComputation* called_computation, - std::optional async_thread_name) { - called_computation->SetThreadName(async_thread_name); - for (HloInstruction* instr : called_computation->instructions()) { - for (HloComputation* nested_called_computation : - instr->called_computations()) { - set_computation_thread_name(nested_called_computation, - async_thread_name); - } - } - }; - set_computation_thread_name(async_wrapped_computation(), async_thread_name); + SetThreadName(async_wrapped_computation(), async_thread_name, + /*skip_async_thread_name_overwrite=*/false); } HloInstructionProto HloAsyncInstruction::ToProto() const { @@ -1740,6 +1749,14 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( return clone; } +void HloCallableInstruction::RecursivelySetComputationsThreadName( + std::optional thread_name, + bool skip_async_thread_name_overwrite) { + for (HloComputation* comp : called_computations()) { + SetThreadName(comp, thread_name, skip_async_thread_name_overwrite); + } +} + HloFusionInstruction::HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 309cfd6b766a0c..fd2224c2467917 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -983,6 +983,13 @@ class HloCallableInstruction : public HloInstruction { HloInstruction* called_computation_root() const; + // Recursively sets all nested called computation to have thread name as + // `thread_name`. if `skip_async_thread_name_overwrite` is true, skip + // overwrite async instruction and its comptuations thread name overwriting. + void RecursivelySetComputationsThreadName( + std::optional thread_name, + bool skip_async_thread_name_overwrite); + protected: // Returns the default called computation name. virtual std::string default_called_computation_name() const = 0; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 527813137e3420..119e40234638c1 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -573,7 +573,7 @@ class MemoryUsageTracker { PickRematerializationCandidates( const InstructionList& instruction_list, int64_t memory_limit_bytes, absl::flat_hash_map* rematerializable_map, - int min_block_size, int max_block_size); + int min_block_size, int max_block_size, int64_t peak_memory_bytes); // Returns whether the given instruction has been placed (BeginInstruction // has been called with 'instruction' as the argument). @@ -603,6 +603,8 @@ class MemoryUsageTracker { return size; } + const HloComputation* computation() const { return computation_; } + // Check invariants of the data structure. This is expensive to call. bool Check() const; @@ -1384,7 +1386,7 @@ std::tuple, RematStrategy, int> MemoryUsageTracker::PickRematerializationCandidates( const InstructionList& instruction_list, int64_t memory_limit_bytes, absl::flat_hash_map* rematerializable_map, - int min_block_size, int max_block_size) { + int min_block_size, int max_block_size, int64_t peak_memory_bytes) { std::vector best_items; int64_t best_cost = 0; RematStrategy best_strategy; @@ -1431,8 +1433,15 @@ MemoryUsageTracker::PickRematerializationCandidates( GetCompactShape(item->instruction).ValueOrDie(); const int64_t memory_reduced = MemoryReducedIfCompressed(item, compact_shape); + // Since the compressed and uncompressed buffers need to be alive + // while performing the compression/uncompression, only perform + // the compression if the sum of the two sizes is less than the + // peak memory. + const int64_t size = size_function_(item->instruction->shape()); + const int64_t reduced_size = size_function_(compact_shape); effort++; - if (memory_reduced > 0) { + if (memory_reduced > 0 && + size + reduced_size < peak_memory_bytes) { const int64_t cost = memory_limit_bytes / memory_reduced; if (best_items.empty() || cost < best_cost) { VLOG(3) << "candidate " << candidate->name() << "(" @@ -1796,7 +1805,9 @@ StatusOr RematerializeBestBlock( std::tie(best_items, best_strategy, effort) = memory_tracker->PickRematerializationCandidates( *instruction_list, memory_limit_bytes, rematerializable_map, - min_block_size, max_block_size); + min_block_size, max_block_size, + rematerialization->ComputationPeakMemory( + memory_tracker->computation())); InstructionsAdded num_instructions_added; num_instructions_added.remat_count = best_items.size(); num_instructions_added.effort = effort; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2619a4479462c5..0203b234e7a55d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -105,6 +105,11 @@ class HloRematerialization : public HloModulePass { // Get the next available channel id and increment count. int64_t NextChannelId() { return next_channel_id_++; } + // Get the peak memory for the computation. + int64_t ComputationPeakMemory(const HloComputation* computation) const { + return computation_peak_memory_.at(computation); + } + // Runs rematerialization on the given module. Returns whether the module was // changed. Requires that the module has a schedule set // (HloModule::has_schedule() is true) before running. Returns whether any diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index f296a883cd9bbc..3f85a7ee33fe09 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -582,6 +582,9 @@ class CompressingRematerializationTest : public RematerializationTestBase { // Swap the layout of the two most-minor dimensions if the second-minor // dimension is bigger than the most-minor dimension. static StatusOr ChooseCompactLayoutForShape(const Shape& shape) { + if (shape.rank() != 2) { + return shape; + } Shape result = shape; Layout layout = result.layout(); int64_t most_minor_index = layout.minor_to_major()[0]; @@ -697,6 +700,44 @@ ENTRY %entry { op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); } +// Test a pathological case where the peak memory is largely due to a single +// tensor (broadcast.0) and compressing it would actually increase the peak +// memory. +TEST_F(CompressingRematerializationTest, AvoidPathologicalCompress) { + const std::string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_float { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %constant = f32[] constant(0) + %broadcast.0 = f32[63,60]{1,0} broadcast(f32[] %param.0), dimensions={} + %broadcast.1 = f32[16,64]{1,0} broadcast(f32[] %param.0), dimensions={} + %reduce.0 = f32[] reduce(%broadcast.1, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.1 = f32[] reduce(%broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/16 * 1024, module.get())); + EXPECT_FALSE(changed); + HloInstruction* broadcast = + module->entry_computation()->GetInstructionWithName("broadcast.0"); + HloInstruction* reduce = + module->entry_computation()->GetInstructionWithName("reduce.1"); + EXPECT_THAT(reduce, op::Reduce(broadcast, op::Constant())); +} + TEST_F(CompressingRematerializationTest, AllUsersUseSameCopy) { const std::string& hlo_string = R"( HloModule fusion, is_scheduled=true diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 09154500d95c54..13b77462561c5c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -120,6 +121,27 @@ int64_t GetSubgroupSize(HloCollectiveInstruction* hlo, } } +Status CheckNestedComputationThreadNameEqual(const HloComputation* comp, + bool skip_nested_async_op_check) { + std::optional thread_name = comp->thread_name(); + for (const HloInstruction* instr : comp->instructions()) { + if (skip_nested_async_op_check && instr->IsAsynchronous()) { + continue; + } + for (const HloComputation* cmp : instr->called_computations()) { + if (cmp->thread_name() != thread_name) { + return InternalError( + "Nested computations expects same computation's thread name (%s vs " + "%s).", + thread_name ? absl::StrCat(*thread_name) : "none", + cmp->thread_name() ? absl::StrCat(*cmp->thread_name()) : "none"); + } + TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual( + cmp, skip_nested_async_op_check)); + } + } + return Status::OK(); +} } // namespace Status ShapeVerifier::Preprocess(HloInstruction* hlo) { @@ -1382,11 +1404,56 @@ Status CheckAsyncOpComputationShapes(const HloInstruction* async_op, } return Status::OK(); } + +Status CheckAsyncOpComputationThreadName(const HloInstruction* async_op) { + std::optional async_thread_name = + async_op->async_thread_name(); + if (async_thread_name != + async_op->async_wrapped_computation()->thread_name()) { + return InternalError( + "async-start expects same async thread name as wrapped computation's " + "thread name (%s vs %s).", + async_thread_name ? absl::StrCat(*async_thread_name) : "none", + async_op->async_wrapped_computation()->thread_name() + ? absl::StrCat( + *async_op->async_wrapped_computation()->thread_name()) + : "none"); + } + return CheckNestedComputationThreadNameEqual( + async_op->async_wrapped_computation(), + /*skip_nested_async_op_check=*/false); +} + +// TODO(b/229887502): apply CheckCallableInstructionThreadName to all +// CallableInstructions verifier. +Status CheckCallableInstructionThreadName(const HloInstruction* instruction, + bool skip_nested_async_op_check) { + for (const HloComputation* computation : instruction->called_computations()) { + if (instruction->parent() != nullptr) { + if (instruction->parent()->thread_name() != computation->thread_name()) { + return InternalError( + "callable instruction %s expects parent computation thread name " + "same as called computation's thread name (%s vs %s).", + instruction->ToString(), + instruction->parent()->thread_name() + ? absl::StrCat(*instruction->parent()->thread_name()) + : "none", + computation->thread_name() + ? absl::StrCat(*computation->thread_name()) + : "none"); + } + } + TF_RETURN_IF_ERROR(CheckNestedComputationThreadNameEqual( + computation, skip_nested_async_op_check)); + } + return Status::OK(); +} } // namespace Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) { TF_RETURN_IF_ERROR( CheckAsyncOpComputationShapes(async_start, async_start->shape())); + TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_start)); const Shape& param_shape = async_start->shape().tuple_shapes(0); for (int i = 0; i < async_start->operand_count(); ++i) { if (param_shape.tuple_shapes(i) != async_start->operand(i)->shape()) { @@ -1402,6 +1469,7 @@ Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) { } Status ShapeVerifier::HandleAsyncUpdate(HloInstruction* async_update) { + TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_update)); if (async_update->operand(0)->shape() != async_update->shape()) { return InternalError( "The %s expects the shape of operand and output to match (%s vs %s).", @@ -1415,6 +1483,7 @@ Status ShapeVerifier::HandleAsyncUpdate(HloInstruction* async_update) { } Status ShapeVerifier::HandleAsyncDone(HloInstruction* async_done) { + TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_done)); TF_RETURN_IF_ERROR(CheckAsyncOpComputationShapes( async_done, async_done->operand(0)->shape())); const Shape& root_shape = async_done->operand(0)->shape().tuple_shapes(1); @@ -2294,6 +2363,8 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction*) override { return OkStatus(); } Status HandleFusion(HloInstruction* fusion) override { + TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName( + fusion, /*skip_nested_async_op_check*/ false)); return CheckFusionInstruction(fusion); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 21e4b8357a0c78..c00b39b1645160 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -1863,6 +1863,56 @@ TEST_F(HloVerifierTest, FusionShapeVerifier) { HasSubstr("Fused computation shape")); } +TEST_F(HloVerifierTest, FusionThreadVerifier) { + const char* const kModuleStr = R"( + HloModule test + + fused_computation { + ROOT p0 = f32[8,12] parameter(0) + }, thread_name="parallel_thread" + + ENTRY entry { + p0 = f32[8,12] parameter(0) + ROOT out = f32[8,12] fusion(p0), kind=kInput, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("expects parent computation thread name same as called " + "computation's thread name")); +} + +TEST_F(HloVerifierTest, FusionNestedComputationThreadVerifier) { + const char* const kModuleStr = R"( + HloModule test + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + }, thread_name="parallel_thread" + + fused_computation { + p0 = f32[8,12] parameter(0) + p1 = f32[8,12] parameter(1) + crs0 = f32[8,12] all-reduce(p1), replica_groups={}, to_apply=add + ROOT result = add(p0, crs0) + } + + ENTRY entry { + p0 = f32[8,12] parameter(0) + p1 = f32[8,12] parameter(1) + ROOT out = f32[8,12] fusion(p0, p1), kind=kInput, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT( + verifier().Run(module.get()).status().error_message(), + HasSubstr("Nested computations expects same computation's thread name")); +} + TEST_F(HloVerifierTest, AllReduceVerifier) { const char* const kModuleStr = R"( HloModule test diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 6f3c958cdc586e..ffba488252e9ac 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index fb33a489901728..b36817190e0a25 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -479,13 +479,16 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { // If not replicated yet, first replicate and then reshard to use one of the // two implementations below. if (!sharding().IsReplicated()) { - LOG(ERROR) << "[spmd] Involuntary full rematerialization. The compiled was " - "not able to go from sharding " - << sharding().ToString(/*include_metadata=*/true) << " to " - << target.ToString(/*include_metadata=*/true) - << " without doing a full rematerialization of the tensor. You " - "probably want to enrich the sharding annotations to prevent " - "this from happening."; + if (!target.IsReplicated()) { + LOG(ERROR) + << "[spmd] Involuntary full rematerialization. The compiled was " + "not able to go from sharding " + << sharding().ToString(/*include_metadata=*/true) << " to " + << target.ToString(/*include_metadata=*/true) + << " without doing a full rematerialization of the tensor. You " + "probably want to enrich the sharding annotations to prevent " + "this from happening."; + } return Replicate().Reshard(target); } diff --git a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc index 53382bb39353ac..29e7b1d1b10826 100644 --- a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc @@ -447,7 +447,6 @@ MovableAllReduceContext IsAllReduceMovable( [&is_value_replicated_within_replica_group]( const HloInstruction& dynamic_update_slice) -> bool { for (int i = 2; i < dynamic_update_slice.operand_count(); ++i) { - LOG(INFO) << " operand: " << dynamic_update_slice.operand(i)->ToString(); if (!is_value_replicated_within_replica_group( *dynamic_update_slice.operand(i), {})) { return false; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index a570174bf5fb93..cc33e9e217ff27 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -446,9 +446,10 @@ constexpr inline int Log2Ceiling(T x) { return x == 0 ? -1 : absl::bit_width(x - 1); } -// Returns the value with every bit except the lower 'width' bits set to zero. +// Returns `value` with the low `width` bits set and the remaining bits set to +// zero. template -constexpr inline T ClearUpperBits(T value, int width) { +constexpr inline T KeepLowerBits(T value, int width) { return value & LsbMask(width); } diff --git a/tensorflow/core/distributed_runtime/preemption/BUILD b/tensorflow/core/distributed_runtime/preemption/BUILD index 64cd0a47551170..86e0071586484c 100644 --- a/tensorflow/core/distributed_runtime/preemption/BUILD +++ b/tensorflow/core/distributed_runtime/preemption/BUILD @@ -77,9 +77,6 @@ tf_cc_test( name = "preemption_notifier_test", size = "small", srcs = ["preemption_notifier_test.cc"], - tags = [ - "no_windows", # TODO(b/236272855): Fails on Windows RBE (C++) - ], deps = [ ":preemption_notifier", "//tensorflow/core:test", diff --git a/tensorflow/core/distributed_runtime/preemption/preemption_notifier.cc b/tensorflow/core/distributed_runtime/preemption/preemption_notifier.cc index 1922985b773e68..1f92da8e902163 100644 --- a/tensorflow/core/distributed_runtime/preemption/preemption_notifier.cc +++ b/tensorflow/core/distributed_runtime/preemption/preemption_notifier.cc @@ -66,7 +66,7 @@ void SigtermNotifier::StartListenerThread() { // 1) Cancel any pending callbacks and blocking WillBePreemptedAt() // calls. NotifyRegisteredListeners( - errors::Cancelled("Preemption notifier is shutting down.")); + errors::Cancelled("Preemption notifier is being deleted.")); // 2) Exit listener thread. return; } diff --git a/tensorflow/core/distributed_runtime/preemption/preemption_notifier_test.cc b/tensorflow/core/distributed_runtime/preemption/preemption_notifier_test.cc index b94d67e95c2629..07e16b02282576 100644 --- a/tensorflow/core/distributed_runtime/preemption/preemption_notifier_test.cc +++ b/tensorflow/core/distributed_runtime/preemption/preemption_notifier_test.cc @@ -46,7 +46,7 @@ TEST(PreemptNotifierTest, WillBePreemptedAt) { // Make sure that preempt time is approximately correct. absl::Duration time_diff = preempt_time - start_time; // Signal was raised 1 second after start time. - EXPECT_GT(time_diff, absl::Seconds(1)); + EXPECT_GT(time_diff, absl::Seconds(1.0)); // Listen to signal once per second, so we should catch within 2 seconds. EXPECT_LT(time_diff, absl::Seconds(3)); } diff --git a/tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.cc b/tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.cc index 4474dae5c0b042..d959fa055bd405 100644 --- a/tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.cc @@ -121,10 +121,15 @@ Status PreemptionSyncManagerImpl::Initialize( preemption_notifier_->WillBePreemptedAtAsync( [agent = agent_, task_name](StatusOr death_time) { if (!death_time.ok()) { - // This usually happens when the preemption notifier dtor is called - // and blocking calls are cancelled. - LOG(ERROR) << "Error from preemption notifier: " - << death_time.status(); + // The preemption notifier invokes callback with Cancelled error when + // its being destructed. + if (errors::IsCancelled(death_time.status())) { + LOG(INFO) << "Preemption sync protocol cancelled by notifier: " + << death_time.status(); + } else { + LOG(ERROR) << "Error from preemption notifier: " + << death_time.status(); + } return; } // Notify coordination service about preemption notice. diff --git a/tensorflow/core/example/BUILD b/tensorflow/core/example/BUILD index a844c7c39e8934..7e0764450c780e 100644 --- a/tensorflow/core/example/BUILD +++ b/tensorflow/core/example/BUILD @@ -51,8 +51,6 @@ cc_library( ":example_protos_cc", "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:stringpiece", - "//tensorflow/core/platform:types", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index 7ec6f3d375a6c9..04da4dcd0d003c 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -69,6 +69,12 @@ limitations under the License. // } // } // } +// For string-valued features, note that the Append... and Set... functions +// support absl::string_view containers. This allows you to copy existing +// buffers into a Feature with only one copy: +// std::vector image; +// image.push_back(image_buffer); // No copy. +// SetFeatureValues(image, "image", &example); // Copy. // // Functions exposed by this library: // HasFeature<[FeatureType]>(key, proto) -> bool @@ -116,16 +122,15 @@ limitations under the License. #include #include +#include #include #include -#include "absl/base/macros.h" #include "absl/strings/string_view.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -347,8 +352,12 @@ void AppendFeatureValues(const ContainerType& container, Feature* feature) { typename std::iterator_traits::value_type>::Type; auto* values = GetFeatureValues(feature); internal::ReserveIfSizeAvailable(container, *values); - std::copy(container.begin(), container.end(), - protobuf::RepeatedFieldBackInserter(values)); + // This is equivalent to std::copy into `values` with a + // RepeatedFieldBackInserter, the difference is RFBI isn't compatible with + // types that we want to convert (e.g. absl::string_view -> std::string). + for (const auto& elt : container) { + *values->Add() = elt; + } } // Copies elements from the range, defined by [first, last) into the feature diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc index f4fd06ea3cbd26..8500517d1a0bf3 100644 --- a/tensorflow/core/example/feature_util_test.cc +++ b/tensorflow/core/example/feature_util_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/example/feature_util.h" +#include #include #include "absl/strings/string_view.h" @@ -323,6 +324,21 @@ TEST(SetFeatureValuesTest, FloatValuesUsingInitializerList) { EXPECT_NEAR(30.3, tag_ro.Get(2), kTolerance); } +TEST(SetFeatureValuesTest, ContainerOfStringView) { + Example example; + + std::vector values = {"hello", "world"}; + std::vector values_string_view(values.begin(), + values.end()); + + SetFeatureValues(values_string_view, "tag", &example); + + auto tag_ro = GetFeatureValues("tag", example); + ASSERT_EQ(tag_ro.size(), 2); + EXPECT_EQ(tag_ro.Get(0), "hello"); + EXPECT_EQ(tag_ro.Get(1), "world"); +} + TEST(AppendFeatureValuesTest, Int64ValuesUsingInitializerList) { Example example; diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index fe16a6998a3031..a14c81fa4e1f90 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -562,5 +562,22 @@ void UpdateEagerClientErrorCounter(const string& error_source, eager_client_error_counter->GetCell(error_source, error_type)->IncrementBy(1); } +void UpdateTfMlirBridgeGraphAnalysisPerOp( + const std::string& op_name, const std::string& construction_context, + bool is_single_core_inference_mode, const std::string& unsupported_reason, + bool has_unsupported_features) { + static auto* metric = monitoring::Counter<5>::New( + "/tensorflow/core/tf_mlir_bridge_graph_analysis_per_op", + "Tracks processing state per op in first phase of mlir bridge", "op_name", + "construction_context", "is_single_core_inference_mode", + "unsupported_reason", "has_unsupported_features"); + + metric + ->GetCell(op_name, construction_context, + is_single_core_inference_mode ? "Yes" : "No", + unsupported_reason, has_unsupported_features ? "Yes" : "No") + ->IncrementBy(1); +} + } // namespace metrics } // namespace tensorflow diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h index 343eb74ed57514..64086daa853372 100644 --- a/tensorflow/core/framework/metrics.h +++ b/tensorflow/core/framework/metrics.h @@ -212,6 +212,21 @@ void UpdateTfMlirBridgeFirstPhaseCounter(const std::string& device_type, bool fallback_enabled, const std::string& result); +// Records the activity per op using the +// tf_metadata.tf_mlir_bridge_graph_analysis_per_op. +// op_name: the name of op. +// construction_context: eager, session, Not tracked. +// is_single_core_inference_mode: true, false. +// unsupported_reason: the reason why the graph is not supported in MLIR-based +// bridge, like invalid graph, has unsupported ops, etc. +// has_unsupported_features: true indicates MLIR-based bridge is disabled, +// false indicates MLIR-based bridge is enabled. + +void UpdateTfMlirBridgeGraphAnalysisPerOp( + const std::string& op_name, const std::string& construction_context, + bool is_single_core_inference_mode, const std::string& unsupported_reason, + bool has_unsupported_features); + // Convenience class allowing RAII style of reporting for a monitoring::Counter. template class ScopedCounter final { diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD index a9e57d6e17c737..c1f57472d6f792 100644 --- a/tensorflow/core/function/trace_type/BUILD +++ b/tensorflow/core/function/trace_type/BUILD @@ -19,6 +19,7 @@ pytype_strict_library( visibility = ["//tensorflow:internal"], deps = [ ":default_types", + ":serialization", ":util", "//tensorflow/python/types", ], diff --git a/tensorflow/core/function/trace_type/__init__.py b/tensorflow/core/function/trace_type/__init__.py index 93eed8dd60bf33..472f512e91364b 100644 --- a/tensorflow/core/function/trace_type/__init__.py +++ b/tensorflow/core/function/trace_type/__init__.py @@ -25,8 +25,11 @@ Other implementations of TraceType include tf.TypeSpec and its subclasses. """ - +from tensorflow.core.function.trace_type.serialization import deserialize +from tensorflow.core.function.trace_type.serialization import register_serializable +from tensorflow.core.function.trace_type.serialization import Serializable +from tensorflow.core.function.trace_type.serialization import serialize +from tensorflow.core.function.trace_type.serialization import SerializedTraceType from tensorflow.core.function.trace_type.trace_type_builder import from_object from tensorflow.core.function.trace_type.trace_type_builder import InternalTracingContext from tensorflow.core.function.trace_type.trace_type_builder import WeakrefDeletionObserver - diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py index cafab0cfc3644f..99e12dd456f897 100644 --- a/tensorflow/core/function/trace_type/default_types.py +++ b/tensorflow/core/function/trace_type/default_types.py @@ -603,3 +603,11 @@ def __hash__(self) -> int: def __repr__(self): return (f"{self.__class__.__name__}(base={self.base!r}, " f"identifier={self.identifier!r})") + +serialization.register_serializable(Literal) +serialization.register_serializable(Tuple) +serialization.register_serializable(List) +serialization.register_serializable(NamedTuple) +serialization.register_serializable(Attrs) +serialization.register_serializable(Dict) +serialization.register_serializable(Reference) diff --git a/tensorflow/core/function/trace_type/serialization.py b/tensorflow/core/function/trace_type/serialization.py index 7a8cb3b2b148c6..a1943e828a0602 100644 --- a/tensorflow/core/function/trace_type/serialization.py +++ b/tensorflow/core/function/trace_type/serialization.py @@ -28,18 +28,6 @@ class Serializable(metaclass=abc.ABCMeta): """TraceTypes implementing this additional interface are portable.""" - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS: - raise ValueError( - "Existing Python class " + - PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ + - " already has " + cls.experimental_type_proto().__name__ + - " as its associated proto representation. Please ensure " + - cls.__name__ + " has a unique proto representation.") - - PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls - @classmethod @abc.abstractmethod def experimental_type_proto(cls) -> Type[message.Message]: @@ -58,6 +46,25 @@ def experimental_as_proto(self) -> message.Message: raise NotImplementedError +def register_serializable(cls: Type[Serializable]): + """Registers a Python class to support serialization. + + Only register standard TF types. Custom types should NOT be registered. + + Args: + cls: Python class to register. + """ + if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS: + raise ValueError( + "Existing Python class " + + PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ + + " already has " + cls.experimental_type_proto().__name__ + + " as its associated proto representation. Please ensure " + + cls.__name__ + " has a unique proto representation.") + + PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls + + def serialize(to_serialize: Serializable) -> SerializedTraceType: """Converts Serializable to a proto SerializedTraceType.""" diff --git a/tensorflow/core/function/trace_type/serialization_test.py b/tensorflow/core/function/trace_type/serialization_test.py index 4afbd013c56929..ab7a229442d56f 100644 --- a/tensorflow/core/function/trace_type/serialization_test.py +++ b/tensorflow/core/function/trace_type/serialization_test.py @@ -39,6 +39,9 @@ def experimental_as_proto(self): return proto +serialization.register_serializable(MyCustomClass) + + class MyCompositeClass(serialization.Serializable): def __init__(self, *elements): @@ -62,6 +65,9 @@ def experimental_as_proto(self): return proto +serialization.register_serializable(MyCompositeClass) + + class SerializeTest(test.TestCase): def testCustomClassSerialization(self): @@ -124,25 +130,26 @@ def testCompositeClassDeserialization(self): self.assertEqual(deserialized.elements[2].name, "name_3") def testNonUniqueProto(self): + class ClassThatReusesProto(serialization.Serializable): + + @classmethod + def experimental_type_proto(cls): + return serialization_test_pb2.MyCustomRepresentation + + @classmethod + def experimental_from_proto(cls, proto): + raise NotImplementedError + + def experimental_as_proto(self): + raise NotImplementedError + with self.assertRaisesRegex( ValueError, ("Existing Python class MyCustomClass already has " "MyCustomRepresentation as its associated proto representation. " "Please ensure ClassThatReusesProto has a unique proto representation." )): - - class ClassThatReusesProto(serialization.Serializable): # pylint: disable=unused-variable - - @classmethod - def experimental_type_proto(cls): - return serialization_test_pb2.MyCustomRepresentation - - @classmethod - def experimental_from_proto(cls, proto): - raise NotImplementedError - - def experimental_as_proto(self): - raise NotImplementedError + serialization.register_serializable(ClassThatReusesProto) def testWrongProto(self): diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index b0b75e321c82a7..2680c19016406e 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -307,9 +307,7 @@ bool IsCpuCompatibleDataType(const NodeDef* contraction, bool IsGpuCompatibleDataType(const NodeDef* contraction, const string& type_attr = "T") { DataType dtype = GetDataTypeFromAttr(*contraction, type_attr); - if (IsConv2D(*contraction)) { - return dtype == DT_FLOAT; - } else if (IsMatMul(*contraction)) { + if (IsConv2D(*contraction) || IsMatMul(*contraction)) { return dtype == DT_FLOAT || dtype == DT_HALF; } else { return false; @@ -429,10 +427,10 @@ bool IsGpuCompatible(const RemapperContext& ctx, // in-graph computation in micro benchmarks (see kernels/conv_ops_test.cc), // and significantly slower in large scale benchmarks. bool is_spatial_conv = Rank(filter_shape) == 4 && // + IsKnown(filter_shape.dim(0)) && // IsKnown(filter_shape.dim(1)) && // - IsKnown(filter_shape.dim(2)) && // - filter_shape.dim(1).size() != 1 && // - filter_shape.dim(2).size() != 1; + filter_shape.dim(0).size() != 1 && // + filter_shape.dim(1).size() != 1; return is_spatial_conv && IsGpuCompatibleConv2D(ctx, &contraction_node); } else if (IsMatMul(contraction_node)) { @@ -3129,7 +3127,8 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) { const auto is_relu_biasadd_conv_candidate = [&]() -> bool { if (!IsRelu(*node_def)) return false; - if (GetDataTypeFromAttr(*node_def, "T") != DT_FLOAT) return false; + DataType act_dtype = GetDataTypeFromAttr(*node_def, "T"); + if (act_dtype != DT_FLOAT && act_dtype != DT_HALF) return false; if (node_view->NumRegularFanins() < 1) return false; const auto& relu_fanin_0 = node_view->GetRegularFanin(0); @@ -3138,8 +3137,8 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) { if (!IsBiasAdd(*relu_fanin_0_node_def) && !IsAdd(*relu_fanin_0_node_def)) return false; - if (GetDataTypeFromAttr(*relu_fanin_0_node_def, "T") != DT_FLOAT) - return false; + DataType biasadd_dtype = GetDataTypeFromAttr(*relu_fanin_0_node_def, "T"); + if (biasadd_dtype != DT_FLOAT && biasadd_dtype != DT_HALF) return false; if (relu_fanin_0_node_view->NumRegularFanins() < 1) return false; @@ -3149,8 +3148,8 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) { if (!IsConv2D(*biasadd_fanin_0_node_def) && !IsConv3D(*biasadd_fanin_0_node_def)) return false; - if (GetDataTypeFromAttr(*biasadd_fanin_0_node_def, "T") != DT_FLOAT) - return false; + DataType conv_dtype = GetDataTypeFromAttr(*biasadd_fanin_0_node_def, "T"); + if (conv_dtype != DT_FLOAT && conv_dtype != DT_HALF) return false; return true; }; diff --git a/tensorflow/core/kernels/conv_ops_fused_half.cc b/tensorflow/core/kernels/conv_ops_fused_half.cc index 5086b2b6f1b908..2945fcf530ac9a 100644 --- a/tensorflow/core/kernels/conv_ops_fused_half.cc +++ b/tensorflow/core/kernels/conv_ops_fused_half.cc @@ -27,6 +27,8 @@ namespace functor { DECLARE_FUNCTOR_GPU_SPEC(Eigen::half); } // namespace functor +TF_CALL_half(REGISTER_FUSED_GPU_CONV2D); + #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops_fused_impl.h b/tensorflow/core/kernels/conv_ops_fused_impl.h index 19580a105478c8..76a3589178903a 100644 --- a/tensorflow/core/kernels/conv_ops_fused_impl.h +++ b/tensorflow/core/kernels/conv_ops_fused_impl.h @@ -220,6 +220,9 @@ struct LaunchFusedConv2DOp { OP_REQUIRES(context, params.data_format == FORMAT_NHWC, errors::Unimplemented("Fused conv implementation only supports " "NHWC tensor format for now.")); + OP_REQUIRES(context, DataTypeToEnum::value != DT_HALF, + errors::Unimplemented("Fused conv implementation with half " + "precision is not supported on CPU.")); BiasAddArgs bias_add_args; if (BiasAddArgs::IsSupported(fusion)) { @@ -420,7 +423,10 @@ struct LaunchFusedConv2DOp { in_cols = new_in_cols; } - if (params.data_format == FORMAT_NHWC) { + const bool compute_in_nhwc = DataTypeToEnum::value == DT_HALF && + stream->GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::VOLTA); + if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) { // Convert the input tensor from NHWC to NCHW. TensorShape nchw_shape = ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths); @@ -452,23 +458,37 @@ struct LaunchFusedConv2DOp { LOG(FATAL) << "Unsupported fusion type"; // Crash OK } + const TensorFormat compute_data_format = + compute_in_nhwc ? FORMAT_NHWC : FORMAT_NCHW; + constexpr auto kComputeInNHWC = + std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, + se::dnn::FilterLayout::kOutputYXInput); + constexpr auto kComputeInNCHW = + std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, + se::dnn::FilterLayout::kOutputInputYX); + se::dnn::DataLayout compute_data_layout; + se::dnn::FilterLayout filter_layout; + std::tie(compute_data_layout, filter_layout) = + compute_in_nhwc ? kComputeInNHWC : kComputeInNCHW; + se::dnn::BatchDescriptor input_desc; input_desc.set_count(in_batch) .set_feature_map_count(in_depths) .set_height(in_rows) .set_width(in_cols) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); se::dnn::FilterDescriptor filter_desc; filter_desc.set_input_filter_height(patch_rows) .set_input_filter_width(patch_cols) .set_input_feature_map_count(patch_depths) - .set_output_feature_map_count(filter.dim_size(3)); + .set_output_feature_map_count(filter.dim_size(3)) + .set_layout(filter_layout); se::dnn::BatchDescriptor bias_desc; bias_desc.set_count(1) .set_height(1) .set_width(1) .set_feature_map_count(out_depths) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); se::dnn::ConvolutionDescriptor conv_desc; conv_desc.set_vertical_dilation_rate(dimensions.dilation_rows) .set_horizontal_dilation_rate(dimensions.dilation_cols) @@ -482,22 +502,38 @@ struct LaunchFusedConv2DOp { .set_height(out_rows) .set_width(out_cols) .set_feature_map_count(out_depths) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(compute_data_layout); Tensor transformed_filter; - OP_REQUIRES_OK(context, - context->allocate_temp( - DataTypeToEnum::value, - TensorShape({filter.dim_size(3), filter.dim_size(2), - filter.dim_size(0), filter.dim_size(1)}), - &transformed_filter)); - functor::TransformFilter()( - context->eigen_device(), FORMAT_OIHW, - To32Bit(filter.tensor()), - To32Bit(transformed_filter.tensor())); + const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status { + VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO) + << " to " << ToString(dst_format); + + TensorShape dst_shape = + dst_format == FORMAT_OIHW + ? TensorShape({filter.dim_size(3), filter.dim_size(2), + filter.dim_size(0), filter.dim_size(1)}) + : TensorShape({filter.dim_size(3), filter.dim_size(0), + filter.dim_size(1), filter.dim_size(2)}); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, dst_shape, &transformed_filter)); + functor::TransformFilter()( + context->eigen_device(), dst_format, + To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); + + return OkStatus(); + }; + + if (compute_in_nhwc) { + OP_REQUIRES_OK(context, transform_filter(FORMAT_OHWI)); + } else { + OP_REQUIRES_OK(context, transform_filter(FORMAT_OIHW)); + } Tensor transformed_output; - if (params.data_format == FORMAT_NHWC) { + if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) { // Only allocate temporary memory when a layout transformation is needed. OP_REQUIRES_OK(context, context->allocate_temp( @@ -533,7 +569,7 @@ struct LaunchFusedConv2DOp { in_depths, // in_depths {{in_rows, // in_rows in_cols}}, // in_cols - FORMAT_NCHW, // compute_data_format + compute_data_format, // compute_data_format out_depths, // out_depths {{patch_rows, // filter_rows patch_cols, // filter_cols @@ -616,7 +652,7 @@ struct LaunchFusedConv2DOp { OP_REQUIRES_OK(context, cudnn_launch_status); // Convert the output tensor back from NCHW to NHWC. - if (params.data_format == FORMAT_NHWC) { + if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) { functor::NCHWToNHWC()( context->eigen_device(), const_cast(transformed_output).tensor(), diff --git a/tensorflow/core/kernels/conv_ops_gpu.cc b/tensorflow/core/kernels/conv_ops_gpu.cc index aa6936e6ff8dba..24a73c6db74283 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu.cc @@ -60,7 +60,7 @@ StatusOr> AutotuneConvImpl( ? static_cast(&rz_scratch_allocator) : static_cast(&scratch_allocator); - SE_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc()); + TF_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc()); se::dnn::ProfileResult profile_result; Status cudnn_launch_status = actually_do_autotune @@ -157,7 +157,7 @@ StatusOr> AutotuneFusedConv( std::vector> runners; auto element_type = se::dnn::ToDataType::value; - SE_RETURN_IF_ERROR(stream->parent()->GetFusedConvolveRunners( + TF_RETURN_IF_ERROR(stream->parent()->GetFusedConvolveRunners( CudnnUseFrontend(), se::dnn::ConvolutionKind::FORWARD, element_type, element_type, element_type, conv_scale, side_input_scale, stream, input_desc, filter_desc, bias_desc, output_desc, conv_desc, @@ -173,7 +173,7 @@ StatusOr> AutotuneFusedConv( side_input_ptr, bias_ptr, output_ptr_rz); }; - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto results, AutotuneConvImpl(ctx, runners, cudnn_use_autotune, launch_func, scratch_size_limit, rz_allocator)); @@ -207,13 +207,13 @@ StatusOr> AutotuneFusedConv( << params.ToString(); std::vector> fallback_runners; - SE_RETURN_IF_ERROR(stream->parent()->GetFusedConvolveRunners( + TF_RETURN_IF_ERROR(stream->parent()->GetFusedConvolveRunners( CudnnUseFrontend(), se::dnn::ConvolutionKind::FORWARD, element_type, element_type, element_type, conv_scale, side_input_scale, stream, input_desc, filter_desc, bias_desc, output_desc, conv_desc, /*use_fallback=*/true, activation_mode, &fallback_runners)); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto fallback_results, AutotuneConvImpl(ctx, fallback_runners, cudnn_use_autotune, launch_func, scratch_size_limit, rz_allocator)); @@ -271,6 +271,24 @@ template StatusOr> AutotuneFusedConv( se::DeviceMemory bias_ptr, se::DeviceMemory side_input_ptr, int64_t scratch_size_limit); +template StatusOr> +AutotuneFusedConv( + bool cudnn_use_autotune, + AutotuneMap>* + autotune_map, + const ConvParameters& params, OpKernelContext* ctx, + const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& bias_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, + const se::dnn::ActivationMode activation_mode, double conv_scale, + double side_input_scale, se::DeviceMemory input_ptr, + se::DeviceMemory filter_ptr, + se::DeviceMemory output_ptr, + se::DeviceMemory bias_ptr, + se::DeviceMemory side_input_ptr, int64_t scratch_size_limit); + template StatusOr> AutotuneUnfusedConv( bool cudnn_use_autotune, @@ -331,7 +349,7 @@ StatusOr> AutotuneUnfusedConv( return (*runner)(stream, profile_result, scratch, input_ptr, filter_ptr, output_ptr); }; - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto results, AutotuneConvImpl(ctx, runners, cudnn_use_autotune, launch_func, scratch_size_limit, rz_allocator)); @@ -353,7 +371,7 @@ StatusOr> AutotuneUnfusedConv( } if (!CudnnUseFrontend() || found_working_engine) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( autotune_entry, BestCudnnConvAlgorithm(results, std::move(runners))); } else { @@ -368,7 +386,7 @@ StatusOr> AutotuneUnfusedConv( output_ptr, conv_desc, /*use_fallback=*/true, &rz_allocator, &fallback_runners)); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto fallback_results, AutotuneConvImpl(ctx, fallback_runners, cudnn_use_autotune, launch_func, scratch_size_limit, rz_allocator)); @@ -378,7 +396,7 @@ StatusOr> AutotuneUnfusedConv( output_desc, conv_desc, stream->parent(), fallback_results); - SE_ASSIGN_OR_RETURN(autotune_entry, + TF_ASSIGN_OR_RETURN(autotune_entry, BestCudnnConvAlgorithm( fallback_results, std::move(fallback_runners))); } @@ -432,7 +450,7 @@ StatusOr> AutotuneUnfusedConv( filter_ptr, output_ptr, input_desc, filter_desc, output_desc, conv_desc, stream->parent(), results); - SE_ASSIGN_OR_RETURN(auto algo_desc, BestCudnnConvAlgorithm(results)); + TF_ASSIGN_OR_RETURN(auto algo_desc, BestCudnnConvAlgorithm(results)); autotune_entry = AutotuneEntry(algo_desc); #endif diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index 71b5cf809587e3..1b1122ff0c05ed 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -469,7 +469,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { VLOG(3) << "Calling GetNext in data service dataset's iterator."; mutex_lock l(mu_); EnsureThreadsStarted(ctx); - Result result; + std::shared_ptr result; do { while (!ResultReady() && !Finished() && !cancelled_ && status_.ok()) { VLOG(3) << "Blocking in GetNext: " << DebugString(); @@ -488,20 +488,24 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { VLOG(3) << "Returning from GetNext with end_of_sequence"; return OkStatus(); } + if (!ResultReady()) { + return errors::Internal( + "Expected a result to be ready, but none were."); + } result = PopNextResult(); worker_thread_cv_.notify_one(); - } while (result.skip); + } while (result->skip); - *end_of_sequence = result.end_of_sequence; + *end_of_sequence = result->end_of_sequence; if (!*end_of_sequence) { VLOG(1) << "Returning the next element from data service dataset's " - << "Iterator: task " << result.task_id << ", element " - << result.element_index; + << "Iterator: task " << result->task_id << ", element " + << result->element_index; if (StrictRoundRobin()) { VLOG(1) << "Consumer " << dataset()->consumer_index_.value() << ": Result " << get_next_index_++; } - out_tensors->swap(result.element); + out_tensors->swap(result->element); } return OkStatus(); } @@ -896,7 +900,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { VLOG(1) << "Starting worker thread"; std::shared_ptr task_to_process; while (true) { - Result* result; + std::shared_ptr result; { mutex_lock l(mu_); if (task_to_process) { @@ -922,21 +926,18 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { ++outstanding_requests_; if (StrictRoundRobin()) { // Reserve a spot in the results_ queue. - results_.emplace(); - result = &results_.back(); + results_.push(std::make_shared()); + result = results_.back(); + } else { + // The result will be added to results_ when it's ready. + result = std::make_shared(); } VLOG(3) << "Processing task " << task_to_process->info.task_id(); } int64_t deadline_micros = kint64max; - Status s; - if (StrictRoundRobin()) { - s = GetElementTraced(task_to_process.get(), deadline_micros, - /*enqueue_result=*/false, *result); - } else { - Result r; - s = GetElementTraced(task_to_process.get(), deadline_micros, - /*enqueue_result=*/true, r); - } + Status s = + GetElementTraced(task_to_process.get(), deadline_micros, + /*enqueue_result=*/!StrictRoundRobin(), result); if (!s.ok()) { mutex_lock l(mu_); VLOG(1) << "Failed to get element from worker " @@ -1035,30 +1036,30 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { void ProcessGetElementResponse(bool enqueue_result, GetElementResult& get_element_result, - Result& result, Task& task) { + std::shared_ptr result, Task& task) { mutex_lock l(mu_); - result.ready = true; - result.end_of_sequence = get_element_result.end_of_sequence; - result.skip = get_element_result.skip; + result->ready = true; + result->end_of_sequence = get_element_result.end_of_sequence; + result->skip = get_element_result.skip; if (!get_element_result.end_of_sequence && !get_element_result.skip) { task.skipped_previous_round = false; - result.element = std::move(get_element_result.components); - result.element_index = get_element_result.element_index; - result.task_id = task.info.task_id(); + result->element = std::move(get_element_result.components); + result->element_index = get_element_result.element_index; + result->task_id = task.info.task_id(); } else if (get_element_result.skip) { task.skipped_previous_round = true; } else { task.end_of_sequence = true; finished_tasks_++; } - if (enqueue_result && !result.end_of_sequence) { + if (enqueue_result && !result->end_of_sequence) { results_.push(std::move(result)); } get_next_cv_.notify_all(); } Status GetElementTraced(Task* task, int64_t deadline_micros, - bool enqueue_result, Result& result) + bool enqueue_result, std::shared_ptr result) TF_LOCKS_EXCLUDED(mu_) { VLOG(3) << "Getting an element for task id " << task->info.task_id(); tensorflow::profiler::TraceMe activity( @@ -1114,7 +1115,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { } Status GetElement(Task* task, int64_t deadline_micros, bool enqueue_result, - Result& result) TF_LOCKS_EXCLUDED(mu_) { + std::shared_ptr result) TF_LOCKS_EXCLUDED(mu_) { GetElementResult get_element_result; for (int num_retries = 0;; ++num_retries) { Status s = TryGetElement(*task, get_element_result); @@ -1134,9 +1135,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { return s; } if (StrictRoundRobin() && num_retries > 0) { - TF_RETURN_IF_ERROR(MaybeRemoveTask(*task, deadline_micros, result)); + TF_RETURN_IF_ERROR(MaybeRemoveTask(*task, deadline_micros, *result)); mutex_lock l(mu_); - if (result.skip) { + if (result->skip) { return OkStatus(); } } @@ -1155,11 +1156,11 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { } bool ResultReady() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return !results_.empty() && results_.front().ready; + return !results_.empty() && results_.front()->ready; } - Result PopNextResult() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - Result result = std::move(results_.front()); + std::shared_ptr PopNextResult() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::shared_ptr result = results_.front(); results_.pop(); return result; } @@ -1173,7 +1174,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { "results_ { size: $0 front.ready: $1 } iteration_finished_: $2 " "tasks { size: $3 } finished_tasks_: $4 " "num_running_worker_threads_: $5", - results_.size(), !results_.empty() && results_.front().ready, + results_.size(), !results_.empty() && results_.front()->ready, iteration_finished_, tasks_.size(), finished_tasks_, num_running_worker_threads_); } @@ -1226,7 +1227,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { // their `Result::ready` field false until their data has been retrieved // from a worker. When not doing round-robin reads, results are only added // to the queue after they are ready, to avoid head-of-line blocking. - std::queue results_ TF_GUARDED_BY(mu_); + std::queue> results_ TF_GUARDED_BY(mu_); bool initialized_ = false; // Set once in Initialize(). diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index b5e87a61774645..1dcc7ae80b6bc9 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_shape_util.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/conv_ops.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/padding.h" @@ -300,6 +302,23 @@ class DepthwiseConv2dNativeOp : public BinaryOp { OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, /*num_dims=*/4, data_format_)); + // CPU/GPU kernel currently ignores dilations, so all must be 1. + std::vector dilations; + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations)); + bool unit_dilations = true; + for (int32_t dilation : dilations) { + if (dilation != 1) { + unit_dilations = false; + } + } + OP_REQUIRES(context, unit_dilations, + errors::Unimplemented( + "Current kernel implementation does not support " + "dilations, received [", + Eigen::Map>( + dilations.data(), dilations.size()), + "]")); + cudnn_use_autotune_ = CudnnUseAutotune(); dtype_ = DataTypeToEnum::value; #if CUDNN_VERSION >= 8000 @@ -474,7 +493,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp { bool use_cudnn_grouped_conv_; private: - std::vector strides_; + std::vector strides_; Padding padding_; std::vector explicit_paddings_; TensorFormat data_format_; diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 4dad433ff1bbd3..091b18107aa786 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -21,6 +21,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include +#include #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -425,7 +426,7 @@ struct LaunchBatchMatMul { OP_REQUIRES(context, status_or_computation_type.ok(), errors::Internal("Unsupported dtype for batched matmul.")); se::blas::ComputationType computation_type = - status_or_computation_type.ConsumeValueOrDie(); + std::move(status_or_computation_type).value(); se::cuda::BlasLt::MatmulPlanParams matmul_params{ /*ab_type=*/blas_dtype, @@ -453,7 +454,7 @@ struct LaunchBatchMatMul { GetPlanAndAlgorithms(stream, matmul_params, max_algorithm_count); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); const auto* plan_and_algorithms = - plan_and_algorithms_or.ConsumeValueOrDie(); + std::move(plan_and_algorithms_or).value(); const auto& plan = plan_and_algorithms->plan; const auto& algorithms = plan_and_algorithms->algorithms; diff --git a/tensorflow/core/kernels/matmul_util.cc b/tensorflow/core/kernels/matmul_util.cc index 1bcbd786eb4d65..4828726a8c82e5 100644 --- a/tensorflow/core/kernels/matmul_util.cc +++ b/tensorflow/core/kernels/matmul_util.cc @@ -60,21 +60,19 @@ StatusOr GetBlasComputationType( const DataType& dtype) { using se::blas::ComputationType; static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input(); - bool allow_tf32 = tensor_float_32_execution_enabled(); - ComputationType f32_type = - allow_tf32 ? ComputationType::kTF32AsF32 : ComputationType::kF32; switch (dtype) { case DT_HALF: + return use_f32_for_f16_computation ? ComputationType::kF32 + : ComputationType::kF16; case DT_BFLOAT16: - return use_f32_for_f16_computation ? f32_type : ComputationType::kF16; - case DT_FLOAT: - return f32_type; - case DT_DOUBLE: - return ComputationType::kF64; + return ComputationType::kF32; + case DT_FLOAT: // fall-through case DT_COMPLEX64: - return f32_type; + return tensor_float_32_execution_enabled() ? ComputationType::kTF32AsF32 + : ComputationType::kF32; + case DT_DOUBLE: // fall-through case DT_COMPLEX128: - return ComputationType::kComplexF64; + return ComputationType::kF64; default: return errors::Internal("Unsupported dtype for Blas Plans."); } diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 88b9e357363e3a..a34d7897dd2cff 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -439,7 +439,7 @@ REGISTER_OP("_FusedConv2D") .Input("filter: T") .Input("args: num_args * T") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {half, float, double}") .Attr("num_args: int >= 0") .Attr("strides: list(int)") .Attr(GetPaddingAttrStringWithExplicit()) diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc index 701f4cc1598b9c..83b9c745cbf25f 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.cc +++ b/tensorflow/core/profiler/utils/xplane_utils.cc @@ -497,6 +497,13 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) { }); }); + // TODO(b/238349654): Remove when XPlane better XPlane Comparison mechanism + // exists. + aggregated_plane.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kMinDurationPs)); + aggregated_plane.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kSelfDurationPs)); + for (const auto& [line_id, stat_by_event] : stats) { XLineBuilder aggregated_line = aggregated_plane.GetOrCreateLine(line_id); for (const auto& [event_id, event_stat] : stat_by_event) { diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 16550a6f6b6315..37d33787d0c758 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1182 // Updated: 2022/7/4 +#define TF_GRAPH_DEF_VERSION 1189 // Updated: 2022/7/11 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc index f9521f7e2a4016..c62cbe19cf39fc 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc @@ -87,12 +87,12 @@ KernelFallbackCompatRequestState::KernelFallbackCompatRequestState( DCHECK(resource_array_); DCHECK(rendezvous_); - // TODO(tfrt-devs): Support customizing non-CPU devices. - auto* device = device_manager_->HostCPU(); if (user_intra_op_threadpool != nullptr) { - custom_device_ = tensorflow::RenamedDevice::NewRenamedDevice( - device->name(), device, /*owns_underlying=*/false, - /*isolate_session_state=*/false, user_intra_op_threadpool); + for (auto* device : device_manager_->ListDevices()) { + custom_device_[device] = tensorflow::RenamedDevice::NewRenamedDevice( + device->name(), device, /*owns_underlying=*/false, + /*isolate_session_state=*/false, user_intra_op_threadpool); + } } if (model_metadata.has_value()) { session_metadata_ = *model_metadata; diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h index 71e646be2a85a3..6a5280cc2975f0 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/device.h" @@ -92,9 +93,13 @@ class KernelFallbackCompatRequestState { const absl::optional& model_metadata, const tensorflow::ProcessFunctionLibraryRuntime* pflr); - // Returns the user-specified custom device for this request. It is currently - // only used for configure per-request intra op threadpool. - tensorflow::Device* custom_device() const { return custom_device_.get(); } + // Returns the user-specified custom device corresponding to the given device. + // It is currently only used for configure per-request intra op threadpool. + tensorflow::Device* custom_device(const tensorflow::Device* device) const { + auto it = custom_device_.find(device); + if (it == custom_device_.end()) return nullptr; + return it->second.get(); + } ScopedStepContainer* step_container() const { return step_container_.get(); } @@ -136,7 +141,9 @@ class KernelFallbackCompatRequestState { // Below are resources needed by current tensorflow. std::function)>* runner_ = nullptr; ::tfrt::OwnedOrUnownedPtr step_container_; - std::unique_ptr custom_device_; + absl::flat_hash_map> + custom_device_; std::unique_ptr collective_executor_handle_; CollectiveExecutor* collective_executor_ = nullptr; core::RefCountPtr rendezvous_; diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.cc index d4c96f65789b8a..10e4a8771e6faa 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.cc @@ -56,10 +56,11 @@ tensorflow::Device* GetDeviceFromFallbackState( // // The device handling is similar to TF1 code in the below link: // http://cs/?q=f:common_runtime%2Fexecutor.cc:692%20package:piper&rcl=351575626 - if (auto* custom_device = fallback_request_state.custom_device()) { + auto* device = kernel_runner.device(); + if (auto* custom_device = fallback_request_state.custom_device(device)) { return custom_device; } - return kernel_runner.device(); + return device; } } // namespace tfd diff --git a/tensorflow/core/tfrt/fallback/BUILD b/tensorflow/core/tfrt/fallback/BUILD index eba793b0f7018d..7d89e62d48dd2b 100644 --- a/tensorflow/core/tfrt/fallback/BUILD +++ b/tensorflow/core/tfrt/fallback/BUILD @@ -81,6 +81,16 @@ cc_library( ]), ) +cc_library( + name = "cost_recorder", + hdrs = ["cost_recorder.h"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@tf_runtime//:hostcontext", + ], +) + tf_cuda_cc_test( name = "op_kernel_runner_test", size = "small", diff --git a/tensorflow/core/tfrt/fallback/cost_recorder.h b/tensorflow/core/tfrt/fallback/cost_recorder.h new file mode 100644 index 00000000000000..ed428c422fcbf1 --- /dev/null +++ b/tensorflow/core/tfrt/fallback/cost_recorder.h @@ -0,0 +1,49 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines a recorder for op cost measurement + +#ifndef TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ +#define TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tfrt/host_context/shared_context.h" // from @tf_runtime + +namespace tfrt { +class HostContext; +} // namespace tfrt + +namespace tensorflow { +namespace tfrt_stub { +class CostRecorder : public tfrt::SharedContext { + public: + explicit CostRecorder(tfrt::HostContext* host) {} + + // TODO(xiangll): This is used for cost measurement only. Clean up after the + // measurement is done. + void RecordCost(const absl::string_view op_name, + const uint64_t run_duration) { + cost_per_op_map_[op_name] = run_duration; + } + + private: + // Map op name to op run duration in terms of microseconds. + absl::flat_hash_map cost_per_op_map_; +}; +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index 641e4d0fca7c26..f3bf16c9d03171 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -424,29 +424,46 @@ GraphExecutor::ImportAndCompileClientGraph( runtime(), tpu_model_resource_, options_.compile_options.tpu_target); // Step 1 of loading: Import the client graph from proto to an MLIR module. + auto import_start_time = absl::Now(); mlir::MLIRContext context; ASSIGN_OR_RETURN_IN_IMPORT( auto module, ImportClientGraphToMlirModule(client_graph, &context)); + auto import_duration = absl::Now() - import_start_time; + LOG(INFO) << "TFRT finished importing client graph (" << &client_graph + << "). Took " << absl::ToInt64Milliseconds(import_duration) + << " ms. Client graph name: " << client_graph.name; // Step 2 of loading: Compile the MLIR module from TF dialect to TFRT dialect // (in BEF). + auto compile_start_time = absl::Now(); ASSIGN_OR_RETURN_IN_COMPILE(loaded_client_graph->bef, CompileMlirModuleToBef(module.get())); + auto compile_duration = absl::Now() - compile_start_time; + LOG(INFO) << "TFRT finished compiling client graph (" << &client_graph + << "). Took " << absl::ToInt64Milliseconds(compile_duration) + << " ms. Client graph name: " << client_graph.name; return loaded_client_graph; } StatusOr> GraphExecutor::LoadClientGraph(const GraphExecutor::ClientGraph& client_graph) { + LOG(INFO) << "TFRT loading client graph (" << &client_graph << ") " + << client_graph.name; TF_ASSIGN_OR_RETURN(auto loaded_client_graph, ImportAndCompileClientGraph(client_graph)); // Step 3 of loading: Initialize runtime states using special BEF functions. + auto init_start_time = absl::Now(); ASSIGN_OR_RETURN_IN_INIT( loaded_client_graph->bef_file, tfrt::CreateBefFileFromBefBuffer(runtime(), loaded_client_graph->bef)); RETURN_IF_ERROR_IN_INIT(InitBef(loaded_client_graph->bef_file.get(), loaded_client_graph->resource_context.get())); + auto init_duration = absl::Now() - init_start_time; + LOG(INFO) << "TFRT finished initializing client graph (" << &client_graph + << "). Took " << absl::ToInt64Milliseconds(init_duration) + << " ms. Client graph name: " << client_graph.name; return loaded_client_graph; } @@ -467,6 +484,16 @@ GraphExecutor::ImportClientGraphToMlirModule( auto optimized_graph, graph_execution_state_->CreateOptimizedGraph(graph_import_config)); + LOG(INFO) << "TFRT import client graph (" << &client_graph + << "): Functionalization took " + << absl::ToInt64Milliseconds( + optimized_graph.functionalization_duration) + << " ms. Client graph name: " << client_graph.name; + LOG(INFO) << "TFRT import client graph (" << &client_graph + << "): Grappler took " + << absl::ToInt64Milliseconds(optimized_graph.grappler_duration) + << " ms. Client graph name: " << client_graph.name; + // Convert the optimized graph to an MLIR module. return tensorflow::ConvertGraphToMlir( *optimized_graph.graph, /*debug_info=*/{}, diff --git a/tensorflow/core/tfrt/runtime/runtime.cc b/tensorflow/core/tfrt/runtime/runtime.cc index 59c5501ca67228..43d4e2bb43d07b 100644 --- a/tensorflow/core/tfrt/runtime/runtime.cc +++ b/tensorflow/core/tfrt/runtime/runtime.cc @@ -146,14 +146,6 @@ std::unique_ptr Runtime::Create( new Runtime(std::move(expected_core_runtime.get()), work_queue_ptr)); } -// TODO(b/196962112): Remove this overload. -std::unique_ptr Runtime::Create() { - static constexpr int kDefaultNumInterOpThreads = 4; - // Let system pick the number of intra op threads. - static constexpr int kDefaultNumIntraOpThreads = 0; - return Runtime::Create(kDefaultNumInterOpThreads, kDefaultNumIntraOpThreads); -} - std::unique_ptr Runtime::Create(int num_inter_op_threads, int num_intra_op_threads) { if (num_intra_op_threads <= 0) diff --git a/tensorflow/core/tfrt/runtime/runtime.h b/tensorflow/core/tfrt/runtime/runtime.h index 2c09be9799b5bb..4b98f2c0c8aa3b 100644 --- a/tensorflow/core/tfrt/runtime/runtime.h +++ b/tensorflow/core/tfrt/runtime/runtime.h @@ -39,9 +39,6 @@ namespace tfrt_stub { // tensorflow::experimental::cc::Runtime when it lands. class Runtime { public: - ABSL_DEPRECATED("Use other Create() methods instead.") - static std::unique_ptr Create(); - // Creates a runtime instance with specified threading configuration. Returns // null upon creation error. static std::unique_ptr Create(int num_inter_op_threads, diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index c54eaaa02aafbe..16827ae7da922b 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -35,7 +35,6 @@ cc_library( tags = ["no_oss"], deps = [ "@com_google_absl//absl/container:flat_hash_map", -# "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index 1dddea571f33bd..5eb703b406310c 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -16,12 +16,10 @@ limitations under the License. #include #include -#include #include #include #include -#include "absl/log/check.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -43,7 +41,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/path.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/lib/connected_traceme.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" @@ -645,36 +642,37 @@ tensorflow::Status SavedModelImpl::Run( TF_RET_CHECK(outputs) << "outputs must be provided"; outputs->clear(); - if (!options_.enable_lazy_loading) { - auto sig_iter = signatures_.find(name); - TF_RET_CHECK(sig_iter != signatures_.end()) - << "failed to find signature " << name << " in the graph"; - if (run_options.validate_input_specs) { - TF_RETURN_IF_ERROR(IsInputSpecsCorrect(name, sig_iter->second, inputs)); - } - std::vector captures; - for (const auto& capture : sig_iter->second.captures) { - captures.push_back(capture); - } - const tfrt::Function* func = - bef_file_->GetFunction({name.data(), name.size()}); - DCHECK(func); - return GraphExecutionRunOnFunction( - options_.graph_execution_options, run_options, name, *func, inputs, - captures, outputs, resource_context_.get(), runtime(), *fallback_state_, - req_deadline_tracker_); + auto sig_iter = signatures_.find(name); + TF_RET_CHECK(sig_iter != signatures_.end()) + << "failed to find signature " << name << " in the graph"; + if (run_options.validate_input_specs) { + TF_RETURN_IF_ERROR(IsInputSpecsCorrect(name, sig_iter->second, inputs)); + } + std::vector captures; + for (const auto& capture : sig_iter->second.captures) { + captures.push_back(capture); } - // If lazy loading is enabled, no signature is loaded into `bef_file_`, - // invoke `RunMultipleSignatures()` and delegate to `graph_executor_` with - // lazy loading work and execution work. - std::vector inputs_vector(inputs.begin(), inputs.end()); - std::vector> multi_outputs; - TF_RETURN_IF_ERROR(RunMultipleSignatures(run_options, {std::string(name)}, - {inputs_vector}, &multi_outputs)); - DCHECK_EQ(multi_outputs.size(), 1); - *outputs = std::move(multi_outputs[0]); - return OkStatus(); + const tfrt::Function* func; + tfrt::ResourceContext* resource_context; + if (options_.enable_lazy_loading) { + // If lazy loading is enabled, no signature is loaded into `bef_file_`, so + // we need to find the BEF from the cache or create one. + TF_ASSIGN_OR_RETURN(const LoadingResult& loading_result, + GetOrCreateLoadingResult({std::string(name)})); + func = loading_result.bef_file->GetFunction( + tensorflow::kImportModelDefaultGraphFuncName); + resource_context = loading_result.resource_context.get(); + } else { + func = bef_file_->GetFunction({name.data(), name.size()}); + resource_context = resource_context_.get(); + } + DCHECK(func); + + return GraphExecutionRunOnFunction(options_.graph_execution_options, + run_options, name, *func, inputs, captures, + outputs, resource_context, runtime(), + *fallback_state_, req_deadline_tracker_); } struct SavedModelImpl::JoinedSignature { @@ -840,5 +838,140 @@ tensorflow::Status SavedModelImpl::RunByTensorNames( target_node_names, outputs); } +namespace { + +using JoinedSignature = SavedModelImpl::JoinedSignature; + +// Returns a joined signature with the signatures in `names`. For inputs, as +// their corresponding nodes may overlap, we deduplicate them by the nodes so +// the order of inputs for the joined signature would be different from the +// original order. For outputs, overlapping is fine so we only flatten it in the +// original order. +StatusOr JoinSignatures( + absl::Span names, const SignatureMap& signature_map, + const tensorflow::protobuf::Map& + signature_def_map) { + // Join all the names, all the inputs, and all the outputs. + JoinedSignature joined_signature; + joined_signature.name = absl::StrJoin(names, kSignatureJoiningDelimiter); + + // Keep the feed tensor names visited. + absl::flat_hash_set visited_feed_tensor_names; + + for (const auto& name : names) { + const auto& signature_def = signature_def_map.at(name); + + // For inputs, we deduplicate possible overlapping feed nodes and create the + // new input array. + for (const auto& iter : signature_def.inputs()) { + const auto& tensor_info = iter.second; + + // Skip if this feed node is already visited. + if (visited_feed_tensor_names.contains(tensor_info.name())) continue; + + // Otherwise, we parse its tensor info and collect it for later + // compilation. + visited_feed_tensor_names.insert(tensor_info.name()); + + // TODO(b/184675681): Support other encoding cases. + // + // TODO(b/184679394): Add unit test for this check. + TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName) + << "Only dense tensor is supported, but got encoding case " + << tensor_info.encoding_case(); + + VLOG(1) << "Importing Signature Input: input_key = " << iter.first + << ", tensor_info = " << tensor_info.DebugString(); + + tensorflow::ArrayInfo array_info; + array_info.imported_dtype = tensor_info.dtype(); + + if (tensor_info.has_tensor_shape()) { + array_info.shape = tensor_info.tensor_shape(); + } else { + // If there is no tensor shape in the tensor info, conservatively set + // unknown_rank to true. + array_info.shape.set_unknown_rank(true); + } + + joined_signature.input_nodes.insert( + std::pair(tensor_info.name(), + std::move(array_info))); + } + + // For outputs, we simply flatten them in the original order, as it is fine + // to have duplicated fetch nodes. + const internal::Signature& signature = signature_map.at(name); + for (const auto& output_key : signature.output_names) { + const auto& tensor_info = signature_def.outputs().at(output_key); + + VLOG(1) << "Importing Signature Output: output_key = " << output_key + << ", tensor_info = " << tensor_info.DebugString(); + + TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName) + << "Only dense tensor is supported, but got encoding case " + << tensor_info.encoding_case(); + + joined_signature.output_nodes.push_back(tensor_info.name()); + } + } + + return joined_signature; +} + +} // namespace + +// TODO(b/216379787): Reuse `GraphExecutor::LoadClientGraph()`. +StatusOr> +SavedModelImpl::LoadJoinedSignature(const JoinedSignature& joined_signature) { + // Step 1: Import the combined subgraph from proto to an MLIR module. + mlir::MLIRContext context; + ASSIGN_OR_RETURN_IN_IMPORT( + auto module, ImportSubgraph(&context, joined_signature.input_nodes, + joined_signature.output_nodes, + joined_signature.target_nodes)); + + // Step 2: Compile the MLIR module from TF dialect to TFRT dialect (in BEF). + auto loading_result = std::make_unique(); + loading_result->name = joined_signature.name; + loading_result->resource_context = CreateResourceContext( + runtime(), tpu_model_resource_.get(), + options_.graph_execution_options.compile_options.tpu_target); + + RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef( + options_.graph_execution_options.compile_options, module.get(), + &loading_result->bef)); + + // Step 3: Initialize runtime states using special BEF functions. + ASSIGN_OR_RETURN_IN_INIT( + loading_result->bef_file, + tfrt::CreateBefFileFromBefBuffer( + *options_.graph_execution_options.runtime, loading_result->bef)); + RETURN_IF_ERROR_IN_INIT(RunInitializers( + /*initializers_and_signatures=*/{}, + options_.graph_execution_options.model_metadata, + loading_result->bef_file.get(), *options_.graph_execution_options.runtime, + loading_result->resource_context.get(), *fallback_state_)); + + // Store loading_result in cache. + const auto* loading_result_ptr = loading_result.get(); + loading_result_cache_[joined_signature.name] = std::move(loading_result); + return {*loading_result_ptr}; +} + +StatusOr> +SavedModelImpl::GetOrCreateLoadingResult(absl::Span names) { + const auto joined_name = absl::StrJoin(names, kSignatureJoiningDelimiter); + tensorflow::mutex_lock l(loading_result_cache_mu_); + const auto iter = loading_result_cache_.find(joined_name); + if (iter != loading_result_cache_.end()) return {*iter->second}; + + TF_ASSIGN_OR_RETURN( + const auto joined_signature, + JoinSignatures(names, signatures_, meta_graph_def_.signature_def())); + + return LoadJoinedSignature(joined_signature); +} + } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tfrt/saved_model/saved_model.h b/tensorflow/core/tfrt/saved_model/saved_model.h index 08837527a7545f..8b6dccfd0ded57 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.h +++ b/tensorflow/core/tfrt/saved_model/saved_model.h @@ -254,6 +254,14 @@ class SavedModelImpl final : public SavedModel { std::vector* outputs) override; private: + // The result of loading signature(s). + struct LoadingResult { + std::string name; + tfrt::BefBuffer bef; + tfrt::RCReference bef_file; + std::unique_ptr resource_context; + }; + // Imports a subgraph as an MLIR module with the specified `input_nodes`, // `output_nodes`. tensorflow::StatusOr> ImportSubgraph( @@ -262,6 +270,18 @@ class SavedModelImpl final : public SavedModel { const std::vector& output_nodes, const std::vector& target_nodes); + // Given the joined signature, loads the subgraph and returns loading result. + tensorflow::StatusOr< + std::reference_wrapper> + LoadJoinedSignature(const JoinedSignature& joined_signature) + TF_EXCLUSIVE_LOCKS_REQUIRED(loading_result_cache_mu_); + + // Returns the loading result given the signature names. + tensorflow::StatusOr< + std::reference_wrapper> + GetOrCreateLoadingResult(absl::Span names) + TF_LOCKS_EXCLUDED(loading_result_cache_mu_); + // Runs `func` with the given inputs, and outputs the result. tensorflow::Status RunInternal(const RunOptions& run_options, absl::string_view signature_name, @@ -288,6 +308,12 @@ class SavedModelImpl final : public SavedModel { // (TpuModelResource) to a general and plugable interface. std::unique_ptr tpu_model_resource_; std::unique_ptr resource_context_; + tensorflow::mutex loading_result_cache_mu_; + // For pointer stability of values in `absl::flat_hash_map<>`, additional + // `std::unique_ptr<>` is necessary. (See https://abseil.io/tips/136.) + absl::flat_hash_map> + loading_result_cache_ TF_GUARDED_BY(loading_result_cache_mu_); std::unique_ptr graph_executor_; }; diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc index 89e56030160fb4..bbeed78c874c24 100644 --- a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc +++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc @@ -512,7 +512,8 @@ TEST(SavedModelTest, RunOptionsWorkQueue) { std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( "tensorflow/core/tfrt/saved_model/tests/toy_v1"); - auto runtime = tensorflow::tfrt_stub::Runtime::Create(); + auto runtime = + tensorflow::tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4); auto options = DefaultSavedModelOptions(runtime.get()); options.graph_execution_options.compile_options.enable_native_ops = false; diff --git a/tensorflow/dtensor/python/BUILD b/tensorflow/dtensor/python/BUILD index 1f8ed659c6d0c6..ccf8e247f996f9 100644 --- a/tensorflow/dtensor/python/BUILD +++ b/tensorflow/dtensor/python/BUILD @@ -242,6 +242,7 @@ pytype_strict_library( name = "multi_client_util", srcs = ["multi_client_util.py"], deps = [ + ":api", "//tensorflow/core:protos_all_py", "//tensorflow/python:platform", "//tensorflow/python/eager:context", diff --git a/tensorflow/dtensor/python/api.py b/tensorflow/dtensor/python/api.py index afa4c068fdb004..7ba72023fc3d48 100644 --- a/tensorflow/dtensor/python/api.py +++ b/tensorflow/dtensor/python/api.py @@ -488,7 +488,7 @@ def full_job_name(task_id: Optional[int] = None) -> str: task_id = client_id() # In local runs and unit tests, there should be exactly one client running # on one TF task. - if job_name() == "localhost" and task_id != 0: + if num_clients() == 1 and task_id != 0: raise ValueError(f"Unexpected task ID {task_id} in local runs") return f"{job_name()}/replica:0/task:{task_id}" diff --git a/tensorflow/dtensor/python/dtensor_device.py b/tensorflow/dtensor/python/dtensor_device.py index 9ba61256097f1f..83c9f9207018c0 100644 --- a/tensorflow/dtensor/python/dtensor_device.py +++ b/tensorflow/dtensor/python/dtensor_device.py @@ -61,7 +61,7 @@ def __init__(self, meshes: List[layout_lib.Mesh], is_async=True): if any(not isinstance(mesh, layout_lib.Mesh) for mesh in meshes): raise TypeError( "Expected a flat list of Mesh objects, got {}".format(meshes)) - global _next_device_number, _next_device_number_lock + global _next_device_number ctx = context.context() with _next_device_number_lock: self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(), @@ -98,9 +98,7 @@ def _job_name(self): return os.environ.get(_DT_JOB_NAME, "localhost") def _full_job_name(self): - """Returns the fully qualified TF job name for this or another task.""" - if self._job_name() == "localhost": - return "localhost/replica:0/task:0" + """Returns the fully qualified TF job name for this task.""" return self._job_name() + "/replica:0/task:" + str(self._client_id()) def _create_host_array(self, shape, host_id): diff --git a/tensorflow/dtensor/python/mesh_util.py b/tensorflow/dtensor/python/mesh_util.py index d583bcb6ac8f1a..fb77c09500b63b 100644 --- a/tensorflow/dtensor/python/mesh_util.py +++ b/tensorflow/dtensor/python/mesh_util.py @@ -43,9 +43,6 @@ def _print_context(num_global_devices: int, num_clients: int, client_id: int, # pylint: enable=protected-access -_in_multi_client_mode = None - - @tf_export('experimental.dtensor.create_mesh', v1=[]) def create_mesh(mesh_dims: Optional[List[Tuple[str, int]]] = None, mesh_name: str = '', @@ -165,10 +162,9 @@ def create_distributed_mesh(mesh_dims: List[Tuple[str, int]], if num_clients <= 0: raise ValueError(f'num_clients ({num_clients}) must be > 0') - if _in_multi_client_mode is None and num_clients > 1: - raise ValueError( - 'Invalid multi-client topology, run dtensor.initialize_multi_client() first' - ) + if api.num_clients() > 1 and not multi_client_util.is_initialized(): + raise ValueError('Invalid multi-client topology, please run ' + 'dtensor.initialize_multi_client() first.') if client_id is None: client_id = api.client_id() @@ -263,24 +259,11 @@ def dtensor_initialize_multi_client( service to make sure that workers know the devices on each other, a prerequisite for data transfer through cross-worker rendezvous. """ - global _in_multi_client_mode assert context.executing_eagerly() - _in_multi_client_mode = api.job_name() != 'localhost' - - if not _in_multi_client_mode and api.num_clients() != 1: - raise ValueError( - 'DTENSOR_NUM_CLIENTS is set and not 1, while DTENSOR_JOB_NAME is ' - 'set to localhost for single client mode.') - # Collective GRPC servers are only necessary in multi-client setup. # Single clients can use local mode of collectives. - if _in_multi_client_mode: - if api.jobs() is None: - raise ValueError( - 'DTENSOR_JOBS environment variable is required when' - 'using multi-client to properly set up communications between servers' - ) + if api.num_clients() > 1: multi_client_util.initialize_multi_client_cluster( job_name=api.job_name(), dtensor_jobs=api.jobs(), diff --git a/tensorflow/dtensor/python/multi_client_util.py b/tensorflow/dtensor/python/multi_client_util.py index 4cf2a5bf3d4af6..b4f4b7929fa3c2 100644 --- a/tensorflow/dtensor/python/multi_client_util.py +++ b/tensorflow/dtensor/python/multi_client_util.py @@ -20,9 +20,12 @@ from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.dtensor.python import api from tensorflow.python.eager import context from tensorflow.python.platform import remote_utils +_is_multi_client_initialized = False + def initialize_multi_client_cluster(job_name: str, dtensor_jobs: List[str], @@ -52,8 +55,26 @@ def initialize_multi_client_cluster(job_name: str, Raises: RuntimeError: If running inside a tf.function. """ + global _is_multi_client_initialized assert context.executing_eagerly() + if _is_multi_client_initialized: + raise ValueError("Multi-client mode has already been initialized.") + + if api.num_clients() <= 1: + raise ValueError( + "DTENSOR_NUM_CLIENTS must be set greater than 1 for multi-client mode.") + + if not api.jobs() or len(api.jobs()) <= 1: + raise ValueError( + "DTENSOR_JOBS environment variable is required when using multi-client " + "mode to properly set up communications between servers.") + + if len(api.jobs()) != api.num_clients(): + raise ValueError( + "DTENSOR_JOBS environment variable must be configured with the same " + "number of items as DTENSOR_NUM_CLIENTS.") + if not collective_leader.startswith("/job:"): collective_leader = "/job:" + collective_leader @@ -85,3 +106,10 @@ def initialize_multi_client_cluster(job_name: str, logging.info("Enabling collectives with server_def: %s", server_def) context.context().enable_collective_ops(server_def) context.ensure_initialized() + + _is_multi_client_initialized = True + + +def is_initialized() -> bool: + """Returns whether multi-client mode has been initialized.""" + return _is_multi_client_initialized diff --git a/tensorflow/dtensor/python/tpu_util.py b/tensorflow/dtensor/python/tpu_util.py index 120a459810ff37..e17b9d14722d51 100644 --- a/tensorflow/dtensor/python/tpu_util.py +++ b/tensorflow/dtensor/python/tpu_util.py @@ -170,16 +170,9 @@ def dtensor_initialize_tpu_system(enable_coordination_service=False): # Reconfigure TensorFlow to use TFRT TPU runtime if requested. _configure_tpu_runtime() - in_multi_client_mode = api.job_name() != "localhost" - # Collective GRPC servers are only necessary in mutli-client setup. # Single clients can use local mode of collectives. - if in_multi_client_mode: - if api.jobs() is None: - raise ValueError( - "DTENSOR_JOBS environment variable is required when" - "using multi-client to properly set up communications between servers" - ) + if api.num_clients() > 1 and not multi_client_util.is_initialized(): multi_client_util.initialize_multi_client_cluster( job_name=api.job_name(), dtensor_jobs=api.jobs(), @@ -309,7 +302,7 @@ def _tpu_init_fn(): raise e # Optionally exchange heartbeats between workers every minute. - if in_multi_client_mode and api.heartbeat_enabled(): + if api.num_clients() > 1 and api.heartbeat_enabled(): logging.info( "Starting DTensor heartbeat service exchanging signals every 10 minutes" ) @@ -672,7 +665,6 @@ def create_tpu_mesh(mesh_dim_names: List[str], logging.info("Actual ring_axes: %s", ring_axes) # Validate ring_bounds values. - global _tpu_topology if _tpu_topology is None: raise ValueError( "Invalid TPU topology, run dtensor.initialize_tpu_system() first") @@ -726,7 +718,6 @@ def create_tpu_mesh(mesh_dim_names: List[str], # For this point on, change from List[CoreLocation] to List[List[int]] for # easier interaction with the C++ API. global_core_locations = [l.to_list() for l in global_core_locations] - global _dtensor_device if _dtensor_device is None: raise ValueError( "Invalid system device, run dtensor.initialize_tpu_system() first") diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 62a2af8bb07823..2f2795bb8b833c 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -13901,7 +13901,7 @@ func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { // The attr `format` can be used to override the color format of the encoded // output. Values can be: // -// - `''`: Use a default format based on the number of channels in the image. +// - `”`: Use a default format based on the number of channels in the image. // - `grayscale`: Output a grayscale JPEG image. The `channels` dimension // of `image` must be 1. // - `rgb`: Output an RGB JPEG image. The `channels` dimension diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index d8fdc5023fac50..09609032765989 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -270,6 +270,7 @@ cc_library( name = "tensorflow_profiler_logger_shim", srcs = ["tensorflow_profiler_logger_shim.cc"], hdrs = ["tensorflow_profiler_logger.h"], + compatible_with = get_compatible_with_portable(), copts = tflite_copts_warnings(), deps = [ ":macros", @@ -284,6 +285,8 @@ cc_library( hdrs = ["tensorflow_profiler_logger.h"], copts = tflite_copts_warnings(), deps = [ + ":macros", + "//tensorflow/core/profiler/lib:scoped_memory_debug_annotation", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:kernel_util", @@ -539,6 +542,7 @@ cc_library( ":simple_memory_arena", ":stderr_reporter", ":string", + ":tensorflow_profiler_logger_shim", ":type_to_tflitetype", ":util", ":version", diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 354638f864e3d6..1ae1f852896731 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -48,6 +48,12 @@ def tflite_copts(): "//conditions:default": [ "-fno-exceptions", # Exceptions are unused in TFLite. ], + }) + select({ + "//tensorflow/lite:tflite_with_xnnpack_explicit_false": ["-DTFLITE_WITHOUT_XNNPACK"], + "//conditions:default": [], + }) + select({ + "//tensorflow/lite:tensorflow_profiler_config": ["-DTF_LITE_TENSORFLOW_PROFILER"], + "//conditions:default": [], }) return copts + tflite_copts_extra() diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index 811f0a032937cd..0eb6b211cf54df 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -218,10 +218,7 @@ cc_library( "common.h", ], compatible_with = get_compatible_with_portable(), - copts = tflite_copts() + select({ - "//tensorflow/lite:tensorflow_profiler_config": ["-DTF_LITE_TENSORFLOW_PROFILER"], - "//conditions:default": [], - }), + copts = tflite_copts(), deps = [ ":c_api_types", ] + select({ diff --git a/tensorflow/lite/c/common.cc b/tensorflow/lite/c/common.cc index 8548424d108387..ae5c44b544b3d8 100644 --- a/tensorflow/lite/c/common.cc +++ b/tensorflow/lite/c/common.cc @@ -17,9 +17,6 @@ limitations under the License. #include "tensorflow/lite/c/c_api_types.h" #ifdef TF_LITE_TENSORFLOW_PROFILER -#include - -#include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/tensorflow_profiler_logger.h" #endif @@ -28,19 +25,6 @@ limitations under the License. #include #endif // TF_LITE_STATIC_MEMORY -#ifdef TF_LITE_TENSORFLOW_PROFILER -namespace tflite { -// Use weak symbols here (even though they are guarded by macros) to avoid -// build breakage when building a benchmark requires TFLite runs. The main -// benchmark library should have tensor_profiler_logger dependency. -TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorAlloc(TfLiteTensor* tensor, - size_t num_bytes); - -TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorDealloc(TfLiteTensor* tensor); -} // namespace tflite - -#endif // TF_LITE_TENSORFLOW_PROFILER - extern "C" { size_t TfLiteIntArrayGetSizeInBytes(int size) { diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 5175d903982d33..658eda099fa821 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -836,16 +836,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, *builtin_data = params.release(); return kTfLiteOk; } - case BuiltinOperator_UNSORTED_SEGMENT_PROD: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* unsorted_segment_prod_params = - op->builtin_options_as_UnsortedSegmentProdOptions()) { - params->num_segments = unsorted_segment_prod_params->num_segments(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } // Below are the ops with no builtin_data structure. // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are // ok for now, since there is no call implementation either. @@ -868,6 +858,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_RANGE: case BuiltinOperator_SQUARED_DIFFERENCE: case BuiltinOperator_REVERSE_V2: + case BuiltinOperator_UNSORTED_SEGMENT_PROD: case BuiltinOperator_WHERE: case BuiltinOperator_RANK: case BuiltinOperator_NON_MAX_SUPPRESSION_V4: diff --git a/tensorflow/lite/core/api/profiler.h b/tensorflow/lite/core/api/profiler.h index 8e8a00e9213b3c..74ff6f4e60ce83 100644 --- a/tensorflow/lite/core/api/profiler.h +++ b/tensorflow/lite/core/api/profiler.h @@ -71,7 +71,10 @@ class Profiler { // is useful when 'event_metadata's are not available when the event begins // or when one wants to overwrite the 'event_metadata's set at the beginning. virtual void EndEvent(uint32_t event_handle, int64_t event_metadata1, - int64_t event_metadata2) {} + int64_t event_metadata2) { + // By default discards the metadata. + EndEvent(event_handle); + } // Signals an end to the specified profile event. virtual void EndEvent(uint32_t event_handle) = 0; @@ -138,12 +141,17 @@ class ScopedDelegateOperatorProfile : public ScopedProfile { static_cast(node_index)) {} }; -class ScopedRuntimeInstrumentationProfile : public ScopedProfile { +// Similar to ScopedProfile but has extra event metadata for EndEvent. +class ScopedRuntimeInstrumentationProfile { public: ScopedRuntimeInstrumentationProfile(Profiler* profiler, const char* tag) - : ScopedProfile( - profiler, tag, - Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, -1) {} + : profiler_(profiler), event_handle_(0) { + if (profiler) { + event_handle_ = profiler_->BeginEvent( + tag, Profiler::EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT, + /*event_metadata=*/-1); + } + } void set_runtime_status(int64_t delegate_status, int64_t interpreter_status) { if (profiler_) { @@ -159,8 +167,10 @@ class ScopedRuntimeInstrumentationProfile : public ScopedProfile { } private: - int64_t delegate_status_; - int64_t interpreter_status_; + Profiler* profiler_ = nullptr; + uint32_t event_handle_ = 0; + int64_t delegate_status_ = 0; + int64_t interpreter_status_ = 0; }; } // namespace tflite diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 96e628dd1470a8..acfb2baacd35cd 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -49,6 +49,9 @@ limitations under the License. #else #include "tensorflow/lite/arena_planner.h" #endif +#ifdef TF_LITE_TENSORFLOW_PROFILER +#include "tensorflow/lite/tensorflow_profiler_logger.h" +#endif // TF_LITE_TENSORFLOW_PROFILER namespace tflite { @@ -1218,6 +1221,12 @@ TfLiteStatus Subgraph::Invoke() { const char* op_name = nullptr; if (profiler_) op_name = GetTFLiteOpName(registration); +#ifdef TF_LITE_TENSORFLOW_PROFILER + if (!op_name) { + op_name = GetTFLiteOpName(registration); + } + tflite::OnTfLiteOpInvoke(op_name, node_index); +#endif // TF_LITE_TENSORFLOW_PROFILER TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE(profiler_.get(), op_name, node_index); for (int i = 0; i < node.inputs->size; ++i) { diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc b/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc index 63bd72496fb32d..db02ea5349a4a9 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc @@ -454,10 +454,10 @@ absl::Status CLArguments::SetObjectRef(const std::string& name, absl::Status CLArguments::SetGPUResources( const std::string& name, const GPUResourcesWithValue& resources) { - for (const auto& r : resources.ints) { + for (const auto& r : resources.generic.ints) { RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second)); } - for (const auto& r : resources.floats) { + for (const auto& r : resources.generic.floats) { RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second)); } for (const auto& r : resources.buffers) { diff --git a/tensorflow/lite/delegates/gpu/cl/gpu_object.h b/tensorflow/lite/delegates/gpu/cl/gpu_object.h index 1c6764c0093f19..0e1c67b94b5383 100644 --- a/tensorflow/lite/delegates/gpu/cl/gpu_object.h +++ b/tensorflow/lite/delegates/gpu/cl/gpu_object.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" @@ -32,14 +33,21 @@ namespace gpu { namespace cl { struct GPUResourcesWithValue { - std::vector> ints; - std::vector> floats; + GenericGPUResourcesWithValue generic; + std::vector> buffers; std::vector> images2d; std::vector> image2d_arrays; std::vector> images3d; std::vector> image_buffers; std::vector> custom_memories; + + void AddFloat(const std::string& name, float value) { + generic.AddFloat(name, value); + } + void AddInt(const std::string& name, int value) { + generic.AddInt(name, value); + } }; class GPUObject { diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index f5bfdd384eae6a..6dea1b3b8d0fbd 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -176,7 +176,7 @@ absl::Status GetBufferAsignment( const auto& t = gpu_model.tensors.at(usage.first); const auto& shape = t.GetBHWDCShape(); const auto& descriptor = t; - const size_t element_size = SizeOf(descriptor.data_type); + const size_t element_size = SizeOf(descriptor.GetDataType()); size_t buffer_size; if (descriptor.GetStorageType() == TensorStorageType::TEXTURE_2D || descriptor.GetStorageType() == TensorStorageType::SINGLE_TEXTURE_2D) { @@ -593,7 +593,7 @@ absl::Status InferenceContext::AllocateBufferBasedTensors( if (t.second.GetStorageType() == TensorStorageType::TEXTURE_2D || t.second.GetStorageType() == TensorStorageType::SINGLE_TEXTURE_2D) { const size_t bytes_per_pixel = - SizeOf(t.second.data_type) * + SizeOf(t.second.GetDataType()) * (t.second.GetStorageType() == TensorStorageType::TEXTURE_2D ? 4 : shape.c); diff --git a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc index b94c997657209b..089c592a4de6ef 100644 --- a/tensorflow/lite/delegates/gpu/cl/linear_storage.cc +++ b/tensorflow/lite/delegates/gpu/cl/linear_storage.cc @@ -62,7 +62,7 @@ absl::Status LinearStorage::GetGPUResources( "Expected TensorLinearDescriptor on input."); } - resources->ints.push_back({"length", depth_}); + resources->AddInt("length", depth_); if (storage_type_ == LinearStorageType::BUFFER) { resources->buffers.push_back({"buffer", memory_}); diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc index b0fa1637bb313c..359fdfa1f5da01 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/str_cat.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h" @@ -44,7 +45,7 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape, case TensorStorageType::BUFFER: case TensorStorageType::IMAGE_BUFFER: { const size_t data_size = shape.b * shape.w * shape.h * shape.d * slices * - 4 * SizeOf(descriptor.data_type); + 4 * SizeOf(descriptor.GetDataType()); cl_int error_code; cl_mem memory = clCreateBuffer(context.context(), mem_flags, data_size, const_cast(data_ptr), &error_code); @@ -71,7 +72,7 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape, cl_image_format format; format.image_channel_order = CL_RGBA; format.image_channel_data_type = - DataTypeToChannelType(descriptor.data_type); + DataTypeToChannelType(descriptor.GetDataType()); cl_int error_code; cl_mem memory = @@ -101,7 +102,7 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape, cl_image_format format; format.image_channel_order = CL_RGBA; format.image_channel_data_type = - DataTypeToChannelType(descriptor.data_type); + DataTypeToChannelType(descriptor.GetDataType()); cl_int error_code; cl_mem memory = @@ -132,7 +133,7 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape, cl_image_format format; format.image_channel_order = CL_RGBA; format.image_channel_data_type = - DataTypeToChannelType(descriptor.data_type); + DataTypeToChannelType(descriptor.GetDataType()); cl_int error_code; cl_mem memory = @@ -166,10 +167,11 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape, desc.buffer = nullptr; cl_image_format format; - if (context.IsFloatTexture2DSupported(shape.c, descriptor.data_type)) { + if (context.IsFloatTexture2DSupported(shape.c, + descriptor.GetDataType())) { format.image_channel_order = ToChannelOrder(shape.c); format.image_channel_data_type = - DataTypeToChannelType(descriptor.data_type); + DataTypeToChannelType(descriptor.GetDataType()); } else { return absl::InvalidArgumentError(absl::StrCat( "This device doesn't support ", shape.c, "-channel textures.")); @@ -267,7 +269,7 @@ absl::Status CreateTensor(const CLContext& context, const BHWDC& shape, if (descriptor.GetStorageType() == TensorStorageType::IMAGE_BUFFER) { cl_mem image_memory; RETURN_IF_ERROR(CreateImageBufferFromBuffer( - context, memory, descriptor.data_type, + context, memory, descriptor.GetDataType(), shape.b * shape.w * shape.h * shape.d * DivideRoundUp(shape.c, 4), &image_memory)); *result = Tensor(memory, memory_owner, image_memory, shape, descriptor); @@ -284,7 +286,7 @@ absl::Status CreateTensorShared(const CLContext& context, const BHWDC& shape, if (descriptor.GetStorageType() == TensorStorageType::IMAGE_BUFFER) { cl_mem image_memory; RETURN_IF_ERROR(CreateImageBufferFromBuffer( - context, memory, descriptor.data_type, + context, memory, descriptor.GetDataType(), shape.b * shape.w * shape.h * shape.d * DivideRoundUp(shape.c, 4), &image_memory)); *result = Tensor(memory, memory_owner, image_memory, shape, descriptor); @@ -407,24 +409,7 @@ absl::Status Tensor::GetGPUResources(const GPUObjectDescriptor* obj_ptr, if (!tensor_desc) { return absl::InvalidArgumentError("Expected TensorDescriptor on input."); } - resources->ints.push_back( - {"slice_stride", tensor_desc->GetSliceStrideSize(shape_)}); - if (descriptor_.HasAxis(Axis::WIDTH)) { - resources->ints.push_back({"width", tensor_desc->GetWidthSize(shape_)}); - } - if (descriptor_.HasAxis(Axis::HEIGHT)) { - resources->ints.push_back({"height", Height()}); - } - if (descriptor_.HasAxis(Axis::CHANNELS)) { - resources->ints.push_back({"slices", Slices()}); - resources->ints.push_back({"channels", Channels()}); - } - if (descriptor_.HasAxis(Axis::BATCH)) { - resources->ints.push_back({"batch", Batch()}); - } - if (descriptor_.HasAxis(Axis::DEPTH)) { - resources->ints.push_back({"depth", Depth()}); - } + tensor_desc->GetGpuResources(shape_, &resources->generic); if (descriptor_.GetStorageType() == TensorStorageType::BUFFER) { resources->buffers.push_back({"buffer", memory_}); @@ -433,8 +418,7 @@ absl::Status Tensor::GetGPUResources(const GPUObjectDescriptor* obj_ptr, TensorStorageType::SINGLE_TEXTURE_2D) { if (obj_ptr->GetAccess() == AccessType::WRITE && tensor_desc->GetUseBufferForWriteOnlyTexture2d()) { - resources->ints.push_back( - {"aligned_texture_width", aligned_texture_width_}); + resources->AddInt("aligned_texture_width", aligned_texture_width_); resources->buffers.push_back({"buffer", memory_}); } else { cl_mem mem = buffer_based_ ? image_buffer_memory_ : memory_; @@ -524,7 +508,7 @@ int Tensor::GetAlignedChannels() const { } uint64_t Tensor::GetMemorySizeInBytes() const { - const int flt_size = SizeOf(descriptor_.data_type); + const int flt_size = SizeOf(descriptor_.GetDataType()); const int flt4_size = 4 * flt_size; switch (descriptor_.GetStorageType()) { case TensorStorageType::BUFFER: @@ -583,7 +567,7 @@ absl::Status Tensor::CreateFromDescriptor(const TensorDescriptor& desc, memory_ = memory.Release(); if (desc.GetStorageType() == TensorStorageType::IMAGE_BUFFER) { RETURN_IF_ERROR(CreateImageBufferFromBuffer( - *context, memory_, desc.data_type, + *context, memory_, desc.GetDataType(), shape_.b * shape_.w * shape_.h * shape_.d * DivideRoundUp(shape_.c, 4), &image_buffer_memory_)); } @@ -695,7 +679,7 @@ absl::Status CreateSharedImage2DBufferTensor(const CLContext& context, : 4; cl_mem image_memory; RETURN_IF_ERROR(CreateImage2DFromBuffer( - context, memory, descriptor.data_type, width, height, channels, + context, memory, descriptor.GetDataType(), width, height, channels, width_pixel_alignment, &image_memory)); *result = Tensor(memory, false, image_memory, shape, descriptor); result->aligned_texture_width_ = AlignByN(width, width_pixel_alignment); diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h index 0cfd5473535fe0..058e07bc260eac 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor.h +++ b/tensorflow/lite/delegates/gpu/cl/tensor.h @@ -69,7 +69,7 @@ class Tensor : public GPUObject, public GpuSpatialTensor { int Batch() const override { return shape_.b; } TensorDescriptor GetDescriptor() const override { return descriptor_; } - DataType GetDataType() const { return descriptor_.data_type; } + DataType GetDataType() const { return descriptor_.GetDataType(); } TensorStorageType GetStorageType() const { return descriptor_.GetStorageType(); } @@ -207,7 +207,7 @@ template absl::Status Tensor::WriteDataBHWDC(const T* in, CLCommandQueue* queue) { std::unique_ptr data_copy; data_copy.reset(new uint8_t[GetMemorySizeInBytes()]); - if (descriptor_.data_type == DataType::FLOAT16) { + if (descriptor_.GetDataType() == DataType::FLOAT16) { // rearrangement and conversion from float32 to float16 DataFromBHWDC(reinterpret_cast(in), shape_, descriptor_, reinterpret_cast(data_copy.get())); @@ -227,7 +227,7 @@ absl::Status Tensor::ReadDataBHWDC(T* out, CLCommandQueue* queue) const { RETURN_IF_ERROR(ReadData(data_copy.get(), queue)); - if (descriptor_.data_type == DataType::FLOAT16) { + if (descriptor_.GetDataType() == DataType::FLOAT16) { // rearrangement and conversion from float32 to float16 DataToBHWDC(reinterpret_cast(data_copy.get()), shape_, descriptor_, reinterpret_cast(out)); diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc index 2764f3b0d5eb4d..dbe91d7b94b6e9 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc @@ -774,9 +774,9 @@ bool GpuInfo::SupportsZeroClampForImages() const { } else if (IsApiOpenCl()) { return true; } else if (IsApiVulkan()) { - return true; + return false; } else if (IsApiOpenGl()) { - return opengl_info.IsApiOpenGl32OrAbove(); + return false; } else { return false; } diff --git a/tensorflow/lite/delegates/gpu/common/gpu_model.cc b/tensorflow/lite/delegates/gpu/common/gpu_model.cc index a21a8dda34cb97..f2b42d694a88dc 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_model.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_model.cc @@ -136,7 +136,7 @@ absl::Status CheckExternalTensorDescription(const GpuInfo& gpu_info, const TensorDescriptor& tensor_desc, const BHWC& shape, DataType data_type) { - if (tensor_desc.data_type != data_type) { + if (tensor_desc.GetDataType() != data_type) { return absl::InvalidArgumentError( "Global precision and precision of predefined/external tensors must be " "synchronized."); diff --git a/tensorflow/lite/delegates/gpu/common/selectors/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/BUILD index 149754badd9544..2d4311f28707ec 100644 --- a/tensorflow/lite/delegates/gpu/common/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/common/selectors/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", + "//tensorflow/lite/delegates/gpu/common/tasks:mean_stddev_normalization", "//tensorflow/lite/delegates/gpu/common/tasks/special:conv_pointwise", "//tensorflow/lite/delegates/gpu/common/tasks/special:depthwise_conv_plus_1x1_conv", "//tensorflow/lite/delegates/gpu/common/tasks/special:fc_fc_add", diff --git a/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc index be78e16c78e04e..084da166a1dc14 100644 --- a/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/operation_selector.cc @@ -167,7 +167,7 @@ absl::Status AddDynamicConv(ModelHints hints, const GpuInfo& gpu_info, gpu_subgraph->operations.push_back({}); auto& conv_op = gpu_subgraph->operations.back(); OperationDef conv_temp_def = op_def; - conv_temp_def.src_tensors[1] = {op_def.src_tensors[1].data_type, + conv_temp_def.src_tensors[1] = {op_def.src_tensors[1].GetDataType(), TensorStorageType::BUFFER, Layout::HWC}; WeightsDescription weights_desc; const BHWC weights_shape_bhwc(weights_shape.o, weights_shape.h, @@ -364,10 +364,10 @@ absl::Status GPUOperationFromNodePart0( const BHWC hwc_output_shape(1, dst_shape.b * dst_shape.h, dst_shape.w, dst_shape.c); TensorDescriptor hwc_input_desc = { - op_def.src_tensors[0].data_type, + op_def.src_tensors[0].GetDataType(), op_def.src_tensors[0].GetStorageType(), Layout::BHWC}; TensorDescriptor hwc_output_desc = { - op_def.dst_tensors[0].data_type, + op_def.dst_tensors[0].GetDataType(), op_def.dst_tensors[0].GetStorageType(), Layout::BHWC}; src_id = gpu_subgraph->AddTensor(hwc_input_shape, hwc_input_desc); dst_id = gpu_subgraph->AddTensor(hwc_output_shape, hwc_output_desc); @@ -390,7 +390,7 @@ absl::Status GPUOperationFromNodePart0( dst_shape, src_id, inputs[1]->id, dst_id, gpu_subgraph)); if (dst_shape.b != 1) { TensorDescriptor hwc_output_desc = { - op_def.dst_tensors[0].data_type, + op_def.dst_tensors[0].GetDataType(), op_def.dst_tensors[0].GetStorageType(), Layout::BHWC}; OperationDef reshape_output_def; @@ -497,7 +497,7 @@ absl::Status GPUOperationFromNodePart0( attr.weights.shape.w, attr.weights.shape.i); OperationDef conv_temp_def = op_def; conv_temp_def.src_tensors.push_back( - {op_def.src_tensors[0].data_type, TensorStorageType::BUFFER, + {op_def.src_tensors[0].GetDataType(), TensorStorageType::BUFFER, Layout::HWC}); *gpu_op = SelectConvolutionWithDynamicWeights( attr, weights_shape_bhwc, output_shape, gpu_info, conv_temp_def, @@ -612,7 +612,7 @@ absl::Status GPUOperationFromNodePart0( case OperationType::MAX_UNPOOLING_2D: { auto attr = absl::any_cast(node.operation.attributes); - *gpu_op = SelectMaxUnpooling(attr, op_def); + *gpu_op = SelectMaxUnpooling(attr, gpu_info, op_def); return absl::OkStatus(); } case OperationType::MEAN: { diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc index d594c3c7073fea..f95b37228997fb 100644 --- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.cc @@ -77,8 +77,10 @@ std::unique_ptr SelectPooling(const Pooling2DAttributes& attr, } std::unique_ptr SelectMaxUnpooling( - const MaxUnpooling2DAttributes& attr, const OperationDef& op_def) { - return std::make_unique(CreateMaxUnpooling(op_def, attr)); + const MaxUnpooling2DAttributes& attr, const GpuInfo& gpu_info, + const OperationDef& op_def) { + return std::make_unique( + CreateMaxUnpooling(gpu_info, op_def, attr)); } void SelectAdd(const OperationDef& op_def, const std::vector& channels, diff --git a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h index ae2900c5f3fe3b..e4d7c36b6c173c 100644 --- a/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/common/selectors/simple_selectors.h @@ -42,7 +42,8 @@ std::unique_ptr SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def); std::unique_ptr SelectMaxUnpooling( - const MaxUnpooling2DAttributes& attr, const OperationDef& op_def); + const MaxUnpooling2DAttributes& attr, const GpuInfo& gpu_info, + const OperationDef& op_def); void SelectAdd(const OperationDef& op_def, const std::vector& channels, int dst_channels, std::unique_ptr* ptr); diff --git a/tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc index 11b82174b26653..bbf01af5d502fd 100644 --- a/tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc +++ b/tensorflow/lite/delegates/gpu/common/selectors/special_selector.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" +#include "tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h" #include "tensorflow/lite/delegates/gpu/common/tasks/special/conv_pointwise.h" #include "tensorflow/lite/delegates/gpu/common/tasks/special/depthwise_conv_plus_1x1_conv.h" #include "tensorflow/lite/delegates/gpu/common/tasks/special/fc_fc_add.h" @@ -250,6 +251,169 @@ absl::Status TryFCFCAdd( consumed_nodes->insert(add_node->id); return absl::OkStatus(); } + +absl::Status CheckIfValidNodeOfType(const Node* node, + OperationType required_type) { + if (node == nullptr) { + return absl::NotFoundError("Invalid node."); + } + if (OperationTypeFromString(node->operation.type) != required_type) { + return absl::NotFoundError("Type mismatch."); + } + return absl::OkStatus(); +} + +absl::Status GetElementwiseScalarValue(const Node* node, float* result) { + auto attr = absl::any_cast(node->operation.attributes); + const float* value = absl::get_if(&attr.param); + if (!value) { + return absl::NotFoundError("Not a scalar value inside attributes."); + } + *result = *value; + return absl::OkStatus(); +} + +absl::Status GetNextSingleNode(const GraphFloat32& graph, const Node& node, + OperationType next_type, Node** next_node) { + auto consumers = graph.FindConsumers(graph.FindOutputs(node.id)[0]->id); + if (consumers.size() != 1) { + return absl::NotFoundError("Not a single consumer."); + } + RETURN_IF_ERROR(CheckIfValidNodeOfType(consumers[0], next_type)); + *next_node = consumers[0]; + return absl::OkStatus(); +} + +// MeanStdDevNormalization fusion works with this subgraph +// input +// / \ +// | mean +// \ / +// substraction +// / \ +// | | +// | pow +// | | +// | mean +// | | +// | add +// | | +// | rsqrt +// | | +// \ / +// multiplication +// | +// output +absl::Status TryMeanStdDevNormalization( + const GpuInfo& gpu_info, CalculationsPrecision precision, + const GraphFloat32& graph, NodeId first_node_id, + const std::map& tensor_descriptors, + std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) { + Node* first_mean_node = graph.GetNode(first_node_id); + RETURN_IF_ERROR(CheckIfValidNodeOfType(first_mean_node, OperationType::MEAN)); + auto first_mean_attr = + absl::any_cast(first_mean_node->operation.attributes); + if (first_mean_attr.dims != std::set{Axis::CHANNELS}) { + return absl::NotFoundError("MeanStdDevNormalization not suitable."); + } + Node* sub_node; + RETURN_IF_ERROR(GetNextSingleNode(graph, *first_mean_node, OperationType::SUB, + &sub_node)); + auto sub_inputs = graph.FindInputs(sub_node->id); + if (sub_inputs.size() != 2) { + return absl::NotFoundError("MeanStdDevNormalization not suitable."); + } else { + // checking structure + // input + // / \ + // | mean + // \ / + // substraction + Node* sub_first_parent = graph.FindProducer(sub_inputs[0]->id); + Node* sub_second_parent = graph.FindProducer(sub_inputs[1]->id); + if (sub_second_parent != first_mean_node) { + return absl::NotFoundError("MeanStdDevNormalization not suitable."); + } + auto mean_inputs = graph.FindInputs(first_mean_node->id); + Node* mean_parent = graph.FindProducer(mean_inputs[0]->id); + if (mean_parent != sub_first_parent) { + return absl::NotFoundError("MeanStdDevNormalization not suitable."); + } + } + auto sub_output = graph.FindOutputs(sub_node->id)[0]->id; + auto consumers = graph.FindConsumers(sub_output); + if (consumers.size() != 2) { + return absl::NotFoundError("MeanStdDevNormalization not suitable."); + } + Node* pow_node = consumers[0]; + Node* sub_child_mul_node = consumers[1]; + if (!CheckIfValidNodeOfType(pow_node, OperationType::POW).ok()) { + pow_node = consumers[1]; + sub_child_mul_node = consumers[0]; + } + RETURN_IF_ERROR(CheckIfValidNodeOfType(pow_node, OperationType::POW)); + RETURN_IF_ERROR( + CheckIfValidNodeOfType(sub_child_mul_node, OperationType::MUL)); + float pow_value; + RETURN_IF_ERROR(GetElementwiseScalarValue(pow_node, &pow_value)); + if (pow_value != 2.0) { + return absl::NotFoundError("MeanStdDevNormalization not suitable."); + } + Node* second_mean_node; + RETURN_IF_ERROR(GetNextSingleNode(graph, *pow_node, OperationType::MEAN, + &second_mean_node)); + auto second_mean_attr = + absl::any_cast(second_mean_node->operation.attributes); + if (second_mean_attr.dims != std::set{Axis::CHANNELS}) { + return absl::NotFoundError("MeanStdDevNormalization not suitable."); + } + Node* add_node; + RETURN_IF_ERROR(GetNextSingleNode(graph, *second_mean_node, + OperationType::ADD, &add_node)); + float add_value; + RETURN_IF_ERROR(GetElementwiseScalarValue(add_node, &add_value)); + Node* rsqrt_node; + RETURN_IF_ERROR( + GetNextSingleNode(graph, *add_node, OperationType::RSQRT, &rsqrt_node)); + Node* mul_node; + RETURN_IF_ERROR( + GetNextSingleNode(graph, *rsqrt_node, OperationType::MUL, &mul_node)); + if (sub_child_mul_node != mul_node) { + return absl::NotFoundError("MeanStdDevNormalization not suitable."); + } + + OperationDef op_def; + op_def.precision = precision; + auto input_id = graph.FindInputs(first_mean_node->id)[0]->id; + auto it = tensor_descriptors.find(input_id); + if (it != tensor_descriptors.end()) { + op_def.src_tensors.push_back(it->second); + } + auto output_id = graph.FindInputs(mul_node->id)[0]->id; + it = tensor_descriptors.find(output_id); + if (it != tensor_descriptors.end()) { + op_def.dst_tensors.push_back(it->second); + } + + auto subgraph_inputs = graph.FindInputs(first_mean_node->id); + auto subgraph_outputs = graph.FindOutputs(mul_node->id); + std::unique_ptr* gpu_op = + InitSingleOpSubgraph(subgraph_inputs, subgraph_outputs, gpu_subgraph); + *gpu_op = + std::make_unique(CreateMeanStdDevNormalization( + op_def, gpu_info, subgraph_inputs[0]->tensor.shape, add_value, + /*two_step*/ false)); + + consumed_nodes->insert(first_mean_node->id); + consumed_nodes->insert(sub_node->id); + consumed_nodes->insert(pow_node->id); + consumed_nodes->insert(second_mean_node->id); + consumed_nodes->insert(add_node->id); + consumed_nodes->insert(rsqrt_node->id); + consumed_nodes->insert(mul_node->id); + + return absl::OkStatus(); +} } // namespace absl::Status GPUSubgraphFromGraph( @@ -277,6 +441,13 @@ absl::Status GPUSubgraphFromGraph( gpu_subgraph->operations[0].name = "slice_mul_mean_concat"; return absl::OkStatus(); } + if (TryMeanStdDevNormalization(gpu_info, precision, graph, first_node_id, + tensor_descriptors, consumed_nodes, + gpu_subgraph) + .ok()) { + gpu_subgraph->operations[0].name = "mean_stddev_normalization"; + return absl::OkStatus(); + } return absl::NotFoundError("No special combination."); } diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h b/tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h index 13d64a807fb19f..c3919120465aa9 100644 --- a/tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h +++ b/tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h @@ -154,6 +154,18 @@ struct GPUResources { } }; +struct GenericGPUResourcesWithValue { + std::vector> ints; + std::vector> floats; + + void AddFloat(const std::string& name, float value) { + floats.push_back({name, value}); + } + void AddInt(const std::string& name, int value) { + ints.push_back({name, value}); + } +}; + class GPUObjectDescriptor { public: GPUObjectDescriptor() = default; diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc index d836dfd4135889..26d92d66113bde 100644 --- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc @@ -83,7 +83,7 @@ DataType OperationDef::GetDataType() const { } DataType OperationDef::GetPrimaryDataType() const { - return src_tensors[0].data_type; + return src_tensors[0].GetDataType(); } TensorStorageType OperationDef::GetPrimaryStorageType() const { return src_tensors[0].GetStorageType(); diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc index 74a0deb76adaaa..2a33c5f0129cc7 100644 --- a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc +++ b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.cc @@ -79,7 +79,7 @@ std::string GetWriteImageFromDataType(DataType data_type) { } } -std::string GetConvertionForImage(const GpuInfo& gpu_info, DataType src_type, +std::string GetConversionForImage(const GpuInfo& gpu_info, DataType src_type, DataType dst_type) { DataType interm_type = src_type; if (gpu_info.IsApiOpenCl()) { @@ -90,20 +90,20 @@ std::string GetConvertionForImage(const GpuInfo& gpu_info, DataType src_type, } else if (gpu_info.IsApiMetal()) { interm_type = ToMetalTextureType(src_type); } - return GetTypeConvertion(gpu_info, interm_type, dst_type, 4); + return GetTypeConversion(gpu_info, interm_type, dst_type, 4); } -std::string GetConvertion(const GpuInfo& gpu_info, +std::string GetConversion(const GpuInfo& gpu_info, TensorStorageType storage_type, DataType src_type, DataType dst_type) { if (storage_type == TensorStorageType::BUFFER) { - return GetTypeConvertion(gpu_info, src_type, dst_type, 4); + return GetTypeConversion(gpu_info, src_type, dst_type, 4); } else { - return GetConvertionForImage(gpu_info, src_type, dst_type); + return GetConversionForImage(gpu_info, src_type, dst_type); } } -void MayBeAddConvertion(const std::string& conversion, std::string* result) { +void MayBeAddConversion(const std::string& conversion, std::string* result) { if (!conversion.empty()) { *result = conversion + "(" + *result + ")"; } @@ -251,6 +251,27 @@ GPUResources TensorDescriptor::GetGPUResources(const GpuInfo& gpu_info) const { return resources; } +void TensorDescriptor::GetGpuResources( + const BHWDC& tensor_shape, GenericGPUResourcesWithValue* resources) const { + resources->AddInt("slice_stride", GetSliceStrideSize(tensor_shape)); + if (HasAxis(Axis::WIDTH)) { + resources->AddInt("width", GetWidthSize(tensor_shape)); + } + if (HasAxis(Axis::HEIGHT)) { + resources->AddInt("height", tensor_shape.h); + } + if (HasAxis(Axis::CHANNELS)) { + resources->AddInt("slices", DivideRoundUp(tensor_shape.c, 4)); + resources->AddInt("channels", tensor_shape.c); + } + if (HasAxis(Axis::BATCH)) { + resources->AddInt("batch", tensor_shape.b); + } + if (HasAxis(Axis::DEPTH)) { + resources->AddInt("depth", tensor_shape.d); + } +} + absl::Status TensorDescriptor::PerformConstExpr(const GpuInfo& gpu_info, const std::string& const_expr, std::string* result) const { @@ -610,7 +631,7 @@ std::string TensorDescriptor::Read( const GpuInfo& gpu_info, DataType read_as_type, const std::vector& coords) const { const std::string conversion = - GetConvertion(gpu_info, storage_type, data_type, read_as_type); + GetConversion(gpu_info, storage_type, data_type, read_as_type); if (gpu_info.IsApiOpenCl() && !(data_type == DataType::FLOAT16 && read_as_type == DataType::FLOAT32)) { read_as_type = data_type; @@ -626,7 +647,7 @@ std::string TensorDescriptor::Read( } else { result = absl::StrCat("buffer[", coords[0], "]"); } - MayBeAddConvertion(conversion, &result); + MayBeAddConversion(conversion, &result); return result; } case TensorStorageType::TEXTURE_2D: @@ -647,7 +668,7 @@ std::string TensorDescriptor::Read( result = "f16vec4(" + result + ")"; } } - MayBeAddConvertion(conversion, &result); + MayBeAddConversion(conversion, &result); return result; } case TensorStorageType::TEXTURE_3D: { @@ -668,7 +689,7 @@ std::string TensorDescriptor::Read( result = "f16vec4(" + result + ")"; } } - MayBeAddConvertion(conversion, &result); + MayBeAddConversion(conversion, &result); return result; } case TensorStorageType::TEXTURE_ARRAY: { @@ -689,7 +710,7 @@ std::string TensorDescriptor::Read( result = "f16vec4(" + result + ")"; } } - MayBeAddConvertion(conversion, &result); + MayBeAddConversion(conversion, &result); return result; } case TensorStorageType::IMAGE_BUFFER: { @@ -706,7 +727,7 @@ std::string TensorDescriptor::Read( result = "f16vec4(" + result + ")"; } } - MayBeAddConvertion(conversion, &result); + MayBeAddConversion(conversion, &result); return result; } case TensorStorageType::UNKNOWN: @@ -740,7 +761,7 @@ std::string TensorDescriptor::Write( std::string write_expr = var_name; if (write_type != write_required_type) { const std::string conversion = - GetTypeConvertion(gpu_info, write_type, write_required_type, 4); + GetTypeConversion(gpu_info, write_type, write_required_type, 4); if (!conversion.empty()) { write_expr = conversion + "(" + write_expr + ")"; } diff --git a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h index 609636b203586f..1d62ae8225bf52 100644 --- a/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h +++ b/tensorflow/lite/delegates/gpu/common/task/tensor_desc.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" @@ -58,6 +59,9 @@ struct TensorDescriptor : public GPUObjectDescriptor { bool operator!=(const TensorDescriptor& d) const { return !(*this == d); } + void GetGpuResources(const BHWDC& tensor_shape, + GenericGPUResourcesWithValue* resources) const; + absl::Status PerformConstExpr(const GpuInfo& gpu_info, const std::string& const_expr, std::string* result) const override; @@ -97,6 +101,7 @@ struct TensorDescriptor : public GPUObjectDescriptor { bool CanReadOutOfBorder(const Axis& axis) const; bool IsLinear() const; + DataType GetDataType() const { return data_type; } TensorStorageType GetStorageType() const { return storage_type; } // applicable only for types that: IsLinear -> true. @@ -119,8 +124,6 @@ struct TensorDescriptor : public GPUObjectDescriptor { absl::Status UpdateToSupportedStorageType(const GpuInfo& gpu_info, const BHWC& shape); - DataType data_type = DataType::UNKNOWN; - void SetUseBufferForWriteOnlyTexture2d(bool value) { use_buffer_for_write_only_2d_texture = value; } @@ -240,6 +243,7 @@ struct TensorDescriptor : public GPUObjectDescriptor { template void DownloadData(T* dst); + DataType data_type = DataType::UNKNOWN; TensorStorageType storage_type = TensorStorageType::UNKNOWN; // This field describes logical layout, actual(physical) GPU layout can be diff --git a/tensorflow/lite/delegates/gpu/common/task/util.cc b/tensorflow/lite/delegates/gpu/common/task/util.cc index 2800b27d4aae55..1d5917aad7233b 100644 --- a/tensorflow/lite/delegates/gpu/common/task/util.cc +++ b/tensorflow/lite/delegates/gpu/common/task/util.cc @@ -255,7 +255,7 @@ std::string GetOneValue(const GpuInfo& gpu_info, DataType data_type, } } -std::string GetTypeConvertion(const GpuInfo& gpu_info, DataType src_type, +std::string GetTypeConversion(const GpuInfo& gpu_info, DataType src_type, DataType dst_type, int vec_size) { if (src_type != dst_type) { if (gpu_info.IsApiOpenCl()) { diff --git a/tensorflow/lite/delegates/gpu/common/task/util.h b/tensorflow/lite/delegates/gpu/common/task/util.h index b7694cab96bd36..89b6243a3cdba1 100644 --- a/tensorflow/lite/delegates/gpu/common/task/util.h +++ b/tensorflow/lite/delegates/gpu/common/task/util.h @@ -70,7 +70,7 @@ std::string GetZeroValue(const GpuInfo& gpu_info, DataType data_type, std::string GetOneValue(const GpuInfo& gpu_info, DataType data_type, int vec_size); -std::string GetTypeConvertion(const GpuInfo& gpu_info, DataType src_type, +std::string GetTypeConversion(const GpuInfo& gpu_info, DataType src_type, DataType dst_type, int vec_size); } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/tasks/BUILD b/tensorflow/lite/delegates/gpu/common/tasks/BUILD index 52e5f8fc62456b..d53da31c049ea7 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/BUILD +++ b/tensorflow/lite/delegates/gpu/common/tasks/BUILD @@ -638,7 +638,6 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", - "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ], ) diff --git a/tensorflow/lite/delegates/gpu/common/tasks/cast.cc b/tensorflow/lite/delegates/gpu/common/tasks/cast.cc index 18cfc7f13945ef..9521062bf0d33c 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/cast.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/cast.cc @@ -54,8 +54,8 @@ std::string GetCastKernelCode(const OperationDef& op_def, c += " args.src_tensor::type src_value = args.src_tensor.Read(" + coords + ");\n"; const std::string conversion = - GetTypeConvertion(gpu_info, op_def.src_tensors[0].data_type, - op_def.dst_tensors[0].data_type, 4); + GetTypeConversion(gpu_info, op_def.src_tensors[0].GetDataType(), + op_def.dst_tensors[0].GetDataType(), 4); if (conversion.empty()) { c += " args.dst_tensor::type result = src_value;\n"; } else { diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc index d50a323d7a1ccb..6c31087a1dc057 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc @@ -295,7 +295,7 @@ void ConvGeneric::GenerateCode(const GpuInfo& gpu_info) { definition_.src_tensors.push_back(desc); for (int i = 0; i < 4; ++i) { Texture2DDescriptor desc; - desc.element_type = definition_.src_tensors[1 + i].data_type; + desc.element_type = definition_.src_tensors[1 + i].GetDataType(); const std::string name = "weights" + std::to_string(i); AddSrcTexture2D("weights" + std::to_string(i), desc); } diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_metal.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_metal.cc index b141bed99265d4..dffc25b4c1f91e 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/conv_metal.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_metal.cc @@ -1082,7 +1082,7 @@ ConvolutionMetal CreateConvolutionMetal(const OperationDef& definition, if (definition.src_tensors.size() == 2) { // dynamic weights BufferDescriptor weights_desc; - weights_desc.element_type = definition.src_tensors[1].data_type; + weights_desc.element_type = definition.src_tensors[1].GetDataType(); weights_desc.element_size = 4; weights_desc.memory_type = params.GetMemoryType(); desc.AddSrcBuffer("weights", weights_desc); @@ -1133,7 +1133,7 @@ ConvolutionMetal CreateConvolutionMetalBatchedMatMul( // dynamic weights BufferDescriptor weights_desc; - weights_desc.element_type = definition.src_tensors[1].data_type; + weights_desc.element_type = definition.src_tensors[1].GetDataType(); weights_desc.element_size = 4; weights_desc.memory_type = params.GetMemoryType(); desc.AddSrcBuffer("weights", weights_desc); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_metal_simd.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_metal_simd.cc index 4f43a9546d315f..e7b0de7ddb456e 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/conv_metal_simd.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_metal_simd.cc @@ -483,7 +483,7 @@ ConvolutionMetalSimd CreateConvolutionMetalSimd( if (definition.src_tensors.size() == 2) { // dynamic weights BufferDescriptor weights_desc; - weights_desc.element_type = definition.src_tensors[1].data_type; + weights_desc.element_type = definition.src_tensors[1].GetDataType(); weights_desc.element_size = 4; weights_desc.memory_type = mem_type; desc.AddSrcBuffer("weights", weights_desc); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.cc index 71ec1199dca190..c872d918dea03e 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.cc @@ -140,14 +140,14 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( if (weights_layout_ == WeightsLayout::kOSpatialIOGroupI4O4 || weights_layout_ == WeightsLayout::kOSpatialIOGroupO4I4) { BufferDescriptor desc; - desc.element_type = op_def.src_tensors[1].data_type; + desc.element_type = op_def.src_tensors[1].GetDataType(); desc.element_size = 16; desc.memory_type = MemoryType::GLOBAL; AddSrcBuffer("weights", desc); } else { for (int i = 0; i < 4; ++i) { Texture2DDescriptor desc; - desc.element_type = op_def.src_tensors[1 + i].data_type; + desc.element_type = op_def.src_tensors[1 + i].GetDataType(); const std::string name = "weights" + std::to_string(i); AddSrcTexture2D("weights" + std::to_string(i), desc); } diff --git a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3.cc b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3.cc index 55ef8c2eb20f52..83554da6afabe8 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3.cc @@ -78,7 +78,7 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( if (op_def.src_tensors.size() == 2) { // dynamic weights BufferDescriptor desc; - desc.element_type = op_def.src_tensors[1].data_type; + desc.element_type = op_def.src_tensors[1].GetDataType(); desc.element_size = 4; desc.memory_type = weights_upload_type == diff --git a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3_thin.cc b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3_thin.cc index 582a3174af71eb..2a387960d140ac 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3_thin.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_3x3_thin.cc @@ -84,7 +84,7 @@ std::string ConvolutionTransposed3x3Thin::GenerateConvolutionTransposedCode( if (op_def.src_tensors.size() == 2) { // dynamic weights BufferDescriptor desc; - desc.element_type = op_def.src_tensors[1].data_type; + desc.element_type = op_def.src_tensors[1].GetDataType(); desc.element_size = 4; desc.memory_type = MemoryType::CONSTANT; AddSrcBuffer("weights", desc); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_4x4.cc index b6e1bb2b2062e0..12d3074187b6a6 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_4x4.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed_4x4.cc @@ -95,7 +95,7 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( if (op_def.src_tensors.size() == 2) { // dynamic weights BufferDescriptor desc; - desc.element_type = op_def.src_tensors[1].data_type; + desc.element_type = op_def.src_tensors[1].GetDataType(); desc.element_size = 4; desc.memory_type = weights_upload_type == diff --git a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc index 03a7ba5d60edca..1d00b9e3df9f2f 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/elementwise.cc @@ -30,13 +30,20 @@ std::string GetOneInputCode(const GpuInfo& gpu_info, const OperationType& op_type, CalculationsPrecision precision, const std::string& input0) { + const bool use_native_opencl_functions = + gpu_info.IsApiOpenCl() && precision != CalculationsPrecision::F32 && + gpu_info.IsAdreno(); std::string result; switch (op_type) { case OperationType::ABS: result = "$0 = fabs($0);\n"; break; case OperationType::COS: - result = "$0 = cos($0);\n"; + if (use_native_opencl_functions) { + result = "$0 = native_cos($0);\n"; + } else { + result = "$0 = cos($0);\n"; + } break; case OperationType::COPY: // No op as inout_value will be copied to dest automatically. @@ -58,7 +65,11 @@ std::string GetOneInputCode(const GpuInfo& gpu_info, } break; case OperationType::EXP: - result = "$0 = exp($0);\n"; + if (use_native_opencl_functions) { + result = "$0 = native_exp($0);\n"; + } else { + result = "$0 = exp($0);\n"; + } break; case OperationType::FLOOR: result = "$0 = floor($0);\n"; @@ -70,16 +81,24 @@ std::string GetOneInputCode(const GpuInfo& gpu_info, "INIT_FLT4(1.0f));\n"; break; case OperationType::LOG: - result = "$0 = log($0);\n"; + if (use_native_opencl_functions) { + result = "$0 = native_log($0);\n"; + } else { + result = "$0 = log($0);\n"; + } break; case OperationType::NEG: result = "$0 = -($0);\n"; break; case OperationType::RSQRT: - result = "$0 = rsqrt($0);\n"; + if (use_native_opencl_functions) { + result = "$0 = native_rsqrt($0);\n"; + } else { + result = "$0 = rsqrt($0);\n"; + } break; case OperationType::SIGMOID: - if (gpu_info.IsApiOpenCl() && precision != CalculationsPrecision::F32) { + if (use_native_opencl_functions) { result = "$0 = convert_half4(native_recip(1.0f + " "native_exp(convert_float4(-$0))));\n"; @@ -88,16 +107,31 @@ std::string GetOneInputCode(const GpuInfo& gpu_info, } break; case OperationType::SIN: - result = "$0 = sin($0);\n"; + if (use_native_opencl_functions) { + result = "$0 = native_sin($0);\n"; + } else { + result = "$0 = sin($0);\n"; + } break; case OperationType::SQRT: - result = "$0 = sqrt($0);\n"; + if (use_native_opencl_functions) { + result = "$0 = native_sqrt($0);\n"; + } else { + result = "$0 = sqrt($0);\n"; + } break; case OperationType::SQUARE: result = "$0 *= $0;\n"; break; case OperationType::TANH: - result = "$0 = tanh($0);\n"; + if (use_native_opencl_functions) { + result = " FLT4 exp_val = native_exp(INIT_FLT4(2.0f) * $0);\n"; + result += + "$0 = ((exp_val - INIT_FLT4(1.0f)) / (exp_val + " + "INIT_FLT4(1.0f)));\n"; + } else { + result = "$0 = tanh($0);\n"; + } break; default: return "Unknown operation type;\n"; diff --git a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.cc b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.cc index 0c9564aecb7258..ddf74a20679ac2 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.cc @@ -17,38 +17,40 @@ limitations under the License. #include -#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" - namespace tflite { namespace gpu { - namespace { -std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def, - GPUOperation* op) { - auto src_desc = op_def.src_tensors[0]; - if (op_def.IsBatchSupported()) { - src_desc.SetStateVar("BatchedWidth", "true"); - } - op->AddSrcTensor("src_tensor", src_desc); - auto src_ind_desc = op_def.src_tensors[1]; - if (op_def.IsBatchSupported()) { - src_ind_desc.SetStateVar("BatchedWidth", "true"); - } - op->AddSrcTensor("src_indices", src_ind_desc); - auto dst_desc = op_def.dst_tensors[0]; - if (op_def.IsBatchSupported()) { - dst_desc.SetStateVar("BatchedWidth", "true"); +void AppendConditionally(const std::string& value, const std::string& delimeter, + std::string* result) { + if (!result->empty()) { + *result += delimeter; } - op->AddDstTensor("dst_tensor", dst_desc); + *result += value; +} + +std::string GetMaxUnpoolingKernelCode(const GpuInfo& gpu_info, + const OperationDef& op_def, + GPUOperation* op) { + op->AddSrcTensor("src_tensor", op_def.src_tensors[0]); + op->AddSrcTensor("src_indices", op_def.src_tensors[1]); + op->AddDstTensor("dst_tensor", op_def.dst_tensors[0]); std::string c; c += "MAIN_FUNCTION($0) {\n"; - c += " int X = GLOBAL_ID_0;\n"; + if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { + c += " int linear_id = GLOBAL_ID_0;\n"; + c += " int X = linear_id / args.dst_tensor.Batch();\n"; + c += " int B = linear_id % args.dst_tensor.Batch();\n"; + c += " args.src_tensor.SetBatchRef(B);\n"; + c += " args.src_indices.SetBatchRef(B);\n"; + c += " args.dst_tensor.SetBatchRef(B);\n"; + } else { + c += " int X = GLOBAL_ID_0;\n"; + } if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { c += " int linear_id_1 = GLOBAL_ID_1;\n"; c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n"; c += " int Z = linear_id_1 % args.dst_tensor.Depth();\n"; - c += " int src_z = (Z + args.padding_z) / args.stride_z;\n"; } else { c += " int Y = GLOBAL_ID_1;\n"; } @@ -57,72 +59,66 @@ std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def, "S >= args.dst_tensor.Slices()) { \n"; c += " return; \n"; c += " } \n"; - if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { - c += " int linear_id_0 = GLOBAL_ID_0;\n"; - c += " int X0 = linear_id_0 / args.dst_tensor.Batch();\n"; - c += " int B = linear_id_0 % args.dst_tensor.Batch();\n"; - c += " int src_x0 = (X0 + args.padding_x * args.dst_tensor.Batch()) / " - "args.stride_x;\n"; - c += " int src_x = src_x0 * args.dst_tensor.Batch() + B;\n"; - } else { - c += " int src_x = (X + args.padding_x) / args.stride_x;\n"; - } + c += " int src_x = (X + args.padding_x) / args.stride_x;\n"; + c += " int t_x = X - (src_x * args.stride_x - args.padding_x);\n"; c += " int src_y = (Y + args.padding_y) / args.stride_y;\n"; - std::string src_args = op_def.dst_tensors[0].HasAxis(Axis::DEPTH) - ? "src_x, src_y, src_z, S" - : "src_x, src_y, S"; - if (op_def.src_tensors[0].GetStorageType() == TensorStorageType::BUFFER) { - if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { - c += " bool outside = src_x < 0 || src_y < 0 || src_z < 0 || src_x >= " - "args.src_tensor.Width() || src_y >= args.src_tensor.Height() || " - "src_z >= args.src_tensor.Depth();\n"; - } else { - c += " bool outside = src_x < 0 || src_y < 0 || src_x >= " - "args.src_tensor.Width() || src_y >= args.src_tensor.Height();\n"; - } - c += " FLT4 src = INIT_FLT4(0.0f);\n"; - c += " int4 ind = INIT_INT4v4(0, 0, 0, 0);\n"; - c += " if (!outside) {\n"; - c += " src = args.src_tensor.Read(" + src_args + ");\n"; - c += " ind = args.src_indices.Read(" + src_args + ");\n"; - c += " }\n"; - } else { - c += " FLT4 src = args.src_tensor.Read(" + src_args + ");\n"; - c += " int4 ind = args.src_indices.Read(" + src_args + ");\n"; - } - if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) { - c += " int t_x = X0 - (src_x0 * args.stride_x - args.padding_x * " - "args.dst_tensor.Batch());\n"; - } else { - c += " int t_x = X - (src_x * args.stride_x - args.padding_x);\n"; - } c += " int t_y = Y - (src_y * args.stride_y - args.padding_y);\n"; if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { + c += " int src_z = (Z + args.padding_z) / args.stride_z;\n"; c += " int t_z = Z - (src_z * args.stride_z - args.padding_z);\n"; c += " int t_index = (t_y * args.kernel_size_x + t_x) * " "args.kernel_size_z + t_z;\n"; } else { c += " int t_index = t_y * args.kernel_size_x + t_x;\n"; } - c += " FLT4 result;\n"; - const std::string channels[] = {".x", ".y", ".z", ".w"}; - for (int i = 0; i < 4; ++i) { - const auto& s = channels[i]; - c += " result" + s + "= t_index == ind" + s + "? src" + s + - ": INIT_FLT(0.0f);\n"; + std::string inbounds_check; + if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::WIDTH, gpu_info) || + !op_def.src_tensors[1].SupportsZeroClamp(Axis::WIDTH, gpu_info)) { + c += " bool inside_x = src_x >= 0 && src_x < args.src_tensor.Width();\n"; + c += " src_x = clamp(src_x, 0, args.src_tensor.Width() - 1);\n"; + AppendConditionally("inside_x", " && ", &inbounds_check); + } + if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::HEIGHT, gpu_info) || + !op_def.src_tensors[1].SupportsZeroClamp(Axis::HEIGHT, gpu_info)) { + c += " bool inside_y = src_y >= 0 && src_y < args.src_tensor.Height();\n"; + c += " src_y = clamp(src_y, 0, args.src_tensor.Height() - 1);\n"; + AppendConditionally("inside_y", " && ", &inbounds_check); } + if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { + if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::DEPTH, gpu_info) || + !op_def.src_tensors[1].SupportsZeroClamp(Axis::DEPTH, gpu_info)) { + c += " bool inside_z = src_z >= 0 && src_z < args.src_tensor.Depth();\n"; + c += " src_z = clamp(src_z, 0, args.src_tensor.Depth() - 1);\n"; + AppendConditionally("inside_z", " && ", &inbounds_check); + } + } + std::string src_args = op_def.dst_tensors[0].HasAxis(Axis::DEPTH) + ? "src_x, src_y, src_z, S" + : "src_x, src_y, S"; + c += + " args.src_tensor::type src = args.src_tensor.Read(" + src_args + ");\n"; + c += " int4 ind = args.src_indices.Read(" + src_args + ");\n"; + if (!inbounds_check.empty()) { + c += " src *= INIT_FLT(" + inbounds_check + ");\n"; + c += " ind *= INIT_INT(" + inbounds_check + ");\n"; + } + c += " args.src_tensor::type result;\n"; + c += " result.x = t_index == ind.x ? src.x : INIT_FLT(0.0f);\n"; + c += " result.y = t_index == ind.y ? src.y : INIT_FLT(0.0f);\n"; + c += " result.z = t_index == ind.z ? src.z : INIT_FLT(0.0f);\n"; + c += " result.w = t_index == ind.w ? src.w : INIT_FLT(0.0f);\n"; if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { c += " args.dst_tensor.Write(result, X, Y, Z, S);\n"; } else { c += " args.dst_tensor.Write(result, X, Y, S);\n"; } c += "}\n"; - return c; } } // namespace -GPUOperation CreateMaxUnpooling(const OperationDef& definition, +GPUOperation CreateMaxUnpooling(const GpuInfo& gpu_info, + const OperationDef& definition, const MaxUnpooling2DAttributes& attr) { GPUOperation op(definition); op.args_.AddInt("kernel_size_x", attr.kernel.w); @@ -131,12 +127,13 @@ GPUOperation CreateMaxUnpooling(const OperationDef& definition, op.args_.AddInt("kernel_size_y", attr.kernel.h); op.args_.AddInt("padding_y", attr.padding.appended.h); op.args_.AddInt("stride_y", attr.strides.h); - op.code_ = GetMaxUnpoolingKernelCode(definition, &op); + op.code_ = GetMaxUnpoolingKernelCode(gpu_info, definition, &op); op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; return op; } -GPUOperation CreateMaxUnpooling(const OperationDef& definition, +GPUOperation CreateMaxUnpooling(const GpuInfo& gpu_info, + const OperationDef& definition, const MaxUnpooling3DAttributes& attr) { GPUOperation op(definition); op.args_.AddInt("kernel_size_x", attr.kernel.w); @@ -148,7 +145,7 @@ GPUOperation CreateMaxUnpooling(const OperationDef& definition, op.args_.AddInt("kernel_size_z", attr.kernel.d); op.args_.AddInt("padding_z", attr.padding.appended.d); op.args_.AddInt("stride_z", attr.strides.d); - op.code_ = GetMaxUnpoolingKernelCode(definition, &op); + op.code_ = GetMaxUnpoolingKernelCode(gpu_info, definition, &op); op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; return op; } diff --git a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h index 6e90d372cb6d55..d6b0bd6fd0094e 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h +++ b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling.h @@ -24,10 +24,12 @@ limitations under the License. namespace tflite { namespace gpu { -GPUOperation CreateMaxUnpooling(const OperationDef& definition, +GPUOperation CreateMaxUnpooling(const GpuInfo& gpu_info, + const OperationDef& definition, const MaxUnpooling2DAttributes& attr); -GPUOperation CreateMaxUnpooling(const OperationDef& definition, +GPUOperation CreateMaxUnpooling(const GpuInfo& gpu_info, + const OperationDef& definition, const MaxUnpooling3DAttributes& attr); } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling_test_util.cc b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling_test_util.cc index d70ef1ab4e175e..43d2fddbc867d6 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling_test_util.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/max_unpooling_test_util.cc @@ -50,7 +50,8 @@ absl::Status MaxUnpoolingTest(TestExecutionEnvironment* env) { op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); TensorFloat32 dst_tensor; - GPUOperation operation = CreateMaxUnpooling(op_def, attr); + GPUOperation operation = + CreateMaxUnpooling(env->GetGpuInfo(), op_def, attr); RETURN_IF_ERROR(env->ExecuteGPUOperation( {src_tensor, src_ind_tensor}, std::make_unique(std::move(operation)), diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.cc b/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.cc index 7e07c361bb9156..c3847f976df232 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h" #include +#include #include #include "absl/strings/substitute.h" @@ -27,7 +28,8 @@ namespace gpu { namespace { std::string GetReduceCode(const std::string& src_value, - const std::string& dst_value, int3 work_group_size) { + const std::string& dst_value, int3 work_group_size, + bool two_step) { int reduction_size = work_group_size.z; std::string mem_name = work_group_size.x * work_group_size.y != 1 ? "shared_mem[LOCAL_ID_1][LOCAL_ID_0]" @@ -42,7 +44,9 @@ std::string GetReduceCode(const std::string& src_value, result += " " + dst_value + " += " + mem_name + "[" + std::to_string(i) + "];\n"; } - result += " LOCAL_MEM_BARRIER;\n"; + if (two_step) { + result += " LOCAL_MEM_BARRIER;\n"; + } result += " }\n"; return result; } else { @@ -93,7 +97,8 @@ std::string ZeroClampVec4Code(const std::string& slice_name, MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition, const GpuInfo& gpu_info, const BHWC& shape, - float variance_bias) + float variance_bias, + bool two_step) : GPUOperation(definition) { const int tensor_slices = DivideRoundUp(shape.c, 4); int desired_work_group_size = gpu_info.GetMaxWorkGroupSizeForZ(); @@ -169,11 +174,12 @@ MeanStdDevNormalization::MeanStdDevNormalization(const OperationDef& definition, } } args_.AddFloat("variance_bias", variance_bias); - code_ = GetNormalizationCode(gpu_info, shape.c % 4 == 0); + args_.AddFloat("inv_ch_count", 1.0f / shape.c); + code_ = GetNormalizationCode(gpu_info, shape.c % 4 == 0, two_step); } std::string MeanStdDevNormalization::GetNormalizationCode( - const GpuInfo& gpu_info, bool channels_x4) { + const GpuInfo& gpu_info, bool channels_x4, bool two_step) { AddSrcTensor("src_tensor", definition_.src_tensors[0]); AddDstTensor("dst_tensor", definition_.dst_tensors[0]); @@ -185,12 +191,14 @@ std::string MeanStdDevNormalization::GetNormalizationCode( std::to_string(work_group_size_.z) + ")))\n"; } c += "MAIN_FUNCTION($0) {\n"; + std::string accum_type = two_step ? "float" : "float2"; if (work_group_size_.x * work_group_size_.y == 1) { - c += "__local float shared_mem[" + std::to_string(work_group_size_.z) + - "];\n"; + c += "__local " + accum_type + " shared_mem[" + + std::to_string(work_group_size_.z) + "];\n"; } else { - c += "__local float shared_mem[" + std::to_string(work_group_size_.x) + - "][" + std::to_string(work_group_size_.y) + "][" + + c += "__local " + accum_type + " shared_mem[" + + std::to_string(work_group_size_.x) + "][" + + std::to_string(work_group_size_.y) + "][" + std::to_string(work_group_size_.z) + "];\n"; } if (definition_.dst_tensors[0].HasAxis(Axis::BATCH)) { @@ -203,9 +211,10 @@ std::string MeanStdDevNormalization::GetNormalizationCode( c += " int X = GLOBAL_ID_0;\n"; } c += " int Y = GLOBAL_ID_1;\n"; + if (!two_step) { + c += " float4 private_sum4_sq = INIT_FLOAT4(0.0f);\n"; + } c += R"( - // Calculate the total sum of the input tensor. - // First, get a local sum of input[local_id_x + N*local_size_x] for all N. float4 private_sum4 = INIT_FLOAT4(0.0f); int local_id = LOCAL_ID_2; int reduction_group_size = GROUP_SIZE_2; @@ -216,17 +225,25 @@ std::string MeanStdDevNormalization::GetNormalizationCode( if (!channels_x4) { c += ZeroClampVec4Code("S", "args.src_tensor.Channels()", "t"); } - c += R"( - private_sum4 += t; + if (two_step) { + c += " private_sum4 += t;\n"; + c += " }\n"; + c += " float private_sum = dot(private_sum4, INIT_FLOAT4(1.0f));\n"; + c += " float sum;\n"; + } else { + c += " private_sum4 += t;\n"; + c += " private_sum4_sq += t * t;\n"; + c += " }\n"; + c += " float2 private_sum;\n"; + c += " private_sum.x = dot(private_sum4, INIT_FLOAT4(1.0f));\n"; + c += " private_sum.y = dot(private_sum4_sq, INIT_FLOAT4(1.0f));\n"; + c += " float2 sum;\n"; } - // Reduce the vector to a single float and do a workgroup reduce. - float private_sum = dot(private_sum4, INIT_FLOAT4(1.0f)); - float sum; -)"; - c += GetReduceCode("private_sum", "sum", work_group_size_); - c += R"( + c += GetReduceCode("private_sum", "sum", work_group_size_, two_step); + if (two_step) { + c += R"( // Calculate the mean - float mean = sum / INIT_FLOAT(args.src_tensor.Channels()); + float mean = sum * args.inv_ch_count; // Calculate the squared sum of the difference from the mean. float4 private_sum_diff_sq4 = INIT_FLOAT4(0.0f); for (int S = local_id; S < args.src_tensor.Slices(); S += reduction_group_size) { @@ -234,23 +251,29 @@ std::string MeanStdDevNormalization::GetNormalizationCode( int y_clamped = min(Y, args.src_tensor.Height() - 1); float4 t = args.src_tensor.Read(x_clamped, y_clamped, S); float4 diff = t - mean;)"; - if (!channels_x4) { - c += ZeroClampVec4Code("S", "args.src_tensor.Channels()", "diff"); - } - c += R"( + if (!channels_x4) { + c += ZeroClampVec4Code("S", "args.src_tensor.Channels()", "diff"); + } + c += R"( private_sum_diff_sq4 += diff * diff; } // Reduce float private_sum_diff_sq = dot(private_sum_diff_sq4, INIT_FLOAT4(1.0f)); float sum_diff_sq; )"; - c += GetReduceCode("private_sum_diff_sq", "sum_diff_sq", work_group_size_); + c += GetReduceCode("private_sum_diff_sq", "sum_diff_sq", work_group_size_, + two_step); + c += " float variance = sum_diff_sq * args.inv_ch_count;\n"; + } else { + c += " float mean = sum.x * args.inv_ch_count;\n"; + c += " float mean_sq = sum.y * args.inv_ch_count;\n"; + c += " float variance = mean_sq - mean * mean;\n"; + } c += R"( // no more shared memory usage, 'useless' threads can exit now if (X >= args.dst_tensor.Width()) { return; } if (Y >= args.dst_tensor.Height()) { return; } // Calculate 1/stddev (with the 'regulazing constant' as in tensor_utils.cc) - float variance = sum_diff_sq / INIT_FLOAT(args.src_tensor.Channels()); float stddev_inv = rsqrt(variance + args.variance_bias); // Calculate (t-mean)/stddev for each element for (int S = local_id; S < args.src_tensor.Slices(); S += reduction_group_size) { @@ -273,8 +296,9 @@ int3 MeanStdDevNormalization::GetGridSize() const { MeanStdDevNormalization CreateMeanStdDevNormalization( const OperationDef& definition, const GpuInfo& gpu_info, const BHWC& shape, - float variance_bias) { - return MeanStdDevNormalization(definition, gpu_info, shape, variance_bias); + float variance_bias, bool two_step) { + return MeanStdDevNormalization(definition, gpu_info, shape, variance_bias, + two_step); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h b/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h index c8f227b015fd9e..b1da62c410138e 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h +++ b/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization.h @@ -31,7 +31,7 @@ class MeanStdDevNormalization : public GPUOperation { public: explicit MeanStdDevNormalization(const OperationDef& definition, const GpuInfo& gpu_info, const BHWC& shape, - float variance_bias); + float variance_bias, bool two_step); void GetPossibleKernelWorkGroups( TuningType tuning_type, const GpuInfo& gpu_info, @@ -49,12 +49,15 @@ class MeanStdDevNormalization : public GPUOperation { MeanStdDevNormalization& operator=(const MeanStdDevNormalization&) = delete; private: - std::string GetNormalizationCode(const GpuInfo& gpu_info, bool channels_x4); + std::string GetNormalizationCode(const GpuInfo& gpu_info, bool channels_x4, + bool two_step); }; +// std dev can be calculated in single step, but two step algorithm can +// provide more stable and robust results MeanStdDevNormalization CreateMeanStdDevNormalization( const OperationDef& definition, const GpuInfo& gpu_info, const BHWC& shape, - float variance_bias = 1.0e-8f); + float variance_bias = 1.0e-8f, bool two_step = true); } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization_test_util.cc b/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization_test_util.cc index af1de5277d866a..28bc70be1fbe7a 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization_test_util.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization_test_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/tasks/mean_stddev_normalization_test_util.h" +#include #include #include @@ -63,6 +64,18 @@ absl::Status MeanStddevNormSeparateBatchesTest(float mean, float diff, } RETURN_IF_ERROR( PointWiseNear(expected_output, dst_tensor.data, tolerance)); + + TensorFloat32 dst_tensor_single_step; + auto operation_single_step = CreateMeanStdDevNormalization( + op_def, env->GetGpuInfo(), src_tensor.shape, + /*variance_bias*/ 1.0e-8f, /*two_step*/ false); + RETURN_IF_ERROR( + env->ExecuteGPUOperation({src_tensor}, + std::make_unique( + std::move(operation_single_step)), + BHWC(1, 1, 2, 4), &dst_tensor_single_step)); + RETURN_IF_ERROR(PointWiseNear(expected_output, + dst_tensor_single_step.data, tolerance)); } } return absl::OkStatus(); @@ -115,6 +128,19 @@ absl::Status MeanStddevNormalizationAllBatchesTest( }; RETURN_IF_ERROR(PointWiseNear(expected_output, dst_tensor.data, eps)) << "Failed using precision " << ToString(precision); + + TensorFloat32 dst_tensor_single_step; + auto operation_single_step = CreateMeanStdDevNormalization( + op_def, env->GetGpuInfo(), src_tensor.shape, + /*variance_bias*/ 1.0e-8f, /*two_step*/ false); + RETURN_IF_ERROR( + env->ExecuteGPUOperation({src_tensor}, + std::make_unique( + std::move(operation_single_step)), + BHWC(9, 1, 1, 4), &dst_tensor_single_step)); + RETURN_IF_ERROR( + PointWiseNear(expected_output, dst_tensor_single_step.data, eps)) + << "Failed using precision " << ToString(precision); } } return absl::OkStatus(); @@ -174,6 +200,21 @@ absl::Status MeanStddevNormalizationLargeVectorTest( } RETURN_IF_ERROR(PointWiseNear(expected_output, dst_tensor.data, eps)) << "Failed using precision " << ToString(precision); + + if (precision != CalculationsPrecision::F32) { + TensorFloat32 dst_tensor_single_step; + auto operation_single_step = CreateMeanStdDevNormalization( + op_def, env->GetGpuInfo(), src_tensor.shape, + /*variance_bias*/ 1.0e-8f, /*two_step*/ false); + RETURN_IF_ERROR(env->ExecuteGPUOperation( + {src_tensor}, + std::make_unique( + std::move(operation_single_step)), + BHWC(1, 1, 2, kVectorSize), &dst_tensor_single_step)); + RETURN_IF_ERROR( + PointWiseNear(expected_output, dst_tensor_single_step.data, eps)) + << "Failed using precision " << ToString(precision); + } } } return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc b/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc index 5b56410e316cb3..545411b4dc4330 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/reduce.cc @@ -212,7 +212,7 @@ std::string Reduce::GetReduceKernelCode(const OperationDef& op_def, } }; - auto accum_type = GetAccumType(op_def.src_tensors[0].data_type); + auto accum_type = GetAccumType(op_def.src_tensors[0].GetDataType()); const std::string accum_type_decl = GetTypeDeclaration(gpu_info, accum_type, 4); std::string read_as_template; @@ -454,8 +454,8 @@ std::string Reduce::GetReduceKernelCode(const OperationDef& op_def, c += " reducer.x = min(reducer.x, reducer.w);\n"; } } - const std::string conversion = GetTypeConvertion( - gpu_info, accum_type, op_def.src_tensors[0].data_type, 4); + const std::string conversion = GetTypeConversion( + gpu_info, accum_type, op_def.src_tensors[0].GetDataType(), 4); if (conversion.empty()) { c += " args.src_tensor::type result = reducer;\n"; } else { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc index 62c3ec3985467c..673502f2112de0 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -108,12 +108,94 @@ class MergeConvolutionWithAdd : public SequenceTransformation { } }; +void FuseAddWithConvolution2D(const ElementwiseAttributes& add_attr, + Convolution2DAttributes* attr) { + auto add = absl::get_if>(&add_attr.param); + auto add_scalar = absl::get_if(&add_attr.param); + if (attr->bias.data.empty()) { + attr->bias = MakeZeroTensor( + Linear(attr->weights.shape.o)); + } + for (int d = 0; d < attr->weights.shape.o; ++d) { + float sum = 0.0f; + for (int s = 0; s < attr->weights.shape.i; ++s) { + const float add_value = add ? add->data[s] : *add_scalar; + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}}); + sum += add_value * attr->weights.data[index]; + } + } + } + attr->bias.data[d] += sum; + } +} + +class MergeAddWithConvolution : public SequenceTransformation { + public: + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final { + auto& conv_node = *sequence[1]; + if (graph->FindInputs(conv_node.id).size() != 1) { + return {TransformStatus::DECLINED, + "This fusion is only applicable to ops with one runtime input."}; + } + auto& add_node = *sequence[0]; + if (add_node.operation.type != ToString(OperationType::ADD)) { + return {TransformStatus::SKIPPED, ""}; + } + ElementwiseAttributes add_attr = + absl::any_cast(add_node.operation.attributes); + if (!absl::holds_alternative>( + add_attr.param) && + !absl::holds_alternative(add_attr.param)) { + return {TransformStatus::DECLINED, + "This fuse applicable only for broadcast or scalar addition."}; + } + + if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { + Convolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + if (conv_attr->groups != 1) { + return {TransformStatus::DECLINED, + "This fuse not applicable for grouped convolution."}; + } + if (conv_attr->padding.appended.w != 0 || + conv_attr->padding.appended.h != 0 || + conv_attr->padding.prepended.w != 0 || + conv_attr->padding.prepended.h != 0) { + return {TransformStatus::DECLINED, + "This fuse applicable only for convolution that do not read " + "out of bound elements."}; + } + FuseAddWithConvolution2D(add_attr, conv_attr); + } else { + return {TransformStatus::SKIPPED, ""}; + } + + absl::Status status = RemovePrecedingNode(graph, &add_node, &conv_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove mul node after convolution: " + + std::string(status.message())}; + } + return {TransformStatus::APPLIED, ""}; + } +}; + } // namespace std::unique_ptr NewMergeConvolutionWithAdd() { return absl::make_unique(); } +std::unique_ptr NewMergeAddWithConvolution() { + return absl::make_unique(); +} + void FuseConvolution2DWithAdd(const ElementwiseAttributes& add_attr, Convolution2DAttributes* attr) { FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h index 26f93dc3765ee1..7a7b05d710de52 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h @@ -29,6 +29,10 @@ namespace gpu { // convolution. std::unique_ptr NewMergeConvolutionWithAdd(); +// Fuse Add Scalar or Add Broadcast before Convolution2D into weights and biases +// of convolution. +std::unique_ptr NewMergeAddWithConvolution(); + // Modify Convolution2DAttributes so that after making convolution with // modified attributes we will have the same result as convolution // with old attributes and following add operation. diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc index 8374fd11f54623..ee41d5bd3ebda5 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc @@ -172,6 +172,64 @@ TEST(FuseAddAfterFullyConnectedTest, Smoke) { EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f})); } +TEST(MergeAddWithConvolutionTest, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 2); + + Tensor add_tensor; + add_tensor.shape = Linear(2); + add_tensor.data = {1.0f, 2.0f}; + ElementwiseAttributes add_attr; + add_attr.param = add_tensor; + + Convolution2DAttributes conv_attr; + conv_attr.padding.prepended = HW(0, 0); + conv_attr.padding.appended = HW(0, 0); + conv_attr.strides = HW(1, 1); + conv_attr.dilations = HW(1, 1); + conv_attr.weights.shape = OHWI(2, 1, 2, 2); + conv_attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + conv_attr.bias.shape = Linear(2); + conv_attr.bias.data = {1.1f, 1.2f}; + + auto conv_node = graph.NewNode(); + conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); + conv_node->operation.attributes = conv_attr; + auto add_node = graph.NewNode(); + add_node->operation.type = ToString(OperationType::ADD); + add_node->operation.attributes = add_attr; + + ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok()); + + Value* output = nullptr; + ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok()); + output->tensor.shape = BHWC(1, 4, 3, 2); + + Value* link1 = nullptr; + ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok()); + link1->tensor.shape = BHWC(1, 4, 4, 2); + + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewMergeAddWithConvolution(); + ModelTransformer transformer(&graph); + transformer.Apply("merge_add_with_convolution", transformation.get()); + + EXPECT_EQ(1, graph.nodes().size()); + EXPECT_EQ(2, graph.values().size()); + EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D), + graph.nodes()[0]->operation.type); + + Convolution2DAttributes* conv_attr_new = + absl::any_cast( + &graph.nodes()[0]->operation.attributes); + + EXPECT_THAT(conv_attr_new->bias.data, + Pointwise(FloatNear(1e-6), {2.7f, 5.2f})); +} + } // namespace } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/model_transformations.cc b/tensorflow/lite/delegates/gpu/common/transformations/model_transformations.cc index 74147a78fda973..8f044cf79e9ea0 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/model_transformations.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/model_transformations.cc @@ -63,6 +63,8 @@ bool ApplyGeneralTransformations(ModelTransformer* transformer) { NewMergeConvolutionWithMul().get()) && transformer->Apply("merge_convolution_with_add", NewMergeConvolutionWithAdd().get()) && + transformer->Apply("merge_add_with_convolution", + NewMergeAddWithConvolution().get()) && transformer->Apply("merge_mul_with_convolution", NewMergeMulWithConvolution().get()); } diff --git a/tensorflow/lite/delegates/gpu/metal/gpu_object.h b/tensorflow/lite/delegates/gpu/metal/gpu_object.h index 668e8eea0fa573..0c29f4b3881104 100644 --- a/tensorflow/lite/delegates/gpu/metal/gpu_object.h +++ b/tensorflow/lite/delegates/gpu/metal/gpu_object.h @@ -33,8 +33,8 @@ namespace gpu { namespace metal { struct GPUResourcesWithValue { - std::vector> ints; - std::vector> floats; + GenericGPUResourcesWithValue generic; + struct BufferParameter { id handle; uint64_t offset; @@ -44,6 +44,13 @@ struct GPUResourcesWithValue { std::vector>> image2d_arrays; std::vector>> images3d; std::vector>> image_buffers; + + void AddFloat(const std::string& name, float value) { + generic.AddFloat(name, value); + } + void AddInt(const std::string& name, int value) { + generic.AddInt(name, value); + } }; class GPUObject { diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.cc b/tensorflow/lite/delegates/gpu/metal/inference_context.cc index b1638524fea155..6f75bd45eea663 100644 --- a/tensorflow/lite/delegates/gpu/metal/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.cc @@ -477,11 +477,12 @@ absl::Status InferenceContext::AllocateMemoryForBuffers(MetalDevice* device) { const auto& t = tensors_descs_[usage.first]; const auto& shape = t.GetBHWDCShape(); const auto& descriptor = t; - const size_t element_size = SizeOf(descriptor.data_type); + const size_t element_size = SizeOf(descriptor.GetDataType()); size_t buffer_size; size_t row_bytes_alignment = [device->device() minimumLinearTextureAlignmentForPixelFormat:DataTypeToRGBAPixelFormat( - descriptor.data_type, + descriptor + .GetDataType(), false)]; if (descriptor.GetStorageType() == TensorStorageType::TEXTURE_2D) { min_common_alignment = @@ -575,7 +576,7 @@ absl::Status InferenceContext::AllocateMemoryForBuffers(MetalDevice* device) { TensorStorageType::SINGLE_TEXTURE_2D) { size_t row_bytes_alignment = [device->device() minimumLinearTextureAlignmentForPixelFormat: - DataTypeToRGBAPixelFormat(tensor_dummy.data_type, false)]; + DataTypeToRGBAPixelFormat(tensor_dummy.GetDataType(), false)]; RETURN_IF_ERROR(CreateSharedImage2DBufferTensor( base_buffer, tensor_dummy.GetBHWDCShape(), tensor_dummy, row_bytes_alignment, &shared_buffer_tensors_[tensor_index], diff --git a/tensorflow/lite/delegates/gpu/metal/linear_storage.cc b/tensorflow/lite/delegates/gpu/metal/linear_storage.cc index b109d4e51ac1bd..5f781ae39a4222 100644 --- a/tensorflow/lite/delegates/gpu/metal/linear_storage.cc +++ b/tensorflow/lite/delegates/gpu/metal/linear_storage.cc @@ -64,7 +64,7 @@ absl::Status LinearStorage::GetGPUResources( "Expected TensorLinearDescriptor on input."); } - resources->ints.push_back({"length", depth_}); + resources->AddInt("length", depth_); if (storage_type_ == LinearStorageType::BUFFER) { resources->buffers.push_back({"buffer", {buffer_, 0}}); diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc index bd4b201c7b58b6..68eb34a780177a 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc +++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc @@ -618,10 +618,10 @@ std::string MetalArguments::GetListOfArgs(int buffer_offset, absl::Status MetalArguments::SetGPUResources( const std::string& name, const GPUResourcesWithValue& resources) { - for (const auto& r : resources.ints) { + for (const auto& r : resources.generic.ints) { RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second)); } - for (const auto& r : resources.floats) { + for (const auto& r : resources.generic.floats) { RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second)); } for (const auto& r : resources.buffers) { diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc index 718608295c6de5..9006d177b39063 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc +++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.cc @@ -33,11 +33,11 @@ absl::Status CreateTextureBuffer(id buffer, uint64_t buffer_offset, if (@available(macOS 10.14, iOS 12.0, tvOS 12.0, *)) { const int slices = DivideRoundUp(shape.c, 4); const size_t flt4_count = shape.b * shape.w * shape.h * shape.d * slices; - const size_t data_size = flt4_count * 4 * SizeOf(descriptor.data_type); + const size_t data_size = flt4_count * 4 * SizeOf(descriptor.GetDataType()); MTLTextureDescriptor* texture_desc = [[MTLTextureDescriptor alloc] init]; texture_desc.width = flt4_count; texture_desc.pixelFormat = - DataTypeToRGBAPixelFormat(descriptor.data_type, false); + DataTypeToRGBAPixelFormat(descriptor.GetDataType(), false); texture_desc.textureType = MTLTextureTypeTextureBuffer; texture_desc.usage = MTLTextureUsageShaderRead | MTLTextureUsageShaderWrite; texture_desc.storageMode = buffer.storageMode; @@ -64,7 +64,7 @@ absl::Status AllocateTensorMemory(id device, const BHWDC& shape, case TensorStorageType::BUFFER: case TensorStorageType::IMAGE_BUFFER: { const size_t data_size = shape.b * shape.w * shape.h * shape.d * slices * - 4 * SizeOf(descriptor.data_type); + 4 * SizeOf(descriptor.GetDataType()); if (data_ptr) { *buffer = [device newBufferWithBytes:data_ptr length:data_size @@ -85,7 +85,8 @@ absl::Status AllocateTensorMemory(id device, const BHWDC& shape, case TensorStorageType::TEXTURE_2D: { MTLTextureDescriptor* texture_desc = [MTLTextureDescriptor texture2DDescriptorWithPixelFormat:DataTypeToRGBAPixelFormat( - descriptor.data_type, false) + descriptor.GetDataType(), + false) width:shape.w * shape.b * shape.d height:shape.h * slices mipmapped:NO]; @@ -109,7 +110,7 @@ absl::Status AllocateTensorMemory(id device, const BHWDC& shape, texture_desc.height = shape.h; texture_desc.depth = slices * shape.d; texture_desc.pixelFormat = - DataTypeToRGBAPixelFormat(descriptor.data_type, false); + DataTypeToRGBAPixelFormat(descriptor.GetDataType(), false); texture_desc.textureType = MTLTextureType3D; texture_desc.usage = MTLTextureUsageShaderRead | MTLTextureUsageShaderWrite; @@ -130,7 +131,7 @@ absl::Status AllocateTensorMemory(id device, const BHWDC& shape, texture_desc.height = shape.h; texture_desc.arrayLength = slices * shape.d; texture_desc.pixelFormat = - DataTypeToRGBAPixelFormat(descriptor.data_type, false); + DataTypeToRGBAPixelFormat(descriptor.GetDataType(), false); texture_desc.textureType = MTLTextureType2DArray; texture_desc.usage = MTLTextureUsageShaderRead | MTLTextureUsageShaderWrite; @@ -258,32 +259,14 @@ absl::Status MetalSpatialTensor::GetGPUResources( if (!tensor_desc) { return absl::InvalidArgumentError("Expected TensorDescriptor on input."); } - resources->ints.push_back( - {"slice_stride", tensor_desc->GetSliceStrideSize(shape_)}); - if (descriptor_.HasAxis(Axis::WIDTH)) { - resources->ints.push_back({"width", tensor_desc->GetWidthSize(shape_)}); - } - if (descriptor_.HasAxis(Axis::HEIGHT)) { - resources->ints.push_back({"height", Height()}); - } - if (descriptor_.HasAxis(Axis::CHANNELS)) { - resources->ints.push_back({"slices", Slices()}); - resources->ints.push_back({"channels", Channels()}); - } - if (descriptor_.HasAxis(Axis::BATCH)) { - resources->ints.push_back({"batch", Batch()}); - } - if (descriptor_.HasAxis(Axis::DEPTH)) { - resources->ints.push_back({"depth", Depth()}); - } + tensor_desc->GetGpuResources(shape_, &resources->generic); if (descriptor_.GetStorageType() == TensorStorageType::BUFFER) { resources->buffers.push_back({"buffer", {memory_, buffer_offset_}}); } else if (descriptor_.GetStorageType() == TensorStorageType::TEXTURE_2D) { if (obj_ptr->GetAccess() == AccessType::WRITE && tensor_desc->GetUseBufferForWriteOnlyTexture2d()) { - resources->ints.push_back( - {"aligned_texture_width", aligned_texture_width_}); + resources->AddInt("aligned_texture_width", aligned_texture_width_); resources->buffers.push_back({"buffer", {memory_, buffer_offset_}}); } else { resources->images2d.push_back({"image2d", texture_mem_}); @@ -365,7 +348,7 @@ absl::Status MetalSpatialTensor::IsValid(const BHWDC& shape) const { } uint64_t MetalSpatialTensor::GetMemorySizeInBytes() const { - const int flt_size = SizeOf(descriptor_.data_type); + const int flt_size = SizeOf(descriptor_.GetDataType()); const int flt4_size = 4 * flt_size; switch (descriptor_.GetStorageType()) { case TensorStorageType::BUFFER: @@ -566,10 +549,10 @@ absl::Status CreateSharedImage2DBufferTensor(id buffer, texture_desc.mipmapLevelCount = 1; texture_desc.sampleCount = 1; texture_desc.pixelFormat = - DataTypeToRGBAPixelFormat(descriptor.data_type, false); + DataTypeToRGBAPixelFormat(descriptor.GetDataType(), false); texture_desc.usage = MTLTextureUsageShaderRead | MTLTextureUsageShaderWrite; texture_desc.storageMode = buffer.storageMode; - const size_t pixel_size = channels * SizeOf(descriptor.data_type); + const size_t pixel_size = channels * SizeOf(descriptor.GetDataType()); const size_t bytes_per_row = width * pixel_size; const size_t bytes_per_row_aligned = AlignByN(bytes_per_row, row_bytes_alignment); diff --git a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h index d3aca0c5b452b8..08fc449d95f96c 100644 --- a/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h +++ b/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h @@ -62,7 +62,7 @@ class MetalSpatialTensor : public GPUObject, public GpuSpatialTensor { int Batch() const override { return shape_.b; } TensorDescriptor GetDescriptor() const override { return descriptor_; } - DataType GetDataType() const { return descriptor_.data_type; } + DataType GetDataType() const { return descriptor_.GetDataType(); } TensorStorageType GetStorageType() const { return descriptor_.GetStorageType(); } @@ -188,7 +188,7 @@ template absl::Status MetalSpatialTensor::WriteDataBHWDC(id device, const T* in) { std::unique_ptr data_copy; data_copy.reset(new uint8_t[GetMemorySizeInBytes()]); - if (descriptor_.data_type == DataType::FLOAT16) { + if (descriptor_.GetDataType() == DataType::FLOAT16) { // rearrangement and conversion from float32 to float16 DataFromBHWDC(reinterpret_cast(in), shape_, descriptor_, reinterpret_cast(data_copy.get())); @@ -207,7 +207,7 @@ absl::Status MetalSpatialTensor::ReadDataBHWDC(id device, T* out) con RETURN_IF_ERROR(ReadData(device, data_copy.get())); - if (descriptor_.data_type == DataType::FLOAT16) { + if (descriptor_.GetDataType() == DataType::FLOAT16) { // rearrangement and conversion from float32 to float16 DataToBHWDC(reinterpret_cast(data_copy.get()), shape_, descriptor_, reinterpret_cast(out)); diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD index aeff8a9996d4d4..ab74006b2aee4d 100644 --- a/tensorflow/lite/delegates/nnapi/BUILD +++ b/tensorflow/lite/delegates/nnapi/BUILD @@ -21,12 +21,14 @@ cc_library( ], "//conditions:default": [ "nnapi_delegate.cc", + "nnapi_delegate_c_api.cc", "quant_lstm_sup.h", "quant_lstm_sup.cc", ], }), hdrs = [ "nnapi_delegate.h", + "nnapi_delegate_c_api.h", "nnapi_delegate_kernel.h", "nnapi_delegate_plugin.h", ], @@ -56,6 +58,7 @@ cc_library( name = "nnapi_delegate", hdrs = [ "nnapi_delegate.h", + "nnapi_delegate_c_api.h", "nnapi_delegate_kernel.h", "nnapi_delegate_plugin.h", ], @@ -70,6 +73,8 @@ cc_library( ], ) +exports_files(["nnapi_delegate_c_api.h"]) + cc_library( name = "nnapi_delegate_verbose_validation", srcs = select({ @@ -177,6 +182,31 @@ cc_test( ], ) +cc_test( + name = "nnapi_delegate_c_api_test", + size = "small", + srcs = [ + "nnapi_delegate_c_api_test.cc", + ], + tags = [ + "no_windows", + "tflite_not_portable_ios", + ], + visibility = ["//visibility:private"], + deps = [ + ":nnapi_delegate", + ":nnapi_delegate_mock_test", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:deprecated_backends", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/nnapi:nnapi_implementation", + "//tensorflow/lite/nnapi:nnapi_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "nnapi_delegate_errno_test", size = "small", diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index c71fa2a9c20b1d..101b3827bb940d 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -413,6 +413,7 @@ SelectOpTest/.+,29 -SliceOpTest/SliceOpTest/SliceInt64/.+ -SliceOpTest/SliceOpTest/SliceBool/.+ -SliceOpTest/SliceOpTest/SliceInt16/.+ +-SliceOpTest/SliceOpTest/SliceInt64StaticOutput/.* # Only constant tensors SliceOpTest/SliceOpTest/.+/0,29 diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.cc new file mode 100644 index 00000000000000..052194d23ddf21 --- /dev/null +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.cc @@ -0,0 +1,69 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h" + +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/lite/nnapi/sl/public/NeuralNetworksSupportLibraryImpl.h" + +TfLiteDelegate* TfLiteNnapiDelegateCreate( + const TfLiteNnapiDelegateOptions* options) { + tflite::StatefulNnApiDelegate::StatefulNnApiDelegate::Options + internal_options; + internal_options.execution_preference = + static_cast( + options->execution_preference); + internal_options.accelerator_name = options->accelerator_name; + internal_options.cache_dir = options->cache_dir; + internal_options.model_token = options->model_token; + internal_options.disallow_nnapi_cpu = options->disallow_nnapi_cpu; + internal_options.max_number_delegated_partitions = + options->max_number_delegated_partitions; + internal_options.allow_fp16 = options->allow_fp16; + + tflite::StatefulNnApiDelegate* delegate = nullptr; + if (options->nnapi_support_library_handle) { + delegate = new tflite::StatefulNnApiDelegate( + static_cast( + options->nnapi_support_library_handle), + internal_options); + } else { + delegate = new tflite::StatefulNnApiDelegate(internal_options); + } + return delegate; +} + +TfLiteNnapiDelegateOptions TfLiteNnapiDelegateOptionsDefault() { + TfLiteNnapiDelegateOptions result = {}; + tflite::StatefulNnApiDelegate::Options options; + result.execution_preference = + static_cast( + options.execution_preference); + result.accelerator_name = options.accelerator_name; + result.cache_dir = options.cache_dir; + result.model_token = options.model_token; + result.disallow_nnapi_cpu = options.disallow_nnapi_cpu; + result.max_number_delegated_partitions = + options.max_number_delegated_partitions; + result.allow_fp16 = options.allow_fp16; + result.nnapi_support_library_handle = nullptr; + return result; +} + +void TfLiteNnapiDelegateDelete(TfLiteDelegate* delegate) { + if (delegate == nullptr) return; + delete static_cast(delegate); +} diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h new file mode 100644 index 00000000000000..bc178358394b54 --- /dev/null +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h @@ -0,0 +1,104 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_C_API_H_ +#define TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_C_API_H_ + +#include "tensorflow/lite/c/common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Use TfLiteNnapiDelegateOptionsDefault() for Default options. +// WARNING: This is an experimental API and subject to change. +struct TFL_CAPI_EXPORT TfLiteNnapiDelegateOptions { + // Preferred Power/perf trade-off. For more details please see + // ANeuralNetworksCompilation_setPreference documentation in : + // https://developer.android.com/ndk/reference/group/neural-networks.html + enum ExecutionPreference { + kUndefined = -1, + kLowPower = 0, + kFastSingleAnswer = 1, + kSustainedSpeed = 2, + }; + + // Preferred Power/perf trade-off. Default to kUndefined. + ExecutionPreference execution_preference; + + // Selected NNAPI accelerator with nul-terminated name. + // Default to nullptr, which implies the NNAPI default behavior: NNAPI + // runtime is allowed to use all available accelerators. If the selected + // accelerator cannot be found, NNAPI will not be used. + // It is the caller's responsibility to ensure the string is valid for the + // duration of the Options object lifetime. + const char* accelerator_name; + + // The nul-terminated cache dir for NNAPI model. + // Default to nullptr, which implies the NNAPI will not try caching the + // compilation. + const char* cache_dir; + + // The unique nul-terminated token string for NNAPI model. + // Default to nullptr, which implies the NNAPI will not try caching the + // compilation. It is the caller's responsibility to ensure there is no + // clash of the tokens. + // NOTE: when using compilation caching, it is not recommended to use the + // same delegate instance for multiple models. + const char* model_token; + + // Whether to disallow NNAPI CPU usage. Default to 1 (true). Only effective on + // Android 10 and above. The NNAPI CPU typically performs less well than + // built-in TfLite kernels, but allowing CPU allows partial acceleration of + // models. If this is set to true, NNAPI is only used if the whole model is + // accelerated. + int disallow_nnapi_cpu; + + // Whether to allow fp32 compuation to be run in fp16. Default to 0 (false). + int allow_fp16; + + // Specifies the max number of partitions to delegate. A value <= 0 means + // no limit. Default to 3. + // If the delegation of the full set of supported nodes would generate a + // number of partition greater than this parameter, only + // of them will be actually accelerated. + // The selection is currently done sorting partitions in decreasing order + // of number of nodes and selecting them until the limit is reached. + int max_number_delegated_partitions; + + // The pointer to NNAPI support lib implementation. Default to nullptr. + // If specified, NNAPI delegate will use the support lib instead of NNAPI in + // Android OS. + void* nnapi_support_library_handle; +}; + +// Returns a delegate that uses NNAPI for ops execution. +// Must outlive the interpreter. +// WARNING: This is an experimental API and subject to change. +TfLiteDelegate* TFL_CAPI_EXPORT +TfLiteNnapiDelegateCreate(const TfLiteNnapiDelegateOptions* options); + +// Returns TfLiteNnapiDelegateOptions populated with default values. +// WARNING: This is an experimental API and subject to change. +TFL_CAPI_EXPORT TfLiteNnapiDelegateOptions TfLiteNnapiDelegateOptionsDefault(); + +// Does any needed cleanup and deletes 'delegate'. +// WARNING: This is an experimental API and subject to change. +void TFL_CAPI_EXPORT TfLiteNnapiDelegateDelete(TfLiteDelegate* delegate); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_C_API_H_ diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api_test.cc new file mode 100644 index 00000000000000..2376d78ec35372 --- /dev/null +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h" + +#include + +#include +#include + +#include +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/test_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class SingleOpModelWithNnapiDelegateCApi : public SingleOpModel { + public: + SingleOpModelWithNnapiDelegateCApi() { + options_ = TfLiteNnapiDelegateOptionsDefault(); + options_.disallow_nnapi_cpu = false; + } + + explicit SingleOpModelWithNnapiDelegateCApi( + const TfLiteNnapiDelegateOptions& options) { + options_ = options; + options_.disallow_nnapi_cpu = false; + } + + ~SingleOpModelWithNnapiDelegateCApi() { + if (nnapi_delegate_) { + TfLiteNnapiDelegateDelete(nnapi_delegate_); + } + nnapi_delegate_ = nullptr; + } + + protected: + void BuildInterpreterWithNNAPI(std::vector> input_shapes) { + if (nnapi_delegate_) { + TfLiteNnapiDelegateDelete(nnapi_delegate_); + } + nnapi_delegate_ = TfLiteNnapiDelegateCreate(&options_); + SetDelegate(nnapi_delegate_); + BuildInterpreter(input_shapes, /*num_threads=*/-1, options_.allow_fp16, + /*apply_delegate=*/true, /*allocate_and_delegate=*/true); + } + + private: + TfLiteNnapiDelegateOptions options_; + TfLiteDelegate* nnapi_delegate_ = nullptr; +}; + +class FloatAddOpModel : public SingleOpModelWithNnapiDelegateCApi { + public: + FloatAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + Init(input1, input2, output, activation_type); + } + + FloatAddOpModel(const TfLiteNnapiDelegateOptions& options, + const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) + : SingleOpModelWithNnapiDelegateCApi(options) { + Init(input1, input2, output, activation_type); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; + + private: + // Performs initialization logic shared across all constructors. + void Init(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)}); + } +}; + +// Basic test for the NNAPI delegate C APIs. +TEST(NNAPIDelegate, C_API) { + TfLiteNnapiDelegateOptions options = TfLiteNnapiDelegateOptionsDefault(); + options.execution_preference = + TfLiteNnapiDelegateOptions::ExecutionPreference::kLowPower; + + FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Basic test for the NNAPI delegate C API with accelerator_name specified. +TEST(NNAPIDelegate, C_API_WithAcceleratorName) { + TfLiteNnapiDelegateOptions options = TfLiteNnapiDelegateOptionsDefault(); + options.execution_preference = + TfLiteNnapiDelegateOptions::ExecutionPreference::kLowPower; + options.accelerator_name = "nnapi-reference"; + + FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Basic test for the NNAPI delegate C API with compilation caching enabled. +TEST(NNAPIDelegate, C_API_WithCompilationCaching) { + TfLiteNnapiDelegateOptions options = TfLiteNnapiDelegateOptionsDefault(); + options.execution_preference = + TfLiteNnapiDelegateOptions::ExecutionPreference::kLowPower; + options.cache_dir = "/data/local/tmp"; + options.model_token = "NNAPIDelegate.C_API_WithCompilationCaching"; + + // 1st run + { + FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); + } + // 2nd run + { + FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-1.0, 0.1, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.2, 0.2, 0.4, 0.2}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-0.8, 0.3, 1.1, 1.0})); + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index c9740197a3a2e9..aee8af09ecf632 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -1987,9 +1987,6 @@ cc_test( cc_test( name = "weights_cache_test", srcs = ["weights_cache_test.cc"], - data = [ - "//tensorflow/lite:testdata/conv_huge_im2col.bin", - ], deps = [ ":conv_2d_tester", ":test_main", diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD index d6e0b8edd7efb2..f8cf096d8710d2 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD @@ -454,9 +454,11 @@ cc_library( deps = [ ":call", ":decode_jpeg", + ":model_loader", ":status_codes", "//tensorflow/lite:framework", "//tensorflow/lite:minimal_logging", + "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", "//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs", "//tensorflow/lite/experimental/acceleration/configuration:delegate_registry", @@ -471,15 +473,19 @@ cc_library( hdrs = ["validator_runner.h"], deps = [ ":fb_storage", + ":model_loader", ":runner", ":status_codes", ":validator", + "@flatbuffers", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/core/api", "//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs", - "//tensorflow/lite/nnapi/sl:nnapi_support_library", + # For NNAPI support library, the headears and source files are defined + # as two separate targets. We need to include both targets for NNAPI to + # be invoked. + "//tensorflow/lite/nnapi/sl:nnapi_support_library", # buildcleaner: keep "//tensorflow/lite/nnapi/sl:nnapi_support_library_headers", - "@flatbuffers", ], ) @@ -493,17 +499,20 @@ cc_library( srcs = ["validator_runner_entrypoint.cc"], deps = [ ":fb_storage", - ":runner", + ":model_loader", ":set_big_core_affinity_h", ":status_codes", ":validator", ":validator_runner", + "@com_google_absl//absl/strings", + "@flatbuffers", "//tensorflow/lite/core/api", "//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs", - "//tensorflow/lite/nnapi/sl:nnapi_support_library", + # For NNAPI support library, the headears and source files are defined + # as two separate targets. We need to include both targets for NNAPI to + # be invoked. + "//tensorflow/lite/nnapi/sl:nnapi_support_library", # buildcleaner: keep "//tensorflow/lite/nnapi/sl:nnapi_support_library_headers", - "@com_google_absl//absl/strings", - "@flatbuffers", ], ) @@ -709,6 +718,32 @@ cc_test( ], ) +cc_library( + name = "model_loader", + srcs = ["model_loader.cc"], + hdrs = ["model_loader.h"], + deps = [ + ":status_codes", + "//tensorflow/lite:allocation", + "//tensorflow/lite:model_builder", + "//tensorflow/lite:stderr_reporter", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "model_loader_test", + srcs = ["model_loader_test.cc"], + deps = [ + ":embedded_mobilenet_model", + ":mini_benchmark_test_helper", + ":model_loader", + ":status_codes", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) + # # Test targets for separate process. # Unit tests using cc_test and turned into Android tests with tflite_portable_test_suite(). @@ -770,11 +805,13 @@ cc_binary( linkshared = True, deps = [ ":fb_storage", + ":model_loader", ":runner", ":status_codes", ":set_big_core_affinity", ":validator", ":validator_runner", + "@com_google_absl//absl/strings", "@flatbuffers", "//tensorflow/lite/experimental/acceleration/configuration:nnapi_plugin", "//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs", @@ -844,12 +881,16 @@ cc_test( ":embedded_mobilenet_validation_model", ":embedded_mobilenet_model", ":mini_benchmark_test_helper", + ":model_loader", ":status_codes", ":validator", "@com_google_googletest//:gtest_main", "@flatbuffers", + "//tensorflow/lite/experimental/acceleration/configuration:configuration_cc_proto", "//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/experimental/acceleration/configuration:flatbuffer_to_proto", "//tensorflow/lite/experimental/acceleration/configuration:nnapi_plugin", + "//tensorflow/lite/experimental/acceleration/configuration:proto_to_flatbuffer", ] + select({ clean_dep("//tensorflow:android"): [ "//tensorflow/lite/experimental/acceleration/configuration:gpu_plugin", diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.cc new file mode 100644 index 00000000000000..f1c7d1cdf52488 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.cc @@ -0,0 +1,77 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h" + +#include +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "tensorflow/lite/allocation.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/stderr_reporter.h" + +namespace tflite { +namespace acceleration { + +std::unique_ptr ModelLoader::CreateFromFdOrPath( + absl::string_view fd_or_path) { + if (!absl::StartsWith(fd_or_path, "fd:")) { + return std::make_unique(fd_or_path); + } + + std::vector parts = absl::StrSplit(fd_or_path, ':'); + int model_fd; + size_t model_offset, model_size; + if (parts.size() != 4 || !absl::SimpleAtoi(parts[1], &model_fd) || + !absl::SimpleAtoi(parts[2], &model_offset) || + !absl::SimpleAtoi(parts[3], &model_size)) { + return nullptr; + } + return std::make_unique(model_fd, model_offset, model_size); +} + +MinibenchmarkStatus ModelLoader::Init() { + if (model_) { + // Already done. + return kMinibenchmarkSuccess; + } + if (model_path_.empty() && model_fd_ <= 0) { + return kMinibenchmarkPreconditionNotMet; + } + if (!model_path_.empty()) { + model_ = FlatBufferModel::VerifyAndBuildFromFile(model_path_.c_str()); + } else if (MMAPAllocation::IsSupported()) { + auto allocation = std::make_unique( + model_fd_, model_offset_, model_size_, tflite::DefaultErrorReporter()); + if (!allocation->valid()) { + return kMinibenchmarkModelReadFailed; + } + model_ = + FlatBufferModel::VerifyAndBuildFromAllocation(std::move(allocation)); + } else { + return kMinibenchmarkUnsupportedPlatform; + } + if (!model_) { + return kMinibenchmarkModelBuildFailed; + } + return kMinibenchmarkSuccess; +} + +} // namespace acceleration +} // namespace tflite diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h new file mode 100644 index 00000000000000..ee2dfb0597ecf9 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h @@ -0,0 +1,78 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_ + +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" +#include "tensorflow/lite/model_builder.h" + +namespace tflite { +namespace acceleration { + +// Class to load the Model. +class ModelLoader { + public: + // Create the model loader from a model_path or a file descriptor. File + // descriptor path must be in the format of + // "fd:%model_fd%:%model_offset%:%model_size%". Return nullptr if the path + // starts with "fd:" but cannot be parsed with the given format. + static std::unique_ptr CreateFromFdOrPath( + absl::string_view fd_or_path); + + // Create the model loader from model_path. + explicit ModelLoader(absl::string_view model_path) + : model_path_(model_path) {} + +#ifndef _WIN32 + // Create the model loader from file descriptor. The model_fd only has to be + // valid for the duration of the constructor (it's dup'ed inside). This + // constructor is not available on Windows. + ModelLoader(int model_fd, size_t model_offset, size_t model_size) + : model_fd_(dup(model_fd)), + model_offset_(model_offset), + model_size_(model_size) {} +#endif // !_WIN32 + + ~ModelLoader() { + if (model_fd_ >= 0) { + close(model_fd_); + } + } + + // Return whether the model is loaded successfully. + MinibenchmarkStatus Init(); + + const FlatBufferModel* GetModel() const { return model_.get(); } + + private: + const std::string model_path_; + const int model_fd_ = -1; + const size_t model_offset_ = 0; + const size_t model_size_ = 0; + std::unique_ptr model_; +}; + +} // namespace acceleration + +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_ diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader_test.cc new file mode 100644 index 00000000000000..cdfbdf1cea9977 --- /dev/null +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h" + +#include +#include + +#include +#include + +#include +#include +#include "absl/strings/str_format.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedded_mobilenet_model.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_test_helper.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" + +namespace tflite { +namespace acceleration { +namespace { + +class ModelLoaderTest : public ::testing::Test { + protected: + void SetUp() override { + model_path_ = MiniBenchmarkTestHelper::DumpToTempFile( + "mobilenet_quant.tflite", + g_tflite_acceleration_embedded_mobilenet_model, + g_tflite_acceleration_embedded_mobilenet_model_len); + } + std::string model_path_; +}; + +TEST_F(ModelLoaderTest, CreateFromModelPath) { + std::unique_ptr model_loader = + ModelLoader::CreateFromFdOrPath(model_path_); + ASSERT_NE(model_loader, nullptr); + EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess); +} + +TEST_F(ModelLoaderTest, CreateFromFdPath) { + int fd = open(model_path_.c_str(), O_RDONLY); + ASSERT_GE(fd, 0); + struct stat stat_buf = {0}; + ASSERT_EQ(fstat(fd, &stat_buf), 0); + auto model_loader = std::make_unique(fd, 0, stat_buf.st_size); + close(fd); + + ASSERT_NE(model_loader, nullptr); + EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess); +} + +TEST_F(ModelLoaderTest, CreateFromFdOrModelPath) { + int fd = open(model_path_.c_str(), O_RDONLY); + ASSERT_GE(fd, 0); + struct stat stat_buf = {0}; + ASSERT_EQ(fstat(fd, &stat_buf), 0); + std::string path = absl::StrFormat("fd:%d:%zu:%zu", fd, 0, stat_buf.st_size); + auto model_loader = ModelLoader::CreateFromFdOrPath(path); + close(fd); + + ASSERT_NE(model_loader, nullptr); + EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess); +} + +TEST_F(ModelLoaderTest, InvalidFdPath) { + int fd = open(model_path_.c_str(), O_RDONLY); + ASSERT_GE(fd, 0); + struct stat stat_buf = {0}; + ASSERT_EQ(fstat(fd, &stat_buf), 0); + std::string path = absl::StrFormat("fd:%d:%zu", fd, 0); + auto model_loader = ModelLoader::CreateFromFdOrPath(path); + close(fd); + + EXPECT_EQ(model_loader, nullptr); +} + +} // namespace +} // namespace acceleration +} // namespace tflite diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc index 3ba83641266805..fe587214da15df 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc @@ -105,8 +105,9 @@ class LocalizerValidationRegressionTest : public ::testing::Test { ASSERT_GE(fd, 0); struct stat stat_buf = {0}; ASSERT_EQ(fstat(fd, &stat_buf), 0); - auto validator = - std::make_unique(fd, 0, stat_buf.st_size, settings); + auto validator = std::make_unique( + std::make_unique(fd, /*offset=*/0, stat_buf.st_size), + settings); close(fd); Validator::Results results; @@ -130,10 +131,10 @@ class LocalizerValidationRegressionTest : public ::testing::Test { } std::cerr << "\n"; } - std::cerr << "Compilation time us " << results.compilation_time_us + std::cerr << "Delegate prep time us " << results.delegate_prep_time_us << std::endl; - RecordProperty(accelerator_name + " Compilation time us", - results.compilation_time_us); + RecordProperty(accelerator_name + " Delegate prep time us", + results.delegate_prep_time_us); std::cerr << "Execution time us"; int test_case = 0; int64_t total_execution_time_us = 0; diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h index ab732df1e17b52..c41808112a81c6 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h @@ -76,6 +76,7 @@ enum MinibenchmarkStatus { // Validator status codes. kMinibenchmarkDelegateNotSupported = 1000, kMinibenchmarkDelegatePluginNotFound = 1001, + kMinibenchmarkDelegateCreateFailed = 1013, kMinibenchmarkModelTooLarge = 1002, // Safety limit currently set at 100M. kMinibenchmarkSeekToModelOffsetFailed = 1003, kMinibenchmarkModelReadFailed = 1004, diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc index 402be9cf8e6ea9..c673f546fea84d 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.cc @@ -14,22 +14,33 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h" -#include +#include +#include #include +#include -#include +#include +#include +#include #include #include +#include #include "absl/container/flat_hash_set.h" #include "tensorflow/lite/core/api/profiler.h" +#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/call_register.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_register.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" #include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" +#include "tensorflow/lite/mutable_op_resolver.h" #ifndef TEMP_FAILURE_RETRY #ifdef __ANDROID__ @@ -41,35 +52,6 @@ limitations under the License. namespace tflite { namespace acceleration { - -Validator::Validator(const std::string& model_path, - const ComputeSettings* compute_settings) - : model_path_(model_path), - compute_settings_(compute_settings), - delegate_(nullptr, [](TfLiteDelegate*) {}) {} - -Validator::Validator(int model_fd, size_t model_offset, size_t model_size, - const ComputeSettings* compute_settings) - : -#ifndef _WIN32 - model_fd_(dup(model_fd)), -#else // _WIN32 - model_fd_(-1), -#endif // !_WIN32 - model_offset_(model_offset), - model_size_(model_size), - compute_settings_(compute_settings), - delegate_(nullptr, [](TfLiteDelegate*) {}) { -} - -Validator::~Validator() { -#ifndef _WIN32 - if (model_fd_ >= 0) { - close(model_fd_); - } -#endif // !_WIN32 -} - namespace { std::unique_ptr LoadDelegatePlugin( const std::string& name, const tflite::TFLiteSettings& tflite_settings) { @@ -120,7 +102,6 @@ class ValidatorProfiler : public ::tflite::Profiler { } events_[event_handle - 1].end_time_us = ElapsedTimeMicros(); } - uint32_t handle_ = 0; private: std::vector events_; @@ -128,54 +109,21 @@ class ValidatorProfiler : public ::tflite::Profiler { } // namespace -MinibenchmarkStatus Validator::CheckModel(bool load_only) { - if (validation_entrypoint_) { - // Already done. - return kMinibenchmarkSuccess; - } - if (model_path_.empty() && model_fd_ <= 0) { +MinibenchmarkStatus Validator::CheckGoldenOutput() { + if (!interpreter_ || !model_loader_->GetModel()) { return kMinibenchmarkPreconditionNotMet; } - if (!model_path_.empty()) { - model_ = FlatBufferModel::VerifyAndBuildFromFile(model_path_.c_str()); - } else if (MMAPAllocation::IsSupported()) { - auto allocation = std::make_unique( - model_fd_, model_offset_, model_size_, tflite::DefaultErrorReporter()); - if (!allocation->valid()) { - return kMinibenchmarkModelReadFailed; - } - model_ = - FlatBufferModel::VerifyAndBuildFromAllocation(std::move(allocation)); - } else { - return kMinibenchmarkUnsupportedPlatform; - } - if (!model_) { - return kMinibenchmarkModelBuildFailed; - } - if (load_only) { + if (validation_entrypoint_) { + // Already done. return kMinibenchmarkSuccess; } - - if (compute_settings_->tflite_settings() && - compute_settings_->tflite_settings()->disable_default_delegates()) { - resolver_ = - ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates(); - } - resolver_.AddCustom("validation/call", - ::tflite::acceleration::ops::Register_CALL(), 1); - resolver_.AddCustom( - "validation/decode_jpeg", - ::tflite::acceleration::decode_jpeg_kernel::Register_DECODE_JPEG(), 1); - - tflite::InterpreterBuilder(*model_, resolver_)(&interpreter_); - if (!interpreter_) { - return kMinibenchmarkInterpreterBuilderFailed; - } main_model_ = interpreter_->subgraph(0); + int validation_entrypoint_index = 0; for (int i = 0; i < interpreter_->subgraphs_size(); i++) { Subgraph* subgraph = interpreter_->subgraph(i); if (subgraph->GetName() == "VALIDATION:main") { validation_entrypoint_ = subgraph; + validation_entrypoint_index = i; break; } } @@ -190,38 +138,54 @@ MinibenchmarkStatus Validator::CheckModel(bool load_only) { return kMinibenchmarkValidationSubgraphHasTooFewOutputs; } + if (validation_entrypoint_->AllocateTensors() != kTfLiteOk) { + return kMinibenchmarkAllocateTensorsFailed; + } + // Check if we have validation data embedded or need to run CPU for it. If - // the data is embedded, there is already an allocation for it from the model. + // the data is embedded, there is already an allocation for it from the model, + // and we can skip running it on CPU. TfLiteTensor* first_input_tensor = validation_entrypoint_->tensor(validation_entrypoint_->inputs()[0]); - if (!first_input_tensor->allocation) { - // Run on CPU. - if (validation_entrypoint_->AllocateTensors() != kTfLiteOk) { - return kMinibenchmarkAllocateTensorsFailed; - } - // Set initial golden outputs to 0 to avoid accessing uninitialized memory. - // Last input is jpeg, skip. - for (int i = 0; i < validation_entrypoint_->inputs().size() - 1; i++) { - TfLiteTensor* input_tensor = - validation_entrypoint_->tensor(validation_entrypoint_->inputs()[i]); - memset(input_tensor->data.raw, 0, input_tensor->bytes); - } - TfLiteStatus status = validation_entrypoint_->Invoke(); - if (status != kTfLiteOk) { - return kMinibenchmarkInvokeFailed; - } - // Copy CPU outputs as golden. Last input is jpeg image data, skip. - for (int i = 0; i < validation_entrypoint_->inputs().size() - 1; i++) { - TfLiteTensor* input_tensor = - validation_entrypoint_->tensor(validation_entrypoint_->inputs()[i]); - TfLiteTensor* golden_tensor = - validation_entrypoint_->tensor(validation_entrypoint_->outputs()[i]); - if (input_tensor->bytes != golden_tensor->bytes) { - return kMinibenchmarkValidationSubgraphInputsDontMatchOutputs; - } - memcpy(input_tensor->data.raw, golden_tensor->data.raw, - input_tensor->bytes); + if (first_input_tensor->allocation) { + return kMinibenchmarkSuccess; + } + + // Create the interpreter to run on CPU. + tflite::InterpreterBuilder(*model_loader_->GetModel(), + *resolver_)(&golden_interpreter_); + if (!golden_interpreter_) { + return kMinibenchmarkInterpreterBuilderFailed; + } + Subgraph* golden_validation_entrypoint = + golden_interpreter_->subgraph(validation_entrypoint_index); + + // Run on CPU. + if (golden_validation_entrypoint->AllocateTensors() != kTfLiteOk) { + return kMinibenchmarkAllocateTensorsFailed; + } + // Set initial golden outputs to 0 to avoid accessing uninitialized memory. + // Last input is jpeg, skip. + for (int i = 0; i < golden_validation_entrypoint->inputs().size() - 1; i++) { + TfLiteTensor* input_tensor = golden_validation_entrypoint->tensor( + golden_validation_entrypoint->inputs()[i]); + memset(input_tensor->data.raw, 0, input_tensor->bytes); + } + + if (golden_validation_entrypoint->Invoke() != kTfLiteOk) { + return kMinibenchmarkInvokeFailed; + } + // Copy CPU outputs as golden. Last input is jpeg image data, skip. + for (int i = 0; i < validation_entrypoint_->inputs().size() - 1; i++) { + TfLiteTensor* input_tensor = + validation_entrypoint_->tensor(validation_entrypoint_->inputs()[i]); + TfLiteTensor* golden_tensor = golden_validation_entrypoint->tensor( + golden_validation_entrypoint->outputs()[i]); + if (input_tensor->bytes != golden_tensor->bytes) { + return kMinibenchmarkValidationSubgraphInputsDontMatchOutputs; } + memcpy(input_tensor->data.raw, golden_tensor->data.raw, + input_tensor->bytes); } return kMinibenchmarkSuccess; @@ -232,68 +196,85 @@ MinibenchmarkStatus Validator::LoadDelegate() { return kMinibenchmarkPreconditionNotMet; } + // Create delegate plugin and delegate. Delegate which_delegate = Delegate_NONE; if (compute_settings_->tflite_settings()) { which_delegate = compute_settings_->tflite_settings()->delegate(); } - if (which_delegate == Delegate_NNAPI) { - delegate_plugin_ = - LoadDelegatePlugin("Nnapi", *compute_settings_->tflite_settings()); - } else if (which_delegate == Delegate_GPU) { - delegate_plugin_ = - LoadDelegatePlugin("Gpu", *compute_settings_->tflite_settings()); - } else if (which_delegate == Delegate_XNNPACK) { - delegate_plugin_ = - LoadDelegatePlugin("XNNPack", *compute_settings_->tflite_settings()); - } else if (which_delegate == Delegate_NONE) { - return kMinibenchmarkSuccess; - } else { - return kMinibenchmarkDelegateNotSupported; + std::string delegate_name; + switch (which_delegate) { + case Delegate_NONE: + // Skip creating delegate if running on CPU. + return kMinibenchmarkSuccess; + case Delegate_NNAPI: + delegate_name = "Nnapi"; + break; + case Delegate_GPU: + delegate_name = "Gpu"; + break; + case Delegate_XNNPACK: + delegate_name = "XNNPack"; + break; + default: + return kMinibenchmarkDelegateNotSupported; } - if (!delegate_plugin_) { + + TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Running mini-benchmark on %s", + delegate_name.c_str()); + if (!(delegate_plugin_ = LoadDelegatePlugin( + delegate_name, *compute_settings_->tflite_settings()))) { return kMinibenchmarkDelegatePluginNotFound; } - delegate_ = delegate_plugin_->Create(); - + if (!(delegate_ = delegate_plugin_->Create())) { + return kMinibenchmarkDelegateCreateFailed; + } return kMinibenchmarkSuccess; } -MinibenchmarkStatus Validator::ApplyComputeSettings( - int* delegate_error_out, int* delegated_kernels_out) { - if (!delegate_error_out) { +MinibenchmarkStatus Validator::CreateInterpreter(int* delegate_error_out, + int* delegated_kernels_out) { + if (!delegate_error_out || !delegated_kernels_out || + !model_loader_->GetModel()) { return kMinibenchmarkPreconditionNotMet; } *delegate_error_out = 0; - Delegate which_delegate = Delegate_NONE; - if (compute_settings_->tflite_settings()) { - which_delegate = compute_settings_->tflite_settings()->delegate(); + // Create interpreter with the delegate. + if (compute_settings_->tflite_settings() && + compute_settings_->tflite_settings()->disable_default_delegates()) { + resolver_ = std::make_unique< + ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>(); + } else { + resolver_ = std::make_unique<::tflite::ops::builtin::BuiltinOpResolver>(); } - std::string delegate; - if (which_delegate == Delegate_NONE) { - delegate = "CPU"; - } else if (which_delegate == Delegate_GPU) { - delegate = "GPU"; - } else if (which_delegate == Delegate_NNAPI) { - delegate = "NNAPI"; - } else if (which_delegate == Delegate_XNNPACK) { - delegate = "XNNPACK"; + resolver_->AddCustom("validation/call", + ::tflite::acceleration::ops::Register_CALL(), 1); + resolver_->AddCustom( + "validation/decode_jpeg", + ::tflite::acceleration::decode_jpeg_kernel::Register_DECODE_JPEG(), 1); + + tflite::InterpreterBuilder builder(*model_loader_->GetModel(), *resolver_); + // Add delegate if not running on CPU. + if (delegate_ != nullptr) { + builder.AddDelegate(delegate_.get()); } + TfLiteStatus status = builder(&interpreter_); + if (!interpreter_) { + // Return delegate error number if not null. + *delegate_error_out = + delegate_plugin_ ? delegate_plugin_->GetDelegateErrno(delegate_.get()) + : 0; - TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Running mini-benchmark on %s", - delegate.c_str()); - if (which_delegate == Delegate_NONE) { - return kMinibenchmarkSuccess; - } else if (!delegate_) { - return kMinibenchmarkPreconditionNotMet; + TFLITE_LOG_PROD(TFLITE_LOG_ERROR, + "Creating Interpreter failed with error code %d.", status); + return kMinibenchmarkInterpreterBuilderFailed; } - ValidatorProfiler profiler; - main_model_->SetProfiler(&profiler, 0); - TfLiteStatus status = interpreter_->ModifyGraphWithDelegate(delegate_.get()); // Check if the model is actually going to execute on the delegate. // For now just give a warning, with the exception of NNAPI SL mini benchmark. // Can consider changing to error in other contexts. // The logic is copy/pasted from benchmark_tflite_model.cc + // TODO(b/232085640): Replace this logic with Subgraph::IsFullyDelegated() + // after making that function public. absl::flat_hash_set checked_node_ids; int num_delegated_kernels = 0; for (int i = 0; i < interpreter_->execution_plan().size(); ++i) { @@ -317,46 +298,32 @@ MinibenchmarkStatus Validator::ApplyComputeSettings( num_delegated_kernels > 0 ? "partially" : "not"); } - main_model_->SetProfiler(nullptr, 0); - for (const auto& e : profiler.events()) { - if (e.tag == "ModifyGraphWithDelegate" && e.start_time_us != -1 && - e.end_time_us != -1) { - compilation_time_us_ = e.end_time_us - e.start_time_us; - TFLITE_LOG_PROD(TFLITE_LOG_INFO, " Compilation took %d us", - static_cast(compilation_time_us_)); - break; - } - } - if (status == kTfLiteOk) { - return kMinibenchmarkSuccess; - } else { - *delegate_error_out = delegate_plugin_->GetDelegateErrno(delegate_.get()); - return kMinibenchmarkModifyGraphWithDelegateFailed; - } + return kMinibenchmarkSuccess; } MinibenchmarkStatus Validator::RunValidation(Results* results_out) { if (!results_out) { return kMinibenchmarkPreconditionNotMet; } - // The lifetime of the delegate must be at least as long as the lifetime of - // any Interpreter. - MinibenchmarkStatus mb_status = LoadDelegate(); - if (mb_status != kMinibenchmarkSuccess) { - return mb_status; - } - mb_status = CheckModel(); - if (mb_status != kMinibenchmarkSuccess) { - return mb_status; - } - mb_status = ApplyComputeSettings(&results_out->delegate_error, - &results_out->delegated_kernels); - if (mb_status != kMinibenchmarkSuccess) { - return mb_status; + if (!model_loader_) { + return kMinibenchmarkModelReadFailed; } - if (validation_entrypoint_->AllocateTensors() != kTfLiteOk) { - return kMinibenchmarkAllocateTensorsFailed; + +#define MB_RETURN_IF_ERROR(s) \ + { \ + MinibenchmarkStatus c = (s); \ + if (c != kMinibenchmarkSuccess) return c; \ } + + MB_RETURN_IF_ERROR(model_loader_->Init()); + // The lifetime of the delegate must be at least as long as the lifetime of + // any Interpreter. + int64_t delegate_load_start_time_us = ElapsedTimeMicros(); + MB_RETURN_IF_ERROR(LoadDelegate()); + MB_RETURN_IF_ERROR(CreateInterpreter(&results_out->delegate_error, + &results_out->delegated_kernels)); + int64_t delegate_load_end_time_us = ElapsedTimeMicros(); + MB_RETURN_IF_ERROR(CheckGoldenOutput()); ValidatorProfiler profiler; main_model_->SetProfiler(&profiler, 0); TfLiteStatus status = validation_entrypoint_->Invoke(); @@ -392,7 +359,12 @@ MinibenchmarkStatus Validator::RunValidation(Results* results_out) { } TFLITE_LOG_PROD(TFLITE_LOG_INFO, " accuracy: %s", results_out->ok ? "ok" : "not ok"); - results_out->compilation_time_us = compilation_time_us_; + results_out->delegate_prep_time_us = + (delegate_load_end_time_us == -1 || delegate_load_start_time_us == -1) + ? -1 + : delegate_load_end_time_us - delegate_load_start_time_us; + TFLITE_LOG_PROD(TFLITE_LOG_INFO, " Delegate preparation took %d us", + static_cast(results_out->delegate_prep_time_us)); for (const auto& e : profiler.events()) { if (e.tag == "Invoke" && e.start_time_us != -1 && e.end_time_us != -1) { results_out->execution_time_us.push_back(e.end_time_us - e.start_time_us); @@ -400,6 +372,7 @@ MinibenchmarkStatus Validator::RunValidation(Results* results_out) { static_cast(e.end_time_us - e.start_time_us)); } } +#undef MB_RETURN_IF_ERROR return kMinibenchmarkSuccess; } diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h index e209bf864b322a..c6f7cff7c5af03 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h @@ -17,14 +17,19 @@ limitations under the License. #include #include +#include #include +#include +#include +#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/mutable_op_resolver.h" namespace tflite { namespace acceleration { @@ -34,23 +39,14 @@ namespace acceleration { // // The API is split into multiple steps so that callers can construct detailed // telemetry from it. -// -// TODO(b/172541832): add wrapper for running in separate process and using a -// file to communicate results. class Validator { public: - // Construct Validator for given model path and compute settings. The + // Construct Validator for the given model and compute settings. The // compute_settings must be valid for the lifetime of the Validator instance. - Validator(const std::string& model_path, - const ComputeSettings* compute_settings); - // Construct Validator for given model file descriptor and compute settings. - // The model_fd only has to be valid for the duration of the constructor (it's - // dup'ed inside). The compute_settings must be valid for the lifetime of the - // Validator instance. - Validator(int model_fd, size_t model_offset, size_t model_size, - const ComputeSettings* compute_settings); - // Check that the model is valid for validation. - MinibenchmarkStatus CheckModel(bool load_only = false); + Validator(std::unique_ptr model_loader, + const ComputeSettings* compute_settings) + : model_loader_(std::move(model_loader)), + compute_settings_(compute_settings) {} // Results from validation. struct Results { @@ -58,9 +54,9 @@ class Validator { bool ok = false; // What are the metrics results, for telemetry. std::map> metrics; - // How long did compilation (ModifyGraphWithDelegate) take. -1 if running on - // CPU (or in rare cases when reading the system clock fails). - int64_t compilation_time_us; + // How long did loading the delegate and creating the interpreter take. -1 + // if failed. + int64_t delegate_prep_time_us = 0; // How long did execution (Invoke) take. (Empty in rare cases when reading // the system clock fails). std::vector execution_time_us; @@ -69,6 +65,7 @@ class Validator { // Number of delegated kernels int delegated_kernels = 0; }; + // Run the validation graph and return validation results. MinibenchmarkStatus RunValidation(Results* results_out); @@ -76,8 +73,6 @@ class Validator { static int64_t BootTimeMicros(); static int64_t WallTimeMicros(); - ~Validator(); - Validator(Validator&) = delete; Validator& operator=(Validator&) = delete; Validator(Validator&&) = delete; @@ -86,23 +81,31 @@ class Validator { private: // Load delegate plugin and create delegate. MinibenchmarkStatus LoadDelegate(); - // Apply the compute settings (typically applying a delegate to the - // interpreter). - MinibenchmarkStatus ApplyComputeSettings(int* delegate_error_out, - int* delegated_kernels_out); - - std::string model_path_; - int model_fd_ = -1; - size_t model_offset_, model_size_; + + // Create the interpreter with the delegate. Must be called after + // LoadDelegate(). + MinibenchmarkStatus CreateInterpreter(int* delegate_error_out, + int* delegated_kernels_out); + + // Check if the golden output exists. If not, run Model on CPU. + MinibenchmarkStatus CheckGoldenOutput(); + + std::unique_ptr model_loader_; const ComputeSettings* compute_settings_; + // Interpreter that runs on CPU. + std::unique_ptr golden_interpreter_; + // Interpreter that runs with delegate enabled, using the compute settings + // passed to the Validator constructor. std::unique_ptr interpreter_; - ::tflite::ops::builtin::BuiltinOpResolver resolver_; + // Op resolver used to create the interpreters. Depending on the + // compute_settings_, it may or may not include the default delegate. + std::unique_ptr<::tflite::MutableOpResolver> resolver_; std::unique_ptr model_; - ::tflite::delegates::TfLiteDelegatePtr delegate_; + ::tflite::delegates::TfLiteDelegatePtr delegate_ = + delegates::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); std::unique_ptr delegate_plugin_; Subgraph* validation_entrypoint_ = nullptr; Subgraph* main_model_ = nullptr; - int64_t compilation_time_us_ = -1; }; } // namespace acceleration diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.cc index d20cd60569a9a6..aa356379f332c9 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.cc @@ -16,6 +16,8 @@ limitations under the License. #include +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h" + #ifndef _WIN32 #include #include @@ -51,7 +53,7 @@ ValidatorRunner::ValidatorRunner(const std::string& model_path, const NnApiSLDriverImplFL5* nnapi_sl, const std::string validation_function_name, ErrorReporter* error_reporter) - : model_path_(model_path), + : fd_or_model_path_(model_path), storage_path_(storage_path), data_directory_path_(data_directory_path), storage_(storage_path_, error_reporter), @@ -66,44 +68,32 @@ ValidatorRunner::ValidatorRunner(int model_fd, size_t model_offset, const NnApiSLDriverImplFL5* nnapi_sl, const std::string validation_function_name, ErrorReporter* error_reporter) - : -#ifndef _WIN32 - model_fd_(dup(model_fd)), -#else // _WIN32 - model_fd_(-1), -#endif // !_WIN32 - model_offset_(model_offset), - model_size_(model_size), - storage_path_(storage_path), + : storage_path_(storage_path), data_directory_path_(data_directory_path), storage_(storage_path_, error_reporter), validation_function_name_(validation_function_name), error_reporter_(error_reporter), nnapi_sl_(nnapi_sl) { + std::stringstream ss; + ss << "fd:" << model_fd << ":" << model_offset << ":" << model_size; + fd_or_model_path_ = ss.str(); } MinibenchmarkStatus ValidatorRunner::Init() { - flatbuffers::FlatBufferBuilder fbb; - fbb.Finish(CreateComputeSettings(fbb, tflite::ExecutionPreference_ANY, - CreateTFLiteSettings(fbb))); - std::unique_ptr check_validator; - // We are not configuring the validator to use the NNAPI Support Library - // even if specified since we just want to check that the model can be loaded - // from disk and we are not interacting with NNAPI. - if (!model_path_.empty()) { - check_validator = std::make_unique( - model_path_, - flatbuffers::GetRoot(fbb.GetBufferPointer())); - } else { - check_validator = std::make_unique( - model_fd_, model_offset_, model_size_, - flatbuffers::GetRoot(fbb.GetBufferPointer())); + std::unique_ptr model_loader = + ModelLoader::CreateFromFdOrPath(fd_or_model_path_); + if (!model_loader) { + TF_LITE_REPORT_ERROR(error_reporter_, "Failed to parse model path %s", + fd_or_model_path_.c_str()); + return kMinibenchmarkPreconditionNotMet; } - MinibenchmarkStatus load_status = - check_validator->CheckModel(/* load_only */ true); + + // Check that the model can be loaded from disk. + MinibenchmarkStatus load_status = model_loader->Init(); if (load_status != kMinibenchmarkSuccess) { TF_LITE_REPORT_ERROR(error_reporter_, "Could not load model %s: %d", - model_path_.c_str(), static_cast(load_status)); + fd_or_model_path_.c_str(), + static_cast(load_status)); return load_status; } @@ -211,18 +201,9 @@ int ValidatorRunner::TriggerMissingValidation( return 0; } - std::string model_path; - if (!model_path_.empty()) { - model_path = model_path_; - } else { - std::stringstream ss; - ss << "fd:" << model_fd_ << ":" << model_offset_ << ":" << model_size_; - model_path = ss.str(); - } - // We purposefully detach the thread and have it own all the data. The // runner may potentially hang, so we can't wait for it to terminate. - std::thread detached_thread([model_path = model_path, + std::thread detached_thread([model_path = fd_or_model_path_, storage_path = storage_path_, data_directory_path = data_directory_path_, to_be_run, diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.h index d5005db5194bad..6f5a02f86b1cfb 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.h +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.h @@ -109,9 +109,7 @@ class ValidatorRunner { int64_t timeout_us = kDefaultEventTimeoutUs); private: - std::string model_path_; - int model_fd_ = -1; - size_t model_offset_, model_size_; + std::string fd_or_model_path_; std::string storage_path_; std::string data_directory_path_; FlatbufferStorage storage_; diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_entrypoint.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_entrypoint.cc index 621285f58f85d9..c7b1ff01a138f3 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_entrypoint.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_entrypoint.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h" #ifndef _WIN32 #include #include @@ -22,11 +25,12 @@ limitations under the License. #include #include #include // NOLINT: only used on Android, where std::thread is allowed +#include +#include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/fb_storage.h" -#include "tensorflow/lite/experimental/acceleration/mini_benchmark/runner.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/set_big_core_affinity.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h" @@ -35,6 +39,45 @@ limitations under the License. namespace tflite { namespace acceleration { +namespace { + +MinibenchmarkStatus RunValidator(absl::string_view model_path, + const std::string& nnapi_sl_path, + TFLiteSettingsT& tflite_settings, + Validator::Results& results) { + // Load NNAPI Support Library if specified. + std::unique_ptr nnapi_sl_handle; + if (tflite_settings.nnapi_settings && !nnapi_sl_path.empty()) { + // We are not calling dlclose, it will be done once the + // validator process ends. + nnapi_sl_handle = ::tflite::nnapi::loadNnApiSupportLibrary(nnapi_sl_path); + + if (!nnapi_sl_handle) { + return kMiniBenchmarkCannotLoadSupportLibrary; + } + + tflite_settings.nnapi_settings->support_library_handle = + reinterpret_cast(nnapi_sl_handle->getFL5()); + } + + flatbuffers::FlatBufferBuilder fbb; + fbb.Finish( + CreateComputeSettings(fbb, ExecutionPreference_ANY, + CreateTFLiteSettings(fbb, &tflite_settings))); + std::unique_ptr model_loader = + ModelLoader::CreateFromFdOrPath(model_path); + if (!model_loader) { + return kMinibenchmarkPreconditionNotMet; + } + + auto validator = std::make_unique( + std::move(model_loader), + flatbuffers::GetRoot(fbb.GetBufferPointer())); + + return validator->RunValidation(&results); +} + +} // namespace extern "C" { int Java_org_tensorflow_lite_acceleration_validation_entrypoint(int argc, @@ -42,11 +85,11 @@ int Java_org_tensorflow_lite_acceleration_validation_entrypoint(int argc, if (argc < 6) return 1; // argv[1] is the helper binary name // argv[2] is the function name - std::string model_path = argv[3]; - std::string storage_path = argv[4]; + const std::string model_path = argv[3]; + const std::string storage_path = argv[4]; // argv[5] is data directory path. // argv[6] if present is the NNAPI SL path - std::string nnapi_sl_path = argc > 6 ? argv[6] : ""; + const std::string nnapi_sl_path = argc > 6 ? argv[6] : ""; FileLock lock(storage_path + ".child_lock"); if (!lock.TryLock()) { return kMinibenchmarkChildProcessAlreadyRunning; @@ -56,7 +99,6 @@ int Java_org_tensorflow_lite_acceleration_validation_entrypoint(int argc, if (status != kMinibenchmarkSuccess) { return status; } - status = kMinibenchmarkNoValidationRequestFound; TFLiteSettingsT tflite_settings; int32_t set_big_core_affinity_errno = SetBigCoresAffinity(); @@ -75,54 +117,22 @@ int Java_org_tensorflow_lite_acceleration_validation_entrypoint(int argc, Validator::BootTimeMicros(), Validator::WallTimeMicros())); } + status = kMinibenchmarkNoValidationRequestFound; for (int i = storage.Count() - 1; i >= 0; i--) { const BenchmarkEvent* event = storage.Get(i); if (event->event_type() == BenchmarkEventType_START) { event->tflite_settings()->UnPackTo(&tflite_settings); - std::unique_ptr - nnapi_sl_handle; - if (tflite_settings.nnapi_settings && !nnapi_sl_path.empty()) { - // We are not calling dlclose, it will be done once the - // validator process ends. - nnapi_sl_handle = - ::tflite::nnapi::loadNnApiSupportLibrary(nnapi_sl_path); - - if (!nnapi_sl_handle) { - status = kMiniBenchmarkCannotLoadSupportLibrary; - break; - } - - tflite_settings.nnapi_settings->support_library_handle = - reinterpret_cast(nnapi_sl_handle->getFL5()); - } - - flatbuffers::FlatBufferBuilder fbb; - fbb.Finish( - CreateComputeSettings(fbb, ExecutionPreference_ANY, - CreateTFLiteSettings(fbb, &tflite_settings))); - std::unique_ptr validator; - if (model_path.find("fd:") == 0) { // NOLINT - int model_fd, model_offset, model_size; - if (sscanf(model_path.c_str(), "fd:%d:%d:%d", &model_fd, &model_offset, - &model_size) != 3) { - status = kMinibenchmarkPreconditionNotMet; - } - validator = std::make_unique( - model_fd, model_offset, model_size, - flatbuffers::GetRoot(fbb.GetBufferPointer())); - } else { - validator = std::make_unique( - model_path, - flatbuffers::GetRoot(fbb.GetBufferPointer())); - } Validator::Results results; - status = validator->RunValidation(&results); + status = + RunValidator(model_path, nnapi_sl_path, tflite_settings, results); if (status != kMinibenchmarkSuccess) { break; } - fbb.Reset(); - std::vector initialization_times{results.compilation_time_us}; + + // If succeed, write MiniBenchmark metrics to file then return. + flatbuffers::FlatBufferBuilder fbb; + std::vector delegate_prep_time_us{results.delegate_prep_time_us}; std::vector> metrics; metrics.reserve(results.metrics.size()); for (const auto& name_and_values : results.metrics) { @@ -132,16 +142,17 @@ int Java_org_tensorflow_lite_acceleration_validation_entrypoint(int argc, } return storage.Append( &fbb, - CreateBenchmarkEvent( - fbb, CreateTFLiteSettings(fbb, &tflite_settings), - BenchmarkEventType_END, - CreateBenchmarkResult(fbb, fbb.CreateVector(initialization_times), - fbb.CreateVector(results.execution_time_us), - 0, results.ok, fbb.CreateVector(metrics)), - /* error */ 0, Validator::BootTimeMicros(), - Validator::WallTimeMicros())); + CreateBenchmarkEvent(fbb, CreateTFLiteSettings(fbb, &tflite_settings), + BenchmarkEventType_END, + CreateBenchmarkResult( + fbb, fbb.CreateVector(delegate_prep_time_us), + fbb.CreateVector(results.execution_time_us), + 0, results.ok, fbb.CreateVector(metrics)), + /* error */ 0, Validator::BootTimeMicros(), + Validator::WallTimeMicros())); } } + // Write error to file. flatbuffers::FlatBufferBuilder fbb; return storage.Append( &fbb, CreateBenchmarkEvent( diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_entrypoint_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_entrypoint_test.cc index 2f7973761c938e..8522f7107bd7ce 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_entrypoint_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner_entrypoint_test.cc @@ -15,12 +15,13 @@ limitations under the License. #include #include +#include #include +#include #include #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/experimental/acceleration/compatibility/android_info.h" #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/validator_runner.h" @@ -134,6 +135,50 @@ TEST_F(ValidatorRunnerEntryPointTest, CannotSetCpuAffinity) { EXPECT_EQ(10, event->error()->mini_benchmark_error_code()); } +TEST_F(ValidatorRunnerEntryPointTest, CannotLoadNnapi) { + // Write TFLiteSettings to storage_. + flatbuffers::FlatBufferBuilder fbb; + TFLiteSettingsT tflite_settings; + NNAPISettingsT nnapi_settings; + ASSERT_EQ( + storage_.Append( + &fbb, + CreateBenchmarkEvent( + fbb, + CreateTFLiteSettings(fbb, Delegate_NNAPI, + CreateNNAPISettings(fbb, &nnapi_settings)), + BenchmarkEventType_START, /* result */ 0, /* error */ 0, + Validator::BootTimeMicros(), Validator::WallTimeMicros())), + kMinibenchmarkSuccess); + // Prep argv. + std::vector args = { + "test", + "binary_name", + "Java_org_tensorflow_lite_acceleration_validation_entrypoint", + "model_path", + storage_path_, + "data_directory_path", + "nnapi_path"}; + std::vector> mutable_args(args.size()); + std::vector argv(args.size()); + for (int i = 0; i < mutable_args.size(); i++) { + mutable_args[i] = {args[i].data(), args[i].data() + args[i].size()}; + mutable_args[i].push_back('\0'); + argv[i] = mutable_args[i].data(); + } + EXPECT_EQ(kMinibenchmarkSuccess, + Java_org_tensorflow_lite_acceleration_validation_entrypoint( + 7, argv.data())); + + // Verify. + std::vector events = GetEvents(); + ASSERT_THAT(events, testing::SizeIs(2)); + const tflite::BenchmarkEvent* event = events[1]; + EXPECT_EQ(BenchmarkEventType_ERROR, event->event_type()); + EXPECT_EQ(kMiniBenchmarkCannotLoadSupportLibrary, + event->error()->exit_code()); +} + } // namespace } // namespace acceleration } // namespace tflite diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_test.cc index 1d88a85e5dc746..40a3fae853c9ba 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/validator_test.cc @@ -14,14 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/acceleration/mini_benchmark/validator.h" +#include #include +#include #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h" #include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/experimental/acceleration/configuration/proto_to_flatbuffer.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedded_mobilenet_model.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedded_mobilenet_validation_model.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_test_helper.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" // Note that these tests are not meant to be completely exhaustive, but to test @@ -34,21 +39,24 @@ namespace { class ValidatorTest : public ::testing::Test { protected: void SetUp() override { - validation_model_path_ = MiniBenchmarkTestHelper::DumpToTempFile( + std::string validation_model_path = MiniBenchmarkTestHelper::DumpToTempFile( "mobilenet_quant_with_validation.tflite", g_tflite_acceleration_embedded_mobilenet_validation_model, g_tflite_acceleration_embedded_mobilenet_validation_model_len); - ASSERT_TRUE(!validation_model_path_.empty()); + ASSERT_TRUE(!validation_model_path.empty()); + validation_model_loader_ = + std::make_unique(validation_model_path); - plain_model_path_ = MiniBenchmarkTestHelper::DumpToTempFile( + std::string plain_model_path = MiniBenchmarkTestHelper::DumpToTempFile( "mobilenet_quant.tflite", g_tflite_acceleration_embedded_mobilenet_model, g_tflite_acceleration_embedded_mobilenet_model_len); - ASSERT_TRUE(!plain_model_path_.empty()); + ASSERT_TRUE(!plain_model_path.empty()); + plain_model_loader_ = std::make_unique(plain_model_path); } - std::string validation_model_path_; - std::string plain_model_path_; + std::unique_ptr validation_model_loader_; + std::unique_ptr plain_model_loader_; }; TEST_F(ValidatorTest, HappyPath) { @@ -57,21 +65,32 @@ TEST_F(ValidatorTest, HappyPath) { const ComputeSettings* settings = flatbuffers::GetRoot(fbb.GetBufferPointer()); - Validator validator(validation_model_path_, settings); + Validator validator(std::move(validation_model_loader_), settings); Validator::Results results; EXPECT_EQ(validator.RunValidation(&results), kMinibenchmarkSuccess); EXPECT_TRUE(results.ok); EXPECT_EQ(results.delegate_error, 0); } +TEST_F(ValidatorTest, DelegateNotSupported) { + proto::ComputeSettings settings_proto; + settings_proto.mutable_tflite_settings()->set_delegate(proto::CORE_ML); + flatbuffers::FlatBufferBuilder fbb; + const ComputeSettings* settings = ConvertFromProto(settings_proto, &fbb); + + Validator validator(std::move(plain_model_loader_), settings); + Validator::Results results; + EXPECT_EQ(validator.RunValidation(&results), + kMinibenchmarkDelegateNotSupported); +} + TEST_F(ValidatorTest, NoValidationSubgraph) { flatbuffers::FlatBufferBuilder fbb; fbb.Finish(CreateComputeSettings(fbb)); const ComputeSettings* settings = flatbuffers::GetRoot(fbb.GetBufferPointer()); - Validator validator(plain_model_path_, settings); - EXPECT_EQ(validator.CheckModel(), kMinibenchmarkValidationSubgraphNotFound); + Validator validator(std::move(plain_model_loader_), settings); Validator::Results results; EXPECT_EQ(validator.RunValidation(&results), kMinibenchmarkValidationSubgraphNotFound); @@ -88,12 +107,22 @@ TEST_F(ValidatorTest, InvalidModel) { const ComputeSettings* settings = flatbuffers::GetRoot(fbb.GetBufferPointer()); - Validator validator(dump_path, settings); - EXPECT_EQ(validator.CheckModel(), kMinibenchmarkModelBuildFailed); + Validator validator(std::make_unique(dump_path), settings); Validator::Results results; EXPECT_EQ(validator.RunValidation(&results), kMinibenchmarkModelBuildFailed); } +TEST_F(ValidatorTest, EmptyModelLoader) { + flatbuffers::FlatBufferBuilder fbb; + fbb.Finish(CreateComputeSettings(fbb)); + const ComputeSettings* settings = + flatbuffers::GetRoot(fbb.GetBufferPointer()); + + Validator validator(nullptr, settings); + Validator::Results results; + EXPECT_EQ(validator.RunValidation(&results), kMinibenchmarkModelReadFailed); +} + } // namespace } // namespace acceleration } // namespace tflite diff --git a/tensorflow/lite/g3doc/android/tutorials/audio_classification.md b/tensorflow/lite/g3doc/android/tutorials/audio_classification.md index 9cf3325eb02c73..9478a685424d24 100644 --- a/tensorflow/lite/g3doc/android/tutorials/audio_classification.md +++ b/tensorflow/lite/g3doc/android/tutorials/audio_classification.md @@ -5,6 +5,7 @@ learning models to recognize sounds and spoken words in an Android app. Audio classification models like the ones shown in this tutorial can be used to detect activity, identify actions, or recognize voice commands. +![Audio recognition animated demo](https://storage.googleapis.com/download.tensorflow.org/tflite/examples/audio_classification.gif){: .attempt-right} This tutorial shows you how to download the example code, load the project into [Android Studio](https://developer.android.com/studio/), and explains key parts of the code example so you can start adding this functionality to your own app. diff --git a/tensorflow/lite/g3doc/performance/coreml_delegate.md b/tensorflow/lite/g3doc/performance/coreml_delegate.md index 91ae96ecbceaae..096a8693243524 100644 --- a/tensorflow/lite/g3doc/performance/coreml_delegate.md +++ b/tensorflow/lite/g3doc/performance/coreml_delegate.md @@ -62,7 +62,7 @@ TensorFlow Lite 2.4.0 release, this was the only option. } else { interpreter = try Interpreter(modelPath: modelPath) } -

+

Objective-C

@@ -92,7 +92,7 @@ TensorFlow Lite 2.4.0 release, this was the only option. if (error != nil) { /* Error handling... */ } // Run inference ... -

+

C (Until 2.3.0)

@@ -159,7 +159,7 @@ pass `TfLiteCoreMlDelegateAllDevices`. Following example shows how to do this: initWithOptions:coreMLOptions]; // Initialize interpreter with delegate -

+

C

@@ -191,7 +191,7 @@ performance benefits. Following example shows how to do this: let interpreter = try Interpreter(modelPath: modelPath, delegates: [delegate!]) -

+

Objective-C

diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index d4b032d9705e35..a2f3fc2f3eb881 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -133,6 +133,7 @@ aar_with_jni( "//tensorflow/lite/c:c_api_experimental.h", # TODO(b/175298345): Clean up and if possible remove common.h here. "//tensorflow/lite/c:common.h", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate_c_api.h", ], ) diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 46bdb386de16c7..82ebb099be570d 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -2187,6 +2187,7 @@ cc_test( ":test_main", ":test_util", "//tensorflow/lite:string", + "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 8851e377d7eb8c..095c75c626e96f 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -297,6 +297,7 @@ cc_library( "optimized/reduce.h", "optimized/resize_bilinear.h", "optimized/sparse_ops/fully_connected.h", + "reduce_common.h", ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), @@ -307,6 +308,7 @@ cc_library( ":cppmath", ":cpu_check", ":quantization_util", + ":reduce_utils", ":reference_base", ":strided_slice_logic", ":tensor", @@ -317,6 +319,7 @@ cc_library( "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:cpu_backend_gemm", "//tensorflow/lite/kernels:cpu_backend_threadpool", + "//tensorflow/lite/kernels:kernel_util", "//third_party/eigen3", "@gemmlowp//:fixedpoint", "@ruy//ruy/profiler:instrumentation", @@ -561,6 +564,7 @@ cc_library( "reference/process_broadcast_shapes.h", "reference/quantize.h", "reference/reduce.h", + "reduce_common.h", "reference/requantize.h", "reference/resize_bilinear.h", "reference/resize_nearest_neighbor.h", @@ -838,6 +842,15 @@ cc_library( ], ) +cc_library( + name = "reduce_utils", + hdrs = [ + "optimized/reduce_utils.h", + ], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), +) + # Audio support classes imported directly from TensorFlow. cc_library( name = "audio_utils", @@ -1220,6 +1233,16 @@ cc_test( ], ) +cc_test( + name = "reduce_utils_test", + srcs = ["optimized/reduce_utils_test.cc"], + deps = [ + ":common", + ":reduce_utils", + "@com_google_googletest//:gtest_main", + ], +) + filegroup( name = "optimized_op_headers", srcs = glob([ diff --git a/tensorflow/lite/kernels/internal/optimized/reduce.h b/tensorflow/lite/kernels/internal/optimized/reduce.h index 0c3e84f3206ca6..664272c73c5060 100644 --- a/tensorflow/lite/kernels/internal/optimized/reduce.h +++ b/tensorflow/lite/kernels/internal/optimized/reduce.h @@ -23,6 +23,8 @@ limitations under the License. #include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops_utils.h" +#include "tensorflow/lite/kernels/internal/optimized/reduce_utils.h" +#include "tensorflow/lite/kernels/internal/reduce_common.h" #include "tensorflow/lite/kernels/internal/reference/reduce.h" #include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -297,6 +299,425 @@ inline bool MeanGeneral( resolved_axis, temp_sum); } +template +struct SumOp { + inline T operator()(const T& a) const { return a; } + inline T operator()(const T& a, const T& b) const { return a + b; } + static constexpr T kNeutralElement = T(0); +}; + +template +struct CastSumOp { + inline U operator()(const T& a) const { return static_cast(a); } + inline U operator()(const U& a, const T& b) const { + return a + static_cast(b); + } + static constexpr U kNeutralElement = U(0); +}; + +template +struct ProdOp { + inline T operator()(const T& a) const { return a; } + inline T operator()(const T& a, const T& b) const { return a * b; } + static constexpr T kNeutralElement = T(1); +}; + +template +struct MaxOp { + inline T operator()(const T& a) const { return a; } + inline T operator()(const T& a, const T& b) const { return (a > b) ? a : b; } + static constexpr T kNeutralElement = std::numeric_limits::lowest(); +}; + +template +struct MinOp { + inline T operator()(const T& a) const { return a; } + inline T operator()(const T& a, const T& b) const { return (a < b) ? a : b; } + static constexpr T kNeutralElement = std::numeric_limits::max(); +}; + +struct AndOp { + inline bool operator()(bool a) const { return a; } + inline bool operator()(bool a, bool b) const { return a && b; } + static constexpr bool kNeutralElement = true; +}; + +struct OrOp { + inline bool operator()(bool a) const { return a; } + inline bool operator()(bool a, bool b) const { return a || b; } + static constexpr bool kNeutralElement = false; +}; + +// When the number of axis is zero, the reduction is simply a copy. +template +void ReduceIsCopy(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data) { + int num_elems = 1; + for (int i = 0; i < input_num_dims; ++i) { + num_elems *= input_dims[i]; + } + memcpy(output_data, input_data, num_elems * sizeof(T)); +} + +// Reduces the input over either odd or even dimensions using Op. +// One recursive call for each dimension is made. +// 'depth' is the depth of recursion. +// 'parity' indicates whether odd or even dimensions are being reduced. +// ReducerFirst is applied to the first element to be written to each output +// position. +// ReducerNext is applied to each subsequent element to be written to each +// output position. +template +inline std::pair ReduceImpl(const T* input_data, + const int* input_dims, U* output_data, + int depth, int parity, bool next, + const ReducerFirst& reducer_first, + const ReducerNext& reducer_next) { + // The output pointer is incremented conditionally depending on whether the + // odd or even dimension is being reduced. + // The input pointer is always incremented as each input is read once. + if (depth > 0) { + U* future_output = output_data; + bool update_output = (depth % 2) == parity; + for (int i = 0; i < input_dims[0]; ++i) { + if (i > 0 && !update_output) { + next = true; + } + std::tie(input_data, future_output) = + ReduceImpl(input_data, &input_dims[1], output_data, depth - 1, parity, + next, reducer_first, reducer_next); + if (update_output) { + output_data = future_output; + } + } + output_data = future_output; + } else { + // Reduce the final dimension. + if (parity) { + // Reduce the even dimension. The entire dimension is reduced into one + // value. + U res = next ? reducer_next(*output_data, *input_data++) + : reducer_first(*input_data++); + for (int i = 1; i < input_dims[0]; ++i) { + res = reducer_next(res, *input_data++); + } + *output_data++ = res; + } else { + // Reduce the odd dimension. Each input is accumulated into a separate + // output. + if (!next) { + for (int i = 0; i < input_dims[0]; ++i) { + U res = reducer_first(*input_data++); + *output_data++ = res; + } + } else { + for (int i = 0; i < input_dims[0]; ++i) { + U res = *output_data; + res = reducer_next(res, *input_data++); + *output_data++ = res; + } + } + } + } + return {input_data, output_data}; +} + +// A generic reduce method that can be used for reduce_sum, reduce_mean, etc. +// This method iterates through input data and reduce elements along the +// dimensions given in axis. ReducerFirst is used the first time each output +// element is written and ReducerNext is used for all subsequent writes. +template +inline bool Reduce(const In* input_data, const int* input_dims, + const int input_num_dims, const int* axis, + const int num_axis, Out* output_data, + const ReducerFirst& reducer_first, + const ReducerNext& reducer_next) { + const int parity = (axis[num_axis - 1] == input_num_dims - 1) ? 1 : 0; + ReduceImpl(input_data, input_dims, output_data, input_num_dims - 1, parity, + /*next=*/false, reducer_first, reducer_next); + return true; +} + +// Computes the mean or sum of elements across dimensions given in axis. +// It does so in two stages, first calculates the sum of elements along the axis +// then divides it by the number of element in axis for quantized values. +template +bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point, + float input_scale, const int* input_dims, + const int input_num_dims, T* output_data, + int32_t output_zero_point, float output_scale, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, + bool keep_dims, int* normalized_dims, + int* resolved_axis, U* temp_sum, bool compute_sum) { + ruy::profiler::ScopeLabel label(compute_sum ? "QuantizedSum" + : "QuantizedMean"); + // Reset output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; + } + + // Return early when input shape has zero dim. This is done after initializing + // data for output tensor because there are cases that the input tensor is + // empty but output tensor is not. In that case, output tensor should be + // filled with init_value. + for (int i = 0; i < input_num_dims; ++i) { + if (input_dims[i] == 0) return true; + } + + // Resolve axis. + int num_resolved_axis = 0; + int normalized_num_dims = 0; + if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions, + resolved_axis, num_resolved_axis, input_dims, + normalized_dims, normalized_num_dims)) { + return false; + } + + if (!Reduce, CastSumOp>( + input_data, normalized_dims, normalized_num_dims, resolved_axis, + num_resolved_axis, temp_sum, CastSumOp(), CastSumOp())) { + return false; + } + + // Calculate mean by dividing output_data by num of aggregated element. + size_t num_elements_in_axis = 1; + for (int idx = 0; idx < num_resolved_axis; ++idx) { + size_t current = static_cast(normalized_dims[resolved_axis[idx]]); + // Overflow prevention. + if (current > (std::numeric_limits::max() / num_elements_in_axis)) { + return false; + } + num_elements_in_axis *= current; + } + + if (num_elements_in_axis > 0) { + const float scale = input_scale / output_scale; + if (compute_sum) { + const float bias = -input_zero_point * scale * num_elements_in_axis; + for (size_t idx = 0; idx < num_outputs; ++idx) { + const U value = + static_cast(TfLiteRound(temp_sum[idx] * scale + bias)) + + output_zero_point; + output_data[idx] = static_cast(value); + } + } else { + const float bias = -input_zero_point * scale; + for (size_t idx = 0; idx < num_outputs; ++idx) { + float float_mean = static_cast(temp_sum[idx]) / + static_cast(num_elements_in_axis); + float result = TfLiteMin( + TfLiteRound(float_mean * scale + bias) + output_zero_point, + static_cast(std::numeric_limits::max())); + result = TfLiteMax(result, + static_cast(std::numeric_limits::min())); + output_data[idx] = static_cast(result); + } + } + } + return true; +} + +using ops::builtin::reduce::ReduceType; + +template +inline bool ReduceDispatcher(const T* input_data, const int* input_dims, + const int input_num_dims, const int* output_dims, + int output_num_dims, T* output_data, + const int* axis, const int64_t num_axis_dimensions, + ReduceType reduce_type) { + T init_value; + switch (reduce_type) { + case ReduceType::kProd: + init_value = ProdOp::kNeutralElement; + break; + case ReduceType::kSum: + init_value = SumOp::kNeutralElement; + break; + case ReduceType::kMin: + init_value = MinOp::kNeutralElement; + break; + case ReduceType::kMax: + init_value = MaxOp::kNeutralElement; + break; + default: + return false; + } + // Return early when input shape has zero dim. This is done after initializing + // data for output tensor because there are cases that the input tensor is + // empty but output tensor is not. In that case, output tensor should be + // filled with Op::kNeutralElement. + for (int i = 0; i < input_num_dims; ++i) { + if (input_dims[i] == 0) { + return reference_ops::InitTensorDataForReduce( + output_dims, output_num_dims, init_value, output_data); + } + } + + switch (reduce_type) { + case ReduceType::kProd: + return Reduce, ProdOp>( + input_data, input_dims, input_num_dims, axis, num_axis_dimensions, + output_data, ProdOp(), ProdOp()); + case ReduceType::kSum: + return Reduce, SumOp>( + input_data, input_dims, input_num_dims, axis, num_axis_dimensions, + output_data, SumOp(), SumOp()); + case ReduceType::kMin: + return Reduce, MinOp>( + input_data, input_dims, input_num_dims, axis, num_axis_dimensions, + output_data, MinOp(), MinOp()); + case ReduceType::kMax: + return Reduce, MaxOp>( + input_data, input_dims, input_num_dims, axis, num_axis_dimensions, + output_data, MaxOp(), MaxOp()); + default: + return false; + } +} + +template <> +inline bool ReduceDispatcher(const bool* input_data, + const int* input_dims, + const int input_num_dims, + const int* output_dims, int output_num_dims, + bool* output_data, const int* axis, + const int64_t num_axis_dimensions, + ReduceType reduce_type) { + bool init_value; + switch (reduce_type) { + case ReduceType::kAny: + init_value = OrOp::kNeutralElement; + break; + case ReduceType::kAll: + init_value = AndOp::kNeutralElement; + break; + default: + return false; + } + // Return early when input shape has zero dim. This is done after initializing + // data for output tensor because there are cases that the input tensor is + // empty but output tensor is not. In that case, output tensor should be + // filled with Op::kNeutralElement. + for (int i = 0; i < input_num_dims; ++i) { + if (input_dims[i] == 0) { + return reference_ops::InitTensorDataForReduce( + output_dims, output_num_dims, init_value, output_data); + } + } + switch (reduce_type) { + case ReduceType::kAll: + return Reduce( + input_data, input_dims, input_num_dims, axis, num_axis_dimensions, + output_data, AndOp(), AndOp()); + case ReduceType::kAny: + return Reduce( + input_data, input_dims, input_num_dims, axis, num_axis_dimensions, + output_data, OrOp(), OrOp()); + default: + return false; + } +} + +// Calculate the reduced product by rescaling each multiplication step to +// avoid an overflow. +template +struct ReducerFirst { + explicit ReducerFirst(int input_zero_point_arg) + : input_zero_point(input_zero_point_arg) {} + int32_t operator()(T in) const { return in - input_zero_point; } + int input_zero_point; +}; + +template +struct ReducerNext { + ReducerNext(int32_t input_zero_point_arg, int32_t scaling_multiplier_arg, + int32_t scaling_shift_arg) + : input_zero_point(input_zero_point_arg), + scaling_multiplier(scaling_multiplier_arg), + scaling_shift(scaling_shift_arg) {} + int32_t operator()(int32_t current, T in) const { + const int64_t result = + static_cast(current) * (in - input_zero_point); + return MultiplyByQuantizedMultiplier(result, scaling_multiplier, + scaling_shift); + } + int32_t input_zero_point, scaling_multiplier, scaling_shift; +}; + +template +inline bool QuantizedReduceProd( + const T* input_data, int32_t input_zero_point, + const RuntimeShape& input_shape, T* output_data, int32_t output_zero_point, + const RuntimeShape& output_shape, const int* axis, + const int64_t num_axis_dimensions, int* resolved_axis, int* normalized_dims, + int32_t* temp_prod, int32_t scaling_multiplier, int scaling_shift) { + const int32_t kMinValue = std::numeric_limits::min(); + const int32_t kMaxValue = std::numeric_limits::max(); + + // Resolve axis. + int num_resolved_axis = 0; + int normalized_num_dims = 0; + if (!reduce_utils::ResolveAxis(input_shape.DimensionsCount(), axis, + num_axis_dimensions, resolved_axis, + num_resolved_axis, input_shape.DimsData(), + normalized_dims, normalized_num_dims)) { + return false; + } + + if (!Reduce, ReducerNext>( + input_data, normalized_dims, normalized_num_dims, resolved_axis, + num_resolved_axis, temp_prod, ReducerFirst(input_zero_point), + ReducerNext(input_zero_point, scaling_multiplier, + scaling_shift))) { + return false; + } + + for (int i = 0; i < output_shape.FlatSize(); i++) { + int32_t result = + MultiplyByQuantizedMultiplier(static_cast(temp_prod[i]), + scaling_multiplier, scaling_shift) + + output_zero_point; + result = std::min(std::max(result, kMinValue), kMaxValue); + output_data[i] = static_cast(result); + } + + return true; +} + +// Computes the generic value (i.e., sum/max/min/prod) of elements across +// dimensions given in axis. It needs to pass in init_value and reducer. +template +inline bool ReduceGeneric(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int64_t num_axis_dimensions, + int* resolved_axis, int* normalized_dims, + ReduceType reduce_type) { + int num_resolved_axis = 0; + int normalized_num_dims = 0; + if (!reduce_utils::ResolveAxis(input_num_dims, axis, num_axis_dimensions, + resolved_axis, num_resolved_axis, input_dims, + normalized_dims, normalized_num_dims)) { + return false; + } + if (num_resolved_axis == 0) { + optimized_ops::ReduceIsCopy(input_data, input_dims, input_num_dims, + output_data); + return true; + } + return ReduceDispatcher(input_data, normalized_dims, normalized_num_dims, + output_dims, output_num_dims, output_data, + resolved_axis, num_resolved_axis, reduce_type); +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/optimized/reduce_utils.h b/tensorflow/lite/kernels/internal/optimized/reduce_utils.h new file mode 100644 index 00000000000000..51e30da014f162 --- /dev/null +++ b/tensorflow/lite/kernels/internal/optimized/reduce_utils.h @@ -0,0 +1,137 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_ + +#include + +#include + +namespace tflite { +namespace reduce_utils { + +inline void RemoveSize1Dims(int* shape_out, int& out_num_dims, int* axis_out, + int& out_num_axis) { + for (int64_t i = 0; i < out_num_dims;) { + if (shape_out[i] == 1) { + for (int64_t j = i + 1; j < out_num_dims; ++j) { + shape_out[j - 1] = shape_out[j]; + } + for (int64_t j = 0; j < out_num_axis; ++j) { + if (axis_out[j] == i) { + for (int64_t k = j + 1; k < out_num_axis; ++k) { + axis_out[k - 1] = axis_out[k]; + } + out_num_axis -= 1; + break; + } + } + for (int64_t j = 0; j < out_num_axis; ++j) { + if (axis_out[j] > i) { + axis_out[j] -= 1; + } + } + --out_num_dims; + } else { + ++i; + } + } +} + +// This method parses the input 'axis' to remove duplicates, handle negative +// values and remove redundant dimensions. It returns a valid 'axis_out' and +// 'shape_out' contains the flattened input shape. 'out_num_dims' contains the +// reduced number of dimensions. +inline bool ResolveAxis(const int num_dims, const int* axis, + const int64_t num_axis, int* axis_out, + int& out_num_axis, const int* shape_in, int* shape_out, + int& out_num_dims) { + // Short-circuit axis resolution for scalars; the axis will go unused. + if (num_dims == 0) { + out_num_axis = 0; + out_num_dims = 0; + return true; + } + out_num_axis = 0; + out_num_dims = num_dims; + // o(n^2) is fine since out_num_axis should be really small, mostly <= 4 + for (int64_t idx = 0; idx < num_axis; ++idx) { + // Handle negative index. A positive index 'p_idx' can be represented as a + // negative index 'n_idx' as: n_idx = p_idx-num_dims + // eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */ + int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx]; + if (current < 0 || current >= num_dims) { + return false; + } + bool is_dup = false; + for (int j = 0; j < out_num_axis; ++j) { + if (axis_out[j] == current) { + is_dup = true; + break; + } + } + if (!is_dup) { + axis_out[out_num_axis] = current; + out_num_axis += 1; + } + } + // If two or more adjacent dimensions are either reduced + // over or not, then the second and subsequent dimensions may be flattened. + memcpy(shape_out, shape_in, num_dims * sizeof(int)); + std::sort(&axis_out[0], &axis_out[out_num_axis]); + + RemoveSize1Dims(shape_out, out_num_dims, axis_out, out_num_axis); + if (out_num_axis > 0) { + int64_t j = out_num_axis - 1; + // true if the previous index is present in axis_out. + bool previous_here = (axis_out[j] == out_num_dims - 1); + if (previous_here) { + j -= 1; + } + + for (int64_t i = out_num_dims - 2; i >= 0; --i) { + // true if the current index is present in axis_out. + bool current_here = j >= 0 ? (axis_out[j] == i) : false; + if (current_here == previous_here) { + shape_out[i] *= shape_out[i + 1]; + for (int64_t k = i + 1; k + 1 < out_num_dims; ++k) { + shape_out[k] = shape_out[k + 1]; + } + // All axis bigger than this need to be reduced by 1. + for (int64_t k = 0; k < out_num_axis; ++k) { + if (axis_out[k] > i) { + axis_out[k] -= 1; + } + } + if (current_here) { + for (int64_t k = j + 1; k + 1 < out_num_axis; ++k) { + axis_out[k] = axis_out[k + 1]; + } + out_num_axis -= 1; + } + out_num_dims -= 1; + } + if (current_here) { + j -= 1; + } + previous_here = current_here; + } + } + return true; +} +} // namespace reduce_utils +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_ diff --git a/tensorflow/lite/kernels/internal/optimized/reduce_utils_test.cc b/tensorflow/lite/kernels/internal/optimized/reduce_utils_test.cc new file mode 100644 index 00000000000000..7b2bc26ea88c89 --- /dev/null +++ b/tensorflow/lite/kernels/internal/optimized/reduce_utils_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/optimized/reduce_utils.h" + +#include + +namespace tflite { +namespace reduce_utils { +namespace { + +using ::testing::ElementsAreArray; + +void TestFunction(const std::vector& axis_in, + const std::vector& shape_in, + const std::vector& expected_axis_out, + const std::vector& expected_shape_out) { + int num_dims = shape_in.size(); + int expected_out_num_dims = expected_shape_out.size(); + int actual_out_num_dims; + int expected_out_num_axis = expected_axis_out.size(); + int actual_out_num_axis; + std::vector actual_shape_out(num_dims); + std::vector actual_axis_out(num_dims); + ResolveAxis(shape_in.size(), axis_in.data(), axis_in.size(), + actual_axis_out.data(), actual_out_num_axis, shape_in.data(), + actual_shape_out.data(), actual_out_num_dims); + EXPECT_EQ(expected_out_num_dims, actual_out_num_dims); + EXPECT_EQ(expected_out_num_axis, actual_out_num_axis); + EXPECT_THAT(expected_shape_out, + ElementsAreArray(actual_shape_out.data(), expected_out_num_dims)); + EXPECT_THAT(expected_axis_out, + ElementsAreArray(actual_axis_out.data(), expected_out_num_axis)); +} + +TEST(ResolveAxisTest, Flatten_0_1_2) { + const std::vector axis_in = {0, 1, 2}; + const std::vector shape_in = {2, 3, 4, 5}; + const std::vector expected_shape_out{24, 5}; + const std::vector expected_axis_out{0}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, Flatten_0_1_2_3) { + const std::vector axis_in = {3, 2}; + const std::vector shape_in = {2, 3, 4, 5}; + const std::vector expected_shape_out{6, 20}; + const std::vector expected_axis_out{1}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, ZeroDims) { + const std::vector axis_in = {}; + const std::vector shape_in = {}; + const std::vector expected_shape_out{}; + const std::vector expected_axis_out{}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, DoNothing) { + const std::vector axis_in = {0}; + const std::vector shape_in = {4, 5}; + const std::vector expected_shape_out{4, 5}; + const std::vector expected_axis_out{0}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, NegativeAxis) { + const std::vector axis_in = {-2}; + const std::vector shape_in = {4, 3}; + const std::vector expected_shape_out{4, 3}; + const std::vector expected_axis_out{0}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, NegativeAxisFold) { + const std::vector axis_in = {-1}; + const std::vector shape_in = {4, 3, 5}; + const std::vector expected_shape_out{12, 5}; + const std::vector expected_axis_out{1}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, DuplicateAxis) { + const std::vector axis_in = {2, 1, 2, 1, 2, 1}; + const std::vector shape_in = {4, 3, 2}; + const std::vector expected_shape_out{4, 6}; + const std::vector expected_axis_out{1}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, DuplicateNegativeAxis) { + const std::vector axis_in = {2, -1, -2, -1, 2, 1}; + const std::vector shape_in = {4, 3, 2}; + const std::vector expected_shape_out{4, 6}; + const std::vector expected_axis_out{1}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, RemoveSize1Dim) { + const std::vector axis_in = {0}; + const std::vector shape_in = {1, 4, 3, 1}; + const std::vector expected_shape_out{4, 3}; + const std::vector expected_axis_out{}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, OneSize1DimToScalar) { + const std::vector axis_in = {0}; + const std::vector shape_in = {1}; + const std::vector expected_shape_out{}; + const std::vector expected_axis_out{}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +TEST(ResolveAxisTest, InterleavedSize1Dim) { + const std::vector axis_in = {1, 3}; + const std::vector shape_in = {1, 2, 1, 4, 1, 7}; + const std::vector expected_shape_out{8, 7}; + const std::vector expected_axis_out{0}; + TestFunction(axis_in, shape_in, expected_axis_out, expected_shape_out); +} + +} // namespace +} // namespace reduce_utils +} // namespace tflite diff --git a/tensorflow/lite/kernels/internal/reduce_common.h b/tensorflow/lite/kernels/internal/reduce_common.h new file mode 100644 index 00000000000000..948caa3ace9f8d --- /dev/null +++ b/tensorflow/lite/kernels/internal/reduce_common.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_REDUCE_COMMON_H_ +#define TENSORFLOW_LITE_KERNELS_REDUCE_COMMON_H_ + +namespace tflite { +namespace ops { +namespace builtin { +namespace reduce { + +enum ReduceType { + kSum, + kProd, + kMax, + kMin, + kAny, + kAll, +}; + +} // namespace reduce +} // namespace builtin +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_REDUCE_COMMON_H_ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 55b56a937d788f..a2bffa2c2c6425 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -1112,7 +1112,6 @@ inline void UnsortedSegmentProd(const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& segment_ids_shape, const int32_t* segment_ids_data, - const int32_t num_segments, const RuntimeShape& output_shape, T* output_data) { for (int i = 0; i < output_shape.FlatSize(); ++i) { @@ -1122,6 +1121,7 @@ inline void UnsortedSegmentProd(const RuntimeShape& input_shape, MatchingFlatSizeSkipDim(input_shape, 0, output_shape); for (int i = 0; i < input_shape.Dims(0); i++) { int output_index = segment_ids_data[i]; + if (output_index < 0) continue; for (int j = 0; j < segment_flat_size; ++j) { output_data[output_index * segment_flat_size + j] *= input_data[i * segment_flat_size + j]; diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc index 396bfe86d93877..ea0eebfb3a92cc 100644 --- a/tensorflow/lite/kernels/reduce.cc +++ b/tensorflow/lite/kernels/reduce.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/reduce.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reduce_common.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" @@ -69,10 +70,10 @@ struct OpContext { }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - // Creates two temp tensors to store index and axis for internal + // Creates three temp tensors to store index and axis for internal // implementation only. auto* op_data = new OpData(); - context->AddTensors(context, 3, &op_data->scratch_tensor_index); + context->AddTensors(context, 4, &op_data->scratch_tensor_index); return op_data; } @@ -164,13 +165,21 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) { } } +// Resizes the temp tensor that stores normalized dims. +TfLiteStatus ResizeTempDims(TfLiteContext* context, OpContext* op_context, + TfLiteTensor* normalized_dims) { + TfLiteIntArray* dims_size = TfLiteIntArrayCreate(1); + dims_size->data[0] = (op_context->input->dims->size); + return context->ResizeTensor(context, normalized_dims, dims_size); +} + // Initializes temp tensors to store index and resolved axis. TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, OpContext* op_context) { // Creates a temp index to iterate through input data. OpData* op_data = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(3); + node->temporaries = TfLiteIntArrayCreate(4); node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* scratch_tensor; TF_LITE_ENSURE_OK( @@ -215,6 +224,12 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, default: return kTfLiteError; } + // Creates a temp tensor to store normalized shape given input data. + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* normalized_dims; + TF_LITE_ENSURE_OK( + context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims)); + normalized_dims->type = kTfLiteInt32; return kTfLiteOk; } @@ -234,6 +249,17 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* resolved_axis; TF_LITE_ENSURE_OK( context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis)); + TfLiteTensor* normalized_dims; + TF_LITE_ENSURE_OK( + context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims)); + + if (!IsConstantTensor(op_context.input)) { + SetTensorToDynamic(normalized_dims); + } else { + normalized_dims->allocation_type = kTfLiteArenaRw; + TF_LITE_ENSURE_OK(context, + ResizeTempDims(context, &op_context, normalized_dims)); + } // Leaves work to Eval if axis is not constant; else resizes output. if (!IsConstantTensor(op_context.axis)) { SetTensorToDynamic(op_context.output); @@ -697,9 +723,9 @@ void ReduceAllDims(const T* input_data, const int* input_dims, // The underlying logic for Reduce Sum/Prod/Max/Min/Any template -TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node, - OpContext* op_context, T init_value, - T reducer(const T current, const T in)) { +TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node, + OpContext* op_context, KernelType kernel_type, + ReduceType reduce_type) { int64_t num_axis = NumElements(op_context->axis); TfLiteTensor* temp_index; TF_LITE_ENSURE_OK(context, @@ -722,90 +748,89 @@ TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node, TF_LITE_ENSURE_EQ(context, input->params.zero_point, op_context->output->params.zero_point); } - int num_resolved_axis = 0; - if (!tflite::reference_ops::ResolveAxis( - input->dims->size, GetTensorData(op_context->axis), num_axis, - GetTensorData(resolved_axis), &num_resolved_axis)) { - return kTfLiteError; - } - if (IsReduceAllDims(resolved_axis, num_resolved_axis, input->dims->size)) { - ReduceAllDims(GetTensorData(input), input->dims->data, input->dims->size, - GetTensorData(op_context->output), init_value, reducer, - context); - return kTfLiteOk; - } - TF_LITE_ENSURE( - context, - reference_ops::ReduceGeneric( - GetTensorData(input), input->dims->data, input->dims->size, - GetTensorData(op_context->output), op_context->output->dims->data, - op_context->output->dims->size, GetTensorData(op_context->axis), - num_axis, op_context->params->keep_dims, - GetTensorData(temp_index), GetTensorData(resolved_axis), - init_value, reducer)); - return kTfLiteOk; -} - -enum ReduceType { - kSum, - kProd, - kMax, - kMin, - kAny, - kAll, -}; - -// Eval for determined input type and reduce type. -template -TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node, - OpContext* op_context, ReduceType reduce_type) { - switch (reduce_type) { - case kSum: - return EvalLogic( - context, node, op_context, static_cast(0), - [](const T current, const T in) -> T { return in + current; }); - break; - case kProd: - return EvalLogic( - context, node, op_context, static_cast(1), - [](const T current, const T in) -> T { return in * current; }); - break; - case kMax: - return EvalLogic(context, node, op_context, - std::numeric_limits::lowest(), - [](const T current, const T in) -> T { - return (in > current) ? in : current; - }); - break; - case kMin: - return EvalLogic(context, node, op_context, - std::numeric_limits::max(), - [](const T current, const T in) -> T { - return (in < current) ? in : current; - }); - break; - default: - return kTfLiteError; - } -} + if (kernel_type == kReference) { + T init_value = 0; + T (*reducer)(const T current, const T in); + switch (reduce_type) { + case kSum: + reducer = [](const T current, const T in) -> T { return in + current; }; + init_value = T(0); + break; + case kProd: + init_value = static_cast(1); + reducer = [](const T current, const T in) -> T { return in * current; }; + break; + case kMax: + init_value = std::numeric_limits::lowest(); + reducer = [](const T current, const T in) -> T { + return (in > current) ? in : current; + }; + break; + case kMin: + init_value = std::numeric_limits::max(); + reducer = [](const T current, const T in) -> T { + return (in < current) ? in : current; + }; + break; + case kAny: + init_value = false; + reducer = [](const T current, const T in) -> T { + return in || current; + }; + break; + case kAll: + init_value = true; + reducer = [](const T current, const T in) -> T { + return in && current; + }; + break; + default: + TF_LITE_KERNEL_LOG(context, "Unsupported ReduceType: %d", reduce_type); + return kTfLiteError; + } -// Template specialization for bool type -template <> -TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node, - OpContext* op_context, ReduceType reduce_type) { - switch (reduce_type) { - case kAny: - return EvalLogic(context, node, op_context, false, - [](const bool current, const bool in) -> bool { - return in || current; - }); - case kAll: - return EvalLogic(context, node, op_context, true, - [](const bool current, const bool in) -> bool { - return in && current; - }); - default: - return kTfLiteError; + int num_resolved_axis = 0; + TF_LITE_ENSURE_MSG( + context, + tflite::reference_ops::ResolveAxis( + input->dims->size, GetTensorData(op_context->axis), num_axis, + GetTensorData(resolved_axis), &num_resolved_axis), + "Invalid axis index."); + + if (IsReduceAllDims(resolved_axis, num_resolved_axis, input->dims->size)) { + ReduceAllDims(GetTensorData(input), input->dims->data, + input->dims->size, GetTensorData(op_context->output), + init_value, reducer, context); + return kTfLiteOk; + } + TF_LITE_ENSURE( + context, + reference_ops::ReduceGeneric( + GetTensorData(input), input->dims->data, input->dims->size, + GetTensorData(op_context->output), + op_context->output->dims->data, op_context->output->dims->size, + GetTensorData(op_context->axis), num_axis, + op_context->params->keep_dims, GetTensorData(temp_index), + GetTensorData(resolved_axis), init_value, reducer)); + return kTfLiteOk; + } else { + TfLiteTensor* normalized_dims; + TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3, + &normalized_dims)); + if (IsDynamicTensor(normalized_dims)) { + TF_LITE_ENSURE_OK(context, + ResizeTempDims(context, op_context, normalized_dims)); + } + TF_LITE_ENSURE( + context, + optimized_ops::ReduceGeneric( + GetTensorData(input), input->dims->data, input->dims->size, + GetTensorData(op_context->output), + op_context->output->dims->data, op_context->output->dims->size, + GetTensorData(op_context->axis), num_axis, + GetTensorData(resolved_axis), + GetTensorData(normalized_dims), reduce_type)); + return kTfLiteOk; } } @@ -813,37 +838,69 @@ TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node, // handle ReduceType. template TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) { - if (kernel_type != kReference) { - return kTfLiteOk; - } OpContext op_context(context, node); switch (op_context.input->type) { case kTfLiteFloat32: - return EvalType(context, node, &op_context, reduce_type); + return EvalType(context, node, &op_context, kernel_type, + reduce_type); break; case kTfLiteInt32: - return EvalType(context, node, &op_context, reduce_type); + return EvalType(context, node, &op_context, kernel_type, + reduce_type); break; case kTfLiteInt64: - return EvalType(context, node, &op_context, reduce_type); + return EvalType(context, node, &op_context, kernel_type, + reduce_type); break; case kTfLiteUInt8: - return EvalType(context, node, &op_context, reduce_type); + return EvalType(context, node, &op_context, kernel_type, + reduce_type); break; case kTfLiteInt8: - return EvalType(context, node, &op_context, reduce_type); + return EvalType(context, node, &op_context, kernel_type, + reduce_type); break; case kTfLiteInt16: - return EvalType(context, node, &op_context, reduce_type); + return EvalType(context, node, &op_context, kernel_type, + reduce_type); break; case kTfLiteBool: - return EvalType(context, node, &op_context, reduce_type); + return EvalType(context, node, &op_context, kernel_type, + reduce_type); break; default: return kTfLiteError; } } +template +TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, OpContext* op_context, + int* temp_index, int* resolved_axis, + int* temp_sum, KernelType kernel_type, + bool compute_sum) { + int num_axis = static_cast(NumElements(op_context->axis)); + auto args = std::tuple( + GetTensorData(op_context->input), op_context->input->params.zero_point, + op_context->input->params.scale, &op_context->input->dims->data[0], + op_context->input->dims->size, GetTensorData(op_context->output), + op_context->output->params.zero_point, op_context->output->params.scale, + &op_context->output->dims->data[0], op_context->output->dims->size, + GetTensorData(op_context->axis), num_axis, + op_context->params->keep_dims, temp_index, resolved_axis, temp_sum, + compute_sum); + if (kernel_type == kReference) { + TF_LITE_ENSURE( + context, + std::apply(reference_ops::QuantizedMeanOrSum, args)); + } else { + TF_LITE_ENSURE( + context, + std::apply(optimized_ops::QuantizedMeanOrSum, args)); + } + return kTfLiteOk; +} + +template TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); ruy::profiler::ScopeLabel label("Sum"); @@ -857,7 +914,6 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { const bool need_rescale = (eight_bit_quantized && !same_scale); if (need_rescale) { // Rescaling 8bit reduce sum. - int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index; TF_LITE_ENSURE_OK( context, GetTemporarySafe(context, node, /*index=*/0, &temp_index)); @@ -877,47 +933,26 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { } if (input->type == kTfLiteUInt8) { - TF_LITE_ENSURE( - context, - reference_ops::QuantizedMeanOrSum<>( - GetTensorData(op_context.input), - op_context.input->params.zero_point, - op_context.input->params.scale, op_context.input->dims->data, - op_context.input->dims->size, - GetTensorData(op_context.output), - op_context.output->params.zero_point, - op_context.output->params.scale, op_context.output->dims->data, - op_context.output->dims->size, - GetTensorData(op_context.axis), num_axis, - op_context.params->keep_dims, GetTensorData(temp_index), - GetTensorData(resolved_axis), GetTensorData(temp_sum), - /*compute_sum=*/true)); - } - if (input->type == kTfLiteInt8) { - TF_LITE_ENSURE( - context, - reference_ops::QuantizedMeanOrSum<>( - GetTensorData(op_context.input), - op_context.input->params.zero_point, - op_context.input->params.scale, op_context.input->dims->data, - op_context.input->dims->size, - GetTensorData(op_context.output), - op_context.output->params.zero_point, - op_context.output->params.scale, op_context.output->dims->data, - op_context.output->dims->size, - GetTensorData(op_context.axis), num_axis, - op_context.params->keep_dims, GetTensorData(temp_index), - GetTensorData(resolved_axis), GetTensorData(temp_sum), - /*compute_sum=*/true)); + QuantizedMeanOrSum(context, &op_context, + GetTensorData(temp_index), + GetTensorData(resolved_axis), + GetTensorData(temp_sum), kernel_type, + /*compute_sum=*/true); + } else { + QuantizedMeanOrSum(context, &op_context, + GetTensorData(temp_index), + GetTensorData(resolved_axis), + GetTensorData(temp_sum), kernel_type, + /*compute_sum=*/true); } } else { - return EvalGeneric(context, node); + return EvalGeneric(context, node); } return kTfLiteOk; } -template +template TfLiteStatus EvalQuantizedProd(TfLiteContext* context, TfLiteNode* node, OpContext* op_context) { OpData* data = reinterpret_cast(node->user_data); @@ -932,7 +967,9 @@ TfLiteStatus EvalQuantizedProd(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* temp_prod; TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, &temp_prod)); - + TfLiteTensor* normalized_dims; + TF_LITE_ENSURE_OK( + context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims)); const TfLiteTensor* input = op_context->input; TfLiteTensor* output = op_context->output; @@ -941,6 +978,10 @@ TfLiteStatus EvalQuantizedProd(TfLiteContext* context, TfLiteNode* node, if (input->dims->data[i] == 0) return kTfLiteOk; } + if (IsDynamicTensor(normalized_dims)) { + TF_LITE_ENSURE_OK(context, + ResizeTempDims(context, op_context, normalized_dims)); + } // Resize the output tensor if the output tensor is dynamic. if (IsDynamicTensor(output)) { TF_LITE_ENSURE_OK(context, @@ -960,19 +1001,34 @@ TfLiteStatus EvalQuantizedProd(TfLiteContext* context, TfLiteNode* node, QuantizeMultiplier(scaling, &data->multiplier, &data->shift); } - TF_LITE_ENSURE( - context, - reference_ops::QuantizedReduceProd( - GetTensorData(input), input->params.zero_point, - GetTensorShape(input), GetTensorData(output), - output->params.zero_point, GetTensorShape(output), - GetTensorData(op_context->axis), num_axis, - op_context->params->keep_dims, GetTensorData(temp_index), - GetTensorData(resolved_axis), GetTensorData(temp_prod), - data->multiplier, data->shift)); - return kTfLiteOk; + if (kernel_type == kReference) { + TF_LITE_ENSURE( + context, + reference_ops::QuantizedReduceProd( + GetTensorData(input), input->params.zero_point, + GetTensorShape(input), GetTensorData(output), + output->params.zero_point, GetTensorShape(output), + GetTensorData(op_context->axis), num_axis, + op_context->params->keep_dims, GetTensorData(temp_index), + GetTensorData(resolved_axis), GetTensorData(temp_prod), + data->multiplier, data->shift)); + return kTfLiteOk; + } else { + TF_LITE_ENSURE( + context, + optimized_ops::QuantizedReduceProd( + GetTensorData(input), input->params.zero_point, + GetTensorShape(input), GetTensorData(output), + output->params.zero_point, GetTensorShape(output), + GetTensorData(op_context->axis), num_axis, + GetTensorData(resolved_axis), + GetTensorData(normalized_dims), + GetTensorData(temp_prod), data->multiplier, data->shift)); + return kTfLiteOk; + } } +template TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); // As we need to support both quantized and non-quantized int8/int16 inputs, @@ -981,21 +1037,24 @@ TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) { // other non-quantized types). if (op_context.input->quantization.type != kTfLiteNoQuantization) { if (op_context.input->type == kTfLiteInt8) { - return EvalQuantizedProd(context, node, &op_context); + return EvalQuantizedProd(context, node, &op_context); } else if (op_context.input->type == kTfLiteInt16) { - return EvalQuantizedProd(context, node, &op_context); + return EvalQuantizedProd(context, node, + &op_context); } else { TF_LITE_KERNEL_LOG(context, "Unsupported quantized data type: %d", op_context.input->type); return kTfLiteError; } } else { - return EvalGeneric(context, node); + return EvalGeneric(context, node); } } } // namespace reduce +using ops::builtin::reduce::ReduceType; + TfLiteRegistration* Register_MEAN_OPT() { static TfLiteRegistration r = {reduce::Init, reduce::Free, reduce::PrepareMeanOrSum, @@ -1012,41 +1071,85 @@ TfLiteRegistration* Register_MEAN_REF() { TfLiteRegistration* Register_SUM_REF() { static TfLiteRegistration r = {reduce::Init, reduce::Free, - reduce::PrepareMeanOrSum, reduce::EvalSum}; + reduce::PrepareMeanOrSum, + reduce::EvalSum}; + return &r; +} + +TfLiteRegistration* Register_SUM_OPT() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareMeanOrSum, + reduce::EvalSum}; return &r; } TfLiteRegistration* Register_REDUCE_PROD_REF() { static TfLiteRegistration r = {reduce::Init, reduce::Free, - reduce::PrepareProd, reduce::EvalProd}; + reduce::PrepareProd, + reduce::EvalProd}; + return &r; +} + +TfLiteRegistration* Register_REDUCE_PROD_OPT() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareProd, + reduce::EvalProd}; return &r; } TfLiteRegistration* Register_REDUCE_MAX_REF() { static TfLiteRegistration r = { reduce::Init, reduce::Free, reduce::PrepareSimple, - reduce::EvalGeneric}; + reduce::EvalGeneric}; + return &r; +} + +TfLiteRegistration* Register_REDUCE_MAX_OPT() { + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareSimple, + reduce::EvalGeneric}; return &r; } TfLiteRegistration* Register_REDUCE_MIN_REF() { static TfLiteRegistration r = { reduce::Init, reduce::Free, reduce::PrepareSimple, - reduce::EvalGeneric}; + reduce::EvalGeneric}; + return &r; +} + +TfLiteRegistration* Register_REDUCE_MIN_OPT() { + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareSimple, + reduce::EvalGeneric}; return &r; } TfLiteRegistration* Register_REDUCE_ANY_REF() { static TfLiteRegistration r = { reduce::Init, reduce::Free, reduce::PrepareAllOrAny, - reduce::EvalGeneric}; + reduce::EvalGeneric}; + return &r; +} + +TfLiteRegistration* Register_REDUCE_ANY_OPT() { + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareAllOrAny, + reduce::EvalGeneric}; return &r; } TfLiteRegistration* Register_REDUCE_ALL_REF() { static TfLiteRegistration r = { reduce::Init, reduce::Free, reduce::PrepareAllOrAny, - reduce::EvalGeneric}; + reduce::EvalGeneric}; + return &r; +} + +TfLiteRegistration* Register_REDUCE_ALL_OPT() { + static TfLiteRegistration r = { + reduce::Init, reduce::Free, reduce::PrepareAllOrAny, + reduce::EvalGeneric}; return &r; } @@ -1058,14 +1161,14 @@ TfLiteRegistration* Register_MEAN() { #endif } -TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); } +TfLiteRegistration* Register_SUM() { return Register_SUM_OPT(); } TfLiteRegistration* Register_REDUCE_PROD() { - return Register_REDUCE_PROD_REF(); + return Register_REDUCE_PROD_OPT(); } -TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); } -TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_REF(); } -TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_REF(); } -TfLiteRegistration* Register_REDUCE_ALL() { return Register_REDUCE_ALL_REF(); } +TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_OPT(); } +TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_OPT(); } +TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_OPT(); } +TfLiteRegistration* Register_REDUCE_ALL() { return Register_REDUCE_ALL_OPT(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc index 93b75003b552e5..fd0365149945ba 100644 --- a/tensorflow/lite/kernels/reduce_test.cc +++ b/tensorflow/lite/kernels/reduce_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -135,6 +136,19 @@ TYPED_TEST(DynamicReductionIsCopyTestBool, ReduceIsCopy) { EXPECT_THAT(m.template GetOutput(), ElementsAreArray(data)); } +TEST(ConstFloatMeanOpTest, FoldFirstDim) { + int count = 1 * 2 * 2 * 3; + std::vector data(count); + std::iota(data.begin(), data.end(), 0); + SumOpConstModel m({TensorType_FLOAT32, {1, 2, 2, 3}}, + {TensorType_FLOAT32, {2, 2}}, {2}, {3, 0}, false); + m.SetInput(data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 12, 21, 30}))); +} + TEST(ConstFloatMeanOpTest, Flatten2ReduceDims) { std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, @@ -586,6 +600,67 @@ TEST(ConstUint8MeanOpTest, QuantizedKeepDims) { // Tests for reduce_sum +TEST(ConstFloatSumOpTest, Size1) { + std::vector data = {1.0}; + SumOpConstModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, {1}, + {0}, false); + m.SetInput(data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({1}))); +} + +TEST(ConstFloatSumOpTest, Size1Dims) { + std::vector data = {1.0, 2.0}; + SumOpConstModel m({TensorType_FLOAT32, {2}}, {TensorType_FLOAT32, {1}}, {1}, + {0}, false); + m.SetInput(data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3}))); +} + +TEST(ConstFloatSumOpTest, Size1Contiguous) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + SumOpConstModel m({TensorType_FLOAT32, {8, 1}}, {TensorType_FLOAT32, {8}}, + {1}, {1}, false); + m.SetInput(data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(data))); +} + +TEST(ConstFloatSumOpTest, Size1DisContiguous) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + SumOpConstModel m({TensorType_FLOAT32, {1, 8}}, {TensorType_FLOAT32, {1}}, + {1}, {1}, false); + m.SetInput(data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({36}))); +} + +TEST(ConstFloatSumOpTest, RedundantDimension) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + SumOpConstModel m({TensorType_FLOAT32, {1, 2, 4}}, {TensorType_FLOAT32, {2}}, + {1}, {1}, false); + m.SetInput(data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({6, 8, 10, 12}))); +} + +TEST(ConstFloatSumOpTest, AllSize1) { + std::vector data = {1.0}; + SumOpConstModel m({TensorType_FLOAT32, {1, 1, 1}}, {TensorType_FLOAT32, {1}}, + {1}, {1}, false); + m.SetInput(data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({1}))); +} + TEST(ConstFloatSumOpTest, NotKeepDims) { std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, @@ -711,6 +786,28 @@ TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) { ElementsAreArray(ArrayFloatNear({1.2, 1.2}, kQuantizedTolerance))); } +TEST(ConstUint8SumOpTest, OffsetZeroPoint) { + float kQuantizedTolerance = GetTolerance(0.0, 2.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6, + 0.3, 0.1, 0.5, 0.2, 0.4, 0.6}; + SumOpConstModel m({TensorType_UINT8, {1, 3, 2, 2}, -0.1, 1.0}, + {TensorType_UINT8, {2}, 0.0, 2.0}, {1}, {-1}, false); + m.QuantizeAndPopulate(m.Input(), data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.6, + 0.7, + 1.1, + 0.4, + 0.7, + 1.0, + }, + kQuantizedTolerance))); +} + TEST(ConstUint8SumOpTest, KeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; @@ -814,6 +911,59 @@ void ConstIntProdOpTestNotKeepDimsLarge() { {3.162341376e+11, 1.9619905536e+12}, kQuantizedTolerance))); } +template +void ConstIntProdOpTestDisContigReduction() { + const float input_min = (tensor_type == TensorType_INT16) ? -12.0 : 0.0; + const float input_max = 12.0; + const float output_min = (tensor_type == TensorType_INT16) ? -57600 : 0.0; + const float output_max = 57600; + + const std::vector data = { + 1.0, 2.0, 3.0, 4.0, 8.0, 7.0, 6.0, 5.0, 10.0, 9.0, 11.0, 12.0, + 1.0, 2.0, 3.0, 4.0, 8.0, 7.0, 6.0, 5.0, 10.0, 9.0, 11.0, 12.0}; + ProdOpConstModel m({tensor_type, {3, 2, 2, 2}, input_min, input_max}, + {tensor_type, {2}, output_min, output_max}, {2}, {1, 0}, + false); + m.QuantizeAndPopulate(m.Input(), data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + const int reduced_axis_size = 6; + const float kQuantizedStep = + GetTolerance(output_min, output_max); + const float kQuantizedTolerance = reduced_axis_size * 2 * kQuantizedStep; + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {6404.31, 15875.7, 39208.7, 57602.2}, kQuantizedTolerance))); +} + +template +void ConstIntProdOpTestContigReduction() { + const float input_min = (tensor_type == TensorType_INT16) ? -12.0 : 0.0; + const float input_max = 12.0; + const float output_min = (tensor_type == TensorType_INT16) ? -11880 : 0.0; + const float output_max = 11880; + + const std::vector data = { + 1.0, 8.0, 13.0, 4.0, 8.0, 7.0, 6.0, 5.0, 10.0, 9.0, 11.0, 12.0, + 1.0, 6.0, 9.0, 4.0, 8.0, 7.0, 6.0, 5.0, 10.0, 9.0, 11.0, 12.0}; + ProdOpConstModel m({tensor_type, {3, 2, 2, 2}, input_min, input_max}, + {tensor_type, {2}, output_min, output_max}, {2}, {2, 3}, + false); + m.QuantizeAndPopulate(m.Input(), data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + const int reduced_axis_size = 4; + const float kQuantizedStep = + GetTolerance(output_min, output_max); + const float kQuantizedTolerance = reduced_axis_size * 2 * kQuantizedStep; + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {383.951, 1680.1, 11879.3, 216.086, 1680.1, 11879.3}, + kQuantizedTolerance))); +} + TEST(ConstInt8ProdOpTest, NotKeepDimsLarge) { ConstIntProdOpTestNotKeepDimsLarge(); } @@ -822,6 +972,22 @@ TEST(ConstInt16ProdOpTest, NotKeepDimsLarge) { ConstIntProdOpTestNotKeepDimsLarge(); } +TEST(ConstInt8ProdOpTest, DisContigProdOpTest) { + ConstIntProdOpTestDisContigReduction(); +} + +TEST(ConstInt16ProdOpTest, DisContigProdOpTest) { + ConstIntProdOpTestDisContigReduction(); +} + +TEST(ConstInt8ProdOpTest, ContigProdOpTest) { + ConstIntProdOpTestContigReduction(); +} + +TEST(ConstInt16ProdOpTest, ContigProdOpTest) { + ConstIntProdOpTestContigReduction(); +} + TEST(ConstFloatProdOpTest, NotKeepDimsSmall) { const std::vector data = { -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, @@ -1222,6 +1388,21 @@ TEST(DynamicInt16MaxOpTest, Scalar) { // Tests for reduce_min +TEST(ConstFloatMinOpTest, DiscontiguousReduction) { + int count = 3 * 3 * 2 * 4; + std::vector data(count); + std::iota(data.begin(), data.end(), 0); + MinOpConstModel m({TensorType_FLOAT32, {3, 3, 2, 4}}, + {TensorType_FLOAT32, {4}}, {1}, {1}, true); + m.SetInput(data); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2, 4})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0, 1, 2, 3, 4, 5, 6, 7, 24, 25, 26, 27, + 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55}))); +} + TEST(ConstFloatMinOpTest, NotKeepDims) { std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, diff --git a/tensorflow/lite/kernels/slice.cc b/tensorflow/lite/kernels/slice.cc index ee812bf83ff5bf..95560d95a300bc 100644 --- a/tensorflow/lite/kernels/slice.cc +++ b/tensorflow/lite/kernels/slice.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context_util.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" @@ -109,6 +110,13 @@ TfLiteStatus ResizeOutputShape(TfLiteContext* context, return context->ResizeTensor(context, output, output_shape); } +bool ShapeHasRank(const TfLiteIntArray* shape) { + // Note that we consider scalar as false here because there is + // no differentiation between scalar and dynamic properly supported. + if (shape == nullptr || shape->size == 0) return false; + return true; +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -135,6 +143,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim, "Slice op only supports 1D-5D input arrays."); + // If the shape of output is fully specified then resize even if + // the input shape is not staticly defined. + if (!HasUnspecifiedDimension(output) && ShapeHasRank(output->dims)) { + return kTfLiteOk; + } // Postpone allocation of output if any of the indexing tensors is not // constant, or the input tensor has dynamic dimension. if (!(IsConstantTensor(begin) && IsConstantTensor(size)) || diff --git a/tensorflow/lite/kernels/slice_test.cc b/tensorflow/lite/kernels/slice_test.cc index ef5491484e52cd..a011efb28c615d 100644 --- a/tensorflow/lite/kernels/slice_test.cc +++ b/tensorflow/lite/kernels/slice_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_type.h" @@ -42,7 +43,8 @@ class SliceOpModel : public SingleOpModel { std::initializer_list size_shape, std::initializer_list size_data, TensorType tensor_index_type, TensorType tensor_input_type, - TestType input_tensor_types) { + TestType input_tensor_types, + std::initializer_list output_shape = {}) { input_ = AddInput(tensor_input_type); if (input_tensor_types == TestType::kDynamic) { begin_ = AddInput(tensor_index_type); @@ -52,7 +54,7 @@ class SliceOpModel : public SingleOpModel { AddConstInput(GetTensorType(), begin_data, begin_shape); size_ = AddConstInput(GetTensorType(), size_data, size_shape); } - output_ = AddOutput(tensor_input_type); + output_ = AddOutput(TensorData(tensor_input_type, output_shape)); SetBuiltinOp(BuiltinOperator_SLICE, BuiltinOptions_SliceOptions, CreateSliceOptions(builder_).Union()); BuildInterpreter({input_shape, begin_shape, size_shape}); @@ -75,6 +77,10 @@ class SliceOpModel : public SingleOpModel { } std::vector GetOutputShape() { return GetTensorShape(output_); } + const TfLiteTensor* GetOutputTensor() { + return interpreter_->tensor(output_); + } + private: int input_; int begin_; @@ -273,6 +279,17 @@ TEST_P(SliceOpTest, SliceInt64) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); } +TEST_P(SliceOpTest, SliceInt64StaticOutput) { + SliceOpModel m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4}, + {2, 1, -1, 1}, TensorType_INT32, + TensorType_INT64, GetParam(), {2, 1, 3, 1}); + m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); + EXPECT_NE(m.GetOutputTensor()->allocation_type, kTfLiteDynamic); +} + TEST_P(SliceOpTest, SliceBool) { SliceOpModel m({2, 3}, {2}, {1, 0}, {2}, {-1, 2}, TensorType_INT32, TensorType_BOOL, GetParam()); diff --git a/tensorflow/lite/kernels/unsorted_segment_prod.cc b/tensorflow/lite/kernels/unsorted_segment_prod.cc index 8fc0acb93e9394..1eb4528a688293 100644 --- a/tensorflow/lite/kernels/unsorted_segment_prod.cc +++ b/tensorflow/lite/kernels/unsorted_segment_prod.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -27,14 +29,32 @@ namespace unsorted_segment_prod { static const int kInputDataTensor = 0; static const int kInputSegmentIdsTensor = 1; +static const int kInputNumSegmentsTensor = 2; static const int kOutputTensor = 0; TfLiteStatus ResizeOutputTensor(TfLiteContext* context, const TfLiteTensor* data, - const int num_segments, TfLiteTensor* output) { + const TfLiteTensor* segment_ids, + const TfLiteTensor* num_segments, + TfLiteTensor* output) { + // We take the first element in num_segments as the valid number of segments + // in the case where num_segments tensor is initialized with more than one + // elements + TF_LITE_ENSURE(context, (num_segments->dims->size == 1 && + num_segments->dims->data[0] == 1) || + num_segments->dims->size == 0); + int32_t output_dim = GetTensorData(num_segments)[0]; + const int segment_id_size = segment_ids->dims->data[0]; + TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]); + int max_index = -1; + for (int i = 0; i < segment_id_size; i++) { + max_index = std::max(GetTensorData(segment_ids)[i], max_index); + } + TF_LITE_ENSURE(context, max_index < output_dim); + const int data_rank = NumDimensions(data); TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data)); - output_shape->data[0] = num_segments; + output_shape->data[0] = output_dim; for (int i = 1; i < data_rank; ++i) { output_shape->data[i] = data->dims->data[i]; } @@ -42,7 +62,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* data; TF_LITE_ENSURE_OK(context, @@ -50,6 +70,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* segment_ids; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor, &segment_ids)); + const TfLiteTensor* num_segments; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments)); TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputTensor, &output)); @@ -57,35 +81,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { data->type == kTfLiteInt32 || data->type == kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32); - if (!IsConstantTensor(data) || !IsConstantTensor(segment_ids)) { + if (IsDynamicTensor(data) || !IsConstantTensor(segment_ids) || + !IsConstantTensor(num_segments)) { SetTensorToDynamic(output); return kTfLiteOk; } - - const auto no_segments = - reinterpret_cast( - node->builtin_data) - ->num_segments; - return ResizeOutputTensor(context, data, no_segments, output); + return ResizeOutputTensor(context, data, segment_ids, num_segments, output); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const auto* params = reinterpret_cast( - node->builtin_data); const TfLiteTensor* data; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputDataTensor, &data)); const TfLiteTensor* segment_ids; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor, &segment_ids)); + const TfLiteTensor* num_segments; + TF_LITE_ENSURE_OK( + context, + GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments)); TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputTensor, &output)); if (IsDynamicTensor(output)) { - TF_LITE_ENSURE_OK( - context, - ResizeOutputTensor(context, data, params->num_segments, output)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, data, segment_ids, + num_segments, output)); } TF_LITE_ENSURE_EQ(context, GetTensorShape(data).Dims(0), GetTensorShape(segment_ids).Dims(0)); @@ -94,8 +115,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { reference_ops::UnsortedSegmentProd( \ GetTensorShape(data), GetTensorData(data), \ GetTensorShape(segment_ids), GetTensorData(segment_ids), \ - params->num_segments, GetTensorShape(output), \ - GetTensorData(output)); + GetTensorShape(output), GetTensorData(output)); switch (data->type) { case kTfLiteInt32: TF_LITE_UNSORTED_SEGMENT_PROD(int32_t); diff --git a/tensorflow/lite/kernels/unsorted_segment_prod_test.cc b/tensorflow/lite/kernels/unsorted_segment_prod_test.cc index d892943d9a8974..45b86fa8989012 100644 --- a/tensorflow/lite/kernels/unsorted_segment_prod_test.cc +++ b/tensorflow/lite/kernels/unsorted_segment_prod_test.cc @@ -29,33 +29,50 @@ template class UnsortedSegmentProdOpModel : public SingleOpModel { public: UnsortedSegmentProdOpModel(const TensorData& data, - const TensorData& segment_ids, int num_segments) { + const TensorData& segment_ids, + const TensorData& num_segments) { data_id_ = AddInput(data); segment_ids_id_ = AddInput(segment_ids); + num_segments_id_ = AddInput(num_segments); output_id_ = AddOutput(data.type); - SetBuiltinOp( - BuiltinOperator_UNSORTED_SEGMENT_PROD, - BuiltinOptions_UnsortedSegmentProdOptions, - CreateUnsortedSegmentProdOptions(builder_, num_segments).Union()); - BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_)}); + SetBuiltinOp(BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOptions_UnsortedSegmentProdOptions, 0); + BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_), + GetShape(num_segments_id_)}); } int data() const { return data_id_; } int segment_ids() const { return segment_ids_id_; } + int num_segments() const { return num_segments_id_; } std::vector GetOutput() { return ExtractVector(output_id_); } std::vector GetOutputShape() { return GetTensorShape(output_id_); } protected: int data_id_; int segment_ids_id_; + int num_segments_id_; int output_id_; }; TEST(UnsortedSegmentProdOpModelTest, Int32Test_Simple) { UnsortedSegmentProdOpModel model({TensorType_INT32, {8}}, - {TensorType_INT32, {8}}, 8); + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); model.PopulateTensor(model.data(), {1, 2, 3, 4, 4, 3, 2, 1}); model.PopulateTensor(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7}); + model.PopulateTensor(model.num_segments(), {8}); + ASSERT_EQ(model.Invoke(), kTfLiteOk); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 1, 1, 1, 1, 1, 96})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); +} + +TEST(UnsortedSegmentProdOpModelTest, TestSkipNegSegmentId) { + UnsortedSegmentProdOpModel model({TensorType_INT32, {8}}, + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); + model.PopulateTensor(model.data(), {1, 2, 3, 4, 4, 3, 2, 1}); + model.PopulateTensor(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, -1}); + model.PopulateTensor(model.num_segments(), {8}); ASSERT_EQ(model.Invoke(), kTfLiteOk); EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 1, 1, 1, 1, 1, 96})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({8})); @@ -63,10 +80,12 @@ TEST(UnsortedSegmentProdOpModelTest, Int32Test_Simple) { TEST(UnsortedSegmentProdOpModelTest, Int32Test_Simple2D) { UnsortedSegmentProdOpModel model({TensorType_INT32, {3, 4}}, - {TensorType_INT32, {3}}, 2); + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); model.PopulateTensor(model.data(), {1, 2, 3, 4, 5, 6, 7, 8, 4, 3, 2, 1}); model.PopulateTensor(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor(model.num_segments(), {2}); ASSERT_EQ(model.Invoke(), kTfLiteOk); EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 6, 6, 4, 5, 6, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); @@ -74,10 +93,12 @@ TEST(UnsortedSegmentProdOpModelTest, Int32Test_Simple2D) { TEST(UnsortedSegmentProdOpModelTest, FloatTest_Simple) { UnsortedSegmentProdOpModel model({TensorType_FLOAT32, {8}}, - {TensorType_INT32, {8}}, 8); + {TensorType_INT32, {8}}, + {TensorType_INT32, {1}}); model.PopulateTensor(model.data(), {1.0, 2.0, 3.0, 4.0, 4.0, 3.0, 2.0, 1.0}); model.PopulateTensor(model.segment_ids(), {1, 0, 1, 7, 7, 7, 7, 7}); + model.PopulateTensor(model.num_segments(), {8}); ASSERT_EQ(model.Invoke(), kTfLiteOk); EXPECT_THAT(model.GetOutput(), ElementsAreArray( @@ -87,10 +108,12 @@ TEST(UnsortedSegmentProdOpModelTest, FloatTest_Simple) { TEST(UnsortedSegmentProdOpModelTest, FloatTest_Simple2D) { UnsortedSegmentProdOpModel model({TensorType_FLOAT32, {3, 4}}, - {TensorType_INT32, {3}}, 2); + {TensorType_INT32, {3}}, + {TensorType_INT32, {1}}); model.PopulateTensor(model.data(), {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 4.0, 3.0, 2.0, 1.0}); model.PopulateTensor(model.segment_ids(), {0, 1, 0}); + model.PopulateTensor(model.num_segments(), {2}); ASSERT_EQ(model.Invoke(), kTfLiteOk); EXPECT_THAT(model.GetOutput(), ElementsAreArray( diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 05e6801803e7e3..eb6b609b0a6a51 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -1123,9 +1123,9 @@ table DynamicUpdateSliceOptions { } table UnsortedSegmentProdOptions { - num_segments:int; } + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index c541afc41bd81c..be48a319e9b2e9 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -11055,21 +11055,13 @@ flatbuffers::Offset CreateDynamicUpdateSliceOptions(f struct UnsortedSegmentProdOptionsT : public flatbuffers::NativeTable { typedef UnsortedSegmentProdOptions TableType; - int32_t num_segments = 0; }; struct UnsortedSegmentProdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef UnsortedSegmentProdOptionsT NativeTableType; typedef UnsortedSegmentProdOptionsBuilder Builder; - enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_NUM_SEGMENTS = 4 - }; - int32_t num_segments() const { - return GetField(VT_NUM_SEGMENTS, 0); - } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_NUM_SEGMENTS, 4) && verifier.EndTable(); } UnsortedSegmentProdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -11081,9 +11073,6 @@ struct UnsortedSegmentProdOptionsBuilder { typedef UnsortedSegmentProdOptions Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_num_segments(int32_t num_segments) { - fbb_.AddElement(UnsortedSegmentProdOptions::VT_NUM_SEGMENTS, num_segments, 0); - } explicit UnsortedSegmentProdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -11096,10 +11085,8 @@ struct UnsortedSegmentProdOptionsBuilder { }; inline flatbuffers::Offset CreateUnsortedSegmentProdOptions( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t num_segments = 0) { + flatbuffers::FlatBufferBuilder &_fbb) { UnsortedSegmentProdOptionsBuilder builder_(_fbb); - builder_.add_num_segments(num_segments); return builder_.Finish(); } @@ -16267,7 +16254,6 @@ inline UnsortedSegmentProdOptionsT *UnsortedSegmentProdOptions::UnPack(const fla inline void UnsortedSegmentProdOptions::UnPackTo(UnsortedSegmentProdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = num_segments(); _o->num_segments = _e; } } inline flatbuffers::Offset UnsortedSegmentProdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -16278,10 +16264,8 @@ inline flatbuffers::Offset CreateUnsortedSegmentProd (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const UnsortedSegmentProdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _num_segments = _o->num_segments; return tflite::CreateUnsortedSegmentProdOptions( - _fbb, - _num_segments); + _fbb); } inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/lite/tensorflow_profiler_logger.cc b/tensorflow/lite/tensorflow_profiler_logger.cc index a1b14f07231e80..81a16ab119c42a 100644 --- a/tensorflow/lite/tensorflow_profiler_logger.cc +++ b/tensorflow/lite/tensorflow_profiler_logger.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include +#include #include +#include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -31,7 +33,9 @@ struct Statistics { uint64_t total_bytes_allocated = 0LL; uint64_t peak_bytes_in_use = 0LL; }; -Statistics g_stat; +static Statistics g_stat; + +static char g_current_op_name[256]; // Adds memory trace information for TensorFlow profiler. // `is_allocating`: Whether memory is being allocated or deallocated. @@ -79,7 +83,14 @@ void AddTraceMe(bool is_allocating, TfLiteTensor* tensor, size_t allocation_bytes) { if (tensor == nullptr || allocation_bytes == 0) return; int64_t tensor_id = reinterpret_cast(tensor->data.raw); - std::string name = tensor->name ? tensor->name : ""; + std::string name; + if (g_current_op_name[0]) { + name = g_current_op_name; + } + if (tensor->name) { + name += ":"; + name += tensor->name; + } std::string dims = tensor->dims ? GetShapeDebugString(tensor->dims) : "[]"; int64_t requested_bytes = is_allocating ? allocation_bytes : 0; const std::string allocator_name = "_tflite_native_dynamic"; @@ -96,6 +107,14 @@ void AddTraceMe(bool is_allocating, TfLiteTensor* tensor, } // namespace +void OnTfLiteOpInvoke(const char* op_name, const int node_index) { + snprintf(g_current_op_name, sizeof(g_current_op_name), "%s_%d", op_name, + node_index); + // Updates TF's current annotation object by creating scoped annotation obj. + tensorflow::profiler::ScopedMemoryDebugAnnotation annotation( + g_current_op_name); +} + void OnTfLiteTensorAlloc(TfLiteTensor* tensor, size_t num_bytes) { AddTraceMe(/*is_allocating=*/true, tensor, num_bytes); } diff --git a/tensorflow/lite/tensorflow_profiler_logger.h b/tensorflow/lite/tensorflow_profiler_logger.h index 48d4c24fdd6cb4..89532a2080424a 100644 --- a/tensorflow/lite/tensorflow_profiler_logger.h +++ b/tensorflow/lite/tensorflow_profiler_logger.h @@ -19,13 +19,21 @@ limitations under the License. #include #include +#include "tensorflow/lite/core/macros.h" + struct TfLiteTensor; namespace tflite { +// Records an op invocation with `op_name` and `node_index`. +TFLITE_ATTRIBUTE_WEAK void OnTfLiteOpInvoke(const char* op_name, + const int node_index); + // Records an event of `num_bytes` of memory allocated for `tensor`. -void OnTfLiteTensorAlloc(size_t num_bytes, TfLiteTensor* tensor); +TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorAlloc(TfLiteTensor* tensor, + size_t num_bytes); + // Records an event of memory deallocated for `tensor`. -void OnTfLiteTensorDealloc(TfLiteTensor* tensor); +TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorDealloc(TfLiteTensor* tensor); } // namespace tflite #endif // TENSORFLOW_LITE_TENSORFLOW_PROFILER_LOGGER_H_ diff --git a/tensorflow/lite/tensorflow_profiler_logger_shim.cc b/tensorflow/lite/tensorflow_profiler_logger_shim.cc index 73459419f8ec4b..c447474f0e6c22 100644 --- a/tensorflow/lite/tensorflow_profiler_logger_shim.cc +++ b/tensorflow/lite/tensorflow_profiler_logger_shim.cc @@ -21,11 +21,18 @@ limitations under the License. // benchmark library should have tensor_profiler_logger dependency. // Strong symbol definitions can be found in tensorflow_profiler_logger.cc. +namespace tflite { + +TFLITE_ATTRIBUTE_WEAK void OnTfLiteOpInvoke(const char* op_name, + const int node_index) {} + // No-op for the weak symbol. Overridden by a strong symbol in // tensorflow_profiler_logger.cc. -TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorAlloc(size_t num_bytes, - TfLiteTensor* tensor) {} +TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorAlloc(TfLiteTensor* tensor, + size_t num_bytes) {} // No-op for the weak symbol. Overridden by a strong symbol in // tensorflow_profiler_logger.cc. TFLITE_ATTRIBUTE_WEAK void OnTfLiteTensorDealloc(TfLiteTensor* tensor) {} + +} // namespace tflite diff --git a/tensorflow/lite/testing/build_def.bzl b/tensorflow/lite/testing/build_def.bzl index 8d613276d5c26e..d2d83ad0bcb3da 100644 --- a/tensorflow/lite/testing/build_def.bzl +++ b/tensorflow/lite/testing/build_def.bzl @@ -180,6 +180,7 @@ def generated_test_models(): "unique", "unpack", "unroll_batch_matmul", + "unsorted_segment_prod", "where", "where_v2", "while", @@ -267,6 +268,7 @@ def generated_test_models_failing(conversion_mode, delegate): "topk", "transpose", "unique", + "unsorted_segment_prod", "where", "where_v2", "while", diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 3ca1911b65a89f..d9f998cc2ca9c5 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -186,6 +186,7 @@ from tensorflow.lite.testing.op_tests.unique import make_unique_tests from tensorflow.lite.testing.op_tests.unpack import make_unpack_tests from tensorflow.lite.testing.op_tests.unroll_batch_matmul import make_unroll_batch_matmul_tests +from tensorflow.lite.testing.op_tests.unsorted_segment_prod import make_unsorted_segment_prod_tests from tensorflow.lite.testing.op_tests.where import make_where_tests from tensorflow.lite.testing.op_tests.where_v2 import make_where_v2_tests from tensorflow.lite.testing.op_tests.while_loop import make_while_tests diff --git a/tensorflow/lite/testing/op_tests/unsorted_segment_prod.py b/tensorflow/lite/testing/op_tests/unsorted_segment_prod.py new file mode 100644 index 00000000000000..c9ad2677750024 --- /dev/null +++ b/tensorflow/lite/testing/op_tests/unsorted_segment_prod.py @@ -0,0 +1,130 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test configs for unsorted_segment_prod.""" + +import tensorflow.compat.v1 as tf +from tensorflow.lite.testing.zip_test_utils import create_tensor_data +from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests +from tensorflow.lite.testing.zip_test_utils import register_make_test_function + + +@register_make_test_function() +def make_unsorted_segment_prod_tests(options): + """Make a set of tests for unsorted_segment_prod op.""" + test_parameters = [{ + "data": [[5]], + "segment_id": [[0, 1, 1, 0, 1]], + "num_segments": [2], + "dtype": [tf.int32, tf.float32], + "multi_node": [0] + }, { + "data": [[2, 3, 4], [2, 5, 2]], + "segment_id": [[0, 1]], + "num_segments": [2], + "dtype": [tf.int32, tf.float32], + "multi_node": [0] + }, { + "data": [[4]], + "segment_id": [[0, 0, 1, 8]], + "num_segments": [9], + "dtype": [tf.int32, tf.float32], + "multi_node": [0] + }, { + "data": [[4]], + "segment_id_shape": [[4]], + "segment_id_min": [0], + "segment_id_max": [1], + "num_segments": [2], + "dtype": [tf.int32, tf.float32], + "segment_id_2": [[0, 0]], + "num_segments_2": [1], + "multi_node": [1] + }] + + def build_graph_one_node(parameters): + data_tensor = tf.compat.v1.placeholder( + dtype=parameters["dtype"], name="data", shape=parameters["data"]) + segment_ids_tensor = tf.constant( + parameters["segment_id"], dtype=tf.int32, name="segment_ids") + num_segments_tensor = tf.constant( + parameters["num_segments"], + dtype=tf.int32, + shape=[], + name="num_segments") + output = tf.math.unsorted_segment_prod(data_tensor, segment_ids_tensor, + num_segments_tensor) + return [data_tensor], [output] + + +# test cases for handling dynamically shaped input tensor + def build_graph_multi_node(parameters): + data_tensor = tf.compat.v1.placeholder( + dtype=parameters["dtype"], name="data", shape=parameters["data"]) + segment_ids_tensor = tf.compat.v1.placeholder( + dtype=tf.int32, + name="segment_ids", + shape=parameters["segment_id_shape"]) + num_segments_tensor = tf.constant( + parameters["num_segments"], + dtype=tf.int32, + shape=[], + name="num_segments") + intermediate_tensor = tf.math.unsorted_segment_prod(data_tensor, + segment_ids_tensor, + num_segments_tensor) + segment_ids_tensor_2 = tf.constant( + parameters["segment_id_2"], dtype=tf.int32, name="segment_ids_2") + num_segments_tensor_2 = tf.constant( + parameters["num_segments_2"], + dtype=tf.int32, + shape=[], + name="num_segments_2") + output = tf.math.unsorted_segment_prod(intermediate_tensor, + segment_ids_tensor_2, + num_segments_tensor_2) + return [data_tensor, segment_ids_tensor], [output] + + def build_graph(parameters): + multi_node = parameters["multi_node"] + if multi_node: + return build_graph_multi_node(parameters) + + return build_graph_one_node(parameters) + + def build_inputs_one_node(parameters, sess, inputs, outputs): + data_value = create_tensor_data( + parameters["dtype"], shape=parameters["data"]) + return [data_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [data_value]))) + + def build_inputs_multi_node(parameters, sess, inputs, outputs): + data_value = create_tensor_data( + dtype=parameters["dtype"], shape=parameters["data"]) + segment_id_value = create_tensor_data( + dtype=tf.int32, + shape=parameters["segment_id_shape"], + min_value=parameters["segment_id_min"], + max_value=parameters["segment_id_max"]) + return [data_value, segment_id_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [data_value, segment_id_value]))) + + def build_inputs(parameters, sess, inputs, outputs): + multi_node = parameters["multi_node"] + if multi_node: + return build_inputs_multi_node(parameters, sess, inputs, outputs) + + return build_inputs_one_node(parameters, sess, inputs, outputs) + + make_zip_of_tests(options, test_parameters, build_graph, build_inputs) diff --git a/tensorflow/lite/tools/cmake/modules/eigen.cmake b/tensorflow/lite/tools/cmake/modules/eigen.cmake index 402592f7d84db1..cae290760f2294 100644 --- a/tensorflow/lite/tools/cmake/modules/eigen.cmake +++ b/tensorflow/lite/tools/cmake/modules/eigen.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( eigen GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git # Sync with tensorflow/third_party/eigen3/workspace.bzl - GIT_TAG b02c384ef4e8eba7b8bdef16f9dc6f8f4d6a6b2b + GIT_TAG 0e187141679fdb91da33249d18cb79a011c0e2ea # It's not currently (cmake 3.17) possible to shallow clone with a GIT TAG # as cmake attempts to git checkout the commit hash after the clone # which doesn't work as it's a shallow clone hence a different commit hash. diff --git a/tensorflow/lite/tools/delegates/compatibility/nnapi/BUILD b/tensorflow/lite/tools/delegates/compatibility/nnapi/BUILD new file mode 100644 index 00000000000000..98467c621b16bf --- /dev/null +++ b/tensorflow/lite/tools/delegates/compatibility/nnapi/BUILD @@ -0,0 +1,31 @@ +# BUILD rules for NNAPI delegate compatibility checking. + +cc_library( + name = "nnapi_compatibility_lib", + srcs = [ + "nnapi_compatibility_lib.cc", + ], + hdrs = [ + "nnapi_compatibility_lib.h", + ], + deps = [ + "//tensorflow/lite:framework_stable", + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + ], +) + +cc_test( + name = "nnapi_compatibility_lib_test", + srcs = [ + "nnapi_compatibility_lib_test.cc", + ], + deps = [ + ":nnapi_compatibility_lib", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib.cc b/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib.cc new file mode 100644 index 00000000000000..651e6d6195c54e --- /dev/null +++ b/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib.cc @@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib.h" + +#include +#include +#include + +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/minimal_logging.h" + +namespace tflite { +namespace tools { + +using ::tflite::delegate::nnapi::NNAPIValidationFailure; + +TfLiteStatus CheckCompatibility( + TfLiteContext* context, int32_t runtime_feature_level, + std::vector* supported_nodes, + std::map>* failures_by_node) { + if (!context) { + TFLITE_LOG_PROD_ONCE(TFLITE_LOG_ERROR, "Context is nullptr."); + return kTfLiteError; + } + // Gets execution plan. + TfLiteIntArray* execution_plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); + + // Validates compatibility for each node. + for (int node_index : TfLiteIntArrayView(execution_plan)) { + TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Node index: %d", node_index); + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + std::vector map_failures; + if (NNAPIDelegateKernel::Validate( + context, registration, runtime_feature_level, node, + /* is_accelerator_specified= */ true, + /* vendor_plugin= */ nullptr, &map_failures)) { + TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Built-in Code: %d", + registration->builtin_code); + if (supported_nodes) { + supported_nodes->push_back(node_index); + } + } else { + if (failures_by_node) { + (*failures_by_node)[node_index] = std::move(map_failures); + } + } + } + return kTfLiteOk; +} + +} // namespace tools +} // namespace tflite diff --git a/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib.h b/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib.h new file mode 100644 index 00000000000000..2fb4d03f664266 --- /dev/null +++ b/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib.h @@ -0,0 +1,102 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_TOOLS_DELEGATES_COMPATIBILITY_NNAPI_NNAPI_COMPATIBILITY_LIB_H_ +#define TENSORFLOW_LITE_TOOLS_DELEGATES_COMPATIBILITY_NNAPI_NNAPI_COMPATIBILITY_LIB_H_ + +#include +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h" + +namespace tflite { +namespace tools { + +// Check if the given TFLite flatbuffer model is compatible with NNAPI delegate. +// WARNING: This is an experimental API and subject to change. +TfLiteStatus CheckCompatibility( + TfLiteContext* context, int32_t runtime_feature_level, + std::vector* supported_nodes, + std::map>* + failures_by_node); + +// This utility delegate is required because some TfLiteContext related +// functions are forbidden if not calling in delegate. +// WARNING: This is an experimental class and subject to change. +class CompatibilityCheckerDelegate : public TfLiteDelegate { + public: + explicit CompatibilityCheckerDelegate(int32_t runtime_feature_level) + : TfLiteDelegate(TfLiteDelegateCreate()), + runtime_feature_level_(runtime_feature_level), + supported_nodes_(), + failures_by_node_() { + Prepare = DoPrepare; + CopyFromBufferHandle = DoCopyFromBufferHandle; + CopyToBufferHandle = DoCopyToBufferHandle; + FreeBufferHandle = DoFreeBufferHandle; + data_ = &delegate_data_; + } + + std::vector GetSupportedNodes() { return supported_nodes_; } + std::map> + GetFailuresByNode() { + return failures_by_node_; + } + + protected: + static TfLiteStatus DoPrepare(TfLiteContext* context, + TfLiteDelegate* delegate) { + auto self = reinterpret_cast(delegate); + TF_LITE_ENSURE_OK(context, + CheckCompatibility(context, self->runtime_feature_level_, + &(self->supported_nodes_), + &(self->failures_by_node_))); + return kTfLiteOk; + } + + // This function is not expected to be called in this delegate. + static TfLiteStatus DoCopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + return kTfLiteError; + } + + // This function is not expected to be called in this delegate. + static TfLiteStatus DoCopyToBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + return kTfLiteError; + } + + // There is no buffer handle in this delegate. + static void DoFreeBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle* handle) {} + + private: + int delegate_data_; + int runtime_feature_level_; + std::vector supported_nodes_; + std::map> + failures_by_node_; +}; + +} // namespace tools +} // namespace tflite + +#endif // TENSORFLOW_LITE_TOOLS_DELEGATES_COMPATIBILITY_NNAPI_NNAPI_COMPATIBILITY_LIB_H_ diff --git a/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib_test.cc b/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib_test.cc new file mode 100644 index 00000000000000..691768f8ef0de8 --- /dev/null +++ b/tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/tools/delegates/compatibility/nnapi/nnapi_compatibility_lib.h" + +#include +#include +#include +#include + +#include +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace tools { + +namespace { + +class AddOpModel : public SingleOpModel { + public: + AddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type, + CompatibilityCheckerDelegate* checker_delegate) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + SetDelegate(checker_delegate); + // Builds interpreter and applies delegate. + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + protected: + int input1_; + int input2_; + int output_; +}; + +} // namespace + +TEST(NnapiDelegateCompabilityTest, InvalidInput) { + EXPECT_EQ(CheckCompatibility(nullptr, 0, nullptr, nullptr), kTfLiteError); +} + +TEST(NnapiDelegateCompabilityTest, CompatibleModel) { + CompatibilityCheckerDelegate checker_delegate( + tflite::delegate::nnapi::kMinSdkVersionForNNAPI13); + AddOpModel add_op_model( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, &checker_delegate); + EXPECT_EQ(checker_delegate.GetSupportedNodes().size(), 1); + EXPECT_EQ(checker_delegate.GetFailuresByNode().size(), 0); +} + +TEST(NnapiDelegateCompabilityTest, IncompatibleModel) { + CompatibilityCheckerDelegate checker_delegate( + tflite::delegate::nnapi::kMinSdkVersionForNNAPI13); + // No activation function is supported for INT32 tensor type. + AddOpModel add_op_model( + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}, ActivationFunctionType_RELU_N1_TO_1, + &checker_delegate); + EXPECT_EQ(checker_delegate.GetSupportedNodes().size(), 0); + EXPECT_EQ(checker_delegate.GetFailuresByNode().size(), 1); +} + +} // namespace tools +} // namespace tflite diff --git a/tensorflow/lite/tools/evaluation/BUILD b/tensorflow/lite/tools/evaluation/BUILD index c2b12b7741dfc6..a3b0bc7bbaa8c9 100644 --- a/tensorflow/lite/tools/evaluation/BUILD +++ b/tensorflow/lite/tools/evaluation/BUILD @@ -55,6 +55,7 @@ cc_library( "//conditions:default": [], }) + select({ "//tensorflow:linux_s390x": [], + "//tensorflow/lite:tflite_with_xnnpack_explicit_false": [], "//conditions:default": [ "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", ], diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b39301cb36fbdd..b26f92c10d71c2 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -973,11 +973,21 @@ py_library( deps = [ ":array_ops", ":array_ops_gen", + ":control_flow_ops", + ":control_flow_util", ":math_ops", + ":math_ops_gen", + ":pywrap_tfe", + ":resource_variable_ops_gen", ":sparse_ops", "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", - "//tensorflow/python/framework", + "//tensorflow/python/client:pywrap_tf_session", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/framework:indexed_slices", + "//tensorflow/python/framework:sparse_tensor", + "//tensorflow/python/framework:tensor_util", ], ) diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index 632563ec5d92e9..303cbb56c3f300 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -84,6 +84,7 @@ def loop_body(self_x): from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.types import distribute from tensorflow.python.util import nest +from tensorflow.python.util import variable_utils PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops. @@ -288,6 +289,19 @@ def _verify_tf_loop_vars(init_vars, try: nest.assert_same_structure(init, entry, expand_composites=True) + except (ValueError, TypeError): + # `Variable`s in `init` may be implicitly converted to `Tensor`s. Convert + # `ResourceVariable`s to Tensors so tf.nest.assert_same_structure + # won't break due to type spec mismatches between `ResourceVariable`s and + # `Tensor`s. + try: + init_tensors = variable_utils.convert_variables_to_tensors(init) + nest.assert_same_structure(init_tensors, entry, expand_composites=True) + except (ValueError, TypeError) as e: + raise TypeError("'{}' does not have the same nested structure after one" + ' iteration.\n\n{}'.format(name, e)) from e + + try: nest.assert_same_structure(entry, exit_, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError("'{}' does not have the same nested structure after one" @@ -355,10 +369,20 @@ def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): for name, body_var, orelse_var in named_vars: try: nest.assert_same_structure(body_var, orelse_var, expand_composites=True) - except (ValueError, TypeError) as e: - raise TypeError( - "'{}' must have the same nested structure in the main and else" - ' branches:\n\n{}'.format(name, str(e))) from e + except (ValueError, TypeError): + # One branch of cond could be a `Tensor`, while the other branch could be + # a `ResourceVariable`. Convert `ResourceVariable`s to `Tensor`s so + # assert_same_structure won't fail. + try: + body_var_tensors = variable_utils.convert_variables_to_tensors(body_var) + orelse_var_tensors = variable_utils.convert_variables_to_tensors( + orelse_var) + nest.assert_same_structure(body_var_tensors, orelse_var_tensors, + expand_composites=True) + except (ValueError, TypeError) as e: + raise TypeError( + "'{}' must have the same nested structure in the main and else" + ' branches:\n\n{}'.format(name, str(e))) from e nest.map_structure( functools.partial(verify_single_cond_var, name), body_var, orelse_var) diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index a8888a1656cc58..0b407dfaf96350 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -65,9 +65,6 @@ py_test( srcs = ["ast_util_test.py"], python_version = "PY3", srcs_version = "PY3", - tags = [ - "no_oss_py2", - ], deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -80,9 +77,6 @@ py_test( srcs = ["cache_test.py"], python_version = "PY3", srcs_version = "PY3", - tags = [ - "no_oss_py2", - ], deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -95,9 +89,6 @@ py_test( srcs = ["cfg_test.py"], python_version = "PY3", srcs_version = "PY3", - tags = [ - "no_oss_py2", - ], deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/checkpoint/BUILD b/tensorflow/python/checkpoint/BUILD index e63f6e02b3c8c3..d412003ab656aa 100644 --- a/tensorflow/python/checkpoint/BUILD +++ b/tensorflow/python/checkpoint/BUILD @@ -176,6 +176,7 @@ py_library( deps = [ ":util", "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", "//tensorflow/python:util", "//tensorflow/python/trackable:base", "//tensorflow/python/trackable:converter", diff --git a/tensorflow/python/checkpoint/checkpoint_view.py b/tensorflow/python/checkpoint/checkpoint_view.py index 7c11b46af77f72..577b8fff20e36c 100644 --- a/tensorflow/python/checkpoint/checkpoint_view.py +++ b/tensorflow/python/checkpoint/checkpoint_view.py @@ -17,7 +17,9 @@ from tensorflow.core.protobuf import trackable_object_graph_pb2 +from tensorflow.python.checkpoint import trackable_view from tensorflow.python.framework import errors_impl +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.trackable import base from tensorflow.python.training import py_checkpoint_reader @@ -75,3 +77,50 @@ def descendants(self): all_nodes.append(child.node_id) to_visit.append(child.node_id) return all_nodes + + def match(self, trackable_object): + """Returns all matching trackables between CheckpointView and Trackable. + + Args: + trackable_object: `Trackable` root. + + Returns: + Dictionary containing all overlapping trackables that maps `node_id` to + `Trackable`. + """ + if not isinstance(trackable_object, base.Trackable): + raise ValueError(f"Expected a Trackable, got {trackable_object} of type " + "{type(trackable_object)}.") + + overlapping_nodes = {} + # Root node is always matched. + overlapping_nodes[0] = trackable_object + + # Queue of tuples of node_id and trackable. + to_visit = collections.deque([(0, trackable_object)]) + visited = set() + view = trackable_view.TrackableView(trackable_object) + while to_visit: + current_node_id, current_trackable = to_visit.popleft() + trackable_children = view.children(current_trackable) + for child_name, child_node_id in self.children(current_node_id).items(): + if child_node_id in visited or child_node_id == 0: + continue + if child_name in trackable_children: + current_assignment = overlapping_nodes.get(child_node_id) + if current_assignment is None: + overlapping_nodes[child_node_id] = trackable_children[child_name] + to_visit.append((child_node_id, trackable_children[child_name])) + else: + # The object was already mapped for this checkpoint load, which + # means we don't need to do anything besides check that the mapping + # is consistent (if the dependency DAG is not a tree then there are + # multiple paths to the same object). + if current_assignment is not trackable_children[child_name]: + logging.warning( + "Inconsistent references when matching the checkpoint into " + "this object graph. The referenced objects are: " + f"({current_assignment} and " + f"{trackable_children[child_name]}).") + visited.add(current_node_id) + return overlapping_nodes diff --git a/tensorflow/python/checkpoint/checkpoint_view_test.py b/tensorflow/python/checkpoint/checkpoint_view_test.py index 2a434f49781aca..886a6ee19e9454 100644 --- a/tensorflow/python/checkpoint/checkpoint_view_test.py +++ b/tensorflow/python/checkpoint/checkpoint_view_test.py @@ -19,15 +19,14 @@ from tensorflow.python.checkpoint import checkpoint as trackable_utils from tensorflow.python.checkpoint import checkpoint_view from tensorflow.python.eager import test -from tensorflow.python.trackable import base +from tensorflow.python.trackable import autotrackable class CheckpointViewTest(test.TestCase): def test_children(self): - root = base.Trackable() - leaf = base.Trackable() - root._track_trackable(leaf, name="leaf") + root = autotrackable.AutoTrackable() + root.leaf = autotrackable.AutoTrackable() root_ckpt = trackable_utils.Checkpoint(root=root) root_save_path = root_ckpt.save( os.path.join(self.get_temp_dir(), "root_ckpt")) @@ -38,9 +37,8 @@ def test_children(self): self.assertEqual(1, node_id) def test_all_nodes(self): - root = base.Trackable() - leaf = base.Trackable() - root._track_trackable(leaf, name="leaf") + root = autotrackable.AutoTrackable() + root.leaf = autotrackable.AutoTrackable() root_ckpt = trackable_utils.Checkpoint(root=root) root_save_path = root_ckpt.save( os.path.join(self.get_temp_dir(), "root_ckpt")) @@ -49,5 +47,55 @@ def test_all_nodes(self): self.assertEqual(0, all_nodes[0]) self.assertEqual(1, all_nodes[1]) + def test_match(self): + root1 = autotrackable.AutoTrackable() + leaf1 = root1.leaf1 = autotrackable.AutoTrackable() + leaf2 = root1.leaf2 = autotrackable.AutoTrackable() + leaf1.leaf3 = autotrackable.AutoTrackable() + leaf1.leaf4 = autotrackable.AutoTrackable() + leaf2.leaf5 = autotrackable.AutoTrackable() + root_ckpt = trackable_utils.Checkpoint(root=root1) + root_save_path = root_ckpt.save( + os.path.join(self.get_temp_dir(), "root_ckpt")) + + root2 = autotrackable.AutoTrackable() + leaf11 = root2.leaf1 = autotrackable.AutoTrackable() + leaf12 = root2.leaf2 = autotrackable.AutoTrackable() + leaf13 = leaf11.leaf3 = autotrackable.AutoTrackable() + leaf15 = leaf12.leaf5 = autotrackable.AutoTrackable() + matching_nodes = checkpoint_view.CheckpointView(root_save_path).match(root2) + self.assertDictEqual(matching_nodes, { + 0: root2, + 1: leaf11, + 2: leaf12, + 4: leaf13, + 6: leaf15 + }) + + def test_match_overlapping_nodes(self): + root1 = autotrackable.AutoTrackable() + root1.a = root1.b = autotrackable.AutoTrackable() + root_ckpt = trackable_utils.Checkpoint(root=root1) + root_save_path = root_ckpt.save( + os.path.join(self.get_temp_dir(), "root_ckpt")) + + root2 = autotrackable.AutoTrackable() + a1 = root2.a = autotrackable.AutoTrackable() + root2.b = autotrackable.AutoTrackable() + with self.assertLogs(level="WARNING") as logs: + matching_nodes = checkpoint_view.CheckpointView(root_save_path).match( + root2) + self.assertDictEqual( + matching_nodes, + { + 0: root2, + 1: a1, + # Only the first element at the same position will be matched. + }) + expected_message = ( + "Inconsistent references when matching the checkpoint into this object" + " graph.") + self.assertIn(expected_message, logs.output[0]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index c96e706e9b1c52..83f709d69994b1 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2022, 7, 4) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2022, 7, 11) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py index 3f89bd22607f0b..997c24937ff7ea 100644 --- a/tensorflow/python/compiler/tensorrt/test/base_test.py +++ b/tensorflow/python/compiler/tensorrt/test/base_test.py @@ -117,7 +117,7 @@ def ExpectedEnginesToBuild(self, run_params): } def setUp(self): - super(trt_test.TfTrtIntegrationTestBase, self).setUp() # pylint: disable=bad-super-call + super().setUp() # Disable layout optimizer, since it will convert BiasAdd with NHWC # format to NCHW format under four dimentional input. self.DisableNonTrtOptimizers() diff --git a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py index baa6b98e9666ed..f87910ca8efa1e 100644 --- a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py @@ -105,7 +105,7 @@ def GetParams(self): [[4, 6680]]) def setUp(self): - super(trt_test.TfTrtIntegrationTestBase, self).setUp() # pylint: disable=bad-super-call + super().setUp() # Disable layout optimizer, since it will convert BiasAdd with NHWC # format to NCHW format under four dimentional input. self.DisableNonTrtOptimizers() diff --git a/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py index 676b58f7c53df8..10075f38f48b4d 100644 --- a/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -66,7 +66,7 @@ def ExpectedEnginesToBuild(self, run_params): # TODO(b/176540862): remove this routine to disallow native segment execution # for TensorRT 7+. def setUp(self): - super(trt_test.TfTrtIntegrationTestBase, self).setUp() # pylint: disable=bad-super-call + super().setUp() os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" gpus = config.list_physical_devices("GPU") diff --git a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py index 3112b968a78db1..f6e26ffac02f30 100644 --- a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py +++ b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py @@ -77,7 +77,7 @@ def GetParams(self): expected_output_dims=expected_output_dims) def setUp(self): - super(trt_test.TfTrtIntegrationTestBase, self).setUp() # pylint: disable=bad-super-call + super().setUp() # Disable layout optimizer, since it will convert BiasAdd with NHWC # format to NCHW format under four dimentional input. self.DisableNonTrtOptimizers() diff --git a/tensorflow/python/compiler/tensorrt/test/int32_test.py b/tensorflow/python/compiler/tensorrt/test/int32_test.py index 0bbe99c2658816..21517e884f08d9 100644 --- a/tensorflow/python/compiler/tensorrt/test/int32_test.py +++ b/tensorflow/python/compiler/tensorrt/test/int32_test.py @@ -43,7 +43,7 @@ def GetParams(self): return self.BuildParams(self.GraphFn, dtypes.int32, [[100, 4]], [[100, 10]]) def setUp(self): - super(trt_test.TfTrtIntegrationTestBase, self).setUp() # pylint: disable=bad-super-call + super().setUp() # Disable layout optimizer, since it will convert BiasAdd with NHWC # format to NCHW format under four dimentional input. self.DisableNonTrtOptimizers() diff --git a/tensorflow/python/compiler/tensorrt/test/shape_output_test.py b/tensorflow/python/compiler/tensorrt/test/shape_output_test.py index 1e298a7aa2f8ec..2966c7b409a0a6 100644 --- a/tensorflow/python/compiler/tensorrt/test/shape_output_test.py +++ b/tensorflow/python/compiler/tensorrt/test/shape_output_test.py @@ -32,7 +32,7 @@ class ShapeOutputTest(trt_test.TfTrtIntegrationTestBase): """Test shape value output with TF-TRT.""" def setUp(self): - super(trt_test.TfTrtIntegrationTestBase, self).setUp() # pylint: disable=bad-super-call + super().setUp() self.DisableNonTrtOptimizers() def GraphFn(self, x): diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 9af372f629cbfc..0d4ea67cbd28d0 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -152,7 +152,7 @@ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name def setUp(self): """Setup method.""" - super(TfTrtIntegrationTestBase, self).setUp() + super().setUp() warnings.simplefilter("always") if not is_tensorrt_enabled(): diff --git a/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py index 3a352286a336b9..60a88271a05736 100644 --- a/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py @@ -69,7 +69,7 @@ def ExpectedEnginesToBuild(self, run_params): # TODO(b/159459919): remove this routine to disallow native segment execution. def setUp(self): - super(trt_test.TfTrtIntegrationTestBase, self).setUp() # pylint: disable=bad-super-call + super().setUp() os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" diff --git a/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py index d8b72f89f7c641..e4b87cf0247fa0 100644 --- a/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py +++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py @@ -60,7 +60,7 @@ def ExpectedEnginesToBuild(self, run_params): # TODO(b/159459919): remove this routine to disallow native segment execution. def setUp(self): - super(trt_test.TfTrtIntegrationTestBase, self).setUp() # pylint: disable=bad-super-call + super().setUp() os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 221c0b2cca1517..62cbbe117e6302 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -465,7 +465,7 @@ tf_py_test( tf_py_test( name = "parse_example_dataset_test", - size = "small", + size = "medium", srcs = ["parse_example_dataset_test.py"], shard_count = 4, deps = [ diff --git a/tensorflow/python/data/experimental/kernel_tests/service/BUILD b/tensorflow/python/data/experimental/kernel_tests/service/BUILD index abef5610f59505..61b710c0e96b0a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/service/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/service/BUILD @@ -33,10 +33,7 @@ tf_py_test( name = "coordinated_read_ft_test", size = "medium", srcs = ["coordinated_read_ft_test.py"], - shard_count = 16, - tags = [ - "nomsan", # TODO(b/235140931): Test timing out. - ], + shard_count = 8, deps = [ ":test_base", "//tensorflow:tensorflow_py", @@ -53,10 +50,7 @@ tf_py_test( name = "coordinated_read_test", size = "medium", srcs = ["coordinated_read_test.py"], - shard_count = 16, - tags = [ - "nomsan", # TODO(b/236869501): Failing. - ], + shard_count = 8, deps = [ ":test_base", "//tensorflow:tensorflow_py", diff --git a/tensorflow/python/data/kernel_tests/prefetch_test.py b/tensorflow/python/data/kernel_tests/prefetch_test.py index 23705b7f5d47fe..16dc1aa4dbc4ee 100644 --- a/tensorflow/python/data/kernel_tests/prefetch_test.py +++ b/tensorflow/python/data/kernel_tests/prefetch_test.py @@ -69,7 +69,7 @@ def map_py_fn(x): with self.cached_session() as sess: thread = self.checkedThread(self.assert_op_cancelled, args=(get_next(),)) thread.start() - time.sleep(0.5) + time.sleep(2) sess.close() thread.join() diff --git a/tensorflow/python/data/ops/options.py b/tensorflow/python/data/ops/options.py index c2fef8b603918f..06a7202187af09 100644 --- a/tensorflow/python/data/ops/options.py +++ b/tensorflow/python/data/ops/options.py @@ -267,7 +267,7 @@ class DistributeOptions(options_lib.OptionsBase): ```python options = tf.data.Options() - options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF dataset = dataset.with_options(options) ``` """ diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 816ce7db505875..4c02d5afe5c4fc 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -98,7 +98,6 @@ cuda_py_test( name = "device_util_test", srcs = ["device_util_test.py"], python_version = "PY3", - tags = ["no_oss_py2"], deps = [ ":combinations", ":device_util", @@ -440,6 +439,7 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", + "//tensorflow/python:control_flow_util", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", @@ -448,6 +448,7 @@ py_library( "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/distribute/v1:input_lib", "//tensorflow/python/eager:context", + "//tensorflow/python/framework:device", "//tensorflow/python/tpu:tpu_strategy_util", "//tensorflow/python/util:tf_export", ], @@ -2286,6 +2287,7 @@ distribute_py_test( python_version = "PY3", tags = [ "multi_and_single_gpu", + "noasan", # TODO(b/237407459) "notpu", "notsan", # Tsan failure doesn't seem to be caused by TF. ], diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index 21a84ec1b21568..0a23f14c77b12a 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -321,10 +321,14 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): _check_health_timeout = 10 def __init__(self, container_strategy, cluster_resolver, - communication_options): + communication_options, devices=None): if not isinstance(communication_options, collective_util.Options): raise ValueError("communication_options must be an instance of " "tf.distribute.experimental.CommunicationOptions") + if cluster_resolver and devices: + raise ValueError( + "cluster_resolver and devices cannot be set at the same time") + self._cluster_resolver = cluster_resolver or TFConfigClusterResolver() if not isinstance(self._cluster_resolver, ClusterResolver): raise ValueError("cluster_resolver must be an instance of " @@ -332,7 +336,7 @@ def __init__(self, container_strategy, cluster_resolver, distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) self._communication_options = communication_options self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access - self._initialize_strategy(self._cluster_resolver) + self._initialize_strategy(self._cluster_resolver, devices=devices) self._cfer_fn_cache = weakref.WeakKeyDictionary() self.experimental_enable_get_next_as_optional = True assert isinstance(self._cross_device_ops, @@ -345,11 +349,13 @@ def _use_merge_call(self): ops.get_default_graph()) or not all( [_is_gpu_device(d) for d in self._devices]) - def _initialize_strategy(self, cluster_resolver): - if cluster_resolver.cluster_spec().as_dict(): - self._initialize_multi_worker(cluster_resolver) + def _initialize_strategy(self, cluster_resolver, devices): + # If devices are provided or cluster_spec is not specified, initialize + # single worker. Otherwise initialize multi workers. + if devices or not cluster_resolver.cluster_spec().as_dict(): + self._initialize_local(cluster_resolver, devices=devices) else: - self._initialize_local(cluster_resolver) + self._initialize_multi_worker(cluster_resolver) def _initialize_local_devices(self, cluster_resolver, worker_device): # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py index 9fded908ab7a71..0acc0bcc28a124 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py @@ -25,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib from tensorflow.python.distribute import collective_all_reduce_strategy +from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils @@ -65,12 +66,13 @@ collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental) +# TODO(b/231630416): Create more tests to cover the case that strategy uses +# different number of GPUs than the number of physical devices. def create_test_objects(cluster_spec=None, task_type=None, task_id=None, num_gpus=None, num_tpus=None): - sess_config = config_pb2.ConfigProto() if num_gpus is None: num_gpus = context.num_gpus() if num_tpus is None: @@ -92,9 +94,8 @@ def create_test_objects(cluster_spec=None, strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( cluster_resolver=cluster_resolver) - sess_config = strategy.update_config_proto(sess_config) - return strategy, target, sess_config + return strategy, target class CollectiveAllReduceStrategyTestBase( @@ -106,27 +107,45 @@ def setUp(self): CollectiveAllReduceStrategy._collective_key_base += 100000 super(CollectiveAllReduceStrategyTestBase, self).setUp() - def _get_test_object(self, task_type, task_id, num_gpus=0, num_tpus=0): - strategy, target, session_config = create_test_objects( + def _get_test_object(self, + task_type, + task_id, + num_gpus=0, + num_tpus=0, + use_devices_arg=False): + strategy, target = create_test_objects( cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id, num_gpus=num_gpus, num_tpus=num_tpus) - return strategy, target, session_config + + if use_devices_arg: + devices = ['GPU:%d' % i for i in range(num_gpus)] + # Temporary workaround to manually set the `_extended` field before device + # initialization is exposed as a public interface. + strategy._extended = CollectiveAllReduceExtended( + container_strategy=strategy, + cluster_resolver=None, + communication_options=collective_util.Options(), + devices=devices) + # Manually set the field since the workaround bypasses the base + # contructor, resulting in the absence of this field. + strategy._extended._retrace_functions_for_each_device = (num_gpus > 1) + + return strategy, target def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): - d, master_target, config = self._get_test_object(task_type, task_id, - num_gpus) + distribution, master_target = self._get_test_object(task_type, task_id, + num_gpus) with ops.Graph().as_default(), \ - self.cached_session(config=config, - target=master_target) as sess, \ - d.scope(): + self.cached_session(target=master_target) as sess, \ + distribution.scope(): initializer = functools.partial( init_ops_v2.GlorotUniform(), (1, 1), dtype=dtypes.float32) kernel = variables.Variable( initial_value=initializer, - name='gpu_%d/kernel' % d.extended._num_devices_per_worker, + name='gpu_%d/kernel' % distribution.extended._num_devices_per_worker, trainable=True) def loss_fn(x): @@ -153,26 +172,27 @@ def update(v, g): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.extended.call_for_each_replica(grad_fn, args=[one]) + g_v = distribution.extended.call_for_each_replica(grad_fn, args=[one]) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.extended.read_var(v) + fetched = distribution.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. - g = d.extended.reduce_to( + g = distribution.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( - d.extended.update(v, update, args=(g,), group=False)): - after_list.append(d.extended.read_var(v)) + distribution.extended.update(v, update, args=(g,), + group=False)): + after_list.append(distribution.extended.read_var(v)) return before_list, after_list before_out, after_out = step() - if (d.extended._local_device_type == 'GPU' - and context.num_gpus() < d.extended._num_devices_per_worker): + if (distribution.extended._local_device_type == 'GPU' and + context.num_gpus() < distribution.extended._num_devices_per_worker): return True sess.run(variables.global_variables_initializer()) @@ -189,11 +209,10 @@ def step(): self.assertLess(error_after, error_before) def _test_variable_initialization(self, task_type, task_id, num_gpus): - distribution, master_target, config = self._get_test_object( - task_type, task_id, num_gpus) + distribution, master_target = self._get_test_object(task_type, task_id, + num_gpus) with ops.Graph().as_default(), \ - self.cached_session(config=config, - target=master_target) as sess, \ + self.cached_session(target=master_target) as sess, \ distribution.scope(): def model_fn(): @@ -223,14 +242,14 @@ def _test_input_fn_iterator(self, input_fn, expected_values, test_reinitialize=True, - ignore_order=False): - distribution, master_target, config = self._get_test_object( - task_type, task_id, num_gpus) + ignore_order=False, + use_devices_arg=False): + distribution, master_target = self._get_test_object( + task_type, task_id, num_gpus, use_devices_arg=use_devices_arg) devices = distribution.extended.worker_devices with ops.Graph().as_default(), \ - self.cached_session(config=config, - target=master_target) as sess: + self.cached_session(target=master_target) as sess: iterator = distribution.make_input_fn_iterator(input_fn) sess.run(iterator.initializer) @@ -276,7 +295,7 @@ def setUpClass(cls): @combinations.generate(combinations.combine(mode=['graph'])) def test_num_replicas_in_sync(self): - distribution, _, _ = create_test_objects( + distribution, _ = create_test_objects( cluster_spec=self._cluster_spec, task_type='worker', task_id=0, @@ -290,10 +309,8 @@ def test_num_replicas_in_sync(self): mode=['graph'], prefetch_to_device=[None, True])) def test_prefetch_to_device_dataset(self, prefetch_to_device): - distribution, _, _ = self._get_test_object( - task_type='worker', - task_id=0, - num_gpus=2) + distribution, _ = self._get_test_object( + task_type='worker', task_id=0, num_gpus=2) if prefetch_to_device is None: input_options = None else: @@ -314,10 +331,8 @@ def test_prefetch_to_device_dataset(self, prefetch_to_device): @combinations.generate(combinations.combine(mode=['graph'])) def test_prefetch_to_host_dataset(self): - distribution, _, _ = self._get_test_object( - task_type='worker', - task_id=0, - num_gpus=2) + distribution, _ = self._get_test_object( + task_type='worker', task_id=0, num_gpus=2) input_options = distribute_lib.InputOptions( experimental_fetch_to_device=False) dataset = dataset_ops.Dataset.range(100) @@ -383,7 +398,7 @@ def fn(): @combinations.generate(combinations.combine(mode=['graph'])) def testUpdateConfigProto(self): - strategy, _, _ = self._get_test_object( + strategy, _ = self._get_test_object( task_type='worker', task_id=1, num_gpus=2) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) @@ -435,38 +450,68 @@ class SingleWorkerCollectiveAllReduceStrategy( CollectiveAllReduceStrategyTestBase, strategy_test_lib.DistributionTestBase, strategy_test_lib.TwoDeviceDistributionTestBase, parameterized.TestCase): + @combinations.generate(combinations.combine(mode=['eager'])) + def testStrategyInitializationError(self): + with self.assertRaisesRegex( + ValueError, + 'cluster_resolver and devices cannot be set at the same time'): + _ = collective_all_reduce_strategy.CollectiveAllReduceExtended( + container_strategy=None, + cluster_resolver=multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0), + communication_options=collective_util.Options(), + devices=['GPU:0', 'GPU:1']) + @combinations.generate( - combinations.combine(mode=['graph', 'eager'], required_gpus=[0, 1, 2])) - def testMinimizeLoss(self, required_gpus): + combinations.combine( + mode=['graph', 'eager'], + required_gpus=[0, 1, 2], + use_devices_arg=[True, False])) + def testMinimizeLoss(self, required_gpus, use_devices_arg): # Collective ops doesn't support strategy with one device. if context.executing_eagerly(): - strategy, _, _ = self._get_test_object(None, None, required_gpus) + strategy, _ = self._get_test_object( + None, None, required_gpus, use_devices_arg=use_devices_arg) self._test_minimize_loss_eager(strategy) else: self._test_minimize_loss_graph(None, None, required_gpus) @combinations.generate( - combinations.combine(mode=['eager'], required_gpus=[1, 2])) - def testNumReplicasInSync(self, required_gpus): - strategy, _, _ = self._get_test_object(None, None, required_gpus) + combinations.combine( + mode=['eager'], required_gpus=[1, 2], use_devices_arg=[True, False])) + def testNumReplicasInSync(self, required_gpus, use_devices_arg): + strategy, _ = self._get_test_object( + None, None, required_gpus, use_devices_arg=use_devices_arg) self.assertEqual(required_gpus, strategy.num_replicas_in_sync) @combinations.generate( - combinations.combine(mode=['eager'], required_tpus=[0, 1, 2])) - def testMinimizeLossTPU(self, required_tpus): - strategy, _, _ = self._get_test_object(None, None, num_tpus=required_tpus) + combinations.combine( + mode=['eager'], + required_tpus=[0, 1, 2], + use_devices_arg=[True, False])) + def testMinimizeLossTPU(self, required_tpus, use_devices_arg): + strategy, _ = self._get_test_object( + None, None, num_tpus=required_tpus, use_devices_arg=use_devices_arg) self._test_minimize_loss_eager(strategy) @combinations.generate( - combinations.combine(mode=['graph', 'eager'], required_gpus=[0, 1, 2])) - def testCallAndMergeExceptions(self, required_gpus): - strategy, _, _ = self._get_test_object(None, None, num_gpus=required_gpus) + combinations.combine( + mode=['graph', 'eager'], + required_gpus=[0, 1, 2], + use_devices_arg=[True, False])) + def testCallAndMergeExceptions(self, required_gpus, use_devices_arg): + strategy, _ = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) self._test_call_and_merge_exceptions(strategy) @combinations.generate( combinations.combine( - mode=['graph'], required_gpus=2, use_dataset=[True, False])) - def testMakeInputFnIterator(self, required_gpus, use_dataset): + mode=['graph'], + required_gpus=2, + use_dataset=[True, False], + use_devices_arg=[True, False])) + def testMakeInputFnIterator(self, required_gpus, use_dataset, + use_devices_arg): if use_dataset: fn = lambda: dataset_ops.Dataset.range(5 * required_gpus) else: @@ -494,68 +539,88 @@ def fn(): ignore_order=not use_dataset) @combinations.generate( - combinations.combine(mode=['graph', 'eager'], required_gpus=[0, 1, 2])) - def testReduceToCpu(self, required_gpus): - strategy, _, _ = self._get_test_object(None, None, required_gpus) + combinations.combine( + mode=['graph', 'eager'], + required_gpus=[0, 1, 2], + use_devices_arg=[True, False])) + def testReduceToCpu(self, required_gpus, use_devices_arg): + strategy, _ = self._get_test_object( + None, None, required_gpus, use_devices_arg=use_devices_arg) with strategy.scope(): result = strategy.extended.call_for_each_replica(_replica_id_f32) reduced = strategy.reduce(reduce_util.ReduceOp.SUM, result, axis=None) expected = sum(range(strategy.num_replicas_in_sync)) self.assertEqual(expected, self.evaluate(reduced)) - @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) - def testAllReduceSum(self, required_gpus): - distribution, target, config = self._get_test_object( - None, None, num_gpus=required_gpus) - with self.cached_session(config=config, target=target): + @combinations.generate( + combinations.combine( + mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) + def testAllReduceSum(self, required_gpus, use_devices_arg): + distribution, target = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) + with self.cached_session(target=target): self._test_all_reduce_sum(distribution) - @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) - def testAllReduceSumGradients(self, required_gpus): - distribution, target, config = self._get_test_object( - None, None, num_gpus=required_gpus) - with self.cached_session(config=config, target=target): + @combinations.generate( + combinations.combine( + mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) + def testAllReduceSumGradients(self, required_gpus, use_devices_arg): + distribution, target = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) + with self.cached_session(target=target): self._test_all_reduce_sum_gradients(distribution) - @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) - def testAllReduceSumGradientTape(self, required_gpus): - distribution, target, config = self._get_test_object( - None, None, num_gpus=required_gpus) - with self.cached_session(config=config, target=target): + @combinations.generate( + combinations.combine( + mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) + def testAllReduceSumGradientTape(self, required_gpus, use_devices_arg): + distribution, target = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) + with self.cached_session(target=target): self._test_all_reduce_sum_gradient_tape(distribution) - @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) - def testAllReduceMean(self, required_gpus): - distribution, target, config = self._get_test_object( - None, None, num_gpus=required_gpus) - with self.cached_session(config=config, target=target): + @combinations.generate( + combinations.combine( + mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) + def testAllReduceMean(self, required_gpus, use_devices_arg): + distribution, target = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) + with self.cached_session(target=target): self._test_all_reduce_mean(distribution) - @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) - def testAllReduceMeanGradients(self, required_gpus): - distribution, target, config = self._get_test_object( - None, None, num_gpus=required_gpus) - with self.cached_session(config=config, target=target): + @combinations.generate( + combinations.combine( + mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) + def testAllReduceMeanGradients(self, required_gpus, use_devices_arg): + distribution, target = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) + with self.cached_session(target=target): self._test_all_reduce_mean_gradients(distribution) - @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) - def testAllReduceMeanGradientTape(self, required_gpus): - distribution, target, config = self._get_test_object( - None, None, num_gpus=required_gpus) - with self.cached_session(config=config, target=target): + @combinations.generate( + combinations.combine( + mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) + def testAllReduceMeanGradientTape(self, required_gpus, use_devices_arg): + distribution, target = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) + with self.cached_session(target=target): self._test_all_reduce_mean_gradient_tape(distribution) - @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) - def testNumpyDataset(self, required_gpus): - strategy, target, config = self._get_test_object( - None, None, num_gpus=required_gpus) + @combinations.generate( + combinations.combine( + mode=['graph'], required_gpus=2, use_devices_arg=[True, False])) + def testNumpyDataset(self, required_gpus, use_devices_arg): + strategy, target = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) self._test_numpy_dataset( - strategy, session=self.cached_session(config=config, target=target)) + strategy, session=self.cached_session(target=target)) @combinations.generate( - combinations.combine(mode=['eager'], required_gpus=2)) - def testReplicateDataset(self, required_gpus): - strategy, _, _ = self._get_test_object(None, None, num_gpus=required_gpus) + combinations.combine( + mode=['eager'], required_gpus=2, use_devices_arg=[True, False])) + def testReplicateDataset(self, required_gpus, use_devices_arg): + strategy, _ = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i, i + 1] for i in range(0, 10, 2)] input_fn = self._input_fn_to_test_input_context( @@ -565,23 +630,32 @@ def testReplicateDataset(self, required_gpus): expected_input_pipeline_id=0) self._test_input_fn_iterable(strategy, input_fn, expected_values) - @combinations.generate(combinations.combine(mode=['graph'])) - def testDeepCopy(self): - distribution, _, _ = self._get_test_object(None, None) + @combinations.generate( + combinations.combine(mode=['graph'], use_devices_arg=[True, False])) + def testDeepCopy(self, use_devices_arg): + distribution, _ = self._get_test_object( + None, None, use_devices_arg=use_devices_arg) copy.deepcopy(distribution) @combinations.generate( - combinations.combine(mode=['graph', 'eager'], required_gpus=[0, 1, 2])) - def testSummaryForReplicaZeroOnly(self, required_gpus): - strategy, target, config = self._get_test_object( - None, None, num_gpus=required_gpus) - with self.cached_session(config=config, target=target): + combinations.combine( + mode=['graph', 'eager'], + required_gpus=[0, 1, 2], + use_devices_arg=[True, False])) + def testSummaryForReplicaZeroOnly(self, required_gpus, use_devices_arg): + strategy, target = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) + with self.cached_session(target=target): self._test_summary_for_replica_zero_only(strategy) @combinations.generate( - combinations.combine(mode=['graph', 'eager'], required_gpus=[0, 1, 2])) - def testTrainableVariables(self, required_gpus): - strategy, _, _ = self._get_test_object(None, None, num_gpus=required_gpus) + combinations.combine( + mode=['graph', 'eager'], + required_gpus=[0, 1, 2], + use_devices_arg=[True, False])) + def testTrainableVariables(self, required_gpus, use_devices_arg): + strategy, _ = self._get_test_object( + None, None, num_gpus=required_gpus, use_devices_arg=use_devices_arg) self._test_trainable_variable(strategy) @@ -589,6 +663,10 @@ class LogicalDeviceTest(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine(mode=['eager'], required_gpus=1)) def testKeepLogicalDevice(self): + gpus = tf_config.list_physical_devices('GPU') + if len(gpus) > 1: + self.skipTest('Skip logical device test on multi GPUs, since partial GPU ' + 'virtualization is not permitted.') # Cannot change logical device after the context initialization. context._reset_context() # pylint: disable=protected-access cluster_spec = multi_worker_test_base.create_cluster_spec( @@ -597,7 +675,7 @@ def testKeepLogicalDevice(self): cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), task_type='worker', task_id=0) - gpus = tf_config.list_physical_devices('GPU') + logical_gpus = len(gpus) * 2 for i, device in enumerate(gpus): n = (i + 1) * logical_gpus // len(gpus) - i * logical_gpus // len(gpus) diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index c1638e9f86c549..8e6003fbc35435 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -1232,7 +1232,7 @@ def batch_reduce_sparse(): get_global_mpr(num_processes).run(replica_fn) - @combinations.generate(combinations.combine(num_processes=1, required_gpus=2)) + @combinations.generate(combinations.combine(num_processes=2, required_gpus=2)) def testNcclOrdering(self, num_processes, required_gpus): if num_processes != required_gpus: diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index c6e3bd1a952367..453f3bedc3b21b 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -252,7 +252,7 @@ def __init__(self, initializer tensor will be added to this map in addition to adding the assignment to the function. lifted_initializer_graph: FuncGraph to try to lift initializers to. - synchronization: Indicates when a distributed a variable will be + synchronization: Indicates when a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses @@ -1543,7 +1543,7 @@ def function(func=None, ## Using type annotations to improve performance - 'experimental_follow_type_hints` can be used along with type annotations to + `experimental_follow_type_hints` can be used along with type annotations to reduce retracing by automatically casting any Python values to `tf.Tensor` (something that is not done by default, unless you use input signatures). @@ -1568,7 +1568,7 @@ def function(func=None, Args: - func: the function to be compiled. If `func` is None, `tf.function` returns + func: The function to be compiled. If `func` is None, `tf.function` returns a decorator that can be invoked with a single argument - `func`. In other words, `tf.function(input_signature=...)(func)` is equivalent to `tf.function(func, input_signature=...)`. The former can be used as diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 9afc30cf3f3f5c..47317f87186c3c 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -1265,6 +1265,7 @@ py_library( deps = [ ":dtypes", "//tensorflow/core:protos_all_py", + "//tensorflow/core/function/trace_type", "//tensorflow/python:tf2", "//tensorflow/python/eager:monitoring", "//tensorflow/python/util", @@ -1857,6 +1858,7 @@ tf_py_test( ":for_generated_wrappers", ":test_lib", "//tensorflow/core:protos_all_py", + "//tensorflow/core/function/trace_type", "//tensorflow/python:platform_test", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 6672299e09a58c..adc1fb2c9deac9 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -15,11 +15,12 @@ """Helper classes for tensor shape inference.""" import functools import operator -from typing import Optional, Sequence +from typing import Optional, Sequence, Type import six from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.function import trace_type from tensorflow.python import tf2 from tensorflow.python.eager import monitoring from tensorflow.python.platform import tf_logging as logging @@ -740,7 +741,7 @@ def as_dimension(value): @tf_export("TensorShape") -class TensorShape(trace.TraceType): +class TensorShape(trace.TraceType, trace_type.Serializable): """Represents the shape of a `Tensor`. A `TensorShape` represents a possibly-partial shape specification for a @@ -1227,6 +1228,21 @@ def most_specific_common_supertype( ] return TensorShape(dims) + @classmethod + def experimental_type_proto(cls) -> Type[tensor_shape_pb2.TensorShapeProto]: + """Returns the type of proto associated with TensorShape serialization.""" + return tensor_shape_pb2.TensorShapeProto + + @classmethod + def experimental_from_proto( + cls, proto: tensor_shape_pb2.TensorShapeProto) -> "TensorShape": + """Returns a TensorShape instance based on the serialized proto.""" + return TensorShape(proto) + + def experimental_as_proto(self) -> tensor_shape_pb2.TensorShapeProto: + """Returns a proto representation of the TensorShape instance.""" + return self.as_proto() + # TODO(b/216206374): Consider deprecation at TraceType release. def is_compatible_with(self, other): """Returns True iff `self` is compatible with `other`. @@ -1421,6 +1437,8 @@ def __reduce__(self): def __concat__(self, other): return self.concatenate(other) +trace_type.register_serializable(TensorShape) + def as_shape(shape): """Converts the given object to a TensorShape.""" diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py index d112101fa5b45d..09f506acee5417 100644 --- a/tensorflow/python/framework/tensor_shape_test.py +++ b/tensorflow/python/framework/tensor_shape_test.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Functional tests for shape inference helper classes.""" from absl.testing import parameterized from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.function import trace_type from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -30,14 +30,14 @@ def testDimension(self): self.assertEqual(12, dim.value) self.assertEqual(12, int(dim)) self.assertEqual(dim, tensor_shape.Dimension(12)) - self.assertEqual(tensor_shape.Dimension(15), - dim + tensor_shape.Dimension(3)) + self.assertEqual( + tensor_shape.Dimension(15), dim + tensor_shape.Dimension(3)) self.assertEqual(tensor_shape.Dimension(15), dim + 3) self.assertEqual(tensor_shape.Dimension(15), 3 + dim) self.assertEqual(tensor_shape.Dimension(9), dim - 3) self.assertEqual(tensor_shape.Dimension(1), 13 - dim) - self.assertEqual(tensor_shape.Dimension(24), - dim * tensor_shape.Dimension(2)) + self.assertEqual( + tensor_shape.Dimension(24), dim * tensor_shape.Dimension(2)) self.assertEqual(tensor_shape.Dimension(24), dim * 2) self.assertEqual(tensor_shape.Dimension(24), 2 * dim) self.assertEqual([4] * 12, [4] * dim) @@ -47,18 +47,18 @@ def testDimension(self): tensor_shape.Dimension(6), dim // tensor_shape.Dimension(2)) self.assertEqual(tensor_shape.Dimension(6), dim // 2) self.assertEqual(tensor_shape.Dimension(0), 2 // dim) - self.assertEqual(tensor_shape.Dimension(12), - dim.merge_with(tensor_shape.Dimension(12))) + self.assertEqual( + tensor_shape.Dimension(12), dim.merge_with(tensor_shape.Dimension(12))) self.assertEqual(tensor_shape.Dimension(12), dim.merge_with(12)) self.assertLess(tensor_shape.Dimension(12), tensor_shape.Dimension(13)) self.assertGreater(tensor_shape.Dimension(13), tensor_shape.Dimension(12)) self.assertLessEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(12)) self.assertLessEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(13)) self.assertGreater(tensor_shape.Dimension(13), tensor_shape.Dimension(12)) - self.assertGreaterEqual(tensor_shape.Dimension(12), - tensor_shape.Dimension(12)) - self.assertGreaterEqual(tensor_shape.Dimension(13), - tensor_shape.Dimension(12)) + self.assertGreaterEqual( + tensor_shape.Dimension(12), tensor_shape.Dimension(12)) + self.assertGreaterEqual( + tensor_shape.Dimension(13), tensor_shape.Dimension(12)) self.assertNotEqual(dim, (12,)) with self.assertRaises(ValueError): dim.merge_with(tensor_shape.Dimension(13)) @@ -67,15 +67,18 @@ def testUnknownDimension(self): dim = tensor_shape.Dimension(None) self.assertIsNone(dim.value) self.assertEqual(dim.value, tensor_shape.Dimension(None).value) - self.assertEqual(tensor_shape.Dimension(None).value, - (dim + tensor_shape.Dimension(None)).value) - self.assertEqual(tensor_shape.Dimension(None).value, - (dim * tensor_shape.Dimension(None)).value) + self.assertEqual( + tensor_shape.Dimension(None).value, + (dim + tensor_shape.Dimension(None)).value) + self.assertEqual( + tensor_shape.Dimension(None).value, + (dim * tensor_shape.Dimension(None)).value) self.assertEqual( tensor_shape.Dimension(None).value, (dim // tensor_shape.Dimension(None)).value) - self.assertEqual(tensor_shape.Dimension(None).value, - dim.merge_with(tensor_shape.Dimension(None)).value) + self.assertEqual( + tensor_shape.Dimension(None).value, + dim.merge_with(tensor_shape.Dimension(None)).value) self.assertIsNone( tensor_shape.Dimension(None) < tensor_shape.Dimension(None)) self.assertIsNone( @@ -100,10 +103,8 @@ def testKnownAndUnknownDimensions(self): tensor_shape.Dimension(None).value, (known // unknown).value) self.assertEqual( tensor_shape.Dimension(None).value, (unknown // known).value) - self.assertEqual( - tensor_shape.Dimension(12), known.merge_with(unknown)) - self.assertEqual( - tensor_shape.Dimension(12), unknown.merge_with(known)) + self.assertEqual(tensor_shape.Dimension(12), known.merge_with(unknown)) + self.assertEqual(tensor_shape.Dimension(12), unknown.merge_with(known)) self.assertIsNone(tensor_shape.Dimension(12) < tensor_shape.Dimension(None)) self.assertIsNone( tensor_shape.Dimension(12) <= tensor_shape.Dimension(None)) @@ -118,14 +119,16 @@ def testKnownAndUnknownDimensions(self): tensor_shape.Dimension(None) >= tensor_shape.Dimension(12)) def testAsDimension(self): - self.assertEqual(tensor_shape.Dimension(12), - tensor_shape.as_dimension(tensor_shape.Dimension(12))) + self.assertEqual( + tensor_shape.Dimension(12), + tensor_shape.as_dimension(tensor_shape.Dimension(12))) self.assertEqual(tensor_shape.Dimension(12), tensor_shape.as_dimension(12)) self.assertEqual( tensor_shape.Dimension(None).value, tensor_shape.as_dimension(tensor_shape.Dimension(None)).value) - self.assertEqual(tensor_shape.Dimension(None).value, - tensor_shape.as_dimension(None).value) + self.assertEqual( + tensor_shape.Dimension(None).value, + tensor_shape.as_dimension(None).value) def testEquality(self): self.assertEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(12)) @@ -233,6 +236,21 @@ def testDiv(self): _ = 6 / two +class SerilizationTest(test_util.TensorFlowTestCase): + + def testSerialization(self): + shape_1 = tensor_shape.TensorShape([1, 2, 3]) + shape_2 = tensor_shape.TensorShape([None, 2, None]) + shape_3 = tensor_shape.TensorShape(None) + + self.assertEqual( + trace_type.deserialize(trace_type.serialize(shape_1)), shape_1) + self.assertEqual( + trace_type.deserialize(trace_type.serialize(shape_2)), shape_2) + self.assertEqual( + trace_type.deserialize(trace_type.serialize(shape_3)), shape_3) + + class ShapeTest(test_util.TensorFlowTestCase, parameterized.TestCase): def testUnknownShape(self): @@ -253,16 +271,21 @@ def testUnknownShape(self): pass def testFullyDefinedShape(self): - s = tensor_shape.TensorShape([tensor_shape.Dimension( - 3), tensor_shape.Dimension(4), tensor_shape.Dimension(7)]) + s = tensor_shape.TensorShape([ + tensor_shape.Dimension(3), + tensor_shape.Dimension(4), + tensor_shape.Dimension(7) + ]) s.assert_is_fully_defined() self.assertEqual(s.rank, 3) self.assertLen(s, 3) self.assertTrue(s) s.assert_has_rank(3) - self.assertEqual([tensor_shape.Dimension(3), - tensor_shape.Dimension(4), - tensor_shape.Dimension(7)], s.dims) + self.assertEqual([ + tensor_shape.Dimension(3), + tensor_shape.Dimension(4), + tensor_shape.Dimension(7) + ], s.dims) self.assertEqual(tensor_shape.Dimension(3), s[0]) self.assertEqual(tensor_shape.Dimension(4), s[1]) self.assertEqual(tensor_shape.Dimension(7), s[2]) @@ -273,8 +296,11 @@ def testFullyDefinedShape(self): assert tensor_shape.dimension_value(d1) == d2 def testPartiallyDefinedShape(self): - s = tensor_shape.TensorShape([tensor_shape.Dimension( - 3), tensor_shape.Dimension(None), tensor_shape.Dimension(7)]) + s = tensor_shape.TensorShape([ + tensor_shape.Dimension(3), + tensor_shape.Dimension(None), + tensor_shape.Dimension(7) + ]) # pylint: disable=g-error-prone-assert-raises with self.assertRaisesRegex(ValueError, "Shape .+ is not fully defined"): s.assert_is_fully_defined() @@ -299,10 +325,16 @@ def testMergeFullShapes(self): tensor_shape.TensorShape([6, 3, 7])) def testMergePartialShapes(self): - s1 = tensor_shape.TensorShape([tensor_shape.Dimension( - 3), tensor_shape.Dimension(None), tensor_shape.Dimension(7)]) - s2 = tensor_shape.TensorShape([tensor_shape.Dimension( - None), tensor_shape.Dimension(4), tensor_shape.Dimension(7)]) + s1 = tensor_shape.TensorShape([ + tensor_shape.Dimension(3), + tensor_shape.Dimension(None), + tensor_shape.Dimension(7) + ]) + s2 = tensor_shape.TensorShape([ + tensor_shape.Dimension(None), + tensor_shape.Dimension(4), + tensor_shape.Dimension(7) + ]) self.assertEqual([3, 4, 7], s1.merge_with(s2).as_list()) def testMergeFullAndUnknownShape(self): @@ -319,30 +351,25 @@ def testSlice(self): self.assertEqual( tensor_shape.Dimension(None).value, tensor_shape.dimension_value(unknown[2])) - tensor_shape.TensorShape( - [None, None, None]).assert_is_compatible_with(unknown[1:4]) + tensor_shape.TensorShape([None, None, + None]).assert_is_compatible_with(unknown[1:4]) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y)), - ("Add", lambda x, y: x + y), - ("RAdd", lambda x, y: y.__radd__(x))) + ("Add", lambda x, y: x + y), ("RAdd", lambda x, y: y.__radd__(x))) def testConcatenate(self, concatenate_fn): tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with( concatenate_fn( - tensor_shape.TensorShape([1, 2]), - tensor_shape.TensorShape([3, 4]))) + tensor_shape.TensorShape([1, 2]), tensor_shape.TensorShape([3, 4]))) tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with( concatenate_fn( - tensor_shape.TensorShape([1, 2]), - tensor_shape.TensorShape(None))) + tensor_shape.TensorShape([1, 2]), tensor_shape.TensorShape(None))) tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with( concatenate_fn( - tensor_shape.TensorShape(None), - tensor_shape.TensorShape([3, 4]))) + tensor_shape.TensorShape(None), tensor_shape.TensorShape([3, 4]))) tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with( concatenate_fn( - tensor_shape.TensorShape(None), - tensor_shape.TensorShape(None))) + tensor_shape.TensorShape(None), tensor_shape.TensorShape(None))) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y)), @@ -350,21 +377,16 @@ def testConcatenate(self, concatenate_fn): def testConcatenateWithDimension(self, concatenate_fn): tensor_shape.TensorShape([1, 2, 3]).assert_is_compatible_with( concatenate_fn( - tensor_shape.TensorShape([1, 2]), - tensor_shape.Dimension(3))) + tensor_shape.TensorShape([1, 2]), tensor_shape.Dimension(3))) - @parameterized.named_parameters( - ("List", [3, 4, 5]), - ("Tuple", (3, 4, 5))) + @parameterized.named_parameters(("List", [3, 4, 5]), ("Tuple", (3, 4, 5))) def testAdd_nonTensorShape(self, addend): two = tensor_shape.TensorShape([2]) result = two + addend self.assertIsInstance(result, tensor_shape.TensorShape) tensor_shape.TensorShape([2, 3, 4, 5]).assert_is_compatible_with(result) - @parameterized.named_parameters( - ("List", [2, 3, 4]), - ("Tuple", (2, 3, 4))) + @parameterized.named_parameters(("List", [2, 3, 4]), ("Tuple", (2, 3, 4))) def testRAdd_nonTensorShape(self, addend): five = tensor_shape.TensorShape([5]) result = addend + five @@ -398,20 +420,21 @@ def testTruedivFails(self): unknown / unknown # pylint: disable=pointless-statement def testConvertFromProto(self): + def make_tensor_shape_proto(shape): return tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=x) for x in shape]) + proto = make_tensor_shape_proto([]) - self.assertEqual(tensor_shape.TensorShape([]), - tensor_shape.TensorShape(proto)) - self.assertEqual(tensor_shape.TensorShape([]), - tensor_shape.as_shape(proto)) + self.assertEqual( + tensor_shape.TensorShape([]), tensor_shape.TensorShape(proto)) + self.assertEqual(tensor_shape.TensorShape([]), tensor_shape.as_shape(proto)) proto = make_tensor_shape_proto([1, 37, 42]) - self.assertEqual(tensor_shape.TensorShape([1, 37, 42]), - tensor_shape.TensorShape(proto)) - self.assertEqual(tensor_shape.TensorShape([1, 37, 42]), - tensor_shape.as_shape(proto)) + self.assertEqual( + tensor_shape.TensorShape([1, 37, 42]), tensor_shape.TensorShape(proto)) + self.assertEqual( + tensor_shape.TensorShape([1, 37, 42]), tensor_shape.as_shape(proto)) partial_proto_shape = tensor_shape.as_shape( make_tensor_shape_proto([-1, 37, 42])) @@ -443,20 +466,26 @@ def testStr(self): def testAsProto(self): self.assertTrue(tensor_shape.unknown_shape().as_proto().unknown_rank) - self.assertFalse( - tensor_shape.unknown_shape(rank=3).as_proto().unknown_rank) + self.assertFalse(tensor_shape.unknown_shape(rank=3).as_proto().unknown_rank) self.assertFalse( tensor_shape.TensorShape([1, 2, 3]).as_proto().unknown_rank) self.assertFalse( tensor_shape.TensorShape([1, None, 3]).as_proto().unknown_rank) def testEquality(self): - s1 = tensor_shape.TensorShape([tensor_shape.Dimension( - 3), tensor_shape.Dimension(4), tensor_shape.Dimension(7)]) - s2 = tensor_shape.TensorShape([tensor_shape.Dimension( - 3), tensor_shape.Dimension(4), tensor_shape.Dimension(7)]) - s3 = tensor_shape.TensorShape([tensor_shape.Dimension(3), - tensor_shape.Dimension(4), None]) + s1 = tensor_shape.TensorShape([ + tensor_shape.Dimension(3), + tensor_shape.Dimension(4), + tensor_shape.Dimension(7) + ]) + s2 = tensor_shape.TensorShape([ + tensor_shape.Dimension(3), + tensor_shape.Dimension(4), + tensor_shape.Dimension(7) + ]) + s3 = tensor_shape.TensorShape( + [tensor_shape.Dimension(3), + tensor_shape.Dimension(4), None]) self.assertEqual(s1, s2) self.assertEqual(s1, s2) @@ -479,8 +508,8 @@ def testAsList(self): "not defined on an unknown TensorShape"): tensor_shape.unknown_shape().as_list() self.assertAllEqual([None, None], tensor_shape.unknown_shape(2).as_list()) - self.assertAllEqual([2, None, 4], tensor_shape.TensorShape( - (2, None, 4)).as_list()) + self.assertAllEqual([2, None, 4], + tensor_shape.TensorShape((2, None, 4)).as_list()) def testReduce(self): shape = tensor_shape.TensorShape([2, 3]) diff --git a/tensorflow/python/framework/type_utils.py b/tensorflow/python/framework/type_utils.py index 22911a2e3867a1..0825797d38d90b 100644 --- a/tensorflow/python/framework/type_utils.py +++ b/tensorflow/python/framework/type_utils.py @@ -20,7 +20,7 @@ from tensorflow.core.framework import types_pb2 from tensorflow.python.framework import type_spec from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec -from tensorflow.python.ops.structured.structured_tensor import StructuredTensorSpec +from tensorflow.python.ops.structured.structured_tensor import StructuredTensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -114,7 +114,7 @@ def _specs_for_flat_tensors(element_spec): datasets and map_fn for ELEMENT_SPEC. The items in this list correspond to the items in `_flat_tensor_specs`. """ - if isinstance(element_spec, StructuredTensorSpec): + if isinstance(element_spec, StructuredTensor.Spec): specs = [] for _, field_spec in sorted( element_spec._field_specs.items(), key=lambda t: t[0]): # pylint: disable=protected-access diff --git a/tensorflow/python/grappler/remapper_test.py b/tensorflow/python/grappler/remapper_test.py index 0a97ddb29ffc89..055de481dca666 100644 --- a/tensorflow/python/grappler/remapper_test.py +++ b/tensorflow/python/grappler/remapper_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -74,6 +75,43 @@ def _maybe_skip(self, mode): if mode == 'mkl' and not test_util.IsMklEnabled(): self.skipTest('MKL is not enabled.') + def _VerifyValues(self, model_fn, use_low_precision, fused_op, epilog_ops): + run_options = config_pb2.RunOptions(output_partition_graphs=True) + metadata = config_pb2.RunMetadata() + # Compute reference value. + config = _get_config(remapping_on=False) + with session.Session(config=config) as sess: + sess.run(variables.global_variables_initializer()) + output_ref = sess.run( + model_fn, options=run_options, run_metadata=metadata) + # Compute output with fusion. + config = _get_config(remapping_on=True) + with session.Session(config=config) as sess: + sess.run(variables.global_variables_initializer()) + output_val = sess.run( + model_fn, options=run_options, run_metadata=metadata) + graph = metadata.partition_graphs[0] + + # Graph should contain fused op. + found_fused_op = False + for node in graph.node: + if node.op in fused_op: + fused_ops = node.attr['fused_ops'].list.s + ops_matched = len(fused_ops) >= 1 and len(fused_ops) == len(epilog_ops) + for op_a, op_b in zip(fused_ops, epilog_ops): + if op_a != op_b: + ops_matched = False + break + found_fused_op = ops_matched + break + self.assertTrue(found_fused_op) + + # Computed output value should be close to reference value. + tol = 1e-2 if use_low_precision else 1e-5 + self.assertAllClose(output_ref, output_val, atol=tol, rtol=tol) + + return graph + @parameterized.parameters(['cuda', 'mkl']) @test_util.run_deprecated_v1 @test_util.disable_xla('This test does not pass with XLA') @@ -81,8 +119,6 @@ def test_matmul_biasadd_gelu_fusion(self, mode): """Test MatMul+BiasAdd+Gelu fusion.""" self._maybe_skip(mode) is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU() - run_options = config_pb2.RunOptions(output_partition_graphs=True) - metadata = config_pb2.RunMetadata() m, n, k = (3, 3, 4) # Matrix dimensions for precision in ('float32', 'bfloat16'): @@ -109,33 +145,11 @@ def test_matmul_biasadd_gelu_fusion(self, mode): z = nn.bias_add(y, b) out = nn.gelu(z, approximate=approximate) - # Compute reference value. - config = _get_config(remapping_on=False) - with session.Session(config=config) as sess: - sess.run(variables.global_variables_initializer()) - output_val_ref = sess.run( - out, options=run_options, run_metadata=metadata) - # Compute output with fusion. - config = _get_config(remapping_on=True) - with session.Session(config=config) as sess: - sess.run(variables.global_variables_initializer()) - output_val = sess.run(out, options=run_options, run_metadata=metadata) - graph = metadata.partition_graphs[0] - - # Graph should contain fused op. - found_fused_op = False gelu_type = b'GeluApproximate' if approximate else b'GeluExact' - for node in graph.node: - if node.op in ('_MklNativeFusedMatMul', '_MklFusedMatMul'): - fused_ops = node.attr['fused_ops'].list.s - found_fused_op = len(fused_ops) == 2 and \ - fused_ops[0] == b'BiasAdd' and fused_ops[1] == gelu_type - break - self.assertTrue(found_fused_op) - - # Computed output value should be close to reference value. - tol = 1e-5 if precision == 'float32' else 1e-2 - self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) + epilog_ops = [b'BiasAdd', gelu_type] + fused_op = ['_MklNativeFusedMatMul', '_MklFusedMatMul'] + graph = self._VerifyValues(out, precision == 'bfloat16', fused_op, + epilog_ops) @test_util.run_deprecated_v1 @test_util.disable_xla('This test does not pass with XLA') @@ -143,43 +157,38 @@ def test_conv2d_biasadd_relu_fusion(self): """Test Conv2D+BiasAdd+Relu fusion.""" if not test_util.is_gpu_available(): self.skipTest('No GPU available') - run_options = config_pb2.RunOptions(output_partition_graphs=True) - metadata = config_pb2.RunMetadata() - - n, h, w, c = (5, 3, 3, 4) - - ops.reset_default_graph() - x = _input([n, c, h, w]) - w = _weight([2, 2, c, c]) - b = _bias([c]) - y = nn_ops.conv2d(x, w, strides=(1, 1), padding='SAME', data_format='NCHW') - z = nn.bias_add(y, b, data_format='NC..') - out = nn.relu(z) - - # Compute reference value. - config = _get_config(remapping_on=False) - with session.Session(config=config) as sess: - sess.run(variables.global_variables_initializer()) - output_val_ref = sess.run(out, options=run_options, run_metadata=metadata) - # Compute output with fusion. - config = _get_config(remapping_on=True) - with session.Session(config=config) as sess: - sess.run(variables.global_variables_initializer()) - output_val = sess.run(out, options=run_options, run_metadata=metadata) - graph = metadata.partition_graphs[0] - - # Graph should contain fused op. - found_fused_op = False - for node in graph.node: - if node.op == '_FusedConv2D': - found_fused_op = True - break - self.assertTrue(found_fused_op) - - # Computed output value should be close to reference value. - tol = 1e-5 - self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) + N, H, W, C = (5, 3, 3, 4) + + for precision in ('float16', 'float32'): + ops.reset_default_graph() + x_shape = [N, C, H, W] + x_format = 'NCHW' + b_format = 'NC..' + use_fp16 = precision == 'float16' + if use_fp16: + x_shape = [N, H, W, C] + x_format = 'NHWC' + b_format = 'N..C' + + x = _input(x_shape) + w = _weight([2, 2, C, C]) + b = _bias([C]) + + if use_fp16: + x = math_ops.cast(x, dtypes.float16) + w = math_ops.cast(w, dtypes.float16) + b = math_ops.cast(b, dtypes.float16) + + y = nn_ops.conv2d( + x, w, strides=(1, 1), padding='SAME', data_format=x_format) + z = nn.bias_add(y, b, data_format=b_format) + out = nn.relu(z) + out = array_ops.identity(out) + + epilog_ops = [b'BiasAdd', b'Relu'] + fused_op = ['_FusedConv2D'] + graph = self._VerifyValues(out, use_fp16, fused_op, epilog_ops) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py index a7f7b64c53feb6..c2e45ebdc54095 100644 --- a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py @@ -889,9 +889,11 @@ def testUpdateMinMax(self): constant_op.constant([1, 1, 1, 2, 1, 1, 1, 2])) def testUpdateMinMaxGradients(self): - with self.cached_session(): + + # Loop body as a function to avoid go/gpylint-faq#cell-var-from-loop. + def _TestFn(dtype): x = array_ops.ones([4], dtype=dtypes.float32) - indices = constant_op.constant([[1], [2], [3], [3]]) + indices = constant_op.constant([[1], [2], [3], [3]], dtype=dtype) updates = constant_op.constant([2.0, 0.5, 1.0, 1.0], dtype=dtypes.float32) theoretical, _ = gradient_checker_v2.compute_gradient( @@ -927,6 +929,10 @@ def testUpdateMinMaxGradients(self): dtype=dtypes.float32) self.assertAllClose(theoretical, manual, 5e-4, 5e-4) + with self.cached_session(): + for dtype in (dtypes.int32, dtypes.int64): + _TestFn(dtype) + def testTensorScatterUpdateWithForwarding(self): for dtype in (dtypes.int32, dtypes.float32): diff --git a/tensorflow/python/kernel_tests/strings_ops/BUILD b/tensorflow/python/kernel_tests/strings_ops/BUILD index 3600f39e07d5aa..e543653bf65347 100644 --- a/tensorflow/python/kernel_tests/strings_ops/BUILD +++ b/tensorflow/python/kernel_tests/strings_ops/BUILD @@ -5,10 +5,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_py_test") -package( - default_visibility = ["//tensorflow:internal"], - licenses = ["notice"], -) +package(licenses = ["notice"]) tf_py_test( name = "as_string_op_test", diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index b2dd67e69e9cfa..2f4031df13917b 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -465,7 +465,7 @@ PyArray_Descr CustomFloatTypeDescriptor::npy_descr = { /*kind=*/TypeDescriptor::kNpyDescrKind, /*type=*/TypeDescriptor::kNpyDescrType, /*byteorder=*/TypeDescriptor::kNpyDescrByteorder, - /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, + /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM, /*type_num=*/0, /*elsize=*/sizeof(T), /*alignment=*/alignof(T), @@ -484,7 +484,7 @@ template PyObject* NPyCustomFloat_GetItem(void* data, void* arr) { T x; memcpy(&x, data, sizeof(T)); - return PyCustomFloat_FromT(x).release(); + return PyFloat_FromDouble(static_cast(x)); } template diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index e0d442b57b1e07..2d91a4cc2a373f 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -177,6 +177,9 @@ def testRepr(self, float_type): self.assertEqual("%.6g" % float(float_type(value)), repr(float_type(value))) + def testItem(self, float_type): + self.assertIsInstance(float_type(0).item(), float) + def testHashZero(self, float_type): """Tests that negative zero and zero hash to the same value.""" self.assertEqual(hash(float_type(-0.0)), hash(float_type(0.0))) diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index a4a303e47f943e..5fbd9c707c11ef 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -1169,8 +1169,8 @@ def _TensorScatterMinOrMaxGrad(op, grad): x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype) y_output = array_ops.gather_nd(output, indices) y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype) - ys_indicators = array_ops.scatter_nd(indices, y_indicators, - array_ops.shape(x)) + ys_indicators = array_ops.scatter_nd( + indices, y_indicators, array_ops.shape(x, out_type=indices.dtype)) indicators = x_indicators + ys_indicators # All elements are >= 1. # If there are multiple minimum or maximum elements then the gradient will be # divided between them. diff --git a/tensorflow/python/ops/array_grad_test.py b/tensorflow/python/ops/array_grad_test.py index b779ee4bd77eca..1b1e2d4257fe88 100644 --- a/tensorflow/python/ops/array_grad_test.py +++ b/tensorflow/python/ops/array_grad_test.py @@ -114,7 +114,6 @@ def f(x): self._testGrad(f, x) - @test_util.disable_xla("b/206689921") # XLA does not support DT_INT64 def test_broadcast_to_int64(self): x = constant_op.constant([1., 2., 3.], dtype=dtypes.float64) y = constant_op.constant([2, 3], dtype=dtypes.int64) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index a59b82bdbb0f22..c3aabe4f95806f 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -53,6 +53,7 @@ from tensorflow.python.util import dispatch from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use +from tensorflow.python.util import variable_utils from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.tf_export import tf_export @@ -1025,11 +1026,15 @@ def BuildCondBranch(self, fn): if original_result is None: return no_op(), None elif not isinstance(original_result, ops.Operation): + original_result = variable_utils.convert_variables_to_tensors( + original_result) original_result = nest.map_structure( array_ops.identity, original_result, expand_composites=True) if original_result is None: return None, None + original_result = variable_utils.convert_variables_to_tensors( + original_result) result = nest.map_structure( self._BuildCondTensor, original_result, expand_composites=True) if not isinstance(result, (list, _basetuple)): @@ -2189,7 +2194,7 @@ def _BuildLoop(self, pred, body, flat_orig_loop_vars, flat_loop_vars, pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access body_result = body(*packed_vars_for_body) post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access - if not nest.is_nested_or_composite(body_result): + if not nest.is_nested(body_result): body_result = [body_result] if len(post_summaries) > len(pre_summaries): new_summaries = post_summaries[len(pre_summaries):] @@ -2206,6 +2211,7 @@ def map_fn(x): body_result = nest.map_structure( map_fn, body_result, expand_composites=True) + body_result = variable_utils.convert_variables_to_tensors(body_result) # Compare the structure types of input and output of body. # For backwards compatibility, the first layer is forced to a list # during this comparison, because inputs are typically lists and @@ -2698,6 +2704,8 @@ def while_loop(cond, if parallel_iterations < 1: raise TypeError("'parallel_iterations' must be a positive integer.") + loop_vars = variable_utils.convert_variables_to_tensors(loop_vars) + # Always enable control flow v2 if building a function, regardless of toggle. executing_eagerly = context.executing_eagerly() if (util.EnableControlFlowV2(ops.get_default_graph()) and @@ -2861,12 +2869,12 @@ def with_dependencies(dependencies, output_tensor, name=None): with ops.colocate_with(output_tensor): with ops.control_dependencies(dependencies): output_tensor = ops.convert_to_tensor_or_composite(output_tensor) - if isinstance(output_tensor, ops.Tensor): - return _Identity(output_tensor, name=name) - else: + if isinstance(output_tensor, indexed_slices.IndexedSlices): return indexed_slices.IndexedSlices( _Identity(output_tensor.values, name=name), output_tensor.indices, output_tensor.dense_shape) + else: + return _Identity(output_tensor, name=name) def _GroupControlDeps(dev, deps, name=None): diff --git a/tensorflow/python/ops/ragged/dynamic_ragged_shape.py b/tensorflow/python/ops/ragged/dynamic_ragged_shape.py index 72658394f01ea2..14fc34709178a4 100644 --- a/tensorflow/python/ops/ragged/dynamic_ragged_shape.py +++ b/tensorflow/python/ops/ragged/dynamic_ragged_shape.py @@ -191,6 +191,7 @@ class DynamicRaggedShape(extension_type.BatchableExtensionType): _inner_shape: ops.Tensor _static_inner_shape: tensor_shape.TensorShape __batch_encoder__ = _DynamicRaggedShapeBatchEncoder() + __name__ = "tf.DynamicRaggedShape" def __init__(self, row_partitions: Sequence[RowPartition], diff --git a/tensorflow/python/ops/structured/structured_array_ops.py b/tensorflow/python/ops/structured/structured_array_ops.py index df240ba057c599..80bb87eda87ce6 100644 --- a/tensorflow/python/ops/structured/structured_array_ops.py +++ b/tensorflow/python/ops/structured/structured_array_ops.py @@ -22,6 +22,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.ragged import dynamic_ragged_shape from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.ops.structured.structured_tensor import StructuredTensor @@ -29,6 +30,22 @@ from tensorflow.python.util import dispatch +@dispatch.dispatch_for_api(array_ops.shape_v2) +def shape_v2(input: StructuredTensor, out_type=dtypes.int32, # pylint: disable=redefined-builtin + name=None) -> dynamic_ragged_shape.DynamicRaggedShape: + """Returns a DynamicRaggedShape containing the shape of the input.""" + del name + return input._ragged_shape.with_dtype(out_type) # pylint: disable=protected-access + + +@dispatch.dispatch_for_api(array_ops.shape) +def shape_v1(input: StructuredTensor, name=None, # pylint: disable=redefined-builtin + out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape: + """Returns a DynamicRaggedShape containing the shape of the input.""" + del name + return input._ragged_shape.with_dtype(out_type) # pylint: disable=protected-access + + @dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor) @deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim') def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin diff --git a/tensorflow/python/ops/structured/structured_array_ops_test.py b/tensorflow/python/ops/structured/structured_array_ops_test.py index e22787c80050c8..8c711df9bd165f 100644 --- a/tensorflow/python/ops/structured/structured_array_ops_test.py +++ b/tensorflow/python/ops/structured/structured_array_ops_test.py @@ -28,6 +28,7 @@ from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import row_partition +from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape from tensorflow.python.ops.structured import structured_array_ops from tensorflow.python.ops.structured import structured_tensor from tensorflow.python.ops.structured.structured_tensor import StructuredTensor @@ -263,6 +264,22 @@ def testSizeObject(self, row_partitions, shape, dtype, expected): actual2 = array_ops.size_v2(st, out_type=dtype) self.assertAllEqual(actual2, expected) + def test_shape_v2(self): + rt = ragged_tensor.RaggedTensor.from_row_lengths(["a", "b", "c"], [1, 2]) + st = StructuredTensor.from_fields_and_rank({"r": rt}, rank=2) + actual = array_ops.shape_v2(st, out_type=dtypes.int64) + actual_static_lengths = actual.static_lengths() + self.assertAllEqual([2, (1, 2)], actual_static_lengths) + + def test_shape(self): + rt = ragged_tensor.RaggedTensor.from_row_lengths(["a", "b", "c"], [1, 2]) + st = StructuredTensor.from_fields_and_rank({"r": rt}, rank=2) + actual = array_ops.shape(st, out_type=dtypes.int64).static_lengths() + actual_v2 = array_ops.shape_v2(st, out_type=dtypes.int64).static_lengths() + expected = [2, (1, 2)] + self.assertAllEqual(expected, actual) + self.assertAllEqual(expected, actual_v2) + @parameterized.named_parameters([ dict( testcase_name="list_empty_2_1", @@ -919,14 +936,19 @@ def testStructuredTensorArrayRankOneKnownShape(self): result = structured_array_ops._structured_tensor_like(foo) self.assertAllEqual([{}, {}, {}, {}], result) + # Note that we have to be careful about whether the indices are int32 + # or int64. def testStructuredTensorArrayRankOneUnknownShape(self): """Fully test structured_tensor_array_like.""" @def_function.function def my_fun(my_shape): my_zeros = array_ops.zeros(my_shape) return structured_array_ops._structured_tensor_like(my_zeros) + result = my_fun(array_ops.constant(4)) - self.assertAllEqual([{}, {}, {}, {}], result) + shape = DynamicRaggedShape._from_inner_shape([4], dtype=dtypes.int32) + expected = StructuredTensor.from_shape(shape) + self.assertAllEqual(expected, result) def testStructuredTensorArrayRankTwoUnknownShape(self): """Fully test structured_tensor_array_like.""" diff --git a/tensorflow/python/ops/structured/structured_tensor.py b/tensorflow/python/ops/structured/structured_tensor.py index 8011b27b298292..714b8abd7a75f7 100644 --- a/tensorflow/python/ops/structured/structured_tensor.py +++ b/tensorflow/python/ops/structured/structured_tensor.py @@ -14,15 +14,14 @@ # ============================================================================== """Structured Tensors.""" -import logging import re -from typing import Callable, Dict, List, Sequence, Tuple, Union +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np -from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec @@ -31,15 +30,15 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.ragged import dynamic_ragged_shape from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.ops.ragged import row_partition as row_partition_lib from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.util import compat from tensorflow.python.util import nest -class StructuredTensor(composite_tensor.CompositeTensor): +class StructuredTensor(extension_type.BatchableExtensionType): """A multidimensional collection of structures with the same schema. A **`StructuredTensor`** is a multi-dimensional collection of ***structures*** @@ -83,7 +82,11 @@ class StructuredTensor(composite_tensor.CompositeTensor): A *field path* is a tuple of field names, specifying the path to a nested field. """ + _fields: Mapping[str, Union[ops.Tensor, ragged_tensor.RaggedTensor, + 'StructuredTensor', extension_type.ExtensionType]] + _ragged_shape: dynamic_ragged_shape.DynamicRaggedShape + __name__ = 'tf.StructuredTensor' #============================================================================= # Common Types #============================================================================= @@ -104,8 +107,13 @@ class StructuredTensor(composite_tensor.CompositeTensor): #============================================================================= # Constructor & Factory Methods #============================================================================= + def __init__(self, fields: Mapping[str, FieldValue], + ragged_shape: dynamic_ragged_shape.DynamicRaggedShape): + self._fields = fields + self._ragged_shape = ragged_shape - def __init__(self, fields, shape, nrows, row_partitions, internal=False): + @classmethod + def _old_init(cls, fields, shape, nrows, row_partitions, internal=False): """Private constructor -- use factory methods to create StructuredTensors. This constructor builds a `StructuredTensor` from the given attributes, @@ -118,21 +126,33 @@ def __init__(self, fields, shape, nrows, row_partitions, internal=False): shape: `tf.TensorShape` with statically known rank. nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`. row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`. - internal: Private key value, required to ensure that this private - constructor is *only* called from the factory methods. + internal: ignored argument. + Returns: + a StructuredTensor. """ - if internal is not _structured_tensor_factory_key: - raise ValueError('StructuredTensor constructor is private; please use ' - 'one of the factory methods instead (e.g., ' - 'StructuredTensor.from_fields())') assert isinstance(fields, dict), fields assert isinstance(shape, tensor_shape.TensorShape), shape assert nrows is None or isinstance(nrows, ops.Tensor), nrows - assert isinstance(row_partitions, tuple), row_partitions - self._fields = fields - self._shape = shape - self._nrows = nrows - self._row_partitions = row_partitions + assert row_partitions is None or isinstance(row_partitions, + tuple), row_partitions + return StructuredTensor( + fields=fields, + ragged_shape=_dynamic_ragged_shape_init(fields, shape, nrows, + row_partitions)) + + @classmethod + def from_shape( + cls, ragged_shape: dynamic_ragged_shape.DynamicRaggedShape + ) -> 'StructuredTensor': + """Creates a `StructuredTensor` with no fields and ragged_shape. + + Args: + ragged_shape: the shape of the structured tensor. + + Returns: + a StructuredTensor with no fields and ragged_shape. + """ + return StructuredTensor(fields={}, ragged_shape=ragged_shape) @classmethod def from_fields(cls, @@ -205,68 +225,26 @@ def from_fields(cls, fields = dict(fields) # Make a private copy. with ops.name_scope(None, 'StructuredTensor', fields.values()): + # TODO(martinz): Make this have better errors. + shape = _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions) + + # TODO(martinz): This may not need to be done if all fields are dense. + if shape.rank > 1: + shape = shape._with_num_row_partitions(shape.rank - 1) # Validate keys and convert field values to tensors. for key, value in fields.items(): if not isinstance(key, str): - raise TypeError('Unexpected type for key in `fields`: %r' % key) + + raise TypeError( + f'Unexpected type for key in `fields`: {key}') if not _FIELD_NAME_RE.match(key): raise ValueError('Field name %r is not currently allowed.' % key) fields[key] = _convert_to_structured_field_value(value) - # Determine dtype for row_partitions and nrows. - shape_dtype = _find_shape_dtype(fields, nrows, row_partitions) - if nrows is not None: - nrows = ops.convert_to_tensor(nrows, shape_dtype) - - # Get the static TensorShape for this StructuredTensor. - if rank > 0: - for key, value in fields.items(): - if not shape.is_compatible_with(value.shape[:rank]): - raise ValueError('Field {} has shape {}, which is incompatible ' - 'with the shape that was specified or inferred ' - 'from other fields: {}'.format( - key, value.shape[:rank], shape)) - shape = shape.merge_with(value.shape[:rank]) - - if rank == 1: - # Find a consistent value for `nrows`. - static_nrows = tensor_shape.dimension_at_index(shape, 0) - for value in fields.values(): - nrows, static_nrows = _merge_nrows(nrows, static_nrows, value, - shape_dtype, validate) - if nrows is None: - if static_nrows.value is None: - raise ValueError('nrows must be specified if rank==1 ' - 'and `fields` is empty.') - else: - nrows = constant_op.constant(static_nrows.value, shape_dtype) - - if rank > 1: - # Find a consistent list of RowPartitions. - for value in fields.values(): - row_partitions = _merge_row_partitions(row_partitions, value, rank, - shape_dtype, validate) - if row_partitions is None: - if not shape.is_fully_defined(): - raise ValueError('row_partitions must be specified if rank>1 ' - 'and `fields` is empty.') - else: - row_partitions = _row_partitions_for_uniform_shape( - np.array(shape.as_list(), dtype=shape_dtype.as_numpy_dtype), - shape.rank) - assert len(row_partitions) == rank - 1 - nrows = row_partitions[0].nrows() - # Update all field values to use the shared RowPartition objects. fields = dict([(k, _replace_row_partitions(v, row_partitions)) for (k, v) in fields.items()]) - - return cls( - fields, - shape, - nrows, - row_partitions, - internal=_structured_tensor_factory_key) + return cls(fields=fields, ragged_shape=shape) @classmethod def from_fields_and_rank(cls, fields, rank, validate=False): @@ -305,8 +283,19 @@ def from_fields_and_rank(cls, fields, rank, validate=False): raise ValueError('rank must be an integer') if rank < 0: raise ValueError('rank must be nonnegative') - return StructuredTensor.from_fields(fields, shape=[None] * rank, - validate=validate) + fields = { + k: _convert_to_structured_field_value(v) for (k, v) in fields.items() + } + dtype = _find_shape_dtype(fields, None, None) + + shape = _shape_from_fields(fields, rank, dtype) + if rank > 1: + shape = shape._with_num_row_partitions(rank - 1) + new_rp = shape._row_partitions # pylint: disable=protected-access + fields = { + k: _replace_row_partitions(v, new_rp) for (k, v) in fields.items() + } + return StructuredTensor(fields=fields, ragged_shape=shape) def with_updates( self, @@ -558,7 +547,7 @@ def promote(self, source_path, new_name): @property def rank(self): """The rank of this StructuredTensor. Guaranteed not to be `None`.""" - return self._shape.rank + return self._ragged_shape.rank @property def shape(self): @@ -570,7 +559,13 @@ def shape(self): Returns: `tf.TensorShape` """ - return self._shape + return self._ragged_shape._to_tensor_shape() # pylint: disable=protected-access + + # TODO(martinz): for backwards compatibility + @property + def _row_partitions(self): + """Deprecated form of row_partitions.""" + return self.row_partitions # TODO(edloper): Make this a func instead of a property? Or make nrows # a property instead of a func? Seems like these should be consistent. @@ -628,7 +623,9 @@ def row_partitions(self): (or `0` if `self.rank < 2`) """ - return self._row_partitions + if self.rank < 2: + return () + return self._ragged_shape._as_row_partitions() # pylint:disable=protected-access def nrows(self): """The number of rows in this StructuredTensor (if rank>0). @@ -644,7 +641,9 @@ def nrows(self): Returns: A scalar integer `Tensor` (or `None` if `self.rank == 0`). """ - return self._nrows + if self.rank == 0: + return None + return self._ragged_shape[0] def _is_eager(self): """True if all fields are composed of eager tensors.""" @@ -738,7 +737,7 @@ def __getitem__(self, key): if not key: return self - if self._shape.rank == 0: + if self.rank == 0: return self._scalar_getitem(key) else: return self._tensor_getitem(key) @@ -748,7 +747,7 @@ def _scalar_getitem(self, key): key[0].stop is None and key[0].step is None): fields = dict((field_name, field_value.__getitem__(key[1:])) for (field_name, field_value) in self._fields.items()) - return StructuredTensor.from_fields(fields, self._shape) + return StructuredTensor.from_fields(fields, self.shape) elif not isinstance(key[0], compat.bytes_or_text_types): raise ValueError('Key for indexing a StructuredTensor must be a ' @@ -757,7 +756,7 @@ def _scalar_getitem(self, key): return self._fields[key[0]].__getitem__(key[1:]) def _tensor_getitem(self, key): - rank = self._shape.rank + rank = self.rank if len(key) <= rank: new_fields = dict((field_name, field_value.__getitem__(key)) for (field_name, field_value) in self._fields.items()) @@ -791,7 +790,7 @@ def __repr__(self): return ('' % (dict_repr, self._shape)) + ' shape=%s)>' % (dict_repr, self.shape)) #============================================================================= # Conversion @@ -846,12 +845,12 @@ def to_pyval(self): result[key] = value # If rank>0, then re-group each value from dict-of-list to list-of-dict. - if len(self._shape) > 0: # pylint: disable=g-explicit-length-test + if len(self.shape) > 0: # pylint: disable=g-explicit-length-test if not result: # special-case for StructuredTensors w/ no fields. return _empty_dict_pylist_from_row_partitions(self.row_partitions, self.nrows()) return _pyval_field_major_to_node_major( - list(result.keys()), list(result.values()), self._shape.rank) + list(result.keys()), list(result.values()), self.rank) else: return result @@ -872,7 +871,7 @@ def from_pyval(cls, pyval, typespec=None): Args: pyval: The nested Python structure that should be used to create the new `StructuredTensor`. - typespec: A `StructuredTensorSpec` specifying the expected type for each + typespec: A `StructuredTensor.Spec` specifying the expected type for each field. If not specified, then all nested dictionaries are turned into StructuredTensors, and all nested lists are turned into Tensors (if rank<2) or RaggedTensors (if rank>=2). @@ -890,7 +889,7 @@ def _from_pyval(cls, pyval, typespec, path_so_far): Args: pyval: The nested Python structure that should be used to create the new `StructuredTensor`. - typespec: A `StructuredTensorSpec` specifying the expected type for each + typespec: A `StructuredTensor.Spec` specifying the expected type for each field. If not specified, then all nested dictionaries are turned into StructuredTensors, and all nested lists are turned into Tensors (if rank<2) or RaggedTensors (if rank>=2). @@ -921,7 +920,7 @@ def _from_pydict(cls, pyval, typespec, path_so_far): else: spec_shape = typespec._shape # pylint: disable=protected-access field_specs = typespec._field_specs # pylint: disable=protected-access - if not (isinstance(typespec, StructuredTensorSpec) and + if not (isinstance(typespec, StructuredTensor.Spec) and spec_shape.rank == 0 and set(pyval) == set(field_specs)): raise ValueError('Value at %r does not match typespec: %r vs %r' % (path_so_far, pyval, typespec)) @@ -940,8 +939,8 @@ def _from_pylist_of_dict(cls, pyval, keys, rank, typespec, path_so_far): for (key, target) in fields.items(): fields[key] = cls._from_pyval(target, None, path_so_far + (key,)) else: - field_specs = typespec._field_specs # pylint: disable=protected-access - if ((not isinstance(typespec, StructuredTensorSpec)) or # pylint: disable=superfluous-parens + field_specs = typespec._fields # pylint: disable=protected-access + if ((not isinstance(typespec, StructuredTensor.Spec)) or # pylint: disable=superfluous-parens (set(fields) - set(field_specs))): raise ValueError('Value at %r does not match typespec: %r vs %r' % (path_so_far, pyval, typespec)) @@ -1009,7 +1008,7 @@ def _from_pylist_of_value(cls, pyval, typespec, path_so_far): inner_shape=typespec._shape[typespec._ragged_rank + 1:]) except Exception as exc: raise ValueError('Error parsing path %r' % (path_so_far,)) from exc - elif isinstance(typespec, StructuredTensorSpec): + elif isinstance(typespec, StructuredTensor.Spec): empty_rank = _pyval_empty_list_depth(pyval) if empty_rank is None: raise ValueError('Value at %r does not match typespec: %r vs %r' % @@ -1085,7 +1084,7 @@ def merge_dims(self, outer_axis, inner_axis): + shape=(3,))> Args: outer_axis: `int`: The first dimension in the range of dimensions to @@ -1114,168 +1113,31 @@ def merge_dims(self, outer_axis, inner_axis): 'inner_axis (%d)' % (outer_axis, inner_axis)) return _merge_dims(self, outer_axis, inner_axis) - #============================================================================= - # Composite Tensor - #============================================================================= - - @property - def _type_spec(self): - return StructuredTensorSpec.from_value(self) - - -@type_spec.register('tf.StructuredTensorSpec') -class StructuredTensorSpec(type_spec.BatchableTypeSpec): - """Type specification for `StructuredTensor`s.""" - - __slots__ = ['_shape', '_field_specs'] - - def __init__(self, shape, field_specs): - """Build a type specification for a StructuredTensor. - - Args: - shape: The shape of the StructuredTensor. shape.rank must not be None. - field_specs: A dictionary mapping from field name to TypeSpec, specifying - the tensor type used to encode each field. These TypeSpecs should - specify the type of the entire field (including outer dimensions which - correspond to `shape`). For example, if `shape=[2, 3]`, and field 'x' - contains an int32 vector of size `10` for each structure, then - `field_specs['x']` should be `tf.TensorSpec([2, 3, 10], tf.int32)`. - """ - shape = tensor_shape.as_shape(shape) - - # Perform a few sanity checks on the inputs. - if shape.rank is None: - raise TypeError("StructuredTensor's shape must have known rank.") - if not isinstance(field_specs, dict): - raise TypeError('field_specs must be a dictionary.') - for key, value in field_specs.items(): - if not isinstance(key, str): - raise TypeError('field_specs must be a dictionary with string keys.') - if not isinstance(value, (StructuredTensorSpec, tensor_spec.TensorSpec, - ragged_tensor.RaggedTensorSpec)): - raise TypeError('field_specs must be a dictionary with ' - 'TypeSpec values.') - - self._shape = shape - self._field_specs = dict(field_specs) - - @property - def shape(self): - return self._shape - - @property - def value_type(self): - return StructuredTensor - - def _to_components(self, value): - nrows = () if value.nrows() is None else value.nrows() - return (value._fields, nrows, value.row_partitions) - - def _from_components(self, components): - if isinstance(components, dict): - logging.warning('Loading deprecated encoding for StructuredTensorSpec.') - return StructuredTensor.from_fields(components, self._shape, - validate=False) - elif not isinstance(components[0], dict): - logging.warning('Loading deprecated encoding for StructuredTensorSpec.') - fields = {} - nrows, row_partitions = components - if isinstance(nrows, tuple) and not nrows: - nrows = None # empty rank-0 structured tensor - return StructuredTensor.from_fields(fields, self._shape, nrows=nrows, - row_partitions=row_partitions, - validate=False) - - (fields, nrows, row_partitions) = components - if isinstance(nrows, tuple) and not nrows: - nrows = None # empty rank-0 structured tensor - return StructuredTensor(fields, self._shape, nrows, row_partitions, - internal=_structured_tensor_factory_key) - - @property - def _component_specs(self): - if self._shape.rank == 0: - nrows_spec = () - else: - nrows_spec = tensor_spec.TensorSpec([], dtypes.int64) - - row_partition_specs = ((row_partition_lib.RowPartitionSpec(),) - * (self._shape.rank - 1)) - return (self._field_specs, nrows_spec, row_partition_specs) + class Spec: + """A spec for StructuredTensor.""" - @classmethod - def from_value(cls, value): - field_specs = dict((k, type_spec.type_spec_from_value(v)) - for (k, v) in value._fields.items()) - return cls(value.shape, field_specs) - - def _serialize(self): - return (self._shape, self._field_specs) - - def _batch(self, batch_size): - # pylint: disable=protected-access - return StructuredTensorSpec( - tensor_shape.TensorShape([batch_size]).concatenate(self._shape), - dict((k, v._batch(batch_size)) for (k, v) in self._field_specs.items())) - - def _unbatch(self): - # pylint: disable=protected-access - return StructuredTensorSpec( - self._shape[1:], - dict((k, v._unbatch()) for (k, v) in self._field_specs.items())) + def __validate__(self): + assert self._ragged_shape is not None - @property - def _flat_tensor_specs(self): - # pylint: disable=protected-access - result = [] - for _, field_spec in sorted(self._field_specs.items(), key=lambda t: t[0]): - result.extend(field_spec._flat_tensor_specs) - return result + # For backwards compatibility + @property + def _shape(self) -> tensor_shape.TensorShape: + return self._ragged_shape._to_tensor_shape() # pylint: disable=protected-access - def _to_tensor_list(self, value): - return self._to_tensor_list_internal(value, batched=False) + # For backwards compatibility + @property + def _field_specs(self) -> Dict[str, type_spec.TypeSpec]: + return self._fields - def _to_batched_tensor_list(self, value): - return self._to_tensor_list_internal(value, batched=True) + # For backwards compatibility + @property + def shape(self) -> tensor_shape.TensorShape: + return self._shape - def _from_compatible_tensor_list(self, tensor_list): - # pylint: disable=protected-access - fields = {} - pos = 0 - for field_name, field_spec in sorted( - self._field_specs.items(), key=lambda t: t[0]): - num_tensors_for_field = len(field_spec._flat_tensor_specs) - field_tensors = tensor_list[pos:pos + num_tensors_for_field] - fields[field_name] = field_spec._from_compatible_tensor_list( - field_tensors) - pos += num_tensors_for_field - return StructuredTensor.from_fields(fields, self._shape) - - def _to_tensor_list_internal(self, value, batched): - """Returns a dict whose entries are each field's (batched) tensor_list. - - If a field is a StructuredTensor, then its entry will be a dict, - recursively. - - Args: - value: A StructuredTensor (conforming to `self`). - batched: A boolean. if True, produce `batched_tensor_list` for each field - otherwise produce `tensor_list`. - - Returns: - A dict. - """ - result = [] - for field_name, field_spec in sorted( - self._field_specs.items(), key=lambda t: t[0]): - # pylint: disable=protected-access - field_value = value._fields[field_name] - if batched: - result.extend(field_spec._to_batched_tensor_list(field_value)) - else: - result.extend(field_spec._to_tensor_list(field_value)) - - return result + # For backwards compatibility + @property + def rank(self): + return self._ragged_shape.rank # Regular expression used to determine whether a string is a valid field name. @@ -1296,6 +1158,8 @@ def _convert_to_structured_field_value(value): return value elif ragged_tensor.is_ragged(value): return ragged_tensor.convert_to_tensor_or_ragged_tensor(value) + elif isinstance(value, extension_type.ExtensionType): + return value else: try: return ops.convert_to_tensor(value) @@ -1306,19 +1170,44 @@ def _convert_to_structured_field_value(value): def _find_shape_dtype(fields, nrows, row_partitions): """Return a consistent dtype for fields, nrows, & row_partitions.""" - shape_dtypes = set() - for value in fields.values(): + field_dtypes = dict() + for (key, value) in fields.items(): if isinstance(value, ragged_tensor.RaggedTensor): - shape_dtypes.add(value.row_splits.dtype) + field_dtypes[key] = value.row_splits.dtype elif isinstance(value, StructuredTensor) and value.rank > 0: - shape_dtypes.add(value.nrows().dtype) - if isinstance(nrows, ops.Tensor): - shape_dtypes.add(nrows.dtype) + field_dtypes[key] = value.nrows().dtype + + field_dtype = None + for value in field_dtypes.values(): + if field_dtype is None: + field_dtype = value + elif field_dtype != value: + raise ValueError('field values have incompatible row_partition dtypes. ' + + f'field_dtypes: {field_dtypes}') + + row_partition_dtype = None + row_partition_dtypes = [] if row_partitions is not None: - for partition in row_partitions: - shape_dtypes.add(partition.dtype) + row_partition_dtypes = [rp.dtype for rp in row_partitions] + for rp_dtype in row_partition_dtypes: + if row_partition_dtype is None: + row_partition_dtype = rp_dtype + elif row_partition_dtype != rp_dtype: + raise ValueError('row_partitions have incompatible dtypes with ' + f'themselves:{row_partition_dtypes}') + + nrows_dtype = None + if isinstance(nrows, ops.Tensor): + nrows_dtype = nrows.dtype + all_dtypes = filter(lambda x: x is not None, + [field_dtype, row_partition_dtype, nrows_dtype]) + shape_dtypes = set() + shape_dtypes.update(all_dtypes) if len(shape_dtypes) > 1: - raise ValueError('field values have incompatible row_partition dtypes.') + raise ValueError('row_partition dtypes are inconsistent: ' + + f'field_dtype:{field_dtype} ' + + f'row_partition_dtype:{row_partition_dtype} ' + + f'nrows_dtype:{nrows_dtype}') elif shape_dtypes: return shape_dtypes.pop() else: @@ -1359,7 +1248,7 @@ def _merge_nrows(nrows, static_nrows, value, dtype, validate): check_ops.assert_equal( nrows, value_nrows, message='fields have incompatible nrows') ], nrows) - return nrows, static_nrows.merge_with(static_value_nrows) + return nrows, static_nrows._merge_with(static_value_nrows) # pylint: disable=protected-access def _merge_row_partitions(row_partitions, value, rank, dtype, validate): @@ -1575,13 +1464,12 @@ def _replace_row_partitions(value, new_partitions): assert isinstance(value, StructuredTensor) new_fields = dict((k, _replace_row_partitions(v, new_partitions)) for (k, v) in value._fields.items()) - return StructuredTensor( + return StructuredTensor._old_init( # pylint: disable=protected-access fields=new_fields, shape=value.shape, nrows=value.nrows(), - row_partitions=new_partitions + - value.row_partitions[len(new_partitions):], - internal=_structured_tensor_factory_key) + row_partitions=tuple(new_partitions) + + tuple(value.row_partitions[len(new_partitions):])) def _partition_outer_dimension(value, row_partition): @@ -1628,11 +1516,10 @@ def _partition_outer_dimension(value, row_partition): ncols]).concatenate(value.shape[1:]) fields = dict((k, _partition_outer_dimension(v, row_partition)) for (k, v) in value._fields.items()) - return StructuredTensor( + return StructuredTensor._old_init( # pylint: disable=protected-access fields, shape, - row_partition.nrows(), (row_partition,) + value.row_partitions, - internal=_structured_tensor_factory_key) + row_partition.nrows(), (row_partition,) + value.row_partitions) def _merge_dims(value, outer_axis, inner_axis): @@ -1642,45 +1529,25 @@ def _merge_dims(value, outer_axis, inner_axis): return ragged_tensor.merge_dims(value, outer_axis, inner_axis) else: assert isinstance(value, StructuredTensor) - - # Build the new fields. fields = dict((k, _merge_dims(v, outer_axis, inner_axis)) for (k, v) in value._fields.items()) + ragged_shape = value._ragged_shape._merge_dims( # pylint: disable=protected-access + outer_axis, inner_axis) + return StructuredTensor(fields, ragged_shape) - # Build the new shape. - value_shape = value.shape - shape = (value_shape[:outer_axis] + [None] + value_shape[inner_axis + 1:]) - - # Build the new row_partitions & nrows - if outer_axis == 0: - if inner_axis == value.shape.rank - 1: - partitions = () - nrows = value.row_partitions[-1].nvals() - else: - partitions = value.row_partitions[inner_axis:] - nrows = partitions[0].nrows() - else: - # Use tf.gather to merge row_splits from the merged row partitions. - merged_splits = value.row_partitions[outer_axis - 1].row_splits() - for dim in range(outer_axis, inner_axis): - merged_splits = array_ops.gather(value.row_partitions[dim].row_splits(), - merged_splits) - - partitions = ( - value.row_partitions[:outer_axis - 1] + - (RowPartition.from_row_splits(merged_splits),) + - value.row_partitions[inner_axis:]) - nrows = partitions[0].nrows() - return StructuredTensor( - fields, - shape, - nrows, - partitions, - internal=_structured_tensor_factory_key) +_structured_tensor_factory_key = object() # unique private object -_structured_tensor_factory_key = object() # unique private object +def _dynamic_ragged_shape_spec_from_spec( + spec: Union[dynamic_ragged_shape.DynamicRaggedShape.Spec, + ragged_tensor.RaggedTensorSpec, StructuredTensor.Spec, + tensor_spec.TensorSpec] +) -> dynamic_ragged_shape.DynamicRaggedShape.Spec: + if isinstance(spec, StructuredTensor.Spec): + return spec._ragged_shape # pylint: disable=protected-access + else: + return dynamic_ragged_shape.DynamicRaggedShape.Spec._from_spec(spec) # pylint: disable=protected-access def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]: @@ -1721,3 +1588,136 @@ def _merge_dims_generic(source, outer, inner): return source.merge_dims(outer, inner) else: return ragged_tensor.merge_dims(source, outer, inner) + + +def _dynamic_ragged_shape_from_tensor( + field, dtype=None) -> dynamic_ragged_shape.DynamicRaggedShape: + """Extension of DynamicRaggedShape.from_tensor to support StructuredTensor.""" + if isinstance(field, StructuredTensor): + return field._ragged_shape # pylint: disable=protected-access + shape = array_ops.shape_v2(field, out_type=dtype) + + if isinstance(shape, ops.Tensor): + return dynamic_ragged_shape.DynamicRaggedShape( + row_partitions=[], + inner_shape=shape) + elif isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape): + return shape + # TODO(martinz): add a test for the following line. + raise TypeError(f'Expected shape tf.shape({field}) to return a Tensor or a ' + f'DynamicRaggedShape. Instead, got: {shape}.') + + +def _merge_with_optional( + a: Optional[dynamic_ragged_shape.DynamicRaggedShape], + b: Optional[dynamic_ragged_shape.DynamicRaggedShape] + ) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]: + if a is None: + return b + if b is None: + return a + return a._merge_with(b) # pylint: disable=protected-access + + +def _shape_from_fields( + fields, rank: int, + dtype: dtypes.DType) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]: + """Given fields, rank, and dtype, create a shape.""" + + field_shape = None + for (k, field) in fields.items(): + try: + next_field_shape_raw = _dynamic_ragged_shape_from_tensor( + field, dtype=dtype) + next_field_shape = next_field_shape_raw[:rank] + field_shape = _merge_with_optional(field_shape, next_field_shape) + except Exception as err: + raise ValueError(f'Error in shape of {k}') from err + + return field_shape + + +# pylint:disable=protected-access +def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions): + """Produce a DynamicRaggedShape for StructuredTensor.""" + assert isinstance(fields, dict), fields + assert isinstance(shape, tensor_shape.TensorShape), shape + assert nrows is None or isinstance(nrows, ops.Tensor) or isinstance( + nrows, int), nrows + assert row_partitions is None or isinstance(row_partitions, + tuple), row_partitions + rank = shape.rank + + if rank is None: + raise TypeError("StructuredTensor's shape must have known rank.") + + # TODO(martinz): figure out whether to validate. + dtype = _find_shape_dtype(fields, nrows, row_partitions) + result = None + if shape.is_fully_defined(): + result = dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( + shape.as_list(), dtype=dtype) + + if rank == 0: + return dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( + array_ops.zeros((0,), dtype=dtype)) + + result = _merge_with_optional(result, _shape_from_fields(fields, rank, dtype)) + if rank == 1: + alt_value = tensor_shape.dimension_value(shape[0]) + if alt_value is not None: + nrows = alt_value + if nrows is not None: + result = _merge_with_optional( + result, + dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( + [nrows], dtype=dtype)) + if result is None: + raise ValueError('Must specify `nrows`, a fully specified `shape`,' + + ' or have `fields` if `rank=1`') + + return result + + if row_partitions: + result = _merge_with_optional( + result, dynamic_ragged_shape.DynamicRaggedShape.from_row_partitions( + row_partitions, dtype=dtype)) + + if result is None: + raise ValueError('Must specify row_partitions, a fully specified shape, ' + + 'or have fields if rank > 1') + return result + + +# TODO(martinz): Drop this method or rename. +def StructuredTensorSpec(shape, field_specs): # pylint:disable=invalid-name + """A placeholder for the old StructuredTensorSpec.""" + if not isinstance(field_specs, dict): + raise TypeError('field_specs must be a dictionary.') + for k in field_specs.keys(): + if not isinstance(k, str): + raise TypeError('field_specs must be a dictionary with string keys.') + for v in field_specs.values(): + if not isinstance(v, type_spec.TypeSpec): + raise TypeError('field_specs must be a dictionary with TypeSpec values.') + + shape = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( + tensor_shape.as_shape(shape), + 0, + dtypes.int32) + rank = shape.rank + if rank is None: + raise TypeError("StructuredTensor's shape must have known rank.") + for (k, v) in field_specs.items(): + field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v) + if field_shape_untruncated is None: + raise ValueError(f'Cannot convert spec of {k}.') + untruncated_rank = field_shape_untruncated.rank + if (untruncated_rank is not None + and untruncated_rank < rank): + raise ValueError( + f'Rank of field {k} is {untruncated_rank},' + f' but must be at least {rank}.') + field_shape = field_shape_untruncated._truncate(rank) + shape = shape._merge_with(field_shape) + return StructuredTensor.Spec(_ragged_shape=shape, _fields=field_specs) diff --git a/tensorflow/python/ops/structured/structured_tensor_spec_test.py b/tensorflow/python/ops/structured/structured_tensor_spec_test.py index 78e8973aaff3f4..9cc949e9174e48 100644 --- a/tensorflow/python/ops/structured/structured_tensor_spec_test.py +++ b/tensorflow/python/ops/structured/structured_tensor_spec_test.py @@ -20,10 +20,10 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.framework import type_spec +from tensorflow.python.ops.ragged import dynamic_ragged_shape from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import row_partition @@ -105,13 +105,20 @@ def testValueType(self): @parameterized.parameters([ (StructuredTensorSpec([1, 2, 3], {}), - (tensor_shape.TensorShape([1, 2, 3]), {})), + (('_fields', {}), + ('_ragged_shape', + dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( + [1, 2, 3], num_row_partitions=0, dtype=dtypes.int32)))), (StructuredTensorSpec([], {'a': T_1_2}), - (tensor_shape.TensorShape([]), {'a': T_1_2})), + (('_fields', {'a': T_1_2}), + ('_ragged_shape', + dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( + [], num_row_partitions=0, dtype=dtypes.int64)))), (StructuredTensorSpec([1, 2], {'a': T_1_2, 'b': R_1_N}), - (tensor_shape.TensorShape([1, 2]), {'a': T_1_2, 'b': R_1_N})), - (StructuredTensorSpec([], {'a': T_1_2}), - (tensor_shape.TensorShape([]), {'a': T_1_2})), + (('_fields', {'a': T_1_2, 'b': R_1_N}), + ('_ragged_shape', + dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( + [1, 2], num_row_partitions=1, dtype=dtypes.int64)))), ]) # pyformat: disable def testSerialize(self, spec, expected): serialization = spec._serialize() @@ -122,13 +129,15 @@ def testSerialize(self, spec, expected): @parameterized.parameters([ (StructuredTensorSpec([1, 2, 3], {}), - ({}, NROWS_SPEC, (PARTITION_SPEC, PARTITION_SPEC))), + (dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( + [1, 2, 3], num_row_partitions=0, dtype=dtypes.int32),)), (StructuredTensorSpec([], {'a': T_1_2}), - ({'a': T_1_2}, (), ())), + (tensor_spec.TensorSpec(shape=(1, 2), dtype=dtypes.float32, name=None), + dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( + [], num_row_partitions=0, dtype=dtypes.int64),)), (StructuredTensorSpec([1, 2], {'a': T_1_2, 'b': R_1_N}), - ({'a': T_1_2, 'b': R_1_N}, NROWS_SPEC, (PARTITION_SPEC,))), - (StructuredTensorSpec([], {'a': T_1_2}), - ({'a': T_1_2}, (), ())), + (T_1_2, R_1_N, dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( + [1, 2], num_row_partitions=1, dtype=dtypes.int64))), ]) # pyformat: disable def testComponentSpecs(self, spec, expected): self.assertEqual(spec._component_specs, expected) @@ -144,7 +153,7 @@ def testComponentSpecs(self, spec, expected): 'fields': dict( a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]), b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), - 'field_specs': dict(a=R_1_N, b=T_2_3), + 'field_specs': dict(a=R_2_N, b=T_2_3), }, ]) # pyformat: disable def testToFromComponents(self, shape, fields, field_specs): diff --git a/tensorflow/python/ops/structured/structured_tensor_test.py b/tensorflow/python/ops/structured/structured_tensor_test.py index e222147ad9b2cb..90664d9e284841 100644 --- a/tensorflow/python/ops/structured/structured_tensor_test.py +++ b/tensorflow/python/ops/structured/structured_tensor_test.py @@ -22,7 +22,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors +from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape @@ -40,6 +40,36 @@ from tensorflow.python.ops.structured import structured_tensor_dynamic from tensorflow.python.ops.structured.structured_tensor import StructuredTensor from tensorflow.python.platform import googletest +from tensorflow.python.util import dispatch + + +class _PrivateSpecialType(extension_type.ExtensionType): + ragged: ragged_tensor.RaggedTensor + + +@dispatch.dispatch_for_types(array_ops.shape_v2, _PrivateSpecialType) +def shape_v2_special(input: _PrivateSpecialType, out_type=dtypes.int32, # pylint: disable=redefined-builtin + name=None): + """Returns a DynamicRaggedShape containing the shape of the input.""" + del name + return array_ops.shape_v2(input.ragged, out_type) # pylint: disable=protected-access + + +class _PrivateBrokenType(extension_type.ExtensionType): + ragged: ragged_tensor.RaggedTensor + + +@dispatch.dispatch_for_types(array_ops.shape_v2, _PrivateBrokenType) +def shape_v2_broken(input: _PrivateBrokenType, out_type=dtypes.int32, # pylint: disable=redefined-builtin + name=None): + """Returns a DynamicRaggedShape containing the shape of the input.""" + del name + del input + del out_type + return { + "foo": "This is not a shape", + "bar": "But if I put a string here, it becomes a vector" + } # pylint: disable=g-long-lambda @@ -73,11 +103,6 @@ def _assertStructuredEqual(self, a, b, msg, check_shape): else: self.assertAllEqual(a_value, b_value, msg) - def testConstructorIsPrivate(self): - with self.assertRaisesRegex(ValueError, - "StructuredTensor constructor is private"): - structured_tensor.StructuredTensor({}, (), None, ()) - @parameterized.named_parameters([ # Scalar (rank=0) StructuredTensors. { @@ -584,8 +609,7 @@ def testFromFields(self, fields=dict(x=[1], y=[]), shape=[None], err=ValueError, - msg=r"Field . has shape .*, which is incompatible with the shape " - r"that was specified or inferred from other fields: .*"), + msg=r"Error in shape of y"), dict( fields={"": 5}, shape=[], @@ -603,8 +627,8 @@ def testFromFields(self, }, shape=[2, None], validate=True, - err=errors.InvalidArgumentError, - msg=r"incompatible row_splits", + err=ValueError, + msg=r"Error in shape of r2", ), dict( fields={}, @@ -628,13 +652,14 @@ def testFromFields(self, fields={}, shape=[None], err=ValueError, - msg="nrows must be specified if rank==1 and `fields` is empty."), + msg="Must specify `nrows`, a fully specified `shape`, " + "or have `fields` if `rank=1`"), dict( fields={}, shape=[None, None], err=ValueError, - msg="row_partitions must be specified if rank>1 and `fields` " - "is empty."), + msg="Must specify row_partitions, a fully specified shape, " + "or have fields if rank > 1"), dict( fields={}, shape=[None, None], @@ -642,7 +667,7 @@ def testFromFields(self, row_partitions=lambda: [row_partition.RowPartition.from_row_lengths([3, 4])], err=ValueError, - msg="field values have incompatible row_partition dtypes"), + msg="row_partition dtypes are inconsistent"), dict( fields=lambda: { "a": @@ -655,18 +680,21 @@ def testFromFields(self, shape=[None, None], err=ValueError, msg="field values have incompatible row_partition dtypes"), - dict( - fields=lambda: { - "a": - array_ops.placeholder_with_default(np.array([1, 2, 3]), None), - "b": - array_ops.placeholder_with_default(np.array([4, 5]), None) - }, - validate=True, - shape=[None], - err=(ValueError, errors.InvalidArgumentError), - msg="fields have incompatible nrows", - test_in_eager=False), + # Currently, this doesn't throw an error. + # That's fine. + # dict( + # fields=lambda: { + # "a": + # array_ops.placeholder_with_default(np.array([1, 2, 3]), + # None), + # "b": + # array_ops.placeholder_with_default(np.array([4, 5]), None) + # }, + # validate=True, + # shape=[None], + # err=(ValueError, errors.InvalidArgumentError), + # msg="Error in shape of b", + # test_in_eager=False), ]) def testFromFieldsErrors(self, fields, @@ -774,6 +802,24 @@ def testPartitionOuterDimension3(self): row_partition.RowPartition.from_row_splits([0, 1, 2])) self.assertEqual(3, struct_3.rank) + def testWithPrivateSpecialType(self): + rt = ragged_tensor.RaggedTensor.from_value_rowids( + array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) + pst = _PrivateSpecialType(rt) + pst_shape = array_ops.shape_v2(pst) + st = structured_tensor.StructuredTensor.from_fields_and_rank({"r": pst}, 1) + st_shape = st._ragged_shape + self.assertEqual(1, st.rank) + self.assertAllEqual(pst_shape[0], st_shape[0]) + + def testWithPrivateBrokenType(self): + rt = ragged_tensor.RaggedTensor.from_value_rowids( + array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) + pbt = _PrivateBrokenType(rt) + + with self.assertRaisesRegex(ValueError, "Error in shape of r"): + structured_tensor.StructuredTensor.from_fields_and_rank({"r": pbt}, 1) + def testPartitionOuterDimsErrors(self): st = StructuredTensor.from_fields({}) partition = row_partition.RowPartition.from_row_splits([0]) @@ -914,6 +960,13 @@ def testPyvalConversion(self, pyval, expected, type_spec=None): if context.executing_eagerly(): # to_pyval only available in eager. self.assertEqual(actual.to_pyval(), pyval) + def testStructuredTensorSpecFactory(self): + spec = structured_tensor.StructuredTensorSpec([], { + "a": tensor_spec.TensorSpec([], dtypes.int32), + "b": tensor_spec.TensorSpec([None], dtypes.int32), + "c": ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)}) + self.assertEqual(spec.rank, 0) + @parameterized.named_parameters([ dict( testcase_name="NoFieldsRaggedRank0", @@ -1013,7 +1066,6 @@ def testToPyval(self, st, expected): pyval=[{"a": 1}, {"a": "c"}], type_spec=None, msg=r"Error parsing path \('a',\)"), - dict(testcase_name="TypeSpecMismatch_ListSparse", pyval=[1, 2], type_spec=sparse_tensor.SparseTensorSpec([None], dtypes.int32), diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index a9da0511d84d50..44d6ba0f21f36e 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -49,6 +49,7 @@ from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import object_identity +from tensorflow.python.util import variable_utils # pylint: disable=protected-access @@ -77,6 +78,7 @@ def while_loop(cond, return_same_structure=True, back_prop=True): """Like tf.while_loop, except emits a single While op.""" + loop_vars = variable_utils.convert_variables_to_tensors(loop_vars) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars flat_orig_loop_vars = nest.flatten(orig_loop_vars, expand_composites=True) @@ -197,13 +199,20 @@ def wrapped_body(loop_counter, maximum_iterations_arg, *args): # structure of `loop_vars_signature`. outputs = body( *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)) - if not nest.is_nested_or_composite(outputs): + if not nest.is_nested(outputs): outputs = [outputs] - # Compare the structure of input and output of body converting the - # top-level tuples to list to be compatible with legacy while_loop. - nest.assert_same_structure(list(outputs), list(orig_loop_vars), - expand_composites=True) - + try: + # The legacy while_loop considers list and tuple to be the same + # structure. + nest.assert_same_structure(outputs, orig_loop_vars, check_types=False, + expand_composites=True) + except ValueError: + # Traditionally we consider variables and tensors to be the same + # structure. + vars1 = variable_utils.convert_variables_to_tensors(outputs) + vars2 = variable_utils.convert_variables_to_tensors(orig_loop_vars) + nest.assert_same_structure(vars1, vars2, check_types=False, + expand_composites=True) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index f41e1d2c652de1..f20a68e0a4f0b6 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -32,7 +32,6 @@ py_test( python_version = "PY3", srcs_version = "PY3", tags = [ - "no_oss_py2", "no_oss_py35", "no_pip", ], diff --git a/tensorflow/python/tpu/client/BUILD b/tensorflow/python/tpu/client/BUILD index bb95d4f605798c..b7b657fcfca412 100644 --- a/tensorflow/python/tpu/client/BUILD +++ b/tensorflow/python/tpu/client/BUILD @@ -40,9 +40,6 @@ tf_py_test( grpc_enabled = True, main = "client_test.py", python_version = "PY3", - tags = [ - "no_oss_py2", - ], deps = [ ":client", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index 616e130a52b2fe..3d0963e0005529 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -1677,7 +1677,8 @@ def merge_caches_on_tpu(self, local_tpu_cache_tensor): local_tpu_cache_tensor.shape.as_list()) return tpu_ops.all_to_all( x, concat_dimension=0, split_dimension=0, - split_count=self._tt_config.num_replicas) + split_count=self._tt_config.num_replicas, + group_assignment=[list(range(self._tt_config.num_replicas))]) def aggregate_global_cache(self, global_tt_summary_cache): """Merges the given caches on tpu. diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index 41110281a01463..79f4dc90eccb44 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -370,6 +370,7 @@ cc_library( deps = [ ":platform", ":stream_executor_headers", + "//tensorflow/core/platform:errors", "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", "@com_google_absl//absl/base:core_headers", diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index b5987ff201843d..01a0df74245e17 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -82,12 +82,12 @@ std::string ComputationTypeString(ComputationType ty) { return "f64"; case ComputationType::kI32: return "i32"; - case ComputationType::kComplexF32: - return "complex f32"; - case ComputationType::kComplexF64: - return "complex f64"; - default: - LOG(FATAL) << "Unknown ComputationType " << static_cast(ty); + case ComputationType::kF16AsF32: + return "f16 (w/ f32 accumulation)"; + case ComputationType::kBF16AsF32: + return "bf16 (w/ f32 accumulation)"; + case ComputationType::kTF32AsF32: + return "tf32 (w/ f32 accumulation)"; } } diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 118990f30efedb..206d67f0a989e6 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -100,53 +100,20 @@ std::string SideString(Side s); // the type of their inputs/outputs. This lets you e.g. multiply two matrices // of int8s using float32s to store the matmul's intermediate values. enum class ComputationType { - kF16, // 16-bit floating-point - kF32, // 32-bit floating-point - kF64, // 64-bit floating-point - kI32, // 32-bit integer - kComplexF32, // Complex number comprised of two f32s. - kComplexF64, // Complex number comprised of two f64s. - // The below values are only supported for BlasLt routines (both real and - // complex). They use float32 for accumulation but round the input mantissas - // to a smaller number of bits. - kTF32AsF32, // 32-bit floating-point with reduced (>=10-bit) mantissa - kBF16AsF32, // 32-bit floating-point with reduced (7-bit) mantissa + kF16, // 16-bit floating-point + kF32, // 32-bit floating-point + kF64, // 64-bit floating-point + kI32, // 32-bit integer + // The below values use float32 for accumulation, but allow the inputs and + // outputs to be downcast to a lower precision: + kF16AsF32, // Allow downcast to F16 precision. + kBF16AsF32, // Allow downcast to BF16 precision. + kTF32AsF32, // Allow downcast to TF32 precision. }; // Converts a ComputationType to a string. std::string ComputationTypeString(ComputationType ty); -template -struct ToComputationType; -template <> -struct ToComputationType { - static constexpr ComputationType value = ComputationType::kF32; -}; -template <> -struct ToComputationType { - static constexpr ComputationType value = ComputationType::kF64; -}; -template <> -struct ToComputationType { - static constexpr ComputationType value = ComputationType::kF16; -}; -template <> -struct ToComputationType { - static constexpr ComputationType value = ComputationType::kBF16AsF32; -}; -template <> -struct ToComputationType { - static constexpr ComputationType value = ComputationType::kI32; -}; -template <> -struct ToComputationType> { - static constexpr ComputationType value = ComputationType::kComplexF32; -}; -template <> -struct ToComputationType> { - static constexpr ComputationType value = ComputationType::kComplexF64; -}; - std::ostream &operator<<(std::ostream &os, ComputationType ty); using dnn::DataType; diff --git a/tensorflow/stream_executor/cuda/cublas_11_0.inc b/tensorflow/stream_executor/cuda/cublas_11_0.inc index d30927b1271291..ebc4ec296934be 100644 --- a/tensorflow/stream_executor/cuda/cublas_11_0.inc +++ b/tensorflow/stream_executor/cuda/cublas_11_0.inc @@ -5030,4 +5030,11 @@ void CUBLASWINAPI cublasZtrmm(char side, char uplo, char transa, char diag, return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } +CUBLASAPI const char* CUBLASWINAPI cublasGetStatusString(cublasStatus_t status) { + using FuncPtr = const char*(CUBLASWINAPI *)(cublasStatus_t); + static auto func_ptr = LoadSymbol("cublasGetStatusString"); + if (!func_ptr) LogFatalSymbolNotFound("cublasGetStatusString"); + return func_ptr(status); +} + } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index e0593cf9f12671..ab8fff9e1d7429 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -65,7 +65,6 @@ limitations under the License. #include "tensorflow/stream_executor/gpu/gpu_types.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -309,8 +308,7 @@ cublasSideMode_t CUDABlasSide(blas::Side side) { } // CUDADataType::type translates from a C++ type (e.g. float) to a -// cudaDataType_t (e.g. CUDA_R_32F). CUDAComputationType(ty) translates from a -// blas::ComputationType to a cudaDataType_t. +// cudaDataType_t (e.g. CUDA_R_32F). // // These are used to build the argument type and computation type args to // cublasGemmEx. @@ -372,28 +370,6 @@ struct CUDADataType> { static constexpr cudaDataType_t type = CUDA_C_8U; }; -cudaDataType_t CUDAComputationType(blas::ComputationType ty) { - switch (ty) { - case blas::ComputationType::kF16: - return CUDA_R_16F; - case blas::ComputationType::kF32: - return CUDA_R_32F; - case blas::ComputationType::kF64: - return CUDA_R_64F; - case blas::ComputationType::kI32: - return CUDA_R_32I; - case blas::ComputationType::kComplexF32: - return CUDA_C_32F; - case blas::ComputationType::kComplexF64: - return CUDA_C_64F; - case blas::ComputationType::kTF32AsF32: // fall-through - case blas::ComputationType::kBF16AsF32: - // These cases are currently only supported in the blasLt routines, which - // use CUBLASComputationType() instead. - LOG(FATAL) << "Invalid value of blas::ComputationType."; - } -} - } // namespace template @@ -2043,7 +2019,7 @@ port::Status CUDABlas::DoBlasGemmWithAlgorithm( AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, alpha, a.opaque(), AsCudaDataType(type_a), lda, b.opaque(), AsCudaDataType(type_b), ldb, beta, c->opaque(), AsCudaDataType(type_c), - ldc, CUDAComputationType(computation_type), + ldc, AsCublasComputeType(computation_type), static_cast(algorithm))); TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm, output_profile_result, stream)); @@ -2082,7 +2058,7 @@ port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, static_cast(alpha), a_matrix, CUDA_R_16BF, lda, b_matrix, CUDA_R_16BF, ldb, static_cast(beta), - c_matrix, CUDA_R_16BF, ldc, CUDAComputationType(computation_type), + c_matrix, CUDA_R_16BF, ldc, AsCublasComputeType(computation_type), static_cast(algorithm))); } TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm, @@ -2096,7 +2072,7 @@ port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, alpha, a.opaque(), cuda_in_type, lda, stride_a, b.opaque(), cuda_in_type, ldb, stride_b, beta, c->opaque(), AsCudaDataType(type_c), ldc, stride_c, - batch_count, CUDAComputationType(computation_type), + batch_count, AsCublasComputeType(computation_type), static_cast(algorithm))); TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm, output_profile_result, stream)); @@ -2230,21 +2206,21 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal( // Decide how to allocate device-side copy of pointers to matrices based on // whether a scratch allocator was passed. if (scratch_allocator != nullptr) { - SE_ASSIGN_OR_RETURN(DeviceMemory a_bytes, + TF_ASSIGN_OR_RETURN(DeviceMemory a_bytes, scratch_allocator->AllocateBytes(size)); - SE_ASSIGN_OR_RETURN(DeviceMemory b_bytes, + TF_ASSIGN_OR_RETURN(DeviceMemory b_bytes, scratch_allocator->AllocateBytes(size)); - SE_ASSIGN_OR_RETURN(DeviceMemory c_bytes, + TF_ASSIGN_OR_RETURN(DeviceMemory c_bytes, scratch_allocator->AllocateBytes(size)); a = DeviceMemory(a_bytes); b = DeviceMemory(b_bytes); c = DeviceMemory(c_bytes); } else { - SE_ASSIGN_OR_RETURN(a_temporary, + TF_ASSIGN_OR_RETURN(a_temporary, stream->AllocateTemporaryArray(batch_count)); - SE_ASSIGN_OR_RETURN(b_temporary, + TF_ASSIGN_OR_RETURN(b_temporary, stream->AllocateTemporaryArray(batch_count)); - SE_ASSIGN_OR_RETURN(c_temporary, + TF_ASSIGN_OR_RETURN(c_temporary, stream->AllocateTemporaryArray(batch_count)); a = DeviceMemory(*a_temporary->mutable_device_memory()); b = DeviceMemory(*b_temporary->mutable_device_memory()); diff --git a/tensorflow/stream_executor/cuda/cuda_blas_lt.cc b/tensorflow/stream_executor/cuda/cuda_blas_lt.cc index 0b1e6d35f61ee7..e0f00fd1696867 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas_lt.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas_lt.cc @@ -36,43 +36,12 @@ namespace stream_executor { namespace cuda { namespace { -blas::DataType GetScaleType(blas::DataType data_type, +blas::DataType GetScaleType(blas::DataType c_type, blas::ComputationType compute_type) { - switch (compute_type) { - case blas::ComputationType::kF16: - return blas::DataType::kHalf; - case blas::ComputationType::kF32: // fall-through - case blas::ComputationType::kTF32AsF32: // fall-through - case blas::ComputationType::kBF16AsF32: - return blas::DataType::kFloat; - case blas::ComputationType::kF64: - return blas::DataType::kDouble; - case blas::ComputationType::kComplexF32: - return blas::DataType::kComplexFloat; - case blas::ComputationType::kComplexF64: - return blas::DataType::kComplexDouble; - case blas::ComputationType::kI32: - return blas::DataType::kInt32; - } -} - -cublasComputeType_t AsCublasComputeType(blas::ComputationType type) { - switch (type) { - case blas::ComputationType::kF16: - return CUBLAS_COMPUTE_16F; - case blas::ComputationType::kF32: // fall-through - case blas::ComputationType::kComplexF32: - return CUBLAS_COMPUTE_32F; - case blas::ComputationType::kF64: // fall-through - case blas::ComputationType::kComplexF64: - return CUBLAS_COMPUTE_64F; - case blas::ComputationType::kI32: - return CUBLAS_COMPUTE_32I; - case blas::ComputationType::kTF32AsF32: - return CUBLAS_COMPUTE_32F_FAST_TF32; - case blas::ComputationType::kBF16AsF32: - return CUBLAS_COMPUTE_32F_FAST_16BF; - } + return ((compute_type == blas::ComputationType::kF32) && + (c_type != blas::DataType::kComplexFloat)) + ? blas::DataType::kFloat + : c_type; } cublasLtPointerMode_t AsCublasLtPointerMode(BlasLt::PointerMode pointer_mode) { @@ -195,13 +164,13 @@ port::StatusOr CreateCublasLtOperationDesc( computation_type, ") failed: ", ToString(status))); } BlasLt::UniqueOpDesc unique_desc(desc); - SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + TF_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, AsCublasLtPointerMode(pointer_mode))); - SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE, + TF_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE, AsCublasLtEpilogue(epilogue))); - SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, + TF_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, AsCublasOperation(transa))); - SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, + TF_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, AsCublasOperation(transb))); return unique_desc; } @@ -218,9 +187,9 @@ port::StatusOr CreateCublasLtLayoutDesc( absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status))); } BlasLt::UniqueLayoutDesc unique_desc(desc); - SE_RETURN_IF_ERROR( + TF_RETURN_IF_ERROR( SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count)); - SE_RETURN_IF_ERROR(SetCublasLtAttr( + TF_RETURN_IF_ERROR(SetCublasLtAttr( desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride)); return unique_desc; } @@ -256,7 +225,7 @@ port::StatusOr CreateCublasLtMatmulPreference( ToString(status))); } BlasLt::UniqueMatmulPreference unique_preference(preference); - SE_RETURN_IF_ERROR(SetCublasLtAttr(preference, + TF_RETURN_IF_ERROR(SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, max_workspace_bytes)); @@ -270,25 +239,25 @@ port::StatusOr CreateCublasLtMatmulPreference( return (stride & -stride) * GetDataTypeSizeBytes(dtype); }; if (plan.params().stride_a) { - SE_RETURN_IF_ERROR( + TF_RETURN_IF_ERROR( SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, (uint32)get_alignment_bytes(plan.params().stride_a, plan.params().ab_type))); } if (plan.params().stride_b) { - SE_RETURN_IF_ERROR( + TF_RETURN_IF_ERROR( SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, (uint32)get_alignment_bytes(plan.params().stride_b, plan.params().ab_type))); } if (plan.params().stride_c) { - SE_RETURN_IF_ERROR( + TF_RETURN_IF_ERROR( SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, (uint32)get_alignment_bytes(plan.params().stride_c, plan.params().c_type))); } if (plan.params().stride_c) { - SE_RETURN_IF_ERROR( + TF_RETURN_IF_ERROR( SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, (uint32)get_alignment_bytes(plan.params().stride_c, plan.params().c_type))); @@ -299,7 +268,7 @@ port::StatusOr CreateCublasLtMatmulPreference( port::Status AllocateWorkspace(void **workspace, ScratchAllocator *scratch_allocator, size_t num_bytes) { - SE_ASSIGN_OR_RETURN(DeviceMemory workspace_bytes, + TF_ASSIGN_OR_RETURN(DeviceMemory workspace_bytes, scratch_allocator->AllocateBytes(num_bytes)); *workspace = (void *)gpu::GpuMemoryMutable(&workspace_bytes); return port::Status::OK(); @@ -329,7 +298,7 @@ int BlasLt::MatmulAlgorithm::algo_id() const { port::Status BlasLt::MatmulPlan::init(const MatmulPlanParams &p) { params_ = p; scale_type_ = GetScaleType(p.c_type, p.computation_type); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( op_desc_, CreateCublasLtOperationDesc( p.computation_type, GetScaleType(p.c_type, p.computation_type), @@ -338,34 +307,34 @@ port::Status BlasLt::MatmulPlan::init(const MatmulPlanParams &p) { uint64_t cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m; uint64_t rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n; uint64_t cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k; - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a, capped_batch_count())); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b, capped_batch_count())); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, capped_batch_count())); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, capped_batch_count())); remainder_batch_count_ = p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0; if (remainder_batch_count_) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( a_remainder_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a, remainder_batch_count_)); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( b_remainder_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b, remainder_batch_count_)); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( c_remainder_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, remainder_batch_count_)); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( d_remainder_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, remainder_batch_count_)); @@ -382,7 +351,7 @@ bool BlasLt::MatmulPlan::SetBiasPointer(const void *bias) const { /*static*/ port::StatusOr BlasLt::CreateMatmulPlan( const BlasLt::MatmulPlanParams &p) { MatmulPlan cuda_plan; - SE_RETURN_IF_ERROR(cuda_plan.init(p)); + TF_RETURN_IF_ERROR(cuda_plan.init(p)); return std::move(cuda_plan); } @@ -391,7 +360,7 @@ BlasLt::GetMatmulAlgorithmsInternal(const BlasLt::MatmulPlan &plan, size_t max_workspace_size, int max_algorithm_count, bool for_remainder_batch) { - SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference, + TF_ASSIGN_OR_RETURN(UniqueMatmulPreference preference, CreateCublasLtMatmulPreference(plan, max_workspace_size)); std::vector results(max_algorithm_count); diff --git a/tensorflow/stream_executor/cuda/cuda_blas_utils.cc b/tensorflow/stream_executor/cuda/cuda_blas_utils.cc index 89d8408cafc585..b943b915eb8bb1 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas_utils.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas_utils.cc @@ -62,6 +62,25 @@ cudaDataType_t AsCudaDataType(blas::DataType type) { } } +cublasComputeType_t AsCublasComputeType(blas::ComputationType type) { + switch (type) { + case blas::ComputationType::kF16: + return CUBLAS_COMPUTE_16F; + case blas::ComputationType::kF32: + return CUBLAS_COMPUTE_32F; + case blas::ComputationType::kF64: + return CUBLAS_COMPUTE_64F; + case blas::ComputationType::kI32: + return CUBLAS_COMPUTE_32I; + case blas::ComputationType::kF16AsF32: + return CUBLAS_COMPUTE_32F_FAST_16F; + case blas::ComputationType::kBF16AsF32: + return CUBLAS_COMPUTE_32F_FAST_16BF; + case blas::ComputationType::kTF32AsF32: + return CUBLAS_COMPUTE_32F_FAST_TF32; + } +} + cublasOperation_t AsCublasOperation(blas::Transpose trans) { switch (trans) { case blas::Transpose::kNoTranspose: diff --git a/tensorflow/stream_executor/cuda/cuda_blas_utils.h b/tensorflow/stream_executor/cuda/cuda_blas_utils.h index 34a3c199fb8c51..6ab73329d6b59a 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas_utils.h +++ b/tensorflow/stream_executor/cuda/cuda_blas_utils.h @@ -19,12 +19,12 @@ limitations under the License. #include #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/status_macros.h" #define SE_CUBLAS_RETURN_IF_ERROR(expr) \ - SE_RETURN_IF_ERROR(::stream_executor::cuda::ToStatus(expr, #expr)) + TF_RETURN_IF_ERROR(::stream_executor::cuda::ToStatus(expr, #expr)) namespace stream_executor { namespace cuda { @@ -32,6 +32,7 @@ namespace cuda { const char* ToString(cublasStatus_t status); port::Status ToStatus(cublasStatus_t status, const char* prefix = "cublasLt"); cudaDataType_t AsCudaDataType(blas::DataType type); +cublasComputeType_t AsCublasComputeType(blas::ComputationType type); cublasOperation_t AsCublasOperation(blas::Transpose trans); } // namespace cuda diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 145447da361c71..9fac0b88fb2227 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -327,9 +327,9 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(std::optional algorithm) { } port::Status GetLoadedCudnnVersion(CudnnVersion* version) { - SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION)); - SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION)); - SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL)); + TF_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION)); + TF_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION)); + TF_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL)); return ::tensorflow::OkStatus(); } @@ -1107,7 +1107,7 @@ class CudnnDropoutDescriptor { size_t state_sizes_in_bytes = 0; RETURN_IF_CUDNN_ERROR( cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes)); - SE_ASSIGN_OR_RETURN(state_memory, + TF_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes(state_sizes_in_bytes)); } RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor( @@ -1196,7 +1196,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { cudnnDataType_t data_type, cudnnDataType_t compute_type, const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64_t seed, ScratchAllocator* state_allocator, bool use_padded_io) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( CudnnDropoutDescriptor dropout_desc, CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator)); @@ -1293,7 +1293,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { // TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and // cudnnRNNForward***Ex are removed from the codebase, since the new API // doesn't need param descriptors any more. - SE_ASSIGN_OR_RETURN(auto params_desc, + TF_ASSIGN_OR_RETURN(auto params_desc, CudnnRnnParamsDescriptor::Create( cudnn, input_size, data_type, rnn_desc.get(), rnn_mode, direction_mode, num_layers)); @@ -1873,7 +1873,7 @@ port::Status CudnnSupport::DoRnnForwardImpl( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( RnnModelDims model_dims, ExtractAndCheckRnnForward( rnn_desc, input_desc, input_data, input_h_desc, input_h_data, @@ -1882,7 +1882,7 @@ port::Status CudnnSupport::DoRnnForwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); - SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been // deprecated. Instead, we use the cudnnRNNForward which requires the @@ -1907,11 +1907,11 @@ port::Status CudnnSupport::DoRnnForwardImpl( /*reserveSpaceSize=*/&reserve_space_size_in_bytes)); if (workspace_size_in_bytes > 0) { - SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( + TF_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( workspace_size_in_bytes)); } if (reserve_space_size_in_bytes > 0) { - SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( + TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( reserve_space_size_in_bytes)); } @@ -1956,9 +1956,9 @@ port::Status CudnnSupport::DoRnnForwardImpl( return port::Status::OK(); } #endif - SE_ASSIGN_OR_RETURN(DeviceMemory workspace, + TF_ASSIGN_OR_RETURN(DeviceMemory workspace, CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, - workspace_allocator)) + workspace_allocator)); // query the reserve space size // allocate the reserve space @@ -1971,7 +1971,7 @@ port::Status CudnnSupport::DoRnnForwardImpl( /*sizeInBytes=*/&reserve_space_size_in_bytes)); if (reserve_space_size_in_bytes > 0) { - SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( + TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( reserve_space_size_in_bytes)); } } @@ -2093,7 +2093,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( RnnModelDims model_dims, ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc, input_h_data, input_c_desc, input_c_data, @@ -2102,7 +2102,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); - SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been // deprecated. Instead, we use the cudnnRNNForward which requires the @@ -2118,7 +2118,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl( /*workSpaceSize=*/&workspace_size_in_bytes, /*reserveSpaceSize=*/NULL)); if (workspace_size_in_bytes > 0) { - SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( + TF_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( workspace_size_in_bytes)); } @@ -2189,7 +2189,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl( return port::Status::OK(); } #endif - SE_ASSIGN_OR_RETURN(DeviceMemory workspace, + TF_ASSIGN_OR_RETURN(DeviceMemory workspace, CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, workspace_allocator)); @@ -2349,7 +2349,7 @@ CudnnSupport::createRnnDescriptor( // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's // not enqueueing anything into a stream, we pass in the null stream. auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( CudnnRnnDescriptor rnn_desc, CudnnRnnDescriptor::Create( cudnn, num_layers, hidden_size, input_size, cell_size, batch_size, @@ -2365,7 +2365,7 @@ port::StatusOr> CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, dnn::DataType data_type) { - SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, + TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, CudnnRnnSequenceTensorDescriptor::Create( parent_, max_seq_length, batch_size, data_size, ToCudnnDataType(data_type))); @@ -2378,7 +2378,7 @@ CudnnSupport::createRnnSequenceTensorDescriptor( int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, bool time_major, dnn::DataType data_type) { - SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, + TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, CudnnRnnSequenceTensorDescriptor::Create( parent_, max_seq_length, batch_size, data_size, seq_lengths, time_major, ToCudnnDataType(data_type))); @@ -2992,7 +2992,7 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( convolution_descriptor, ToCudnnDataType(GetConvAccumulatorType(element_type))); bool use_tensor_ops; - SE_ASSIGN_OR_RETURN(use_tensor_ops, + TF_ASSIGN_OR_RETURN(use_tensor_ops, UseTensorOps(stream, element_type, algo_desc)); conv.set_use_tensor_op_math(use_tensor_ops); @@ -3004,7 +3004,7 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( specify_workspace_limit ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0}) : int64_t{0}; - SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo, + TF_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo, GetCudnnConvolutionForwardAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); @@ -3032,10 +3032,10 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( "Returned status: ", scratch_or.status().ToString())); } - SE_ASSIGN_OR_RETURN(use_tensor_ops, + TF_ASSIGN_OR_RETURN(use_tensor_ops, UseTensorOps(stream, element_type, algo_desc)); conv.set_use_tensor_op_math(use_tensor_ops); - SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( + TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator)); return *algo_desc; @@ -3054,7 +3054,7 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( convolution_descriptor, ToCudnnDataType(GetConvAccumulatorType(element_type))); bool use_tensor_ops; - SE_ASSIGN_OR_RETURN(use_tensor_ops, + TF_ASSIGN_OR_RETURN(use_tensor_ops, UseTensorOps(stream, element_type, algo_desc)); conv.set_use_tensor_op_math(use_tensor_ops); @@ -3066,7 +3066,7 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( specify_workspace_limit ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0}) : int64_t{0}; - SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo, + TF_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo, GetCudnnConvolutionBackwardDataAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); @@ -3093,10 +3093,10 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN(use_tensor_ops, + TF_ASSIGN_OR_RETURN(use_tensor_ops, UseTensorOps(stream, element_type, algo_desc)); conv.set_use_tensor_op_math(use_tensor_ops); - SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( + TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator)); return *algo_desc; @@ -3115,7 +3115,7 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( convolution_descriptor, ToCudnnDataType(GetConvAccumulatorType(element_type))); bool use_tensor_ops; - SE_ASSIGN_OR_RETURN(use_tensor_ops, + TF_ASSIGN_OR_RETURN(use_tensor_ops, UseTensorOps(stream, element_type, algo_desc)); conv.set_use_tensor_op_math(use_tensor_ops); @@ -3127,7 +3127,7 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( specify_workspace_limit ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64_t{0}) : int64_t{0}; - SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo, + TF_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo, GetCudnnConvolutionBackwardFilterAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); @@ -3157,10 +3157,10 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( scratch_or.status().ToString())); } - SE_ASSIGN_OR_RETURN(use_tensor_ops, + TF_ASSIGN_OR_RETURN(use_tensor_ops, UseTensorOps(stream, element_type, algo_desc)); conv.set_use_tensor_op_math(use_tensor_ops); - SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( + TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator)); return *algo_desc; @@ -3851,7 +3851,7 @@ port::Status CudnnSupport::DoPrepareForConvolution( switch (kind) { case dnn::ConvolutionKind::FORWARD: { - SE_ASSIGN_OR_RETURN(*algorithm_desc, + TF_ASSIGN_OR_RETURN(*algorithm_desc, GetCudnnConvolutionForwardAlgorithm( stream, cudnn, algorithm_config, input_nd, filter_nd, element_type, convolution_descriptor, @@ -3859,7 +3859,7 @@ port::Status CudnnSupport::DoPrepareForConvolution( break; } case dnn::ConvolutionKind::BACKWARD_DATA: { - SE_ASSIGN_OR_RETURN(*algorithm_desc, + TF_ASSIGN_OR_RETURN(*algorithm_desc, GetCudnnConvolutionBackwardDataAlgorithm( stream, cudnn, algorithm_config, input_nd, filter_nd, element_type, convolution_descriptor, @@ -3867,7 +3867,7 @@ port::Status CudnnSupport::DoPrepareForConvolution( break; } case dnn::ConvolutionKind::BACKWARD_FILTER: { - SE_ASSIGN_OR_RETURN(*algorithm_desc, + TF_ASSIGN_OR_RETURN(*algorithm_desc, GetCudnnConvolutionBackwardFilterAlgorithm( stream, cudnn, algorithm_config, input_nd, filter_nd, element_type, convolution_descriptor, @@ -3959,7 +3959,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { auto algo = MakeAlgorithmDesc(); // Check that the current stream supports tensor ops if they're requested. - SE_RETURN_IF_ERROR(UseTensorOps(stream, input_type_, algo).status()); + TF_RETURN_IF_ERROR(UseTensorOps(stream, input_type_, algo).status()); if (static_cast(parent_) != stream->parent()->implementation()) { @@ -4019,7 +4019,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { switch (kind_) { case dnn::ConvolutionKind::FORWARD: { - SE_RETURN_IF_ERROR(get_fwd_bugs()); + TF_RETURN_IF_ERROR(get_fwd_bugs()); RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward( cudnn.handle(), /*alpha=*/alpha, /*srcDesc=*/input_nd_.handle(), @@ -4032,7 +4032,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { break; } case dnn::ConvolutionKind::BACKWARD_DATA: { - SE_RETURN_IF_ERROR(get_bwd_data_bugs()); + TF_RETURN_IF_ERROR(get_bwd_data_bugs()); RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData( cudnn.handle(), /*alpha=*/alpha, @@ -4050,7 +4050,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { break; } case dnn::ConvolutionKind::BACKWARD_FILTER: { - SE_RETURN_IF_ERROR(get_bwd_filter_bugs()); + TF_RETURN_IF_ERROR(get_bwd_filter_bugs()); RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter( cudnn.handle(), /*alpha=*/alpha, @@ -4150,11 +4150,11 @@ port::Status CudnnSupport::DoConvolve( auto accumulator_type = GetConvAccumulatorType(element_type); CudnnConvolutionDescriptor conv(convolution_descriptor, ToCudnnDataType(accumulator_type)); - SE_ASSIGN_OR_RETURN(bool use_tensor_ops, + TF_ASSIGN_OR_RETURN(bool use_tensor_ops, UseTensorOps(stream, element_type, algorithm_desc)); conv.set_use_tensor_op_math(use_tensor_ops); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto runner, CudnnLegacyConvRunner::Create( parent_, stream, cudnn_.get(), algorithm_desc, element_type, @@ -4196,7 +4196,7 @@ port::StatusOr> GetDescriptorAttribute( std::vector result(n); for (int i = 0; i < n; ++i) { - SE_ASSIGN_OR_RETURN(result[i], CreateBackendDesc(type)); + TF_ASSIGN_OR_RETURN(result[i], CreateBackendDesc(type)); } std::vector raw_ptrs; @@ -4218,7 +4218,7 @@ port::StatusOr> GetDescriptorAttribute( // them in the form of an AlgorithmDesc for use with RebuildExecutionPlan. port::StatusOr ExecutionPlanToAlgorithmDesc( const cudnn_frontend::ExecutionPlan& plan, size_t workspace_size) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto engine_cfgs, GetDescriptorAttribute(plan.get_raw_desc(), CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, @@ -4228,7 +4228,7 @@ port::StatusOr ExecutionPlanToAlgorithmDesc( "CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG had more than one element."); } - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto engines, GetDescriptorAttribute(engine_cfgs[0].get(), CUDNN_ATTR_ENGINECFG_ENGINE, CUDNN_BACKEND_ENGINE_DESCRIPTOR)); @@ -4251,7 +4251,7 @@ port::StatusOr ExecutionPlanToAlgorithmDesc( // were filled. std::vector knobs(CUDNN_KNOB_TYPE_COUNTS); for (int i = 0; i < knobs.size(); ++i) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( knobs[i], CreateBackendDesc(CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR)); } std::vector raw_knob_ptrs; @@ -4372,7 +4372,7 @@ class CudnnExecutionPlanRunner if (!timer->Stop(AsGpuStream(stream))) { return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - SE_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc()); + TF_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc()); profile_result->set_algorithm(desc); profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds()); profile_result->set_scratch_size(scratch_memory.size()); @@ -4643,7 +4643,7 @@ port::Status CudnnSupport::GetConvolveRunners( #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND auto cudnn = cudnn_->GetHandle(parent_, stream); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto op_graph, GetCudnnOperationGraph(kind, input_type, output_type, input_descriptor, filter_descriptor, output_descriptor, @@ -4672,7 +4672,7 @@ CudnnSupport::ConvolveRunnerFromDesc( ToCudnnDataType(GetConvAccumulatorType(input_type))); conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto runner, CudnnLegacyConvRunner::Create( parent_, stream, cudnn_.get(), algorithm_desc, input_type, @@ -4697,16 +4697,16 @@ CudnnSupport::ConvolveRunnerFromDesc( #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND auto cudnn = cudnn_->GetHandle(parent_, stream); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto op_graph, GetCudnnOperationGraph(kind, input_type, output_type, input_descriptor, filter_descriptor, output_descriptor, convolution_descriptor, cudnn)); - SE_ASSIGN_OR_RETURN(auto execution_plan, + TF_ASSIGN_OR_RETURN(auto execution_plan, RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph)); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto runner, CudnnExecutionPlanRunner::Create( parent_, cudnn_.get(), std::move(execution_plan), {'x', 'w', 'y'})); @@ -4941,7 +4941,7 @@ CudnnSupport::FusedConvolveRunnerFromDesc( CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max()); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto runner, CudnnLegacyFusedConvRunner::Create( parent_, stream, cudnn_.get(), algorithm_desc, input_type, @@ -4954,17 +4954,17 @@ CudnnSupport::FusedConvolveRunnerFromDesc( #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND auto cudnn = cudnn_->GetHandle(parent_, stream); - SE_ASSIGN_OR_RETURN(auto op_graph, + TF_ASSIGN_OR_RETURN(auto op_graph, GetCudnnFusedOperationGraph( kind, input_type, bias_type, output_type, conv_scale, side_input_scale, input_descriptor, filter_descriptor, bias_descriptor, output_descriptor, convolution_descriptor, activation_mode, cudnn)); - SE_ASSIGN_OR_RETURN(auto execution_plan, + TF_ASSIGN_OR_RETURN(auto execution_plan, RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph)); - SE_ASSIGN_OR_RETURN(auto runner, + TF_ASSIGN_OR_RETURN(auto runner, CudnnExecutionPlanRunner::Create( parent_, cudnn_.get(), std::move(execution_plan), {'x', 'w', 'z', 'b', 'y'})); @@ -5351,11 +5351,11 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl( activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max()); if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( workspace, CreateBatchNormForwardWorkspace( stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor, - scale_offset_descriptor, workspace_allocator)) + scale_offset_descriptor, workspace_allocator)); if (is_training) { size_t reserve_space_size_in_bytes = 0; RETURN_IF_CUDNN_ERROR( @@ -5364,7 +5364,7 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl( /*activationDesc=*/activation_desc.handle(), /*xDesc=*/x_descriptor.handle(), /*sizeInBytes=*/&reserve_space_size_in_bytes)); - SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( + TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( reserve_space_size_in_bytes)); } } @@ -5434,7 +5434,7 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl( } #endif if (!called) { - SE_RETURN_IF_ERROR(check_no_side_input_or_activation()); + TF_RETURN_IF_ERROR(check_no_side_input_or_activation()); RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), @@ -5444,7 +5444,7 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl( } } else { const void* maybe_inv_var = estimated_variance.opaque(); - SE_RETURN_IF_ERROR(check_no_side_input_or_activation()); + TF_RETURN_IF_ERROR(check_no_side_input_or_activation()); RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), @@ -5541,11 +5541,11 @@ port::Status CudnnSupport::DoBatchNormalizationBackwardImpl( CudnnActivationDescriptor activation_desc( activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max()); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( DeviceMemory workspace, CreateBatchNormBackwardWorkspace( stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor, - scale_offset_descriptor, workspace_allocator)) + scale_offset_descriptor, workspace_allocator)); RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx( /*handle=*/cudnn.handle(), /*mode=*/mode, @@ -5655,7 +5655,7 @@ port::Status CudnnSupport::DoFusedConvolve( dnn::AlgorithmDesc algo_desc; { auto cudnn = cudnn_->GetHandle(parent_, stream); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( algo_desc, GetCudnnConvolutionForwardAlgorithm( stream, cudnn, algorithm_config, conv_input_nd, filter, input_type, @@ -5674,7 +5674,7 @@ port::Status CudnnSupport::DoFusedConvolve( CudnnActivationDescriptor activation_desc( activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max()); - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto runner, CudnnLegacyFusedConvRunner::Create( parent_, stream, cudnn_.get(), std::move(algo_desc), input_type, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 830e8e0af18f90..4e631beea8ed79 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -39,7 +39,6 @@ limitations under the License. #include "tensorflow/stream_executor/dnn.pb.h" #include "tensorflow/stream_executor/lib/array_slice.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc index a8543881d508a3..6c1a311225a561 100644 --- a/tensorflow/stream_executor/host/host_platform.cc +++ b/tensorflow/stream_executor/host/host_platform.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/status_macros.h" namespace stream_executor { namespace host { diff --git a/tensorflow/stream_executor/lazy_op_runner.h b/tensorflow/stream_executor/lazy_op_runner.h index e59b9537992422..254e54957b19d8 100644 --- a/tensorflow/stream_executor/lazy_op_runner.h +++ b/tensorflow/stream_executor/lazy_op_runner.h @@ -56,7 +56,7 @@ class LazyOpRunner { if (!runner) { return port::InternalError("Null runner argument to FromOpRunner"); } - SE_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc()); + TF_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc()); // Private constructor cannot be called by make_unique :( return {std::unique_ptr( new LazyOpRunner(desc, std::move(runner)))}; @@ -80,7 +80,7 @@ class LazyOpRunner { typename Op::Config config, Stream* stream) { absl::MutexLock lock(&mu_); if (!runner_) { - SE_ASSIGN_OR_RETURN(runner_, Op::RunnerFromAlgorithmDesc( + TF_ASSIGN_OR_RETURN(runner_, Op::RunnerFromAlgorithmDesc( desc_, std::move(config), stream)); } return runner_.get(); diff --git a/tensorflow/stream_executor/lib/status_macros.h b/tensorflow/stream_executor/lib/status_macros.h deleted file mode 100644 index ff8f4a71c8d702..00000000000000 --- a/tensorflow/stream_executor/lib/status_macros.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Helper macros for dealing with the port::Status datatype. - -#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_MACROS_H_ -#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_MACROS_H_ - -// Early-returns the status if it is in error; otherwise, proceeds. -// -// The argument expression is guaranteed to be evaluated exactly once. -#define SE_RETURN_IF_ERROR(__status) \ - do { \ - auto status = __status; \ - if (!status.ok()) { \ - return status; \ - } \ - } while (false) - -// Identifier concatenation helper macros. -#define SE_MACRO_CONCAT_INNER(__x, __y) __x##__y -#define SE_MACRO_CONCAT(__x, __y) SE_MACRO_CONCAT_INNER(__x, __y) - -// Implementation of SE_ASSIGN_OR_RETURN that uses a unique temporary identifier -// for avoiding collision in the enclosing scope. -#define SE_ASSIGN_OR_RETURN_IMPL(__lhs, __rhs, __name) \ - auto __name = (__rhs); \ - if (!__name.ok()) { \ - return __name.status(); \ - } \ - __lhs = std::move(__name.ValueOrDie()); - -// Early-returns the status if it is in error; otherwise, assigns the -// right-hand-side expression to the left-hand-side expression. -// -// The right-hand-side expression is guaranteed to be evaluated exactly once. -#define SE_ASSIGN_OR_RETURN(__lhs, __rhs) \ - SE_ASSIGN_OR_RETURN_IMPL(__lhs, __rhs, \ - SE_MACRO_CONCAT(__status_or_value, __COUNTER__)) - -#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_MACROS_H_ diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc index 9202bbf62d62e5..12521d60285919 100644 --- a/tensorflow/stream_executor/multi_platform_manager.cc +++ b/tensorflow/stream_executor/multi_platform_manager.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/initialize.h" @@ -132,15 +133,15 @@ port::StatusOr MultiPlatformManagerImpl::PlatformWithName( if(lookup.ok()) platform = lookup.value(); else { - SE_ASSIGN_OR_RETURN(platform, LookupByNameLocked("rocm")); + TF_ASSIGN_OR_RETURN(platform, LookupByNameLocked("rocm")); } } else { if(target == "cuda_only") target = "cuda"; - SE_ASSIGN_OR_RETURN(platform, LookupByNameLocked(target)); + TF_ASSIGN_OR_RETURN(platform, LookupByNameLocked(target)); } if (initialize_platform && !platform->Initialized()) { - SE_RETURN_IF_ERROR(platform->Initialize({})); + TF_RETURN_IF_ERROR(platform->Initialize({})); } return platform; @@ -150,9 +151,9 @@ port::StatusOr MultiPlatformManagerImpl::PlatformWithId( const Platform::Id& id, bool initialize_platform) { absl::MutexLock lock(&mu_); - SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); + TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); if (initialize_platform && !platform->Initialized()) { - SE_RETURN_IF_ERROR(platform->Initialize({})); + TF_RETURN_IF_ERROR(platform->Initialize({})); } return platform; @@ -163,14 +164,14 @@ port::StatusOr MultiPlatformManagerImpl::InitializePlatformWithName( const std::map& options) { absl::MutexLock lock(&mu_); - SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); + TF_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target)); if (platform->Initialized()) { return port::Status( port::error::FAILED_PRECONDITION, absl::StrCat("platform \"", target, "\" is already initialized")); } - SE_RETURN_IF_ERROR(platform->Initialize(options)); + TF_RETURN_IF_ERROR(platform->Initialize(options)); return platform; } @@ -179,14 +180,14 @@ port::StatusOr MultiPlatformManagerImpl::InitializePlatformWithId( const Platform::Id& id, const std::map& options) { absl::MutexLock lock(&mu_); - SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); + TF_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id)); if (platform->Initialized()) { return port::Status( port::error::FAILED_PRECONDITION, absl::StrFormat("platform with id %p is already initialized", id)); } - SE_RETURN_IF_ERROR(platform->Initialize(options)); + TF_RETURN_IF_ERROR(platform->Initialize(options)); return platform; } @@ -212,7 +213,7 @@ MultiPlatformManagerImpl::PlatformsWithFilter( Platform* platform = entry.second; if (filter(platform)) { if (initialize_platform && !platform->Initialized()) { - SE_RETURN_IF_ERROR(platform->Initialize({})); + TF_RETURN_IF_ERROR(platform->Initialize({})); } platforms.push_back(platform); } diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h index becaf76940b1c8..b0588de9fe9b5f 100644 --- a/tensorflow/stream_executor/platform.h +++ b/tensorflow/stream_executor/platform.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/stream_executor/device_description.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin.h" diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index 0409bf972fd0be..3c471260394696 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" @@ -1768,13 +1767,13 @@ port::Status ROCMBlas::AllocateStridedBuffer( } if (scratch_allocator != nullptr) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( DeviceMemory batch_matrix_bytes, scratch_allocator->AllocateBytes(matrix_batch_byte_size)); *device_memory = DeviceMemory(batch_matrix_bytes); } else { assert(temp_memory != nullptr); - SE_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray( + TF_ASSIGN_OR_RETURN(*temp_memory, stream->AllocateTemporaryArray( matrix_batch_byte_size)); *device_memory = DeviceMemory(*(*temp_memory)->mutable_device_memory()); diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc index 338b2c7db5b86d..fc859cf36409b9 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc @@ -3172,7 +3172,7 @@ port::Status MIOpenSupport::DoConvolve( const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, dnn::ProfileResult* output_profile_result) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto runner, ConvolveRunnerFromDesc(stream, algorithm_desc, kind, element_type, output_type, input_descriptor, filter_descriptor, @@ -3225,7 +3225,7 @@ port::Status MIOpenSupport::GetConvolveRunners( } for (const auto& profile_result : profile_results) { - SE_ASSIGN_OR_RETURN( + TF_ASSIGN_OR_RETURN( auto runner, ConvolveRunnerFromDesc( stream, profile_result.algorithm(), kind, input_type, output_type, input_descriptor, filter_descriptor, diff --git a/tensorflow/stream_executor/scratch_allocator.cc b/tensorflow/stream_executor/scratch_allocator.cc index b17e7ba558235c..8e1660c52f1d49 100644 --- a/tensorflow/stream_executor/scratch_allocator.cc +++ b/tensorflow/stream_executor/scratch_allocator.cc @@ -24,7 +24,7 @@ namespace stream_executor { port::StatusOr> OneTimeScratchAllocator::AllocateBytes( int64_t byte_size) { CHECK(temporary_ == nullptr); - SE_ASSIGN_OR_RETURN(temporary_, + TF_ASSIGN_OR_RETURN(temporary_, stream_->AllocateTemporaryArray(byte_size)); return temporary_->device_memory(); } diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 610624345346e7..ca4d434f9565a6 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -89,6 +89,17 @@ struct NonDeduced { template using NonDeducedType = typename NonDeduced::type; +// Helper to return if `T` is the same type as `First` or any or `Rest`. +template +constexpr bool is_any_of() { + return false; +} + +template +constexpr bool is_any_of() { + return std::is_same_v || is_any_of(); +} + } // namespace detail // Convert a type to the corresponding QuantizedActivationMode. @@ -1201,24 +1212,18 @@ class Stream { const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, int ldc, blas::ComputePrecision precision) { - static_assert(!std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert( + detail::is_any_of, std::complex>(), + "Input can be half, bf16, float, double, std::complex or " + "std::complex"); + static_assert(!std::is_same_v || + detail::is_any_of(), "If input is Eigen::half, constant has to be either " "Eigen::half or float"); static_assert( - std::is_same::value || - std::is_same::value, + detail::is_any_of(), "If input is not Eigen::half, constant and input types have to match"); - static_assert( - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same>::value || - std::is_same>::value, - "Input can be half, bf16, float, double, std::complex or " - "std::complex"); blas::BlasSupport *blas = parent()->AsBlas(); if (!blas) { return port::InternalError( @@ -1461,17 +1466,15 @@ class Stream { int64_t stride_a, const DeviceMemory &b, int ldb, int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, int batch_count) { - static_assert(((std::is_same::value || - std::is_same::value) && - std::is_same::value) || - ((std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same>::value || - std::is_same>::value) && - std::is_same::value), - "Input or constant type mismatch"); + static_assert( + detail::is_any_of, std::complex>(), + "Unsupported input type"); + static_assert( + std::is_same_v || + (detail::is_any_of() && + std::is_same_v), + "Mismatched input and alpha/beta types"); blas::BlasSupport *blas = parent()->AsBlas(); if (!blas) { return port::InternalError( @@ -2127,40 +2130,48 @@ class Stream { friend class ocl::CLBlas; // for parent_. // Checks whether types match before a call to extended BLAS version. - template + template port::Status CheckTypesForExtendedBlas( blas::ComputationType computation_type) { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same>::value || - std::is_same>::value, - "The only buffer types supported are: Eigen::half, float, " - "double, int8, std::complex and std::complex"); static_assert( - std::is_same::value || - (std::is_same::value && - std::is_same::value), + detail::is_any_of, std::complex>(), + "The only buffer types supported are: Eigen::half, float, " + "double, int8, std::complex and std::complex"); + static_assert( + std::is_same_v || + (std::is_same_v && std::is_same_v), "Input and output buffer types should be the same unless input is " "int8 and output is int32"); - static_assert(std::is_same::value || - (std::is_same::value && - (std::is_same::value || - std::is_same::value)), - "Constant and output types should match"); - blas::ComputationType expected_computation_type = - blas::ToComputationType::value; - if (expected_computation_type != computation_type && - !(computation_type == blas::ComputationType::kF32 && - (expected_computation_type == blas::ComputationType::kF16 || - expected_computation_type == blas::ComputationType::kBF16AsF32))) { + static_assert( + std::is_same_v || + (std::is_same_v && + detail::is_any_of()), + "Mismatched alpha/beta and output types"); + + bool valid_computation_type = [computation_type] { + switch (computation_type) { + case blas::ComputationType::kF16: + return std::is_same_v; + case blas::ComputationType::kF32: + return detail::is_any_of>(); + case blas::ComputationType::kF64: + return detail::is_any_of>(); + case blas::ComputationType::kI32: + return std::is_same_v; + case blas::ComputationType::kF16AsF32: // fall-through + case blas::ComputationType::kBF16AsF32: // fall-through + case blas::ComputationType::kTF32AsF32: + return detail::is_any_of>(); + } + }(); + + if (!valid_computation_type) { return port::InternalError(absl::StrCat( - "Alpha/beta type and computation type have to match, got ", - blas::ComputationTypeString(computation_type), - " for computation type, expected: ", - blas::ComputationTypeString(expected_computation_type))); + "Invalid computation type ", + blas::ComputationTypeString(computation_type), " for output type: ", + blas::DataTypeString(blas::ToDataType::value))); } return ::tensorflow::OkStatus(); } diff --git a/tensorflow/stream_executor/tf_allocator_adapter.h b/tensorflow/stream_executor/tf_allocator_adapter.h index c562825e3e38e7..9c8fb3bb33f2a1 100644 --- a/tensorflow/stream_executor/tf_allocator_adapter.h +++ b/tensorflow/stream_executor/tf_allocator_adapter.h @@ -16,11 +16,17 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_TF_ALLOCATOR_ADAPTER_H_ #define TENSORFLOW_STREAM_EXECUTOR_TF_ALLOCATOR_ADAPTER_H_ +#include +#include +#include + #include "tensorflow/core/framework/allocator.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace stream_executor { @@ -73,7 +79,13 @@ class MultiDeviceAdapter : public DeviceMemoryAllocator { : DeviceMemoryAllocator(platform) { tf_allocators_.reserve(tf_allocators.size()); for (AllocatorWithStream &p : tf_allocators) { - per_device_allocators_.emplace_back(p.first.get(), p.second); + int device_ordinal = p.second->parent()->device_ordinal(); + if (per_device_allocators_.size() <= device_ordinal) { + per_device_allocators_.resize(device_ordinal + 1); + } + CHECK(!per_device_allocators_[device_ordinal]); + per_device_allocators_[device_ordinal] = + std::make_unique(p.first.get(), p.second); tf_allocators_.push_back(std::move(p.first)); } } @@ -82,14 +94,14 @@ class MultiDeviceAdapter : public DeviceMemoryAllocator { bool retry_on_failure, int64_t memory_space) override { CHECK_LT(device_ordinal, per_device_allocators_.size()); - return per_device_allocators_[device_ordinal].Allocate( + return per_device_allocators_[device_ordinal]->Allocate( device_ordinal, size, retry_on_failure, memory_space); } port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override { CHECK_LT(device_ordinal, per_device_allocators_.size()); - return per_device_allocators_[device_ordinal].Deallocate(device_ordinal, - mem); + return per_device_allocators_[device_ordinal]->Deallocate(device_ordinal, + mem); } // The Tensorflow BFC allocator used on GPU allows host-side deallocation @@ -102,11 +114,11 @@ class MultiDeviceAdapter : public DeviceMemoryAllocator { bool AllowsAsynchronousDeallocation() const override { return true; } port::StatusOr GetStream(int device_ordinal) override { - return per_device_allocators_[device_ordinal].GetStream(device_ordinal); + return per_device_allocators_[device_ordinal]->GetStream(device_ordinal); } private: - std::vector per_device_allocators_; + std::vector> per_device_allocators_; // The wrapped TF allocators backing per_device_allocators_ // (TfAllocatorAdapter does not take ownership of its underlying Allocator). std::vector> tf_allocators_; diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d5595689786a03..f35525c26920e0 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2428,10 +2428,6 @@ def pywrap_tensorflow_macro( # //third_party/tensorflow/tools/pip_package:win_pip_package_marker for specific reasons. # 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test. def py_test(deps = [], data = [], kernels = [], exec_properties = None, **kwargs): - # Python version placeholder - if kwargs.get("python_version", None) == "PY3": - kwargs["tags"] = kwargs.get("tags", []) + ["no_oss_py2"] - if not exec_properties: exec_properties = tf_exec_properties(kwargs) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt index 88743205291da1..5f08af16b0a0dd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.TensorShape" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "dims" @@ -47,6 +48,18 @@ tf_class { name: "concatenate" argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_as_proto" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_from_proto" + argspec: "args=[\'cls\', \'proto\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_type_proto" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "is_compatible_with" argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt index 88743205291da1..5f08af16b0a0dd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.TensorShape" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member { name: "dims" @@ -47,6 +48,18 @@ tf_class { name: "concatenate" argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_as_proto" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_from_proto" + argspec: "args=[\'cls\', \'proto\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_type_proto" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "is_compatible_with" argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adadelta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adadelta.pbtxt index 1816f45a8110da..de0f551d86eb55 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adadelta.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adadelta.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'epsilon\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.95\', \'1e-07\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adadelta\'], " @@ -91,6 +87,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'grad\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adagrad.pbtxt index df9fdf7ccea5a7..2679e1b4a57422 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adagrad.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adagrad.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'epsilon\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.1\', \'1e-07\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adagrad\'], " @@ -91,6 +87,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'grad\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam-w.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam-w.pbtxt index 1a6393edf12d55..d319bebccb6a25 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam-w.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam-w.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'weight_decay\', \'beta_1\', \'beta_2\', \'epsilon\', \'amsgrad\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.004\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'AdamW\'], " @@ -95,6 +91,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'gradient\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt index fc90fbf101c8b0..a546c6534a2699 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adam.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'amsgrad\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'False\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adam\'], " @@ -91,6 +87,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'gradient\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adamax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adamax.pbtxt index eefbfe44f61cf2..00c14520808994 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adamax.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-adamax.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Adamax\'], " @@ -91,6 +87,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'gradient\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-ftrl.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-ftrl.pbtxt index e66a6d9bb8d69d..aa17836e01acd4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-ftrl.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-ftrl.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'l2_shrinkage_regularization_strength\', \'beta\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Ftrl\'], " @@ -91,6 +87,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'gradient\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-nadam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-nadam.pbtxt index 41d8b7add64cc7..cdc19fa2b0117c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-nadam.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-nadam.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'Nadam\'], " @@ -91,6 +87,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'gradient\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-optimizer.pbtxt index 5fb6609a321211..bddc4be45b8a21 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-optimizer.pbtxt @@ -38,10 +38,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'name\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\'], " @@ -90,6 +86,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'gradient\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-r-m-sprop.pbtxt index ecd3ca463d447d..8a7b8e8d88ea41 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-r-m-sprop.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-r-m-sprop.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'momentum\', \'epsilon\', \'centered\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.0\', \'1e-07\', \'False\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'100\', \'True\', \'RMSprop\'], " @@ -91,6 +87,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'gradient\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-s-g-d.pbtxt index 3c31479cbb38b9..00d1fcbe790c14 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-s-g-d.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.experimental.-s-g-d.pbtxt @@ -39,10 +39,6 @@ tf_class { name: "trainable_variables" mtype: "" } - member { - name: "variables" - mtype: "" - } member_method { name: "__init__" argspec: "args=[\'self\', \'learning_rate\', \'momentum\', \'nesterov\', \'amsgrad\', \'clipnorm\', \'clipvalue\', \'global_clipnorm\', \'use_ema\', \'ema_momentum\', \'ema_overwrite_frequency\', \'jit_compile\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'0.0\', \'False\', \'False\', \'None\', \'None\', \'None\', \'False\', \'0.99\', \'None\', \'True\', \'SGD\'], " @@ -91,6 +87,10 @@ tf_class { name: "update_step" argspec: "args=[\'self\', \'gradient\', \'variable\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "with_name_scope" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.arm64 b/tensorflow/tools/ci_build/Dockerfile.cpu.arm64 index c6aed56189742f..db392b157597d1 100644 --- a/tensorflow/tools/ci_build/Dockerfile.cpu.arm64 +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.arm64 @@ -18,7 +18,7 @@ RUN yum -y check-update || true && \ COPY install/install_bazel.sh /install/ RUN /install/install_bazel.sh -ARG py_major_minor_version +ARG py_major_minor_version='3.10' ENV TF_PYTHON_VERSION=python${py_major_minor_version} ENV PYTHON_BIN_PATH=/usr/local/bin/${TF_PYTHON_VERSION} diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython deleted file mode 100644 index a7fe65a5856d09..00000000000000 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython +++ /dev/null @@ -1,101 +0,0 @@ -# Dockerfile to build a manylinux 2010 compliant cross-compiler. -# -# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible -# glibc (2.12) and system libstdc++ (4.4). -# -# To push a new version, run: -# $ docker build -f Dockerfile.rbe.cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython \ -# --tag "gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython" . -# $ docker push gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython - -FROM nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04 as devtoolset - -ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get install -y \ - cpio \ - file \ - flex \ - g++ \ - make \ - patch \ - rpm2cpio \ - unar \ - wget \ - xz-utils \ - && \ - rm -rf /var/lib/apt/lists/* - -ADD devtoolset/fixlinks.sh fixlinks.sh -ADD devtoolset/build_devtoolset.sh build_devtoolset.sh -ADD devtoolset/rpm-patch.sh rpm-patch.sh - -# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-7 in /dt7. -RUN /build_devtoolset.sh devtoolset-7 /dt7 -# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8. -RUN /build_devtoolset.sh devtoolset-8 /dt8 - -# TODO(klimek): Split up into two different docker images. -FROM nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04 -COPY --from=devtoolset /dt7 /dt7 -COPY --from=devtoolset /dt8 /dt8 - -# Install TensorRT. -RUN echo \ - deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 / \ - > /etc/apt/sources.list.d/nvidia-ml.list \ - && \ - apt-get update && apt-get install -y \ - libnvinfer-dev=7.2.2-1+cuda11.1 \ - libnvinfer7=7.2.2-1+cuda11.1 \ - libnvinfer-plugin-dev=7.2.2-1+cuda11.1 \ - libnvinfer-plugin7=7.2.2-1+cuda11.1 \ - && \ - rm -rf /var/lib/apt/lists/* - -# Copy and run the install scripts. -ARG DEBIAN_FRONTEND=noninteractive - -COPY install/install_bootstrap_deb_packages.sh /install/ -RUN /install/install_bootstrap_deb_packages.sh - -COPY install/install_deb_packages.sh /install/ -RUN /install/install_deb_packages.sh - -# Install additional packages needed for this image: -# - dependencies to build Python from source -# - patchelf, as it is required by auditwheel -RUN apt-get update && apt-get install -y \ - libbz2-dev \ - libffi-dev \ - libgdbm-dev \ - libncurses5-dev \ - libnss3-dev \ - libreadline-dev \ - libsqlite3-dev \ - patchelf \ - && \ - rm -rf /var/lib/apt/lists/* - -COPY install/install_bazel.sh /install/ -RUN /install/install_bazel.sh - -COPY install/build_and_install_python.sh /install/ -RUN /install/build_and_install_python.sh "3.7.7" -RUN /install/build_and_install_python.sh "3.8.2" -RUN /install/build_and_install_python.sh "3.9.4" -RUN /install/build_and_install_python.sh "3.10.0" - -COPY install/install_pip_packages_by_version.sh /install/ -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.7" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.8" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.9" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.10" - -ENV CLANG_VERSION="r91087153210132a4c2d3cf19a4526d8f395cb5a4" -COPY install/install_latest_clang.sh /install/ -RUN /install/install_latest_clang.sh - -# TensorRT 7 for CUDA 11.1 is compatible with CUDA 11.2, but requires -# libnvrtc.so.11.1. See https://github.com/NVIDIA/TensorRT/issues/1064. -# TODO(b/187962120): Remove when upgrading to TensorRT 8. -ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda-11.1/lib64" diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython deleted file mode 100644 index e86bc6268c8362..00000000000000 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython +++ /dev/null @@ -1,101 +0,0 @@ -# Dockerfile to build a manylinux 2010 compliant cross-compiler. -# -# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible -# glibc (2.12) and system libstdc++ (4.4). -# -# To push a new version, run: -# $ docker build -f Dockerfile.rbe.cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython \ -# --tag "gcr.io/tensorflow-testing/nosla-cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython" . -# $ docker push gcr.io/tensorflow-testing/nosla-cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython - -FROM nvidia/cuda:11.2.1-cudnn8-devel-ubuntu18.04 as devtoolset - -ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get install -y \ - cpio \ - file \ - flex \ - g++ \ - make \ - patch \ - rpm2cpio \ - unar \ - wget \ - xz-utils \ - && \ - rm -rf /var/lib/apt/lists/* - -ADD devtoolset/fixlinks.sh fixlinks.sh -ADD devtoolset/build_devtoolset.sh build_devtoolset.sh -ADD devtoolset/rpm-patch.sh rpm-patch.sh - -# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-7 in /dt7. -RUN /build_devtoolset.sh devtoolset-7 /dt7 -# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8. -RUN /build_devtoolset.sh devtoolset-8 /dt8 - -# TODO(klimek): Split up into two different docker images. -FROM nvidia/cuda:11.2.1-cudnn8-devel-ubuntu18.04 -COPY --from=devtoolset /dt7 /dt7 -COPY --from=devtoolset /dt8 /dt8 - -# Install TensorRT. -RUN echo \ - deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 / \ - > /etc/apt/sources.list.d/nvidia-ml.list \ - && \ - apt-get update && apt-get install -y \ - libnvinfer-dev=7.2.2-1+cuda11.1 \ - libnvinfer7=7.2.2-1+cuda11.1 \ - libnvinfer-plugin-dev=7.2.2-1+cuda11.1 \ - libnvinfer-plugin7=7.2.2-1+cuda11.1 \ - && \ - rm -rf /var/lib/apt/lists/* - -# Copy and run the install scripts. -ARG DEBIAN_FRONTEND=noninteractive - -COPY install/install_bootstrap_deb_packages.sh /install/ -RUN /install/install_bootstrap_deb_packages.sh - -COPY install/install_deb_packages.sh /install/ -RUN /install/install_deb_packages.sh - -# Install additional packages needed for this image: -# - dependencies to build Python from source -# - patchelf, as it is required by auditwheel -RUN apt-get update && apt-get install -y \ - libbz2-dev \ - libffi-dev \ - libgdbm-dev \ - libncurses5-dev \ - libnss3-dev \ - libreadline-dev \ - libsqlite3-dev \ - patchelf \ - && \ - rm -rf /var/lib/apt/lists/* - -COPY install/install_bazel.sh /install/ -RUN /install/install_bazel.sh - -COPY install/build_and_install_python.sh /install/ -RUN /install/build_and_install_python.sh "3.7.7" -RUN /install/build_and_install_python.sh "3.8.2" -RUN /install/build_and_install_python.sh "3.9.4" -RUN /install/build_and_install_python.sh "3.10.0" - -COPY install/install_pip_packages_by_version.sh /install/ -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.7" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.8" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.9" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.10" - -ENV CLANG_VERSION="r91087153210132a4c2d3cf19a4526d8f395cb5a4" -COPY install/install_latest_clang.sh /install/ -RUN /install/install_latest_clang.sh - -# TensorRT 7 for CUDA 11.1 is compatible with CUDA 11.2, but requires -# libnvrtc.so.11.1. See https://github.com/NVIDIA/TensorRT/issues/1064. -# TODO(b/187962120): Remove when upgrading to TensorRT 8. -ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda-11.1/lib64" diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython deleted file mode 100644 index c0b70f7571e6f5..00000000000000 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython +++ /dev/null @@ -1,34 +0,0 @@ -# Dockerfile to build a manylinux 2010 compliant cross-compiler. -# -# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible -# glibc (2.12) and system libstdc++ (4.4). -# -# To push a new version, run: -# $ docker build -f Dockerfile.rbe.cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython \ -# --tag "gcr.io/tensorflow-testing/nosla-cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython" . -# $ docker push gcr.io/tensorflow-testing/nosla-cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython - -FROM gcr.io/tensorflow-testing/nosla-cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython - -RUN apt-get update -RUN apt-get remove -y --allow-change-held-packages libcudnn8 libnccl2 libnccl-dev -RUN apt-get install -y --no-install-recommends --allow-downgrades --allow-change-held-packages \ - libcublas-11-4 \ - libcublas-dev-11-4 \ - cuda-nvml-dev-11.4 \ - cuda-command-line-tools-11.4 \ - cuda-libraries-dev-11.4 \ - cuda-minimal-build-11.4 \ - libcudnn8=8.0.5.39-1+cuda11.1 \ - libcudnn8-dev=8.0.5.39-1+cuda11.1 -RUN rm -f /usr/local/cuda -RUN ln -s /usr/local/cuda-11.4 /usr/local/cuda - -# Install TensorRT. -RUN apt-get update && \ - apt-get install -y \ - libnvinfer-dev=8.0.1-1+cuda11.3 \ - libnvinfer8=8.0.1-1+cuda11.3 \ - libnvinfer-plugin-dev=8.0.1-1+cuda11.3 \ - libnvinfer-plugin8=8.0.1-1+cuda11.3 && \ - rm -rf /var/lib/apt/lists/* diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython deleted file mode 100644 index 99f83ccdd3ef0b..00000000000000 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython +++ /dev/null @@ -1,92 +0,0 @@ -# Dockerfile to build a manylinux 2010 compliant cross-compiler. -# -# Builds a devtoolset gcc/libstdc++ that targets manylinux 2010 compatible -# glibc (2.12) and system libstdc++ (4.4). -# -# To push a new version, run: -# $ docker build -f Dockerfile.rbe.cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython \ -# --tag "gcr.io/tensorflow-testing/nosla-cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython" . -# $ docker push gcr.io/tensorflow-testing/nosla-cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython - -FROM nvidia/cuda:11.4.1-cudnn8-devel-ubuntu18.04 as devtoolset - -ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get install -y \ - cpio \ - file \ - flex \ - g++ \ - make \ - patch \ - rpm2cpio \ - unar \ - wget \ - xz-utils \ - && \ - rm -rf /var/lib/apt/lists/* - -ADD devtoolset/fixlinks.sh fixlinks.sh -ADD devtoolset/build_devtoolset.sh build_devtoolset.sh -ADD devtoolset/rpm-patch.sh rpm-patch.sh - -# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-7 in /dt7. -RUN /build_devtoolset.sh devtoolset-7 /dt7 -# Set up a sysroot for glibc 2.12 / libstdc++ 4.4 / devtoolset-8 in /dt8. -RUN /build_devtoolset.sh devtoolset-8 /dt8 - -# TODO(klimek): Split up into two different docker images. -FROM nvidia/cuda:11.4.1-cudnn8-devel-ubuntu18.04 -COPY --from=devtoolset /dt7 /dt7 -COPY --from=devtoolset /dt8 /dt8 - -# Install TensorRT. -RUN apt-get update && \ - apt-get install -y \ - libnvinfer-dev=8.0.1-1+cuda11.3 \ - libnvinfer8=8.0.1-1+cuda11.3 \ - libnvinfer-plugin-dev=8.0.1-1+cuda11.3 \ - libnvinfer-plugin8=8.0.1-1+cuda11.3 && \ - rm -rf /var/lib/apt/lists/* - -# Copy and run the install scripts. -ARG DEBIAN_FRONTEND=noninteractive - -COPY install/install_bootstrap_deb_packages.sh /install/ -RUN /install/install_bootstrap_deb_packages.sh - -COPY install/install_deb_packages.sh /install/ -RUN /install/install_deb_packages.sh - -# Install additional packages needed for this image: -# - dependencies to build Python from source -# - patchelf, as it is required by auditwheel -RUN apt-get update && apt-get install -y \ - libbz2-dev \ - libffi-dev \ - libgdbm-dev \ - libncurses5-dev \ - libnss3-dev \ - libreadline-dev \ - libsqlite3-dev \ - patchelf \ - && \ - rm -rf /var/lib/apt/lists/* - -COPY install/install_bazel.sh /install/ -RUN /install/install_bazel.sh - -COPY install/build_and_install_python.sh /install/ -RUN /install/build_and_install_python.sh "3.7.7" -RUN /install/build_and_install_python.sh "3.8.2" -RUN /install/build_and_install_python.sh "3.9.4" -RUN /install/build_and_install_python.sh "3.10.0" - -COPY install/install_pip_packages_by_version.sh /install/ -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.7" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.8" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.9" -RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.10" - -ENV CLANG_VERSION="r91087153210132a4c2d3cf19a4526d8f395cb5a4" -COPY install/install_latest_clang.sh /install/ -RUN /install/install_latest_clang.sh diff --git a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh index ca58a673f196b7..534189d32385a4 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh @@ -89,7 +89,7 @@ echo "" # execution in an MKL primitive. This reduces the effects of an oversubscription # of OpenMP threads caused by executing multiple tests concurrently. bazel test \ - --test_tag_filters=-no_oss,-no_oss_py2,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only \ + --test_tag_filters=-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only \ --test_lang_filters=cc,py \ -k \ --jobs=${N_JOBS} \ diff --git a/tensorflow/tools/ci_build/presubmit/macos/py37_cc/build.sh b/tensorflow/tools/ci_build/presubmit/macos/py37_cc/build.sh index 657e682ded0c88..8efa17887444cb 100644 --- a/tensorflow/tools/ci_build/presubmit/macos/py37_cc/build.sh +++ b/tensorflow/tools/ci_build/presubmit/macos/py37_cc/build.sh @@ -29,7 +29,7 @@ function run_build () { export TF_NEED_CUDA=0 export PYTHON_BIN_PATH=$(which python3.7) yes "" | $PYTHON_BIN_PATH configure.py - tag_filters="-no_oss,-no_oss_py2,-gpu,-tpu,-benchmark-test,-nomac,-no_mac,-v1only" + tag_filters="-no_oss,-gpu,-tpu,-benchmark-test,-nomac,-no_mac,-v1only" # Get the default test targets for bazel. source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh index cc3a628ba54948..a7db71b718c3a9 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh @@ -59,8 +59,8 @@ py_ver=$(python -c 'import sys; print(str(sys.version_info.major)+str(sys.versio export TF_BUILD_FLAGS="--config=mkl_aarch64 --copt=-mtune=generic --copt=-march=armv8-a \ --copt=-O3 --copt=-fopenmp --copt=-flax-vector-conversions --linkopt=-lgomp" export TF_TEST_FLAGS="${TF_BUILD_FLAGS} \ - --test_env=TF_ENABLE_ONEDNN_OPTS=1 --test_env=TF2_BEHAVIOR=1 --test_lang_filters=py \ - --define=no_tensorflow_py_deps=true --verbose_failures=true --test_keep_going" + --test_env=TF_ENABLE_ONEDNN_OPTS=1 --test_env=TF2_BEHAVIOR=1 --define=no_tensorflow_py_deps=true \ + --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium --verbose_failures=true --test_keep_going" export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} \ -//tensorflow/lite/... \ -//tensorflow/compiler/mlir/lite/tests:const-fold.mlir.test \ @@ -76,6 +76,12 @@ export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} \ -//tensorflow/python/eager:forwardprop_test \ -//tensorflow/python/framework:node_file_writer_test \ -//tensorflow/python/grappler:memory_optimizer_test \ + -//tensorflow/python/kernel_tests/array_ops:array_ops_test_cpu \ + -//tensorflow/python/kernel_tests/array_ops:concat_op_test_cpu \ + -//tensorflow/python/kernel_tests/array_ops:pad_op_test_cpu \ + -//tensorflow/python/kernel_tests/array_ops:slice_op_test_cpu \ + -//tensorflow/python/kernel_tests/array_ops:split_op_test_cpu \ + -//tensorflow/python/kernel_tests/control_flow:scan_ops_test_cpu \ -//tensorflow/python/kernel_tests/linalg:linear_operator_householder_test \ -//tensorflow/python/kernel_tests/linalg:linear_operator_inversion_test \ -//tensorflow/python/kernel_tests/linalg:linear_operator_block_diag_test \ @@ -89,7 +95,7 @@ export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} \ -//tensorflow/python/ops/parallel_for:math_test \ -//tensorflow/python/training:server_lib_test" export TF_PIP_TESTS="test_pip_virtualenv_clean" -export TF_TEST_FILTER_TAGS="-no_oss,-oss_serial,-no_oss_py${py_ver},-gpu,-tpu,-benchmark-test,-v1only,-no_aarch64,-requires-gpu" +export TF_TEST_FILTER_TAGS="-nopip,-no_pip,-no_oss,-oss_serial,-v1only,-requires-gpu,-gpu,-tpu,-benchmark-test,-no_aarch64" export TF_PIP_TEST_ROOT="pip_test" export TF_AUDITWHEEL_TARGET_PLAT="manylinux2014" diff --git a/tensorflow/tools/ci_build/release/requirements_mac.txt b/tensorflow/tools/ci_build/release/requirements_mac.txt index c51ffcb74023b1..42afc3e5486047 100644 --- a/tensorflow/tools/ci_build/release/requirements_mac.txt +++ b/tensorflow/tools/ci_build/release/requirements_mac.txt @@ -8,5 +8,5 @@ twine ~= 3.6.0 setuptools # Test dependencies which don't exist on Windows -jax ~= 0.2.26 -jaxlib ~= 0.1.75 +jax ~= 0.3.14 +jaxlib ~= 0.3.14 diff --git a/tensorflow/tools/ci_build/release/requirements_ubuntu.txt b/tensorflow/tools/ci_build/release/requirements_ubuntu.txt index 30e4e1c5e49a10..a380f22dce5ed0 100644 --- a/tensorflow/tools/ci_build/release/requirements_ubuntu.txt +++ b/tensorflow/tools/ci_build/release/requirements_ubuntu.txt @@ -5,5 +5,5 @@ PyYAML ~= 6.0 # Test dependencies which don't exist on Windows -jax ~= 0.2.26 -jaxlib ~= 0.1.75 +jax ~= 0.3.14 +jaxlib ~= 0.3.14 diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index e37580bd87f147..5b2f166f553773 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -53,7 +53,6 @@ py_test( python_version = "PY3", shard_count = 4, tags = [ - "no_oss_py2", "no_pip", "no_rocm", # No need to rerun this test for ROCm config. "no_windows", # numpy prints differently on windows. @@ -105,7 +104,6 @@ py_test( main = "tf_doctest.py", python_version = "PY3", tags = [ - "no_oss_py2", "no_pip", "no_rocm", "no_windows", # numpy prints differently on windows. @@ -129,7 +127,6 @@ py_test( srcs = ["tf_doctest_test.py"], python_version = "PY3", tags = [ - "no_oss_py2", "no_pip", "noasan", "nomsan", diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl index 6b25c857c5f6ba..7e2df99ae109dd 100644 --- a/tensorflow/tools/toolchains/remote_config/configs.bzl +++ b/tensorflow/tools/toolchains/remote_config/configs.bzl @@ -126,110 +126,6 @@ def initialize_rbe_configs(): python_install_path = "/usr/local", ) - tensorflow_rbe_config( - name = "ubuntu18.04-clang_manylinux2010-cuda11.1-cudnn8-tensorrt7.2", - compiler = "/clang_r969a51ff363263a3b5f2df55eba6b4d392bf30c0/bin/clang", - cuda_version = "11.1", - cudnn_version = "8", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["3.7", "3.8", "3.9", "3.10"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.2", - sysroot = "/dt7", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "11.1", - cudnn_version = "8", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["3.7", "3.8", "3.9", "3.10"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.2", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2", - compiler = "/clang_r969a51ff363263a3b5f2df55eba6b4d392bf30c0/bin/clang", - cuda_version = "11.2", - cudnn_version = "8.1", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["3.7", "3.8", "3.9", "3.10"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.2", - sysroot = "/dt7", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "11.2", - cudnn_version = "8.1", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["3.7", "3.8", "3.9", "3.10"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.2", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-clang_manylinux2010-cuda11.4-cudnn8.0.5-tensorrt7.2", - compiler = "/clang_r969a51ff363263a3b5f2df55eba6b4d392bf30c0/bin/clang", - cuda_version = "11.4", - cudnn_version = "8.0.5", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["3.7", "3.8", "3.9", "3.10"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.2", - sysroot = "/dt7", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-gcc7_manylinux2010-cuda11.4-cudnn8.0.5-tensorrt7.2", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "11.4", - cudnn_version = "8.0.5", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["3.7", "3.8", "3.9", "3.10"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.2", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-clang_manylinux2010-cuda11.4-cudnn8.2-tensorrt7.2", - compiler = "/clang_r969a51ff363263a3b5f2df55eba6b4d392bf30c0/bin/clang", - cuda_version = "11.4", - cudnn_version = "8.2", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["3.7", "3.8", "3.9", "3.10"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.2", - sysroot = "/dt7", - python_install_path = "/usr/local", - ) - - tensorflow_rbe_config( - name = "ubuntu18.04-gcc7_manylinux2010-cuda11.4-cudnn8.2-tensorrt7.2", - compiler = "/dt7/usr/bin/gcc", - compiler_prefix = "/usr/bin", - cuda_version = "11.4", - cudnn_version = "8.2", - os = "ubuntu18.04-manylinux2010-multipython", - python_versions = ["3.7", "3.8", "3.9", "3.10"], - tensorrt_install_path = "/usr", - tensorrt_version = "7.2", - python_install_path = "/usr/local", - ) - tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2", compiler = "/clang11/bin/clang", diff --git a/tensorflow/tools/toolchains/remote_config/containers.bzl b/tensorflow/tools/toolchains/remote_config/containers.bzl index 41943cca3588e3..a76c2e744d739e 100644 --- a/tensorflow/tools/toolchains/remote_config/containers.bzl +++ b/tensorflow/tools/toolchains/remote_config/containers.bzl @@ -7,12 +7,6 @@ container_digests = { # JAX manylinux2014 configs. "cuda11.1-cudnn8-ubuntu20.04-manylinux2014-multipython": "sha256:3764b49e64a16e5778995fc21a119ff0e364174ecbf461f741701b48e6d4f204", "cuda11.4-cudnn8.2-ubuntu20.04-manylinux2014-multipython": "sha256:6531a7ed4b1524e9f997ad10ad214af79e1042bc7ec6efe2f3e0692bafdb968f", - # TODO(yashkatariya): Remove manylinux2010 configs when 2014 is working. - # JAX uses these, TF used to use them too. Some might not be used anymore - "cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython": "sha256:bf24e58c0e18d60a99bee81c65d9f50b19548dec352404f0593ba5ea18c7e85c", - "cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython": "sha256:904ea6196b81fe67bf5a3c00d336b7c6f990d49291abd2c1dec0654ee7ac3041", - "cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython": "sha256:0777b477c37b003895713bd11e4e4db99329b7f03b77b130d49437881d71b795", - "cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython": "sha256:589c2fa98484dd83bcf0ffe371640a7c1a0c5e7299c0fc871c8820ddcbca2699", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:26720ebae4d6d12b1fca529616bfacfd0460990d4725af35e0f4af3c2422f227", @@ -83,34 +77,6 @@ containers = { "digest": container_digests["cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython"], }, - # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython. - "cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython": { - "registry": "gcr.io", - "repository": "tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython", - "digest": container_digests["cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython"], - }, - - # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython. - "cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython": { - "registry": "gcr.io", - "repository": "tensorflow-testing/nosla-cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython", - "digest": container_digests["cuda11.2-cudnn8.1-ubuntu18.04-manylinux2010-multipython"], - }, - - # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython. - "cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython": { - "registry": "gcr.io", - "repository": "tensorflow-testing/nosla-cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython", - "digest": container_digests["cuda11.4-cudnn8.0.5-ubuntu18.04-manylinux2010-multipython"], - }, - - # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython. - "cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython": { - "registry": "gcr.io", - "repository": "tensorflow-testing/nosla-cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython", - "digest": container_digests["cuda11.4-cudnn8.2-ubuntu18.04-manylinux2010-multipython"], - }, - # Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython. "cuda11.2-cudnn8.1-ubuntu20.04-manylinux2014-multipython": { "registry": "gcr.io", diff --git a/third_party/eigen3/workspace.bzl b/third_party/eigen3/workspace.bzl index 220f369844c791..1380722ee67340 100644 --- a/third_party/eigen3/workspace.bzl +++ b/third_party/eigen3/workspace.bzl @@ -7,8 +7,8 @@ def repo(): # Attention: tools parse and update these lines. # LINT.IfChange - EIGEN_COMMIT = "b02c384ef4e8eba7b8bdef16f9dc6f8f4d6a6b2b" - EIGEN_SHA256 = "515b3c266d798f3a112efe781dda0cf1aef7bd73f6864d8f4f16129310ae1fdf" + EIGEN_COMMIT = "0e187141679fdb91da33249d18cb79a011c0e2ea" + EIGEN_SHA256 = "52a7ef3ffe2b581973615b000657f456e2eab8e899fb863f456711feb790cb8c" # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/eigen.cmake) tf_http_archive( diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 083663dde43433..b2b5924d37eb93 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "1ecfc12b0c67de24c160dd126b873efa8e51c7c7" - LLVM_SHA256 = "0ccb19b092e6635906903c40b66dc32065a8a4e04c6040ab4bbc79806b52705b" + LLVM_COMMIT = "ebb78a95cede526ece9b904e9ba623d4b963df60" + LLVM_SHA256 = "001d3b87097a4120cc3ccdc33b1ad899385da50ff6dac2692712b09998a75b90" tf_http_archive( name = name, diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index b78366eec8063d..6c49523f81501a 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "7573508e0ed3e53e68716317ef2dc3710eaadd7b" - TFRT_SHA256 = "e709a26007e68d15940771ce8d4664f727885d4b59eded1ef685136e5e250565" + TFRT_COMMIT = "35c46f2fab174b3e7b474599da4e31713cc9a515" + TFRT_SHA256 = "3250df282ebad8161e8ce54f31bcad64ffb95b35442508f267f21c434c5f8171" tf_http_archive( name = "tf_runtime",