Skip to content

Commit

Permalink
fix: Fix the race condition in decay average (#850)
Browse files Browse the repository at this point in the history
* fix: Fix the race condition in decay average

* fix format

* fix

* remove initial condition

* update

* code review

* update

* use clock and don't decay mean

* merge getDecay and getWeight

* update

* update
  • Loading branch information
mutianf committed Jun 3, 2021
1 parent 32284d2 commit 66a9c9e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 71 deletions.
Expand Up @@ -15,48 +15,50 @@
*/
package com.google.cloud.bigtable.data.v2.stub;

import com.google.api.core.ApiClock;
import com.google.api.core.InternalApi;
import com.google.api.core.NanoClock;
import com.google.api.gax.batching.FlowController;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

/**
* Records stats used in dynamic flow control, the decaying average of recorded latencies and the
* last timestamp when the thresholds in {@link FlowController} are updated.
*
* <pre>Exponential decaying average = weightedSum / weightedCount, where
* weightedSum(n) = weight(n) * value(n) + weightedSum(n - 1)
* weightedCount(n) = weight(n) + weightedCount(n - 1),
* and weight(n) grows exponentially over elapsed time. Biased to the past 5 minutes.
*/
final class DynamicFlowControlStats {

private static final double DEFAULT_DECAY_CONSTANT = 0.015; // Biased to the past 5 minutes
// Biased to the past 5 minutes (300 seconds), e^(-decay_constant * 300) = 0.01, decay_constant ~=
// 0.015
private static final double DEFAULT_DECAY_CONSTANT = 0.015;
// Update decay cycle start time every 15 minutes so the values won't be infinite
private static final long DECAY_CYCLE_SECOND = TimeUnit.MINUTES.toSeconds(15);

private AtomicLong lastAdjustedTimestampMs;
private DecayingAverage meanLatency;
private final AtomicLong lastAdjustedTimestampMs;
private final DecayingAverage meanLatency;

DynamicFlowControlStats() {
this(DEFAULT_DECAY_CONSTANT);
this(DEFAULT_DECAY_CONSTANT, NanoClock.getDefaultClock());
}

DynamicFlowControlStats(double decayConstant) {
@InternalApi("visible for testing")
DynamicFlowControlStats(double decayConstant, ApiClock clock) {
this.lastAdjustedTimestampMs = new AtomicLong(0);
this.meanLatency = new DecayingAverage(decayConstant);
this.meanLatency = new DecayingAverage(decayConstant, clock);
}

void updateLatency(long latency) {
updateLatency(latency, System.currentTimeMillis());
}

@VisibleForTesting
void updateLatency(long latency, long timestampMs) {
meanLatency.update(latency, timestampMs);
meanLatency.update(latency);
}

/** Return the mean calculated from the last update, will not decay over time. */
double getMeanLatency() {
return getMeanLatency(System.currentTimeMillis());
}

@VisibleForTesting
double getMeanLatency(long timestampMs) {
return meanLatency.getMean(timestampMs);
return meanLatency.getMean();
}

public long getLastAdjustedTimestampMs() {
Expand All @@ -71,46 +73,45 @@ private class DecayingAverage {
private double decayConstant;
private double mean;
private double weightedCount;
private AtomicLong lastUpdateTimeInSecond;
private long decayCycleStartEpoch;
private final ApiClock clock;

DecayingAverage(double decayConstant) {
DecayingAverage(double decayConstant, ApiClock clock) {
this.decayConstant = decayConstant;
this.mean = 0.0;
this.weightedCount = 0.0;
this.lastUpdateTimeInSecond = new AtomicLong(0);
this.clock = clock;
this.decayCycleStartEpoch = TimeUnit.MILLISECONDS.toSeconds(clock.millisTime());
}

synchronized void update(long value, long timestampMs) {
long now = TimeUnit.MILLISECONDS.toSeconds(timestampMs);
Preconditions.checkArgument(
now >= lastUpdateTimeInSecond.get(), "can't update an event in the past");
if (lastUpdateTimeInSecond.get() == 0) {
lastUpdateTimeInSecond.set(now);
mean = value;
weightedCount = 1;
} else {
long prev = lastUpdateTimeInSecond.getAndSet(now);
long elapsed = now - prev;
double alpha = getAlpha(elapsed);
// Exponential moving average = weightedSum / weightedCount, where
// weightedSum(n) = value + alpha * weightedSum(n - 1)
// weightedCount(n) = 1 + alpha * weightedCount(n - 1)
// Using weighted count in case the sum overflows
mean =
mean * ((weightedCount * alpha) / (weightedCount * alpha + 1))
+ value / (weightedCount * alpha + 1);
weightedCount = weightedCount * alpha + 1;
}
synchronized void update(long value) {
long now = TimeUnit.MILLISECONDS.toSeconds(clock.millisTime());
double weight = getWeight(now);
// Using weighted count in case the sum overflows.
mean =
mean * (weightedCount / (weightedCount + weight))
+ weight * value / (weightedCount + weight);
weightedCount += weight;
}

double getMean(long timestampMs) {
long timestampSecs = TimeUnit.MILLISECONDS.toSeconds(timestampMs);
long elapsed = timestampSecs - lastUpdateTimeInSecond.get();
return mean * getAlpha(Math.max(0, elapsed));
double getMean() {
return mean;
}

private double getAlpha(long elapsedSecond) {
return Math.exp(-decayConstant * elapsedSecond);
private double getWeight(long now) {
long elapsedSecond = now - decayCycleStartEpoch;
double weight = Math.exp(decayConstant * elapsedSecond);
// Decay mean, weightedCount and reset decay cycle start epoch every 15 minutes, so the
// values won't be infinite
if (elapsedSecond > DECAY_CYCLE_SECOND) {
mean /= weight;
weightedCount /= weight;
decayCycleStartEpoch = now;
// After resetting start time, weight = e^0 = 1
return 1;
} else {
return weight;
}
}
}
}
Expand Up @@ -17,50 +17,56 @@

import static com.google.common.truth.Truth.assertThat;

import com.google.api.core.ApiClock;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;

@RunWith(JUnit4.class)
public class DynamicFlowControlStatsTest {

@Rule public final MockitoRule rule = MockitoJUnit.rule();

@Mock private ApiClock clock;

@Test
public void testUpdate() {
DynamicFlowControlStats stats = new DynamicFlowControlStats();
long now = System.currentTimeMillis();

stats.updateLatency(10, now);
assertThat(stats.getMeanLatency(now)).isEqualTo(10);

stats.updateLatency(10, now);
stats.updateLatency(10, now);
assertThat(stats.getMeanLatency(now)).isEqualTo(10);
Mockito.when(clock.millisTime()).thenReturn(0L);
DynamicFlowControlStats stats = new DynamicFlowControlStats(0.015, clock);
stats.updateLatency(10);
assertThat(stats.getMeanLatency()).isEqualTo(10);
stats.updateLatency(10);
stats.updateLatency(10);
assertThat(stats.getMeanLatency()).isEqualTo(10);

// In five minutes the previous latency should be decayed to under 1. And the new average should
// be very close to 20
long fiveMinutesLater = now + TimeUnit.MINUTES.toMillis(5);
assertThat(stats.getMeanLatency(fiveMinutesLater)).isLessThan(1);
stats.updateLatency(20, fiveMinutesLater);
assertThat(stats.getMeanLatency(fiveMinutesLater)).isGreaterThan(19);
assertThat(stats.getMeanLatency(fiveMinutesLater)).isLessThan(20);

long aDayLater = now + TimeUnit.HOURS.toMillis(24);
assertThat(stats.getMeanLatency(aDayLater)).isZero();
Mockito.when(clock.millisTime()).thenReturn(TimeUnit.MINUTES.toMillis(5));
stats.updateLatency(20);
assertThat(stats.getMeanLatency()).isGreaterThan(19);
assertThat(stats.getMeanLatency()).isLessThan(20);

long timestamp = aDayLater;
// After a day
long aDay = TimeUnit.DAYS.toMillis(1);
for (int i = 0; i < 10; i++) {
timestamp += TimeUnit.SECONDS.toMillis(i);
stats.updateLatency(i, timestamp);
Mockito.when(clock.millisTime()).thenReturn(aDay + TimeUnit.SECONDS.toMillis(i));
stats.updateLatency(i);
}
assertThat(stats.getMeanLatency(timestamp)).isGreaterThan(4.5);
assertThat(stats.getMeanLatency(timestamp)).isLessThan(6);
assertThat(stats.getMeanLatency()).isGreaterThan(4.5);
assertThat(stats.getMeanLatency()).isLessThan(6);
}

@Test(timeout = 1000)
Expand Down

0 comments on commit 66a9c9e

Please sign in to comment.