Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use credentials key in pool #430

Merged
merged 3 commits into from Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -384,6 +384,7 @@ public static Builder newBuilder() {
private final String uri;
private final String credentialsUrl;
private final String oauthToken;
private final Credentials fixedCredentials;

private final boolean usePlainText;
private final String host;
Expand Down Expand Up @@ -413,6 +414,7 @@ private ConnectionOptions(Builder builder) {
builder.credentialsUrl != null ? builder.credentialsUrl : parseCredentials(builder.uri);
this.oauthToken =
builder.oauthToken != null ? builder.oauthToken : parseOAuthToken(builder.uri);
this.fixedCredentials = builder.credentials;
// Check that not both credentials and an OAuth token have been specified.
Preconditions.checkArgument(
(builder.credentials == null && this.credentialsUrl == null) || this.oauthToken == null,
Expand Down Expand Up @@ -441,11 +443,10 @@ private ConnectionOptions(Builder builder) {
this.credentials = NoCredentials.getInstance();
} else if (this.oauthToken != null) {
this.credentials = new GoogleCredentials(new AccessToken(oauthToken, null));
} else if (this.fixedCredentials != null) {
this.credentials = fixedCredentials;
} else {
this.credentials =
builder.credentials == null
? getCredentialsService().createCredentials(this.credentialsUrl)
: builder.credentials;
this.credentials = getCredentialsService().createCredentials(this.credentialsUrl);
}
String numChannelsValue = parseNumChannels(builder.uri);
if (numChannelsValue != null) {
Expand Down Expand Up @@ -593,6 +594,14 @@ public String getCredentialsUrl() {
return credentialsUrl;
}

String getOAuthToken() {
return this.oauthToken;
}

Credentials getFixedCredentials() {
return this.fixedCredentials;
}

/** The {@link SessionPoolOptions} of this {@link ConnectionOptions}. */
public SessionPoolOptions getSessionPoolOptions() {
return sessionPoolOptions;
Expand Down
Expand Up @@ -17,7 +17,6 @@
package com.google.cloud.spanner.connection;

import com.google.api.core.ApiFunction;
import com.google.auth.Credentials;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SessionPoolOptions;
Expand All @@ -28,8 +27,11 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import io.grpc.ManagedChannelBuilder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -108,10 +110,38 @@ public void run() {
}
}

static class CredentialsKey {
static final Object DEFAULT_CREDENTIALS_KEY = new Object();
final Object key;

static CredentialsKey create(ConnectionOptions options) {
return new CredentialsKey(
Iterables.find(
Arrays.asList(
options.getOAuthToken(),
options.getFixedCredentials(),
options.getCredentialsUrl(),
DEFAULT_CREDENTIALS_KEY),
Predicates.notNull()));
}

private CredentialsKey(Object key) {
this.key = Preconditions.checkNotNull(key);
}

public int hashCode() {
return key.hashCode();
}

public boolean equals(Object o) {
return (o instanceof CredentialsKey && Objects.equals(((CredentialsKey) o).key, this.key));
}
}

static class SpannerPoolKey {
private final String host;
private final String projectId;
private final Credentials credentials;
private final CredentialsKey credentialsKey;
private final SessionPoolOptions sessionPoolOptions;
private final Integer numChannels;
private final boolean usePlainText;
Expand All @@ -124,7 +154,7 @@ private static SpannerPoolKey of(ConnectionOptions options) {
private SpannerPoolKey(ConnectionOptions options) {
this.host = options.getHost();
this.projectId = options.getProjectId();
this.credentials = options.getCredentials();
this.credentialsKey = CredentialsKey.create(options);
this.sessionPoolOptions = options.getSessionPoolOptions();
this.numChannels = options.getNumChannels();
this.usePlainText = options.isUsePlainText();
Expand All @@ -139,7 +169,7 @@ public boolean equals(Object o) {
SpannerPoolKey other = (SpannerPoolKey) o;
return Objects.equals(this.host, other.host)
&& Objects.equals(this.projectId, other.projectId)
&& Objects.equals(this.credentials, other.credentials)
&& Objects.equals(this.credentialsKey, other.credentialsKey)
&& Objects.equals(this.sessionPoolOptions, other.sessionPoolOptions)
&& Objects.equals(this.numChannels, other.numChannels)
&& Objects.equals(this.usePlainText, other.usePlainText)
Expand All @@ -151,7 +181,7 @@ public int hashCode() {
return Objects.hash(
this.host,
this.projectId,
this.credentials,
this.credentialsKey,
this.sessionPoolOptions,
this.numChannels,
this.usePlainText,
Expand Down Expand Up @@ -240,7 +270,7 @@ Spanner getSpanner(ConnectionOptions options, ConnectionImpl connection) {
if (spanners.get(key) != null) {
spanner = spanners.get(key);
} else {
spanner = createSpanner(key);
spanner = createSpanner(key, options);
spanners.put(key, spanner);
}
List<ConnectionImpl> registeredConnectionsForSpanner = connections.get(key);
Expand Down Expand Up @@ -279,13 +309,13 @@ public Thread newThread(Runnable r) {

@SuppressWarnings("rawtypes")
@VisibleForTesting
Spanner createSpanner(SpannerPoolKey key) {
Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) {
SpannerOptions.Builder builder = SpannerOptions.newBuilder();
builder
.setClientLibToken(MoreObjects.firstNonNull(key.userAgent, CONNECTION_API_CLIENT_LIB_TOKEN))
.setHost(key.host)
.setProjectId(key.projectId)
.setCredentials(key.credentials);
.setCredentials(options.getCredentials());
builder.setSessionPoolOption(key.sessionPoolOptions);
if (key.numChannels != null) {
builder.setNumChannels(key.numChannels);
Expand Down
Expand Up @@ -21,6 +21,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

import com.google.api.client.util.BackOff;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
Expand All @@ -29,10 +30,12 @@
import com.google.rpc.RetryInfo;
import com.google.spanner.v1.PartialResultSet;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.protobuf.ProtoUtils;
import io.opencensus.trace.EndSpanOptions;
import io.opencensus.trace.Span;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
Expand Down Expand Up @@ -79,6 +82,11 @@ static class RetryableException extends SpannerException {
// OK to instantiate SpannerException directly for this unit test.
super(DoNotConstructDirectly.ALLOWED, code, true, message, statusWithRetryInfo(code));
}

RetryableException(ErrorCode code, @Nullable String message, StatusRuntimeException cause) {
// OK to instantiate SpannerException directly for this unit test.
super(DoNotConstructDirectly.ALLOWED, code, true, message, cause);
}
}

static class NonRetryableException extends SpannerException {
Expand Down Expand Up @@ -220,6 +228,30 @@ public void restartWithHoldBackMidStream() {
.inOrder();
}

@Test
public void retryableErrorWithoutRetryInfo() throws IOException {
BackOff backOff = mock(BackOff.class);
Mockito.when(backOff.nextBackOffMillis()).thenReturn(1L);
Whitebox.setInternalState(this.resumableStreamIterator, "backOff", backOff);

ResultSetStream s1 = Mockito.mock(ResultSetStream.class);
Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1));
Mockito.when(s1.next())
.thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a"))
.thenThrow(
new RetryableException(
ErrorCode.UNAVAILABLE, "failed by test", Status.UNAVAILABLE.asRuntimeException()));

ResultSetStream s2 = Mockito.mock(ResultSetStream.class);
Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1")))
.thenReturn(new ResultSetIterator(s2));
Mockito.when(s2.next())
.thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b"))
.thenReturn(null);
assertThat(consume(resumableStreamIterator)).containsExactly("a", "b").inOrder();
verify(backOff).nextBackOffMillis();
}

@Test
public void nonRetryableError() {
ResultSetStream s1 = Mockito.mock(ResultSetStream.class);
Expand Down
Expand Up @@ -26,14 +26,14 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.NoCredentials;
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SessionPoolOptions;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.connection.ConnectionImpl.LeakedConnectionException;
import com.google.cloud.spanner.connection.SpannerPool.CheckAndCloseSpannersMode;
import com.google.cloud.spanner.connection.SpannerPool.SpannerPoolKey;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.util.logging.Handler;
Expand All @@ -51,13 +51,16 @@ public class SpannerPoolTest {
private ConnectionImpl connection1 = mock(ConnectionImpl.class);
private ConnectionImpl connection2 = mock(ConnectionImpl.class);
private ConnectionImpl connection3 = mock(ConnectionImpl.class);
private GoogleCredentials credentials1 = mock(GoogleCredentials.class);
private GoogleCredentials credentials2 = mock(GoogleCredentials.class);
private String credentials1 = "credentials1";
private String credentials2 = "credentials2";
private ConnectionOptions options1 = mock(ConnectionOptions.class);
private ConnectionOptions options2 = mock(ConnectionOptions.class);
private ConnectionOptions options3 = mock(ConnectionOptions.class);
private ConnectionOptions options4 = mock(ConnectionOptions.class);

private ConnectionOptions options5 = mock(ConnectionOptions.class);
private ConnectionOptions options6 = mock(ConnectionOptions.class);

private SpannerPool createSubjectAndMocks() {
return createSubjectAndMocks(0L);
}
Expand All @@ -66,21 +69,25 @@ private SpannerPool createSubjectAndMocks(long closeSpannerAfterMillisecondsUnus
SpannerPool pool =
new SpannerPool(closeSpannerAfterMillisecondsUnused) {
@Override
Spanner createSpanner(SpannerPoolKey key) {
Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) {
return mock(Spanner.class);
}
};

when(options1.getCredentials()).thenReturn(credentials1);
when(options1.getCredentialsUrl()).thenReturn(credentials1);
when(options1.getProjectId()).thenReturn("test-project-1");
when(options2.getCredentials()).thenReturn(credentials2);
when(options2.getCredentialsUrl()).thenReturn(credentials2);
when(options2.getProjectId()).thenReturn("test-project-1");

when(options3.getCredentials()).thenReturn(credentials1);
when(options3.getCredentialsUrl()).thenReturn(credentials1);
when(options3.getProjectId()).thenReturn("test-project-2");
when(options4.getCredentials()).thenReturn(credentials2);
when(options4.getCredentialsUrl()).thenReturn(credentials2);
when(options4.getProjectId()).thenReturn("test-project-2");

// ConnectionOptions with no specific credentials.
when(options5.getProjectId()).thenReturn("test-project-3");
when(options6.getProjectId()).thenReturn("test-project-3");

return pool;
}

Expand Down Expand Up @@ -108,6 +115,10 @@ public void testGetSpanner() {
spanner1 = pool.getSpanner(options4, connection1);
spanner2 = pool.getSpanner(options4, connection2);
assertThat(spanner1, is(equalTo(spanner2)));
// Options 5 and 6 both use default credentials.
spanner1 = pool.getSpanner(options5, connection1);
spanner2 = pool.getSpanner(options6, connection2);
assertThat(spanner1, is(equalTo(spanner2)));

// assert not equal
spanner1 = pool.getSpanner(options1, connection1);
Expand Down