Skip to content

Commit

Permalink
fix: ManagedChannel shutdown issues (#700)
Browse files Browse the repository at this point in the history
* fix: ManagedChannel shutdown issues

ManagedChannel emits a warning when it is marked for finalization; calling close() in the finalize() method is insufficient to remove this warning. Add ManagedBacklogReaderFactory to tie ManagedChannel shutdown to the SDF finalize method

* fix: reformat
  • Loading branch information
dpcollins-google committed Jun 16, 2021
1 parent 2f55375 commit 2d0cbde
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 23 deletions.
@@ -0,0 +1,32 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.pubsublite.beam;

import java.io.Serializable;

/**
* A ManagedBacklogReaderFactory produces TopicBacklogReaders and tears down any produced readers
* when it is itself closed.
*
* <p>close() should never be called on produced readers.
*/
public interface ManagedBacklogReaderFactory extends AutoCloseable, Serializable {
TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition);

@Override
void close();
}
@@ -0,0 +1,67 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.pubsublite.beam;

import com.google.api.gax.rpc.ApiException;
import com.google.cloud.pubsublite.Offset;
import com.google.cloud.pubsublite.proto.ComputeMessageStatsResponse;
import java.util.HashMap;
import java.util.Map;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.sdk.transforms.SerializableFunction;

public class ManagedBacklogReaderFactoryImpl implements ManagedBacklogReaderFactory {
private final SerializableFunction<SubscriptionPartition, TopicBacklogReader> newReader;

@GuardedBy("this")
private final Map<SubscriptionPartition, TopicBacklogReader> readers = new HashMap<>();

ManagedBacklogReaderFactoryImpl(
SerializableFunction<SubscriptionPartition, TopicBacklogReader> newReader) {
this.newReader = newReader;
}

private static final class NonCloseableTopicBacklogReader implements TopicBacklogReader {
private final TopicBacklogReader underlying;

NonCloseableTopicBacklogReader(TopicBacklogReader underlying) {
this.underlying = underlying;
}

@Override
public ComputeMessageStatsResponse computeMessageStats(Offset offset) throws ApiException {
return underlying.computeMessageStats(offset);
}

@Override
public void close() {
throw new IllegalArgumentException(
"Cannot call close() on a reader returned from ManagedBacklogReaderFactory.");
}
}

@Override
public synchronized TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition) {
return new NonCloseableTopicBacklogReader(
readers.computeIfAbsent(subscriptionPartition, newReader::apply));
}

@Override
public synchronized void close() {
readers.values().forEach(TopicBacklogReader::close);
}
}
Expand Up @@ -42,7 +42,7 @@
* would return ProcessContinuation.resume().
*/
class OffsetByteRangeTracker extends TrackerWithProgress {
private final TopicBacklogReader backlogReader;
private final TopicBacklogReader unownedBacklogReader;
private final Duration minTrackingTime;
private final long minBytesReceived;
private final Stopwatch stopwatch;
Expand All @@ -51,7 +51,7 @@ class OffsetByteRangeTracker extends TrackerWithProgress {

public OffsetByteRangeTracker(
OffsetByteRange range,
TopicBacklogReader backlogReader,
TopicBacklogReader unownedBacklogReader,
Stopwatch stopwatch,
Duration minTrackingTime,
long minBytesReceived) {
Expand All @@ -61,18 +61,13 @@ public OffsetByteRangeTracker(
checkArgument(
range.getByteCount() == 0L,
"May only construct OffsetByteRangeTracker with an unbounded range with no progress.");
this.backlogReader = backlogReader;
this.unownedBacklogReader = unownedBacklogReader;
this.minTrackingTime = minTrackingTime;
this.minBytesReceived = minBytesReceived;
this.stopwatch = stopwatch.reset().start();
this.range = range;
}

@Override
public void finalize() {
this.backlogReader.close();
}

@Override
public IsBounded isBounded() {
return IsBounded.UNBOUNDED;
Expand Down Expand Up @@ -170,7 +165,7 @@ public void checkDone() throws IllegalStateException {
@Override
public Progress getProgress() {
ComputeMessageStatsResponse stats =
this.backlogReader.computeMessageStats(Offset.of(nextOffset()));
this.unownedBacklogReader.computeMessageStats(Offset.of(nextOffset()));
return Progress.from(range.getByteCount(), stats.getMessageBytes());
}
}
Expand Up @@ -31,27 +31,35 @@

class PerSubscriptionPartitionSdf extends DoFn<SubscriptionPartition, SequencedMessage> {
private final Duration maxSleepTime;
private final ManagedBacklogReaderFactory backlogReaderFactory;
private final SubscriptionPartitionProcessorFactory processorFactory;
private final SerializableFunction<SubscriptionPartition, InitialOffsetReader>
offsetReaderFactory;
private final SerializableBiFunction<SubscriptionPartition, OffsetByteRange, TrackerWithProgress>
private final SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress>
trackerFactory;
private final SerializableFunction<SubscriptionPartition, Committer> committerFactory;

PerSubscriptionPartitionSdf(
Duration maxSleepTime,
ManagedBacklogReaderFactory backlogReaderFactory,
SerializableFunction<SubscriptionPartition, InitialOffsetReader> offsetReaderFactory,
SerializableBiFunction<SubscriptionPartition, OffsetByteRange, TrackerWithProgress>
SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress>
trackerFactory,
SubscriptionPartitionProcessorFactory processorFactory,
SerializableFunction<SubscriptionPartition, Committer> committerFactory) {
this.maxSleepTime = maxSleepTime;
this.backlogReaderFactory = backlogReaderFactory;
this.processorFactory = processorFactory;
this.offsetReaderFactory = offsetReaderFactory;
this.trackerFactory = trackerFactory;
this.committerFactory = committerFactory;
}

@Teardown
public void teardown() {
backlogReaderFactory.close();
}

@GetInitialWatermarkEstimatorState
public Instant getInitialWatermarkState() {
return Instant.EPOCH;
Expand Down Expand Up @@ -103,7 +111,7 @@ public OffsetByteRange getInitialRestriction(
@NewTracker
public TrackerWithProgress newTracker(
@Element SubscriptionPartition subscriptionPartition, @Restriction OffsetByteRange range) {
return trackerFactory.apply(subscriptionPartition, range);
return trackerFactory.apply(backlogReaderFactory.newReader(subscriptionPartition), range);
}

@GetSize
Expand Down
Expand Up @@ -85,12 +85,16 @@ private SubscriptionPartitionProcessor newPartitionProcessor(
options.flowControlSettings());
}

private TrackerWithProgress newRestrictionTracker(
SubscriptionPartition subscriptionPartition, OffsetByteRange initial) {
private TopicBacklogReader newBacklogReader(SubscriptionPartition subscriptionPartition) {
checkSubscription(subscriptionPartition);
return options.getBacklogReader(subscriptionPartition.partition());
}

private TrackerWithProgress newRestrictionTracker(
TopicBacklogReader backlogReader, OffsetByteRange initial) {
return new OffsetByteRangeTracker(
initial,
options.getBacklogReader(subscriptionPartition.partition()),
backlogReader,
Stopwatch.createUnstarted(),
options.minBundleTimeout(),
LongMath.saturatedMultiply(options.flowControlSettings().bytesOutstanding(), 10));
Expand Down Expand Up @@ -129,6 +133,7 @@ public PCollection<SequencedMessage> expand(PBegin input) {
new PerSubscriptionPartitionSdf(
// Ensure we read for at least 5 seconds more than the bundle timeout.
options.minBundleTimeout().plus(Duration.standardSeconds(5)),
new ManagedBacklogReaderFactoryImpl(this::newBacklogReader),
this::newInitialOffsetReader,
this::newRestrictionTracker,
this::newPartitionProcessor,
Expand Down
Expand Up @@ -48,7 +48,7 @@ public class OffsetByteRangeTrackerTest {
private static final double IGNORED_FRACTION = -10000000.0;
private static final long MIN_BYTES = 1000;
private static final OffsetRange RANGE = new OffsetRange(123L, Long.MAX_VALUE);
private final TopicBacklogReader reader = mock(TopicBacklogReader.class);
private final TopicBacklogReader unownedBacklogReader = mock(TopicBacklogReader.class);

@Spy Ticker ticker;
private OffsetByteRangeTracker tracker;
Expand All @@ -60,7 +60,7 @@ public void setUp() {
tracker =
new OffsetByteRangeTracker(
OffsetByteRange.of(RANGE, 0),
reader,
unownedBacklogReader,
Stopwatch.createUnstarted(ticker),
Duration.millis(500),
MIN_BYTES);
Expand All @@ -70,7 +70,7 @@ public void setUp() {
public void progressTracked() {
assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(123), 10)));
assertTrue(tracker.tryClaim(OffsetByteProgress.of(Offset.of(124), 11)));
when(reader.computeMessageStats(Offset.of(125)))
when(unownedBacklogReader.computeMessageStats(Offset.of(125)))
.thenReturn(ComputeMessageStatsResponse.newBuilder().setMessageBytes(1000).build());
Progress progress = tracker.getProgress();
assertEquals(21, progress.getWorkCompleted(), .0001);
Expand All @@ -79,7 +79,7 @@ public void progressTracked() {

@Test
public void getProgressStatsFailure() {
when(reader.computeMessageStats(Offset.of(123)))
when(unownedBacklogReader.computeMessageStats(Offset.of(123)))
.thenThrow(new CheckedApiException(Code.INTERNAL).underlying);
assertThrows(ApiException.class, tracker::getProgress);
}
Expand Down
Expand Up @@ -74,9 +74,11 @@ public class PerSubscriptionPartitionSdfTest {

@Mock SerializableFunction<SubscriptionPartition, InitialOffsetReader> offsetReaderFactory;

@Mock ManagedBacklogReaderFactory backlogReaderFactory;
@Mock TopicBacklogReader backlogReader;

@Mock
SerializableBiFunction<SubscriptionPartition, OffsetByteRange, TrackerWithProgress>
trackerFactory;
SerializableBiFunction<TopicBacklogReader, OffsetByteRange, TrackerWithProgress> trackerFactory;

@Mock SubscriptionPartitionProcessorFactory processorFactory;
@Mock SerializableFunction<SubscriptionPartition, Committer> committerFactory;
Expand All @@ -100,9 +102,11 @@ public void setUp() {
when(trackerFactory.apply(any(), any())).thenReturn(tracker);
when(committerFactory.apply(any())).thenReturn(committer);
when(tracker.currentRestriction()).thenReturn(RESTRICTION);
when(backlogReaderFactory.newReader(any())).thenReturn(backlogReader);
sdf =
new PerSubscriptionPartitionSdf(
MAX_SLEEP_TIME,
backlogReaderFactory,
offsetReaderFactory,
trackerFactory,
processorFactory,
Expand All @@ -128,7 +132,13 @@ public void getInitialRestrictionReadFailure() {
@Test
public void newTrackerCallsFactory() {
assertSame(tracker, sdf.newTracker(PARTITION, RESTRICTION));
verify(trackerFactory).apply(PARTITION, RESTRICTION);
verify(trackerFactory).apply(backlogReader, RESTRICTION);
}

@Test
public void tearDownClosesBacklogReaderFactory() {
sdf.teardown();
verify(backlogReaderFactory).close();
}

@Test
Expand Down Expand Up @@ -162,13 +172,29 @@ public void process() throws Exception {
order2.verify(committer).awaitTerminated();
}

private static final class NoopManagedBacklogReaderFactory
implements ManagedBacklogReaderFactory {
@Override
public TopicBacklogReader newReader(SubscriptionPartition subscriptionPartition) {
return null;
}

@Override
public void close() {}
}

@Test
@SuppressWarnings("return.type.incompatible")
public void dofnIsSerializable() throws Exception {
ObjectOutputStream output = new ObjectOutputStream(new ByteArrayOutputStream());
output.writeObject(
new PerSubscriptionPartitionSdf(
MAX_SLEEP_TIME, x -> null, (x, y) -> null, (x, y, z) -> null, (x) -> null));
MAX_SLEEP_TIME,
new NoopManagedBacklogReaderFactory(),
x -> null,
(x, y) -> null,
(x, y, z) -> null,
(x) -> null));
}

@Test
Expand Down

0 comments on commit 2d0cbde

Please sign in to comment.