diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/DirectWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/DirectWriter.java index f9a117fccb..4338f5f598 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/DirectWriter.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/DirectWriter.java @@ -33,11 +33,13 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Logger; +import org.json.JSONArray; /** * Writer that can help user to write data to BigQuery. This is a simplified version of the Write * API. For users writing with COMMITTED stream and don't care about row deduplication, it is - * recommended to use this Writer. + * recommended to use this Writer. The DirectWriter can be used to write both JSON and protobuf + * data. * *
{@code
  * DataProto data;
@@ -50,7 +52,9 @@
 public class DirectWriter {
   private static final Logger LOG = Logger.getLogger(DirectWriter.class.getName());
   private static WriterCache cache = null;
+  private static JsonWriterCache jsonCache = null;
   private static Lock cacheLock = new ReentrantLock();
+  private static Lock jsonCacheLock = new ReentrantLock();
 
   /**
    * Append rows to the given table.
@@ -103,10 +107,53 @@ public Long apply(Storage.AppendRowsResponse appendRowsResponse) {
         MoreExecutors.directExecutor());
   }
 
+  /**
+   * Append rows to the given table.
+   *
+   * @param tableName table name in the form of "projects/{pName}/datasets/{dName}/tables/{tName}"
+   * @param json A JSONArray
+   * @return A future that contains the offset at which the append happened. Only when the future
+   *     returns with valid offset, then the append actually happened.
+   * @throws IOException, InterruptedException, InvalidArgumentException,
+   *     Descriptors.DescriptorValidationException
+   */
+  public static ApiFuture append(String tableName, JSONArray json)
+      throws IOException, InterruptedException, InvalidArgumentException,
+          Descriptors.DescriptorValidationException {
+    Preconditions.checkNotNull(tableName, "TableName is null.");
+    Preconditions.checkNotNull(json, "JSONArray is null.");
+
+    if (json.length() == 0) {
+      throw new InvalidArgumentException(
+          new Exception("Empty JSONArrays are not allowed"),
+          GrpcStatusCode.of(Status.Code.INVALID_ARGUMENT),
+          false);
+    }
+    try {
+      jsonCacheLock.lock();
+      if (jsonCache == null) {
+        jsonCache = JsonWriterCache.getInstance();
+      }
+    } finally {
+      jsonCacheLock.unlock();
+    }
+    JsonStreamWriter writer = jsonCache.getTableWriter(tableName);
+    return ApiFutures.transform(
+        writer.append(json, /* offset = */ -1, /*allowUnknownFields = */ false),
+        new ApiFunction() {
+          @Override
+          public Long apply(Storage.AppendRowsResponse appendRowsResponse) {
+            return Long.valueOf(appendRowsResponse.getOffset());
+          }
+        },
+        MoreExecutors.directExecutor());
+  }
+
   @VisibleForTesting
   public static void testSetStub(
       BigQueryWriteClient stub, int maxTableEntry, SchemaCompatibility schemaCheck) {
     cache = WriterCache.getTestInstance(stub, maxTableEntry, schemaCheck);
+    jsonCache = JsonWriterCache.getTestInstance(stub, maxTableEntry);
   }
 
   /** Clears the underlying cache and all the transport connections. */
diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/JsonStreamWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/JsonStreamWriter.java
index ed8ee0f9fe..f0c63dd583 100644
--- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/JsonStreamWriter.java
+++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/JsonStreamWriter.java
@@ -260,6 +260,11 @@ public void close() {
     this.streamWriter.close();
   }
 
+  /** Returns if a stream has expired. */
+  public Boolean expired() {
+    return this.streamWriter.expired();
+  }
+
   private class JsonStreamWriterOnSchemaUpdateRunnable extends OnSchemaUpdateRunnable {
     private JsonStreamWriter jsonStreamWriter;
     /**
diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/JsonWriterCache.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/JsonWriterCache.java
new file mode 100644
index 0000000000..d9d22ac75a
--- /dev/null
+++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1alpha2/JsonWriterCache.java
@@ -0,0 +1,147 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.bigquery.storage.v1alpha2;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import com.google.protobuf.Descriptors;
+import java.io.IOException;
+import java.util.concurrent.ConcurrentMap;
+import java.util.logging.Logger;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * A cache of JsonStreamWriters that can be looked up by Table Name. The entries will expire after 5
+ * minutes if not used. Code sample: JsonWriterCache cache = JsonWriterCache.getInstance();
+ * JsonStreamWriter writer = cache.getWriter(); // Use... cache.returnWriter(writer);
+ */
+public class JsonWriterCache {
+  private static final Logger LOG = Logger.getLogger(JsonWriterCache.class.getName());
+
+  private static String tablePatternString = "(projects/[^/]+/datasets/[^/]+/tables/[^/]+)";
+  private static Pattern tablePattern = Pattern.compile(tablePatternString);
+
+  private static JsonWriterCache instance;
+  private Cache jsonWriterCache;
+
+  // Maximum number of tables to hold in the cache, once the maxium exceeded, the cache will be
+  // evicted based on least recent used.
+  private static final int MAX_TABLE_ENTRY = 100;
+  private static final int MAX_WRITERS_PER_TABLE = 1;
+
+  private final BigQueryWriteClient stub;
+
+  private JsonWriterCache(BigQueryWriteClient stub, int maxTableEntry) {
+    this.stub = stub;
+    jsonWriterCache =
+        CacheBuilder.newBuilder().maximumSize(maxTableEntry).build();
+  }
+
+  public static JsonWriterCache getInstance() throws IOException {
+    if (instance == null) {
+      BigQueryWriteSettings stubSettings = BigQueryWriteSettings.newBuilder().build();
+      BigQueryWriteClient stub = BigQueryWriteClient.create(stubSettings);
+      instance = new JsonWriterCache(stub, MAX_TABLE_ENTRY);
+    }
+    return instance;
+  }
+
+  /** Returns a cache with custom stub used by test. */
+  @VisibleForTesting
+  public static JsonWriterCache getTestInstance(BigQueryWriteClient stub, int maxTableEntry) {
+    Preconditions.checkNotNull(stub, "Stub is null.");
+    return new JsonWriterCache(stub, maxTableEntry);
+  }
+
+  private Stream.WriteStream CreateNewWriteStream(String tableName) {
+    Stream.WriteStream stream =
+        Stream.WriteStream.newBuilder().setType(Stream.WriteStream.Type.COMMITTED).build();
+    stream =
+        stub.createWriteStream(
+            Storage.CreateWriteStreamRequest.newBuilder()
+                .setParent(tableName)
+                .setWriteStream(stream)
+                .build());
+    LOG.info("Created write stream:" + stream.getName());
+    return stream;
+  }
+
+  JsonStreamWriter CreateNewWriter(Stream.WriteStream writeStream)
+      throws IllegalArgumentException, IOException, InterruptedException,
+          Descriptors.DescriptorValidationException {
+    return JsonStreamWriter.newBuilder(writeStream.getName(), writeStream.getTableSchema())
+        .setChannelProvider(stub.getSettings().getTransportChannelProvider())
+        .setCredentialsProvider(stub.getSettings().getCredentialsProvider())
+        .setExecutorProvider(stub.getSettings().getExecutorProvider())
+        .build();
+  }
+  /**
+   * Gets a writer for a given table with the given tableName
+   *
+   * @param tableName
+   * @return
+   * @throws Exception
+   */
+  public JsonStreamWriter getTableWriter(String tableName)
+      throws IllegalArgumentException, IOException, InterruptedException,
+          Descriptors.DescriptorValidationException {
+    Preconditions.checkNotNull(tableName, "TableName is null.");
+    Matcher matcher = tablePattern.matcher(tableName);
+    if (!matcher.matches()) {
+      throw new IllegalArgumentException("Invalid table name: " + tableName);
+    }
+
+    Stream.WriteStream writeStream = null;
+    JsonStreamWriter writer = null;
+
+    synchronized (this) {
+      writer = jsonWriterCache.getIfPresent(tableName);
+      if (writer != null) {
+        if (!writer.expired()) {
+          return writer;
+        } else {
+          writer.close();
+        }
+      }
+      writeStream = CreateNewWriteStream(tableName);
+      writer = CreateNewWriter(writeStream);
+      jsonWriterCache.put(tableName, writer);
+    }
+    return writer;
+  }
+
+  /** Clear the cache and close all the writers in the cache. */
+  public void clear() {
+    synchronized (this) {
+      ConcurrentMap map = jsonWriterCache.asMap();
+      for (String key : map.keySet()) {
+        JsonStreamWriter entry = jsonWriterCache.getIfPresent(key);
+        entry.close();
+      }
+      jsonWriterCache.cleanUp();
+    }
+  }
+
+  @VisibleForTesting
+  public long cachedTableCount() {
+    synchronized (jsonWriterCache) {
+      return jsonWriterCache.size();
+    }
+  }
+}
diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/DirectWriterTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/DirectWriterTest.java
index f57ac92339..1e358f26ed 100644
--- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/DirectWriterTest.java
+++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/DirectWriterTest.java
@@ -23,17 +23,29 @@
 import com.google.api.gax.grpc.testing.LocalChannelProvider;
 import com.google.api.gax.grpc.testing.MockGrpcService;
 import com.google.api.gax.grpc.testing.MockServiceHelper;
-import com.google.cloud.bigquery.storage.test.Test.*;
+import com.google.cloud.bigquery.storage.test.Test.AllSupportedTypes;
+import com.google.cloud.bigquery.storage.test.Test.FooType;
+import com.google.cloud.bigquery.storage.v1alpha2.Storage.AppendRowsRequest;
 import com.google.common.collect.Sets;
 import com.google.protobuf.AbstractMessage;
 import com.google.protobuf.Timestamp;
 import java.io.IOException;
-import java.util.*;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Set;
+import java.util.UUID;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.logging.Logger;
-import org.junit.*;
+import org.json.JSONArray;
+import org.json.JSONObject;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 import org.mockito.Mock;
@@ -53,6 +65,15 @@ public class DirectWriterTest {
   private BigQueryWriteClient client;
   private LocalChannelProvider channelProvider;
 
+  private final Table.TableFieldSchema FOO =
+      Table.TableFieldSchema.newBuilder()
+          .setType(Table.TableFieldSchema.Type.STRING)
+          .setMode(Table.TableFieldSchema.Mode.NULLABLE)
+          .setName("foo")
+          .build();
+  private final Table.TableSchema TABLE_SCHEMA =
+      Table.TableSchema.newBuilder().addFields(0, FOO).build();
+
   @Mock private static SchemaCompatibility schemaCheck;
 
   @BeforeClass
@@ -113,8 +134,83 @@ void WriterCreationResponseMock(String testStreamName, Set responseOffsets
     }
   }
 
+  /** Response mocks for create a new writer */
+  void JsonWriterCreationResponseMock(String testStreamName, Set responseOffsets) {
+    // Response from CreateWriteStream
+    Stream.WriteStream expectedResponse =
+        Stream.WriteStream.newBuilder()
+            .setName(testStreamName)
+            .setTableSchema(TABLE_SCHEMA)
+            .build();
+    mockBigQueryWrite.addResponse(expectedResponse);
+
+    // Response from GetWriteStream
+    Instant time = Instant.now();
+    Timestamp timestamp =
+        Timestamp.newBuilder().setSeconds(time.getEpochSecond()).setNanos(time.getNano()).build();
+    Stream.WriteStream expectedResponse2 =
+        Stream.WriteStream.newBuilder()
+            .setName(testStreamName)
+            .setType(Stream.WriteStream.Type.COMMITTED)
+            .setCreateTime(timestamp)
+            .build();
+    mockBigQueryWrite.addResponse(expectedResponse2);
+
+    for (Long offset : responseOffsets) {
+      Storage.AppendRowsResponse response =
+          Storage.AppendRowsResponse.newBuilder().setOffset(offset).build();
+      mockBigQueryWrite.addResponse(response);
+    }
+  }
+
   @Test
-  public void testWriteSuccess() throws Exception {
+  public void testJsonWriteSuccess() throws Exception {
+    DirectWriter.testSetStub(client, 10, schemaCheck);
+    FooType m1 = FooType.newBuilder().setFoo("m1").build();
+    FooType m2 = FooType.newBuilder().setFoo("m2").build();
+    JSONObject m1_json = new JSONObject();
+    m1_json.put("foo", "m1");
+    JSONObject m2_json = new JSONObject();
+    m2_json.put("foo", "m2");
+    JSONArray jsonArr = new JSONArray();
+    jsonArr.put(m1_json);
+    jsonArr.put(m2_json);
+
+    JSONArray jsonArr2 = new JSONArray();
+    jsonArr2.put(m1_json);
+
+    JsonWriterCreationResponseMock(TEST_STREAM, Sets.newHashSet(Long.valueOf(0L)));
+    ApiFuture ret = DirectWriter.append(TEST_TABLE, jsonArr);
+    assertEquals(Long.valueOf(0L), ret.get());
+    List actualRequests = mockBigQueryWrite.getRequests();
+    Assert.assertEquals(3, actualRequests.size());
+    assertEquals(
+        TEST_TABLE, ((Storage.CreateWriteStreamRequest) actualRequests.get(0)).getParent());
+    assertEquals(
+        Stream.WriteStream.Type.COMMITTED,
+        ((Storage.CreateWriteStreamRequest) actualRequests.get(0)).getWriteStream().getType());
+    assertEquals(TEST_STREAM, ((Storage.GetWriteStreamRequest) actualRequests.get(1)).getName());
+    assertEquals(
+        m1.toByteString(),
+        ((AppendRowsRequest) actualRequests.get(2)).getProtoRows().getRows().getSerializedRows(0));
+    assertEquals(
+        m2.toByteString(),
+        ((AppendRowsRequest) actualRequests.get(2)).getProtoRows().getRows().getSerializedRows(1));
+
+    Storage.AppendRowsResponse response =
+        Storage.AppendRowsResponse.newBuilder().setOffset(2).build();
+    mockBigQueryWrite.addResponse(response);
+
+    ret = DirectWriter.append(TEST_TABLE, jsonArr2);
+    assertEquals(Long.valueOf(2L), ret.get());
+    assertEquals(
+        m1.toByteString(),
+        ((AppendRowsRequest) actualRequests.get(3)).getProtoRows().getRows().getSerializedRows(0));
+    DirectWriter.clearCache();
+  }
+
+  @Test
+  public void testProtobufWriteSuccess() throws Exception {
     DirectWriter.testSetStub(client, 10, schemaCheck);
     FooType m1 = FooType.newBuilder().setFoo("m1").build();
     FooType m2 = FooType.newBuilder().setFoo("m2").build();
@@ -203,6 +299,27 @@ public void testWriteBadTableName() throws Exception {
     DirectWriter.clearCache();
   }
 
+  @Test
+  public void testJsonWriteBadTableName() throws Exception {
+    DirectWriter.testSetStub(client, 10, schemaCheck);
+    JSONObject m1_json = new JSONObject();
+    m1_json.put("foo", "m1");
+    JSONObject m2_json = new JSONObject();
+    m2_json.put("foo", "m2");
+    final JSONArray jsonArr = new JSONArray();
+    jsonArr.put(m1_json);
+    jsonArr.put(m2_json);
+
+    try {
+      ApiFuture ret = DirectWriter.append("abc", jsonArr);
+      fail("should fail");
+    } catch (IllegalArgumentException expected) {
+      assertEquals("Invalid table name: abc", expected.getMessage());
+    }
+
+    DirectWriter.clearCache();
+  }
+
   @Test
   public void testConcurrentAccess() throws Exception {
     DirectWriter.testSetStub(client, 2, schemaCheck);
@@ -213,8 +330,8 @@ public void testConcurrentAccess() throws Exception {
             Long.valueOf(0L),
             Long.valueOf(2L),
             Long.valueOf(4L),
-            Long.valueOf(8L),
-            Long.valueOf(10L));
+            Long.valueOf(6L),
+            Long.valueOf(8L));
     // Make sure getting the same table writer in multiple thread only cause create to be called
     // once.
     WriterCreationResponseMock(TEST_STREAM, expectedOffset);
@@ -244,4 +361,119 @@ public void run() {
     }
     DirectWriter.clearCache();
   }
+
+  @Test
+  public void testJsonConcurrentAccess() throws Exception {
+    DirectWriter.testSetStub(client, 2, schemaCheck);
+    FooType m1 = FooType.newBuilder().setFoo("m1").build();
+    FooType m2 = FooType.newBuilder().setFoo("m2").build();
+    JSONObject m1_json = new JSONObject();
+    m1_json.put("foo", "m1");
+    JSONObject m2_json = new JSONObject();
+    m2_json.put("foo", "m2");
+    final JSONArray jsonArr = new JSONArray();
+    jsonArr.put(m1_json);
+    jsonArr.put(m2_json);
+
+    final Set expectedOffset =
+        Sets.newHashSet(
+            Long.valueOf(0L),
+            Long.valueOf(2L),
+            Long.valueOf(4L),
+            Long.valueOf(6L),
+            Long.valueOf(8L));
+    // Make sure getting the same table writer in multiple thread only cause create to be called
+    // once.
+    JsonWriterCreationResponseMock(TEST_STREAM, expectedOffset);
+    ExecutorService executor = Executors.newFixedThreadPool(5);
+    for (int i = 0; i < 5; i++) {
+      executor.execute(
+          new Runnable() {
+            @Override
+            public void run() {
+              try {
+                ApiFuture result = DirectWriter.append(TEST_TABLE, jsonArr);
+                synchronized (expectedOffset) {
+                  assertTrue(expectedOffset.remove(result.get()));
+                }
+              } catch (Exception e) {
+                fail(e.toString());
+              }
+            }
+          });
+    }
+    executor.shutdown();
+    try {
+      executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
+    } catch (InterruptedException e) {
+      LOG.info(e.toString());
+    }
+    DirectWriter.clearCache();
+  }
+
+  @Test
+  public void testJsonProtobufWrite() throws Exception {
+    DirectWriter.testSetStub(client, 10, schemaCheck);
+    FooType m1 = FooType.newBuilder().setFoo("m1").build();
+    FooType m2 = FooType.newBuilder().setFoo("m2").build();
+    JSONObject m1_json = new JSONObject();
+    m1_json.put("foo", "m1");
+    JSONObject m2_json = new JSONObject();
+    m2_json.put("foo", "m2");
+    JSONArray jsonArr = new JSONArray();
+    jsonArr.put(m1_json);
+    jsonArr.put(m2_json);
+
+    JSONArray jsonArr2 = new JSONArray();
+    jsonArr2.put(m1_json);
+
+    WriterCreationResponseMock(TEST_STREAM, Sets.newHashSet(Long.valueOf(0L)));
+
+    ApiFuture ret = DirectWriter.append(TEST_TABLE, Arrays.asList(m1, m2));
+    verify(schemaCheck).check(TEST_TABLE, FooType.getDescriptor());
+    assertEquals(Long.valueOf(0L), ret.get());
+    List actualRequests = mockBigQueryWrite.getRequests();
+    Assert.assertEquals(3, actualRequests.size());
+    assertEquals(
+        TEST_TABLE, ((Storage.CreateWriteStreamRequest) actualRequests.get(0)).getParent());
+    assertEquals(
+        Stream.WriteStream.Type.COMMITTED,
+        ((Storage.CreateWriteStreamRequest) actualRequests.get(0)).getWriteStream().getType());
+    assertEquals(TEST_STREAM, ((Storage.GetWriteStreamRequest) actualRequests.get(1)).getName());
+
+    Storage.AppendRowsRequest.ProtoData.Builder dataBuilder =
+        Storage.AppendRowsRequest.ProtoData.newBuilder();
+    dataBuilder.setWriterSchema(ProtoSchemaConverter.convert(FooType.getDescriptor()));
+    dataBuilder.setRows(
+        ProtoBufProto.ProtoRows.newBuilder()
+            .addSerializedRows(m1.toByteString())
+            .addSerializedRows(m2.toByteString())
+            .build());
+    Storage.AppendRowsRequest expectRequest =
+        Storage.AppendRowsRequest.newBuilder()
+            .setWriteStream(TEST_STREAM)
+            .setProtoRows(dataBuilder.build())
+            .build();
+    assertEquals(expectRequest.toString(), actualRequests.get(2).toString());
+
+    JsonWriterCreationResponseMock(TEST_STREAM, Sets.newHashSet(Long.valueOf(0L)));
+    ret = DirectWriter.append(TEST_TABLE, jsonArr);
+    assertEquals(Long.valueOf(0L), ret.get());
+    actualRequests = mockBigQueryWrite.getRequests();
+    Assert.assertEquals(6, actualRequests.size());
+    assertEquals(
+        TEST_TABLE, ((Storage.CreateWriteStreamRequest) actualRequests.get(3)).getParent());
+    assertEquals(
+        Stream.WriteStream.Type.COMMITTED,
+        ((Storage.CreateWriteStreamRequest) actualRequests.get(3)).getWriteStream().getType());
+    assertEquals(TEST_STREAM, ((Storage.GetWriteStreamRequest) actualRequests.get(4)).getName());
+    assertEquals(
+        m1.toByteString(),
+        ((AppendRowsRequest) actualRequests.get(5)).getProtoRows().getRows().getSerializedRows(0));
+    assertEquals(
+        m2.toByteString(),
+        ((AppendRowsRequest) actualRequests.get(5)).getProtoRows().getRows().getSerializedRows(1));
+
+    DirectWriter.clearCache();
+  }
 }
diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/JsonWriterCacheTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/JsonWriterCacheTest.java
new file mode 100644
index 0000000000..5dd4ce820d
--- /dev/null
+++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/JsonWriterCacheTest.java
@@ -0,0 +1,255 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.bigquery.storage.v1alpha2;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.*;
+
+import com.google.api.gax.core.NoCredentialsProvider;
+import com.google.api.gax.grpc.testing.LocalChannelProvider;
+import com.google.api.gax.grpc.testing.MockGrpcService;
+import com.google.api.gax.grpc.testing.MockServiceHelper;
+import com.google.cloud.bigquery.storage.test.Test.*;
+import com.google.protobuf.AbstractMessage;
+import com.google.protobuf.Timestamp;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.logging.Logger;
+import org.junit.*;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.threeten.bp.Instant;
+import org.threeten.bp.temporal.ChronoUnit;
+
+@RunWith(JUnit4.class)
+public class JsonWriterCacheTest {
+  private static final Logger LOG = Logger.getLogger(JsonWriterCacheTest.class.getName());
+
+  private static final String TEST_TABLE = "projects/p/datasets/d/tables/t";
+  private static final String TEST_STREAM = "projects/p/datasets/d/tables/t/streams/s";
+  private static final String TEST_STREAM_2 = "projects/p/datasets/d/tables/t/streams/s2";
+  private static final String TEST_STREAM_3 = "projects/p/datasets/d/tables/t/streams/s3";
+  private static final String TEST_STREAM_4 = "projects/p/datasets/d/tables/t/streams/s4";
+  private static final String TEST_TABLE_2 = "projects/p/datasets/d/tables/t2";
+  private static final String TEST_STREAM_21 = "projects/p/datasets/d/tables/t2/streams/s1";
+  private static final String TEST_TABLE_3 = "projects/p/datasets/d/tables/t3";
+  private static final String TEST_STREAM_31 = "projects/p/datasets/d/tables/t3/streams/s1";
+
+  private static MockBigQueryWrite mockBigQueryWrite;
+  private static MockServiceHelper serviceHelper;
+  @Mock private static SchemaCompatibility mockSchemaCheck;
+  private BigQueryWriteClient client;
+  private LocalChannelProvider channelProvider;
+
+  private final Table.TableFieldSchema FOO =
+      Table.TableFieldSchema.newBuilder()
+          .setType(Table.TableFieldSchema.Type.STRING)
+          .setMode(Table.TableFieldSchema.Mode.NULLABLE)
+          .setName("foo")
+          .build();
+  private final Table.TableSchema TABLE_SCHEMA =
+      Table.TableSchema.newBuilder().addFields(0, FOO).build();
+
+  @BeforeClass
+  public static void startStaticServer() {
+    mockBigQueryWrite = new MockBigQueryWrite();
+    serviceHelper =
+        new MockServiceHelper(
+            UUID.randomUUID().toString(), Arrays.asList(mockBigQueryWrite));
+    serviceHelper.start();
+  }
+
+  @AfterClass
+  public static void stopServer() {
+    serviceHelper.stop();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    serviceHelper.reset();
+    channelProvider = serviceHelper.createChannelProvider();
+    BigQueryWriteSettings settings =
+        BigQueryWriteSettings.newBuilder()
+            .setTransportChannelProvider(channelProvider)
+            .setCredentialsProvider(NoCredentialsProvider.create())
+            .build();
+    client = BigQueryWriteClient.create(settings);
+    MockitoAnnotations.initMocks(this);
+  }
+
+  /** Response mocks for create a new writer */
+  void WriterCreationResponseMock(String testStreamName) {
+    // Response from CreateWriteStream
+    Stream.WriteStream expectedResponse =
+        Stream.WriteStream.newBuilder()
+            .setName(testStreamName)
+            .setTableSchema(TABLE_SCHEMA)
+            .build();
+    mockBigQueryWrite.addResponse(expectedResponse);
+
+    // Response from GetWriteStream
+    Instant time = Instant.now();
+    Timestamp timestamp =
+        Timestamp.newBuilder().setSeconds(time.getEpochSecond()).setNanos(time.getNano()).build();
+    Stream.WriteStream expectedResponse2 =
+        Stream.WriteStream.newBuilder()
+            .setName(testStreamName)
+            .setType(Stream.WriteStream.Type.COMMITTED)
+            .setCreateTime(timestamp)
+            .build();
+    mockBigQueryWrite.addResponse(expectedResponse2);
+  }
+
+  @After
+  public void tearDown() throws Exception {
+    client.close();
+  }
+
+  @Test
+  public void testRejectBadTableName() throws Exception {
+    JsonWriterCache cache = JsonWriterCache.getTestInstance(client, 10);
+    try {
+      cache.getTableWriter("abc");
+      fail();
+    } catch (IllegalArgumentException expected) {
+      assertEquals(expected.getMessage(), "Invalid table name: abc");
+    }
+  }
+
+  @Test
+  public void testCreateNewWriter() throws Exception {
+    JsonWriterCache cache = JsonWriterCache.getTestInstance(client, 10);
+    WriterCreationResponseMock(TEST_STREAM);
+    JsonStreamWriter writer = cache.getTableWriter(TEST_TABLE);
+    List actualRequests = mockBigQueryWrite.getRequests();
+    assertEquals(2, actualRequests.size());
+    assertEquals(
+        TEST_TABLE, ((Storage.CreateWriteStreamRequest) actualRequests.get(0)).getParent());
+    assertEquals(
+        Stream.WriteStream.Type.COMMITTED,
+        ((Storage.CreateWriteStreamRequest) actualRequests.get(0)).getWriteStream().getType());
+    assertEquals(TEST_STREAM, ((Storage.GetWriteStreamRequest) actualRequests.get(1)).getName());
+
+    assertEquals(TEST_STREAM, writer.getStreamName());
+    assertEquals(1, cache.cachedTableCount());
+    cache.clear();
+  }
+
+  @Test
+  public void testWriterExpired() throws Exception {
+    JsonWriterCache cache = JsonWriterCache.getTestInstance(client, 10);
+    // Response from CreateWriteStream
+    Stream.WriteStream expectedResponse =
+        Stream.WriteStream.newBuilder().setName(TEST_STREAM).build();
+    mockBigQueryWrite.addResponse(expectedResponse);
+
+    // Response from GetWriteStream
+    Instant time = Instant.now().minus(2, ChronoUnit.DAYS);
+    Timestamp timestamp =
+        Timestamp.newBuilder().setSeconds(time.getEpochSecond()).setNanos(time.getNano()).build();
+    Stream.WriteStream expectedResponse2 =
+        Stream.WriteStream.newBuilder()
+            .setName(TEST_STREAM)
+            .setType(Stream.WriteStream.Type.COMMITTED)
+            .setCreateTime(timestamp)
+            .build();
+    mockBigQueryWrite.addResponse(expectedResponse2);
+
+    try {
+      JsonStreamWriter writer = cache.getTableWriter(TEST_TABLE);
+      fail("Should fail");
+    } catch (IllegalStateException e) {
+      assertEquals(
+          "Cannot write to a stream that is already expired: projects/p/datasets/d/tables/t/streams/s",
+          e.getMessage());
+    }
+    cache.clear();
+  }
+
+  @Test
+  public void testWriterWithDifferentTable() throws Exception {
+    JsonWriterCache cache = JsonWriterCache.getTestInstance(client, 2);
+    WriterCreationResponseMock(TEST_STREAM);
+    WriterCreationResponseMock(TEST_STREAM_21);
+    JsonStreamWriter writer1 = cache.getTableWriter(TEST_TABLE);
+    JsonStreamWriter writer2 = cache.getTableWriter(TEST_TABLE_2);
+
+    List actualRequests = mockBigQueryWrite.getRequests();
+    assertEquals(4, actualRequests.size());
+    assertEquals(
+        TEST_TABLE, ((Storage.CreateWriteStreamRequest) actualRequests.get(0)).getParent());
+    assertEquals(TEST_STREAM, ((Storage.GetWriteStreamRequest) actualRequests.get(1)).getName());
+    assertEquals(
+        TEST_TABLE_2, ((Storage.CreateWriteStreamRequest) actualRequests.get(2)).getParent());
+    Assert.assertEquals(
+        TEST_STREAM_21, ((Storage.GetWriteStreamRequest) actualRequests.get(3)).getName());
+    assertEquals(TEST_STREAM, writer1.getStreamName());
+    assertEquals(TEST_STREAM_21, writer2.getStreamName());
+    assertEquals(2, cache.cachedTableCount());
+
+    // Still able to get the FooType writer.
+    JsonStreamWriter writer3 = cache.getTableWriter(TEST_TABLE_2);
+    Assert.assertEquals(TEST_STREAM_21, writer3.getStreamName());
+
+    // Create a writer with a even new schema.
+    WriterCreationResponseMock(TEST_STREAM_31);
+    WriterCreationResponseMock(TEST_STREAM);
+    JsonStreamWriter writer4 = cache.getTableWriter(TEST_TABLE_3);
+    // This would cause a new stream to be created since the old entry is evicted.
+    JsonStreamWriter writer5 = cache.getTableWriter(TEST_TABLE);
+    assertEquals(TEST_STREAM_31, writer4.getStreamName());
+    assertEquals(TEST_STREAM, writer5.getStreamName());
+    assertEquals(2, cache.cachedTableCount());
+    cache.clear();
+  }
+
+  @Test
+  public void testConcurrentAccess() throws Exception {
+    final JsonWriterCache cache = JsonWriterCache.getTestInstance(client, 2);
+    // Make sure getting the same table writer in multiple thread only cause create to be called
+    // once.
+    WriterCreationResponseMock(TEST_STREAM);
+    ExecutorService executor = Executors.newFixedThreadPool(10);
+    for (int i = 0; i < 10; i++) {
+      executor.execute(
+          new Runnable() {
+            @Override
+            public void run() {
+              try {
+                assertTrue(cache.getTableWriter(TEST_TABLE) != null);
+              } catch (Exception e) {
+                fail(e.getMessage());
+              }
+            }
+          });
+    }
+    executor.shutdown();
+    try {
+      executor.awaitTermination(1, TimeUnit.MINUTES);
+    } catch (InterruptedException e) {
+      LOG.info(e.toString());
+    }
+  }
+}
diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/WriterCacheTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/WriterCacheTest.java
index 450789da36..a427a5bbc3 100644
--- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/WriterCacheTest.java
+++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/WriterCacheTest.java
@@ -45,7 +45,7 @@
 
 @RunWith(JUnit4.class)
 public class WriterCacheTest {
-  private static final Logger LOG = Logger.getLogger(StreamWriterTest.class.getName());
+  private static final Logger LOG = Logger.getLogger(WriterCacheTest.class.getName());
 
   private static final String TEST_TABLE = "projects/p/datasets/d/tables/t";
   private static final String TEST_STREAM = "projects/p/datasets/d/tables/t/streams/s";
diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/it/ITBigQueryWriteManualClientTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/it/ITBigQueryWriteManualClientTest.java
index 64e8a07f4f..c1d4ca98a3 100644
--- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/it/ITBigQueryWriteManualClientTest.java
+++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1alpha2/it/ITBigQueryWriteManualClientTest.java
@@ -581,7 +581,9 @@ public void run() {
   }
 
   @Test
-  public void testDirectWrite() throws IOException, InterruptedException, ExecutionException {
+  public void testDirectWrite()
+      throws IOException, InterruptedException, ExecutionException,
+          Descriptors.DescriptorValidationException {
     final FooType fa = FooType.newBuilder().setFoo("aaa").build();
     final FooType fb = FooType.newBuilder().setFoo("bbb").build();
     Set expectedOffset = new HashSet<>();
@@ -605,12 +607,45 @@ public Long call() throws IOException, InterruptedException, ExecutionException
       assertTrue(expectedOffset.remove(response.get()));
     }
     assertTrue(expectedOffset.isEmpty());
+
+    JSONObject a_json = new JSONObject();
+    a_json.put("foo", "aaa");
+    JSONObject b_json = new JSONObject();
+    b_json.put("foo", "bbb");
+    final JSONArray jsonArr = new JSONArray();
+    jsonArr.put(a_json);
+    jsonArr.put(b_json);
+
+    expectedOffset = new HashSet<>();
+    for (int i = 0; i < 10; i++) {
+      expectedOffset.add(Long.valueOf(i * 2));
+    }
+    executor = Executors.newFixedThreadPool(10);
+    responses = new ArrayList<>();
+    callable =
+        new Callable() {
+          @Override
+          public Long call()
+              throws IOException, InterruptedException, ExecutionException,
+                  Descriptors.DescriptorValidationException {
+            ApiFuture result = DirectWriter.append(tableId, jsonArr);
+            return result.get();
+          }
+        };
+    for (int i = 0; i < 10; i++) {
+      responses.add(executor.submit(callable));
+    }
+    for (Future response : responses) {
+      assertTrue(expectedOffset.remove(response.get()));
+    }
+    assertTrue(expectedOffset.isEmpty());
     executor.shutdown();
     try {
       executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
     } catch (InterruptedException e) {
       LOG.info(e.toString());
     }
+
     DirectWriter.clearCache();
   }