diff --git a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/models/Filters.java b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/models/Filters.java index bd310cd2c..4c8260854 100644 --- a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/models/Filters.java +++ b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/models/Filters.java @@ -25,6 +25,9 @@ import com.google.cloud.bigtable.data.v2.models.Range.AbstractTimestampRange; import com.google.common.base.Preconditions; import com.google.protobuf.ByteString; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.Serializable; import javax.annotation.Nonnull; @@ -200,7 +203,19 @@ public Filter label(@Nonnull String label) { // Implementations of target specific filters. /** DSL for adding filters to a chain. */ public static final class ChainFilter implements Filter { - private RowFilter.Chain.Builder builder; + private static final long serialVersionUID = -6756759448656768478L; + private transient RowFilter.Chain.Builder builder; + + private void writeObject(ObjectOutputStream s) throws IOException { + s.defaultWriteObject(); + s.writeObject(builder.build()); + } + + private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException { + s.defaultReadObject(); + RowFilter.Chain chain = (RowFilter.Chain) s.readObject(); + this.builder = chain.toBuilder(); + } private ChainFilter() { this.builder = RowFilter.Chain.newBuilder(); @@ -241,7 +256,19 @@ public ChainFilter clone() { /** DSL for adding filters to the interleave list. */ public static final class InterleaveFilter implements Filter { - private RowFilter.Interleave.Builder builder; + private static final long serialVersionUID = -6356151037337889421L; + private transient RowFilter.Interleave.Builder builder; + + private void writeObject(ObjectOutputStream s) throws IOException { + s.defaultWriteObject(); + s.writeObject(builder.build()); + } + + private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException { + s.defaultReadObject(); + RowFilter.Interleave interleave = (RowFilter.Interleave) s.readObject(); + this.builder = interleave.toBuilder(); + } private InterleaveFilter() { builder = RowFilter.Interleave.newBuilder(); @@ -281,7 +308,19 @@ public InterleaveFilter clone() { /** DSL for configuring a conditional filter. */ public static final class ConditionFilter implements Filter { - private RowFilter.Condition.Builder builder; + private static final long serialVersionUID = -2720899822014446776L; + private transient RowFilter.Condition.Builder builder; + + private void writeObject(ObjectOutputStream s) throws IOException { + s.defaultWriteObject(); + s.writeObject(builder.build()); + } + + private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException { + s.defaultReadObject(); + RowFilter.Condition condition = (RowFilter.Condition) s.readObject(); + this.builder = condition.toBuilder(); + } private ConditionFilter(@Nonnull Filter predicate) { Preconditions.checkNotNull(predicate); @@ -323,7 +362,9 @@ public ConditionFilter clone() { } } - public static final class KeyFilter { + public static final class KeyFilter implements Serializable { + private static final long serialVersionUID = 5137765114285539458L; + private KeyFilter() {} /** @@ -383,7 +424,9 @@ public Filter sample(double probability) { } } - public static final class FamilyFilter { + public static final class FamilyFilter implements Serializable { + private static final long serialVersionUID = -4470936841191831553L; + private FamilyFilter() {} /** @@ -405,7 +448,9 @@ public Filter exactMatch(@Nonnull String value) { } } - public static final class QualifierFilter { + public static final class QualifierFilter implements Serializable { + private static final long serialVersionUID = -1274850022909506559L; + private QualifierFilter() {} /** @@ -459,7 +504,8 @@ public QualifierRangeFilter rangeWithinFamily(@Nonnull String family) { /** Matches only cells from columns within the given range. */ public static final class QualifierRangeFilter - extends AbstractByteStringRange implements Filter, Serializable { + extends AbstractByteStringRange implements Filter { + private static final long serialVersionUID = -1909319911147913630L; private final String family; private QualifierRangeFilter(String family) { @@ -505,7 +551,9 @@ public QualifierRangeFilter clone() { } } - public static final class TimestampFilter { + public static final class TimestampFilter implements Serializable { + private static final long serialVersionUID = 5284219722591464991L; + private TimestampFilter() {} /** @@ -529,7 +577,9 @@ public TimestampRangeFilter exact(Long exactTimestamp) { /** Matches only cells with microsecond timestamps within the given range. */ public static final class TimestampRangeFilter - extends AbstractTimestampRange implements Filter, Serializable { + extends AbstractTimestampRange implements Filter { + private static final long serialVersionUID = 8410980338603335276L; + private TimestampRangeFilter() {} @InternalApi @@ -571,7 +621,9 @@ public TimestampRangeFilter clone() { } } - public static final class ValueFilter { + public static final class ValueFilter implements Serializable { + private static final long serialVersionUID = 6722715229238811179L; + private ValueFilter() {} /** @@ -628,7 +680,9 @@ public Filter strip() { /** Matches only cells with values that fall within the given value range. */ public static final class ValueRangeFilter extends AbstractByteStringRange - implements Filter, Serializable { + implements Filter { + private static final long serialVersionUID = -2452360677825047088L; + private ValueRangeFilter() {} @InternalApi @@ -668,7 +722,9 @@ public ValueRangeFilter clone() { } } - public static final class OffsetFilter { + public static final class OffsetFilter implements Serializable { + private static final long serialVersionUID = 3228791236971884041L; + private OffsetFilter() {} /** @@ -681,7 +737,9 @@ public Filter cellsPerRow(int count) { } } - public static final class LimitFilter { + public static final class LimitFilter implements Serializable { + private static final long serialVersionUID = -794915549003008940L; + private LimitFilter() {} /** @@ -705,6 +763,7 @@ public Filter cellsPerColumn(int count) { } private static final class SimpleFilter implements Filter { + private static final long serialVersionUID = 3595911451325189833L; private final RowFilter proto; private SimpleFilter(@Nonnull RowFilter proto) { @@ -729,7 +788,7 @@ public SimpleFilter clone() { } @InternalExtensionOnly - public interface Filter extends Cloneable { + public interface Filter extends Cloneable, Serializable { @InternalApi RowFilter toProto(); } diff --git a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/models/FiltersTest.java b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/models/FiltersTest.java index ad7902525..e5fcd133f 100644 --- a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/models/FiltersTest.java +++ b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/models/FiltersTest.java @@ -17,6 +17,8 @@ import static com.google.cloud.bigtable.data.v2.models.Filters.FILTERS; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static org.junit.Assert.fail; import com.google.bigtable.v2.ColumnRange; import com.google.bigtable.v2.RowFilter; @@ -26,6 +28,15 @@ import com.google.bigtable.v2.TimestampRange; import com.google.bigtable.v2.ValueRange; import com.google.protobuf.ByteString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -528,4 +539,213 @@ public void labelTest() { assertThat(actualFilter).isEqualTo(expectedFilter); } + + @Test + public void serializationTest() throws InvocationTargetException, IllegalAccessException { + // checks that the all objects returned by the all methods of the Filters class + // can be serialized/deserialized. + + for (Method m : Filters.class.getDeclaredMethods()) { + String name = m.getName(); + if (Modifier.isPublic(m.getModifiers())) { + switch (name) { + case "condition": + checkSerialization( + name, + FILTERS + .condition( + FILTERS + .chain() + .filter(FILTERS.qualifier().exactMatch("data_plan_10gb")) + .filter(FILTERS.value().exactMatch("true"))) + .then(FILTERS.label("passed-filter")) + .otherwise(FILTERS.label("filtered-out"))); + break; + case "label": + checkSerialization(name, FILTERS.label("label")); + break; + case "fromProto": + checkSerialization(name, FILTERS.label("label").toProto()); + break; + default: + checkSerialization(name, m.invoke(FILTERS)); + } + } + } + } + + private static void checkSerialization(String name, Object filter) { + try { + Object deserialized = serializeDeserialize(filter); + checkClassDeclaresSerialVersionUid(filter.getClass()); + if (filter instanceof Filters.Filter) { + checkFilters(name, (Filters.Filter) filter, (Filters.Filter) deserialized); + } else if (filter instanceof RowFilter) { + assertWithMessage("'" + name + "' deserialized filter differs") + .that(filter) + .isEqualTo(deserialized); + } else { + Class cls = filter.getClass(); + checkClassDoesNotContainNonStaticFields(cls, cls.getFields()); + checkClassDoesNotContainNonStaticFields(cls, cls.getDeclaredFields()); + checkSpawnedFilters(name, cls, filter, deserialized); + } + } catch (IOException | ClassNotFoundException e) { + fail(name + ": " + e); + } + } + + private static void checkFilters( + String name, Filters.Filter original, Filters.Filter deserialized) { + RowFilter protoBefore = ((Filters.Filter) original).toProto(); + RowFilter protoAfter = ((Filters.Filter) deserialized).toProto(); + assertWithMessage("'" + name + "' filter protoBuf mismatches after deserialization") + .that(protoBefore) + .isEqualTo(protoAfter); + } + + private static void checkSpawnedFilters( + String name, Class cls, Object original, Object deserialized) { + + int numberOfMethods = 0; + for (Method m : cls.getDeclaredMethods()) { + if (Modifier.isPublic(m.getModifiers())) { + numberOfMethods++; + } + } + ByteString re = ByteString.copyFromUtf8("some\\[0\\-9\\]regex"); + + switch (name) { + case "family": + { + Filters.FamilyFilter f1 = (Filters.FamilyFilter) original; + Filters.FamilyFilter f2 = (Filters.FamilyFilter) deserialized; + + assertThat(numberOfMethods).isEqualTo(2); + checkFilters(name + "/exactMatch", f1.exactMatch("abc"), f2.exactMatch("abc")); + checkFilters(name + "/regex", f1.regex("*"), f2.regex("*")); + + break; + } + case "qualifier": + { + Filters.QualifierFilter f1 = (Filters.QualifierFilter) original; + Filters.QualifierFilter f2 = (Filters.QualifierFilter) deserialized; + + assertThat(numberOfMethods).isEqualTo(5); + checkFilters(name + "/exactMatch", f1.exactMatch("abc"), f2.exactMatch("abc")); + checkFilters(name + "/exactMatch(ByteString)", f1.exactMatch(re), f2.exactMatch(re)); + checkFilters(name + "/regex", f1.regex("*"), f2.regex("*")); + checkFilters(name + "/regex(ByteString)", f1.regex(re), f2.regex(re)); + checkFilters( + name + "/rangeWithinFamily", + f1.rangeWithinFamily("family"), + f2.rangeWithinFamily("family")); + + break; + } + case "limit": + { + Filters.LimitFilter f1 = (Filters.LimitFilter) original; + Filters.LimitFilter f2 = (Filters.LimitFilter) deserialized; + + assertThat(numberOfMethods).isEqualTo(2); + checkFilters( + name + "/cellsPerColumn", f1.cellsPerColumn(100500), f2.cellsPerColumn(100500)); + checkFilters(name + "/cellsPerRow", f1.cellsPerRow(-10), f2.cellsPerRow(-10)); + + break; + } + case "value": + { + Filters.ValueFilter f1 = (Filters.ValueFilter) original; + Filters.ValueFilter f2 = (Filters.ValueFilter) deserialized; + + assertThat(numberOfMethods).isEqualTo(6); + checkFilters(name + "/exactMatch", f1.exactMatch("x"), f2.exactMatch("x")); + checkFilters(name + "/exactMatch(ByteString)", f1.exactMatch(re), f2.exactMatch(re)); + checkFilters(name + "/range", f1.range(), f2.range()); + checkFilters(name + "/regex", f1.regex("*"), f2.regex("*")); + checkFilters(name + "/regex(ByteString)", f1.regex(re), f2.regex(re)); + checkFilters(name + "/strip", f1.strip(), f2.strip()); + + break; + } + case "offset": + { + Filters.OffsetFilter f1 = (Filters.OffsetFilter) original; + Filters.OffsetFilter f2 = (Filters.OffsetFilter) deserialized; + + assertThat(numberOfMethods).isEqualTo(1); + checkFilters(name + "/cellsPerRow", f1.cellsPerRow(100500), f2.cellsPerRow(100500)); + + break; + } + case "key": + { + Filters.KeyFilter f1 = (Filters.KeyFilter) original; + Filters.KeyFilter f2 = (Filters.KeyFilter) deserialized; + + assertThat(numberOfMethods).isEqualTo(5); + checkFilters(name + "/exactMatch", f1.exactMatch("a"), f2.exactMatch("a")); + checkFilters(name + "/exactMatch(ByteString)", f1.exactMatch(re), f2.exactMatch(re)); + checkFilters(name + "/regex", f1.regex("a"), f2.regex("a")); + checkFilters(name + "/regex(ByteString)", f1.regex(re), f2.regex(re)); + checkFilters(name + "/sample", f1.sample(0.1), f2.sample(0.1)); + + break; + } + case "timestamp": + { + Filters.TimestampFilter f1 = (Filters.TimestampFilter) original; + Filters.TimestampFilter f2 = (Filters.TimestampFilter) deserialized; + + assertThat(numberOfMethods).isEqualTo(2); + checkFilters(name + "/exact", f1.exact(100500L), f2.exact(100500L)); + checkFilters(name + "/range", f1.range(), f2.range()); + + break; + } + default: + fail("Untested filter: " + name); + } + } + + private static void checkClassDeclaresSerialVersionUid(Class cls) { + String uid = "serialVersionUID"; + for (Field field : cls.getDeclaredFields()) { + if (field.getName() == uid) { + int modifiers = field.getModifiers(); + assertWithMessage(field + " is not static").that(Modifier.isStatic(modifiers)).isTrue(); + assertWithMessage(field + " is not final").that(Modifier.isFinal(modifiers)).isTrue(); + assertWithMessage(field + " is not private").that(Modifier.isPrivate(modifiers)).isTrue(); + assertWithMessage(field + " must be long") + .that(field.getType().getSimpleName()) + .isEqualTo("long"); + return; + } + } + fail(cls + " does not declare serialVersionUID"); + } + + private static void checkClassDoesNotContainNonStaticFields(Class cls, Field[] fields) { + for (Field field : fields) { + assertWithMessage(cls + " has a non-static field '" + field + "'") + .that(Modifier.isStatic(field.getModifiers())) + .isTrue(); + } + } + + private static Object serializeDeserialize(Object obj) + throws IOException, ClassNotFoundException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (ObjectOutputStream outStream = new ObjectOutputStream(bos)) { + outStream.writeObject(obj); + } + + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + try (ObjectInputStream inStream = new ObjectInputStream(bis)) { + return inStream.readObject(); + } + } }