Skip to content

Commit

Permalink
Merge pull request tensorflow#1760 from ROCmSoftwarePlatform/develop-…
Browse files Browse the repository at this point in the history
…upstream-sync-220711

Develop upstream sync 220711
  • Loading branch information
rsanthanam-amd committed Jul 11, 2022
2 parents e32c7ad + 45d7533 commit fab10a1
Show file tree
Hide file tree
Showing 441 changed files with 10,564 additions and 4,655 deletions.
20 changes: 0 additions & 20 deletions .bazelrc
Expand Up @@ -663,23 +663,3 @@ build:ubsan --linkopt -lubsan
# Disable TFRT integration for now unless --config=tfrt is specified.
build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/common,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils
build:tfrt --deleted_packages=

# Experimental configuration for building XLA GPU lowering to TFRT.
build:experimental_enable_xlir --config=tfrt
build:experimental_enable_xlir --@tf_runtime//:enable_gpu
build:experimental_enable_xlir --@rules_cuda//cuda:cuda_runtime=//tensorflow/compiler/xla/service/gpu:cuda_runtime_for_xlir
build:experimental_enable_xlir --nocheck_visibility
build:experimental_enable_xlir --incompatible_strict_action_env
build:experimental_enable_xlir --config=monolithic
build:experimental_enable_xlir --//tensorflow/compiler/xla/service/gpu:enable_xlir

# bazel test --config=experimental_enable_bef_thunk \
# //tensorflow/compiler/xla/service/gpu:bef_thunk_tests
build:experimental_enable_bef_thunk --config=experimental_enable_xlir
test:experimental_enable_bef_thunk --test_env=XLA_FLAGS=--xla_gpu_bef_thunk

# bazel test --config=experimental_enable_bef_executable \
# //tensorflow/compiler/xla/service/gpu:bef_executable_tests
build:experimental_enable_bef_executable --config=experimental_enable_xlir
test:experimental_enable_bef_executable --test_env=XLA_FLAGS=--xla_gpu_bef_executable

7 changes: 5 additions & 2 deletions .github/workflows/trusted-partners.yml
Expand Up @@ -41,10 +41,13 @@ jobs:
const domain = await script.get_email_domain({github, username});
switch(domain) {
case "intel.com":
console.log(await script.filter({github, context}));
console.log(await script.filter({github, context, domain}));
break;
case "apple.com":
console.log(await script.filter({github, context}));
console.log(await script.filter({github, context, domain}));
break;
case "nvidia.com":
console.log(await script.filter({github, context, domain}));
break;
case "google.com":
console.log("Googler. No action necessary");
Expand Down
21 changes: 18 additions & 3 deletions .github/workflows/trusted_partners.js
Expand Up @@ -49,7 +49,7 @@ const get_email_domain = async ({github, username}) => {
context has the commit message details in the payload
@return {string} Returns the message with labels attached and assignees added
*/
const filter_action = async ({github, context}) => {
const filter_action = async ({github, context, domain}) => {
const labels = ['kokoro:force-run', 'ready to pull'];

let assignees = [];
Expand All @@ -58,11 +58,26 @@ const filter_action = async ({github, context}) => {
if (title && title.toLowerCase().includes("onednn"))
assignees = onednn_assignees;
const intel_windows_assignees = ['nitins17', 'learning-to-play'];
if (title && title.toLowerCase().includes("intel") && title.toLowerCase().includes("windows"))
if (title && title.toLowerCase().includes('intel') &&
title.toLowerCase().includes('windows') && domain.includes('intel.com'))
assignees = intel_windows_assignees;
const apple_silicon_assignees = ['penpornk', 'nitins17'];
if (title && title.toLowerCase().includes("apple") && title.toLowerCase().includes("silicon"))
if (title && title.toLowerCase().includes('apple') &&
title.toLowerCase().includes('silicon') && domain.includes('apple.com'))
assignees = apple_silicon_assignees;
if (title && title.toLowerCase().includes('nvidia') &&
domain.includes('nvidia.com')) {
if (title.toLowerCase().includes('jax')) {
assignees.push('hawkinsp', 'yashk2810', 'skye');
}
if (title.toLowerCase().includes('xla') ||
title.toLowerCase().includes('gpu')) {
assignees.push('cheshire', 'gcforster', 'reedwm', 'chsigg');
}
if (title.toLowerCase().includes('tf')) {
assignees.push('rohan100jain', 'bfontain', 'penpornk');
}
}

const resp_label = await github.rest.issues.addLabels({
issue_number: context.issue.number,
Expand Down
4 changes: 4 additions & 0 deletions RELEASE.md
Expand Up @@ -6,6 +6,10 @@

* <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
* Causal attention in `keras.layers.Attention` and
`keras.layers.AdditiveAttention` is now specified in the `call()` method
via the `use_causal_mask` argument (rather than in the constructor),
for consistency with other layers.
* Some files in `tensorflow/python/training` have been moved to
`tensorflow/python/tracking` and `tensorflow/python/checkpoint`. Please
update your imports accordingly, the old files will be removed in Release
Expand Down
1 change: 1 addition & 0 deletions tensorflow/BUILD
Expand Up @@ -897,6 +897,7 @@ config_setting(
package_group(
name = "internal",
packages = [
"//devtools/python/indexer/...",
"//learning/brain/keras/...",
"//learning/brain/mlir/...",
"//learning/brain/tfrt/...",
Expand Down
23 changes: 23 additions & 0 deletions tensorflow/c/experimental/ops/BUILD
Expand Up @@ -29,6 +29,26 @@ cc_library(
],
)

cc_library(
name = "io_ops",
srcs = [
"io_ops.cc",
],
hdrs = [
"io_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:tracing_utils",
"//tensorflow/core:framework",
"//tensorflow/core/platform:errors",
],
)

cc_library(
name = "math_ops",
srcs = [
Expand Down Expand Up @@ -99,6 +119,7 @@ cc_library(
name = "ops",
hdrs = [
"array_ops.h",
"io_ops.h",
"math_ops.h",
"nn_ops.h",
"resource_variable_ops.h",
Expand All @@ -108,6 +129,7 @@ cc_library(
],
deps = [
":array_ops",
":io_ops",
":math_ops",
":nn_ops",
":resource_variable_ops",
Expand All @@ -122,6 +144,7 @@ filegroup(
name = "pywrap_required_hdrs",
srcs = [
"array_ops.h",
"io_ops.h",
"math_ops.h",
"nn_ops.h",
"resource_variable_ops.h",
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/jit/compilability_check_util.cc
Expand Up @@ -488,6 +488,14 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
return false;
}

if (!op_filter_.allow_where_op && node.type_string() == "Where") {
absl::string_view uncompilable_reason = "Where op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}

if (!op_filter_.allow_unique_op && node.type_string() == "Unique") {
absl::string_view uncompilable_reason = "Unique op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/jit/compilability_check_util.h
Expand Up @@ -136,6 +136,9 @@ class RecursiveCompilabilityChecker {
// Whether to allow the compilation of CollectiveReduceV2Op.
bool allow_collective_reduce_v2 = true;

// Whether to allow the compilation of WhereOp.
bool allow_where_op = true;

// Whether to allow the compilation of UniqueOp. Compilation of the UniqueOp
// generates output with bounded dynamic shape that may cause failures with
// auto clustering.
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/jit/flags.cc
Expand Up @@ -114,6 +114,12 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
" BN: TF FusedBatchNorm* operations."
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
"You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
Flag("tf_xla_cluster_exclude_ops",
&mark_for_compilation_flags->tf_xla_cluster_exclude_ops,
"(experimental) "
"Exclude the operations from auto-clustering. "
"If multiple, separate them with commas."
" Where, Some_other_ops"),
Flag("tf_xla_clustering_debug",
&mark_for_compilation_flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/jit/flags.h
Expand Up @@ -61,6 +61,9 @@ struct MarkForCompilationPassFlags {
// If non-empty, limit XLA clustering to the following TF operations.
string tf_xla_ops_to_cluster;

// If non-empty, remove following operations from XLA clustering excludelist.
string tf_xla_cluster_exclude_ops;

// Dump graphs during XLA compilation.
bool tf_xla_clustering_debug;

Expand Down
31 changes: 31 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Expand Up @@ -1189,6 +1189,24 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
return true;
}

absl::flat_hash_set<string> GetOrCreateClusterExcludeList() {
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
absl::flat_hash_set<string> excludelist;
for (auto s : absl::StrSplit(flags->tf_xla_cluster_exclude_ops, ',')) {
if (!s.empty()) {
excludelist.insert(string(s));
}
}
if (VLOG_IS_ON(2) && !excludelist.empty()) {
std::vector<string> vexcludelist(excludelist.begin(), excludelist.end());
absl::c_sort(vexcludelist);
VLOG(2) << "XLA clustering will exclude following TF operations from auto "
"clustering: "
<< absl::StrJoin(vexcludelist, " ");
}
return excludelist;
}

absl::flat_hash_set<string> GetOrCreateAllowlist() {
absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
tensorflow::GetAllowlistTable();
Expand Down Expand Up @@ -1289,12 +1307,25 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
continue;
}

auto cluster_exclude_op_list = GetOrCreateClusterExcludeList();
RecursiveCompilabilityChecker::OperationFilter filter =
CreateOperationFilter(*registration);
filter.require_always_compilable = true;
filter.allow_string_consts = false;
filter.allow_collective_reduce_v2 = false;
filter.allow_unique_op = false;
filter.allow_where_op = true;

for (const auto& s : cluster_exclude_op_list) {
if (s == "Where") {
filter.allow_where_op = false;
} else {
return errors::InvalidArgument(
"The operation '", s,
"' passed to --tf_xla_cluster_exclude_ops is not supported by "
"XLA.");
}
}

RecursiveCompilabilityChecker checker(
filter, DeviceType{registration->compilation_device_name});
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
Expand Up @@ -196,6 +196,24 @@ TEST(XlaCompilationTest, StringUnsupported) {
EXPECT_TRUE(clusters.empty());
}

TEST(XlaCompilationTest, WhereUnsupported) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_INT32)
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Where", a, builder.opts().WithName("B"));
ops::BinaryOp("Gather", b, a, builder.opts().WithName("C"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}

TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(!clusters.empty());
}

TEST(XlaCompilationTest, HalfSupported) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
Expand Down
68 changes: 65 additions & 3 deletions tensorflow/compiler/mlir/hlo/BUILD
Expand Up @@ -1365,6 +1365,7 @@ cc_library(
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
Expand Down Expand Up @@ -2193,6 +2194,7 @@ cc_library(
deps = [
":bufferize_pass",
":gml_st_tiling",
":gml_st_vectorization",
":legalize_mhlo_to_gml",
":legalize_to_linalg",
"@llvm-project//mlir:BufferizationTransforms",
Expand Down Expand Up @@ -2394,9 +2396,10 @@ gentbl_cc_library(
td_file = "include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td",
td_srcs = [
"include/mlir-hlo/Dialect/gml_st/IR/gml_st_extension_ops.td",
"include/mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td",
"include/mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td",
"include/mlir-hlo/Dialect/gml_st/IR/gml_st_set_ops.td",
"include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.td",
"include/mlir-hlo/Dialect/gml_st/transforms/fusion_interface.td",
],
deps = [":gml_st_ops_td_files"],
)
Expand All @@ -2411,6 +2414,7 @@ cc_library(
],
includes = ["include"],
deps = [
":compose_set_interface",
":fusion_interface",
":gml_st_ops_inc_gen",
"@llvm-project//llvm:Support",
Expand Down Expand Up @@ -2474,14 +2478,49 @@ cc_library(
deps = [
":fusion_interface",
":gml_st",
":hlo",
":map_mhlo_to_scalar_op",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorUtils",
],
)

gentbl_cc_library(
name = "compose_set_interface_inc_gen",
compatible_with = get_compatible_with_cloud(),
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-interface-decls"],
"include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h.inc",
),
(
["-gen-op-interface-defs"],
"include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.td",
deps = ["@llvm-project//mlir:OpBaseTdFiles"],
)

cc_library(
name = "compose_set_interface",
srcs = [
"include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.cc.inc",
"include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h.inc",
"lib/Dialect/gml_st/transforms/compose_set_interface.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/gml_st/transforms/compose_set_interface.h",
],
includes = ["include"],
deps = [
":compose_set_interface_inc_gen",
"@llvm-project//mlir:IR",
],
)

Expand Down Expand Up @@ -2584,6 +2623,7 @@ cc_library(
],
includes = ["include"],
deps = [
":compose_set_interface",
":gml_st",
":gml_st_passes_inc_gen",
"@llvm-project//llvm:Support",
Expand All @@ -2595,6 +2635,28 @@ cc_library(
],
)

cc_library(
name = "gml_st_vectorization",
srcs = [
"include/mlir-hlo/Dialect/gml_st/transforms/pass_detail.h",
"lib/Dialect/gml_st/transforms/vectorization.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/gml_st/transforms/passes.h",
],
includes = ["include"],
deps = [
":gml_st",
":gml_st_passes_inc_gen",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:VectorDialect",
],
)

cc_library(
name = "legalize_mhlo_to_gml",
srcs = [
Expand Down

0 comments on commit fab10a1

Please sign in to comment.