Skip to content

Commit

Permalink
[FLINK-35168][State] Basic State Iterator for async processing (#24690)
Browse files Browse the repository at this point in the history
This closes #24690
  • Loading branch information
Zakelly committed May 13, 2024
1 parent 4e6b420 commit 0158678
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.asyncprocessing;

import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.state.v2.StateFuture;
import org.apache.flink.api.common.state.v2.StateIterator;
import org.apache.flink.core.state.InternalStateFuture;
import org.apache.flink.core.state.StateFutureUtils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.function.Consumer;
import java.util.function.Function;

/**
* A {@link StateIterator} implementation to facilitate async data load of iterator. Each state
* backend could override this class to maintain more variables in need. Any subclass should
* implement two methods, {@link #hasNext()} and {@link #nextPayloadForContinuousLoading()}. The
* philosophy behind this class is to carry some already loaded elements and provide iterating right
* on the task thread, and load following ones if needed (determined by {@link #hasNext()}) by
* creating **ANOTHER** iterating request. Thus, later it returns another iterator instance, and we
* continue to apply the user iteration on that instance. The whole elements will be iterated by
* recursive call of {@code #onNext()}.
*/
@SuppressWarnings("rawtypes")
public abstract class AbstractStateIterator<T> implements StateIterator<T> {

/** The state this iterator iterates on. */
final State originalState;

/** The request type that create this iterator. */
final StateRequestType requestType;

/** The controller that can receive further requests. */
final StateRequestHandler stateHandler;

/** The already loaded partial elements. */
final Collection<T> cache;

public AbstractStateIterator(
State originalState,
StateRequestType requestType,
StateRequestHandler stateHandler,
Collection<T> partialResult) {
this.originalState = originalState;
this.requestType = requestType;
this.stateHandler = stateHandler;
this.cache = partialResult;
}

/** Return whether this iterator has more elements to load besides current cache. */
protected abstract boolean hasNext();

/**
* To perform following loading, build and get next payload for the next request. This will put
* into {@link StateRequest#getPayload()}.
*
* @return the packed payload for next loading.
*/
protected abstract Object nextPayloadForContinuousLoading();

protected StateRequestType getRequestType() {
return requestType;
}

@SuppressWarnings("unchecked")
private InternalStateFuture<StateIterator<T>> asyncNextLoad() {
return stateHandler.handleRequest(
originalState,
StateRequestType.ITERATOR_LOADING,
nextPayloadForContinuousLoading());
}

@Override
public <U> StateFuture<Collection<U>> onNext(Function<T, StateFuture<? extends U>> iterating) {
// Public interface implementation, this is on task thread.
// We perform the user code on cache, and create a new request and chain with it.
if (isEmpty()) {
return StateFutureUtils.completedFuture(Collections.emptyList());
}
Collection<StateFuture<? extends U>> resultFutures = new ArrayList<>();

for (T item : cache) {
resultFutures.add(iterating.apply(item));
}
if (hasNext()) {
return StateFutureUtils.combineAll(resultFutures)
.thenCombine(
asyncNextLoad().thenCompose(itr -> itr.onNext(iterating)),
(a, b) -> {
// TODO optimization: Avoid results copy.
Collection<U> result = new ArrayList<>(a.size() + b.size());
result.addAll(a);
result.addAll(b);
return result;
});
} else {
return StateFutureUtils.combineAll(resultFutures);
}
}

@Override
public StateFuture<Void> onNext(Consumer<T> iterating) {
// Public interface implementation, this is on task thread.
// We perform the user code on cache, and create a new request and chain with it.
if (isEmpty()) {
return StateFutureUtils.completedVoidFuture();
}
for (T item : cache) {
iterating.accept(item);
}
if (hasNext()) {
return asyncNextLoad().thenCompose(itr -> itr.onNext(iterating));
} else {
return StateFutureUtils.completedVoidFuture();
}
}

@Override
public boolean isEmpty() {
return (cache == null || cache.isEmpty()) && !hasNext();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,8 @@ public enum StateRequestType {
* Check the existence of any key-value mapping within current partition, {@link
* MapState#asyncIsEmpty()}.
*/
MAP_IS_EMPTY
MAP_IS_EMPTY,

/** Continuously load elements for one iterator. */
ITERATOR_LOADING
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.asyncprocessing;

import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.state.v2.StateIterator;
import org.apache.flink.core.state.StateFutureUtils;
import org.apache.flink.runtime.mailbox.SyncMailboxExecutor;
import org.apache.flink.util.Preconditions;

import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;

/**
* The tests for {@link AbstractStateIterator} which facilitate the basic partial loading of state
* asynchronous iterators.
*/
public class AbstractStateIteratorTest {

@Test
@SuppressWarnings({"unchecked", "rawtypes"})
public void testPartialLoading() {
TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
AsyncExecutionController aec =
new AsyncExecutionController(
new SyncMailboxExecutor(), (a, b) -> {}, stateExecutor, 1, 100, 1000, 1);
stateExecutor.bindAec(aec);
RecordContext<String> recordContext = aec.buildContext("1", "key1");
aec.setCurrentContext(recordContext);

AtomicInteger processed = new AtomicInteger();

aec.handleRequest(null, StateRequestType.MAP_ITER, null)
.thenAccept(
(iter) -> {
assertThat(iter).isInstanceOf(StateIterator.class);
((StateIterator<Integer>) iter)
.onNext(
(item) -> {
assertThat(item)
.isEqualTo(processed.getAndIncrement());
})
.thenAccept(
(v) -> {
assertThat(processed.get()).isEqualTo(100);
});
});
aec.drainInflightRecords(0);
}

@Test
@SuppressWarnings({"unchecked", "rawtypes"})
public void testPartialLoadingWithReturnValue() {
TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
AsyncExecutionController aec =
new AsyncExecutionController(
new SyncMailboxExecutor(), (a, b) -> {}, stateExecutor, 1, 100, 1000, 1);
stateExecutor.bindAec(aec);
RecordContext<String> recordContext = aec.buildContext("1", "key1");
aec.setCurrentContext(recordContext);

AtomicInteger processed = new AtomicInteger();

aec.handleRequest(null, StateRequestType.MAP_ITER, null)
.thenAccept(
(iter) -> {
assertThat(iter).isInstanceOf(StateIterator.class);
((StateIterator<Integer>) iter)
.onNext(
(item) -> {
assertThat(item)
.isEqualTo(processed.getAndIncrement());
return StateFutureUtils.completedFuture(
String.valueOf(item));
})
.thenAccept(
(strings) -> {
assertThat(processed.get()).isEqualTo(100);
int validate = 0;
for (String item : strings) {
assertThat(item)
.isEqualTo(String.valueOf(validate++));
}
});
});
aec.drainInflightRecords(0);
}

/**
* A brief implementation of {@link StateExecutor}, to illustrate the interaction between AEC
* and StateExecutor.
*/
@SuppressWarnings({"rawtypes"})
static class TestIteratorStateExecutor implements StateExecutor {

final int limit;

final int step;

AsyncExecutionController aec;

int current = 0;

AtomicInteger processedCount = new AtomicInteger(0);

public TestIteratorStateExecutor(int limit, int step) {
this.limit = limit;
this.step = step;
}

public void bindAec(AsyncExecutionController aec) {
this.aec = aec;
}

@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public CompletableFuture<Void> executeBatchRequests(
StateRequestContainer stateRequestContainer) {
Preconditions.checkArgument(stateRequestContainer instanceof MockStateRequestContainer);
CompletableFuture<Void> future = new CompletableFuture<>();
for (StateRequest request :
((MockStateRequestContainer) stateRequestContainer).getStateRequestList()) {
if (request.getRequestType() == StateRequestType.MAP_ITER) {
ArrayList<Integer> results = new ArrayList<>(step);
for (int i = 0; current < limit && i < step; i++) {
results.add(current++);
}
request.getFuture()
.complete(
new TestIterator(
request.getState(),
request.getRequestType(),
aec,
results,
current,
limit));
} else if (request.getRequestType() == StateRequestType.ITERATOR_LOADING) {
assertThat(request.getPayload()).isInstanceOf(TestIterator.class);
assertThat(((TestIterator) request.getPayload()).current).isEqualTo(current);
ArrayList<Integer> results = new ArrayList<>(step);
for (int i = 0; current < limit && i < step; i++) {
results.add(current++);
}
request.getFuture()
.complete(
new TestIterator(
request.getState(),
((TestIterator) request.getPayload()).getRequestType(),
aec,
results,
current,
limit));
} else {
fail("Unsupported request type " + request.getRequestType());
}
processedCount.incrementAndGet();
}
future.complete(null);
return future;
}

@Override
public StateRequestContainer createStateRequestContainer() {
return new MockStateRequestContainer();
}

@Override
public void shutdown() {}

static class TestIterator extends AbstractStateIterator<Integer> {

final int current;

final int limit;

public TestIterator(
State originalState,
StateRequestType requestType,
AsyncExecutionController aec,
Collection<Integer> partialResult,
int current,
int limit) {
super(originalState, requestType, aec, partialResult);
this.current = current;
this.limit = limit;
}

@Override
protected boolean hasNext() {
return current < limit;
}

@Override
protected Object nextPayloadForContinuousLoading() {
return this;
}
}
}
}

0 comments on commit 0158678

Please sign in to comment.