Skip to content

Commit

Permalink
Fix the bug to remove interal key when receiving the 100-continue hea…
Browse files Browse the repository at this point in the history
…der (#2781)

Co-authored-by: Sophie Guo <sopguo@sopguo-mn2.linkedin.biz>
  • Loading branch information
SophieGuo410 and Sophie Guo committed May 8, 2024
1 parent 2cef2c6 commit eae3750
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ public interface RestRequest extends ReadableStreamChannel {
*/
Object setArg(String key, Object value);

/**
* Remove one argument from the key-value map.
* @param key The key of the argument.
*/
void removeArg(String key);

/**
* If this request was over HTTPS, gets the {@link SSLSession} associated with the request.
* @return The {@link SSLSession} for the request and response, or {@code null} if SSL was not used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ public void handlePost(RestRequest restRequest, RestResponseChannel restResponse
public void handlePut(RestRequest restRequest, RestResponseChannel restResponseChannel) {
if (shouldProceed(restRequest, restResponseChannel)) {
try {
//set the internal key during preProcessRequest.
restRequest.setArg(InternalKeys.SEND_TRACKING_INFO, "true");
if (CONTINUE.equals(restRequest.getArgs().get(EXPECT))) {
restResponseChannel.setStatus(ResponseStatus.Continue);
handleResponse(restRequest, restResponseChannel, null, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.Collections;
import java.util.Date;
import java.util.GregorianCalendar;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
Expand All @@ -63,6 +64,10 @@ class AmbrySecurityService implements SecurityService {
static final Set<String> OPERATIONS = Collections.unmodifiableSet(
Utils.getStaticFieldValuesAsStrings(Operations.class)
.collect(Collectors.toCollection(() -> new TreeSet<>(String.CASE_INSENSITIVE_ORDER))));

static final List<String> INTERNAL_KEYS = Collections.unmodifiableList(
Utils.getStaticFieldValuesAsStrings(InternalKeys.class).collect(Collectors.toList()));

private final FrontendConfig frontendConfig;
private final FrontendMetrics frontendMetrics;
private final UrlSigningService urlSigningService;
Expand Down Expand Up @@ -94,6 +99,10 @@ public void preProcessRequest(RestRequest restRequest, Callback<Void> callback)
} else if (restRequest.getArgs().containsKey(InternalKeys.KEEP_ALIVE_ON_ERROR_HINT)) {
exception = new RestServiceException(InternalKeys.KEEP_ALIVE_ON_ERROR_HINT + " is not allowed in the request",
RestServiceErrorCode.BadRequest);
//we can't check all internal keys due to we have test in processRequestTest to test preProcessRequest without handler.
} else if (restRequest.getArgs().containsKey(InternalKeys.SEND_TRACKING_INFO)) {
exception = new RestServiceException(InternalKeys.SEND_TRACKING_INFO + " is not allowed in the request",
RestServiceErrorCode.BadRequest);
}
restRequest.setArg(InternalKeys.SEND_TRACKING_INFO, frontendConfig.attachTrackingInfo);
if (exception == null && urlSigningService.isRequestSigned(restRequest)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4228,6 +4228,9 @@ public Object setArg(String key, Object value) {
throw new IllegalStateException("Not implemented");
}

@Override
public void removeArg(String key) { throw new IllegalStateException("Not implemented"); }

@Override
public SSLSession getSSLSession() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@
import io.netty.handler.timeout.IdleStateEvent;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static com.github.ambry.rest.RestUtils.*;
import static com.github.ambry.rest.RestUtils.Headers.*;
import static com.github.ambry.rest.RestUtils.InternalKeys.*;


/**
Expand Down Expand Up @@ -379,6 +383,7 @@ private boolean handleContent(HttpContent httpContent) throws RestServiceExcepti
if (success && (!isPutOrPost || isMultipart || hasContinue)) {
if (hasContinue) {
request.setArg(EXPECT, "");
removeInternalKeyFromRequest();
responseChannel = new NettyResponseChannel(ctx, nettyMetrics, performanceConfig, nettyConfig);
// FIXME: The request could be accepted as ctor arg to NettyResponseChannel to avoid null pointers
responseChannel.setRequest(request);
Expand All @@ -396,6 +401,22 @@ private boolean handleContent(HttpContent httpContent) throws RestServiceExcepti
return success;
}

/**
* Remove the internal key after we send back 100-continue back to customer.
*/
private void removeInternalKeyFromRequest() {
Set<String> internalKeysNeedToBeRemoved = new HashSet<>();
for (Map.Entry<String, Object> requestArg : request.getArgs().entrySet()) {
String requestKey = requestArg.getKey();
if (requestKey.startsWith(KEY_PREFIX)) {
internalKeysNeedToBeRemoved.add(requestKey);
}
}
for (String internalKeyNeedToBeRemoved : internalKeysNeedToBeRemoved) {
request.removeArg(internalKeyNeedToBeRemoved);
}
}

/**
* Resets the state of the processor in preparation for the next request.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ public Object setArg(String key, Object value) {
return allArgs.put(key, value);
}

@Override
public void removeArg(String key) { allArgs.remove(key); }

@Override
public SSLSession getSSLSession() {
return sslSession;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ public Object setArg(String key, Object value) {
return restRequest.setArg(key, value);
}

@Override
public void removeArg(String key) { restRequest.removeArg(key); }

@Override
public SSLSession getSSLSession() {
return restRequest.getSSLSession();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,9 @@ public Object setArg(String key, Object value) {
return null;
}

@Override
public void removeArg(String key) { throw new IllegalStateException("Not implemented"); }

@Override
public SSLSession getSSLSession() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class MockRestRequest implements RestRequest {
* List of "events" (function calls) that can occur inside MockRestRequest.
*/
public enum Event {
GetRestMethod, GetPath, GetUri, GetArgs, SetArgs, GetSize, ReadInto, IsOpen, Close, GetMetricsTracker
GetRestMethod, GetPath, GetUri, GetArgs, SetArgs, RemoveArgs, GetSize, ReadInto, IsOpen, Close, GetMetricsTracker
}

/**
Expand Down Expand Up @@ -183,6 +183,12 @@ public Object setArg(String key, Object value) {
return args.put(key, value);
}

@Override
public void removeArg(String key) {
onEventComplete(Event.RemoveArgs);
args.remove(key);
}

@Override
public SSLSession getSSLSession() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ public Object setArg(String key, Object value) {
return args.put(key, value);
}

@Override
public void removeArg(String key) { args.remove(key); }

@Override
public SSLSession getSSLSession() {
return null;
Expand Down

0 comments on commit eae3750

Please sign in to comment.