Skip to content

Commit

Permalink
Clean up HloInputOutputAliasConfig
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628601105
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 27, 2024
1 parent 7bb6eb1 commit c3325a8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/hlo/ir/BUILD
Expand Up @@ -93,6 +93,8 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
Expand Down
34 changes: 18 additions & 16 deletions third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.h
Expand Up @@ -24,9 +24,12 @@ limitations under the License.
#include <utility>

#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_set.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"

Expand Down Expand Up @@ -72,13 +75,11 @@ class HloInputOutputAliasConfig {
explicit HloInputOutputAliasConfig(Shape output_shape)
: alias_(std::move(output_shape)) {}

virtual ~HloInputOutputAliasConfig() = default;

// Sets up alias config from `output_index` to `param_index` at
// `param_number`.
Status SetUpAlias(const ShapeIndex& output_index, int64_t param_number,
const ShapeIndex& param_index,
AliasKind must_alias = kMayAlias);
absl::Status SetUpAlias(const ShapeIndex& output_index, int64_t param_number,
const ShapeIndex& param_index,
AliasKind must_alias = kMayAlias);

// Returns true if the given parameter is aliased with one of the output
// buffers.
Expand Down Expand Up @@ -120,16 +121,16 @@ class HloInputOutputAliasConfig {
// Iterates through each aliased output and input.
void ForEachAlias(AliasFn fn) const;

using AliasFnWithStatus =
absl::FunctionRef<Status(const ShapeIndex& output_index, const Alias&)>;
using AliasFnWithStatus = absl::FunctionRef<absl::Status(
const ShapeIndex& output_index, const Alias&)>;

// Verifies that the given config is valid for the given module.
// Specifically, the config's input and output should be in-bound and size of
// Specifically, the config's input and output should be in-bound and size ofF
// the aliased buffers should match.
Status Verify(const HloModule& module,
absl::FunctionRef<int64_t(const Shape&)> size_func) const;
absl::Status Verify(const HloModule& module,
absl::FunctionRef<int64_t(const Shape&)> size_func) const;

Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;
absl::Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;

// Returns the shape of the output of the alias config.
const Shape& shape() const;
Expand Down Expand Up @@ -186,12 +187,13 @@ class HloBufferDonorConfig {
};

HloBufferDonorConfig() = default;
virtual ~HloBufferDonorConfig() = default;

// Register and unregister the parameter with `param_index` at `param_number`
// as a buffer donor.
Status AddBufferDonor(int64_t param_number, const ShapeIndex& param_index);
Status RemoveBufferDonor(int64_t param_number, const ShapeIndex& param_index);
absl::Status AddBufferDonor(int64_t param_number,
const ShapeIndex& param_index);
absl::Status RemoveBufferDonor(int64_t param_number,
const ShapeIndex& param_index);

// Returns true if the given parameter is registered as a buffer donor.
bool ParameterIsBufferDonor(int64_t param_number,
Expand All @@ -205,7 +207,7 @@ class HloBufferDonorConfig {
// Verifies that the given config is valid for the given module.
// The config's input should be in-bound and this config cannot overlap with
// the given module's input_output_alias_config.
Status Verify(const HloModule& module) const;
absl::Status Verify(const HloModule& module) const;

// Returns the registered buffer donors
const absl::btree_set<BufferDonor>& buffer_donor() const {
Expand Down

0 comments on commit c3325a8

Please sign in to comment.