Skip to content

Commit

Permalink
Use Map to manage names of MDC fields.
Browse files Browse the repository at this point in the history
  • Loading branch information
Chavjoh committed Jan 20, 2024
1 parent 201b777 commit ee297dd
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -27,7 +27,7 @@ The logging of requests and responses is done through a filter that can be activ

It will add the following information to MDC for the request processing:

* Request identifier (random UUID)
* Request identifier (from X-Request-ID header or random UUID)
* Request HTTP method
* Request URI path relative to the base URI
* Resource class matched by the current request
Expand Down
53 changes: 53 additions & 0 deletions src/main/java/com/chavaillaz/jakarta/rs/LoggedField.java
@@ -0,0 +1,53 @@
package com.chavaillaz.jakarta.rs;

import java.util.HashMap;
import java.util.Map;
import java.util.stream.Stream;

/**
* List of context fields to be written in MDC.
*/
public enum LoggedField {

REQUEST_ID("request-id"),
REQUEST_METHOD("request-method"),
REQUEST_URI("request-uri"),
REQUEST_BODY("request-body"),
RESPONSE_BODY("response-body"),
RESPONSE_STATUS("response-status"),
RESOURCE_CLASS("resource-class"),
RESOURCE_METHOD("resource-method"),
DURATION("duration");

private final String defaultField;

/**
* Creates a new context field to be logged in MDC.
*
* @param defaultField The default MDC field name
*/
LoggedField(String defaultField) {
this.defaultField = defaultField;
}

/**
* Gets a {@link Map} with the enumeration name as key and the default field name as value.
*
* @return The corresponding {@link Map}
*/
public static Map<String, String> getDefaultFields() {
Map<String, String> map = new HashMap<>();
Stream.of(LoggedField.values()).forEach(entry -> map.put(entry.name(), entry.getDefaultField()));
return map;
}

/**
* Gets the default MDC field name to be used.
*
* @return The default name
*/
public String getDefaultField() {
return this.defaultField;
}

}
70 changes: 40 additions & 30 deletions src/main/java/com/chavaillaz/jakarta/rs/LoggedFilter.java
@@ -1,5 +1,15 @@
package com.chavaillaz.jakarta.rs;

import static com.chavaillaz.jakarta.rs.LoggedField.DURATION;
import static com.chavaillaz.jakarta.rs.LoggedField.REQUEST_BODY;
import static com.chavaillaz.jakarta.rs.LoggedField.REQUEST_ID;
import static com.chavaillaz.jakarta.rs.LoggedField.REQUEST_METHOD;
import static com.chavaillaz.jakarta.rs.LoggedField.REQUEST_URI;
import static com.chavaillaz.jakarta.rs.LoggedField.RESOURCE_CLASS;
import static com.chavaillaz.jakarta.rs.LoggedField.RESOURCE_METHOD;
import static com.chavaillaz.jakarta.rs.LoggedField.RESPONSE_BODY;
import static com.chavaillaz.jakarta.rs.LoggedField.RESPONSE_STATUS;
import static com.chavaillaz.jakarta.rs.LoggedField.getDefaultFields;
import static java.lang.String.valueOf;
import static java.lang.System.nanoTime;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand All @@ -13,6 +23,7 @@
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.Optional;

import jakarta.ws.rs.container.ContainerRequestContext;
Expand Down Expand Up @@ -56,16 +67,16 @@ public class LoggedFilter implements ContainerRequestFilter, ContainerResponseFi

protected static final Logger log = LoggerFactory.getLogger(LoggedFilter.class);

protected static final String DURATION = "duration";
protected static final String REQUEST_ID = "request-id";
/**
* Name of the property stored in container context to compute the duration time.
*/
protected static final String REQUEST_TIME = "request-time";
protected static final String REQUEST_METHOD = "request-method";
protected static final String REQUEST_URI = "request-uri";
protected static final String REQUEST_BODY = "request-body";
protected static final String RESPONSE_BODY = "response-body";
protected static final String RESPONSE_STATUS = "response-status";
protected static final String RESOURCE_CLASS = "resource-class";
protected static final String RESOURCE_METHOD = "resource-method";

/**
* Names of MDC fields to be used for all logged fields.
* Allows changes from children classes.
*/
protected final Map<String, String> mdcFields = getDefaultFields();

@Context
ResourceInfo resourceInfo;
Expand Down Expand Up @@ -93,14 +104,23 @@ protected static <A extends Annotation> Optional<A> getAnnotation(ResourceInfo r

/**
* Gets the annotation type used to activate this filter.
* Can be overridden for more specific annotation to be used when extending this filter.
*
* @return The annotation type
*/
protected Optional<Logged> getAnnotation() {
return getAnnotation(resourceInfo, Logged.class);
}

/**
* Puts a diagnostic context value identified by the given field into the current thread's context map.
*
* @param field The field for which put the given value
* @param value The value to put
*/
private void putMdc(LoggedField field, String value) {
MDC.put(mdcFields.get(field.name()), value);
}

/**
* Gets the request identifier that will be stored in MDC for the complete request processing.
* Returns the header value of {@code X-Request-ID} or a random UUID when not present.
Expand All @@ -117,25 +137,23 @@ protected String getRequestId(ContainerRequestContext requestContext) {

@Override
public void filter(ContainerRequestContext requestContext) {
String requestId = getRequestId(requestContext);
requestContext.setProperty(REQUEST_ID, requestId);
MDC.put(REQUEST_ID, requestId);
putMdc(REQUEST_ID, getRequestId(requestContext));

requestContext.setProperty(REQUEST_TIME, nanoTime());

Optional.of(requestContext.getUriInfo())
.map(UriInfo::getPath)
.ifPresent(path -> MDC.put(REQUEST_URI, path));
.ifPresent(path -> putMdc(REQUEST_URI, path));

MDC.put(REQUEST_METHOD, requestContext.getMethod());
putMdc(REQUEST_METHOD, requestContext.getMethod());

Optional.ofNullable(resourceInfo.getResourceClass())
.map(Class::getSimpleName)
.ifPresent(value -> MDC.put(RESOURCE_CLASS, value));
.ifPresent(value -> putMdc(RESOURCE_CLASS, value));

Optional.ofNullable(resourceInfo.getResourceMethod())
.map(Method::getName)
.ifPresent(value -> MDC.put(RESOURCE_METHOD, value));
.ifPresent(value -> putMdc(RESOURCE_METHOD, value));
}

@Override
Expand All @@ -145,9 +163,9 @@ public void filter(ContainerRequestContext requestContext, ContainerResponseCont
.map(Long::parseLong)
.orElse(nanoTime());
long duration = (nanoTime() - requestStartTime) / 1_000_000;
MDC.put(DURATION, valueOf(duration));
putMdc(DURATION, valueOf(duration));

MDC.put(RESPONSE_STATUS, valueOf(responseContext.getStatus()));
putMdc(RESPONSE_STATUS, valueOf(responseContext.getStatus()));

logRequestBody(requestContext);
logResponseBody(responseContext);
Expand All @@ -170,7 +188,7 @@ protected void logRequestBody(ContainerRequestContext requestContext) {
getAnnotation()
.map(Logged::requestBody)
.filter(loggingActivated -> loggingActivated && requestContext.hasEntity())
.ifPresent(logging -> MDC.put(REQUEST_BODY, getRequestBody(requestContext)));
.ifPresent(logging -> putMdc(REQUEST_BODY, getRequestBody(requestContext)));
}

/**
Expand All @@ -182,7 +200,7 @@ protected void logResponseBody(ContainerResponseContext responseContext) {
getAnnotation()
.map(Logged::responseBody)
.filter(loggingActivated -> loggingActivated && responseContext.hasEntity())
.ifPresent(logging -> MDC.put(RESPONSE_BODY, getResponseBody(responseContext)));
.ifPresent(logging -> putMdc(RESPONSE_BODY, getResponseBody(responseContext)));
}

/**
Expand Down Expand Up @@ -240,15 +258,7 @@ protected String getResponseBody(ContainerResponseContext responseContext) {
* and {@link #filter(ContainerRequestContext, ContainerResponseContext)}.
*/
protected void cleanupMdc() {
MDC.remove(DURATION);
MDC.remove(REQUEST_ID);
MDC.remove(REQUEST_BODY);
MDC.remove(REQUEST_URI);
MDC.remove(REQUEST_METHOD);
MDC.remove(RESPONSE_BODY);
MDC.remove(RESPONSE_STATUS);
MDC.remove(RESOURCE_CLASS);
MDC.remove(RESOURCE_METHOD);
mdcFields.values().forEach(MDC::remove);
}

}
58 changes: 33 additions & 25 deletions src/test/java/com/chavaillaz/jakarta/rs/LoggedFilterTest.java
@@ -1,14 +1,14 @@
package com.chavaillaz.jakarta.rs;

import static com.chavaillaz.jakarta.rs.LoggedFilter.DURATION;
import static com.chavaillaz.jakarta.rs.LoggedFilter.REQUEST_BODY;
import static com.chavaillaz.jakarta.rs.LoggedFilter.REQUEST_ID;
import static com.chavaillaz.jakarta.rs.LoggedFilter.REQUEST_METHOD;
import static com.chavaillaz.jakarta.rs.LoggedFilter.REQUEST_URI;
import static com.chavaillaz.jakarta.rs.LoggedFilter.RESOURCE_CLASS;
import static com.chavaillaz.jakarta.rs.LoggedFilter.RESOURCE_METHOD;
import static com.chavaillaz.jakarta.rs.LoggedFilter.RESPONSE_BODY;
import static com.chavaillaz.jakarta.rs.LoggedFilter.RESPONSE_STATUS;
import static com.chavaillaz.jakarta.rs.LoggedField.DURATION;
import static com.chavaillaz.jakarta.rs.LoggedField.REQUEST_BODY;
import static com.chavaillaz.jakarta.rs.LoggedField.REQUEST_ID;
import static com.chavaillaz.jakarta.rs.LoggedField.REQUEST_METHOD;
import static com.chavaillaz.jakarta.rs.LoggedField.REQUEST_URI;
import static com.chavaillaz.jakarta.rs.LoggedField.RESOURCE_CLASS;
import static com.chavaillaz.jakarta.rs.LoggedField.RESOURCE_METHOD;
import static com.chavaillaz.jakarta.rs.LoggedField.RESPONSE_BODY;
import static com.chavaillaz.jakarta.rs.LoggedField.RESPONSE_STATUS;
import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE;
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static java.lang.Integer.parseInt;
Expand Down Expand Up @@ -91,11 +91,11 @@ void filterRequestCheckMdc() throws URISyntaxException {
requestLoggingFilter.filter(requestContext);

// Then
assertNotNull(MDC.get(REQUEST_ID));
assertEquals(requestContext.getUriInfo().getPath(), MDC.get(REQUEST_URI));
assertEquals(requestContext.getMethod(), MDC.get(REQUEST_METHOD));
assertEquals(getClass().getSimpleName(), MDC.get(RESOURCE_CLASS));
assertEquals("setupTest", MDC.get(RESOURCE_METHOD));
assertNotNull(getMdc(REQUEST_ID));
assertEquals(requestContext.getUriInfo().getPath(), getMdc(REQUEST_URI));
assertEquals(requestContext.getMethod(), getMdc(REQUEST_METHOD));
assertEquals(getClass().getSimpleName(), getMdc(RESOURCE_CLASS));
assertEquals("setupTest", getMdc(RESOURCE_METHOD));
}

@Test
Expand All @@ -111,15 +111,15 @@ void filterResponseCheckLog() throws Exception {
requestLoggingFilter.filter(requestContext, responseContext);

// Then
assertNotNull(getLoggedMdc(REQUEST_ID));
assertEquals(requestContext.getUriInfo().getPath(), getLoggedMdc(REQUEST_URI));
assertEquals(requestContext.getMethod(), getLoggedMdc(REQUEST_METHOD));
assertEquals(getClass().getSimpleName(), getLoggedMdc(RESOURCE_CLASS));
assertEquals("setupTest", getLoggedMdc(RESOURCE_METHOD));
assertEquals(responseContext.getHttpResponse().getStatus(), parseInt(getLoggedMdc(RESPONSE_STATUS)));
assertNotNull(getLoggedMdc(DURATION));
assertEquals(INPUT, getLoggedMdc(REQUEST_BODY));
assertEquals(OUTPUT, getLoggedMdc(RESPONSE_BODY));
assertNotNull(getMdcLogged(REQUEST_ID));
assertEquals(requestContext.getUriInfo().getPath(), getMdcLogged(REQUEST_URI));
assertEquals(requestContext.getMethod(), getMdcLogged(REQUEST_METHOD));
assertEquals(getClass().getSimpleName(), getMdcLogged(RESOURCE_CLASS));
assertEquals("setupTest", getMdcLogged(RESOURCE_METHOD));
assertEquals(responseContext.getHttpResponse().getStatus(), parseInt(getMdcLogged(RESPONSE_STATUS)));
assertNotNull(getMdcLogged(DURATION));
assertEquals(INPUT, getMdcLogged(REQUEST_BODY));
assertEquals(OUTPUT, getMdcLogged(RESPONSE_BODY));
}

private PreMatchContainerRequestContext getRequestContext() throws URISyntaxException {
Expand All @@ -139,11 +139,19 @@ private ContainerResponseContextImpl getResponseContext(PreMatchContainerRequest
return new ContainerResponseContextImpl(request.getHttpRequest(), httpResponse, builtResponse);
}

private String getLoggedMdc(String key) {
private String getMdc(LoggedField field) {
return MDC.get(getMdcField(field));
}

private String getMdcField(LoggedField field) {
return requestLoggingFilter.mdcFields.get(field.name());
}

private String getMdcLogged(LoggedField key) {
return LIST_APPENDER.getMessages().stream()
.filter(log -> log.getMessage().getFormattedMessage().startsWith("Processed"))
.map(LogEvent::getContextData)
.map(mdc -> mdc.getValue(key))
.map(mdc -> mdc.getValue(getMdcField(key)))
.map(Object::toString)
.findFirst()
.orElse(null);
Expand Down

0 comments on commit ee297dd

Please sign in to comment.