Skip to content

Commit

Permalink
fix: allows user-agent header with header provider (#871)
Browse files Browse the repository at this point in the history
* fix: allows user-agent header with header provider

A bug was introduced, where if the caller tried to set a custom user
agent with a header provider an exception would be thrown (for duplicate
keys). Here, we merge the user agent set by the client along with the
one set by the library, instead of throwing such exception.

* test: adds test for default user agent

Tests if the default user agent is present in the user-agent header set
in the GapicSpannerRpc class.
  • Loading branch information
thiagotnunes committed Feb 17, 2021
1 parent ab14a5e commit 3de7e2a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 12 deletions.
Expand Up @@ -77,7 +77,6 @@
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.RateLimiter;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand Down Expand Up @@ -161,6 +160,7 @@
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -244,6 +244,8 @@ private void awaitTermination() throws InterruptedException {
private static final int GRPC_KEEPALIVE_SECONDS = 2 * 60;
private static final String USER_AGENT_KEY = "user-agent";
private static final String CLIENT_LIBRARY_LANGUAGE = "spanner-java";
public static final String DEFAULT_USER_AGENT =
CLIENT_LIBRARY_LANGUAGE + "/" + GaxProperties.getLibraryVersion(GapicSpannerRpc.class);

private final ManagedInstantiatingExecutorProvider executorProvider;
private boolean rpcIsClosed;
Expand Down Expand Up @@ -305,18 +307,11 @@ public GapicSpannerRpc(final SpannerOptions options) {
GaxGrpcProperties.getGrpcTokenName(), GaxGrpcProperties.getGrpcVersion())
.build();

HeaderProvider mergedHeaderProvider = options.getMergedHeaderProvider(internalHeaderProvider);
Map<String, String> headersWithUserAgent =
ImmutableMap.<String, String>builder()
.put(
USER_AGENT_KEY,
CLIENT_LIBRARY_LANGUAGE
+ "/"
+ GaxProperties.getLibraryVersion(GapicSpannerRpc.class))
.putAll(mergedHeaderProvider.getHeaders())
.build();
final HeaderProvider mergedHeaderProvider =
options.getMergedHeaderProvider(internalHeaderProvider);
final HeaderProvider headerProviderWithUserAgent =
FixedHeaderProvider.create(headersWithUserAgent);
headerProviderWithUserAgentFrom(mergedHeaderProvider);

this.metadataProvider =
SpannerMetadataProvider.create(
headerProviderWithUserAgent.getHeaders(),
Expand Down Expand Up @@ -494,6 +489,16 @@ public <RequestT, ResponseT> UnaryCallable<RequestT, ResponseT> createUnaryCalla
}
}

private static HeaderProvider headerProviderWithUserAgentFrom(HeaderProvider headerProvider) {
final Map<String, String> headersWithUserAgent = new HashMap<>(headerProvider.getHeaders());
final String userAgent = headersWithUserAgent.get(USER_AGENT_KEY);
headersWithUserAgent.put(
USER_AGENT_KEY,
userAgent == null ? DEFAULT_USER_AGENT : userAgent + " " + DEFAULT_USER_AGENT);

return FixedHeaderProvider.create(headersWithUserAgent);
}

private static void checkEmulatorConnection(
SpannerOptions options,
TransportChannelProvider channelProvider,
Expand Down
Expand Up @@ -24,7 +24,9 @@
import static org.junit.Assume.assumeTrue;

import com.google.api.core.ApiFunction;
import com.google.api.gax.core.GaxProperties;
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.cloud.spanner.DatabaseAdminClient;
Expand Down Expand Up @@ -151,6 +153,8 @@ public class GapicSpannerRpcTest {
private Server server;
private InetSocketAddress address;
private final Map<SpannerRpc.Option, Object> optionsMap = new HashMap<>();
private Metadata seenHeaders;
private String defaultUserAgent;

@BeforeClass
public static void checkNotEmulator() {
Expand All @@ -161,6 +165,7 @@ public static void checkNotEmulator() {

@Before
public void startServer() throws IOException {
defaultUserAgent = "spanner-java/" + GaxProperties.getLibraryVersion(GapicSpannerRpc.class);
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.query(SELECT1AND2, SELECT1_RESULTSET));
Expand All @@ -183,6 +188,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
seenHeaders = headers;
String auth =
headers.get(Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER));
assertThat(auth).isEqualTo("Bearer " + VARIABLE_OAUTH_TOKEN);
Expand Down Expand Up @@ -502,6 +508,46 @@ public void testAdminRequestsLimitExceededRetryAlgorithm() {
assertThat(alg.shouldRetry(new Exception("random exception"), null)).isFalse();
}

@Test
public void testDefaultUserAgent() {
final SpannerOptions options = createSpannerOptions();
final Spanner spanner = options.getService();
final DatabaseClient databaseClient =
spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]"));

try (final ResultSet rs = databaseClient.singleUse().executeQuery(SELECT1AND2)) {
rs.next();
}

assertThat(seenHeaders.get(Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER)))
.contains(defaultUserAgent);
}

@Test
public void testCustomUserAgent() {
final HeaderProvider userAgentHeaderProvider =
new HeaderProvider() {
@Override
public Map<String, String> getHeaders() {
final Map<String, String> headers = new HashMap<>();
headers.put("user-agent", "test-agent");
return headers;
}
};
final SpannerOptions options =
createSpannerOptions().toBuilder().setHeaderProvider(userAgentHeaderProvider).build();
final Spanner spanner = options.getService();
final DatabaseClient databaseClient =
spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]"));

try (final ResultSet rs = databaseClient.singleUse().executeQuery(SELECT1AND2)) {
rs.next();
}

assertThat(seenHeaders.get(Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER)))
.contains("test-agent " + defaultUserAgent);
}

@SuppressWarnings("rawtypes")
private SpannerOptions createSpannerOptions() {
String endpoint = address.getHostString() + ":" + server.getPort();
Expand Down

0 comments on commit 3de7e2a

Please sign in to comment.