diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionOptions.java index 0b1bc8d21b..379459884c 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionOptions.java @@ -384,6 +384,7 @@ public static Builder newBuilder() { private final String uri; private final String credentialsUrl; private final String oauthToken; + private final Credentials fixedCredentials; private final boolean usePlainText; private final String host; @@ -413,6 +414,7 @@ private ConnectionOptions(Builder builder) { builder.credentialsUrl != null ? builder.credentialsUrl : parseCredentials(builder.uri); this.oauthToken = builder.oauthToken != null ? builder.oauthToken : parseOAuthToken(builder.uri); + this.fixedCredentials = builder.credentials; // Check that not both credentials and an OAuth token have been specified. Preconditions.checkArgument( (builder.credentials == null && this.credentialsUrl == null) || this.oauthToken == null, @@ -441,11 +443,10 @@ private ConnectionOptions(Builder builder) { this.credentials = NoCredentials.getInstance(); } else if (this.oauthToken != null) { this.credentials = new GoogleCredentials(new AccessToken(oauthToken, null)); + } else if (this.fixedCredentials != null) { + this.credentials = fixedCredentials; } else { - this.credentials = - builder.credentials == null - ? getCredentialsService().createCredentials(this.credentialsUrl) - : builder.credentials; + this.credentials = getCredentialsService().createCredentials(this.credentialsUrl); } String numChannelsValue = parseNumChannels(builder.uri); if (numChannelsValue != null) { @@ -593,6 +594,14 @@ public String getCredentialsUrl() { return credentialsUrl; } + String getOAuthToken() { + return this.oauthToken; + } + + Credentials getFixedCredentials() { + return this.fixedCredentials; + } + /** The {@link SessionPoolOptions} of this {@link ConnectionOptions}. */ public SessionPoolOptions getSessionPoolOptions() { return sessionPoolOptions; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java index 7116bc17f3..ecf13cd399 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java @@ -17,7 +17,6 @@ package com.google.cloud.spanner.connection; import com.google.api.core.ApiFunction; -import com.google.auth.Credentials; import com.google.cloud.NoCredentials; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SessionPoolOptions; @@ -28,8 +27,11 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; import io.grpc.ManagedChannelBuilder; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -108,10 +110,38 @@ public void run() { } } + static class CredentialsKey { + static final Object DEFAULT_CREDENTIALS_KEY = new Object(); + final Object key; + + static CredentialsKey create(ConnectionOptions options) { + return new CredentialsKey( + Iterables.find( + Arrays.asList( + options.getOAuthToken(), + options.getFixedCredentials(), + options.getCredentialsUrl(), + DEFAULT_CREDENTIALS_KEY), + Predicates.notNull())); + } + + private CredentialsKey(Object key) { + this.key = Preconditions.checkNotNull(key); + } + + public int hashCode() { + return key.hashCode(); + } + + public boolean equals(Object o) { + return (o instanceof CredentialsKey && Objects.equals(((CredentialsKey) o).key, this.key)); + } + } + static class SpannerPoolKey { private final String host; private final String projectId; - private final Credentials credentials; + private final CredentialsKey credentialsKey; private final SessionPoolOptions sessionPoolOptions; private final Integer numChannels; private final boolean usePlainText; @@ -124,7 +154,7 @@ private static SpannerPoolKey of(ConnectionOptions options) { private SpannerPoolKey(ConnectionOptions options) { this.host = options.getHost(); this.projectId = options.getProjectId(); - this.credentials = options.getCredentials(); + this.credentialsKey = CredentialsKey.create(options); this.sessionPoolOptions = options.getSessionPoolOptions(); this.numChannels = options.getNumChannels(); this.usePlainText = options.isUsePlainText(); @@ -139,7 +169,7 @@ public boolean equals(Object o) { SpannerPoolKey other = (SpannerPoolKey) o; return Objects.equals(this.host, other.host) && Objects.equals(this.projectId, other.projectId) - && Objects.equals(this.credentials, other.credentials) + && Objects.equals(this.credentialsKey, other.credentialsKey) && Objects.equals(this.sessionPoolOptions, other.sessionPoolOptions) && Objects.equals(this.numChannels, other.numChannels) && Objects.equals(this.usePlainText, other.usePlainText) @@ -151,7 +181,7 @@ public int hashCode() { return Objects.hash( this.host, this.projectId, - this.credentials, + this.credentialsKey, this.sessionPoolOptions, this.numChannels, this.usePlainText, @@ -240,7 +270,7 @@ Spanner getSpanner(ConnectionOptions options, ConnectionImpl connection) { if (spanners.get(key) != null) { spanner = spanners.get(key); } else { - spanner = createSpanner(key); + spanner = createSpanner(key, options); spanners.put(key, spanner); } List registeredConnectionsForSpanner = connections.get(key); @@ -279,13 +309,13 @@ public Thread newThread(Runnable r) { @SuppressWarnings("rawtypes") @VisibleForTesting - Spanner createSpanner(SpannerPoolKey key) { + Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) { SpannerOptions.Builder builder = SpannerOptions.newBuilder(); builder .setClientLibToken(MoreObjects.firstNonNull(key.userAgent, CONNECTION_API_CLIENT_LIB_TOKEN)) .setHost(key.host) .setProjectId(key.projectId) - .setCredentials(key.credentials); + .setCredentials(options.getCredentials()); builder.setSessionPoolOption(key.sessionPoolOptions); if (key.numChannels != null) { builder.setNumChannels(key.numChannels); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java index 6c387f0d48..ef744d31a1 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import com.google.api.client.util.BackOff; import com.google.common.collect.AbstractIterator; import com.google.common.collect.Lists; import com.google.protobuf.ByteString; @@ -29,10 +30,12 @@ import com.google.rpc.RetryInfo; import com.google.spanner.v1.PartialResultSet; import io.grpc.Metadata; +import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.protobuf.ProtoUtils; import io.opencensus.trace.EndSpanOptions; import io.opencensus.trace.Span; +import java.io.IOException; import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedList; @@ -79,6 +82,11 @@ static class RetryableException extends SpannerException { // OK to instantiate SpannerException directly for this unit test. super(DoNotConstructDirectly.ALLOWED, code, true, message, statusWithRetryInfo(code)); } + + RetryableException(ErrorCode code, @Nullable String message, StatusRuntimeException cause) { + // OK to instantiate SpannerException directly for this unit test. + super(DoNotConstructDirectly.ALLOWED, code, true, message, cause); + } } static class NonRetryableException extends SpannerException { @@ -220,6 +228,30 @@ public void restartWithHoldBackMidStream() { .inOrder(); } + @Test + public void retryableErrorWithoutRetryInfo() throws IOException { + BackOff backOff = mock(BackOff.class); + Mockito.when(backOff.nextBackOffMillis()).thenReturn(1L); + Whitebox.setInternalState(this.resumableStreamIterator, "backOff", backOff); + + ResultSetStream s1 = Mockito.mock(ResultSetStream.class); + Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(s1.next()) + .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) + .thenThrow( + new RetryableException( + ErrorCode.UNAVAILABLE, "failed by test", Status.UNAVAILABLE.asRuntimeException())); + + ResultSetStream s2 = Mockito.mock(ResultSetStream.class); + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"))) + .thenReturn(new ResultSetIterator(s2)); + Mockito.when(s2.next()) + .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) + .thenReturn(null); + assertThat(consume(resumableStreamIterator)).containsExactly("a", "b").inOrder(); + verify(backOff).nextBackOffMillis(); + } + @Test public void nonRetryableError() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SpannerPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SpannerPoolTest.java index c0145203ce..3c0e9cf160 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SpannerPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SpannerPoolTest.java @@ -26,7 +26,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.NoCredentials; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SessionPoolOptions; @@ -34,6 +33,7 @@ import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.connection.ConnectionImpl.LeakedConnectionException; import com.google.cloud.spanner.connection.SpannerPool.CheckAndCloseSpannersMode; +import com.google.cloud.spanner.connection.SpannerPool.SpannerPoolKey; import java.io.ByteArrayOutputStream; import java.io.OutputStream; import java.util.logging.Handler; @@ -51,13 +51,16 @@ public class SpannerPoolTest { private ConnectionImpl connection1 = mock(ConnectionImpl.class); private ConnectionImpl connection2 = mock(ConnectionImpl.class); private ConnectionImpl connection3 = mock(ConnectionImpl.class); - private GoogleCredentials credentials1 = mock(GoogleCredentials.class); - private GoogleCredentials credentials2 = mock(GoogleCredentials.class); + private String credentials1 = "credentials1"; + private String credentials2 = "credentials2"; private ConnectionOptions options1 = mock(ConnectionOptions.class); private ConnectionOptions options2 = mock(ConnectionOptions.class); private ConnectionOptions options3 = mock(ConnectionOptions.class); private ConnectionOptions options4 = mock(ConnectionOptions.class); + private ConnectionOptions options5 = mock(ConnectionOptions.class); + private ConnectionOptions options6 = mock(ConnectionOptions.class); + private SpannerPool createSubjectAndMocks() { return createSubjectAndMocks(0L); } @@ -66,21 +69,25 @@ private SpannerPool createSubjectAndMocks(long closeSpannerAfterMillisecondsUnus SpannerPool pool = new SpannerPool(closeSpannerAfterMillisecondsUnused) { @Override - Spanner createSpanner(SpannerPoolKey key) { + Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) { return mock(Spanner.class); } }; - when(options1.getCredentials()).thenReturn(credentials1); + when(options1.getCredentialsUrl()).thenReturn(credentials1); when(options1.getProjectId()).thenReturn("test-project-1"); - when(options2.getCredentials()).thenReturn(credentials2); + when(options2.getCredentialsUrl()).thenReturn(credentials2); when(options2.getProjectId()).thenReturn("test-project-1"); - when(options3.getCredentials()).thenReturn(credentials1); + when(options3.getCredentialsUrl()).thenReturn(credentials1); when(options3.getProjectId()).thenReturn("test-project-2"); - when(options4.getCredentials()).thenReturn(credentials2); + when(options4.getCredentialsUrl()).thenReturn(credentials2); when(options4.getProjectId()).thenReturn("test-project-2"); + // ConnectionOptions with no specific credentials. + when(options5.getProjectId()).thenReturn("test-project-3"); + when(options6.getProjectId()).thenReturn("test-project-3"); + return pool; } @@ -108,6 +115,10 @@ public void testGetSpanner() { spanner1 = pool.getSpanner(options4, connection1); spanner2 = pool.getSpanner(options4, connection2); assertThat(spanner1, is(equalTo(spanner2))); + // Options 5 and 6 both use default credentials. + spanner1 = pool.getSpanner(options5, connection1); + spanner2 = pool.getSpanner(options6, connection2); + assertThat(spanner1, is(equalTo(spanner2))); // assert not equal spanner1 = pool.getSpanner(options1, connection1);