Skip to content

Commit

Permalink
[pjrt] NFC: Always use concrete async value in PjRtFuture<T>
Browse files Browse the repository at this point in the history
We plan to make absl::StatusOr<T> an implicit payload of PjRtFuture<T> and for consistency always use AsyncValueRef with concrete payload instead of relying on AsyncValue error semantics.

PjRtFuture can't be constructed from async values passed from a user (only from a promise) so we can safely ignore the error bit as we never use it.

This is a first CL in preparation for making absl::StatusOr implicit in PjRtFuture<T>.

PiperOrigin-RevId: 628598778
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 27, 2024
1 parent f0dfc04 commit 7bb6eb1
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions third_party/xla/xla/pjrt/pjrt_future.h
Expand Up @@ -21,11 +21,9 @@ limitations under the License.
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <type_traits>
#include <utility>

#include "absl/base/optimization.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -199,16 +197,6 @@ class PjRtFutureBase : public PjRtFutureMoveControl<

explicit Promise(tsl::AsyncValueRef<T> ref) : ref_(std::move(ref)) {}

void SetStateConcrete() {
DCHECK(ref_) << "Promise must wrap an async value";
ref_.SetStateConcrete();
}

void SetError(absl::Status error) {
DCHECK(ref_) << "Promise must wrap an async value";
ref_.SetError(std::move(error));
}

template <typename... Args>
void emplace(Args&&... args) const {
DCHECK(ref_) << "Promise must wrap an async value";
Expand Down Expand Up @@ -250,6 +238,11 @@ class PjRtFutureBase : public PjRtFutureMoveControl<
on_block_start_(std::move(on_block_start)),
on_block_end_(std::move(on_block_end)) {}

PjRtFutureBase(T t, PjRtFutureHelpers::OnBlockStartFn on_block_start,
PjRtFutureHelpers::OnBlockEndFn on_block_end)
: PjRtFutureBase(tsl::MakeAvailableAsyncValueRef<T>(std::move(t)),
std::move(on_block_start), std::move(on_block_end)) {}

tsl::AsyncValuePtr<T> promise() const { return promise_.AsPtr(); }

PjRtFutureHelpers::ProfilingKeys OnBlockStart() const {
Expand Down Expand Up @@ -365,6 +358,7 @@ class PjRtFuture : public internal::PjRtFutureBase<T> {
// Blocks the calling thread until the future is ready, then returns the
// final value.
const T& Await() & {
CHECK(Base::IsValid());
Base::BlockUntilReady();
DCHECK(Base::promise().IsConcrete());
return *Base::promise();
Expand All @@ -373,6 +367,7 @@ class PjRtFuture : public internal::PjRtFutureBase<T> {
// Blocks the calling thread until the future is ready, then returns the
// final value.
const T& Await() const& {
CHECK(Base::IsValid());
Base::BlockUntilReady();
DCHECK(Base::promise().IsConcrete());
return *Base::promise();
Expand All @@ -381,6 +376,7 @@ class PjRtFuture : public internal::PjRtFutureBase<T> {
// Blocks the calling thread until the future is ready, then returns the
// final value.
std::conditional_t<Base::is_unique(), T, const T&> Await() && {
CHECK(Base::IsValid());
Base::BlockUntilReady();
DCHECK(Base::promise().IsConcrete());

Expand Down Expand Up @@ -444,8 +440,8 @@ class PjRtFuture : public internal::PjRtFutureBase<T> {
//
// See PjRtFuture<T> documentation above for more details.
template <>
class PjRtFuture<void> : public internal::PjRtFutureBase<std::nullopt_t> {
using Base = internal::PjRtFutureBase<std::nullopt_t>;
class PjRtFuture<void> : public internal::PjRtFutureBase<absl::Status> {
using Base = internal::PjRtFutureBase<absl::Status>;

public:
class Promise : public Base::Promise {
Expand All @@ -459,11 +455,7 @@ class PjRtFuture<void> : public internal::PjRtFutureBase<std::nullopt_t> {
// After Set is called, completion event will be delivered to waiters on the
// PjRtFuture constructed from a promise, via blocking or callbacks.
void Set(absl::Status status = absl::OkStatus()) {
if (ABSL_PREDICT_TRUE(status.ok())) {
Base::Promise::SetStateConcrete();
} else {
Base::Promise::SetError(std::move(status));
}
Base::Promise::emplace(std::move(status));
}
};

Expand All @@ -472,8 +464,7 @@ class PjRtFuture<void> : public internal::PjRtFutureBase<std::nullopt_t> {
//
// Used by clients that do not use TSL concurrency library.
static Promise CreatePromise() {
return Promise(
tsl::MakeConstructedAsyncValueRef<std::nullopt_t>(std::nullopt));
return Promise(tsl::MakeUnconstructedAsyncValueRef<absl::Status>());
}

PjRtFuture() = default;
Expand All @@ -482,10 +473,8 @@ class PjRtFuture<void> : public internal::PjRtFutureBase<std::nullopt_t> {
// is already successfully completed. Error means that future is already
// completed with an error.
explicit PjRtFuture(absl::Status status)
: Base(status.ok()
? tsl::MakeAvailableAsyncValueRef<std::nullopt_t>(std::nullopt)
: tsl::MakeErrorAsyncValueRef(std::move(status)),
/*on_block_start=*/nullptr, /*on_block_end=*/nullptr) {}
: Base(std::move(status), /*on_block_start=*/nullptr,
/*on_block_end=*/nullptr) {}

// Constructor for an unavailable PjRtFuture that will be resolved later by
// setting the promise completed.
Expand All @@ -503,8 +492,8 @@ class PjRtFuture<void> : public internal::PjRtFutureBase<std::nullopt_t> {
absl::Status Await() {
CHECK(Base::IsValid());
Base::BlockUntilReady();
return Base::promise().IsError() ? Base::promise().GetError()
: absl::OkStatus();
DCHECK(Base::promise().IsConcrete());
return *Base::promise();
}

// Registers callback to be called once the future is ready.
Expand All @@ -515,7 +504,11 @@ class PjRtFuture<void> : public internal::PjRtFutureBase<std::nullopt_t> {
// client-owned threadpool.
void OnReady(absl::AnyInvocable<void(absl::Status)> callback) const {
CHECK(Base::IsValid());
Base::promise().AndThen(std::move(callback));
Base::promise().AndThen(
[promise = Base::promise(), callback = std::move(callback)]() mutable {
DCHECK(promise.IsConcrete());
callback(promise.get());
});
}
};

Expand Down

0 comments on commit 7bb6eb1

Please sign in to comment.