diff --git a/clirr-ignored-differences.xml b/clirr-ignored-differences.xml
index 1aa41e4f..6aa9dcf9 100644
--- a/clirr-ignored-differences.xml
+++ b/clirr-ignored-differences.xml
@@ -12,4 +12,29 @@
*
*
+
+ 8001
+ com/google/cloud/pubsublite/spark/LimitingHeadOffsetReader
+
+
+ 8001
+ com/google/cloud/pubsublite/spark/MultiPartitionCommitter*
+
+
+ 8001
+ com/google/cloud/pubsublite/spark/PartitionSubscriberFactory
+
+
+ 8001
+ com/google/cloud/pubsublite/spark/PerTopicHeadOffsetReader
+
+
+ 8001
+ com/google/cloud/pubsublite/spark/PslCredentialsProvider
+
+
+ 8001
+ com/google/cloud/pubsublite/spark/PslDataSourceOptions*
+
+
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index c377b164..74123ab4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -113,6 +113,11 @@
${scala.version}
provided
+
+ org.scala-lang.modules
+ scala-java8-compat_2.11
+ 0.9.1
+
diff --git a/src/main/java/com/google/cloud/pubsublite/spark/Constants.java b/src/main/java/com/google/cloud/pubsublite/spark/Constants.java
index cac4337a..9ad29b23 100644
--- a/src/main/java/com/google/cloud/pubsublite/spark/Constants.java
+++ b/src/main/java/com/google/cloud/pubsublite/spark/Constants.java
@@ -17,7 +17,12 @@
package com.google.cloud.pubsublite.spark;
import com.google.cloud.pubsublite.internal.wire.PubsubContext;
+import com.google.common.collect.ImmutableMap;
+import java.util.Map;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -26,22 +31,33 @@ public class Constants {
public static long DEFAULT_BYTES_OUTSTANDING = 50_000_000;
public static long DEFAULT_MESSAGES_OUTSTANDING = Long.MAX_VALUE;
public static long DEFAULT_MAX_MESSAGES_PER_BATCH = Long.MAX_VALUE;
+
+ public static ArrayType ATTRIBUTES_PER_KEY_DATATYPE =
+ DataTypes.createArrayType(DataTypes.BinaryType);
+ public static MapType ATTRIBUTES_DATATYPE =
+ DataTypes.createMapType(DataTypes.StringType, ATTRIBUTES_PER_KEY_DATATYPE);
+ public static Map PUBLISH_FIELD_TYPES =
+ ImmutableMap.of(
+ "key", DataTypes.BinaryType,
+ "data", DataTypes.BinaryType,
+ "attributes", ATTRIBUTES_DATATYPE,
+ "event_timestamp", DataTypes.TimestampType);
public static StructType DEFAULT_SCHEMA =
new StructType(
new StructField[] {
new StructField("subscription", DataTypes.StringType, false, Metadata.empty()),
new StructField("partition", DataTypes.LongType, false, Metadata.empty()),
new StructField("offset", DataTypes.LongType, false, Metadata.empty()),
- new StructField("key", DataTypes.BinaryType, false, Metadata.empty()),
- new StructField("data", DataTypes.BinaryType, false, Metadata.empty()),
+ new StructField("key", PUBLISH_FIELD_TYPES.get("key"), false, Metadata.empty()),
+ new StructField("data", PUBLISH_FIELD_TYPES.get("data"), false, Metadata.empty()),
new StructField("publish_timestamp", DataTypes.TimestampType, false, Metadata.empty()),
- new StructField("event_timestamp", DataTypes.TimestampType, true, Metadata.empty()),
new StructField(
- "attributes",
- DataTypes.createMapType(
- DataTypes.StringType, DataTypes.createArrayType(DataTypes.BinaryType)),
+ "event_timestamp",
+ PUBLISH_FIELD_TYPES.get("event_timestamp"),
true,
- Metadata.empty())
+ Metadata.empty()),
+ new StructField(
+ "attributes", PUBLISH_FIELD_TYPES.get("attributes"), true, Metadata.empty())
});
public static final PubsubContext.Framework FRAMEWORK = PubsubContext.Framework.of("SPARK");
@@ -52,6 +68,7 @@ public class Constants {
"pubsublite.flowcontrol.byteoutstandingperpartition";
public static String MESSAGES_OUTSTANDING_CONFIG_KEY =
"pubsublite.flowcontrol.messageoutstandingperparition";
+ public static String TOPIC_CONFIG_KEY = "pubsublite.topic";
public static String SUBSCRIPTION_CONFIG_KEY = "pubsublite.subscription";
public static String CREDENTIALS_KEY_CONFIG_KEY = "gcp.credentials.key";
}
diff --git a/src/main/java/com/google/cloud/pubsublite/spark/PslContinuousReader.java b/src/main/java/com/google/cloud/pubsublite/spark/PslContinuousReader.java
index 65953031..ad2ca3da 100644
--- a/src/main/java/com/google/cloud/pubsublite/spark/PslContinuousReader.java
+++ b/src/main/java/com/google/cloud/pubsublite/spark/PslContinuousReader.java
@@ -22,6 +22,9 @@
import com.google.cloud.pubsublite.cloudpubsub.FlowControlSettings;
import com.google.cloud.pubsublite.internal.CursorClient;
import com.google.cloud.pubsublite.internal.wire.SubscriberFactory;
+import com.google.cloud.pubsublite.spark.internal.MultiPartitionCommitter;
+import com.google.cloud.pubsublite.spark.internal.PartitionCountReader;
+import com.google.cloud.pubsublite.spark.internal.PartitionSubscriberFactory;
import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.Arrays;
diff --git a/src/main/java/com/google/cloud/pubsublite/spark/PslDataSource.java b/src/main/java/com/google/cloud/pubsublite/spark/PslDataSource.java
index 08a96ee8..2ef2535d 100644
--- a/src/main/java/com/google/cloud/pubsublite/spark/PslDataSource.java
+++ b/src/main/java/com/google/cloud/pubsublite/spark/PslDataSource.java
@@ -23,6 +23,9 @@
import com.google.cloud.pubsublite.AdminClient;
import com.google.cloud.pubsublite.SubscriptionPath;
import com.google.cloud.pubsublite.TopicPath;
+import com.google.cloud.pubsublite.spark.internal.CachedPartitionCountReader;
+import com.google.cloud.pubsublite.spark.internal.LimitingHeadOffsetReader;
+import com.google.cloud.pubsublite.spark.internal.PartitionCountReader;
import java.util.Objects;
import java.util.Optional;
import org.apache.spark.sql.sources.DataSourceRegister;
@@ -30,13 +33,20 @@
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.MicroBatchReadSupport;
+import org.apache.spark.sql.sources.v2.StreamWriteSupport;
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader;
import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader;
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter;
+import org.apache.spark.sql.streaming.OutputMode;
import org.apache.spark.sql.types.StructType;
@AutoService(DataSourceRegister.class)
public final class PslDataSource
- implements DataSourceV2, ContinuousReadSupport, MicroBatchReadSupport, DataSourceRegister {
+ implements DataSourceV2,
+ ContinuousReadSupport,
+ MicroBatchReadSupport,
+ StreamWriteSupport,
+ DataSourceRegister {
@Override
public String shortName() {
@@ -51,23 +61,24 @@ public ContinuousReader createContinuousReader(
"PubSub Lite uses fixed schema and custom schema is not allowed");
}
- PslDataSourceOptions pslDataSourceOptions =
- PslDataSourceOptions.fromSparkDataSourceOptions(options);
- SubscriptionPath subscriptionPath = pslDataSourceOptions.subscriptionPath();
+ PslReadDataSourceOptions pslReadDataSourceOptions =
+ PslReadDataSourceOptions.fromSparkDataSourceOptions(options);
+ SubscriptionPath subscriptionPath = pslReadDataSourceOptions.subscriptionPath();
TopicPath topicPath;
- try (AdminClient adminClient = pslDataSourceOptions.newAdminClient()) {
+ try (AdminClient adminClient = pslReadDataSourceOptions.newAdminClient()) {
topicPath = TopicPath.parse(adminClient.getSubscription(subscriptionPath).get().getTopic());
} catch (Throwable t) {
throw toCanonical(t).underlying;
}
PartitionCountReader partitionCountReader =
- new CachedPartitionCountReader(pslDataSourceOptions.newAdminClient(), topicPath);
+ new CachedPartitionCountReader(pslReadDataSourceOptions.newAdminClient(), topicPath);
return new PslContinuousReader(
- pslDataSourceOptions.newCursorClient(),
- pslDataSourceOptions.newMultiPartitionCommitter(partitionCountReader.getPartitionCount()),
- pslDataSourceOptions.getSubscriberFactory(),
+ pslReadDataSourceOptions.newCursorClient(),
+ pslReadDataSourceOptions.newMultiPartitionCommitter(
+ partitionCountReader.getPartitionCount()),
+ pslReadDataSourceOptions.getSubscriberFactory(),
subscriptionPath,
- Objects.requireNonNull(pslDataSourceOptions.flowControlSettings()),
+ Objects.requireNonNull(pslReadDataSourceOptions.flowControlSettings()),
partitionCountReader);
}
@@ -79,28 +90,38 @@ public MicroBatchReader createMicroBatchReader(
"PubSub Lite uses fixed schema and custom schema is not allowed");
}
- PslDataSourceOptions pslDataSourceOptions =
- PslDataSourceOptions.fromSparkDataSourceOptions(options);
- SubscriptionPath subscriptionPath = pslDataSourceOptions.subscriptionPath();
+ PslReadDataSourceOptions pslReadDataSourceOptions =
+ PslReadDataSourceOptions.fromSparkDataSourceOptions(options);
+ SubscriptionPath subscriptionPath = pslReadDataSourceOptions.subscriptionPath();
TopicPath topicPath;
- try (AdminClient adminClient = pslDataSourceOptions.newAdminClient()) {
+ try (AdminClient adminClient = pslReadDataSourceOptions.newAdminClient()) {
topicPath = TopicPath.parse(adminClient.getSubscription(subscriptionPath).get().getTopic());
} catch (Throwable t) {
throw toCanonical(t).underlying;
}
PartitionCountReader partitionCountReader =
- new CachedPartitionCountReader(pslDataSourceOptions.newAdminClient(), topicPath);
+ new CachedPartitionCountReader(pslReadDataSourceOptions.newAdminClient(), topicPath);
return new PslMicroBatchReader(
- pslDataSourceOptions.newCursorClient(),
- pslDataSourceOptions.newMultiPartitionCommitter(partitionCountReader.getPartitionCount()),
- pslDataSourceOptions.getSubscriberFactory(),
+ pslReadDataSourceOptions.newCursorClient(),
+ pslReadDataSourceOptions.newMultiPartitionCommitter(
+ partitionCountReader.getPartitionCount()),
+ pslReadDataSourceOptions.getSubscriberFactory(),
new LimitingHeadOffsetReader(
- pslDataSourceOptions.newTopicStatsClient(),
+ pslReadDataSourceOptions.newTopicStatsClient(),
topicPath,
partitionCountReader,
Ticker.systemTicker()),
subscriptionPath,
- Objects.requireNonNull(pslDataSourceOptions.flowControlSettings()),
- pslDataSourceOptions.maxMessagesPerBatch());
+ Objects.requireNonNull(pslReadDataSourceOptions.flowControlSettings()),
+ pslReadDataSourceOptions.maxMessagesPerBatch());
+ }
+
+ @Override
+ public StreamWriter createStreamWriter(
+ String queryId, StructType schema, OutputMode mode, DataSourceOptions options) {
+ PslSparkUtils.verifyWriteInputSchema(schema);
+ PslWriteDataSourceOptions pslWriteDataSourceOptions =
+ PslWriteDataSourceOptions.fromSparkDataSourceOptions(options);
+ return new PslStreamWriter(schema, pslWriteDataSourceOptions);
}
}
diff --git a/src/main/java/com/google/cloud/pubsublite/spark/PslDataWriter.java b/src/main/java/com/google/cloud/pubsublite/spark/PslDataWriter.java
new file mode 100644
index 00000000..631fb2d3
--- /dev/null
+++ b/src/main/java/com/google/cloud/pubsublite/spark/PslDataWriter.java
@@ -0,0 +1,97 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.pubsublite.spark;
+
+import com.google.api.core.ApiFuture;
+import com.google.api.core.ApiService;
+import com.google.cloud.pubsublite.MessageMetadata;
+import com.google.cloud.pubsublite.internal.Publisher;
+import com.google.cloud.pubsublite.spark.internal.PublisherFactory;
+import com.google.common.flogger.GoogleLogger;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.concurrent.ExecutionException;
+import javax.annotation.concurrent.GuardedBy;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.sources.v2.writer.DataWriter;
+import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;
+import org.apache.spark.sql.types.StructType;
+
+public class PslDataWriter implements DataWriter {
+
+ private static final GoogleLogger log = GoogleLogger.forEnclosingClass();
+
+ private final long partitionId, taskId, epochId;
+ private final StructType inputSchema;
+ private final PublisherFactory publisherFactory;
+
+ @GuardedBy("this")
+ private Optional> publisher = Optional.empty();
+
+ @GuardedBy("this")
+ private final List> futures = new ArrayList<>();
+
+ public PslDataWriter(
+ long partitionId,
+ long taskId,
+ long epochId,
+ StructType schema,
+ PublisherFactory publisherFactory) {
+ this.partitionId = partitionId;
+ this.taskId = taskId;
+ this.epochId = epochId;
+ this.inputSchema = schema;
+ this.publisherFactory = publisherFactory;
+ }
+
+ @Override
+ public synchronized void write(InternalRow record) {
+ if (!publisher.isPresent() || publisher.get().state() != ApiService.State.RUNNING) {
+ publisher = Optional.of(publisherFactory.newPublisher());
+ }
+ futures.add(
+ publisher
+ .get()
+ .publish(Objects.requireNonNull(PslSparkUtils.toPubSubMessage(inputSchema, record))));
+ }
+
+ @Override
+ public synchronized WriterCommitMessage commit() throws IOException {
+ for (ApiFuture f : futures) {
+ try {
+ f.get();
+ } catch (InterruptedException | ExecutionException e) {
+ publisher = Optional.empty();
+ throw new IOException(e);
+ }
+ }
+ log.atInfo().log(
+ "All writes for partitionId:%d, taskId:%d, epochId:%d succeeded, committing...",
+ partitionId, taskId, epochId);
+ return PslWriterCommitMessage.create(futures.size());
+ }
+
+ @Override
+ public synchronized void abort() {
+ log.atWarning().log(
+ "One or more writes for partitionId:%d, taskId:%d, epochId:%d failed, aborted.",
+ partitionId, taskId, epochId);
+ }
+}
diff --git a/src/main/java/com/google/cloud/pubsublite/spark/PslDataWriterFactory.java b/src/main/java/com/google/cloud/pubsublite/spark/PslDataWriterFactory.java
new file mode 100644
index 00000000..12d95921
--- /dev/null
+++ b/src/main/java/com/google/cloud/pubsublite/spark/PslDataWriterFactory.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.pubsublite.spark;
+
+import com.google.cloud.pubsublite.spark.internal.CachedPublishers;
+import com.google.cloud.pubsublite.spark.internal.PublisherFactory;
+import java.io.Serializable;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.sources.v2.writer.DataWriter;
+import org.apache.spark.sql.sources.v2.writer.DataWriterFactory;
+import org.apache.spark.sql.types.StructType;
+
+public class PslDataWriterFactory implements Serializable, DataWriterFactory {
+ private static final long serialVersionUID = -6904546364310978844L;
+
+ private static final CachedPublishers CACHED_PUBLISHERS = new CachedPublishers();
+
+ private final StructType inputSchema;
+ private final PslWriteDataSourceOptions writeOptions;
+
+ public PslDataWriterFactory(StructType inputSchema, PslWriteDataSourceOptions writeOptions) {
+ this.inputSchema = inputSchema;
+ this.writeOptions = writeOptions;
+ }
+
+ @Override
+ public DataWriter createDataWriter(int partitionId, long taskId, long epochId) {
+ PublisherFactory pf = () -> CACHED_PUBLISHERS.getOrCreate(writeOptions);
+ return new PslDataWriter(partitionId, taskId, epochId, inputSchema, pf);
+ }
+}
diff --git a/src/main/java/com/google/cloud/pubsublite/spark/PslMicroBatchReader.java b/src/main/java/com/google/cloud/pubsublite/spark/PslMicroBatchReader.java
index b2a346c0..a0f0dfee 100644
--- a/src/main/java/com/google/cloud/pubsublite/spark/PslMicroBatchReader.java
+++ b/src/main/java/com/google/cloud/pubsublite/spark/PslMicroBatchReader.java
@@ -24,6 +24,9 @@
import com.google.cloud.pubsublite.cloudpubsub.FlowControlSettings;
import com.google.cloud.pubsublite.internal.CursorClient;
import com.google.cloud.pubsublite.internal.wire.SubscriberFactory;
+import com.google.cloud.pubsublite.spark.internal.MultiPartitionCommitter;
+import com.google.cloud.pubsublite.spark.internal.PartitionSubscriberFactory;
+import com.google.cloud.pubsublite.spark.internal.PerTopicHeadOffsetReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
diff --git a/src/main/java/com/google/cloud/pubsublite/spark/PslDataSourceOptions.java b/src/main/java/com/google/cloud/pubsublite/spark/PslReadDataSourceOptions.java
similarity index 92%
rename from src/main/java/com/google/cloud/pubsublite/spark/PslDataSourceOptions.java
rename to src/main/java/com/google/cloud/pubsublite/spark/PslReadDataSourceOptions.java
index 380e022a..f5987788 100644
--- a/src/main/java/com/google/cloud/pubsublite/spark/PslDataSourceOptions.java
+++ b/src/main/java/com/google/cloud/pubsublite/spark/PslReadDataSourceOptions.java
@@ -33,6 +33,10 @@
import com.google.cloud.pubsublite.internal.wire.RoutingMetadata;
import com.google.cloud.pubsublite.internal.wire.ServiceClients;
import com.google.cloud.pubsublite.internal.wire.SubscriberBuilder;
+import com.google.cloud.pubsublite.spark.internal.MultiPartitionCommitter;
+import com.google.cloud.pubsublite.spark.internal.MultiPartitionCommitterImpl;
+import com.google.cloud.pubsublite.spark.internal.PartitionSubscriberFactory;
+import com.google.cloud.pubsublite.spark.internal.PslCredentialsProvider;
import com.google.cloud.pubsublite.v1.AdminServiceClient;
import com.google.cloud.pubsublite.v1.AdminServiceSettings;
import com.google.cloud.pubsublite.v1.CursorServiceClient;
@@ -47,7 +51,7 @@
import org.apache.spark.sql.sources.v2.DataSourceOptions;
@AutoValue
-public abstract class PslDataSourceOptions implements Serializable {
+public abstract class PslReadDataSourceOptions implements Serializable {
private static final long serialVersionUID = 2680059304693561607L;
@Nullable
@@ -60,7 +64,7 @@ public abstract class PslDataSourceOptions implements Serializable {
public abstract long maxMessagesPerBatch();
public static Builder builder() {
- return new AutoValue_PslDataSourceOptions.Builder()
+ return new AutoValue_PslReadDataSourceOptions.Builder()
.setCredentialsKey(null)
.setMaxMessagesPerBatch(Constants.DEFAULT_MAX_MESSAGES_PER_BATCH)
.setFlowControlSettings(
@@ -70,7 +74,7 @@ public static Builder builder() {
.build());
}
- public static PslDataSourceOptions fromSparkDataSourceOptions(DataSourceOptions options) {
+ public static PslReadDataSourceOptions fromSparkDataSourceOptions(DataSourceOptions options) {
if (!options.get(Constants.SUBSCRIPTION_CONFIG_KEY).isPresent()) {
throw new IllegalArgumentException(Constants.SUBSCRIPTION_CONFIG_KEY + " is required.");
}
@@ -115,7 +119,7 @@ public abstract static class Builder {
public abstract Builder setFlowControlSettings(FlowControlSettings flowControlSettings);
- public abstract PslDataSourceOptions build();
+ public abstract PslReadDataSourceOptions build();
}
MultiPartitionCommitter newMultiPartitionCommitter(long topicPartitionCount) {
@@ -135,7 +139,7 @@ PartitionSubscriberFactory getSubscriberFactory() {
PubsubContext context = PubsubContext.of(Constants.FRAMEWORK);
SubscriberServiceSettings.Builder settingsBuilder =
SubscriberServiceSettings.newBuilder()
- .setCredentialsProvider(new PslCredentialsProvider(this));
+ .setCredentialsProvider(new PslCredentialsProvider(credentialsKey()));
ServiceClients.addDefaultMetadata(
context, RoutingMetadata.of(this.subscriptionPath(), partition), settingsBuilder);
try {
@@ -161,7 +165,7 @@ private CursorServiceClient newCursorServiceClient() {
addDefaultSettings(
this.subscriptionPath().location().region(),
CursorServiceSettings.newBuilder()
- .setCredentialsProvider(new PslCredentialsProvider(this))));
+ .setCredentialsProvider(new PslCredentialsProvider(credentialsKey()))));
} catch (IOException e) {
throw new IllegalStateException("Unable to create CursorServiceClient.");
}
@@ -181,7 +185,7 @@ private AdminServiceClient newAdminServiceClient() {
addDefaultSettings(
this.subscriptionPath().location().region(),
AdminServiceSettings.newBuilder()
- .setCredentialsProvider(new PslCredentialsProvider(this))));
+ .setCredentialsProvider(new PslCredentialsProvider(credentialsKey()))));
} catch (IOException e) {
throw new IllegalStateException("Unable to create AdminServiceClient.");
}
@@ -201,7 +205,7 @@ private TopicStatsServiceClient newTopicStatsServiceClient() {
addDefaultSettings(
this.subscriptionPath().location().region(),
TopicStatsServiceSettings.newBuilder()
- .setCredentialsProvider(new PslCredentialsProvider(this))));
+ .setCredentialsProvider(new PslCredentialsProvider(credentialsKey()))));
} catch (IOException e) {
throw new IllegalStateException("Unable to create TopicStatsServiceClient.");
}
diff --git a/src/main/java/com/google/cloud/pubsublite/spark/PslSparkUtils.java b/src/main/java/com/google/cloud/pubsublite/spark/PslSparkUtils.java
index 1d54fe19..2510315a 100644
--- a/src/main/java/com/google/cloud/pubsublite/spark/PslSparkUtils.java
+++ b/src/main/java/com/google/cloud/pubsublite/spark/PslSparkUtils.java
@@ -19,12 +19,16 @@
import static com.google.common.base.Preconditions.checkArgument;
import static scala.collection.JavaConverters.asScalaBufferConverter;
+import com.google.cloud.pubsublite.Message;
import com.google.cloud.pubsublite.Offset;
import com.google.cloud.pubsublite.Partition;
import com.google.cloud.pubsublite.SequencedMessage;
import com.google.cloud.pubsublite.SubscriptionPath;
import com.google.cloud.pubsublite.internal.CursorClient;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
+import com.google.common.flogger.GoogleLogger;
import com.google.common.math.LongMath;
import com.google.protobuf.ByteString;
import com.google.protobuf.util.Timestamps;
@@ -34,15 +38,29 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData;
+import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.GenericArrayData;
+import org.apache.spark.sql.catalyst.util.MapData;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.ByteArray;
import org.apache.spark.unsafe.types.UTF8String;
+import scala.Option;
+import scala.compat.java8.functionConverterImpls.FromJavaBiConsumer;
public class PslSparkUtils {
- private static ArrayBasedMapData convertAttributesToSparkMap(
+
+ private static final GoogleLogger log = GoogleLogger.forEnclosingClass();
+
+ @VisibleForTesting
+ public static ArrayBasedMapData convertAttributesToSparkMap(
ListMultimap attributeMap) {
List keyList = new ArrayList<>();
@@ -83,6 +101,97 @@ public static InternalRow toInternalRow(
return InternalRow.apply(asScalaBufferConverter(list).asScala());
}
+ @SuppressWarnings("unchecked")
+ private static void extractVal(
+ StructType inputSchema,
+ InternalRow row,
+ String fieldName,
+ DataType expectedDataType,
+ Consumer consumer) {
+ Option