Skip to content

Commit

Permalink
fix: use credentials key in pool (#430)
Browse files Browse the repository at this point in the history
* fix: use credentials key in pool

* fix: remove unused test class

* test: increase test coverage
  • Loading branch information
olavloite committed Sep 29, 2020
1 parent e620a15 commit 28103fb
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 20 deletions.
Expand Up @@ -384,6 +384,7 @@ public static Builder newBuilder() {
private final String uri;
private final String credentialsUrl;
private final String oauthToken;
private final Credentials fixedCredentials;

private final boolean usePlainText;
private final String host;
Expand Down Expand Up @@ -413,6 +414,7 @@ private ConnectionOptions(Builder builder) {
builder.credentialsUrl != null ? builder.credentialsUrl : parseCredentials(builder.uri);
this.oauthToken =
builder.oauthToken != null ? builder.oauthToken : parseOAuthToken(builder.uri);
this.fixedCredentials = builder.credentials;
// Check that not both credentials and an OAuth token have been specified.
Preconditions.checkArgument(
(builder.credentials == null && this.credentialsUrl == null) || this.oauthToken == null,
Expand Down Expand Up @@ -441,11 +443,10 @@ private ConnectionOptions(Builder builder) {
this.credentials = NoCredentials.getInstance();
} else if (this.oauthToken != null) {
this.credentials = new GoogleCredentials(new AccessToken(oauthToken, null));
} else if (this.fixedCredentials != null) {
this.credentials = fixedCredentials;
} else {
this.credentials =
builder.credentials == null
? getCredentialsService().createCredentials(this.credentialsUrl)
: builder.credentials;
this.credentials = getCredentialsService().createCredentials(this.credentialsUrl);
}
String numChannelsValue = parseNumChannels(builder.uri);
if (numChannelsValue != null) {
Expand Down Expand Up @@ -593,6 +594,14 @@ public String getCredentialsUrl() {
return credentialsUrl;
}

String getOAuthToken() {
return this.oauthToken;
}

Credentials getFixedCredentials() {
return this.fixedCredentials;
}

/** The {@link SessionPoolOptions} of this {@link ConnectionOptions}. */
public SessionPoolOptions getSessionPoolOptions() {
return sessionPoolOptions;
Expand Down
Expand Up @@ -17,7 +17,6 @@
package com.google.cloud.spanner.connection;

import com.google.api.core.ApiFunction;
import com.google.auth.Credentials;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SessionPoolOptions;
Expand All @@ -28,8 +27,11 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import io.grpc.ManagedChannelBuilder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -108,10 +110,38 @@ public void run() {
}
}

static class CredentialsKey {
static final Object DEFAULT_CREDENTIALS_KEY = new Object();
final Object key;

static CredentialsKey create(ConnectionOptions options) {
return new CredentialsKey(
Iterables.find(
Arrays.asList(
options.getOAuthToken(),
options.getFixedCredentials(),
options.getCredentialsUrl(),
DEFAULT_CREDENTIALS_KEY),
Predicates.notNull()));
}

private CredentialsKey(Object key) {
this.key = Preconditions.checkNotNull(key);
}

public int hashCode() {
return key.hashCode();
}

public boolean equals(Object o) {
return (o instanceof CredentialsKey && Objects.equals(((CredentialsKey) o).key, this.key));
}
}

static class SpannerPoolKey {
private final String host;
private final String projectId;
private final Credentials credentials;
private final CredentialsKey credentialsKey;
private final SessionPoolOptions sessionPoolOptions;
private final Integer numChannels;
private final boolean usePlainText;
Expand All @@ -124,7 +154,7 @@ private static SpannerPoolKey of(ConnectionOptions options) {
private SpannerPoolKey(ConnectionOptions options) {
this.host = options.getHost();
this.projectId = options.getProjectId();
this.credentials = options.getCredentials();
this.credentialsKey = CredentialsKey.create(options);
this.sessionPoolOptions = options.getSessionPoolOptions();
this.numChannels = options.getNumChannels();
this.usePlainText = options.isUsePlainText();
Expand All @@ -139,7 +169,7 @@ public boolean equals(Object o) {
SpannerPoolKey other = (SpannerPoolKey) o;
return Objects.equals(this.host, other.host)
&& Objects.equals(this.projectId, other.projectId)
&& Objects.equals(this.credentials, other.credentials)
&& Objects.equals(this.credentialsKey, other.credentialsKey)
&& Objects.equals(this.sessionPoolOptions, other.sessionPoolOptions)
&& Objects.equals(this.numChannels, other.numChannels)
&& Objects.equals(this.usePlainText, other.usePlainText)
Expand All @@ -151,7 +181,7 @@ public int hashCode() {
return Objects.hash(
this.host,
this.projectId,
this.credentials,
this.credentialsKey,
this.sessionPoolOptions,
this.numChannels,
this.usePlainText,
Expand Down Expand Up @@ -240,7 +270,7 @@ Spanner getSpanner(ConnectionOptions options, ConnectionImpl connection) {
if (spanners.get(key) != null) {
spanner = spanners.get(key);
} else {
spanner = createSpanner(key);
spanner = createSpanner(key, options);
spanners.put(key, spanner);
}
List<ConnectionImpl> registeredConnectionsForSpanner = connections.get(key);
Expand Down Expand Up @@ -279,13 +309,13 @@ public Thread newThread(Runnable r) {

@SuppressWarnings("rawtypes")
@VisibleForTesting
Spanner createSpanner(SpannerPoolKey key) {
Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) {
SpannerOptions.Builder builder = SpannerOptions.newBuilder();
builder
.setClientLibToken(MoreObjects.firstNonNull(key.userAgent, CONNECTION_API_CLIENT_LIB_TOKEN))
.setHost(key.host)
.setProjectId(key.projectId)
.setCredentials(key.credentials);
.setCredentials(options.getCredentials());
builder.setSessionPoolOption(key.sessionPoolOptions);
if (key.numChannels != null) {
builder.setNumChannels(key.numChannels);
Expand Down
Expand Up @@ -21,6 +21,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

import com.google.api.client.util.BackOff;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
Expand All @@ -29,10 +30,12 @@
import com.google.rpc.RetryInfo;
import com.google.spanner.v1.PartialResultSet;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.protobuf.ProtoUtils;
import io.opencensus.trace.EndSpanOptions;
import io.opencensus.trace.Span;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
Expand Down Expand Up @@ -79,6 +82,11 @@ static class RetryableException extends SpannerException {
// OK to instantiate SpannerException directly for this unit test.
super(DoNotConstructDirectly.ALLOWED, code, true, message, statusWithRetryInfo(code));
}

RetryableException(ErrorCode code, @Nullable String message, StatusRuntimeException cause) {
// OK to instantiate SpannerException directly for this unit test.
super(DoNotConstructDirectly.ALLOWED, code, true, message, cause);
}
}

static class NonRetryableException extends SpannerException {
Expand Down Expand Up @@ -220,6 +228,30 @@ public void restartWithHoldBackMidStream() {
.inOrder();
}

@Test
public void retryableErrorWithoutRetryInfo() throws IOException {
BackOff backOff = mock(BackOff.class);
Mockito.when(backOff.nextBackOffMillis()).thenReturn(1L);
Whitebox.setInternalState(this.resumableStreamIterator, "backOff", backOff);

ResultSetStream s1 = Mockito.mock(ResultSetStream.class);
Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1));
Mockito.when(s1.next())
.thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a"))
.thenThrow(
new RetryableException(
ErrorCode.UNAVAILABLE, "failed by test", Status.UNAVAILABLE.asRuntimeException()));

ResultSetStream s2 = Mockito.mock(ResultSetStream.class);
Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1")))
.thenReturn(new ResultSetIterator(s2));
Mockito.when(s2.next())
.thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b"))
.thenReturn(null);
assertThat(consume(resumableStreamIterator)).containsExactly("a", "b").inOrder();
verify(backOff).nextBackOffMillis();
}

@Test
public void nonRetryableError() {
ResultSetStream s1 = Mockito.mock(ResultSetStream.class);
Expand Down
Expand Up @@ -26,14 +26,14 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SessionPoolOptions;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.connection.ConnectionImpl.LeakedConnectionException;
import com.google.cloud.spanner.connection.SpannerPool.CheckAndCloseSpannersMode;
import com.google.cloud.spanner.connection.SpannerPool.SpannerPoolKey;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.util.logging.Handler;
Expand All @@ -51,13 +51,16 @@ public class SpannerPoolTest {
private ConnectionImpl connection1 = mock(ConnectionImpl.class);
private ConnectionImpl connection2 = mock(ConnectionImpl.class);
private ConnectionImpl connection3 = mock(ConnectionImpl.class);
private GoogleCredentials credentials1 = mock(GoogleCredentials.class);
private GoogleCredentials credentials2 = mock(GoogleCredentials.class);
private String credentials1 = "credentials1";
private String credentials2 = "credentials2";
private ConnectionOptions options1 = mock(ConnectionOptions.class);
private ConnectionOptions options2 = mock(ConnectionOptions.class);
private ConnectionOptions options3 = mock(ConnectionOptions.class);
private ConnectionOptions options4 = mock(ConnectionOptions.class);

private ConnectionOptions options5 = mock(ConnectionOptions.class);
private ConnectionOptions options6 = mock(ConnectionOptions.class);

private SpannerPool createSubjectAndMocks() {
return createSubjectAndMocks(0L);
}
Expand All @@ -66,21 +69,25 @@ private SpannerPool createSubjectAndMocks(long closeSpannerAfterMillisecondsUnus
SpannerPool pool =
new SpannerPool(closeSpannerAfterMillisecondsUnused) {
@Override
Spanner createSpanner(SpannerPoolKey key) {
Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) {
return mock(Spanner.class);
}
};

when(options1.getCredentials()).thenReturn(credentials1);
when(options1.getCredentialsUrl()).thenReturn(credentials1);
when(options1.getProjectId()).thenReturn("test-project-1");
when(options2.getCredentials()).thenReturn(credentials2);
when(options2.getCredentialsUrl()).thenReturn(credentials2);
when(options2.getProjectId()).thenReturn("test-project-1");

when(options3.getCredentials()).thenReturn(credentials1);
when(options3.getCredentialsUrl()).thenReturn(credentials1);
when(options3.getProjectId()).thenReturn("test-project-2");
when(options4.getCredentials()).thenReturn(credentials2);
when(options4.getCredentialsUrl()).thenReturn(credentials2);
when(options4.getProjectId()).thenReturn("test-project-2");

// ConnectionOptions with no specific credentials.
when(options5.getProjectId()).thenReturn("test-project-3");
when(options6.getProjectId()).thenReturn("test-project-3");

return pool;
}

Expand Down Expand Up @@ -108,6 +115,10 @@ public void testGetSpanner() {
spanner1 = pool.getSpanner(options4, connection1);
spanner2 = pool.getSpanner(options4, connection2);
assertThat(spanner1, is(equalTo(spanner2)));
// Options 5 and 6 both use default credentials.
spanner1 = pool.getSpanner(options5, connection1);
spanner2 = pool.getSpanner(options6, connection2);
assertThat(spanner1, is(equalTo(spanner2)));

// assert not equal
spanner1 = pool.getSpanner(options1, connection1);
Expand Down

0 comments on commit 28103fb

Please sign in to comment.