diff --git a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStats.java b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStats.java index 4169ac213..01fabe52c 100644 --- a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStats.java +++ b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStats.java @@ -15,48 +15,50 @@ */ package com.google.cloud.bigtable.data.v2.stub; +import com.google.api.core.ApiClock; +import com.google.api.core.InternalApi; +import com.google.api.core.NanoClock; import com.google.api.gax.batching.FlowController; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; /** * Records stats used in dynamic flow control, the decaying average of recorded latencies and the * last timestamp when the thresholds in {@link FlowController} are updated. + * + *
Exponential decaying average = weightedSum / weightedCount, where
+ *   weightedSum(n) = weight(n) * value(n) + weightedSum(n - 1)
+ *   weightedCount(n) = weight(n) + weightedCount(n - 1),
+ * and weight(n) grows exponentially over elapsed time. Biased to the past 5 minutes.
  */
 final class DynamicFlowControlStats {
 
-  private static final double DEFAULT_DECAY_CONSTANT = 0.015; // Biased to the past 5 minutes
+  // Biased to the past 5 minutes (300 seconds), e^(-decay_constant * 300) = 0.01, decay_constant ~=
+  // 0.015
+  private static final double DEFAULT_DECAY_CONSTANT = 0.015;
+  // Update decay cycle start time every 15 minutes so the values won't be infinite
+  private static final long DECAY_CYCLE_SECOND = TimeUnit.MINUTES.toSeconds(15);
 
-  private AtomicLong lastAdjustedTimestampMs;
-  private DecayingAverage meanLatency;
+  private final AtomicLong lastAdjustedTimestampMs;
+  private final DecayingAverage meanLatency;
 
   DynamicFlowControlStats() {
-    this(DEFAULT_DECAY_CONSTANT);
+    this(DEFAULT_DECAY_CONSTANT, NanoClock.getDefaultClock());
   }
 
-  DynamicFlowControlStats(double decayConstant) {
+  @InternalApi("visible for testing")
+  DynamicFlowControlStats(double decayConstant, ApiClock clock) {
     this.lastAdjustedTimestampMs = new AtomicLong(0);
-    this.meanLatency = new DecayingAverage(decayConstant);
+    this.meanLatency = new DecayingAverage(decayConstant, clock);
   }
 
   void updateLatency(long latency) {
-    updateLatency(latency, System.currentTimeMillis());
-  }
-
-  @VisibleForTesting
-  void updateLatency(long latency, long timestampMs) {
-    meanLatency.update(latency, timestampMs);
+    meanLatency.update(latency);
   }
 
+  /** Return the mean calculated from the last update, will not decay over time. */
   double getMeanLatency() {
-    return getMeanLatency(System.currentTimeMillis());
-  }
-
-  @VisibleForTesting
-  double getMeanLatency(long timestampMs) {
-    return meanLatency.getMean(timestampMs);
+    return meanLatency.getMean();
   }
 
   public long getLastAdjustedTimestampMs() {
@@ -71,46 +73,45 @@ private class DecayingAverage {
     private double decayConstant;
     private double mean;
     private double weightedCount;
-    private AtomicLong lastUpdateTimeInSecond;
+    private long decayCycleStartEpoch;
+    private final ApiClock clock;
 
-    DecayingAverage(double decayConstant) {
+    DecayingAverage(double decayConstant, ApiClock clock) {
       this.decayConstant = decayConstant;
       this.mean = 0.0;
       this.weightedCount = 0.0;
-      this.lastUpdateTimeInSecond = new AtomicLong(0);
+      this.clock = clock;
+      this.decayCycleStartEpoch = TimeUnit.MILLISECONDS.toSeconds(clock.millisTime());
     }
 
-    synchronized void update(long value, long timestampMs) {
-      long now = TimeUnit.MILLISECONDS.toSeconds(timestampMs);
-      Preconditions.checkArgument(
-          now >= lastUpdateTimeInSecond.get(), "can't update an event in the past");
-      if (lastUpdateTimeInSecond.get() == 0) {
-        lastUpdateTimeInSecond.set(now);
-        mean = value;
-        weightedCount = 1;
-      } else {
-        long prev = lastUpdateTimeInSecond.getAndSet(now);
-        long elapsed = now - prev;
-        double alpha = getAlpha(elapsed);
-        // Exponential moving average = weightedSum / weightedCount, where
-        // weightedSum(n) = value + alpha * weightedSum(n - 1)
-        // weightedCount(n) = 1 + alpha * weightedCount(n - 1)
-        // Using weighted count in case the sum overflows
-        mean =
-            mean * ((weightedCount * alpha) / (weightedCount * alpha + 1))
-                + value / (weightedCount * alpha + 1);
-        weightedCount = weightedCount * alpha + 1;
-      }
+    synchronized void update(long value) {
+      long now = TimeUnit.MILLISECONDS.toSeconds(clock.millisTime());
+      double weight = getWeight(now);
+      // Using weighted count in case the sum overflows.
+      mean =
+          mean * (weightedCount / (weightedCount + weight))
+              + weight * value / (weightedCount + weight);
+      weightedCount += weight;
     }
 
-    double getMean(long timestampMs) {
-      long timestampSecs = TimeUnit.MILLISECONDS.toSeconds(timestampMs);
-      long elapsed = timestampSecs - lastUpdateTimeInSecond.get();
-      return mean * getAlpha(Math.max(0, elapsed));
+    double getMean() {
+      return mean;
     }
 
-    private double getAlpha(long elapsedSecond) {
-      return Math.exp(-decayConstant * elapsedSecond);
+    private double getWeight(long now) {
+      long elapsedSecond = now - decayCycleStartEpoch;
+      double weight = Math.exp(decayConstant * elapsedSecond);
+      // Decay mean, weightedCount and reset decay cycle start epoch every 15 minutes, so the
+      // values won't be infinite
+      if (elapsedSecond > DECAY_CYCLE_SECOND) {
+        mean /= weight;
+        weightedCount /= weight;
+        decayCycleStartEpoch = now;
+        // After resetting start time, weight = e^0 = 1
+        return 1;
+      } else {
+        return weight;
+      }
     }
   }
 }
diff --git a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStatsTest.java b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStatsTest.java
index 653489f33..2a407dda9 100644
--- a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStatsTest.java
+++ b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/DynamicFlowControlStatsTest.java
@@ -17,6 +17,7 @@
 
 import static com.google.common.truth.Truth.assertThat;
 
+import com.google.api.core.ApiClock;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.concurrent.ExecutionException;
@@ -24,43 +25,48 @@
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
 
 @RunWith(JUnit4.class)
 public class DynamicFlowControlStatsTest {
 
+  @Rule public final MockitoRule rule = MockitoJUnit.rule();
+
+  @Mock private ApiClock clock;
+
   @Test
   public void testUpdate() {
-    DynamicFlowControlStats stats = new DynamicFlowControlStats();
-    long now = System.currentTimeMillis();
 
-    stats.updateLatency(10, now);
-    assertThat(stats.getMeanLatency(now)).isEqualTo(10);
-
-    stats.updateLatency(10, now);
-    stats.updateLatency(10, now);
-    assertThat(stats.getMeanLatency(now)).isEqualTo(10);
+    Mockito.when(clock.millisTime()).thenReturn(0L);
+    DynamicFlowControlStats stats = new DynamicFlowControlStats(0.015, clock);
+    stats.updateLatency(10);
+    assertThat(stats.getMeanLatency()).isEqualTo(10);
+    stats.updateLatency(10);
+    stats.updateLatency(10);
+    assertThat(stats.getMeanLatency()).isEqualTo(10);
 
     // In five minutes the previous latency should be decayed to under 1. And the new average should
     // be very close to 20
-    long fiveMinutesLater = now + TimeUnit.MINUTES.toMillis(5);
-    assertThat(stats.getMeanLatency(fiveMinutesLater)).isLessThan(1);
-    stats.updateLatency(20, fiveMinutesLater);
-    assertThat(stats.getMeanLatency(fiveMinutesLater)).isGreaterThan(19);
-    assertThat(stats.getMeanLatency(fiveMinutesLater)).isLessThan(20);
-
-    long aDayLater = now + TimeUnit.HOURS.toMillis(24);
-    assertThat(stats.getMeanLatency(aDayLater)).isZero();
+    Mockito.when(clock.millisTime()).thenReturn(TimeUnit.MINUTES.toMillis(5));
+    stats.updateLatency(20);
+    assertThat(stats.getMeanLatency()).isGreaterThan(19);
+    assertThat(stats.getMeanLatency()).isLessThan(20);
 
-    long timestamp = aDayLater;
+    // After a day
+    long aDay = TimeUnit.DAYS.toMillis(1);
     for (int i = 0; i < 10; i++) {
-      timestamp += TimeUnit.SECONDS.toMillis(i);
-      stats.updateLatency(i, timestamp);
+      Mockito.when(clock.millisTime()).thenReturn(aDay + TimeUnit.SECONDS.toMillis(i));
+      stats.updateLatency(i);
     }
-    assertThat(stats.getMeanLatency(timestamp)).isGreaterThan(4.5);
-    assertThat(stats.getMeanLatency(timestamp)).isLessThan(6);
+    assertThat(stats.getMeanLatency()).isGreaterThan(4.5);
+    assertThat(stats.getMeanLatency()).isLessThan(6);
   }
 
   @Test(timeout = 1000)